//===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the parser for the MLIR Types.
//
//===----------------------------------------------------------------------===//

#include "Parser.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Parser/AsmParserState.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"

using namespace mlir;
using namespace mlir::detail;

/// Parse an arbitrary attribute.
///
///  attribute-value ::= `unit`
///                    | bool-literal
///                    | integer-literal (`:` (index-type | integer-type))?
///                    | float-literal (`:` float-type)?
///                    | string-literal (`:` type)?
///                    | type
///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
///                    | symbol-ref-id (`::` symbol-ref-id)*
///                    | `dense` `<` attribute-value `>` `:`
///                      (tensor-type | vector-type)
///                    | `sparse` `<` attribute-value `,` attribute-value `>`
///                      `:` (tensor-type | vector-type)
///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
///                      `>` `:` (tensor-type | vector-type)
///                    | extended-attribute
///
Attribute Parser::parseAttribute(Type type) {
  switch (getToken().getKind()) {
  // Parse an AffineMap or IntegerSet attribute.
  case Token::kw_affine_map: {
    consumeToken(Token::kw_affine_map);

    AffineMap map;
    if (parseToken(Token::less, "expected '<' in affine map") ||
        parseAffineMapReference(map) ||
        parseToken(Token::greater, "expected '>' in affine map"))
      return Attribute();
    return AffineMapAttr::get(map);
  }
  case Token::kw_affine_set: {
    consumeToken(Token::kw_affine_set);

    IntegerSet set;
    if (parseToken(Token::less, "expected '<' in integer set") ||
        parseIntegerSetReference(set) ||
        parseToken(Token::greater, "expected '>' in integer set"))
      return Attribute();
    return IntegerSetAttr::get(set);
  }

  // Parse an array attribute.
  case Token::l_square: {
    SmallVector<Attribute, 4> elements;
    auto parseElt = [&]() -> ParseResult {
      elements.push_back(parseAttribute());
      return elements.back() ? success() : failure();
    };

    if (parseCommaSeparatedList(Delimiter::Square, parseElt))
      return nullptr;
    return builder.getArrayAttr(elements);
  }

  // Parse a boolean attribute.
  case Token::kw_false:
    consumeToken(Token::kw_false);
    return builder.getBoolAttr(false);
  case Token::kw_true:
    consumeToken(Token::kw_true);
    return builder.getBoolAttr(true);

  // Parse a dense elements attribute.
  case Token::kw_dense:
    return parseDenseElementsAttr(type);

  // Parse a dictionary attribute.
  case Token::l_brace: {
    NamedAttrList elements;
    if (parseAttributeDict(elements))
      return nullptr;
    return elements.getDictionary(getContext());
  }

  // Parse an extended attribute, i.e. alias or dialect attribute.
  case Token::hash_identifier:
    return parseExtendedAttr(type);

  // Parse floating point and integer attributes.
  case Token::floatliteral:
    return parseFloatAttr(type, /*isNegative=*/false);
  case Token::integer:
    return parseDecOrHexAttr(type, /*isNegative=*/false);
  case Token::minus: {
    consumeToken(Token::minus);
    if (getToken().is(Token::integer))
      return parseDecOrHexAttr(type, /*isNegative=*/true);
    if (getToken().is(Token::floatliteral))
      return parseFloatAttr(type, /*isNegative=*/true);

    return (emitError("expected constant integer or floating point value"),
            nullptr);
  }

  // Parse a location attribute.
  case Token::kw_loc: {
    consumeToken(Token::kw_loc);

    LocationAttr locAttr;
    if (parseToken(Token::l_paren, "expected '(' in inline location") ||
        parseLocationInstance(locAttr) ||
        parseToken(Token::r_paren, "expected ')' in inline location"))
      return Attribute();
    return locAttr;
  }

  // Parse an opaque elements attribute.
  case Token::kw_opaque:
    return parseOpaqueElementsAttr(type);

  // Parse a sparse elements attribute.
  case Token::kw_sparse:
    return parseSparseElementsAttr(type);

  // Parse a string attribute.
  case Token::string: {
    auto val = getToken().getStringValue();
    consumeToken(Token::string);
    // Parse the optional trailing colon type if one wasn't explicitly provided.
    if (!type && consumeIf(Token::colon) && !(type = parseType()))
      return Attribute();

    return type ? StringAttr::get(val, type)
                : StringAttr::get(getContext(), val);
  }

  // Parse a symbol reference attribute.
  case Token::at_identifier: {
    // When populating the parser state, this is a list of locations for all of
    // the nested references.
    SmallVector<llvm::SMRange> referenceLocations;
    if (state.asmState)
      referenceLocations.push_back(getToken().getLocRange());

    // Parse the top-level reference.
    std::string nameStr = getToken().getSymbolReference();
    consumeToken(Token::at_identifier);

    // Parse any nested references.
    std::vector<FlatSymbolRefAttr> nestedRefs;
    while (getToken().is(Token::colon)) {
      // Check for the '::' prefix.
      const char *curPointer = getToken().getLoc().getPointer();
      consumeToken(Token::colon);
      if (!consumeIf(Token::colon)) {
        state.lex.resetPointer(curPointer);
        consumeToken();
        break;
      }
      // Parse the reference itself.
      auto curLoc = getToken().getLoc();
      if (getToken().isNot(Token::at_identifier)) {
        emitError(curLoc, "expected nested symbol reference identifier");
        return Attribute();
      }

      // If we are populating the assembly state, add the location for this
      // reference.
      if (state.asmState)
        referenceLocations.push_back(getToken().getLocRange());

      std::string nameStr = getToken().getSymbolReference();
      consumeToken(Token::at_identifier);
      nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
    }
    SymbolRefAttr symbolRefAttr =
        SymbolRefAttr::get(getContext(), nameStr, nestedRefs);

    // If we are populating the assembly state, record this symbol reference.
    if (state.asmState)
      state.asmState->addUses(symbolRefAttr, referenceLocations);
    return symbolRefAttr;
  }

  // Parse a 'unit' attribute.
  case Token::kw_unit:
    consumeToken(Token::kw_unit);
    return builder.getUnitAttr();

  default:
    // Parse a type attribute.
    if (Type type = parseType())
      return TypeAttr::get(type);
    return nullptr;
  }
}

/// Parse an optional attribute with the provided type.
OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
                                                   Type type) {
  switch (getToken().getKind()) {
  case Token::at_identifier:
  case Token::floatliteral:
  case Token::integer:
  case Token::hash_identifier:
  case Token::kw_affine_map:
  case Token::kw_affine_set:
  case Token::kw_dense:
  case Token::kw_false:
  case Token::kw_loc:
  case Token::kw_opaque:
  case Token::kw_sparse:
  case Token::kw_true:
  case Token::kw_unit:
  case Token::l_brace:
  case Token::l_square:
  case Token::minus:
  case Token::string:
    attribute = parseAttribute(type);
    return success(attribute != nullptr);

  default:
    // Parse an optional type attribute.
    Type type;
    OptionalParseResult result = parseOptionalType(type);
    if (result.hasValue() && succeeded(*result))
      attribute = TypeAttr::get(type);
    return result;
  }
}
OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
                                                   Type type) {
  return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
}
OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
                                                   Type type) {
  return parseOptionalAttributeWithToken(Token::string, attribute, type);
}

/// Attribute dictionary.
///
///   attribute-dict ::= `{` `}`
///                    | `{` attribute-entry (`,` attribute-entry)* `}`
///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
  llvm::SmallDenseSet<StringAttr> seenKeys;
  auto parseElt = [&]() -> ParseResult {
    // The name of an attribute can either be a bare identifier, or a string.
    Optional<StringAttr> nameId;
    if (getToken().is(Token::string))
      nameId = builder.getStringAttr(getToken().getStringValue());
    else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
             getToken().isKeyword())
      nameId = builder.getStringAttr(getTokenSpelling());
    else
      return emitError("expected attribute name");
    if (!seenKeys.insert(*nameId).second)
      return emitError("duplicate key '")
             << nameId->getValue() << "' in dictionary attribute";
    consumeToken();

    // Lazy load a dialect in the context if there is a possible namespace.
    auto splitName = nameId->strref().split('.');
    if (!splitName.second.empty())
      getContext()->getOrLoadDialect(splitName.first);

    // Try to parse the '=' for the attribute value.
    if (!consumeIf(Token::equal)) {
      // If there is no '=', we treat this as a unit attribute.
      attributes.push_back({*nameId, builder.getUnitAttr()});
      return success();
    }

    auto attr = parseAttribute();
    if (!attr)
      return failure();
    attributes.push_back({*nameId, attr});
    return success();
  };

  if (parseCommaSeparatedList(Delimiter::Braces, parseElt,
                              " in attribute dictionary"))
    return failure();

  return success();
}

/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
  auto val = getToken().getFloatingPointValue();
  if (!val.hasValue())
    return (emitError("floating point value too large for attribute"), nullptr);
  consumeToken(Token::floatliteral);
  if (!type) {
    // Default to F64 when no type is specified.
    if (!consumeIf(Token::colon))
      type = builder.getF64Type();
    else if (!(type = parseType()))
      return nullptr;
  }
  if (!type.isa<FloatType>())
    return (emitError("floating point value not valid for specified type"),
            nullptr);
  return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
}

/// Construct an APint from a parsed value, a known attribute type and
/// sign.
static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
                                           StringRef spelling) {
  // Parse the integer value into an APInt that is big enough to hold the value.
  APInt result;
  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
  if (spelling.getAsInteger(isHex ? 0 : 10, result))
    return llvm::None;

  // Extend or truncate the bitwidth to the right size.
  unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
                                  : type.getIntOrFloatBitWidth();

  // APInt cannot hold a zero bit value.
  if (width == 0)
    return llvm::None;

  if (width > result.getBitWidth()) {
    result = result.zext(width);
  } else if (width < result.getBitWidth()) {
    // The parser can return an unnecessarily wide result with leading zeros.
    // This isn't a problem, but truncating off bits is bad.
    if (result.countLeadingZeros() < result.getBitWidth() - width)
      return llvm::None;

    result = result.trunc(width);
  }

  if (isNegative) {
    // The value is negative, we have an overflow if the sign bit is not set
    // in the negated apInt.
    result.negate();
    if (!result.isSignBitSet())
      return llvm::None;
  } else if ((type.isSignedInteger() || type.isIndex()) &&
             result.isSignBitSet()) {
    // The value is a positive signed integer or index,
    // we have an overflow if the sign bit is set.
    return llvm::None;
  }

  return result;
}

/// Parse a decimal or a hexadecimal literal, which can be either an integer
/// or a float attribute.
Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
  Token tok = getToken();
  StringRef spelling = tok.getSpelling();
  llvm::SMLoc loc = tok.getLoc();

  consumeToken(Token::integer);
  if (!type) {
    // Default to i64 if not type is specified.
    if (!consumeIf(Token::colon))
      type = builder.getIntegerType(64);
    else if (!(type = parseType()))
      return nullptr;
  }

  if (auto floatType = type.dyn_cast<FloatType>()) {
    Optional<APFloat> result;
    if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
                                            floatType.getFloatSemantics(),
                                            floatType.getWidth())))
      return Attribute();
    return FloatAttr::get(floatType, *result);
  }

  if (!type.isa<IntegerType, IndexType>())
    return emitError(loc, "integer literal not valid for specified type"),
           nullptr;

  if (isNegative && type.isUnsignedInteger()) {
    emitError(loc,
              "negative integer literal not valid for unsigned integer type");
    return nullptr;
  }

  Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
  if (!apInt)
    return emitError(loc, "integer constant out of range for attribute"),
           nullptr;
  return builder.getIntegerAttr(type, *apInt);
}

//===----------------------------------------------------------------------===//
// TensorLiteralParser
//===----------------------------------------------------------------------===//

/// Parse elements values stored within a hex string. On success, the values are
/// stored into 'result'.
static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
                                             std::string &result) {
  if (Optional<std::string> value = tok.getHexStringValue()) {
    result = std::move(*value);
    return success();
  }
  return parser.emitError(
      tok.getLoc(), "expected string containing hex digits starting with `0x`");
}

namespace {
/// This class implements a parser for TensorLiterals. A tensor literal is
/// either a single element (e.g, 5) or a multi-dimensional list of elements
/// (e.g., [[5, 5]]).
class TensorLiteralParser {
public:
  TensorLiteralParser(Parser &p) : p(p) {}

  /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
  /// may also parse a tensor literal that is store as a hex string.
  ParseResult parse(bool allowHex);

  /// Build a dense attribute instance with the parsed elements and the given
  /// shaped type.
  DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);

  ArrayRef<int64_t> getShape() const { return shape; }

private:
  /// Get the parsed elements for an integer attribute.
  ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
                                 std::vector<APInt> &intValues);

  /// Get the parsed elements for a float attribute.
  ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
                                   std::vector<APFloat> &floatValues);

  /// Build a Dense String attribute for the given type.
  DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);

  /// Build a Dense attribute with hex data for the given type.
  DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);

  /// Parse a single element, returning failure if it isn't a valid element
  /// literal. For example:
  /// parseElement(1) -> Success, 1
  /// parseElement([1]) -> Failure
  ParseResult parseElement();

  /// Parse a list of either lists or elements, returning the dimensions of the
  /// parsed sub-tensors in dims. For example:
  ///   parseList([1, 2, 3]) -> Success, [3]
  ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
  ///   parseList([[1, 2], 3]) -> Failure
  ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
  ParseResult parseList(SmallVectorImpl<int64_t> &dims);

  /// Parse a literal that was printed as a hex string.
  ParseResult parseHexElements();

  Parser &p;

  /// The shape inferred from the parsed elements.
  SmallVector<int64_t, 4> shape;

  /// Storage used when parsing elements, this is a pair of <is_negated, token>.
  std::vector<std::pair<bool, Token>> storage;

  /// Storage used when parsing elements that were stored as hex values.
  Optional<Token> hexStorage;
};
} // namespace

/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
/// may also parse a tensor literal that is store as a hex string.
ParseResult TensorLiteralParser::parse(bool allowHex) {
  // If hex is allowed, check for a string literal.
  if (allowHex && p.getToken().is(Token::string)) {
    hexStorage = p.getToken();
    p.consumeToken(Token::string);
    return success();
  }
  // Otherwise, parse a list or an individual element.
  if (p.getToken().is(Token::l_square))
    return parseList(shape);
  return parseElement();
}

/// Build a dense attribute instance with the parsed elements and the given
/// shaped type.
DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
                                               ShapedType type) {
  Type eltType = type.getElementType();

  // Check to see if we parse the literal from a hex string.
  if (hexStorage.hasValue() &&
      (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
    return getHexAttr(loc, type);

  // Check that the parsed storage size has the same number of elements to the
  // type, or is a known splat.
  if (!shape.empty() && getShape() != type.getShape()) {
    p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
                     << "]) does not match type ([" << type.getShape() << "])";
    return nullptr;
  }

  // Handle the case where no elements were parsed.
  if (!hexStorage.hasValue() && storage.empty() && type.getNumElements()) {
    p.emitError(loc) << "parsed zero elements, but type (" << type
                     << ") expected at least 1";
    return nullptr;
  }

  // Handle complex types in the specific element type cases below.
  bool isComplex = false;
  if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
    eltType = complexTy.getElementType();
    isComplex = true;
  }

  // Handle integer and index types.
  if (eltType.isIntOrIndex()) {
    std::vector<APInt> intValues;
    if (failed(getIntAttrElements(loc, eltType, intValues)))
      return nullptr;
    if (isComplex) {
      // If this is a complex, treat the parsed values as complex values.
      auto complexData = llvm::makeArrayRef(
          reinterpret_cast<std::complex<APInt> *>(intValues.data()),
          intValues.size() / 2);
      return DenseElementsAttr::get(type, complexData);
    }
    return DenseElementsAttr::get(type, intValues);
  }
  // Handle floating point types.
  if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
    std::vector<APFloat> floatValues;
    if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
      return nullptr;
    if (isComplex) {
      // If this is a complex, treat the parsed values as complex values.
      auto complexData = llvm::makeArrayRef(
          reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
          floatValues.size() / 2);
      return DenseElementsAttr::get(type, complexData);
    }
    return DenseElementsAttr::get(type, floatValues);
  }

  // Other types are assumed to be string representations.
  return getStringAttr(loc, type, type.getElementType());
}

/// Build a Dense Integer attribute for the given type.
ParseResult
TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
                                        std::vector<APInt> &intValues) {
  intValues.reserve(storage.size());
  bool isUintType = eltTy.isUnsignedInteger();
  for (const auto &signAndToken : storage) {
    bool isNegative = signAndToken.first;
    const Token &token = signAndToken.second;
    auto tokenLoc = token.getLoc();

    if (isNegative && isUintType) {
      return p.emitError(tokenLoc)
             << "expected unsigned integer elements, but parsed negative value";
    }

    // Check to see if floating point values were parsed.
    if (token.is(Token::floatliteral)) {
      return p.emitError(tokenLoc)
             << "expected integer elements, but parsed floating-point";
    }

    assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
           "unexpected token type");
    if (token.isAny(Token::kw_true, Token::kw_false)) {
      if (!eltTy.isInteger(1)) {
        return p.emitError(tokenLoc)
               << "expected i1 type for 'true' or 'false' values";
      }
      APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
      intValues.push_back(apInt);
      continue;
    }

    // Create APInt values for each element with the correct bitwidth.
    Optional<APInt> apInt =
        buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
    if (!apInt)
      return p.emitError(tokenLoc, "integer constant out of range for type");
    intValues.push_back(*apInt);
  }
  return success();
}

/// Build a Dense Float attribute for the given type.
ParseResult
TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
                                          std::vector<APFloat> &floatValues) {
  floatValues.reserve(storage.size());
  for (const auto &signAndToken : storage) {
    bool isNegative = signAndToken.first;
    const Token &token = signAndToken.second;

    // Handle hexadecimal float literals.
    if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
      Optional<APFloat> result;
      if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
                                                eltTy.getFloatSemantics(),
                                                eltTy.getWidth())))
        return failure();

      floatValues.push_back(*result);
      continue;
    }

    // Check to see if any decimal integers or booleans were parsed.
    if (!token.is(Token::floatliteral))
      return p.emitError()
             << "expected floating-point elements, but parsed integer";

    // Build the float values from tokens.
    auto val = token.getFloatingPointValue();
    if (!val.hasValue())
      return p.emitError("floating point value too large for attribute");

    APFloat apVal(isNegative ? -*val : *val);
    if (!eltTy.isF64()) {
      bool unused;
      apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
                    &unused);
    }
    floatValues.push_back(apVal);
  }
  return success();
}

/// Build a Dense String attribute for the given type.
DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
                                                     ShapedType type,
                                                     Type eltTy) {
  if (hexStorage.hasValue()) {
    auto stringValue = hexStorage.getValue().getStringValue();
    return DenseStringElementsAttr::get(type, {stringValue});
  }

  std::vector<std::string> stringValues;
  std::vector<StringRef> stringRefValues;
  stringValues.reserve(storage.size());
  stringRefValues.reserve(storage.size());

  for (auto val : storage) {
    stringValues.push_back(val.second.getStringValue());
    stringRefValues.emplace_back(stringValues.back());
  }

  return DenseStringElementsAttr::get(type, stringRefValues);
}

/// Build a Dense attribute with hex data for the given type.
DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
                                                  ShapedType type) {
  Type elementType = type.getElementType();
  if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
    p.emitError(loc)
        << "expected floating-point, integer, or complex element type, got "
        << elementType;
    return nullptr;
  }

  std::string data;
  if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
    return nullptr;

  ArrayRef<char> rawData(data.data(), data.size());
  bool detectedSplat = false;
  if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
    p.emitError(loc) << "elements hex data size is invalid for provided type: "
                     << type;
    return nullptr;
  }

  if (llvm::support::endian::system_endianness() ==
      llvm::support::endianness::big) {
    // Convert endianess in big-endian(BE) machines. `rawData` is
    // little-endian(LE) because HEX in raw data of dense element attribute
    // is always LE format. It is converted into BE here to be used in BE
    // machines.
    SmallVector<char, 64> outDataVec(rawData.size());
    MutableArrayRef<char> convRawData(outDataVec);
    DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
        rawData, convRawData, type);
    return DenseElementsAttr::getFromRawBuffer(type, convRawData,
                                               detectedSplat);
  }

  return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
}

ParseResult TensorLiteralParser::parseElement() {
  switch (p.getToken().getKind()) {
  // Parse a boolean element.
  case Token::kw_true:
  case Token::kw_false:
  case Token::floatliteral:
  case Token::integer:
    storage.emplace_back(/*isNegative=*/false, p.getToken());
    p.consumeToken();
    break;

  // Parse a signed integer or a negative floating-point element.
  case Token::minus:
    p.consumeToken(Token::minus);
    if (!p.getToken().isAny(Token::floatliteral, Token::integer))
      return p.emitError("expected integer or floating point literal");
    storage.emplace_back(/*isNegative=*/true, p.getToken());
    p.consumeToken();
    break;

  case Token::string:
    storage.emplace_back(/*isNegative=*/false, p.getToken());
    p.consumeToken();
    break;

  // Parse a complex element of the form '(' element ',' element ')'.
  case Token::l_paren:
    p.consumeToken(Token::l_paren);
    if (parseElement() ||
        p.parseToken(Token::comma, "expected ',' between complex elements") ||
        parseElement() ||
        p.parseToken(Token::r_paren, "expected ')' after complex elements"))
      return failure();
    break;

  default:
    return p.emitError("expected element literal of primitive type");
  }

  return success();
}

/// Parse a list of either lists or elements, returning the dimensions of the
/// parsed sub-tensors in dims. For example:
///   parseList([1, 2, 3]) -> Success, [3]
///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
///   parseList([[1, 2], 3]) -> Failure
///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
  auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
                       const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
    if (prevDims == newDims)
      return success();
    return p.emitError("tensor literal is invalid; ranks are not consistent "
                       "between elements");
  };

  bool first = true;
  SmallVector<int64_t, 4> newDims;
  unsigned size = 0;
  auto parseOneElement = [&]() -> ParseResult {
    SmallVector<int64_t, 4> thisDims;
    if (p.getToken().getKind() == Token::l_square) {
      if (parseList(thisDims))
        return failure();
    } else if (parseElement()) {
      return failure();
    }
    ++size;
    if (!first)
      return checkDims(newDims, thisDims);
    newDims = thisDims;
    first = false;
    return success();
  };
  if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
    return failure();

  // Return the sublists' dimensions with 'size' prepended.
  dims.clear();
  dims.push_back(size);
  dims.append(newDims.begin(), newDims.end());
  return success();
}

//===----------------------------------------------------------------------===//
// ElementsAttr Parser
//===----------------------------------------------------------------------===//

/// Parse a dense elements attribute.
Attribute Parser::parseDenseElementsAttr(Type attrType) {
  auto attribLoc = getToken().getLoc();
  consumeToken(Token::kw_dense);
  if (parseToken(Token::less, "expected '<' after 'dense'"))
    return nullptr;

  // Parse the literal data if necessary.
  TensorLiteralParser literalParser(*this);
  if (!consumeIf(Token::greater)) {
    if (literalParser.parse(/*allowHex=*/true) ||
        parseToken(Token::greater, "expected '>'"))
      return nullptr;
  }

  // If the type is specified `parseElementsLiteralType` will not parse a type.
  // Use the attribute location as the location for error reporting in that
  // case.
  auto loc = attrType ? attribLoc : getToken().getLoc();
  auto type = parseElementsLiteralType(attrType);
  if (!type)
    return nullptr;
  return literalParser.getAttr(loc, type);
}

/// Parse an opaque elements attribute.
Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
  consumeToken(Token::kw_opaque);
  if (parseToken(Token::less, "expected '<' after 'opaque'"))
    return nullptr;

  if (getToken().isNot(Token::string))
    return (emitError("expected dialect namespace"), nullptr);

  std::string name = getToken().getStringValue();
  consumeToken(Token::string);

  if (parseToken(Token::comma, "expected ','"))
    return nullptr;

  Token hexTok = getToken();
  if (parseToken(Token::string, "elements hex string should start with '0x'") ||
      parseToken(Token::greater, "expected '>'"))
    return nullptr;
  auto type = parseElementsLiteralType(attrType);
  if (!type)
    return nullptr;

  std::string data;
  if (parseElementAttrHexValues(*this, hexTok, data))
    return nullptr;
  return OpaqueElementsAttr::get(builder.getStringAttr(name), type, data);
}

/// Shaped type for elements attribute.
///
///   elements-literal-type ::= vector-type | ranked-tensor-type
///
/// This method also checks the type has static shape.
ShapedType Parser::parseElementsLiteralType(Type type) {
  // If the user didn't provide a type, parse the colon type for the literal.
  if (!type) {
    if (parseToken(Token::colon, "expected ':'"))
      return nullptr;
    if (!(type = parseType()))
      return nullptr;
  }

  if (!type.isa<RankedTensorType, VectorType>()) {
    emitError("elements literal must be a ranked tensor or vector type");
    return nullptr;
  }

  auto sType = type.cast<ShapedType>();
  if (!sType.hasStaticShape())
    return (emitError("elements literal type must have static shape"), nullptr);

  return sType;
}

/// Parse a sparse elements attribute.
Attribute Parser::parseSparseElementsAttr(Type attrType) {
  llvm::SMLoc loc = getToken().getLoc();
  consumeToken(Token::kw_sparse);
  if (parseToken(Token::less, "Expected '<' after 'sparse'"))
    return nullptr;

  // Check for the case where all elements are sparse. The indices are
  // represented by a 2-dimensional shape where the second dimension is the rank
  // of the type.
  Type indiceEltType = builder.getIntegerType(64);
  if (consumeIf(Token::greater)) {
    ShapedType type = parseElementsLiteralType(attrType);
    if (!type)
      return nullptr;

    // Construct the sparse elements attr using zero element indice/value
    // attributes.
    ShapedType indicesType =
        RankedTensorType::get({0, type.getRank()}, indiceEltType);
    ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
    return getChecked<SparseElementsAttr>(
        loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
        DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
  }

  /// Parse the indices. We don't allow hex values here as we may need to use
  /// the inferred shape.
  auto indicesLoc = getToken().getLoc();
  TensorLiteralParser indiceParser(*this);
  if (indiceParser.parse(/*allowHex=*/false))
    return nullptr;

  if (parseToken(Token::comma, "expected ','"))
    return nullptr;

  /// Parse the values.
  auto valuesLoc = getToken().getLoc();
  TensorLiteralParser valuesParser(*this);
  if (valuesParser.parse(/*allowHex=*/true))
    return nullptr;

  if (parseToken(Token::greater, "expected '>'"))
    return nullptr;

  auto type = parseElementsLiteralType(attrType);
  if (!type)
    return nullptr;

  // If the indices are a splat, i.e. the literal parser parsed an element and
  // not a list, we set the shape explicitly. The indices are represented by a
  // 2-dimensional shape where the second dimension is the rank of the type.
  // Given that the parsed indices is a splat, we know that we only have one
  // indice and thus one for the first dimension.
  ShapedType indicesType;
  if (indiceParser.getShape().empty()) {
    indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
  } else {
    // Otherwise, set the shape to the one parsed by the literal parser.
    indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
  }
  auto indices = indiceParser.getAttr(indicesLoc, indicesType);

  // If the values are a splat, set the shape explicitly based on the number of
  // indices. The number of indices is encoded in the first dimension of the
  // indice shape type.
  auto valuesEltType = type.getElementType();
  ShapedType valuesType =
      valuesParser.getShape().empty()
          ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
          : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
  auto values = valuesParser.getAttr(valuesLoc, valuesType);

  // Build the sparse elements attribute by the indices and values.
  return getChecked<SparseElementsAttr>(loc, type, indices, values);
}
