Skip to content

Commit b59dbb2

Browse files
authored
Merge pull request #3 from ChujieChen/master
added an example python file for polarity LSTM
2 parents c5dce27 + d88b096 commit b59dbb2

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed
+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import datetime
2+
import torch
3+
from torch.nn import CrossEntropyLoss
4+
from torch.utils.data import DataLoader
5+
from torch.utils.data import random_split
6+
7+
import yews.datasets as dsets
8+
import yews.transforms as transforms
9+
from yews.train import Trainer
10+
11+
#from yews.models import cpic
12+
#from yews.models import cpic_v1
13+
#from yews.models import cpic_v2
14+
#cpic = cpic_v1
15+
16+
from yews.models import polarity_v1
17+
from yews.models import polarity_v2
18+
from yews.models import polarity_lstm
19+
polarity=polarity_lstm
20+
21+
22+
if __name__ == '__main__':
23+
24+
print("Now: start : " + str(datetime.datetime.now()))
25+
26+
# Preprocessing
27+
waveform_transform = transforms.Compose([
28+
transforms.ZeroMean(),
29+
#transforms.SoftClip(1e-4),
30+
transforms.ToTensor(),
31+
])
32+
33+
# Prepare dataset
34+
dsets.set_memory_limit(10 * 1024 ** 3) # first number is GB
35+
# dset = dsets.Wenchuan(path='/home/qszhai/temp_project/deep_learning_course_project/cpic', download=False,sample_transform=waveform_transform)
36+
dset = dsets.Wenchuan(path='/home/qszhai/temp_project/deep_learning_course_project/first_motion_polarity/scsn_data/train_npy', download=False, sample_transform=waveform_transform)
37+
38+
# Split datasets into training and validation
39+
train_length = int(len(dset) * 0.8)
40+
val_length = len(dset) - train_length
41+
train_set, val_set = random_split(dset, [train_length, val_length])
42+
43+
# Prepare dataloaders
44+
train_loader = DataLoader(train_set, batch_size=5000, shuffle=True, num_workers=4)
45+
val_loader = DataLoader(val_set, batch_size=10000, shuffle=False, num_workers=4)
46+
47+
# Prepare trainer
48+
#trainer = Trainer(cpic(), CrossEntropyLoss(), lr=0.1)
49+
model_conf = {"hidden_size": 64}
50+
plt = polarity(**model_conf)
51+
trainer = Trainer(plt, CrossEntropyLoss(), lr=0.001)
52+
53+
# Train model over training dataset
54+
trainer.train(train_loader, val_loader, epochs=50, print_freq=100)
55+
#resume='checkpoint_best.pth.tar')
56+
57+
# Save training results to disk
58+
trainer.results(path='scsn_polarity_results.pth.tar')
59+
60+
# Validate saved model
61+
results = torch.load('scsn_polarity_results.pth.tar')
62+
#model = cpic()
63+
model = plt
64+
model.load_state_dict(results['model'])
65+
trainer = Trainer(model, CrossEntropyLoss(), lr=0.001)
66+
trainer.validate(val_loader, print_freq=100)
67+
68+
print("Now: end : " + str(datetime.datetime.now()))
69+
70+
import matplotlib.pyplot as plt
71+
import numpy as np
72+
73+
myfontsize1=14
74+
myfontsize2=18
75+
myfontsize3=24
76+
77+
results = torch.load('scsn_polarity_results.pth.tar')
78+
79+
fig, axes = plt.subplots(2, 1, num=0, figsize=(6, 4), sharex=True)
80+
axes[0].plot(results['val_acc'], label='Validation')
81+
axes[0].plot(results['train_acc'], label='Training')
82+
83+
#axes[1].set_xlabel("Epochs",fontsize=myfontsize2)
84+
axes[0].set_xscale('log')
85+
axes[0].set_xlim([1, 100])
86+
axes[0].xaxis.set_tick_params(labelsize=myfontsize1)
87+
88+
axes[0].set_ylabel("Accuracies (%)",fontsize=myfontsize2)
89+
axes[0].set_ylim([0, 100])
90+
axes[0].set_yticks(np.arange(0, 101, 10))
91+
axes[0].yaxis.set_tick_params(labelsize=myfontsize1)
92+
93+
axes[0].grid(True, 'both')
94+
axes[0].legend(loc=4)
95+
96+
#axes[1].semilogx(results['val_loss'], label='Validation')
97+
#axes[1].semilogx(results['train_loss'], label='Training')
98+
axes[1].plot(results['val_loss'], label='Validation')
99+
axes[1].plot(results['train_loss'], label='Training')
100+
101+
axes[1].set_xlabel("Epochs",fontsize=myfontsize2)
102+
axes[1].set_xscale('log')
103+
axes[1].set_xlim([1, 100])
104+
axes[1].xaxis.set_tick_params(labelsize=myfontsize1)
105+
106+
axes[1].set_ylabel("Losses",fontsize=myfontsize2)
107+
axes[1].set_ylim([0.0, 1.0])
108+
axes[1].set_yticks(np.arange(0.0,1.01,0.2))
109+
axes[1].yaxis.set_tick_params(labelsize=myfontsize1)
110+
111+
axes[1].grid(True, 'both')
112+
axes[1].legend(loc=1)
113+
114+
fig.tight_layout()
115+
plt.savefig('Accuracies_train_val.pdf')

0 commit comments

Comments
 (0)