/******************************************************************************
* Copyright (c) 2019 - present Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*******************************************************************************/

#include <complex>
#include <hip/hip_runtime_api.h>
#include <hipfft.h>
#include <iostream>
#include <vector>

int main()
{
    std::cout
        << "hipfft 2D double-precision complex-to-complex transform using advanced interface\n";

    int rank    = 2;
    int n[2]    = {4, 5};
    int howmany = 3;

    // array is contiguous in memory
    int istride = 1;
    // in-place transforms require istride=ostride
    int ostride = istride;

    // we choose to have no padding around our data:
    int inembed[2] = {istride * n[0], istride * n[1]};
    // in-place transforms require inembed=oneembed:
    int onembed[2] = {inembed[0], inembed[1]};

    int idist = inembed[0] * inembed[1];
    int odist = onembed[0] * onembed[1];

    std::cout << "n: " << n[0] << " " << n[1] << "\n"
              << "howmany: " << howmany << "\n"
              << "istride: " << istride << "\tostride: " << ostride << "\n"
              << "inembed: " << inembed[0] << " " << inembed[1] << "\n"
              << "onembed: " << onembed[0] << " " << onembed[1] << "\n"
              << "idist: " << idist << "\todist: " << odist << "\n"
              << std::endl;

    std::vector<std::complex<double>> data(howmany * idist);
    const auto total_bytes = data.size() * sizeof(decltype(data)::value_type);

    std::cout << "input:\n";
    std::fill(data.begin(), data.end(), 0.0);
    for(int ibatch = 0; ibatch < howmany; ++ibatch)
    {
        for(int i = 0; i < n[0]; i++)
        {
            for(int j = 0; j < n[1]; j++)
            {
                const auto pos = ibatch * idist + istride * (i * inembed[1] + j);
                data[pos]      = std::complex<double>(i + ibatch, j);
            }
        }
    }
    for(int ibatch = 0; ibatch < howmany; ++ibatch)
    {
        std::cout << "batch: " << ibatch << "\n";
        for(int i = 0; i < inembed[0]; i++)
        {
            for(int j = 0; j < inembed[1]; j++)
            {
                const auto pos = ibatch * idist + i * inembed[1] + j;
                std::cout << data[pos] << " ";
            }
            std::cout << "\n";
        }
        std::cout << "\n";
    }
    std::cout << std::endl;

    hipfftHandle hipPlan;
    hipfftResult result;
    result = hipfftPlanMany(
        &hipPlan, rank, n, inembed, istride, idist, onembed, ostride, odist, HIPFFT_Z2Z, howmany);

    hipfftDoubleComplex* d_in_out;
    hipMalloc((void**)&d_in_out, total_bytes);
    hipMemcpy(d_in_out, (void*)data.data(), total_bytes, hipMemcpyHostToDevice);

    result = hipfftExecZ2Z(hipPlan, d_in_out, d_in_out, HIPFFT_FORWARD);

    hipMemcpy((void*)data.data(), d_in_out, total_bytes, hipMemcpyDeviceToHost);

    std::cout << "output:\n";
    for(int ibatch = 0; ibatch < howmany; ++ibatch)
    {
        std::cout << "batch: " << ibatch << "\n";
        for(int i = 0; i < onembed[0]; i++)
        {
            for(int j = 0; j < onembed[1]; j++)
            {
                const auto pos = ibatch * odist + i * onembed[1] + j;
                std::cout << data[pos] << " ";
            }
            std::cout << "\n";
        }
        std::cout << "\n";
    }
    std::cout << std::endl;

    hipFree(d_in_out);
}
