//==============================================================================
// Copyright (c) 2015-2018 Advanced Micro Devices, Inc. All rights reserved.
/// \author AMD Developer Tools Team
/// \file
/// \brief  This class interfaces with GPA to retrieve PMC and write the output file
//==============================================================================

#include <iostream>
#include <algorithm>

#include <hsa_ext_amd.h>
#include <amd_hsa_kernel_code.h>

#include <AMDTOSWrappers/Include/osThread.h>

#include "DeviceInfoUtils.h"
#include "GPUPerfAPI-ROCm.h"

#include "HSAGPAProfiler.h"
#include "HSAModule.h"
#include "HSARTModuleLoader.h"
#include "HSAAgentIterateReplacer.h"
#include "AutoGenerated/HSAPMCInterception.h"

#include "FinalizerInfoManager.h"
#include "HSAKernelDemangler.h"

#include "../Common/Logger.h"
#include "../Common/FileUtils.h"
#include "../Common/KernelProfileResultManager.h"

#include "../CLOccupancyAgent/CLOccupancyInfoManager.h"
#include <ProfilerOutputFileDefs.h>

using namespace std;

HSAGPAProfiler::HSAGPAProfiler(void) :
    m_uiCurKernelCount(0),
    m_uiMaxKernelCount(DEFAULT_MAX_KERNELS),
    m_uiOutputLineCount(0),
    m_isProfilingEnabled(true),
    m_isProfilerInitialized(false),
    m_bDelayStartEnabled(false),
    m_bProfilerDurationEnabled(false),
    m_delayInMilliseconds(0ul),
    m_durationInMilliseconds(0ul),
    m_pDelayTimer(nullptr),
    m_pDurationTimer(nullptr),
    m_commandListId(nullptr)
{
}

void HSAGPAProfilerTimerEndResponse(ProfilerTimerType timerType)
{
    switch (timerType)
    {
        case PROFILEDELAYTIMER:
            HSAGPAProfiler::Instance()->EnableProfiling(true);
            unsigned long profilerDuration;

            if (HSAGPAProfiler::Instance()->IsProfilerDurationEnabled(profilerDuration))
            {
                HSAGPAProfiler::Instance()->CreateTimer(PROFILEDURATIONTIMER, profilerDuration);
                HSAGPAProfiler::Instance()->SetTimerFinishHandler(PROFILEDURATIONTIMER, HSAGPAProfilerTimerEndResponse);
                HSAGPAProfiler::Instance()->StartTimer(PROFILEDURATIONTIMER);
            }

            break;

        case PROFILEDURATIONTIMER:
            HSAGPAProfiler::Instance()->EnableProfiling(false);
            break;

        default:
            break;
    }
}

HSAGPAProfiler::~HSAGPAProfiler(void)
{
    if (nullptr != m_pDelayTimer)
    {
        m_pDelayTimer->stopTimer();
        SAFE_DELETE(m_pDelayTimer);
    }

    if (nullptr != m_pDurationTimer)
    {
        m_pDurationTimer->stopTimer();
        SAFE_DELETE(m_pDurationTimer);
    }
}

bool HSAGPAProfiler::WaitForCompletedSession(uint64_t queueId, uint32_t timeoutSeconds)
{
    bool retVal = true;

    // to avoid an infinite loop, bail after spinning for the specified number of seconds
    static const uint32_t SLEEP_TIME_MILLISECONDS = 1;
    uint64_t waitCount = (timeoutSeconds * 1000) / SLEEP_TIME_MILLISECONDS;
    uint64_t safetyNet = 0;

    bool queueFound = m_activeSessionMap.find(queueId) != m_activeSessionMap.end();

    if (!queueFound)
    {
        Log(logERROR, "Unknown queue specified\n");
        retVal = false;
    }

    while (queueFound && (safetyNet < waitCount))
    {
        safetyNet++;

        if (!CheckForCompletedSession(queueId))
        {
            OSUtils::Instance()->SleepMillisecond(SLEEP_TIME_MILLISECONDS);
        }

        queueFound = m_activeSessionMap.find(queueId) != m_activeSessionMap.end();
    }

    if (queueFound)
    {
        // previous session never completed
        Log(logERROR, "Session never completed after waiting %d seconds\n", timeoutSeconds);
        retVal = false;
    }

    return retVal;
}

void HSAGPAProfiler::WaitForCompletedSessions(uint32_t timeoutSeconds)
{
    bool waitForSessionCompletedSucceeded = true;
    auto it = m_activeSessionMap.begin();

    while (waitForSessionCompletedSucceeded && it != m_activeSessionMap.end())
    {
        waitForSessionCompletedSucceeded = WaitForCompletedSession(it->first, timeoutSeconds);

        // reinitialize the iterator -- WaitForCompletedSession can remove an item from the map
        it = m_activeSessionMap.begin();
    }
}

bool HSAGPAProfiler::CheckForCompletedSession(uint64_t queueId)
{
    bool retVal = false;
    QueueSessionMap::iterator it = m_activeSessionMap.find(queueId);

    if (m_activeSessionMap.end() != it && it->second.m_sessionEnded)
    {
        GPA_Status sessionStatus = m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_IsSessionComplete(it->second.m_sessionID));

        if (GPA_STATUS_OK == sessionStatus)
        {
            WriteSessionResult(it->second);
            m_gpaUtils.Close();
            retVal = true;
        }
    }

    if (retVal)
    {
        m_activeSessionMap.erase(queueId);
    }

    return retVal;
}

bool HSAGPAProfiler::IsGPUDevice(hsa_agent_t agent)
{
    bool retVal = false;

    hsa_device_type_t deviceType;
    hsa_status_t status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_DEVICE, &deviceType);

    if (HSA_STATUS_SUCCESS == status && HSA_DEVICE_TYPE_GPU == deviceType)
    {
        retVal = true;
    }

    return retVal;
}

hsa_status_t HSAGPAProfiler::GetGPUDeviceIDs(hsa_agent_t agent, void* pData)
{
    hsa_status_t status = HSA_STATUS_SUCCESS;

    if (NULL == pData)
    {
        status = HSA_STATUS_ERROR_INVALID_ARGUMENT;
    }
    else
    {
        if (IsGPUDevice(agent))
        {
            uint32_t deviceId;
            status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, static_cast<hsa_agent_info_t>(HSA_AMD_AGENT_INFO_CHIP_ID), &deviceId);

            if (HSA_STATUS_SUCCESS == status)
            {
                static_cast<vector<uint32_t>*>(pData)->push_back(deviceId);
            }
        }
    }

    return status;
}

bool HSAGPAProfiler::Init(const Parameters& params, std::string& strErrorOut)
{
#if defined (_LINUX) || defined (LINUX)
    // Set kernel output flags.
    m_hsaKernelAssembly.SetOutputIsaFlag(params.m_bOutputISA);
    m_hsaKernelAssembly.SetOutputHsailFlag(params.m_bOutputHSAIL);
#endif

    if (!m_isProfilerInitialized)
    {
        Log(logMESSAGE, "Initializing HSAGPAProfiler\n");
        const size_t nMaxPass = 1;

        if (!params.m_bStartDisabled)
        {
            m_bDelayStartEnabled = params.m_bDelayStartEnabled;
            m_bProfilerDurationEnabled = params.m_bProfilerDurationEnabled;
            m_delayInMilliseconds = params.m_delayInMilliseconds;
            m_durationInMilliseconds = params.m_durationInMilliseconds;
            m_isProfilingEnabled = m_delayInMilliseconds > 0 ? false : true;

            if (m_bDelayStartEnabled)
            {
                CreateTimer(PROFILEDELAYTIMER, m_delayInMilliseconds);

                if (nullptr != m_pDelayTimer)
                {
                    m_pDelayTimer->SetTimerFinishHandler(HSAGPAProfilerTimerEndResponse);
                    m_pDelayTimer->startTimer(true);
                }
            }
            else if (m_bProfilerDurationEnabled)
            {
                CreateTimer(PROFILEDURATIONTIMER, m_durationInMilliseconds);

                if (nullptr != m_pDurationTimer)
                {
                    m_pDurationTimer->SetTimerFinishHandler(HSAGPAProfilerTimerEndResponse);
                    m_pDurationTimer->startTimer(true);
                }
            }
        }
        else
        {
            m_isProfilingEnabled = !params.m_bStartDisabled;
        }

        m_uiMaxKernelCount = params.m_uiMaxKernels;

        if (params.m_strCounterFile.empty())
        {
            cout << "No counter file specified. Only counters that will fit into a single pass will be enabled.\n";
        }

        // Set output file
        SetOutputFile(params.m_strOutputFile);

        // Set list separator
        KernelProfileResultManager::Instance()->SetListSeparator(params.m_cOutputSeparator);

        // Init CSV file header and column row
        InitHeader();

        CounterList enabledCounters;
        m_gpaUtils.InitGPA(GPA_API_ROCM,
                           params.m_strDLLPath,
                           strErrorOut,
                           params.m_strCounterFile.empty() ? NULL : params.m_strCounterFile.c_str(),
                           &enabledCounters,
                           nMaxPass); // only allow a single pass

        SP_TODO("support enforcing single pass when there are multiple HSA devices");

        // Enable all counters if no counter file is specified or counter file is empty.
        if (enabledCounters.empty())
        {
#ifdef AMDT_INTERNAL
            // Internal mode must have a counter file specified.
            cout << "Please specify a counter file using -c. No counter is enabled." << endl;
            return false;
#else
            vector<uint32_t> gpuDeviceIds;
            hsa_status_t status = g_pRealCoreFunctions->hsa_iterate_agents_fn(HSAAgentIterateReplacer::Instance()->GetAgentIterator(GetGPUDeviceIDs, g_pRealCoreFunctions), &gpuDeviceIds);

            if (HSA_STATUS_SUCCESS == status)
            {
                set<string> counterSet;
                CounterList tempCounters;

                for (auto deviceId : gpuDeviceIds)
                {
                    CounterList counterNames;
                    // TODO: need to get revision id from HSA runtime (SWDEV-79571)
                    m_gpaUtils.GetAvailableCountersForDevice(deviceId, REVISION_ID_ANY, nMaxPass, counterNames);
                    tempCounters.insert(tempCounters.end(), counterNames.begin(), counterNames.end());
                }

                // remove duplicated counter
                for (CounterList::iterator it = tempCounters.begin(); it != tempCounters.end(); ++it)
                {
                    if (counterSet.find(*it) == counterSet.end())
                    {
                        counterSet.insert(*it);
                        enabledCounters.push_back(*it);
                    }
                }

                m_gpaUtils.SetEnabledCounters(enabledCounters);
            }

#endif // AMDT_INTERNAL
        }

        for (CounterList::iterator it = enabledCounters.begin(); it != enabledCounters.end(); ++it)
        {
            KernelProfileResultManager::Instance()->AddProfileResultItem(*it);
        }

        m_isProfilerInitialized = true;
    }

    return true;
}

void HSAGPAProfiler::InitHeader()
{
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%d.%d", FILE_HEADER_PROFILE_FILE_VERSION, RCP_MAJOR_VERSION, RCP_MINOR_VERSION));
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%d.%d.%d", FILE_HEADER_PROFILER_VERSION, RCP_MAJOR_VERSION, RCP_MINOR_VERSION, RCP_BUILD_NUMBER));
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=HSA", FILE_HEADER_API));
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%s", FILE_HEADER_APPLICATION, FileUtils::GetExeFullPath().c_str()));
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%s", FILE_HEADER_APPLICATION_ARGS, GlobalSettings::GetInstance()->m_params.m_strCmdArgs.asUTF8CharArray()));
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%s", FILE_HEADER_WORKING_DIRECTORY, GlobalSettings::GetInstance()->m_params.m_strWorkingDir.asUTF8CharArray()));

    EnvVarMap envVarMap = GlobalSettings::GetInstance()->m_params.m_mapEnvVars;

    if (envVarMap.size() > 0)
    {
        KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%d", FILE_HEADER_FULL_ENVIRONMENT, GlobalSettings::GetInstance()->m_params.m_bFullEnvBlock));

        for (EnvVarMap::const_iterator it = envVarMap.begin(); it != envVarMap.end(); ++it)
        {
            KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%s=%s", FILE_HEADER_ENV_VAR, it->first.asUTF8CharArray(), it->second.asUTF8CharArray()));
        }
    }

    // TODO: determine what device info to include in file header for HSA
    //for (CLPlatformSet::iterator idxPlatform = m_platformList.begin(); idxPlatform != m_platformList.end(); idxPlatform++)
    //{
    //   KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("Device %s Platform Vendor=%s", idxPlatform->strDeviceName.c_str(), idxPlatform->strPlatformVendor.c_str()));
    //   KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("Device %s Platform Name=%s", idxPlatform->strDeviceName.c_str(), idxPlatform->strPlatformName.c_str()));
    //   KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("Device %s Platform Version=%s", idxPlatform->strDeviceName.c_str(), idxPlatform->strPlatformVersion.c_str()));
    //   KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("Device %s CLDriver Version=%s", idxPlatform->strDeviceName.c_str(), idxPlatform->strDriverVersion.c_str()));
    //   KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("Device %s CLRuntime Version=%s", idxPlatform->strDeviceName.c_str(), idxPlatform->strCLRuntime.c_str()));
    //   KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("Device %s NumberAppAddressBits=%d", idxPlatform->strDeviceName.c_str(), idxPlatform->uiNbrAddressBits));
    //}

    std::string strOSVersion = OSUtils::Instance()->GetOSInfo();
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%s", FILE_HEADER_OS_VERSION, OSUtils::Instance()->GetOSInfo().c_str()));
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%s", FILE_HEADER_DISPLAY_NAME, GlobalSettings::GetInstance()->m_params.m_strSessionName.c_str()));
    KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%c", FILE_HEADER_LIST_SEPARATOR, GlobalSettings::GetInstance()->m_params.m_cOutputSeparator));

    if (m_uiMaxKernelCount != DEFAULT_MAX_KERNELS)
    {
        KernelProfileResultManager::Instance()->AddHeader(StringUtils::FormatString("%s=%d", FILE_HEADER_MAX_NUMBER_OF_KERNELS_TO_PROFILE, m_uiMaxKernelCount));
    }

    // TODO: add back in columns when kernel stats are added
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_METHOD);
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_EXECUTION_ORDER);
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_THREAD_ID);
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_GLOBAL_WORK_SIZE);
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_WORK_GROUP_SIZE);
    //KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_TIME);
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_LOCAL_MEM_SIZE);
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_VGPRs);
    KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_SGPRs);
    //KernelProfileResultManager::Instance()->AddProfileResultItem(CSV_COMMON_COLUMN_SCRATCH_REGS);
}

void HSAGPAProfiler::SetOutputFile(const std::string& strOutputFile)
{
    if (strOutputFile.empty())
    {
        // If output file is not set, we use exe name as file name
        m_strOutputFile = FileUtils::GetDefaultOutputPath() + FileUtils::GetExeName() + ".hsa." + PERF_COUNTER_EXT;
    }
    else
    {
        std::string strExtension("");

        strExtension = FileUtils::GetFileExtension(strOutputFile);

        if (strExtension != PERF_COUNTER_EXT)
        {
            if ((strExtension == TRACE_EXT) || (strExtension == OCCUPANCY_EXT))
            {
                std::string strBaseFileName = FileUtils::GetBaseFileName(strOutputFile);
                m_strOutputFile = strBaseFileName + ".hsa." + PERF_COUNTER_EXT;
            }
            else
            {
                m_strOutputFile = strOutputFile + ".hsa." + PERF_COUNTER_EXT;
            }
        }
        else
        {
            m_strOutputFile = strOutputFile;
        }
    }

    KernelProfileResultManager::Instance()->SetOutputFile(m_strOutputFile);
}

bool HSAGPAProfiler::PopulateKernelStatsFromDispatchPacket(const hsa_kernel_dispatch_packet_t* pAqlPacket,
                                                           const std::string& strAgentName,
                                                           KernelStats& kernelStats,
                                                           hsa_agent_t agent)
{
    bool retVal = true;

    if (nullptr == pAqlPacket || 0 == pAqlPacket->kernel_object)
    {
        Log(logERROR, "Unable to get Kernel Stats from dispatch packet.\n");
        retVal = false;
    }
    else
    {
        FinalizerInfoManager* pFinalizerInfoMan = FinalizerInfoManager::Instance();

#ifdef _DEBUG
        Log(logMESSAGE, "Lookup %llu\n", pAqlPacket->kernel_object);

        Log(logMESSAGE, "Dump m_codeHandleToSymbolHandleMap\n");

        for (auto mapItem : pFinalizerInfoMan->m_codeHandleToSymbolHandleMap)
        {
            Log(logMESSAGE, "  Item: %llu == %llu\n", mapItem.first, mapItem.second);

            if (pAqlPacket->kernel_object == mapItem.first)
            {
                Log(logMESSAGE, "  Match found!\n");
            }
        }

        Log(logMESSAGE, "End Dump m_codeHandleToSymbolHandleMap\n");
#endif

        std::string symName;

        if (pFinalizerInfoMan->m_codeHandleToSymbolHandleMap.count(pAqlPacket->kernel_object) > 0)
        {
            uint64_t symHandle = pFinalizerInfoMan->m_codeHandleToSymbolHandleMap[pAqlPacket->kernel_object];

            if (pFinalizerInfoMan->m_symbolHandleToNameMap.count(symHandle) > 0)
            {
                symName = pFinalizerInfoMan->m_symbolHandleToNameMap[symHandle];
                Log(logMESSAGE, "Lookup: CodeHandle: %llu, SymHandle: %llu, symName: %s\n", pAqlPacket->kernel_object, symHandle, symName.c_str());
            }
        }

        if (symName.empty())
        {
            symName = "<UnknownKernelName>";
        }
        else
        {
            symName = DemangleKernelName(symName);
        }

        kernelStats.m_strName = symName + "_" + strAgentName;

        const amd_kernel_code_t* pKernelCode = reinterpret_cast<const amd_kernel_code_t*>(pAqlPacket->kernel_object);

        HSAModule* pHsaModule = HSARTModuleLoader<HSAModule>::Instance()->GetHSARTModule();

        if (nullptr == pHsaModule)
        {
            Log(logERROR, "Unable to load HSA RT Module\n");
            retVal = false;
        }
        else
        {
            const void* pKernelHostAddress = nullptr;

            if (nullptr != pHsaModule->ven_amd_loader_query_host_address)
            {
                hsa_status_t status = pHsaModule->ven_amd_loader_query_host_address(reinterpret_cast<const void*>(pAqlPacket->kernel_object), &pKernelHostAddress);

                if (HSA_STATUS_SUCCESS == status)
                {
                    pKernelCode = reinterpret_cast<const amd_kernel_code_t*>(pKernelHostAddress);
                }
            }

#if defined (_LINUX) || defined (LINUX)
            // Extract the disassembly
            if (nullptr != pHsaModule->ven_amd_loader_loaded_code_object_get_info)
            {
                if (pFinalizerInfoMan->m_kernelObjHandleToExeHandleMap.count(pAqlPacket->kernel_object) > 0)
                {
                    uint64_t exeHandle = pFinalizerInfoMan->m_kernelObjHandleToExeHandleMap[pAqlPacket->kernel_object];
                    auto exeAndAgentHandle = std::make_pair(exeHandle, agent.handle);

                    if (pFinalizerInfoMan->m_exeAndAgentHandleToLoadedCodeObjHandleMap.count(exeAndAgentHandle) > 0)
                    {
                        hsa_loaded_code_object_t loadedCodeObject;

                        loadedCodeObject.handle = pFinalizerInfoMan->m_exeAndAgentHandleToLoadedCodeObjHandleMap[exeAndAgentHandle];

                        uint64_t loadedCodeObjectSize = 0;
                        uint64_t loadedCodeObjectAddress = 0;

                        if (HSA_STATUS_SUCCESS == pHsaModule->ven_amd_loader_loaded_code_object_get_info(loadedCodeObject, HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_MEMORY_SIZE, &loadedCodeObjectSize))
                        {
                            if (HSA_STATUS_SUCCESS == pHsaModule->ven_amd_loader_loaded_code_object_get_info(loadedCodeObject, HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_MEMORY_BASE, &loadedCodeObjectAddress))
                            {
                                char* buffer = reinterpret_cast<char*>(loadedCodeObjectAddress);
                                std::vector<char> codeObjectBinary(buffer, buffer + loadedCodeObjectSize);

                                char deviceNameBuffer[64] = {0};
                                uint32_t deviceId = 0;

                                if (HSA_STATUS_SUCCESS == pHsaModule->agent_get_info(agent, HSA_AGENT_INFO_NAME, deviceNameBuffer))
                                {
                                    if(HSA_STATUS_SUCCESS == pHsaModule->agent_get_info(agent, static_cast<hsa_agent_info_t>(HSA_AMD_AGENT_INFO_CHIP_ID), &deviceId))
                                    {
                                        std::string deviceName(deviceNameBuffer);
                                        std::string outputDir("");
                                        bool isGPU = false;
                                        uint32_t deviceType = 0;

                                        if (HSA_STATUS_SUCCESS == pHsaModule->agent_get_info(agent, HSA_AGENT_INFO_DEVICE, &deviceType))
                                        {
                                            isGPU = (HSA_DEVICE_TYPE_GPU == deviceType);
                                        }

                                        if (FileUtils::GetWorkingDirectory(m_strOutputFile, outputDir))
                                        {
                                            m_hsaKernelAssembly.Generate(codeObjectBinary, deviceName, symName, strAgentName, outputDir, isGPU);
                                        }
                                    }
                                    else
                                    {
                                        Log(logERROR, "Unable to get device ID.\n");
                                    }
                                }
                                else
                                {
                                    Log(logERROR, "Unable to get device name.\n");
                                }
                            }
                            else
                            {
                                Log(logERROR, "Unable to extract code object address info.\n");
                            }
                        }
                        else
                        {
                            Log(logERROR, "Unable to extract code object size info.\n");
                        }
                    }
                }
            }
#else
            SP_UNREFERENCED_PARAMETER(agent);
#endif
        }

        if (nullptr == pKernelCode)
        {
            Log(logERROR, "Unable to get Kernel Stats from dispatch packet: kernel code object is null.\n");
            retVal = false;
        }
        else
        {
            kernelStats.m_kernelInfo.m_nUsedGPRs = pKernelCode->workitem_vgpr_count;
            kernelStats.m_kernelInfo.m_nUsedScalarGPRs = pKernelCode->wavefront_sgpr_count;
            kernelStats.m_kernelInfo.m_nUsedLDSSize = pKernelCode->workgroup_group_segment_byte_size;

            kernelStats.m_kernelInfo.m_nAvailableGPRs = 256; // TODO: get value from runtime when available
            kernelStats.m_kernelInfo.m_nAvailableScalarGPRs = 104; // TODO: get value from runtime when available
            kernelStats.m_kernelInfo.m_nAvailableLDSSize = 65536; // TODO: get value from runtime when available

            // extract the number of dimensions from the setup field in the packet
            kernelStats.m_uWorkDim = pAqlPacket->setup & ((1 << HSA_KERNEL_DISPATCH_PACKET_SETUP_WIDTH_DIMENSIONS) - 1);

            kernelStats.m_workGroupSize[0] = pAqlPacket->workgroup_size_x;
            kernelStats.m_workGroupSize[1] = pAqlPacket->workgroup_size_y;
            kernelStats.m_workGroupSize[2] = pAqlPacket->workgroup_size_z;

            kernelStats.m_globalWorkSize[0] = pAqlPacket->grid_size_x;
            kernelStats.m_globalWorkSize[1] = pAqlPacket->grid_size_y;
            kernelStats.m_globalWorkSize[2] = pAqlPacket->grid_size_z;

            // per the HSA RT team, the thread we get the pre-dispatch callback from is
            // the same thread that dispatched the kernel.
            kernelStats.m_threadId = osGetUniqueCurrentThreadId();
        }
    }

    return retVal;
}

bool HSAGPAProfiler::Begin(const rocprofiler_callback_data_t* pRocProfilerData)
{
    SpAssertRet(NULL != pRocProfilerData) false;

    const hsa_agent_t agent = pRocProfilerData->agent;
    const hsa_queue_t* pQueue = pRocProfilerData->queue;
    const hsa_kernel_dispatch_packet_t* pAqlPacket = pRocProfilerData->packet;

    SpAssertRet(NULL != pQueue) false;
    SpAssertRet(NULL != pAqlPacket) false;

    bool retVal = true;

    char agentName[64] = {0};
    hsa_status_t status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_NAME, agentName);

    std::string strAgentName = "<UnknownDeviceName>";

    if (HSA_STATUS_SUCCESS == status)
    {
        strAgentName = std::string(agentName);
    }

    KernelStats kernelStats;
    PopulateKernelStatsFromDispatchPacket(pAqlPacket, strAgentName, kernelStats, agent);

    if (IsGPUDevice(agent))
    {
        if (!m_gpaUtils.IsInitialized())
        {
            return false;
        }

        ++m_uiCurKernelCount;

        // make sure any previous sessions are completed before starting a new one.
        WaitForCompletedSessions();
        SpAssertRet(m_activeSessionMap.empty()) false; // there should be no active sessions at this point

        m_mtx.lock();
        SpAssertRet(pQueue != NULL) false;

        //TODO: If we ever want to support opening more than one context at a
        //      time, we will need to create GPA_ROCm_Context intances on the
        //      heap and manage their lifetime. Using a local can cause
        //      GPA_OpenContext to think that we are opening the same context
        //      more than once, and it will throw an error
        GPA_ROCm_Context gpaContext;
        gpaContext.m_pAgent = &agent;
        gpaContext.m_pQueue = pQueue;
        bool bRet = m_gpaUtils.Open(&gpaContext);
        SpAssertRet(bRet) false;
        GPA_SessionId currentSessionId = nullptr;
        bRet = m_gpaUtils.CreateSession(currentSessionId);
        SpAssertRet(bRet) false;
        bRet = m_gpaUtils.EnableCounters(currentSessionId);
        SpAssertRet(bRet) false;

        int stat = m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_BeginSession(currentSessionId));
        gpa_uint32 gpaPassCount = 0;
        stat += m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_GetPassCount(currentSessionId, &gpaPassCount));

        SpAssertRet(1 == gpaPassCount) false;
        SpAssertRet(nullptr == m_commandListId) false;

        stat += m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_BeginCommandList(currentSessionId, 0, GPA_NULL_COMMAND_LIST, GPA_COMMAND_LIST_NONE, &m_commandListId));
        stat += m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_BeginSample(0, m_commandListId));

        retVal = stat == static_cast<int>(GPA_STATUS_OK);

        if (retVal)
        {
            SessionInfo sessionInfo = { currentSessionId, kernelStats, strAgentName, false };
            m_activeSessionMap[pQueue->id] = sessionInfo;

            if (GlobalSettings::GetInstance()->m_params.m_bKernelOccupancy)
            {
                if (!AddOccupancyEntry(kernelStats, strAgentName, agent))
                {
                    Log(logERROR, "Unable to add Occupancy data\n");
                }
            }
        }
    }

    return retVal;
}

bool HSAGPAProfiler::End()
{
    bool retVal = true;

    if (!m_gpaUtils.IsInitialized())
    {
        retVal = false;
    }

    if (retVal)
    {
        int stat = m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_EndSample(m_commandListId));
        stat += m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_EndCommandList(m_commandListId));

        //TODO: confirm that there will only be a single active session
        //      or reinstate the code that looks up a session by the queue here
        assert(1 == m_activeSessionMap.size());
        QueueSessionMap::iterator it = m_activeSessionMap.begin();

        if (m_activeSessionMap.end() != it)
        {
            stat += m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_EndSession(it->second.m_sessionID));

            if (GPA_STATUS_OK == stat)
            {
                it->second.m_sessionEnded = true;
            }

            m_commandListId = nullptr;
        }
        else
        {
            stat = GPA_STATUS_ERROR_SESSION_NOT_FOUND;
        }

        m_mtx.unlock();
        retVal = (stat == static_cast<int>(GPA_STATUS_OK));
    }

    return retVal;
}

bool HSAGPAProfiler::WriteSessionResult(const SessionInfo& sessionInfo)
{
    if (!m_gpaUtils.IsInitialized())
    {
        return false;
    }

    ++m_uiOutputLineCount;

    KernelProfileResultManager::Instance()->BeginKernelInfo();

    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_METHOD, sessionInfo.m_kernelStats.m_strName);
    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_EXECUTION_ORDER, m_uiOutputLineCount);
    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_THREAD_ID, sessionInfo.m_kernelStats.m_threadId);

    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_GLOBAL_WORK_SIZE,
                                                            StringUtils::FormatString("{%7lu %7lu %7lu}", sessionInfo.m_kernelStats.m_globalWorkSize[0], sessionInfo.m_kernelStats.m_globalWorkSize[1], sessionInfo.m_kernelStats.m_globalWorkSize[2]));
    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_WORK_GROUP_SIZE,
                                                            StringUtils::FormatString("{%5lu %5lu %5lu}", sessionInfo.m_kernelStats.m_workGroupSize[0], sessionInfo.m_kernelStats.m_workGroupSize[1], sessionInfo.m_kernelStats.m_workGroupSize[2]));

    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_LOCAL_MEM_SIZE, sessionInfo.m_kernelStats.m_kernelInfo.m_nUsedLDSSize == KERNELINFO_NONE ? "NA" : StringUtils::ToString(sessionInfo.m_kernelStats.m_kernelInfo.m_nUsedLDSSize));
    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_VGPRs, sessionInfo.m_kernelStats.m_kernelInfo.m_nUsedGPRs == KERNELINFO_NONE ? "NA" : StringUtils::ToString(sessionInfo.m_kernelStats.m_kernelInfo.m_nUsedGPRs));
    KernelProfileResultManager::Instance()->WriteKernelInfo(CSV_COMMON_COLUMN_SGPRs, sessionInfo.m_kernelStats.m_kernelInfo.m_nUsedScalarGPRs == KERNELINFO_NONE ? "NA" : StringUtils::ToString(sessionInfo.m_kernelStats.m_kernelInfo.m_nUsedScalarGPRs));

    gpa_uint32 sampleCount;
    m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_GetSampleCount(sessionInfo.m_sessionID, &sampleCount));

    if (0 == sampleCount)
    {
        SpAssertRet(!"No samples found") false;
    }

    size_t sampleResultSizeInBytes = 0;
    m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_GetSampleResultSize(sessionInfo.m_sessionID, 0, &sampleResultSizeInBytes));

    gpa_uint64* pResultsBuffer = reinterpret_cast<gpa_uint64*>(malloc(sampleResultSizeInBytes));
    m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_GetSampleResult(sessionInfo.m_sessionID, 0, sampleResultSizeInBytes, pResultsBuffer));

    gpa_uint32 nEnabledCounters = 0;
    m_gpaUtils.GetGPAFuncTable()->GPA_GetNumEnabledCounters(sessionInfo.m_sessionID, &nEnabledCounters);

    for (gpa_uint32 sample = 0; sample < sampleCount; sample++)
    {
        for (gpa_uint32 counter = 0; counter < nEnabledCounters; counter++)
        {
            gpa_uint32 enabledCounterIndex;

            if (GPA_STATUS_OK != m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_GetEnabledIndex(sessionInfo.m_sessionID, counter, &enabledCounterIndex)))
            {
                SpBreak("Failed to retrieve counter index.");
                continue;
            }

            GPA_Data_Type dataType;

            if (!m_gpaUtils.GetCounterDataType(enabledCounterIndex, dataType))
            {
                SpBreak("Failed to retrieve counter data type.");
                continue;
            }

            string strName;

            if (!m_gpaUtils.GetCounterName(enabledCounterIndex, strName))
            {
                SpBreak("Failed to retrieve counter name.");
                continue;
            }

            if (GPA_DATA_TYPE_UINT64 == dataType)
            {
#ifdef _WIN32
                KernelProfileResultManager::Instance()->WriteKernelInfo(strName, StringUtils::FormatString("%8I64u", pResultsBuffer[counter]));
#else
                KernelProfileResultManager::Instance()->WriteKernelInfo(strName, StringUtils::FormatString("%lu", pResultsBuffer[counter]));
#endif
            }
            else if (GPA_DATA_TYPE_FLOAT64 == dataType)
            {
                KernelProfileResultManager::Instance()->WriteKernelInfo(strName, StringUtils::FormatString("%12.2f", reinterpret_cast<gpa_float64*>(pResultsBuffer)[counter]));
            }
            else
            {
                SpAssertRet(!"Unrecognized data type") false;
            }
        }
    }

    free(pResultsBuffer);
    KernelProfileResultManager::Instance()->EndKernelInfo();
    m_gpaUtils.StatusCheck(m_gpaUtils.GetGPAFuncTable()->GPA_DeleteSession(sessionInfo.m_sessionID));

    return true;
}

bool HSAGPAProfiler::AddOccupancyEntry(const KernelStats& kernelStats, const std::string& deviceName, hsa_agent_t agent)
{
    bool retVal = true;
    OccupancyInfoEntry* pEntry = new(std::nothrow) OccupancyInfoEntry();
    SpAssertRet(pEntry != NULL) false;

    pEntry->m_tid = kernelStats.m_threadId;
    pEntry->m_strKernelName = kernelStats.m_strName;
    pEntry->m_strDeviceName = deviceName;
    pEntry->m_nWorkGroupItemCount = kernelStats.m_workGroupSize[0];

    for (unsigned int i = 1; i < kernelStats.m_uWorkDim; i++)
    {
        pEntry->m_nWorkGroupItemCount *= kernelStats.m_workGroupSize[i];
    }

    uint32_t maxWorkGroupSize;
    hsa_status_t status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_WORKGROUP_MAX_SIZE, &maxWorkGroupSize);

    if (HSA_STATUS_SUCCESS == status)
    {
        pEntry->m_nWorkGroupItemCountMax = maxWorkGroupSize;
    }
    else
    {
        Log(logERROR, "Unable to get Max Workgroup Size from hsa_agent_get_info\n");
        retVal = false;
    }

    pEntry->m_nGlobalItemCount = kernelStats.m_globalWorkSize[0];

    for (unsigned int i = 1; i < kernelStats.m_uWorkDim; i++)
    {
        pEntry->m_nGlobalItemCount *= kernelStats.m_globalWorkSize[i];
    }

    uint32_t maxGridSize;
    status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_GRID_MAX_SIZE, &maxGridSize);

    if (HSA_STATUS_SUCCESS == status)
    {
        pEntry->m_nGlobalItemCountMax = maxGridSize;
    }
    else
    {
        Log(logERROR, "Unable to get Max Grid Size from hsa_agent_get_info\n");
        retVal = false;
    }

    uint32_t numComputeUnit;
    status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, static_cast<hsa_agent_info_t>(HSA_AMD_AGENT_INFO_COMPUTE_UNIT_COUNT), &numComputeUnit);

    if (HSA_STATUS_SUCCESS == status)
    {
        pEntry->m_nNumberOfComputeUnits = numComputeUnit;
    }
    else
    {
        Log(logERROR, "Unable to get Compute Unit Count from hsa_agent_get_info\n");
        retVal = false;
    }

    uint32_t deviceId;
    status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, static_cast<hsa_agent_info_t>(HSA_AMD_AGENT_INFO_CHIP_ID), &deviceId);

    GDT_HW_GENERATION gen = GDT_HW_GENERATION_NONE;

    if (HSA_STATUS_SUCCESS == status)
    {
        if (AMDTDeviceInfoUtils::Instance()->GetHardwareGeneration(deviceId, gen))
        {
            if (!AMDTDeviceInfoUtils::Instance()->HwGenerationToGfxIPVer(gen, pEntry->m_nDeviceGfxIpVer))
            {
                gen = GDT_HW_GENERATION_NONE;
            }
        }
    }

    uint32_t maxWavesPerCU;
    status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, static_cast<hsa_agent_info_t>(HSA_AMD_AGENT_INFO_MAX_WAVES_PER_CU), &maxWavesPerCU);

    if (HSA_STATUS_SUCCESS == status)
    {
        pEntry->m_nMaxWavefrontsPerCU = maxWavesPerCU;
    }
    else
    {
        Log(logERROR, "Unable to get max waves per CU from hsa_agent_get_info\n");
        retVal = false;
    }

    uint32_t wavefrontSize;
    status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_WAVEFRONT_SIZE, &wavefrontSize);

    if (HSA_STATUS_SUCCESS == status)
    {
        pEntry->m_nWavefrontSize = wavefrontSize;
    }
    else
    {
        Log(logERROR, "Unable to get wavefront size from hsa_agent_get_info\n");
        retVal = false;
    }

    uint32_t simdsPerCU;
    status = g_pRealCoreFunctions->hsa_agent_get_info_fn(agent, static_cast<hsa_agent_info_t>(HSA_AMD_AGENT_INFO_NUM_SIMDS_PER_CU), &simdsPerCU);

    if (HSA_STATUS_SUCCESS != status)
    {
        // log an error but don't treat this as an error.
        // If we are running against a pre-ROCm1.5 build, we won't be able to query SIMDs-per-CU.
        // In that case, just use 4 which is the default for all current devices anyway
        simdsPerCU = 4;
        Log(logERROR, "Unable to get number of simds per CU from hsa_agent_get_info\n");
    }

    pEntry->m_nSimdsPerCU = simdsPerCU;

    if (GDT_HW_GENERATION_NONE == gen && !AMDTDeviceInfoUtils::Instance()->GetHardwareGeneration(pEntry->m_strDeviceName.c_str(), gen))
    {
        Log(logERROR, "Unable to query the hw generation\n");
        SAFE_DELETE(pEntry);
        return false;
    }

    if (gen >= GDT_HW_GENERATION_VOLCANICISLAND && gen < GDT_HW_GENERATION_LAST)
    {
        pEntry->m_pCLCUInfo = new(std::nothrow) CLCUInfoVI();
    }
    else if (gen == GDT_HW_GENERATION_SEAISLAND)
    {
        pEntry->m_pCLCUInfo = new(std::nothrow) CLCUInfoSI();
    }
    else
    {
        // don't use the EG/NI/SI occupancy calculation since HSA is supported on CI and newer
        Log(logERROR, "Unsupported hw generation\n");
        SAFE_DELETE(pEntry);
        return false;
    }

    SpAssertRet(pEntry->m_pCLCUInfo != NULL) false;

    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_VECTOR_GPRS_MAX, kernelStats.m_kernelInfo.m_nAvailableGPRs);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_SCALAR_GPRS_MAX, kernelStats.m_kernelInfo.m_nAvailableScalarGPRs);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_LDS_MAX, kernelStats.m_kernelInfo.m_nAvailableLDSSize);

    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_VECTOR_GPRS_USED, kernelStats.m_kernelInfo.m_nUsedGPRs);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_SCALAR_GPRS_USED, kernelStats.m_kernelInfo.m_nUsedScalarGPRs);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_LDS_USED, kernelStats.m_kernelInfo.m_nUsedLDSSize);

    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_KERNEL_WG_SIZE, pEntry->m_nWorkGroupItemCount);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_WG_SIZE_MAX, pEntry->m_nWorkGroupItemCountMax);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_KERNEL_GLOBAL_SIZE, pEntry->m_nGlobalItemCount);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_GLOBAL_SIZE_MAX, pEntry->m_nGlobalItemCountMax);

    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_NBR_COMPUTE_UNITS, pEntry->m_nNumberOfComputeUnits);

    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_DEVICE_NAME, pEntry->m_strDeviceName);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_DEVICE_GFXIP_VER, pEntry->m_nDeviceGfxIpVer);

    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_WAVEFRONT_PER_COMPUTE_UNIT, pEntry->m_nMaxWavefrontsPerCU);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_WAVEFRONT_SIZE, pEntry->m_nWavefrontSize);
    pEntry->m_pCLCUInfo->SetCUParam(CU_PARAMS_SIMDS_PER_CU, pEntry->m_nSimdsPerCU);

    pEntry->m_pCLCUInfo->ComputeCUOccupancy((unsigned int)pEntry->m_nWorkGroupItemCount);

    OccupancyInfoManager::Instance()->AddTraceInfoEntry(pEntry);

    return retVal;
}

bool HSAGPAProfiler::IsProfilerDelayEnabled(unsigned long& delayInMilliseconds)
{
    delayInMilliseconds = m_delayInMilliseconds;
    return m_bDelayStartEnabled;
}


bool HSAGPAProfiler::IsProfilerDurationEnabled(unsigned long& durationInMilliseconds)
{
    durationInMilliseconds = m_durationInMilliseconds;
    return m_bProfilerDurationEnabled;
}


void HSAGPAProfiler::SetTimerFinishHandler(ProfilerTimerType timerType, TimerEndHandler timerEndHandler)
{
    switch (timerType)
    {
        case PROFILEDELAYTIMER:
            if (nullptr != m_pDelayTimer)
            {
                m_pDelayTimer->SetTimerFinishHandler(timerEndHandler);
            }

            break;

        case PROFILEDURATIONTIMER:
            if (nullptr != m_pDurationTimer)
            {
                m_pDurationTimer->SetTimerFinishHandler(timerEndHandler);
            }

            break;

        default:
            break;
    }
}

void HSAGPAProfiler::CreateTimer(ProfilerTimerType timerType, unsigned int timeIntervalInMilliseconds)
{
    switch (timerType)
    {
        case PROFILEDELAYTIMER:
            if (m_pDelayTimer == nullptr && timeIntervalInMilliseconds > 0)
            {
                m_pDelayTimer = new(std::nothrow) ProfilerTimer(timeIntervalInMilliseconds);

                if (nullptr == m_pDelayTimer)
                {
                    Log(logERROR, "CreateTimer: unable to allocate memory for delay timer\n");
                }
                else
                {
                    m_pDelayTimer->SetTimerType(PROFILEDELAYTIMER);
                    m_bDelayStartEnabled = true;
                    m_delayInMilliseconds = timeIntervalInMilliseconds;
                }
            }

            break;

        case PROFILEDURATIONTIMER:
            if (m_pDurationTimer == nullptr && timeIntervalInMilliseconds > 0)
            {
                m_pDurationTimer = new(std::nothrow) ProfilerTimer(timeIntervalInMilliseconds);

                if (nullptr == m_pDurationTimer)
                {
                    Log(logERROR, "CreateTimer: unable to allocate memory for duration timer\n");
                }
                else
                {
                    m_pDurationTimer->SetTimerType(PROFILEDURATIONTIMER);
                    m_bProfilerDurationEnabled = true;
                    m_durationInMilliseconds = timeIntervalInMilliseconds;
                }
            }

            break;

        default:
            break;
    }
}

void HSAGPAProfiler::StartTimer(ProfilerTimerType timerType)
{
    switch (timerType)
    {
        case PROFILEDELAYTIMER:
            if (nullptr != m_pDelayTimer)
            {
                m_pDelayTimer->startTimer(true);
            }

            break;

        case PROFILEDURATIONTIMER:
            if (nullptr != m_pDurationTimer)
            {
                m_pDurationTimer->startTimer(true);
            }

            break;

        default:
            break;
    }
}
