//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"

using namespace mlir;
using namespace mlir::pdl_interp;

#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"

//===----------------------------------------------------------------------===//
// PDLInterp Dialect
//===----------------------------------------------------------------------===//

void PDLInterpDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
      >();
}

template <typename OpT>
static LogicalResult verifySwitchOp(OpT op) {
  // Verify that the number of case destinations matches the number of case
  // values.
  size_t numDests = op.getCases().size();
  size_t numValues = op.getCaseValues().size();
  if (numDests != numValues) {
    return op.emitOpError(
               "expected number of cases to match the number of case "
               "values, got ")
           << numDests << " but expected " << numValues;
  }
  return success();
}

//===----------------------------------------------------------------------===//
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//

LogicalResult CreateOperationOp::verify() {
  if (!getInferredResultTypes())
    return success();
  if (!getInputResultTypes().empty()) {
    return emitOpError("with inferred results cannot also have "
                       "explicit result types");
  }
  OperationName opName(getName(), getContext());
  if (!opName.hasInterface<InferTypeOpInterface>()) {
    return emitOpError()
           << "has inferred results, but the created operation '" << opName
           << "' does not support result type inference (or is not "
              "registered)";
  }
  return success();
}

static ParseResult parseCreateOperationOpAttributes(
    OpAsmParser &p,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
    ArrayAttr &attrNamesAttr) {
  Builder &builder = p.getBuilder();
  SmallVector<Attribute, 4> attrNames;
  if (succeeded(p.parseOptionalLBrace())) {
    do {
      StringAttr nameAttr;
      OpAsmParser::UnresolvedOperand operand;
      if (p.parseAttribute(nameAttr) || p.parseEqual() ||
          p.parseOperand(operand))
        return failure();
      attrNames.push_back(nameAttr);
      attrOperands.push_back(operand);
    } while (succeeded(p.parseOptionalComma()));
    if (p.parseRBrace())
      return failure();
  }
  attrNamesAttr = builder.getArrayAttr(attrNames);
  return success();
}

static void printCreateOperationOpAttributes(OpAsmPrinter &p,
                                             CreateOperationOp op,
                                             OperandRange attrArgs,
                                             ArrayAttr attrNames) {
  if (attrNames.empty())
    return;
  p << " {";
  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
                  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
  p << '}';
}

static ParseResult parseCreateOperationOpResults(
    OpAsmParser &p,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
    SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
  if (failed(p.parseOptionalArrow()))
    return success();

  // Handle the case of inferred results.
  if (succeeded(p.parseOptionalLess())) {
    if (p.parseKeyword("inferred") || p.parseGreater())
      return failure();
    inferredResultTypes = p.getBuilder().getUnitAttr();
    return success();
  }

  // Otherwise, parse the explicit results.
  return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
                 p.parseColonTypeList(resultTypes) || p.parseRParen());
}

static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
                                          OperandRange resultOperands,
                                          TypeRange resultTypes,
                                          UnitAttr inferredResultTypes) {
  // Handle the case of inferred results.
  if (inferredResultTypes) {
    p << " -> <inferred>";
    return;
  }

  // Otherwise, handle the explicit results.
  if (!resultTypes.empty())
    p << " -> (" << resultOperands << " : " << resultTypes << ")";
}

//===----------------------------------------------------------------------===//
// pdl_interp::ForEachOp
//===----------------------------------------------------------------------===//

void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
                      Value range, Block *successor, bool initLoop) {
  build(builder, state, range, successor);
  if (initLoop) {
    // Create the block and the loop variable.
    // FIXME: Allow passing in a proper location for the loop variable.
    auto rangeType = range.getType().cast<pdl::RangeType>();
    state.regions.front()->emplaceBlock();
    state.regions.front()->addArgument(rangeType.getElementType(),
                                       state.location);
  }
}

ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
  // Parse the loop variable followed by type.
  OpAsmParser::Argument loopVariable;
  OpAsmParser::UnresolvedOperand operandInfo;
  if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
      parser.parseKeyword("in", " after loop variable") ||
      // Parse the operand (value range).
      parser.parseOperand(operandInfo))
    return failure();

  // Resolve the operand.
  Type rangeType = pdl::RangeType::get(loopVariable.type);
  if (parser.resolveOperand(operandInfo, rangeType, result.operands))
    return failure();

  // Parse the body region.
  Region *body = result.addRegion();
  Block *successor;
  if (parser.parseRegion(*body, loopVariable) ||
      parser.parseOptionalAttrDict(result.attributes) ||
      // Parse the successor.
      parser.parseArrow() || parser.parseSuccessor(successor))
    return failure();

  result.addSuccessors(successor);
  return success();
}

void ForEachOp::print(OpAsmPrinter &p) {
  BlockArgument arg = getLoopVariable();
  p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
  p.printOptionalAttrDict((*this)->getAttrs());
  p << " -> ";
  p.printSuccessor(getSuccessor());
}

LogicalResult ForEachOp::verify() {
  // Verify that the operation has exactly one argument.
  if (getRegion().getNumArguments() != 1)
    return emitOpError("requires exactly one argument");

  // Verify that the loop variable and the operand (value range)
  // have compatible types.
  BlockArgument arg = getLoopVariable();
  Type rangeType = pdl::RangeType::get(arg.getType());
  if (rangeType != getValues().getType())
    return emitOpError("operand must be a range of loop variable type");

  return success();
}

//===----------------------------------------------------------------------===//
// pdl_interp::FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                   FunctionType type, ArrayRef<NamedAttribute> attrs) {
  buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}

ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
  auto buildFuncType =
      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
         function_interface_impl::VariadicFlag,
         std::string &) { return builder.getFunctionType(argTypes, results); };

  return function_interface_impl::parseFunctionOp(
      parser, result, /*allowVariadic=*/false, buildFuncType);
}

void FuncOp::print(OpAsmPrinter &p) {
  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}

//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//

/// Given the result type of a `GetValueTypeOp`, return the expected input type.
static Type getGetValueTypeOpValueType(Type type) {
  Type valueTy = pdl::ValueType::get(type.getContext());
  return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
}

//===----------------------------------------------------------------------===//
// pdl_interp::SwitchAttributeOp
//===----------------------------------------------------------------------===//

LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }

//===----------------------------------------------------------------------===//
// pdl_interp::SwitchOperandCountOp
//===----------------------------------------------------------------------===//

LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }

//===----------------------------------------------------------------------===//
// pdl_interp::SwitchOperationNameOp
//===----------------------------------------------------------------------===//

LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }

//===----------------------------------------------------------------------===//
// pdl_interp::SwitchResultCountOp
//===----------------------------------------------------------------------===//

LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }

//===----------------------------------------------------------------------===//
// pdl_interp::SwitchTypeOp
//===----------------------------------------------------------------------===//

LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }

//===----------------------------------------------------------------------===//
// pdl_interp::SwitchTypesOp
//===----------------------------------------------------------------------===//

LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }

//===----------------------------------------------------------------------===//
// TableGen Auto-Generated Op and Interface Definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
