Skip to content

Commit e1afdcc

Browse files
authored
Merge pull request #12 from truefoundry/np-add-mnist-eg
added mnist example
2 parents 2372083 + 1a5ebc6 commit e1afdcc

File tree

4 files changed

+477
-0
lines changed

4 files changed

+477
-0
lines changed
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import logging, os, argparse
2+
from servicefoundry import Build, Job, PythonBuild, Param, Port, LocalSource, Resources
3+
4+
# parsing the arguments
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument(
7+
"--workspace_fqn", type=str, required=True, help="fqn of the workspace to deploy to"
8+
)
9+
args = parser.parse_args()
10+
11+
# defining the job specifications
12+
job = Job(
13+
name="mnist-train-job",
14+
image=Build(
15+
build_spec=PythonBuild(
16+
command="python train.py --num_epochs {{num_epochs}} --ml_repo {{ml_repo}}",
17+
requirements_path="requirements.txt",
18+
),
19+
build_source=LocalSource(local_build=False)
20+
),
21+
params=[
22+
Param(name="num_epochs", default='4'),
23+
Param(name="ml_repo", param_type="ml_repo"),
24+
],
25+
resources=Resources(
26+
cpu_request=0.5,
27+
cpu_limit=0.5,
28+
memory_request=1500,
29+
memory_limit=2000
30+
)
31+
32+
)
33+
deployment = job.deploy(workspace_fqn=args.workspace_fqn)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
matplotlib==3.8.2
2+
tensorflow==2.15.0
3+
mlfoundry==0.10.4
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import mlfoundry
2+
import tensorflow as tf
3+
from tensorflow.keras.datasets import mnist
4+
import matplotlib.pyplot as plt
5+
from tensorflow.keras.datasets import mnist
6+
import os
7+
import argparse
8+
9+
# parsing the arguments
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument(
12+
"--num_epochs", type=int, default=4
13+
)
14+
parser.add_argument(
15+
"--ml_repo", type=str, required=True
16+
)
17+
args = parser.parse_args()
18+
19+
20+
# Load the MNIST dataset
21+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
22+
23+
# Normalize the pixel values between 0 and 1
24+
x_train = x_train / 255.0
25+
x_test = x_test / 255.0
26+
27+
print(f"The number of train images: {len(x_train)}")
28+
print(f"The number of test images: {len(x_test)}")
29+
30+
# Creating client for logging the metadata
31+
client = mlfoundry.get_client()
32+
33+
client.create_ml_repo(args.ml_repo)
34+
run = client.create_run(ml_repo=args.ml_repo, run_name="train-model")
35+
36+
# Plot some sample images
37+
plt.figure(figsize=(10, 5))
38+
for i in range(10):
39+
plt.subplot(2, 5, i+1)
40+
plt.imshow(x_train[i], cmap='gray')
41+
plt.title(f"Label: {y_train[i]}")
42+
plt.axis('off')
43+
run.log_plots({"images": plt})
44+
plt.tight_layout()
45+
46+
47+
# Define the model architecture
48+
model = tf.keras.Sequential([
49+
tf.keras.layers.Flatten(input_shape=(28, 28)),
50+
tf.keras.layers.Dense(128, activation='relu'),
51+
tf.keras.layers.Dense(10, activation='softmax')
52+
])
53+
54+
# Compile the model
55+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
56+
57+
#logging the parameters
58+
run.log_params({"optimizer": "adam", "loss": "sparse_categorical_crossentropy", "metric": ["accuracy"]})
59+
60+
# Train the model
61+
epochs = args.num_epochs
62+
model.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))
63+
64+
# Evaluate the model
65+
loss, accuracy = model.evaluate(x_test, y_test)
66+
print(f'Test loss: {loss}')
67+
print(f'Test accuracy: {accuracy}')
68+
69+
70+
# Log Metrics and Model
71+
72+
# Logging the metrics of the model
73+
run.log_metrics(metric_dict={"accuracy": accuracy, "loss": loss})
74+
75+
# Save the trained model
76+
model.save('mnist_model.h5')
77+
78+
# Logging the model
79+
run.log_model(
80+
name="handwritten-digits-recognition",
81+
model_file_or_folder='mnist_model.h5',
82+
framework="tensorflow",
83+
description="sample model to recognize the handwritten digits",
84+
metadata={"accuracy": accuracy, "loss": loss},
85+
step=1, # step number, useful when using iterative algorithms like SGD
86+
)
87+
88+

0 commit comments

Comments
 (0)