diff --git a/cads_adaptors/adaptors/mars.py b/cads_adaptors/adaptors/mars.py index 05acd32c..d3b51550 100644 --- a/cads_adaptors/adaptors/mars.py +++ b/cads_adaptors/adaptors/mars.py @@ -46,15 +46,37 @@ def convert_format( return paths +def get_mars_server_list(config) -> list[str]: + if config.get("mars_servers") is not None: + return ensure_list(config["mars_servers"]) + + # TODO: Refactor when we have a more stable set of mars-servers + if os.getenv("MARS_API_SERVER_LIST") is not None: + default_mars_server_list = os.getenv("MARS_API_SERVER_LIST") + else: + for default_mars_server_list in [ + "/etc/mars/mars-api-server-legacy.list", + "/etc/mars/mars-api-server.list", + ]: + if os.path.exists(default_mars_server_list): + break + + mars_server_list: str = config.get("mars_server_list", default_mars_server_list) + if os.path.exists(mars_server_list): + with open(mars_server_list) as f: + mars_servers = f.read().splitlines() + else: + raise SystemError( + "MARS servers cannot be found, this is an error at the system level." + ) + return mars_servers + + def execute_mars( request: Union[Request, list], context: Context, config: dict[str, Any] = dict(), target: str = "data.grib", - # mars_cmd: tuple[str, ...] = ("/usr/local/bin/mars", "r"), - mars_server_list: str = os.getenv( - "MARS_API_SERVER_LIST", "/etc/mars/mars-api-server.list" - ), ) -> str: from cads_mars_server import client as mars_client @@ -63,13 +85,7 @@ def execute_mars( requests, _cacheable = implement_embargo(requests, config["embargo"]) context.add_stdout(f"Request (after embargo implemented): {requests}") - if os.path.exists(mars_server_list): - with open(mars_server_list) as f: - mars_servers = f.read().splitlines() - else: - raise SystemError( - "MARS servers cannot be found, this is an error at the system level." - ) + mars_servers = get_mars_server_list(config) cluster = mars_client.RemoteMarsClientCluster(urls=mars_servers, log=context) diff --git a/tests/data/mars_servers.list b/tests/data/mars_servers.list new file mode 100644 index 00000000..b723a39d --- /dev/null +++ b/tests/data/mars_servers.list @@ -0,0 +1 @@ +http://a-test-server.url diff --git a/tests/test_15_mars.py b/tests/test_15_mars.py index b90e824a..4ab2abfc 100644 --- a/tests/test_15_mars.py +++ b/tests/test_15_mars.py @@ -1,21 +1,26 @@ -# import pytest +import os -# from cads_adaptors.adaptors import mars +from cads_adaptors.adaptors import mars -# @pytest.mark.parametrize( -# "cmd,error_msg", -# ( -# ("cat r; echo error 1>&2; exit 1", "MARS has crashed."), -# ("cat r; touch data.grib; echo error 1>&2", "MARS returned no data."), -# ), -# ) -# def test_execute_mars_errors(tmp_path, monkeypatch, cmd, error_msg): -# monkeypatch.chdir(tmp_path) # execute_mars generates files in the working dir -# context = mars.Context() -# with pytest.raises(RuntimeError, match=error_msg): -# mars.execute_mars( -# {}, -# context=context, -# mars_cmd=("bash", "-c", cmd), -# ) +def test_get_mars_servers(): + mars_servers = mars.get_mars_server_list( + {"mars_servers": "http://b-test-server.url"} + ) + assert len(mars_servers) == 1 + assert mars_servers[0] == "http://b-test-server.url" + + +def test_get_mars_servers_list_file(): + mars_servers = mars.get_mars_server_list( + {"mars_server_list": "tests/data/mars_servers.list"} + ) + assert len(mars_servers) == 1 + assert mars_servers[0] == "http://a-test-server.url" + + +def test_get_mars_servers_envvar(): + os.environ["MARS_API_SERVER_LIST"] = "tests/data/mars_servers.list" + mars_servers = mars.get_mars_server_list({}) + assert len(mars_servers) == 1 + assert mars_servers[0] == "http://a-test-server.url"