nnvm_test.cc

#include "tvm/tvm.h"
#include "nnvm/nnvm.h"
#include "nnvm/compiler/op_attr_types.h"
#include "topi/generic/default.h"

#include <gtest/gtest.h>

struct VariableParam {
  uint32_t version{0};
};

nnvm::NodeEntry CreateVariableNode(const std::string& name) {
  nnvm::NodePtr n = nnvm::Node::Create();
  n->attrs.op = nullptr;
  n->attrs.name = name;
  n->attrs.parsed = VariableParam();
  return nnvm::NodeEntry{n, 0, 0};
}

using namespace nnvm;
using namespace nnvm::compiler;

Schedule sched(const NodeAttrs& attrs,
           const Array<Tensor>& outs,
           const std::string& target_) {
  auto target = topi::Target::create(target_);
  return topi::generic::default_schedule_auto_inline(target, outs);
}

NNVM_REGISTER_OP(elemwise_add)
.set_attr<FTVMSchedule>("FTVMSchedule", sched);

TEST(NNVM, Basic) {
  nnvm::Graph g;
  std::string target = "llvm";
  g.attrs["target"] = std::make_shared<dmlc::any>(std::move(target));
  auto x = CreateVariableNode("X");
  auto y = CreateVariableNode("Y");
  auto z = CreateVariableNode("Z");
  x.node->attrs.dict["dtype"] = "0";
  y.node->attrs.dict["dtype"] = "0";
  z.node->attrs.dict["dtype"] = "0";
  auto add = nnvm::MakeNode("elemwise_add", "add1", {x, y});
  LOG(ERROR) << x.node.get();
  LOG(ERROR) << y.node.get();
  LOG(ERROR) << add.node.get();
  g.outputs = {add};
  LOG(ERROR) << "inferring shape";
  g = ApplyPass(g, "InferShape");
  LOG(ERROR) << "inferring type";
  g = ApplyPass(g, "InferType");
  LOG(ERROR) << "inferring fusible";
  g = ApplyPass(g, "GraphFindFusibleGroups");
  LOG(ERROR) << "fuse";
  g = ApplyPass(g, "GraphFuse");
  LOG(ERROR) << "compile";
  g = ApplyPass(g, "GraphCompile");
}