Skip to content

Commit

Permalink
Fixes to development environment and test runner.
Browse files Browse the repository at this point in the history
Due to a bug in how tests were discovered, not all tests were run by
default. This fixes test discovery to use pytest instead, ensuring
all unit tests actually run.

Additionally, updates the VSCode configuration to properly configure
the pyink formatter, adds a field specifier for dataclasses.field
to get correct type inference in pyright, and makes the canonical
alias public-API-walking logic more robust.

PiperOrigin-RevId: 644195253
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jun 18, 2024
1 parent 2ba29c6 commit f38fb29
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
# cache-dependency-path: '**/pyproject.toml'

- run: pip --version
- run: pip install -e .[extras]
- run: pip install -e .[dev,extras]
- run: pip freeze

# Run tests
Expand Down
20 changes: 6 additions & 14 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# Its original canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
#
# Also includes some modifications specific to this repository.

[MASTER]

Expand Down Expand Up @@ -90,6 +92,7 @@ disable=abstract-method,
input-builtin,
intern-builtin,
invalid-str-codec,
invalid-field-call,
locally-disabled,
long-builtin,
long-suffix,
Expand All @@ -109,6 +112,7 @@ disable=abstract-method,
no-name-in-module,
no-self-use,
nonzero-method,
not-callable, # false positives for jax.jit
oct-method,
old-division,
old-ne-operator,
Expand Down Expand Up @@ -161,12 +165,6 @@ disable=abstract-method,
# mypackage.mymodule.MyReporterClass.
output-format=text

# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no

# Tells whether to display a full report or only the messages
reports=no

Expand Down Expand Up @@ -285,12 +283,6 @@ ignore-long-lines=(?x)(
# else.
single-line-if-stmt=yes

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=

# Maximum number of lines in a module
max-module-lines=99999

Expand Down Expand Up @@ -444,4 +436,4 @@ valid-metaclass-classmethod-first-arg=mcs
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException
BaseException
5 changes: 2 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
"[python]": {
"editor.rulers": [80],
"editor.tabSize": 2,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.detectIndentation": false
},
"python.formatting.provider": "none",
"black-formatter.path": ["pyink"],
"python.formatting.provider": "black",
"python.formatting.blackPath": "pyink",
"files.watcherExclude": {
"**/.git/**": true
},
Expand Down
1 change: 1 addition & 0 deletions penzai/core/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class PyTreeDataclassSafetyError(Exception):

@dataclass_transform(
frozen_default=True, # pylint: disable=unexpected-keyword-arg # pytype: disable=not-supported-yet
field_specifiers=(dataclasses.field,),
)
def pytree_dataclass(
cls: type[Any] | None = None,
Expand Down
20 changes: 13 additions & 7 deletions penzai/treescope/canonical_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,11 @@ def populate_from_public_api(
]

for name in public_names:
value = getattr(module, name)
try:
value = getattr(module, name)
except AttributeError:
# Possibly a misspecified __all__?
continue
path = ModuleAttributePath(module.__name__, (name,))
if isinstance(value, types.ModuleType):
if (
Expand All @@ -552,21 +556,23 @@ def populate_from_public_api(
add_alias(value, path, on_conflict="ignore")


def prefix_filter(prefix: str):
def prefix_filter(include: str, excludes: tuple[str, ...] = ()):
"""Builds a filter that only defines aliases within a given prefix."""

def module_is_under_prefix(module_name: str):
return module_name == prefix or module_name.startswith(prefix + ".")
def is_under_prefix(path: str, prefix: str):
return path == prefix or path.startswith(prefix + ".")

def predicate(the_object: Any, path: ModuleAttributePath) -> bool:
if not default_well_known_filter(the_object, path):
return False
if not module_is_under_prefix(path.module_name):
if not is_under_prefix(str(path), include):
return False
if any(is_under_prefix(str(path), exclude) for exclude in excludes):
return False
if (
hasattr(the_object, "__module__")
and the_object.__module__
and not module_is_under_prefix(the_object.__module__)
and not is_under_prefix(the_object.__module__, include)
):
return False
return True
Expand All @@ -578,7 +584,7 @@ def predicate(the_object: Any, path: ModuleAttributePath) -> bool:
# they are likely to be used in penzai code.
_alias_environment.get().lazy_populate_if_imported.extend([
# Third-party libraries with useful APIs:
("numpy", prefix_filter("numpy")),
("numpy", prefix_filter("numpy", excludes=("numpy.core",))),
("jax.lax", prefix_filter("jax")),
("jax.numpy", prefix_filter("jax")),
("jax.scipy", prefix_filter("jax")),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ notebook = [
dev = [
"pylint>=2.6.0",
"pyink>=24.3.0",
"pytest>=8.2.2",
"ipython",
"jupyter",
]
Expand Down
11 changes: 6 additions & 5 deletions run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
"""Entry point executable to run all tests."""

import sys
from absl.testing import absltest
import subprocess

if __name__ == "__main__":
absltest.main(
module=None,
argv=[sys.argv[0], "discover", "-s", "tests", "-p", "*_test.py"],
subprocess.check_call(
["python", "-m", "pytest", "tests", "-k", "not ShardingUtilTest"]
)
subprocess.check_call(
["python", "-m", "pytest", "tests", "-k", "ShardingUtilTest"]
)

0 comments on commit f38fb29

Please sign in to comment.