{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exporting Frozen Graphs in Tensorflow 2 \n",
    "In order to use a trained model as input to MIGraphX, the model must be first be saved in a frozen graph format. This was accomplished in Tensorflow 1 by launching a graph in a tf.Session and then saving the session. However, Tensorflow has decided to deprecate Sessions in favor of functions and SavedModel format.  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After importing the necessary libraries, the next step is to instantiate a model. For simplicity, in this example we will use a resnet50 architecture with pre-trained imagenet weights. These weights may also be trained or fine-tuned before freezing. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "tf.enable_eager_execution() #May not be required depending on tensorflow version\n",
    "from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "MODEL_NAME = \"resnet50\"\n",
    "model = tf.keras.applications.ResNet50(weights=\"imagenet\")\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SavedModel format\n",
    "The simplest way to save a model is through saved\\_model.save()\n",
    "\n",
    "This will create an equivalent tensorflow program which can later be loaded for fine-tuning or inference, although it is not directly compatible with MIGraphX."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.saved_model.save(model, \"./Saved_Models/{}\".format(MODEL_NAME))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert to ConcreteFunction\n",
    "To begin, we need to get the function equivalent of the model and then concretize the function to avoid retracing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model = tf.function(lambda x: model(x))\n",
    "full_model = full_model.get_concrete_function(\n",
    "    x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Freeze ConcreteFunction and Serialize\n",
    "Since we are saving the graph for the purpose of inference, all variables can be made constant (i.e. \"frozen\").\n",
    "\n",
    "Next, we need to obtain a serialized GraphDef representation of the graph. \n",
    "\n",
    "\n",
    "Optionally, the operators can be printed out layer by layer followed by the inputs and outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "frozen_func = convert_variables_to_constants_v2(full_model)\n",
    "frozen_func.graph.as_graph_def()\n",
    "\n",
    "layers = [op.name for op in frozen_func.graph.get_operations()]\n",
    "print(\"-\" * 50)\n",
    "print(\"Frozen model layers: \")\n",
    "for layer in layers:\n",
    "    print(layer)\n",
    "\n",
    "print(\"-\" * 50)\n",
    "print(\"Frozen model inputs: \")\n",
    "print(frozen_func.inputs)\n",
    "print(\"Frozen model outputs: \")\n",
    "print(frozen_func.outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save Frozen Graph as Protobuf\n",
    "Finally, we can save to hard drive, and now the frozen graph will be stored as `./frozen_models/<MODEL_NAME>_frozen_graph.pb`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.io.write_graph(graph_or_graph_def=frozen_func.graph,\n",
    "                  logdir=\"./frozen_models\",\n",
    "                  name=\"{}_frozen_graph.pb\".format(MODEL_NAME),\n",
    "                  as_text=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Assuming MIGraphX has already been built and installed on your system, the driver can be used to verify that the frozen graph has been correctly exported. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "driver = \"/opt/rocm/bin/migraphx-driver\"\n",
    "command = \"read\"\n",
    "model_path = \"./frozen_models/{}_frozen_graph.pb\".format(MODEL_NAME)\n",
    "process = subprocess.run([driver, command, model_path], \n",
    "                         stdout=subprocess.PIPE, \n",
    "                         universal_newlines=True)\n",
    "\n",
    "print(process.stdout)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow",
   "language": "python",
   "name": "tensorflow"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
