//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
//
//===----------------------------------------------------------------------===//

#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"

using namespace mlir;
using namespace mlir::tblgen;

using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;

#define DEBUG_TYPE "mlir-tblgen-rewritergen"

namespace llvm {
template <>
struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
  static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
                     raw_ostream &os, StringRef style) {
    os << v.first << ":" << v.second;
  }
};
} // namespace llvm

//===----------------------------------------------------------------------===//
// PatternEmitter
//===----------------------------------------------------------------------===//

namespace {

class StaticMatcherHelper;

class PatternEmitter {
public:
  PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
                 StaticMatcherHelper &helper);

  // Emits the mlir::RewritePattern struct named `rewriteName`.
  void emit(StringRef rewriteName);

  // Emits the static function of DAG matcher.
  void emitStaticMatcher(DagNode tree, std::string funcName);

private:
  // Emits the code for matching ops.
  void emitMatchLogic(DagNode tree, StringRef opName);

  // Emits the code for rewriting ops.
  void emitRewriteLogic();

  //===--------------------------------------------------------------------===//
  // Match utilities
  //===--------------------------------------------------------------------===//

  // Emits C++ statements for matching the DAG structure.
  void emitMatch(DagNode tree, StringRef name, int depth);

  // Emit C++ function call to static DAG matcher.
  void emitStaticMatchCall(DagNode tree, StringRef name);

  // Emit C++ function call to static type/attribute constraint function.
  void emitStaticVerifierCall(StringRef funcName, StringRef opName,
                              StringRef arg, StringRef failureStr);

  // Emits C++ statements for matching using a native code call.
  void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);

  // Emits C++ statements for matching the op constrained by the given DAG
  // `tree` returning the op's variable name.
  void emitOpMatch(DagNode tree, StringRef opName, int depth);

  // Emits C++ statements for matching the `argIndex`-th argument of the given
  // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
  // bound name and the constraint of the operand respectively.
  void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
                        DagLeaf operandMatcher, StringRef argName,
                        int argIndex);

  // Emits C++ statements for matching the operands which can be matched in
  // either order.
  void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
                              StringRef opName, int argIndex, int &operandIndex,
                              int depth);

  // Emits C++ statements for matching the `argIndex`-th argument of the given
  // DAG `tree` as an attribute.
  void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
                          int depth);

  // Emits C++ for checking a match with a corresponding match failure
  // diagnostic.
  void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
                      const llvm::formatv_object_base &failureFmt);

  // Emits C++ for checking a match with a corresponding match failure
  // diagnostics.
  void emitMatchCheck(StringRef opName, const std::string &matchStr,
                      const std::string &failureStr);

  //===--------------------------------------------------------------------===//
  // Rewrite utilities
  //===--------------------------------------------------------------------===//

  // The entry point for handling a result pattern rooted at `resultTree`. This
  // method dispatches to concrete handlers according to `resultTree`'s kind and
  // returns a symbol representing the whole value pack. Callers are expected to
  // further resolve the symbol according to the specific use case.
  //
  // `depth` is the nesting level of `resultTree`; 0 means top-level result
  // pattern. For top-level result pattern, `resultIndex` indicates which result
  // of the matched root op this pattern is intended to replace, which can be
  // used to deduce the result type of the op generated from this result
  // pattern.
  std::string handleResultPattern(DagNode resultTree, int resultIndex,
                                  int depth);

  // Emits the C++ statement to replace the matched DAG with a value built via
  // calling native C++ code.
  std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);

  // Returns the symbol of the old value serving as the replacement.
  StringRef handleReplaceWithValue(DagNode tree);

  // Trailing directives are used at the end of DAG node argument lists to
  // specify additional behaviour for op matchers and creators, etc.
  struct TrailingDirectives {
    // DAG node containing the `location` directive. Null if there is none.
    DagNode location;

    // DAG node containing the `returnType` directive. Null if there is none.
    DagNode returnType;

    // Number of found trailing directives.
    int numDirectives;
  };

  // Collect any trailing directives.
  TrailingDirectives getTrailingDirectives(DagNode tree);

  // Returns the location value to use.
  std::string getLocation(TrailingDirectives &tail);

  // Returns the location value to use.
  std::string handleLocationDirective(DagNode tree);

  // Emit return type argument.
  std::string handleReturnTypeArg(DagNode returnType, int i, int depth);

  // Emits the C++ statement to build a new op out of the given DAG `tree` and
  // returns the variable name that this op is assigned to. If the root op in
  // DAG `tree` has a specified name, the created op will be assigned to a
  // variable of the given name. Otherwise, a unique name will be used as the
  // result value name.
  std::string handleOpCreation(DagNode tree, int resultIndex, int depth);

  using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;

  // Emits a local variable for each value and attribute to be used for creating
  // an op.
  void createSeparateLocalVarsForOpArgs(DagNode node,
                                        ChildNodeIndexNameMap &childNodeNames);

  // Emits the concrete arguments used to call an op's builder.
  void supplyValuesForOpArgs(DagNode node,
                             const ChildNodeIndexNameMap &childNodeNames,
                             int depth);

  // Emits the local variables for holding all values as a whole and all named
  // attributes as a whole to be used for creating an op.
  void createAggregateLocalVarsForOpArgs(
      DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);

  // Returns the C++ expression to construct a constant attribute of the given
  // `value` for the given attribute kind `attr`.
  std::string handleConstantAttr(Attribute attr, const Twine &value);

  // Returns the C++ expression to build an argument from the given DAG `leaf`.
  // `patArgName` is used to bound the argument to the source pattern.
  std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);

  //===--------------------------------------------------------------------===//
  // General utilities
  //===--------------------------------------------------------------------===//

  // Collects all of the operations within the given dag tree.
  void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);

  // Returns a unique symbol for a local variable of the given `op`.
  std::string getUniqueSymbol(const Operator *op);

  //===--------------------------------------------------------------------===//
  // Symbol utilities
  //===--------------------------------------------------------------------===//

  // Returns how many static values the given DAG `node` correspond to.
  int getNodeValueCount(DagNode node);

private:
  // Pattern instantiation location followed by the location of multiclass
  // prototypes used. This is intended to be used as a whole to
  // PrintFatalError() on errors.
  ArrayRef<SMLoc> loc;

  // Op's TableGen Record to wrapper object.
  RecordOperatorMap *opMap;

  // Handy wrapper for pattern being emitted.
  Pattern pattern;

  // Map for all bound symbols' info.
  SymbolInfoMap symbolInfoMap;

  StaticMatcherHelper &staticMatcherHelper;

  // The next unused ID for newly created values.
  unsigned nextValueId = 0;

  raw_indented_ostream os;

  // Format contexts containing placeholder substitutions.
  FmtContext fmtCtx;
};

// Tracks DagNode's reference multiple times across patterns. Enables generating
// static matcher functions for DagNode's referenced multiple times rather than
// inlining them.
class StaticMatcherHelper {
public:
  StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
                      RecordOperatorMap &mapper);

  // Determine if we should inline the match logic or delegate to a static
  // function.
  bool useStaticMatcher(DagNode node) {
    return refStats[node] > kStaticMatcherThreshold;
  }

  // Get the name of the static DAG matcher function corresponding to the node.
  std::string getMatcherName(DagNode node) {
    assert(useStaticMatcher(node));
    return matcherNames[node];
  }

  // Get the name of static type/attribute verification function.
  StringRef getVerifierName(DagLeaf leaf);

  // Collect the `Record`s, i.e., the DRR, so that we can get the information of
  // the duplicated DAGs.
  void addPattern(Record *record);

  // Emit all static functions of DAG Matcher.
  void populateStaticMatchers(raw_ostream &os);

  // Emit all static functions for Constraints.
  void populateStaticConstraintFunctions(raw_ostream &os);

private:
  static constexpr unsigned kStaticMatcherThreshold = 1;

  // Consider two patterns as down below,
  //   DagNode_Root_A    DagNode_Root_B
  //       \                 \
  //     DagNode_C         DagNode_C
  //         \                 \
  //       DagNode_D         DagNode_D
  //
  // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
  // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
  // multiple times so we'll have static matchers for both of them. When we're
  // emitting the match logic for DagNode_C, we will check if DagNode_D has the
  // static matcher generated. If so, then we'll generate a call to the
  // function, inline otherwise. In this case, inlining is not what we want. As
  // a result, generate the static matcher in topological order to ensure all
  // the dependent static matchers are generated and we can avoid accidentally
  // inlining.
  //
  // The topological order of all the DagNodes among all patterns.
  SmallVector<std::pair<DagNode, Record *>> topologicalOrder;

  RecordOperatorMap &opMap;

  // Records of the static function name of each DagNode
  DenseMap<DagNode, std::string> matcherNames;

  // After collecting all the DagNode in each pattern, `refStats` records the
  // number of users for each DagNode. We will generate the static matcher for a
  // DagNode while the number of users exceeds a certain threshold.
  DenseMap<DagNode, unsigned> refStats;

  // Number of static matcher generated. This is used to generate a unique name
  // for each DagNode.
  int staticMatcherCounter = 0;

  // The DagLeaf which contains type or attr constraint.
  SetVector<DagLeaf> constraints;

  // Static type/attribute verification function emitter.
  StaticVerifierFunctionEmitter staticVerifierEmitter;
};

} // namespace

PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
                               raw_ostream &os, StaticMatcherHelper &helper)
    : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
      symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
  fmtCtx.withBuilder("rewriter");
}

std::string PatternEmitter::handleConstantAttr(Attribute attr,
                                               const Twine &value) {
  if (!attr.isConstBuildable())
    PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
                             " does not have the 'constBuilderCall' field");

  // TODO: Verify the constants here
  return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
}

void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
  os << formatv(
      "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
      "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
      "*, 4> &tblgen_ops",
      funcName);

  // We pass the reference of the variables that need to be captured. Hence we
  // need to collect all the symbols in the tree first.
  pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
  symbolInfoMap.assignUniqueAlternativeNames();
  for (const auto &info : symbolInfoMap)
    os << formatv(", {0}", info.second.getArgDecl(info.first));

  os << ") {\n";
  os.indent();
  os << "(void)tblgen_ops;\n";

  // Note that a static matcher is considered at least one step from the match
  // entry.
  emitMatch(tree, "op0", /*depth=*/1);

  os << "return ::mlir::success();\n";
  os.unindent();
  os << "}\n\n";
}

// Helper function to match patterns.
void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
  if (tree.isNativeCodeCall()) {
    emitNativeCodeMatch(tree, name, depth);
    return;
  }

  if (tree.isOperation()) {
    emitOpMatch(tree, name, depth);
    return;
  }

  PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
}

void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
  std::string funcName = staticMatcherHelper.getMatcherName(tree);
  os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
                opName);

  // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
  // one pass.

  // In general, bound symbol should have the unique name in the pattern but
  // for the operand, binding same symbol to multiple operands imply a
  // constraint at the same time. In this case, we will rename those operands
  // with different names. As a result, we need to collect all the symbolInfos
  // from the DagNode then get the updated name of the local variables from the
  // global symbolInfoMap.

  // Collect all the bound symbols in the Dag
  SymbolInfoMap localSymbolMap(loc);
  pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);

  for (const auto &info : localSymbolMap) {
    auto name = info.first;
    auto symboInfo = info.second;
    auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
    os << formatv(", {0}", ret->second.getVarName(name));
  }

  os << "))) {\n";
  os.scope().os << "return ::mlir::failure();\n";
  os << "}\n";
}

void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
                                            StringRef opName, StringRef arg,
                                            StringRef failureStr) {
  os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
                funcName, opName, arg, failureStr);
  os.scope().os << "return ::mlir::failure();\n";
  os << "}\n";
}

// Helper function to match patterns.
void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
                                         int depth) {
  LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
  LLVM_DEBUG(tree.print(llvm::dbgs()));
  LLVM_DEBUG(llvm::dbgs() << '\n');

  // The order of generating static matcher follows the topological order so
  // that for every dependent DagNode already have their static matcher
  // generated if needed. The reason we check if `getMatcherName(tree).empty()`
  // is when we are generating the static matcher for a DagNode itself. In this
  // case, we need to emit the function body rather than a function call.
  if (staticMatcherHelper.useStaticMatcher(tree) &&
      !staticMatcherHelper.getMatcherName(tree).empty()) {
    emitStaticMatchCall(tree, opName);

    // NativeCodeCall will never be at depth 0 so that we don't need to catch
    // the root operation as emitOpMatch();

    return;
  }

  // TODO(suderman): iterate through arguments, determine their types, output
  // names.
  SmallVector<std::string, 8> capture;

  raw_indented_ostream::DelimitedScope scope(os);

  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
    std::string argName = formatv("arg{0}_{1}", depth, i);
    if (DagNode argTree = tree.getArgAsNestedDag(i)) {
      if (argTree.isEither())
        PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");

      os << "::mlir::Value " << argName << ";\n";
    } else {
      auto leaf = tree.getArgAsLeaf(i);
      if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
        os << "::mlir::Attribute " << argName << ";\n";
      } else {
        os << "::mlir::Value " << argName << ";\n";
      }
    }

    capture.push_back(std::move(argName));
  }

  auto tail = getTrailingDirectives(tree);
  if (tail.returnType)
    PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
  auto locToUse = getLocation(tail);

  auto fmt = tree.getNativeCodeTemplate();
  if (fmt.count("$_self") != 1)
    PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
                         "passing the defining Operation");

  auto nativeCodeCall = std::string(
      tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
            static_cast<ArrayRef<std::string>>(capture)));

  emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall),
                 formatv("\"{0} return ::mlir::failure\"", nativeCodeCall));

  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
    auto name = tree.getArgName(i);
    if (!name.empty() && name != "_") {
      os << formatv("{0} = {1};\n", name, capture[i]);
    }
  }

  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
    std::string argName = capture[i];

    // Handle nested DAG construct first
    if (DagNode argTree = tree.getArgAsNestedDag(i)) {
      PrintFatalError(
          loc, formatv("Matching nested tree in NativeCodecall not support for "
                       "{0} as arg {1}",
                       argName, i));
    }

    DagLeaf leaf = tree.getArgAsLeaf(i);

    // The parameter for native function doesn't bind any constraints.
    if (leaf.isUnspecified())
      continue;

    auto constraint = leaf.getAsConstraint();

    std::string self;
    if (leaf.isAttrMatcher() || leaf.isConstantAttr())
      self = argName;
    else
      self = formatv("{0}.getType()", argName);
    StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
    emitStaticVerifierCall(
        verifier, opName, self,
        formatv("\"operand {0} of native code call '{1}' failed to satisfy "
                "constraint: "
                "'{2}'\"",
                i, tree.getNativeCodeTemplate(),
                escapeString(constraint.getSummary()))
            .str());
  }

  LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
}

// Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
  Operator &op = tree.getDialectOp(opMap);
  LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
                          << op.getOperationName() << "' at depth " << depth
                          << '\n');

  auto getCastedName = [depth]() -> std::string {
    return formatv("castedOp{0}", depth);
  };

  // The order of generating static matcher follows the topological order so
  // that for every dependent DagNode already have their static matcher
  // generated if needed. The reason we check if `getMatcherName(tree).empty()`
  // is when we are generating the static matcher for a DagNode itself. In this
  // case, we need to emit the function body rather than a function call.
  if (staticMatcherHelper.useStaticMatcher(tree) &&
      !staticMatcherHelper.getMatcherName(tree).empty()) {
    emitStaticMatchCall(tree, opName);
    // In the codegen of rewriter, we suppose that castedOp0 will capture the
    // root operation. Manually add it if the root DagNode is a static matcher.
    if (depth == 0)
      os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
                    "(void){2};\n",
                    opName, op.getQualCppClassName(), getCastedName());
    return;
  }

  std::string castedName = getCastedName();
  os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
                "(void){0};\n",
                castedName, opName, op.getQualCppClassName());

  // Skip the operand matching at depth 0 as the pattern rewriter already does.
  if (depth != 0)
    emitMatchCheck(opName, /*matchStr=*/castedName,
                   formatv("\"{0} is not {1} type\"", castedName,
                           op.getQualCppClassName()));

  // If the operand's name is set, set to that variable.
  auto name = tree.getSymbol();
  if (!name.empty())
    os << formatv("{0} = {1};\n", name, castedName);

  for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
    auto opArg = op.getArg(i);
    std::string argName = formatv("op{0}", depth + 1);

    // Handle nested DAG construct first
    if (DagNode argTree = tree.getArgAsNestedDag(i)) {
      if (argTree.isEither()) {
        emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
                               depth);
        continue;
      }
      if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
        if (operand->isVariableLength()) {
          auto error = formatv("use nested DAG construct to match op {0}'s "
                               "variadic operand #{1} unsupported now",
                               op.getOperationName(), i);
          PrintFatalError(loc, error);
        }
      }

      os << "{\n";

      // Attributes don't count for getODSOperands.
      // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
      os.indent() << formatv(
          "auto *{0} = "
          "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
          argName, castedName, nextOperand);
      // Null check of operand's definingOp
      emitMatchCheck(
          castedName, /*matchStr=*/argName,
          formatv("\"There's no operation that defines operand {0} of {1}\"",
                  nextOperand++, castedName));
      emitMatch(argTree, argName, depth + 1);
      os << formatv("tblgen_ops.push_back({0});\n", argName);
      os.unindent() << "}\n";
      continue;
    }

    // Next handle DAG leaf: operand or attribute
    if (opArg.is<NamedTypeConstraint *>()) {
      auto operandName =
          formatv("{0}.getODSOperands({1})", castedName, nextOperand);
      emitOperandMatch(tree, castedName, operandName.str(),
                       /*operandMatcher=*/tree.getArgAsLeaf(i),
                       /*argName=*/tree.getArgName(i),
                       /*argIndex=*/i);
      ++nextOperand;
    } else if (opArg.is<NamedAttribute *>()) {
      emitAttributeMatch(tree, opName, i, depth);
    } else {
      PrintFatalError(loc, "unhandled case when matching op");
    }
  }
  LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
                          << op.getOperationName() << "' at depth " << depth
                          << '\n');
}

void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
                                      StringRef operandName,
                                      DagLeaf operandMatcher, StringRef argName,
                                      int argIndex) {
  Operator &op = tree.getDialectOp(opMap);
  auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();

  // If a constraint is specified, we need to generate C++ statements to
  // check the constraint.
  if (!operandMatcher.isUnspecified()) {
    if (!operandMatcher.isOperandMatcher())
      PrintFatalError(
          loc, formatv("the {1}-th argument of op '{0}' should be an operand",
                       op.getOperationName(), argIndex + 1));

    // Only need to verify if the matcher's type is different from the one
    // of op definition.
    Constraint constraint = operandMatcher.getAsConstraint();
    if (operand->constraint != constraint) {
      if (operand->isVariableLength()) {
        auto error = formatv(
            "further constrain op {0}'s variadic operand #{1} unsupported now",
            op.getOperationName(), argIndex);
        PrintFatalError(loc, error);
      }
      auto self = formatv("(*{0}.begin()).getType()", operandName);
      StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
      emitStaticVerifierCall(
          verifier, opName, self.str(),
          formatv(
              "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
              operand - op.operand_begin(), op.getOperationName(),
              escapeString(constraint.getSummary()))
              .str());
    }
  }

  // Capture the value
  // `$_` is a special symbol to ignore op argument matching.
  if (!argName.empty() && argName != "_") {
    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
    os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
  }
}

void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
                                            StringRef opName, int argIndex,
                                            int &operandIndex, int depth) {
  constexpr int numEitherArgs = 2;
  if (eitherArgTree.getNumArgs() != numEitherArgs)
    PrintFatalError(loc, "`either` only supports grouping two operands");

  Operator &op = tree.getDialectOp(opMap);

  std::string codeBuffer;
  llvm::raw_string_ostream tblgenOps(codeBuffer);

  std::string lambda = formatv("eitherLambda{0}", depth);
  os << formatv(
      "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
      lambda);

  os.indent();

  for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
    if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
      if (argTree.isEither())
        PrintFatalError(loc, "either cannot be nested");

      std::string argName = formatv("local_op_{0}", i).str();

      os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
                    i);
      emitMatchCheck(
          opName, /*matchStr=*/argName,
          formatv("\"There's no operation that defines operand {0} of {1}\"",
                  operandIndex++, opName));
      emitMatch(argTree, argName, depth + 1);
      // `tblgen_ops` is used to collect the matched operations. In either, we
      // need to queue the operation only if the matching success. Thus we emit
      // the code at the end.
      tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
    } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
      emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
                       /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
                       /*argName=*/eitherArgTree.getArgName(i), argIndex);
      ++operandIndex;
    } else {
      PrintFatalError(loc, "either can only be applied on operand");
    }
  }

  os << tblgenOps.str();
  os << "return ::mlir::success();\n";
  os.unindent() << "};\n";

  os << "{\n";
  os.indent();

  os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
                operandIndex - 2);
  os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
                operandIndex - 1);

  os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
                "::mlir::failed({0}(eitherOperand1, "
                "eitherOperand0)))\n",
                lambda);
  os.indent() << "return ::mlir::failure();\n";

  os.unindent().unindent() << "}\n";
}

void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
                                        int argIndex, int depth) {
  Operator &op = tree.getDialectOp(opMap);
  auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
  const auto &attr = namedAttr->attr;

  os << "{\n";
  os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
                         "(void)tblgen_attr;\n",
                         opName, attr.getStorageType(), namedAttr->name);

  // TODO: This should use getter method to avoid duplication.
  if (attr.hasDefaultValue()) {
    os << "if (!tblgen_attr) tblgen_attr = "
       << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
                            attr.getDefaultValue()))
       << ";\n";
  } else if (attr.isOptional()) {
    // For a missing attribute that is optional according to definition, we
    // should just capture a mlir::Attribute() to signal the missing state.
    // That is precisely what getAttr() returns on missing attributes.
  } else {
    emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
                   formatv("\"expected op '{0}' to have attribute '{1}' "
                           "of type '{2}'\"",
                           op.getOperationName(), namedAttr->name,
                           attr.getStorageType()));
  }

  auto matcher = tree.getArgAsLeaf(argIndex);
  if (!matcher.isUnspecified()) {
    if (!matcher.isAttrMatcher()) {
      PrintFatalError(
          loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
                       op.getOperationName(), argIndex + 1));
    }

    // If a constraint is specified, we need to generate function call to its
    // static verifier.
    StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
    if (attr.isOptional()) {
      // Avoid dereferencing null attribute. This is using a simple heuristic to
      // avoid common cases of attempting to dereference null attribute. This
      // will return where there is no check if attribute is null unless the
      // attribute's value is not used.
      // FIXME: This could be improved as some null dereferences could slip
      // through.
      if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") &&
          StringRef(matcher.getConditionTemplate()).contains("$_self")) {
        os << "if (!tblgen_attr) return ::mlir::failure();\n";
      }
    }
    emitStaticVerifierCall(
        verifier, opName, "tblgen_attr",
        formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
                "'{2}'\"",
                op.getOperationName(), namedAttr->name,
                escapeString(matcher.getAsConstraint().getSummary()))
            .str());
  }

  // Capture the value
  auto name = tree.getArgName(argIndex);
  // `$_` is a special symbol to ignore op argument matching.
  if (!name.empty() && name != "_") {
    os << formatv("{0} = tblgen_attr;\n", name);
  }

  os.unindent() << "}\n";
}

void PatternEmitter::emitMatchCheck(
    StringRef opName, const FmtObjectBase &matchFmt,
    const llvm::formatv_object_base &failureFmt) {
  emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
}

void PatternEmitter::emitMatchCheck(StringRef opName,
                                    const std::string &matchStr,
                                    const std::string &failureStr) {

  os << "if (!(" << matchStr << "))";
  os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
                              << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
                              << failureStr << ";\n});";
}

void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
  LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
  int depth = 0;
  emitMatch(tree, opName, depth);

  for (auto &appliedConstraint : pattern.getConstraints()) {
    auto &constraint = appliedConstraint.constraint;
    auto &entities = appliedConstraint.entities;

    auto condition = constraint.getConditionTemplate();
    if (isa<TypeConstraint>(constraint)) {
      if (entities.size() != 1)
        PrintFatalError(loc, "type constraint requires exactly one argument");

      auto self = formatv("({0}.getType())",
                          symbolInfoMap.getValueAndRangeUse(entities.front()));
      emitMatchCheck(
          opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
          formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
                  entities.front(), escapeString(constraint.getSummary())));

    } else if (isa<AttrConstraint>(constraint)) {
      PrintFatalError(
          loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
    } else {
      // TODO: replace formatv arguments with the exact specified
      // args.
      if (entities.size() > 4) {
        PrintFatalError(loc, "only support up to 4-entity constraints now");
      }
      SmallVector<std::string, 4> names;
      int i = 0;
      for (int e = entities.size(); i < e; ++i)
        names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
      std::string self = appliedConstraint.self;
      if (!self.empty())
        self = symbolInfoMap.getValueAndRangeUse(self);
      for (; i < 4; ++i)
        names.push_back("<unused>");
      emitMatchCheck(opName,
                     tgfmt(condition, &fmtCtx.withSelf(self), names[0],
                           names[1], names[2], names[3]),
                     formatv("\"entities '{0}' failed to satisfy constraint: "
                             "'{1}'\"",
                             llvm::join(entities, ", "),
                             escapeString(constraint.getSummary())));
    }
  }

  // Some of the operands could be bound to the same symbol name, we need
  // to enforce equality constraint on those.
  // TODO: we should be able to emit equality checks early
  // and short circuit unnecessary work if vars are not equal.
  for (auto symbolInfoIt = symbolInfoMap.begin();
       symbolInfoIt != symbolInfoMap.end();) {
    auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
    auto startRange = range.first;
    auto endRange = range.second;

    auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
    for (++startRange; startRange != endRange; ++startRange) {
      auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
      emitMatchCheck(
          opName,
          formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
          formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
                  secondOperand));
    }

    symbolInfoIt = endRange;
  }

  LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}

void PatternEmitter::collectOps(DagNode tree,
                                llvm::SmallPtrSetImpl<const Operator *> &ops) {
  // Check if this tree is an operation.
  if (tree.isOperation()) {
    const Operator &op = tree.getDialectOp(opMap);
    LLVM_DEBUG(llvm::dbgs()
               << "found operation " << op.getOperationName() << '\n');
    ops.insert(&op);
  }

  // Recurse the arguments of the tree.
  for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
    if (auto child = tree.getArgAsNestedDag(i))
      collectOps(child, ops);
}

void PatternEmitter::emit(StringRef rewriteName) {
  // Get the DAG tree for the source pattern.
  DagNode sourceTree = pattern.getSourcePattern();

  const Operator &rootOp = pattern.getSourceRootOp();
  auto rootName = rootOp.getOperationName();

  // Collect the set of result operations.
  llvm::SmallPtrSet<const Operator *, 4> resultOps;
  LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
  for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
    collectOps(pattern.getResultPattern(i), resultOps);
  }
  LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");

  // Emit RewritePattern for Pattern.
  auto locs = pattern.getLocation();
  os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
                make_range(locs.rbegin(), locs.rend()));
  os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
  {0}(::mlir::MLIRContext *context)
      : ::mlir::RewritePattern("{1}", {2}, context, {{)",
                rewriteName, rootName, pattern.getBenefit());
  // Sort result operators by name.
  llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
                                                         resultOps.end());
  llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
    return lhs->getOperationName() < rhs->getOperationName();
  });
  llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
    os << '"' << op->getOperationName() << '"';
  });
  os << "}) {}\n";

  // Emit matchAndRewrite() function.
  {
    auto classScope = os.scope();
    os.printReindented(R"(
    ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
        ::mlir::PatternRewriter &rewriter) const override {)")
        << '\n';
    {
      auto functionScope = os.scope();

      // Register all symbols bound in the source pattern.
      pattern.collectSourcePatternBoundSymbols(symbolInfoMap);

      LLVM_DEBUG(llvm::dbgs()
                 << "start creating local variables for capturing matches\n");
      os << "// Variables for capturing values and attributes used while "
            "creating ops\n";
      // Create local variables for storing the arguments and results bound
      // to symbols.
      for (const auto &symbolInfoPair : symbolInfoMap) {
        const auto &symbol = symbolInfoPair.first;
        const auto &info = symbolInfoPair.second;

        os << info.getVarDecl(symbol);
      }
      // TODO: capture ops with consistent numbering so that it can be
      // reused for fused loc.
      os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
      LLVM_DEBUG(llvm::dbgs()
                 << "done creating local variables for capturing matches\n");

      os << "// Match\n";
      os << "tblgen_ops.push_back(op0);\n";
      emitMatchLogic(sourceTree, "op0");

      os << "\n// Rewrite\n";
      emitRewriteLogic();

      os << "return ::mlir::success();\n";
    }
    os << "};\n";
  }
  os << "};\n\n";
}

void PatternEmitter::emitRewriteLogic() {
  LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
  const Operator &rootOp = pattern.getSourceRootOp();
  int numExpectedResults = rootOp.getNumResults();
  int numResultPatterns = pattern.getNumResultPatterns();

  // First register all symbols bound to ops generated in result patterns.
  pattern.collectResultPatternBoundSymbols(symbolInfoMap);

  // Only the last N static values generated are used to replace the matched
  // root N-result op. We need to calculate the starting index (of the results
  // of the matched op) each result pattern is to replace.
  SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
  // If we don't need to replace any value at all, set the replacement starting
  // index as the number of result patterns so we skip all of them when trying
  // to replace the matched op's results.
  int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
  for (int i = numResultPatterns - 1; i >= 0; --i) {
    auto numValues = getNodeValueCount(pattern.getResultPattern(i));
    offsets[i] = offsets[i + 1] - numValues;
    if (offsets[i] == 0) {
      if (replStartIndex == -1)
        replStartIndex = i;
    } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
      auto error = formatv(
          "cannot use the same multi-result op '{0}' to generate both "
          "auxiliary values and values to be used for replacing the matched op",
          pattern.getResultPattern(i).getSymbol());
      PrintFatalError(loc, error);
    }
  }

  if (offsets.front() > 0) {
    const char error[] = "no enough values generated to replace the matched op";
    PrintFatalError(loc, error);
  }

  os << "auto odsLoc = rewriter.getFusedLoc({";
  for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
    os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
  }
  os << "}); (void)odsLoc;\n";

  // Process auxiliary result patterns.
  for (int i = 0; i < replStartIndex; ++i) {
    DagNode resultTree = pattern.getResultPattern(i);
    auto val = handleResultPattern(resultTree, offsets[i], 0);
    // Normal op creation will be streamed to `os` by the above call; but
    // NativeCodeCall will only be materialized to `os` if it is used. Here
    // we are handling auxiliary patterns so we want the side effect even if
    // NativeCodeCall is not replacing matched root op's results.
    if (resultTree.isNativeCodeCall() &&
        resultTree.getNumReturnsOfNativeCode() == 0)
      os << val << ";\n";
  }

  if (numExpectedResults == 0) {
    assert(replStartIndex >= numResultPatterns &&
           "invalid auxiliary vs. replacement pattern division!");
    // No result to replace. Just erase the op.
    os << "rewriter.eraseOp(op0);\n";
  } else {
    // Process replacement result patterns.
    os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
    for (int i = replStartIndex; i < numResultPatterns; ++i) {
      DagNode resultTree = pattern.getResultPattern(i);
      auto val = handleResultPattern(resultTree, offsets[i], 0);
      os << "\n";
      // Resolve each symbol for all range use so that we can loop over them.
      // We need an explicit cast to `SmallVector` to capture the cases where
      // `{0}` resolves to an `Operation::result_range` as well as cases that
      // are not iterable (e.g. vector that gets wrapped in additional braces by
      // RewriterGen).
      // TODO: Revisit the need for materializing a vector.
      os << symbolInfoMap.getAllRangeUse(
          val,
          "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
          "  tblgen_repl_values.push_back(v);\n}\n",
          "\n");
    }
    os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
  }

  LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
}

std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
  return std::string(
      formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
}

std::string PatternEmitter::handleResultPattern(DagNode resultTree,
                                                int resultIndex, int depth) {
  LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
  LLVM_DEBUG(resultTree.print(llvm::dbgs()));
  LLVM_DEBUG(llvm::dbgs() << '\n');

  if (resultTree.isLocationDirective()) {
    PrintFatalError(loc,
                    "location directive can only be used with op creation");
  }

  if (resultTree.isNativeCodeCall())
    return handleReplaceWithNativeCodeCall(resultTree, depth);

  if (resultTree.isReplaceWithValue())
    return handleReplaceWithValue(resultTree).str();

  // Normal op creation.
  auto symbol = handleOpCreation(resultTree, resultIndex, depth);
  if (resultTree.getSymbol().empty()) {
    // This is an op not explicitly bound to a symbol in the rewrite rule.
    // Register the auto-generated symbol for it.
    symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
  }
  return symbol;
}

StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
  assert(tree.isReplaceWithValue());

  if (tree.getNumArgs() != 1) {
    PrintFatalError(
        loc, "replaceWithValue directive must take exactly one argument");
  }

  if (!tree.getSymbol().empty()) {
    PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
  }

  return tree.getArgName(0);
}

std::string PatternEmitter::handleLocationDirective(DagNode tree) {
  assert(tree.isLocationDirective());
  auto lookUpArgLoc = [this, &tree](int idx) {
    const auto *const lookupFmt = "{0}.getLoc()";
    return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt);
  };

  if (tree.getNumArgs() == 0)
    llvm::PrintFatalError(
        "At least one argument to location directive required");

  if (!tree.getSymbol().empty())
    PrintFatalError(loc, "cannot bind symbol to location");

  if (tree.getNumArgs() == 1) {
    DagLeaf leaf = tree.getArgAsLeaf(0);
    if (leaf.isStringAttr())
      return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
                     leaf.getStringAttr())
          .str();
    return lookUpArgLoc(0);
  }

  std::string ret;
  llvm::raw_string_ostream os(ret);
  std::string strAttr;
  os << "rewriter.getFusedLoc({";
  bool first = true;
  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
    DagLeaf leaf = tree.getArgAsLeaf(i);
    // Handle the optional string value.
    if (leaf.isStringAttr()) {
      if (!strAttr.empty())
        llvm::PrintFatalError("Only one string attribute may be specified");
      strAttr = leaf.getStringAttr();
      continue;
    }
    os << (first ? "" : ", ") << lookUpArgLoc(i);
    first = false;
  }
  os << "}";
  if (!strAttr.empty()) {
    os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
  }
  os << ")";
  return os.str();
}

std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
                                                int depth) {
  // Nested NativeCodeCall.
  if (auto dagNode = returnType.getArgAsNestedDag(i)) {
    if (!dagNode.isNativeCodeCall())
      PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
                           "call");
    return handleReplaceWithNativeCodeCall(dagNode, depth);
  }
  // String literal.
  auto dagLeaf = returnType.getArgAsLeaf(i);
  if (dagLeaf.isStringAttr())
    return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
  return tgfmt(
      "$0.getType()", &fmtCtx,
      handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
}

std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
                                             StringRef patArgName) {
  if (leaf.isStringAttr())
    PrintFatalError(loc, "raw string not supported as argument");
  if (leaf.isConstantAttr()) {
    auto constAttr = leaf.getAsConstantAttr();
    return handleConstantAttr(constAttr.getAttribute(),
                              constAttr.getConstantValue());
  }
  if (leaf.isEnumAttrCase()) {
    auto enumCase = leaf.getAsEnumAttrCase();
    // This is an enum case backed by an IntegerAttr. We need to get its value
    // to build the constant.
    std::string val = std::to_string(enumCase.getValue());
    return handleConstantAttr(enumCase, val);
  }

  LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
  auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
  if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
    LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
                            << "' (via symbol ref)\n");
    return argName;
  }
  if (leaf.isNativeCodeCall()) {
    auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
    LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
                            << "' (via NativeCodeCall)\n");
    return std::string(repl);
  }
  PrintFatalError(loc, "unhandled case when rewriting op");
}

std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
                                                            int depth) {
  LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
  LLVM_DEBUG(tree.print(llvm::dbgs()));
  LLVM_DEBUG(llvm::dbgs() << '\n');

  auto fmt = tree.getNativeCodeTemplate();

  SmallVector<std::string, 16> attrs;

  auto tail = getTrailingDirectives(tree);
  if (tail.returnType)
    PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
  auto locToUse = getLocation(tail);

  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
    if (tree.isNestedDagArg(i)) {
      attrs.push_back(
          handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
    } else {
      attrs.push_back(
          handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
    }
    LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
                            << " replacement: " << attrs[i] << "\n");
  }

  std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
                             static_cast<ArrayRef<std::string>>(attrs));

  // In general, NativeCodeCall without naming binding don't need this. To
  // ensure void helper function has been correctly labeled, i.e., use
  // NativeCodeCallVoid, we cache the result to a local variable so that we will
  // get a compilation error in the auto-generated file.
  // Example.
  //   // In the td file
  //   Pat<(...), (NativeCodeCall<Foo> ...)>
  //
  //   ---
  //
  //   // In the auto-generated .cpp
  //   ...
  //   // Causes compilation error if Foo() returns void.
  //   auto nativeVar = Foo();
  //   ...
  if (tree.getNumReturnsOfNativeCode() != 0) {
    // Determine the local variable name for return value.
    std::string varName =
        SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
    if (varName.empty()) {
      varName = formatv("nativeVar_{0}", nextValueId++);
      // Register the local variable for later uses.
      symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
    }

    // Catch the return value of helper function.
    os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);

    if (!tree.getSymbol().empty())
      symbol = tree.getSymbol().str();
    else
      symbol = varName;
  }

  return symbol;
}

int PatternEmitter::getNodeValueCount(DagNode node) {
  if (node.isOperation()) {
    // If the op is bound to a symbol in the rewrite rule, query its result
    // count from the symbol info map.
    auto symbol = node.getSymbol();
    if (!symbol.empty()) {
      return symbolInfoMap.getStaticValueCount(symbol);
    }
    // Otherwise this is an unbound op; we will use all its results.
    return pattern.getDialectOp(node).getNumResults();
  }

  if (node.isNativeCodeCall())
    return node.getNumReturnsOfNativeCode();

  return 1;
}

PatternEmitter::TrailingDirectives
PatternEmitter::getTrailingDirectives(DagNode tree) {
  TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};

  // Look backwards through the arguments.
  auto numPatArgs = tree.getNumArgs();
  for (int i = numPatArgs - 1; i >= 0; --i) {
    auto dagArg = tree.getArgAsNestedDag(i);
    // A leaf is not a directive. Stop looking.
    if (!dagArg)
      break;

    auto isLocation = dagArg.isLocationDirective();
    auto isReturnType = dagArg.isReturnTypeDirective();
    // If encountered a DAG node that isn't a trailing directive, stop looking.
    if (!(isLocation || isReturnType))
      break;
    // Save the directive, but error if one of the same type was already
    // found.
    ++tail.numDirectives;
    if (isLocation) {
      if (tail.location)
        PrintFatalError(loc, "`location` directive can only be specified "
                             "once");
      tail.location = dagArg;
    } else if (isReturnType) {
      if (tail.returnType)
        PrintFatalError(loc, "`returnType` directive can only be specified "
                             "once");
      tail.returnType = dagArg;
    }
  }

  return tail;
}

std::string
PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
  if (tail.location)
    return handleLocationDirective(tail.location);

  // If no explicit location is given, use the default, all fused, location.
  return "odsLoc";
}

std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
                                             int depth) {
  LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
  LLVM_DEBUG(tree.print(llvm::dbgs()));
  LLVM_DEBUG(llvm::dbgs() << '\n');

  Operator &resultOp = tree.getDialectOp(opMap);
  auto numOpArgs = resultOp.getNumArgs();
  auto numPatArgs = tree.getNumArgs();

  auto tail = getTrailingDirectives(tree);
  auto locToUse = getLocation(tail);

  auto inPattern = numPatArgs - tail.numDirectives;
  if (numOpArgs != inPattern) {
    PrintFatalError(loc,
                    formatv("resultant op '{0}' argument number mismatch: "
                            "{1} in pattern vs. {2} in definition",
                            resultOp.getOperationName(), inPattern, numOpArgs));
  }

  // A map to collect all nested DAG child nodes' names, with operand index as
  // the key. This includes both bound and unbound child nodes.
  ChildNodeIndexNameMap childNodeNames;

  // First go through all the child nodes who are nested DAG constructs to
  // create ops for them and remember the symbol names for them, so that we can
  // use the results in the current node. This happens in a recursive manner.
  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
    if (auto child = tree.getArgAsNestedDag(i))
      childNodeNames[i] = handleResultPattern(child, i, depth + 1);
  }

  // The name of the local variable holding this op.
  std::string valuePackName;
  // The symbol for holding the result of this pattern. Note that the result of
  // this pattern is not necessarily the same as the variable created by this
  // pattern because we can use `__N` suffix to refer only a specific result if
  // the generated op is a multi-result op.
  std::string resultValue;
  if (tree.getSymbol().empty()) {
    // No symbol is explicitly bound to this op in the pattern. Generate a
    // unique name.
    valuePackName = resultValue = getUniqueSymbol(&resultOp);
  } else {
    resultValue = std::string(tree.getSymbol());
    // Strip the index to get the name for the value pack and use it to name the
    // local variable for the op.
    valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
  }

  // Create the local variable for this op.
  os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
                valuePackName);

  // Right now ODS don't have general type inference support. Except a few
  // special cases listed below, DRR needs to supply types for all results
  // when building an op.
  bool isSameOperandsAndResultType =
      resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
  bool useFirstAttr =
      resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");

  if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
    // We know how to deduce the result type for ops with these traits and we've
    // generated builders taking aggregate parameters. Use those builders to
    // create the ops.

    // First prepare local variables for op arguments used in builder call.
    createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);

    // Then create the op.
    os.scope("", "\n}\n").os << formatv(
        "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
        valuePackName, resultOp.getQualCppClassName(), locToUse);
    return resultValue;
  }

  bool usePartialResults = valuePackName != resultValue;

  if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
    // For these cases (broadcastable ops, op results used both as auxiliary
    // values and replacement values, ops in nested patterns, auxiliary ops), we
    // still need to supply the result types when building the op. But because
    // we don't generate a builder automatically with ODS for them, it's the
    // developer's responsibility to make sure such a builder (with result type
    // deduction ability) exists. We go through the separate-parameter builder
    // here given that it's easier for developers to write compared to
    // aggregate-parameter builders.
    createSeparateLocalVarsForOpArgs(tree, childNodeNames);

    os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
                             resultOp.getQualCppClassName(), locToUse);
    supplyValuesForOpArgs(tree, childNodeNames, depth);
    os << "\n  );\n}\n";
    return resultValue;
  }

  // If we are provided explicit return types, use them to build the op.
  // However, if depth == 0 and resultIndex >= 0, it means we are replacing
  // the values generated from the source pattern root op. Then we must use the
  // source pattern's value types to determine the value type of the generated
  // op here.
  if (depth == 0 && resultIndex >= 0 && tail.returnType)
    PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
                         "return values replace the source pattern's root op");

  // First prepare local variables for op arguments used in builder call.
  createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);

  // Then prepare the result types. We need to specify the types for all
  // results.
  os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
                         "(void)tblgen_types;\n");
  int numResults = resultOp.getNumResults();
  if (tail.returnType) {
    auto numRetTys = tail.returnType.getNumArgs();
    for (int i = 0; i < numRetTys; ++i) {
      auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
      os << "tblgen_types.push_back(" << varName << ");\n";
    }
  } else {
    if (numResults != 0) {
      // Copy the result types from the source pattern.
      for (int i = 0; i < numResults; ++i)
        os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
                      "  tblgen_types.push_back(v.getType());\n}\n",
                      resultIndex + i);
    }
  }
  os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
                "tblgen_values, tblgen_attrs);\n",
                valuePackName, resultOp.getQualCppClassName(), locToUse);
  os.unindent() << "}\n";
  return resultValue;
}

void PatternEmitter::createSeparateLocalVarsForOpArgs(
    DagNode node, ChildNodeIndexNameMap &childNodeNames) {
  Operator &resultOp = node.getDialectOp(opMap);

  // Now prepare operands used for building this op:
  // * If the operand is non-variadic, we create a `Value` local variable.
  // * If the operand is variadic, we create a `SmallVector<Value>` local
  //   variable.

  int valueIndex = 0; // An index for uniquing local variable names.
  for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
    const auto *operand =
        resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
    // We do not need special handling for attributes.
    if (!operand)
      continue;

    raw_indented_ostream::DelimitedScope scope(os);
    std::string varName;
    if (operand->isVariadic()) {
      varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
      os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
      std::string range;
      if (node.isNestedDagArg(argIndex)) {
        range = childNodeNames[argIndex];
      } else {
        range = std::string(node.getArgName(argIndex));
      }
      // Resolve the symbol for all range use so that we have a uniform way of
      // capturing the values.
      range = symbolInfoMap.getValueAndRangeUse(range);
      os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
                    varName);
    } else {
      varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
      os << formatv("::mlir::Value {0} = ", varName);
      if (node.isNestedDagArg(argIndex)) {
        os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
      } else {
        DagLeaf leaf = node.getArgAsLeaf(argIndex);
        auto symbol =
            symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
        if (leaf.isNativeCodeCall()) {
          os << std::string(
              tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
        } else {
          os << symbol;
        }
      }
      os << ";\n";
    }

    // Update to use the newly created local variable for building the op later.
    childNodeNames[argIndex] = varName;
  }
}

void PatternEmitter::supplyValuesForOpArgs(
    DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
  Operator &resultOp = node.getDialectOp(opMap);
  for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
       argIndex != numOpArgs; ++argIndex) {
    // Start each argument on its own line.
    os << ",\n    ";

    Argument opArg = resultOp.getArg(argIndex);
    // Handle the case of operand first.
    if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
      if (!operand->name.empty())
        os << "/*" << operand->name << "=*/";
      os << childNodeNames.lookup(argIndex);
      continue;
    }

    // The argument in the op definition.
    auto opArgName = resultOp.getArgName(argIndex);
    if (auto subTree = node.getArgAsNestedDag(argIndex)) {
      if (!subTree.isNativeCodeCall())
        PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                             "for creating attribute");
      os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
    } else {
      auto leaf = node.getArgAsLeaf(argIndex);
      // The argument in the result DAG pattern.
      auto patArgName = node.getArgName(argIndex);
      if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
        // TODO: Refactor out into map to avoid recomputing these.
        if (!opArg.is<NamedAttribute *>())
          PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
        if (!patArgName.empty())
          os << "/*" << patArgName << "=*/";
      } else {
        os << "/*" << opArgName << "=*/";
      }
      os << handleOpArgument(leaf, patArgName);
    }
  }
}

void PatternEmitter::createAggregateLocalVarsForOpArgs(
    DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
  Operator &resultOp = node.getDialectOp(opMap);

  auto scope = os.scope();
  os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
                "tblgen_values; (void)tblgen_values;\n");
  os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
                "tblgen_attrs; (void)tblgen_attrs;\n");

  const char *addAttrCmd =
      "if (auto tmpAttr = {1}) {\n"
      "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
      "tmpAttr);\n}\n";
  for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
    if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
      // The argument in the op definition.
      auto opArgName = resultOp.getArgName(argIndex);
      if (auto subTree = node.getArgAsNestedDag(argIndex)) {
        if (!subTree.isNativeCodeCall())
          PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                               "for creating attribute");
        os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
      } else {
        auto leaf = node.getArgAsLeaf(argIndex);
        // The argument in the result DAG pattern.
        auto patArgName = node.getArgName(argIndex);
        os << formatv(addAttrCmd, opArgName,
                      handleOpArgument(leaf, patArgName));
      }
      continue;
    }

    const auto *operand =
        resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
    std::string varName;
    if (operand->isVariadic()) {
      std::string range;
      if (node.isNestedDagArg(argIndex)) {
        range = childNodeNames.lookup(argIndex);
      } else {
        range = std::string(node.getArgName(argIndex));
      }
      // Resolve the symbol for all range use so that we have a uniform way of
      // capturing the values.
      range = symbolInfoMap.getValueAndRangeUse(range);
      os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
                    range);
    } else {
      os << formatv("tblgen_values.push_back(");
      if (node.isNestedDagArg(argIndex)) {
        os << symbolInfoMap.getValueAndRangeUse(
            childNodeNames.lookup(argIndex));
      } else {
        DagLeaf leaf = node.getArgAsLeaf(argIndex);
        if (leaf.isConstantAttr())
          // TODO: Use better location
          PrintFatalError(
              loc,
              "attribute found where value was expected, if attempting to use "
              "constant value, construct a constant op with given attribute "
              "instead");

        auto symbol =
            symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
        if (leaf.isNativeCodeCall()) {
          os << std::string(
              tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
        } else {
          os << symbol;
        }
      }
      os << ");\n";
    }
  }
}

StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
                                         const RecordKeeper &recordKeeper,
                                         RecordOperatorMap &mapper)
    : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}

void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
  // PatternEmitter will use the static matcher if there's one generated. To
  // ensure that all the dependent static matchers are generated before emitting
  // the matching logic of the DagNode, we use topological order to achieve it.
  for (auto &dagInfo : topologicalOrder) {
    DagNode node = dagInfo.first;
    if (!useStaticMatcher(node))
      continue;

    std::string funcName =
        formatv("static_dag_matcher_{0}", staticMatcherCounter++);
    assert(matcherNames.find(node) == matcherNames.end());
    PatternEmitter(dagInfo.second, &opMap, os, *this)
        .emitStaticMatcher(node, funcName);
    matcherNames[node] = funcName;
  }
}

void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
  staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
}

void StaticMatcherHelper::addPattern(Record *record) {
  Pattern pat(record, &opMap);

  // While generating the function body of the DAG matcher, it may depends on
  // other DAG matchers. To ensure the dependent matchers are ready, we compute
  // the topological order for all the DAGs and emit the DAG matchers in this
  // order.
  llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
    ++refStats[node];

    if (refStats[node] != 1)
      return;

    for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
      if (DagNode sibling = node.getArgAsNestedDag(i))
        dfs(sibling);
      else {
        DagLeaf leaf = node.getArgAsLeaf(i);
        if (!leaf.isUnspecified())
          constraints.insert(leaf);
      }

    topologicalOrder.push_back(std::make_pair(node, record));
  };

  dfs(pat.getSourcePattern());
}

StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
  if (leaf.isAttrMatcher()) {
    Optional<StringRef> constraint =
        staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
    assert(constraint.hasValue() && "attribute constraint was not uniqued");
    return *constraint;
  }
  assert(leaf.isOperandMatcher());
  return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
}

static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
  emitSourceFileHeader("Rewriters", os);

  const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");

  // We put the map here because it can be shared among multiple patterns.
  RecordOperatorMap recordOpMap;

  // Exam all the patterns and generate static matcher for the duplicated
  // DagNode.
  StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
  for (Record *p : patterns)
    staticMatcher.addPattern(p);
  staticMatcher.populateStaticConstraintFunctions(os);
  staticMatcher.populateStaticMatchers(os);

  std::vector<std::string> rewriterNames;
  rewriterNames.reserve(patterns.size());

  std::string baseRewriterName = "GeneratedConvert";
  int rewriterIndex = 0;

  for (Record *p : patterns) {
    std::string name;
    if (p->isAnonymous()) {
      // If no name is provided, ensure unique rewriter names simply by
      // appending unique suffix.
      name = baseRewriterName + llvm::utostr(rewriterIndex++);
    } else {
      name = std::string(p->getName());
    }
    LLVM_DEBUG(llvm::dbgs()
               << "=== start generating pattern '" << name << "' ===\n");
    PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
    LLVM_DEBUG(llvm::dbgs()
               << "=== done generating pattern '" << name << "' ===\n");
    rewriterNames.push_back(std::move(name));
  }

  // Emit function to add the generated matchers to the pattern list.
  os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
        "::mlir::RewritePatternSet &patterns) {\n";
  for (const auto &name : rewriterNames) {
    os << "  patterns.add<" << name << ">(patterns.getContext());\n";
  }
  os << "}\n";
}

static mlir::GenRegistration
    genRewriters("gen-rewriters", "Generate pattern rewriters",
                 [](const RecordKeeper &records, raw_ostream &os) {
                   emitRewriters(records, os);
                   return false;
                 });
