Skip to content

Commit

Permalink
Fix sot eval and test len (#59408)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
  • Loading branch information
3 people authored Dec 6, 2023
1 parent b5ebcae commit 549b33a
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 7 deletions.
18 changes: 18 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/phi/core/kernel_factory.h"
Expand Down Expand Up @@ -281,6 +282,23 @@ void DoValueCheck(const pir::Value& value,
value_type,
joined.str()));
}
} else if (value.type().isa<paddle::dialect::SelectedRowsType>()) {
std::string value_type = phi::DataTypeToString(dialect::TransToPhiDataType(
value.type().dyn_cast<paddle::dialect::SelectedRowsType>().dtype()));
if (expected_dtype.find(value_type) == expected_dtype.end()) {
std::ostringstream joined;
std::copy(expected_dtype.begin(),
expected_dtype.end(),
std::ostream_iterator<std::string>(joined, ","));
PADDLE_THROW(phi::errors::InvalidArgument(
"Check data type error for op: %s, input: %s, %s.dtype is %s, and "
"expected_dtype is %s",
op_name,
input_name,
input_name,
value_type,
joined.str()));
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get dtype for dense "
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ phi::DataType GetValueDtype(Value value) {
const phi::DDim &GetValueDims(Value value) {
if (value.type().isa<DenseTensorType>()) {
return value.type().dyn_cast<DenseTensorType>().dims();
} else if (value.type().isa<SelectedRowsType>()) {
return value.type().dyn_cast<SelectedRowsType>().dims();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get shape for dense "
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,8 @@ def convert_len(var):
% type(var)
)
elif isinstance(var, OpResult):
assert var.ndim > 0, "len() of a 0-D tensor is wrong"
if var.is_dense_tensor_type() or var.is_selected_row_type():
assert var.ndim > 0, "len() of a 0-D tensor is wrong"
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ def _perform_call(self, *args, **kwargs):
traced_fun = symbolic_translate(
self._dygraph_function,
build_strategy=build_strategy,
training=self._training,
backend=backend,
)
if self._class_instance is not None:
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ class FallbackWrapper:
Used to store and call static graph methods generated by paddle.jit.to_static
"""

def __init__(self, compiled_fn, SIR):
def __init__(self, compiled_fn, SIR, is_training: bool):
self.compiled_fn = compiled_fn
self.partial_program = None
self.concrete_program = None
self.SIR = SIR # for debug
self.is_training = is_training
self.compiled_fn.eval() if not is_training else self.compiled_fn.train()

def amp_cast_inputs(self, args, kwargs):
"""Prepare inputs for amp, cast float16 into float32 if needed."""
Expand Down Expand Up @@ -149,7 +151,7 @@ def key_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs):
"""
sir = context.get_sir(sir_name)
# NOTE(dev): Is str(sir) a heavy opearation ?
hash_key = hash(str(sir))
hash_key = hash((str(sir), kwargs['training']))
return hash_key

def value_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs):
Expand All @@ -174,4 +176,5 @@ def value_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs):
full_graph=True,
),
context.get_sir(sir_name),
is_training=kwargs['training'],
)
2 changes: 2 additions & 0 deletions python/paddle/jit/sot/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]:
"""

kwargs.setdefault('training', True)

def callback(frame):
return eval_frame_callback(frame, **kwargs)

Expand Down
5 changes: 3 additions & 2 deletions test/dygraph_to_static/dygraph_to_static_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ def to_pir_test(fn):
@wraps(fn)
def impl(*args, **kwargs):
logger.info("[PIR] running pir")
ir_outs = None
in_dygraph_mode = paddle.in_dynamic_mode()
with paddle.pir_utils.IrGuard():
paddle.disable_static()
if in_dygraph_mode:
paddle.disable_static()
ir_outs = fn(*args, **kwargs)
return ir_outs

Expand Down
65 changes: 63 additions & 2 deletions test/dygraph_to_static/test_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_only,
test_pir_only,
)

import paddle
from paddle import base
Expand Down Expand Up @@ -67,6 +73,8 @@ def _run(self, to_static):
out = out.numpy()
return out

@test_ast_only
@test_legacy_and_pir
def test_len(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
Expand All @@ -81,7 +89,49 @@ def init_func(self):
# Note: Variable(SelectedRows) is not exposed directly in dygraph.
# The unittest is used to test coverage by fake transformed code.
def len_with_selected_rows(place):
block = base.default_main_program().global_block()
# create selected_rows variable
paddle.enable_static()
non_used_initializer = paddle.nn.initializer.Constant(0.0)
var = paddle.static.create_parameter(
name="X",
dtype="float32",
shape=[5, 20],
)
selected_var = (
paddle.base.libpaddle.pir.create_selected_rows_type_by_dense_tensor(
var.type()
)
)
var.set_type(selected_var)
# y is Variable(SelectedRows)
y = clip.merge_selected_rows(var)
y_len = Call(len)(y)

# z is inner tensor with shape [4, 2]
z = clip.get_tensor_from_selected_rows(y)
z_len = paddle.shape(z)[0]

# set data for selected_rows
x_rows = [0, 2, 2, 4, 19]
row_numel = 2
np_array = np.ones((len(x_rows), row_numel)).astype("float32")

x_var = base.global_scope().var("X").get_selected_rows()
x_var.set_rows(x_rows)
x_var.set_height(20)
x_tensor = x_var.get_tensor()
x_tensor.set(np_array, place)

exe = paddle.static.Executor(place=place)
result = exe.run(
paddle.static.default_main_program(), fetch_list=[y_len, z_len]
)
return result


def legacy_len_with_selected_rows(place):
paddle.enable_static()
block = paddle.static.default_main_program().global_block()
# create selected_rows variable
var = block.create_var(
name="X",
Expand Down Expand Up @@ -122,6 +172,17 @@ def setUp(self):
else base.CPUPlace()
)

@test_legacy_only
@test_ast_only
def test_len_legacy(self):
selected_rows_var_len, var_tensor_len = legacy_len_with_selected_rows(
self.place
)
self.assertEqual(selected_rows_var_len, var_tensor_len)

@test_pir_only
@test_ast_only
@test_pir_only
def test_len(self):
selected_rows_var_len, var_tensor_len = len_with_selected_rows(
self.place
Expand Down
65 changes: 65 additions & 0 deletions test/sot/test_model_switch_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle


class SimpleNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
if self.training:
out1 = paddle.nn.functional.dropout(x, p=0.5, training=True)
else:
out1 = paddle.nn.functional.dropout(x, p=0.5, training=False)
return out1


class TestModelSwitchTraining(unittest.TestCase):
def setUp(self):
self.seed = 1127
self.net = SimpleNet()

def get_dygraph_out(self, input):
paddle.seed(self.seed)
self.net.eval()
eval_result = self.net(input)
self.net.train()
train_result = self.net(input)
return eval_result, train_result

def get_static_out(self, input):
paddle.seed(self.seed)
static_net = paddle.jit.to_static(self.net)
static_net.eval()
eval_result = static_net(input)
static_net.train()
train_result = static_net(input)
return eval_result, train_result

def test_model_switch_training(self):
input = paddle.rand((10, 10))
dygraph_eval, dygraph_train = self.get_dygraph_out(input)
static_eval, static_train = self.get_static_out(input)
np.testing.assert_allclose(dygraph_eval.numpy(), static_eval.numpy())
np.testing.assert_allclose(dygraph_train.numpy(), static_train.numpy())


if __name__ == "__main__":
unittest.main()

0 comments on commit 549b33a

Please sign in to comment.