Skip to content

Commit

Permalink
Typehints (#295)
Browse files Browse the repository at this point in the history
Add mypy typehints to public SmartSim entities
  • Loading branch information
ankona authored Jun 8, 2023
1 parent d4f7df8 commit 395ffb0
Show file tree
Hide file tree
Showing 53 changed files with 1,295 additions and 815 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,10 @@ jobs:
with:
fail_ci_if_error: true
files: ./coverage.xml

- name: Run mypy
# TF 2.6.2 has a dep conflict with new mypy versions
if: (matrix.rai != '1.2.5')
run: |
python -m pip install .[mypy]
make check-mypy
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ check-lint:
@pylint --rcfile=.pylintrc ./smartsim


# help: check-mypy - run static type check
.PHONY: check-mypy
check-mypy:
@mypy --config-file=./pyproject.toml


# help:
# help: Documentation
# help: -------
Expand Down
3 changes: 3 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ former began deprecation in May 2022 and was finally removed in May 2023. (PR285
codes. These have now all been updated. (PR284_)
- Orchestrator and Colocated DB now accept a list of interfaces to bind to. The
argument name is still `interface` for backward compatibility reasons. (PR281_)
- Typehints have been added to public APIs. A makefile target to execute static
analysis with mypy is available `make check-mypy`. (PR295_)

.. _PR298: https://github.com/CrayLabs/SmartSim/pull/298
.. _PR293: https://github.com/CrayLabs/SmartSim/pull/293
Expand All @@ -63,6 +65,7 @@ argument name is still `interface` for backward compatibility reasons. (PR281_)
.. _PR285: https://github.com/CrayLabs/SmartSim/pull/285
.. _PR284: https://github.com/CrayLabs/SmartSim/pull/284
.. _PR281: https://github.com/CrayLabs/SmartSim/pull/281
.. _PR295: https://github.com/CrayLabs/SmartSim/pull/295

0.4.2
-----
Expand Down
37 changes: 37 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,40 @@ ignore_errors = true

[tool.coverage.html]
directory = "htmlcov"

[tool.mypy]
namespace_packages = true
files = [
"smartsim"
]
plugins = []
ignore_errors = true

[[tool.mypy.overrides]]
# Ignore packages that are not used or not typed
module = [
"coloredlogs",
"smartredis",
"smartredis.error"
]
ignore_missing_imports = true
ignore_errors = true

[[tool.mypy.overrides]]
module = [
"smartsim.database.*",
"smartsim.entity.*",
"smartsim.experiment"
]

ignore_errors=false

# Strict fn defs
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_decorators = true

# Safety/Upgrading Mypy
warn_unused_ignores = true
# warn_redundant_casts = true # not a per-module setting?
11 changes: 11 additions & 0 deletions requirements-mypy.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
mypy>=1.3.0

# From typeshed
types-psutil
types-redis
types-tabulate
types-tqdm
types-tensorflow
types-setuptools

# Not from typeshed
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ def has_ext_modules(_placeholder):
"pytest-cov>=2.10.1",
"click==8.0.2",
],
"mypy": [
"mypy>=1.3.0",
"types-psutil",
"types-redis",
"types-tabulate",
"types-tqdm",
"types-tensorflow",
"types-setuptools",
],
# see smartsim/_core/_install/buildenv.py for more details
"ml": versions.ml_extras_required(),
}
Expand Down
61 changes: 35 additions & 26 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,24 @@
import signal
import threading
import time
import typing as t

from smartredis import Client
from smartredis.error import RedisConnectionError, RedisReplyError

from ..._core.utils.redis import db_is_active, set_ml_model, set_script
from ...database import Orchestrator
from ...entity import DBModel, DBNode, DBObject, DBScript, EntityList, SmartSimEntity
from ...entity import DBNode, EntityList, SmartSimEntity
from ...error import LauncherError, SmartSimError, SSInternalError, SSUnsupportedError
from ...log import get_logger
from ...status import STATUS_RUNNING, TERMINAL_STATUSES
from ..config import CONFIG
from ..launcher import *
from ..utils import check_cluster_status, create_cluster
from .jobmanager import JobManager
from .manifest import Manifest
from .job import Job
from ...settings.base import BatchSettings


logger = get_logger(__name__)

Expand All @@ -56,7 +60,7 @@ class Controller:
underlying workload manager or run framework.
"""

def __init__(self, launcher="local"):
def __init__(self, launcher: str = "local") -> None:
"""Initialize a Controller
:param launcher: the type of launcher being used
Expand All @@ -65,7 +69,9 @@ def __init__(self, launcher="local"):
self._jobs = JobManager(JM_LOCK)
self.init_launcher(launcher)

def start(self, manifest, block=True, kill_on_interrupt=True):
def start(
self, manifest: Manifest, block: bool = True, kill_on_interrupt: bool = True
) -> None:
"""Start the passed SmartSim entities
This function should not be called directly, but rather
Expand All @@ -90,7 +96,7 @@ def start(self, manifest, block=True, kill_on_interrupt=True):
self.poll(5, True, kill_on_interrupt=kill_on_interrupt)

@property
def orchestrator_active(self):
def orchestrator_active(self) -> bool:
JM_LOCK.acquire()
try:
if len(self._jobs.db_jobs) > 0:
Expand All @@ -99,7 +105,9 @@ def orchestrator_active(self):
finally:
JM_LOCK.release()

def poll(self, interval, verbose, kill_on_interrupt=True):
def poll(
self, interval: int, verbose: bool, kill_on_interrupt: bool = True
) -> None:
"""Poll running jobs and receive logging output of job status
:param interval: number of seconds to wait before polling again
Expand All @@ -124,7 +132,7 @@ def poll(self, interval, verbose, kill_on_interrupt=True):
finally:
JM_LOCK.release()

def finished(self, entity):
def finished(self, entity: SmartSimEntity) -> bool:
"""Return a boolean indicating wether a job has finished or not
:param entity: object launched by SmartSim.
Expand All @@ -149,7 +157,7 @@ def finished(self, entity):
f"Entity {entity.name} has not been launched in this experiment"
) from None

def stop_entity(self, entity):
def stop_entity(self, entity: SmartSimEntity) -> None:
"""Stop an instance of an entity
This function will also update the status of the job in
Expand Down Expand Up @@ -180,7 +188,7 @@ def stop_entity(self, entity):
finally:
JM_LOCK.release()

def stop_entity_list(self, entity_list):
def stop_entity_list(self, entity_list: EntityList) -> None:
"""Stop an instance of an entity list
:param entity_list: entity list to be stopped
Expand All @@ -192,7 +200,7 @@ def stop_entity_list(self, entity_list):
for entity in entity_list.entities:
self.stop_entity(entity)

def get_jobs(self):
def get_jobs(self) -> t.Dict[str, Job]:
"""Return a dictionary of completed job data
:returns: dict[str, Job]
Expand All @@ -203,7 +211,7 @@ def get_jobs(self):
finally:
JM_LOCK.release()

def get_entity_status(self, entity):
def get_entity_status(self, entity: SmartSimEntity) -> str:
"""Get the status of an entity
:param entity: entity to get status of
Expand All @@ -218,7 +226,7 @@ def get_entity_status(self, entity):
)
return self._jobs.get_status(entity)

def get_entity_list_status(self, entity_list):
def get_entity_list_status(self, entity_list: EntityList) -> t.List[str]:
"""Get the statuses of an entity list
:param entity_list: entity list containing entities to
Expand All @@ -237,7 +245,7 @@ def get_entity_list_status(self, entity_list):
statuses.append(self.get_entity_status(entity))
return statuses

def init_launcher(self, launcher):
def init_launcher(self, launcher: str) -> None:
"""Initialize the controller with a specific type of launcher.
SmartSim currently supports slurm, pbs(pro), cobalt, lsf,
and local launching
Expand Down Expand Up @@ -267,7 +275,7 @@ def init_launcher(self, launcher):
else:
raise TypeError("Must provide a 'launcher' argument")

def _launch(self, manifest):
def _launch(self, manifest: Manifest) -> None:
"""Main launching function of the controller
Orchestrators are always launched first so that the
Expand Down Expand Up @@ -323,7 +331,7 @@ def _launch(self, manifest):
for step, entity in steps:
self._launch_step(step, entity)

def _launch_orchestrator(self, orchestrator):
def _launch_orchestrator(self, orchestrator: Orchestrator) -> None:
"""Launch an Orchestrator instance
This function will launch the Orchestrator instance and
Expand Down Expand Up @@ -378,7 +386,7 @@ def _launch_orchestrator(self, orchestrator):
self._save_orchestrator(orchestrator)
logger.debug(f"Orchestrator launched on nodes: {orchestrator.hosts}")

def _launch_step(self, job_step, entity):
def _launch_step(self, job_step, entity: SmartSimEntity) -> None:
"""Use the launcher to launch a job stop
:param job_step: a job step instance
Expand Down Expand Up @@ -408,7 +416,7 @@ def _launch_step(self, job_step, entity):
logger.debug(f"Launching {entity.name}")
self._jobs.add_job(job_step.name, job_id, entity, is_task)

def _create_batch_job_step(self, entity_list):
def _create_batch_job_step(self, entity_list: EntityList) -> t.Any:
"""Use launcher to create batch job step
:param entity_list: EntityList to launch as batch
Expand All @@ -426,7 +434,7 @@ def _create_batch_job_step(self, entity_list):
batch_step.add_to_batch(step)
return batch_step

def _create_job_step(self, entity):
def _create_job_step(self, entity: SmartSimEntity) -> t.Any:
"""Create job steps for all entities with the launcher
:param entities: list of all entities to create steps for
Expand All @@ -441,7 +449,7 @@ def _create_job_step(self, entity):
step = self._launcher.create_step(entity.name, entity.path, entity.run_settings)
return step

def _prep_entity_client_env(self, entity):
def _prep_entity_client_env(self, entity: SmartSimEntity) -> None:
"""Retrieve all connections registered to this entity
:param entity: The entity to retrieve connections from
Expand Down Expand Up @@ -482,7 +490,7 @@ def _prep_entity_client_env(self, entity):
)
entity.run_settings.update_env(client_env)

def _save_orchestrator(self, orchestrator):
def _save_orchestrator(self, orchestrator: Orchestrator) -> None:
"""Save the orchestrator object via pickle
This function saves the orchestrator information to a pickle
Expand All @@ -503,7 +511,7 @@ def _save_orchestrator(self, orchestrator):
with open(dat_file, "wb") as pickle_file:
pickle.dump(orc_data, pickle_file)

def _orchestrator_launch_wait(self, orchestrator):
def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None:
"""Wait for the orchestrator instances to run
In the case where the orchestrator is launched as a batch
Expand Down Expand Up @@ -542,7 +550,6 @@ def _orchestrator_launch_wait(self, orchestrator):
else:
logger.debug("Waiting for orchestrator instances to spin up...")
except KeyboardInterrupt:

logger.info("Orchestrator launch cancelled - requesting to stop")
self.stop_entity_list(orchestrator)

Expand All @@ -552,7 +559,7 @@ def _orchestrator_launch_wait(self, orchestrator):
# launch explicitly
raise

def reload_saved_db(self, checkpoint_file):
def reload_saved_db(self, checkpoint_file: str) -> Orchestrator:
JM_LOCK.acquire()
try:
if self.orchestrator_active:
Expand Down Expand Up @@ -609,7 +616,7 @@ def reload_saved_db(self, checkpoint_file):
finally:
JM_LOCK.release()

def _set_dbobjects(self, manifest):
def _set_dbobjects(self, manifest: Manifest) -> None:
if not manifest.has_db_objects:
return

Expand Down Expand Up @@ -649,9 +656,11 @@ def _set_dbobjects(self, manifest):


class _AnonymousBatchJob(EntityList):
def __init__(self, name, path, batch_settings, **kwargs):
def __init__(
self, name: str, path: str, batch_settings: BatchSettings, **kwargs: t.Any
) -> None:
super().__init__(name, path)
self.batch_settings = batch_settings

def _initialize_entities(self, **kwargs):
def _initialize_entities(self, **kwargs: t.Any) -> None:
...
Loading

0 comments on commit 395ffb0

Please sign in to comment.