Skip to content

Commit c3560d8

Browse files
authored
Add files via upload
1 parent 7b62670 commit c3560d8

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed

ldm/data/personalized.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import os
2+
import numpy as np
3+
import PIL
4+
from PIL import Image
5+
from torch.utils.data import Dataset
6+
from torchvision import transforms
7+
8+
import random
9+
10+
imagenet_templates_smallest = [
11+
'a photo of a {}',
12+
]
13+
14+
imagenet_templates_small = [
15+
'a photo of a {}',
16+
'a rendering of a {}',
17+
'a cropped photo of the {}',
18+
'the photo of a {}',
19+
'a photo of a clean {}',
20+
'a photo of a dirty {}',
21+
'a dark photo of the {}',
22+
'a photo of my {}',
23+
'a photo of the cool {}',
24+
'a close-up photo of a {}',
25+
'a bright photo of the {}',
26+
'a cropped photo of a {}',
27+
'a photo of the {}',
28+
'a good photo of the {}',
29+
'a photo of one {}',
30+
'a close-up photo of the {}',
31+
'a rendition of the {}',
32+
'a photo of the clean {}',
33+
'a rendition of a {}',
34+
'a photo of a nice {}',
35+
'a good photo of a {}',
36+
'a photo of the nice {}',
37+
'a photo of the small {}',
38+
'a photo of the weird {}',
39+
'a photo of the large {}',
40+
'a photo of a cool {}',
41+
'a photo of a small {}',
42+
]
43+
44+
imagenet_dual_templates_small = [
45+
'a photo of a {} with {}',
46+
'a rendering of a {} with {}',
47+
'a cropped photo of the {} with {}',
48+
'the photo of a {} with {}',
49+
'a photo of a clean {} with {}',
50+
'a photo of a dirty {} with {}',
51+
'a dark photo of the {} with {}',
52+
'a photo of my {} with {}',
53+
'a photo of the cool {} with {}',
54+
'a close-up photo of a {} with {}',
55+
'a bright photo of the {} with {}',
56+
'a cropped photo of a {} with {}',
57+
'a photo of the {} with {}',
58+
'a good photo of the {} with {}',
59+
'a photo of one {} with {}',
60+
'a close-up photo of the {} with {}',
61+
'a rendition of the {} with {}',
62+
'a photo of the clean {} with {}',
63+
'a rendition of a {} with {}',
64+
'a photo of a nice {} with {}',
65+
'a good photo of a {} with {}',
66+
'a photo of the nice {} with {}',
67+
'a photo of the small {} with {}',
68+
'a photo of the weird {} with {}',
69+
'a photo of the large {} with {}',
70+
'a photo of a cool {} with {}',
71+
'a photo of a small {} with {}',
72+
]
73+
74+
per_img_token_list = [
75+
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
76+
]
77+
78+
class PersonalizedBase(Dataset):
79+
def __init__(self,
80+
data_root,
81+
size=None,
82+
repeats=100,
83+
interpolation="bicubic",
84+
flip_p=0.5,
85+
set="train",
86+
placeholder_token="*",
87+
per_image_tokens=False,
88+
center_crop=False,
89+
mixing_prob=0.25,
90+
coarse_class_text=None,
91+
):
92+
93+
self.data_root = data_root
94+
95+
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
96+
97+
# self._length = len(self.image_paths)
98+
self.num_images = len(self.image_paths)
99+
self._length = self.num_images
100+
101+
self.placeholder_token = placeholder_token
102+
103+
self.per_image_tokens = per_image_tokens
104+
self.center_crop = center_crop
105+
self.mixing_prob = mixing_prob
106+
107+
self.coarse_class_text = coarse_class_text
108+
109+
if per_image_tokens:
110+
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
111+
112+
if set == "train":
113+
self._length = self.num_images * repeats
114+
115+
self.size = size
116+
self.interpolation = {"linear": PIL.Image.LINEAR,
117+
"bilinear": PIL.Image.BILINEAR,
118+
"bicubic": PIL.Image.BICUBIC,
119+
"lanczos": PIL.Image.LANCZOS,
120+
}[interpolation]
121+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
122+
123+
def __len__(self):
124+
return self._length
125+
126+
def __getitem__(self, i):
127+
example = {}
128+
image = Image.open(self.image_paths[i % self.num_images])
129+
130+
placeholder_string = self.placeholder_token
131+
if self.coarse_class_text:
132+
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
133+
134+
image = image.convert('RGBA')
135+
new_image = Image.new('RGBA', image.size, 'WHITE')
136+
new_image.paste(image, (0, 0), image)
137+
image = new_image.convert('RGB')
138+
139+
templates = [
140+
'a {} portrait of {}',
141+
'an {} image of {}',
142+
'a {} pretty picture of {}',
143+
'a {} clip art picture of {}',
144+
'an {} illustration of {}',
145+
'a {} 3D render of {}',
146+
'a {} {}',
147+
]
148+
149+
filename = os.path.basename(self.image_paths[i % self.num_images])
150+
filename_tokens = os.path.splitext(filename)[0].replace(' ', '-').replace('_', '-').split('-')
151+
filename_tokens = [token for token in filename_tokens if token.isalpha()]
152+
153+
text = random.choice(templates).format(' '.join(filename_tokens), self.placeholder_token)
154+
155+
example["caption"] = text
156+
157+
# default to score-sde preprocessing
158+
img = np.array(image).astype(np.uint8)
159+
160+
if self.center_crop:
161+
crop = min(img.shape[0], img.shape[1])
162+
h, w, = img.shape[0], img.shape[1]
163+
img = img[(h - crop) // 2:(h + crop) // 2,
164+
(w - crop) // 2:(w + crop) // 2]
165+
166+
image = Image.fromarray(img)
167+
if self.size is not None:
168+
image = image.resize((self.size, self.size), resample=self.interpolation)
169+
170+
image = self.flip(image)
171+
image = np.array(image).astype(np.uint8)
172+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
173+
return example

ldm/data/personalized_style.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
import numpy as np
3+
import PIL
4+
from PIL import Image
5+
from torch.utils.data import Dataset
6+
from torchvision import transforms
7+
8+
import random
9+
10+
imagenet_templates_small = [
11+
'a painting in the style of {}',
12+
'a rendering in the style of {}',
13+
'a cropped painting in the style of {}',
14+
'the painting in the style of {}',
15+
'a clean painting in the style of {}',
16+
'a dirty painting in the style of {}',
17+
'a dark painting in the style of {}',
18+
'a picture in the style of {}',
19+
'a cool painting in the style of {}',
20+
'a close-up painting in the style of {}',
21+
'a bright painting in the style of {}',
22+
'a cropped painting in the style of {}',
23+
'a good painting in the style of {}',
24+
'a close-up painting in the style of {}',
25+
'a rendition in the style of {}',
26+
'a nice painting in the style of {}',
27+
'a small painting in the style of {}',
28+
'a weird painting in the style of {}',
29+
'a large painting in the style of {}',
30+
]
31+
32+
imagenet_dual_templates_small = [
33+
'a painting in the style of {} with {}',
34+
'a rendering in the style of {} with {}',
35+
'a cropped painting in the style of {} with {}',
36+
'the painting in the style of {} with {}',
37+
'a clean painting in the style of {} with {}',
38+
'a dirty painting in the style of {} with {}',
39+
'a dark painting in the style of {} with {}',
40+
'a cool painting in the style of {} with {}',
41+
'a close-up painting in the style of {} with {}',
42+
'a bright painting in the style of {} with {}',
43+
'a cropped painting in the style of {} with {}',
44+
'a good painting in the style of {} with {}',
45+
'a painting of one {} in the style of {}',
46+
'a nice painting in the style of {} with {}',
47+
'a small painting in the style of {} with {}',
48+
'a weird painting in the style of {} with {}',
49+
'a large painting in the style of {} with {}',
50+
]
51+
52+
per_img_token_list = [
53+
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
54+
]
55+
56+
class PersonalizedBase(Dataset):
57+
def __init__(self,
58+
data_root,
59+
size=None,
60+
repeats=100,
61+
interpolation="bicubic",
62+
flip_p=0.5,
63+
set="train",
64+
placeholder_token="*",
65+
per_image_tokens=False,
66+
center_crop=False,
67+
):
68+
69+
self.data_root = data_root
70+
71+
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
72+
73+
# self._length = len(self.image_paths)
74+
self.num_images = len(self.image_paths)
75+
self._length = self.num_images
76+
77+
self.placeholder_token = placeholder_token
78+
79+
self.per_image_tokens = per_image_tokens
80+
self.center_crop = center_crop
81+
82+
if per_image_tokens:
83+
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
84+
85+
if set == "train":
86+
self._length = self.num_images * repeats
87+
88+
self.size = size
89+
self.interpolation = {"linear": PIL.Image.LINEAR,
90+
"bilinear": PIL.Image.BILINEAR,
91+
"bicubic": PIL.Image.BICUBIC,
92+
"lanczos": PIL.Image.LANCZOS,
93+
}[interpolation]
94+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
95+
96+
def __len__(self):
97+
return self._length
98+
99+
def __getitem__(self, i):
100+
example = {}
101+
image = Image.open(self.image_paths[i % self.num_images])
102+
103+
image = image.convert('RGBA')
104+
new_image = Image.new('RGBA', image.size, 'WHITE')
105+
new_image.paste(image, (0, 0), image)
106+
image = new_image.convert('RGB')
107+
108+
templates = [
109+
'a {} portrait of {}',
110+
'an {} image of {}',
111+
'a {} pretty picture of {}',
112+
'a {} clip art picture of {}',
113+
'an {} illustration of {}',
114+
'a {} 3D render of {}',
115+
'a {} {}',
116+
]
117+
118+
filename = os.path.basename(self.image_paths[i % self.num_images])
119+
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').split('-')
120+
filename_tokens = [token for token in filename_tokens if token.isalpha()]
121+
122+
text = random.choice(templates).format(' '.join(filename_tokens), self.placeholder_token)
123+
print(text)
124+
125+
example["caption"] = text
126+
127+
# default to score-sde preprocessing
128+
img = np.array(image).astype(np.uint8)
129+
130+
if self.center_crop:
131+
crop = min(img.shape[0], img.shape[1])
132+
h, w, = img.shape[0], img.shape[1]
133+
img = img[(h - crop) // 2:(h + crop) // 2,
134+
(w - crop) // 2:(w + crop) // 2]
135+
136+
image = Image.fromarray(img)
137+
if self.size is not None:
138+
image = image.resize((self.size, self.size), resample=self.interpolation)
139+
140+
image = self.flip(image)
141+
image = np.array(image).astype(np.uint8)
142+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
143+
return example

0 commit comments

Comments
 (0)