/*
 * Copyright (c) 2012-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

#define	CONCAT_(l,r)	l##r
#define	CONCAT(l,r)	CONCAT_(l,r)

#ifdef TABLE_TARGET

/*  #include <math.h>  */

/* Table elements/constants for exp... */

static const double table[] = {1.0,
                               1.0218971486541166,
                               1.0442737824274138,
                               1.0671404006768237,
                               1.0905077326652577,
                               1.1143867425958924,
                               1.1387886347566916,
                               1.1637248587775775,
                               1.1892071150027210,
                               1.2152473599804690,
                               1.2418578120734840,
                               1.2690509571917332,
                               1.2968395546510096,
                               1.3252366431597413,
                               1.3542555469368927,
                               1.3839098819638320,
                               1.4142135623730951,
                               1.4451808069770467,
                               1.4768261459394993,
                               1.5091644275934228,
                               1.5422108254079407,
                               1.5759808451078865,
                               1.6104903319492543,
                               1.6457554781539650,
                               1.6817928305074290,
                               1.7186192981224780,
                               1.7562521603732995,
                               1.7947090750031072,
                               1.8340080864093424,
                               1.8741676341103000,
                               1.9152065613971474,
                               1.9571441241754002,
                               2.0};

static const double invl2x32 = 46.166241308446829;     // 32.0/ln(2.0)
static const double ln2xinv32 = 0.0216608493924982909; // ln(2.0)/32.0
static const double inv32 = 3.125e-2;                  // 1.0/32.0
static const double ln2 = 0.693147180559945309;        // ln(2.0)
static const double invln2 = 1.44269504088896341;      // 1.0/ln(2.0)
static const double a1 = 5.0000000000000000e-1;        // coefficient for x^2
static const double a2 = 1.6666666666526087e-1;        // coefficient for x^3
static const double a3 = 0.041666666666666667;         // coefficient for x^4
                                                       /*
                                                        * two is the integer bit pattern for 2.0 in the low half of a 64-bit floating
                                                        * point
                                                        * number. muliples of inc added to or subtracted from two manipulates the
                                                        * exponent
                                                        * of 2.0, effectively raising 2.0 to a power.
                                                        */

static const int two = 1072693248, inc = 1048576;
static const int ten23 = 1023;
static const long long isign = 0x100000000000000;

/* Table elements/constants for sincos... */

static const int ts = 32; /* Maximum index into arraySC */

/*
 * Obviously not all these values will be use. Compiler will round them. Numbers
 * were
 * generated by Casio High Accuracy calculator at:
 * http://keisan.casio.com/has10/Free.cgi
 */

static const double pi = 3.1415926535897932384626433832795028841971693;
static const double pihalves =
    3.1415926535897932384626433832795028841971693 / 2.0;
static const double invpihalves =
    2.0 / 3.1415926535897932384626433832795028841971693;

/*
 *Next constant is the size of the increments for the loopup table
 */

static const double angle_inc =
    3.1415926535897932384626433832795028841971693 / 64.0;

/*
 *This is the inverse of the angle increment
 */

static const double invangle_inc =
    64.0 / 3.1415926535897932384626433832795028841971693;

static const double scArray[] = {
    // Lookup table for approximating sincos
    0.000000000000000000, 0.049067674327418015, 0.098017140329560604,
    0.146730474455361748, 0.195090322016128248, 0.242980179903263871,
    0.290284677254462331, 0.336889853392220051, 0.382683432365089782,
    0.427555093430282085, 0.471396736825997642, 0.514102744193221661,
    0.555570233019602178, 0.595699304492433357, 0.634393284163645488,
    0.671558954847018330, 0.707106781186547462, 0.740951125354959106,
    0.773010453362736993, 0.803207531480644832, 0.831469612302545236,
    0.857728610000272118, 0.881921264348354939, 0.903989293123443227,
    0.923879532511286738, 0.941544065183020806, 0.956940335732208935,
    0.970031253194543974, 0.980785280403230431, 0.989176509964780903,
    0.995184726672196929, 0.998795456205172405, 1.000000000000000000,
};
#else

#define fabs(x) __builtin_fabs(x)
double __builtin_fabs(double);

#define cos(x) __builtin_cos(x)
double __builtin_cos(double);

#define sin(x) __builtin_sin(x)
double __builtin_sin(double);

extern	double CONCAT(__fsd_exp_,TARGET_VEX_OR_FMA)(double);

double	CONCAT(__rsd_exp_,TARGET_VEX_OR_FMA)(double x)
{

  /*
   * This algorithm was orignally developed by Peter Tang. It is based on
   * the work described in:
   * P. T. P. Tang. Table driven implementation of the exponential function
   * in IEEE floating-point arithmetic. ACM Transactions on Mathematical
   * Software, 15(2):144-157, June 1989
   *
   * I have short circuited some of the things Tang did for higher accuracy in
   * the interest
   * of performance.
   * Testing indicates that for degree-2 polynomial, accuracy ~ 2.0x10-7 and for
   * degree-3 polynomial, accuracy ~4.0x10-10.
   */

  long long itemp;
  int m, j;
  double r, q, s, result;
  double xprime, temp, tempx;
  double a2r, a1r, rsq, rcube;
  union {
    unsigned long long i;
    double d;
  } convert;
  double half = 0.5;
  int n, n1, n2;
  /*
   * for now we will introduce ntemp;
   */
  int ntemp;
  int addmod = 0;
  int n1pj;
  double n1pjln2inv32;
  tempx = x * invl2x32; // x * 32.0/ln(2) -- scaled x
  if (fabs(x) > 200.0)  // guard against over/underflow
    return (CONCAT(__fsd_exp_,TARGET_VEX_OR_FMA)(x));
  if (x < 0.0) {                // We need to know if x < 0.0
    half = -0.5;                // For rounding to nearest
    addmod = 32;                // Add to computer % for correct mod
  }
  tempx += half;      // Assures rounding to nearest
  n = (int)(tempx);   // integer part of tempx rounded to nearest
                      /*
                       * Create our own modulo function
                       *
                       */
  ntemp = n >> 5;     // Divide by 32
  ntemp = ntemp << 5; // And then multiply by 32
  j = n - ntemp;      // this is n%32 modulo a sign issue
  n1 = n - j;         // this is almost m ...
  m = n1 >> 5;        // ... once we divide by 32
  n1pj = n1 + j;      // add j to n1
  n1pjln2inv32 = ln2xinv32 * (double)n1pj; // then multiply it by ln(2)/32.0
  r = x - n1pjln2inv32;                    // get the remainder
  convert.i = ten23 + m;                   // 1023 << 52 = 2.0
  convert.i = convert.i << 52;             // convert.d now has 2^m
  s = table[j];                            // get 2^(j/32)
  s *= convert.d;                          // * 2^m
                                           /*
                                            * We're going to use just a few terms of the expansion.
                                            *
                                            * p(t) = t + t^2/2.0 + t^3/6.0. We can add 3rd term if needed.
                                            *
                                            */
  /*
   * for degree 3 polynomial,
   *
   * q =  r * ( 1 + r * ( a1 + a2 * r ) )
   *
   */

  q = r * a2;
  q += a1;
  q *= r;
  q += 1.0;
  q *= r;

  /*
   * For degree 2 polynomial
   *
   * q = r * ( 1 + a1 * r )
   */
  /*
    q = a1 * r;
    q += 1.0;
    q *= r;
  */
  result = 1.0 + q; // ~exp(r)
  result *= s;      // ~exp(x)
  return (result);
}

/*
 * Compute the complex exponential.
 * We have that exp( x + iy ) = exp( x ) * ( sin( y ) + i cos( y ) )
 *
 * For purposes of this code, we have
 *
 * exp( z ) = exp( real( z ) + ( sincos( imag( z ) )
 *
*/


extern	double	CONCAT(__rsd_exp_,TARGET_VEX_OR_FMA)(double);
extern	void	CONCAT(__rsd_sincos_c_,TARGET_VEX_OR_FMA)(double, double *, double *);

void
CONCAT(__rsz_exp_,TARGET_VEX_OR_FMA)(double *y, double xr, double xi)
{
  double sinvalue, cosvalue;
  double expvalue;

  expvalue = CONCAT(__rsd_exp_,TARGET_VEX_OR_FMA)(xr); // exp( xr )
  /*  __rsd_sincos_c_fma4( xi, &sinvalue, &cosvalue ); // sincos( xi ) */
  CONCAT(__rsd_sincos_c_,TARGET_VEX_OR_FMA)(xi, &y[1], &y[0]); // sincos( xi )
                                         /*
                                           y[ 0 ] = expvalue * cosvalue;         // exp( xr ) * cos( xi )
                                           y[ 1 ] = expvalue * sinvalue;         // exp( xr ) * sin( xi )
                                         */
  y[0] = expvalue * y[0];                // exp( xr ) * cos( xi )
  y[1] = expvalue * y[1];                // exp( xr ) * sin( xi )
}

/*
 * Because the range of input values for sines and cosines this code has to
 * solve
 * is limited to angles between +/- 100, and because the accuracy constraints
 * are
 * fairly loose, this program does the reduction phase using the most obvious
 * method, namely dividing the input angle by pi/2.0. Higher accuracy is not
 * required.
 * Once the reduced angle is determined, we need six more steps.
 * 1. Approximate the sincos for the reduced angle. We do this with a lookup
 * table.
 * 2. Subtract the angle for which we have entries in the table from the
 *    reduced angle. This will leave a small residual angle for which the sine
 *    and cosine need to be computed.
 * 3. Use a minmax polynomial to compute the remainder angle. (See Muller, pages
 * 58,
 *    59)
 * 4. Determine in which quadrant the angle resides to properly assine signs and
 *    whether, for instance, and then decide how properly to use the information
 *    in step 5. A modulo operation determines this.
 * 5.. Use the trig identities for computing sines and cosines for sums of
 * angles:
 *    sin( a + b ) = sin( a ) * cos( b ) + cos( a ) * sin( b )
 *    cos( a + b ) = cos( a ) * cos( b ) - sin( a ) * sin( b )
 * 6. Correct the sign for the sine depending on whether x >= 0 or x < 0.
 */

/*
 * The next two coeffiecients used for low order sincos approximation
 */

#define S33 -0.1666596  /* Multiplies x^3 for sin */
#define S22 -0.49996629 /* Multiples x^2 for cos */

void
CONCAT(__rsd_sincos_c_,TARGET_VEX_OR_FMA)(double x, double *s, double *c)
{
  double kpihalves;                      // pihalves * k
  int k;                                 // x/pihalves
  int sindex, cindex;                    // indices into table
  double tsin, tcos;                     // sine and cosine from table
  double tsinss, tsinsc, tcosss, tcossc; // products of step 5 above
  double mysin, mycos;                   // final sine and cosine
  double x2, x3;                         // remains**2 and remains * x2
  double smallsin, smallcos;             // sine and cosine of remainder
  double x2s22, x3s33;                   // x2 * s22 and x3 * s33
  double remainder, remains;             // used to compute x - k * pi/2
  double sign;                           // sign( x )
  double fabsx;                          // absolute value of x
  double r;                              // temporary storage for remainder
  int mod4;                              // This is used to compute k%4

  fabsx = fabs(x);                     // Need |x|
  if (fabsx < 100.0) {                 // We'll work only on "small" angles
    sign = 1.0 - 2.0 * (x < 0.0);      // Faster than test and branch
    k = fabsx * invpihalves;           // How many 1/2 rotations in |x|
    kpihalves = (double)k * pihalves;  // now we should be at k * pi/2
    remainder = fabsx - kpihalves;     // x - k * pi/2
    sindex = remainder * invangle_inc; // generate index into sine table
    cindex = ts - sindex;              // generate index inot cosine table
    tsin = scArray[sindex];            // fetch sine from table
    tcos = scArray[cindex];            // fetch cosine from table
    r = (double)sindex * angle_inc;    // Need x - k * pi/2 - small-angle
    remains = remainder - r;           // sin( remains ) has to be computed
    x2 = remains * remains;            // x'^2 starts polynomial expansion
    x3 = x2 * remains;                 // x'^3
    x3s33 = x3 * S33;                  // x'^3 * S33
    x2s22 = x2 * S22;                  // x'^2 * S22
    smallsin = x3s33 + remains;        // ~sin( x' )
    smallcos = x2s22 + 1.0;            // ~cos( x' )
    tsinss = tsin * smallsin;          // part of sin( a + b )
    tsinsc = tsin * smallcos;          // "                  "
    tcosss = tcos * smallsin;          // "                  "
    tcossc = tcos * smallcos;          // "                  "
    mod4 = k >> 2;                     // Computing mod with shifts
    mod4 = mod4 << 2;                  // for k%4 is faster than k%4
    mod4 = k - mod4;                   //

    switch (mod4) { // mod4 selects the quadrant to ..
    case 0:         // .. properly compute sincos()
      mysin = tsinsc + tcosss;
      mysin = sign * mysin;
      mycos = tcossc - tsinss;
      *s = mysin;
      *c = mycos;
      break;
    case 1:
      mysin = tcossc - tsinss;
      mysin = sign * mysin;
      mycos = -(tsinsc + tcosss);
      *s = mysin;
      *c = mycos;
      break;
    case 2:
      mysin = -(tsinsc + tcosss);
      mysin = sign * mysin;
      mycos = tsinss - tcossc;
      *s = mysin;
      *c = mycos;
      break;
    case 3:
      mysin = tsinss - tcossc;
      mysin = sign * mysin;
      mycos = tsinsc + tcosss;
      *s = mysin;
      *c = mycos;
    }
  } else {
    *s = sin(x); // We go here if |x| > 100.0
    *c = cos(x);
  }
}

#endif
