///
///  hostrpc.cpp: definitions of device stubs and host fallback functions
///

#include "hostrpc.h"
// -----------------------------------------------------------------------------
//
// hostrpc/src/hostrpc.cpp:  device stubs
//
// GPUs typically do not support vargs style functions.  So to implement
// printf or any vaargs function as a hostrpc service requires the compiler
// to generate code to allocate a buffer, fill the buffer with the value of
// each argument, and then call a stub to execute the service with a pointer to
// the buffer. The clang compiler does this in the CGGPUBuiltin.cpp source.
// Here we define printf_allocate and printf_execute device functions that are
// generated by the clang compiler when it encounters a printf statement.
// printf_allocate is implemented as a hostrpc stub. We assume that the
// host routine for printf_execute will free the buffer that was allocated
// by printf_allocate.

#pragma omp declare target

typedef struct hostrpc_result_s {
  uint64_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
} hostrpc_result_t;
// No hostrpc_invoke in header since all stubs are defined here.
EXTERN hostrpc_result_t hostrpc_invoke(uint32_t id, uint64_t arg0,
                                       uint64_t arg1, uint64_t arg2,
                                       uint64_t arg3, uint64_t arg4,
                                       uint64_t arg5, uint64_t arg6,
                                       uint64_t arg7);

#ifdef __AMDGCN__

static hostrpc_result_t
hostrpc_invoke_zeros(uint32_t id, uint64_t arg0 = 0, uint64_t arg1 = 0,
                     uint64_t arg2 = 0, uint64_t arg3 = 0, uint64_t arg4 = 0,
                     uint64_t arg5 = 0, uint64_t arg6 = 0, uint64_t arg7 = 0) {
  return hostrpc_invoke(id, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7);
}

// This definition of __ockl_devmem_request needs to override the weak
// symbol for __ockl_devmem_request in ockl.bc because by default ockl
// uses hostcall. But OpenMP uses hostrpc.
EXTERN uint64_t __ockl_devmem_request(uint64_t addr, uint64_t size) {
  uint64_t arg0;
  if (size) { // allocation request
    arg0 = size;
    hostrpc_result_t result =
        hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_MALLOC), arg0);
    return result.arg1;
  } else { // free request
    arg0 = addr;
    hostrpc_result_t result =
        hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_FREE), arg0);
    return result.arg0;
  }
}
EXTERN void f90print_(char *s) { printf("%s\n", s); }
EXTERN void f90printi_(char *s, int *i) { printf("%s %d\n", s, *i); }
EXTERN void f90printl_(char *s, long *i) { printf("%s %ld\n", s, *i); }
EXTERN void f90printf_(char *s, float *f) { printf("%s %f\n", s, *f); }
EXTERN void f90printd_(char *s, double *d) { printf("%s %g\n", s, *d); }

EXTERN char *printf_allocate(uint32_t bufsz) {
  uint64_t arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0);
  return (char *)result.arg1;
}
EXTERN char *hostrpc_varfn_uint_allocate(uint32_t bufsz) {
  uint64_t arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0);
  return (char *)result.arg1;
}
EXTERN char *hostrpc_varfn_uint64_allocate(uint32_t bufsz) {
  uint64_t arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0);
  return (char *)result.arg1;
}
EXTERN char *hostrpc_varfn_double_allocate(uint32_t bufsz) {
  uint64_t arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0);
  return (char *)result.arg1;
}

EXTERN void hostrpc_fptr0(void *fptr) {
  uint64_t arg0 = (uint64_t)fptr;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_FUNCTIONCALL), arg0);
}

EXTERN int printf_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_PRINTF), arg0, arg1);
  return (int)result.arg0;
}

EXTERN char *fprintf_allocate(uint32_t bufsz) {
  uint64_t arg0 = (uint64_t)bufsz;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_MALLOC_PRINTF), arg0);
  return (char *)result.arg1;
}
EXTERN int fprintf_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_FPRINTF), arg0, arg1);
  return (int)result.arg0;
}

EXTERN uint64_t __tgt_fort_ptr_assn_i8(void *varg0, void *varg1, void *varg2,
                                       void *varg3, void *varg4) {
  uint64_t arg0, arg1, arg2, arg3, arg4;
  arg0 = (uint64_t)varg0;
  arg1 = (uint64_t)varg1;
  arg2 = (uint64_t)varg2;
  arg3 = (uint64_t)varg3;
  arg4 = (uint64_t)varg4;
  hostrpc_result_t result = hostrpc_invoke_zeros(
      PACK_VERS(HOSTRPC_SERVICE_FTNASSIGN), arg0, arg1, arg2, arg3, arg4);
  return (uint64_t)result.arg0;
}

EXTERN uint32_t hostrpc_varfn_uint_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_VARFNUINT), arg0, arg1);
  return (int)result.arg0;
}
EXTERN uint64_t hostrpc_varfn_uint64_execute(char *print_buffer,
                                             uint32_t bufsz) {
  uint64_t arg0, arg1;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_VARFNUINT64), arg0, arg1);
  return (uint64_t)result.arg0;
}
EXTERN double hostrpc_varfn_double_execute(char *print_buffer, uint32_t bufsz) {
  uint64_t arg0, arg1;
  arg0 = (uint64_t)bufsz;
  arg1 = (uint64_t)print_buffer;
  hostrpc_result_t result =
      hostrpc_invoke_zeros(PACK_VERS(HOSTRPC_SERVICE_VARFNDOUBLE), arg0, arg1);
  union {
    uint64_t val;
    double dval;
  } unionarg;
  unionarg.val = result.arg0;
  return unionarg.dval;
}

// -----------------------------------------------------------------------------
//
// vector_product_zeros: Example stub to demonstrate hostrpc services
//
// This is an example hostrpc stub for a service called vector_product_zeros.
// This function calculates C = A*B and returns the number of zeros.
// Naturally, one would typically do this type of operation on a GPU.
// But this is a demo  to illustrate the use of hostrpc to run a service
// on the host.  The host service for HOSTRPC_SERVICE_DEMO is in
// llvm-project/openmp/libomptarget/plugins/hsa/impl/hostrpc_handlers.c
// After copying the vectors from the GPU the service handler calls this
// routine on the host.
//

EXTERN int vector_product_zeros(int N, int *A, int *B, int *C) {
  uint64_t arg0, arg1, arg2, arg3;
  arg0 = (int64_t)N;
  // Pass these pointers to the host for memcpy
  arg1 = (int64_t)A;
  arg2 = (int64_t)B;
  arg3 = (int64_t)C;
  hostrpc_result_t result = hostrpc_invoke_zeros(
      PACK_VERS(HOSTRPC_SERVICE_DEMO), arg0, arg1, arg2, arg3);
  int num_zeros = (int)result.arg1;
  return num_zeros;
}

// This utility is used for printf arguments that are variable length strings
// The clang compiler will generate calls to this only when a string length is
// not a compile time constant.
EXTERN uint32_t __strlen_max(char *instr, uint32_t maxstrlen) {
  for (uint32_t i = 0; i < maxstrlen; i++)
    if (instr[i] == (char)0)
      return (uint32_t)(i + 1);
  return maxstrlen;
}
// NOTE: if you add a new interface above, also add it to
// libomptarget/src/exports and to libomptarget/src/slib_hostrpc.cpp

// ---------------------------------------------------
#else
// ---------------------------------------------------
// This stub is needed to satisfy omp pragma syntax.
static int stub() { return 0; };
#endif
#pragma omp end declare target
