From ccfba19c3d3e74f1afd71379a3e6fd42a1a7bcee Mon Sep 17 00:00:00 2001 From: Christopher McBride Date: Mon, 7 Aug 2023 12:20:04 -0400 Subject: [PATCH 1/6] Add typehints to untyped module & remove suppression from .toml --- pyproject.toml | 12 +++++++++--- smartsim/_core/_install/buildenv.py | 20 +++++++++++++++----- smartsim/_core/control/controller.py | 2 +- smartsim/_core/entrypoints/colocated.py | 13 +++++++------ smartsim/_core/launcher/pbs/pbsParser.py | 2 +- smartsim/_core/launcher/taskManager.py | 9 +++++++-- smartsim/_core/utils/helpers.py | 2 +- smartsim/_core/utils/redis.py | 4 +++- smartsim/database/orchestrator.py | 2 +- smartsim/ml/data.py | 21 ++++++++++++++------- smartsim/ml/tf/data.py | 6 ++++-- smartsim/settings/settings.py | 4 ++-- 12 files changed, 65 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index def4063e0..5c61b9c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,8 +73,12 @@ namespace_packages = true files = [ "smartsim" ] -plugins = [] -ignore_errors=false +plugins = ["numpy.typing.mypy_plugin"] +ignore_errors = false + +# Dynamic typing +disallow_any_generics = true +warn_return_any = true # Strict fn defs disallow_untyped_calls = true @@ -84,7 +88,9 @@ disallow_untyped_decorators = true # Safety/Upgrading Mypy warn_unused_ignores = true -# warn_redundant_casts = true # not a per-module setting? +warn_redundant_casts = true +warn_unused_configs = true +show_error_codes = true [[tool.mypy.overrides]] # Ignore packages that are not used or not typed diff --git a/smartsim/_core/_install/buildenv.py b/smartsim/_core/_install/buildenv.py index 2ab529425..d9328f52e 100644 --- a/smartsim/_core/_install/buildenv.py +++ b/smartsim/_core/_install/buildenv.py @@ -105,31 +105,41 @@ def patch(self) -> str: def __gt__(self, cmp: t.Any) -> bool: try: - return Version(self).__gt__(self._convert_to_version(cmp)) + if Version(self).__gt__(self._convert_to_version(cmp)): + return True + return False except InvalidVersion: return super().__gt__(cmp) def __lt__(self, cmp: t.Any) -> bool: try: - return Version(self).__lt__(self._convert_to_version(cmp)) + if Version(self).__lt__(self._convert_to_version(cmp)): + return True + return False except InvalidVersion: return super().__lt__(cmp) def __eq__(self, cmp: t.Any) -> bool: try: - return Version(self).__eq__(self._convert_to_version(cmp)) + if Version(self).__eq__(self._convert_to_version(cmp)): + return True + return False except InvalidVersion: return super().__eq__(cmp) def __ge__(self, cmp: t.Any) -> bool: try: - return Version(self).__ge__(self._convert_to_version(cmp)) + if Version(self).__ge__(self._convert_to_version(cmp)): + return True + return False except InvalidVersion: return super().__ge__(cmp) def __le__(self, cmp: t.Any) -> bool: try: - return Version(self).__le__(self._convert_to_version(cmp)) + if Version(self).__le__(self._convert_to_version(cmp)): + return True + return False except InvalidVersion: return super().__le__(cmp) diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 36c221268..2fc870c12 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -618,7 +618,7 @@ def reload_saved_db(self, checkpoint_file: str) -> Orchestrator: if not self._jobs.actively_monitoring: self._jobs.start() - return orc + return orc # type: ignore # noqa: no-any-return def _set_dbobjects(self, manifest: Manifest) -> None: if not manifest.has_db_objects: diff --git a/smartsim/_core/entrypoints/colocated.py b/smartsim/_core/entrypoints/colocated.py index c672f1500..ba82d355f 100644 --- a/smartsim/_core/entrypoints/colocated.py +++ b/smartsim/_core/entrypoints/colocated.py @@ -94,11 +94,12 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str: if args.outputs: outputs = list(args.outputs) - # devices_per_node being greater than one only applies - # to GPU devices + name = str(args.name) + + # devices_per_node being greater than one only applies to GPU devices if args.devices_per_node > 1 and args.device.lower() == "gpu": client.set_model_from_file_multigpu( - args.name, + name, args.file, args.backend, 0, @@ -111,7 +112,7 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str: ) else: client.set_model_from_file( - args.name, + name, args.file, args.backend, args.device, @@ -122,7 +123,7 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str: outputs, ) - return args.name + return name def launch_db_script(client: Client, db_script: t.List[str]) -> str: @@ -163,7 +164,7 @@ def launch_db_script(client: Client, db_script: t.List[str]) -> str: else: raise ValueError("No file or func provided.") - return args.name + return str(args.name) def main( diff --git a/smartsim/_core/launcher/pbs/pbsParser.py b/smartsim/_core/launcher/pbs/pbsParser.py index 724659188..426166342 100644 --- a/smartsim/_core/launcher/pbs/pbsParser.py +++ b/smartsim/_core/launcher/pbs/pbsParser.py @@ -122,7 +122,7 @@ def parse_step_id_from_qstat(output: str, step_name: str) -> t.Optional[str]: :return: the step_id :rtype: str """ - step_id = None + step_id: t.Optional[str] = None out_json = load_and_clean_json(output) if "Jobs" not in out_json: diff --git a/smartsim/_core/launcher/taskManager.py b/smartsim/_core/launcher/taskManager.py index 7068a49b8..d244db304 100644 --- a/smartsim/_core/launcher/taskManager.py +++ b/smartsim/_core/launcher/taskManager.py @@ -290,7 +290,10 @@ def check_status(self) -> t.Optional[int]: :rtype: int """ if self.owned and isinstance(self.process, psutil.Popen): - return self.process.poll() + poll_result = self.process.poll() + if poll_result is not None: + return int(poll_result) + return None # we can't manage Processed we don't own # have to rely on .kill() to stop. return self.returncode @@ -363,7 +366,9 @@ def wait(self) -> None: @property def returncode(self) -> t.Optional[int]: if self.owned and isinstance(self.process, psutil.Popen): - return self.process.returncode + if self.process.returncode is not None: + return int(self.process.returncode) + return None if self.is_alive: return None return 0 diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index d5327c1d1..f37dddb35 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -81,7 +81,7 @@ def get_base_36_repr(positive_int: int) -> str: def init_default( default: t.Any, init_value: t.Any, - expected_type: t.Optional[t.Union[t.Type, t.Tuple]] = None, + expected_type: t.Union[t.Type[t.Any], t.Tuple[t.Type[t.Any], ...], None] = None, ) -> t.Any: if init_value is None: return default diff --git a/smartsim/_core/utils/redis.py b/smartsim/_core/utils/redis.py index ddeaf7ed4..55bf4ca6a 100644 --- a/smartsim/_core/utils/redis.py +++ b/smartsim/_core/utils/redis.py @@ -108,7 +108,9 @@ def check_cluster_status( # wait for cluster to spin up time.sleep(5) try: - redis_tester: RedisCluster = RedisCluster(startup_nodes=cluster_nodes) + redis_tester: RedisCluster[t.Any] = RedisCluster( + startup_nodes=cluster_nodes + ) redis_tester.set("__test__", "__test__") redis_tester.delete("__test__") # type: ignore logger.debug("Cluster status verified") diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index 4bf4687be..aa29bb955 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -671,7 +671,7 @@ def _build_run_settings_lsf( return erf_rs def _initialize_entities(self, **kwargs: t.Any) -> None: - self.db_nodes = kwargs.get("db_nodes", 1) + self.db_nodes = int(kwargs.get("db_nodes", 1)) single_cmd = kwargs.get("single_cmd", True) if int(self.db_nodes) == 2: diff --git a/smartsim/ml/data.py b/smartsim/ml/data.py index 4b8f62fa4..10d365aed 100644 --- a/smartsim/ml/data.py +++ b/smartsim/ml/data.py @@ -35,6 +35,7 @@ from ..error import SSInternalError from ..log import get_logger + logger = get_logger(__name__) @@ -205,7 +206,9 @@ def publish_info(self) -> None: self._info.publish(self.client) def put_batch( - self, samples: np.ndarray, targets: t.Optional[np.ndarray] = None + self, + samples: np.ndarray, # type: ignore[type-arg] + targets: t.Optional[np.ndarray] = None, # type: ignore[type-arg] ) -> None: batch_ds_name = form_name("training_samples", self.rank, self.batch_idx) batch_ds = Dataset(batch_ds_name) @@ -381,10 +384,12 @@ def __len__(self) -> int: length = int(np.floor(self.num_samples / self.batch_size)) return length - def _calc_indices(self, index: int) -> np.ndarray: + def _calc_indices(self, index: int) -> np.ndarray: # type: ignore[type-arg] return self.indices[index * self.batch_size : (index + 1) * self.batch_size] - def __iter__(self) -> t.Iterator[t.Tuple[np.ndarray, np.ndarray]]: + def __iter__( + self, + ) -> t.Iterator[t.Tuple[np.ndarray, np.ndarray]]: # type: ignore[type-arg] self.update_data() # Generate data if len(self) < 1: @@ -426,11 +431,11 @@ def init_samples(self, init_trials: int = -1) -> None: def _data_exists(self, batch_name: str, target_name: str) -> bool: if self.need_targets: - return self.client.tensor_exists(batch_name) and self.client.tensor_exists( - target_name + return all( + self.client.tensor_exists(datum) for datum in [batch_name, target_name] ) - return self.client.tensor_exists(batch_name) + return bool(self.client.tensor_exists(batch_name)) def _add_samples(self, indices: t.List[int]) -> None: datasets: t.List[Dataset] = [] @@ -491,7 +496,9 @@ def update_data(self) -> None: if self.shuffle: np.random.shuffle(self.indices) - def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]: + def _data_generation( + self, indices: np.ndarray # type: ignore[type-arg]# type: ignore + ) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] # Initialization if self.samples is None: raise ValueError("Samples have not been initialized") diff --git a/smartsim/ml/tf/data.py b/smartsim/ml/tf/data.py index ab4ae18c0..823553786 100644 --- a/smartsim/ml/tf/data.py +++ b/smartsim/ml/tf/data.py @@ -32,7 +32,9 @@ class _TFDataGenerationCommon(DataDownloader, keras.utils.Sequence): - def __getitem__(self, index: int) -> t.Tuple[np.ndarray, np.ndarray]: + def __getitem__( + self, index: int + ) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] if len(self) < 1: raise ValueError( "Not enough samples in generator for one batch. Please " @@ -57,7 +59,7 @@ def on_epoch_end(self) -> None: if self.shuffle: np.random.shuffle(self.indices) - def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]: + def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] # Initialization if self.samples is None: raise ValueError("No samples loaded for data generation") diff --git a/smartsim/settings/settings.py b/smartsim/settings/settings.py index 1189cd505..ef95019ab 100644 --- a/smartsim/settings/settings.py +++ b/smartsim/settings/settings.py @@ -77,7 +77,7 @@ def create_batch_settings( :raises SmartSimError: if batch creation fails """ # all supported batch class implementations - by_launcher = { + by_launcher: t.Dict[str, t.Callable[..., base.BatchSettings]] = { "cobalt": CobaltBatchSettings, "pbs": QsubBatchSettings, "slurm": SbatchSettings, @@ -144,7 +144,7 @@ def create_run_settings( :raises SmartSimError: if run_command=="auto" and detection fails """ # all supported RunSettings child classes - supported = { + supported: t.Dict[str, t.Callable[..., RunSettings]] = { "aprun": AprunSettings, "srun": SrunSettings, "mpirun": MpirunSettings, From cb541f68a1ae4024629ddb6a4f3ebcc72d8b89a1 Mon Sep 17 00:00:00 2001 From: Christopher McBride Date: Tue, 8 Aug 2023 12:00:07 -0400 Subject: [PATCH 2/6] remove unused numpy mypy plugin --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5c61b9c9a..c4387be8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ namespace_packages = true files = [ "smartsim" ] -plugins = ["numpy.typing.mypy_plugin"] +plugins = [] ignore_errors = false # Dynamic typing From be11c4949e9fc219f84cd95c8138c921867f2328 Mon Sep 17 00:00:00 2001 From: Christopher McBride Date: Tue, 8 Aug 2023 18:13:09 -0400 Subject: [PATCH 3/6] cast orchestrator instead of suppressing --- smartsim/_core/control/controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 2fc870c12..bc38370b9 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -600,7 +600,7 @@ def reload_saved_db(self, checkpoint_file: str) -> Orchestrator: raise SmartSimError( err_message + "Could not find database job objects." ) - orc = db_config["db"] + orc: Orchestrator = db_config["db"] # TODO check that each db_object is running @@ -618,7 +618,7 @@ def reload_saved_db(self, checkpoint_file: str) -> Orchestrator: if not self._jobs.actively_monitoring: self._jobs.start() - return orc # type: ignore # noqa: no-any-return + return orc def _set_dbobjects(self, manifest: Manifest) -> None: if not manifest.has_db_objects: From aeb87ccc5aa377ebe9cf45a4e9cb24e8e143f499 Mon Sep 17 00:00:00 2001 From: Christopher McBride Date: Tue, 8 Aug 2023 18:13:23 -0400 Subject: [PATCH 4/6] fix dupe ignore --- smartsim/ml/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smartsim/ml/data.py b/smartsim/ml/data.py index 10d365aed..8e617f914 100644 --- a/smartsim/ml/data.py +++ b/smartsim/ml/data.py @@ -497,7 +497,7 @@ def update_data(self) -> None: np.random.shuffle(self.indices) def _data_generation( - self, indices: np.ndarray # type: ignore[type-arg]# type: ignore + self, indices: np.ndarray # type: ignore[type-arg] ) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] # Initialization if self.samples is None: From c433ae7dbd5b75d60237080fc4686689006071ba Mon Sep 17 00:00:00 2001 From: Christopher McBride Date: Tue, 8 Aug 2023 18:17:54 -0400 Subject: [PATCH 5/6] undo unnecessary boolean handling --- smartsim/_core/_install/buildenv.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/smartsim/_core/_install/buildenv.py b/smartsim/_core/_install/buildenv.py index d9328f52e..7e29f9e86 100644 --- a/smartsim/_core/_install/buildenv.py +++ b/smartsim/_core/_install/buildenv.py @@ -105,41 +105,31 @@ def patch(self) -> str: def __gt__(self, cmp: t.Any) -> bool: try: - if Version(self).__gt__(self._convert_to_version(cmp)): - return True - return False + return bool(Version(self).__gt__(self._convert_to_version(cmp))) except InvalidVersion: return super().__gt__(cmp) def __lt__(self, cmp: t.Any) -> bool: try: - if Version(self).__lt__(self._convert_to_version(cmp)): - return True - return False + return bool(Version(self).__lt__(self._convert_to_version(cmp))) except InvalidVersion: return super().__lt__(cmp) def __eq__(self, cmp: t.Any) -> bool: try: - if Version(self).__eq__(self._convert_to_version(cmp)): - return True - return False + return bool(Version(self).__eq__(self._convert_to_version(cmp))) except InvalidVersion: return super().__eq__(cmp) def __ge__(self, cmp: t.Any) -> bool: try: - if Version(self).__ge__(self._convert_to_version(cmp)): - return True - return False + return bool(Version(self).__ge__(self._convert_to_version(cmp))) except InvalidVersion: return super().__ge__(cmp) def __le__(self, cmp: t.Any) -> bool: try: - if Version(self).__le__(self._convert_to_version(cmp)): - return True - return False + return bool(Version(self).__le__(self._convert_to_version(cmp))) except InvalidVersion: return super().__le__(cmp) From ea3d5a0bd4d80ebed10fd51fd9589195b7f8bbd4 Mon Sep 17 00:00:00 2001 From: Christopher McBride Date: Wed, 9 Aug 2023 12:19:15 -0400 Subject: [PATCH 6/6] avoid type arg due to generic type stub --- smartsim/_core/utils/redis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smartsim/_core/utils/redis.py b/smartsim/_core/utils/redis.py index 55bf4ca6a..9645a367e 100644 --- a/smartsim/_core/utils/redis.py +++ b/smartsim/_core/utils/redis.py @@ -108,7 +108,7 @@ def check_cluster_status( # wait for cluster to spin up time.sleep(5) try: - redis_tester: RedisCluster[t.Any] = RedisCluster( + redis_tester: "RedisCluster[t.Any]" = RedisCluster( startup_nodes=cluster_nodes ) redis_tester.set("__test__", "__test__")