//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the SPIR-V dialect in MLIR.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::spirv;

#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"

//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//

/// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
static inline bool containsReturn(Region &region) {
  return llvm::any_of(region, [](Block &block) {
    Operation *terminator = block.getTerminator();
    return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
  });
}

namespace {
/// This class defines the interface for inlining within the SPIR-V dialect.
struct SPIRVInlinerInterface : public DialectInlinerInterface {
  using DialectInlinerInterface::DialectInlinerInterface;

  /// All call operations within SPIRV can be inlined.
  bool isLegalToInline(Operation *call, Operation *callable,
                       bool wouldBeCloned) const final {
    return true;
  }

  /// Returns true if the given region 'src' can be inlined into the region
  /// 'dest' that is attached to an operation registered to the current dialect.
  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
                       BlockAndValueMapping &) const final {
    // Return true here when inlining into spv.func, spv.mlir.selection, and
    // spv.mlir.loop operations.
    auto *op = dest->getParentOp();
    return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
  }

  /// Returns true if the given operation 'op', that is registered to this
  /// dialect, can be inlined into the region 'dest' that is attached to an
  /// operation registered to the current dialect.
  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
                       BlockAndValueMapping &) const final {
    // TODO: Enable inlining structured control flows with return.
    if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
        containsReturn(op->getRegion(0)))
      return false;
    // TODO: we need to filter OpKill here to avoid inlining it to
    // a loop continue construct:
    // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
    // However OpKill is fragment shader specific and we don't support it yet.
    return true;
  }

  /// Handle the given inlined terminator by replacing it with a new operation
  /// as necessary.
  void handleTerminator(Operation *op, Block *newDest) const final {
    if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
      OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
      op->erase();
    } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
      llvm_unreachable("unimplemented spv.ReturnValue in inliner");
    }
  }

  /// Handle the given inlined terminator by replacing it with a new operation
  /// as necessary.
  void handleTerminator(Operation *op,
                        ArrayRef<Value> valuesToRepl) const final {
    // Only spv.ReturnValue needs to be handled here.
    auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
    if (!retValOp)
      return;

    // Replace the values directly with the return operands.
    assert(valuesToRepl.size() == 1 &&
           "spv.ReturnValue expected to only handle one result");
    valuesToRepl.front().replaceAllUsesWith(retValOp.value());
  }
};
} // namespace

//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//

void SPIRVDialect::initialize() {
  registerAttributes();
  registerTypes();

  // Add SPIR-V ops.
  addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
      >();

  addInterfaces<SPIRVInlinerInterface>();

  // Allow unknown operations because SPIR-V is extensible.
  allowUnknownOperations();
}

std::string SPIRVDialect::getAttributeName(Decoration decoration) {
  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
}

//===----------------------------------------------------------------------===//
// Type Parsing
//===----------------------------------------------------------------------===//

// Forward declarations.
template <typename ValTy>
static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
                                      DialectAsmParser &parser);
template <>
Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
                                    DialectAsmParser &parser);

template <>
Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
                                            DialectAsmParser &parser);

static Type parseAndVerifyType(SPIRVDialect const &dialect,
                               DialectAsmParser &parser) {
  Type type;
  llvm::SMLoc typeLoc = parser.getCurrentLocation();
  if (parser.parseType(type))
    return Type();

  // Allow SPIR-V dialect types
  if (&type.getDialect() == &dialect)
    return type;

  // Check other allowed types
  if (auto t = type.dyn_cast<FloatType>()) {
    if (type.isBF16()) {
      parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
      return Type();
    }
  } else if (auto t = type.dyn_cast<IntegerType>()) {
    if (!ScalarType::isValid(t)) {
      parser.emitError(typeLoc,
                       "only 1/8/16/32/64-bit integer type allowed but found ")
          << type;
      return Type();
    }
  } else if (auto t = type.dyn_cast<VectorType>()) {
    if (t.getRank() != 1) {
      parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
      return Type();
    }
    if (t.getNumElements() > 4) {
      parser.emitError(
          typeLoc, "vector length has to be less than or equal to 4 but found ")
          << t.getNumElements();
      return Type();
    }
  } else {
    parser.emitError(typeLoc, "cannot use ")
        << type << " to compose SPIR-V types";
    return Type();
  }

  return type;
}

static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
                                     DialectAsmParser &parser) {
  Type type;
  llvm::SMLoc typeLoc = parser.getCurrentLocation();
  if (parser.parseType(type))
    return Type();

  if (auto t = type.dyn_cast<VectorType>()) {
    if (t.getRank() != 1) {
      parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
      return Type();
    }
    if (t.getNumElements() > 4 || t.getNumElements() < 2) {
      parser.emitError(typeLoc,
                       "matrix columns size has to be less than or equal "
                       "to 4 and greater than or equal 2, but found ")
          << t.getNumElements();
      return Type();
    }

    if (!t.getElementType().isa<FloatType>()) {
      parser.emitError(typeLoc, "matrix columns' elements must be of "
                                "Float type, got ")
          << t.getElementType();
      return Type();
    }
  } else {
    parser.emitError(typeLoc, "matrix must be composed using vector "
                              "type, got ")
        << type;
    return Type();
  }

  return type;
}

static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
                                           DialectAsmParser &parser) {
  Type type;
  llvm::SMLoc typeLoc = parser.getCurrentLocation();
  if (parser.parseType(type))
    return Type();

  if (!type.isa<ImageType>()) {
    parser.emitError(typeLoc,
                     "sampled image must be composed using image type, got ")
        << type;
    return Type();
  }

  return type;
}

/// Parses an optional `, stride = N` assembly segment. If no parsing failure
/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
/// missing.
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
                                              DialectAsmParser &parser,
                                              unsigned &stride) {
  if (failed(parser.parseOptionalComma())) {
    stride = 0;
    return success();
  }

  if (parser.parseKeyword("stride") || parser.parseEqual())
    return failure();

  llvm::SMLoc strideLoc = parser.getCurrentLocation();
  Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
  if (!optStride)
    return failure();

  if (!(stride = optStride.getValue())) {
    parser.emitError(strideLoc, "ArrayStride must be greater than zero");
    return failure();
  }
  return success();
}

// element-type ::= integer-type
//                | floating-point-type
//                | vector-type
//                | spirv-type
//
// array-type ::= `!spv.array` `<` integer-literal `x` element-type
//                (`,` `stride` `=` integer-literal)? `>`
static Type parseArrayType(SPIRVDialect const &dialect,
                           DialectAsmParser &parser) {
  if (parser.parseLess())
    return Type();

  SmallVector<int64_t, 1> countDims;
  llvm::SMLoc countLoc = parser.getCurrentLocation();
  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
    return Type();
  if (countDims.size() != 1) {
    parser.emitError(countLoc,
                     "expected single integer for array element count");
    return Type();
  }

  // According to the SPIR-V spec:
  // "Length is the number of elements in the array. It must be at least 1."
  int64_t count = countDims[0];
  if (count == 0) {
    parser.emitError(countLoc, "expected array length greater than 0");
    return Type();
  }

  Type elementType = parseAndVerifyType(dialect, parser);
  if (!elementType)
    return Type();

  unsigned stride = 0;
  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
    return Type();

  if (parser.parseGreater())
    return Type();
  return ArrayType::get(elementType, count, stride);
}

// cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
//                                                   rows ',' columns>`
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
                                       DialectAsmParser &parser) {
  if (parser.parseLess())
    return Type();

  SmallVector<int64_t, 2> dims;
  llvm::SMLoc countLoc = parser.getCurrentLocation();
  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
    return Type();

  if (dims.size() != 2) {
    parser.emitError(countLoc, "expected rows and columns size");
    return Type();
  }

  auto elementTy = parseAndVerifyType(dialect, parser);
  if (!elementTy)
    return Type();

  Scope scope;
  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
    return Type();

  if (parser.parseGreater())
    return Type();
  return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
}

// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
// storage-class ::= `UniformConstant`
//                 | `Uniform`
//                 | `Workgroup`
//                 | <and other storage classes...>
//
// pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
static Type parsePointerType(SPIRVDialect const &dialect,
                             DialectAsmParser &parser) {
  if (parser.parseLess())
    return Type();

  auto pointeeType = parseAndVerifyType(dialect, parser);
  if (!pointeeType)
    return Type();

  StringRef storageClassSpec;
  llvm::SMLoc storageClassLoc = parser.getCurrentLocation();
  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
    return Type();

  auto storageClass = symbolizeStorageClass(storageClassSpec);
  if (!storageClass) {
    parser.emitError(storageClassLoc, "unknown storage class: ")
        << storageClassSpec;
    return Type();
  }
  if (parser.parseGreater())
    return Type();
  return PointerType::get(pointeeType, *storageClass);
}

// runtime-array-type ::= `!spv.rtarray` `<` element-type
//                        (`,` `stride` `=` integer-literal)? `>`
static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
                                  DialectAsmParser &parser) {
  if (parser.parseLess())
    return Type();

  Type elementType = parseAndVerifyType(dialect, parser);
  if (!elementType)
    return Type();

  unsigned stride = 0;
  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
    return Type();

  if (parser.parseGreater())
    return Type();
  return RuntimeArrayType::get(elementType, stride);
}

// matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
static Type parseMatrixType(SPIRVDialect const &dialect,
                            DialectAsmParser &parser) {
  if (parser.parseLess())
    return Type();

  SmallVector<int64_t, 1> countDims;
  llvm::SMLoc countLoc = parser.getCurrentLocation();
  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
    return Type();
  if (countDims.size() != 1) {
    parser.emitError(countLoc, "expected single unsigned "
                               "integer for number of columns");
    return Type();
  }

  int64_t columnCount = countDims[0];
  // According to the specification, Matrices can have 2, 3, or 4 columns
  if (columnCount < 2 || columnCount > 4) {
    parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
                               "columns");
    return Type();
  }

  Type columnType = parseAndVerifyMatrixType(dialect, parser);
  if (!columnType)
    return Type();

  if (parser.parseGreater())
    return Type();

  return MatrixType::get(columnType, columnCount);
}

// Specialize this function to parse each of the parameters that define an
// ImageType. By default it assumes this is an enum type.
template <typename ValTy>
static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
                                      DialectAsmParser &parser) {
  StringRef enumSpec;
  llvm::SMLoc enumLoc = parser.getCurrentLocation();
  if (parser.parseKeyword(&enumSpec)) {
    return llvm::None;
  }

  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
  if (!val)
    parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
  return val;
}

template <>
Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
                                    DialectAsmParser &parser) {
  // TODO: Further verify that the element type can be sampled
  auto ty = parseAndVerifyType(dialect, parser);
  if (!ty)
    return llvm::None;
  return ty;
}

template <typename IntTy>
static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
                                             DialectAsmParser &parser) {
  IntTy offsetVal = std::numeric_limits<IntTy>::max();
  if (parser.parseInteger(offsetVal))
    return llvm::None;
  return offsetVal;
}

template <>
Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
                                            DialectAsmParser &parser) {
  return parseAndVerifyInteger<unsigned>(dialect, parser);
}

namespace {
// Functor object to parse a comma separated list of specs. The function
// parseAndVerify does the actual parsing and verification of individual
// elements. This is a functor since parsing the last element of the list
// (termination condition) needs partial specialization.
template <typename ParseType, typename... Args> struct ParseCommaSeparatedList {
  Optional<std::tuple<ParseType, Args...>>
  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
    auto parseVal = parseAndVerify<ParseType>(dialect, parser);
    if (!parseVal)
      return llvm::None;

    auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
    if (numArgs != 0 && failed(parser.parseComma()))
      return llvm::None;
    auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
    if (!remainingValues)
      return llvm::None;
    return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
                          remainingValues.getValue());
  }
};

// Partial specialization of the function to parse a comma separated list of
// specs to parse the last element of the list.
template <typename ParseType> struct ParseCommaSeparatedList<ParseType> {
  Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
                                             DialectAsmParser &parser) const {
    if (auto value = parseAndVerify<ParseType>(dialect, parser))
      return std::tuple<ParseType>(value.getValue());
    return llvm::None;
  }
};
} // namespace

// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
//
// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
//
// arrayed-info ::= `NonArrayed` | `Arrayed`
//
// sampling-info ::= `SingleSampled` | `MultiSampled`
//
// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` |  `NoSampler`
//
// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
//
// image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
//                              arrayed-info `,` sampling-info `,`
//                              sampler-use-info `,` format `>`
static Type parseImageType(SPIRVDialect const &dialect,
                           DialectAsmParser &parser) {
  if (parser.parseLess())
    return Type();

  auto value =
      ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
                              ImageSamplingInfo, ImageSamplerUseInfo,
                              ImageFormat>{}(dialect, parser);
  if (!value)
    return Type();

  if (parser.parseGreater())
    return Type();
  return ImageType::get(value.getValue());
}

// sampledImage-type :: = `!spv.sampledImage<` image-type `>`
static Type parseSampledImageType(SPIRVDialect const &dialect,
                                  DialectAsmParser &parser) {
  if (parser.parseLess())
    return Type();

  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
  if (!parsedType)
    return Type();

  if (parser.parseGreater())
    return Type();
  return SampledImageType::get(parsedType);
}

// Parse decorations associated with a member.
static ParseResult parseStructMemberDecorations(
    SPIRVDialect const &dialect, DialectAsmParser &parser,
    ArrayRef<Type> memberTypes,
    SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
    SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {

  // Check if the first element is offset.
  llvm::SMLoc offsetLoc = parser.getCurrentLocation();
  StructType::OffsetInfo offset = 0;
  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
  if (offsetParseResult.hasValue()) {
    if (failed(*offsetParseResult))
      return failure();

    if (offsetInfo.size() != memberTypes.size() - 1) {
      return parser.emitError(offsetLoc,
                              "offset specification must be given for "
                              "all members");
    }
    offsetInfo.push_back(offset);
  }

  // Check for no spirv::Decorations.
  if (succeeded(parser.parseOptionalRSquare()))
    return success();

  // If there was an offset, make sure to parse the comma.
  if (offsetParseResult.hasValue() && parser.parseComma())
    return failure();

  // Check for spirv::Decorations.
  do {
    auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
    if (!memberDecoration)
      return failure();

    // Parse member decoration value if it exists.
    if (succeeded(parser.parseOptionalEqual())) {
      auto memberDecorationValue =
          parseAndVerifyInteger<uint32_t>(dialect, parser);

      if (!memberDecorationValue)
        return failure();

      memberDecorationInfo.emplace_back(
          static_cast<uint32_t>(memberTypes.size() - 1), 1,
          memberDecoration.getValue(), memberDecorationValue.getValue());
    } else {
      memberDecorationInfo.emplace_back(
          static_cast<uint32_t>(memberTypes.size() - 1), 0,
          memberDecoration.getValue(), 0);
    }

  } while (succeeded(parser.parseOptionalComma()));

  return parser.parseRSquare();
}

// struct-member-decoration ::= integer-literal? spirv-decoration*
// struct-type ::=
//             `!spv.struct<` (id `,`)?
//                          `(`
//                            (spirv-type (`[` struct-member-decoration `]`)?)*
//                          `)>`
static Type parseStructType(SPIRVDialect const &dialect,
                            DialectAsmParser &parser) {
  // TODO: This function is quite lengthy. Break it down into smaller chunks.

  // To properly resolve recursive references while parsing recursive struct
  // types, we need to maintain a list of enclosing struct type names. This set
  // maintains the names of struct types in which the type we are about to parse
  // is nested.
  //
  // Note: This has to be thread_local to enable multiple threads to safely
  // parse concurrently.
  thread_local SetVector<StringRef> structContext;

  static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
                                           StringRef identifier) {
    if (!identifier.empty())
      structContext.remove(identifier);

    return Type();
  };

  if (parser.parseLess())
    return Type();

  StringRef identifier;

  // Check if this is an identified struct type.
  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
    // Check if this is a possible recursive reference.
    if (succeeded(parser.parseOptionalGreater())) {
      if (structContext.count(identifier) == 0) {
        parser.emitError(
            parser.getNameLoc(),
            "recursive struct reference not nested in struct definition");

        return Type();
      }

      return StructType::getIdentified(dialect.getContext(), identifier);
    }

    if (failed(parser.parseComma()))
      return Type();

    if (structContext.count(identifier) != 0) {
      parser.emitError(parser.getNameLoc(),
                       "identifier already used for an enclosing struct");

      return removeIdentifierAndFail(structContext, identifier);
    }

    structContext.insert(identifier);
  }

  if (failed(parser.parseLParen()))
    return removeIdentifierAndFail(structContext, identifier);

  if (succeeded(parser.parseOptionalRParen()) &&
      succeeded(parser.parseOptionalGreater())) {
    if (!identifier.empty())
      structContext.remove(identifier);

    return StructType::getEmpty(dialect.getContext(), identifier);
  }

  StructType idStructTy;

  if (!identifier.empty())
    idStructTy = StructType::getIdentified(dialect.getContext(), identifier);

  SmallVector<Type, 4> memberTypes;
  SmallVector<StructType::OffsetInfo, 4> offsetInfo;
  SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;

  do {
    Type memberType;
    if (parser.parseType(memberType))
      return removeIdentifierAndFail(structContext, identifier);
    memberTypes.push_back(memberType);

    if (succeeded(parser.parseOptionalLSquare()))
      if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
                                       memberDecorationInfo))
        return removeIdentifierAndFail(structContext, identifier);
  } while (succeeded(parser.parseOptionalComma()));

  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
    parser.emitError(parser.getNameLoc(),
                     "offset specification must be given for all members");
    return removeIdentifierAndFail(structContext, identifier);
  }

  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
    return removeIdentifierAndFail(structContext, identifier);

  if (!identifier.empty()) {
    if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
                                     memberDecorationInfo)))
      return Type();

    structContext.remove(identifier);
    return idStructTy;
  }

  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
}

// spirv-type ::= array-type
//              | element-type
//              | image-type
//              | pointer-type
//              | runtime-array-type
//              | sampled-image-type
//              | struct-type
Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
  StringRef keyword;
  if (parser.parseKeyword(&keyword))
    return Type();

  if (keyword == "array")
    return parseArrayType(*this, parser);
  if (keyword == "coopmatrix")
    return parseCooperativeMatrixType(*this, parser);
  if (keyword == "image")
    return parseImageType(*this, parser);
  if (keyword == "ptr")
    return parsePointerType(*this, parser);
  if (keyword == "rtarray")
    return parseRuntimeArrayType(*this, parser);
  if (keyword == "sampled_image")
    return parseSampledImageType(*this, parser);
  if (keyword == "struct")
    return parseStructType(*this, parser);
  if (keyword == "matrix")
    return parseMatrixType(*this, parser);
  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
  return Type();
}

//===----------------------------------------------------------------------===//
// Type Printing
//===----------------------------------------------------------------------===//

static void print(ArrayType type, DialectAsmPrinter &os) {
  os << "array<" << type.getNumElements() << " x " << type.getElementType();
  if (unsigned stride = type.getArrayStride())
    os << ", stride=" << stride;
  os << ">";
}

static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
  os << "rtarray<" << type.getElementType();
  if (unsigned stride = type.getArrayStride())
    os << ", stride=" << stride;
  os << ">";
}

static void print(PointerType type, DialectAsmPrinter &os) {
  os << "ptr<" << type.getPointeeType() << ", "
     << stringifyStorageClass(type.getStorageClass()) << ">";
}

static void print(ImageType type, DialectAsmPrinter &os) {
  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
     << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
     << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
     << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
     << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
     << stringifyImageFormat(type.getImageFormat()) << ">";
}

static void print(SampledImageType type, DialectAsmPrinter &os) {
  os << "sampled_image<" << type.getImageType() << ">";
}

static void print(StructType type, DialectAsmPrinter &os) {
  thread_local SetVector<StringRef> structContext;

  os << "struct<";

  if (type.isIdentified()) {
    os << type.getIdentifier();

    if (structContext.count(type.getIdentifier())) {
      os << ">";
      return;
    }

    os << ", ";
    structContext.insert(type.getIdentifier());
  }

  os << "(";

  auto printMember = [&](unsigned i) {
    os << type.getElementType(i);
    SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
    type.getMemberDecorations(i, decorations);
    if (type.hasOffset() || !decorations.empty()) {
      os << " [";
      if (type.hasOffset()) {
        os << type.getMemberOffset(i);
        if (!decorations.empty())
          os << ", ";
      }
      auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
        os << stringifyDecoration(decoration.decoration);
        if (decoration.hasValue) {
          os << "=" << decoration.decorationValue;
        }
      };
      llvm::interleaveComma(decorations, os, eachFn);
      os << "]";
    }
  };
  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
                        printMember);
  os << ")>";

  if (type.isIdentified())
    structContext.remove(type.getIdentifier());
}

static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
  os << type.getElementType() << ", " << stringifyScope(type.getScope());
  os << ">";
}

static void print(MatrixType type, DialectAsmPrinter &os) {
  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
  os << ">";
}

void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
  TypeSwitch<Type>(type)
      .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
            ImageType, SampledImageType, StructType, MatrixType>(
          [&](auto type) { print(type, os); })
      .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}

//===----------------------------------------------------------------------===//
// Attribute Parsing
//===----------------------------------------------------------------------===//

/// Parses a comma-separated list of keywords, invokes `processKeyword` on each
/// of the parsed keyword, and returns failure if any error occurs.
static ParseResult parseKeywordList(
    DialectAsmParser &parser,
    function_ref<LogicalResult(llvm::SMLoc, StringRef)> processKeyword) {
  if (parser.parseLSquare())
    return failure();

  // Special case for empty list.
  if (succeeded(parser.parseOptionalRSquare()))
    return success();

  // Keep parsing the keyword and an optional comma following it. If the comma
  // is successfully parsed, then we have more keywords to parse.
  do {
    auto loc = parser.getCurrentLocation();
    StringRef keyword;
    if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword)))
      return failure();
  } while (succeeded(parser.parseOptionalComma()));

  if (parser.parseRSquare())
    return failure();

  return success();
}

/// Parses a spirv::InterfaceVarABIAttr.
static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
  if (parser.parseLess())
    return {};

  Builder &builder = parser.getBuilder();

  if (parser.parseLParen())
    return {};

  IntegerAttr descriptorSetAttr;
  {
    auto loc = parser.getCurrentLocation();
    uint32_t descriptorSet = 0;
    auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);

    if (!descriptorSetParseResult.hasValue() ||
        failed(*descriptorSetParseResult)) {
      parser.emitError(loc, "missing descriptor set");
      return {};
    }
    descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
  }

  if (parser.parseComma())
    return {};

  IntegerAttr bindingAttr;
  {
    auto loc = parser.getCurrentLocation();
    uint32_t binding = 0;
    auto bindingParseResult = parser.parseOptionalInteger(binding);

    if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
      parser.emitError(loc, "missing binding");
      return {};
    }
    bindingAttr = builder.getI32IntegerAttr(binding);
  }

  if (parser.parseRParen())
    return {};

  IntegerAttr storageClassAttr;
  {
    if (succeeded(parser.parseOptionalComma())) {
      auto loc = parser.getCurrentLocation();
      StringRef storageClass;
      if (parser.parseKeyword(&storageClass))
        return {};

      if (auto storageClassSymbol =
              spirv::symbolizeStorageClass(storageClass)) {
        storageClassAttr = builder.getI32IntegerAttr(
            static_cast<uint32_t>(*storageClassSymbol));
      } else {
        parser.emitError(loc, "unknown storage class: ") << storageClass;
        return {};
      }
    }
  }

  if (parser.parseGreater())
    return {};

  return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
                                         storageClassAttr);
}

static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
  if (parser.parseLess())
    return {};

  Builder &builder = parser.getBuilder();

  IntegerAttr versionAttr;
  {
    auto loc = parser.getCurrentLocation();
    StringRef version;
    if (parser.parseKeyword(&version) || parser.parseComma())
      return {};

    if (auto versionSymbol = spirv::symbolizeVersion(version)) {
      versionAttr =
          builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
    } else {
      parser.emitError(loc, "unknown version: ") << version;
      return {};
    }
  }

  ArrayAttr capabilitiesAttr;
  {
    SmallVector<Attribute, 4> capabilities;
    llvm::SMLoc errorloc;
    StringRef errorKeyword;

    auto processCapability = [&](llvm::SMLoc loc, StringRef capability) {
      if (auto capSymbol = spirv::symbolizeCapability(capability)) {
        capabilities.push_back(
            builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol)));
        return success();
      }
      return errorloc = loc, errorKeyword = capability, failure();
    };
    if (parseKeywordList(parser, processCapability) || parser.parseComma()) {
      if (!errorKeyword.empty())
        parser.emitError(errorloc, "unknown capability: ") << errorKeyword;
      return {};
    }

    capabilitiesAttr = builder.getArrayAttr(capabilities);
  }

  ArrayAttr extensionsAttr;
  {
    SmallVector<Attribute, 1> extensions;
    llvm::SMLoc errorloc;
    StringRef errorKeyword;

    auto processExtension = [&](llvm::SMLoc loc, StringRef extension) {
      if (spirv::symbolizeExtension(extension)) {
        extensions.push_back(builder.getStringAttr(extension));
        return success();
      }
      return errorloc = loc, errorKeyword = extension, failure();
    };
    if (parseKeywordList(parser, processExtension)) {
      if (!errorKeyword.empty())
        parser.emitError(errorloc, "unknown extension: ") << errorKeyword;
      return {};
    }

    extensionsAttr = builder.getArrayAttr(extensions);
  }

  if (parser.parseGreater())
    return {};

  return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr,
                                   extensionsAttr);
}

/// Parses a spirv::TargetEnvAttr.
static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
  if (parser.parseLess())
    return {};

  spirv::VerCapExtAttr tripleAttr;
  if (parser.parseAttribute(tripleAttr) || parser.parseComma())
    return {};

  // Parse [vendor[:device-type[:device-id]]]
  Vendor vendorID = Vendor::Unknown;
  DeviceType deviceType = DeviceType::Unknown;
  uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
  {
    auto loc = parser.getCurrentLocation();
    StringRef vendorStr;
    if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
      if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
        vendorID = *vendorSymbol;
      } else {
        parser.emitError(loc, "unknown vendor: ") << vendorStr;
      }

      if (succeeded(parser.parseOptionalColon())) {
        loc = parser.getCurrentLocation();
        StringRef deviceTypeStr;
        if (parser.parseKeyword(&deviceTypeStr))
          return {};
        if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
          deviceType = *deviceTypeSymbol;
        } else {
          parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
        }

        if (succeeded(parser.parseOptionalColon())) {
          loc = parser.getCurrentLocation();
          if (parser.parseInteger(deviceID))
            return {};
        }
      }
      if (parser.parseComma())
        return {};
    }
  }

  DictionaryAttr limitsAttr;
  {
    auto loc = parser.getCurrentLocation();
    if (parser.parseAttribute(limitsAttr))
      return {};

    if (!limitsAttr.isa<spirv::ResourceLimitsAttr>()) {
      parser.emitError(
          loc,
          "limits must be a dictionary attribute containing two 32-bit integer "
          "attributes 'max_compute_workgroup_invocations' and "
          "'max_compute_workgroup_size'");
      return {};
    }
  }

  if (parser.parseGreater())
    return {};

  return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
                                   limitsAttr);
}

Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
                                       Type type) const {
  // SPIR-V attributes are dictionaries so they do not have type.
  if (type) {
    parser.emitError(parser.getNameLoc(), "unexpected type");
    return {};
  }

  // Parse the kind keyword first.
  StringRef attrKind;
  if (parser.parseKeyword(&attrKind))
    return {};

  if (attrKind == spirv::TargetEnvAttr::getKindName())
    return parseTargetEnvAttr(parser);
  if (attrKind == spirv::VerCapExtAttr::getKindName())
    return parseVerCapExtAttr(parser);
  if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
    return parseInterfaceVarABIAttr(parser);

  parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
      << attrKind;
  return {};
}

//===----------------------------------------------------------------------===//
// Attribute Printing
//===----------------------------------------------------------------------===//

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
  auto &os = printer.getStream();
  printer << spirv::VerCapExtAttr::getKindName() << "<"
          << spirv::stringifyVersion(triple.getVersion()) << ", [";
  llvm::interleaveComma(
      triple.getCapabilities(), os,
      [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
  printer << "], [";
  llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
    os << attr.cast<StringAttr>().getValue();
  });
  printer << "]>";
}

static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
  printer << spirv::TargetEnvAttr::getKindName() << "<#spv.";
  print(targetEnv.getTripleAttr(), printer);
  spirv::Vendor vendorID = targetEnv.getVendorID();
  spirv::DeviceType deviceType = targetEnv.getDeviceType();
  uint32_t deviceID = targetEnv.getDeviceID();
  if (vendorID != spirv::Vendor::Unknown) {
    printer << ", " << spirv::stringifyVendor(vendorID);
    if (deviceType != spirv::DeviceType::Unknown) {
      printer << ":" << spirv::stringifyDeviceType(deviceType);
      if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID)
        printer << ":" << deviceID;
    }
  }
  printer << ", " << targetEnv.getResourceLimits() << ">";
}

static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
                  DialectAsmPrinter &printer) {
  printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
          << interfaceVarABIAttr.getDescriptorSet() << ", "
          << interfaceVarABIAttr.getBinding() << ")";
  auto storageClass = interfaceVarABIAttr.getStorageClass();
  if (storageClass)
    printer << ", " << spirv::stringifyStorageClass(*storageClass);
  printer << ">";
}

void SPIRVDialect::printAttribute(Attribute attr,
                                  DialectAsmPrinter &printer) const {
  if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
    print(targetEnv, printer);
  else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
    print(vceAttr, printer);
  else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
    print(interfaceVarABIAttr, printer);
  else
    llvm_unreachable("unhandled SPIR-V attribute kind");
}

//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//

Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
                                             Attribute value, Type type,
                                             Location loc) {
  if (!spirv::ConstantOp::isBuildableWith(type))
    return nullptr;

  return builder.create<spirv::ConstantOp>(loc, type, value);
}

//===----------------------------------------------------------------------===//
// Shader Interface ABI
//===----------------------------------------------------------------------===//

LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
                                                     NamedAttribute attribute) {
  StringRef symbol = attribute.getName().strref();
  Attribute attr = attribute.getValue();

  // TODO: figure out a way to generate the description from the
  // StructAttr definition.
  if (symbol == spirv::getEntryPointABIAttrName()) {
    if (!attr.isa<spirv::EntryPointABIAttr>())
      return op->emitError("'")
             << symbol
             << "' attribute must be a dictionary attribute containing one "
                "32-bit integer elements attribute: 'local_size'";
  } else if (symbol == spirv::getTargetEnvAttrName()) {
    if (!attr.isa<spirv::TargetEnvAttr>())
      return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
  } else {
    return op->emitError("found unsupported '")
           << symbol << "' attribute on operation";
  }

  return success();
}

/// Verifies the given SPIR-V `attribute` attached to a value of the given
/// `valueType` is valid.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
                                           NamedAttribute attribute) {
  StringRef symbol = attribute.getName().strref();
  Attribute attr = attribute.getValue();

  if (symbol != spirv::getInterfaceVarABIAttrName())
    return emitError(loc, "found unsupported '")
           << symbol << "' attribute on region argument";

  auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
  if (!varABIAttr)
    return emitError(loc, "'")
           << symbol << "' must be a spirv::InterfaceVarABIAttr";

  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
    return emitError(loc, "'") << symbol
                               << "' attribute cannot specify storage class "
                                  "when attaching to a non-scalar value";

  return success();
}

LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
                                                     unsigned regionIndex,
                                                     unsigned argIndex,
                                                     NamedAttribute attribute) {
  return verifyRegionAttribute(
      op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
      attribute);
}

LogicalResult SPIRVDialect::verifyRegionResultAttribute(
    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
    NamedAttribute attribute) {
  return op->emitError("cannot attach SPIR-V attributes to region result");
}
