Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRT] Support RealESRGAN #441

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/lite/cv/test_lite_realesrgan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,30 @@ static void test_default()
}


static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT
std::string engine_path = "../../../examples/hub/trt/RealESRGAN_x4plus_fp16.engine";
std::string test_img_path = "../../../examples/lite/resources/test_lite_realesrgan.jpg";
std::string save_img_path = "../../../examples/logs/test_lite_realesrgan_trt.jpg";

lite::trt::cv::upscale::RealESRGAN *realesrgan = new lite::trt::cv::upscale::RealESRGAN (engine_path);

cv::Mat test_image = cv::imread(test_img_path);

realesrgan->detect(test_image,save_img_path);

std::cout<<"trt upscale enhance done!"<<std::endl;

delete realesrgan;
#endif
}


static void test_lite()
{
test_default();
test_tensorrt();
}

int main(__unused int argc, __unused char *argv[])
Expand Down
6 changes: 6 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
#include "lite/trt/cv/trt_yolov6.h"
#include "lite/trt/cv/trt_yolov5_blazeface.h"
#include "lite/trt/cv/trt_lightenhance.h"
#include "lite/trt/cv/trt_realesrgan.h"
#include "lite/trt/sd/trt_clip.h"
#include "lite/trt/sd/trt_vae.h"
#include "lite/trt/sd/trt_unet.h"
Expand Down Expand Up @@ -729,6 +730,7 @@ namespace lite{
typedef trtcv::TRTYoloV6 _TRT_YOLOv6;
typedef trtcv::TRTYOLO5Face _TRT_YOLO5Face;
typedef trtcv::TRTLightEnhance _TRT_LightEnhance;
typedef trtcv::TRTRealESRGAN _TRT_RealESRGAN;
namespace classification
{

Expand All @@ -752,6 +754,10 @@ namespace lite{
{
typedef _TRT_LightEnhance LightEnhance;
}
namespace upscale
{
typedef _TRT_RealESRGAN RealESRGAN;
}
}

namespace sd
Expand Down
109 changes: 109 additions & 0 deletions lite/trt/cv/trt_realesrgan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//
// Created by wangzijian on 10/25/24.
//

#include "trt_realesrgan.h"
using trtcv::TRTRealESRGAN;

void TRTRealESRGAN::preprocess(const cv::Mat &frame, cv::Mat &output_mat) {
cv::cvtColor(frame,output_mat,cv::COLOR_BGR2RGB);
output_mat.convertTo(output_mat,CV_32FC3,1 / 255.f);
}


void TRTRealESRGAN::detect(const cv::Mat &input_mat, const std::string &output_path) {
if (input_mat.empty()) return;

ori_input_width = input_mat.cols;
ori_input_height = input_mat.rows;

cv::Mat preprocessed_mat;
preprocess(input_mat, preprocessed_mat);

const int batch_size = 1;
const int channels = 3;
const int input_h = preprocessed_mat.rows;
const int input_w = preprocessed_mat.cols;
const size_t input_size = batch_size * channels * input_h * input_w * sizeof(float);
const size_t output_size = batch_size * channels * input_h * 4 * input_w * 4 * sizeof(float);

for (auto& buffer : buffers) {
if (buffer) {
cudaFree(buffer);
buffer = nullptr;
}
}

cudaMalloc(&buffers[0], input_size);
cudaMalloc(&buffers[1], output_size);
if (!buffers[0] || !buffers[1]) {
std::cerr << "Failed to allocate CUDA memory" << std::endl;
return;
}

input_node_dims = {batch_size, channels, input_h, input_w};

std::vector<float> input;
trtcv::utils::transform::create_tensor(preprocessed_mat, input, input_node_dims, trtcv::utils::transform::CHW);

cudaError_t status = cudaMemcpyAsync(buffers[0], input.data(), input_size,
cudaMemcpyHostToDevice, stream);
if (status != cudaSuccess) {
std::cerr << "Input copy failed: " << cudaGetErrorString(status) << std::endl;
return;
}
cudaStreamSynchronize(stream);

nvinfer1::Dims ESRGANDims;
ESRGANDims.nbDims = 4;
ESRGANDims.d[0] = batch_size;
ESRGANDims.d[1] = channels;
ESRGANDims.d[2] = input_h;
ESRGANDims.d[3] = input_w;

auto input_tensor_name = trt_engine->getIOTensorName(0);
auto output_tensor_name = trt_engine->getIOTensorName(1);
trt_context->setTensorAddress(input_tensor_name, buffers[0]);
trt_context->setTensorAddress(output_tensor_name, buffers[1]);

trt_context->setInputShape(input_tensor_name, ESRGANDims);

bool infer_status = trt_context->enqueueV3(stream);
if (!infer_status) {
std::cerr << "TensorRT inference failed!" << std::endl;
return;
}
cudaStreamSynchronize(stream);

const size_t total_output_elements = batch_size * channels * input_h * 4 * input_w * 4;
std::vector<float> output(total_output_elements);

status = cudaMemcpyAsync(output.data(), buffers[1], output_size,
cudaMemcpyDeviceToHost, stream);
if (status != cudaSuccess) {
std::cerr << "Output copy failed: " << cudaGetErrorString(status) << std::endl;
return;
}
cudaStreamSynchronize(stream);

postprocess(output.data(), output_path);
}

void TRTRealESRGAN::postprocess(float *trt_outputs, const std::string &output_path) {
const int out_h = ori_input_height * 4;
const int out_w = ori_input_width * 4;
const int channel_step = out_h * out_w;
cv::Mat bmat(out_h, out_w, CV_32FC1, trt_outputs);
cv::Mat gmat(out_h, out_w, CV_32FC1, trt_outputs + channel_step);
cv::Mat rmat(out_h, out_w, CV_32FC1, trt_outputs + 2 * channel_step);
bmat *= 255.f;
gmat *= 255.f;
rmat *= 255.f;
std::vector<cv::Mat> channel_mats = {rmat, gmat, bmat};
cv::Mat dstimg;
cv::merge(channel_mats,dstimg);
dstimg.convertTo(dstimg, CV_8UC3);
cv::imwrite(output_path,dstimg);
}


27 changes: 27 additions & 0 deletions lite/trt/cv/trt_realesrgan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

//
// Created by wangzijian on 10/25/24.
//

#ifndef LITE_AI_TOOLKIT_TRT_REALESRGAN_H
#define LITE_AI_TOOLKIT_TRT_REALESRGAN_H
#include "lite/trt/core/trt_core.h"
#include "lite/trt/core/trt_utils.h"

namespace trtcv{
class LITE_EXPORTS TRTRealESRGAN : public BasicTRTHandler{
public:
explicit TRTRealESRGAN(const std::string& _trt_model_path,unsigned int _num_threads = 1):
BasicTRTHandler(_trt_model_path, _num_threads){};

private:
int ori_input_width;
int ori_input_height;
void preprocess(const cv::Mat& frame,cv::Mat &output_mat);
void postprocess(float *trt_outputs,const std::string &output_path);
public:
void detect(const cv::Mat &input_mat,const std::string &output_path);
};
}

#endif //LITE_AI_TOOLKIT_TRT_REALESRGAN_H