Skip to content

Commit 463c02d

Browse files
authored
Fix quadrature rule hash (#132)
* Make hashing of quadrature rules safe * Also remove some uses of deprecated `@abstractproperty` and gives cells and point sets `repr()`s. * Add safe_repr function to handle floating point types
1 parent 8839f87 commit 463c02d

File tree

6 files changed

+147
-12
lines changed

6 files changed

+147
-12
lines changed

FIAT/reference_element.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from math import factorial
2727

2828
import numpy
29+
from gem.utils import safe_repr
2930
from recursivenodes.nodes import _decode_family, _recursive
3031

3132
from FIAT.orientation_utils import (
@@ -126,7 +127,7 @@ def linalg_subspace_intersection(A, B):
126127
return U[:, :rank_c]
127128

128129

129-
class Cell(object):
130+
class Cell:
130131
"""Abstract class for a reference cell. Provides accessors for
131132
geometry (vertex coordinates) as well as topology (orderings of
132133
vertices that make up edges, faces, etc."""
@@ -184,6 +185,9 @@ def __init__(self, shape, vertices, topology):
184185
# Dictionary with derived cells
185186
self._split_cache = {}
186187

188+
def __repr__(self):
189+
return f"{type(self).__name__}({self.shape!r}, {safe_repr(self.vertices)}, {self.topology!r})"
190+
187191
def _key(self):
188192
"""Hashable object key data (excluding type)."""
189193
# Default: only type matters
@@ -1130,6 +1134,9 @@ def __init__(self, *cells):
11301134
super().__init__(TENSORPRODUCT, vertices, topology)
11311135
self.cells = tuple(cells)
11321136

1137+
def __repr__(self):
1138+
return f"{type(self).__name__}({self.cells!r})"
1139+
11331140
def _key(self):
11341141
return self.cells
11351142

finat/point_set.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
1-
from abc import ABCMeta, abstractproperty
1+
import abc
2+
import hashlib
3+
from functools import cached_property
24
from itertools import chain, product
35

46
import numpy
57

68
import gem
7-
from gem.utils import cached_property
9+
from gem.utils import safe_repr
810

911

10-
class AbstractPointSet(metaclass=ABCMeta):
12+
class AbstractPointSet(abc.ABC):
1113
"""A way of specifying a known set of points, perhaps with some
1214
(tensor) structure.
1315
1416
Points, when stored, have shape point_set_shape + (point_dimension,)
1517
where point_set_shape is () for scalar, (N,) for N element vector,
1618
(N, M) for N x M matrix etc.
1719
"""
20+
def __hash__(self):
21+
return int.from_bytes(hashlib.md5(repr(self).encode()).digest(), byteorder="big")
1822

19-
@abstractproperty
23+
@abc.abstractmethod
24+
def __repr__(self):
25+
pass
26+
27+
@property
28+
@abc.abstractmethod
2029
def points(self):
2130
"""A flattened numpy array of points or ``UnknownPointsArray``
2231
object with shape (# of points, point dimension)."""
@@ -27,12 +36,14 @@ def dimension(self):
2736
_, dim = self.points.shape
2837
return dim
2938

30-
@abstractproperty
39+
@property
40+
@abc.abstractmethod
3141
def indices(self):
3242
"""GEM indices with matching shape and extent to the structure of the
3343
point set."""
3444

35-
@abstractproperty
45+
@property
46+
@abc.abstractmethod
3647
def expression(self):
3748
"""GEM expression describing the points, with free indices
3849
``self.indices`` and shape (point dimension,)."""
@@ -53,6 +64,9 @@ def __init__(self, point):
5364
assert len(point.shape) == 1
5465
self.point = point
5566

67+
def __repr__(self):
68+
return f"{type(self).__name__}({safe_repr(self.point)})"
69+
5670
@cached_property
5771
def points(self):
5872
# Make sure we conform to the expected (# of points, point dimension)
@@ -106,6 +120,9 @@ def __init__(self, points_expr):
106120
assert len(points_expr.shape) == 2
107121
self._points_expr = points_expr
108122

123+
def __repr__(self):
124+
return f"{type(self).__name__}({self._points_expr!r})"
125+
109126
@cached_property
110127
def points(self):
111128
return UnknownPointsArray(self._points_expr.shape)
@@ -133,6 +150,9 @@ def __init__(self, points):
133150
assert len(points.shape) == 2
134151
self.points = points
135152

153+
def __repr__(self):
154+
return f"{type(self).__name__}({self.points!r})"
155+
136156
@cached_property
137157
def points(self):
138158
pass # set at initialisation
@@ -177,6 +197,9 @@ class TensorPointSet(AbstractPointSet):
177197
def __init__(self, factors):
178198
self.factors = tuple(factors)
179199

200+
def __repr__(self):
201+
return f"{type(self).__name__}({self.factors!r})"
202+
180203
@cached_property
181204
def points(self):
182205
return numpy.array([list(chain(*pt_tuple))

finat/quadrature.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from abc import ABCMeta, abstractproperty
2-
from functools import reduce
1+
import hashlib
2+
from abc import ABCMeta, abstractmethod
3+
from functools import cached_property, reduce
34

45
import gem
56
import numpy
67
from FIAT.quadrature import GaussLegendreQuadratureLineRule
78
from FIAT.quadrature_schemes import create_quadrature as fiat_scheme
89
from FIAT.reference_element import LINE, QUADRILATERAL, TENSORPRODUCT
9-
from gem.utils import cached_property
10+
from gem.utils import safe_repr
1011

1112
from finat.point_set import GaussLegendrePointSet, PointSet, TensorPointSet
1213

@@ -60,11 +61,23 @@ class AbstractQuadratureRule(metaclass=ABCMeta):
6061
"""Abstract class representing a quadrature rule as point set and a
6162
corresponding set of weights."""
6263

63-
@abstractproperty
64+
def __hash__(self):
65+
return int.from_bytes(hashlib.md5(repr(self).encode()).digest(), byteorder="big")
66+
67+
def __eq__(self, other):
68+
return type(other) is type(self) and repr(other) == repr(self)
69+
70+
@abstractmethod
71+
def __repr__(self):
72+
pass
73+
74+
@property
75+
@abstractmethod
6476
def point_set(self):
6577
"""Point set object representing the quadrature points."""
6678

67-
@abstractproperty
79+
@property
80+
@abstractmethod
6881
def weight_expression(self):
6982
"""GEM expression describing the weights, with the same free indices
7083
as the point set."""
@@ -110,6 +123,16 @@ def __init__(self, point_set, weights, ref_el=None, io_ornt_map_tuple=(None, )):
110123
self.weights = numpy.asarray(weights)
111124
self._intrinsic_orientation_permutation_map_tuple = io_ornt_map_tuple
112125

126+
def __repr__(self):
127+
return (
128+
f"{type(self).__name__}("
129+
f"{self.point_set!r}, "
130+
f"{safe_repr(self.weights)}, "
131+
f"{self.ref_el!r}, "
132+
f"{self._intrinsic_orientation_permutation_map_tuple!r}"
133+
")"
134+
)
135+
113136
@cached_property
114137
def point_set(self):
115138
pass # set at initialisation
@@ -131,6 +154,9 @@ def __init__(self, factors, ref_el=None):
131154
for m in factor._intrinsic_orientation_permutation_map_tuple
132155
)
133156

157+
def __repr__(self):
158+
return f"{type(self).__name__}({self.factors!r}, {self.ref_el!r})"
159+
134160
@cached_property
135161
def point_set(self):
136162
return TensorPointSet(q.point_set for q in self.factors)

gem/utils.py

+55
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import collections
2+
import functools
3+
import numbers
24
from functools import cached_property # noqa: F401
5+
from typing import Any
6+
7+
import numpy as np
38

49

510
def groupby(iterable, key=None):
@@ -88,3 +93,53 @@ def __exit__(self, exc_type, exc_value, traceback):
8893
assert self.state is variable._head
8994
value, variable._head = variable._head
9095
self.state = None
96+
97+
98+
@functools.singledispatch
99+
def safe_repr(obj: Any) -> str:
100+
"""Return a 'safe' repr for an object, accounting for floating point error.
101+
102+
Parameters
103+
----------
104+
obj :
105+
The object to produce a repr for.
106+
107+
Returns
108+
-------
109+
str :
110+
A repr for the object.
111+
112+
"""
113+
raise TypeError(f"Cannot provide a safe repr for {type(obj).__name__}")
114+
115+
116+
@safe_repr.register(str)
117+
def _(text: str) -> str:
118+
return text
119+
120+
121+
@safe_repr.register(numbers.Integral)
122+
def _(num: numbers.Integral) -> str:
123+
return repr(num)
124+
125+
126+
@safe_repr.register(numbers.Real)
127+
def _(num: numbers.Real) -> str:
128+
# set roundoff to close-to-but-not-exactly machine epsilon
129+
precision = np.finfo(num).precision - 2
130+
return "{:.{prec}}".format(num, prec=precision)
131+
132+
133+
@safe_repr.register(np.ndarray)
134+
def _(array: np.ndarray) -> str:
135+
return f"{type(array).__name__}([{', '.join(map(safe_repr, array))}])"
136+
137+
138+
@safe_repr.register(list)
139+
def _(list_: list) -> str:
140+
return f"[{', '.join(map(safe_repr, list_))}]"
141+
142+
143+
@safe_repr.register(tuple)
144+
def _(tuple_: tuple) -> str:
145+
return f"({', '.join(map(safe_repr, tuple_))})"

test/finat/test_create_fiat_element.py

+5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def tensor_name(request):
5959
ids=lambda x: x.cellname(),
6060
scope="module")
6161
def ufl_A(request, tensor_name):
62+
if request.param == ufl.quadrilateral:
63+
if tensor_name == "DG":
64+
tensor_name = "DQ"
65+
elif tensor_name == "DG L2":
66+
tensor_name = "DQ L2"
6267
return finat.ufl.FiniteElement(tensor_name, request.param, 1)
6368

6469

test/finat/test_quadrature.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
from FIAT import ufc_cell
4+
from finat.quadrature import make_quadrature
5+
6+
7+
@pytest.mark.parametrize(
8+
"cell_name",
9+
["interval", "triangle", "interval * interval", "triangle * interval"]
10+
)
11+
def test_quadrature_rules_are_hashable(cell_name):
12+
ref_cell = ufc_cell(cell_name)
13+
quadrature1 = make_quadrature(ref_cell, 3)
14+
quadrature2 = make_quadrature(ref_cell, 3)
15+
16+
assert quadrature1 is not quadrature2
17+
assert hash(quadrature1) == hash(quadrature2)
18+
assert repr(quadrature1) == repr(quadrature2)
19+
assert quadrature1 == quadrature2

0 commit comments

Comments
 (0)