-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplotScript.py
72 lines (62 loc) · 1.92 KB
/
plotScript.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from matplotlib import pyplot as plt
import csv
# Define path for the log data
unixTime = 1679036461
acc_path = f"Models/{unixTime}/tb_data/epoch_accuracy"
loss_path = f"Models/{unixTime}/tb_data/epoch_loss"
# Init variables storing all the scalar data
x_train_acc = []
y_train_acc = []
x_val_acc = []
y_val_acc = []
x_train_loss = []
y_train_loss = []
x_val_loss = []
y_val_loss = []
# Store data into the variables after reading from the csv files
with open(f"{acc_path}/train.csv") as csvfile :
lines = csv.reader(csvfile)
row_num = 0
for row in lines :
if row_num > 0 :
x_train_acc.append(int(row[1]))
y_train_acc.append(round(float(row[2]), 4))
row_num += 1
with open(f"{acc_path}/validation.csv") as csvfile :
lines = csv.reader(csvfile)
row_num = 0
for row in lines :
if row_num > 0 :
x_val_acc.append(int(row[1]))
y_val_acc.append(round(float(row[2]), 4))
row_num += 1
with open(f"{loss_path}/train.csv") as csvfile :
lines = csv.reader(csvfile)
row_num = 0
for row in lines :
if row_num > 0 :
x_train_loss.append(int(row[1]))
y_train_loss.append(round(float(row[2]), 4))
row_num += 1
with open(f"{loss_path}/validation.csv") as csvfile :
lines = csv.reader(csvfile)
row_num = 0
for row in lines :
if row_num > 0 :
x_val_loss.append(int(row[1]))
y_val_loss.append(round(float(row[2]), 4))
row_num += 1
# Plot the figures
plt.figure(0)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.plot(x_train_acc, y_train_acc, label= "train", color= 'tab:orange')
plt.plot(x_val_acc, y_val_acc, label= "validation", color= 'c')
plt.legend()
plt.figure(1)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(x_train_loss, y_train_loss, label= "train", color= 'tab:orange')
plt.plot(x_val_loss, y_val_loss, label= "validation", color= 'c')
plt.legend()
plt.show()