diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index cf95a4d9b55e0e..b4484c7ad14c37 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -26,12 +26,29 @@ class SkipLayerNormOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { -#if IS_TRT_VERSION_GE(6000) VLOG(4) << "convert fused skip layernorm op to tensorrt layer"; framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); + + // We should expand input1 to input2's nbDims + if (input1->getDimensions().nbDims < input2->getDimensions().nbDims) { + nvinfer1::Dims reshape_input1_dim; + reshape_input1_dim.nbDims = input2->getDimensions().nbDims; + for (int i = 0; i < reshape_input1_dim.nbDims; i++) { + if (i < input1->getDimensions().nbDims) { + reshape_input1_dim.d[i] = 0; + } else { + reshape_input1_dim.d[i] = 1; + } + } + auto* reshape_input1_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input1); + reshape_input1_layer->setReshapeDimensions(reshape_input1_dim); + input1 = reshape_input1_layer->getOutput(0); + } + std::vector inputs; inputs.push_back(input1); inputs.push_back(input2); @@ -74,10 +91,10 @@ class SkipLayerNormOpConverter : public OpConverter { bias_weight.values, GetPluginFieldType(bias_weight.type), static_cast(bias_weight.count)}, - { "gamma", - scale_weight.values, - GetPluginFieldType(scale_weight.type), - static_cast(scale_weight.count) }}; + {"gamma", + scale_weight.values, + GetPluginFieldType(scale_weight.type), + static_cast(scale_weight.count)}}; nvinfer1::PluginFieldCollection* pluginPtr = static_cast( malloc(sizeof(*pluginPtr) + @@ -178,11 +195,6 @@ class SkipLayerNormOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode); -#else - PADDLE_THROW(platform::errors::Fatal( - "You are running the TRT Dynamic Shape mode, need to confirm that " - "your TRT version is no less than 6.0")); -#endif } };