# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference for squad/bert using onnx.

This is going to do the samem as 'python run_squad.py --do_predict=True ...' using a squad/bert model
that was converted to onnx. Lots of code was taken from run_squad.py.
You run it with:


python onnx_squad.py --model $SQUAD_MODEL/squad.onnx \
                     --vocab_file $BERT_BASE_DIR/uncased_L-12_H-768_A-12/vocab.txt
                     --predict_file $SQUAD_DATA/dev-v1.1.json \
                     --bert_config_file $BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_config.json \
                     --output /tmp/
"""

import argparse
import collections
import json
import math
import os
import sys
from timeit import default_timer as timer

import numpy as np
import onnxruntime as onnxrt
import six
from tokenizers import BertWordPieceTokenizer
from tokenizers import pre_tokenizers

RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])

Feature = collections.namedtuple("Feature", [
    "unique_id", "tokens", "example_index", "token_to_orig_map",
    "token_is_max_context"
])


class SquadExample(object):
    """A single training/test example for simple sequence classification."""
    def __init__(self,
                 qas_id,
                 question_text,
                 doc_tokens,
                 orig_answer_text=None,
                 start_position=None,
                 end_position=None):
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = []
        s.append("qas_id: %s" % (self.qas_id))
        s.append("question_text: %s" % (self.question_text))
        s.append("doc_tokens: [%s]" % (" ".join(self.doc_tokens)))
        if self.start_position:
            s.append("start_position: %d" % (self.start_position))
        if self.start_position:
            s.append("end_position: %d" % (self.end_position))
        return ", ".join(s)


def _check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""

    # Because of the sliding window approach taken to scoring documents, a single
    # token can appear in multiple documents. E.g.
    #  Doc: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    #
    # Now the word 'bought' will have two scores from spans B and C. We only
    # want to consider the score with "maximum context", which we define as
    # the *minimum* of its left and right context (the *sum* of left and
    # right context will always be the same, of course).
    #
    # In the example the maximum context for 'bought' would be span C since
    # it has 1 left context and 3 right context, while span B has 4 left context
    # and 0 right context.
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context,
                    num_right_context) + 0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index


def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 doc_stride, max_query_length):
    """Loads a data file into a list of `InputBatch`s."""

    res_input_ids = []
    res_input_mask = []
    res_segment_ids = []
    extra = []
    unique_id = 0

    for (example_index, example) in enumerate(examples):
        query_tokens = tokenizer.encode(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.encode(token, add_special_tokens=False)
            for sub_token in sub_tokens.tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens.tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans,
                                                       doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

            input_ids = []
            for token in tokens:
                input_ids.append(tokenizer.token_to_id(token))

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
            res_input_ids.append(np.array(input_ids, dtype=np.int64))
            res_input_mask.append(np.array(input_mask, dtype=np.int64))
            res_segment_ids.append(np.array(segment_ids, dtype=np.int64))
            feature = Feature(unique_id=unique_id,
                              tokens=tokens,
                              example_index=example_index,
                              token_to_orig_map=token_to_orig_map,
                              token_is_max_context=token_is_max_context)
            extra.append(feature)
            unique_id += 1
    return np.array(res_input_ids), np.array(res_input_mask), np.array(
        res_segment_ids), extra


def read_squad_examples(input_file):
    """Read a SQuAD json file into a list of SquadExample."""
    with open(input_file, "r") as f:
        input_data = json.load(f)["data"]

    def is_whitespace(c):
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False

    examples = []
    for idx, entry in enumerate(input_data):
        for paragraph in entry["paragraphs"]:
            paragraph_text = paragraph["context"]
            doc_tokens = []
            char_to_word_offset = []
            prev_is_whitespace = True
            for c in paragraph_text:
                if is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(c)
                    else:
                        doc_tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)

            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                question_text = qa["question"]
                start_position = None
                end_position = None
                orig_answer_text = None
                example = SquadExample(qas_id=qas_id,
                                       question_text=question_text,
                                       doc_tokens=doc_tokens,
                                       orig_answer_text=orig_answer_text,
                                       start_position=start_position,
                                       end_position=end_position)
                examples.append(example)
    return examples


def write_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length, do_lower_case, output_prediction_file,
                      output_nbest_file):
    """Write final predictions to the json file."""
    example_index_to_features = collections.defaultdict(list)
    for feature in all_features:
        example_index_to_features[feature.example_index].append(feature)

    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction", [
            "feature_index", "start_index", "end_index", "start_logit",
            "end_logit"
        ])

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    for (example_index, example) in enumerate(all_examples):
        features = example_index_to_features[example_index]
        prelim_predictions = []
        for (feature_index, feature) in enumerate(features):
            if not feature.unique_id in unique_id_to_result:
                print("feature not in unique_Id", feature.unique_id)
                continue
            result = unique_id_to_result[feature.unique_id]

            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(
                            start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=feature_index,
                            start_index=start_index,
                            end_index=end_index,
                            start_logit=result.start_logits[start_index],
                            end_logit=result.end_logits[end_index]))

        prelim_predictions = sorted(prelim_predictions,
                                    key=lambda x:
                                    (x.start_logit + x.end_logit),
                                    reverse=True)

        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text", "start_logit", "end_logit"])

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]

            tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
            orig_doc_start = feature.token_to_orig_map[pred.start_index]
            orig_doc_end = feature.token_to_orig_map[pred.end_index]
            orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
            tok_text = " ".join(tok_tokens)

            # De-tokenize WordPieces that have been split off.
            tok_text = tok_text.replace(" ##", "")
            tok_text = tok_text.replace("##", "")

            # Clean whitespace
            tok_text = tok_text.strip()
            tok_text = " ".join(tok_text.split())
            orig_text = " ".join(orig_tokens)

            final_text = get_final_text(tok_text, orig_text, do_lower_case)
            if final_text in seen_predictions:
                continue

            seen_predictions[final_text] = True
            nbest.append(
                _NbestPrediction(text=final_text,
                                 start_logit=pred.start_logit,
                                 end_logit=pred.end_logit))

        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(
                _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        assert len(nbest) >= 1

        total_scores = []
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = float(entry.start_logit)
            output["end_logit"] = float(entry.end_logit)
            nbest_json.append(output)

        all_predictions[example.qas_id] = nbest_json[0]["text"]
        all_nbest_json[example.qas_id] = nbest_json

    with open(output_prediction_file, "w") as writer:
        writer.write(json.dumps(all_predictions, indent=4) + "\n")

    with open(output_nbest_file, "w") as writer:
        writer.write(json.dumps(all_nbest_json, indent=4) + "\n")


def get_final_text(pred_text, orig_text, do_lower_case):
    """Project the tokenized prediction back to the original text."""

    # When we created the data, we kept track of the alignment between original
    # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
    # now `orig_text` contains the span of our original text corresponding to the
    # span that we predicted.
    #
    # However, `orig_text` may contain extra characters that we don't want in
    # our prediction.
    #
    # For example, let's say:
    #   pred_text = steve smith
    #   orig_text = Steve Smith's
    #
    # We don't want to return `orig_text` because it contains the extra "'s".
    #
    # We don't want to return `pred_text` because it's already been normalized
    # (the SQuAD eval script also does punctuation stripping/lower casing but
    # our tokenizer does additional normalization like stripping accent
    # characters).
    #
    # What we really want to return is "Steve Smith".
    #
    # Therefore, we have to apply a semi-complicated alignment heruistic between
    # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
    # can fail in certain cases in which case we just return `orig_text`.

    def _strip_spaces(text):
        ns_chars = []
        ns_to_s_map = collections.OrderedDict()
        for (i, c) in enumerate(text):
            if c == " ":
                continue
            ns_to_s_map[len(ns_chars)] = i
            ns_chars.append(c)
        ns_text = "".join(ns_chars)
        return (ns_text, ns_to_s_map)

    # We first tokenize `orig_text`, strip whitespace from the result
    # and `pred_text`, and check if they are the same length. If they are
    # NOT the same length, the heuristic has failed. If they are the same
    # length, we assume the characters are one-to-one aligned.
    tokenizer = pre_tokenizers.Sequence(
        [pre_tokenizers.Whitespace(),
         pre_tokenizers.Punctuation()])

    tok_text = []
    for item in tokenizer.pre_tokenize_str(orig_text):
        tok_text.append(item[0])

    tok_text = " ".join(tok_text)

    start_position = tok_text.find(pred_text)
    if start_position == -1:
        return orig_text
    end_position = start_position + len(pred_text) - 1

    (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
    (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

    if len(orig_ns_text) != len(tok_ns_text):
        return orig_text

    # We then project the characters in `pred_text` back to `orig_text` using
    # the character-to-character alignment.
    tok_s_to_ns_map = {}
    for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
        tok_s_to_ns_map[tok_index] = i

    orig_start_position = None
    if start_position in tok_s_to_ns_map:
        ns_start_position = tok_s_to_ns_map[start_position]
        if ns_start_position in orig_ns_to_s_map:
            orig_start_position = orig_ns_to_s_map[ns_start_position]

    if orig_start_position is None:
        return orig_text

    orig_end_position = None
    if end_position in tok_s_to_ns_map:
        ns_end_position = tok_s_to_ns_map[end_position]
        if ns_end_position in orig_ns_to_s_map:
            orig_end_position = orig_ns_to_s_map[ns_end_position]

    if orig_end_position is None:
        return orig_text

    output_text = orig_text[orig_start_position:(orig_end_position + 1)]
    return output_text


def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(enumerate(logits),
                             key=lambda x: x[1],
                             reverse=True)
    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes


def _compute_softmax(scores):
    """Compute softmax probability over raw logits."""
    if not scores:
        return []

    max_score = None
    for score in scores:
        if max_score is None or score > max_score:
            max_score = score

    exp_scores = []
    total_sum = 0.0
    for score in scores:
        x = math.exp(score - max_score)
        exp_scores.append(x)
        total_sum += x

    probs = []
    for score in exp_scores:
        probs.append(score / total_sum)
    return probs


def main():
    parser = argparse.ArgumentParser(description='onnx squad')
    parser.add_argument('--model', required=True, help='model')
    parser.add_argument('--vocab_file', required=True, help='vocab_file')
    parser.add_argument('--bert_config_file', help='vocab_file')
    parser.add_argument('--predict_file', required=True, help='predict_file')
    parser.add_argument('--output_dir', help='output dir')
    parser.add_argument('--max_seq_length',
                        type=int,
                        default=256,
                        help='max_seq_length')
    parser.add_argument('--max_query_length',
                        type=int,
                        default=64,
                        help='max_query_length')
    parser.add_argument('--max_answer_length',
                        type=int,
                        default=30,
                        help='max_answer_length')
    parser.add_argument('--n_best_size',
                        type=int,
                        default=20,
                        help='n_best_size')
    parser.add_argument('--doc_stride',
                        type=int,
                        default=128,
                        help='doc_stride')
    parser.add_argument('--batch_size', type=int, default=1, help='batch_size')
    parser.add_argument('--profile',
                        action='store_true',
                        help='enable chrome timeline trace profiling.')
    parser.add_argument('--log', type=int, help='log level.')
    args = parser.parse_args()

    sess_options = None
    if args.profile:
        sess_options = onnxrt.SessionOptions()
        sess_options.enable_profiling = True
        sess_options.profile_file_prefix = os.path.basename(args.model)
    if args.log:
        sess_options = onnxrt.SessionOptions()
        sess_options.session_log_verbosity_level = args.log

    tokenizer = BertWordPieceTokenizer(args.vocab_file)

    eval_examples = read_squad_examples(input_file=args.predict_file)
    input_ids, input_mask, segment_ids, extra_data = \
        convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,
                                     args.doc_stride, args.max_query_length)

    sess = onnxrt.InferenceSession(args.model, sess_options)
    for input_meta in sess.get_inputs():
        print(input_meta)
    n = len(input_ids)
    bs = args.batch_size
    all_results = []
    start = timer()
    for idx in range(0, n, bs):
        data = {
            "input_ids:0": input_ids[idx:idx + bs],
            "input_mask:0": input_mask[idx:idx + bs],
            "segment_ids:0": segment_ids[idx:idx + bs]
        }
        result = sess.run(["unstack:0", "unstack:1"], data)
        in_batch = result[0].shape[1]
        for i in range(0, in_batch):
            unique_id = len(all_results)
            all_results.append(
                RawResult(unique_id=unique_id,
                          start_logits=result[0][0][i],
                          end_logits=result[1][0][i]))
            if unique_id > 0 and unique_id % 100 == 0:
                print("at {} {}sec per item".format(
                    unique_id, (timer() - start) / unique_id))
    end = timer()

    print("total time: {}sec, {}sec per item".format(
        end - start, (end - start) / len(all_results)))

    if args.output_dir:
        output_prediction_file = os.path.join(args.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(args.output_dir,
                                         "nbest_predictions.json")
        write_predictions(eval_examples, extra_data, all_results,
                          args.n_best_size, args.max_answer_length, True,
                          output_prediction_file, output_nbest_file)
    if args.profile:
        trace_file = sess.end_profiling()
        print("trace file written to: {}".format(trace_file))
    return 0


if __name__ == "__main__":
    sys.exit(main())
