Skip to content

Commit d0bea63

Browse files
authored
Merge pull request #64 from firedrakeproject/rckirby/feature/macro
Implement C0/C1 macroelements
2 parents dbc1c5d + 23ad19a commit d0bea63

25 files changed

+1712
-457
lines changed

FIAT/P0.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,28 @@
1717
class P0Dual(dual_set.DualSet):
1818
def __init__(self, ref_el):
1919
entity_ids = {}
20-
nodes = []
2120
entity_permutations = {}
22-
vs = numpy.array(ref_el.get_vertices())
23-
if ref_el.get_dimension() == 0:
24-
bary = ()
25-
else:
26-
bary = tuple(numpy.average(vs, 0))
27-
28-
nodes = [functional.PointEvaluation(ref_el, bary)]
29-
entity_ids = {}
21+
sd = ref_el.get_dimension()
3022
top = ref_el.get_topology()
23+
if sd == 0:
24+
pts = [tuple() for entity in sorted(top[sd])]
25+
else:
26+
pts = [tuple(numpy.average(ref_el.get_vertices_of_subcomplex(top[sd][entity]), 0))
27+
for entity in sorted(top[sd])]
28+
nodes = [functional.PointEvaluation(ref_el, pt) for pt in pts]
3129
for dim in sorted(top):
3230
entity_ids[dim] = {}
3331
entity_permutations[dim] = {}
3432
sym_size = ref_el.symmetry_group_size(dim)
33+
num_points = 1 if dim == sd else 0
3534
if isinstance(dim, tuple):
3635
assert isinstance(sym_size, tuple)
37-
perms = {o: [] for o in numpy.ndindex(sym_size)}
36+
perms = {o: list(range(num_points)) for o in numpy.ndindex(sym_size)}
3837
else:
39-
perms = {o: [] for o in range(sym_size)}
38+
perms = {o: list(range(num_points)) for o in range(sym_size)}
4039
for entity in sorted(top[dim]):
41-
entity_ids[dim][entity] = []
40+
entity_ids[dim][entity] = [entity] if dim == sd else []
4241
entity_permutations[dim][entity] = perms
43-
entity_ids[dim] = {0: [0]}
44-
if isinstance(dim, tuple):
45-
entity_permutations[dim][0] = {o: [0] for o in numpy.ndindex(sym_size)}
46-
else:
47-
entity_permutations[dim][0] = {o: [0] for o in range(sym_size)}
4842

4943
super(P0Dual, self).__init__(nodes, ref_el, entity_ids, entity_permutations)
5044

FIAT/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Import finite element classes
88
from FIAT.finite_element import FiniteElement, CiarletElement # noqa: F401
99
from FIAT.argyris import Argyris
10+
from FIAT.hct import HsiehCloughTocher
1011
from FIAT.bernstein import Bernstein
1112
from FIAT.bell import Bell
1213
from FIAT.argyris import QuinticArgyris
@@ -30,6 +31,7 @@
3031
from FIAT.morley import Morley
3132
from FIAT.nedelec import Nedelec
3233
from FIAT.nedelec_second_kind import NedelecSecondKind
34+
from FIAT.hierarchical import Legendre, IntegratedLegendre
3335
from FIAT.P0 import P0
3436
from FIAT.raviart_thomas import RaviartThomas
3537
from FIAT.crouzeix_raviart import CrouzeixRaviart
@@ -48,7 +50,6 @@
4850
from FIAT.restricted import RestrictedElement # noqa: F401
4951
from FIAT.quadrature_element import QuadratureElement # noqa: F401
5052
from FIAT.kong_mulder_veldhuizen import KongMulderVeldhuizen # noqa: F401
51-
from FIAT.hierarchical import Legendre, IntegratedLegendre # noqa: F401
5253
from FIAT.fdm_element import FDMLagrange, FDMDiscontinuousLagrange, FDMQuadrature, FDMBrokenH1, FDMBrokenL2, FDMHermite # noqa: F401
5354

5455
# Important functionality
@@ -61,6 +62,7 @@
6162

6263
# List of supported elements and mapping to element classes
6364
supported_elements = {"Argyris": Argyris,
65+
"HsiehCloughTocher": HsiehCloughTocher,
6466
"Bell": Bell,
6567
"Bernstein": Bernstein,
6668
"Brezzi-Douglas-Marini": BrezziDouglasMarini,
@@ -81,6 +83,8 @@
8183
"Gauss-Lobatto-Legendre": GaussLobattoLegendre,
8284
"Gauss-Legendre": GaussLegendre,
8385
"Gauss-Radau": GaussRadau,
86+
"Legendre": Legendre,
87+
"Integrated Legendre": IntegratedLegendre,
8488
"Morley": Morley,
8589
"Nedelec 1st kind H(curl)": Nedelec,
8690
"Nedelec 2nd kind H(curl)": NedelecSecondKind,

FIAT/barycentric_interpolation.py

+65-46
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,40 @@
1111
from FIAT.functional import index_iterator
1212

1313

14+
def get_lagrange_points(nodes):
15+
"""Extract singleton point for each node."""
16+
points = []
17+
for node in nodes:
18+
pt, = node.get_point_dict()
19+
points.append(pt)
20+
return points
21+
22+
23+
def barycentric_interpolation(nodes, wts, dmat, pts, order=0):
24+
"""Evaluates a Lagrange basis on a line reference element
25+
via the second barycentric interpolation formula. See Berrut and Trefethen (2004)
26+
https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4)
27+
"""
28+
if pts.dtype == object:
29+
from sympy import simplify
30+
sp_simplify = numpy.vectorize(simplify)
31+
else:
32+
sp_simplify = lambda x: x
33+
phi = numpy.add.outer(-nodes, pts.flatten())
34+
with numpy.errstate(divide='ignore', invalid='ignore'):
35+
numpy.reciprocal(phi, out=phi)
36+
numpy.multiply(phi, wts[:, None], out=phi)
37+
numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi)
38+
phi[phi != phi] = 1.0
39+
40+
phi = sp_simplify(phi)
41+
results = {(0,): phi}
42+
for r in range(1, order+1):
43+
phi = sp_simplify(numpy.dot(dmat, phi))
44+
results[(r,)] = phi
45+
return results
46+
47+
1448
def make_dmat(x):
1549
"""Returns Lagrange differentiation matrix and barycentric weights
1650
associated with x[j]."""
@@ -24,83 +58,68 @@ def make_dmat(x):
2458

2559

2660
class LagrangeLineExpansionSet(expansions.LineExpansionSet):
27-
"""Evaluates a Lagrange basis on a line reference element
28-
via the second barycentric interpolation formula. See Berrut and Trefethen (2004)
29-
https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4)
30-
"""
61+
"""Lagrange polynomial expansion set for given points the line."""
3162
def __init__(self, ref_el, pts):
3263
self.points = pts
33-
self.x = numpy.array(pts).flatten()
34-
self.dmat, self.weights = make_dmat(self.x)
64+
self.x = numpy.array(pts, dtype="d").flatten()
65+
self.cell_node_map = expansions.compute_cell_point_map(ref_el, pts, unique=False)
66+
self.dmats = []
67+
self.weights = []
68+
self.nodes = []
69+
for ibfs in self.cell_node_map:
70+
nodes = self.x[ibfs]
71+
dmat, wts = make_dmat(nodes)
72+
self.dmats.append(dmat)
73+
self.weights.append(wts)
74+
self.nodes.append(nodes)
75+
76+
self.degree = max(len(wts) for wts in self.weights)-1
77+
self.recurrence_order = self.degree + 1
3578
super(LagrangeLineExpansionSet, self).__init__(ref_el)
3679

3780
def get_num_members(self, n):
3881
return len(self.points)
3982

83+
def get_cell_node_map(self, n):
84+
return self.cell_node_map
85+
4086
def get_points(self):
4187
return self.points
4288

43-
def get_dmats(self, degree):
44-
return [self.dmat.T]
45-
46-
def tabulate(self, n, pts):
47-
assert n == len(self.points)-1
48-
results = numpy.add.outer(-self.x, numpy.array(pts).flatten())
49-
with numpy.errstate(divide='ignore', invalid='ignore'):
50-
numpy.reciprocal(results, out=results)
51-
numpy.multiply(results, self.weights[:, None], out=results)
52-
numpy.multiply(1.0 / numpy.sum(results, axis=0), results, out=results)
53-
54-
results[results != results] = 1.0
55-
if results.dtype == object:
56-
from sympy import simplify
57-
results = numpy.vectorize(simplify)(results)
58-
return results
59-
60-
def _tabulate(self, n, pts, order=0):
61-
vals = self.tabulate(n, pts)
62-
results = [vals]
63-
for r in range(order):
64-
vals = numpy.dot(self.dmat, vals)
65-
if vals.dtype == object:
66-
from sympy import simplify
67-
vals = numpy.vectorize(simplify)(vals)
68-
results.append(vals)
69-
for r in range(order+1):
70-
shape = results[r].shape
71-
shape = shape[:1] + (1,)*r + shape[1:]
72-
results[r] = numpy.reshape(results[r], shape)
73-
return results
89+
def get_dmats(self, degree, cell=0):
90+
return [self.dmats[cell].T]
91+
92+
def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None):
93+
return barycentric_interpolation(self.nodes[cell], self.weights[cell], self.dmats[cell], pts, order=order)
7494

7595

7696
class LagrangePolynomialSet(polynomial_set.PolynomialSet):
7797

7898
def __init__(self, ref_el, pts, shape=tuple()):
79-
degree = len(pts) - 1
99+
if ref_el.get_shape() != reference_element.LINE:
100+
raise ValueError("Invalid reference element type.")
101+
102+
expansion_set = LagrangeLineExpansionSet(ref_el, pts)
103+
degree = expansion_set.degree
80104
if shape == tuple():
81105
num_components = 1
82106
else:
83107
flat_shape = numpy.ravel(shape)
84108
num_components = numpy.prod(flat_shape)
85-
num_exp_functions = expansions.polynomial_dimension(ref_el, degree)
109+
num_exp_functions = expansion_set.get_num_members(degree)
86110
num_members = num_components * num_exp_functions
87111
embedded_degree = degree
88-
if ref_el.get_shape() == reference_element.LINE:
89-
expansion_set = LagrangeLineExpansionSet(ref_el, pts)
90-
else:
91-
raise ValueError("Invalid reference element type.")
92112

93113
# set up coefficients
94114
if shape == tuple():
95-
coeffs = numpy.eye(num_members)
115+
coeffs = numpy.eye(num_members, dtype="d")
96116
else:
97117
coeffs_shape = (num_members, *shape, num_exp_functions)
98118
coeffs = numpy.zeros(coeffs_shape, "d")
99119
# use functional's index_iterator function
100120
cur_bf = 0
101121
for idx in index_iterator(shape):
102-
n = expansions.polynomial_dimension(ref_el, embedded_degree)
103-
for exp_bf in range(n):
122+
for exp_bf in range(num_exp_functions):
104123
cur_idx = (cur_bf, *idx, exp_bf)
105124
coeffs[cur_idx] = 1.0
106125
cur_bf += 1

FIAT/check_format_variant.py

+65
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
import re
22

3+
from FIAT.macro import AlfeldSplit, IsoSplit
4+
5+
# dicts mapping Lagrange variant names to recursivenodes family names
6+
supported_cg_variants = {
7+
"spectral": "gll",
8+
"chebyshev": "lgc",
9+
"equispaced": "equispaced",
10+
"gll": "gll"}
11+
12+
supported_dg_variants = {
13+
"spectral": "gl",
14+
"chebyshev": "gc",
15+
"equispaced": "equispaced",
16+
"equispaced_interior": "equispaced_interior",
17+
"gll": "gll",
18+
"gl": "gl"}
19+
320

421
def check_format_variant(variant, degree):
522
if variant is None:
@@ -20,3 +37,51 @@ def check_format_variant(variant, degree):
2037
'or variant="integral(q)"')
2138

2239
return variant, interpolant_degree
40+
41+
42+
def parse_lagrange_variant(variant, discontinuous=False, integral=False):
43+
"""Parses variant options for Lagrange elements.
44+
45+
variant may be a single option or comma-separated pair
46+
indicating the dof type (integral, equispaced, spectral, etc)
47+
and the type of splitting to give a macro-element (Alfeld, iso)
48+
"""
49+
if variant is None:
50+
variant = "integral" if integral else "equispaced"
51+
options = variant.replace(" ", "").split(",")
52+
assert len(options) <= 2
53+
54+
default = "integral" if integral else "spectral"
55+
if integral:
56+
supported_point_variants = {"integral": None}
57+
elif discontinuous:
58+
supported_point_variants = supported_dg_variants
59+
else:
60+
supported_point_variants = supported_cg_variants
61+
62+
# defaults
63+
splitting = None
64+
splitting_args = tuple()
65+
point_variant = supported_point_variants[default]
66+
67+
for pre_opt in options:
68+
opt = pre_opt.lower()
69+
if opt == "alfeld":
70+
splitting = AlfeldSplit
71+
elif opt == "iso":
72+
splitting = IsoSplit
73+
elif opt.startswith("iso"):
74+
match = re.match(r"^iso(?:\((\d+)\))?$", opt)
75+
k, = match.groups()
76+
call_split = IsoSplit
77+
splitting_args = (int(k),)
78+
elif opt in supported_point_variants:
79+
point_variant = supported_point_variants[opt]
80+
else:
81+
raise ValueError("Illegal variant option")
82+
83+
if discontinuous and splitting is not None and point_variant in supported_cg_variants.values():
84+
raise ValueError("Illegal variant. DG macroelements with DOFs on subcell boundaries are not unisolvent.")
85+
if len(splitting_args) > 0:
86+
splitting = lambda T: call_split(T, *splitting_args, point_variant or "gll")
87+
return splitting, point_variant

0 commit comments

Comments
 (0)