//===----- CGAMPRuntime.cpp - Interface to C++ AMP Runtime ----------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This provides an abstract class for C++ AMP code generation.  Concrete
// subclasses of this implement code generation for specific C++ AMP
// runtime libraries.
//
//===----------------------------------------------------------------------===//

#include "CodeGenFunction.h"
#include "CGAMPRuntime.h"
#include "clang/AST/Decl.h"
#include "clang/AST/ExprCXX.h"
#include "CGCall.h"
#include "TargetInfo.h"

namespace clang {
namespace CodeGen {

CGAMPRuntime::~CGAMPRuntime() {}

/// Creates an instance of a C++ AMP runtime class.
CGAMPRuntime *CreateAMPRuntime(CodeGenModule &CGM) {
  return new CGAMPRuntime(CGM);
}
static CXXMethodDecl *findValidIndexType(QualType IndexTy) {
  CXXRecordDecl *IndexClass = (*IndexTy).getAsCXXRecordDecl();
  CXXMethodDecl *IndexConstructor = NULL;
  if (IndexClass) {
    for (CXXRecordDecl::method_iterator CtorIt = IndexClass->method_begin(),
        CtorE = IndexClass->method_end();
        CtorIt != CtorE; ++CtorIt) {
      if (CtorIt->hasAttr<AnnotateAttr>() &&
          CtorIt->getAttr<AnnotateAttr>()->getAnnotation() ==
            "__cxxamp_opencl_index") {
        IndexConstructor = *CtorIt;
      }
    }
  }
  return IndexConstructor;
}


void CGAMPRuntime::EmitCXXAMPDeserializer(CodeGenFunction &CGF,
  const FunctionDecl *Trampoline, FunctionArgList& Args,
  Address& ai) {

  const CXXRecordDecl *ClassDecl = dyn_cast<CXXMethodDecl>(Trampoline)->getParent();

  CXXConstructorDecl *DeserializeConstructor =
    dyn_cast<CXXConstructorDecl>(ClassDecl->getCXXAMPDeserializationConstructor());
  assert(DeserializeConstructor);

  CallArgList DeserializerArgs;

  // this
  DeserializerArgs.add(RValue::get(ai.getPointer()),
                       DeserializeConstructor->getThisType());

  // the rest of constructor args. Create temporary objects for references
  // on stack
  CXXConstructorDecl::param_iterator CPI = DeserializeConstructor->param_begin(),
  CPE = DeserializeConstructor->param_end();

  for (FunctionArgList::iterator I = Args.begin(); 
         I != Args.end() && CPI != CPE; ++CPI) {
    // Reference types are only allowed to have one level; i.e. no
    // class base {&int}; class foo { bar &base; };
    QualType MemberType = (*CPI)->getType().getNonReferenceType();
    if (MemberType != (*CPI)->getType()) {
      if (!CGM.getLangOpts().HSAExtension) {

        assert(MemberType.getTypePtr()->isClassType() == true &&
               "Only supporting taking reference of classes");

        CXXRecordDecl *MemberClass = MemberType.getTypePtr()->getAsCXXRecordDecl();

        CXXConstructorDecl *MemberDeserializer = dyn_cast<CXXConstructorDecl>(
                              MemberClass->getCXXAMPDeserializationConstructor());
        assert(MemberDeserializer);

        std::vector<Expr*>MemberArgDeclRefs;
        for (CXXMethodDecl::param_iterator MCPI = MemberDeserializer->param_begin(),
              MCPE = MemberDeserializer->param_end(); MCPI!=MCPE; ++MCPI, ++I) {

          Expr *ArgDeclRef = DeclRefExpr::Create(CGM.getContext(),
                                                 NestedNameSpecifierLoc(),
                                                 SourceLocation(),
                                                 const_cast<VarDecl *>(*I),
                                                 false,
                                                 SourceLocation(),
                                                 (*MCPI)->getType(), VK_RValue);
  	      MemberArgDeclRefs.push_back(ArgDeclRef);
        }

        // Allocate "this" for member referenced objects
        Address mai = CGF.CreateMemTemp(MemberType);

        // Emit code to call the deserializing constructor of temp objects
        CXXConstructExpr *CXXCE = CXXConstructExpr::Create(CGM.getContext(), 
                                                           MemberType,
                                                           SourceLocation(),
                                                           MemberDeserializer,
                                                           false,
                                                           MemberArgDeclRefs,
                                                           false, false, false, false,
                                                           CXXConstructExpr::CK_Complete,
                                                           SourceLocation());

        auto currAVS = AggValueSlot::forAddr(mai, MemberType.getQualifiers(),
        		                               AggValueSlot::IsNotDestructed,
        		                               AggValueSlot::DoesNotNeedGCBarriers,
        		                               AggValueSlot::IsNotAliased,
        		                               AggValueSlot::DoesNotOverlap);
        CGF.EmitCXXConstructorCall(MemberDeserializer, Ctor_Complete, false, false, currAVS, CXXCE);
        DeserializerArgs.add(RValue::get(mai.getPointer()), (*CPI)->getType());

      } else { // HSA extension check

        if (MemberType.getTypePtr()->isClassType()) {

          // hc::array should still be serialized as traditional C++AMP objects
          if (MemberType.getTypePtr()->isGPUArrayType()) {

            CXXRecordDecl *MemberClass = MemberType.getTypePtr()->getAsCXXRecordDecl();

            CXXConstructorDecl *MemberDeserializer =
              dyn_cast<CXXConstructorDecl>(MemberClass->getCXXAMPDeserializationConstructor());
            assert(MemberDeserializer);

            std::vector<Expr*>MemberArgDeclRefs;
            for (CXXMethodDecl::param_iterator MCPI = MemberDeserializer->param_begin(),
              MCPE = MemberDeserializer->param_end(); MCPI!=MCPE; ++MCPI, ++I) {

              Expr *ArgDeclRef = DeclRefExpr::Create(CGM.getContext(),
                                                     NestedNameSpecifierLoc(),
                                                     SourceLocation(),
                                                     const_cast<VarDecl *>(*I),
                                                     false,
                                                     SourceLocation(),
                                                     (*MCPI)->getType(), VK_RValue);

               MemberArgDeclRefs.push_back(ArgDeclRef);
            }

            // Allocate "this" for member referenced objects
            Address mai = CGF.CreateMemTemp(MemberType);

            // Emit code to call the deserializing constructor of temp objects
            CXXConstructExpr *CXXCE = CXXConstructExpr::Create(CGM.getContext(), 
                                                               MemberType,
                                                               SourceLocation(),
                                                               MemberDeserializer,
                                                               false,
                                                               MemberArgDeclRefs,
                                                               false, false, false, false,
                                                               CXXConstructExpr::CK_Complete,
                                                               SourceLocation());

            auto currAVS = AggValueSlot::forAddr(mai, MemberType.getQualifiers(),
                    		                               AggValueSlot::IsNotDestructed,
                    		                               AggValueSlot::DoesNotNeedGCBarriers,
                    		                               AggValueSlot::IsNotAliased,
                    		                               AggValueSlot::DoesNotOverlap);
            CGF.EmitCXXConstructorCall(MemberDeserializer, Ctor_Complete, false, false, currAVS, CXXCE);
            DeserializerArgs.add(RValue::get(mai.getPointer()), (*CPI)->getType());

          } else {

            // capture by refernce for HSA
            Expr *ArgDeclRef = DeclRefExpr::Create(CGM.getContext(),
                                                   NestedNameSpecifierLoc(),
                                                   SourceLocation(),
                                                   const_cast<VarDecl *>(*I), false,
                                                   SourceLocation(),
                                                   (*I)->getType(), VK_RValue);

            RValue ArgRV = CGF.EmitAnyExpr(ArgDeclRef);
            DeserializerArgs.add(ArgRV, CGM.getContext().getPointerType(MemberType));
            ++I;
          }
          
        } else {

          // capture by refernce for HSA
          Expr *ArgDeclRef = DeclRefExpr::Create(CGM.getContext(),
                                                 NestedNameSpecifierLoc(),
                                                 SourceLocation(),
                                                 const_cast<VarDecl *>(*I), false,
                                                 SourceLocation(),
                                                 (*I)->getType(), VK_RValue);

          RValue ArgRV = CGF.EmitAnyExpr(ArgDeclRef);
          DeserializerArgs.add(ArgRV, CGM.getContext().getPointerType(MemberType));
          ++I;
        }
      } // HSA extension check

    } else {

      Expr *ArgDeclRef = DeclRefExpr::Create(CGM.getContext(),
	                                           NestedNameSpecifierLoc(),
	                                           SourceLocation(),
	                                           const_cast<VarDecl *>(*I), false,
	                                           SourceLocation(),
	                                           (*I)->getType(), VK_RValue);

      RValue ArgRV = CGF.EmitAnyExpr(ArgDeclRef);
      DeserializerArgs.add(ArgRV, (*CPI)->getType());
      ++I;
    }
  }

  // Emit code to call the deserializing constructor
  llvm::Constant *Callee =
    CGM.getAddrOfCXXStructor(GlobalDecl(DeserializeConstructor, Ctor_Complete));

  const FunctionProtoType *FPT =
      DeserializeConstructor->getType()->castAs<FunctionProtoType>();

  const CGFunctionInfo &DesFnInfo =
    CGM.getTypes().arrangeCXXStructorDeclaration(
      GlobalDecl(DeserializeConstructor, Ctor_Complete));

  for (unsigned I = 1, E = DeserializerArgs.size(); I != E; ++I) {
    auto T = FPT->getParamType(I-1);
    // EmitFromMemory is necessary in case function has bool parameter.
    if (T->isBooleanType()) {
      DeserializerArgs[I] =
          CallArg(RValue::get(CGF.EmitFromMemory(
                      DeserializerArgs[I].getKnownRValue().getScalarVal(), T)),
                  T);
    }
  }
  CGF.EmitCall(DesFnInfo, CGCallee::forDirect(Callee), ReturnValueSlot(), DeserializerArgs);
}

/// Operations:
/// For each reference-typed members, construct temporary object
/// Invoke constructor of index
/// Invoke constructor of the class
/// Invoke operator(index)
void CGAMPRuntime::EmitTrampolineBody(CodeGenFunction &CGF,
  const FunctionDecl *Trampoline, FunctionArgList& Args) {
  const CXXRecordDecl *ClassDecl = dyn_cast<CXXMethodDecl>(Trampoline)->getParent();
  assert(ClassDecl);
  // Allocate "this"
  Address ai = CGF.CreateMemTemp(QualType(ClassDecl->getTypeForDecl(),0));
  // Locate the constructor to call
  if(ClassDecl->getCXXAMPDeserializationConstructor()) {
    EmitCXXAMPDeserializer(CGF,Trampoline,Args,ai); 
  }

  // Locate the type of Concurrency::index<1>
  // Locate the operator to call
  CXXMethodDecl *KernelDecl = NULL;
  CXXMethodDecl *KernelDeclNoArg = NULL;
  const FunctionType *MT = NULL;
  QualType IndexTy;
  for (CXXRecordDecl::method_iterator Method = ClassDecl->method_begin(),
                                   MethodEnd = ClassDecl->method_end();
                                     Method != MethodEnd; ++Method) {

    CXXMethodDecl *MethodDecl = *Method;
    if (MethodDecl->isOverloadedOperator() &&
        MethodDecl->getOverloadedOperator() == OO_Call &&
        MethodDecl->hasAttr<CXXAMPRestrictAMPAttr>()) {
       
      //Check types.
      if(MethodDecl->getNumParams() > 1) {
	      continue;
      }
      else if (MethodDecl->getNumParams() == 0) {
         MT = dyn_cast<FunctionType>(MethodDecl->getType().getTypePtr());
         assert(MT);
         KernelDeclNoArg = MethodDecl;
         continue;
      }
      else {
        ParmVarDecl *P = MethodDecl->getParamDecl(0);
        IndexTy = P->getType().getNonReferenceType();
        if (!findValidIndexType(IndexTy))
          continue;
        MT = dyn_cast<FunctionType>(MethodDecl->getType().getTypePtr());
        assert(MT);
        KernelDecl = MethodDecl;
        break;
      }
    }
  }

  // in case we couldn't find any kernel declarator
  // raise error
  if (!KernelDecl && !KernelDeclNoArg) {
    CGF.CGM.getDiags().Report(ClassDecl->getLocation(), diag::err_amp_ill_formed_functor);
    return;
  }

  CXXMethodDecl *Kernel = KernelDecl ? KernelDecl : KernelDeclNoArg;

  // Invoke this->operator(index)
  // Prepate the operator() to call
  llvm::FunctionType *fnType =
    CGM.getTypes().GetFunctionType(CGM.getTypes().arrangeCXXMethodDeclaration(Kernel));
  llvm::Constant *fnAddr = CGM.GetAddrOfFunction(Kernel, fnType);

  // Prepare argument
  CallArgList KArgs;

  // this
  KArgs.add(RValue::get(ai.getPointer()), Kernel->getThisType());

  if (KernelDecl) {

    // Allocate Index
    Address index = CGF.CreateMemTemp(IndexTy);

    // Locate the constructor to call
    CXXMethodDecl *IndexConstructor = findValidIndexType(IndexTy);
    assert(IndexConstructor);

    // Emit code to call the Concurrency::index<1>::__cxxamp_opencl_index()
    if (!CGF.getLangOpts().AMPCPU) {
      if (CXXConstructorDecl *Constructor =
            dyn_cast <CXXConstructorDecl>(IndexConstructor)) {

        CXXConstructExpr *CXXCE = CXXConstructExpr::Create(CGM.getContext(), 
                                                           IndexTy,
                                                           SourceLocation(),
                                                           Constructor,
                                                           false,
                                                           ArrayRef<Expr*>(),
                                                           false, false, false, false,
                                                           CXXConstructExpr::CK_Complete,
                                                           SourceLocation());

        auto currAVS = AggValueSlot::forAddr(index, IndexTy.getQualifiers(),
                            		                               AggValueSlot::IsNotDestructed,
                            		                               AggValueSlot::DoesNotNeedGCBarriers,
                            		                               AggValueSlot::IsNotAliased,
                            		                               AggValueSlot::DoesNotOverlap);
        CGF.EmitCXXConstructorCall(Constructor, Ctor_Complete, false, false, currAVS, CXXCE);

      } else {
        llvm::FunctionType *indexInitType = CGM.getTypes().GetFunctionType(
                                               CGM.getTypes().arrangeCXXMethodDeclaration(IndexConstructor));

        llvm::Constant *indexInitAddr = CGM.GetAddrOfFunction(IndexConstructor, indexInitType);

        CGF.EmitCXXMemberOrOperatorCall(IndexConstructor, CGCallee::forDirect(indexInitAddr),
                                        ReturnValueSlot(), index.getPointer(), /*ImplicitParam=*/0, 
                                        QualType(), /*CallExpr=*/nullptr, /*RtlArgs=*/nullptr);
      }
    }

    // *index
    // index is of reference type of IndexTy.
    KArgs.add(RValue::get(index.getPointer()),
        CGF.getContext().getLValueReferenceType(IndexTy));
  }

  const CGFunctionInfo &FnInfo = CGM.getTypes().arrangeFreeFunctionCall(KArgs, MT, false);
  CGF.EmitCall(FnInfo, CGCallee::forDirect(fnAddr), ReturnValueSlot(), KArgs);
  CGM.getTargetCodeGenInfo().setTargetAttributes(Kernel, CGF.CurFn, CGM);
}

void CGAMPRuntime::EmitTrampolineNameBody(CodeGenFunction &CGF,
  const FunctionDecl *Trampoline, FunctionArgList& Args) {
  const CXXRecordDecl *ClassDecl = dyn_cast<CXXMethodDecl>(Trampoline)->getParent();
  assert(ClassDecl);
  // Locate the trampoline
  // Locate the operator to call
  CXXMethodDecl *TrampolineDecl = NULL;
  for (CXXRecordDecl::method_iterator Method = ClassDecl->method_begin(),
      MethodEnd = ClassDecl->method_end();
      Method != MethodEnd; ++Method) {
    CXXMethodDecl *MethodDecl = *Method;
    if (Method->hasAttr<AnnotateAttr>() &&
        Method->getAttr<AnnotateAttr>()->getAnnotation() == "__cxxamp_trampoline") {
      TrampolineDecl = MethodDecl;
      break;
    }
  }
  assert(TrampolineDecl && "Trampoline not declared!");
  GlobalDecl GD(TrampolineDecl);
  llvm::Constant *S = llvm::ConstantDataArray::getString(CGM.getLLVMContext(),
    CGM.getMangledName(GD));
  llvm::GlobalVariable *GV = new llvm::GlobalVariable(CGM.getModule(), S->getType(),
    true, llvm::GlobalValue::PrivateLinkage, S, "__cxxamp_trampoline.kernelname");
  GV->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);

  //Create GetElementPtr(0, 0)
  std::vector<llvm::Constant*> indices;
  llvm::ConstantInt *zero = llvm::ConstantInt::get(CGM.getLLVMContext(), llvm::APInt(32, StringRef("0"), 10));
  indices.push_back(zero);
  indices.push_back(zero);
  llvm::Constant *const_ptr = llvm::ConstantExpr::getGetElementPtr(GV->getValueType(), GV, indices);
  CGF.Builder.CreateStore(const_ptr, CGF.ReturnValue);

}
} // namespace CodeGen
} // namespace clang
