//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a configurable pass that can apply patterns liberally
// and be plugged in a pass pipeline.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"

using namespace mlir;
using namespace mlir::vector;
using namespace linalg;

namespace {

/// Configurable pass to apply pattern-based tiling and fusion.
struct LinalgStrategyTileAndFusePass
    : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {

  LinalgStrategyTileAndFusePass() = default;

  LinalgStrategyTileAndFusePass(StringRef opName,
                                LinalgTilingAndFusionOptions opt,
                                LinalgTransformationFilter filt)
      : options(opt), filter(filt) {
    this->anchorOpName.setValue(opName.str());
  }

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
    if (!anchorOpName.empty()) {
      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
          anchorOpName, funcOp.getContext(), options, filter);
    } else {
      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
          funcOp.getContext(), options, filter);
    }
    // Search the root operation using bottom up traversal.
    GreedyRewriteConfig config;
    config.useTopDownTraversal = false;
    (void)applyPatternsAndFoldGreedily(
        funcOp, std::move(tilingAndFusionPattern), config);
  }

  LinalgTilingAndFusionOptions options;
  LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg tiling.
struct LinalgStrategyTilePass
    : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {

  LinalgStrategyTilePass() = default;

  LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
                         LinalgTransformationFilter filt)
      : options(opt), filter(filt) {
    this->anchorOpName.setValue(opName.str());
  }

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    RewritePatternSet tilingPattern(funcOp.getContext());
    if (!anchorOpName.empty()) {
      tilingPattern.add<LinalgGenericTilingPattern>(
          anchorOpName, funcOp.getContext(), options, filter);
    } else {
      tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
                                                    options);
    }
    (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
  }

  LinalgTilingOptions options;
  LinalgTransformationFilter filter;
};

/// Configurable pass to apply hoisting and padding.
struct LinalgStrategyPadPass
    : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {

  LinalgStrategyPadPass() = default;

  LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
                        LinalgTransformationFilter filt)
      : options(opt), filter(filt) {
    this->anchorOpName.setValue(opName.str());
  }

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    RewritePatternSet paddingPattern(funcOp.getContext());
    if (!anchorOpName.empty()) {
      paddingPattern.add<LinalgPaddingPattern>(
          anchorOpName, funcOp.getContext(), options, filter);
    } else {
      paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
                                               filter);
    }
    (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern));
  }

  LinalgPaddingOptions options;
  LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg generalization.
struct LinalgStrategyGeneralizePass
    : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {

  LinalgStrategyGeneralizePass() = default;

  LinalgStrategyGeneralizePass(StringRef opName,
                               LinalgTransformationFilter filter)
      : filter(filter) {
    this->anchorOpName.setValue(opName.str());
  }

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    RewritePatternSet generalizationPattern(funcOp.getContext());
    if (!anchorOpName.empty()) {
      generalizationPattern.add<LinalgGeneralizationPattern>(
          anchorOpName, funcOp.getContext(), filter);
    } else {
      generalizationPattern.add<LinalgGeneralizationPattern>(
          funcOp.getContext(), filter);
    }
    if (failed(applyPatternsAndFoldGreedily(funcOp,
                                            std::move(generalizationPattern))))
      signalPassFailure();
  }

  LinalgTransformationFilter filter;
};

/// Configurable pass to apply lowering of coarser-grained named linalg ops into
/// finer-grained named versions.
struct LinalgStrategyDecomposePass
    : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {

  LinalgStrategyDecomposePass() = default;

  LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
      : filter(filter) {}

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;
    RewritePatternSet decompositionPattern(funcOp.getContext());
    populateDecomposeConvolutionPatterns(decompositionPattern, filter);
    if (failed(applyPatternsAndFoldGreedily(funcOp,
                                            std::move(decompositionPattern))))
      signalPassFailure();
  }

  LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg generalization.
struct LinalgStrategyInterchangePass
    : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {

  LinalgStrategyInterchangePass() = default;

  LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
                                LinalgTransformationFilter filter)
      : iteratorInterchange(iteratorInterchange.begin(),
                            iteratorInterchange.end()),
        filter(filter) {}

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
                                            iteratorInterchange.end());
    RewritePatternSet interchangePattern(funcOp.getContext());
    interchangePattern.add<GenericOpInterchangePattern>(
        funcOp.getContext(), interchangeVector, filter);
    if (failed(applyPatternsAndFoldGreedily(funcOp,
                                            std::move(interchangePattern))))
      signalPassFailure();
  }

  SmallVector<int64_t> iteratorInterchange;
  LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg promotion.
struct LinalgStrategyPromotePass
    : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {

  LinalgStrategyPromotePass() = default;

  LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
                            LinalgTransformationFilter filt)
      : options(opt), filter(filt) {
    this->anchorOpName.setValue(opName.str());
  }

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    RewritePatternSet promotionPattern(funcOp.getContext());
    if (!anchorOpName.empty()) {
      promotionPattern.add<LinalgBasePromotionPattern>(
          anchorOpName, funcOp.getContext(), options, filter);
    } else {
      promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
                                                       filter, options);
    }
    (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
  }

  LinalgPromotionOptions options;
  LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg vectorization.
struct LinalgStrategyVectorizePass
    : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {

  LinalgStrategyVectorizePass() = default;

  LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
                              LinalgTransformationFilter filt,
                              bool padVectorize = false)
      : options(opt), filter(filt) {
    this->anchorOpName.setValue(opName.str());
    this->vectorizePadding.setValue(padVectorize);
  }

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    RewritePatternSet vectorizationPatterns(funcOp.getContext());
    if (!anchorOpName.empty()) {
      vectorizationPatterns.add<LinalgVectorizationPattern>(
          anchorOpName, funcOp.getContext(), options, filter);
    } else {
      vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
                                                            filter, options);
    }
    vector::populateVectorTransferPermutationMapLoweringPatterns(
        vectorizationPatterns);
    vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
    vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
                              linalg::LinalgCopyVTWForwardingPattern>(
        funcOp.getContext(), /*benefit=*/2);
    TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns,
                                                funcOp.getContext());
    TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns,
                                                 funcOp.getContext());
    (void)applyPatternsAndFoldGreedily(funcOp,
                                       std::move(vectorizationPatterns));

    // Apply the pad tensor op vectorization separately to avoid running the
    // GenericPadTensorOpVectorizationPattern too early.
    // TODO: Improve once we have better infrastructure to control pattern
    // application.
    if (vectorizePadding) {
      RewritePatternSet patterns(funcOp.getContext());
      linalg::populatePadTensorOpVectorizationPatterns(patterns);
      (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
    }
  }

  LinalgVectorizationOptions options;
  LinalgTransformationFilter filter;
};

/// Configurable pass to enable the application of other pattern-based linalg
/// passes.
struct LinalgStrategyEnablePass
    : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {

  LinalgStrategyEnablePass(LinalgEnablingOptions opt,
                           LinalgTransformationFilter filt)
      : options(opt), filter(filt) {}

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    MLIRContext *context = funcOp.getContext();
    RewritePatternSet patterns =
        linalg::getLinalgTilingCanonicalizationPatterns(context);
    scf::populateSCFForLoopCanonicalizationPatterns(patterns);
    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
      return signalPassFailure();

    if (options.licm) {
      if (funcOp
              ->walk([&](LoopLikeOpInterface loopLike) {
                if (failed(moveLoopInvariantCode(loopLike)))
                  return WalkResult::interrupt();
                return WalkResult::advance();
              })
              .wasInterrupted())
        return signalPassFailure();
    }

    promoteSingleIterationLoops(funcOp);
    if (options.hoistRedundantVectorTransfers)
      hoistRedundantVectorTransfers(funcOp);

    if (options.hoistRedundantVectorTransfersOnTensor)
      hoistRedundantVectorTransfersOnTensor(funcOp);

    // Run CSE to cleanup after canonicalization.
    OpPassManager dynamicPM("builtin.func");
    dynamicPM.addPass(createCSEPass());
    if (failed(runPipeline(dynamicPM, funcOp)))
      return signalPassFailure();
  }

  LinalgEnablingOptions options;
  LinalgTransformationFilter filter;
};

/// Configurable pass to lower vector operations.
struct LinalgStrategyLowerVectorsPass
    : public LinalgStrategyLowerVectorsPassBase<
          LinalgStrategyLowerVectorsPass> {

  LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
                                 LinalgTransformationFilter filt)
      : options(opt), filter(filt) {}

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;

    MLIRContext *context = funcOp.getContext();
    RewritePatternSet patterns(context);
    vector::populateVectorToVectorCanonicalizationPatterns(patterns);
    // In a progressive lowering of vectors, this would be the 1st step.
    if (options.contractionLowering) {
      patterns.add<ContractionOpToOuterProductOpLowering,
                   ContractionOpToMatmulOpLowering, ContractionOpLowering>(
          options.vectorTransformOptions, context);
      vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
    }
    // In a progressive lowering of vectors, this would be the 2nd step.
    if (options.multiReductionLowering) {
      vector::populateVectorMultiReductionLoweringPatterns(
          patterns,
          options.vectorTransformOptions.vectorMultiReductionLowering);
    }
    // In a progressive lowering of vectors, this would be the 3rd step.
    if (options.transferPartialRewrite) {
      patterns.add<vector::VectorTransferFullPartialRewriter>(
          context, options.vectorTransformOptions);
    }
    // In a progressive lowering of vectors, this would be the 4th step.
    if (options.transferLowering) {
      vector::populateVectorTransferLoweringPatterns(patterns,
                                                     options.maxTransferRank);
    }
    // In a progressive lowering of vectors, this would be the 5th step.
    if (options.transferToSCFConversion) {
      populateVectorToSCFConversionPatterns(
          patterns, options.vectorTransferToSCFOptions.setTargetRank(
                        options.maxTransferRank));
    }
    // In a progressive lowering of vectors, this would be the 6th step.
    if (options.shapeCastLowering) {
      vector::populateVectorShapeCastLoweringPatterns(patterns);
    }
    // In a progressive lowering of vectors, this would be the 7th step.
    if (options.transposeLowering) {
      vector::populateVectorTransposeLoweringPatterns(
          patterns, options.vectorTransformOptions);
      if (options.avx2Lowering)
        x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
            patterns, options.avx2LoweringOptions, /*benefit=*/10);
    }
    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
  }

  LinalgVectorLoweringOptions options;
  LinalgTransformationFilter filter;
};

/// Configurable pass to lower vector operations.
struct LinalgStrategyRemoveMarkersPass
    : public LinalgStrategyRemoveMarkersPassBase<
          LinalgStrategyRemoveMarkersPass> {

  void runOnFunction() override {
    auto funcOp = getFunction();
    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
      return;
    funcOp.walk([](LinalgOp op) {
      op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
    });
  }
};
} // namespace

/// Create a LinalgStrategyTileAndFusePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyTileAndFusePass(StringRef opName,
                                          LinalgTilingAndFusionOptions options,
                                          LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
                                                         filter);
}

/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
                                   LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
}

/// Create a LinalgStrategyPadPass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
                                  LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
}

/// Create a LinalgStrategyPromotePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyPromotePass(StringRef opName,
                                      LinalgPromotionOptions opt,
                                      LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
}

/// Create a LinalgStrategyGeneralizePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyGeneralizePass(StringRef opName,
                                         LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
}

/// Create a LinalgStrategyDecomposePass.
// TODO: if/when we need finer control add an `opName` parameter.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyDecomposePass>(filter);
}

/// Create a LinalgStrategyInterchangePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
                                          LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
                                                         filter);
}

/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
    StringRef opName, LinalgVectorizationOptions opt,
    LinalgTransformationFilter filter, bool padVectorize) {
  return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
                                                       padVectorize);
}

/// Create a LinalgStrategyEnablePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
                                     LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
}

/// Create a LinalgStrategyLowerVectorsPass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
                                           LinalgTransformationFilter filter) {
  return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
}

/// Create a LinalgStrategyRemoveMarkersPass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyRemoveMarkersPass() {
  return std::make_unique<LinalgStrategyRemoveMarkersPass>();
}
