//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the AMX dialect and its operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"

using namespace mlir;

#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"

void amx::AMXDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMX/AMX.cpp.inc"
      >();
}

/// Verify that AMX supports the implied tile shape.
static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
  const unsigned kMaxRows = 16;
  const unsigned kBitsPerRow = 64 * 8;
  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
  if (tp.getDimSize(0) > kMaxRows)
    return op->emitOpError("bad row height: ") << tp.getDimSize(0);
  if (col > kBitsPerRow || col & 0x1f)
    return op->emitOpError("bad column width: ") << (col >> 3);
  return success();
}

/// Verify that AMX supports the multiplication.
static LogicalResult verifyMultShape(Operation *op, VectorType atp,
                                     VectorType btp, VectorType ctp,
                                     unsigned scale) {
  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
  if (cm != am || cn != bn || ak != bk)
    return op->emitOpError("bad mult shape: ")
           << cm << " x " << cn << " x " << ak;
  return success();
}

static LogicalResult verify(amx::TileZeroOp op) {
  return verifyTileSize(op, op.getVectorType());
}

static LogicalResult verify(amx::TileLoadOp op) {
  unsigned rank = op.getMemRefType().getRank();
  if (llvm::size(op.indices()) != rank)
    return op.emitOpError("requires ") << rank << " indices";
  return verifyTileSize(op, op.getVectorType());
}

static LogicalResult verify(amx::TileStoreOp op) {
  unsigned rank = op.getMemRefType().getRank();
  if (llvm::size(op.indices()) != rank)
    return op.emitOpError("requires ") << rank << " indices";
  return verifyTileSize(op, op.getVectorType());
}

static LogicalResult verify(amx::TileMulFOp op) {
  VectorType aType = op.getLhsVectorType();
  VectorType bType = op.getRhsVectorType();
  VectorType cType = op.getVectorType();
  if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
      failed(verifyTileSize(op, cType)) ||
      failed(verifyMultShape(op, aType, bType, cType, 1)))
    return failure();
  Type ta = aType.getElementType();
  Type tb = bType.getElementType();
  Type tc = cType.getElementType();
  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
    return op.emitOpError("unsupported type combination");
  return success();
}

static LogicalResult verify(amx::TileMulIOp op) {
  VectorType aType = op.getLhsVectorType();
  VectorType bType = op.getRhsVectorType();
  VectorType cType = op.getVectorType();
  if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
      failed(verifyTileSize(op, cType)) ||
      failed(verifyMultShape(op, aType, bType, cType, 2)))
    return failure();
  Type ta = aType.getElementType();
  Type tb = bType.getElementType();
  Type tc = cType.getElementType();
  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
    return op.emitOpError("unsupported type combination");
  return success();
}

#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.cpp.inc"
