//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"

#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"

#include <numeric>

using namespace mlir;

Optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
                                        ShapedType targetType) {
  // Make the sourceType greater rank than the targetType. If they are same
  // rank, then its an unsupported reshape op.
  if (sourceType.getRank() == targetType.getRank())
    return llvm::None;
  if (sourceType.getRank() < targetType.getRank())
    std::swap(sourceType, targetType);

  ArrayRef<int64_t> sourceShape = sourceType.getShape();
  ArrayRef<int64_t> targetShape = targetType.getShape();
  unsigned sourceDim = 0;
  SmallVector<ReassociationIndices> reassociationMap;
  reassociationMap.reserve(targetType.getRank());

  ReassociationIndices currIndices;
  int64_t prodOfCollapsedDims = 1;
  while (sourceDim < sourceShape.size()) {
    unsigned targetDim = reassociationMap.size();

    // If all the dimensions of the targetShape are exhausted, then the
    // remaining dims in the source shape must be all 1s. So for such cases, set
    // 1 as the target shape. The actual reassociation indices will be handled
    // later.
    int64_t currTargetShape =
        (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
    while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
           sourceDim < sourceShape.size()) {
      prodOfCollapsedDims *= sourceShape[sourceDim];
      currIndices.push_back(sourceDim++);
    }

    // If the current expanded dimension is dynamic, then the collapsed
    // dimensions should also be dynamic and product of all previous unprocessed
    // dimensions of the expanded shape should be 1.
    if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
        (currTargetShape != ShapedType::kDynamicSize ||
         prodOfCollapsedDims != 1))
      return llvm::None;

    // If the collapsed dim is dynamic, the current expanded dim should also
    // be dynamic.
    if (currTargetShape == ShapedType::kDynamicSize &&
        sourceShape[sourceDim] != ShapedType::kDynamicSize)
      return llvm::None;

    // For static shapes, if the product of dimensions of the expanded shape
    // should match the collapsed dimension shape.
    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
      return llvm::None;

    currIndices.push_back(sourceDim++);
    // If the reassociation is empty but the currIndices is not, this by
    // definition is folding unit-dimensions with the result being scalar type.
    // So only append the `currIndices` if reassociation map is not empty.
    if (targetDim == targetShape.size()) {
      while (sourceDim < sourceShape.size())
        currIndices.push_back(sourceDim++);
      if (!reassociationMap.empty() && !currIndices.empty())
        reassociationMap.back().append(currIndices.begin(), currIndices.end());
      // Break out of the loops. We should be done here.
      break;
    }
    reassociationMap.emplace_back(ReassociationIndices{});
    std::swap(reassociationMap.back(), currIndices);
    prodOfCollapsedDims = 1;
  }
  // All the dimensions in the two shapes must have been processed.
  if (reassociationMap.size() != targetShape.size() ||
      sourceDim != sourceShape.size())
    return llvm::None;
  return reassociationMap;
}

ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
                                     OperationState &result) {
  // Parse the operand.
  OpAsmParser::OperandType src;
  if (parser.parseOperand(src))
    return failure();

  // Parse reassociation indices.
  Builder &b = parser.getBuilder();
  SmallVector<Attribute, 4> reassociation;
  if (parser.parseLSquare())
    return failure();

  while (true) {
    if (succeeded(parser.parseOptionalRSquare()))
      break;
    if (parser.parseLSquare())
      return failure();
    SmallVector<int64_t> indices;
    while (true) {
      int64_t index;
      if (parser.parseInteger(index))
        return failure();
      indices.push_back(index);

      if (succeeded(parser.parseOptionalComma()))
        continue;
      if (failed(parser.parseRSquare()))
        return failure();
      break;
    }
    reassociation.push_back(b.getI64ArrayAttr(indices));
    if (succeeded(parser.parseOptionalComma()))
      continue;
    if (failed(parser.parseRSquare()))
      return failure();
    break;
  }

  result.addAttribute(getReassociationAttrName(),
                      b.getArrayAttr(reassociation));

  // Parse optional attributes.
  parser.parseOptionalAttrDict(result.attributes);

  // Parse types.
  Type srcType;
  Type resultType;
  if (parser.parseColon() || parser.parseType(srcType) ||
      parser.resolveOperand(src, srcType, result.operands) ||
      parser.parseKeyword("into") || parser.parseType(resultType))
    return failure();
  result.addTypes(resultType);
  return success();
}

Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
    ArrayRef<ReassociationIndices> producerReassociations,
    ArrayRef<ReassociationIndices> consumerReassociations,
    MLIRContext *context) {
  SmallVector<ReassociationIndices> composedIndices;
  // Make the producer the larger sized vector. If they are of same size, the
  // resulting reshape is not a supported reshape op.
  if (producerReassociations.size() == consumerReassociations.size())
    return llvm::None;
  if (producerReassociations.size() < consumerReassociations.size())
    std::swap(producerReassociations, consumerReassociations);

  // Handle the corner case of the result being a rank 0 shaped type. Return an
  // empty reassociation.
  if (consumerReassociations.empty())
    return composedIndices;

  size_t consumerDims = std::accumulate(
      consumerReassociations.begin(), consumerReassociations.end(), 0,
      [](size_t all, ReassociationIndicesRef indices) {
        return all + indices.size();
      });
  if (producerReassociations.size() != consumerDims)
    return llvm::None;

  for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
    ReassociationIndices reassociations;
    for (int64_t consumerIndex : consumerIndices) {
      for (int64_t producerIndex : producerReassociations[consumerIndex])
        reassociations.push_back(producerIndex);
    }
    composedIndices.push_back(std::move(reassociations));
  }
  return composedIndices;
}

SmallVector<SmallVector<AffineExpr, 2>, 2>
mlir::convertReassociationIndicesToExprs(
    MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
  SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
  for (const auto &indices : reassociationIndices) {
    SmallVector<AffineExpr, 2> reassociationMap;
    reassociationMap.reserve(indices.size());
    for (int64_t index : indices)
      reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
    reassociationMaps.push_back(std::move(reassociationMap));
  }
  return reassociationMaps;
}

template <typename AffineExprTy>
unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
  unsigned pos = 0;
  for (const auto &exprs : exprArrays) {
    for (auto expr : exprs) {
      expr.walk([&pos](AffineExpr e) {
        if (auto d = e.dyn_cast<AffineExprTy>())
          pos = std::max(pos, d.getPosition());
      });
    }
  }
  return pos;
}

ArrayAttr mlir::getReassociationIndicesAttribute(
    OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
  SmallVector<Attribute, 4> reassociationAttr =
      llvm::to_vector<4>(llvm::map_range(
          reassociation, [&](ReassociationIndices indices) -> Attribute {
            return b.getI64ArrayAttr(indices).cast<Attribute>();
          }));
  return b.getArrayAttr(reassociationAttr);
}

SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
  SmallVector<ReassociationIndices, 2> reassociationIndices;
  for (const auto &exprs : reassociationExprs) {
    ReassociationIndices indices;
    indices.reserve(exprs.size());
    for (const auto &expr : exprs)
      indices.push_back(expr.cast<AffineDimExpr>().getPosition());
    reassociationIndices.push_back(indices);
  }
  return reassociationIndices;
}

SmallVector<AffineMap, 4>
mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
  unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
  assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
         "Expected symbol-less expressions");
  SmallVector<AffineMap, 4> maps;
  maps.reserve(reassociation.size());
  for (const auto &exprs : reassociation) {
    assert(!exprs.empty());
    maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
  }
  return maps;
}
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
                                int *invalidIndex) {
  if (reassociation.empty())
    return true;
  unsigned nDims = reassociation[0].getNumDims();
  unsigned nextExpectedDim = 0;
  for (auto it : llvm::enumerate(reassociation)) {
    auto m = it.value();
    if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
      if (invalidIndex)
        *invalidIndex = it.index();
      return false;
    }
    for (auto e : m.getResults()) {
      auto d = e.dyn_cast<AffineDimExpr>();
      if (!d || d.getPosition() != nextExpectedDim++) {
        if (invalidIndex)
          *invalidIndex = it.index();
        return false;
      }
    }
  }
  if (nextExpectedDim != nDims) {
    if (invalidIndex)
      *invalidIndex = reassociation.size() - 1;
    return false;
  }
  return true;
}
