Skip to content

Commit 150222e

Browse files
authored
Merge pull request #99 from d0c-s4vage/hotfix/98-typedefd_structs_with_params
Adds tests for typedef'd parameterized structs
2 parents 4ec0ae2 + 68fd244 commit 150222e

File tree

3 files changed

+72
-17
lines changed

3 files changed

+72
-17
lines changed

pfp/interp.py

+45-14
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def _pfp__init(self, stream):
5252
self._pfp__node.args, scope, self, None
5353
)
5454
param_list = params.instantiate(scope, struct_args, self._pfp__interp)
55-
super(self.__class__, self)._pfp__init(stream)
55+
56+
if hasattr(super(self.__class__, self), "_pfp__init"):
57+
super(self.__class__, self)._pfp__init(stream)
5658

5759
new_class = type(
5860
struct_cls.__name__ + "_", (struct_cls,), {"_pfp__init": _pfp__init}
@@ -84,13 +86,28 @@ def StructUnionTypeRef(curr_scope, typedef_name, refd_name, interp, node):
8486
elif isinstance(node, AST.Union):
8587
cls = fields.Union
8688

87-
def __new__(self, *args, **kwargs):
89+
def __new__(cls_, *args, **kwargs):
8890
refd_type = curr_scope.get_type(refd_name)
8991
if refd_type is None:
9092
refd_node = node
9193
else:
9294
refd_node = refd_type._pfp__node
93-
return StructUnionDef(typedef_name, interp, refd_node)(*args, **kwargs)
95+
96+
def merged_init(self, stream):
97+
if six.PY3:
98+
cls_._pfp__init(self, stream)
99+
else:
100+
cls_._pfp__init.__func__(self, stream)
101+
self._pfp__init_orig(stream)
102+
103+
overrides = {}
104+
if hasattr(cls_, "_pfp__init"):
105+
overrides["_pfp__init"] = merged_init
106+
107+
res = base_cls = StructUnionDef(
108+
typedef_name, interp, refd_node, overrides=overrides,
109+
)
110+
return res(*args, **kwargs)
94111

95112
new_class = type(
96113
typedef_name,
@@ -102,13 +119,16 @@ def __new__(self, *args, **kwargs):
102119
return new_class
103120

104121

105-
106-
def StructUnionDef(typedef_name, interp, node):
122+
def StructUnionDef(typedef_name, interp, node, overrides=None, cls=None):
123+
if overrides is None:
124+
overrides = {}
107125
if isinstance(node, AST.Struct):
108-
cls = fields.Struct
126+
if cls is None:
127+
cls = fields.Struct
109128
decls = StructDecls(node.decls, node.coord)
110129
elif isinstance(node, AST.Union):
111-
cls = fields.Union
130+
if cls is None:
131+
cls = fields.Union
112132
decls = UnionDecls(node.decls, node.coord)
113133

114134
# this is so that we can have all nested structs added to
@@ -117,23 +137,34 @@ def StructUnionDef(typedef_name, interp, node):
117137
# the new struct to not be added to its parent, and the user would
118138
# not be able to see how far the script got
119139
def __init__(self, stream=None, metadata_processor=None, do_init=True):
120-
cls.__init__(self, stream, metadata_processor=metadata_processor)
140+
cls.__init__(
141+
self,
142+
stream,
143+
metadata_processor=metadata_processor,
144+
)
121145

122146
if do_init:
123147
self._pfp__init(stream)
124148

125149
def _pfp__init(self, stream):
126150
self._pfp__interp._handle_node(decls, ctxt=self, stream=stream)
127151

152+
cls_members = {
153+
"__init__": __init__,
154+
"_pfp__init": _pfp__init,
155+
"_pfp__node": node,
156+
"_pfp__interp": interp,
157+
}
158+
159+
for k, v in six.iteritems(overrides or {}):
160+
if k in cls_members:
161+
cls_members[k + "_orig"] = cls_members[k]
162+
cls_members[k] = v
163+
128164
new_class = type(
129165
typedef_name,
130166
(cls,),
131-
{
132-
"__init__": __init__,
133-
"_pfp__init": _pfp__init,
134-
"_pfp__node": node,
135-
"_pfp__interp": interp,
136-
},
167+
cls_members,
137168
)
138169
return new_class
139170

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
py010parser>=0.1.15
1+
py010parser>=0.1.17
22
six>=1.10.0,<2.0.0
3-
intervaltree>=3.0.2,<4.0.0
3+
intervaltree>=3.0.2,<4.0.0

tests/test_struct_union.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_struct_vit9696_5(self):
9494
LittleEndian();
9595
ME s;
9696
""",
97-
debug=True,
97+
debug=False,
9898
)
9999
assert dom.s.magic == "\x00\x01\x02\x03"
100100
assert dom.s.filesize == 0x03020100
@@ -239,6 +239,30 @@ def test_struct_with_parameters3(self):
239239
self.assertEqual(dom.l.c[1], 2)
240240
self.assertEqual(dom.l.c[2], 3)
241241

242+
def test_typedefd_struct_with_parameters(self):
243+
dom = self._test_parse_build(
244+
"\x01\x02\x03\x04\x01\x02\x03",
245+
"""
246+
struct TEST_STRUCT(int arraySize, int arraySize2)
247+
{
248+
uchar b[arraySize];
249+
uchar c[arraySize2];
250+
};
251+
local int bytes = 4;
252+
typedef struct TEST_STRUCT NEW_STRUCT;
253+
NEW_STRUCT l(bytes, 3);
254+
""",
255+
)
256+
self.assertEqual(len(dom.l.b), 4)
257+
self.assertEqual(dom.l.b[0], 1)
258+
self.assertEqual(dom.l.b[1], 2)
259+
self.assertEqual(dom.l.b[2], 3)
260+
self.assertEqual(dom.l.b[3], 4)
261+
self.assertEqual(len(dom.l.c), 3)
262+
self.assertEqual(dom.l.c[0], 1)
263+
self.assertEqual(dom.l.c[1], 2)
264+
self.assertEqual(dom.l.c[2], 3)
265+
242266
def test_struct_decl_with_struct_keyword(self):
243267
dom = self._test_parse_build(
244268
"ABCD",

0 commit comments

Comments
 (0)