//===------- MicrosoftCXXABI.cpp - AST support for the Microsoft C++ ABI --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides C++ AST support targeting the Microsoft Visual C++
// ABI.
//
//===----------------------------------------------------------------------===//

#include "CXXABI.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/MangleNumberingContext.h"
#include "clang/AST/RecordLayout.h"
#include "clang/AST/Type.h"
#include "clang/Basic/TargetInfo.h"

using namespace clang;

// Before revising the interface, clone of `ItaniumNumberingContext` from
// `lib/AST/ItaniumCXXABI.cpp`.
// {{{ BEGIN CLONE
namespace {

/// According to Itanium C++ ABI 5.1.2:
/// the name of an anonymous union is considered to be
/// the name of the first named data member found by a pre-order,
/// depth-first, declaration-order walk of the data members of
/// the anonymous union.
/// If there is no such data member (i.e., if all of the data members
/// in the union are unnamed), then there is no way for a program to
/// refer to the anonymous union, and there is therefore no need to mangle its name.
///
/// Returns the name of anonymous union VarDecl or nullptr if it is not found.
static const IdentifierInfo *findAnonymousUnionVarDeclName(const VarDecl& VD) {
  const RecordType *RT = VD.getType()->getAs<RecordType>();
  assert(RT && "type of VarDecl is expected to be RecordType.");
  assert(RT->getDecl()->isUnion() && "RecordType is expected to be a union.");
  if (const FieldDecl *FD = RT->getDecl()->findFirstNamedDataMember()) {
    return FD->getIdentifier();
  }

  return nullptr;
}

/// The name of a decomposition declaration.
struct DecompositionDeclName {
  using BindingArray = ArrayRef<const BindingDecl*>;

  /// Representative example of a set of bindings with these names.
  BindingArray Bindings;

  /// Iterators over the sequence of identifiers in the name.
  struct Iterator
      : llvm::iterator_adaptor_base<Iterator, BindingArray::const_iterator,
                                    std::random_access_iterator_tag,
                                    const IdentifierInfo *> {
    Iterator(BindingArray::const_iterator It) : iterator_adaptor_base(It) {}
    const IdentifierInfo *operator*() const {
      return (*this->I)->getIdentifier();
    }
  };
  Iterator begin() const { return Iterator(Bindings.begin()); }
  Iterator end() const { return Iterator(Bindings.end()); }
};
}

namespace llvm {
template<>
struct DenseMapInfo<DecompositionDeclName> {
  using ArrayInfo = llvm::DenseMapInfo<ArrayRef<const BindingDecl*>>;
  using IdentInfo = llvm::DenseMapInfo<const IdentifierInfo*>;
  static DecompositionDeclName getEmptyKey() {
    return {ArrayInfo::getEmptyKey()};
  }
  static DecompositionDeclName getTombstoneKey() {
    return {ArrayInfo::getTombstoneKey()};
  }
  static unsigned getHashValue(DecompositionDeclName Key) {
    assert(!isEqual(Key, getEmptyKey()) && !isEqual(Key, getTombstoneKey()));
    return llvm::hash_combine_range(Key.begin(), Key.end());
  }
  static bool isEqual(DecompositionDeclName LHS, DecompositionDeclName RHS) {
    if (ArrayInfo::isEqual(LHS.Bindings, ArrayInfo::getEmptyKey()))
      return ArrayInfo::isEqual(RHS.Bindings, ArrayInfo::getEmptyKey());
    if (ArrayInfo::isEqual(LHS.Bindings, ArrayInfo::getTombstoneKey()))
      return ArrayInfo::isEqual(RHS.Bindings, ArrayInfo::getTombstoneKey());
    return LHS.Bindings.size() == RHS.Bindings.size() &&
           std::equal(LHS.begin(), LHS.end(), RHS.begin());
  }
};
}

namespace {

/// Keeps track of the mangled names of lambda expressions and block
/// literals within a particular context.
class ItaniumNumberingContext : public MangleNumberingContext {
  llvm::DenseMap<const Type *, unsigned> ManglingNumbers;
  llvm::DenseMap<const IdentifierInfo *, unsigned> VarManglingNumbers;
  llvm::DenseMap<const IdentifierInfo *, unsigned> TagManglingNumbers;
  llvm::DenseMap<DecompositionDeclName, unsigned>
      DecompsitionDeclManglingNumbers;

public:
  unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
    const FunctionProtoType *Proto =
        CallOperator->getType()->getAs<FunctionProtoType>();
    ASTContext &Context = CallOperator->getASTContext();

    FunctionProtoType::ExtProtoInfo EPI;
    EPI.Variadic = Proto->isVariadic();
    QualType Key =
        Context.getFunctionType(Context.VoidTy, Proto->getParamTypes(), EPI);
    Key = Context.getCanonicalType(Key);
    return ++ManglingNumbers[Key->castAs<FunctionProtoType>()];
  }

  unsigned getManglingNumber(const BlockDecl *BD) override {
    const Type *Ty = nullptr;
    return ++ManglingNumbers[Ty];
  }

  unsigned getStaticLocalNumber(const VarDecl *VD) override {
    return 0;
  }

  /// Variable decls are numbered by identifier.
  unsigned getManglingNumber(const VarDecl *VD, unsigned) override {
    if (auto *DD = dyn_cast<DecompositionDecl>(VD)) {
      DecompositionDeclName Name{DD->bindings()};
      return ++DecompsitionDeclManglingNumbers[Name];
    }

    const IdentifierInfo *Identifier = VD->getIdentifier();
    if (!Identifier) {
      // VarDecl without an identifier represents an anonymous union
      // declaration.
      Identifier = findAnonymousUnionVarDeclName(*VD);
    }
    return ++VarManglingNumbers[Identifier];
  }

  unsigned getManglingNumber(const TagDecl *TD, unsigned) override {
    return ++TagManglingNumbers[TD->getIdentifier()];
  }
};

} // End anonymous namesapce
// END CLONE }}}

namespace {

/// Numbers things which need to correspond across multiple TUs.
/// Typically these are things like static locals, lambdas, or blocks.
class MicrosoftNumberingContext : public MangleNumberingContext {
  llvm::DenseMap<const Type *, unsigned> ManglingNumbers;
  unsigned LambdaManglingNumber;
  unsigned StaticLocalNumber;
  unsigned StaticThreadlocalNumber;

public:
  MicrosoftNumberingContext()
      : MangleNumberingContext(), LambdaManglingNumber(0),
        StaticLocalNumber(0), StaticThreadlocalNumber(0) {}

  unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
    return ++LambdaManglingNumber;
  }

  unsigned getManglingNumber(const BlockDecl *BD) override {
    const Type *Ty = nullptr;
    return ++ManglingNumbers[Ty];
  }

  unsigned getStaticLocalNumber(const VarDecl *VD) override {
    if (VD->getTLSKind())
      return ++StaticThreadlocalNumber;
    return ++StaticLocalNumber;
  }

  unsigned getManglingNumber(const VarDecl *VD,
                             unsigned MSLocalManglingNumber) override {
    return MSLocalManglingNumber;
  }

  unsigned getManglingNumber(const TagDecl *TD,
                             unsigned MSLocalManglingNumber) override {
    return MSLocalManglingNumber;
  }
};

class MSHIPNumberingContext : public MangleNumberingContext {
  MicrosoftNumberingContext HostCtx;
  ItaniumNumberingContext DeviceCtx;

public:

  unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
    return HostCtx.getManglingNumber(CallOperator);
  }

  unsigned getManglingNumber(const BlockDecl *BD) override {
    return HostCtx.getManglingNumber(BD);
  }

  unsigned getStaticLocalNumber(const VarDecl *VD) override {
    return HostCtx.getStaticLocalNumber(VD);
  }

  unsigned getManglingNumber(const VarDecl *VD,
                             unsigned MSLocalManglingNumber) override {
    return HostCtx.getManglingNumber(VD, MSLocalManglingNumber);
  }

  unsigned getManglingNumber(const TagDecl *TD,
                             unsigned MSLocalManglingNumber) override {
    return HostCtx.getManglingNumber(TD, MSLocalManglingNumber);
  }

  bool hasDeviceMangleNumberingContext() override { return true; }

  unsigned getDeviceManglingNumber(const CXXMethodDecl *CallOperator) override {
    return DeviceCtx.getManglingNumber(CallOperator);
  }
};

class MicrosoftCXXABI : public CXXABI {
  ASTContext &Context;
  llvm::SmallDenseMap<CXXRecordDecl *, CXXConstructorDecl *> RecordToCopyCtor;

  llvm::SmallDenseMap<TagDecl *, DeclaratorDecl *>
      UnnamedTagDeclToDeclaratorDecl;
  llvm::SmallDenseMap<TagDecl *, TypedefNameDecl *>
      UnnamedTagDeclToTypedefNameDecl;

public:
  MicrosoftCXXABI(ASTContext &Ctx) : Context(Ctx) { }

  MemberPointerInfo
  getMemberPointerInfo(const MemberPointerType *MPT) const override;

  CallingConv getDefaultMethodCallConv(bool isVariadic) const override {
    if (!isVariadic &&
        Context.getTargetInfo().getTriple().getArch() == llvm::Triple::x86)
      return CC_X86ThisCall;
    return Context.getTargetInfo().getDefaultCallingConv();
  }

  bool isNearlyEmpty(const CXXRecordDecl *RD) const override {
    llvm_unreachable("unapplicable to the MS ABI");
  }

  const CXXConstructorDecl *
  getCopyConstructorForExceptionObject(CXXRecordDecl *RD) override {
    return RecordToCopyCtor[RD];
  }

  void
  addCopyConstructorForExceptionObject(CXXRecordDecl *RD,
                                       CXXConstructorDecl *CD) override {
    assert(CD != nullptr);
    assert(RecordToCopyCtor[RD] == nullptr || RecordToCopyCtor[RD] == CD);
    RecordToCopyCtor[RD] = CD;
  }

  void addTypedefNameForUnnamedTagDecl(TagDecl *TD,
                                       TypedefNameDecl *DD) override {
    TD = TD->getCanonicalDecl();
    DD = DD->getCanonicalDecl();
    TypedefNameDecl *&I = UnnamedTagDeclToTypedefNameDecl[TD];
    if (!I)
      I = DD;
  }

  TypedefNameDecl *getTypedefNameForUnnamedTagDecl(const TagDecl *TD) override {
    return UnnamedTagDeclToTypedefNameDecl.lookup(
        const_cast<TagDecl *>(TD->getCanonicalDecl()));
  }

  void addDeclaratorForUnnamedTagDecl(TagDecl *TD,
                                      DeclaratorDecl *DD) override {
    TD = TD->getCanonicalDecl();
    DD = cast<DeclaratorDecl>(DD->getCanonicalDecl());
    DeclaratorDecl *&I = UnnamedTagDeclToDeclaratorDecl[TD];
    if (!I)
      I = DD;
  }

  DeclaratorDecl *getDeclaratorForUnnamedTagDecl(const TagDecl *TD) override {
    return UnnamedTagDeclToDeclaratorDecl.lookup(
        const_cast<TagDecl *>(TD->getCanonicalDecl()));
  }

  std::unique_ptr<MangleNumberingContext>
  createMangleNumberingContext() const override {
    if (Context.getLangOpts().CUDA)
      return std::make_unique<MSHIPNumberingContext>();
    return std::make_unique<MicrosoftNumberingContext>();
  }
};
}

// getNumBases() seems to only give us the number of direct bases, and not the
// total.  This function tells us if we inherit from anybody that uses MI, or if
// we have a non-primary base class, which uses the multiple inheritance model.
static bool usesMultipleInheritanceModel(const CXXRecordDecl *RD) {
  while (RD->getNumBases() > 0) {
    if (RD->getNumBases() > 1)
      return true;
    assert(RD->getNumBases() == 1);
    const CXXRecordDecl *Base =
        RD->bases_begin()->getType()->getAsCXXRecordDecl();
    if (RD->isPolymorphic() && !Base->isPolymorphic())
      return true;
    RD = Base;
  }
  return false;
}

MSInheritanceAttr::Spelling CXXRecordDecl::calculateInheritanceModel() const {
  if (!hasDefinition() || isParsingBaseSpecifiers())
    return MSInheritanceAttr::Keyword_unspecified_inheritance;
  if (getNumVBases() > 0)
    return MSInheritanceAttr::Keyword_virtual_inheritance;
  if (usesMultipleInheritanceModel(this))
    return MSInheritanceAttr::Keyword_multiple_inheritance;
  return MSInheritanceAttr::Keyword_single_inheritance;
}

MSInheritanceAttr::Spelling
CXXRecordDecl::getMSInheritanceModel() const {
  MSInheritanceAttr *IA = getAttr<MSInheritanceAttr>();
  assert(IA && "Expected MSInheritanceAttr on the CXXRecordDecl!");
  return IA->getSemanticSpelling();
}

MSVtorDispAttr::Mode CXXRecordDecl::getMSVtorDispMode() const {
  if (MSVtorDispAttr *VDA = getAttr<MSVtorDispAttr>())
    return VDA->getVtorDispMode();
  return MSVtorDispAttr::Mode(getASTContext().getLangOpts().VtorDispMode);
}

// Returns the number of pointer and integer slots used to represent a member
// pointer in the MS C++ ABI.
//
// Member function pointers have the following general form;  however, fields
// are dropped as permitted (under the MSVC interpretation) by the inheritance
// model of the actual class.
//
//   struct {
//     // A pointer to the member function to call.  If the member function is
//     // virtual, this will be a thunk that forwards to the appropriate vftable
//     // slot.
//     void *FunctionPointerOrVirtualThunk;
//
//     // An offset to add to the address of the vbtable pointer after
//     // (possibly) selecting the virtual base but before resolving and calling
//     // the function.
//     // Only needed if the class has any virtual bases or bases at a non-zero
//     // offset.
//     int NonVirtualBaseAdjustment;
//
//     // The offset of the vb-table pointer within the object.  Only needed for
//     // incomplete types.
//     int VBPtrOffset;
//
//     // An offset within the vb-table that selects the virtual base containing
//     // the member.  Loading from this offset produces a new offset that is
//     // added to the address of the vb-table pointer to produce the base.
//     int VirtualBaseAdjustmentOffset;
//   };
static std::pair<unsigned, unsigned>
getMSMemberPointerSlots(const MemberPointerType *MPT) {
  const CXXRecordDecl *RD = MPT->getMostRecentCXXRecordDecl();
  MSInheritanceAttr::Spelling Inheritance = RD->getMSInheritanceModel();
  unsigned Ptrs = 0;
  unsigned Ints = 0;
  if (MPT->isMemberFunctionPointer())
    Ptrs = 1;
  else
    Ints = 1;
  if (MSInheritanceAttr::hasNVOffsetField(MPT->isMemberFunctionPointer(),
                                          Inheritance))
    Ints++;
  if (MSInheritanceAttr::hasVBPtrOffsetField(Inheritance))
    Ints++;
  if (MSInheritanceAttr::hasVBTableOffsetField(Inheritance))
    Ints++;
  return std::make_pair(Ptrs, Ints);
}

CXXABI::MemberPointerInfo MicrosoftCXXABI::getMemberPointerInfo(
    const MemberPointerType *MPT) const {
  // The nominal struct is laid out with pointers followed by ints and aligned
  // to a pointer width if any are present and an int width otherwise.
  const TargetInfo &Target = Context.getTargetInfo();
  unsigned PtrSize = Target.getPointerWidth(0);
  unsigned IntSize = Target.getIntWidth();

  unsigned Ptrs, Ints;
  std::tie(Ptrs, Ints) = getMSMemberPointerSlots(MPT);
  MemberPointerInfo MPI;
  MPI.HasPadding = false;
  MPI.Width = Ptrs * PtrSize + Ints * IntSize;

  // When MSVC does x86_32 record layout, it aligns aggregate member pointers to
  // 8 bytes.  However, __alignof usually returns 4 for data memptrs and 8 for
  // function memptrs.
  if (Ptrs + Ints > 1 && Target.getTriple().isArch32Bit())
    MPI.Align = 64;
  else if (Ptrs)
    MPI.Align = Target.getPointerAlign(0);
  else
    MPI.Align = Target.getIntAlign();

  if (Target.getTriple().isArch64Bit()) {
    MPI.Width = llvm::alignTo(MPI.Width, MPI.Align);
    MPI.HasPadding = MPI.Width != (Ptrs * PtrSize + Ints * IntSize);
  }
  return MPI;
}

CXXABI *clang::CreateMicrosoftCXXABI(ASTContext &Ctx) {
  return new MicrosoftCXXABI(Ctx);
}

