|
| 1 | +import mlfoundry |
| 2 | +import tensorflow as tf |
| 3 | +from tensorflow.keras.datasets import mnist |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +from tensorflow.keras.datasets import mnist |
| 6 | +import os |
| 7 | +import argparse |
| 8 | + |
| 9 | +# parsing the arguments |
| 10 | +parser = argparse.ArgumentParser() |
| 11 | +parser.add_argument( |
| 12 | + "--num_epochs", type=int, default=4 |
| 13 | +) |
| 14 | +parser.add_argument( |
| 15 | + "--ml_repo", type=str, required=True |
| 16 | +) |
| 17 | +args = parser.parse_args() |
| 18 | + |
| 19 | +ML_REPO_NAME=args.ml_repo |
| 20 | + |
| 21 | +# Load the MNIST dataset |
| 22 | +(x_train, y_train), (x_test, y_test) = mnist.load_data() |
| 23 | + |
| 24 | +print(f"The number of train images: {len(x_train)}") |
| 25 | +print(f"The number of test images: {len(x_test)}") |
| 26 | + |
| 27 | +# Plot some sample images |
| 28 | +plt.figure(figsize=(10, 5)) |
| 29 | +for i in range(10): |
| 30 | + plt.subplot(2, 5, i+1) |
| 31 | + plt.imshow(x_train[i], cmap='gray') |
| 32 | + plt.title(f"Label: {y_train[i]}") |
| 33 | + plt.axis('off') |
| 34 | +plt.tight_layout() |
| 35 | +plt.show() |
| 36 | + |
| 37 | + |
| 38 | +# Load the MNIST dataset |
| 39 | +(x_train, y_train), (x_test, y_test) = mnist.load_data() |
| 40 | + |
| 41 | +# Normalize the pixel values between 0 and 1 |
| 42 | +x_train = x_train / 255.0 |
| 43 | +x_test = x_test / 255.0 |
| 44 | + |
| 45 | + |
| 46 | +# Define the model architecture |
| 47 | +model = tf.keras.Sequential([ |
| 48 | + tf.keras.layers.Flatten(input_shape=(28, 28)), |
| 49 | + tf.keras.layers.Dense(128, activation='relu'), |
| 50 | + tf.keras.layers.Dense(10, activation='softmax') |
| 51 | +]) |
| 52 | + |
| 53 | +# Compile the model |
| 54 | +model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) |
| 55 | + |
| 56 | + |
| 57 | +# Creating client for logging the metadata |
| 58 | +client = mlfoundry.get_client() |
| 59 | + |
| 60 | +client.create_ml_repo(ML_REPO_NAME) |
| 61 | +run = client.create_run(ml_repo=ML_REPO_NAME) |
| 62 | + |
| 63 | + |
| 64 | +#logging the parameters |
| 65 | +run.log_params({"optimizer": "adam", "loss": "sparse_categorical_crossentropy", "metric": ["accuracy"]}) |
| 66 | + |
| 67 | + |
| 68 | + |
| 69 | +# Train the model |
| 70 | +epochs = args.num_epochs |
| 71 | +model.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test)) |
| 72 | + |
| 73 | +# Evaluate the model |
| 74 | +loss, accuracy = model.evaluate(x_test, y_test) |
| 75 | +print(f'Test loss: {loss}') |
| 76 | +print(f'Test accuracy: {accuracy}') |
| 77 | + |
| 78 | + |
| 79 | +# Log Metrics and Model |
| 80 | + |
| 81 | +# Logging the metrics of the model |
| 82 | +run.log_metrics(metric_dict={"accuracy": accuracy, "loss": loss}) |
| 83 | + |
| 84 | +# Save the trained model |
| 85 | +model.save('mnist_model.h5') |
| 86 | + |
| 87 | +# Logging the model |
| 88 | +run.log_model( |
| 89 | + name="handwritten-digits-recognition", |
| 90 | + model_file_or_folder='mnist_model.h5', |
| 91 | + framework="tensorflow", |
| 92 | + description="sample model to recognize the handwritten digits", |
| 93 | + metadata={"accuracy": accuracy, "loss": loss}, |
| 94 | + step=1, # step number, useful when using iterative algorithms like SGD |
| 95 | +) |
| 96 | + |
| 97 | + |
0 commit comments