1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch .utils .data import Dataset
6
+ from torchvision import transforms
7
+
8
+ import random
9
+
10
+ imagenet_templates_smallest = [
11
+ 'a photo of a {}' ,
12
+ ]
13
+
14
+ imagenet_templates_small = [
15
+ 'a photo of a {}' ,
16
+ 'a rendering of a {}' ,
17
+ 'a cropped photo of the {}' ,
18
+ 'the photo of a {}' ,
19
+ 'a photo of a clean {}' ,
20
+ 'a photo of a dirty {}' ,
21
+ 'a dark photo of the {}' ,
22
+ 'a photo of my {}' ,
23
+ 'a photo of the cool {}' ,
24
+ 'a close-up photo of a {}' ,
25
+ 'a bright photo of the {}' ,
26
+ 'a cropped photo of a {}' ,
27
+ 'a photo of the {}' ,
28
+ 'a good photo of the {}' ,
29
+ 'a photo of one {}' ,
30
+ 'a close-up photo of the {}' ,
31
+ 'a rendition of the {}' ,
32
+ 'a photo of the clean {}' ,
33
+ 'a rendition of a {}' ,
34
+ 'a photo of a nice {}' ,
35
+ 'a good photo of a {}' ,
36
+ 'a photo of the nice {}' ,
37
+ 'a photo of the small {}' ,
38
+ 'a photo of the weird {}' ,
39
+ 'a photo of the large {}' ,
40
+ 'a photo of a cool {}' ,
41
+ 'a photo of a small {}' ,
42
+ ]
43
+
44
+ imagenet_dual_templates_small = [
45
+ 'a photo of a {} with {}' ,
46
+ 'a rendering of a {} with {}' ,
47
+ 'a cropped photo of the {} with {}' ,
48
+ 'the photo of a {} with {}' ,
49
+ 'a photo of a clean {} with {}' ,
50
+ 'a photo of a dirty {} with {}' ,
51
+ 'a dark photo of the {} with {}' ,
52
+ 'a photo of my {} with {}' ,
53
+ 'a photo of the cool {} with {}' ,
54
+ 'a close-up photo of a {} with {}' ,
55
+ 'a bright photo of the {} with {}' ,
56
+ 'a cropped photo of a {} with {}' ,
57
+ 'a photo of the {} with {}' ,
58
+ 'a good photo of the {} with {}' ,
59
+ 'a photo of one {} with {}' ,
60
+ 'a close-up photo of the {} with {}' ,
61
+ 'a rendition of the {} with {}' ,
62
+ 'a photo of the clean {} with {}' ,
63
+ 'a rendition of a {} with {}' ,
64
+ 'a photo of a nice {} with {}' ,
65
+ 'a good photo of a {} with {}' ,
66
+ 'a photo of the nice {} with {}' ,
67
+ 'a photo of the small {} with {}' ,
68
+ 'a photo of the weird {} with {}' ,
69
+ 'a photo of the large {} with {}' ,
70
+ 'a photo of a cool {} with {}' ,
71
+ 'a photo of a small {} with {}' ,
72
+ ]
73
+
74
+ per_img_token_list = [
75
+ 'א' , 'ב' , 'ג' , 'ד' , 'ה' , 'ו' , 'ז' , 'ח' , 'ט' , 'י' , 'כ' , 'ל' , 'מ' , 'נ' , 'ס' , 'ע' , 'פ' , 'צ' , 'ק' , 'ר' , 'ש' , 'ת' ,
76
+ ]
77
+
78
+ class PersonalizedBase (Dataset ):
79
+ def __init__ (self ,
80
+ data_root ,
81
+ size = None ,
82
+ repeats = 100 ,
83
+ interpolation = "bicubic" ,
84
+ flip_p = 0.5 ,
85
+ set = "train" ,
86
+ placeholder_token = "*" ,
87
+ per_image_tokens = False ,
88
+ center_crop = False ,
89
+ mixing_prob = 0.25 ,
90
+ coarse_class_text = None ,
91
+ ):
92
+
93
+ self .data_root = data_root
94
+
95
+ self .image_paths = [os .path .join (self .data_root , file_path ) for file_path in os .listdir (self .data_root )]
96
+
97
+ # self._length = len(self.image_paths)
98
+ self .num_images = len (self .image_paths )
99
+ self ._length = self .num_images
100
+
101
+ self .placeholder_token = placeholder_token
102
+
103
+ self .per_image_tokens = per_image_tokens
104
+ self .center_crop = center_crop
105
+ self .mixing_prob = mixing_prob
106
+
107
+ self .coarse_class_text = coarse_class_text
108
+
109
+ if per_image_tokens :
110
+ assert self .num_images < len (per_img_token_list ), f"Can't use per-image tokens when the training set contains more than { len (per_img_token_list )} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
111
+
112
+ if set == "train" :
113
+ self ._length = self .num_images * repeats
114
+
115
+ self .size = size
116
+ self .interpolation = {"linear" : PIL .Image .LINEAR ,
117
+ "bilinear" : PIL .Image .BILINEAR ,
118
+ "bicubic" : PIL .Image .BICUBIC ,
119
+ "lanczos" : PIL .Image .LANCZOS ,
120
+ }[interpolation ]
121
+ self .flip = transforms .RandomHorizontalFlip (p = flip_p )
122
+
123
+ def __len__ (self ):
124
+ return self ._length
125
+
126
+ def __getitem__ (self , i ):
127
+ example = {}
128
+ image = Image .open (self .image_paths [i % self .num_images ])
129
+
130
+ placeholder_string = self .placeholder_token
131
+ if self .coarse_class_text :
132
+ placeholder_string = f"{ self .coarse_class_text } { placeholder_string } "
133
+
134
+ image = image .convert ('RGBA' )
135
+ new_image = Image .new ('RGBA' , image .size , 'WHITE' )
136
+ new_image .paste (image , (0 , 0 ), image )
137
+ image = new_image .convert ('RGB' )
138
+
139
+ templates = [
140
+ 'a {} portrait of {}' ,
141
+ 'an {} image of {}' ,
142
+ 'a {} pretty picture of {}' ,
143
+ 'a {} clip art picture of {}' ,
144
+ 'an {} illustration of {}' ,
145
+ 'a {} 3D render of {}' ,
146
+ 'a {} {}' ,
147
+ ]
148
+
149
+ filename = os .path .basename (self .image_paths [i % self .num_images ])
150
+ filename_tokens = os .path .splitext (filename )[0 ].replace (' ' , '-' ).replace ('_' , '-' ).split ('-' )
151
+ filename_tokens = [token for token in filename_tokens if token .isalpha ()]
152
+
153
+ text = random .choice (templates ).format (' ' .join (filename_tokens ), self .placeholder_token )
154
+
155
+ example ["caption" ] = text
156
+
157
+ # default to score-sde preprocessing
158
+ img = np .array (image ).astype (np .uint8 )
159
+
160
+ if self .center_crop :
161
+ crop = min (img .shape [0 ], img .shape [1 ])
162
+ h , w , = img .shape [0 ], img .shape [1 ]
163
+ img = img [(h - crop ) // 2 :(h + crop ) // 2 ,
164
+ (w - crop ) // 2 :(w + crop ) // 2 ]
165
+
166
+ image = Image .fromarray (img )
167
+ if self .size is not None :
168
+ image = image .resize ((self .size , self .size ), resample = self .interpolation )
169
+
170
+ image = self .flip (image )
171
+ image = np .array (image ).astype (np .uint8 )
172
+ example ["image" ] = (image / 127.5 - 1.0 ).astype (np .float32 )
173
+ return example
0 commit comments