nnvm_test.cc
#include "tvm/tvm.h"
#include "nnvm/nnvm.h"
#include "nnvm/compiler/op_attr_types.h"
#include "topi/generic/default.h"
#include <gtest/gtest.h>
struct VariableParam {
uint32_t version{0};
};
nnvm::NodeEntry CreateVariableNode(const std::string& name) {
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = nullptr;
n->attrs.name = name;
n->attrs.parsed = VariableParam();
return nnvm::NodeEntry{n, 0, 0};
}
using namespace nnvm;
using namespace nnvm::compiler;
Schedule sched(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target_) {
auto target = topi::Target::create(target_);
return topi::generic::default_schedule_auto_inline(target, outs);
}
NNVM_REGISTER_OP(elemwise_add)
.set_attr<FTVMSchedule>("FTVMSchedule", sched);
TEST(NNVM, Basic) {
nnvm::Graph g;
std::string target = "llvm";
g.attrs["target"] = std::make_shared<dmlc::any>(std::move(target));
auto x = CreateVariableNode("X");
auto y = CreateVariableNode("Y");
auto z = CreateVariableNode("Z");
x.node->attrs.dict["dtype"] = "0";
y.node->attrs.dict["dtype"] = "0";
z.node->attrs.dict["dtype"] = "0";
auto add = nnvm::MakeNode("elemwise_add", "add1", {x, y});
LOG(ERROR) << x.node.get();
LOG(ERROR) << y.node.get();
LOG(ERROR) << add.node.get();
g.outputs = {add};
LOG(ERROR) << "inferring shape";
g = ApplyPass(g, "InferShape");
LOG(ERROR) << "inferring type";
g = ApplyPass(g, "InferType");
LOG(ERROR) << "inferring fusible";
g = ApplyPass(g, "GraphFindFusibleGroups");
LOG(ERROR) << "fuse";
g = ApplyPass(g, "GraphFuse");
LOG(ERROR) << "compile";
g = ApplyPass(g, "GraphCompile");
}