# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Adds unittest-esque functionality to Legion."""

import argparse
import logging
import sys
import unittest

# pylint: disable=relative-import
# Import common_lib first so we can setup the environment
from lib import common_lib
common_lib.SetupEnvironment()

from legion.lib import task_controller
from legion.lib import task_registration_server
from legion.lib.comm_server import comm_server

BANNER_WIDTH = 80


class TestCase(unittest.TestCase):
  """Test case class with added Legion support."""

  _registration_server = None
  _initialized = False

  @classmethod
  def __new__(cls, *args, **kwargs):
    """Initialize the class and return a new instance."""
    cls._InitializeClass()
    return super(TestCase, cls).__new__(*args, **kwargs)

  def __init__(self, test_name='runTest'):
    super(TestCase, self).__init__(test_name)
    method = getattr(self, test_name, None)
    if method:
      # Install the _RunTest method
      self._TestMethod = method
      setattr(self, test_name, self._RunTest)
    self._output_dir = None

  @property
  def output_dir(self):
    if not self._output_dir:
      self._output_dir = self.rpc.GetOutputDir()
    return self._output_dir

  def _RunTest(self):
    """Runs the test method and provides banner info and error reporting."""
    self._LogInfoBanner(self._testMethodName, self.shortDescription())
    try:
      return self._TestMethod()
    except:
      exc_info = sys.exc_info()
      logging.error('', exc_info=exc_info)
      raise exc_info[0], exc_info[1], exc_info[2]

  @classmethod
  def _InitializeClass(cls):
    """Handles class level initialization.

    There are 2 types of setup/teardown methods that always need to be run:
    1) Framework level setup/teardown
    2) Test case level setup/teardown

    This method installs handlers in place of setUpClass and tearDownClass that
    will ensure both types of setup/teardown methods are called correctly.
    """
    if cls._initialized:
      return
    cls._OriginalSetUpClassMethod = cls.setUpClass
    cls.setUpClass = cls._HandleSetUpClass
    cls._OriginalTearDownClassMethod = cls.tearDownClass
    cls.tearDownClass = cls._HandleTearDownClass
    cls._initialized = True

  @classmethod
  def _LogInfoBanner(cls, method_name, method_doc=None):
    """Formats and logs test case information."""
    logging.info('*' * BANNER_WIDTH)
    logging.info(method_name.center(BANNER_WIDTH))
    if method_doc:
      for line in method_doc.split('\n'):
        logging.info(line.center(BANNER_WIDTH))
    logging.info('*' * BANNER_WIDTH)

  @classmethod
  def CreateTask(cls, *args, **kwargs):
    """Convenience method to create a new task."""
    task = task_controller.TaskController(
        reg_server_port=cls._registration_server.port, *args, **kwargs)
    cls._registration_server.RegisterTaskCallback(
        task.otp, task.OnConnect)
    return task

  @classmethod
  def _SetUpFramework(cls):
    """Perform the framework-specific setup operations."""
    # Setup the registration server
    cls._registration_server = (
        task_registration_server.TaskRegistrationServer())
    common_lib.OnShutdown += cls._registration_server.Shutdown
    cls._registration_server.Start()

    # Setup the event server
    cls.comm_server = comm_server.CommServer()
    common_lib.OnShutdown += cls.comm_server.shutdown
    cls.comm_server.start()

  @classmethod
  def _TearDownFramework(cls):
    """Perform the framework-specific teardown operations."""
    common_lib.Shutdown()

  @classmethod
  def _HandleSetUpClass(cls):
    """Performs common class-level setup operations.

    This method performs test-wide setup such as starting the registration
    server and then calls the original setUpClass method."""
    try:
      cls._LogInfoBanner('setUpClass', 'Performs class level setup.')
      cls._SetUpFramework()
      cls._OriginalSetUpClassMethod()
    except:
      # Make sure we tear down in case of any exceptions
      cls._HandleTearDownClass(setup_failed=True)
      exc_info = sys.exc_info()
      logging.error('', exc_info=exc_info)
      raise exc_info[0], exc_info[1], exc_info[2]

  @classmethod
  def _HandleTearDownClass(cls, setup_failed=False):
    """Performs common class-level tear down operations.

    This method calls the original tearDownClass then performs test-wide
    tear down such as stopping the registration server.
    """
    cls._LogInfoBanner('tearDownClass', 'Performs class level tear down.')
    try:
      if not setup_failed:
        cls._OriginalTearDownClassMethod()
    finally:
      cls._TearDownFramework()


def main():
  unittest.main(argv=sys.argv[:1])
