//===- QuantUtils.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains TOSA numerical support functions and quantization
// attribute builders.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"

using namespace mlir;
using namespace mlir::tosa;

/// From a scale value, generates multiplier and shift values where
/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
/// multiplier = mantissa*2^shift for 16-bit scaling.
static void computeMultiplierAndShiftTosaScale16(double scale,
                                                 int32_t &multiplier,
                                                 int32_t &shift) {

  const double mantissa = std::frexp(scale, &shift);
  auto shiftedM = std::round(mantissa * (int64_t(1) << 15));

  // Can't be greater than 1.0.
  assert(shiftedM <= (int64_t(1) << 15) &&
         "Shifted mantissa exceeds 16 signed bits");

  if (shiftedM == (int64_t(1) << 15)) {
    shiftedM /= 2;
    shift++;
  }

  // TOSA expects right shift to be positive and embed (1 << 15) into right
  // shift bits.
  shift = (-shift) + 15;

  assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
         "Shifted mantissa exceeds 32-bit signed output type");

  multiplier = static_cast<int32_t>(shiftedM);

  // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
  if (shift > 63) {
    // Shifting the multiplier by more than 32-bits is unnecessary.
    multiplier = multiplier >> std::min<int32_t>(32, shift - 63);
    shift = 63;
  }
}

/// From a scale value, generates multiplier and shift values where
/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
/// multiplier = mantissa*2^shift for 32-bit scaling.
static void computeMultiplierAndShiftTosaScale32(double scale,
                                                 int32_t &multiplier,
                                                 int32_t &shift) {

  const double mantissa = std::frexp(scale, &shift);
  auto shiftedM = std::round(mantissa * (int64_t(1) << 31));

  // Can't be greater than 1.0.
  assert(shiftedM <= (int64_t(1) << 31) &&
         "Shifted mantissa exceeds 32 signed bits");
  if (shiftedM == (int64_t(1) << 31)) {
    shiftedM /= 2;
    shift++;
  }

  // TOSA expects right shift to be positive, and embed (1 << 31) into right
  // shift bits.
  shift = (-shift) + 31;

  assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
         "Shifted mantissa exceeds 32-bit signed output type");

  multiplier = static_cast<int32_t>(shiftedM);

  // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
  if (shift > 63) {
    // Shifting the multiplier by more than 32-bits is unnecessary.
    multiplier = multiplier >> std::min<int32_t>(32, shift - 63);
    shift = 63;
  }
}

/// Generates a quantized multiplier/shift from double.
void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
                                           int32_t &shift, int32_t scaleWidth) {

  switch (scaleWidth) {
  case 16:
    computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
    return;
  case 32:
    computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
    return;
  default:
    assert(0 && "Unsupported Tosa quantized_scale regime specified!");
  }
}

#define GET_UQTYPE(input_type)                                                 \
  ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>())
#define GET_QTYPE(input_type)                                                  \
  ((input_type).getElementType().dyn_cast<quant::QuantizedType>())

/// Method to build ConvOpQuantizationAttr, called from
/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
/// input_zp: input zeropoint
/// weight_zp: weight zeropoint.
ConvOpQuantizationAttr
mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
                                        Value weight) {

  auto inputType = input.getType().dyn_cast<ShapedType>();
  auto weightType = weight.getType().dyn_cast<ShapedType>();

  if (!inputType || !weightType)
    return nullptr;

  auto inputQType = GET_UQTYPE(inputType);
  auto weightPerTensorQType = GET_UQTYPE(weightType);
  auto weightPerAxisQType = weightType.getElementType()
                                .dyn_cast<quant::UniformQuantizedPerAxisType>();

  // Weights must be either per-tensor quantized or per-axis quantized.
  assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
         "Weights must be either per-tensor or per-axis quantized");

  // Either all quantized or all not quantized.
  assert(!((bool)inputQType ^
           ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
         "Inputs and weights must be all quantized or all not quantized");

  if (inputQType) {

    int64_t inputZp = inputQType.getZeroPoint();
    int64_t weightZp = 0;

    if (weightPerTensorQType) {
      weightZp = weightPerTensorQType.getZeroPoint();
    } else if (weightPerAxisQType) {
      weightZp = weightPerAxisQType.getZeroPoints().front();
    }

    auto quantAttr = tosa::ConvOpQuantizationAttr::get(
        builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp),
        builder.getContext());

    return quantAttr;
  }

  return nullptr;
}

/// Builds MatMulOpQuantizationAttr, called from
/// MatMulOpQuantInfoBuilder:
/// aZp: input a zeropoint
/// bZp: input b zeropoint.
MatMulOpQuantizationAttr
mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
                                          Value b) {

  auto aType = a.getType().dyn_cast<ShapedType>();
  auto bType = b.getType().dyn_cast<ShapedType>();

  if (!aType || !bType)
    return nullptr;

  auto aQType = GET_UQTYPE(aType);
  auto bQType = GET_UQTYPE(bType);

  // A and B are either all quantized or all not quantized.
  assert(!((bool)aQType ^ (bool)bQType) &&
         "Matmul operands must be all quantized or all not quantized");

  if (aQType) {

    int64_t aZp = aQType.getZeroPoint();
    int64_t bZp = bQType.getZeroPoint();

    auto quantAttr = tosa::MatMulOpQuantizationAttr::get(
        builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp),
        builder.getContext());

    return quantAttr;
  }

  return nullptr;
}

/// Builds UnaryOpQuantizationAttr
/// UnaryOpQuantInfoBuilder:
/// inputZp: input zeropoint
/// outputZp: output zeropoint.
UnaryOpQuantizationAttr
mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
                                         Type outputRawType) {

  auto inputType = input.getType().dyn_cast<ShapedType>();
  auto outputType = outputRawType.dyn_cast<ShapedType>();

  if (!inputType || !outputType)
    return nullptr;

  auto inputQType = GET_UQTYPE(inputType);
  auto outputQType = GET_UQTYPE(outputType);

  // Either all quantized or all not quantized.
  assert(!((bool)inputQType ^ (bool)outputQType) &&
         "Unary inputs/outputs must be all quantized or all not quantized");

  if (inputQType) {

    int64_t inputZp = inputQType.getZeroPoint();
    int64_t outputZp = outputQType.getZeroPoint();

    auto quantAttr = tosa::UnaryOpQuantizationAttr::get(
        builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp),
        builder.getContext());

    return quantAttr;
  }

  return nullptr;
}

/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
/// inputZp: input zeropoint.
PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
                                                             Value input) {

  auto inputType = input.getType().dyn_cast<ShapedType>();

  if (!inputType)
    return nullptr;

  auto inputQType = GET_UQTYPE(inputType);

  if (inputQType) {

    int64_t inputZp = inputQType.getZeroPoint();

    auto quantAttr = tosa::PadOpQuantizationAttr::get(
        builder.getI32IntegerAttr(inputZp), builder.getContext());

    return quantAttr;
  }

  return nullptr;
}

/// Builds output type for a quantized ConvOp with the right bitwidth.
/// This is called by the builder when dealing with quantized content.
Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
                                           Value input, Value weight) {

  auto inputType = input.getType().dyn_cast<ShapedType>();
  auto weightType = weight.getType().dyn_cast<ShapedType>();

  assert(inputType && weightType &&
         "Could not extract input or weight tensors from Conv op");

  auto inputQType = GET_QTYPE(inputType);
  auto weightQType = GET_QTYPE(weightType);

  assert(inputQType && weightQType &&
         "Could not extract input or weight tensor types from Conv op");

  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
  unsigned weightBits = weightQType.getStorageTypeIntegralWidth();

  auto outputShapedType = outputType.dyn_cast<ShapedType>();
  assert(outputShapedType &&
         "Could not extract output shape type from Conv op");

  IntegerType accElementType;
  if (inputBits == 16 && weightBits == 8)
    accElementType = builder.getIntegerType(48);
  else
    accElementType = builder.getI32Type();
  auto accType = outputShapedType.clone(accElementType);
  return accType;
}

/// Builds Tosa quantization attributes from min/max values.
Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
                                      Attribute minAttr, Attribute maxAttr,
                                      IntegerAttr quantBits, int filterQuantDim,
                                      bool isSigned, BoolAttr narrowRange) {

  quant::QuantizedType retType;

  auto convfunc =
      quant::ExpressedToQuantizedConverter::forInputType(inputDType);

  auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
  auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();

  SmallVector<double, 2> min, max;

  // At least one is per-axis quantized elementsattr.
  if (minElems || maxElems) {
    // Must have the same number of elements.
    if (minElems.getNumElements() != maxElems.getNumElements())
      return {};
    min.reserve(minElems.getNumElements());
    max.reserve(maxElems.getNumElements());
    for (auto i : minElems)
      min.push_back(FloatAttr::getValueAsDouble(i));
    for (auto i : maxElems)
      max.push_back(FloatAttr::getValueAsDouble(i));
  } else { // Just a single FP value.
    auto minVal = minAttr.dyn_cast<FloatAttr>();
    if (minVal)
      min.push_back(minVal.getValueAsDouble());
    else
      return {};
    auto maxVal = maxAttr.dyn_cast<FloatAttr>();
    if (maxVal)
      max.push_back(maxVal.getValueAsDouble());
    else
      return {};
  }

  if (min.size() == max.size()) {
    if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
      retType = quant::fakeQuantAttrsToType(
          builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
          narrowRange.getValue(), convfunc.expressedType, isSigned);
    } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
      auto shape = inputDType.dyn_cast<ShapedType>();
      if (!shape)
        return {};
      if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
        retType = quant::fakeQuantAttrsToType(
            builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
            max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
      }
    } else {
      return {};
    }
  } else {
    return {};
  }

  if (!retType)
    return {};

  return convfunc.convert(retType);
}

/// Builds Tosa quantization attributes from min/max values.
TypeAttr
mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
                                     Attribute minAttr, Attribute maxAttr,
                                     IntegerAttr quantBits, int filterQuantDim,
                                     bool isSigned, BoolAttr narrowRange) {

  return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
                                            maxAttr, quantBits, filterQuantDim,
                                            isSigned, narrowRange));
}
