-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdatasets_util.py
32 lines (25 loc) · 1.1 KB
/
datasets_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torchvision import transforms
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((384,384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(size=(384,384), scale=(1-0.3, 1)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def open_image_and_apply_transform(image_path):
"""Given the path of an image, open the image, and return it as a normalized tensor.
"""
pil_image = Image.open(image_path)
tensor_image = transform_test(pil_image)
return tensor_image
def open_image_and_apply_transform_train(image_path):
"""Given the path of an image, open the image, and return it as a normalized tensor.
"""
pil_image = Image.open(image_path)
tensor_image = transform_train(pil_image)
return tensor_image