#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#

function print_hdrs()
{
  print "\
/*\n\
 *     Copyright (c) 2018-2019, NVIDIA CORPORATION.  All rights reserved.\n\
 *\n\
 * NVIDIA CORPORATION and its licensors retain all intellectual property\n\
 * and proprietary rights in and to this software, related documentation\n\
 * and any modifications thereto.  Any use, reproduction, disclosure or\n\
 * distribution of this software and related documentation without an express\n\
 * license agreement from NVIDIA CORPORATION is strictly prohibited.\n\
 *\n\
 */\n\
\n\n\
/*\n\
 *\n\
 * WARNING - this file is automatically generated. DO NOT EDIT.\n\
 *\n\
 */\n\
\n\
#include \"mth_intrinsics.h\"\n\
#include \"mth_tbldefs.h\"\n\
#include <complex.h>\n\
\n\
"
}

function init_target()
{
  if (MAX_VREG_SIZE == 128) {
    VLS = 4
    VLD = 2
    VL_XYZ = "x"
  } else if (MAX_VREG_SIZE == 256) {
    VLS = 8
    VLD = 4
    VL_XYZ = "y"
  } else {
    VLS = 16
    VLD = 8
    VL_XYZ = "z"
  }

  frps["f"]= ""
  frps["r"]= ""
  frps["p"]= ""
  sds["s"]= ""
  sds["d"]= ""
  sds["c"]= ""
  sds["z"]= ""
  iks["i"]= ""
  iks["k"]= ""

  # Vector types
  vts["s"] = "vrs" VLS "_t"
  vts["d"] = "vrd" VLD "_t"
  vts["c"] = "vcs" VLS/2 "_t"
  vts["z"] = "vcd" VLD/2 "_t"   # 128 -> _1v
  vts["i"] = "vis" VLS "_t"
  vts["iby2"] = "vis" VLS/2 "_t"
  vts["k"] = "vid" VLD "_t"

  # Scalar types
  sts["s"] = "float"
  sts["d"] = "double"
  sts["c"] = "float _Complex"
  sts["z"] = "double _Complex"
  sts["i"] = "int32_t"
  sts["k"] = "long long"

  vls["s"] = VLS
  vls["d"] = VLD
  vls["c"] = VLS/2
  vls["z"] = VLD/2
  vls["i"] = VLS
  vls["k"] = VLD

  vs["s"] = "vr4"
  vs["d"] = "vr8"
  vs["c"] = "vc4"
  vs["z"] = "vc8"
  vs["i"] = "vi4"
  vs["k"] = "vi8"
  vs["i1"] = "si4"
  vs["k1"] = "si8"

  one_arg = 0
  two_args = 1
  is_power = TARGET == "POWER"
  is_x8664 = TARGET == "X8664"
  is_arm64 = TARGET == "ARM64"
  is_generic = TARGET == "GENERIC"

}

function func_rr_def(name, frp, sd, yarg)
{
  print "\n" vts[sd]
  print "__g" sd "_" name "_" (MAX_VREG_SIZE == 128 && sd == "z" ? "1v" : vls[sd] "") \
        "_" frp "(" vts[sd] " x" (yarg != 0 ? ", " vts[sd] " y" : "") ")"
  print "{"
  print "  " sts[sd] " (*fptr)(" sts[sd] (yarg != 0 ? ", " sts[sd] : "") ");"
  print "  fptr = (" sts[sd] "(*)(" sts[sd] (yarg != 0 ? ", " sts[sd] : "") \
        "))MTH_DISPATCH_TBL[func_" name "][sv_" sd "s][frp_" frp "];"

  print "  return __ZGV" VL_XYZ "N" vls[sd] "v" (yarg != 0 ? "v" : "")\
        "__mth_i_" vs[sd] (yarg != 0 ? vs[sd] : "") "( x"\
	(yarg != 0 ? ", y": "") ", fptr);"
  print "}"
}


function old_do_all_rr(name, yarg)
{
  for (frp in frps) {
    for (sd in sds) {
      func_rr_def(name, frp, sd, yarg)
    }
  }
}

function do_all_rr()
{
  for (frp in frps) {
    for (sd in sds) {
      func_rr_def("acos", frp, sd, one_arg)
      func_rr_def("asin", frp, sd, one_arg)
      func_rr_def("atan", frp, sd, one_arg)
      func_rr_def("atan2", frp, sd, two_args)
      func_rr_def("cos", frp, sd, one_arg)
      func_rr_def("sin", frp, sd, one_arg)
      func_rr_def("tan", frp, sd, one_arg)
#incorrect      func_rr_def("sincos", frp, sd, one_arg)
      func_rr_def("cosh", frp, sd, one_arg)
      func_rr_def("sinh", frp, sd, one_arg)
      func_rr_def("tanh", frp, sd, one_arg)
      func_rr_def("exp", frp, sd, one_arg)
      func_rr_def("log", frp, sd, one_arg)
      func_rr_def("log10", frp, sd, one_arg)
      func_rr_def("pow", frp, sd, two_args)
      func_rr_def("div", frp, sd, two_args)
      func_rr_def("sqrt", frp, sd, one_arg)
      func_rr_def("mod", frp, sd, two_args)
      func_rr_def("aint", frp, sd, one_arg)
      func_rr_def("ceil", frp, sd, one_arg)
      func_rr_def("floor", frp, sd, one_arg)
      func_rr_def("cotan", frp, sd, one_arg)
    }
  }
}

function func_pow_decl_scalar(name, frp, sd, ik)
{
  print "\n" vts[sd]
  #
  # Another inconsistency.  For all but double precision complex, the
  # entry point is: __g<SD>_pow<IK>1_<VL>
  # But for double precision complex we have
  # entry point is: __gz_pow<IK>_1v
  # That is there is no "1" after <IK> and <VL> is hardcoded to "1v"
  #
  print "__g" sd "_" name ik (sd == "z" ? "" : "1") "_" \
        (MAX_VREG_SIZE == 128 && sd == "z" ? "1v" : vls[sd] "") \
        "_" frp "(" vts[sd] " x, " sts[ik] " iy)"
  print "{"
  print "  " sts[sd] " (*fptr)(" sts[sd] ", " sts[ik] ");"
  print "  fptr = (" sts[sd] "(*)(" sts[sd] ", " sts[ik]\
        "))MTH_DISPATCH_TBL[func_" name ik "1][sv_" sd "s][frp_" frp "];"

  print "  return __ZGV" VL_XYZ "N" vls[sd] "v"\
        "__mth_i_" vs[sd] vs[ik"1"] "(x, iy, fptr);"
	
  print "}"
}

#vcd1_t
#__gz_powi_1v(vcd1_t x, int iy)
#{
#  return(__ZGVxN1v__mth_i_vc8si4(x, iy, __mth_i_cdpowi_c99));
#}
#
#vcd1_t
#__gz_powk_1v(vcd1_t x, long long iy)
#{
#  return(__ZGVxN1v__mth_i_vc8si8(x, iy, __mth_i_cdpowk_c99));
#}


function func_pow_decl_vect(name, frp, sd, ik)
{
  print "\n" vts[sd]
  print "__g" sd "_" name ik "_" vls[sd] \
        "_" frp "(" vts[sd] " x, " \
        ((sd == "d" && ik == "i") ? (VLS == 4 ? vts[ik] : vts["iby2"]) : vts[ik])\
        " iy)"
  print "{"
  print "  " sts[sd] " (*fptr)(" sts[sd] ", " sts[ik] ");"
  print "  fptr = (" sts[sd] "(*)(" sts[sd] ", " sts[ik]\
        "))MTH_DISPATCH_TBL[func_" name ik "][sv_" sd "s][frp_" frp "];"

  print "  return __ZGV" VL_XYZ "N" vls[sd] "vv"\
        "__mth_i_" vs[sd] vs[ik] "(x, iy, fptr);"
  print "}"
}

function func_pow_decl_vect_sk(name, frp)
{
  print "\n" vts["s"]
  print "__gs_" name ik "_" vls["s"] \
        "_" frp "(" vts["s"] " x, " vts["k"] " iyu, " vts["k"] " iyl)"
  print "{"
  print "  " sts["s"] " (*fptr)(" sts["s"] ", " sts["k"] ");"
  print "  fptr = (" sts["s"] "(*)(" sts["s"] ", " sts["k"]\
        "))MTH_DISPATCH_TBL[func_" name "k][sv_ss][frp_" frp "];"
  print "  return __ZGV" VL_XYZ "N" vls["s"] "vv"\
        "__mth_i_" vs[sd] "vi8(x, iyu, iyl, fptr);"
  print "}"
}

function func_pow_def(name, frp, sd, is_scalar, ik)
{
  if (is_scalar) {
    func_pow_decl_scalar(name, frp, sd, ik)
  } else {
    # Four variants of R(:)**I(:)
    # 1) sd == "d" && ik == "k" - trivial both args same size(VLD)
    # 2) sd == "s" && ik == "i" - trivial both args same size(VLS)
    # 3) sd == "d" && ik == "i" - x is VLD, iy is effectively VLS/2
    # 4) sd == "s" && ik == "k" - x is VLS, iy is effectively VLS*2

    # Trivial first
    # Because POWER does not have a type for VLS/2 we have to use VLS for ik.
    # Cases 1..3 can be handled by func_pow_decl_vect().
    #
    if (sd == "s" && ik == "k") {
      func_pow_decl_vect_sk(name, frp)
    } else {
      func_pow_decl_vect(name, frp, sd, ik)
    }
  }
}

function do_all_pow_r2i()
{
  for (frp in frps) {
    for (sd in sds) {
      if ((sd == "c") || (sd == "z" && VLS != 4)) { continue; }
      for (ik in iks) {
        func_pow_def("pow", frp, sd, 1, ik)
        if (sd == "z") { continue; } # No vector version
        func_pow_def("pow", frp, sd, 0, ik)
      }
    }
  }
}

BEGIN {
  # Some quick runtime tests.
  if (TARGET == "POWER") {
    if (MAX_VREG_SIZE != 128) {
      print "TARGET == POWER, MAX_VREG_SIZE must be 128"
      exit(1)
    }
  } else if (TARGET == "ARM64") {
    if (MAX_VREG_SIZE != 128) {
      print "TARGET == ARM64, MAX_VREG_SIZE must be 128"
      exit(1)
    }
  } else if (MAX_VREG_SIZE != 128 && MAX_VREG_SIZE != 256 && MAX_VREG_SIZE != 512) {
    print "TARGET == X8664, MAX_VREG_SIZE must be either 128, 256, or 512"
    exit(1)
  }

# Initialize some associative arrays and constants
  init_target()

  print_hdrs()


if (0) {
  old_do_all_rr("acos", one_arg)
  old_do_all_rr("asin", one_arg)
  old_do_all_rr("atan", one_arg)
  old_do_all_rr("atan2", two_args)
  old_do_all_rr("cos", one_arg)
  old_do_all_rr("sin", one_arg)
  old_do_all_rr("tan", one_arg)
  old_do_all_rr("sincos", one_arg)
  old_do_all_rr("cosh", one_arg)
  old_do_all_rr("sinh", one_arg)
  old_do_all_rr("tanh", one_arg)
  old_do_all_rr("exp", one_arg)
  old_do_all_rr("log", one_arg)
  old_do_all_rr("log10", one_arg)
  old_do_all_rr("pow", two_args)
  old_do_all_rr("div", two_args)
  old_do_all_rr("sqrt", one_arg)
  old_do_all_rr("mod", two_args)
  old_do_all_rr("cotan", one_arg)
  
}
#  if (MAX_VREG_SIZE == 128) {
    do_all_rr()
    do_all_pow_r2i()
#  }
}
