from __future__ import print_function
import ompdModule
from ompd_handles import ompd_thread, ompd_task, ompd_parallel
import gdb
import sys
import traceback
from enum import Enum

class ompd_scope(Enum):
        ompd_scope_global = 1
        ompd_scope_address_space = 2
        ompd_scope_thread = 3
        ompd_scope_parallel = 4
        ompd_scope_implicit_task = 5
        ompd_scope_task = 6

class ompd_address_space(object):

        def __init__(self):
                """Initializes an ompd_address_space object by calling ompd_initialize
                in ompdModule.c
                """
                self.addr_space = ompdModule.call_ompd_initialize()
                # maps thread_num (thread id given by gdb) to ompd_thread object with thread handle
                self.threads = {}
                self.states = None
                self.icv_map = None
                self.ompd_tool_test_bp = None
                self.scope_map = {1:'global', 2:'address_space', 3:'thread', 4:'parallel', 5:'implicit_task', 6:'task'}
                gdb.events.stop.connect(self.handle_stop_event)
                self.new_thread_breakpoint = gdb.Breakpoint("ompd_bp_thread_begin", internal=True)
                tool_break_symbol = gdb.lookup_global_symbol("ompd_tool_break")
                if (tool_break_symbol is not None):
                        self.ompd_tool_test_bp = gdb.Breakpoint("ompd_tool_break", internal=True)

        def handle_stop_event(self, event):
                """Sets a breakpoint at different events, e.g. when a new OpenMP
                thread is created.
                """
                if (isinstance(event, gdb.BreakpointEvent)):
                        # check if breakpoint has already been hit
                        if (self.new_thread_breakpoint in event.breakpoints):
                                self.add_thread()
                                gdb.execute('continue')
                                return
                        elif (self.ompd_tool_test_bp is not None and self.ompd_tool_test_bp in event.breakpoints):
                                try:
                                        self.compare_ompt_data()
                                except():
                                        traceback.print_exc()
                elif (isinstance(event, gdb.SignalEvent)):
                        # TODO: what do we need to do on SIGNALS?
                        pass
                else:
                        # TODO: probably not possible?
                        pass

        def get_icv_map(self):
                """Fills ICV map.
                """
                self.icv_map = {}
                current = 0
                more = 1
                while more > 0:
                        tup = ompdModule.call_ompd_enumerate_icvs(self.addr_space, current)
                        (current, next_icv, next_scope, more) = tup
                        self.icv_map[next_icv] = (current, next_scope, self.scope_map[next_scope])
                print('Initialized ICV map successfully for checking OMP API values.')

        def compare_ompt_data(self):
                """Compares OMPT tool data about parallel region to data returned by OMPD functions.
                """
                # make sure all threads and states are set
                self.list_threads(False)

                thread_id = gdb.selected_thread().ptid[1]
                curr_thread = self.get_curr_thread()

                # check if current thread is LWP thread; return if "ompd_rc_unavailable"
                thread_handle = ompdModule.get_thread_handle(thread_id, self.addr_space)
                if thread_handle == -1:
                        print("Skipping OMPT-OMPD checks for non-LWP thread.")
                        return

                print('Comparing OMPT data to OMPD data...')
                field_names = [i.name for i in gdb.parse_and_eval('thread_data').type.fields()]
                thread_data = gdb.parse_and_eval('thread_data')

                if self.icv_map is None:
                        self.get_icv_map()

                # compare state values
                if 'ompt_state' in field_names:
                        if self.states is None:
                                self.enumerate_states()
                        ompt_state = str(thread_data['ompt_state'])
                        ompd_state = str(self.states[curr_thread.get_state()[0]])
                        if ompt_state != ompd_state:
                                print('OMPT-OMPD mismatch: ompt_state (%s) does not match OMPD state (%s)!' % (ompt_state, ompd_state))

                # compare wait_id values
                if 'ompt_wait_id' in field_names:
                        ompt_wait_id = thread_data['ompt_wait_id']
                        ompd_wait_id = curr_thread.get_state()[1]
                        if ompt_wait_id != ompd_wait_id:
                                print('OMPT-OMPD mismatch: ompt_wait_id (%d) does not match OMPD wait id (%d)!' % (ompt_wait_id, ompd_wait_id))

                # compare thread id
                if 'omp_thread_num' in field_names and 'ompd-thread-num-var' in self.icv_map:
                        ompt_thread_num = thread_data['omp_thread_num']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.thread_handle, self.icv_map['ompd-thread-num-var'][1], self.icv_map['ompd-thread-num-var'][0])
                        if ompt_thread_num != icv_value:
                                print('OMPT-OMPD mismatch: omp_thread_num (%d) does not match OMPD thread num according to ICVs (%d)!' % (ompt_thread_num, icv_value))

                # compare thread data
                if 'ompt_thread_data' in field_names:
                        ompt_thread_data = thread_data['ompt_thread_data'].dereference()['value']
                        ompd_value = ompdModule.call_ompd_get_tool_data(3, curr_thread.thread_handle)[0]
                        if ompt_thread_data != ompd_value:
                                print('OMPT-OMPD mismatch: value of ompt_thread_data (%d) does not match that of OMPD data union (%d)!' % (ompt_thread_data, ompd_value))

                # compare number of threads
                if 'omp_num_threads' in field_names and 'ompd-team-size-var' in self.icv_map:
                        ompt_num_threads = thread_data['omp_num_threads']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_parallel_handle(), self.icv_map['ompd-team-size-var'][1], self.icv_map['ompd-team-size-var'][0])
                        if ompt_num_threads != icv_value:
                                print('OMPT-OMPD mismatch: omp_num_threads (%d) does not match OMPD num threads according to ICVs (%d)!' % (ompt_num_threads, icv_value))

                # compare omp level
                if 'omp_level' in field_names and 'levels-var' in self.icv_map:
                        ompt_levels = thread_data['omp_level']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_parallel_handle(), self.icv_map['levels-var'][1], self.icv_map['levels-var'][0])
                        if ompt_levels != icv_value:
                                print('OMPT-OMPD mismatch: omp_level (%d) does not match OMPD levels according to ICVs (%d)!' % (ompt_levels, icv_value))

                # compare active level
                if 'omp_active_level' in field_names and 'active-levels-var' in self.icv_map:
                        ompt_active_levels = thread_data['omp_active_level']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_parallel_handle(), self.icv_map['active-levels-var'][1], self.icv_map['active-levels-var'][0])
                        if ompt_active_levels != icv_value:
                                print('OMPT-OMPD mismatch: active levels (%d) do not match active levels according to ICVs (%d)!' % (ompt_active_levels, icv_value))

                # compare parallel data
                if 'ompt_parallel_data' in field_names:
                        ompt_parallel_data = thread_data['ompt_parallel_data'].dereference()['value']
                        current_parallel_handle = curr_thread.get_current_parallel_handle()
                        ompd_value = ompdModule.call_ompd_get_tool_data(4, current_parallel_handle)[0]
                        if ompt_parallel_data != ompd_value:
                                print('OMPT-OMPD mismatch: value of ompt_parallel_data (%d) does not match that of OMPD data union (%d)!' % (ompt_parallel_data, ompd_value))

                # compare max threads; NOTE: not in ICV map
                if 'omp_max_threads' in field_names and 'nthreads-var' in self.icv_map:
                        ompt_max_threads = thread_data['omp_max_threads']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_task_handle(), self.icv_map['nthreads-var'][1], self.icv_map['nthreads-var'][0])
                        if ompt_max_threads != icv_value:
                                print('OMPT-OMPD mismatch: omp_max_threads (%d) does not match OMPD thread limit according to ICVs (%d)!' % (ompt_max_threads, icv_value))

                # compare omp_parallel
                # NOTE: omp_parallel = true if active-levels-var > 0
                if 'omp_parallel' in field_names:
                        ompt_parallel = thread_data['omp_parallel']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_parallel_handle(), self.icv_map['active-levels-var'][1], self.icv_map['active-levels-var'][0])
                        if ompt_parallel == 1 and icv_value <= 0 or ompt_parallel == 0 and icv_value > 0:
                                print('OMPT-OMPD mismatch: ompt_parallel (%d) does not match OMPD parallel according to ICVs (%d)!' % (ompt_parallel, icv_value))

                # compare omp_final
                if 'omp_final' in field_names and 'ompd-final-var' in self.icv_map:
                        ompt_final = thread_data['omp_final']
                        current_task_handle = curr_thread.get_current_task_handle()
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(current_task_handle, self.icv_map['ompd-final-var'][1], self.icv_map['ompd-final-var'][0])
                        if icv_value != ompt_final:
                                print('OMPT-OMPD mismatch: omp_final (%d) does not match OMPD final according to ICVs (%d)!' % (ompt_final, icv_value))

                # compare omp_dynamic; TODO: test; not in ICV map
                if 'omp_dynamic' in field_names and 'dyn-var' in self.icv_map:
                        ompt_dynamic = thread_data['omp_dynamic']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_task_handle(), self.icv_map['dyn-var'][1], self.icv_map['dyn-var'][0])
                        if icv_value != ompt_dynamic:
                                print('OMPT-OMPD mismatch: omp_dynamic (%d) does not match OMPD dynamic according to ICVs (%d)!' % (ompt_dynamic, icv_value))

                # compare omp_nested; TODO: test; not in ICV map
                if 'omp_nested' in field_names and 'nest-var' in self.icv_map:
                        ompt_nested = thread_data['omp_nested']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(self.addr_space, self.icv_map['nest-var'][1], self.icv_map['nest-var'][0])
                        if ompt_nested != icv_value:
                                print('OMPT-OMPD mismatch: omp_nested (%d) does not match OMPD nested according to ICVs (%d)!' % (ompt_nested, icv_value))

                # compare omp_max_active_levels
                if 'omp_max_active_levels' in field_names and 'max-active-levels-var' in self.icv_map:
                        ompt_max_active_levels = thread_data['omp_max_active_levels']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_task_handle(), self.icv_map['max-active-levels-var'][1], self.icv_map['max-active-levels-var'][0])
                        if ompt_max_active_levels != icv_value:
                                print('OMPT-OMPD mismatch: omp_max_active_levels (%d) does not match OMPD max active levels (%d)!' % (ompt_max_active_levels, icv_value))

                # compare omp_kind; TODO: test; not in ICV map
                if 'omp_kind' in field_names and 'run-sched-var' in self.icv_map:
                        ompt_sched_kind = thread_data['omp_kind']
                        icv_value = ompdModule.call_ompd_get_icv_string_from_scope(curr_thread.get_current_task_handle(), self.icv_map['run-sched-var'][1], self.icv_map['run-sched-var'][0])
                        if ompt_sched_kind not in icv_value:
                                print('OMPT-OMPD mismatch: omp_kind kind (%s) does not match OMPD schedule kind according to ICVs (%s)!' % (str(ompd_kint_sched_kind), str(icv_value)))

                # compare omp_modifier; TODO: test; not in ICV map
                if 'omp_modifier' in field_names and 'run-sched-var' in self.icv_map:
                        ompt_sched_mod = thread_data['omp_modifier']
                        ompd_sched = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_task_handle(), self.icv_map['run-sched-var'][1], self.icv_map['run-sched-var'][0])
                        if ompt_sched_mod != ompd_sched[1]:
                                print('OMPT-OMPD mismatch: omp_kind modifier does not match OMPD schedule modifier according to ICVs!')

                # compare omp_proc_bind
                if 'omp_proc_bind' in field_names and 'bind-var' in self.icv_map:
                        ompt_proc_bind = thread_data['omp_proc_bind']
                        icv_value = ompdModule.call_ompd_get_icv_from_scope(curr_thread.get_current_task_handle(), self.icv_map['bind-var'][1], self.icv_map['bind-var'][0])
                        if icv_value != ompt_proc_bind:
                                print('OMPT-OMPD mismatch: omp_proc_bind (%d) does not match OMPD proc bind according to ICVs (%d)!' % (ompt_proc_bind, icv_value))

                # compare enter and exit frames
                if 'ompt_frame_list' in field_names:
                        ompt_task_frame_dict = thread_data['ompt_frame_list'].dereference()
                        ompt_task_frames = (int(ompt_task_frame_dict['enter_frame']), int(ompt_task_frame_dict['exit_frame']))
                        current_task = curr_thread.get_current_task()
                        ompd_task_frames = current_task.get_task_frame()
                        if ompt_task_frames != ompd_task_frames:
                                print('OMPT-OMPD mismatch: ompt_task_frames (%s) do not match OMPD task frames (%s)!' % (ompt_task_frames, ompd_task_frames))

                # compare task data
                if 'ompt_task_data' in field_names:
                        ompt_task_data = thread_data['ompt_task_data'].dereference()['value']
                        current_task_handle = curr_thread.get_current_task_handle()
                        ompd_value = ompdModule.call_ompd_get_tool_data(6, current_task_handle)[0]
                        if ompt_task_data != ompd_value:
                                print('OMPT-OMPD mismatch: value of ompt_task_data (%d) does not match that of OMPD data union (%d)!' % (ompt_task_data, ompd_value))

        def save_thread_object(self, thread_num, thread_id, addr_space):
                """Saves thread object for thread_num inside threads dictionary.
                """
                thread_handle = ompdModule.get_thread_handle(thread_id, addr_space)
                self.threads[int(thread_num)] = ompd_thread(thread_handle)

        def get_thread(self, thread_num):
                """ Get thread object from map.
                """
                return self.threads[int(thread_num)]

        def get_curr_thread(self):
                """ Get current thread object from map or add new one to map, if missing.
                """
                thread_num = int(gdb.selected_thread().num)
                if thread_num not in self.threads:
                        self.add_thread()
                return self.threads[thread_num]

        def add_thread(self):
                """Add currently selected (*) thread to dictionary threads.
                """
                inf_thread = gdb.selected_thread()
                try:
                        self.save_thread_object(inf_thread.num, inf_thread.ptid[1], self.addr_space)
                except:
                        traceback.print_exc()

        def list_threads(self, verbose):
                """Prints OpenMP threads only that are being tracking inside the "threads" dictionary.
                See handle_stop_event and add_thread.
                """
                list_tids = []
                curr_inferior = gdb.selected_inferior()

                for inf_thread in curr_inferior.threads():
                        list_tids.append((inf_thread.num, inf_thread.ptid))
                if verbose:
                        if self.states is None:
                                self.enumerate_states()
                        for (thread_num, thread_ptid) in sorted(list_tids):
                                if thread_num in self.threads:
                                        try:
                                                print('Thread %i (%i) is an OpenMP thread; state: %s' % (thread_num, thread_ptid[1], self.states[self.threads[thread_num].get_state()[0]]))
                                        except:
                                                traceback.print_exc()
                                else:
                                        print('Thread %i (%i) is no OpenMP thread' % (thread_num, thread_ptid[1]))

        def enumerate_states(self):
                """Helper function for list_threads: initializes map of OMPD states for output of
                'ompd threads'.
                """
                if self.states is None:
                        self.states = {}
                        current = int("0x102", 0)
                        count = 0
                        more = 1

                        while more > 0:
                                tup = ompdModule.call_ompd_enumerate_states(self.addr_space, current)
                                (next_state, next_state_name, more) = tup

                                self.states[next_state] = next_state_name
                                current = next_state
