//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"

#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/bit.h"

using namespace mlir;

// TODO: generate these strings using ODS.
static constexpr const char kMemoryAccessAttrName[] = "memory_access";
static constexpr const char kSourceMemoryAccessAttrName[] =
    "source_memory_access";
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kClusterSize[] = "cluster_size";
static constexpr const char kControl[] = "control";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kGroupOperationAttrName[] = "group_operation";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kInitializerAttrName[] = "initializer";
static constexpr const char kInterfaceAttrName[] = "interface";
static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
static constexpr const char kSemanticsAttrName[] = "semantics";
static constexpr const char kSpecIdAttrName[] = "spec_id";
static constexpr const char kTypeAttrName[] = "type";
static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
static constexpr const char kCompositeSpecConstituentsName[] = "constituents";

//===----------------------------------------------------------------------===//
// Common utility functions
//===----------------------------------------------------------------------===//

/// Returns true if the given op is a function-like op or nested in a
/// function-like op without a module-like op in the middle.
static bool isNestedInFunctionLikeOp(Operation *op) {
  if (!op)
    return false;
  if (op->hasTrait<OpTrait::SymbolTable>())
    return false;
  if (op->hasTrait<OpTrait::FunctionLike>())
    return true;
  return isNestedInFunctionLikeOp(op->getParentOp());
}

/// Returns true if the given op is an module-like op that maintains a symbol
/// table.
static bool isDirectInModuleLikeOp(Operation *op) {
  return op && op->hasTrait<OpTrait::SymbolTable>();
}

static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
  if (!constOp) {
    return failure();
  }
  auto valueAttr = constOp.value();
  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
  if (!integerValueAttr) {
    return failure();
  }

  if (integerValueAttr.getType().isSignlessInteger())
    value = integerValueAttr.getInt();
  else
    value = integerValueAttr.getSInt();

  return success();
}

template <typename Ty>
static ArrayAttr
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
                           function_ref<StringRef(Ty)> stringifyFn) {
  if (enumValues.empty()) {
    return nullptr;
  }
  SmallVector<StringRef, 1> enumValStrs;
  enumValStrs.reserve(enumValues.size());
  for (auto val : enumValues) {
    enumValStrs.emplace_back(stringifyFn(val));
  }
  return builder.getStrArrayAttr(enumValStrs);
}

/// Parses the next string attribute in `parser` as an enumerant of the given
/// `EnumClass`.
template <typename EnumClass>
static ParseResult
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
                 StringRef attrName = spirv::attributeName<EnumClass>()) {
  Attribute attrVal;
  NamedAttrList attr;
  auto loc = parser.getCurrentLocation();
  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
                            attrName, attr)) {
    return failure();
  }
  if (!attrVal.isa<StringAttr>()) {
    return parser.emitError(loc, "expected ")
           << attrName << " attribute specified as string";
  }
  auto attrOptional =
      spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
  if (!attrOptional) {
    return parser.emitError(loc, "invalid ")
           << attrName << " attribute specification: " << attrVal;
  }
  value = attrOptional.getValue();
  return success();
}

/// Parses the next string attribute in `parser` as an enumerant of the given
/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
/// attribute with the enum class's name as attribute name.
template <typename EnumClass>
static ParseResult
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
                 StringRef attrName = spirv::attributeName<EnumClass>()) {
  if (parseEnumStrAttr(value, parser)) {
    return failure();
  }
  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
                                   llvm::bit_cast<int32_t>(value)));
  return success();
}

/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
/// the enum class's name as attribute name.
template <typename EnumClass>
static ParseResult
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
                     OperationState &state,
                     StringRef attrName = spirv::attributeName<EnumClass>()) {
  if (parseEnumKeywordAttr(value, parser)) {
    return failure();
  }
  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
                                   llvm::bit_cast<int32_t>(value)));
  return success();
}

/// Parses Function, Selection and Loop control attributes. If no control is
/// specified, "None" is used as a default.
template <typename EnumClass>
static ParseResult
parseControlAttribute(OpAsmParser &parser, OperationState &state,
                      StringRef attrName = spirv::attributeName<EnumClass>()) {
  if (succeeded(parser.parseOptionalKeyword(kControl))) {
    EnumClass control;
    if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
        parser.parseRParen())
      return failure();
    return success();
  }
  // Set control to "None" otherwise.
  Builder builder = parser.getBuilder();
  state.addAttribute(attrName, builder.getI32IntegerAttr(0));
  return success();
}

/// Parses optional memory access attributes attached to a memory access
/// operand/pointer. Specifically, parses the following syntax:
///     (`[` memory-access `]`)?
/// where:
///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
///         integer-literal | `"NonTemporal"`
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
                                               OperationState &state) {
  // Parse an optional list of attributes staring with '['
  if (parser.parseOptionalLSquare()) {
    // Nothing to do
    return success();
  }

  spirv::MemoryAccess memoryAccessAttr;
  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
                       kMemoryAccessAttrName)) {
    return failure();
  }

  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
    // Parse integer attribute for alignment.
    Attribute alignmentAttr;
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.parseComma() ||
        parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
                              state.attributes)) {
      return failure();
    }
  }
  return parser.parseRSquare();
}

// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                                     OperationState &state) {
  // Parse an optional list of attributes staring with '['
  if (parser.parseOptionalLSquare()) {
    // Nothing to do
    return success();
  }

  spirv::MemoryAccess memoryAccessAttr;
  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
                       kSourceMemoryAccessAttrName)) {
    return failure();
  }

  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
    // Parse integer attribute for alignment.
    Attribute alignmentAttr;
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.parseComma() ||
        parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
                              state.attributes)) {
      return failure();
    }
  }
  return parser.parseRSquare();
}

template <typename MemoryOpTy>
static void printMemoryAccessAttribute(
    MemoryOpTy memoryOp, OpAsmPrinter &printer,
    SmallVectorImpl<StringRef> &elidedAttrs,
    Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
    Optional<uint32_t> alignmentAttrValue = None) {
  // Print optional memory access attribute.
  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                              : memoryOp.memory_access())) {
    elidedAttrs.push_back(kMemoryAccessAttrName);

    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";

    if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
      // Print integer alignment attribute.
      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                               : memoryOp.alignment())) {
        elidedAttrs.push_back(kAlignmentAttrName);
        printer << ", " << alignment;
      }
    }
    printer << "]";
  }
  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
}

// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static void printSourceMemoryAccessAttribute(
    MemoryOpTy memoryOp, OpAsmPrinter &printer,
    SmallVectorImpl<StringRef> &elidedAttrs,
    Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
    Optional<uint32_t> alignmentAttrValue = None) {

  printer << ", ";

  // Print optional memory access attribute.
  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                              : memoryOp.memory_access())) {
    elidedAttrs.push_back(kSourceMemoryAccessAttrName);

    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";

    if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
      // Print integer alignment attribute.
      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                               : memoryOp.alignment())) {
        elidedAttrs.push_back(kSourceAlignmentAttrName);
        printer << ", " << alignment;
      }
    }
    printer << "]";
  }
  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
}

static ParseResult parseImageOperands(OpAsmParser &parser,
                                      spirv::ImageOperandsAttr &attr) {
  // Expect image operands
  if (parser.parseOptionalLSquare())
    return success();

  spirv::ImageOperands imageOperands;
  if (parseEnumStrAttr(imageOperands, parser))
    return failure();

  attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);

  return parser.parseRSquare();
}

static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
                               spirv::ImageOperandsAttr attr) {
  if (attr) {
    auto strImageOperands = stringifyImageOperands(attr.getValue());
    printer << "[\"" << strImageOperands << "\"]";
  }
}

template <typename Op>
static LogicalResult verifyImageOperands(Op imageOp,
                                         spirv::ImageOperandsAttr attr,
                                         Operation::operand_range operands) {
  if (!attr) {
    if (operands.empty())
      return success();

    return imageOp.emitError("the Image Operands should encode what operands "
                             "follow, as per Image Operands");
  }

  // TODO: Add the validation rules for the following Image Operands.
  spirv::ImageOperands noSupportOperands =
      spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
      spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
      spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
      spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
      spirv::ImageOperands::MakeTexelAvailable |
      spirv::ImageOperands::MakeTexelVisible |
      spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;

  if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
    llvm_unreachable("unimplemented operands of Image Operands");

  return success();
}

static LogicalResult verifyCastOp(Operation *op,
                                  bool requireSameBitWidth = true,
                                  bool skipBitWidthCheck = false) {
  // Some CastOps have no limit on bit widths for result and operand type.
  if (skipBitWidthCheck)
    return success();

  Type operandType = op->getOperand(0).getType();
  Type resultType = op->getResult(0).getType();

  // ODS checks that result type and operand type have the same shape.
  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
    operandType = vectorType.getElementType();
    resultType = resultType.cast<VectorType>().getElementType();
  }

  if (auto coopMatrixType =
          operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
    operandType = coopMatrixType.getElementType();
    resultType =
        resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
  }

  auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
  auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
  auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;

  if (requireSameBitWidth) {
    if (!isSameBitWidth) {
      return op->emitOpError(
                 "expected the same bit widths for operand type and result "
                 "type, but provided ")
             << operandType << " and " << resultType;
    }
    return success();
  }

  if (isSameBitWidth) {
    return op->emitOpError(
               "expected the different bit widths for operand type and result "
               "type, but provided ")
           << operandType << " and " << resultType;
  }
  return success();
}

template <typename MemoryOpTy>
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
  // ODS checks for attributes values. Just need to verify that if the
  // memory-access attribute is Aligned, then the alignment attribute must be
  // present.
  auto *op = memoryOp.getOperation();
  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
  if (!memAccessAttr) {
    // Alignment attribute shouldn't be present if memory access attribute is
    // not present.
    if (op->getAttr(kAlignmentAttrName)) {
      return memoryOp.emitOpError(
          "invalid alignment specification without aligned memory access "
          "specification");
    }
    return success();
  }

  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());

  if (!memAccess) {
    return memoryOp.emitOpError("invalid memory access specifier: ")
           << memAccessVal;
  }

  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
    if (!op->getAttr(kAlignmentAttrName)) {
      return memoryOp.emitOpError("missing alignment value");
    }
  } else {
    if (op->getAttr(kAlignmentAttrName)) {
      return memoryOp.emitOpError(
          "invalid alignment specification with non-aligned memory access "
          "specification");
    }
  }
  return success();
}

// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
  // ODS checks for attributes values. Just need to verify that if the
  // memory-access attribute is Aligned, then the alignment attribute must be
  // present.
  auto *op = memoryOp.getOperation();
  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
  if (!memAccessAttr) {
    // Alignment attribute shouldn't be present if memory access attribute is
    // not present.
    if (op->getAttr(kSourceAlignmentAttrName)) {
      return memoryOp.emitOpError(
          "invalid alignment specification without aligned memory access "
          "specification");
    }
    return success();
  }

  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());

  if (!memAccess) {
    return memoryOp.emitOpError("invalid memory access specifier: ")
           << memAccessVal;
  }

  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
    if (!op->getAttr(kSourceAlignmentAttrName)) {
      return memoryOp.emitOpError("missing alignment value");
    }
  } else {
    if (op->getAttr(kSourceAlignmentAttrName)) {
      return memoryOp.emitOpError(
          "invalid alignment specification with non-aligned memory access "
          "specification");
    }
  }
  return success();
}

static LogicalResult
verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
  // According to the SPIR-V specification:
  // "Despite being a mask and allowing multiple bits to be combined, it is
  // invalid for more than one of these four bits to be set: Acquire, Release,
  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
  // Release semantics is done by setting the AcquireRelease bit, not by setting
  // two bits."
  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
                        spirv::MemorySemantics::Release |
                        spirv::MemorySemantics::AcquireRelease |
                        spirv::MemorySemantics::SequentiallyConsistent;

  auto bitCount = llvm::countPopulation(
      static_cast<uint32_t>(memorySemantics & atMostOneInSet));
  if (bitCount > 1) {
    return op->emitError(
        "expected at most one of these four memory constraints "
        "to be set: `Acquire`, `Release`,"
        "`AcquireRelease` or `SequentiallyConsistent`");
  }
  return success();
}

template <typename LoadStoreOpTy>
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
                                                   Value val) {
  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
  // type of the pointer and the type of the value are the same
  //
  // TODO: Check that the value type satisfies restrictions of
  // SPIR-V OpLoad/OpStore operations
  if (val.getType() !=
      ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
    return op.emitOpError("mismatch in result type and pointer type");
  }
  return success();
}

template <typename BlockReadWriteOpTy>
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
                                                        Value ptr, Value val) {
  auto valType = val.getType();
  if (auto valVecTy = valType.dyn_cast<VectorType>())
    valType = valVecTy.getElementType();

  if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
    return op.emitOpError("mismatch in result type and pointer type");
  }
  return success();
}

static ParseResult parseVariableDecorations(OpAsmParser &parser,
                                            OperationState &state) {
  auto builtInName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::BuiltIn));
  if (succeeded(parser.parseOptionalKeyword("bind"))) {
    Attribute set, binding;
    // Parse optional descriptor binding
    auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
        stringifyDecoration(spirv::Decoration::DescriptorSet));
    auto bindingName = llvm::convertToSnakeFromCamelCase(
        stringifyDecoration(spirv::Decoration::Binding));
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.parseLParen() ||
        parser.parseAttribute(set, i32Type, descriptorSetName,
                              state.attributes) ||
        parser.parseComma() ||
        parser.parseAttribute(binding, i32Type, bindingName,
                              state.attributes) ||
        parser.parseRParen()) {
      return failure();
    }
  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
    StringAttr builtIn;
    if (parser.parseLParen() ||
        parser.parseAttribute(builtIn, builtInName, state.attributes) ||
        parser.parseRParen()) {
      return failure();
    }
  }

  // Parse other attributes
  if (parser.parseOptionalAttrDict(state.attributes))
    return failure();

  return success();
}

static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
                                     SmallVectorImpl<StringRef> &elidedAttrs) {
  // Print optional descriptor binding
  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::DescriptorSet));
  auto bindingName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::Binding));
  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
  if (descriptorSet && binding) {
    elidedAttrs.push_back(descriptorSetName);
    elidedAttrs.push_back(bindingName);
    printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
            << ")";
  }

  // Print BuiltIn attribute if present
  auto builtInName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::BuiltIn));
  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
    printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
    elidedAttrs.push_back(builtInName);
  }

  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}

// Get bit width of types.
static unsigned getBitWidth(Type type) {
  if (type.isa<spirv::PointerType>()) {
    // Just return 64 bits for pointer types for now.
    // TODO: Make sure not caller relies on the actual pointer width value.
    return 64;
  }

  if (type.isIntOrFloat())
    return type.getIntOrFloatBitWidth();

  if (auto vectorType = type.dyn_cast<VectorType>()) {
    assert(vectorType.getElementType().isIntOrFloat());
    return vectorType.getNumElements() *
           vectorType.getElementType().getIntOrFloatBitWidth();
  }
  llvm_unreachable("unhandled bit width computation for type");
}

/// Walks the given type hierarchy with the given indices, potentially down
/// to component granularity, to select an element type. Returns null type and
/// emits errors with the given loc on failure.
static Type
getElementType(Type type, ArrayRef<int32_t> indices,
               function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
  if (indices.empty()) {
    emitErrorFn("expected at least one index for spv.CompositeExtract");
    return nullptr;
  }

  for (auto index : indices) {
    if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
      if (cType.hasCompileTimeKnownNumElements() &&
          (index < 0 ||
           static_cast<uint64_t>(index) >= cType.getNumElements())) {
        emitErrorFn("index ") << index << " out of bounds for " << type;
        return nullptr;
      }
      type = cType.getElementType(index);
    } else {
      emitErrorFn("cannot extract from non-composite type ")
          << type << " with index " << index;
      return nullptr;
    }
  }
  return type;
}

static Type
getElementType(Type type, Attribute indices,
               function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
  auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
  if (!indicesArrayAttr) {
    emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
    return nullptr;
  }
  if (!indicesArrayAttr.size()) {
    emitErrorFn("expected at least one index for spv.CompositeExtract");
    return nullptr;
  }

  SmallVector<int32_t, 2> indexVals;
  for (auto indexAttr : indicesArrayAttr) {
    auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
    if (!indexIntAttr) {
      emitErrorFn("expected an 32-bit integer for index, but found '")
          << indexAttr << "'";
      return nullptr;
    }
    indexVals.push_back(indexIntAttr.getInt());
  }
  return getElementType(type, indexVals, emitErrorFn);
}

static Type getElementType(Type type, Attribute indices, Location loc) {
  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
    return ::mlir::emitError(loc, err);
  };
  return getElementType(type, indices, errorFn);
}

static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
                           llvm::SMLoc loc) {
  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
    return parser.emitError(loc, err);
  };
  return getElementType(type, indices, errorFn);
}

/// Returns true if the given `block` only contains one `spv.mlir.merge` op.
static inline bool isMergeBlock(Block &block) {
  return !block.empty() && std::next(block.begin()) == block.end() &&
         isa<spirv::MergeOp>(block.front());
}

//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//

// Parses an atomic update op. If the update op does not take a value (like
// AtomicIIncrement) `hasValue` must be false.
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
                                       OperationState &state, bool hasValue) {
  spirv::Scope scope;
  spirv::MemorySemantics memoryScope;
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  OpAsmParser::OperandType ptrInfo, valueInfo;
  Type type;
  llvm::SMLoc loc;
  if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
      parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
      parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
      parser.getCurrentLocation(&loc) || parser.parseColonType(type))
    return failure();

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType)
    return parser.emitError(loc, "expected pointer type");

  SmallVector<Type, 2> operandTypes;
  operandTypes.push_back(ptrType);
  if (hasValue)
    operandTypes.push_back(ptrType.getPointeeType());
  if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
                             state.operands))
    return failure();
  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}

// Prints an atomic update op.
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
  printer << " \"";
  auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
  printer << spirv::stringifyScope(
                 static_cast<spirv::Scope>(scopeAttr.getInt()))
          << "\" \"";
  auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
  printer << spirv::stringifyMemorySemantics(
                 static_cast<spirv::MemorySemantics>(
                     memorySemanticsAttr.getInt()))
          << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
}

template <typename T>
static StringRef stringifyTypeName();

template <>
StringRef stringifyTypeName<IntegerType>() {
  return "integer";
}

template <>
StringRef stringifyTypeName<FloatType>() {
  return "float";
}

// Verifies an atomic update op.
template <typename ExpectedElementType>
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
  auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
  auto elementType = ptrType.getPointeeType();
  if (!elementType.isa<ExpectedElementType>())
    return op->emitOpError() << "pointer operand must point to an "
                             << stringifyTypeName<ExpectedElementType>()
                             << " value, found " << elementType;

  if (op->getNumOperands() > 1) {
    auto valueType = op->getOperand(1).getType();
    if (valueType != elementType)
      return op->emitOpError("expected value to have the same type as the "
                             "pointer operand's pointee type ")
             << elementType << ", but found " << valueType;
  }
  auto memorySemantics = static_cast<spirv::MemorySemantics>(
      op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
  if (failed(verifyMemorySemantics(op, memorySemantics))) {
    return failure();
  }
  return success();
}

static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
                                                    OperationState &state) {
  spirv::Scope executionScope;
  spirv::GroupOperation groupOperation;
  OpAsmParser::OperandType valueInfo;
  if (parseEnumStrAttr(executionScope, parser, state,
                       kExecutionScopeAttrName) ||
      parseEnumStrAttr(groupOperation, parser, state,
                       kGroupOperationAttrName) ||
      parser.parseOperand(valueInfo))
    return failure();

  Optional<OpAsmParser::OperandType> clusterSizeInfo;
  if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
    clusterSizeInfo = OpAsmParser::OperandType();
    if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
        parser.parseRParen())
      return failure();
  }

  Type resultType;
  if (parser.parseColonType(resultType))
    return failure();

  if (parser.resolveOperand(valueInfo, resultType, state.operands))
    return failure();

  if (clusterSizeInfo.hasValue()) {
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
      return failure();
  }

  return parser.addTypeToList(resultType, state.types);
}

static void printGroupNonUniformArithmeticOp(Operation *groupOp,
                                             OpAsmPrinter &printer) {
  printer << " \""
          << stringifyScope(static_cast<spirv::Scope>(
                 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
                     .getInt()))
          << "\" \""
          << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
                 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
                     .getInt()))
          << "\" " << groupOp->getOperand(0);

  if (groupOp->getNumOperands() > 1)
    printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
  printer << " : " << groupOp->getResult(0).getType();
}

static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
  spirv::Scope scope = static_cast<spirv::Scope>(
      groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return groupOp->emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
      groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
  if (operation == spirv::GroupOperation::ClusteredReduce &&
      groupOp->getNumOperands() == 1)
    return groupOp->emitOpError("cluster size operand must be provided for "
                                "'ClusteredReduce' group operation");
  if (groupOp->getNumOperands() > 1) {
    Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
    int32_t clusterSize = 0;

    // TODO: support specialization constant here.
    if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
      return groupOp->emitOpError(
          "cluster size operand must come from a constant op");

    if (!llvm::isPowerOf2_32(clusterSize))
      return groupOp->emitOpError(
          "cluster size operand must be a power of two");
  }
  return success();
}

static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
  OpAsmParser::OperandType operandInfo;
  Type type;
  if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
      parser.resolveOperands(operandInfo, type, state.operands)) {
    return failure();
  }
  state.addTypes(type);
  return success();
}

static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
  printer << ' ' << unaryOp->getOperand(0) << " : "
          << unaryOp->getOperand(0).getType();
}

/// Result of a logical op must be a scalar or vector of boolean type.
static Type getUnaryOpResultType(Builder &builder, Type operandType) {
  Type resultType = builder.getIntegerType(1);
  if (auto vecType = operandType.dyn_cast<VectorType>()) {
    return VectorType::get(vecType.getNumElements(), resultType);
  }
  return resultType;
}

static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
                                       OperationState &state) {
  OpAsmParser::OperandType operandInfo;
  Type type;
  if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
      parser.resolveOperand(operandInfo, type, state.operands)) {
    return failure();
  }
  state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
  return success();
}

static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
                                        OperationState &result) {
  SmallVector<OpAsmParser::OperandType, 2> ops;
  Type type;
  if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
      parser.resolveOperands(ops, type, result.operands)) {
    return failure();
  }
  result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
  return success();
}

static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
  printer << ' ' << logicalOp->getOperands() << " : "
          << logicalOp->getOperand(0).getType();
}

static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  Type baseType;
  Type shiftType;
  auto loc = parser.getCurrentLocation();

  if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
      parser.parseType(baseType) || parser.parseComma() ||
      parser.parseType(shiftType) ||
      parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
                             state.operands)) {
    return failure();
  }
  state.addTypes(baseType);
  return success();
}

static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
  Value base = op->getOperand(0);
  Value shift = op->getOperand(1);
  printer << ' ' << base << ", " << shift << " : " << base.getType() << ", "
          << shift.getType();
}

static LogicalResult verifyShiftOp(Operation *op) {
  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
    return op->emitError("expected the same type for the first operand and "
                         "result, but provided ")
           << op->getOperand(0).getType() << " and "
           << op->getResult(0).getType();
  }
  return success();
}

static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
                                 Value lhs, Value rhs) {
  assert(lhs.getType() == rhs.getType());

  Type boolType = builder.getI1Type();
  if (auto vecType = lhs.getType().dyn_cast<VectorType>())
    boolType = VectorType::get(vecType.getShape(), boolType);
  state.addTypes(boolType);

  state.addOperands({lhs, rhs});
}

static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
                                Value value) {
  Type boolType = builder.getI1Type();
  if (auto vecType = value.getType().dyn_cast<VectorType>())
    boolType = VectorType::get(vecType.getShape(), boolType);
  state.addTypes(boolType);

  state.addOperands(value);
}

//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//

static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType) {
    emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
                       "to composite type, but provided ")
        << type;
    return nullptr;
  }

  auto resultType = ptrType.getPointeeType();
  auto resultStorageClass = ptrType.getStorageClass();
  int32_t index = 0;

  for (auto indexSSA : indices) {
    auto cType = resultType.dyn_cast<spirv::CompositeType>();
    if (!cType) {
      emitError(baseLoc,
                "'spv.AccessChain' op cannot extract from non-composite type ")
          << resultType << " with index " << index;
      return nullptr;
    }
    index = 0;
    if (resultType.isa<spirv::StructType>()) {
      Operation *op = indexSSA.getDefiningOp();
      if (!op) {
        emitError(baseLoc, "'spv.AccessChain' op index must be an "
                           "integer spv.Constant to access "
                           "element of spv.struct");
        return nullptr;
      }

      // TODO: this should be relaxed to allow
      // integer literals of other bitwidths.
      if (failed(extractValueFromConstOp(op, index))) {
        emitError(baseLoc,
                  "'spv.AccessChain' index must be an integer spv.Constant to "
                  "access element of spv.struct, but provided ")
            << op->getName();
        return nullptr;
      }
      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
        emitError(baseLoc, "'spv.AccessChain' op index ")
            << index << " out of bounds for " << resultType;
        return nullptr;
      }
    }
    resultType = cType.getElementType(index);
  }
  return spirv::PointerType::get(resultType, resultStorageClass);
}

void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
                                 Value basePtr, ValueRange indices) {
  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
  assert(type && "Unable to deduce return type based on basePtr and indices");
  build(builder, state, type, basePtr, indices);
}

static ParseResult parseAccessChainOp(OpAsmParser &parser,
                                      OperationState &state) {
  OpAsmParser::OperandType ptrInfo;
  SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
  Type type;
  auto loc = parser.getCurrentLocation();
  SmallVector<Type, 4> indicesTypes;

  if (parser.parseOperand(ptrInfo) ||
      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
      parser.parseColonType(type) ||
      parser.resolveOperand(ptrInfo, type, state.operands)) {
    return failure();
  }

  // Check that the provided indices list is not empty before parsing their
  // type list.
  if (indicesInfo.empty()) {
    return emitError(state.location, "'spv.AccessChain' op expected at "
                                     "least one index ");
  }

  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
    return failure();

  // Check that the indices types list is not empty and that it has a one-to-one
  // mapping to the provided indices.
  if (indicesTypes.size() != indicesInfo.size()) {
    return emitError(state.location, "'spv.AccessChain' op indices "
                                     "types' count must be equal to indices "
                                     "info count");
  }

  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
    return failure();

  auto resultType = getElementPtrType(
      type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
  if (!resultType) {
    return failure();
  }

  state.addTypes(resultType);
  return success();
}

template <typename Op>
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
  printer << ' ' << op.base_ptr() << '[' << indices
          << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
}

static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
  printAccessChain(op, op.indices(), printer);
}

template <typename Op>
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
  auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
                                      indices, accessChainOp.getLoc());
  if (!resultType)
    return failure();

  auto providedResultType =
      accessChainOp.getType().template dyn_cast<spirv::PointerType>();
  if (!providedResultType)
    return accessChainOp.emitOpError(
               "result type must be a pointer, but provided")
           << providedResultType;

  if (resultType != providedResultType)
    return accessChainOp.emitOpError("invalid result type: expected ")
           << resultType << ", but provided " << providedResultType;

  return success();
}

static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
  return verifyAccessChain(accessChainOp, accessChainOp.indices());
}

//===----------------------------------------------------------------------===//
// spv.mlir.addressof
//===----------------------------------------------------------------------===//

void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
                               spirv::GlobalVariableOp var) {
  build(builder, state, var.type(), SymbolRefAttr::get(var));
}

static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
      SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
                                           addressOfOp.variableAttr()));
  if (!varOp) {
    return addressOfOp.emitOpError("expected spv.GlobalVariable symbol");
  }
  if (addressOfOp.pointer().getType() != varOp.type()) {
    return addressOfOp.emitOpError(
        "result type mismatch with the referenced global variable's type");
  }
  return success();
}

template <typename T>
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
          << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
          << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
          << atomOp.getOperands() << " : " << atomOp.pointer().getType();
}

static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
                                                  OperationState &state) {
  spirv::Scope memoryScope;
  spirv::MemorySemantics equalSemantics, unequalSemantics;
  SmallVector<OpAsmParser::OperandType, 3> operandInfo;
  Type type;
  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
      parseEnumStrAttr(equalSemantics, parser, state,
                       kEqualSemanticsAttrName) ||
      parseEnumStrAttr(unequalSemantics, parser, state,
                       kUnequalSemanticsAttrName) ||
      parser.parseOperandList(operandInfo, 3))
    return failure();

  auto loc = parser.getCurrentLocation();
  if (parser.parseColonType(type))
    return failure();

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType)
    return parser.emitError(loc, "expected pointer type");

  if (parser.resolveOperands(
          operandInfo,
          {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
          parser.getNameLoc(), state.operands))
    return failure();

  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}

template <typename T>
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
  // According to the spec:
  // "The type of Value must be the same as Result Type. The type of the value
  // pointed to by Pointer must be the same as Result Type. This type must also
  // match the type of Comparator."
  if (atomOp.getType() != atomOp.value().getType())
    return atomOp.emitOpError("value operand must have the same type as the op "
                              "result, but found ")
           << atomOp.value().getType() << " vs " << atomOp.getType();

  if (atomOp.getType() != atomOp.comparator().getType())
    return atomOp.emitOpError(
               "comparator operand must have the same type as the op "
               "result, but found ")
           << atomOp.comparator().getType() << " vs " << atomOp.getType();

  Type pointeeType = atomOp.pointer()
                         .getType()
                         .template cast<spirv::PointerType>()
                         .getPointeeType();
  if (atomOp.getType() != pointeeType)
    return atomOp.emitOpError(
               "pointer operand's pointee type must have the same "
               "as the op result type, but found ")
           << pointeeType << " vs " << atomOp.getType();

  // TODO: Unequal cannot be set to Release or Acquire and Release.
  // In addition, Unequal cannot be set to a stronger memory-order then Equal.

  return success();
}

//===----------------------------------------------------------------------===//
// spv.AtomicExchange
//===----------------------------------------------------------------------===//

static void print(spirv::AtomicExchangeOp atomOp, OpAsmPrinter &printer) {
  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
          << stringifyMemorySemantics(atomOp.semantics()) << "\" "
          << atomOp.getOperands() << " : " << atomOp.pointer().getType();
}

static ParseResult parseAtomicExchangeOp(OpAsmParser &parser,
                                         OperationState &state) {
  spirv::Scope memoryScope;
  spirv::MemorySemantics semantics;
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  Type type;
  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
      parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
      parser.parseOperandList(operandInfo, 2))
    return failure();

  auto loc = parser.getCurrentLocation();
  if (parser.parseColonType(type))
    return failure();

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType)
    return parser.emitError(loc, "expected pointer type");

  if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
                             parser.getNameLoc(), state.operands))
    return failure();

  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}

static LogicalResult verify(spirv::AtomicExchangeOp atomOp) {
  if (atomOp.getType() != atomOp.value().getType())
    return atomOp.emitOpError("value operand must have the same type as the op "
                              "result, but found ")
           << atomOp.value().getType() << " vs " << atomOp.getType();

  Type pointeeType =
      atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
  if (atomOp.getType() != pointeeType)
    return atomOp.emitOpError(
               "pointer operand's pointee type must have the same "
               "as the op result type, but found ")
           << pointeeType << " vs " << atomOp.getType();

  return success();
}

//===----------------------------------------------------------------------===//
// spv.BitcastOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::BitcastOp bitcastOp) {
  // TODO: The SPIR-V spec validation rules are different for different
  // versions.
  auto operandType = bitcastOp.operand().getType();
  auto resultType = bitcastOp.result().getType();
  if (operandType == resultType) {
    return bitcastOp.emitError(
        "result type must be different from operand type");
  }
  if (operandType.isa<spirv::PointerType>() &&
      !resultType.isa<spirv::PointerType>()) {
    return bitcastOp.emitError(
        "unhandled bit cast conversion from pointer type to non-pointer type");
  }
  if (!operandType.isa<spirv::PointerType>() &&
      resultType.isa<spirv::PointerType>()) {
    return bitcastOp.emitError(
        "unhandled bit cast conversion from non-pointer type to pointer type");
  }
  auto operandBitWidth = getBitWidth(operandType);
  auto resultBitWidth = getBitWidth(resultType);
  if (operandBitWidth != resultBitWidth) {
    return bitcastOp.emitOpError("mismatch in result type bitwidth ")
           << resultBitWidth << " and operand type bitwidth "
           << operandBitWidth;
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.BranchOp
//===----------------------------------------------------------------------===//

Optional<MutableOperandRange>
spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
  assert(index == 0 && "invalid successor index");
  return targetOperandsMutable();
}

//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//

Optional<MutableOperandRange>
spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
  assert(index < 2 && "invalid successor index");
  return index == kTrueIndex ? trueTargetOperandsMutable()
                             : falseTargetOperandsMutable();
}

static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
                                            OperationState &state) {
  auto &builder = parser.getBuilder();
  OpAsmParser::OperandType condInfo;
  Block *dest;

  // Parse the condition.
  Type boolTy = builder.getI1Type();
  if (parser.parseOperand(condInfo) ||
      parser.resolveOperand(condInfo, boolTy, state.operands))
    return failure();

  // Parse the optional branch weights.
  if (succeeded(parser.parseOptionalLSquare())) {
    IntegerAttr trueWeight, falseWeight;
    NamedAttrList weights;

    auto i32Type = builder.getIntegerType(32);
    if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
        parser.parseComma() ||
        parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
        parser.parseRSquare())
      return failure();

    state.addAttribute(kBranchWeightAttrName,
                       builder.getArrayAttr({trueWeight, falseWeight}));
  }

  // Parse the true branch.
  SmallVector<Value, 4> trueOperands;
  if (parser.parseComma() ||
      parser.parseSuccessorAndUseList(dest, trueOperands))
    return failure();
  state.addSuccessors(dest);
  state.addOperands(trueOperands);

  // Parse the false branch.
  SmallVector<Value, 4> falseOperands;
  if (parser.parseComma() ||
      parser.parseSuccessorAndUseList(dest, falseOperands))
    return failure();
  state.addSuccessors(dest);
  state.addOperands(falseOperands);
  state.addAttribute(
      spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
      builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
                                static_cast<int32_t>(falseOperands.size())}));

  return success();
}

static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
  printer << ' ' << branchOp.condition();

  if (auto weights = branchOp.branch_weights()) {
    printer << " [";
    llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
      printer << a.cast<IntegerAttr>().getInt();
    });
    printer << "]";
  }

  printer << ", ";
  printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
                                   branchOp.getTrueBlockArguments());
  printer << ", ";
  printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
                                   branchOp.getFalseBlockArguments());
}

static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
  if (auto weights = branchOp.branch_weights()) {
    if (weights->getValue().size() != 2) {
      return branchOp.emitOpError("must have exactly two branch weights");
    }
    if (llvm::all_of(*weights, [](Attribute attr) {
          return attr.cast<IntegerAttr>().getValue().isNullValue();
        }))
      return branchOp.emitOpError("branch weights cannot both be zero");
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.CompositeConstruct
//===----------------------------------------------------------------------===//

static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
                                             OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 4> operands;
  Type type;
  auto loc = parser.getCurrentLocation();

  if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
    return failure();
  }
  auto cType = type.dyn_cast<spirv::CompositeType>();
  if (!cType) {
    return parser.emitError(
               loc, "result type must be a composite type, but provided ")
           << type;
  }

  if (cType.hasCompileTimeKnownNumElements() &&
      operands.size() != cType.getNumElements()) {
    return parser.emitError(loc, "has incorrect number of operands: expected ")
           << cType.getNumElements() << ", but provided " << operands.size();
  }
  // TODO: Add support for constructing a vector type from the vector operands.
  // According to the spec: "for constructing a vector, the operands may
  // also be vectors with the same component type as the Result Type component
  // type".
  SmallVector<Type, 4> elementTypes;
  elementTypes.reserve(operands.size());
  for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
    elementTypes.push_back(cType.getElementType(index));
  }
  state.addTypes(type);
  return parser.resolveOperands(operands, elementTypes, loc, state.operands);
}

static void print(spirv::CompositeConstructOp compositeConstructOp,
                  OpAsmPrinter &printer) {
  printer << " " << compositeConstructOp.constituents() << " : "
          << compositeConstructOp.getResult().getType();
}

static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
  auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
  SmallVector<Value, 4> constituents(compositeConstructOp.constituents());

  if (cType.isa<spirv::CooperativeMatrixNVType>()) {
    if (constituents.size() != 1)
      return compositeConstructOp.emitError(
                 "has incorrect number of operands: expected ")
             << "1, but provided " << constituents.size();
  } else if (constituents.size() != cType.getNumElements()) {
    return compositeConstructOp.emitError(
               "has incorrect number of operands: expected ")
           << cType.getNumElements() << ", but provided "
           << constituents.size();
  }

  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
    if (constituents[index].getType() != cType.getElementType(index)) {
      return compositeConstructOp.emitError(
                 "operand type mismatch: expected operand type ")
             << cType.getElementType(index) << ", but provided "
             << constituents[index].getType();
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//

void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
                                      Value composite,
                                      ArrayRef<int32_t> indices) {
  auto indexAttr = builder.getI32ArrayAttr(indices);
  auto elementType =
      getElementType(composite.getType(), indexAttr, state.location);
  if (!elementType) {
    return;
  }
  build(builder, state, elementType, composite, indexAttr);
}

static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
                                           OperationState &state) {
  OpAsmParser::OperandType compositeInfo;
  Attribute indicesAttr;
  Type compositeType;
  llvm::SMLoc attrLocation;

  if (parser.parseOperand(compositeInfo) ||
      parser.getCurrentLocation(&attrLocation) ||
      parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
      parser.parseColonType(compositeType) ||
      parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
    return failure();
  }

  Type resultType =
      getElementType(compositeType, indicesAttr, parser, attrLocation);
  if (!resultType) {
    return failure();
  }
  state.addTypes(resultType);
  return success();
}

static void print(spirv::CompositeExtractOp compositeExtractOp,
                  OpAsmPrinter &printer) {
  printer << ' ' << compositeExtractOp.composite()
          << compositeExtractOp.indices() << " : "
          << compositeExtractOp.composite().getType();
}

static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
  auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
  auto resultType = getElementType(compExOp.composite().getType(),
                                   indicesArrayAttr, compExOp.getLoc());
  if (!resultType)
    return failure();

  if (resultType != compExOp.getType()) {
    return compExOp.emitOpError("invalid result type: expected ")
           << resultType << " but provided " << compExOp.getType();
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.CompositeInsert
//===----------------------------------------------------------------------===//

void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
                                     Value object, Value composite,
                                     ArrayRef<int32_t> indices) {
  auto indexAttr = builder.getI32ArrayAttr(indices);
  build(builder, state, composite.getType(), object, composite, indexAttr);
}

static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
                                          OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 2> operands;
  Type objectType, compositeType;
  Attribute indicesAttr;
  auto loc = parser.getCurrentLocation();

  return failure(
      parser.parseOperandList(operands, 2) ||
      parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
      parser.parseColonType(objectType) ||
      parser.parseKeywordType("into", compositeType) ||
      parser.resolveOperands(operands, {objectType, compositeType}, loc,
                             state.operands) ||
      parser.addTypesToList(compositeType, state.types));
}

static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
  auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
  auto objectType =
      getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
                     compositeInsertOp.getLoc());
  if (!objectType)
    return failure();

  if (objectType != compositeInsertOp.object().getType()) {
    return compositeInsertOp.emitOpError("object operand type should be ")
           << objectType << ", but found "
           << compositeInsertOp.object().getType();
  }

  if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
    return compositeInsertOp.emitOpError("result type should be the same as "
                                         "the composite type, but found ")
           << compositeInsertOp.composite().getType() << " vs "
           << compositeInsertOp.getType();
  }

  return success();
}

static void print(spirv::CompositeInsertOp compositeInsertOp,
                  OpAsmPrinter &printer) {
  printer << " " << compositeInsertOp.object() << ", "
          << compositeInsertOp.composite() << compositeInsertOp.indices()
          << " : " << compositeInsertOp.object().getType() << " into "
          << compositeInsertOp.composite().getType();
}

//===----------------------------------------------------------------------===//
// spv.Constant
//===----------------------------------------------------------------------===//

static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
  Attribute value;
  if (parser.parseAttribute(value, kValueAttrName, state.attributes))
    return failure();

  Type type = value.getType();
  if (type.isa<NoneType, TensorType>()) {
    if (parser.parseColonType(type))
      return failure();
  }

  return parser.addTypeToList(type, state.types);
}

static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
  printer << ' ' << constOp.value();
  if (constOp.getType().isa<spirv::ArrayType>())
    printer << " : " << constOp.getType();
}

static LogicalResult verify(spirv::ConstantOp constOp) {
  auto opType = constOp.getType();
  auto value = constOp.value();
  auto valueType = value.getType();

  // ODS already generates checks to make sure the result type is valid. We just
  // need to additionally check that the value's attribute type is consistent
  // with the result type.
  if (value.isa<IntegerAttr, FloatAttr>()) {
    if (valueType != opType)
      return constOp.emitOpError("result type (")
             << opType << ") does not match value type (" << valueType << ")";
    return success();
  }
  if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
    if (valueType == opType)
      return success();
    auto arrayType = opType.dyn_cast<spirv::ArrayType>();
    auto shapedType = valueType.dyn_cast<ShapedType>();
    if (!arrayType) {
      return constOp.emitOpError(
          "must have spv.array result type for array value");
    }

    int numElements = arrayType.getNumElements();
    auto opElemType = arrayType.getElementType();
    while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
      numElements *= t.getNumElements();
      opElemType = t.getElementType();
    }
    if (!opElemType.isIntOrFloat())
      return constOp.emitOpError("only support nested array result type");

    auto valueElemType = shapedType.getElementType();
    if (valueElemType != opElemType) {
      return constOp.emitOpError("result element type (")
             << opElemType << ") does not match value element type ("
             << valueElemType << ")";
    }

    if (numElements != shapedType.getNumElements()) {
      return constOp.emitOpError("result number of elements (")
             << numElements << ") does not match value number of elements ("
             << shapedType.getNumElements() << ")";
    }
    return success();
  }
  if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
    auto arrayType = opType.dyn_cast<spirv::ArrayType>();
    if (!arrayType)
      return constOp.emitOpError(
          "must have spv.array result type for array value");
    Type elemType = arrayType.getElementType();
    for (Attribute element : attayAttr.getValue()) {
      if (element.getType() != elemType)
        return constOp.emitOpError("has array element whose type (")
               << element.getType()
               << ") does not match the result element type (" << elemType
               << ')';
    }
    return success();
  }
  return constOp.emitOpError("cannot have value of type ") << valueType;
}

bool spirv::ConstantOp::isBuildableWith(Type type) {
  // Must be valid SPIR-V type first.
  if (!type.isa<spirv::SPIRVType>())
    return false;

  if (isa<SPIRVDialect>(type.getDialect())) {
    // TODO: support constant struct
    return type.isa<spirv::ArrayType>();
  }

  return true;
}

spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
                                             OpBuilder &builder) {
  if (auto intType = type.dyn_cast<IntegerType>()) {
    unsigned width = intType.getWidth();
    if (width == 1)
      return builder.create<spirv::ConstantOp>(loc, type,
                                               builder.getBoolAttr(false));
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
  }
  if (auto floatType = type.dyn_cast<FloatType>()) {
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getFloatAttr(floatType, 0.0));
  }
  if (auto vectorType = type.dyn_cast<VectorType>()) {
    Type elemType = vectorType.getElementType();
    if (elemType.isa<IntegerType>()) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseElementsAttr::get(vectorType,
                                 IntegerAttr::get(elemType, 0.0).getValue()));
    }
    if (elemType.isa<FloatType>()) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseFPElementsAttr::get(vectorType,
                                   FloatAttr::get(elemType, 0.0).getValue()));
    }
  }

  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
}

spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
                                            OpBuilder &builder) {
  if (auto intType = type.dyn_cast<IntegerType>()) {
    unsigned width = intType.getWidth();
    if (width == 1)
      return builder.create<spirv::ConstantOp>(loc, type,
                                               builder.getBoolAttr(true));
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
  }
  if (auto floatType = type.dyn_cast<FloatType>()) {
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getFloatAttr(floatType, 1.0));
  }
  if (auto vectorType = type.dyn_cast<VectorType>()) {
    Type elemType = vectorType.getElementType();
    if (elemType.isa<IntegerType>()) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseElementsAttr::get(vectorType,
                                 IntegerAttr::get(elemType, 1.0).getValue()));
    }
    if (elemType.isa<FloatType>()) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseFPElementsAttr::get(vectorType,
                                   FloatAttr::get(elemType, 1.0).getValue()));
    }
  }

  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
}

void mlir::spirv::ConstantOp::getAsmResultNames(
    llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
  Type type = getType();

  SmallString<32> specialNameBuffer;
  llvm::raw_svector_ostream specialName(specialNameBuffer);
  specialName << "cst";

  IntegerType intTy = type.dyn_cast<IntegerType>();

  if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
    if (intTy && intTy.getWidth() == 1) {
      return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
    }

    if (intTy.isSignless()) {
      specialName << intCst.getInt();
    } else {
      specialName << intCst.getSInt();
    }
  }

  if (intTy || type.isa<FloatType>()) {
    specialName << '_' << type;
  }

  if (auto vecType = type.dyn_cast<VectorType>()) {
    specialName << "_vec_";
    specialName << vecType.getDimSize(0);

    Type elementType = vecType.getElementType();

    if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
      specialName << "x" << elementType;
    }
  }

  setNameFn(getResult(), specialName.str());
}

void mlir::spirv::AddressOfOp::getAsmResultNames(
    llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
  SmallString<32> specialNameBuffer;
  llvm::raw_svector_ostream specialName(specialNameBuffer);
  specialName << variable() << "_addr";
  setNameFn(getResult(), specialName.str());
}

//===----------------------------------------------------------------------===//
// spv.EntryPoint
//===----------------------------------------------------------------------===//

void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
                                spirv::ExecutionModel executionModel,
                                spirv::FuncOp function,
                                ArrayRef<Attribute> interfaceVars) {
  build(builder, state,
        spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
        SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
}

static ParseResult parseEntryPointOp(OpAsmParser &parser,
                                     OperationState &state) {
  spirv::ExecutionModel execModel;
  SmallVector<OpAsmParser::OperandType, 0> identifiers;
  SmallVector<Type, 0> idTypes;
  SmallVector<Attribute, 4> interfaceVars;

  FlatSymbolRefAttr fn;
  if (parseEnumStrAttr(execModel, parser, state) ||
      parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
    return failure();
  }

  if (!parser.parseOptionalComma()) {
    // Parse the interface variables
    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
          // The name of the interface variable attribute isnt important
          FlatSymbolRefAttr var;
          NamedAttrList attrs;
          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
            return failure();
          interfaceVars.push_back(var);
          return success();
        }))
      return failure();
  }
  state.addAttribute(kInterfaceAttrName,
                     parser.getBuilder().getArrayAttr(interfaceVars));
  return success();
}

static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
  printer << " \"" << stringifyExecutionModel(entryPointOp.execution_model())
          << "\" ";
  printer.printSymbolName(entryPointOp.fn());
  auto interfaceVars = entryPointOp.interface().getValue();
  if (!interfaceVars.empty()) {
    printer << ", ";
    llvm::interleaveComma(interfaceVars, printer);
  }
}

static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
  // verification.
  return success();
}

//===----------------------------------------------------------------------===//
// spv.ExecutionMode
//===----------------------------------------------------------------------===//

void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
                                   spirv::FuncOp function,
                                   spirv::ExecutionMode executionMode,
                                   ArrayRef<int32_t> params) {
  build(builder, state, SymbolRefAttr::get(function),
        spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
        builder.getI32ArrayAttr(params));
}

static ParseResult parseExecutionModeOp(OpAsmParser &parser,
                                        OperationState &state) {
  spirv::ExecutionMode execMode;
  Attribute fn;
  if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
      parseEnumStrAttr(execMode, parser, state)) {
    return failure();
  }

  SmallVector<int32_t, 4> values;
  Type i32Type = parser.getBuilder().getIntegerType(32);
  while (!parser.parseOptionalComma()) {
    NamedAttrList attr;
    Attribute value;
    if (parser.parseAttribute(value, i32Type, "value", attr)) {
      return failure();
    }
    values.push_back(value.cast<IntegerAttr>().getInt());
  }
  state.addAttribute(kValuesAttrName,
                     parser.getBuilder().getI32ArrayAttr(values));
  return success();
}

static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
  printer << " ";
  printer.printSymbolName(execModeOp.fn());
  printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
          << "\"";
  auto values = execModeOp.values();
  if (!values.size())
    return;
  printer << ", ";
  llvm::interleaveComma(values, printer, [&](Attribute a) {
    printer << a.cast<IntegerAttr>().getInt();
  });
}

//===----------------------------------------------------------------------===//
// spv.func
//===----------------------------------------------------------------------===//

static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 4> entryArgs;
  SmallVector<NamedAttrList, 4> argAttrs;
  SmallVector<NamedAttrList, 4> resultAttrs;
  SmallVector<Type, 4> argTypes;
  SmallVector<Type, 4> resultTypes;
  auto &builder = parser.getBuilder();

  // Parse the name as a symbol.
  StringAttr nameAttr;
  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             state.attributes))
    return failure();

  // Parse the function signature.
  bool isVariadic = false;
  if (function_like_impl::parseFunctionSignature(
          parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
          isVariadic, resultTypes, resultAttrs))
    return failure();

  auto fnType = builder.getFunctionType(argTypes, resultTypes);
  state.addAttribute(function_like_impl::getTypeAttrName(),
                     TypeAttr::get(fnType));

  // Parse the optional function control keyword.
  spirv::FunctionControl fnControl;
  if (parseEnumStrAttr(fnControl, parser, state))
    return failure();

  // If additional attributes are present, parse them.
  if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
    return failure();

  // Add the attributes to the function arguments.
  assert(argAttrs.size() == argTypes.size());
  assert(resultAttrs.size() == resultTypes.size());
  function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
                                           resultAttrs);

  // Parse the optional function body.
  auto *body = state.addRegion();
  OptionalParseResult result = parser.parseOptionalRegion(
      *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
  return failure(result.hasValue() && failed(*result));
}

static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
  // Print function name, signature, and control.
  printer << " ";
  printer.printSymbolName(fnOp.sym_name());
  auto fnType = fnOp.getType();
  function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
                                             /*isVariadic=*/false,
                                             fnType.getResults());
  printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
          << "\"";
  function_like_impl::printFunctionAttributes(
      printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
      {spirv::attributeName<spirv::FunctionControl>()});

  // Print the body if this is not an external function.
  Region &body = fnOp.body();
  if (!body.empty())
    printer.printRegion(body, /*printEntryBlockArgs=*/false,
                        /*printBlockTerminators=*/true);
}

LogicalResult spirv::FuncOp::verifyType() {
  auto type = getTypeAttr().getValue();
  if (!type.isa<FunctionType>())
    return emitOpError("requires '" + getTypeAttrName() +
                       "' attribute of function type");
  if (getType().getNumResults() > 1)
    return emitOpError("cannot have more than one result");
  return success();
}

LogicalResult spirv::FuncOp::verifyBody() {
  FunctionType fnType = getType();

  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
    if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
      if (fnType.getNumResults() != 0)
        return retOp.emitOpError("cannot be used in functions returning value");
    } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
      if (fnType.getNumResults() != 1)
        return retOp.emitOpError(
                   "returns 1 value but enclosing function requires ")
               << fnType.getNumResults() << " results";

      auto retOperandType = retOp.value().getType();
      auto fnResultType = fnType.getResult(0);
      if (retOperandType != fnResultType)
        return retOp.emitOpError(" return value's type (")
               << retOperandType << ") mismatch with function's result type ("
               << fnResultType << ")";
    }
    return WalkResult::advance();
  });

  // TODO: verify other bits like linkage type.

  return failure(walkResult.wasInterrupted());
}

void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
                          StringRef name, FunctionType type,
                          spirv::FunctionControl control,
                          ArrayRef<NamedAttribute> attrs) {
  state.addAttribute(SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(name));
  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
                     builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
  state.attributes.append(attrs.begin(), attrs.end());
  state.addRegion();
}

// CallableOpInterface
Region *spirv::FuncOp::getCallableRegion() {
  return isExternal() ? nullptr : &body();
}

// CallableOpInterface
ArrayRef<Type> spirv::FuncOp::getCallableResults() {
  return getType().getResults();
}

//===----------------------------------------------------------------------===//
// spv.FunctionCall
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
  auto fnName = functionCallOp.calleeAttr();

  auto funcOp =
      dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
          functionCallOp->getParentOp(), fnName));
  if (!funcOp) {
    return functionCallOp.emitOpError("callee function '")
           << fnName.getValue() << "' not found in nearest symbol table";
  }

  auto functionType = funcOp.getType();

  if (functionCallOp.getNumResults() > 1) {
    return functionCallOp.emitOpError(
               "expected callee function to have 0 or 1 result, but provided ")
           << functionCallOp.getNumResults();
  }

  if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
    return functionCallOp.emitOpError(
               "has incorrect number of operands for callee: expected ")
           << functionType.getNumInputs() << ", but provided "
           << functionCallOp.getNumOperands();
  }

  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
    if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
      return functionCallOp.emitOpError(
                 "operand type mismatch: expected operand type ")
             << functionType.getInput(i) << ", but provided "
             << functionCallOp.getOperand(i).getType() << " for operand number "
             << i;
    }
  }

  if (functionType.getNumResults() != functionCallOp.getNumResults()) {
    return functionCallOp.emitOpError(
               "has incorrect number of results has for callee: expected ")
           << functionType.getNumResults() << ", but provided "
           << functionCallOp.getNumResults();
  }

  if (functionCallOp.getNumResults() &&
      (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
    return functionCallOp.emitOpError("result type mismatch: expected ")
           << functionType.getResult(0) << ", but provided "
           << functionCallOp.getResult(0).getType();
  }

  return success();
}

CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
}

Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
  return arguments();
}

//===----------------------------------------------------------------------===//
// spv.GlobalVariable
//===----------------------------------------------------------------------===//

void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
                                    Type type, StringRef name,
                                    unsigned descriptorSet, unsigned binding) {
  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
      builder.getI32IntegerAttr(descriptorSet));
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
      builder.getI32IntegerAttr(binding));
}

void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
                                    Type type, StringRef name,
                                    spirv::BuiltIn builtin) {
  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
      builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
}

static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
                                         OperationState &state) {
  // Parse variable name.
  StringAttr nameAttr;
  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             state.attributes)) {
    return failure();
  }

  // Parse optional initializer
  if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
    FlatSymbolRefAttr initSymbol;
    if (parser.parseLParen() ||
        parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
                              state.attributes) ||
        parser.parseRParen())
      return failure();
  }

  if (parseVariableDecorations(parser, state)) {
    return failure();
  }

  Type type;
  auto loc = parser.getCurrentLocation();
  if (parser.parseColonType(type)) {
    return failure();
  }
  if (!type.isa<spirv::PointerType>()) {
    return parser.emitError(loc, "expected spv.ptr type");
  }
  state.addAttribute(kTypeAttrName, TypeAttr::get(type));

  return success();
}

static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
  auto *op = varOp.getOperation();
  SmallVector<StringRef, 4> elidedAttrs{
      spirv::attributeName<spirv::StorageClass>()};

  // Print variable name.
  printer << ' ';
  printer.printSymbolName(varOp.sym_name());
  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());

  // Print optional initializer
  if (auto initializer = varOp.initializer()) {
    printer << " " << kInitializerAttrName << '(';
    printer.printSymbolName(initializer.getValue());
    printer << ')';
    elidedAttrs.push_back(kInitializerAttrName);
  }

  elidedAttrs.push_back(kTypeAttrName);
  printVariableDecorations(op, printer, elidedAttrs);
  printer << " : " << varOp.type();
}

static LogicalResult verify(spirv::GlobalVariableOp varOp) {
  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
  // object. It cannot be Generic. It must be the same as the Storage Class
  // operand of the Result Type."
  // Also, Function storage class is reserved by spv.Variable.
  auto storageClass = varOp.storageClass();
  if (storageClass == spirv::StorageClass::Generic ||
      storageClass == spirv::StorageClass::Function) {
    return varOp.emitOpError("storage class cannot be '")
           << stringifyStorageClass(storageClass) << "'";
  }

  if (auto init =
          varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
    Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
        varOp->getParentOp(), init.getAttr());
    // TODO: Currently only variable initialization with specialization
    // constants and other variables is supported. They could be normal
    // constants in the module scope as well.
    if (!initOp ||
        !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
      return varOp.emitOpError("initializer must be result of a "
                               "spv.SpecConstant or spv.GlobalVariable op");
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.GroupBroadcast
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
  spirv::Scope scope = broadcastOp.execution_scope();
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return broadcastOp.emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
    if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
      return broadcastOp.emitOpError("localid is a vector and can be with only "
                                     " 2 or 3 components, actual number is ")
             << localIdTy.getNumElements();

  return success();
}

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformBallotOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
  spirv::Scope scope = ballotOp.execution_scope();
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return ballotOp.emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformBroadcast
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
  spirv::Scope scope = broadcastOp.execution_scope();
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return broadcastOp.emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  // SPIR-V spec: "Before version 1.5, Id must come from a
  // constant instruction.
  auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
  if (auto spirvModule = broadcastOp->getParentOfType<spirv::ModuleOp>())
    targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);

  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
    auto *idOp = broadcastOp.id().getDefiningOp();
    if (!idOp || !isa<spirv::ConstantOp,           // for normal constant
                      spirv::ReferenceOfOp>(idOp)) // for spec constant
      return broadcastOp.emitOpError("id must be the result of a constant op");
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.SubgroupBlockReadINTEL
//===----------------------------------------------------------------------===//

static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
                                                 OperationState &state) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  OpAsmParser::OperandType ptrInfo;
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
      parser.parseColon() || parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (auto valVecTy = elementType.dyn_cast<VectorType>())
    ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);

  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
    return failure();
  }

  state.addTypes(elementType);
  return success();
}

static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
                  OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs;
  printer << " " << blockReadOp.ptr();
  printer << " : " << blockReadOp.getType();
}

static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
  if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
                                                blockReadOp.value())))
    return failure();

  return success();
}

//===----------------------------------------------------------------------===//
// spv.SubgroupBlockWriteINTEL
//===----------------------------------------------------------------------===//

static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
                                                  OperationState &state) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  auto loc = parser.getCurrentLocation();
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) ||
      parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
      parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (auto valVecTy = elementType.dyn_cast<VectorType>())
    ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);

  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
                             state.operands)) {
    return failure();
  }
  return success();
}

static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
                  OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs;
  printer << " " << blockWriteOp.ptr() << ", " << blockWriteOp.value();
  printer << " : " << blockWriteOp.value().getType();
}

static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
  if (failed(verifyBlockReadWritePtrAndValTypes(
          blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
    return failure();

  return success();
}

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformElectOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
  spirv::Scope scope = groupOp.execution_scope();
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return groupOp.emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//

void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
                          Value basePtr, MemoryAccessAttr memoryAccess,
                          IntegerAttr alignment) {
  auto ptrType = basePtr.getType().cast<spirv::PointerType>();
  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
        alignment);
}

static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  OpAsmParser::OperandType ptrInfo;
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
      parseMemoryAccessAttributes(parser, state) ||
      parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
      parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
    return failure();
  }

  state.addTypes(elementType);
  return success();
}

static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
  auto *op = loadOp.getOperation();
  SmallVector<StringRef, 4> elidedAttrs;
  StringRef sc = stringifyStorageClass(
      loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
  printer << " \"" << sc << "\" " << loadOp.ptr();

  printMemoryAccessAttribute(loadOp, printer, elidedAttrs);

  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
  printer << " : " << loadOp.getType();
}

static LogicalResult verify(spirv::LoadOp loadOp) {
  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
  // type with fixed size; i.e., it cannot be, nor include, any
  // OpTypeRuntimeArray types."
  if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
                                           loadOp.value()))) {
    return failure();
  }
  return verifyMemoryAccessAttribute(loadOp);
}

//===----------------------------------------------------------------------===//
// spv.mlir.loop
//===----------------------------------------------------------------------===//

void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
  state.addAttribute("loop_control",
                     builder.getI32IntegerAttr(
                         static_cast<uint32_t>(spirv::LoopControl::None)));
  state.addRegion();
}

static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
  if (parseControlAttribute<spirv::LoopControl>(parser, state))
    return failure();
  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
                            /*argTypes=*/{});
}

static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
  auto *op = loopOp.getOperation();

  auto control = loopOp.loop_control();
  if (control != spirv::LoopControl::None)
    printer << " control(" << spirv::stringifyLoopControl(control) << ")";
  printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
                      /*printBlockTerminators=*/true);
}

/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
/// given `dstBlock`.
static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
  // Check that there is only one op in the `srcBlock`.
  if (!llvm::hasSingleElement(srcBlock))
    return false;

  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
  return branchOp && branchOp.getSuccessor() == &dstBlock;
}

static LogicalResult verify(spirv::LoopOp loopOp) {
  auto *op = loopOp.getOperation();

  // We need to verify that the blocks follow the following layout:
  //
  //                     +-------------+
  //                     | entry block |
  //                     +-------------+
  //                            |
  //                            v
  //                     +-------------+
  //                     | loop header | <-----+
  //                     +-------------+       |
  //                                           |
  //                           ...             |
  //                          \ | /            |
  //                            v              |
  //                    +---------------+      |
  //                    | loop continue | -----+
  //                    +---------------+
  //
  //                           ...
  //                          \ | /
  //                            v
  //                     +-------------+
  //                     | merge block |
  //                     +-------------+

  auto &region = op->getRegion(0);
  // Allow empty region as a degenerated case, which can come from
  // optimizations.
  if (region.empty())
    return success();

  // The last block is the merge block.
  Block &merge = region.back();
  if (!isMergeBlock(merge))
    return loopOp.emitOpError(
        "last block must be the merge block with only one 'spv.mlir.merge' op");

  if (std::next(region.begin()) == region.end())
    return loopOp.emitOpError(
        "must have an entry block branching to the loop header block");
  // The first block is the entry block.
  Block &entry = region.front();

  if (std::next(region.begin(), 2) == region.end())
    return loopOp.emitOpError(
        "must have a loop header block branched from the entry block");
  // The second block is the loop header block.
  Block &header = *std::next(region.begin(), 1);

  if (!hasOneBranchOpTo(entry, header))
    return loopOp.emitOpError(
        "entry block must only have one 'spv.Branch' op to the second block");

  if (std::next(region.begin(), 3) == region.end())
    return loopOp.emitOpError(
        "requires a loop continue block branching to the loop header block");
  // The second to last block is the loop continue block.
  Block &cont = *std::prev(region.end(), 2);

  // Make sure that we have a branch from the loop continue block to the loop
  // header block.
  if (llvm::none_of(
          llvm::seq<unsigned>(0, cont.getNumSuccessors()),
          [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
    return loopOp.emitOpError("second to last block must be the loop continue "
                              "block that branches to the loop header block");

  // Make sure that no other blocks (except the entry and loop continue block)
  // branches to the loop header block.
  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
                                      std::prev(region.end(), 2))) {
    for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
      if (block.getSuccessor(i) == &header) {
        return loopOp.emitOpError("can only have the entry and loop continue "
                                  "block branching to the loop header block");
      }
    }
  }

  return success();
}

Block *spirv::LoopOp::getEntryBlock() {
  assert(!body().empty() && "op region should not be empty!");
  return &body().front();
}

Block *spirv::LoopOp::getHeaderBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The second block is the loop header block.
  return &*std::next(body().begin());
}

Block *spirv::LoopOp::getContinueBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The second to last block is the loop continue block.
  return &*std::prev(body().end(), 2);
}

Block *spirv::LoopOp::getMergeBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The last block is the loop merge block.
  return &body().back();
}

void spirv::LoopOp::addEntryAndMergeBlock() {
  assert(body().empty() && "entry and merge block already exist");
  body().push_back(new Block());
  auto *mergeBlock = new Block();
  body().push_back(mergeBlock);
  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);

  // Add a spv.mlir.merge op into the merge block.
  builder.create<spirv::MergeOp>(getLoc());
}

//===----------------------------------------------------------------------===//
// spv.mlir.merge
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::MergeOp mergeOp) {
  auto *parentOp = mergeOp->getParentOp();
  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
    return mergeOp.emitOpError(
        "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");

  Block &parentLastBlock = mergeOp->getParentRegion()->back();
  if (mergeOp.getOperation() != parentLastBlock.getTerminator())
    return mergeOp.emitOpError("can only be used in the last block of "
                               "'spv.mlir.selection' or 'spv.mlir.loop'");
  return success();
}

//===----------------------------------------------------------------------===//
// spv.module
//===----------------------------------------------------------------------===//

void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                            Optional<StringRef> name) {
  OpBuilder::InsertionGuard guard(builder);
  builder.createBlock(state.addRegion());
  if (name) {
    state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
                            builder.getStringAttr(*name));
  }
}

void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                            spirv::AddressingModel addressingModel,
                            spirv::MemoryModel memoryModel,
                            Optional<VerCapExtAttr> vceTriple,
                            Optional<StringRef> name) {
  state.addAttribute(
      "addressing_model",
      builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
  state.addAttribute("memory_model", builder.getI32IntegerAttr(
                                         static_cast<int32_t>(memoryModel)));
  OpBuilder::InsertionGuard guard(builder);
  builder.createBlock(state.addRegion());
  if (vceTriple)
    state.addAttribute(getVCETripleAttrName(), *vceTriple);
  if (name)
    state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(*name));
}

static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
  Region *body = state.addRegion();

  // If the name is present, parse it.
  StringAttr nameAttr;
  parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                                 state.attributes);

  // Parse attributes
  spirv::AddressingModel addrModel;
  spirv::MemoryModel memoryModel;
  if (parseEnumKeywordAttr(addrModel, parser, state) ||
      parseEnumKeywordAttr(memoryModel, parser, state))
    return failure();

  if (succeeded(parser.parseOptionalKeyword("requires"))) {
    spirv::VerCapExtAttr vceTriple;
    if (parser.parseAttribute(vceTriple,
                              spirv::ModuleOp::getVCETripleAttrName(),
                              state.attributes))
      return failure();
  }

  if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
    return failure();

  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
    return failure();

  // Make sure we have at least one block.
  if (body->empty())
    body->push_back(new Block());

  return success();
}

static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
  if (Optional<StringRef> name = moduleOp.getName()) {
    printer << ' ';
    printer.printSymbolName(*name);
  }

  SmallVector<StringRef, 2> elidedAttrs;

  printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
          << " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
                      SymbolTable::getSymbolAttrName()});

  if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
    printer << " requires " << *triple;
    elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
  }

  printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
  printer.printRegion(moduleOp.getRegion());
}

static LogicalResult verify(spirv::ModuleOp moduleOp) {
  auto &op = *moduleOp.getOperation();
  auto *dialect = op.getDialect();
  DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
      entryPoints;
  SymbolTable table(moduleOp);

  for (auto &op : *moduleOp.getBody()) {
    if (op.getDialect() != dialect)
      return op.emitError("'spv.module' can only contain spv.* ops");

    // For EntryPoint op, check that the function and execution model is not
    // duplicated in EntryPointOps. Also verify that the interface specified
    // comes from globalVariables here to make this check cheaper.
    if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
      auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
      if (!funcOp) {
        return entryPointOp.emitError("function '")
               << entryPointOp.fn() << "' not found in 'spv.module'";
      }
      if (auto interface = entryPointOp.interface()) {
        for (Attribute varRef : interface) {
          auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
          if (!varSymRef) {
            return entryPointOp.emitError(
                       "expected symbol reference for interface "
                       "specification instead of '")
                   << varRef;
          }
          auto variableOp =
              table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
          if (!variableOp) {
            return entryPointOp.emitError("expected spv.GlobalVariable "
                                          "symbol reference instead of'")
                   << varSymRef << "'";
          }
        }
      }

      auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
          funcOp, entryPointOp.execution_model());
      auto entryPtIt = entryPoints.find(key);
      if (entryPtIt != entryPoints.end()) {
        return entryPointOp.emitError("duplicate of a previous EntryPointOp");
      }
      entryPoints[key] = entryPointOp;
    } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
      if (funcOp.isExternal())
        return op.emitError("'spv.module' cannot contain external functions");

      // TODO: move this check to spv.func.
      for (auto &block : funcOp)
        for (auto &op : block) {
          if (op.getDialect() != dialect)
            return op.emitError(
                "functions in 'spv.module' can only contain spv.* ops");
        }
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.mlir.referenceof
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
      referenceOfOp->getParentOp(), referenceOfOp.spec_constAttr());
  Type constType;

  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
  if (specConstOp)
    constType = specConstOp.default_value().getType();

  auto specConstCompositeOp =
      dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
  if (specConstCompositeOp)
    constType = specConstCompositeOp.type();

  if (!specConstOp && !specConstCompositeOp)
    return referenceOfOp.emitOpError(
        "expected spv.SpecConstant or spv.SpecConstantComposite symbol");

  if (referenceOfOp.reference().getType() != constType)
    return referenceOfOp.emitOpError("result type mismatch with the referenced "
                                     "specialization constant's type");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.Return
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::ReturnOp returnOp) {
  // Verification is performed in spv.func op.
  return success();
}

//===----------------------------------------------------------------------===//
// spv.ReturnValue
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::ReturnValueOp retValOp) {
  // Verification is performed in spv.func op.
  return success();
}

//===----------------------------------------------------------------------===//
// spv.Select
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::SelectOp op) {
  if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
    auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
    if (!resultVectorTy) {
      return op.emitOpError("result expected to be of vector type when "
                            "condition is of vector type");
    }
    if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
      return op.emitOpError("result should have the same number of elements as "
                            "the condition when condition is of vector type");
    }
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.mlir.selection
//===----------------------------------------------------------------------===//

static ParseResult parseSelectionOp(OpAsmParser &parser,
                                    OperationState &state) {
  if (parseControlAttribute<spirv::SelectionControl>(parser, state))
    return failure();
  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
                            /*argTypes=*/{});
}

static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
  auto *op = selectionOp.getOperation();
  auto control = selectionOp.selection_control();
  if (control != spirv::SelectionControl::None)
    printer << " control(" << spirv::stringifySelectionControl(control) << ")";
  printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
                      /*printBlockTerminators=*/true);
}

static LogicalResult verify(spirv::SelectionOp selectionOp) {
  auto *op = selectionOp.getOperation();

  // We need to verify that the blocks follow the following layout:
  //
  //                     +--------------+
  //                     | header block |
  //                     +--------------+
  //                          / | \
  //                           ...
  //
  //
  //         +---------+   +---------+   +---------+
  //         | case #0 |   | case #1 |   | case #2 |  ...
  //         +---------+   +---------+   +---------+
  //
  //
  //                           ...
  //                          \ | /
  //                            v
  //                     +-------------+
  //                     | merge block |
  //                     +-------------+

  auto &region = op->getRegion(0);
  // Allow empty region as a degenerated case, which can come from
  // optimizations.
  if (region.empty())
    return success();

  // The last block is the merge block.
  if (!isMergeBlock(region.back()))
    return selectionOp.emitOpError(
        "last block must be the merge block with only one 'spv.mlir.merge' op");

  if (std::next(region.begin()) == region.end())
    return selectionOp.emitOpError("must have a selection header block");

  return success();
}

Block *spirv::SelectionOp::getHeaderBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The first block is the loop header block.
  return &body().front();
}

Block *spirv::SelectionOp::getMergeBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The last block is the loop merge block.
  return &body().back();
}

void spirv::SelectionOp::addMergeBlock() {
  assert(body().empty() && "entry and merge block already exist");
  auto *mergeBlock = new Block();
  body().push_back(mergeBlock);
  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);

  // Add a spv.mlir.merge op into the merge block.
  builder.create<spirv::MergeOp>(getLoc());
}

spirv::SelectionOp spirv::SelectionOp::createIfThen(
    Location loc, Value condition,
    function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
  auto selectionOp =
      builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);

  selectionOp.addMergeBlock();
  Block *mergeBlock = selectionOp.getMergeBlock();
  Block *thenBlock = nullptr;

  // Build the "then" block.
  {
    OpBuilder::InsertionGuard guard(builder);
    thenBlock = builder.createBlock(mergeBlock);
    thenBody(builder);
    builder.create<spirv::BranchOp>(loc, mergeBlock);
  }

  // Build the header block.
  {
    OpBuilder::InsertionGuard guard(builder);
    builder.createBlock(thenBlock);
    builder.create<spirv::BranchConditionalOp>(
        loc, condition, thenBlock,
        /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
        /*falseArguments=*/ArrayRef<Value>());
  }

  return selectionOp;
}

//===----------------------------------------------------------------------===//
// spv.SpecConstant
//===----------------------------------------------------------------------===//

static ParseResult parseSpecConstantOp(OpAsmParser &parser,
                                       OperationState &state) {
  StringAttr nameAttr;
  Attribute valueAttr;

  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             state.attributes))
    return failure();

  // Parse optional spec_id.
  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
    IntegerAttr specIdAttr;
    if (parser.parseLParen() ||
        parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
        parser.parseRParen())
      return failure();
  }

  if (parser.parseEqual() ||
      parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
    return failure();

  return success();
}

static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
  printer << ' ';
  printer.printSymbolName(constOp.sym_name());
  if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
    printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
  printer << " = " << constOp.default_value();
}

static LogicalResult verify(spirv::SpecConstantOp constOp) {
  if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
    if (specID.getValue().isNegative())
      return constOp.emitOpError("SpecId cannot be negative");

  auto value = constOp.default_value();
  if (value.isa<IntegerAttr, FloatAttr>()) {
    // Make sure bitwidth is allowed.
    if (!value.getType().isa<spirv::SPIRVType>())
      return constOp.emitOpError("default value bitwidth disallowed");
    return success();
  }
  return constOp.emitOpError(
      "default value can only be a bool, integer, or float scalar");
}

//===----------------------------------------------------------------------===//
// spv.StoreOp
//===----------------------------------------------------------------------===//

static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  auto loc = parser.getCurrentLocation();
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) ||
      parser.parseOperandList(operandInfo, 2) ||
      parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
      parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
                             state.operands)) {
    return failure();
  }
  return success();
}

static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
  auto *op = storeOp.getOperation();
  SmallVector<StringRef, 4> elidedAttrs;
  StringRef sc = stringifyStorageClass(
      storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
  printer << " \"" << sc << "\" " << storeOp.ptr() << ", " << storeOp.value();

  printMemoryAccessAttribute(storeOp, printer, elidedAttrs);

  printer << " : " << storeOp.value().getType();
  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}

static LogicalResult verify(spirv::StoreOp storeOp) {
  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
  // OpTypePointer whose Type operand is the same as the type of Object."
  if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
                                           storeOp.value()))) {
    return failure();
  }
  return verifyMemoryAccessAttribute(storeOp);
}

//===----------------------------------------------------------------------===//
// spv.Unreachable
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
  auto *op = unreachableOp.getOperation();
  auto *block = op->getBlock();
  // Fast track: if this is in entry block, its invalid. Otherwise, if no
  // predecessors, it's valid.
  if (block->isEntryBlock())
    return unreachableOp.emitOpError("cannot be used in reachable block");
  if (block->hasNoPredecessors())
    return success();

  // TODO: further verification needs to analyze reachability from
  // the entry block.

  return success();
}

//===----------------------------------------------------------------------===//
// spv.Variable
//===----------------------------------------------------------------------===//

static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
  // Parse optional initializer
  Optional<OpAsmParser::OperandType> initInfo;
  if (succeeded(parser.parseOptionalKeyword("init"))) {
    initInfo = OpAsmParser::OperandType();
    if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
        parser.parseRParen())
      return failure();
  }

  if (parseVariableDecorations(parser, state)) {
    return failure();
  }

  // Parse result pointer type
  Type type;
  if (parser.parseColon())
    return failure();
  auto loc = parser.getCurrentLocation();
  if (parser.parseType(type))
    return failure();

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType)
    return parser.emitError(loc, "expected spv.ptr type");
  state.addTypes(ptrType);

  // Resolve the initializer operand
  if (initInfo) {
    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
                              state.operands))
      return failure();
  }

  auto attr = parser.getBuilder().getI32IntegerAttr(
      llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
  state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);

  return success();
}

static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs{
      spirv::attributeName<spirv::StorageClass>()};
  // Print optional initializer
  if (varOp.getNumOperands() != 0)
    printer << " init(" << varOp.initializer() << ")";

  printVariableDecorations(varOp, printer, elidedAttrs);
  printer << " : " << varOp.getType();
}

static LogicalResult verify(spirv::VariableOp varOp) {
  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
  // object. It cannot be Generic. It must be the same as the Storage Class
  // operand of the Result Type."
  if (varOp.storage_class() != spirv::StorageClass::Function) {
    return varOp.emitOpError(
        "can only be used to model function-level variables. Use "
        "spv.GlobalVariable for module-level variables.");
  }

  auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
  if (varOp.storage_class() != pointerType.getStorageClass())
    return varOp.emitOpError(
        "storage class must match result pointer's storage class");

  if (varOp.getNumOperands() != 0) {
    // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
    // a global (module scope) OpVariable instruction".
    auto *initOp = varOp.getOperand(0).getDefiningOp();
    if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
                        spirv::ReferenceOfOp, // for spec constant
                        spirv::AddressOfOp>(initOp))
      return varOp.emitOpError("initializer must be the result of a "
                               "constant or spv.GlobalVariable op");
  }

  // TODO: generate these strings using ODS.
  auto *op = varOp.getOperation();
  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::DescriptorSet));
  auto bindingName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::Binding));
  auto builtInName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::BuiltIn));

  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
    if (op->getAttr(attr))
      return varOp.emitOpError("cannot have '")
             << attr << "' attribute (only allowed in spv.GlobalVariable)";
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.VectorShuffle
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) {
  VectorType resultType = shuffleOp.getType().cast<VectorType>();

  size_t numResultElements = resultType.getNumElements();
  if (numResultElements != shuffleOp.components().size())
    return shuffleOp.emitOpError("result type element count (")
           << numResultElements
           << ") mismatch with the number of component selectors ("
           << shuffleOp.components().size() << ")";

  size_t totalSrcElements =
      shuffleOp.vector1().getType().cast<VectorType>().getNumElements() +
      shuffleOp.vector2().getType().cast<VectorType>().getNumElements();

  for (const auto &selector :
       shuffleOp.components().getAsValueRange<IntegerAttr>()) {
    uint32_t index = selector.getZExtValue();
    if (index >= totalSrcElements &&
        index != std::numeric_limits<uint32_t>().max())
      return shuffleOp.emitOpError("component selector ")
             << index << " out of range: expected to be in [0, "
             << totalSrcElements << ") or 0xffffffff";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixLoadNV
//===----------------------------------------------------------------------===//

static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
                                                  OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 3> operandInfo;
  Type strideType = parser.getBuilder().getIntegerType(32);
  Type columnMajorType = parser.getBuilder().getIntegerType(1);
  Type ptrType;
  Type elementType;
  if (parser.parseOperandList(operandInfo, 3) ||
      parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
      parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
    return failure();
  }
  if (parser.resolveOperands(operandInfo,
                             {ptrType, strideType, columnMajorType},
                             parser.getNameLoc(), state.operands)) {
    return failure();
  }

  state.addTypes(elementType);
  return success();
}

static void print(spirv::CooperativeMatrixLoadNVOp m, OpAsmPrinter &printer) {
  printer << " " << m.pointer() << ", " << m.stride() << ", "
          << m.columnmajor();
  // Print optional memory access attribute.
  if (auto memAccess = m.memory_access())
    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
  printer << " : " << m.pointer().getType() << " as " << m.getType();
}

static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
                                                    Type coopMatrix) {
  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
    return op->emitError(
               "Pointer must point to a scalar or vector type but provided ")
           << pointeeType;
  spirv::StorageClass storage =
      pointer.cast<spirv::PointerType>().getStorageClass();
  if (storage != spirv::StorageClass::Workgroup &&
      storage != spirv::StorageClass::StorageBuffer &&
      storage != spirv::StorageClass::PhysicalStorageBuffer)
    return op->emitError(
               "Pointer storage class must be Workgroup, StorageBuffer or "
               "PhysicalStorageBufferEXT but provided ")
           << stringifyStorageClass(storage);
  return success();
}

//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixStoreNV
//===----------------------------------------------------------------------===//

static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
                                                   OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 4> operandInfo;
  Type strideType = parser.getBuilder().getIntegerType(32);
  Type columnMajorType = parser.getBuilder().getIntegerType(1);
  Type ptrType;
  Type elementType;
  if (parser.parseOperandList(operandInfo, 4) ||
      parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
      parser.parseType(ptrType) || parser.parseComma() ||
      parser.parseType(elementType)) {
    return failure();
  }
  if (parser.resolveOperands(
          operandInfo, {ptrType, elementType, strideType, columnMajorType},
          parser.getNameLoc(), state.operands)) {
    return failure();
  }

  return success();
}

static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
                  OpAsmPrinter &printer) {
  printer << " " << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
          << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
  // Print optional memory access attribute.
  if (auto memAccess = coopMatrix.memory_access())
    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
  printer << " : " << coopMatrix.pointer().getType() << ", "
          << coopMatrix.getOperand(1).getType();
}

//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixMulAddNV
//===----------------------------------------------------------------------===//

static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
  if (op.c().getType() != op.result().getType())
    return op.emitOpError("result and third operand must have the same type");
  auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
  auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
  auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
  auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
  if (typeA.getRows() != typeR.getRows() ||
      typeA.getColumns() != typeB.getRows() ||
      typeB.getColumns() != typeR.getColumns())
    return op.emitOpError("matrix size must match");
  if (typeR.getScope() != typeA.getScope() ||
      typeR.getScope() != typeB.getScope() ||
      typeR.getScope() != typeC.getScope())
    return op.emitOpError("matrix scope must match");
  if (typeA.getElementType() != typeB.getElementType() ||
      typeR.getElementType() != typeC.getElementType())
    return op.emitOpError("matrix element type must match");
  return success();
}

//===----------------------------------------------------------------------===//
// spv.MatrixTimesScalar
//===----------------------------------------------------------------------===//

static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
  // We already checked that result and matrix are both of matrix type in the
  // auto-generated verify method.

  auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();

  // Check that the scalar type is the same as the matrix element type.
  if (op.scalar().getType() != inputMatrix.getElementType())
    return op.emitError("input matrix components' type and scaling value must "
                        "have the same type");

  // Note that the next three checks could be done using the AllTypesMatch
  // trait in the Op definition file but it generates a vague error message.

  // Check that the input and result matrices have the same columns' count
  if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
    return op.emitError("input and result matrices must have the same "
                        "number of columns");

  // Check that the input and result matrices' have the same rows count
  if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
    return op.emitError("input and result matrices' columns must have "
                        "the same size");

  // Check that the input and result matrices' have the same component type
  if (inputMatrix.getElementType() != resultMatrix.getElementType())
    return op.emitError("input and result matrices' columns must have "
                        "the same component type");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.CopyMemory
//===----------------------------------------------------------------------===//

static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
  auto *op = copyMemory.getOperation();
  printer << ' ';

  StringRef targetStorageClass =
      stringifyStorageClass(copyMemory.target()
                                .getType()
                                .cast<spirv::PointerType>()
                                .getStorageClass());
  printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
          << ", ";

  StringRef sourceStorageClass =
      stringifyStorageClass(copyMemory.source()
                                .getType()
                                .cast<spirv::PointerType>()
                                .getStorageClass());
  printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();

  SmallVector<StringRef, 4> elidedAttrs;
  printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
  printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs,
                                   copyMemory.source_memory_access(),
                                   copyMemory.source_alignment());

  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);

  Type pointeeType =
      copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
  printer << " : " << pointeeType;
}

static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
                                     OperationState &state) {
  spirv::StorageClass targetStorageClass;
  OpAsmParser::OperandType targetPtrInfo;

  spirv::StorageClass sourceStorageClass;
  OpAsmParser::OperandType sourcePtrInfo;

  Type elementType;

  if (parseEnumStrAttr(targetStorageClass, parser) ||
      parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
      parseEnumStrAttr(sourceStorageClass, parser) ||
      parser.parseOperand(sourcePtrInfo) ||
      parseMemoryAccessAttributes(parser, state)) {
    return failure();
  }

  if (!parser.parseOptionalComma()) {
    // Parse 2nd memory access attributes.
    if (parseSourceMemoryAccessAttributes(parser, state)) {
      return failure();
    }
  }

  if (parser.parseColon() || parser.parseType(elementType))
    return failure();

  if (parser.parseOptionalAttrDict(state.attributes))
    return failure();

  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);

  if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
      parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
    return failure();
  }

  return success();
}

static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
  Type targetType =
      copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();

  Type sourceType =
      copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();

  if (targetType != sourceType) {
    return copyMemory.emitOpError(
        "both operands must be pointers to the same type");
  }

  if (failed(verifyMemoryAccessAttribute(copyMemory))) {
    return failure();
  }

  // TODO - According to the spec:
  //
  // If two masks are present, the first applies to Target and cannot include
  // MakePointerVisible, and the second applies to Source and cannot include
  // MakePointerAvailable.
  //
  // Add such verification here.

  return verifySourceMemoryAccessAttribute(copyMemory);
}

//===----------------------------------------------------------------------===//
// spv.Transpose
//===----------------------------------------------------------------------===//

static LogicalResult verifyTranspose(spirv::TransposeOp op) {
  auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();

  // Verify that the input and output matrices have correct shapes.
  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
    return op.emitError("input matrix rows count must be equal to "
                        "output matrix columns count");

  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
    return op.emitError("input matrix columns count must be equal to "
                        "output matrix rows count");

  // Verify that the input and output matrices have the same component type
  if (inputMatrix.getElementType() != resultMatrix.getElementType())
    return op.emitError("input and output matrices must have the same "
                        "component type");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//

static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
  auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
  auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();

  // left matrix columns' count and right matrix rows' count must be equal
  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
    return op.emitError("left matrix columns' count must be equal to "
                        "the right matrix rows' count");

  // right and result matrices columns' count must be the same
  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
    return op.emitError(
        "right and result matrices must have equal columns' count");

  // right and result matrices component type must be the same
  if (rightMatrix.getElementType() != resultMatrix.getElementType())
    return op.emitError("right and result matrices' component type must"
                        " be the same");

  // left and result matrices component type must be the same
  if (leftMatrix.getElementType() != resultMatrix.getElementType())
    return op.emitError("left and result matrices' component type"
                        " must be the same");

  // left and result matrices rows count must be the same
  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
    return op.emitError("left and result matrices must have equal rows'"
                        " count");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.SpecConstantComposite
//===----------------------------------------------------------------------===//

static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser,
                                                OperationState &state) {

  StringAttr compositeName;
  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
                             state.attributes))
    return failure();

  if (parser.parseLParen())
    return failure();

  SmallVector<Attribute, 4> constituents;

  do {
    // The name of the constituent attribute isn't important
    const char *attrName = "spec_const";
    FlatSymbolRefAttr specConstRef;
    NamedAttrList attrs;

    if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
      return failure();

    constituents.push_back(specConstRef);
  } while (!parser.parseOptionalComma());

  if (parser.parseRParen())
    return failure();

  state.addAttribute(kCompositeSpecConstituentsName,
                     parser.getBuilder().getArrayAttr(constituents));

  Type type;
  if (parser.parseColonType(type))
    return failure();

  state.addAttribute(kTypeAttrName, TypeAttr::get(type));

  return success();
}

static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
  printer << " ";
  printer.printSymbolName(op.sym_name());
  printer << " (";
  auto constituents = op.constituents().getValue();

  if (!constituents.empty())
    llvm::interleaveComma(constituents, printer);

  printer << ") : " << op.type();
}

static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
  auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
  auto constituents = constOp.constituents().getValue();

  if (!cType)
    return constOp.emitError(
               "result type must be a composite type, but provided ")
           << constOp.type();

  if (cType.isa<spirv::CooperativeMatrixNVType>())
    return constOp.emitError("unsupported composite type  ") << cType;
  if (constituents.size() != cType.getNumElements())
    return constOp.emitError("has incorrect number of operands: expected ")
           << cType.getNumElements() << ", but provided "
           << constituents.size();

  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
    auto constituent = constituents[index].cast<FlatSymbolRefAttr>();

    auto constituentSpecConstOp =
        dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
            constOp->getParentOp(), constituent.getAttr()));

    if (constituentSpecConstOp.default_value().getType() !=
        cType.getElementType(index))
      return constOp.emitError("has incorrect types of operands: expected ")
             << cType.getElementType(index) << ", but provided "
             << constituentSpecConstOp.default_value().getType();
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.SpecConstantOperation
//===----------------------------------------------------------------------===//

static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
                                                OperationState &state) {
  Region *body = state.addRegion();

  if (parser.parseKeyword("wraps"))
    return failure();

  body->push_back(new Block);
  Block &block = body->back();
  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());

  if (!wrappedOp)
    return failure();

  OpBuilder builder(parser.getContext());
  builder.setInsertionPointToEnd(&block);
  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
  state.location = wrappedOp->getLoc();

  state.addTypes(wrappedOp->getResult(0).getType());

  if (parser.parseOptionalAttrDict(state.attributes))
    return failure();

  return success();
}

static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
  printer << " wraps ";
  printer.printGenericOp(&op.body().front().front());
}

static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
  Block &block = constOp.getRegion().getBlocks().front();

  if (block.getOperations().size() != 2)
    return constOp.emitOpError("expected exactly 2 nested ops");

  Operation &enclosedOp = block.getOperations().front();

  if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
    return constOp.emitOpError("invalid enclosed op");

  for (auto operand : enclosedOp.getOperands())
    if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
             spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
      return constOp.emitOpError(
          "invalid operand, must be defined by a constant operation");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.GLSL.FrexpStruct
//===----------------------------------------------------------------------===//
static LogicalResult
verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp) {
  spirv::StructType structTy =
      frexpStructOp.result().getType().dyn_cast<spirv::StructType>();

  if (structTy.getNumElements() != 2)
    return frexpStructOp.emitError("result type must be a struct  type "
                                   "with two memebers");

  Type significandTy = structTy.getElementType(0);
  Type exponentTy = structTy.getElementType(1);
  VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
  IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();

  Type operandTy = frexpStructOp.operand().getType();
  VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
  FloatType operandFTy = operandTy.dyn_cast<FloatType>();

  if (significandTy != operandTy)
    return frexpStructOp.emitError("member zero of the resulting struct type "
                                   "must be the same type as the operand");

  if (exponentVecTy) {
    IntegerType componentIntTy =
        exponentVecTy.getElementType().dyn_cast<IntegerType>();
    if (!(componentIntTy && componentIntTy.getWidth() == 32))
      return frexpStructOp.emitError(
          "member one of the resulting struct type must"
          "be a scalar or vector of 32 bit integer type");
  } else if (!(exponentIntTy && exponentIntTy.getWidth() == 32)) {
    return frexpStructOp.emitError(
        "member one of the resulting struct type "
        "must be a scalar or vector of 32 bit integer type");
  }

  // Check that the two member types have the same number of components
  if (operandVecTy && exponentVecTy &&
      (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
    return success();

  if (operandFTy && exponentIntTy)
    return success();

  return frexpStructOp.emitError(
      "member one of the resulting struct type "
      "must have the same number of components as the operand type");
}

//===----------------------------------------------------------------------===//
// spv.GLSL.Ldexp
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::GLSLLdexpOp ldexpOp) {
  Type significandType = ldexpOp.x().getType();
  Type exponentType = ldexpOp.exp().getType();

  if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
    return ldexpOp.emitOpError("operands must both be scalars or vectors");

  auto getNumElements = [](Type type) -> unsigned {
    if (auto vectorType = type.dyn_cast<VectorType>())
      return vectorType.getNumElements();
    return 1;
  };

  if (getNumElements(significandType) != getNumElements(exponentType))
    return ldexpOp.emitOpError(
        "operands must have the same number of elements");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.ImageDrefGather
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) {
  VectorType resultType =
      imageDrefGatherOp.result().getType().cast<VectorType>();
  auto sampledImageType = imageDrefGatherOp.sampledimage()
                              .getType()
                              .cast<spirv::SampledImageType>();
  auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();

  if (resultType.getNumElements() != 4)
    return imageDrefGatherOp.emitOpError(
        "result type must be a vector of four components");

  Type elementType = resultType.getElementType();
  Type sampledElementType = imageType.getElementType();
  if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
    return imageDrefGatherOp.emitOpError(
        "the component type of result must be the same as sampled type of the "
        "underlying image type");

  spirv::Dim imageDim = imageType.getDim();
  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();

  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
      imageDim != spirv::Dim::Rect)
    return imageDrefGatherOp.emitOpError(
        "the Dim operand of the underlying image type must be 2D, Cube, or "
        "Rect");

  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
    return imageDrefGatherOp.emitOpError(
        "the MS operand of the underlying image type must be 0");

  spirv::ImageOperandsAttr attr = imageDrefGatherOp.imageoperandsAttr();
  auto operandArguments = imageDrefGatherOp.operand_arguments();

  return verifyImageOperands(imageDrefGatherOp, attr, operandArguments);
}

//===----------------------------------------------------------------------===//
// spv.ImageQuerySize
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
  spirv::ImageType imageType =
      imageQuerySizeOp.image().getType().cast<spirv::ImageType>();
  Type resultType = imageQuerySizeOp.result().getType();

  spirv::Dim dim = imageType.getDim();
  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
  switch (dim) {
  case spirv::Dim::Dim1D:
  case spirv::Dim::Dim2D:
  case spirv::Dim::Dim3D:
  case spirv::Dim::Cube:
    if (!(samplingInfo == spirv::ImageSamplingInfo::MultiSampled ||
          samplerInfo == spirv::ImageSamplerUseInfo::SamplerUnknown ||
          samplerInfo == spirv::ImageSamplerUseInfo::NoSampler))
      return imageQuerySizeOp.emitError(
          "if Dim is 1D, 2D, 3D, or Cube, "
          "it must also have either an MS of 1 or a Sampled of 0 or 2");
    break;
  case spirv::Dim::Buffer:
  case spirv::Dim::Rect:
    break;
  default:
    return imageQuerySizeOp.emitError("the Dim operand of the image type must "
                                      "be 1D, 2D, 3D, Buffer, Cube, or Rect");
  }

  unsigned componentNumber = 0;
  switch (dim) {
  case spirv::Dim::Dim1D:
  case spirv::Dim::Buffer:
    componentNumber = 1;
    break;
  case spirv::Dim::Dim2D:
  case spirv::Dim::Cube:
  case spirv::Dim::Rect:
    componentNumber = 2;
    break;
  case spirv::Dim::Dim3D:
    componentNumber = 3;
    break;
  default:
    break;
  }

  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
    componentNumber += 1;

  unsigned resultComponentNumber = 1;
  if (auto resultVectorType = resultType.dyn_cast<VectorType>())
    resultComponentNumber = resultVectorType.getNumElements();

  if (componentNumber != resultComponentNumber)
    return imageQuerySizeOp.emitError("expected the result to have ")
           << componentNumber << " component(s), but found "
           << resultComponentNumber << " component(s)";

  return success();
}

static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
                                             OpAsmParser &parser,
                                             OperationState &state) {
  OpAsmParser::OperandType ptrInfo;
  SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
  Type type;
  auto loc = parser.getCurrentLocation();
  SmallVector<Type, 4> indicesTypes;

  if (parser.parseOperand(ptrInfo) ||
      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
      parser.parseColonType(type) ||
      parser.resolveOperand(ptrInfo, type, state.operands))
    return failure();

  // Check that the provided indices list is not empty before parsing their
  // type list.
  if (indicesInfo.empty())
    return emitError(state.location) << opName << " expected element";

  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
    return failure();

  // Check that the indices types list is not empty and that it has a one-to-one
  // mapping to the provided indices.
  if (indicesTypes.size() != indicesInfo.size())
    return emitError(state.location)
           << opName
           << " indices types' count must be equal to indices info count";

  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
    return failure();

  auto resultType = getElementPtrType(
      type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
  if (!resultType)
    return failure();

  state.addTypes(resultType);
  return success();
}

template <typename Op>
static auto concatElemAndIndices(Op op) {
  SmallVector<Value> ret(op.indices().size() + 1);
  ret[0] = op.element();
  llvm::copy(op.indices(), ret.begin() + 1);
  return ret;
}

//===----------------------------------------------------------------------===//
// spv.InBoundsPtrAccessChainOp
//===----------------------------------------------------------------------===//

void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
                                            OperationState &state,
                                            Value basePtr, Value element,
                                            ValueRange indices) {
  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
  assert(type && "Unable to deduce return type based on basePtr and indices");
  build(builder, state, type, basePtr, element, indices);
}

static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser,
                                                 OperationState &state) {
  return parsePtrAccessChainOpImpl(
      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
}

static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) {
  printAccessChain(op, concatElemAndIndices(op), printer);
}

static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) {
  return verifyAccessChain(accessChainOp, accessChainOp.indices());
}

//===----------------------------------------------------------------------===//
// spv.PtrAccessChainOp
//===----------------------------------------------------------------------===//

void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
                                    Value basePtr, Value element,
                                    ValueRange indices) {
  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
  assert(type && "Unable to deduce return type based on basePtr and indices");
  build(builder, state, type, basePtr, element, indices);
}

static ParseResult parsePtrAccessChainOp(OpAsmParser &parser,
                                         OperationState &state) {
  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
                                   parser, state);
}

static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) {
  printAccessChain(op, concatElemAndIndices(op), printer);
}

static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
  return verifyAccessChain(accessChainOp, accessChainOp.indices());
}

// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"

// TablenGen'erated operation definitions.
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"

namespace mlir {
namespace spirv {
// TableGen'erated operation availability interface implementations.
#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
} // namespace spirv
} // namespace mlir
