//===-- ompt_buffer_mgr.h - Target independent OpenMP target RTL -- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Interface to be used for generating and flushing OMPT device trace records.
//
//===----------------------------------------------------------------------===//

#ifndef _OMPT_BUFFER_MGR_H_
#define _OMPT_BUFFER_MGR_H_

#include <condition_variable>
#include <cstdint>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include <omp-tools.h>

// TODO Start with 1 helper thread and add dynamically if required
// Number of helper threads must not execeed 32 since the
// thread-wait-tracker is 32 bits in length.
#define OMPT_NUM_HELPER_THREADS 1

#ifdef OMPT_SUPPORT
#define OMPT_TRACING_IF_ENABLED(stmts)                                         \
  do {                                                                         \
    stmts                                                                      \
  } while (0)
#else
#define OMPT_TRACING_IF_ENABLED(stmts)
#endif

/*
 * Buffer manager for trace records generated by OpenMP master and
 * worker threads. During device init, a tool may register a
 * buffer-request and a buffer-completion callback. The buffer-request
 * callback should be used to allocate new buffers as required. The
 * buffer-complete callback should be used to return trace records to
 * the tool.
 *
 * In addition to trace records, this class manages the helper threads
 * for dispatching a range of trace records to the tool.
 */
class OmptTracingBufferMgr {
public:
  /*
   * A trace record (TR) holds the trace data. Its type
   * can be ompt or native. Currently, only ompt type is implemented.
   */

  /*
   * A TR can be in the following states:
   * TR_init: initial state
   * TR_ready: An OpenMP thread marks a TR ready when it is done
   * populating the TR
   * TR_released: A helper thread marks a TR released after it has
   * completed returning the TR to the tool
   */
  enum TRStatus { TR_init, TR_ready, TR_released };

private:
  // Internal variable for tracking threads to wait for flush
  uint32_t ThreadFlushTracker;

  // Internal variable for tracking threads shutting down
  uint32_t ThreadShutdownTracker;

  /*
   * Metadata capturing the state of a buffer of trace records. Once a
   * buffer is allocated, trace records are carved out by the OpenMP
   * threads.
   *
   * Id, Start, and TotalBytes are not changed once set. But Cursor,
   * RemainingBytes, and isFull can be read/written more than
   * once. Hence, accesses of the 2nd set of locations need to be
   * synchronized.
   */
  struct Buffer {
    uint64_t Id;           // Unique identifier of the buffer
    void *Start;           // Start of allocated space for trace records
    void *Cursor;          // Address of the last trace record carved out
    size_t TotalBytes;     // Total number of bytes in the allocated space
    size_t RemainingBytes; // Total number of unused bytes
                           // corresponding to Cursor
    bool isFull;           // true if no more trace records can be accomodated,
                           // otherwise false
    Buffer(uint64_t id, void *st, void *cr, size_t bytes, size_t rem,
           bool is_full)
        : Id{id}, Start{st}, Cursor{cr}, TotalBytes{bytes},
          RemainingBytes{rem}, isFull{is_full} {}
    Buffer() = delete;
    Buffer(const Buffer &) = delete;
    Buffer &operator=(const Buffer &) = delete;
  };

  using BufPtr = std::shared_ptr<Buffer>;
  using MapId2Buf = std::map<uint64_t, BufPtr>;

  // Map from id to corresponding buffer. The ids are assigned in
  // increasing order of creation.
  MapId2Buf Id2BufferMap;

  // Trace record metadata
  struct TraceRecordMd {
    BufPtr BufAddr;   // Enclosing buffer
    TRStatus TRState; // Status of a trace record
    TraceRecordMd(BufPtr buf) : BufAddr{buf}, TRState{TR_init} {}
    TraceRecordMd() = delete;
    TraceRecordMd(const TraceRecordMd &) = delete;
    TraceRecordMd &operator=(const TraceRecordMd &) = delete;
  };

  using BufMdPtr = std::shared_ptr<TraceRecordMd>;
  using UMapPtr2BufState = std::unordered_map<void *, BufMdPtr>;

  // A hashmap from cursor -> metadata containing the trace record
  // status and the containing buffer
  UMapPtr2BufState Cursor2BufMdMap;

  /*
   * A buffer is flushed when it fills up or when the tool invokes
   * flush_trace. So it's possible that the same buffer may be flushed
   * more than once. When a buffer is flushed the first time, a unique
   * id (flush-id) is generated and assigned to that buffer. Even if
   * it is flushed again, the previously assigned id is maintained for
   * that buffer. This id is loosely used to determine the order in
   * which the buffers are processed and the corresponding trace
   * records released to the tool.
   */

  struct FlushInfo {
    uint64_t FlushId;
    void *FlushCursor;
    BufPtr FlushBuf;
    FlushInfo() = default;
    FlushInfo(uint64_t id, void *cr, BufPtr buf)
        : FlushId{id}, FlushCursor{cr}, FlushBuf{buf} {}
  };

  /*
   * A buffer may be in the following states:
   * Flush_waiting: when a buffer is flushed, either because it is
   * full or because the tool invokes ompt_flush_trace
   * Flush_processing: when a helper thread claims the waiting buffer
   * and is in the process of dispatching buffer-completion callbacks
   * on an associated range of trace records. If all trace records are
   * not released, the state may be reset to Flush_waiting after the
   * buffer-completion callbacks return
   */
  enum BufferFlushStatus { Flush_waiting, Flush_processing };
  struct FlushMd {
    void *FlushCursor;
    BufPtr FlushBuf;
    BufferFlushStatus FlushStatus;
    FlushMd(void *cr, BufPtr buf, BufferFlushStatus status)
        : FlushCursor{cr}, FlushBuf{buf}, FlushStatus{status} {}
    FlushMd() = delete;
  };

  using MapId2Md = std::map<uint64_t, FlushMd>;

  /*
   * A map from a flush-id to metadata containing the current
   * cursor. the corresponding buffer, and its flushed status. If a
   * buffer is flushed multiple times, the cursor is updated to the
   * furthest one
   */
  MapId2Md Id2FlushMdMap;

  using UMapBufPtr2Id = std::unordered_map<BufPtr, uint64_t>;

  // A hash map from a buffer address to the corresponding flush-id
  UMapBufPtr2Id FlushBufPtr2IdMap;

  using USetCursor = std::unordered_set<void *>;

  USetCursor LastCursors;

  using UMapThd2Id = std::unordered_map<std::thread::id, uint32_t>;

  // A hash map from a helper thread id to an integer
  UMapThd2Id HelperThreadIdMap;

  // Mutex to protect Id2BufferMap and Cursor2BufMdMap
  std::mutex BufferMgrMutex;

  // Mutex to protect FlushBufPtr2IdMap and Id2FlushMdMap
  std::mutex FlushMutex;

  // Mutex to protect metadata tracking last cursors of buffer-completion
  // callbacks
  std::mutex LastCursorMutex;

  // Condition variable used by helper thread to signal that flush is requested
  std::condition_variable FlushCv;

  // Condition variable used while waiting for flushing to complete
  std::condition_variable ThreadFlushCv;

  // Condition variable used while waiting for threads to shutdown
  std::condition_variable ThreadShutdownCv;

  // TODO Separate out the helper thread into its own class
  std::vector<std::thread> CompletionThreads;

  // Called when a buffer may be flushed. setComplete should be called without
  // holding any lock
  void setComplete(void *cursor);

  // Called to dispatch buffer-completion callbacks for the trace records in
  // this buffer
  void flushBuffer(FlushInfo);

  // Dispatch a buffer-completion callback with a range of trace records
  void dispatchCallback(void *buffer, void *first_cursor, void *last_cursor);

  // Add a last cursor
  void addLastCursor(void *cursor) {
    std::unique_lock<std::mutex> lck(LastCursorMutex);
    LastCursors.emplace(cursor);
  }

  // Remove a last cursor
  void removeLastCursor(void *cursor) {
    std::unique_lock<std::mutex> lck(LastCursorMutex);
    assert(LastCursors.find(cursor) != LastCursors.end());
    LastCursors.erase(cursor);
  }

  // Reserve a candidate buffer for flushing, preventing other helper threads
  // from accessing it
  FlushInfo findAndReserveFlushedBuf(uint64_t id);

  // Unreserve a buffer so that other helper threads can process it
  void unreserveFlushedBuf(const FlushInfo &);

  // All done with this buffer, so the buffer and its metadata can be removed
  void destroyFlushedBuf(const FlushInfo &);

  // Add a new buffer by an OpenMP thread so that a helper thread can process it
  uint64_t addNewFlushEntry(BufPtr buf, void *cursor);

  // Get the next trace record
  void *getNextTR(void *tr);

  // Get the size of a trace record
  // We support only ompt records today
  size_t getTRSize() { return sizeof(ompt_record_ompt_t); }
  
  // Given a buffer, return the latest cursor
  void *getBufferCursor(BufPtr);

  // Is no more space remaining for trace records in this buffer?
  bool isBufferFull(const FlushInfo &);

  // Have all trace records in this buffer been returned to the tool?
  bool isBufferOwned(const FlushInfo &);

  // Dispatch a buffer-completion callback and indicate that the buffer can be
  // deallocated
  void dispatchBufferOwnedCallback(const FlushInfo &);

  // Main entry point for a helper thread
  void driveCompletion();

  // Examine the flushed buffers and dispatch buffer-completion callbacks
  void invokeCallbacks();

  // The caller does not hold a lock while calling this method
  void waitForFlushCompletion();

  // Given a thread number, set the corresponding bit in the flush
  // tracker. The caller must hold the flush lock.
  void setThreadFlush(uint32_t thd_num) {
    ThreadFlushTracker |= (1 << thd_num);
  }

  // Reset this thread's flush bit. The caller must hold the flush lock
  void resetThisThreadFlush() {
    std::thread::id id = std::this_thread::get_id();
    assert(HelperThreadIdMap.find(id) != HelperThreadIdMap.end());
    ThreadFlushTracker &= ~(1 << HelperThreadIdMap[id]);
  }

  // Given a thread number, set the corresponding bit in the shutdown
  // tracker. The caller must hold the flush lock.
  void setThreadShutdown(uint32_t thd_num) {
    ThreadShutdownTracker |= (1 << thd_num);
  }

  // Reset this thread's shutdown bit. The caller must hold the flush
  // lock
  void resetThisThreadShutdown() {
    std::thread::id id = std::this_thread::get_id();
    assert(HelperThreadIdMap.find(id) != HelperThreadIdMap.end());
    ThreadShutdownTracker &= ~(1 << HelperThreadIdMap[id]);
  }

  // Return true if this thread's flush bit is set. The caller must
  // hold the flush lock
  bool isThisThreadFlushWaitedUpon() {
    std::thread::id id = std::this_thread::get_id();
    assert(HelperThreadIdMap.find(id) != HelperThreadIdMap.end());
    return (ThreadFlushTracker & (1 << HelperThreadIdMap[id])) != 0;
  }

  // Return true if this thread's shutdown bit is set. The caller must
  // hold the flush lock
  bool isThisThreadShutdownWaitedUpon() {
    std::thread::id id = std::this_thread::get_id();
    assert(HelperThreadIdMap.find(id) != HelperThreadIdMap.end());
    return (ThreadShutdownTracker & (1 << HelperThreadIdMap[id])) != 0;
  }

  // The caller must not hold the flush lock
  bool amIHelperThread() {
    std::unique_lock<std::mutex> flush_lock(FlushMutex);
    if (HelperThreadIdMap.find(std::this_thread::get_id()) !=
        HelperThreadIdMap.end())
      return true;
    return false;
  }

  // The caller must hold the appropriate lock
  void init();

  // The caller must hold the flush lock
  void createHelperThreads();

  // The caller must hold the flush lock
  void destroyHelperThreads();

public:
  OmptTracingBufferMgr();
  ~OmptTracingBufferMgr();

  // The caller must not hold the flush lock
  void startHelperThreads();

  // The caller must not hold the flush lock
  void shutdownHelperThreads();

  // Assign a cursor for a new trace record
  void *assignCursor(ompt_callbacks_t type);

  // Get the status of a trace record
  TRStatus getTRStatus(void *tr);

  // Set the status of a trace record
  void setTRStatus(void *tr, TRStatus);

  // Is this a last cursor of a buffer completion callback?
  bool isLastCursor(void *cursor) {
    std::unique_lock<std::mutex> lck(LastCursorMutex);
    return LastCursors.find(cursor) != LastCursors.end();
  }

  // Called for flushing outstanding buffers
  int flushAllBuffers(ompt_device_t *);
};

#endif // _OMPT_BUFFER_MGR_H_
