/* ************************************************************************
 * Copyright (c) 2018-2020 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *
 * ************************************************************************ */

#pragma once
#ifndef TESTING_GMRES_HPP
#define TESTING_GMRES_HPP

#include "utility.hpp"

#include <rocalution.hpp>

using namespace rocalution;

template <typename T>
bool testing_gmres(Arguments argus, bool expectConvergence = true)
{
    int          ndim    = argus.size;
    int          basis   = argus.index;
    std::string  matrix  = argus.matrix;
    std::string  precond = argus.precond;
    unsigned int format  = argus.format;

    // Initialize rocALUTION platform
    set_device_rocalution(device);
    init_rocalution();

    // rocALUTION structures
    LocalMatrix<T> A;
    LocalVector<T> x;
    LocalVector<T> b;
    LocalVector<T> e;

    // Generate A
    int* csr_ptr = NULL;
    int* csr_col = NULL;
    T*   csr_val = NULL;

    int nrow = 0;
    if(matrix == "laplacian")
        nrow = gen_2d_laplacian(ndim, &csr_ptr, &csr_col, &csr_val);
    else if(matrix == "permuted_identity")
        nrow = gen_permuted_identity(ndim, &csr_ptr, &csr_col, &csr_val);
    else
        return false;

    int nnz = csr_ptr[nrow];

    A.SetDataPtrCSR(&csr_ptr, &csr_col, &csr_val, "A", nnz, nrow, nrow);

    // Move data to accelerator
    A.MoveToAccelerator();
    x.MoveToAccelerator();
    b.MoveToAccelerator();
    e.MoveToAccelerator();

    // Allocate x, b and e
    x.Allocate("x", A.GetN());
    b.Allocate("b", A.GetM());
    e.Allocate("e", A.GetN());

    // b = A * 1
    e.Ones();
    A.Apply(e, &b);

    // Random initial guess
    x.SetRandomUniform(12345ULL, -4.0, 6.0);

    // Solver
    GMRES<LocalMatrix<T>, LocalVector<T>, T> ls;

    // Preconditioner
    Preconditioner<LocalMatrix<T>, LocalVector<T>, T>* p;

    if(precond == "None")
        p = NULL;
    else if(precond == "Chebyshev")
    {
        // Chebyshev preconditioner

        // Determine min and max eigenvalues
        T lambda_min;
        T lambda_max;

        A.Gershgorin(lambda_min, lambda_max);

        AIChebyshev<LocalMatrix<T>, LocalVector<T>, T>* cheb
            = new AIChebyshev<LocalMatrix<T>, LocalVector<T>, T>;
        cheb->Set(3, lambda_max / 7.0, lambda_max);

        p = cheb;
    }
    else if(precond == "FSAI")
        p = new FSAI<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "SPAI")
        p = new SPAI<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "TNS")
        p = new TNS<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "Jacobi")
        p = new Jacobi<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "GS")
        p = new GS<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "SGS")
        p = new SGS<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "ILU")
        p = new ILU<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "ILUT")
        p = new ILUT<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "IC")
        p = new IC<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "MCGS")
        p = new MultiColoredGS<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "MCSGS")
        p = new MultiColoredSGS<LocalMatrix<T>, LocalVector<T>, T>;
    else if(precond == "MCILU")
        p = new MultiColoredILU<LocalMatrix<T>, LocalVector<T>, T>;
    else
        return false;

    ls.Verbose(0);
    ls.SetOperator(A);

    // Set preconditioner
    if(p != NULL)
    {
        ls.SetPreconditioner(*p);
    }

    ls.Init(1e-6, 0.0, 1e+8, 10000);
    ls.SetBasisSize(basis);

    ls.Build();

    // Matrix format
    A.ConvertTo(format, 3);

    ls.Solve(b, &x);

    // Verify solution
    x.ScaleAdd(-1.0, e);
    T nrm2 = x.Norm();

    bool success = expectConvergence ? (nrm2 < 1e3) : true;

    // Clean up
    ls.Clear();
    if(p != NULL)
    {
        delete p;
    }

    // Stop rocALUTION platform
    stop_rocalution();

    return success;
}

#endif // TESTING_GMRES_HPP
