////////////////////////////////////////////////////////////////////////////////
//
// The University of Illinois/NCSA
// Open Source License (NCSA)
//
// Copyright (c) 2018, Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to
// deal with the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
//  - Redistributions of source code must retain the above copyright notice,
//    this list of conditions and the following disclaimers.
//  - Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimers in
//    the documentation and/or other materials provided with the distribution.
//  - Neither the names of Advanced Micro Devices, Inc,
//    nor the names of its contributors may be used to endorse or promote
//    products derived from this Software without specific prior written
//    permission.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS WITH THE SOFTWARE.
//
////////////////////////////////////////////////////////////////////////////////

#include <cstdlib>
#include <iostream>
#include <stdio.h>
#include <signal.h>

// Debug Agent Headers
#include "AgentLogging.h"
#include "AgentUtils.h"
#include "HSADebugAgentGDBInterface.h"
#include "HSADebugAgent.h"
#include "HSADebugInfo.h"
#include "HSAHandleQueueError.h"

static void INThandler(int sig);
static std::map<uint64_t, std::pair<uint64_t, WaveStateInfo *>> FindWavesAllQueues();

void InitialLinuxSignalsHandler()
{
    char *pSignalsEnvVar;
    pSignalsEnvVar = std::getenv("ROCM_DEBUG_ENABLE_LINUX_SIGNALS");

    if ((pSignalsEnvVar == nullptr))
    {
        return;
    }

    std::string envValue(pSignalsEnvVar);
    if (envValue == "0")
    {
      return;
    }
    else if (envValue == "1")
    {
        signal(SIGINT, INThandler);
        signal(SIGTERM, INThandler);
    }
    else
    {
        AGENT_WARNING("Invalid Invalid value for ROCM_DEBUG_ENABLE_LINUX_SIGNALS, signal handling disabled by default. ");
    }
}

void INThandler(int sig)
{
    {
        std::lock_guard<std::mutex> lock(debugAgentAccessLock);

        if (sig == SIGINT)
        {
            signal(sig, SIG_IGN);
            printf("\nDumping wave state due to SIGINT\n\n");
        }
        else if (sig == SIGTERM)
        {
            signal(sig, SIG_IGN);
            printf("\nDumping wave state due to SIGTERM\n\n");
        }

        GPUAgentInfo *pAgent = _r_rocm_debug_info.pAgentList;
        while (pAgent != nullptr)
        {
            DebugAgentStatus status = DEBUG_AGENT_STATUS_SUCCESS;
            status = PreemptAgentQueues(pAgent);
            if (status != DEBUG_AGENT_STATUS_SUCCESS)
            {
                AGENT_ERROR("Cannot get queue preemption.");
            }
            std::map<uint64_t, std::pair<uint64_t, WaveStateInfo *>> waves = FindWavesAllQueues();
            PrintWaves(pAgent, waves);
            pAgent = pAgent->pNext;
        }
        allQueueWaves.clear();
        std::abort();
    }
}

static std::map<uint64_t, std::pair<uint64_t, WaveStateInfo *>> FindWavesAllQueues()
{
    std::map<uint64_t, std::pair<uint64_t, WaveStateInfo *>> waves;

    for (auto &queueWaves : allQueueWaves)
    {
        for (auto &wave : queueWaves.second)
        {
            auto it = waves.find(wave.regs.pc);
            if (it != waves.end())
            {
                it->second.first ++;
            }
            else
            {
                waves.insert(std::make_pair(wave.regs.pc,
                                            std::make_pair(1, &wave)));
            }
        }
    }
    return waves;
}
