// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#ifndef ROCRAND_RNG_MRG32K3A_H_
#define ROCRAND_RNG_MRG32K3A_H_

#include <algorithm>
#include <hip/hip_runtime.h>

#include <rocrand/rocrand.h>
#include <rocrand/rocrand_mrg32k3a_precomputed.h>

#include "common.hpp"
#include "generator_type.hpp"
#include "device_engines.hpp"
#include "distributions.hpp"

namespace rocrand_host {
namespace detail {

    typedef ::rocrand_device::mrg32k3a_engine mrg32k3a_device_engine;

    ROCRAND_KERNEL
    __launch_bounds__(ROCRAND_DEFAULT_MAX_BLOCK_SIZE)
    void init_engines_kernel(mrg32k3a_device_engine * engines,
                             const unsigned int start_engine_id,
                             unsigned long long seed,
                             unsigned long long offset)
    {
        const unsigned int engine_id = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
        engines[engine_id] = mrg32k3a_device_engine(seed, engine_id, offset + (engine_id < start_engine_id ? 1 : 0));
    }

    template<class T, class Distribution>
    ROCRAND_KERNEL
    __launch_bounds__(ROCRAND_DEFAULT_MAX_BLOCK_SIZE)
    void generate_kernel(mrg32k3a_device_engine * engines,
                         const unsigned int start_engine_id,
                         T * data, const size_t n,
                         Distribution distribution)
    {
        constexpr unsigned int input_width = Distribution::input_width;
        constexpr unsigned int output_width = Distribution::output_width;

        using vec_type = aligned_vec_type<T, output_width>;

        const unsigned int id = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
        const unsigned int stride = hipGridDim_x * hipBlockDim_x;

        // Stride must be a power of two
        const unsigned int engine_id = (id + start_engine_id) & (stride - 1);
        mrg32k3a_device_engine engine = engines[engine_id];

        unsigned int input[input_width];
        T output[output_width];

        const uintptr_t uintptr = reinterpret_cast<uintptr_t>(data);
        const size_t misalignment =
            (
                output_width - uintptr / sizeof(T) % output_width
            ) % output_width;
        const unsigned int head_size = min(n, misalignment);
        const unsigned int tail_size = (n - head_size) % output_width;
        const size_t vec_n = (n - head_size) / output_width;

        vec_type * vec_data = reinterpret_cast<vec_type *>(data + misalignment);
        size_t index = id;
        while(index < vec_n)
        {
            for(unsigned int i = 0; i < input_width; i++)
            {
                input[i] = engine();
            }
            distribution(input, output);

            vec_data[index] = *reinterpret_cast<vec_type *>(output);
            // Next position
            index += stride;
        }

        // Check if we need to save head and tail.
        // Those numbers should be generated by the thread that would
        // save next vec_type.
        if(output_width > 1 && index == vec_n)
        {
            // If data is not aligned by sizeof(vec_type)
            if(head_size > 0)
            {
                for(unsigned int i = 0; i < input_width; i++)
                {
                    input[i] = engine();
                }
                distribution(input, output);

                for(unsigned int o = 0; o < output_width; o++)
                {
                    if(o < head_size)
                    {
                        data[o] = output[o];
                    }
                }
            }

            if(tail_size > 0)
            {
                for(unsigned int i = 0; i < input_width; i++)
                {
                    input[i] = engine();
                }
                distribution(input, output);

                for(unsigned int o = 0; o < output_width; o++)
                {
                    if(o < tail_size)
                    {
                        data[n - tail_size + o] = output[o];
                    }
                }
            }
        }

        // Save engine with its state
        engines[engine_id] = engine;
    }

} // end namespace detail
} // end namespace rocrand_host

class rocrand_mrg32k3a : public rocrand_generator_type<ROCRAND_RNG_PSEUDO_MRG32K3A>
{
public:
    using base_type = rocrand_generator_type<ROCRAND_RNG_PSEUDO_MRG32K3A>;
    using engine_type = ::rocrand_host::detail::mrg32k3a_device_engine;

    rocrand_mrg32k3a(unsigned long long seed = 0,
                     unsigned long long offset = 0,
                     hipStream_t stream = 0)
        : base_type(seed, offset, stream),
          m_engines_initialized(false), m_engines(NULL), m_engines_size(s_threads * s_blocks)
    {
        // Allocate device random number engines
        auto error = hipMalloc(&m_engines, sizeof(engine_type) * m_engines_size);
        if(error != hipSuccess)
        {
            throw ROCRAND_STATUS_ALLOCATION_FAILED;
        }
        if(m_seed == 0)
        {
            m_seed = ROCRAND_MRG32K3A_DEFAULT_SEED;
        }
    }

    ~rocrand_mrg32k3a()
    {
        hipFree(m_engines);
    }

    void reset()
    {
        m_engines_initialized = false;
    }

    /// Changes seed to \p seed and resets generator state.
    ///
    /// New seed value should not be zero. If \p seed_value is equal
    /// zero, value \p ROCRAND_MRG32K3A_DEFAULT_SEED is used instead.
    void set_seed(unsigned long long seed)
    {
        if(seed == 0)
        {
            seed = ROCRAND_MRG32K3A_DEFAULT_SEED;
        }
        m_seed = seed;
        m_engines_initialized = false;
    }

    void set_offset(unsigned long long offset)
    {
        m_offset = offset;
        m_engines_initialized = false;
    }

    rocrand_status init()
    {
        if (m_engines_initialized)
            return ROCRAND_STATUS_SUCCESS;

        m_start_engine_id = m_offset % m_engines_size;

        hipLaunchKernelGGL(
            HIP_KERNEL_NAME(rocrand_host::detail::init_engines_kernel),
            dim3(s_blocks), dim3(s_threads), 0, m_stream,
            m_engines, m_start_engine_id, m_seed, m_offset / m_engines_size
        );
        // Check kernel status
        if(hipGetLastError() != hipSuccess)
            return ROCRAND_STATUS_LAUNCH_FAILURE;

        m_engines_initialized = true;

        return ROCRAND_STATUS_SUCCESS;
    }

    template<class T, class Distribution = mrg_uniform_distribution<T> >
    rocrand_status generate(T * data, size_t data_size,
                            Distribution distribution = Distribution())
    {
        rocrand_status status = init();
        if (status != ROCRAND_STATUS_SUCCESS)
            return status;

        hipLaunchKernelGGL(
            HIP_KERNEL_NAME(rocrand_host::detail::generate_kernel),
            dim3(s_blocks), dim3(s_threads), 0, m_stream,
            m_engines, m_start_engine_id, data, data_size, distribution
        );
        // Check kernel status
        if(hipGetLastError() != hipSuccess)
            return ROCRAND_STATUS_LAUNCH_FAILURE;

        // Generating data_size values will use this many distributions
        const auto touched_engines =
            (data_size + Distribution::output_width - 1) /
            Distribution::output_width;

        m_start_engine_id = (m_start_engine_id + touched_engines) % m_engines_size;

        return ROCRAND_STATUS_SUCCESS;
    }

    template<class T>
    rocrand_status generate_uniform(T * data, size_t data_size)
    {
        mrg_uniform_distribution<T> distribution;
        return generate(data, data_size, distribution);
    }

    template<class T>
    rocrand_status generate_normal(T * data, size_t data_size, T mean, T stddev)
    {
        mrg_normal_distribution<T> distribution(mean, stddev);
        return generate(data, data_size, distribution);
    }

    template<class T>
    rocrand_status generate_log_normal(T * data, size_t data_size, T mean, T stddev)
    {
        mrg_log_normal_distribution<T> distribution(mean, stddev);
        return generate(data, data_size, distribution);
    }

    rocrand_status generate_poisson(unsigned int * data, size_t data_size, double lambda)
    {
        try
        {
            m_poisson.set_lambda(lambda);
        }
        catch(rocrand_status status)
        {
            return status;
        }
        mrg_poisson_distribution distribution(m_poisson.dis);
        return generate(data, data_size, distribution);
    }

private:
    bool m_engines_initialized;
    engine_type * m_engines;
    size_t m_engines_size;

    static const uint32_t s_threads = 256;
    static const uint32_t s_blocks = 512;

    // For caching of Poisson for consecutive generations with the same lambda
    poisson_distribution_manager<> m_poisson;

    // m_seed from base_type
    // m_offset from base_type

    unsigned int m_start_engine_id;
};

#endif // ROCRAND_RNG_MRG32K3A_H_
