Skip to content

Commit

Permalink
swap to type strings for py38
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Aug 8, 2023
1 parent 14f0d98 commit f4ad66c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions smartsim/ml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def publish_info(self) -> None:

def put_batch(
self,
samples: npt.NDArray[t.Any],
targets: t.Optional[npt.NDArray[t.Any]] = None,
samples: "npt.NDArray[t.Any]",
targets: "t.Optional[npt.NDArray[t.Any]]" = None,
) -> None:
batch_ds_name = form_name("training_samples", self.rank, self.batch_idx)
batch_ds = Dataset(batch_ds_name)
Expand Down Expand Up @@ -385,12 +385,12 @@ def __len__(self) -> int:
length = int(np.floor(self.num_samples / self.batch_size))
return length

def _calc_indices(self, index: int) -> npt.NDArray[t.Any]:
def _calc_indices(self, index: int) -> "npt.NDArray[t.Any]":
return self.indices[index * self.batch_size : (index + 1) * self.batch_size]

def __iter__(
self,
) -> t.Iterator[t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]]:
) -> "t.Iterator[t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]]":
self.update_data()
# Generate data
if len(self) < 1:
Expand Down Expand Up @@ -498,8 +498,8 @@ def update_data(self) -> None:
np.random.shuffle(self.indices)

def _data_generation(
self, indices: npt.NDArray[t.Any]
) -> t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]:
self, indices: "npt.NDArray[t.Any]"
) -> "t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]":
# Initialization
if self.samples is None:
raise ValueError("Samples have not been initialized")
Expand Down
4 changes: 2 additions & 2 deletions smartsim/ml/tf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


class _TFDataGenerationCommon(DataDownloader, keras.utils.Sequence):
def __getitem__(self, index: int) -> t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]:
def __getitem__(self, index: int) -> "t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]":
if len(self) < 1:
raise ValueError(
"Not enough samples in generator for one batch. Please "
Expand All @@ -58,7 +58,7 @@ def on_epoch_end(self) -> None:
if self.shuffle:
np.random.shuffle(self.indices)

def _data_generation(self, indices: npt.NDArray[t.Any]) -> t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]:
def _data_generation(self, indices: "npt.NDArray[t.Any]") -> "t.Tuple[npt.NDArray[t.Any], npt.NDArray[t.Any]]":
# Initialization
if self.samples is None:
raise ValueError("No samples loaded for data generation")
Expand Down

0 comments on commit f4ad66c

Please sign in to comment.