Skip to content

Commit fa86ed3

Browse files
authored
Fix RaviartThomas on the interval (#71)
1 parent 8e3c7c7 commit fa86ed3

File tree

4 files changed

+36
-17
lines changed

4 files changed

+36
-17
lines changed

FIAT/expansions.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ def get_cell_node_map(self, n):
309309
return self._cell_node_map_cache.setdefault(n, cell_node_map)
310310

311311
def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None):
312+
"""Returns a dict of tabulations such that
313+
tabulations[alpha][i, j] = D^alpha phi_i(pts[j])."""
312314
from FIAT.polynomial_set import mis
313315
lorder = min(order, self.recurrence_order)
314316
A, b = self.affine_mappings[cell]
@@ -468,10 +470,11 @@ def __init__(self, ref_el, **kwargs):
468470
raise ValueError("Must have a point")
469471
super(PointExpansionSet, self).__init__(ref_el, **kwargs)
470472

471-
def tabulate(self, n, pts):
472-
"""Returns a numpy array A[i,j] = phi_i(pts[j]) = 1.0."""
473-
assert n == 0
474-
return numpy.ones((1, len(pts)))
473+
def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None):
474+
"""Returns a dict of tabulations such that
475+
tabulations[alpha][i, j] = D^alpha phi_i(pts[j])."""
476+
assert n == 0 and order == 0
477+
return {(): numpy.ones((1, len(pts)))}
475478

476479

477480
class LineExpansionSet(ExpansionSet):
@@ -482,8 +485,8 @@ def __init__(self, ref_el, **kwargs):
482485
super(LineExpansionSet, self).__init__(ref_el, **kwargs)
483486

484487
def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None):
485-
"""Returns a tuple of (vals, derivs) such that
486-
vals[i,j] = phi_i(pts[j]), derivs[i,j] = D vals[i,j]."""
488+
"""Returns a dict of tabulations such that
489+
tabulations[alpha][i, j] = D^alpha phi_i(pts[j])."""
487490
if self.variant is not None:
488491
return super(LineExpansionSet, self)._tabulate_on_cell(n, pts, order=order, cell=cell, direction=direction)
489492

FIAT/quadrature.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@ def map_quadrature(pts_ref, wts_ref, source_cell, target_cell, jacobian=False):
2424
A, b = reference_element.make_affine_mapping(source_cell.get_vertices(),
2525
target_cell.get_vertices())
2626
scale = pseudo_determinant(A)
27-
pts = numpy.dot(pts_ref.reshape((-1, A.shape[1])), A.T) + b[None, :]
2827
wts = scale * wts_ref
28+
if pts_ref.size == 0:
29+
pts = b[None, :]
30+
else:
31+
pts = numpy.dot(pts_ref.reshape((-1, A.shape[1])), A.T) + b[None, :]
32+
2933
# return immutable types
3034
pts = tuple(map(tuple, pts))
3135
wts = tuple(wts.flat)

FIAT/raviart_thomas.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def __init__(self, ref_el, degree, variant, interpolant_deg):
7272

7373
if variant == "integral":
7474
facet = ref_el.get_facet_element()
75-
# Facet nodes are \int_F v\cdot n p ds where p \in P_{q-1}
76-
# degree is q - 1
77-
Q_ref = create_quadrature(facet, interpolant_deg + degree - 1)
78-
Pq = polynomial_set.ONPolynomialSet(facet, degree - 1)
75+
# Facet nodes are \int_F v\cdot n p ds where p \in P_q
76+
q = degree - 1
77+
Q_ref = create_quadrature(facet, interpolant_deg + q)
78+
Pq = polynomial_set.ONPolynomialSet(facet, q if sd > 1 else 0)
7979
Pq_at_qpts = Pq.tabulate(Q_ref.get_points())[(0,)*(sd - 1)]
8080
for f in top[sd - 1]:
8181
cur = len(nodes)
@@ -87,15 +87,15 @@ def __init__(self, ref_el, degree, variant, interpolant_deg):
8787
for phi in phis)
8888
entity_ids[sd - 1][f] = list(range(cur, len(nodes)))
8989

90-
# internal nodes. These are \int_T v \cdot p dx where p \in P_{q-2}^d
91-
if degree > 1:
90+
# internal nodes. These are \int_T v \cdot p dx where p \in P_{q-1}^d
91+
if q > 0:
9292
cur = len(nodes)
93-
Q = create_quadrature(ref_el, interpolant_deg + degree - 2)
94-
Pkm1 = polynomial_set.ONPolynomialSet(ref_el, degree - 2)
95-
Pkm1_at_qpts = Pkm1.tabulate(Q.get_points())[(0,) * sd]
93+
Q = create_quadrature(ref_el, interpolant_deg + q - 1)
94+
Pqm1 = polynomial_set.ONPolynomialSet(ref_el, q - 1)
95+
Pqm1_at_qpts = Pqm1.tabulate(Q.get_points())[(0,) * sd]
9696
nodes.extend(functional.IntegralMoment(ref_el, Q, phi, (d,), (sd,))
9797
for d in range(sd)
98-
for phi in Pkm1_at_qpts)
98+
for phi in Pqm1_at_qpts)
9999
entity_ids[sd][0] = list(range(cur, len(nodes)))
100100

101101
elif variant == "point":

test/unit/test_fiat.py

+12
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,36 @@ def __init__(self, a, b):
132132
"CrouzeixRaviart(I, 1)",
133133
"CrouzeixRaviart(T, 1)",
134134
"CrouzeixRaviart(S, 1)",
135+
"RaviartThomas(I, 1)",
136+
"RaviartThomas(I, 2)",
137+
"RaviartThomas(I, 3)",
135138
"RaviartThomas(T, 1)",
136139
"RaviartThomas(T, 2)",
137140
"RaviartThomas(T, 3)",
138141
"RaviartThomas(S, 1)",
139142
"RaviartThomas(S, 2)",
140143
"RaviartThomas(S, 3)",
144+
'RaviartThomas(I, 1, variant="integral")',
145+
'RaviartThomas(I, 2, variant="integral")',
146+
'RaviartThomas(I, 3, variant="integral")',
141147
'RaviartThomas(T, 1, variant="integral")',
142148
'RaviartThomas(T, 2, variant="integral")',
143149
'RaviartThomas(T, 3, variant="integral")',
144150
'RaviartThomas(S, 1, variant="integral")',
145151
'RaviartThomas(S, 2, variant="integral")',
146152
'RaviartThomas(S, 3, variant="integral")',
153+
'RaviartThomas(I, 1, variant="integral(1)")',
154+
'RaviartThomas(I, 2, variant="integral(1)")',
155+
'RaviartThomas(I, 3, variant="integral(1)")',
147156
'RaviartThomas(T, 1, variant="integral(1)")',
148157
'RaviartThomas(T, 2, variant="integral(1)")',
149158
'RaviartThomas(T, 3, variant="integral(1)")',
150159
'RaviartThomas(S, 1, variant="integral(1)")',
151160
'RaviartThomas(S, 2, variant="integral(1)")',
152161
'RaviartThomas(S, 3, variant="integral(1)")',
162+
'RaviartThomas(I, 1, variant="point")',
163+
'RaviartThomas(I, 2, variant="point")',
164+
'RaviartThomas(I, 3, variant="point")',
153165
'RaviartThomas(T, 1, variant="point")',
154166
'RaviartThomas(T, 2, variant="point")',
155167
'RaviartThomas(T, 3, variant="point")',

0 commit comments

Comments
 (0)