{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Super Resolution Inference with AMD MIGraphX\n",
    "This notebook is inspired from: https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/dependencies/Run_Super_Resolution_Model.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install --upgrade pip #needed for opencv-python installation\n",
    "!pip3 install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from resizeimage import resizeimage\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download ONNX Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -nc https://github.com/onnx/models/raw/master/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import MIGraphX Python Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import migraphx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_img = Image.open(\"./cat.jpg\")\n",
    "print(orig_img.size)\n",
    "img = resizeimage.resize_cover(orig_img, [224,224], validate=False)\n",
    "img_ycbcr = img.convert('YCbCr')\n",
    "img_y_0, img_cb, img_cr = img_ycbcr.split()\n",
    "img_ndarray = np.asarray(img_y_0)\n",
    "\n",
    "img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0)\n",
    "img_5 = img_4.astype(np.float32) / 255.0\n",
    "img_5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = migraphx.parse_onnx(\"super-resolution-10.onnx\")\n",
    "model.compile(migraphx.get_target(\"gpu\"))\n",
    "#model.print()\n",
    "\n",
    "print(model.get_parameter_names())\n",
    "print(model.get_parameter_shapes())\n",
    "print(model.get_output_shapes())\n",
    "\n",
    "\n",
    "result = model.run({\n",
    "         \"input\": img_5\n",
    "     })\n",
    "\n",
    "data = np.array(result[0])[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Postprocessing Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_out_y = Image.fromarray(np.uint8((data* 255.0).clip(0, 255)[0]), mode='L')\n",
    "# get the output image follow post-processing step from PyTorch implementation\n",
    "final_img = Image.merge(\n",
    "    \"YCbCr\", [\n",
    "        img_out_y,\n",
    "        img_cb.resize(img_out_y.size, Image.BICUBIC),\n",
    "        img_cr.resize(img_out_y.size, Image.BICUBIC),\n",
    "    ]).convert(\"RGB\")\n",
    "final_img.save(\"output.jpg\")\n",
    "print(final_img.size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## PSNR Comparison Output vs Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "\n",
    "imgIN = cv2.imread('cat.jpg')\n",
    "imgOUT = cv2.imread('output.jpg')\n",
    "imgIN = cv2.cvtColor(imgIN, cv2.COLOR_BGR2RGB) #BGR to RGB\n",
    "imgOUT = cv2.cvtColor(imgOUT, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "imgIN_resized = cv2.resize(imgIN, (672,672)) #Resizing input to 672\n",
    "\n",
    "psnr = cv2.PSNR(imgIN_resized, imgOUT) #dimensions need to be same\n",
    "print(\"PSNR Value = %.3f db\"%psnr)\n",
    "\n",
    "fig = plt.figure(figsize=(16, 16))\n",
    "sp1 = fig.add_subplot(1, 2, 1)\n",
    "sp1.title.set_text('Output Super Resolution Image (%sx%s)'%(imgOUT.shape[0], imgOUT.shape[1]))\n",
    "plt.imshow(imgOUT)\n",
    "\n",
    "sp2 = fig.add_subplot(1, 2, 2)\n",
    "sp2.title.set_text('Input Image (%sx%s)'%(imgIN.shape[0], imgIN.shape[1]))\n",
    "plt.imshow(imgIN)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
