Skip to content

Commit bf8cef6

Browse files
committed
Minor Multiple class refactor
1 parent 82f70b0 commit bf8cef6

File tree

5 files changed

+15
-27
lines changed

5 files changed

+15
-27
lines changed

abraia/multiple.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def load_metadata(self, path):
4444

4545
def load_envi(self, path):
4646
dest = self.cache_file(path)
47-
raw = f"{dest.split('.')[0]}.raw"
48-
if not os.path.exists(raw):
49-
self.download_file(f"{path.split('.')[0]}.raw", raw)
47+
raw = self.cache_file(f"{path.split('.')[0]}.raw")
5048
return np.array(spectral.io.envi.open(dest, raw)[:, :, :])
5149

5250
def load_mat(self, path):

abraia/torch.py

+11-21
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,13 @@
2222
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2323

2424

25+
# TODO: Remove with next version
2526
def download_file(path):
26-
dest = os.path.join(tempdir, path)
27-
if not os.path.exists(dest):
28-
os.makedirs(os.path.dirname(dest), exist_ok=True)
29-
multiple.download_file(path, dest)
30-
return dest
27+
return multiple.cache_file(path)
3128

3229

3330
def read_image(path):
34-
dest = download_file(path)
31+
dest = multiple.cache_file(path)
3532
return Image.open(dest).convert('RGB')
3633

3734

@@ -81,7 +78,7 @@ def save_model(path, model, device='cpu'):
8178

8279

8380
def load_model(path, class_names):
84-
dest = download_file(path)
81+
dest = multiple.cache_file(path)
8582
model = create_model(class_names, pretrained=False)
8683
model.load_state_dict(torch.load(dest))
8784
return model
@@ -98,12 +95,14 @@ def export_onnx(path, model, device='cpu'):
9895
multiple.upload_file(src, path)
9996

10097

98+
# TODO: Remove with next version
10199
def save_json(path, values):
102-
multiple.save_file(path, json.dumps(values))
100+
multiple.save_json(path, values)
103101

104102

103+
# TODO: Remove with next version
105104
def load_json(path):
106-
return json.loads(multiple.load_file(path))
105+
return multiple.load_json(path)
107106

108107

109108
transform = transforms.Compose([
@@ -112,8 +111,7 @@ def load_json(path):
112111
transforms.ToTensor(),
113112
transforms.Normalize(
114113
mean=[0.485, 0.456, 0.406],
115-
std=[0.229, 0.224, 0.225]
116-
)
114+
std=[0.229, 0.224, 0.225])
117115
])
118116

119117

@@ -145,27 +143,22 @@ def train_model(model, dataloaders, criterion=None, optimizer=None, scheduler=No
145143

146144
running_loss = 0.0
147145
running_corrects = 0
148-
149-
# Iterate over data.
146+
# Iterate over data
150147
for inputs, labels in dataloaders[phase]:
151148
inputs = inputs.to(device)
152149
labels = labels.to(device)
153-
154150
# zero the parameter gradients
155151
optimizer.zero_grad()
156-
157152
# forward
158153
# track history if only in train
159154
with torch.set_grad_enabled(phase == 'train'):
160155
outputs = model(inputs)
161156
_, preds = torch.max(outputs, 1)
162157
loss = criterion(outputs, labels)
163-
164158
# backward + optimize only if in training phase
165159
if phase == 'train':
166160
loss.backward()
167161
optimizer.step()
168-
169162
# statistics
170163
running_loss += loss.item() * inputs.size(0)
171164
running_corrects += torch.sum(preds == labels.data)
@@ -174,20 +167,17 @@ def train_model(model, dataloaders, criterion=None, optimizer=None, scheduler=No
174167

175168
epoch_loss = running_loss / len(dataloaders[phase].dataset)
176169
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
177-
178170
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
179171

180172
# deep copy the model
181173
if phase == 'val' and epoch_acc > best_acc:
182174
best_acc = epoch_acc
183175
best_model_wts = copy.deepcopy(model.state_dict())
184-
176+
185177
print()
186-
187178
time_elapsed = time.time() - since
188179
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
189180
print(f'Best val Acc: {best_acc:4f}')
190-
191181
# load best model weights
192182
model.load_state_dict(best_model_wts)
193183
return model

images/screenshot.jpg

-3.04 KB
Loading

scripts/abraia

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def input_files(src):
4242

4343

4444
@click.group('abraia')
45-
@click.version_option('0.13.2')
45+
@click.version_option('0.13.3')
4646
def cli():
4747
"""Abraia CLI tool"""
4848
pass
@@ -64,7 +64,7 @@ def configure():
6464
@cli.command()
6565
def info():
6666
"""Show user account information"""
67-
click.echo('abraia, version 0.13.2\n')
67+
click.echo('abraia, version 0.13.3\n')
6868
click.echo('Go to [' + click.style('https://abraia.me/console/', fg='green') + '] to see your account information\n')
6969

7070

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
setup(
1414
name='abraia',
15-
version='0.13.2',
15+
version='0.13.3',
1616
description='Abraia Multiple SDK',
1717
long_description=long_description,
1818
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)