Skip to content

Commit 49e6bdc

Browse files
authored
more robust tests for HCT/HCT-red (#94)
1 parent 1088bfd commit 49e6bdc

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

test/unit/test_hct.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import pytest
22
import numpy
33

4-
from FIAT import HsiehCloughTocher as HCT
5-
from FIAT.reference_element import ufc_simplex, make_lattice
4+
from FIAT import RestrictedElement, HsiehCloughTocher as HCT
5+
from FIAT.reference_element import ufc_simplex
66
from FIAT.functional import PointEvaluation
7+
from FIAT.macro import CkPolynomialSet
78

89

910
@pytest.fixture
@@ -13,12 +14,28 @@ def cell():
1314
return K
1415

1516

17+
def span_greater_equal(A, B):
18+
# span(A) >= span(B)
19+
_, residual, *_ = numpy.linalg.lstsq(A.reshape(A.shape[0], -1).T,
20+
B.reshape(B.shape[0], -1).T)
21+
return numpy.allclose(residual, 0)
22+
23+
24+
def make_points(K, degree):
25+
top = K.get_topology()
26+
pts = []
27+
for dim in top:
28+
for entity in top[dim]:
29+
pts.extend(K.make_points(dim, entity, degree))
30+
return pts
31+
32+
1633
@pytest.mark.parametrize("reduced", (False, True))
1734
def test_hct_constant(cell, reduced):
1835
# Test that bfs associated with point evaluation sum up to 1
1936
fe = HCT(cell, reduced=reduced)
2037

21-
pts = make_lattice(cell.get_vertices(), 3)
38+
pts = make_points(cell, 4)
2239
tab = fe.tabulate(2, pts)
2340

2441
coefs = numpy.zeros((fe.space_dimension(),))
@@ -33,3 +50,27 @@ def test_hct_constant(cell, reduced):
3350
expected = 1 if sum(alpha) == 0 else 0
3451
vals = numpy.dot(coefs, tab[alpha])
3552
assert numpy.allclose(vals, expected)
53+
54+
55+
@pytest.mark.parametrize("reduced", (False, True))
56+
def test_full_polynomials(cell, reduced):
57+
# Test that HCT/HCT-red contains all cubics/quadratics
58+
fe = HCT(cell, reduced=reduced)
59+
if reduced:
60+
fe = RestrictedElement(fe, restriction_domain="vertex")
61+
62+
ref_complex = fe.get_reference_complex()
63+
pts = make_points(ref_complex, 4)
64+
tab = fe.tabulate(0, pts)[(0, 0)]
65+
66+
degree = fe.degree()
67+
if reduced:
68+
degree -= 1
69+
70+
P = CkPolynomialSet(cell, degree, variant="bubble")
71+
P_tab = P.tabulate(pts)[(0, 0)]
72+
assert span_greater_equal(tab, P_tab)
73+
74+
C1 = CkPolynomialSet(ref_complex, degree, order=1, variant="bubble")
75+
C1_tab = C1.tabulate(pts)[(0, 0)]
76+
assert span_greater_equal(tab, C1_tab)

test/unit/test_stokes_complex.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ def rAQ(cell):
3535

3636
def span_greater_equal(A, B):
3737
# span(A) >= span(B)
38-
dimA = A.shape[0]
39-
dimB = B.shape[0]
40-
_, residual, *_ = numpy.linalg.lstsq(A.reshape(dimA, -1).T,
41-
B.reshape(dimB, -1).T)
38+
_, residual, *_ = numpy.linalg.lstsq(A.reshape(A.shape[0], -1).T,
39+
B.reshape(B.shape[0], -1).T)
4240
return numpy.allclose(residual, 0)
4341

4442

0 commit comments

Comments
 (0)