Skip to content

Commit

Permalink
job/jobmanager typehints 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Jun 8, 2023
1 parent 8d0644b commit 8c4bf47
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion smartsim/_core/control/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def set_status(
self,
new_status: str,
raw_status: str,
returncode: int,
returncode: t.Optional[int],
error: t.Optional[str] = None,
output: t.Optional[str] = None,
) -> None:
Expand Down
34 changes: 20 additions & 14 deletions smartsim/_core/control/jobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,24 @@ def check_jobs(self) -> None:
self._lock.acquire()
try:
jobs = self().values()
job_name_map = dict([(job.name, job.ename) for job in jobs])
job_name_map = {job.name: job.ename for job in jobs}

# returns (job step name, StepInfo) tuples
if self._launcher:
statuses = self._launcher.get_step_update(job_name_map.keys())
step_names = list(job_name_map.keys())
statuses = self._launcher.get_step_update(step_names)
for job_name, status in statuses:
job = self[job_name_map[job_name]]
# uses abstract step interface
job.set_status(
status.status,
status.launcher_status,
status.returncode,
error=status.error,
output=status.output,
)

if status:
# uses abstract step interface
job.set_status(
status.status,
status.launcher_status,
status.returncode,
error=status.error,
output=status.output,
)
finally:
self._lock.release()

Expand Down Expand Up @@ -310,9 +313,12 @@ def get_db_host_addresses(self) -> t.List[str]:
"""
addresses = []
for db_job in self.db_jobs.values():
for combine in itertools.product(db_job.hosts, db_job.entity.ports):
ip_addr = get_ip_from_host(combine[0])
addresses.append(":".join((ip_addr, str(combine[1]))))
if isinstance(db_job.entity, (DBNode, Orchestrator)):
db_entity: t.Union[DBNode, Orchestrator] = db_job.entity

for combine in itertools.product(db_job.hosts, db_entity.ports):
ip_addr = get_ip_from_host(combine[0])
addresses.append(":".join((ip_addr, str(combine[1]))))
return addresses

def set_db_hosts(self, orchestrator: Orchestrator) -> None:
Expand All @@ -327,7 +333,7 @@ def set_db_hosts(self, orchestrator: Orchestrator) -> None:
if orchestrator.batch:
self.db_jobs[orchestrator.name].hosts = orchestrator.hosts
else:
for dbnode in orchestrator:
for dbnode in orchestrator.dbnodes:
if not dbnode._mpmd:
self.db_jobs[dbnode.name].hosts = [dbnode.host]
else:
Expand Down

0 comments on commit 8c4bf47

Please sign in to comment.