{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NFNet Inference with AMD MIGraphX\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Normalizer-Free ResNet is a new residual convolutional network providing new state-of-the-art Top-1 accuracy of 86.5% at ImageNet dataset. The most important feature of the model is removing batch normalization. Instead of batch normalization, it uses adaptive gradient clipping to provide same regularization effect of BatchNorm. <br> Details of this network: https://arxiv.org/abs/2102.06171\n",
    "\n",
    "In this notebook, we are showing: <br>\n",
    "- How to optimize NFNet ONNX model with AMD MIGraphX.\n",
    "- How to run inference on AMD GPU with the optimized ONNX model.\n",
    "\n",
    "The NFNet utilized in this example is the smallest NFNet version, F0: 71.5M parameters (83.6% top-1 accuracy on ImageNet)\n",
    "\n",
    "Please make sure MIGraphX Python API is installed following the instructions at Github page."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Requirements"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt-get update\n",
    "!apt-get install ffmpeg libsm6 libxext6  -y "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install --upgrade pip\n",
    "!pip3 install -r requirements_nfnet.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "import json\n",
    "from PIL import Image\n",
    "import time\n",
    "from os import path "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing AMD MIGraphX Python Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import migraphx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create NFNet ONNX file\n",
    "Following repository provides functionality to create NFNet ONNX file from PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -nc https://www.dropbox.com/s/u4ga8zyxtppfzxc/dm_nfnet_f0.onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load ImageNet labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../python_resnet50/imagenet_simple_labels.json') as json_data:\n",
    "    labels = json.load(json_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Load ONNX model using MIGraphX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = migraphx.parse_onnx(\"dm_nfnet_f0.onnx\")\n",
    "model.compile(migraphx.get_target(\"gpu\"))\n",
    "\n",
    "print(model.get_parameter_names())\n",
    "print(model.get_parameter_shapes())\n",
    "print(model.get_output_shapes())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functions for image processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_nxn(image, n):\n",
    "    height, width = image.shape[:2]    \n",
    "    if height > width:\n",
    "        dif = height - width\n",
    "        bar = dif // 2 \n",
    "        square = image[(bar + (dif % 2)):(height - bar),:]\n",
    "        return cv2.resize(square, (n, n))\n",
    "    elif width > height:\n",
    "        dif = width - height\n",
    "        bar = dif // 2\n",
    "        square = image[:,(bar + (dif % 2)):(width - bar)]\n",
    "        return cv2.resize(square, (n, n))\n",
    "    else:\n",
    "        return cv2.resize(image, (n, n))\n",
    "    \n",
    "def preprocess(img_data):\n",
    "    mean_vec = np.array([0.485, 0.456, 0.406])\n",
    "    stddev_vec = np.array([0.229, 0.224, 0.225])\n",
    "    norm_img_data = np.zeros(img_data.shape).astype('float32')\n",
    "    for i in range(img_data.shape[0]):  \n",
    "        norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]\n",
    "    return norm_img_data\n",
    "\n",
    "def input_process(frame, dim):\n",
    "    # Crop and resize original image\n",
    "    cropped = make_nxn(frame, dim)\n",
    "    # Convert from HWC to CHW\n",
    "    chw = cropped.transpose(2,0,1)\n",
    "    # Apply normalization\n",
    "    pp = preprocess(chw)\n",
    "    # Add singleton dimension (CHW to NCHW)\n",
    "    data = np.expand_dims(pp.astype('float32'),0)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fetch example image: traffic light\n",
    "!wget -nc http://farm5.static.flickr.com/4072/4462811418_8bc2bd42ca_z_d.jpg -O traffic_light.jpg\n",
    "# Read the image\n",
    "im = cv2.imread('traffic_light.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process the read image to conform input requirements\n",
    "data_input = input_process(im, 192)\n",
    "\n",
    "# Run the model\n",
    "start = time.time()\n",
    "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n",
    "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n",
    "# Extract the index of the top prediction\n",
    "res_npa = np.array(results[0])\n",
    "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the model again, first one would take long\n",
    "start = time.time()\n",
    "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n",
    "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n",
    "# Extract the index of the top prediction\n",
    "res_npa = np.array(results[0])\n",
    "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
