forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_backend.cpp
124 lines (113 loc) · 4.46 KB
/
test_backend.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/torch.h>
// Tests go in torch::jit
namespace torch {
namespace jit {
TEST(BackendTest, ToBackend) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
def accum(self, x, h):
return x + h
def sub_accum(self, x, h):
return x - h
)");
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs).toTuple()->elements();
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"test_backend", m, compile_spec, any_dict_ty);
// lowered module code:
/*
class test_backendLoweredModule(Module):
__parameters__ = []
__buffers__ = []
__processed_module : Any
__method_compile_spec : Dict[str, Any]
__backend : __torch__.torch.classes.__backends__.test_backend
__handles : Dict[str, Any]
def __create_backend(self: torch.jit.test_backendLoweredModule) -> None:
_0 =
__torch__.torch.classes.__backends__.test_backend.__new__(__torch__.torch.classes.__backends__.test_backend)
_1 = (_0).__init__()
self.__backend = _0
return None
def __getstate__(self: torch.jit.test_backendLoweredModule) ->
Tuple[Dict[str, Any], Any]: _2 = (self.__method_compile_spec,
self.__processed_module) return _2 def __setstate__(self:
torch.jit.test_backendLoweredModule, state: Tuple[Dict[str, Any], Any]) ->
None: self.__method_compile_spec = (state)[0] self.__processed_module =
(state)[1] _3 = (self).__create_backend() _4 =
(self.__backend).compile(self.__processed_module,
self.__method_compile_spec, ) self.__handles = _4 return None def
forward(self: torch.jit.test_backendLoweredModule, x: Tensor, h: Tensor) ->
Tuple[Tensor, Tensor]: _5 = uninitialized(Tensor) typed_inputs =
annotate(List[Any], [x, h]) _6 =
(self.__backend).execute((self.__handles)["forward"], typed_inputs, ) _7,
_8, = _6 _9 = isinstance(_7, Tensor) if _9: _10 = unchecked_cast(Tensor, _7)
else:
ops.prim.RaiseException("AssertionError: ")
_10 = _5
_11 = isinstance(_8, Tensor)
if _11:
_12 = unchecked_cast(Tensor, _8)
else:
ops.prim.RaiseException("AssertionError: ")
_12 = _5
return (_10, _12)
*/
auto res = lm.forward(inputs).toTuple()->elements();
AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor()));
AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
}
TEST(BackendTest, TestCompiler) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return x + h
)");
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs);
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
auto res = lm.forward(inputs);
AT_ASSERT(res.toTensor().equal(ref.toTensor()));
}
TEST(BackendTest, TestCompilerNotSupport) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return x * h
)");
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
ASSERT_THROWS_WITH_MESSAGE(
torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty),
"The node of aten::mul is not supported in this compiler. Source code:");
}
} // namespace jit
} // namespace torch