{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Object Detection with YoloV4\n",
    "This notebook is intended to be an example of how to use MIGraphX to perform object detection. The model used below is a pre-trained yolov4 from the ONNX model zoo. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "if not os.path.exists(\"./utilities/coco.names\"):\n",
    "    !wget https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/dependencies/coco.names -P ./utilities/\n",
    "if not os.path.exists(\"./utilities/yolov4_anchors.txt\"):\n",
    "    !wget https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/dependencies/yolov4_anchors.txt -P ./utilities/\n",
    "if not os.path.exists(\"./utilities/input.jpg\"):\n",
    "    # The image used is from the COCO dataset (https://cocodataset.org/#explore)\n",
    "    # Other images can be tested by replacing the link below\n",
    "    image_link = \"https://farm3.staticflickr.com/2009/2306189268_88cc86b30f_z.jpg\"\n",
    "    !wget -O ./utilities/input.jpg $image_link\n",
    "if not os.path.exists(\"./utilities/yolov4.onnx\"):\n",
    "    !wget https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/model/yolov4.onnx -P ./utilities/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Serialize model using MIGraphX Driver\n",
    "Please refer to the [MIGraphX Driver example](../../migraphx/migraphx_driver) if you would like more information about this tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
    "    !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n",
    "if not os.path.exists(\"yolov4.mxr\"):\n",
    "    !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import libraries \n",
    "Please refer to [this section](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX#using-migraphx-python-module) of the main README if the migraphx module is not found. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import migraphx\n",
    "import cv2\n",
    "import time\n",
    "import numpy as np\n",
    "import image_processing as ip\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read and pre-process image data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 416\n",
    "\n",
    "original_image = cv2.imread(\"./utilities/input.jpg\")\n",
    "original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)\n",
    "original_image_size = original_image.shape[:2]\n",
    "\n",
    "image_data = ip.image_preprocess(np.copy(original_image), [input_size, input_size])\n",
    "image_data = image_data[np.newaxis, ...].astype(np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and run model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load serialized model (either single- or half-precision)\n",
    "model = migraphx.load(\"yolov4.mxr\", format=\"msgpack\")\n",
    "#model = migraphx.load(\"yolov4_fp16.mxr\", format=\"msgpack\")\n",
    "\n",
    "# Get the name of the input parameter and convert image data to an MIGraphX argument\n",
    "input_name = next(iter(model.get_parameter_shapes()))\n",
    "input_argument = migraphx.argument(image_data)\n",
    "\n",
    "# Evaluate the model and convert the outputs for post-processing\n",
    "outputs = model.run({input_name: input_argument})\n",
    "detections = [np.ndarray(shape=out.get_shape().lens(), buffer=np.array(out.tolist()), dtype=float) for out in outputs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Post-process the model outputs and display image with detection bounding boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ANCHORS = \"./utilities/yolov4_anchors.txt\"\n",
    "STRIDES = [8, 16, 32]\n",
    "XYSCALE = [1.2, 1.1, 1.05]\n",
    "\n",
    "ANCHORS = ip.get_anchors(ANCHORS)\n",
    "STRIDES = np.array(STRIDES)\n",
    "\n",
    "pred_bbox = ip.postprocess_bbbox(detections, ANCHORS, STRIDES, XYSCALE)\n",
    "bboxes = ip.postprocess_boxes(pred_bbox, original_image_size, input_size, 0.25)\n",
    "bboxes = ip.nms(bboxes, 0.213, method='nms')\n",
    "image = ip.draw_bbox(original_image, bboxes)\n",
    "\n",
    "image = Image.fromarray(image)\n",
    "image.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.3 64-bit ('base': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "metadata": {
   "interpreter": {
    "hash": "d7283edef085bb46d38a3069bce96b3de1793019cb5bd7b1e86bf9785b67f304"
   }
  },
  "interpreter": {
   "hash": "d7283edef085bb46d38a3069bce96b3de1793019cb5bd7b1e86bf9785b67f304"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
