Skip to content

Commit f02d856

Browse files
committed
added mnist example
1 parent 49dc1c5 commit f02d856

File tree

4 files changed

+510
-0
lines changed

4 files changed

+510
-0
lines changed
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import logging, os, argparse
2+
from servicefoundry import Build, Job, PythonBuild, Param, Port, LocalSource, Resources
3+
4+
# parsing the arguments
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument(
7+
"--workspace_fqn", type=str, required=True, help="fqn of the workspace to deploy to"
8+
)
9+
args = parser.parse_args()
10+
11+
# defining the job specifications
12+
job = Job(
13+
name="mnist-train-job",
14+
image=Build(
15+
build_spec=PythonBuild(
16+
command="python train.py --num_epochs {{num_epochs}} --ml_repo {{ml_repo}}",
17+
requirements_path="requirements.txt",
18+
),
19+
build_source=LocalSource(local_build=False)
20+
),
21+
params=[
22+
Param(name="num_epochs", default='4'),
23+
Param(name="ml_repo", param_type="ml_repo"),
24+
],
25+
resources=Resources(
26+
cpu_request=0.5,
27+
cpu_limit=0.5,
28+
memory_request=1000,
29+
memory_limit=1500
30+
)
31+
32+
)
33+
deployment = job.deploy(workspace_fqn=args.workspace_fqn)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
matplotlib==3.8.2
2+
tensorflow==2.15.0
3+
mlfoundry==0.10.4
+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import mlfoundry
2+
import tensorflow as tf
3+
from tensorflow.keras.datasets import mnist
4+
import matplotlib.pyplot as plt
5+
from tensorflow.keras.datasets import mnist
6+
import os
7+
import argparse
8+
9+
# parsing the arguments
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument(
12+
"--num_epochs", type=int, default=4
13+
)
14+
parser.add_argument(
15+
"--ml_repo", type=str, required=True
16+
)
17+
args = parser.parse_args()
18+
19+
ML_REPO_NAME=args.ml_repo
20+
21+
# Load the MNIST dataset
22+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
23+
24+
print(f"The number of train images: {len(x_train)}")
25+
print(f"The number of test images: {len(x_test)}")
26+
27+
# Plot some sample images
28+
plt.figure(figsize=(10, 5))
29+
for i in range(10):
30+
plt.subplot(2, 5, i+1)
31+
plt.imshow(x_train[i], cmap='gray')
32+
plt.title(f"Label: {y_train[i]}")
33+
plt.axis('off')
34+
plt.tight_layout()
35+
plt.show()
36+
37+
38+
# Load the MNIST dataset
39+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
40+
41+
# Normalize the pixel values between 0 and 1
42+
x_train = x_train / 255.0
43+
x_test = x_test / 255.0
44+
45+
46+
# Define the model architecture
47+
model = tf.keras.Sequential([
48+
tf.keras.layers.Flatten(input_shape=(28, 28)),
49+
tf.keras.layers.Dense(128, activation='relu'),
50+
tf.keras.layers.Dense(10, activation='softmax')
51+
])
52+
53+
# Compile the model
54+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
55+
56+
57+
# Creating client for logging the metadata
58+
client = mlfoundry.get_client()
59+
60+
client.create_ml_repo(ML_REPO_NAME)
61+
run = client.create_run(ml_repo=ML_REPO_NAME)
62+
63+
64+
#logging the parameters
65+
run.log_params({"optimizer": "adam", "loss": "sparse_categorical_crossentropy", "metric": ["accuracy"]})
66+
67+
68+
69+
# Train the model
70+
epochs = args.num_epochs
71+
model.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))
72+
73+
# Evaluate the model
74+
loss, accuracy = model.evaluate(x_test, y_test)
75+
print(f'Test loss: {loss}')
76+
print(f'Test accuracy: {accuracy}')
77+
78+
79+
# Log Metrics and Model
80+
81+
# Logging the metrics of the model
82+
run.log_metrics(metric_dict={"accuracy": accuracy, "loss": loss})
83+
84+
# Save the trained model
85+
model.save('mnist_model.h5')
86+
87+
# Logging the model
88+
run.log_model(
89+
name="handwritten-digits-recognition",
90+
model_file_or_folder='mnist_model.h5',
91+
framework="tensorflow",
92+
description="sample model to recognize the handwritten digits",
93+
metadata={"accuracy": accuracy, "loss": loss},
94+
step=1, # step number, useful when using iterative algorithms like SGD
95+
)
96+
97+

0 commit comments

Comments
 (0)