/************************************************************************
 * Derived from the BSD3-licensed
 * LAPACK routine (version 3.1) --
 *     Univ. of Tennessee, Univ. of California Berkeley,
 *     Univ. of Colorado Denver and NAG Ltd..
 *     December 2016
 * Copyright (c) 2019-2021 Advanced Micro Devices, Inc.
 * ***********************************************************************/

#pragma once

#include "rocblas.hpp"
#include "roclapack_trtri.hpp"
#include "rocsolver.h"

template <bool BATCHED, bool STRIDED, typename T>
void rocsolver_potri_getMemorySize(const rocblas_int n,
                                   const rocblas_int batch_count,
                                   size_t* size_work1,
                                   size_t* size_work2,
                                   size_t* size_work3,
                                   size_t* size_work4,
                                   size_t* size_tmpcopy,
                                   size_t* size_workArr,
                                   bool* optim_mem)
{
    // if quick return no need of workspace
    if(n == 0 || batch_count == 0)
    {
        *size_work1 = 0;
        *size_work2 = 0;
        *size_work3 = 0;
        *size_work4 = 0;
        *size_tmpcopy = 0;
        *size_workArr = 0;
        *optim_mem = true;
        return;
    }

    // requirements for calling TRTRI
    rocsolver_trtri_getMemorySize<BATCHED, STRIDED, T>(rocblas_diagonal_non_unit, n, batch_count,
                                                       size_work1, size_work2, size_work3, size_work4,
                                                       size_tmpcopy, size_workArr, optim_mem);

    // required space to copy A
    *size_tmpcopy = std::max(*size_tmpcopy, sizeof(T) * n * n * batch_count);
}

template <typename T>
rocblas_status rocsolver_potri_argCheck(rocblas_handle handle,
                                        const rocblas_fill uplo,
                                        const rocblas_int n,
                                        const rocblas_int lda,
                                        T A,
                                        rocblas_int* info,
                                        const rocblas_int batch_count = 1)
{
    // order is important for unit tests:

    // 1. invalid/non-supported values
    if(uplo != rocblas_fill_upper && uplo != rocblas_fill_lower)
        return rocblas_status_invalid_value;

    // 2. invalid size
    if(n < 0 || lda < n || batch_count < 0)
        return rocblas_status_invalid_size;

    // skip pointer check if querying memory size
    if(rocblas_is_device_memory_size_query(handle))
        return rocblas_status_continue;

    // 3. invalid pointers
    if((n && !A) || (batch_count && !info))
        return rocblas_status_invalid_pointer;

    return rocblas_status_continue;
}

template <bool BATCHED, bool STRIDED, typename T, typename U>
rocblas_status rocsolver_potri_template(rocblas_handle handle,
                                        const rocblas_fill uplo,
                                        const rocblas_int n,
                                        U A,
                                        const rocblas_int shiftA,
                                        const rocblas_int lda,
                                        const rocblas_stride strideA,
                                        rocblas_int* info,
                                        const rocblas_int batch_count,
                                        void* work1,
                                        void* work2,
                                        void* work3,
                                        void* work4,
                                        T* tmpcopy,
                                        T** workArr,
                                        const bool optim_mem)
{
    ROCSOLVER_ENTER("potri", "uplo:", uplo, "n:", n, "shiftA:", shiftA, "lda:", lda,
                    "bc:", batch_count);

    // quick return if zero instances in batch
    if(batch_count == 0)
        return rocblas_status_success;

    hipStream_t stream;
    rocblas_get_stream(handle, &stream);

    // quick return if no dimensions
    if(n == 0)
    {
        rocblas_int blocks = (batch_count - 1) / 32 + 1;
        ROCSOLVER_LAUNCH_KERNEL(reset_info, dim3(blocks, 1, 1), dim3(32, 1, 1), 0, stream, info,
                                batch_count, 0);
        return rocblas_status_success;
    }

    // compute inverse of U or L (also check singularity and update info)
    rocsolver_trtri_template<BATCHED, STRIDED, T>(handle, uplo, rocblas_diagonal_non_unit, n, A,
                                                  shiftA, lda, strideA, info, batch_count, work1,
                                                  work2, work3, work4, tmpcopy, workArr, optim_mem);

    // everything must be executed with scalars on the host
    rocblas_pointer_mode old_mode;
    rocblas_get_pointer_mode(handle, &old_mode);
    rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host);

    // constants in host memory
    const rocblas_int copyblocks = (n - 1) / 32 + 1;
    T one = 1;

    // copy elements of A to serve as B matrix for TRMM
    ROCSOLVER_LAUNCH_KERNEL(copy_mat<T>, dim3(copyblocks, copyblocks, batch_count), dim3(32, 32), 0,
                            stream, copymat_to_buffer, n, n, A, shiftA, lda, strideA, tmpcopy,
                            no_mask{}, uplo);

    // compute inv(U) * inv(U)' or inv(L)' * inv(L) and store in tmpcopy
    rocblas_side side = (uplo == rocblas_fill_upper ? rocblas_side_right : rocblas_side_left);
    rocblasCall_trmm<BATCHED, STRIDED, T>(handle, side, uplo, rocblas_operation_conjugate_transpose,
                                          rocblas_diagonal_non_unit, n, n, &one, 0, A, shiftA, lda,
                                          strideA, tmpcopy, 0, n, n * n, batch_count, workArr);

    // copy elements of tmpcopy into A in cases where info is zero
    ROCSOLVER_LAUNCH_KERNEL(copy_mat<T>, dim3(copyblocks, copyblocks, batch_count), dim3(32, 32), 0,
                            stream, copymat_from_buffer, n, n, A, shiftA, lda, strideA, tmpcopy,
                            info_mask(info, info_mask::negate), uplo, rocblas_diagonal_non_unit);

    rocblas_set_pointer_mode(handle, old_mode);
    return rocblas_status_success;
}
