Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix local buffer resize #62856

Merged
merged 4 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions paddle/cinn/ir/group_schedule/config/group_tile_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,27 @@ BuildStaticReduceConfig(
/* tree_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* reduce_method = */ NoneReduceMethod()};
BucketInfo bucket_info__1024_INF{/* sp_lower_bound = */ 1024,
/* sp_upper_bound = */ kMaxNumel,
/* rb_lower_bound = */ 1,
/* rb_upper_bound = */ 1};
ScheduleConfig::TileConfig tile_config__1024_INF{
BucketInfo bucket_info__1024_1M{/* sp_lower_bound = */ 1024,
/* sp_upper_bound = */ 1024 * 1024 - 1,
/* rb_lower_bound = */ 1,
/* rb_upper_bound = */ 1};
ScheduleConfig::TileConfig tile_config__1024_1M{
/* warp_num = */ 32,
/* tree_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* reduce_method = */ NoneReduceMethod()};
BucketInfo bucket_info__1M_INF{/* sp_lower_bound = */ 1024 * 1024,
/* sp_upper_bound = */ kMaxNumel,
/* rb_lower_bound = */ 1,
/* rb_upper_bound = */ 1};
ScheduleConfig::TileConfig tile_config__1M_INF{
/* warp_num = */ 32,
/* tree_reduce_num = */ 1,
/* spatial_inner_num = */ 16,
/* reduce_method = */ NoneReduceMethod()};
return {{bucket_info__1_1023, tile_config__1_1023},
{bucket_info__1024_INF, tile_config__1024_INF}};
{bucket_info__1024_1M, tile_config__1024_1M},
{bucket_info__1M_INF, tile_config__1M_INF}};
} else if (base_info->reduce_numel <= 256) {
BucketInfo bucket_info{/* sp_lower_bound = */ 1,
/* sp_upper_bound = */ kMaxNumel,
Expand Down
83 changes: 68 additions & 15 deletions paddle/cinn/optim/resize_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/optim/replace_mod_to_max.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"

PD_DECLARE_bool(group_schedule_tiling_first);
namespace cinn {
namespace optim {

Expand Down Expand Up @@ -71,6 +73,7 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> {
ir::Store* store = expr->As<ir::Store>();
ir::Tensor tensor = store->tensor.as_tensor_ref();
AnalyzeTensorRange(store->indices, tensor);
AnalyzeBufferSize(store->indices, tensor);
ir::IRMutator<>::Visit(op, expr);
}

Expand Down Expand Up @@ -103,10 +106,8 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> {
private:
void AnalyzeTensorRange(const std::vector<Expr>& indices,
const ir::Tensor& tensor) {
if (!tensor->buffer.defined() ||
tensor->buffer->memory_type == ir::MemoryType::Heap) {
return;
}
if (!tensor->buffer.defined()) return;
if (tensor->buffer->memory_type == ir::MemoryType::Heap) return;

std::vector<ir::Expr> indice_extent;
for (int i = 0; i < indices.size(); ++i) {
Expand Down Expand Up @@ -144,6 +145,45 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> {
<< buffer_name_to_indice_extent[buffer_name];
}

void AnalyzeBufferSize(const std::vector<Expr>& indices,
const ir::Tensor& tensor) {
if (!tensor->buffer.defined()) return;
if (tensor->buffer->memory_type == ir::MemoryType::Heap) return;

const std::string& buffer_name = tensor->buffer->name;
buffer_name_to_size[buffer_name] = AnalyzeBufferSize(indices);
VLOG(6) << "buffer_name = " << buffer_name
<< ", size = " << buffer_name_to_size[buffer_name];
}

ir::Expr AnalyzeBufferSize(const std::vector<ir::Expr>& indices) {
const auto GetIterVarNames =
[](const std::vector<ir::Expr>& indices) -> std::set<std::string> {
std::set<std::string> iter_var_names;
for (const ir::Expr& e : indices) {
ir::ir_utils::CollectIRNodes(e, [&](const ir::Expr* x) {
if (x->as_var() && !x->as_var()->is_symbolic_constant) {
iter_var_names.insert(x->as_var()->name);
}
return false;
});
}
return iter_var_names;
};

std::set<std::string> iter_var_names = GetIterVarNames(indices);
ir::Expr size(1);
for (const std::string& var_name : iter_var_names) {
PADDLE_ENFORCE_GT(var_name_to_extent_.count(var_name),
0,
::common::errors::PreconditionNotMet(
"Cannot find the extent of var %s", var_name));
size = common::AutoSimplify(size * var_name_to_extent_.at(var_name));
}

return size;
}

// A recursion function to calculate the max index range
// The index may contain some vars like index = 8 * i / j, where we know the
// range of i, j, we search all values to get the max index range
Expand Down Expand Up @@ -188,6 +228,7 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> {
public:
std::unordered_map<std::string, std::vector<ir::Expr>>
buffer_name_to_indice_extent;
std::unordered_map<std::string, ir::Expr> buffer_name_to_size;

private:
std::unordered_map<std::string, ir::Expr> var_name_to_extent_;
Expand All @@ -197,8 +238,10 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
public:
ResizeBufferFromAnalyzedRange(
const std::unordered_map<std::string, std::vector<ir::Expr>>&
buffer_name_to_shape)
: buffer_name_to_shape_(buffer_name_to_shape) {}
buffer_name_to_shape,
const std::unordered_map<std::string, ir::Expr>& buffer_name_to_size)
: buffer_name_to_shape_(buffer_name_to_shape),
buffer_name_to_size_(buffer_name_to_size) {}

void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

Expand All @@ -221,8 +264,11 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
return;
}

load->tensor.as_tensor_ref()->shape =
load->tensor.as_tensor_ref()->buffer->shape;
const std::string& buffer_name = load->tensor.as_tensor_ref()->buffer->name;
if (buffer_name_to_shape_.count(buffer_name) > 0) {
load->tensor.as_tensor_ref()->shape =
buffer_name_to_shape_.at(buffer_name);
}

// For the moment, align the load tensor indices with the tensor shape using
// the trick method. A better way would be to modify the FlattenLoop
Expand All @@ -237,33 +283,40 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
private:
void ResizeTensor(ir::Tensor* tensor_ptr) {
ir::Buffer buffer = (*tensor_ptr)->buffer;
if (!buffer.defined() || buffer->memory_type == ir::MemoryType::Heap) {
return;
}
if (!buffer.defined()) return;
if (buffer->memory_type == ir::MemoryType::Heap) return;

const std::string& buffer_name = buffer->name;
if (buffer_name_to_shape_.count(buffer_name)) {
const std::vector<ir::Expr>& analyzed_shape =
buffer_name_to_shape_.at(buffer_name);
VLOG(6) << "Replacing shape of tensor " << (*tensor_ptr)->name
<< ", buffer " << buffer->name << ", with shape "
<< analyzed_shape;

<< " with shape " << analyzed_shape;
(*tensor_ptr)->shape = analyzed_shape;
buffer->shape = analyzed_shape;
}
if (FLAGS_group_schedule_tiling_first &&
buffer_name_to_size_.count(buffer_name) > 0) {
const ir::Expr& analyzed_size = buffer_name_to_size_.at(buffer_name);
VLOG(6) << "Replacing shape of buffer " << buffer->name << " with shape "
<< analyzed_size;
buffer->shape = {analyzed_size};
}
}

private:
const std::unordered_map<std::string, std::vector<ir::Expr>>&
buffer_name_to_shape_;
const std::unordered_map<std::string, ir::Expr>& buffer_name_to_size_;
};

void ResizeBufferToMaxVarRange(ir::Expr* expr) {
VLOG(6) << "Before ResizeBufferToMaxVarRange, Expr = \n" << *expr;
AnalyzeLoopVarRange analyze_functor;
analyze_functor(expr);
ResizeBufferFromAnalyzedRange resize_functor(
analyze_functor.buffer_name_to_indice_extent);
analyze_functor.buffer_name_to_indice_extent,
analyze_functor.buffer_name_to_size);
resize_functor(expr);
VLOG(6) << "After ResizeBufferToMaxVarRange, Expr = \n" << *expr;
}
Expand Down
14 changes: 13 additions & 1 deletion test/ir/pir/cinn/symbolic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ if(WITH_GPU)
test_llama_mlp_st.py
test_llama_mlp_dy.py
test_while_st.py
test_infer_sym_shape_utils.py)
test_infer_sym_shape_utils.py
test_dyshape_cast.py)

foreach(cinn_pir_test_name ${CINN_PIR_SYMBOLIC_TEST})
string(REGEX REPLACE ".py" "" cinn_pir_test_name ${cinn_pir_test_name})
Expand Down Expand Up @@ -221,4 +222,15 @@ if(WITH_GPU)
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
set_tests_properties(test_while_st PROPERTIES LABELS "RUN_TYPE=CINN")

add_test(
NAME test_dyshape_cast
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH}
FLAGS_prim_all=true FLAGS_cinn_bucket_compile=True
FLAGS_group_schedule_tiling_first=1 FLAGS_enable_pir_api=1
${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_dyshape_cast.py
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
set_tests_properties(test_dyshape_cast PROPERTIES LABELS "RUN_TYPE=CINN")

endif()
74 changes: 74 additions & 0 deletions test/ir/pir/cinn/symbolic/test_dyshape_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
from os.path import dirname

import numpy as np

import paddle
from paddle import nn
from paddle.static import InputSpec

sys.path.append(dirname(dirname(__file__)))

import utils


class CastLayer(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
x = paddle.cast(x, dtype="float16")
return paddle.cast(x, dtype="float32")


class TestCast(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.shape = [1024, 32, 1024, 17]
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = True

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval(self, use_cinn):
net = CastLayer()
input_spec = [
InputSpec(shape=[None, 32, None, None], dtype='float32'),
]
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval(self):
cinn_out = self.eval(use_cinn=True)
dy_out = self.eval(use_cinn=False)
np.testing.assert_allclose(
cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
)


if __name__ == '__main__':
unittest.main()