Skip to content

Commit 7cead40

Browse files
authored
Fix environment variable leak for unused formatters (jupyterlab-contrib#338)
* Add tests against environment pollution * Delay `rpy2` import until the formatter is requested
1 parent a7a9f10 commit 7cead40

File tree

2 files changed

+65
-32
lines changed

2 files changed

+65
-32
lines changed

jupyterlab_code_formatter/formatters.py

+27-32
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,11 @@
99
from functools import wraps
1010
from typing import List, Type
1111

12-
try:
13-
import rpy2
14-
import rpy2.robjects
15-
except ImportError:
16-
pass
1712
if sys.version_info >= (3, 9):
1813
from functools import cache
1914
else:
2015
from functools import lru_cache
16+
2117
cache = lru_cache(maxsize=None)
2218

2319
from packaging import version
@@ -357,56 +353,55 @@ def format_code(self, code: str, notebook: bool, **options) -> str:
357353
return isort.code(code=code, **options)
358354

359355

360-
class FormatRFormatter(BaseFormatter):
361-
label = "Apply FormatR Formatter"
362-
package_name = "formatR"
356+
class RFormatter(BaseFormatter):
357+
@property
358+
@abc.abstractmethod
359+
def package_name(self) -> str:
360+
pass
363361

364362
@property
365363
def importable(self) -> bool:
366-
try:
367-
import rpy2.robjects.packages as rpackages
364+
package_location = subprocess.run(
365+
["Rscript", "-e", f"cat(system.file(package='{self.package_name}'))"],
366+
capture_output=True,
367+
text=True,
368+
)
369+
return package_location != ""
368370

369-
rpackages.importr(self.package_name, robject_translations={".env": "env"})
370371

371-
return True
372-
except Exception:
373-
return False
372+
class FormatRFormatter(RFormatter):
373+
label = "Apply FormatR Formatter"
374+
package_name = "formatR"
374375

375376
@handle_line_ending_and_magic
376377
def format_code(self, code: str, notebook: bool, **options) -> str:
377378
import rpy2.robjects.packages as rpackages
379+
from rpy2.robjects import conversion, default_converter
378380

379-
format_r = rpackages.importr(self.package_name, robject_translations={".env": "env"})
380-
formatted_code = format_r.tidy_source(text=code, output=False, **options)
381-
return "\n".join(formatted_code[0])
381+
with conversion.localconverter(default_converter):
382+
format_r = rpackages.importr(self.package_name, robject_translations={".env": "env"})
383+
formatted_code = format_r.tidy_source(text=code, output=False, **options)
384+
return "\n".join(formatted_code[0])
382385

383386

384-
class StylerFormatter(BaseFormatter):
387+
class StylerFormatter(RFormatter):
385388
label = "Apply Styler Formatter"
386389
package_name = "styler"
387390

388-
@property
389-
def importable(self) -> bool:
390-
try:
391-
import rpy2.robjects.packages as rpackages
392-
393-
rpackages.importr(self.package_name)
394-
395-
return True
396-
except Exception:
397-
return False
398-
399391
@handle_line_ending_and_magic
400392
def format_code(self, code: str, notebook: bool, **options) -> str:
401393
import rpy2.robjects.packages as rpackages
394+
from rpy2.robjects import conversion, default_converter
402395

403-
styler_r = rpackages.importr(self.package_name)
404-
formatted_code = styler_r.style_text(code, **self._transform_options(styler_r, options))
405-
return "\n".join(formatted_code)
396+
with conversion.localconverter(default_converter):
397+
styler_r = rpackages.importr(self.package_name)
398+
formatted_code = styler_r.style_text(code, **self._transform_options(styler_r, options))
399+
return "\n".join(formatted_code)
406400

407401
@staticmethod
408402
def _transform_options(styler_r, options):
409403
transformed_options = copy.deepcopy(options)
404+
import rpy2.robjects
410405

411406
if "math_token_spacing" in transformed_options:
412407
if isinstance(options["math_token_spacing"], dict):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import json
2+
import os
3+
import sys
4+
from subprocess import run
5+
from unittest import mock
6+
7+
import pytest
8+
9+
from jupyterlab_code_formatter.formatters import SERVER_FORMATTERS
10+
11+
12+
def test_env_pollution_on_import():
13+
# should not pollute environment on import
14+
code = "; ".join(
15+
[
16+
"from jupyterlab_code_formatter import formatters",
17+
"import json",
18+
"import os",
19+
"assert formatters",
20+
"print(json.dumps(os.environ.copy()))",
21+
]
22+
)
23+
result = run([sys.executable, "-c", f"{code}"], capture_output=True, text=True, check=True, env={})
24+
environ = json.loads(result.stdout)
25+
assert set(environ.keys()) - {"LC_CTYPE"} == set()
26+
27+
28+
@pytest.mark.parametrize("name", SERVER_FORMATTERS)
29+
def test_env_pollution_on_importable_check(name):
30+
formatter = SERVER_FORMATTERS[name]
31+
# should not pollute environment on `importable` check
32+
with mock.patch.dict(os.environ, {}, clear=True):
33+
# invoke the property getter
34+
is_importable = formatter.importable
35+
# the environment should have no extra keys
36+
assert set(os.environ.keys()) == set()
37+
if not is_importable:
38+
pytest.skip(f"{name} formatter was not importable, the test may yield false negatives")

0 commit comments

Comments
 (0)