# Copyright 2016 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

'''Generates a test suite from NIST PKITS test descriptions.

The output is a set of Type Parameterized Tests which are included by
pkits_unittest.h. See pkits_unittest.h for information on using the tests.
GoogleTest has a limit of 50 tests per type parameterized testcase, so the tests
are split up by section number (this also makes it possible to easily skip
sections that pertain to non-implemented features).

Usage:
  generate_tests.py <PKITS.pdf> <output.h>
'''

import os
import re
import subprocess
import sys
import tempfile


def sanitize_name(s):
  return s.translate(None, ' -')


def finalize_test_case(test_case_name, sanitized_test_names, output):
  output.write('\nWRAPPED_REGISTER_TYPED_TEST_CASE_P(%s' % test_case_name)
  for name in sanitized_test_names:
    output.write(',\n    %s' % name)
  output.write(');\n')


def generate_test(test_case_name, test_number, raw_test_name, certs, crls, should_validate,
                     output):
  sanitized_test_name = 'Section%s%s' % (test_number.split('.')[1],
                                         sanitize_name(raw_test_name))
  certs_formatted = ', '.join('"%s"' % n for n in certs)
  crls_formatted = ', '.join('"%s"' % n for n in crls)
  assert_function = 'ASSERT_TRUE' if should_validate else 'ASSERT_FALSE'
  output.write('''
// %(test_number)s %(raw_test_name)s
WRAPPED_TYPED_TEST_P(%(test_case_name)s, %(sanitized_test_name)s) {
  const char* const certs[] = {
    %(certs_formatted)s
  };
  const char* const crls[] = {
    %(crls_formatted)s
  };
  %(assert_function)s(this->Verify(certs, crls));
}
''' % vars())

  return sanitized_test_name


# Matches a section header, ex: "4.1 Signature Verification"
SECTION_MATCHER = re.compile('^\s*(\d+\.\d+)\s+(.+)\s*$')
# Matches a test header, ex: "4.1.1 Valid Signatures Test1"
TEST_MATCHER = re.compile('^\s*(\d+\.\d+.\d+)\s+(.+)\s*$')
# Match an expected test result. Note that some results in the PDF have a typo
# "path not should validate" instead of "path should not validate".
TEST_RESULT_MATCHER = re.compile(
    '^\s*Expected Result:.*path (should validate|'
    'should not validate|not should validate)')
PATH_HEADER_MATCHER = re.compile('^\s*Certification Path:')
# Matches a line in the certification path, ex: "\u2022 Good CA Cert, Good CA CRL"
PATH_MATCHER = re.compile('^\s*\xe2\x80\xa2\s*(.+)\s*$')
# Matches a page number. These may appear in the middle of multi-line fields and
# thus need to be ignored.
PAGE_NUMBER_MATCHER = re.compile('^\s*\d+\s*$')
# Matches if an entry in a certification path refers to a CRL, ex:
# "onlySomeReasons CA2 CRL1".
CRL_MATCHER = re.compile('^.*CRL\d*$')
def parse_test(lines, i, test_case_name, test_number, test_name, output):
  expected_result = None
  certs = []
  crls = []

  while i < len(lines):
    result_match = TEST_RESULT_MATCHER.match(lines[i])
    i += 1
    if result_match:
      expected_result = result_match.group(1) == 'should validate'
      break

  while i < len(lines):
    path_match = PATH_HEADER_MATCHER.match(lines[i])
    i += 1
    if path_match:
      break

  path_lines = []
  while i < len(lines):
    line = lines[i].strip()
    if TEST_MATCHER.match(line) or SECTION_MATCHER.match(line):
      break
    i += 1
    if not line or PAGE_NUMBER_MATCHER.match(line):
      continue
    path_match = PATH_MATCHER.match(line)
    if path_match:
      path_lines.append(path_match.group(1))
      continue
    # Continuation of previous path line.
    path_lines[-1] += ' ' + line

  for path_line in path_lines:
    for path in path_line.split(','):
      path = sanitize_name(path.strip())
      if CRL_MATCHER.match(path):
        crls.append(path)
      else:
        certs.append(path)

  assert certs
  assert crls
  assert expected_result is not None
  sanitized_test_name = generate_test(test_case_name, test_number, test_name,
                                      certs, crls, expected_result, output)

  return i, sanitized_test_name


def main():
  pkits_pdf_path, output_path = sys.argv[1:]

  pkits_txt_file = tempfile.NamedTemporaryFile()

  subprocess.check_call(['pdftotext', '-layout', '-nopgbrk', '-eol', 'unix',
                         pkits_pdf_path, pkits_txt_file.name])

  test_descriptions = pkits_txt_file.read()

  # Extract section 4 of the text, which is the part that contains the tests.
  test_descriptions = test_descriptions.split(
      '4 Certification Path Validation Tests')[-1]
  test_descriptions = test_descriptions.split(
      '5 Relationship to Previous Test Suite', 1)[0]

  output = open(output_path, 'w')
  output.write('// Autogenerated by %s, do not edit\n\n' % sys.argv[0])
  output.write('// Hack to allow disabling type parameterized test cases.\n'
               '// See https://github.com/google/googletest/issues/389\n')
  output.write('#define WRAPPED_TYPED_TEST_P(CaseName, TestName) '
               'TYPED_TEST_P(CaseName, TestName)\n')
  output.write('#define WRAPPED_REGISTER_TYPED_TEST_CASE_P(CaseName, ...) '
               'REGISTER_TYPED_TEST_CASE_P(CaseName, __VA_ARGS__)\n\n')

  test_case_name = None
  sanitized_test_names = []

  lines = test_descriptions.splitlines()

  i = 0
  while i < len(lines):
    section_match = SECTION_MATCHER.match(lines[i])
    match = TEST_MATCHER.match(lines[i])
    i += 1

    if section_match:
      if test_case_name:
        finalize_test_case(test_case_name, sanitized_test_names, output)
        sanitized_test_names = []

      # TODO(mattm): Handle certificate policies tests.
      if section_match.group(1) in ('4.8', '4.9', '4.10', '4.11', '4.12'):
        test_case_name = None
        output.write('\n// Skipping section %s\n' % section_match.group(1))
        continue

      test_case_name = 'PkitsTest%02d%s' % (
          int(section_match.group(1).split('.')[-1]),
          sanitize_name(section_match.group(2)))
      output.write('\ntemplate <typename PkitsTestDelegate>\n')
      output.write('class %s : public PkitsTest<PkitsTestDelegate> {};\n' % test_case_name)
      output.write('TYPED_TEST_CASE_P(%s);\n' % test_case_name)

    if match:
      test_number = match.group(1)
      test_name = match.group(2)
      if not test_case_name:
        output.write('// Skipped %s %s\n' % (test_number, test_name))
        continue
      i, sanitized_test_name = parse_test(lines, i, test_case_name, test_number,
                                         test_name, output)
      if sanitized_test_name:
        sanitized_test_names.append(sanitized_test_name)

  if test_case_name:
    finalize_test_case(test_case_name, sanitized_test_names, output)


if __name__ == '__main__':
  main()
