// Copyright (c) 2017 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

/*
 * Copyright (c) 2009, 2010 Mutsuo Saito, Makoto Matsumoto and Hiroshima
 * University.  All rights reserved.
 * Copyright (c) 2011 Mutsuo Saito, Makoto Matsumoto, Hiroshima
 * University and University of Tokyo.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are
 * met:
 *
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above
 *       copyright notice, this list of conditions and the following
 *       disclaimer in the documentation and/or other materials provided
 *       with the distribution.
 *     * Neither the name of the Hiroshima University nor the names of
 *       its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written
 *       permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#ifndef ROCRAND_MTGP32_H_
#define ROCRAND_MTGP32_H_

#include <stdlib.h>

#ifndef FQUALIFIERS
#define FQUALIFIERS __forceinline__ __device__
#endif // FQUALIFIERS_

#include "rocrand.h"
#include "rocrand_common.h"

#define MTGP_MEXP 11213
#define MTGP_N 351
#define MTGP_FLOOR_2P 256
#define MTGP_CEIL_2P 512
#define MTGP_TN MTGP_FLOOR_2P
#define MTGP_LS (MTGP_TN * 3)
#define MTGP_BN_MAX 512
#define MTGP_TS 16
#define MTGP_STATE 1024
#define MTGP_MASK 1023

// Source: https://github.com/MersenneTwister-Lab/MTGP/blob/master/mtgp32-fast.h
/**
 * \struct mtgp32_params_fast_t
 * MTGP32 parameters.
 * Some element is redundant to keep structure simple.
 *
 * \b pos is a pick up position which is selected to have good
 * performance on graphic processors.  3 < \b pos < Q, where Q is a
 * maximum number such that the size of status array - Q is a power of 2.
 * For example, when \b mexp is 44497, size of 32-bit status array
 * is 696, and Q is 184, then \b pos is between 4 and 183. This means
 * 512 parallel calculations is allowed when \b mexp is 44497.
 *
 * \b poly_sha1 is SHA1 digest of the characteristic polynomial of
 * state transition function. SHA1 is calculated based on printing
 * form of the polynomial. This is important when we use parameters
 * generated by the dynamic creator which
 *
 * \b mask This is a mask to make the dimension of state space have
 * just Mersenne Prime. This is redundant.
 */
struct mtgp32_params_fast_t {
    int mexp;            /**< Mersenne exponent. This is redundant. */
    int pos;            /**< pick up position. */
    int sh1;            /**< Shift value 1. 0 < sh1 < 32. */
    int sh2;            /**< Shift value 2. 0 < sh2 < 32. */
    uint32_t tbl[16];        /**< A small matrix. */
    uint32_t tmp_tbl[16];    /**< A small matrix for tempering. */
    uint32_t flt_tmp_tbl[16];    /**< A small matrix for tempering and converting to float. */
    uint32_t mask;        /**< This is a mask for state space. */
    unsigned char poly_sha1[21]; /**< SHA1 digest. */
};

namespace rocrand_device {

struct mtgp32_params
{
    unsigned int pos_tbl[MTGP_BN_MAX];
    unsigned int param_tbl[MTGP_BN_MAX][MTGP_TS];
    unsigned int temper_tbl[MTGP_BN_MAX][MTGP_TS];
    unsigned int single_temper_tbl[MTGP_BN_MAX][MTGP_TS];
    unsigned int sh1_tbl[MTGP_BN_MAX];
    unsigned int sh2_tbl[MTGP_BN_MAX];
    unsigned int mask[1];

    FQUALIFIERS
    ~mtgp32_params() { }
};

typedef mtgp32_params_fast_t mtgp32_fast_params;

struct mtgp32_state
{
    unsigned int status[MTGP_STATE];
    int offset;
    int id;

    FQUALIFIERS
    ~mtgp32_state() { }
};

inline
void rocrand_mtgp32_init_state(unsigned int array[],
                               const mtgp32_fast_params *para, unsigned int seed)
{
    int i;
    int size = para->mexp / 32 + 1;
    unsigned int hidden_seed;
    unsigned int tmp;
    hidden_seed = para->tbl[4] ^ (para->tbl[8] << 16);
    tmp = hidden_seed;
    tmp += tmp >> 16;
    tmp += tmp >> 8;
    memset(array, tmp & 0xff, sizeof(unsigned int) * size);
    array[0] = seed;
    array[1] = hidden_seed;
    for (i = 1; i < size; i++)
        array[i] ^= (1812433253) * (array[i - 1] ^ (array[i - 1] >> 30)) + i;
}

class mtgp32_engine
{
public:
    FQUALIFIERS
    mtgp32_engine()
    {

    }

    FQUALIFIERS
    mtgp32_engine(const mtgp32_state m_state,
                  const mtgp32_params * params,
                  int bid)
    {
        this->m_state = m_state;
        pos_tbl = params->pos_tbl[bid];
        sh1_tbl = params->sh1_tbl[bid];
        sh2_tbl = params->sh2_tbl[bid];
        mask = params->mask[0];
        for (int j = 0; j < MTGP_TS; j++) {
            param_tbl[j] = params->param_tbl[bid][j];
            temper_tbl[j] = params->temper_tbl[bid][j];
            single_temper_tbl[j] = params->single_temper_tbl[bid][j];
        }
    }

    FQUALIFIERS
    ~mtgp32_engine() { }

    FQUALIFIERS
    void copy(const mtgp32_engine * m_engine)
    {
        #if defined(__HIP_DEVICE_COMPILE__)
        const unsigned int thread_id = hipThreadIdx_x;
        for (int i = thread_id; i < MTGP_STATE; i += hipBlockDim_x)
            m_state.status[i] = m_engine->m_state.status[i];

        if (thread_id == 0)
        {
            m_state.offset = m_engine->m_state.offset;
            m_state.id = m_engine->m_state.id;
            pos_tbl = m_engine->pos_tbl;
            sh1_tbl = m_engine->sh1_tbl;
            sh2_tbl = m_engine->sh2_tbl;
            mask = m_engine->mask;
        }
        if (thread_id < MTGP_TS)
        {
            param_tbl[thread_id] = m_engine->param_tbl[thread_id];
            temper_tbl[thread_id] = m_engine->temper_tbl[thread_id];
            single_temper_tbl[thread_id] = m_engine->single_temper_tbl[thread_id];
        }
        __syncthreads();
        #else
        this->m_state = m_engine->m_state;
        pos_tbl = m_engine->pos_tbl;
        sh1_tbl = m_engine->sh1_tbl;
        sh2_tbl = m_engine->sh2_tbl;
        mask = m_engine->mask;
        for (int j = 0; j < MTGP_TS; j++) {
            param_tbl[j] = m_engine->param_tbl[j];
            temper_tbl[j] = m_engine->temper_tbl[j];
            single_temper_tbl[j] = m_engine->single_temper_tbl[j];
        }
        #endif
    }

    FQUALIFIERS
    void set_params(mtgp32_params * params)
    {
        pos_tbl = params->pos_tbl[m_state.id];
        sh1_tbl = params->sh1_tbl[m_state.id];
        sh2_tbl = params->sh2_tbl[m_state.id];
        mask = params->mask[0];
        for (int j = 0; j < MTGP_TS; j++) {
            param_tbl[j] = params->param_tbl[m_state.id][j];
            temper_tbl[j] = params->temper_tbl[m_state.id][j];
            single_temper_tbl[j] = params->single_temper_tbl[m_state.id][j];
        }
    }

    FQUALIFIERS
    unsigned int operator()()
    {
        return this->next();
    }

    FQUALIFIERS
    unsigned int next()
    {
        #if defined(__HIP_DEVICE_COMPILE__)
        unsigned int t = hipThreadIdx_x;
        unsigned int d = hipBlockDim_x;
        int pos = pos_tbl;
        unsigned int r;
        unsigned int o;

        r = para_rec(m_state.status[(t + m_state.offset) & MTGP_MASK],
                     m_state.status[(t + m_state.offset + 1) & MTGP_MASK],
                     m_state.status[(t + m_state.offset + pos) & MTGP_MASK]);
        m_state.status[(t + m_state.offset + MTGP_N) & MTGP_MASK] = r;

        o = temper(r, m_state.status[(t + m_state.offset + pos - 1) & MTGP_MASK]);
        __syncthreads();
        if (t == 0)
            m_state.offset = (m_state.offset + d) & MTGP_MASK;
        __syncthreads();
        return o;
        #else
        return 0;
        #endif
    }

    FQUALIFIERS
    unsigned int next_single()
    {
        #if defined(__HIP_DEVICE_COMPILE__)
        unsigned int t = hipThreadIdx_x;
        unsigned int d = hipBlockDim_x;
        int pos = pos_tbl;
        unsigned int r;
        unsigned int o;

        r = para_rec(m_state.status[(t + m_state.offset) & MTGP_MASK],
                     m_state.status[(t + m_state.offset + 1) & MTGP_MASK],
                     m_state.status[(t + m_state.offset + pos) & MTGP_MASK]);
        m_state.status[(t + m_state.offset + MTGP_N) & MTGP_MASK] = r;

        o = temper_single(r, m_state.status[(t + m_state.offset + pos - 1) & MTGP_MASK]);
        __syncthreads();
        if (t == 0)
            m_state.offset = (m_state.offset + d) & MTGP_MASK;
        __syncthreads();
        return o;
        #else
        return 0;
        #endif
    }

private:
    FQUALIFIERS
    unsigned int para_rec(unsigned int X1, unsigned int X2, unsigned int Y)
    {
        unsigned int X = (X1 & mask) ^ X2;
        unsigned int MAT;

        X ^= X << sh1_tbl;
        Y = X ^ (Y >> sh2_tbl);
        MAT = param_tbl[Y & 0x0f];
        return Y ^ MAT;
    }

    FQUALIFIERS
    unsigned int temper(unsigned int V, unsigned int T)
    {
        unsigned int MAT;

        T ^= T >> 16;
        T ^= T >> 8;
        MAT = temper_tbl[T & 0x0f];
        return V ^ MAT;
    }

    FQUALIFIERS
    unsigned int temper_single(unsigned int V, unsigned int T)
    {
        unsigned int MAT;
        unsigned int r;

        T ^= T >> 16;
        T ^= T >> 8;
        MAT = single_temper_tbl[T & 0x0f];
        r = (V >> 9) ^ MAT;
        return r;
    }

public:
    // State
    mtgp32_state m_state;
    // Parameters
    unsigned int pos_tbl;
    unsigned int param_tbl[MTGP_TS];
    unsigned int temper_tbl[MTGP_TS];
    unsigned int single_temper_tbl[MTGP_TS];
    unsigned int sh1_tbl;
    unsigned int sh2_tbl;
    unsigned int mask;

}; // mtgp32_engine class

} // end namespace rocrand_device

/** \rocrand_internal \addtogroup rocranddevice
 *
 *  @{
 */

/// \cond ROCRAND_KERNEL_DOCS_TYPEDEFS
typedef rocrand_device::mtgp32_engine rocrand_state_mtgp32;
typedef rocrand_device::mtgp32_state mtgp32_state;
typedef rocrand_device::mtgp32_fast_params mtgp32_fast_params;
typedef rocrand_device::mtgp32_params mtgp32_params;
/// \endcond

/**
 * \brief Initializes MTGP32 states
 *
 * Initializes MTGP32 states on the host-side by allocating a state array in host
 * memory, initializes that array, and copies the result to device memory.
 *
 * \param d_state - Pointer to an array of states in device memory
 * \param params - Pointer to an array of type mtgp32_fast_params in host memory
 * \param n - Number of states to initialize
 * \param seed - Seed value
 *
 * \return
 * - ROCRAND_STATUS_ALLOCATION_FAILED if states could not be initialized
 * - ROCRAND_STATUS_SUCCESS if states are initialized
 */
__host__ inline
rocrand_status rocrand_make_state_mtgp32(rocrand_state_mtgp32 * d_state,
                                         mtgp32_fast_params params[],
                                         int n,
                                         unsigned long long seed)
{
    int i;
    rocrand_state_mtgp32 * h_state = (rocrand_state_mtgp32 *) malloc(sizeof(rocrand_state_mtgp32) * n);
    seed = seed ^ (seed >> 32);

    if (h_state == NULL)
        return ROCRAND_STATUS_ALLOCATION_FAILED;

    for (i = 0; i < n; i++) {
        rocrand_device::rocrand_mtgp32_init_state(&(h_state[i].m_state.status[0]), &params[i], (unsigned int)seed + i + 1);
        h_state[i].m_state.offset = 0;
        h_state[i].m_state.id = i;
        h_state[i].pos_tbl = params[i].pos;
        h_state[i].sh1_tbl = params[i].sh1;
        h_state[i].sh2_tbl = params[i].sh2;
        h_state[i].mask = params[0].mask;
        for (int j = 0; j < MTGP_TS; j++) {
            h_state[i].param_tbl[j] = params[i].tbl[j];
            h_state[i].temper_tbl[j] = params[i].tmp_tbl[j];
            h_state[i].single_temper_tbl[j] = params[i].flt_tmp_tbl[j];
        }
    }

    hipMemcpy(d_state, h_state, sizeof(rocrand_state_mtgp32) * n, hipMemcpyHostToDevice);
    free(h_state);

    if (hipPeekAtLastError() != hipSuccess)
        return ROCRAND_STATUS_ALLOCATION_FAILED;

    return ROCRAND_STATUS_SUCCESS;
}

/**
 * \brief Loads parameters for MTGP32
 *
 * Loads parameters for use by kernel functions on the host-side and copies the
 * results to the specified location in device memory.
 *
 * NOTE: Not used as rocrand_make_state_mtgp32 handles loading parameters into
 * state.
 *
 * \param params - Pointer to an array of type mtgp32_fast_params in host memory
 * \param p - Pointer to a mtgp32_params structure allocated in device memory
 *
 * \return
 * - ROCRAND_STATUS_ALLOCATION_FAILED if parameters could not be loaded
 * - ROCRAND_STATUS_SUCCESS if parameters are loaded
 */
__host__ inline
rocrand_status rocrand_make_constant(const mtgp32_fast_params params[], mtgp32_params * p)
{
    const int block_num = MTGP_BN_MAX;
    const int size1 = sizeof(uint32_t) * block_num;
    const int size2 = sizeof(uint32_t) * block_num * MTGP_TS;
    uint32_t *h_pos_tbl;
    uint32_t *h_sh1_tbl;
    uint32_t *h_sh2_tbl;
    uint32_t *h_param_tbl;
    uint32_t *h_temper_tbl;
    uint32_t *h_single_temper_tbl;
    uint32_t *h_mask;
    h_pos_tbl = (uint32_t *)malloc(size1);
    h_sh1_tbl = (uint32_t *)malloc(size1);
    h_sh2_tbl = (uint32_t *)malloc(size1);
    h_param_tbl = (uint32_t *)malloc(size2);
    h_temper_tbl = (uint32_t *)malloc(size2);
    h_single_temper_tbl = (uint32_t *)malloc(size2);
    h_mask = (uint32_t *)malloc(sizeof(uint32_t));
    rocrand_status status = ROCRAND_STATUS_SUCCESS;

    if (h_pos_tbl == NULL || h_sh1_tbl == NULL || h_sh2_tbl == NULL
        || h_param_tbl == NULL || h_temper_tbl == NULL || h_single_temper_tbl == NULL
        || h_mask == NULL) {
        printf("failure in allocating host memory for constant table.\n");
        status = ROCRAND_STATUS_ALLOCATION_FAILED;
    }
    else {
        h_mask[0] = params[0].mask;
        for (int i = 0; i < block_num; i++) {
            h_pos_tbl[i] = params[i].pos;
            h_sh1_tbl[i] = params[i].sh1;
            h_sh2_tbl[i] = params[i].sh2;
            for (int j = 0; j < MTGP_TS; j++) {
                h_param_tbl[i * MTGP_TS + j] = params[i].tbl[j];
                h_temper_tbl[i * MTGP_TS + j] = params[i].tmp_tbl[j];
                h_single_temper_tbl[i * MTGP_TS + j] = params[i].flt_tmp_tbl[j];
            }
        }

        if (hipMemcpy(p->pos_tbl, h_pos_tbl, size1, hipMemcpyHostToDevice) != hipSuccess)
            status = ROCRAND_STATUS_ALLOCATION_FAILED;
        if (hipMemcpy(p->sh1_tbl, h_sh1_tbl, size1, hipMemcpyHostToDevice) != hipSuccess)
            status = ROCRAND_STATUS_ALLOCATION_FAILED;
        if (hipMemcpy(p->sh2_tbl, h_sh2_tbl, size1, hipMemcpyHostToDevice) != hipSuccess)
            status = ROCRAND_STATUS_ALLOCATION_FAILED;
        if (hipMemcpy(p->param_tbl, h_param_tbl, size2, hipMemcpyHostToDevice) != hipSuccess)
            status = ROCRAND_STATUS_ALLOCATION_FAILED;
        if (hipMemcpy(p->temper_tbl, h_temper_tbl, size2, hipMemcpyHostToDevice) != hipSuccess)
            status = ROCRAND_STATUS_ALLOCATION_FAILED;
        if (hipMemcpy(p->single_temper_tbl, h_single_temper_tbl, size2, hipMemcpyHostToDevice) != hipSuccess)
            status = ROCRAND_STATUS_ALLOCATION_FAILED;
        if (hipMemcpy(p->mask, h_mask, sizeof(unsigned int), hipMemcpyHostToDevice) != hipSuccess)
            status = ROCRAND_STATUS_ALLOCATION_FAILED;
    }

    free(h_pos_tbl);
    free(h_sh1_tbl);
    free(h_sh2_tbl);
    free(h_param_tbl);
    free(h_temper_tbl);
    free(h_single_temper_tbl);
    free(h_mask);

    return status;
}

/**
 * \brief Returns uniformly distributed random <tt>unsigned int</tt> value
 * from [0; 2^32 - 1] range.
 *
 * Generates and returns uniformly distributed random <tt>unsigned int</tt>
 * value from [0; 2^32 - 1] range using MTGP32 generator in \p state.
 * State is incremented by one position.
 *
 * \param state - Pointer to a state to use
 *
 * \return Pseudorandom value (32-bit) as an <tt>unsigned int</tt>
 */
FQUALIFIERS
unsigned int rocrand(rocrand_state_mtgp32 * state)
{
    return state->next();
}

/**
 * \brief Copies MTGP32 state to another state using block of threads
 *
 * Copies a MTGP32 state \p src to \p dest using a block of threads
 * efficiently. Example usage would be:
 *
 * \code
 * __global__
 * void generate_kernel(hiprandStateMtgp32_t * states, unsigned int * output, const size_t size)
 * {
 *      const unsigned int state_id = hipBlockIdx_x;
 *      unsigned int index = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
 *      unsigned int stride = hipGridDim_x * hipBlockDim_x;
 *
 *      __shared__ GeneratorState state;
 *      rocrand_mtgp32_block_copy(&states[state_id], &state);
 *
 *      while(index < size)
 *      {
 *          output[index] = rocrand(&state);
 *          index += stride;
 *      }
 *
 *      rocrand_mtgp32_block_copy(&state, &states[state_id]);
 * }
 * \endcode
 *
 * \param src - Pointer to a state to copy from
 * \param dest - Pointer to a state to copy to
 *
 */
FQUALIFIERS
void rocrand_mtgp32_block_copy(rocrand_state_mtgp32 * src, rocrand_state_mtgp32 * dest)
{
    dest->copy(src);
}

/**
 * \brief Changes parameters of a MTGP32 state.
 *
 * \param state - Pointer to a MTGP32 state
 * \param params - Pointer to new parameters
 */
FQUALIFIERS
void rocrand_mtgp32_set_params(rocrand_state_mtgp32 * state, mtgp32_params * params)
{
    state->set_params(params);
}

/** @} */ // end of group rocranddevice

#endif // ROCRAND_MTGP32_H_
