11
11
class RestrictedElement (CiarletElement ):
12
12
"""Restrict given element to specified list of dofs."""
13
13
14
- def __init__ (self , element , indices = None , restriction_domain = None ):
14
+ def __init__ (self , element , indices = None , restriction_domain = None , take_closure = True ):
15
15
'''For sake of argument, indices overrides restriction_domain'''
16
16
17
17
if not (indices or restriction_domain ):
18
18
raise RuntimeError ("Either indices or restriction_domain must be passed in" )
19
19
20
20
if not indices :
21
- indices = _get_indices (element , restriction_domain )
21
+ indices = _get_indices (element , restriction_domain , take_closure )
22
22
23
23
if isinstance (indices , str ):
24
24
raise RuntimeError ("variable 'indices' was a string; did you forget to use a keyword?" )
@@ -70,7 +70,7 @@ def _key(x):
70
70
return sorted (mapping .items (), key = _key )
71
71
72
72
73
- def _get_indices (element , restriction_domain ):
73
+ def _get_indices (element , restriction_domain , take_closure ):
74
74
"Restriction domain can be 'interior', 'vertex', 'edge', 'face' or 'facet'"
75
75
76
76
if restriction_domain == "interior" :
@@ -91,9 +91,10 @@ def _get_indices(element, restriction_domain):
91
91
92
92
is_prodcell = isinstance (max (element .entity_dofs ().keys ()), tuple )
93
93
94
+ ldim = 0 if take_closure else dim
94
95
entity_dofs = element .entity_dofs ()
95
96
indices = []
96
- for d in range (dim + 1 ):
97
+ for d in range (ldim , dim + 1 ):
97
98
if is_prodcell :
98
99
for a in range (d + 1 ):
99
100
b = d - a
0 commit comments