From f38fb291e070ab0481ad9a587aff1682655150c6 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Mon, 17 Jun 2024 18:13:46 -0700 Subject: [PATCH] Fixes to development environment and test runner. Due to a bug in how tests were discovered, not all tests were run by default. This fixes test discovery to use pytest instead, ensuring all unit tests actually run. Additionally, updates the VSCode configuration to properly configure the pyink formatter, adds a field specifier for dataclasses.field to get correct type inference in pyright, and makes the canonical alias public-API-walking logic more robust. PiperOrigin-RevId: 644195253 --- .github/workflows/unittests.yml | 2 +- .pylintrc | 20 ++++++-------------- .vscode/settings.json | 5 ++--- penzai/core/struct.py | 1 + penzai/treescope/canonical_aliases.py | 20 +++++++++++++------- pyproject.toml | 1 + run_tests.py | 11 ++++++----- 7 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 6b557f7..0e1d6cf 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -24,7 +24,7 @@ jobs: # cache-dependency-path: '**/pyproject.toml' - run: pip --version - - run: pip install -e .[extras] + - run: pip install -e .[dev,extras] - run: pip freeze # Run tests diff --git a/.pylintrc b/.pylintrc index ead0ab6..8e695e4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -2,8 +2,10 @@ # best-practices and style described in the Google Python style guide: # https://google.github.io/styleguide/pyguide.html # -# Its canonical open-source location is: +# Its original canonical open-source location is: # https://google.github.io/styleguide/pylintrc +# +# Also includes some modifications specific to this repository. [MASTER] @@ -90,6 +92,7 @@ disable=abstract-method, input-builtin, intern-builtin, invalid-str-codec, + invalid-field-call, locally-disabled, long-builtin, long-suffix, @@ -109,6 +112,7 @@ disable=abstract-method, no-name-in-module, no-self-use, nonzero-method, + not-callable, # false positives for jax.jit oct-method, old-division, old-ne-operator, @@ -161,12 +165,6 @@ disable=abstract-method, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -285,12 +283,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check= - # Maximum number of lines in a module max-module-lines=99999 @@ -444,4 +436,4 @@ valid-metaclass-classmethod-first-arg=mcs # "Exception" overgeneral-exceptions=StandardError, Exception, - BaseException \ No newline at end of file + BaseException diff --git a/.vscode/settings.json b/.vscode/settings.json index 2510af6..e8dd68c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,12 +12,11 @@ "[python]": { "editor.rulers": [80], "editor.tabSize": 2, - "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true, "editor.detectIndentation": false }, - "python.formatting.provider": "none", - "black-formatter.path": ["pyink"], + "python.formatting.provider": "black", + "python.formatting.blackPath": "pyink", "files.watcherExclude": { "**/.git/**": true }, diff --git a/penzai/core/struct.py b/penzai/core/struct.py index 5f1c12b..b37ca86 100644 --- a/penzai/core/struct.py +++ b/penzai/core/struct.py @@ -69,6 +69,7 @@ class PyTreeDataclassSafetyError(Exception): @dataclass_transform( frozen_default=True, # pylint: disable=unexpected-keyword-arg # pytype: disable=not-supported-yet + field_specifiers=(dataclasses.field,), ) def pytree_dataclass( cls: type[Any] | None = None, diff --git a/penzai/treescope/canonical_aliases.py b/penzai/treescope/canonical_aliases.py index 7168c79..6c5ebd6 100644 --- a/penzai/treescope/canonical_aliases.py +++ b/penzai/treescope/canonical_aliases.py @@ -538,7 +538,11 @@ def populate_from_public_api( ] for name in public_names: - value = getattr(module, name) + try: + value = getattr(module, name) + except AttributeError: + # Possibly a misspecified __all__? + continue path = ModuleAttributePath(module.__name__, (name,)) if isinstance(value, types.ModuleType): if ( @@ -552,21 +556,23 @@ def populate_from_public_api( add_alias(value, path, on_conflict="ignore") -def prefix_filter(prefix: str): +def prefix_filter(include: str, excludes: tuple[str, ...] = ()): """Builds a filter that only defines aliases within a given prefix.""" - def module_is_under_prefix(module_name: str): - return module_name == prefix or module_name.startswith(prefix + ".") + def is_under_prefix(path: str, prefix: str): + return path == prefix or path.startswith(prefix + ".") def predicate(the_object: Any, path: ModuleAttributePath) -> bool: if not default_well_known_filter(the_object, path): return False - if not module_is_under_prefix(path.module_name): + if not is_under_prefix(str(path), include): + return False + if any(is_under_prefix(str(path), exclude) for exclude in excludes): return False if ( hasattr(the_object, "__module__") and the_object.__module__ - and not module_is_under_prefix(the_object.__module__) + and not is_under_prefix(the_object.__module__, include) ): return False return True @@ -578,7 +584,7 @@ def predicate(the_object: Any, path: ModuleAttributePath) -> bool: # they are likely to be used in penzai code. _alias_environment.get().lazy_populate_if_imported.extend([ # Third-party libraries with useful APIs: - ("numpy", prefix_filter("numpy")), + ("numpy", prefix_filter("numpy", excludes=("numpy.core",))), ("jax.lax", prefix_filter("jax")), ("jax.numpy", prefix_filter("jax")), ("jax.scipy", prefix_filter("jax")), diff --git a/pyproject.toml b/pyproject.toml index e70a0a6..21992b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ notebook = [ dev = [ "pylint>=2.6.0", "pyink>=24.3.0", + "pytest>=8.2.2", "ipython", "jupyter", ] diff --git a/run_tests.py b/run_tests.py index 7b6fdc5..0d29602 100644 --- a/run_tests.py +++ b/run_tests.py @@ -13,11 +13,12 @@ # limitations under the License. """Entry point executable to run all tests.""" -import sys -from absl.testing import absltest +import subprocess if __name__ == "__main__": - absltest.main( - module=None, - argv=[sys.argv[0], "discover", "-s", "tests", "-p", "*_test.py"], + subprocess.check_call( + ["python", "-m", "pytest", "tests", "-k", "not ShardingUtilTest"] + ) + subprocess.check_call( + ["python", "-m", "pytest", "tests", "-k", "ShardingUtilTest"] )