Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for key-based SSH connections #534

Merged
merged 3 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions api/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# Infections are injected into the application via the environment variable
# 'INFECTIONS', a comma-separated list of infection names.

import os
from typing import Dict, Set

from django.conf import settings

from api.utils import deployment_mode_is_production

# The built-in set of infections.
Expand All @@ -20,9 +21,6 @@
INFECTION_STRUCTURE_DOWNLOAD: 'An error in the DownloadStructures view'
}

# What infection have been set?
_INFECTIONS: str = os.environ.get('INFECTIONS', '').lower()


def have_infection(name: str) -> bool:
"""Returns True if we've been given the named infection.
Expand All @@ -31,9 +29,11 @@ def have_infection(name: str) -> bool:


def _get_infections() -> Set[str]:
if _INFECTIONS == '':
if settings.INFECTIONS == '':
return set()
infections: set[str] = {
infection for infection in _INFECTIONS.split(',') if infection in _CATALOGUE
infection
for infection in settings.INFECTIONS.split(',')
if infection in _CATALOGUE
}
return infections
54 changes: 44 additions & 10 deletions api/remote_ispyb_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
remote=False,
ssh_user=None,
ssh_password=None,
ssh_private_key_filename=None,
ssh_host=None,
conn_inactivity=360,
):
Expand All @@ -45,6 +46,7 @@ def __init__(
'ssh_host': ssh_host,
'ssh_user': ssh_user,
'ssh_pass': ssh_password,
'ssh_pkey': ssh_private_key_filename,
'db_host': host,
'db_port': int(port),
'db_user': user,
Expand All @@ -53,12 +55,11 @@ def __init__(
}
self.remote_connect(**creds)
logger.debug(
"Started host=%s username=%s local_bind_port=%s",
"Started remote ssh_host=%s ssh_user=%s local_bind_port=%s",
ssh_host,
ssh_user,
self.server.local_bind_port,
)

else:
self.connect(
user=user,
Expand All @@ -68,29 +69,60 @@ def __init__(
port=port,
conn_inactivity=conn_inactivity,
)
logger.debug("Started host=%s user=%s port=%s", host, user, port)
logger.debug("Started direct host=%s user=%s port=%s", host, user, port)

def remote_connect(
self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user, db_pass, db_name
self,
ssh_host,
ssh_user,
ssh_pass,
ssh_pkey,
db_host,
db_port,
db_user,
db_pass,
db_name,
):
sshtunnel.SSH_TIMEOUT = 10.0
sshtunnel.TUNNEL_TIMEOUT = 10.0
sshtunnel.DEFAULT_LOGLEVEL = logging.CRITICAL
self.conn_inactivity = int(self.conn_inactivity)

self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port),
)
if ssh_pkey:
logger.debug(
'Creating SSHTunnelForwarder (with SSH Key) host=%s user=%s',
ssh_host,
ssh_user,
)
self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user,
ssh_pkey=ssh_pkey,
remote_bind_address=(db_host, db_port),
)
else:
logger.debug(
'Creating SSHTunnelForwarder (with password) host=%s user=%s',
ssh_host,
ssh_user,
)
self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port),
)
logger.debug('Created SSHTunnelForwarder')

# stops hanging connections in transport
self.server.daemon_forward_servers = True
self.server.daemon_transport = True

logger.debug('Starting SSH server...')
self.server.start()
logger.debug('Started SSH server')

logger.debug('Connecting to ISPyB (db_user=%s db_name=%s)...', db_user, db_name)
self.conn = pymysql.connect(
user=db_user,
password=db_pass,
Expand All @@ -100,8 +132,10 @@ def remote_connect(
)

if self.conn is not None:
logger.debug('Connected')
self.conn.autocommit = True
else:
logger.debug('Failed to connect')
self.server.stop()
raise ISPyBConnectionException
self.last_activity_ts = time.time()
Expand Down
39 changes: 20 additions & 19 deletions api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,52 @@


def get_remote_conn() -> Optional[SSHConnector]:
ispyb_credentials: Dict[str, Any] = {
"user": os.environ.get("ISPYB_USER"),
"pw": os.environ.get("ISPYB_PASSWORD"),
"host": os.environ.get("ISPYB_HOST"),
"port": os.environ.get("ISPYB_PORT"),
credentials: Dict[str, Any] = {
"user": settings.ISPYB_USER,
"pw": settings.ISPYB_PASSWORD,
"host": settings.ISPYB_HOST,
"port": settings.ISPYB_PORT,
"db": "ispyb",
"conn_inactivity": 360,
}

ssh_credentials: Dict[str, Any] = {
'ssh_host': os.environ.get("SSH_HOST"),
'ssh_user': os.environ.get("SSH_USER"),
'ssh_password': os.environ.get("SSH_PASSWORD"),
'ssh_host': settings.SSH_HOST,
'ssh_user': settings.SSH_USER,
'ssh_password': settings.SSH_PASSWORD,
"ssh_private_key_filename": settings.SSH_PRIVATE_KEY_FILENAME,
'remote': True,
}

ispyb_credentials.update(**ssh_credentials)
credentials.update(**ssh_credentials)

# Caution: Credentials may not be set in the environment.
# Assume the credentials are invalid if there is no host.
# If a host is not defined other properties are useless.
if not ispyb_credentials["host"]:
if not credentials["host"]:
logger.debug("No ISPyB host - cannot return a connector")
return None

# Try to get an SSH connection (aware that it might fail)
conn: Optional[SSHConnector] = None
try:
conn = SSHConnector(**ispyb_credentials)
conn = SSHConnector(**credentials)
except Exception:
# Log the exception if DEBUG level or lower/finer?
# The following wil not log if the level is set to INFO for example.
# The following will not log if the level is set to INFO for example.
if logging.DEBUG >= logger.level:
logger.info("ispyb_credentials=%s", ispyb_credentials)
logger.info("credentials=%s", credentials)
logger.exception("Got the following exception creating SSHConnector...")

return conn


def get_conn() -> Optional[Connector]:
credentials: Dict[str, Any] = {
"user": os.environ.get("ISPYB_USER"),
"pw": os.environ.get("ISPYB_PASSWORD"),
"host": os.environ.get("ISPYB_HOST"),
"port": os.environ.get("ISPYB_PORT"),
"user": settings.ISPYB_USER,
"pw": settings.ISPYB_PASSWORD,
"host": settings.ISPYB_HOST,
"port": settings.ISPYB_PORT,
"db": "ispyb",
"conn_inactivity": 360,
}
Expand All @@ -108,7 +109,7 @@ def get_conn() -> Optional[Connector]:
conn = Connector(**credentials)
except Exception:
# Log the exception if DEBUG level or lower/finer?
# The following wil not log if the level is set to INFO for example.
# The following will not log if the level is set to INFO for example.
if logging.DEBUG >= logger.level:
logger.info("credentials=%s", credentials)
logger.exception("Got the following exception creating Connector...")
Expand Down Expand Up @@ -349,7 +350,7 @@ def get_proposals_for_user(self, user, restrict_to_membership=False):
assert user

proposals = set()
ispyb_user = os.environ.get("ISPYB_USER")
ispyb_user = settings.ISPYB_USER
logger.debug(
"ispyb_user=%s restrict_to_membership=%s",
ispyb_user,
Expand Down
Loading