{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT-SQuAD Inference Example with AMD MIGraphX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial shows how to run the BERT-Squad model on ONNX-Runtime with MIGraphX backend."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Requirements "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install -r requirements_bertsquad.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "import time\n",
    "import os.path\n",
    "from os import path\n",
    "import sys\n",
    "\n",
    "import tokenizers\n",
    "from run_onnx_squad import *\n",
    "\n",
    "import migraphx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download BERT ONNX file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -nc https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download uncased file / vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt-get install unzip\n",
    "!wget -q -nc https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n",
    "!unzip -n uncased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_file = 'inputs.json'\n",
    "with open(input_file) as json_file:\n",
    "    test_data = json.load(json_file)\n",
    "    print(json.dumps(test_data, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configuration for inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 256\n",
    "doc_stride = 128\n",
    "max_query_length = 64\n",
    "batch_size = 1\n",
    "n_best_size = 20\n",
    "max_answer_length = 30"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read vocabulary file and tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt')\n",
    "tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert the example to features to input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# preprocess input\n",
    "predict_file = 'inputs.json'\n",
    "\n",
    "# Use read_squad_examples method from run_onnx_squad to read the input file\n",
    "eval_examples = read_squad_examples(input_file=predict_file)\n",
    "\n",
    "# Use convert_examples_to_features method from run_onnx_squad to get parameters from the input\n",
    "input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features(\n",
    "    eval_examples, tokenizer, max_seq_length, doc_stride, max_query_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compile with MIGraphX for GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = migraphx.parse_onnx(\"bertsquad-10.onnx\")\n",
    "model.compile(migraphx.get_target(\"gpu\"))\n",
    "#model.print()\n",
    "\n",
    "model.get_parameter_names()\n",
    "model.get_parameter_shapes()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the input through the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = len(input_ids)\n",
    "bs = batch_size\n",
    "all_results = []\n",
    "\n",
    "for idx in range(0, n):\n",
    "    item = eval_examples[idx]\n",
    "    print(item)\n",
    "\n",
    "    result = model.run({\n",
    "        \"unique_ids_raw_output___9:0\":\n",
    "        np.array([item.qas_id], dtype=np.int64),\n",
    "        \"input_ids:0\":\n",
    "        input_ids[idx:idx + bs],\n",
    "        \"input_mask:0\":\n",
    "        input_mask[idx:idx + bs],\n",
    "        \"segment_ids:0\":\n",
    "        segment_ids[idx:idx + bs]\n",
    "    })\n",
    "\n",
    "    in_batch = result[1].get_shape().lens()[0]\n",
    "    print(in_batch)\n",
    "    start_logits = [float(x) for x in result[1].tolist()]\n",
    "    end_logits = [float(x) for x in result[0].tolist()]\n",
    "    # print(start_logits)\n",
    "    # print(end_logits)\n",
    "    for i in range(0, in_batch):\n",
    "        unique_id = len(all_results)\n",
    "        all_results.append(\n",
    "            RawResult(unique_id=unique_id,\n",
    "                      start_logits=start_logits,\n",
    "                      end_logits=end_logits))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = 'predictions'\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "output_prediction_file = os.path.join(output_dir, \"predictions.json\")\n",
    "output_nbest_file = os.path.join(output_dir, \"nbest_predictions.json\")\n",
    "write_predictions(eval_examples, extra_data, all_results, n_best_size,\n",
    "                  max_answer_length, True, output_prediction_file,\n",
    "                  output_nbest_file)\n",
    "\n",
    "with open(output_prediction_file) as json_file:\n",
    "    test_data = json.load(json_file)\n",
    "    print(json.dumps(test_data, indent=2))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
