Skip to content

Commit

Permalink
Merge pull request #4 from kabkabm/get_celeba_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
po0ya authored Nov 13, 2018
2 parents ed9eb6e + e230b7c commit 7e3feae
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
4 changes: 2 additions & 2 deletions blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ def get_celeba(data_path, test_on_dev=True, orig_data=False):
images: Images of the dataset.
labels: Labels of the loaded images.
"""
dev_name = 'dev'
dev_name = 'val'
if not test_on_dev:
dev_name = 'test'
ds = CelebA(attribute=FLAGS.attribute)
ds.load()
ds_test = CelebA(attribute=FLAGS.attribute)
ds_test.load(split=dev_name, transform_type=1)
ds_test.load(split=dev_name)
train_labels = ds.labels
test_labels = ds_test.labels

Expand Down
21 changes: 18 additions & 3 deletions datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __getitem__(self, index):
if isinstance(index, int):
return self._get_image(self.filepaths[index])
# Case of a slice or array of indices.
elif isinstance(index, slice) or isinstance(index, np.ndarray):
elif isinstance(index, slice):
if isinstance(index, slice):
if index.start is None:
index = range(index.stop)
Expand All @@ -146,9 +146,16 @@ def __getitem__(self, index):
else:
index = range(index.start, index.stop, index.step)
return np.array(
[self._get_image(self.filepaths[i]) for i in index])
[self._get_image(self.filepaths[i]) for i in index]
)
else:
raise TypeError("Index must be an integer or a slice.")
try:
inds = [int(i) for i in index]
return np.array(
[self._get_image(self.filepaths[i]) for i in inds]
)
except TypeError:
raise TypeError("Index must be an integer, a slice, a container or an integer generator.")

def get_subset(self, indices):
"""Gets a subset of the images
Expand All @@ -167,6 +174,14 @@ def get_subset(self, indices):
else:
raise TypeError("Index must be an integer or a slice.")

@property
def shape(self):
return tuple([None] + list(self._get_image(self.filepaths[0]).shape))

@property
def dtype(self):
return self._get_image(self.filepaths[0]).dtype


class PickleLazyDataset(LazyDataset):
"""This dataset is a lazy dataset for working with saved pickle files
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
numpy==1.14.2
scipy==1.0.1
tensorflow-gpu==1.7.0
requests==2.20.0
keras==2.1.5
opencv-python==3.4.0.12
scikit-image==0.13.1
matplotlib==2.1.2
matplotlib==2.1.2
tqdm=4.28.1

0 comments on commit 7e3feae

Please sign in to comment.