From e8c52198fcb1f459465494b0b58d403e4e2be3cf Mon Sep 17 00:00:00 2001 From: Superjomn Date: Tue, 3 Mar 2020 20:40:31 +0800 Subject: [PATCH] support output header and source to files --- cinn/backends/codegen_c.cc | 54 +++++++++++++++++++++++++---- cinn/backends/codegen_c.h | 29 +++++----------- cinn/backends/codegen_c_test.cc | 60 +++++++-------------------------- cinn/backends/outputs.cc | 25 ++++++++++++++ cinn/backends/outputs.h | 27 ++++++--------- 5 files changed, 104 insertions(+), 91 deletions(-) diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index 1699e581d11eff..7b7bb768d31c77 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -1,5 +1,6 @@ #include "cinn/backends/codegen_c.h" +#include #include "cinn/ir/lowered_func.h" #include "cinn/optim/remove_nested_block.h" #include "cinn/runtime/intrinsic.h" @@ -9,13 +10,33 @@ namespace cinn { namespace backends { using namespace utils; -CodeGenC::CodeGenC(std::ostream &os, Target target, OutputKind output_kind) - : ir::IrPrinter(os), target_(target), output_kind_(output_kind) {} +void CodeGenC::Compile(const lang::Module &module, const Outputs &outputs) { + if (!outputs.c_header_name.empty()) { + LOG(WARNING) << "Output C source to file " << outputs.c_header_name; + auto source = Compile(module, OutputKind::CHeader); + std::ofstream file(outputs.c_header_name); + CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name; + file << source; + file.close(); + } + + if (!outputs.c_source_name.empty()) { + LOG(WARNING) << "Output C source to file " << outputs.c_source_name; + auto source = Compile(module, OutputKind::CImpl); + std::ofstream file(outputs.c_source_name); + CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name; + file << source; + file.close(); + } +} + +CodeGenC::CodeGenC(Target target) : ir::IrPrinter(ss_) {} -void CodeGenC::Compile(const lang::Module &module) { - if (output_kind_ == OutputKind::CHeader) { +std::string CodeGenC::Compile(const lang::Module &module, OutputKind output_kind) { + ss_.str(); + if (output_kind == OutputKind::CHeader) { GenerateHeaderFile(module); - } else if (output_kind_ == OutputKind::CImpl) { + } else if (output_kind == OutputKind::CImpl) { PrintIncludes(); PrintBufferCreation(module->buffers); @@ -26,15 +47,18 @@ void CodeGenC::Compile(const lang::Module &module) { } else { LOG(FATAL) << "Not supported OutputKind"; } + return ss_.str(); } -void CodeGenC::Compile(const ir::LoweredFunc &function) { +std::string CodeGenC::Compile(const ir::LoweredFunc &function) { Print(function); os() << "\n\n"; + return ss_.str(); } -void CodeGenC::Compile(const ir::Buffer &buffer) { +std::string CodeGenC::Compile(const ir::Buffer &buffer) { Print(runtime::BufferCreate(buffer)); os() << "\n"; os() << "\n"; + return ss_.str(); } std::string CodeGenC::PrintType(Type type) { @@ -267,5 +291,21 @@ void CodeGenC::GenerateHeaderFile(const lang::Module &module) { PrintFileGuardClose(module.name()); } +void CodeGenC::PrintFuncArg(const ir::Argument &arg) { + if (arg.is_buffer()) { + if (arg.is_input()) { + os() << "const struct cinn_buffer_t *"; + } else { + os() << "struct cinn_buffer_t *"; + } + } else if (arg.is_scalar()) { + os() << PrintType(arg.type) << " "; + os() << arg.name; + } else { + NOT_IMPLEMENTED + } + os() << arg.name; +} + } // namespace backends } // namespace cinn diff --git a/cinn/backends/codegen_c.h b/cinn/backends/codegen_c.h index 1032641a86fdd9..eb400bb641b540 100644 --- a/cinn/backends/codegen_c.h +++ b/cinn/backends/codegen_c.h @@ -25,13 +25,16 @@ class CodeGenC : public ir::IrPrinter { CImpl, //! output the C implementation file. }; - CodeGenC(std::ostream& os, Target target, OutputKind output_kind); + CodeGenC(Target target); - void Compile(const lang::Module& module); - void Compile(const ir::LoweredFunc& function); - void Compile(const ir::Buffer& buffer); + void Compile(const lang::Module& module, const Outputs& outputs); + + std::string Compile(const lang::Module& module, OutputKind output_kind); protected: + std::string Compile(const ir::LoweredFunc& function); + std::string Compile(const ir::Buffer& buffer); + void GenerateHeaderFile(const lang::Module& module); std::string PrintType(Type type); @@ -48,25 +51,11 @@ class CodeGenC : public ir::IrPrinter { NODETY_FORALL(__DEFINE_VISIT) #undef __DEFINE_VISIT - void PrintFuncArg(const ir::Argument& arg) { - if (arg.is_buffer()) { - if (arg.is_input()) { - os() << "const struct cinn_buffer_t *"; - } else { - os() << "struct cinn_buffer_t *"; - } - } else if (arg.is_scalar()) { - os() << PrintType(arg.type) << " "; - os() << arg.name; - } else { - NOT_IMPLEMENTED - } - os() << arg.name; - } + void PrintFuncArg(const ir::Argument& arg); private: Target target_; - OutputKind output_kind_; + std::stringstream ss_; }; } // namespace backends diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc index 2e06d36700784d..894a188471bbf3 100644 --- a/cinn/backends/codegen_c_test.cc +++ b/cinn/backends/codegen_c_test.cc @@ -23,40 +23,6 @@ std::tuple CreateTensor1() { return std::make_tuple(A, B, C, C_buf); } -TEST(CodeGenC, basic) { - std::stringstream ss; - Target target; - CodeGenC codegen(ss, target, CodeGenC::OutputKind::CImpl); - - ir::Tensor A, B, C; - lang::Buffer C_buf; - std::tie(A, B, C, C_buf) = CreateTensor1(); - CHECK(!C->inlined()); - - auto funcs = lang::Lower("func_C", {A, B, C}); - ASSERT_EQ(funcs.size(), 1UL); - - codegen.Compile(funcs.front()); - - auto out = ss.str(); - - std::cout << "codegen C:" << std::endl << out << std::endl; - - EXPECT_EQ(utils::Trim(out), - utils::Trim( - R"ROC( -void func_C(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct cinn_buffer_t *C) -{ - cinn_buffer_malloc(C); - for (int32_t i = 0; (i <= 99); i += 1){ - for (int32_t j = 0; (j <= 19); j += 1){ - C[((i * 20) + j)] = (A[((i * 20) + j)] + B[((i * 20) + j)]); - }; - }; -} -)ROC")); -} - TEST(CodeGenC, module) { ir::Tensor A, B, C; lang::Buffer C_buf; @@ -76,10 +42,8 @@ TEST(CodeGenC, module) { { std::stringstream ss; - CodeGenC codegen(ss, target, CodeGenC::OutputKind::CImpl); - codegen.Compile(module); - - auto out = ss.str(); + CodeGenC codegen(target); + auto out = codegen.Compile(module, CodeGenC::OutputKind::CImpl); std::cout << "codegen C:" << std::endl << out << std::endl; std::string target_str = R"ROC( @@ -101,10 +65,8 @@ void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct c } { - std::stringstream ss; - CodeGenC header_compiler(ss, target, CodeGenC::OutputKind::CHeader); - header_compiler.Compile(module); - auto out = ss.str(); + CodeGenC compiler(target); + auto out = compiler.Compile(module, CodeGenC::OutputKind::CHeader); std::cout << "header:\n" << out << std::endl; auto target_str = R"ROC( #ifndef _MODULE1_CINN_H_ @@ -121,6 +83,13 @@ void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct c EXPECT_EQ(utils::Trim(out), utils::Trim(target_str)); } + + { + CodeGenC compiler(target); + Outputs outputs; + outputs = outputs.c_header("./generated_module1.h").c_source("./generated_module1.cc"); + compiler.Compile(module, outputs); + } } TEST(CodeGenC, module_with_transform) { @@ -158,11 +127,8 @@ TEST(CodeGenC, module_with_transform) { module.Append(funcs.front()); module.Append(C_buf); - std::stringstream ss; - CodeGenC codegen(ss, target, CodeGenC::OutputKind::CImpl); - codegen.Compile(module); - - auto out = ss.str(); + CodeGenC codegen(target); + auto out = codegen.Compile(module, CodeGenC::OutputKind::CImpl); std::cout << "codegen C:" << std::endl << out << std::endl; auto tgt = R"ROC( diff --git a/cinn/backends/outputs.cc b/cinn/backends/outputs.cc index e1c23fd72e421c..7d8985e296a4a3 100644 --- a/cinn/backends/outputs.cc +++ b/cinn/backends/outputs.cc @@ -2,4 +2,29 @@ namespace cinn { namespace lang {} // namespace lang + +backends::Outputs backends::Outputs::object(const std::string &name) const { + Outputs updated = *this; + updated.object_name = name; + return updated; +} + +backends::Outputs backends::Outputs::bitcode(const std::string &name) const { + Outputs updated = *this; + updated.bitcode_name = name; + return updated; +} + +backends::Outputs backends::Outputs::c_header(const std::string &name) const { + Outputs updated = *this; + updated.c_header_name = name; + return updated; +} + +backends::Outputs backends::Outputs::c_source(const std::string &name) const { + Outputs updated = *this; + updated.c_source_name = name; + return updated; +} + } // namespace cinn diff --git a/cinn/backends/outputs.h b/cinn/backends/outputs.h index db0a8cf859c785..1941ce115c923b 100644 --- a/cinn/backends/outputs.h +++ b/cinn/backends/outputs.h @@ -17,23 +17,16 @@ struct Outputs { //! The name of the emitted C header file. std::string c_header_name; - Outputs object(const std::string& name) const { - Outputs updated = *this; - updated.object_name = name; - return updated; - } - - Outputs bitcode(const std::string& name) const { - Outputs updated = *this; - updated.bitcode_name = name; - return updated; - } - - Outputs c_header(const std::string& name) const { - Outputs updated = *this; - updated.c_header_name = name; - return updated; - } + //! The name of the emitted C source file. + std::string c_source_name; + + Outputs object(const std::string& name) const; + + Outputs bitcode(const std::string& name) const; + + Outputs c_header(const std::string& name) const; + + Outputs c_source(const std::string& name) const; }; } // namespace backends