///
///  hostrpc.cpp: definitions of device stubs and host fallback functions
///

#include "hostrpc.h"
#include <omp.h>
// -----------------------------------------------------------------------------
//
// printf: stubs to support printf
//
// GPUs typically do not support vargs style functions.  So to implement
// printf or any vaargs function as a hostrpc service requires the compiler
// to generate code to allocate a buffer, fill the buffer with the value of
// each argument, and then call a stub to execute the service with a pointer to
// the buffer. The clang compiler does this in the CGGPUBuiltin.cpp source.
// Here we define printf_allocate and printf_execute device functions that are
// generated by the clang compiler when it encounters a printf statement.
// printf_allocate is implemented as a hostrpc stub. We assume that the
// host routine for printf_execute will free the buffer that was allocated
// by printf_allocate.

#pragma omp declare target

#ifdef __AMDGCN__
EXTERN char *printf_allocate(uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  return (char *)result.arg1;
}
EXTERN char *hostrpc_varfn_uint_allocate(uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  return (char *)result.arg1;
}
EXTERN char *hostrpc_varfn_uint64_allocate(uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  return (char *)result.arg1;
}
EXTERN char *hostrpc_varfn_double_allocate(uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  return (char *)result.arg1;
}

EXTERN char *global_allocate(uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_MALLOC), arg0, arg1, arg2, arg3,
                     arg4, arg5, arg6, arg7);
  return (char *)result.arg1;
}
EXTERN int global_free(char *ptr) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)ptr;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_FREE), arg0, arg1, arg2, arg3,
                     arg4, arg5, arg6, arg7);
  return (int)result.arg0;
}

EXTERN void hostrpc_fptr0(void *fptr) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)fptr;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_FUNCTIONCALL), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
}

EXTERN int printf_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_PRINTF), arg0, arg1, arg2, arg3,
                     arg4, arg5, arg6, arg7);
  return (int)result.arg0;
}

EXTERN char *fprintf_allocate(uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  return (char *)result.arg1;
}
EXTERN int fprintf_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_FPRINTF), arg0, arg1, arg2, arg3,
                     arg4, arg5, arg6, arg7);
  return (int)result.arg0;
}

EXTERN uint64_t __tgt_fort_ptr_assn_i8(void *varg0, void *varg1, void *varg2,
                                       void *varg3, void *varg4) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)varg0;
  arg1 = (uint64_t)varg1;
  arg2 = (uint64_t)varg2;
  arg3 = (uint64_t)varg3;
  arg4 = (uint64_t)varg4;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_FTNASSIGN), arg0, arg1, arg2, arg3,
                     arg4, arg5, arg6, arg7);
  return (uint64_t)result.arg0;
}

EXTERN uint32_t hostrpc_varfn_uint_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_VARFNUINT), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  return (int)result.arg0;
}
EXTERN uint64_t hostrpc_varfn_uint64_execute(char *print_buffer,
                                             uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_VARFNUINT64), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  return (uint64_t)result.arg0;
}
EXTERN double hostrpc_varfn_double_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_VARFNDOUBLE), arg0, arg1, arg2,
                     arg3, arg4, arg5, arg6, arg7);
  union {
    uint64_t val;
    double dval;
  } unionarg;
  unionarg.val = result.arg0;
  return unionarg.dval;
}

// -----------------------------------------------------------------------------
//
// vector_product_zeros: Example stub to demonstrate hostrpc services
//
// This is an example hostrpc stub for a service called vector_product_zeros.
// This function calculates C = A*B and returns the number of zeros.
// Naturally, one would typically do this type of operation on a GPU.
// But this is a demo  to illustrate the use of hostrpc to run a service
// on the host.  The host service for HOSTRPC_SERVICE_DEMO is in
// llvm-project/openmp/libomptarget/plugins/hsa/impl/hostrpc_handlers.c
// After copying the vectors from the GPU the service handler calls this
// routine on the host.
//

EXTERN int vector_product_zeros(int N, int *A, int *B, int *C) {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
  arg0 = (int64_t)N;
  // Pass these pointers to the host for memcpy
  arg1 = (int64_t)A;
  arg2 = (int64_t)B;
  arg3 = (int64_t)C;
  hostrpc_result_t result =
      hostrpc_invoke(PACK_VERS(HOSTRPC_SERVICE_DEMO), arg0, arg1, arg2, arg3,
                     arg4, arg5, arg6, arg7);
  int num_zeros = (int)result.arg1;
  return num_zeros;
}

// This utility is used for printf arguments that are variable length strings
// The clang compiler will generate calls to this only when a string length is
// not a compile time constant.
EXTERN uint32_t __strlen_max(char *instr, uint32_t maxstrlen) {
  for (uint32_t i = 0; i < maxstrlen; i++)
    if (instr[i] == (char)0)
      return (uint32_t)(i + 1);
  return maxstrlen;
}
// NOTE: if you add a new interface above, also add it to
// libomptarget/src/exports and to libomptarget/src/slib_hostrpc.cpp

// ---------------------------------------------------
#else
// ---------------------------------------------------
// This stub is needed to satisfy omp pragma syntax.
static int stub(){};
#endif
#pragma omp end declare target
