Skip to content

Commit 14dab06

Browse files
Provide automatic sanitization of field names in namedview (#19196)
Related to #19164 -- MbP uses names which are not immediately useable as `namedview` fields. After this PR (and now that #19174 has landed), we can do e.g. ``` PositionView = namedview("PositionView", plant.GetPositionNames()) StateView = namedview("StateView", plant.GetStateNames()) ``` and all of the MbP introspection variants. Co-Authored-By: Jeremy Nimmer <jeremy.nimmer@tri.global>
1 parent c79058c commit 14dab06

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

bindings/pydrake/common/containers.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Provides extensions for containers of Drake-related objects."""
22

33
import numpy as np
4+
import re
45

56

67
class _EqualityProxyBase:
@@ -71,7 +72,6 @@ class EqualToDict(_DictKeyWrap):
7172
`lhs.EqualTo(rhs)`.
7273
"""
7374
def __init__(self, *args, **kwargs):
74-
7575
class Proxy(_EqualityProxyBase):
7676
def __eq__(self, other):
7777
T = type(self.value)
@@ -118,8 +118,7 @@ def __setattr__(self, name, value):
118118
if not hasattr(self, name):
119119
raise AttributeError(
120120
"Cannot add attributes! The fields in this named view are"
121-
f"{self.get_fields()}, but you tried to set '{name}'."
122-
)
121+
f"{self.get_fields()}, but you tried to set '{name}'.")
123122
object.__setattr__(self, name, value)
124123

125124
def __len__(self):
@@ -142,22 +141,41 @@ def __repr__(self):
142141
@staticmethod
143142
def _item_property(i):
144143
# Maps an item (at a given index) to a property.
145-
return property(
146-
fget=lambda self: self[i],
147-
fset=lambda self, value: self.__setitem__(i, value))
144+
return property(fget=lambda self: self[i],
145+
fset=lambda self, value: self.__setitem__(i, value))
148146

149147
@classmethod
150148
def Zero(cls):
151149
"""Constructs a view onto values set to all zeros."""
152-
return cls([0]*len(cls._fields))
150+
return cls([0] * len(cls._fields))
151+
153152

153+
def _sanitize_field_name(name: str):
154+
result = name
155+
# Ensure the first character is a valid opener (e.g., no numbers allowed).
156+
if not result[0].isidentifier():
157+
result = "_" + result
158+
# Ensure that each additional character is valid in turn, avoiding the
159+
# special case for opening characters by prepending "_" during the check.
160+
for i in range(1, len(result)):
161+
if not ("_" + result[i]).isidentifier():
162+
result = result[:i] + "_" + result[i+1:]
163+
result = re.sub("__+", "_", result)
164+
assert result.isidentifier(), f"Sanitization failed on {name} => {result}"
165+
return result
154166

155-
def namedview(name, fields):
167+
168+
def namedview(name, fields, *, sanitize_field_names=True):
156169
"""
157170
Creates a class that is a named view with given ``fields``. When the class
158171
is instantiated, it must be given the object that it will be a proxy for.
159172
Similar to ``namedtuple``.
160173
174+
If ``sanitize_field_names`` is True (the default), then any characters in
175+
``fields`` which are not valid in Python identifiers will be automatically
176+
replaced with `_`. Leading numbers will have `_` inserted, and duplicate
177+
`_` will be replaced by a single `_`.
178+
161179
Example:
162180
::
163181
@@ -186,7 +204,10 @@ def namedview(name, fields):
186204
187205
For more details, see ``NamedViewBase``.
188206
"""
189-
base_cls = (NamedViewBase,)
207+
base_cls = (NamedViewBase, )
208+
if sanitize_field_names:
209+
fields = [_sanitize_field_name(f) for f in fields]
210+
assert len(set(fields)) == len(fields), "Field names must be unique"
190211
type_dict = dict(_fields=tuple(fields))
191212
for i, field in enumerate(fields):
192213
type_dict[field] = NamedViewBase._item_property(i)

bindings/pydrake/common/test/containers_test.py

+22
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,25 @@ def test_Zero(self):
130130
self.assertEqual(len(view), 2)
131131
self.assertEqual(view.a, 0)
132132
self.assertEqual(view.b, 0)
133+
134+
def test_name_sanitation(self):
135+
MyView = namedview("MyView",
136+
["$world_base", "iiwa::iiwa", "no spaces", "2vär"])
137+
self.assertEqual(MyView.get_fields(),
138+
("_world_base", "iiwa_iiwa", "no_spaces", "_2vär"))
139+
view = MyView.Zero()
140+
view._world_base = 3
141+
view.iiwa_iiwa = 4
142+
view.no_spaces = 5
143+
view._2vär = 6
144+
np.testing.assert_equal(view[:], [3, 4, 5, 6])
145+
146+
MyView = namedview("MyView", ["$world_base", "iiwa::iiwa"],
147+
sanitize_field_names=False)
148+
self.assertEqual(MyView.get_fields(), ("$world_base", "iiwa::iiwa"))
149+
150+
def test_uniqueness(self):
151+
with self.assertRaisesRegex(AssertionError, ".*must be unique.*"):
152+
namedview("MyView", ['a', 'a'])
153+
with self.assertRaisesRegex(AssertionError, ".*must be unique.*"):
154+
namedview("MyView", ['a_a', 'a__a'])

0 commit comments

Comments
 (0)