Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket refactor #8719

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
31 changes: 14 additions & 17 deletions src/inmanta/agent/agent_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,26 @@ def __init__(
"""
:param environment: environment id
"""
super().__init__(name="agent", timeout=cfg.server_timeout.get(), reconnect_delay=cfg.agent_reconnect_delay.get())

self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="mainpool")
self._storage = self.check_storage()

if environment is None:
environment = cfg.environment.get()
if environment is None:
raise Exception("The agent requires an environment to be set.")
self.set_environment(environment)

assert self._env_id is not None
if environment is None:
raise Exception("The agent requires an environment to be set.")

super().__init__(
name="agent",
environment=environment,
timeout=cfg.server_timeout.get(),
reconnect_delay=cfg.agent_reconnect_delay.get(),
)

self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="mainpool")
self._storage = self.check_storage()

self.executor_manager: executor.ExecutorManager[executor.Executor] = self.create_executor_manager()
self.scheduler = scheduler.ResourceScheduler(self._env_id, self.executor_manager, self._client)
self.scheduler = scheduler.ResourceScheduler(self._env_id, self.executor_manager, self.session.get_client())
self.working = False
self._client = self.session.get_client()

async def start(self) -> None:
# Make mypy happy
Expand All @@ -85,7 +89,6 @@ def create_executor_manager(self) -> executor.ExecutorManager[executor.Executor]
assert self._env_id is not None
return forking_executor.MPManager(
self.thread_pool,
self.sessionid,
self._env_id,
config.log_dir.get(),
self._storage["executors"],
Expand All @@ -104,12 +107,6 @@ async def stop(self) -> None:
await join_threadpools(threadpools_to_join)
await super().stop()

async def start_connected(self) -> None:
"""
Setup our single endpoint
"""
await self.add_end_point_name(AGENT_SCHEDULER_ID)

async def start_working(self) -> None:
"""Start working, once we have a session"""

Expand Down
8 changes: 2 additions & 6 deletions src/inmanta/agent/forking_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
class ExecutorContext:
"""The context object used by the executor to expose state to the incoming calls"""

client: typing.Optional[inmanta.protocol.SessionClient]
client: typing.Optional[inmanta.protocol.Client]
venv: typing.Optional[inmanta.env.VirtualEnv]
environment: uuid.UUID
executors: dict[str, "inmanta.agent.in_process_executor.InProcessExecutor"] = {}
Expand Down Expand Up @@ -834,7 +834,6 @@ class MPPool(resourcepool.PoolManager[executor.ExecutorBlueprint, executor.Execu
def __init__(
self,
thread_pool: concurrent.futures.thread.ThreadPoolExecutor,
session_gid: uuid.UUID,
environment: uuid.UUID,
log_folder: str,
storage_folder: str,
Expand All @@ -860,7 +859,6 @@ def __init__(
self.thread_pool = thread_pool

self.environment = environment
self.session_gid = session_gid

# on disk
self.log_folder = log_folder
Expand Down Expand Up @@ -998,7 +996,6 @@ class MPManager(
def __init__(
self,
thread_pool: concurrent.futures.thread.ThreadPoolExecutor,
session_gid: uuid.UUID,
environment: uuid.UUID,
log_folder: str,
storage_folder: str,
Expand All @@ -1007,7 +1004,6 @@ def __init__(
) -> None:
"""
:param thread_pool: threadpool to perform work on
:param session_gid: agent session id, used to connect to the server, the agent should keep this alive
:param environment: the inmanta environment we are deploying for
:param log_folder: folder to place log files for the executors
:param storage_folder: folder to place code files and venvs
Expand All @@ -1019,7 +1015,7 @@ def __init__(
retention_time=inmanta.agent.config.agent_executor_retention_time.get(),
)

self.process_pool = MPPool(thread_pool, session_gid, environment, log_folder, storage_folder, log_level, cli_log)
self.process_pool = MPPool(thread_pool, environment, log_folder, storage_folder, log_level, cli_log)

self.environment = environment

Expand Down
5 changes: 3 additions & 2 deletions src/inmanta/agent/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,14 +743,15 @@ def run() -> None:

return f.result()

def get_client(self) -> protocol.SessionClient:
def get_client(self) -> protocol.Client:
"""
Get the client instance that identifies itself with the agent session.

:return: A client that is associated with the session of the agent that executes this handler.
"""
if self._client is None:
self._client = protocol.SessionClient("agent", self._agent.sessionid)
# TODO: use the correct client
self._client = protocol.Client("agent", self._agent.sessionid)
return self._client

def get_file(self, hash_id: str) -> Optional[bytes]:
Expand Down
4 changes: 2 additions & 2 deletions src/inmanta/agent/in_process_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
agent_name: str,
agent_uri: str,
environment: uuid.UUID,
client: inmanta.protocol.SessionClient,
client: inmanta.protocol.Client,
eventloop: asyncio.AbstractEventLoop,
parent_logger: logging.Logger,
):
Expand Down Expand Up @@ -462,7 +462,7 @@ class InProcessExecutorManager(executor.ExecutorManager[InProcessExecutor]):
def __init__(
self,
environment: uuid.UUID,
client: inmanta.protocol.SessionClient,
client: inmanta.protocol.Client,
eventloop: asyncio.AbstractEventLoop,
parent_logger: logging.Logger,
thread_pool: ThreadPoolExecutor,
Expand Down
141 changes: 12 additions & 129 deletions src/inmanta/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,7 +2700,7 @@ async def clear(self, connection: Optional[asyncpg.connection.Connection] = None
async with self.get_connection(connection=connection) as con:
await Agent.delete_all(environment=self.id, connection=con)
await AgentInstance.delete_all(tid=self.id, connection=con)
await AgentProcess.delete_all(environment=self.id, connection=con)
await SchedulerSession.delete_all(environment=self.id, connection=con)
await Compile.delete_all(environment=self.id, connection=con) # Triggers cascading delete on report table
await Parameter.delete_all(environment=self.id, connection=con)
await Notification.delete_all(environment=self.id, connection=con)
Expand Down Expand Up @@ -3063,7 +3063,7 @@ async def get_unknowns_to_copy_in_partial_compile(
return [cls(from_postgres=True, **uk) for uk in result]


class AgentProcess(BaseDocument):
class SchedulerSession(BaseDocument):
"""
A process in the infrastructure that has (had) a session as an agent.

Expand All @@ -3078,11 +3078,10 @@ class AgentProcess(BaseDocument):
hostname: str
environment: uuid.UUID
first_seen: Optional[datetime.datetime] = None
last_seen: Optional[datetime.datetime] = None
expired: Optional[datetime.datetime] = None

@classmethod
async def get_live(cls, environment: Optional[uuid.UUID] = None) -> list["AgentProcess"]:
async def get_live(cls, environment: Optional[uuid.UUID] = None) -> list["SchedulerSession"]:
if environment is not None:
result = await cls.get_list(
limit=DBLIMIT, environment=environment, expired=None, order_by_column="last_seen", order="ASC NULLS LAST"
Expand All @@ -3094,7 +3093,7 @@ async def get_live(cls, environment: Optional[uuid.UUID] = None) -> list["AgentP
@classmethod
async def get_by_sid(
cls, sid: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None
) -> Optional["AgentProcess"]:
) -> Optional["SchedulerSession"]:
objects = await cls.get_list(limit=DBLIMIT, connection=connection, expired=None, sid=sid)
if len(objects) == 0:
return None
Expand All @@ -3105,23 +3104,19 @@ async def get_by_sid(
return objects[0]

@classmethod
async def seen(
async def register(
cls,
env: uuid.UUID,
nodename: str,
hostname: str,
sid: uuid.UUID,
now: datetime.datetime,
connection: Optional[asyncpg.connection.Connection] = None,
) -> None:
"""
Update the last_seen parameter of the process and mark as not expired.
"""
proc = await cls.get_one(connection=connection, sid=sid)
if proc is None:
proc = cls(hostname=nodename, environment=env, first_seen=now, last_seen=now, sid=sid)
await proc.insert(connection=connection)
else:
await proc.update_fields(connection=connection, last_seen=now, expired=None)
proc = cls(hostname=hostname, environment=env, first_seen=now, sid=sid)
await proc.insert(connection=connection)

@classmethod
async def update_last_seen(
Expand Down Expand Up @@ -3150,6 +3145,7 @@ async def expire_all(cls, now: datetime.datetime, connection: Optional[asyncpg.c

@classmethod
async def cleanup(cls, nr_expired_records_to_keep: int) -> None:
# TODO
query = f"""
WITH halted_env AS (
SELECT id FROM environment WHERE halted = true
Expand Down Expand Up @@ -3190,105 +3186,10 @@ def to_dto(self) -> m.AgentProcess:
hostname=self.hostname,
environment=self.environment,
first_seen=self.first_seen,
last_seen=self.last_seen,
expired=self.expired,
)


TAgentInstance = TypeVar("TAgentInstance", bound="AgentInstance")


class AgentInstance(BaseDocument):
"""
A physical server/node in the infrastructure that reports to the management server.

:param hostname: The hostname of the device.
:param last_seen: When did the server receive data from the node for the last time.
"""

__primary_key__ = ("id",)

# TODO: add env to speed up cleanup
id: uuid.UUID
process: uuid.UUID
name: str
expired: Optional[datetime.datetime] = None
tid: uuid.UUID

@classmethod
async def active_for(
cls: type[TAgentInstance],
tid: uuid.UUID,
endpoint: str,
process: Optional[uuid.UUID] = None,
connection: Optional[asyncpg.connection.Connection] = None,
) -> list[TAgentInstance]:
if process is not None:
objects = await cls.get_list(expired=None, tid=tid, name=endpoint, connection=connection)
else:
objects = await cls.get_list(expired=None, tid=tid, name=endpoint, connection=connection)
return objects

@classmethod
async def active(cls: type[TAgentInstance]) -> list[TAgentInstance]:
objects = await cls.get_list(expired=None)
return objects

@classmethod
async def log_instance_creation(
cls: type[TAgentInstance],
tid: uuid.UUID,
process: uuid.UUID,
endpoints: set[str],
connection: Optional[asyncpg.connection.Connection] = None,
) -> None:
"""
Create new agent instances for a given session.
"""
if not endpoints:
return
async with cls.get_connection(connection) as con:
await con.executemany(
f"""
INSERT INTO
{cls.table_name()}
(id, tid, process, name, expired)
VALUES ($1, $2, $3, $4, null)
ON CONFLICT ON CONSTRAINT {cls.table_name()}_unique DO UPDATE
SET expired = null
;
""",
[tuple(map(cls._get_value, (cls._new_id(), tid, process, name))) for name in endpoints],
)

@classmethod
async def log_instance_expiry(
cls: type[TAgentInstance],
sid: uuid.UUID,
endpoints: set[str],
now: datetime.datetime,
connection: Optional[asyncpg.connection.Connection] = None,
) -> None:
"""
Expire specific instances for a given session id.
"""
if not endpoints:
return
instances: list[TAgentInstance] = await cls.get_list(connection=connection, process=sid)
for ai in instances:
if ai.name in endpoints:
await ai.update_fields(connection=connection, expired=now)

@classmethod
async def expire_all(cls, now: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None) -> None:
query = f"""
UPDATE {cls.table_name()}
SET expired=$1
WHERE expired IS NULL
"""
await cls._execute_query(query, cls._get_value(now), connection=connection)


class Agent(BaseDocument):
"""
An inmanta agent
Expand All @@ -3308,13 +3209,8 @@ class Agent(BaseDocument):
name: str
last_failover: Optional[datetime.datetime] = None
paused: bool = False
id_primary: Optional[uuid.UUID] = None
unpause_on_resume: Optional[bool] = None

@property
def primary(self) -> Optional[uuid.UUID]:
return self.id_primary

@classmethod
def get_valid_field_names(cls) -> list[str]:
# Allow the computed fields
Expand Down Expand Up @@ -3345,11 +3241,8 @@ def to_dict(self) -> JsonType:
if self.last_failover is None:
base["last_failover"] = ""

if self.primary is None:
base["primary"] = ""
else:
base["primary"] = base["id_primary"]
del base["id_primary"]
# Field kept for backward compatibility
base["primary"] = ""

base["state"] = self.get_status().value

Expand Down Expand Up @@ -3488,15 +3381,6 @@ async def update_primary(
else:
await agent.update_fields(last_failover=now, id_primary=None, connection=connection)

@classmethod
async def mark_all_as_non_primary(cls, connection: Optional[asyncpg.connection.Connection] = None) -> None:
query = f"""
UPDATE {cls.table_name()}
SET id_primary=NULL
WHERE id_primary IS NOT NULL
"""
await cls._execute_query(query, connection=connection)

@classmethod
async def clean_up(cls, connection: Optional[asyncpg.connection.Connection] = None) -> None:
query = """
Expand Down Expand Up @@ -6613,8 +6497,7 @@ async def set_last_processed_model_version(
Project,
Environment,
UnknownParameter,
AgentProcess,
AgentInstance,
SchedulerSession,
Agent,
Resource,
ResourceAction,
Expand Down
Loading