/*************************************************************************
 * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
 * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/


#ifndef NCCL_REDUCE_KERNEL_H_
#define NCCL_REDUCE_KERNEL_H_

#include "common_kernel.h"
#include <limits>
#include <type_traits>

template<typename T>
struct FuncNull {
  __device__ FuncNull(uint64_t opArg=0) {}
  __device__ T operator()(const T x, const T y) const {
    return 0;
  }
};

template<typename T>
struct FuncSum {
  __device__ FuncSum(uint64_t opArg=0) {}
  __device__ T operator()(const T x, const T y) const {
    return x + y;
  }
};

template<typename T>
struct FuncProd {
  __device__ FuncProd(uint64_t opArg=0) {}
  __device__ T operator()(const T x, const T y) const {
    return x * y;
  }
};

template<typename T>
struct FuncMax {
  __device__ FuncMax(uint64_t opArg=0) {}
  __device__ T operator()(const T x, const T y) const {
    return (x < y) ? y : x;
  }
};

template<typename T>
struct FuncMin {
  __device__ FuncMin(uint64_t opArg=0) {}
  __device__ T operator()(const T x, const T y) const {
    return (x < y) ? x : y;
  }
};

template<typename Fn>
struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max
  static constexpr bool IsPreOpIdentity = true;
  static constexpr bool IsPostOpIdentity = true;

  template<typename T>
  __device__ static T preOp(Fn, T x) { return x; }
  template<typename T>
  __device__ static T postOp(Fn, T x) { return x; }
};

#define NCCL_MASK0 0x00ff00ff
#define NCCL_MASK1 0xff00ff00
static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) {
  /* This can be used both for signed and unsigned 8-bit addition */
  const uint32_t x0 = x & NCCL_MASK0;
  const uint32_t x1 = x & NCCL_MASK1;
  const uint32_t y0 = y & NCCL_MASK0;
  const uint32_t y1 = y & NCCL_MASK1;
  const uint32_t r0 = (x0+y0);
  const uint32_t r1 = (x1+y1);
  return (r0 & NCCL_MASK0) | (r1 & NCCL_MASK1);
}

template<>
struct FuncSum<int8_t> {
  __device__ FuncSum(uint64_t opArg=0) {}
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
    int32_t rv, z=0;
    asm("vadd4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
    return rv;
#else
    return addChar4(x, y);
#endif
  }
  __device__ int8_t operator()(const int8_t x, const int8_t y) const {
    return x+y;
  }
};
template<>
struct FuncSum<uint8_t> {
  __device__ FuncSum(uint64_t opArg=0) {}
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
    int32_t rv, z=0;
    asm("vadd4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
    return rv;
#else
    return addChar4(x, y);
#endif
  }
  __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
    return x+y;
  }
};

static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) {
  /* This can be used both for signed and unsigned 8-bit multiplication */
  union converter { uint32_t storage; char4 a; };
  converter cx, cy, cr;
  cx.storage = x;
  cy.storage = y;
  cr.a.x = cx.a.x * cy.a.x;
  cr.a.y = cx.a.y * cy.a.y;
  cr.a.z = cx.a.z * cy.a.z;
  cr.a.w = cx.a.w * cy.a.w;
  return cr.storage;
}

template<>
struct FuncProd<int8_t> {
  __device__ FuncProd(uint64_t opArg=0) {}
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
    return mulChar4(x, y);
  }
  __device__ int8_t operator()(const int8_t x, const int8_t y) const {
    return x*y;
  }
};
template<>
struct FuncProd<uint8_t> {
  __device__ FuncProd(uint64_t opArg=0) {}
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
    return mulChar4(x, y);
  }
  __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
    return x*y;
  }
};

template<>
struct FuncMax<int8_t> {
  __device__ FuncMax(uint64_t opArg=0) {}
  union converter { uint32_t storage; char4 a; };
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
    int32_t rv, z=0;
    asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
    return rv;
#else
    converter cx, cy, cr;
    cx.storage = x;
    cy.storage = y;
    cr.a.x = max(cx.a.x, cy.a.x);
    cr.a.y = max(cx.a.y, cy.a.y);
    cr.a.z = max(cx.a.z, cy.a.z);
    cr.a.w = max(cx.a.w, cy.a.w);
    return cr.storage;
#endif
  }
  __device__ int8_t operator()(const int8_t x, const int8_t y) const {
    return (x>y) ? x : y;
  }
};
template<>
struct FuncMax<uint8_t> {
  __device__ FuncMax(uint64_t opArg=0) {}
  union converter { uint32_t storage; uchar4 a; };
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
    int32_t rv, z=0;
    asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
    return rv;
#else
    converter cx, cy, cr;
    cx.storage = x;
    cy.storage = y;
    cr.a.x = max(cx.a.x, cy.a.x);
    cr.a.y = max(cx.a.y, cy.a.y);
    cr.a.z = max(cx.a.z, cy.a.z);
    cr.a.w = max(cx.a.w, cy.a.w);
    return cr.storage;
#endif
  }
  __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
    return (x>y) ? x : y;
  }
};

template<>
struct FuncMin<int8_t> {
  __device__ FuncMin(uint64_t opArg=0) {}
  union converter { uint32_t storage; char4 a; };
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
    int32_t rv, z=0;
    asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
    return rv;
#else
    converter cx, cy, cr;
    cx.storage = x;
    cy.storage = y;
    cr.a.x = min(cx.a.x, cy.a.x);
    cr.a.y = min(cx.a.y, cy.a.y);
    cr.a.z = min(cx.a.z, cy.a.z);
    cr.a.w = min(cx.a.w, cy.a.w);
    return cr.storage;
#endif
  }
  __device__ int8_t operator()(const int8_t x, const int8_t y) const {
    return (x<y) ? x : y;
  }
};
template<>
struct FuncMin<uint8_t> {
  __device__ FuncMin(uint64_t opArg=0) {}
  union converter { uint32_t storage; uchar4 a; };
  __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
    int32_t rv, z=0;
    asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z));
    return rv;
#else
    converter cx, cy, cr;
    cx.storage = x;
    cy.storage = y;
    cr.a.x = min(cx.a.x, cy.a.x);
    cr.a.y = min(cx.a.y, cy.a.y);
    cr.a.z = min(cx.a.z, cy.a.z);
    cr.a.w = min(cx.a.w, cy.a.w);
    return cr.storage;
#endif
  }
  __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const {
    return (x<y) ? x : y;
  }
};

template<>
struct FuncSum<half> {
  __device__ FuncSum(uint64_t opArg=0) {}
  __device__ half2 operator()(const half2 x, const half2 y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
    return __hadd2(x, y);
#else
    float2 fx, fy, fr;
    fx = __half22float2(x);
    fy = __half22float2(y);
    fr.x = fx.x + fy.x;
    fr.y = fx.y + fy.y;
    return __float22half2_rn(fr);
#endif
  }
  __device__ half operator()(const half x, const half y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
    return __hadd(x, y);
#else
    return __float2half( __half2float(x) + __half2float(y) );
#endif
  }
};

#if defined(RCCL_BFLOAT16)
template<>
struct FuncSum<rccl_bfloat16> {
  __device__ FuncSum(uint64_t opArg=0) {}
  __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
    return (rccl_bfloat16)((float)x + (float)y);
  }
};
#endif

template<>
struct FuncProd<half> {
  __device__ FuncProd(uint64_t opArg=0) {}
  __device__ half2 operator()(const half2 x, const half2 y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
    return __hmul2(x, y);
#else
    float2 fx, fy, fr;
    fx = __half22float2(x);
    fy = __half22float2(y);
    fr.x = fx.x * fy.x;
    fr.y = fx.y * fy.y;
    return __float22half2_rn(fr);
#endif
  }
  __device__ half operator()(const half x, const half y) const {
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
    return __hmul(x, y);
#else
    return __float2half( __half2float(x) * __half2float(y) );
#endif
  }
};

#if defined(RCCL_BFLOAT16)
template<>
struct FuncProd<rccl_bfloat16> {
  __device__ FuncProd(uint64_t opArg=0) {}
  __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
    return (rccl_bfloat16)((float)x * (float)y);
  }
};
#endif

template<>
struct FuncMax<half> {
  __device__ FuncMax(uint64_t opArg=0) {}
  __device__ half2 operator()(const half2 x, const half2 y) const {
    float2 fx, fy, fr;
    fx = __half22float2(x);
    fy = __half22float2(y);
    fr.x = fmaxf(fx.x, fy.x);
    fr.y = fmaxf(fx.y, fy.y);
    return __float22half2_rn(fr);
  }
  __device__ half operator()(const half x, const half y) const {
    float fx, fy, fm;
    fx = __half2float(x);
    fy = __half2float(y);
    fm = fmaxf(fx, fy);
    return __float2half(fm);
  }
};

#if defined(RCCL_BFLOAT16)
template<>
struct FuncMax<rccl_bfloat16> {
  __device__ FuncMax(uint64_t opArg=0) {}
  __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
    return (float)x < (float)y ? y : x;
  }
};
#endif

template<>
struct FuncMin<half> {
  __device__ FuncMin(uint64_t opArg=0) {}
  __device__ half2 operator()(const half2 x, const half2 y) const {
    float2 fx, fy, fr;
    fx = __half22float2(x);
    fy = __half22float2(y);
    fr.x = fminf(fx.x, fy.x);
    fr.y = fminf(fx.y, fy.y);
    return __float22half2_rn(fr);
  }
  __device__ half operator()(const half x, const half y) const {
    float fx, fy, fm;
    fx = __half2float(x);
    fy = __half2float(y);
    fm = fminf(fx, fy);
    return __float2half(fm);
  }
};

#if defined(RCCL_BFLOAT16)
template<>
struct FuncMin<rccl_bfloat16> {
  __device__ FuncMin(uint64_t opArg=0) {}
  __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
    return (float)x < (float)y ? x : y;
  }
};
#endif

template<>
struct FuncMax<float> {
  __device__ FuncMax(uint64_t opArg=0) {}
  __device__ float operator()(float x, float y) const {
    return fmaxf(x, y);
  }
};
template<>
struct FuncMin<float> {
  __device__ FuncMin(uint64_t opArg=0) {}
  __device__ float operator()(float x, float y) const {
    return fminf(x, y);
  }
};

template<>
struct FuncMax<double> {
  __device__ FuncMax(uint64_t opArg=0) {}
  __device__ double operator()(double x, double y) const {
    return fmax(x, y);
  }
};
template<>
struct FuncMin<double> {
  __device__ FuncMin(uint64_t opArg=0) {}
  __device__ double operator()(double x, double y) const {
    return fmin(x, y);
  }
};

template<typename T>
struct IsFloatingPoint: std::false_type {};
template<>
struct IsFloatingPoint<half>: std::true_type {};
#if defined(RCCL_BFLOAT16)
template<>
struct IsFloatingPoint<rccl_bfloat16>: std::true_type {};
#endif
template<>
struct IsFloatingPoint<float>: std::true_type {};
template<>
struct IsFloatingPoint<double>: std::true_type {};

template<typename T, bool IsFloating=IsFloatingPoint<T>::value>
struct FuncSumPostDiv;

template<typename T>
struct FuncSumPostDiv<T, /*IsFloating=*/false>: FuncSum<T> {
  static constexpr bool IsPreOpIdentity = true;
  static constexpr bool IsPostOpIdentity = false;
  int n;
  __device__ FuncSumPostDiv(uint64_t opArg): n(opArg) {}
  // inherits FuncSum::operator()
  __device__ T preOp(T x) const { return x; }
  __device__ T postOp(T x) const { return T(x/n); }
};

template<typename T>
struct FuncSumPostDiv<T, /*IsFloating=*/true> {
  static_assert(sizeof(T)!=sizeof(T), "FuncSumPostDiv is only for implementing ncclAvg on integral types.");
};

template<typename T>
struct FuncPreMulSum: FuncSum<T> { // integral T since all floats are specialized below
  static constexpr bool IsPreOpIdentity = false;
  static constexpr bool IsPostOpIdentity = true;
  T scale;
  __device__ FuncPreMulSum(uint64_t opArg) { scale = *(T*)&opArg; }
  // inherits FuncSum::operator()
  __device__ T preOp(T x) const { return x*scale; }
  __device__ T postOp(T x) const { return x; }
};

template<>
struct FuncPreMulSum<double>: FuncSum<double> {
  static constexpr bool IsPreOpIdentity = false;
  static constexpr bool IsPostOpIdentity = true;
  double scale;
  __device__ FuncPreMulSum(uint64_t opArg) {
    scale = *(double*)&opArg;
  }
  // inherits FuncSum::operator()
  __device__ double preOp(double x) const {
    return IsPreOpIdentity ? x : x*scale;
  }
  __device__ double postOp(double x) const {
    return IsPostOpIdentity ? x : x*scale;
  }
};

template<>
struct FuncPreMulSum<float>: FuncSum<float> {
  static constexpr bool IsPreOpIdentity = false;
  static constexpr bool IsPostOpIdentity = true;
  float scale;
  __device__ FuncPreMulSum(uint64_t opArg) {
    scale = *(float*)&opArg;
  }
  // inherits FuncSum::operator()
  __device__ float preOp(float x) const {
    return IsPreOpIdentity ? x : x*scale;
  }
  __device__ float postOp(float x) const {
    return IsPostOpIdentity ? x : x*scale;
  }
};

template<>
struct FuncPreMulSum<half>: FuncSum<half> {
  // Change these to switch between all prescale, all postscale, or both by sqrt(N).
  // Obviously, the only invalid combination is both true. An improvement would be
  // make this parameterized as a build time setting and passed here through
  // preprocessor definitions.
  static constexpr bool IsPreOpIdentity = false;
  static constexpr bool IsPostOpIdentity = true;

#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
  half2 scale;
  __device__ FuncPreMulSum(uint64_t opArg) {
    scale.x = *(half*)&opArg;
    scale.y = scale.x;
  }
  // inherits FuncSum::operator()
  __device__ half preOp(half x) const {
    return IsPreOpIdentity ? x : __hmul(x, scale.x);
  }
  __device__ half2 preOp(half2 x) const {
    return IsPreOpIdentity ? x : __hmul2(x, scale);
  }
  __device__ half postOp(half x) const {
    return IsPostOpIdentity ? x : __hmul(x, scale.x);
  }
  __device__ half2 postOp(half2 x) const {
    return IsPostOpIdentity ? x : __hmul2(x, scale);
  }
#else
  float scale;
  __device__ FuncPreMulSum(uint64_t opArg) {
    scale = __half2float(*(half*)&opArg);
  }
  // inherits FuncSum::operator()
  __device__ half preOp(half x) const {
    return IsPreOpIdentity ? x : __float2half(__half2float(x)*scale);
  }
  __device__ half2 preOp(half2 x) const {
    if (IsPreOpIdentity)
      return x;
    else {
      float2 a = __half22float2(x);
      a.x *= scale;
      a.y *= scale;
      return __float22half2_rn(a);
    }
  }
  __device__ half postOp(half x) const {
    return IsPostOpIdentity ? x : __float2half(__half2float(x)*scale);
  }
  __device__ half2 postOp(half2 x) const {
    if (IsPostOpIdentity)
      return x;
    else {
      float2 a = __half22float2(x);
      a.x *= scale;
      a.y *= scale;
      return __float22half2_rn(a);
    }
  }
#endif
};

#if defined(RCCL_BFLOAT16)
template<>
struct FuncPreMulSum<rccl_bfloat16>: FuncSum<rccl_bfloat16> {
  // Change these to switch between all prescale, all postscale, or both by sqrt(N).
  // Obviously, the only invalid combination is both true. An improvement would be
  // make this parameterized as a build time setting and passed here through
  // preprocessor definitions.
  static constexpr bool IsPreOpIdentity = false;
  static constexpr bool IsPostOpIdentity = true;

  float scale;
  __device__ FuncPreMulSum(uint64_t opArg) {
    scale = *(rccl_bfloat16*)&opArg;
  }
  // inherits FuncSum::operator()
  __device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const {
    return IsPreOpIdentity ? x : (rccl_bfloat16)((float)x*scale);
  }
  __device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const {
    return IsPostOpIdentity ? x : (rccl_bfloat16)((float)x*scale);
  }
};
#endif

template<typename T>
struct FuncTraits<FuncPreMulSum<T>> {
  static constexpr bool IsPreOpIdentity = FuncPreMulSum<T>::IsPreOpIdentity;
  static constexpr bool IsPostOpIdentity = FuncPreMulSum<T>::IsPostOpIdentity;

  template<typename U>
  __device__ static U preOp(FuncPreMulSum<T> fn, U x) {
    return fn.preOp(x);
  }
  template<typename U>
  __device__ static U postOp(FuncPreMulSum<T> fn, U x) {
    return fn.postOp(x);
  }
};
template<typename T>
struct FuncTraits<FuncSumPostDiv<T>> {
  static constexpr bool IsPreOpIdentity = FuncSumPostDiv<T>::IsPreOpIdentity;
  static constexpr bool IsPostOpIdentity = FuncSumPostDiv<T>::IsPostOpIdentity;

  template<typename U>
  __device__ static U preOp(FuncSumPostDiv<T> fn, U x) {
    return fn.preOp(x);
  }
  template<typename U>
  __device__ static U postOp(FuncSumPostDiv<T> fn, U x) {
    return fn.postOp(x);
  }
};
#endif // REDUCE_KERNEL_H_
