Creating Neural Networks with Scikit-Learn and Keras#

CSC/DSC 340 Week 7 Slides (Part 2)

Author: Dr. Julie Butler

Date Created: September 28, 2023

Last Modified: September 28, 2023

The Data Set#

In this notebook, we will be attempting to use neural networks to classify the iris data set using both Scikit-Learn and Keras libraries

##############################
##          IMPORTS         ##
##############################
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
# Load the iris dataset from sklearn
iris = load_iris()

# Convert the iris dataset to a pandas dataframe
iris_data = pd.DataFrame(iris.data, columns=iris.feature_names)

# Add the target variable to the dataframe
iris_data['target'] = iris.target
sns.pairplot(iris_data, hue='target')
/Users/juliehartley/Library/Python/3.9/lib/python/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
  self._figure.tight_layout(*args, **kwargs)
<seaborn.axisgrid.PairGrid at 0x137983880>
_images/1a3070f162f027a78c17b903f2ebc1ce6395e1bbed352105d12724a4bcfd6de9.png

Neural Networks in Scikit-Learn with Hyperparameter Tuning#

  • Scikit-Learn does have neural network implementations but they are called MLPClassifier (for classification problems) and MLPRegressor (for regression problems).

  • MLP stands for multi-layer perceptron and is another name for a simple feedforward neural network

##############################
##          IMPORTS         ##
##############################
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
# Load the Iris dataset
X,y = load_iris(return_X_y=True)
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
  • Scaling/standardizing the data is optional with neural networks but its a good thing to test to optimize the performance

  • Its also a good idea to explore PCA and feature engineering during this point of the notebook

# Standardize the feature values (mean=0, std=1)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
  • Define the neural network with two hidden layers, the first having 3 neurons and the second having 2 neurons.

  • The network will train for a maximum of 1,000 iterations or until the change between iterations is less than \(1x10^{-4}\) (whichever happens first)

  • Note that Scikit-Learn only allows for the activation function to be set for the entire network, not per layer

    • Rectified linear unit (ReLU) by default

# Create an MLP classifier
mlp = MLPClassifier(hidden_layer_sizes=(3, 2), max_iter=1000)
  • Fit/train the neural network

# Train the classifier on the training data
mlp.fit(X_train, y_train)
MLPClassifier(hidden_layer_sizes=(3, 2), max_iter=1000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
  • Predict the classes with the trained model and test the accuracy of the trained model

# Predict the labels for the test set
y_pred = mlp.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
Accuracy: 0.9666666666666667
  • We can also extract the probabilities for each category. The predict_proba method will return a list with a length equal to the number of points in the test set and the number of columns is equal to the number of classes.

  • The column with the highest value for a point corresponds to the most likely class the point will belong to

# Predict the probabilities for the test set
probabilities = mlp.predict_proba(X_test)
# You can print the probabilities for the first few samples as an example
print("Probabilities for the first 5 samples:")
print(probabilities[:5])
Probabilities for the first 5 samples:
[[1.94704534e-02 4.91394714e-02 9.31390075e-01]
 [9.99782614e-01 1.01828774e-04 1.15556766e-04]
 [5.76032033e-03 8.05495943e-03 9.86184720e-01]
 [4.16668392e-02 1.60707746e-01 7.97625415e-01]
 [5.08256250e-02 9.24079385e-01 2.50949901e-02]]
  • Since MLPClassifier is a function within the Scikit-Learn library, its parameters can be tuned using GridSearchCV or RandomizedSearchCV

from sklearn.model_selection import RandomizedSearchCV
# Create an MLP classifier
mlp = MLPClassifier(max_iter=100000)
# Define a hyperparameter grid to search over
param_dist = {
    'hidden_layer_sizes': [(3,2), (3,3,3), (5,5,5), (2,2,2,2)],
    'activation': ['identity', 'logistic', 'tanh', 'relu'],
    'alpha': np.logspace(-15,4,500),
    'learning_rate_init': np.logspace(-15,4,500),
}
# Create RandomizedSearchCV object
random_search = RandomizedSearchCV(mlp, param_distributions=param_dist, n_iter=100, cv=5)

# Fit the RandomizedSearchCV to the training data
random_search.fit(X_train, y_train)
RandomizedSearchCV(cv=5, estimator=MLPClassifier(max_iter=100000), n_iter=100,
                   param_distributions={'activation': ['identity', 'logistic',
                                                       'tanh', 'relu'],
                                        'alpha': array([1.00000000e-15, 1.09163173e-15, 1.19165984e-15, 1.30085370e-15,
       1.42005318e-15, 1.55017512e-15, 1.69222035e-15, 1.84728144e-15,
       2.01655104e-15, 2.20133111e-15, 2.40304289e-15, 2.62323788e-15,
       2.86360...
       1.33121590e+03, 1.45319752e+03, 1.58635653e+03, 1.73171713e+03,
       1.89039738e+03, 2.06361777e+03, 2.25271064e+03, 2.45913043e+03,
       2.68446481e+03, 2.93044698e+03, 3.19896892e+03, 3.49209598e+03,
       3.81208280e+03, 4.16139055e+03, 4.54270599e+03, 4.95896201e+03,
       5.41336030e+03, 5.90939590e+03, 6.45088409e+03, 7.04198979e+03,
       7.68725952e+03, 8.39165644e+03, 9.16059848e+03, 1.00000000e+04])})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
# Print the best hyperparameters
print("Best Hyperparameters:", random_search.best_params_)
Best Hyperparameters: {'learning_rate_init': 0.001988827856988812, 'hidden_layer_sizes': (3, 2), 'alpha': 5.964184945842484e-05, 'activation': 'identity'}
# Get the best model
best_mlp = random_search.best_estimator_

# Predict the labels for the test set
y_pred = best_mlp.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
Accuracy: 1.0

Neural Networks in Keras#

  • Keras is a machine learning library which primarily handles various types of neural networks and provides greater flexibity in construction than Scikit-Learn

  • Keras is built on top of another library called Tensorflow that we will learn about next week

from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Standardize the feature values (mean=0, std=1)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
  • As an extra preprocessing step, we need to covert the data from categorical to binary vectors using the one-hot encoding processs and the Keras function to_categorical

# Convert labels to one-hot encoding
y_train = to_categorical(y_train, num_classes=3)
y_test = to_categorical(y_test, num_classes=3)
  • We are going to create a sequential model which will let us add layers from the first layer to the last in order

# Create a Sequential model
model = Sequential()
  • Add the first hidden layer to the model with 8 neurons and a relu activation function

  • We also have to set the input dimensio (4 in this case) when we add the first hudden layer to the model

# Add an input layer with 4 input nodes (features)
model.add(Dense(8, input_dim=4, activation='relu'))
  • Add a second hidden layer with 8 neurons and a relu activation function

# Add a hidden layer with 8 nodes and ReLU activation
model.add(Dense(8, activation='relu'))
  • Add a final layer (the output layer) which has the appropriate dimension (the number of classes with one-hot encoding) and a softmax activation function

  • Softmax is a popular activation function for the output layer when doing classification

    • Gives probabilities and is useful in multiclass problems

# Add an output layer with 3 nodes (one for each class) and softmax activation
model.add(Dense(3, activation='softmax'))
  • Compile the model using a categorical cross-entropy loss function, and Adam optimizer, and using accuracy as the training and prediction metric since this is a classification problem

# Compile the model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
  • Fit/train the model using 100 training iterations

  • verbose = 1 prints the loss and accuracy per epoch

# Train the model
model.fit(X_train, y_train, epochs=100,verbose=1)
Epoch 1/100
1/4 [======>.......................] - ETA: 0s - loss: 1.3268 - accuracy: 0.3438

4/4 [==============================] - 0s 2ms/step - loss: 1.3405 - accuracy: 0.2583
Epoch 2/100
1/4 [======>.......................] - ETA: 0s - loss: 1.3651 - accuracy: 0.1875

4/4 [==============================] - 0s 1ms/step - loss: 1.3144 - accuracy: 0.2917
Epoch 3/100
1/4 [======>.......................] - ETA: 0s - loss: 1.2619 - accuracy: 0.2500

4/4 [==============================] - 0s 1ms/step - loss: 1.2887 - accuracy: 0.3167
Epoch 4/100
1/4 [======>.......................] - ETA: 0s - loss: 1.3483 - accuracy: 0.2500

4/4 [==============================] - 0s 1ms/step - loss: 1.2663 - accuracy: 0.3250
Epoch 5/100
1/4 [======>.......................] - ETA: 0s - loss: 1.1997 - accuracy: 0.3750

4/4 [==============================] - 0s 1ms/step - loss: 1.2422 - accuracy: 0.3250
Epoch 6/100
1/4 [======>.......................] - ETA: 0s - loss: 1.3234 - accuracy: 0.1875

4/4 [==============================] - 0s 941us/step - loss: 1.2217 - accuracy: 0.3250
Epoch 7/100
1/4 [======>.......................] - ETA: 0s - loss: 1.2220 - accuracy: 0.3125

4/4 [==============================] - 0s 977us/step - loss: 1.2002 - accuracy: 0.3250
Epoch 8/100
1/4 [======>.......................] - ETA: 0s - loss: 1.2599 - accuracy: 0.1875

4/4 [==============================] - 0s 962us/step - loss: 1.1805 - accuracy: 0.3417
Epoch 9/100
1/4 [======>.......................] - ETA: 0s - loss: 1.1567 - accuracy: 0.3750

4/4 [==============================] - 0s 972us/step - loss: 1.1611 - accuracy: 0.3417
Epoch 10/100
1/4 [======>.......................] - ETA: 0s - loss: 1.1419 - accuracy: 0.3125

4/4 [==============================] - 0s 1ms/step - loss: 1.1420 - accuracy: 0.3417
Epoch 11/100
1/4 [======>.......................] - ETA: 0s - loss: 1.2072 - accuracy: 0.2812

4/4 [==============================] - 0s 961us/step - loss: 1.1234 - accuracy: 0.3500
Epoch 12/100
1/4 [======>.......................] - ETA: 0s - loss: 1.1123 - accuracy: 0.3750

4/4 [==============================] - 0s 971us/step - loss: 1.1053 - accuracy: 0.3500
Epoch 13/100
1/4 [======>.......................] - ETA: 0s - loss: 1.1200 - accuracy: 0.4062

4/4 [==============================] - 0s 978us/step - loss: 1.0877 - accuracy: 0.3500
Epoch 14/100
1/4 [======>.......................] - ETA: 0s - loss: 1.1334 - accuracy: 0.2812

4/4 [==============================] - 0s 1ms/step - loss: 1.0702 - accuracy: 0.3417
Epoch 15/100
1/4 [======>.......................] - ETA: 0s - loss: 1.1015 - accuracy: 0.2500

4/4 [==============================] - 0s 962us/step - loss: 1.0535 - accuracy: 0.3417
Epoch 16/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0326 - accuracy: 0.3125

4/4 [==============================] - 0s 949us/step - loss: 1.0368 - accuracy: 0.3500
Epoch 17/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0644 - accuracy: 0.3125

4/4 [==============================] - 0s 929us/step - loss: 1.0207 - accuracy: 0.3500
Epoch 18/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9206 - accuracy: 0.4062

4/4 [==============================] - 0s 953us/step - loss: 1.0036 - accuracy: 0.3500
Epoch 19/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0721 - accuracy: 0.2500

4/4 [==============================] - 0s 952us/step - loss: 0.9859 - accuracy: 0.3583
Epoch 20/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8681 - accuracy: 0.5000

4/4 [==============================] - 0s 968us/step - loss: 0.9680 - accuracy: 0.3833
Epoch 21/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8960 - accuracy: 0.4688

4/4 [==============================] - 0s 983us/step - loss: 0.9495 - accuracy: 0.4000
Epoch 22/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9469 - accuracy: 0.3750

4/4 [==============================] - 0s 1ms/step - loss: 0.9303 - accuracy: 0.4000
Epoch 23/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8956 - accuracy: 0.4688

4/4 [==============================] - 0s 925us/step - loss: 0.9104 - accuracy: 0.4083
Epoch 24/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8850 - accuracy: 0.4062

4/4 [==============================] - 0s 931us/step - loss: 0.8912 - accuracy: 0.4167
Epoch 25/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9656 - accuracy: 0.3438

4/4 [==============================] - 0s 931us/step - loss: 0.8708 - accuracy: 0.4333
Epoch 26/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9209 - accuracy: 0.3438

4/4 [==============================] - 0s 927us/step - loss: 0.8519 - accuracy: 0.4500
Epoch 27/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8491 - accuracy: 0.5000

4/4 [==============================] - 0s 947us/step - loss: 0.8311 - accuracy: 0.4667
Epoch 28/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8438 - accuracy: 0.4688

4/4 [==============================] - 0s 932us/step - loss: 0.8130 - accuracy: 0.5167
Epoch 29/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9142 - accuracy: 0.4375

4/4 [==============================] - 0s 934us/step - loss: 0.7944 - accuracy: 0.6250
Epoch 30/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8146 - accuracy: 0.7188

4/4 [==============================] - 0s 907us/step - loss: 0.7755 - accuracy: 0.6750
Epoch 31/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8278 - accuracy: 0.5625

4/4 [==============================] - 0s 942us/step - loss: 0.7566 - accuracy: 0.7417
Epoch 32/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6619 - accuracy: 0.7812

4/4 [==============================] - 0s 905us/step - loss: 0.7387 - accuracy: 0.7667
Epoch 33/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7156 - accuracy: 0.7812

4/4 [==============================] - 0s 902us/step - loss: 0.7208 - accuracy: 0.7667
Epoch 34/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7396 - accuracy: 0.7812

4/4 [==============================] - 0s 936us/step - loss: 0.7035 - accuracy: 0.7833
Epoch 35/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7152 - accuracy: 0.7500

4/4 [==============================] - 0s 901us/step - loss: 0.6850 - accuracy: 0.7833
Epoch 36/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6616 - accuracy: 0.7812

4/4 [==============================] - 0s 937us/step - loss: 0.6675 - accuracy: 0.7833
Epoch 37/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6917 - accuracy: 0.7812

4/4 [==============================] - 0s 911us/step - loss: 0.6504 - accuracy: 0.7833
Epoch 38/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6057 - accuracy: 0.8125

4/4 [==============================] - 0s 885us/step - loss: 0.6332 - accuracy: 0.7833
Epoch 39/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4887 - accuracy: 0.9375

4/4 [==============================] - 0s 872us/step - loss: 0.6174 - accuracy: 0.7833
Epoch 40/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6288 - accuracy: 0.7188

4/4 [==============================] - 0s 948us/step - loss: 0.6011 - accuracy: 0.8000
Epoch 41/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4499 - accuracy: 0.9062

4/4 [==============================] - 0s 950us/step - loss: 0.5857 - accuracy: 0.8250
Epoch 42/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6230 - accuracy: 0.7812

4/4 [==============================] - 0s 914us/step - loss: 0.5703 - accuracy: 0.8250
Epoch 43/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6185 - accuracy: 0.7188

4/4 [==============================] - 0s 934us/step - loss: 0.5567 - accuracy: 0.8250
Epoch 44/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5354 - accuracy: 0.7812

4/4 [==============================] - 0s 962us/step - loss: 0.5439 - accuracy: 0.8250
Epoch 45/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5158 - accuracy: 0.7812

4/4 [==============================] - 0s 926us/step - loss: 0.5311 - accuracy: 0.8333
Epoch 46/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5707 - accuracy: 0.7188

4/4 [==============================] - 0s 932us/step - loss: 0.5192 - accuracy: 0.8333
Epoch 47/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4646 - accuracy: 0.8750

4/4 [==============================] - 0s 905us/step - loss: 0.5085 - accuracy: 0.8333
Epoch 48/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4492 - accuracy: 0.9062

4/4 [==============================] - 0s 937us/step - loss: 0.4978 - accuracy: 0.8500
Epoch 49/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4461 - accuracy: 0.8125

4/4 [==============================] - 0s 948us/step - loss: 0.4882 - accuracy: 0.8417
Epoch 50/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4904 - accuracy: 0.9062

4/4 [==============================] - 0s 931us/step - loss: 0.4790 - accuracy: 0.8417
Epoch 51/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4824 - accuracy: 0.8125

4/4 [==============================] - 0s 921us/step - loss: 0.4704 - accuracy: 0.8417
Epoch 52/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4670 - accuracy: 0.8750

4/4 [==============================] - 0s 919us/step - loss: 0.4620 - accuracy: 0.8417
Epoch 53/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4610 - accuracy: 0.7812

4/4 [==============================] - 0s 926us/step - loss: 0.4542 - accuracy: 0.8417
Epoch 54/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5092 - accuracy: 0.7812

4/4 [==============================] - 0s 904us/step - loss: 0.4468 - accuracy: 0.8417
Epoch 55/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4312 - accuracy: 0.8438

4/4 [==============================] - 0s 950us/step - loss: 0.4400 - accuracy: 0.8417
Epoch 56/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4731 - accuracy: 0.7812

4/4 [==============================] - 0s 885us/step - loss: 0.4332 - accuracy: 0.8333
Epoch 57/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4051 - accuracy: 0.8125

4/4 [==============================] - 0s 911us/step - loss: 0.4268 - accuracy: 0.8333
Epoch 58/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3241 - accuracy: 0.9062

4/4 [==============================] - 0s 949us/step - loss: 0.4207 - accuracy: 0.8417
Epoch 59/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4406 - accuracy: 0.7812

4/4 [==============================] - 0s 948us/step - loss: 0.4145 - accuracy: 0.8417
Epoch 60/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4297 - accuracy: 0.8125

4/4 [==============================] - 0s 897us/step - loss: 0.4091 - accuracy: 0.8417
Epoch 61/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2814 - accuracy: 0.9688

4/4 [==============================] - 0s 897us/step - loss: 0.4036 - accuracy: 0.8417
Epoch 62/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4050 - accuracy: 0.7812

4/4 [==============================] - 0s 873us/step - loss: 0.3984 - accuracy: 0.8500
Epoch 63/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3278 - accuracy: 0.8750

4/4 [==============================] - 0s 949us/step - loss: 0.3934 - accuracy: 0.8500
Epoch 64/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2931 - accuracy: 0.8438

4/4 [==============================] - 0s 935us/step - loss: 0.3886 - accuracy: 0.8500
Epoch 65/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3248 - accuracy: 0.9062

4/4 [==============================] - 0s 964us/step - loss: 0.3836 - accuracy: 0.8583
Epoch 66/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3231 - accuracy: 0.9062

4/4 [==============================] - 0s 955us/step - loss: 0.3791 - accuracy: 0.8583
Epoch 67/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4565 - accuracy: 0.7812

4/4 [==============================] - 0s 956us/step - loss: 0.3747 - accuracy: 0.8583
Epoch 68/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3927 - accuracy: 0.8125

4/4 [==============================] - 0s 992us/step - loss: 0.3703 - accuracy: 0.8583
Epoch 69/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3874 - accuracy: 0.8438

4/4 [==============================] - 0s 943us/step - loss: 0.3659 - accuracy: 0.8583
Epoch 70/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3591 - accuracy: 0.8750

4/4 [==============================] - 0s 945us/step - loss: 0.3619 - accuracy: 0.8583
Epoch 71/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3942 - accuracy: 0.8438

4/4 [==============================] - 0s 898us/step - loss: 0.3579 - accuracy: 0.8583
Epoch 72/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3732 - accuracy: 0.8750

4/4 [==============================] - 0s 943us/step - loss: 0.3537 - accuracy: 0.8583
Epoch 73/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4047 - accuracy: 0.7812

4/4 [==============================] - 0s 938us/step - loss: 0.3498 - accuracy: 0.8583
Epoch 74/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2800 - accuracy: 0.9375

4/4 [==============================] - 0s 943us/step - loss: 0.3459 - accuracy: 0.8583
Epoch 75/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3775 - accuracy: 0.8125

4/4 [==============================] - 0s 912us/step - loss: 0.3422 - accuracy: 0.8583
Epoch 76/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3868 - accuracy: 0.8750

4/4 [==============================] - 0s 913us/step - loss: 0.3386 - accuracy: 0.8583
Epoch 77/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3813 - accuracy: 0.7812

4/4 [==============================] - 0s 931us/step - loss: 0.3350 - accuracy: 0.8583
Epoch 78/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2254 - accuracy: 0.9375

4/4 [==============================] - 0s 943us/step - loss: 0.3313 - accuracy: 0.8583
Epoch 79/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2161 - accuracy: 0.9062

4/4 [==============================] - 0s 930us/step - loss: 0.3277 - accuracy: 0.8667
Epoch 80/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3260 - accuracy: 0.8750

4/4 [==============================] - 0s 892us/step - loss: 0.3243 - accuracy: 0.8667
Epoch 81/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2589 - accuracy: 0.9375

4/4 [==============================] - 0s 933us/step - loss: 0.3210 - accuracy: 0.8833
Epoch 82/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3936 - accuracy: 0.8438

4/4 [==============================] - 0s 921us/step - loss: 0.3176 - accuracy: 0.9000
Epoch 83/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2402 - accuracy: 0.9688

4/4 [==============================] - 0s 918us/step - loss: 0.3141 - accuracy: 0.9000
Epoch 84/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3041 - accuracy: 0.9062

4/4 [==============================] - 0s 914us/step - loss: 0.3111 - accuracy: 0.9000
Epoch 85/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3022 - accuracy: 0.8438

4/4 [==============================] - 0s 912us/step - loss: 0.3076 - accuracy: 0.9083
Epoch 86/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2786 - accuracy: 0.9375

4/4 [==============================] - 0s 934us/step - loss: 0.3045 - accuracy: 0.9167
Epoch 87/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3659 - accuracy: 0.8438

4/4 [==============================] - 0s 936us/step - loss: 0.3012 - accuracy: 0.9167
Epoch 88/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2383 - accuracy: 0.9688

4/4 [==============================] - 0s 894us/step - loss: 0.2979 - accuracy: 0.9250
Epoch 89/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2341 - accuracy: 0.9688

4/4 [==============================] - 0s 929us/step - loss: 0.2946 - accuracy: 0.9250
Epoch 90/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3218 - accuracy: 0.9062

4/4 [==============================] - 0s 908us/step - loss: 0.2915 - accuracy: 0.9167
Epoch 91/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3261 - accuracy: 0.9062

4/4 [==============================] - 0s 938us/step - loss: 0.2881 - accuracy: 0.9167
Epoch 92/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2645 - accuracy: 1.0000

4/4 [==============================] - 0s 973us/step - loss: 0.2847 - accuracy: 0.9333
Epoch 93/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2322 - accuracy: 0.9688

4/4 [==============================] - 0s 912us/step - loss: 0.2816 - accuracy: 0.9333
Epoch 94/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2847 - accuracy: 0.9062

4/4 [==============================] - 0s 972us/step - loss: 0.2782 - accuracy: 0.9333
Epoch 95/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3439 - accuracy: 0.8750

4/4 [==============================] - 0s 945us/step - loss: 0.2753 - accuracy: 0.9333
Epoch 96/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2988 - accuracy: 0.9062

4/4 [==============================] - 0s 908us/step - loss: 0.2720 - accuracy: 0.9417
Epoch 97/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2358 - accuracy: 0.9688

4/4 [==============================] - 0s 892us/step - loss: 0.2689 - accuracy: 0.9417
Epoch 98/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2876 - accuracy: 0.9375

4/4 [==============================] - 0s 890us/step - loss: 0.2660 - accuracy: 0.9417
Epoch 99/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2282 - accuracy: 1.0000

4/4 [==============================] - 0s 891us/step - loss: 0.2629 - accuracy: 0.9417
Epoch 100/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2338 - accuracy: 0.9688

4/4 [==============================] - 0s 917us/step - loss: 0.2598 - accuracy: 0.9417
<keras.src.callbacks.History at 0x28a221f40>
  • We can then evaluate the performance of the model, both in terms of loss (not particularly helpful unless you understand the loss function values) and in terms of the accuracy

# Evaluate the model on the test set
loss, accuracy = model.evaluate(X_test, y_test)
print("Test Loss:", loss)
print("Test Accuracy:", accuracy)
1/1 [==============================] - ETA: 0s - loss: 0.1953 - accuracy: 0.9333

1/1 [==============================] - 0s 83ms/step - loss: 0.1953 - accuracy: 0.9333
Test Loss: 0.19525906443595886
Test Accuracy: 0.9333333373069763
  • It is relatively common to build the Keras model in a function that returns the model so the model is easy to edit and reuse

def iris_classification_model ():
    # Create a Sequential model
    model = Sequential()

    # Add an input layer with 4 input nodes (features)
    model.add(Dense(8, input_dim=4, activation='relu'))

    # Add a hidden layer with 8 nodes and ReLU activation
    model.add(Dense(8, activation='relu'))

    # Add an output layer with 3 nodes (one for each class) and softmax activation
    model.add(Dense(3, activation='softmax'))

    # Compile the model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    return model

model = iris_classification_model()

# Train the model
model.fit(X_train, y_train, epochs=100,verbose=1)

# Evaluate the model on the test set
loss, accuracy = model.evaluate(X_test, y_test)
print("Test Loss:", loss)
print("Test Accuracy:", accuracy)
Epoch 1/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0809 - accuracy: 0.3438

4/4 [==============================] - 0s 1ms/step - loss: 1.0699 - accuracy: 0.2667
Epoch 2/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0862 - accuracy: 0.1875

4/4 [==============================] - 0s 1ms/step - loss: 1.0572 - accuracy: 0.2833
Epoch 3/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0769 - accuracy: 0.2812

4/4 [==============================] - 0s 1ms/step - loss: 1.0450 - accuracy: 0.3250
Epoch 4/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0499 - accuracy: 0.3750

4/4 [==============================] - 0s 943us/step - loss: 1.0318 - accuracy: 0.3833
Epoch 5/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9543 - accuracy: 0.5312

4/4 [==============================] - 0s 1ms/step - loss: 1.0208 - accuracy: 0.4083
Epoch 6/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9509 - accuracy: 0.5312

4/4 [==============================] - 0s 975us/step - loss: 1.0080 - accuracy: 0.4333
Epoch 7/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0012 - accuracy: 0.5625

4/4 [==============================] - 0s 1ms/step - loss: 0.9957 - accuracy: 0.4750
Epoch 8/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9409 - accuracy: 0.5938

4/4 [==============================] - 0s 951us/step - loss: 0.9840 - accuracy: 0.5000
Epoch 9/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0001 - accuracy: 0.5625

4/4 [==============================] - 0s 990us/step - loss: 0.9720 - accuracy: 0.5250
Epoch 10/100
1/4 [======>.......................] - ETA: 0s - loss: 1.0227 - accuracy: 0.5312

4/4 [==============================] - 0s 960us/step - loss: 0.9599 - accuracy: 0.5583
Epoch 11/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9645 - accuracy: 0.5312

4/4 [==============================] - 0s 900us/step - loss: 0.9481 - accuracy: 0.5750
Epoch 12/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8970 - accuracy: 0.6250

4/4 [==============================] - 0s 899us/step - loss: 0.9358 - accuracy: 0.6000
Epoch 13/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9406 - accuracy: 0.6562

4/4 [==============================] - 0s 948us/step - loss: 0.9235 - accuracy: 0.6583
Epoch 14/100
1/4 [======>.......................] - ETA: 0s - loss: 0.9561 - accuracy: 0.6562

4/4 [==============================] - 0s 945us/step - loss: 0.9104 - accuracy: 0.6917
Epoch 15/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8187 - accuracy: 0.8125

4/4 [==============================] - 0s 888us/step - loss: 0.8986 - accuracy: 0.7333
Epoch 16/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8665 - accuracy: 0.8750

4/4 [==============================] - 0s 904us/step - loss: 0.8854 - accuracy: 0.7583
Epoch 17/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8563 - accuracy: 0.7500

4/4 [==============================] - 0s 981us/step - loss: 0.8721 - accuracy: 0.8000
Epoch 18/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8604 - accuracy: 0.8750

4/4 [==============================] - 0s 951us/step - loss: 0.8593 - accuracy: 0.7917
Epoch 19/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8503 - accuracy: 0.9062

4/4 [==============================] - 0s 935us/step - loss: 0.8461 - accuracy: 0.8000
Epoch 20/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8109 - accuracy: 0.8438

4/4 [==============================] - 0s 998us/step - loss: 0.8326 - accuracy: 0.8000
Epoch 21/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8561 - accuracy: 0.7812

4/4 [==============================] - 0s 1ms/step - loss: 0.8196 - accuracy: 0.8000
Epoch 22/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8172 - accuracy: 0.8438

4/4 [==============================] - 0s 941us/step - loss: 0.8063 - accuracy: 0.8000
Epoch 23/100
1/4 [======>.......................] - ETA: 0s - loss: 0.8238 - accuracy: 0.7500

4/4 [==============================] - 0s 945us/step - loss: 0.7933 - accuracy: 0.8000
Epoch 24/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7565 - accuracy: 0.8750

4/4 [==============================] - 0s 924us/step - loss: 0.7802 - accuracy: 0.8000
Epoch 25/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7974 - accuracy: 0.7500

4/4 [==============================] - 0s 917us/step - loss: 0.7674 - accuracy: 0.8000
Epoch 26/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7754 - accuracy: 0.7500

4/4 [==============================] - 0s 933us/step - loss: 0.7544 - accuracy: 0.8000
Epoch 27/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7919 - accuracy: 0.7188

4/4 [==============================] - 0s 883us/step - loss: 0.7424 - accuracy: 0.8000
Epoch 28/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7171 - accuracy: 0.8125

4/4 [==============================] - 0s 944us/step - loss: 0.7306 - accuracy: 0.8083
Epoch 29/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6995 - accuracy: 0.8125

4/4 [==============================] - 0s 899us/step - loss: 0.7186 - accuracy: 0.8083
Epoch 30/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7318 - accuracy: 0.8125

4/4 [==============================] - 0s 947us/step - loss: 0.7074 - accuracy: 0.8083
Epoch 31/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6690 - accuracy: 0.7812

4/4 [==============================] - 0s 938us/step - loss: 0.6962 - accuracy: 0.8083
Epoch 32/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6466 - accuracy: 0.8438

4/4 [==============================] - 0s 908us/step - loss: 0.6853 - accuracy: 0.8083
Epoch 33/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7135 - accuracy: 0.7812

4/4 [==============================] - 0s 948us/step - loss: 0.6748 - accuracy: 0.8083
Epoch 34/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6712 - accuracy: 0.8125

4/4 [==============================] - 0s 896us/step - loss: 0.6641 - accuracy: 0.8083
Epoch 35/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7330 - accuracy: 0.7812

4/4 [==============================] - 0s 920us/step - loss: 0.6541 - accuracy: 0.8000
Epoch 36/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7252 - accuracy: 0.6250

4/4 [==============================] - 0s 879us/step - loss: 0.6439 - accuracy: 0.8000
Epoch 37/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5730 - accuracy: 0.8125

4/4 [==============================] - 0s 874us/step - loss: 0.6343 - accuracy: 0.8000
Epoch 38/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6810 - accuracy: 0.7188

4/4 [==============================] - 0s 912us/step - loss: 0.6246 - accuracy: 0.8000
Epoch 39/100
1/4 [======>.......................] - ETA: 0s - loss: 0.7730 - accuracy: 0.6562

4/4 [==============================] - 0s 917us/step - loss: 0.6155 - accuracy: 0.8000
Epoch 40/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6607 - accuracy: 0.7500

4/4 [==============================] - 0s 964us/step - loss: 0.6066 - accuracy: 0.7917
Epoch 41/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5675 - accuracy: 0.8438

4/4 [==============================] - 0s 880us/step - loss: 0.5975 - accuracy: 0.7917
Epoch 42/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6143 - accuracy: 0.7812

4/4 [==============================] - 0s 928us/step - loss: 0.5888 - accuracy: 0.8000
Epoch 43/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6384 - accuracy: 0.7500

4/4 [==============================] - 0s 891us/step - loss: 0.5798 - accuracy: 0.8000
Epoch 44/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6405 - accuracy: 0.8125

4/4 [==============================] - 0s 899us/step - loss: 0.5710 - accuracy: 0.8083
Epoch 45/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6319 - accuracy: 0.8125

4/4 [==============================] - 0s 896us/step - loss: 0.5626 - accuracy: 0.8167
Epoch 46/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4603 - accuracy: 0.9688

4/4 [==============================] - 0s 963us/step - loss: 0.5544 - accuracy: 0.8167
Epoch 47/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4922 - accuracy: 0.9062

4/4 [==============================] - 0s 916us/step - loss: 0.5466 - accuracy: 0.8167
Epoch 48/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4327 - accuracy: 0.8438

4/4 [==============================] - 0s 917us/step - loss: 0.5388 - accuracy: 0.8167
Epoch 49/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6353 - accuracy: 0.7188

4/4 [==============================] - 0s 885us/step - loss: 0.5314 - accuracy: 0.8167
Epoch 50/100
1/4 [======>.......................] - ETA: 0s - loss: 0.6181 - accuracy: 0.7188

4/4 [==============================] - 0s 931us/step - loss: 0.5241 - accuracy: 0.8167
Epoch 51/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4854 - accuracy: 0.8750

4/4 [==============================] - 0s 919us/step - loss: 0.5175 - accuracy: 0.8250
Epoch 52/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5903 - accuracy: 0.7812

4/4 [==============================] - 0s 889us/step - loss: 0.5110 - accuracy: 0.8333
Epoch 53/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5566 - accuracy: 0.8438

4/4 [==============================] - 0s 891us/step - loss: 0.5044 - accuracy: 0.8333
Epoch 54/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3977 - accuracy: 0.8750

4/4 [==============================] - 0s 905us/step - loss: 0.4982 - accuracy: 0.8333
Epoch 55/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4636 - accuracy: 0.8438

4/4 [==============================] - 0s 928us/step - loss: 0.4922 - accuracy: 0.8333
Epoch 56/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5378 - accuracy: 0.8438

4/4 [==============================] - 0s 912us/step - loss: 0.4862 - accuracy: 0.8333
Epoch 57/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5691 - accuracy: 0.8750

4/4 [==============================] - 0s 918us/step - loss: 0.4801 - accuracy: 0.8333
Epoch 58/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4979 - accuracy: 0.8125

4/4 [==============================] - 0s 906us/step - loss: 0.4742 - accuracy: 0.8417
Epoch 59/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4888 - accuracy: 0.8438

4/4 [==============================] - 0s 904us/step - loss: 0.4684 - accuracy: 0.8500
Epoch 60/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5074 - accuracy: 0.8125

4/4 [==============================] - 0s 922us/step - loss: 0.4630 - accuracy: 0.8500
Epoch 61/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5290 - accuracy: 0.8125

4/4 [==============================] - 0s 912us/step - loss: 0.4573 - accuracy: 0.8583
Epoch 62/100
1/4 [======>.......................] - ETA: 0s - loss: 0.5091 - accuracy: 0.8750

4/4 [==============================] - 0s 943us/step - loss: 0.4516 - accuracy: 0.8667
Epoch 63/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4871 - accuracy: 0.8125

4/4 [==============================] - 0s 958us/step - loss: 0.4463 - accuracy: 0.8667
Epoch 64/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4869 - accuracy: 0.8438

4/4 [==============================] - 0s 922us/step - loss: 0.4408 - accuracy: 0.8667
Epoch 65/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4216 - accuracy: 0.9062

4/4 [==============================] - 0s 919us/step - loss: 0.4353 - accuracy: 0.8750
Epoch 66/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3910 - accuracy: 0.8750

4/4 [==============================] - 0s 941us/step - loss: 0.4302 - accuracy: 0.8750
Epoch 67/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3894 - accuracy: 0.9062

4/4 [==============================] - 0s 909us/step - loss: 0.4247 - accuracy: 0.8833
Epoch 68/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3747 - accuracy: 0.8750

4/4 [==============================] - 0s 895us/step - loss: 0.4195 - accuracy: 0.8833
Epoch 69/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4545 - accuracy: 0.8750

4/4 [==============================] - 0s 941us/step - loss: 0.4143 - accuracy: 0.8917
Epoch 70/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3190 - accuracy: 0.9062

4/4 [==============================] - 0s 918us/step - loss: 0.4091 - accuracy: 0.8917
Epoch 71/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3827 - accuracy: 0.9062

4/4 [==============================] - 0s 964us/step - loss: 0.4041 - accuracy: 0.9000
Epoch 72/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4304 - accuracy: 0.8438

4/4 [==============================] - 0s 935us/step - loss: 0.3991 - accuracy: 0.9000
Epoch 73/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3731 - accuracy: 0.8438

4/4 [==============================] - 0s 882us/step - loss: 0.3935 - accuracy: 0.9000
Epoch 74/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3319 - accuracy: 0.9062

4/4 [==============================] - 0s 925us/step - loss: 0.3884 - accuracy: 0.8917
Epoch 75/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3949 - accuracy: 0.9375

4/4 [==============================] - 0s 938us/step - loss: 0.3835 - accuracy: 0.8917
Epoch 76/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4267 - accuracy: 0.8438

4/4 [==============================] - 0s 967us/step - loss: 0.3782 - accuracy: 0.8917
Epoch 77/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3877 - accuracy: 0.8750

4/4 [==============================] - 0s 911us/step - loss: 0.3733 - accuracy: 0.9000
Epoch 78/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4064 - accuracy: 0.9062

4/4 [==============================] - 0s 970us/step - loss: 0.3680 - accuracy: 0.9083
Epoch 79/100
1/4 [======>.......................] - ETA: 0s - loss: 0.4329 - accuracy: 0.8438

4/4 [==============================] - 0s 878us/step - loss: 0.3630 - accuracy: 0.9083
Epoch 80/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3812 - accuracy: 0.9062

4/4 [==============================] - 0s 871us/step - loss: 0.3579 - accuracy: 0.9083
Epoch 81/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3878 - accuracy: 0.9062

4/4 [==============================] - 0s 903us/step - loss: 0.3529 - accuracy: 0.9083
Epoch 82/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2822 - accuracy: 1.0000

4/4 [==============================] - 0s 900us/step - loss: 0.3479 - accuracy: 0.9167
Epoch 83/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2392 - accuracy: 0.9375

4/4 [==============================] - 0s 954us/step - loss: 0.3428 - accuracy: 0.9167
Epoch 84/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3072 - accuracy: 0.9062

4/4 [==============================] - 0s 885us/step - loss: 0.3379 - accuracy: 0.9167
Epoch 85/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3504 - accuracy: 0.9375

4/4 [==============================] - 0s 914us/step - loss: 0.3332 - accuracy: 0.9167
Epoch 86/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3288 - accuracy: 0.9062

4/4 [==============================] - 0s 912us/step - loss: 0.3287 - accuracy: 0.9167
Epoch 87/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3978 - accuracy: 0.8750

4/4 [==============================] - 0s 950us/step - loss: 0.3240 - accuracy: 0.9167
Epoch 88/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2803 - accuracy: 0.9688

4/4 [==============================] - 0s 928us/step - loss: 0.3194 - accuracy: 0.9167
Epoch 89/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3247 - accuracy: 0.9062

4/4 [==============================] - 0s 917us/step - loss: 0.3147 - accuracy: 0.9250
Epoch 90/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3525 - accuracy: 0.8750

4/4 [==============================] - 0s 922us/step - loss: 0.3103 - accuracy: 0.9250
Epoch 91/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3633 - accuracy: 0.8750

4/4 [==============================] - 0s 963us/step - loss: 0.3060 - accuracy: 0.9333
Epoch 92/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3295 - accuracy: 0.9375

4/4 [==============================] - 0s 945us/step - loss: 0.3018 - accuracy: 0.9417
Epoch 93/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2194 - accuracy: 1.0000

4/4 [==============================] - 0s 969us/step - loss: 0.2975 - accuracy: 0.9417
Epoch 94/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2704 - accuracy: 0.9688

4/4 [==============================] - 0s 869us/step - loss: 0.2935 - accuracy: 0.9417
Epoch 95/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3747 - accuracy: 0.9062

4/4 [==============================] - 0s 943us/step - loss: 0.2892 - accuracy: 0.9417
Epoch 96/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3402 - accuracy: 0.8438

4/4 [==============================] - 0s 928us/step - loss: 0.2850 - accuracy: 0.9417
Epoch 97/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2321 - accuracy: 0.9688

4/4 [==============================] - 0s 963us/step - loss: 0.2809 - accuracy: 0.9417
Epoch 98/100
1/4 [======>.......................] - ETA: 0s - loss: 0.3644 - accuracy: 0.8750

4/4 [==============================] - 0s 928us/step - loss: 0.2770 - accuracy: 0.9417
Epoch 99/100
1/4 [======>.......................] - ETA: 0s - loss: 0.1880 - accuracy: 1.0000

4/4 [==============================] - 0s 892us/step - loss: 0.2730 - accuracy: 0.9500
Epoch 100/100
1/4 [======>.......................] - ETA: 0s - loss: 0.2705 - accuracy: 0.9062

4/4 [==============================] - 0s 891us/step - loss: 0.2693 - accuracy: 0.9500
1/1 [==============================] - ETA: 0s - loss: 0.1979 - accuracy: 0.9667

1/1 [==============================] - 0s 63ms/step - loss: 0.1979 - accuracy: 0.9667
Test Loss: 0.19785042107105255
Test Accuracy: 0.9666666388511658
  • Having the model as a function allows hyperparameters to be passed as arguments which is useful for hyperparameter tuning

def iris_classification_model (neurons1, neurons2):
    # Create a Sequential model
    model = Sequential()

    # Add an input layer with 4 input nodes (features)
    model.add(Dense(neurons1, input_dim=4, activation='relu'))

    # Add a hidden layer with 8 nodes and ReLU activation
    model.add(Dense(neurons2, activation='relu'))

    # Add an output layer with 3 nodes (one for each class) and softmax activation
    model.add(Dense(3, activation='softmax'))

    # Compile the model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    return model
  • You can tune a Keras model using the for loop method

for neurons1 in range(1,6):
    for neurons2 in range(1,6):
        model = iris_classification_model(neurons1, neurons2)

        # Train the model
        model.fit(X_train, y_train, epochs=100,verbose=0)
        
        # Evaluate the model on the test set
        loss, accuracy = model.evaluate(X_test, y_test)
        print("Neurons:", neurons1, neurons2)
        print("Test Accuracy:", accuracy)
        print()
1/1 [==============================] - ETA: 0s - loss: 1.1183 - accuracy: 0.1667

1/1 [==============================] - 0s 62ms/step - loss: 1.1183 - accuracy: 0.1667
Neurons: 1 1
Test Accuracy: 0.1666666716337204
1/1 [==============================] - ETA: 0s - loss: 1.1293 - accuracy: 0.1667

1/1 [==============================] - 0s 62ms/step - loss: 1.1293 - accuracy: 0.1667
Neurons: 1 2
Test Accuracy: 0.1666666716337204
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_test_function.<locals>.test_function at 0x28a6129d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_test_function.<locals>.test_function at 0x28a6129d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - ETA: 0s - loss: 0.6689 - accuracy: 0.6000

1/1 [==============================] - 0s 63ms/step - loss: 0.6689 - accuracy: 0.6000
Neurons: 1 3
Test Accuracy: 0.6000000238418579
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_test_function.<locals>.test_function at 0x28a433a60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_test_function.<locals>.test_function at 0x28a433a60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - ETA: 0s - loss: 0.7597 - accuracy: 0.4333

1/1 [==============================] - 0s 65ms/step - loss: 0.7597 - accuracy: 0.4333
Neurons: 1 4
Test Accuracy: 0.4333333373069763
1/1 [==============================] - ETA: 0s - loss: 0.7258 - accuracy: 0.9000

1/1 [==============================] - 0s 62ms/step - loss: 0.7258 - accuracy: 0.9000
Neurons: 1 5
Test Accuracy: 0.8999999761581421
1/1 [==============================] - ETA: 0s - loss: 0.5023 - accuracy: 0.6333

1/1 [==============================] - 0s 63ms/step - loss: 0.5023 - accuracy: 0.6333
Neurons: 2 1
Test Accuracy: 0.6333333253860474
1/1 [==============================] - ETA: 0s - loss: 0.7719 - accuracy: 0.9000

1/1 [==============================] - 0s 62ms/step - loss: 0.7719 - accuracy: 0.9000
Neurons: 2 2
Test Accuracy: 0.8999999761581421
1/1 [==============================] - ETA: 0s - loss: 0.6911 - accuracy: 0.9000

1/1 [==============================] - 0s 62ms/step - loss: 0.6911 - accuracy: 0.9000
Neurons: 2 3
Test Accuracy: 0.8999999761581421
1/1 [==============================] - ETA: 0s - loss: 0.5407 - accuracy: 0.8667

1/1 [==============================] - 0s 61ms/step - loss: 0.5407 - accuracy: 0.8667
Neurons: 2 4
Test Accuracy: 0.8666666746139526
1/1 [==============================] - ETA: 0s - loss: 0.4017 - accuracy: 0.9333

1/1 [==============================] - 0s 62ms/step - loss: 0.4017 - accuracy: 0.9333
Neurons: 2 5
Test Accuracy: 0.9333333373069763
1/1 [==============================] - ETA: 0s - loss: 0.7884 - accuracy: 0.5000

1/1 [==============================] - 0s 62ms/step - loss: 0.7884 - accuracy: 0.5000
Neurons: 3 1
Test Accuracy: 0.5
1/1 [==============================] - ETA: 0s - loss: 0.7723 - accuracy: 0.5333

1/1 [==============================] - 0s 63ms/step - loss: 0.7723 - accuracy: 0.5333
Neurons: 3 2
Test Accuracy: 0.5333333611488342
1/1 [==============================] - ETA: 0s - loss: 0.5310 - accuracy: 0.7333

1/1 [==============================] - 0s 62ms/step - loss: 0.5310 - accuracy: 0.7333
Neurons: 3 3
Test Accuracy: 0.7333333492279053
1/1 [==============================] - ETA: 0s - loss: 0.4799 - accuracy: 0.6333

1/1 [==============================] - 0s 62ms/step - loss: 0.4799 - accuracy: 0.6333
Neurons: 3 4
Test Accuracy: 0.6333333253860474
1/1 [==============================] - ETA: 0s - loss: 0.4727 - accuracy: 0.7667

1/1 [==============================] - 0s 63ms/step - loss: 0.4727 - accuracy: 0.7667
Neurons: 3 5
Test Accuracy: 0.7666666507720947
1/1 [==============================] - ETA: 0s - loss: 0.7215 - accuracy: 0.5333

1/1 [==============================] - 0s 62ms/step - loss: 0.7215 - accuracy: 0.5333
Neurons: 4 1
Test Accuracy: 0.5333333611488342
1/1 [==============================] - ETA: 0s - loss: 0.5104 - accuracy: 0.7000

1/1 [==============================] - 0s 62ms/step - loss: 0.5104 - accuracy: 0.7000
Neurons: 4 2
Test Accuracy: 0.699999988079071
1/1 [==============================] - ETA: 0s - loss: 0.2952 - accuracy: 0.9000

1/1 [==============================] - 0s 63ms/step - loss: 0.2952 - accuracy: 0.9000
Neurons: 4 3
Test Accuracy: 0.8999999761581421
1/1 [==============================] - ETA: 0s - loss: 0.4911 - accuracy: 0.7000

1/1 [==============================] - 0s 63ms/step - loss: 0.4911 - accuracy: 0.7000
Neurons: 4 4
Test Accuracy: 0.699999988079071
1/1 [==============================] - ETA: 0s - loss: 0.6523 - accuracy: 0.9000

1/1 [==============================] - 0s 64ms/step - loss: 0.6523 - accuracy: 0.9000
Neurons: 4 5
Test Accuracy: 0.8999999761581421
1/1 [==============================] - ETA: 0s - loss: 1.1303 - accuracy: 0.1667

1/1 [==============================] - 0s 63ms/step - loss: 1.1303 - accuracy: 0.1667
Neurons: 5 1
Test Accuracy: 0.1666666716337204
1/1 [==============================] - ETA: 0s - loss: 0.6643 - accuracy: 0.9667

1/1 [==============================] - 0s 62ms/step - loss: 0.6643 - accuracy: 0.9667
Neurons: 5 2
Test Accuracy: 0.9666666388511658
1/1 [==============================] - ETA: 0s - loss: 0.3997 - accuracy: 0.8667

1/1 [==============================] - 0s 63ms/step - loss: 0.3997 - accuracy: 0.8667
Neurons: 5 3
Test Accuracy: 0.8666666746139526
1/1 [==============================] - ETA: 0s - loss: 0.4182 - accuracy: 0.6333

1/1 [==============================] - 0s 62ms/step - loss: 0.4182 - accuracy: 0.6333
Neurons: 5 4
Test Accuracy: 0.6333333253860474
1/1 [==============================] - ETA: 0s - loss: 0.2611 - accuracy: 0.9333

1/1 [==============================] - 0s 62ms/step - loss: 0.2611 - accuracy: 0.9333
Neurons: 5 5
Test Accuracy: 0.9333333373069763
  • There are many ways to create Keras function for hyperaparameter tuning, some of which will be more general

# Define a function to create a Keras model
def create_model(hidden_layers=1, neurons=8, learning_rate=0.001):
    model = Sequential()
    model.add(Dense(neurons, input_dim=4, activation='relu'))
    
    for _ in range(hidden_layers - 1):
        model.add(Dense(neurons, activation='relu'))
    
    model.add(Dense(3, activation='softmax'))
    
    optimizer = Adam(learning_rate=learning_rate)
    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model
  • Finally, we can also extract the probabilities from Keras the same way as Scikit-Learn

# Predict the probabilities for the test set
probabilities = model.predict(X_test)

# probabilities is a 2D array where each row corresponds to a sample in X_test
# and each column corresponds to the probability of that sample belonging to a specific class

# You can print the probabilities for the first few samples as an example
print("Probabilities for the first 5 samples:")
print(probabilities[:5])
1/1 [==============================] - ETA: 0s

1/1 [==============================] - 0s 46ms/step
Probabilities for the first 5 samples:
[[4.1971911e-02 2.3005164e-01 7.2797644e-01]
 [2.0571986e-02 1.7209764e-01 8.0733037e-01]
 [9.6769130e-01 3.2196801e-02 1.1195116e-04]
 [9.6624315e-01 3.3624936e-02 1.3193331e-04]
 [9.7459555e-01 2.5339015e-02 6.5407788e-05]]