-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
288 lines (238 loc) · 11.7 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from utils import *
from sklearn.model_selection import train_test_split, cross_val_score
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 886402639
from PIL import Image, ImageEnhance, ImageFilter
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, default_data_collator
from huggingface_hub import whoami
def custom_collate(batch: list) -> (torch.Tensor, torch.Tensor):
"""
Custom collate function to filter out None entries from the batch.
"""
# Filter out None entries
batch = [item for item in batch if item is not None]
if len(batch) == 0:
return None, None
return default_data_collator(batch)
class ImageClassifier:
def __init__(self, checkpoint: str, num_labels: int, store_dir: str = "/lnet/work/people/lutsai/pythonProject/OCR/ltp-ocr/trans/chekcpoint"):
"""
Initialize the image classifier with the specified checkpoint.
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.processor = AutoImageProcessor.from_pretrained(checkpoint)
self.model = AutoModelForImageClassification.from_pretrained(
checkpoint,
num_labels=num_labels,
cache_dir=store_dir,
ignore_mismatched_sizes=True
).to(self.device)
# Define transformations
self.train_transforms = transforms.Compose([
transforms.RandomApply([
transforms.ColorJitter(brightness=0.5),
transforms.ColorJitter(contrast=0.5),
transforms.ColorJitter(saturation=0.5),
transforms.ColorJitter(hue=0.5),
transforms.Lambda(lambda img: ImageEnhance.Sharpness(img).enhance(random.uniform(0.5, 1.5))),
transforms.Lambda(lambda img: img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0, 2))))
], p=0.5),
transforms.Resize((self.processor.size['height'], self.processor.size['width'])),
transforms.ToTensor(),
transforms.Normalize(mean=self.processor.image_mean, std=self.processor.image_std)
])
self.eval_transforms = transforms.Compose([
transforms.Resize((self.processor.size['height'], self.processor.size['width'])),
transforms.ToTensor(),
transforms.Normalize(mean=self.processor.image_mean, std=self.processor.image_std)
])
def process_images(self, image_paths: list, image_labels: list, batch_size: int, train: bool = True) -> DataLoader:
"""
Process a list of image file paths into batches.
"""
dataset = ImageDataset(image_paths, image_labels, self.train_transforms if train else self.eval_transforms)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=train, collate_fn=custom_collate)
print(f"Dataloader of {'train' if train else 'eval'} dataset is ready:\t{len(image_paths)} images split into {len(dataloader)} batches of size {batch_size}")
return dataloader
def preprocess_image(self, image_path: str, train: bool = True) -> torch.Tensor:
"""
Preprocess a single image for training or evaluation.
"""
image = Image.open(image_path).convert('RGB')
transform = self.train_transforms if train else self.eval_transforms
tensor = transform(image).unsqueeze(0).to(self.device)
return tensor
def train_model(self, train_dataloader, eval_dataloader, output_dir: str, num_epochs: int = 3, learning_rate: float = 5e-5, logging_steps: int = 10):
"""
Train the model using the provided training and evaluation data loaders.
"""
print(f"Training for {num_epochs} epochs on {len(train_dataloader)} train samples and evaluation on {len(eval_dataloader)} samples")
training_args = TrainingArguments(
output_dir=output_dir,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=learning_rate,
per_device_train_batch_size=train_dataloader.batch_size,
per_device_eval_batch_size=eval_dataloader.batch_size,
num_train_epochs=num_epochs,
warmup_ratio=0.1,
logging_steps=logging_steps,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataloader.dataset,
eval_dataset=eval_dataloader.dataset,
data_collator=lambda data: custom_collate(data),
compute_metrics=self.compute_metrics,
)
trainer.train()
self.save_model(f"model/model_{len(train_dataloader)}_{num_epochs}")
def infer(self, image_path: str) -> int:
"""
Perform inference on a single image.
"""
self.model.eval()
with torch.no_grad():
inputs = self.preprocess_image(image_path, train=False)
outputs = self.model(pixel_values=inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return predicted_class_idx
def top_n_predictions(self, image_path: str, top_n: int = 1) -> list:
"""
Perform inference and return top-N predictions with normalized probabilities.
"""
self.model.eval()
with torch.no_grad():
inputs = self.preprocess_image(image_path, train=False)
outputs = self.model(pixel_values=inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
top_n_probs, top_n_indices = torch.topk(probabilities, top_n, dim=-1)
top_n_probs = top_n_probs / top_n_probs.sum()
# print(top_n_indices, top_n_probs)
return list(zip(top_n_indices.squeeze().tolist(), top_n_probs.squeeze().tolist()))
def create_dataloader(self, image_paths: list, batch_size: int) -> DataLoader:
"""
Turn an input list of image paths into a DataLoader without labels.
"""
dataset = ImageDataset(image_paths, transform=self.eval_transforms)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
print(f"Dataloader of directory dataset is ready:\t{len(image_paths)} images split into {len(dataloader)} batches of size {batch_size}")
return dataloader
def infer_dataloader(self, dataloader, top_n: int, raw: bool = False) -> (list, list):
"""
Perform inference on a DataLoader, optionally with top-N predictions.
"""
self.model.eval()
predictions = []
raw_scores = []
with torch.no_grad():
for batch in dataloader:
inputs = batch['pixel_values']
outputs = self.model(pixel_values=inputs.to(self.device))
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
raw_scores.extend(probabilities.tolist())
if top_n > 1:
top_n_probs, top_n_indices = torch.topk(probabilities, top_n, dim=-1)
for indices, probs in zip(top_n_indices, top_n_probs):
top_n_probs_normalized = probs / probs.sum()
predictions.append(list(zip(indices.tolist(), top_n_probs_normalized.tolist())))
else:
predicted_class_idx = logits.argmax(-1).tolist()
predictions.extend(predicted_class_idx)
print(f"Processed {len(predictions)} images")
raw_scores = None if not raw else raw_scores
return predictions, raw_scores
def save_model(self, save_directory: str):
"""
Save the fine-tuned model and processor to the specified directory.
"""
if not os.path.exists(save_directory):
os.makedirs(save_directory)
self.model.save_pretrained(save_directory)
self.processor.save_pretrained(save_directory)
print(f"Model and processor saved to {save_directory}")
def load_model(self, load_directory: str):
"""
Load a fine-tuned model and processor from the specified directory.
"""
self.processor = AutoImageProcessor.from_pretrained(load_directory)
self.model = AutoModelForImageClassification.from_pretrained(load_directory).to(self.device)
print(f"Model and processor loaded from {load_directory}")
def push_to_hub(self, load_directory: str, repo_id: str, private: bool = False,
token: str = None):
"""
Upload the fine-tuned model and processor to the Hugging Face Model Hub.
Args:
load_directory (str): The directory where the model and processor are stored.
repo_name (str): The name of the repository to create or update on the Hugging Face Hub.
organization (str, optional): The organization under which to create the repository. Defaults to None.
private (bool, optional): Whether the repository should be private. Defaults to False.
token (str, optional): The authentication token for Hugging Face Hub. Defaults to None.
"""
# Determine the repository ID
# username = whoami(token=token)['name']
# repo_id = f"{username}/{repo_name}"
# Save the model and processor locally
self.model.save_pretrained(load_directory)
self.processor.save_pretrained(load_directory)
# Upload to the Hub
self.model.push_to_hub(repo_id, private=private, token=token)
self.processor.push_to_hub(repo_id, private=private, token=token)
print(f"Model and processor pushed to the Hugging Face Hub: {repo_id}")
def load_from_hub(self, repo_id: str):
"""
Load a model and its processor from the Hugging Face Hub.
Args:
repo_id (str): The name of the repository on the Hugging Face Hub.
Returns:
model: The loaded model.
processor: The loaded processor.
"""
# Load the model from the repository
model = AutoModelForImageClassification.from_pretrained(repo_id)
# Load the processor from the repository
processor = AutoImageProcessor.from_pretrained(repo_id)
self.model, self.processor = model, processor
print(f"Model and processor loaded from the Hugging Face Hub: {repo_id}")
@staticmethod
def compute_metrics(eval_pred: list) -> dict:
"""
Compute accuracy metrics for evaluation.
"""
from evaluate import load
import numpy as np
accuracy = load("accuracy")
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
labels = np.argmax(labels, axis=-1)
return accuracy.compute(predictions=predictions, references=labels)
class ImageDataset(Dataset):
def __init__(self, image_paths: list, image_labels: list = None, transform=None):
self.image_paths = image_paths
self.image_labels = image_labels
self.transform = transform
self.known = True
if image_labels is None:
self.known = False
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
try:
image = Image.open(image_path).convert('RGB')
if self.transform:
image = self.transform(image)
return {'pixel_values': image, 'label': self.image_labels[idx] if self.known else None}
except Exception as e:
print(image_path, e)
return None