Skip to content

Commit

Permalink
support output header and source to files
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Mar 3, 2020
1 parent 54a1d24 commit e8c5219
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 91 deletions.
54 changes: 47 additions & 7 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "cinn/backends/codegen_c.h"

#include <fstream>
#include "cinn/ir/lowered_func.h"
#include "cinn/optim/remove_nested_block.h"
#include "cinn/runtime/intrinsic.h"
Expand All @@ -9,13 +10,33 @@ namespace cinn {
namespace backends {
using namespace utils;

CodeGenC::CodeGenC(std::ostream &os, Target target, OutputKind output_kind)
: ir::IrPrinter(os), target_(target), output_kind_(output_kind) {}
void CodeGenC::Compile(const lang::Module &module, const Outputs &outputs) {
if (!outputs.c_header_name.empty()) {
LOG(WARNING) << "Output C source to file " << outputs.c_header_name;
auto source = Compile(module, OutputKind::CHeader);
std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source;
file.close();
}

if (!outputs.c_source_name.empty()) {
LOG(WARNING) << "Output C source to file " << outputs.c_source_name;
auto source = Compile(module, OutputKind::CImpl);
std::ofstream file(outputs.c_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name;
file << source;
file.close();
}
}

CodeGenC::CodeGenC(Target target) : ir::IrPrinter(ss_) {}

void CodeGenC::Compile(const lang::Module &module) {
if (output_kind_ == OutputKind::CHeader) {
std::string CodeGenC::Compile(const lang::Module &module, OutputKind output_kind) {
ss_.str();
if (output_kind == OutputKind::CHeader) {
GenerateHeaderFile(module);
} else if (output_kind_ == OutputKind::CImpl) {
} else if (output_kind == OutputKind::CImpl) {
PrintIncludes();

PrintBufferCreation(module->buffers);
Expand All @@ -26,15 +47,18 @@ void CodeGenC::Compile(const lang::Module &module) {
} else {
LOG(FATAL) << "Not supported OutputKind";
}
return ss_.str();
}
void CodeGenC::Compile(const ir::LoweredFunc &function) {
std::string CodeGenC::Compile(const ir::LoweredFunc &function) {
Print(function);
os() << "\n\n";
return ss_.str();
}
void CodeGenC::Compile(const ir::Buffer &buffer) {
std::string CodeGenC::Compile(const ir::Buffer &buffer) {
Print(runtime::BufferCreate(buffer));
os() << "\n";
os() << "\n";
return ss_.str();
}

std::string CodeGenC::PrintType(Type type) {
Expand Down Expand Up @@ -267,5 +291,21 @@ void CodeGenC::GenerateHeaderFile(const lang::Module &module) {
PrintFileGuardClose(module.name());
}

void CodeGenC::PrintFuncArg(const ir::Argument &arg) {
if (arg.is_buffer()) {
if (arg.is_input()) {
os() << "const struct cinn_buffer_t *";
} else {
os() << "struct cinn_buffer_t *";
}
} else if (arg.is_scalar()) {
os() << PrintType(arg.type) << " ";
os() << arg.name;
} else {
NOT_IMPLEMENTED
}
os() << arg.name;
}

} // namespace backends
} // namespace cinn
29 changes: 9 additions & 20 deletions cinn/backends/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ class CodeGenC : public ir::IrPrinter {
CImpl, //! output the C implementation file.
};

CodeGenC(std::ostream& os, Target target, OutputKind output_kind);
CodeGenC(Target target);

void Compile(const lang::Module& module);
void Compile(const ir::LoweredFunc& function);
void Compile(const ir::Buffer& buffer);
void Compile(const lang::Module& module, const Outputs& outputs);

std::string Compile(const lang::Module& module, OutputKind output_kind);

protected:
std::string Compile(const ir::LoweredFunc& function);
std::string Compile(const ir::Buffer& buffer);

void GenerateHeaderFile(const lang::Module& module);

std::string PrintType(Type type);
Expand All @@ -48,25 +51,11 @@ class CodeGenC : public ir::IrPrinter {
NODETY_FORALL(__DEFINE_VISIT)
#undef __DEFINE_VISIT

void PrintFuncArg(const ir::Argument& arg) {
if (arg.is_buffer()) {
if (arg.is_input()) {
os() << "const struct cinn_buffer_t *";
} else {
os() << "struct cinn_buffer_t *";
}
} else if (arg.is_scalar()) {
os() << PrintType(arg.type) << " ";
os() << arg.name;
} else {
NOT_IMPLEMENTED
}
os() << arg.name;
}
void PrintFuncArg(const ir::Argument& arg);

private:
Target target_;
OutputKind output_kind_;
std::stringstream ss_;
};

} // namespace backends
Expand Down
60 changes: 13 additions & 47 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,6 @@ std::tuple<ir::Tensor, ir::Tensor, ir::Tensor, lang::Buffer> CreateTensor1() {
return std::make_tuple(A, B, C, C_buf);
}

TEST(CodeGenC, basic) {
std::stringstream ss;
Target target;
CodeGenC codegen(ss, target, CodeGenC::OutputKind::CImpl);

ir::Tensor A, B, C;
lang::Buffer C_buf;
std::tie(A, B, C, C_buf) = CreateTensor1();
CHECK(!C->inlined());

auto funcs = lang::Lower("func_C", {A, B, C});
ASSERT_EQ(funcs.size(), 1UL);

codegen.Compile(funcs.front());

auto out = ss.str();

std::cout << "codegen C:" << std::endl << out << std::endl;

EXPECT_EQ(utils::Trim(out),
utils::Trim(
R"ROC(
void func_C(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct cinn_buffer_t *C)
{
cinn_buffer_malloc(C);
for (int32_t i = 0; (i <= 99); i += 1){
for (int32_t j = 0; (j <= 19); j += 1){
C[((i * 20) + j)] = (A[((i * 20) + j)] + B[((i * 20) + j)]);
};
};
}
)ROC"));
}

TEST(CodeGenC, module) {
ir::Tensor A, B, C;
lang::Buffer C_buf;
Expand All @@ -76,10 +42,8 @@ TEST(CodeGenC, module) {

{
std::stringstream ss;
CodeGenC codegen(ss, target, CodeGenC::OutputKind::CImpl);
codegen.Compile(module);

auto out = ss.str();
CodeGenC codegen(target);
auto out = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
std::cout << "codegen C:" << std::endl << out << std::endl;

std::string target_str = R"ROC(
Expand All @@ -101,10 +65,8 @@ void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct c
}

{
std::stringstream ss;
CodeGenC header_compiler(ss, target, CodeGenC::OutputKind::CHeader);
header_compiler.Compile(module);
auto out = ss.str();
CodeGenC compiler(target);
auto out = compiler.Compile(module, CodeGenC::OutputKind::CHeader);
std::cout << "header:\n" << out << std::endl;
auto target_str = R"ROC(
#ifndef _MODULE1_CINN_H_
Expand All @@ -121,6 +83,13 @@ void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct c

EXPECT_EQ(utils::Trim(out), utils::Trim(target_str));
}

{
CodeGenC compiler(target);
Outputs outputs;
outputs = outputs.c_header("./generated_module1.h").c_source("./generated_module1.cc");
compiler.Compile(module, outputs);
}
}

TEST(CodeGenC, module_with_transform) {
Expand Down Expand Up @@ -158,11 +127,8 @@ TEST(CodeGenC, module_with_transform) {
module.Append(funcs.front());
module.Append(C_buf);

std::stringstream ss;
CodeGenC codegen(ss, target, CodeGenC::OutputKind::CImpl);
codegen.Compile(module);

auto out = ss.str();
CodeGenC codegen(target);
auto out = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
std::cout << "codegen C:" << std::endl << out << std::endl;

auto tgt = R"ROC(
Expand Down
25 changes: 25 additions & 0 deletions cinn/backends/outputs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,29 @@

namespace cinn {
namespace lang {} // namespace lang

backends::Outputs backends::Outputs::object(const std::string &name) const {
Outputs updated = *this;
updated.object_name = name;
return updated;
}

backends::Outputs backends::Outputs::bitcode(const std::string &name) const {
Outputs updated = *this;
updated.bitcode_name = name;
return updated;
}

backends::Outputs backends::Outputs::c_header(const std::string &name) const {
Outputs updated = *this;
updated.c_header_name = name;
return updated;
}

backends::Outputs backends::Outputs::c_source(const std::string &name) const {
Outputs updated = *this;
updated.c_source_name = name;
return updated;
}

} // namespace cinn
27 changes: 10 additions & 17 deletions cinn/backends/outputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,16 @@ struct Outputs {
//! The name of the emitted C header file.
std::string c_header_name;

Outputs object(const std::string& name) const {
Outputs updated = *this;
updated.object_name = name;
return updated;
}

Outputs bitcode(const std::string& name) const {
Outputs updated = *this;
updated.bitcode_name = name;
return updated;
}

Outputs c_header(const std::string& name) const {
Outputs updated = *this;
updated.c_header_name = name;
return updated;
}
//! The name of the emitted C source file.
std::string c_source_name;

Outputs object(const std::string& name) const;

Outputs bitcode(const std::string& name) const;

Outputs c_header(const std::string& name) const;

Outputs c_source(const std::string& name) const;
};

} // namespace backends
Expand Down

0 comments on commit e8c5219

Please sign in to comment.