Skip to content

Commit

Permalink
Add typehints to untyped module & remove suppression from .toml
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Aug 8, 2023
1 parent 1d2ec9f commit ccfba19
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 32 deletions.
12 changes: 9 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ namespace_packages = true
files = [
"smartsim"
]
plugins = []
ignore_errors=false
plugins = ["numpy.typing.mypy_plugin"]
ignore_errors = false

# Dynamic typing
disallow_any_generics = true
warn_return_any = true

# Strict fn defs
disallow_untyped_calls = true
Expand All @@ -84,7 +88,9 @@ disallow_untyped_decorators = true

# Safety/Upgrading Mypy
warn_unused_ignores = true
# warn_redundant_casts = true # not a per-module setting?
warn_redundant_casts = true
warn_unused_configs = true
show_error_codes = true

[[tool.mypy.overrides]]
# Ignore packages that are not used or not typed
Expand Down
20 changes: 15 additions & 5 deletions smartsim/_core/_install/buildenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,41 @@ def patch(self) -> str:

def __gt__(self, cmp: t.Any) -> bool:
try:
return Version(self).__gt__(self._convert_to_version(cmp))
if Version(self).__gt__(self._convert_to_version(cmp)):
return True
return False
except InvalidVersion:
return super().__gt__(cmp)

def __lt__(self, cmp: t.Any) -> bool:
try:
return Version(self).__lt__(self._convert_to_version(cmp))
if Version(self).__lt__(self._convert_to_version(cmp)):
return True
return False
except InvalidVersion:
return super().__lt__(cmp)

def __eq__(self, cmp: t.Any) -> bool:
try:
return Version(self).__eq__(self._convert_to_version(cmp))
if Version(self).__eq__(self._convert_to_version(cmp)):
return True
return False
except InvalidVersion:
return super().__eq__(cmp)

def __ge__(self, cmp: t.Any) -> bool:
try:
return Version(self).__ge__(self._convert_to_version(cmp))
if Version(self).__ge__(self._convert_to_version(cmp)):
return True
return False
except InvalidVersion:
return super().__ge__(cmp)

def __le__(self, cmp: t.Any) -> bool:
try:
return Version(self).__le__(self._convert_to_version(cmp))
if Version(self).__le__(self._convert_to_version(cmp)):
return True
return False
except InvalidVersion:
return super().__le__(cmp)

Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def reload_saved_db(self, checkpoint_file: str) -> Orchestrator:
if not self._jobs.actively_monitoring:
self._jobs.start()

return orc
return orc # type: ignore # noqa: no-any-return

def _set_dbobjects(self, manifest: Manifest) -> None:
if not manifest.has_db_objects:
Expand Down
13 changes: 7 additions & 6 deletions smartsim/_core/entrypoints/colocated.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str:
if args.outputs:
outputs = list(args.outputs)

# devices_per_node being greater than one only applies
# to GPU devices
name = str(args.name)

# devices_per_node being greater than one only applies to GPU devices
if args.devices_per_node > 1 and args.device.lower() == "gpu":
client.set_model_from_file_multigpu(
args.name,
name,
args.file,
args.backend,
0,
Expand All @@ -111,7 +112,7 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str:
)
else:
client.set_model_from_file(
args.name,
name,
args.file,
args.backend,
args.device,
Expand All @@ -122,7 +123,7 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str:
outputs,
)

return args.name
return name


def launch_db_script(client: Client, db_script: t.List[str]) -> str:
Expand Down Expand Up @@ -163,7 +164,7 @@ def launch_db_script(client: Client, db_script: t.List[str]) -> str:
else:
raise ValueError("No file or func provided.")

return args.name
return str(args.name)


def main(
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/launcher/pbs/pbsParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def parse_step_id_from_qstat(output: str, step_name: str) -> t.Optional[str]:
:return: the step_id
:rtype: str
"""
step_id = None
step_id: t.Optional[str] = None
out_json = load_and_clean_json(output)

if "Jobs" not in out_json:
Expand Down
9 changes: 7 additions & 2 deletions smartsim/_core/launcher/taskManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ def check_status(self) -> t.Optional[int]:
:rtype: int
"""
if self.owned and isinstance(self.process, psutil.Popen):
return self.process.poll()
poll_result = self.process.poll()
if poll_result is not None:
return int(poll_result)
return None
# we can't manage Processed we don't own
# have to rely on .kill() to stop.
return self.returncode
Expand Down Expand Up @@ -363,7 +366,9 @@ def wait(self) -> None:
@property
def returncode(self) -> t.Optional[int]:
if self.owned and isinstance(self.process, psutil.Popen):
return self.process.returncode
if self.process.returncode is not None:
return int(self.process.returncode)
return None
if self.is_alive:
return None
return 0
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_base_36_repr(positive_int: int) -> str:
def init_default(
default: t.Any,
init_value: t.Any,
expected_type: t.Optional[t.Union[t.Type, t.Tuple]] = None,
expected_type: t.Union[t.Type[t.Any], t.Tuple[t.Type[t.Any], ...], None] = None,
) -> t.Any:
if init_value is None:
return default
Expand Down
4 changes: 3 additions & 1 deletion smartsim/_core/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def check_cluster_status(
# wait for cluster to spin up
time.sleep(5)
try:
redis_tester: RedisCluster = RedisCluster(startup_nodes=cluster_nodes)
redis_tester: RedisCluster[t.Any] = RedisCluster(
startup_nodes=cluster_nodes
)
redis_tester.set("__test__", "__test__")
redis_tester.delete("__test__") # type: ignore
logger.debug("Cluster status verified")
Expand Down
2 changes: 1 addition & 1 deletion smartsim/database/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def _build_run_settings_lsf(
return erf_rs

def _initialize_entities(self, **kwargs: t.Any) -> None:
self.db_nodes = kwargs.get("db_nodes", 1)
self.db_nodes = int(kwargs.get("db_nodes", 1))
single_cmd = kwargs.get("single_cmd", True)

if int(self.db_nodes) == 2:
Expand Down
21 changes: 14 additions & 7 deletions smartsim/ml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..error import SSInternalError
from ..log import get_logger


logger = get_logger(__name__)


Expand Down Expand Up @@ -205,7 +206,9 @@ def publish_info(self) -> None:
self._info.publish(self.client)

def put_batch(
self, samples: np.ndarray, targets: t.Optional[np.ndarray] = None
self,
samples: np.ndarray, # type: ignore[type-arg]
targets: t.Optional[np.ndarray] = None, # type: ignore[type-arg]
) -> None:
batch_ds_name = form_name("training_samples", self.rank, self.batch_idx)
batch_ds = Dataset(batch_ds_name)
Expand Down Expand Up @@ -381,10 +384,12 @@ def __len__(self) -> int:
length = int(np.floor(self.num_samples / self.batch_size))
return length

def _calc_indices(self, index: int) -> np.ndarray:
def _calc_indices(self, index: int) -> np.ndarray: # type: ignore[type-arg]
return self.indices[index * self.batch_size : (index + 1) * self.batch_size]

def __iter__(self) -> t.Iterator[t.Tuple[np.ndarray, np.ndarray]]:
def __iter__(
self,
) -> t.Iterator[t.Tuple[np.ndarray, np.ndarray]]: # type: ignore[type-arg]
self.update_data()
# Generate data
if len(self) < 1:
Expand Down Expand Up @@ -426,11 +431,11 @@ def init_samples(self, init_trials: int = -1) -> None:

def _data_exists(self, batch_name: str, target_name: str) -> bool:
if self.need_targets:
return self.client.tensor_exists(batch_name) and self.client.tensor_exists(
target_name
return all(
self.client.tensor_exists(datum) for datum in [batch_name, target_name]
)

return self.client.tensor_exists(batch_name)
return bool(self.client.tensor_exists(batch_name))

def _add_samples(self, indices: t.List[int]) -> None:
datasets: t.List[Dataset] = []
Expand Down Expand Up @@ -491,7 +496,9 @@ def update_data(self) -> None:
if self.shuffle:
np.random.shuffle(self.indices)

def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]:
def _data_generation(
self, indices: np.ndarray # type: ignore[type-arg]# type: ignore
) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg]
# Initialization
if self.samples is None:
raise ValueError("Samples have not been initialized")
Expand Down
6 changes: 4 additions & 2 deletions smartsim/ml/tf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@


class _TFDataGenerationCommon(DataDownloader, keras.utils.Sequence):
def __getitem__(self, index: int) -> t.Tuple[np.ndarray, np.ndarray]:
def __getitem__(
self, index: int
) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg]
if len(self) < 1:
raise ValueError(
"Not enough samples in generator for one batch. Please "
Expand All @@ -57,7 +59,7 @@ def on_epoch_end(self) -> None:
if self.shuffle:
np.random.shuffle(self.indices)

def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]:
def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg]
# Initialization
if self.samples is None:
raise ValueError("No samples loaded for data generation")
Expand Down
4 changes: 2 additions & 2 deletions smartsim/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create_batch_settings(
:raises SmartSimError: if batch creation fails
"""
# all supported batch class implementations
by_launcher = {
by_launcher: t.Dict[str, t.Callable[..., base.BatchSettings]] = {
"cobalt": CobaltBatchSettings,
"pbs": QsubBatchSettings,
"slurm": SbatchSettings,
Expand Down Expand Up @@ -144,7 +144,7 @@ def create_run_settings(
:raises SmartSimError: if run_command=="auto" and detection fails
"""
# all supported RunSettings child classes
supported = {
supported: t.Dict[str, t.Callable[..., RunSettings]] = {
"aprun": AprunSettings,
"srun": SrunSettings,
"mpirun": MpirunSettings,
Expand Down

0 comments on commit ccfba19

Please sign in to comment.