Skip to content

Commit f7aa88b

Browse files
simpkinsfacebook-github-bot
authored andcommitted
[caffe2] Explicitly define all DataTypes in python/core.py (pytorch#51768)
Summary: Pull Request resolved: pytorch#51768 This updates python/core.py to explicitly define all of the `DataType` values rather than dynamically defining them at runtime from the `caffe2_pb2` values. This allows type checkers like Pyre and Mypy to see the members of the `DataType` class. Otherwise the type checkers report errors such as `"core.DataType" has no attribute "INT64"`. This code does keep a run-time check that all of the data types defined by `caffe2_pb2.proto` are defined correctly in this file. This way if someone does add a new type to `caffe2_pb2.proto` it should be very quickly apparent that this file needs to be updated and kept in sync. ghstack-source-id: 121936201 Test Plan: Confirmed that various caffe2/python tests still pass. Verified that this allows many `pyre-fixme` comments to be removed in downstream projects, and that Pyre is still clean for these projects. Reviewed By: jeffdunn Differential Revision: D26271725 Pulled By: simpkins fbshipit-source-id: f9e95795de60aba67d7d3872d0c141ed82ba8e39
1 parent 27d8905 commit f7aa88b

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

caffe2/python/core.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,36 @@
4040

4141
# Bring datatype enums to the main namespace
4242
class DataType:
43-
pass
44-
45-
46-
def _InitDataType():
43+
UNDEFINED = 0
44+
FLOAT = 1
45+
INT32 = 2
46+
BYTE = 3
47+
STRING = 4
48+
BOOL = 5
49+
UINT8 = 6
50+
INT8 = 7
51+
UINT16 = 8
52+
INT16 = 9
53+
INT64 = 10
54+
FLOAT16 = 12
55+
DOUBLE = 13
56+
ZERO_COLLISION_HASH = 14
57+
REBATCHING_BUFFER = 15
58+
59+
60+
def _CheckDataType():
61+
# Verify that the DataType values defined above match the ones defined in
62+
# the caffe2.proto file
4763
for name, value in caffe2_pb2.TensorProto.DataType.items():
48-
setattr(DataType, name, value)
64+
py_value = getattr(DataType, name, None)
65+
if py_value != value:
66+
raise AssertionError(
67+
f"DataType {name} does not match the value defined in "
68+
f"caffe2.proto: {py_value} vs {value}"
69+
)
4970

5071

51-
_InitDataType()
72+
_CheckDataType()
5273

5374

5475
def _GetRegisteredOperators():

0 commit comments

Comments
 (0)