/************************************************************************
 * Derived from the BSD3-licensed
 * LAPACK routine (version 3.7.0) --
 *     Univ. of Tennessee, Univ. of California Berkeley,
 *     Univ. of Colorado Denver and NAG Ltd..
 *     December 2016
 * Copyright (c) 2021 Advanced Micro Devices, Inc.
 * ***********************************************************************/

#pragma once

#include "rocblas.hpp"
#include "roclapack_potrf.hpp"
#include "roclapack_syevd_heevd.hpp"
#include "roclapack_sygst_hegst.hpp"
#include "roclapack_sygv_hegv.hpp"
#include "rocsolver.h"

template <bool BATCHED, typename T, typename S>
void rocsolver_sygvd_hegvd_getMemorySize(const rocblas_eform itype,
                                         const rocblas_evect evect,
                                         const rocblas_fill uplo,
                                         const rocblas_int n,
                                         const rocblas_int batch_count,
                                         size_t* size_scalars,
                                         size_t* size_work1,
                                         size_t* size_work2,
                                         size_t* size_work3,
                                         size_t* size_work4,
                                         size_t* size_tau,
                                         size_t* size_pivots_workArr,
                                         size_t* size_iinfo,
                                         bool* optim_mem)
{
    // if quick return no need of workspace
    if(n == 0 || batch_count == 0)
    {
        *size_scalars = 0;
        *size_work1 = 0;
        *size_work2 = 0;
        *size_work3 = 0;
        *size_work4 = 0;
        *size_tau = 0;
        *size_pivots_workArr = 0;
        *size_iinfo = 0;
        *optim_mem = true;
        return;
    }

    bool opt1, opt2, opt3 = true;
    size_t unused, temp1, temp2, temp3, temp4, temp5;

    // requirements for calling POTRF
    rocsolver_potrf_getMemorySize<BATCHED, T>(n, uplo, batch_count, size_scalars, size_work1,
                                              size_work2, size_work3, size_work4,
                                              size_pivots_workArr, size_iinfo, &opt1);
    *size_iinfo = max(*size_iinfo, sizeof(rocblas_int) * batch_count);

    // requirements for calling SYGST/HEGST
    rocsolver_sygst_hegst_getMemorySize<BATCHED, T>(uplo, itype, n, batch_count, &unused, &temp1,
                                                    &temp2, &temp3, &temp4, &opt2);
    *size_work1 = max(*size_work1, temp1);
    *size_work2 = max(*size_work2, temp2);
    *size_work3 = max(*size_work3, temp3);
    *size_work4 = max(*size_work4, temp4);

    // requirements for calling SYEV/HEEV
    rocsolver_syevd_heevd_getMemorySize<BATCHED, T, S>(evect, uplo, n, batch_count, &unused, &temp1,
                                                       &temp2, &temp3, &temp4, size_tau, &temp5);
    *size_work1 = max(*size_work1, temp1);
    *size_work2 = max(*size_work2, temp2);
    *size_work3 = max(*size_work3, temp3);
    *size_work4 = max(*size_work4, temp4);
    *size_pivots_workArr = max(*size_pivots_workArr, temp5);

    if(evect == rocblas_evect_original)
    {
        if(itype == rocblas_eform_ax || itype == rocblas_eform_abx)
        {
            // requirements for calling TRSM
            rocblas_operation trans
                = (uplo == rocblas_fill_upper ? rocblas_operation_none
                                              : rocblas_operation_conjugate_transpose);
            rocblasCall_trsm_mem<BATCHED, T>(rocblas_side_left, trans, n, n, batch_count, &temp1,
                                             &temp2, &temp3, &temp4);
            *size_work1 = max(*size_work1, temp1);
            *size_work2 = max(*size_work2, temp2);
            *size_work3 = max(*size_work3, temp3);
            *size_work4 = max(*size_work4, temp4);

            // always allocate all required memory for TRSM optimal performance
            opt3 = true;
        }
    }

    *optim_mem = opt1 && opt2 && opt3;
}

template <bool BATCHED, bool STRIDED, typename T, typename S, typename U, bool COMPLEX = is_complex<T>>
rocblas_status rocsolver_sygvd_hegvd_template(rocblas_handle handle,
                                              const rocblas_eform itype,
                                              const rocblas_evect evect,
                                              const rocblas_fill uplo,
                                              const rocblas_int n,
                                              U A,
                                              const rocblas_int shiftA,
                                              const rocblas_int lda,
                                              const rocblas_stride strideA,
                                              U B,
                                              const rocblas_int shiftB,
                                              const rocblas_int ldb,
                                              const rocblas_stride strideB,
                                              S* D,
                                              const rocblas_stride strideD,
                                              S* E,
                                              const rocblas_stride strideE,
                                              rocblas_int* info,
                                              const rocblas_int batch_count,
                                              T* scalars,
                                              void* work1,
                                              void* work2,
                                              void* work3,
                                              void* work4,
                                              T* tau,
                                              void* pivots_workArr,
                                              rocblas_int* iinfo,
                                              bool optim_mem)
{
    ROCSOLVER_ENTER("sygvd_hegvd", "itype:", itype, "evect:", evect, "uplo:", uplo, "n:", n,
                    "shiftA:", shiftA, "lda:", lda, "shiftB:", shiftB, "ldb:", ldb,
                    "bc:", batch_count);

    // quick return
    if(batch_count == 0)
        return rocblas_status_success;

    hipStream_t stream;
    rocblas_get_stream(handle, &stream);

    rocblas_int blocksReset = (batch_count - 1) / BS1 + 1;
    dim3 gridReset(blocksReset, 1, 1);
    dim3 threads(BS1, 1, 1);

    // info=0 (starting with no errors)
    ROCSOLVER_LAUNCH_KERNEL(reset_info, gridReset, threads, 0, stream, info, batch_count, 0);

    // quick return
    if(n == 0)
        return rocblas_status_success;

    // everything must be executed with scalars on the host
    rocblas_pointer_mode old_mode;
    rocblas_get_pointer_mode(handle, &old_mode);
    rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host);

    // constants for rocblas functions calls
    T one = 1;

    // perform Cholesky factorization of B
    rocsolver_potrf_template<BATCHED, T, S>(handle, uplo, n, B, shiftB, ldb, strideB, info,
                                            batch_count, scalars, work1, work2, work3, work4,
                                            (T*)pivots_workArr, iinfo, optim_mem);

    /** (TODO: Strictly speaking, computations should stop here is B is not positive definite.
        A should not be modified in this case as no eigenvalues or eigenvectors can be computed.
        Need to find a way to do this efficiently; for now A will be destroyed in the non
        positive-definite case) **/

    // reduce to standard eigenvalue problem and solve
    rocsolver_sygst_hegst_template<BATCHED, STRIDED, T, S>(
        handle, itype, uplo, n, A, shiftA, lda, strideA, B, shiftB, ldb, strideB, batch_count,
        scalars, work1, work2, work3, work4, optim_mem);

    rocsolver_syevd_heevd_template<BATCHED, STRIDED, T>(
        handle, evect, uplo, n, A, shiftA, lda, strideA, D, strideD, E, strideE, iinfo, batch_count,
        scalars, work1, work2, work3, (T*)work4, tau, (T**)pivots_workArr);

    // combine info from POTRF with info from SYEV/HEEV
    ROCSOLVER_LAUNCH_KERNEL(sygv_update_info, gridReset, threads, 0, stream, info, iinfo, n,
                            batch_count);

    // backtransform eigenvectors
    if(evect == rocblas_evect_original)
    {
        if(itype == rocblas_eform_ax || itype == rocblas_eform_abx)
        {
            rocblas_operation trans
                = (uplo == rocblas_fill_upper ? rocblas_operation_none
                                              : rocblas_operation_conjugate_transpose);
            rocblasCall_trsm<BATCHED, T>(handle, rocblas_side_left, uplo, trans,
                                         rocblas_diagonal_non_unit, n, n, &one, B, shiftB, ldb,
                                         strideB, A, shiftA, lda, strideA, batch_count, optim_mem,
                                         work1, work2, work3, work4);
        }
        else
        {
            rocblas_operation trans
                = (uplo == rocblas_fill_upper ? rocblas_operation_conjugate_transpose
                                              : rocblas_operation_none);
            rocblasCall_trmm<BATCHED, STRIDED, T>(
                handle, rocblas_side_left, uplo, trans, rocblas_diagonal_non_unit, n, n, &one, 0, B,
                shiftB, ldb, strideB, A, shiftA, lda, strideA, batch_count, (T**)pivots_workArr);
        }
    }

    rocblas_set_pointer_mode(handle, old_mode);
    return rocblas_status_success;
}
