Skip to content

Commit e7b2909

Browse files
authored
Merge pull request #61 from firedrakeproject/pbrubeck/cleanup-dualset
Tidy up DualSet
2 parents a306799 + 28935d9 commit e7b2909

9 files changed

+299
-482
lines changed

FIAT/brezzi_douglas_marini.py

+37-57
Original file line numberDiff line numberDiff line change
@@ -9,81 +9,61 @@
99
polynomial_set, nedelec)
1010
from FIAT.check_format_variant import check_format_variant
1111
from FIAT.quadrature_schemes import create_quadrature
12+
from FIAT.quadrature import FacetQuadratureRule
1213

1314

1415
class BDMDualSet(dual_set.DualSet):
1516
def __init__(self, ref_el, degree, variant, interpolant_deg):
16-
17-
# Initialize containers for map: mesh_entity -> dof number and
18-
# dual basis
19-
entity_ids = {}
2017
nodes = []
21-
2218
sd = ref_el.get_spatial_dimension()
23-
t = ref_el.get_topology()
19+
top = ref_el.get_topology()
20+
21+
entity_ids = {}
22+
# set to empty
23+
for dim in top:
24+
entity_ids[dim] = {}
25+
for entity in top[dim]:
26+
entity_ids[dim][entity] = []
2427

2528
if variant == "integral":
2629
facet = ref_el.get_facet_element()
27-
# Facet nodes are \int_F v\cdot n p ds where p \in P_{q-1}
28-
# degree is q - 1
29-
Q = create_quadrature(facet, interpolant_deg + degree)
30+
# Facet nodes are \int_F v\cdot n p ds where p \in P_{q}
31+
# degree is q
32+
Q_ref = create_quadrature(facet, interpolant_deg + degree)
3033
Pq = polynomial_set.ONPolynomialSet(facet, degree)
31-
Pq_at_qpts = Pq.tabulate(Q.get_points())[(0,)*(sd - 1)]
32-
nodes.extend(functional.IntegralMomentOfScaledNormalEvaluation(ref_el, Q, phi, f)
33-
for f in range(len(t[sd - 1]))
34-
for phi in Pq_at_qpts)
35-
36-
# internal nodes
37-
if degree > 1:
38-
Q = create_quadrature(ref_el, interpolant_deg + degree - 1)
39-
qpts = Q.get_points()
40-
Nedel = nedelec.Nedelec(ref_el, degree - 1, variant)
41-
Nedfs = Nedel.get_nodal_basis()
42-
Ned_at_qpts = Nedfs.tabulate(qpts)[(0,) * sd]
34+
Pq_at_qpts = Pq.tabulate(Q_ref.get_points())[(0,)*(sd - 1)]
35+
for f in top[sd - 1]:
36+
cur = len(nodes)
37+
Q = FacetQuadratureRule(ref_el, sd - 1, f, Q_ref)
38+
Jdet = Q.jacobian_determinant()
39+
n = ref_el.compute_scaled_normal(f) / Jdet
40+
phis = n[None, :, None] * Pq_at_qpts[:, None, :]
4341
nodes.extend(functional.FrobeniusIntegralMoment(ref_el, Q, phi)
44-
for phi in Ned_at_qpts)
42+
for phi in phis)
43+
entity_ids[sd - 1][f] = list(range(cur, len(nodes)))
4544

4645
elif variant == "point":
4746
# Define each functional for the dual set
4847
# codimension 1 facets
49-
for i in range(len(t[sd - 1])):
50-
pts_cur = ref_el.make_points(sd - 1, i, sd + degree)
51-
nodes.extend(functional.PointScaledNormalEvaluation(ref_el, i, pt)
48+
for f in top[sd - 1]:
49+
cur = len(nodes)
50+
pts_cur = ref_el.make_points(sd - 1, f, sd + degree)
51+
nodes.extend(functional.PointScaledNormalEvaluation(ref_el, f, pt)
5252
for pt in pts_cur)
53+
entity_ids[sd - 1][f] = list(range(cur, len(nodes)))
5354

54-
# internal nodes
55-
if degree > 1:
56-
Q = create_quadrature(ref_el, 2 * degree - 1)
57-
qpts = Q.get_points()
58-
Nedel = nedelec.Nedelec(ref_el, degree - 1, variant)
59-
Nedfs = Nedel.get_nodal_basis()
60-
Ned_at_qpts = Nedfs.tabulate(qpts)[(0,) * sd]
61-
nodes.extend(functional.FrobeniusIntegralMoment(ref_el, Q, phi)
62-
for phi in Ned_at_qpts)
63-
64-
# sets vertices (and in 3d, edges) to have no nodes
65-
for i in range(sd - 1):
66-
entity_ids[i] = {}
67-
for j in range(len(t[i])):
68-
entity_ids[i][j] = []
69-
70-
cur = 0
71-
72-
# set codimension 1 (edges 2d, faces 3d) dof
73-
pts_facet_0 = ref_el.make_points(sd - 1, 0, sd + degree)
74-
pts_per_facet = len(pts_facet_0)
75-
76-
entity_ids[sd - 1] = {}
77-
for i in range(len(t[sd - 1])):
78-
entity_ids[sd - 1][i] = list(range(cur, cur + pts_per_facet))
79-
cur += pts_per_facet
80-
81-
# internal nodes, if applicable
82-
entity_ids[sd] = {0: []}
83-
55+
# internal nodes
8456
if degree > 1:
85-
num_internal_nodes = len(Ned_at_qpts)
86-
entity_ids[sd][0] = list(range(cur, cur + num_internal_nodes))
57+
if interpolant_deg is None:
58+
interpolant_deg = degree
59+
cur = len(nodes)
60+
Q = create_quadrature(ref_el, interpolant_deg + degree - 1)
61+
Nedel = nedelec.Nedelec(ref_el, degree - 1, variant)
62+
Nedfs = Nedel.get_nodal_basis()
63+
Ned_at_qpts = Nedfs.tabulate(Q.get_points())[(0,) * sd]
64+
nodes.extend(functional.FrobeniusIntegralMoment(ref_el, Q, phi)
65+
for phi in Ned_at_qpts)
66+
entity_ids[sd][0] = list(range(cur, len(nodes)))
8767

8868
super(BDMDualSet, self).__init__(nodes, ref_el, entity_ids)
8969

FIAT/dual_set.py

+60-46
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
# SPDX-License-Identifier: LGPL-3.0-or-later
88

99
import numpy
10-
import collections
1110

12-
from FIAT import polynomial_set
11+
from FIAT import polynomial_set, functional
1312

1413

1514
class DualSet(object):
@@ -105,63 +104,78 @@ def to_riesz(self, poly_set):
105104
ed = poly_set.get_embedded_degree()
106105
num_exp = es.get_num_members(poly_set.get_embedded_degree())
107106

108-
riesz_shape = tuple([num_nodes] + list(tshape) + [num_exp])
109-
107+
riesz_shape = (num_nodes, *tshape, num_exp)
110108
mat = numpy.zeros(riesz_shape, "d")
111109

112-
# Dictionaries mapping pts to which functionals they come from
113-
pts_to_ells = collections.OrderedDict()
114-
dpts_to_ells = collections.OrderedDict()
115-
110+
pts = set()
111+
dpts = set()
112+
Qs_to_ells = dict()
116113
for i, ell in enumerate(self.nodes):
117-
for pt in ell.pt_dict:
118-
if pt in pts_to_ells:
119-
pts_to_ells[pt].append(i)
120-
else:
121-
pts_to_ells[pt] = [i]
122-
123-
for pt in ell.deriv_dict:
124-
if pt in dpts_to_ells:
125-
dpts_to_ells[pt].append(i)
126-
else:
127-
dpts_to_ells[pt] = [i]
114+
if isinstance(ell, functional.IntegralMoment):
115+
Q = ell.Q
116+
else:
117+
Q = None
118+
pts.update(ell.pt_dict.keys())
119+
dpts.update(ell.deriv_dict.keys())
120+
if Q in Qs_to_ells:
121+
Qs_to_ells[Q].append(i)
122+
else:
123+
Qs_to_ells[Q] = [i]
124+
125+
Qs_to_pts = {None: tuple(sorted(pts))}
126+
for Q in Qs_to_ells:
127+
if Q is not None:
128+
cur_pts = tuple(map(tuple, Q.pts))
129+
Qs_to_pts[Q] = cur_pts
130+
pts.update(cur_pts)
128131

129132
# Now tabulate the function values
130-
pts = list(pts_to_ells.keys())
131-
expansion_values = es.tabulate(ed, pts)
132-
133-
for j, pt in enumerate(pts):
134-
which_ells = pts_to_ells[pt]
135-
136-
for k in which_ells:
137-
pt_dict = self.nodes[k].pt_dict
138-
wc_list = pt_dict[pt]
139-
140-
for i in range(num_exp):
141-
for (w, c) in wc_list:
142-
mat[k][c][i] += w*expansion_values[i, j]
133+
pts = list(sorted(pts))
134+
expansion_values = numpy.transpose(es.tabulate(ed, pts))
135+
136+
for Q in Qs_to_ells:
137+
ells = Qs_to_ells[Q]
138+
cur_pts = Qs_to_pts[Q]
139+
indices = list(map(pts.index, cur_pts))
140+
wshape = (len(ells), *tshape, len(cur_pts))
141+
wts = numpy.zeros(wshape, "d")
142+
if Q is None:
143+
for i, k in enumerate(ells):
144+
ell = self.nodes[k]
145+
for pt, wc_list in ell.pt_dict.items():
146+
j = cur_pts.index(pt)
147+
for (w, c) in wc_list:
148+
wts[i][c][j] = w
149+
else:
150+
for i, k in enumerate(ells):
151+
ell = self.nodes[k]
152+
wts[i][ell.comp][:] = ell.f_at_qpts
153+
qwts = Q.get_weights()
154+
wts = numpy.multiply(wts, qwts, out=wts)
155+
mat[ells] += numpy.dot(wts, expansion_values[indices])
143156

144157
# Tabulate the derivative values that are needed
145-
max_deriv_order = max([ell.max_deriv_order for ell in self.nodes])
158+
max_deriv_order = max(ell.max_deriv_order for ell in self.nodes)
146159
if max_deriv_order > 0:
147-
dpts = list(dpts_to_ells.keys())
160+
dpts = list(sorted(dpts))
148161
# It's easiest/most efficient to get derivatives of the
149162
# expansion set through the polynomial set interface.
150163
# This is creating a short-lived set to do just this.
151-
expansion = polynomial_set.ONPolynomialSet(self.ref_el, ed)
164+
coeffs = numpy.eye(num_exp)
165+
expansion = polynomial_set.PolynomialSet(self.ref_el, ed, ed, es, coeffs)
152166
dexpansion_values = expansion.tabulate(dpts, max_deriv_order)
153167

154-
for j, pt in enumerate(dpts):
155-
which_ells = dpts_to_ells[pt]
156-
157-
for k in which_ells:
158-
dpt_dict = self.nodes[k].deriv_dict
159-
wac_list = dpt_dict[pt]
160-
161-
for i in range(num_exp):
162-
for (w, alpha, c) in wac_list:
163-
mat[k][c][i] += w*dexpansion_values[alpha][i, j]
164-
168+
ells = [k for k, ell in enumerate(self.nodes) if len(ell.deriv_dict) > 0]
169+
wshape = (len(ells), *tshape, len(dpts))
170+
dwts = {alpha: numpy.zeros(wshape, "d") for alpha in dexpansion_values if sum(alpha) > 0}
171+
for i, k in enumerate(ells):
172+
ell = self.nodes[k]
173+
for pt, wac_list in ell.deriv_dict.items():
174+
j = dpts.index(pt)
175+
for (w, alpha, c) in wac_list:
176+
dwts[alpha][i][c][j] = w
177+
for alpha in dwts:
178+
mat[ells] += numpy.dot(dwts[alpha], dexpansion_values[alpha].T)
165179
return mat
166180

167181

FIAT/functional.py

+14-25
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,7 @@ def index_iterator(shp):
2626
"""Constructs a generator iterating over all indices in
2727
shp in generalized column-major order So if shp = (2,2), then we
2828
construct the sequence (0,0),(0,1),(1,0),(1,1)"""
29-
if len(shp) == 0:
30-
return
31-
elif len(shp) == 1:
32-
for i in range(shp[0]):
33-
yield [i]
34-
else:
35-
shp_foo = shp[1:]
36-
for i in range(shp[0]):
37-
for foo in index_iterator(shp_foo):
38-
yield [i] + foo
29+
return numpy.ndindex(shp)
3930

4031

4132
class Functional(object):
@@ -292,12 +283,11 @@ class IntegralMoment(Functional):
292283

293284
def __init__(self, ref_el, Q, f_at_qpts, comp=tuple(), shp=tuple()):
294285
self.Q = Q
286+
self.f_at_qpts = f_at_qpts
295287
qpts, qwts = Q.get_points(), Q.get_weights()
296-
pt_dict = OrderedDict()
297288
self.comp = comp
298-
for i in range(len(qpts)):
299-
pt_cur = tuple(qpts[i])
300-
pt_dict[pt_cur] = [(qwts[i] * f_at_qpts[i], comp)]
289+
weights = numpy.multiply(f_at_qpts, qwts)
290+
pt_dict = {tuple(pt): [(wt, comp)] for pt, wt in zip(qpts, weights)}
301291
Functional.__init__(self, ref_el, shp, pt_dict, {}, "IntegralMoment")
302292

303293
def __call__(self, fn):
@@ -331,7 +321,7 @@ def __init__(self, ref_el, facet_no, Q, f_at_qpts):
331321

332322
dpt_dict = OrderedDict()
333323

334-
alphas = [tuple([1 if j == i else 0 for j in range(sd)]) for i in range(sd)]
324+
alphas = [tuple(1 if j == i else 0 for j in range(sd)) for i in range(sd)]
335325
for j, pt in enumerate(dpts):
336326
dpt_dict[tuple(pt)] = [(qwts[j]*n[i]*f_at_qpts[j], alphas[i], tuple()) for i in range(sd)]
337327

@@ -484,24 +474,23 @@ def __init__(self, ref_el, Q, f_at_qpts):
484474
"IntegralMomentOfDivergence")
485475

486476

487-
class FrobeniusIntegralMoment(Functional):
477+
class FrobeniusIntegralMoment(IntegralMoment):
488478

489479
def __init__(self, ref_el, Q, f_at_qpts):
490480
# f_at_qpts is (some shape) x num_qpts
491481
shp = tuple(f_at_qpts.shape[:-1])
492-
if len(Q.get_points()) != f_at_qpts.shape[-1]:
482+
if len(Q.pts) != f_at_qpts.shape[-1]:
493483
raise Exception("Mismatch in number of quadrature points and values")
494484

485+
self.Q = Q
486+
self.comp = slice(None, None)
487+
self.f_at_qpts = f_at_qpts
495488
qpts, qwts = Q.get_points(), Q.get_weights()
496-
pt_dict = {}
497-
498-
for i, (pt_cur, wt_cur) in enumerate(zip(map(tuple, qpts), qwts)):
499-
pt_dict[pt_cur] = []
500-
for alfa in index_iterator(shp):
501-
qpidx = tuple(alfa + [i])
502-
pt_dict[pt_cur].append((wt_cur * f_at_qpts[qpidx], tuple(alfa)))
489+
weights = numpy.transpose(numpy.multiply(f_at_qpts, qwts), (-1,) + tuple(range(len(shp))))
490+
alphas = list(index_iterator(shp))
503491

504-
super().__init__(ref_el, shp, pt_dict, {}, "FrobeniusIntegralMoment")
492+
pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(qpts, weights)}
493+
Functional.__init__(self, ref_el, shp, pt_dict, {}, "FrobeniusIntegralMoment")
505494

506495

507496
class PointNormalEvaluation(Functional):

0 commit comments

Comments
 (0)