Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add yolox implement #418

Merged
merged 6 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/lite/cv/test_lite_yolox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,40 @@ static void test_tnn()
#endif
}

static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT

std::string engine_path = "../../..//examples/hub/trt/yolox_s_fp32.engine";
std::string test_img_path = "../../..//examples/lite/resources/test_lite_yolox_2.jpg";
std::string save_img_path = "../../..//examples/logs/test_lite_yolox_trt_4.jpg";

// 2. Test Specific Engine TensorRT
lite::trt::cv::detection::YoloX *yolox =
new lite::trt::cv::detection::YoloX (engine_path);

std::vector<lite::types::Boxf> detected_boxes;
cv::Mat img_bgr = cv::imread(test_img_path);
yolox->detect(img_bgr, detected_boxes);

lite::utils::draw_boxes_inplace(img_bgr, detected_boxes);

cv::imwrite(save_img_path, img_bgr);

delete yolox;
#endif
}



static void test_lite()
{
test_default();
test_onnxruntime();
test_mnn();
test_ncnn();
test_tnn();
test_tensorrt();
}

int main(__unused int argc, __unused char *argv[])
Expand Down
4 changes: 3 additions & 1 deletion lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
#include "lite/trt/core/trt_core.h"
#include "lite/trt/cv/trt_yolofacev8.h"
#include "lite/trt/cv/trt_yolov5.h"
#include "lite/trt/cv/trt_yolox.h"
#endif

// ENABLE_MNN
Expand Down Expand Up @@ -677,14 +678,15 @@ namespace lite{
{
typedef trtcv::TRTYoloFaceV8 _TRT_YOLOFaceNet;
typedef trtcv::TRTYoloV5 _TRT_YOLOv5;
typedef trtcv::TRTYoloX _TRT_YoloX;
namespace classification
{

}
namespace detection
{
typedef _TRT_YOLOv5 YOLOV5;

typedef _TRT_YoloX YoloX;
}
namespace face
{
Expand Down
1 change: 1 addition & 0 deletions lite/trt/core/trt_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
namespace trtcv{
class LITE_EXPORTS TRTYoloFaceV8; // [1] * reference: https://github.com/derronqi/yolov8-face
class LITE_EXPORTS TRTYoloV5; // [2] * reference: https://github.com/ultralytics/yolov5
class LITE_EXPORTS TRTYoloX; // [3] * reference: https://github.com/Megvii-BaseDetection/YOLOX
}

namespace trtcv{
Expand Down
11 changes: 7 additions & 4 deletions lite/trt/core/trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "trt_utils.h"


float* trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<int64_t> input_node_dims,unsigned int data_format){
void trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<float> &input_vector,std::vector<int64_t> input_node_dims,unsigned int data_format){
// make mat to float type's vector

const unsigned int rows = mat.rows;
Expand All @@ -19,12 +19,11 @@ float* trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<int
if (input_node_dims.size() != 4) throw std::runtime_error("dims mismatch.");
if (input_node_dims.at(0) != 1) throw std::runtime_error("batch != 1");


if (data_format == transform::CHW)
{
const unsigned int target_tensor_size = rows * cols * channels;
// input vector's size
float* input_vector = new float [target_tensor_size];
input_vector.resize(target_tensor_size);

for (int c = 0; c < channels; ++c)
{
Expand All @@ -36,8 +35,12 @@ float* trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<int
}
}
}
return input_vector;

}else
{
throw std::runtime_error("data_format must be transform::CHW!");
}

}


Expand Down
2 changes: 1 addition & 1 deletion lite/trt/core/trt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace trtcv
{
CHW = 0, HWC =1
};
LITE_EXPORTS float* create_tensor(const cv::Mat &mat,std::vector<int64_t> input_node_dims,unsigned int data_format = CHW);
LITE_EXPORTS void create_tensor(const cv::Mat &mat,std::vector<float> &input_vector,std::vector<int64_t> input_node_dims,unsigned int data_format = CHW);


}
Expand Down
16 changes: 6 additions & 10 deletions lite/trt/cv/trt_yolofacev8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,26 @@ void TRTYoloFaceV8::detect(const cv::Mat &mat, std::vector<lite::types::Boxf> &b
cv::Mat normalized_image = normalize(mat);

// 2.trans to input vector
auto input = trtcv::utils::transform::create_tensor(normalized_image,input_node_dims,trtcv::utils::transform::CHW);
std::vector<float> input;
trtcv::utils::transform::create_tensor(normalized_image,input,input_node_dims,trtcv::utils::transform::CHW);

// 3. infer
cudaMemcpyAsync(buffers[0], input, input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyAsync(buffers[0], input.data(), input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyHostToDevice, stream);
bool status = trt_context->enqueueV3(stream);

delete[] input;
input = nullptr;

if (!status){
std::cerr << "Failed to infer by TensorRT." << std::endl;
return;
}

float* output = new float[output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2]];
std::vector<float> output(output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2]);

cudaMemcpyAsync(output, buffers[1], output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2] * sizeof(float),
cudaMemcpyAsync(output.data(), buffers[1], output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
// 4. generate box
generate_box(output,boxes,0.45f,0.5f);
generate_box(output.data(),boxes,0.45f,0.5f);

// free pointer
delete[] output;
output = nullptr;

}
16 changes: 7 additions & 9 deletions lite/trt/cv/trt_yolov5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ void TRTYoloV5::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_bo
cv::Mat normalized_image = normalized(mat_rs);

//1. make the input
auto input = trtcv::utils::transform::create_tensor(normalized_image,input_node_dims,trtcv::utils::transform::CHW);
std::vector<float> input;
trtcv::utils::transform::create_tensor(normalized_image,input,input_node_dims,trtcv::utils::transform::CHW);

//2. infer
cudaMemcpyAsync(buffers[0], input, input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyAsync(buffers[0], input.data(), input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
delete[] input;
input = nullptr;


bool status = trt_context->enqueueV3(stream);
cudaStreamSynchronize(stream);
Expand All @@ -164,18 +164,16 @@ void TRTYoloV5::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_bo
// get the first output dim
auto pred_dims = output_node_dims[0];

float* output = new float[pred_dims[0] * pred_dims[1] * pred_dims[2]];
std::vector<float> output(pred_dims[0] * pred_dims[1] * pred_dims[2]);

cudaMemcpyAsync(output, buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float),
cudaMemcpyAsync(output.data(), buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);

//3. generate the boxes
std::vector<types::Boxf> bbox_collection;
generate_bboxes(scale_params, bbox_collection, output, score_threshold, img_height, img_width);
generate_bboxes(scale_params, bbox_collection, output.data(), score_threshold, img_height, img_width);
nms(bbox_collection, detected_boxes, iou_threshold, topk, nms_type);
delete[] output;
output = nullptr;
}


4 changes: 2 additions & 2 deletions lite/trt/cv/trt_yolov5.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace trtcv
class LITE_EXPORTS TRTYoloV5 : public BasicTRTHandler
{
public:
explicit TRTYoloV5(const std::string &_onnx_path, unsigned int _num_threads = 1) :
BasicTRTHandler(_onnx_path, _num_threads)
explicit TRTYoloV5(const std::string &_trt_model_path, unsigned int _num_threads = 1) :
BasicTRTHandler(_trt_model_path, _num_threads)
{};

~TRTYoloV5() override = default;
Expand Down
Loading