# Copyright (c) 2018 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import os, sys, struct
import datetime, pytz
from nnir import *

tensor_type_nnir2openvx = {
    'F032' : 'VX_TYPE_FLOAT32',
    'F016' : 'VX_TYPE_FLOAT16',
    'U016' : 'VX_TYPE_UINT16',
    'I016' : 'VX_TYPE_INT16',
    'U008' : 'VX_TYPE_UINT8',
    'I064' : 'VX_TYPE_INT64',
}

def generateLicenseForCPP(f):
        f.write( \
"""/*
MIT License

Copyright (c) 2018 Advanced Micro Devices, Inc. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/

/* This file is generated by nnir2openvx.py on %s */
""" % (datetime.datetime.now(tz=pytz.timezone('America/Los_Angeles')).isoformat()))

def generateLicenseForScript(f):
        f.write( \
"""################################################################################
#
# MIT License
#
# Copyright (c) 2018 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################

# This file is generated by nnir2openvx.py on %s
""" % (datetime.datetime.now(tz=pytz.timezone('America/Los_Angeles')).isoformat()))

def generateCMakeFiles(graph,outputFolder):
    fileName = outputFolder + '/CMakeLists.txt'
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForScript(f)
        f.write( \
"""
cmake_minimum_required (VERSION 2.8)
project (annmodule)
set (CMAKE_CXX_STANDARD 11)
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
find_package(OpenCL REQUIRED)
find_package(OpenCV QUIET)
include_directories (${OpenCL_INCLUDE_DIRS} ${OpenCL_INCLUDE_DIRS}/Headers )
include_directories (/opt/rocm/mivisionx/include)
link_directories    (/opt/rocm/mivisionx/lib)
list(APPEND SOURCES annmodule.cpp)
add_library(${PROJECT_NAME} SHARED ${SOURCES})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -mf16c -std=c++11")
target_link_libraries(${PROJECT_NAME} openvx vx_nn pthread)

add_executable(anntest anntest.cpp)
if (OpenCV_FOUND)
  target_compile_definitions(anntest PUBLIC ENABLE_OPENCV=1)
  include_directories(${OpenCV_INCLUDE_DIRS})
  target_link_libraries(anntest ${OpenCV_LIBRARIES})
else(OpenCV_FOUND)
  target_compile_definitions(anntest PUBLIC ENABLE_OPENCV=0)
endif(OpenCV_FOUND)
target_link_libraries(anntest openvx vx_nn pthread ${PROJECT_NAME})

add_library(annpython SHARED annpython.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -mf16c -std=c++11")
target_link_libraries(annpython ${PROJECT_NAME} openvx vx_nn pthread)
""")
    if not os.path.isdir(outputFolder + '/cmake'):
        os.mkdir(outputFolder + '/cmake')
    fileName = outputFolder + '/cmake/FindOpenCL.cmake'
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForScript(f)
        f.write( \
"""
find_path(OPENCL_INCLUDE_DIRS
    NAMES OpenCL/cl.h CL/cl.h
    HINTS
    ${OPENCL_ROOT}/include
    $ENV{AMDAPPSDKROOT}/include
    PATHS
    /usr/include
    /usr/local/include
    /opt/rocm/opencl/include
    DOC "OpenCL header file path"
    )
mark_as_advanced( OPENCL_INCLUDE_DIRS )

if("${CMAKE_SIZEOF_VOID_P}" EQUAL "8")
    find_library( OPENCL_LIBRARIES
        NAMES OpenCL
        HINTS
        ${OPENCL_ROOT}/lib
        $ENV{AMDAPPSDKROOT}/lib
        DOC "OpenCL dynamic library path"
        PATH_SUFFIXES x86_64 x64 x86_64/sdk
        PATHS
        /usr/lib
        /opt/rocm/opencl/lib
        )
else( )
    find_library( OPENCL_LIBRARIES
        NAMES OpenCL
        HINTS
        ${OPENCL_ROOT}/lib
        $ENV{AMDAPPSDKROOT}/lib
        DOC "OpenCL dynamic library path"
        PATH_SUFFIXES x86 Win32

        PATHS
        /usr/lib
        )
endif( )
mark_as_advanced( OPENCL_LIBRARIES )

include( FindPackageHandleStandardArgs )
find_package_handle_standard_args( OPENCL DEFAULT_MSG OPENCL_LIBRARIES OPENCL_INCLUDE_DIRS )

set(OpenCL_FOUND ${OPENCL_FOUND} CACHE INTERNAL "")
set(OpenCL_LIBRARIES ${OPENCL_LIBRARIES} CACHE INTERNAL "")
set(OpenCL_INCLUDE_DIRS ${OPENCL_INCLUDE_DIRS} CACHE INTERNAL "")

if( NOT OPENCL_FOUND )
    message( STATUS "FindOpenCL looked for libraries named: OpenCL" )
endif()
""")

def generateModuleH(graph,fileName):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#ifndef included_file_annmodule_h
#define included_file_annmodule_h

#include <VX/vx.h>

////
// initialize graph neural network for inference
""")
        for tensor in graph.inputs:
            f.write( \
"""//   %s -- dims[] = { %s } (input)
""" % (tensor.name, ', '.join([str(v) for v in reversed(tensor.shape)])))
        for tensor in graph.outputs:
            f.write( \
"""//   %s -- dims[] = { %s, } (output)
""" % (tensor.name, ', '.join([str(v) for v in reversed(tensor.shape)])))
        f.write( \
"""//
extern "C" VX_API_ENTRY vx_status VX_API_CALL annAddToGraph(vx_graph graph, %s, %s, const char * binaryFilename);

#endif
""" % (', '.join(['vx_tensor ' + tensor.name for tensor in graph.inputs]), \
       ', '.join(['vx_tensor ' + tensor.name for tensor in graph.outputs])))

def generateModuleCPP(graph,fileName):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#include "annmodule.h"
#include <VX/vx_khr_nn.h>
#include <vx_amd_nn.h>
#include <vx_ext_amd.h>
#include <stdio.h>

#define ERROR_CHECK_OBJECT(obj) { vx_status status = vxGetStatus((vx_reference)(obj)); if(status != VX_SUCCESS) { vxAddLogEntry((vx_reference)context, status     , "ERROR: failed with status = (%%d) at " __FILE__ "#%%d\\n", status, __LINE__); return status; } }
#define ERROR_CHECK_STATUS(call) { vx_status status = (call); if(status != VX_SUCCESS) { vxAddLogEntry((vx_reference)context, status, "ERROR: failed with status = (%%d) at " __FILE__ "#%%d\\n", status, __LINE__); return status; } }

static vx_status initializeTensor(vx_context context, vx_tensor tensor, FILE * fp, const char * binaryFilename)
{
    vx_enum data_type = VX_TYPE_FLOAT32;
    vx_size num_of_dims = 4, dims[4] = { 1, 1, 1, 1 }, stride[4];
    ERROR_CHECK_STATUS(vxQueryTensor(tensor, VX_TENSOR_DATA_TYPE, &data_type, sizeof(vx_enum)));
    ERROR_CHECK_STATUS(vxQueryTensor(tensor, VX_TENSOR_NUMBER_OF_DIMS, &num_of_dims, sizeof(vx_size)));
    ERROR_CHECK_STATUS(vxQueryTensor(tensor, VX_TENSOR_DIMS, &dims, num_of_dims * sizeof(vx_size)));
    vx_size itemsize = sizeof(float);
    if(data_type == VX_TYPE_UINT8 || data_type == VX_TYPE_INT8) {
        itemsize = sizeof(vx_uint8);
    }
    else if(data_type == VX_TYPE_UINT16 || data_type == VX_TYPE_INT16 || data_type == VX_TYPE_FLOAT16) {
        itemsize = sizeof(vx_uint16);
    }
    else if(data_type == VX_TYPE_INT64) {
        itemsize = sizeof(vx_int64);
    }
    vx_size count = dims[0] * dims[1] * dims[2] * dims[3];

    vx_uint32 h[2] = { 0 };
    fread(h, 1, sizeof(h), fp);
    if(h[0] != 0xf00dd1e1 || (vx_size)h[1] != (count*itemsize)) {
      vxAddLogEntry((vx_reference)tensor, VX_FAILURE, "ERROR: invalid data (magic,size)=(0x%%x,%%d) in %%s at byte position %%d -- expected size is %%ld\\n", h[0], h[1], binaryFilename, ftell(fp)-sizeof(h), count*itemsize);
      return VX_FAILURE;
    }

    vx_map_id map_id;
    void * ptr;
    ERROR_CHECK_STATUS(vxMapTensorPatch(tensor, num_of_dims, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST, 0));
    vx_size n = fread(ptr, itemsize, count, fp);
    if(n != count) {
        vxAddLogEntry((vx_reference)tensor, VX_FAILURE, "ERROR: expected char[%%ld], but got char[%%ld] in %%s\\n", count*itemsize, n*itemsize, binaryFilename);
        return VX_FAILURE;
    }
    ERROR_CHECK_STATUS(vxUnmapTensorPatch(tensor, map_id));

    return VX_SUCCESS;
}

VX_API_ENTRY vx_status VX_API_CALL annAddToGraph(vx_graph graph, %s, %s, const char * binaryFilename)
{
    vx_context context = vxGetContext((vx_reference)graph);
    ERROR_CHECK_OBJECT(context);
    ERROR_CHECK_STATUS(vxLoadKernels(context, "vx_nn"));

    // create variables
""" % (', '.join(['vx_tensor ' + tensor.name for tensor in graph.inputs]), \
       ', '.join(['vx_tensor ' + tensor.name for tensor in graph.outputs])))
        for tensor in graph.initializers:
            f.write( \
"""    vx_size dims_%s[%d] = { %s };
    vx_tensor %s = vxCreateTensor(context, %d, dims_%s, %s, 0);
    ERROR_CHECK_OBJECT(%s);
""" %(tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), \
      tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], tensor.name))
        f.write( \
"""
    // initialize variables
    FILE * fp__variables = fopen(binaryFilename, "rb");
    if(!fp__variables) {
        vxAddLogEntry((vx_reference)context, VX_FAILURE, "ERROR: unable to open: %s\\n", binaryFilename);
        return VX_FAILURE;
    }
    { vx_uint32 magic = 0;
      fread(&magic, 1, sizeof(magic), fp__variables);
      if(magic != 0xf00dd1e0) {
        vxAddLogEntry((vx_reference)context, VX_FAILURE, "ERROR: invalid file magic in %s\\n", binaryFilename);
        return VX_FAILURE;
      }
    }
""")
        for tensor in graph.initializers:
            f.write( \
"""    ERROR_CHECK_STATUS(initializeTensor(context, %s, fp__variables, binaryFilename));
""" %(tensor.name))
        f.write( \
"""    { vx_uint32 magic = 0;
      fread(&magic, 1, sizeof(magic), fp__variables);
      if(magic != 0xf00dd1e2) {
        vxAddLogEntry((vx_reference)context, VX_FAILURE, "ERROR: invalid eoff magic in %s\\n", binaryFilename);
        return VX_FAILURE;
      }
      fclose(fp__variables);
    }

    // create local tensors used in graph
""")
        localList = []
        for tensor in graph.locals:
            localList.append(tensor.name)
        outputList = []
        for tensor in graph.outputs:
            outputList.append(tensor.name)
        for idx, tensor in enumerate(graph.locals):
            if (not tensor.name in outputList) and (not tensor.name in localList[:idx]):
                f.write( \
"""    vx_size dims_%s[%d] = { %s };
    vx_tensor %s = vxCreateVirtualTensor(graph, %d, dims_%s, %s, 0);
    ERROR_CHECK_OBJECT(%s);
""" %(tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), \
      tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], tensor.name))
        f.write( \
"""
    // create nodes in graph
""")
        for node in graph.nodes:
            if node.type == 'conv':
                pads = node.attr.get('pads')
                dilations = node.attr.get('dilations')
                f.write( \
"""
    { vx_nn_convolution_params_t conv_params = { 0 };
      conv_params.padding_x = %d;
      conv_params.padding_y = %d;
      conv_params.overflow_policy = VX_CONVERT_POLICY_SATURATE;
      conv_params.rounding_policy = VX_ROUND_POLICY_TO_NEAREST_EVEN;
      conv_params.down_scale_size_rounding = VX_NN_DS_SIZE_ROUNDING_FLOOR;
      conv_params.dilation_x = %d;
      conv_params.dilation_y = %d;
      vx_node node = vxConvolutionLayer(graph, %s, %s, %s, &conv_params, sizeof(conv_params), %s);
      ERROR_CHECK_OBJECT(node);
""" % (pads[0], pads[1], dilations[0] - 1, dilations[1] - 1, \
      node.inputs[0], node.inputs[1], node.inputs[2] if len(node.inputs) == 3 else 'NULL', node.outputs[0]))
                if (node.attr.get('mode') != 0):
                    f.write( \
"""      vx_float32 alpha = 0;
      vx_scalar s_alpha = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &alpha, sizeof(alpha));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 5, (vx_reference) s_alpha));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_alpha));
""")
                if (node.attr.get('group') > 1):
                    group = node.attr.get('group');
                    f.write( \
"""      vx_int32 groupCount = %d;
      vx_scalar s_groupCount = vxCreateScalarWithSize(context, VX_TYPE_INT32, &groupCount, sizeof(groupCount));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 6, (vx_reference) s_groupCount));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_groupCount));
""" % (group))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'conv_transpose':
                pads = node.attr.get('pads')
                dilations = node.attr.get('dilations')
                kernel_shape = node.attr.get('kernel_shape')
                output_pads = [(dilations[0] - 1) * (kernel_shape[0] - 1), \
                                (dilations[1] - 1) * (kernel_shape[1] - 1)]
                f.write( \
"""
    { vx_nn_deconvolution_params_t conv_params = { 0 };
      conv_params.padding_x = %d;
      conv_params.padding_y = %d;
      conv_params.overflow_policy = VX_CONVERT_POLICY_SATURATE;
      conv_params.rounding_policy = VX_ROUND_POLICY_TO_NEAREST_EVEN;
      conv_params.a_x = %d;
      conv_params.a_y = %d;
      vx_node node = vxDeconvolutionLayer(graph, %s, %s, %s, &conv_params, sizeof(conv_params), %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (pads[0], pads[1], output_pads[0] , output_pads[1] , \
      node.inputs[0], node.inputs[1], node.inputs[2] if len(node.inputs) == 3 else 'NULL', node.outputs[0]))
            elif node.type == 'gemm':
                alpha = node.attr.get('alpha')
                beta = node.attr.get('beta')
                transA = node.attr.get('transA')
                transB = node.attr.get('transB')
                hasBias = False
                if beta == 1.0 and len(node.inputs) == 3 and len(graph.tensor_shapes[node.inputs[2]]) <= 2:
                    hasBias = True
                if alpha == 1.0 and transA == 0 and transB == 1 and (beta == 0.0 or hasBias):
                    f.write( \
"""
    { vx_node node = vxFullyConnectedLayer(graph, %s, %s, %s, VX_CONVERT_POLICY_SATURATE, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % ( \
        node.inputs[0], node.inputs[1], node.inputs[2] if hasBias else 'NULL', node.outputs[0]))
                else:
                    raise ValueError("Unsupported gemm configuration by OpenVX: alpha={} beta={} transA={} transB={}".format(alpha, beta, transA, transB))
            elif node.type == 'max_pool' or node.type == 'avg_pool':
                f.write( \
"""
    { vx_node node = vxPoolingLayer(graph, %s, %s, %d, %d, %d, %d, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_OBJECT(node);
      vx_enum border_mode = %d;
      vx_scalar s_border_mode = vxCreateScalarWithSize(context, VX_TYPE_ENUM, &border_mode, sizeof(border_mode));
      ERROR_CHECK_OBJECT(s_border_mode);
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 8, (vx_reference) s_border_mode));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_border_mode));
""" % (node.inputs[0], 'VX_NN_POOLING_AVG' if node.type == 'avg_pool' else 'VX_NN_POOLING_MAX', \
       node.attr.get('kernel_shape')[0], node.attr.get('kernel_shape')[1], \
       node.attr.get('pads')[0], node.attr.get('pads')[1], node.outputs[0], \
       (1 if node.attr.get('border_mode') == 'discard' else 0)))
                if (node.attr.get('mode') != 0):
                    f.write( \
"""      vx_int32 mode = %s;
      vx_scalar s_mode = vxCreateScalarWithSize(context, VX_TYPE_INT32, &mode, sizeof(mode));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 9, (vx_reference) s_mode));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_mode));
""" % (node.attr.get('mode')))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'global_avg_pool':
                f.write( \
"""
    { vx_node node = vxPoolingLayer(graph, %s, VX_NN_POOLING_AVG, %d, %d, %d, %d, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_OBJECT(node);
""" % (node.inputs[0], graph.tensor_shapes[node.inputs[0]][2], graph.tensor_shapes[node.inputs[0]][3], \
       node.attr.get('pads')[0], node.attr.get('pads')[1], node.outputs[0]))
                if (node.attr.get('mode') != 0):
                    f.write( \
"""      vx_int32 mode = %s;
      vx_scalar s_mode = vxCreateScalarWithSize(context, VX_TYPE_INT32, &mode, sizeof(mode));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 9, (vx_reference) s_mode));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_mode));
""" % (node.attr.get('mode')))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'relu':
                f.write( \
"""
    { vx_node node = vxActivationLayer(graph, %s, VX_NN_ACTIVATION_RELU, 0.0f, 0.0f, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'leaky_relu':
                f.write( \
"""
    {  vx_node node = vxActivationLayer(graph, %s, VX_NN_ACTIVATION_LEAKY_RELU, %f, 0.0f, %s);
       ERROR_CHECK_OBJECT(node);
       ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.attr.get('alpha'), node.outputs[0]))
            elif node.type == 'add' or node.type == 'sum':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_node node = vxTensorAddNode(graph, %s, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'sub':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_node node = vxTensorSubtractNode(graph, %s, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'mul':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_float32 value = 1.0f;
      vx_scalar scale = vxCreateScalar(context, VX_TYPE_FLOAT32, &value);
      ERROR_CHECK_OBJECT(scale);
      vx_node node = vxTensorMultiplyNode(graph, %s, %s, scale, VX_CONVERT_POLICY_SATURATE, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_STATUS(vxReleaseScalar(&scale));
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'muladd':
                tensor = graph.tensor_dict[node.inputs[0]]
                f.write( \
"""
    { vx_float32 value = 1.0f;
      vx_scalar scale = vxCreateScalar(context, VX_TYPE_FLOAT32, &value);
      ERROR_CHECK_OBJECT(scale);
      vx_size dims[%d] = { %s };
      vx_tensor tmp__tensor = vxCreateVirtualTensor(graph, %d, dims, %s, 0);
      ERROR_CHECK_OBJECT(tmp__tensor);
      vx_node node = vxTensorMultiplyNode(graph, %s, %s, scale, VX_CONVERT_POLICY_SATURATE, VX_ROUND_POLICY_TO_NEAREST_EVEN, tmp__tensor);
      ERROR_CHECK_STATUS(vxReleaseScalar(&scale));
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
      node = vxTensorAddNode(graph, tmp__tensor, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), len(tensor.shape), \
       tensor_type_nnir2openvx[tensor.type], node.inputs[0], node.inputs[1], node.inputs[2], node.outputs[0]))
            elif node.type == 'batch_norm':
                f.write( \
"""
    { vx_node node = vxBatchNormalizationLayer(graph, %s, %s, %s, %s, %s, %ef, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[3], node.inputs[4], node.inputs[1], node.inputs[2], node.attr.get('epsilon'), node.outputs[0]))
            elif node.type == 'lrn':
                f.write( \
"""
    { vx_node node = vxNormalizationLayer(graph, %s, %s , %d, %ef, %ef, %s);
""" % (node.inputs[0], "VX_NN_NORMALIZATION_SAME_MAP" if node.attr.get('mode') == 0 else "VX_NN_NORMALIZATION_ACROSS_MAPS" , \
       node.attr.get('size'), node.attr.get('alpha'), node.attr.get('beta'), node.outputs[0]))
                if (node.attr.get('bias') != 1.0):
                    f.write( \
"""   vx_float32 bias = %s;
      vx_scalar s_bias = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &bias, sizeof(bias));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 6, (vx_reference) s_bias));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_bias));
""" % (node.attr.get('bias')))
                f.write( \
"""   ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")                        
            elif node.type == 'slice':
                f.write( \
"""
    { vx_node node = vxSliceLayer(graph, %s, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], ', '.join([name for name in node.outputs]), \
       ', '.join(['NULL' for i in range(8 - len(node.outputs))])))
            elif node.type == 'concat':
                f.write( \
"""
    { vx_node node = vxConcatLayer(graph, %s, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.outputs[0], ', '.join([name for name in node.inputs]), \
       ', '.join(['NULL' for i in range(8 - len(node.inputs))])))
            elif node.type == 'softmax':
                f.write( \
"""
    { vx_node node = vxSoftmaxLayer(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'reshape' or node.type == 'flatten':
                f.write( \
"""
    { vx_node node = vxReshapeLayer(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'copy'or node.type == 'transpose' or node.type == 'permute':  
                if node.type == 'copy':
                    order = 0
                elif node.type == 'transpose':
                    axes = node.attr.get('axes')            
                    if axes == [0, 2, 3, 1]:
                        order = 1
                    elif axes == [0, 3, 1, 2]:
                        order = 2
                elif node.type == 'permute':
                    order_list = node.attr.get('order')            
                    if order_list == [0, 2, 3, 1]:
                        order = 1
                    elif order_list == [0, 3, 1, 2]:
                        order = 2
                f.write( \
"""
    { vx_node node = vxPermuteLayer(graph, %s, %d, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], order, node.outputs[0]))
            elif node.type == 'prior_box':
                aspect_ratio = node.attr.get('aspect_ratio')
                aspect_ratio_str = ','.join(str(e) for e in aspect_ratio)
                variance = node.attr.get('variance')
                variance_str = ','.join(str(e) for e in variance)
                f.write( \
"""
    { 
      vx_float32 min_size = %f;
      vx_float32 max_size = %f;
      vx_int32 flip = %d;
      vx_int32 clip = %d;
      vx_float32 offset = %f;
      
      vx_scalar s_min_size = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &min_size, sizeof(min_size));
      vx_scalar s_max_size = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &max_size, sizeof(max_size));    
      vx_scalar s_flip = vxCreateScalarWithSize(context, VX_TYPE_INT32, &flip, sizeof(flip));
      vx_scalar s_clip = vxCreateScalarWithSize(context, VX_TYPE_INT32, &clip, sizeof(clip));
      vx_scalar s_offset = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &offset, sizeof(offset));
      vx_node node = vxPriorBoxLayer(graph, %s, %s, s_min_size, %s , s_flip, s_clip, s_offset, %s , s_max_size, %s );
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.attr.get('min_size'), node.attr.get('max_size'), node.attr.get('flip'), node.attr.get('clip'), \
       node.attr.get('prior_offset'), node.inputs[0], node.inputs[1], aspect_ratio_str, node.outputs[0], variance_str))
            elif node.type == 'upsample':
                f.write( \
"""
    { vx_node node = vxUpsampleNearestLayer(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'crop':
                offset = node.attr.get('offset')
                f.write( \
"""
    { 
      vx_int32 axis = %d;
      vx_int32 offset1 = %d;
      vx_int32 offset2 = %d;
      vx_int32 offset3 = %d;
      vx_int32 offset4 = %d;
      vx_scalar s_axis = vxCreateScalarWithSize(context, VX_TYPE_INT32, &axis, sizeof(axis));      
      vx_scalar s_offset1 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset1, sizeof(offset1));
      vx_scalar s_offset2 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset2, sizeof(offset2));
      vx_scalar s_offset3 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset3, sizeof(offset3));
      vx_scalar s_offset4 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset4, sizeof(offset4));
      vx_node node = vxCropLayer(graph, %s, %s, %s, s_axis, s_offset1, s_offset2, s_offset3, s_offset4);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.attr.get('axis'), offset[0], offset[1], offset[2], offset[3], node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'crop_and_resize':
                f.write( \
"""
    { 
      vx_node node = vxCropAndResizeLayer(graph, %s, %s, %d, %d, %d, %d, %d, %d);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.outputs[0], node.attr.get('coord')[0], node.attr.get('coord')[1], node.attr.get('shape')[0], node.attr.get('shape')[1], node.attr.get('scale'), node.attr.get('mode')))
            else:
                raise ValueError("Unsupported node by OpenVX: {}".format(node.type))
        f.write( \
"""
    // release local tensors
""")
        for idx, tensor in enumerate(graph.locals):
            if (not tensor.name in outputList) and (not tensor.name in localList[:idx]):
                f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" %(tensor.name))
        f.write( \
"""
    // release initializer tensors
""")
        for tensor in graph.initializers:
            f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" %(tensor.name))
        f.write( \
"""
    return VX_SUCCESS;
}
""")

def generatePythonH(graph, fileName):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#ifndef included_file_annpython_h
#define included_file_annpython_h

#include <VX/vx.h>

////
// python interface handle: upto 8 outputs
//
typedef struct pyif_ann_handle_t {
    vx_context  context;
    vx_graph    graph;
    vx_tensor   input;
    vx_tensor   output[8];
    int         num_output;
} * pyif_ann_handle;

////
// python interface functions
//
extern "C" VX_API_ENTRY const char *    VX_API_CALL annQueryInference();
extern "C" VX_API_ENTRY pyif_ann_handle VX_API_CALL annCreateInference(const char * binaryFilename);
extern "C" VX_API_ENTRY int             VX_API_CALL annReleaseInference(pyif_ann_handle handle);
extern "C" VX_API_ENTRY int             VX_API_CALL annCopyToInferenceInput(pyif_ann_handle handle, float * inp_ptr, size_t inp_size, bool is_nhwc);
extern "C" VX_API_ENTRY int             VX_API_CALL annCopyFromInferenceOutput(pyif_ann_handle handle, float * out_ptr, size_t out_size);
""")
        for i in range(1, len(graph.outputs)):
            f.write( \
"""extern "C" VX_API_ENTRY int             VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size);
""" % (i))
        f.write( \
"""extern "C" VX_API_ENTRY int             VX_API_CALL annRunInference(pyif_ann_handle handle, int num_iterations);

#endif
""")

def generatePythonCPP(graph,fileName):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        if len(graph.inputs) != 1: #or len(graph.outputs) != 1:
            f.write( \
"""
#include "annpython.h"

VX_API_ENTRY const char * VX_API_CALL annQueryInference()
{
    return "Unsupported";
}

VX_API_ENTRY pyif_ann_handle VX_API_CALL annCreateInference(const char * binaryFilename)
{
    return NULL;
}

VX_API_ENTRY int VX_API_CALL annReleaseInference(pyif_ann_handle handle)
{
    return -1;
}

VX_API_ENTRY int VX_API_CALL annRunInference(pyif_ann_handle handle, float * inp_ptr, size_t inp_size, float * out_ptr, size_t out_size)
{
    return -1;
}
""")
        else:
            input_shape = graph.inputs[0].shape
            input_buf_size = eval('*'.join([str(v) for v in input_shape])) * 4
            output_shape = []
            output_buf_size = []
            output_str = []
            config = 'input,' + graph.inputs[0].name + ',' + ','.join(str(v) for v in input_shape) + ';'
            for i in range(len(graph.outputs)):
                output_shape.append(graph.outputs[i].shape)
                output_buf_size.append(eval('*'.join([str(v) for v in output_shape[i]])) * 4)
                config += 'output' + str(i) + ',' +graph.outputs[i].name + ',' + ','.join(str(v) for v in output_shape[i])+';'
                output_str.append('handle->output[' + str(i) + ']')
            f.write( \
"""
#include "annpython.h"
#include "annmodule.h"
#include <stdio.h>
#include <string.h>
#include <string>
#include <vx_ext_amd.h>

static void VX_CALLBACK log_callback(vx_context context, vx_reference ref, vx_status status, const vx_char string[])
{
    size_t len = strlen(string);
    if (len > 0) {
        printf("%%s", string);
        if (string[len - 1] != '\\n')
            printf("\\n");
        fflush(stdout);
    }
}

VX_API_ENTRY const char * VX_API_CALL annQueryInference()
{
    return "%s";
}
""" % (config));

            f.write( \
"""
VX_API_ENTRY pyif_ann_handle VX_API_CALL annCreateInference(const char * binaryFilename)
{
    bool successful = false;

    pyif_ann_handle handle = new pyif_ann_handle_t();
    if(!handle) {
        printf("ERROR: new pyif_ann_handle: failed (nullptr)\\n");
    }
    else {
        vx_status status;
        vxRegisterLogCallback(NULL, log_callback, vx_false_e);
        handle->context = vxCreateContext();
        if((status = vxGetStatus((vx_reference)handle->context)) != VX_SUCCESS) {
            printf("ERROR: vxCreateContext: failed (%%d)\\n", status);
        }
        else {
            handle->graph = vxCreateGraph(handle->context);
            if((status = vxGetStatus((vx_reference)handle->graph)) != VX_SUCCESS) {
                printf("ERROR: vxCreateGraph: failed (%%d)\\n", status);
            }
            else {
                vx_size inp_dim[4] = { %s };
                handle->input = vxCreateTensor(handle->context, 4, inp_dim, VX_TYPE_FLOAT32, 0);
                if((status = vxGetStatus((vx_reference)handle->input)) != VX_SUCCESS) {
                    printf("ERROR: vxCreateTensor(input:[%s]): failed (%%d)\\n", status);
                }
                else {
                    handle->num_output = %d;
""" % (', '.join([str(v) for v in reversed(input_shape)]), 'x'.join([str(v) for v in input_shape]),len(graph.outputs)))
            for i in range(len(graph.outputs)):
                f.write( \
"""                    vx_size out_dim_%d[%d] = { %s };
                    handle->output[%d] = vxCreateTensor(handle->context, %d, out_dim_%d, VX_TYPE_FLOAT32, 0);
                    if((status = vxGetStatus((vx_reference)handle->output[%d])) != VX_SUCCESS) {
                        printf("ERROR: vxCreateTensor(output:[%s]): failed (%%d)\\n", status);
                    }
""" % (i, len(output_shape[i]), ', '.join([str(v) for v in reversed(output_shape[i])]), i, \
       len(output_shape[i]), i, i, 'x'.join([str(v) for v in output_shape[i]])))
            f.write( \
"""                    else if((status = annAddToGraph(handle->graph, handle->input, %s, binaryFilename)) != VX_SUCCESS) {
                        printf("ERROR: annAddToGraph: failed (%%d)\\n", status);
                    }
                    else if((status = vxVerifyGraph(handle->graph)) != VX_SUCCESS) {
                        printf("ERROR: vxVerifyGraph: failed (%%d)\\n", status);
                    }
                    else {
                        printf("OK: annCreateInference: successful\\n");
                        successful = true;
                    }
                }
            }
        }
    }

    if(!successful) {
        if(handle) {
            if(handle->graph)
                vxReleaseGraph(&handle->graph);
            if(handle->input)
                vxReleaseTensor(&handle->input);
""" %(', '.join(output_str)))
            for i in range(len(graph.outputs)):
                f.write( \
"""            if(handle->output[%d])
                vxReleaseTensor(&handle->output[%d]);
""" % ( i, i))          
            f.write( \
"""            if(handle->context)
                vxReleaseContext(&handle->context);
            delete handle;
            handle = nullptr;
        }
    }
    return handle;
}

VX_API_ENTRY int VX_API_CALL annReleaseInference(pyif_ann_handle handle)
{
    vx_status status = VX_SUCCESS;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annReleaseInference: invalid handle\\n");
    }
    else if(handle->graph && (status = vxReleaseGraph(&handle->graph)) != VX_SUCCESS) {
        printf("ERROR: annReleaseInference: vxReleaseGraph: failed (%d)\\n", status);
    }
    else if(handle->input && (status = vxReleaseTensor(&handle->input)) != VX_SUCCESS) {
        printf("ERROR: annReleaseInference: vxReleaseTensor(input): failed (%d)\\n", status);
    }
    else {
        for (int i=0; i<handle->num_output; i++) {
            if(handle->output[i] && (status = vxReleaseTensor(&handle->output[i])) != VX_SUCCESS) {
                printf("ERROR: annReleaseInference: vxReleaseTensor(output<%d>): failed (%d)\\n", i,status);
            }
        }
    }
    if(handle->context && (status = vxReleaseContext(&handle->context)) != VX_SUCCESS) {
        printf("ERROR: annReleaseInference: vxReleaseContext: failed (%d)\\n", status);
    }
    else {
        delete handle;
    }
    return status;
}
""")

            f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyToInferenceInput(pyif_ann_handle handle, float * inp_ptr, size_t inp_size, bool is_nhwc)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    vx_map_id map_id;
    float * ptr = nullptr;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyToInferenceInput: invalid handle\\n");
    }
    else if(inp_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyToInferenceInput: invalid input buffer size (must be %d) -- got %%d\\n", (int)inp_size);
    }
    else if(handle->input == nullptr) {
        printf("ERROR: annCopyToInferenceInput: input is not valid\\n");
    }
    else if(!is_nhwc) {
        if((status = vxCopyTensorPatch(handle->input, 4, nullptr, nullptr, stride, inp_ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
            printf("ERROR: annCopyFromInferenceOutput: vxCopyTensorPatch: failed (%%d)\\n", status);
        }
    }
    else if((status = vxMapTensorPatch(handle->input, 4, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST, 0)) != VX_SUCCESS) {
        printf("ERROR: annCopyToInferenceInput: vxMapTensorPatch: failed (%%d)\\n", status);
    }
    else {
        size_t N = %d, C = %d, H = %d, W = %d;
        for(size_t n = 0; n < N; n++) {
            for(size_t c = 0; c < C; c++) {
                for(size_t y = 0; y < H; y++) {
                    size_t tpos = n * C*H*W + c * H*W + y * W;
                    size_t ipos = n * H*W*C + y * W*C + c;
                    for(size_t x = 0; x < W; x++, tpos++, ipos += C) {
                        ptr[tpos] = inp_ptr[ipos];
                    }
                }
            }
        }
        if ((status = vxUnmapTensorPatch(handle->input, map_id)) != VX_SUCCESS) {
            printf("ERROR: annCopyToInferenceInput: vxUnmapTensorPatch: failed (%%d)\\n", status);
        }
    }
    return status;
}
""" % (input_shape[3]*4, input_shape[2]*input_shape[3]*4, input_shape[1]*input_shape[2]*input_shape[3]*4, \
       input_buf_size, input_buf_size, input_shape[0], input_shape[1], input_shape[2], input_shape[3]))
            tshape = []
            for i in range(len(graph.outputs)):
                if len(output_shape[i]) == 4:
                    tshape.append([output_shape[i][0], output_shape[i][1], output_shape[i][2], output_shape[i][3]])
                else:
                    tshape.append([1, 1, output_shape[i][0], output_shape[i][1]])
            f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output[0] && (status = vxCopyTensorPatch(handle->output[0], %d, nullptr, nullptr, stride, out_ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyFromInferenceOutput: vxCopyTensorPatch: failed (%%d)\\n", status);
    }
    return status;
}
"""   % (tshape[0][3]*4, tshape[0][2]*tshape[0][3]*4, tshape[0][1]*tshape[0][2]*tshape[0][3]*4, output_buf_size[0], output_buf_size[0], len(output_shape[0])))
            if (len(graph.outputs) > 1):
                f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output && (status = vxCopyTensorPatch(handle->output[1], %d, nullptr, nullptr, stride, out_ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyFromInferenceOutput: vxCopyTensorPatch: failed (%%d)\\n", status);
    }
    return status;
}
"""   % (1, tshape[1][3]*4, tshape[1][2]*tshape[1][3]*4, tshape[1][1]*tshape[1][2]*tshape[1][3]*4, output_buf_size[1], output_buf_size[1], len(output_shape[1])))
            if (len(graph.outputs) > 2):
                f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output && (status = vxCopyTensorPatch(handle->output[2], %d, nullptr, nullptr, stride, out_ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyFromInferenceOutput: vxCopyTensorPatch: failed (%%d)\\n", status);
    }
    return status;
}
"""   % (2, tshape[2][3]*4, tshape[2][2]*tshape[2][3]*4, tshape[2][1]*tshape[2][2]*tshape[2][3]*4, output_buf_size[2], output_buf_size[2], len(output_shape[2])))
            f.write( \
"""

VX_API_ENTRY int VX_API_CALL annRunInference(pyif_ann_handle handle, int num_iterations)
{
    vx_status status = VX_SUCCESS;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annRunInference: invalid handle\\n");
    }
    else {
        for(int i = 0; i < num_iterations; i++) {
            if((status = vxProcessGraph(handle->graph)) != VX_SUCCESS)
                break;
        }
    }
    return status;
}
""")

def generatePythonScriptSample(graph,fileName):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForScript(f)
        f.write( \
"""
import sys,os,ctypes 
import numpy as np
from numpy.ctypeslib import ndpointer

class AnnAPI:
    def __init__(self,library):
        self.lib = ctypes.cdll.LoadLibrary(library)
        self.annQueryInference = self.lib.annQueryInference
        self.annQueryInference.restype = ctypes.c_char_p
        self.annQueryInference.argtypes = []
        self.annCreateInference = self.lib.annCreateInference
        self.annCreateInference.restype = ctypes.c_void_p
        self.annCreateInference.argtypes = [ctypes.c_char_p]
        self.annReleaseInference = self.lib.annReleaseInference
        self.annReleaseInference.restype = ctypes.c_int
        self.annReleaseInference.argtypes = [ctypes.c_void_p]
        self.annCopyToInferenceInput = self.lib.annCopyToInferenceInput
        self.annCopyToInferenceInput.restype = ctypes.c_int
        self.annCopyToInferenceInput.argtypes = [ctypes.c_void_p, ndpointer(ctypes.c_float, flags="C_CONTIGUOUS"), ctypes.c_size_t, ctypes.c_bool]
        self.annCopyFromInferenceOutput = self.lib.annCopyFromInferenceOutput
        self.annCopyFromInferenceOutput.restype = ctypes.c_int
        self.annCopyFromInferenceOutput.argtypes = [ctypes.c_void_p, ndpointer(ctypes.c_float, flags="C_CONTIGUOUS"), ctypes.c_size_t]
        self.annRunInference = self.lib.annRunInference
        self.annRunInference.restype = ctypes.c_int
        self.annRunInference.argtypes = [ctypes.c_void_p, ctypes.c_int]
        print('OK: AnnAPI found "' + self.annQueryInference().decode("utf-8") + '" as configuration in ' + library)

if __name__ == '__main__':
    if len(sys.argv) < 4:
        print ("Usage : python anntest.py <libannpython> <weightsfile> <input_tensor_file> <output_tensor_file>")
        sys.exit(1)
    annlibPythonName = sys.argv[1]
    weightsFile = sys.argv[2]
    inputTensorFile = sys.argv[3]
    outputTensorFile = sys.argv[4]
    api = AnnAPI(annlibPythonName)
    input_info,output_info,temp = api.annQueryInference().decode("utf-8").split(';')
    input,name,ni,ci,hi,wi = input_info.split(',')
    hdl = api.annCreateInference(weightsFile)
    im = np.fromfile(inputTensorFile, dtype=np.float32)
    inp_size = int(ni)*int(ci)*int(hi)*int(wi)*4
    status = api.annCopyToInferenceInput(hdl, np.ascontiguousarray(im, dtype=np.float32), inp_size, 0)
    print('INFO: annCopyToInferenceInput status %d'  %(status))
    status = api.annRunInference(hdl, 1)
    print('INFO: annRunInference status %d ' %(status))
    output,name,n,c,h,w = output_info.split(',')
    out_size = int(n)*int(c)*int(h)*int(w)*4
    out_buf = bytearray(out_size)
    out = np.frombuffer(out_buf, dtype=np.float32)
    status = api.annCopyFromInferenceOutput(hdl, np.ascontiguousarray(out, dtype=np.float32), out_size)
    print('INFO: annCopyFromInferenceOutput status %d' %(status))
    fid = open(outputTensorFile, 'wb')
    fid.write(out.tobytes())
    fid.close()
""")

def generateTestCPP(graph,argmaxOutput,fileName):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#include "annmodule.h"
#include <vx_ext_amd.h>
#include <vx_amd_nn.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <stdio.h>
#include <string.h>
#include <string>
#include <inttypes.h>
#include <chrono>
#include <unistd.h>
#include <math.h>
#include <half.hpp>
#include <immintrin.h>
using half_float::half;

#if ENABLE_OPENCV
#include <opencv2/opencv.hpp>
#include <opencv/cv.h>
#include <opencv/highgui.h>
using namespace cv;
#endif

#define ERROR_CHECK_OBJECT(obj) { vx_status status = vxGetStatus((vx_reference)(obj)); if(status != VX_SUCCESS) { vxAddLogEntry((vx_reference)context, status     , "ERROR: failed with status = (%d) at " __FILE__ "#%d\\n", status, __LINE__); return status; } }
#define ERROR_CHECK_STATUS(call) { vx_status status = (call); if(status != VX_SUCCESS) { printf("ERROR: failed with status = (%d) at " __FILE__ "#%d\\n", status, __LINE__); return -1; } }
""")
        if type(argmaxOutput) is np.ndarray:
            f.write( \
"""
// LUT (R)
vx_uint8 g_lut__r[] = {
    %s
};
// LUT (G)
vx_uint8 g_lut__g[] = {
    %s
};
// LUT (B)
vx_uint8 g_lut__b[] = {
    %s
};
""" % (','.join([str(v) for v in argmaxOutput[0]]), ','.join([str(v) for v in argmaxOutput[1]]), ','.join([str(v) for v in argmaxOutput[2]])))
            if len(argmaxOutput) == 4:
                f.write( \
"""// LUT (A)
vx_uint8 g_lut__a[] = {
    %s
};
""" % (','.join([str(v) for v in argmaxOutput[3]])))
        f.write( \
"""
static void VX_CALLBACK log_callback(vx_context context, vx_reference ref, vx_status status, const vx_char string[])
{
    size_t len = strlen(string);
    if (len > 0) {
        printf("%s", string);
        if (string[len - 1] != '\\n')
            printf("\\n");
        fflush(stdout);
    }
}

inline int64_t clockCounter()
{
    return std::chrono::high_resolution_clock::now().time_since_epoch().count();
}

inline int64_t clockFrequency()
{
    return std::chrono::high_resolution_clock::period::den / std::chrono::high_resolution_clock::period::num;
}

static vx_status copyTensor(std::string tensorName, vx_tensor tensor, std::string args, vx_enum usage = VX_WRITE_ONLY)
{
    // split the args into fileName and other parameters
    std::vector<std::string> argList;
    std::istringstream sf(args);
    for(std::string s; std::getline(sf, s, ','); ) {
        argList.push_back(s);
    }
    std::string fileName = argList[0];
    // access the tensor object
    vx_enum data_type = VX_TYPE_FLOAT32;
    vx_size num_of_dims = 4, dims[4] = { 1, 1, 1, 1 }, stride[4];
    vxQueryTensor(tensor, VX_TENSOR_DATA_TYPE, &data_type, sizeof(data_type));
    vxQueryTensor(tensor, VX_TENSOR_NUMBER_OF_DIMS, &num_of_dims, sizeof(num_of_dims));
    vxQueryTensor(tensor, VX_TENSOR_DIMS, &dims, sizeof(dims[0])*num_of_dims);
    if((data_type != VX_TYPE_FLOAT32) && (data_type != VX_TYPE_FLOAT16)) {
        std::cerr << "ERROR: copyTensor() supports only VX_TYPE_FLOAT32 or VX_TYPE_FLOAT16: invalid for " << fileName << std::endl;
        return -1;
    }
    vx_size count = dims[0] * dims[1] * dims[2] * dims[3];
    vx_map_id map_id;
    void * ptr;
    vx_status status = vxMapTensorPatch(tensor, num_of_dims, nullptr, nullptr, &map_id, stride, (void **)&ptr, usage, VX_MEMORY_TYPE_HOST, 0);
    if(status) {
        std::cerr << "ERROR: vxMapTensorPatch() failed for " << fileName << std::endl;
        return -1;
    }
    if(usage == VX_WRITE_ONLY) {
#if ENABLE_OPENCV
        if(dims[2] == 3 && fileName.size() > 4 && (fileName.substr(fileName.size()-4, 4) == ".png" || fileName.substr(fileName.size()-4, 4) == ".jpg"))
        {
            for(size_t n = 0; n < dims[3]; n++) {
                char imgFileName[1024];
                sprintf(imgFileName, fileName.c_str(), (int)n);
                Mat img = imread(imgFileName, CV_LOAD_IMAGE_COLOR);
                if(!img.data || img.rows != dims[1] || img.cols != dims[0]) {
                    printf("ERROR: invalid image or dimensions: %s\\n", imgFileName);
                    return -1;
                }
                for(vx_size y = 0; y < dims[1]; y++) {
                    unsigned char * src = img.data + y*dims[0]*3;
                    if(data_type == VX_TYPE_FLOAT32) {
                        float * dstR = (float *)ptr + ((n * stride[3] + y * stride[1]) >> 2);
                        float * dstG = dstR + (stride[2] >> 2);
                        float * dstB = dstG + (stride[2] >> 2);
                        for(vx_size x = 0; x < dims[0]; x++, src += 3) {
                            *dstR++ = src[2];
                            *dstG++ = src[1];
                            *dstB++ = src[0];
                        }
                    } else
                    {
                        short * dstR = (short *)ptr + ((n * stride[3] + y * stride[1]) >> 1);
                        short * dstG = dstR + (stride[2] >> 2);
                        short * dstB = dstG + (stride[2] >> 2);                    
                        for(vx_size x = 0; x < dims[0]; x++, src += 3) {
                            *dstR++ = src[2];
                            *dstG++ = src[1];
                            *dstB++ = src[0];
                        }
                    }
                }
            }
        }
        else
#endif
        {
            FILE * fp = fopen(fileName.c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            for(size_t n = 0; n < dims[3]; n++) {
                for(size_t c = 0; c < dims[2]; c++) {
                    for(size_t y = 0; y < dims[1]; y++) {
                        if(data_type == VX_TYPE_FLOAT32) {
                            float * ptrY = (float *)ptr + ((n * stride[3] + c * stride[2] + y * stride[1]) >> 2);
                            vx_size n = fread(ptrY, sizeof(float), dims[0], fp);
                            if(n != dims[0]) {
                                std::cerr << "ERROR: expected char[" << count*sizeof(float) << "], but got less in " << fileName << std::endl;
                                return -1;
                            }
                        } else {
                            short * ptrY = (short *)ptr + ((n * stride[3] + c * stride[2] + y * stride[1]) >> 1);
                            vx_size n = fread(ptrY, sizeof(short), dims[0], fp);
                            if(n != dims[0]) {
                                std::cerr << "ERROR: expected char[" << count*sizeof(short) << "], but got less in " << fileName << std::endl;
                                return -1;
                            }
                        }
                    }
                }
            }
            fclose(fp);
        }
    }
    else {""")
        if type(argmaxOutput) is np.ndarray:
            f.write( \
"""
        vx_size W = 1, H =1, C = 1, N = 1;
        if (num_of_dims == 2) {
            C = dims[0], N = dims[1];
        }
        else if(num_of_dims == 4) {
            W = dims[0], H = dims[1], C = dims[2], N = dims[3];
        }
        vx_size HW = H * W;
        vx_size NHW = N * HW;
        vx_size CHW = C * HW;
        if(C < sizeof(g_lut__r) || C < sizeof(g_lut__g) || C < sizeof(g_lut__b)) {
            std::cerr << "ERROR: LUT doesn't have enough entries for all channels with " << fileName << std::endl;
            return -1;
        }
        vx_uint8 * buf = new vx_uint8[NHW*%d], * pb = buf;
        for(vx_size n = 0; n < N; n++) {
            for(vx_size y = 0; y < H; y++) {
                for(vx_size x = 0; x < W; x++) {
                    vx_size best_c = 0;
                    float * pc = (float *)ptr + n * CHW + y * W + x;
                    if (data_type == VX_TYPE_FLOAT32) {
                        float best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    else if (data_type == VX_TYPE_FLOAT16) {
                        half * pc = (half *)((short *)ptr + n * CHW + y * W + x);
                        half best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    *pb++ = g_lut__r[best_c];
                    *pb++ = g_lut__g[best_c];
                    *pb++ = g_lut__b[best_c];
                    %s
                }
            }
        }
#if ENABLE_OPENCV
        if(fileName.size() > 4 && (fileName.substr(fileName.size()-4, 4) == ".png"))
        {
            for(size_t n = 0; n < N; n++) {
                char imgFileName[1024];
                sprintf(imgFileName, fileName.c_str(), (int)n);
                Mat img(dims[1], dims[0], %s);
                if(!img.data || img.rows != H || img.cols != W) {
                    printf("ERROR: invalid image or dimensions: %%s\\n", imgFileName);
                    return -1;
                }
                for(vx_size y = 0; y < H; y++) {
                    size_t outC = %d;
                    vx_uint8 * dst = img.data + y * W * outC;
                    vx_uint8 * src = (vx_uint8 *)(buf + (n * H + y) * W * outC);
                    for(vx_size x = 0; x < W; x++, src += outC, dst += outC) {
                        dst[2] = src[0];
                        dst[1] = src[1];
                        dst[0] = src[2];
                        %s
                    }
                }
                std::vector<int> compression_params;
                compression_params.push_back(IMWRITE_PNG_COMPRESSION);
                compression_params.push_back(9);
                try {
                    imwrite(imgFileName, img, compression_params);
                }
                catch (cv::Exception& ex) {
                    printf("ERROR: exception converting image to PNG format: %%s\\n", ex.what());
                    return -1;
                }
            }
        }
        else
#endif
        if(fileName != "-") {
            FILE * fp = fopen(fileName.c_str(), "wb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            fwrite(buf, %d, NHW, fp);
            fclose(fp);
        }
        if(argList.size() >= 2) {
            size_t outC = %d;
            float errPercentLimit = 0;
            if(argList.size() >= 3) {
                errPercentLimit = (float)atof(argList[2].c_str());
            }
            vx_uint8 * gold = new vx_uint8 [NHW * outC];
            FILE * fp = fopen(argList[1].c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << argList[1] << std::endl;
                return -1;
            }
            if(fread(gold, outC, NHW, fp) != NHW) {
                std::cerr << "ERROR: not enought data (" << NHW * outC << " bytes) in " << argList[1] << std::endl;
                return -1;
            }
            fclose(fp);
            size_t errCount = 0;
            for(size_t i = 0; i < NHW * outC; i += outC) {
                bool isErr = false;
                for(size_t j = 0; j < outC; j++) {
                    if(!(buf[i+j] == gold[i+j]))
                        isErr = true;
                }
                if(isErr)
                    errCount++;
            }
            delete[] gold;
            float errPercent = 100.0f * (float)errCount / (float)count;
            bool isError = errPercent > errPercentLimit;
            printf("%%s: [Percent-Error %%e%%%%] for %%s with %%s\\n",
                isError ? "ERROR" : "OK", errPercent, tensorName.c_str(), argList[1].c_str());
            if(isError) {
                return -1;
            }
        }
        delete[] buf;
""" % (len(argmaxOutput), \
            '*pb++ = g_lut__a[best_c];' if len(argmaxOutput) == 4 else '', \
            'CV_8UC4' if len(argmaxOutput) == 4 else 'CV_8UC3', len(argmaxOutput), \
            'dst[3] = src[3];' if len(argmaxOutput) == 4 else '', len(argmaxOutput), \
            len(argmaxOutput)))
        elif type(argmaxOutput) is str:
            f.write( \
"""
        vx_size W = 1, H = 1, C = 1, N = 1;
        if(num_of_dims == 2) {
            C = dims[0], N = dims[1];
        }
        else if(num_of_dims == 4) {
            W = dims[0], H = dims[1], C = dims[2], N = dims[3];
        }
        vx_size HW = H * W;
        vx_size NHW = N * HW;
        vx_size CHW = C * HW;
        %s * buf = new %s[NHW], * pb = buf;
        for(vx_size n = 0; n < N; n++) {
            for(vx_size y = 0; y < H; y++) {
                for(vx_size x = 0; x < W; x++) {
                    vx_size best_c = 0;
                    if (data_type == VX_TYPE_FLOAT32) {
                        float * pc = (float *)ptr + n * CHW + y * W + x;
                        float best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }else {
                        half * pc = (half *)ptr + n * CHW + y * W + x;
                        half best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    *pb++ = (%s)best_c;
                }
            }
        }
        if(fileName != "-") {
            FILE * fp = fopen(fileName.c_str(), "wb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            fwrite(buf, sizeof(%s), NHW, fp);
            fclose(fp);
        }
        if(argList.size() >= 2) {
            float errPercentLimit = 0;
            if(argList.size() >= 3) {
                errPercentLimit = (float)atof(argList[2].c_str());
            }
            %s * gold = new %s [NHW];
            FILE * fp = fopen(argList[1].c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << argList[1] << std::endl;
                return -1;
            }
            if(fread(gold, sizeof(%s), NHW, fp) != NHW) {
                std::cerr << "ERROR: not enought data (" << count << " %s needed) in " << argList[1] << std::endl;
                return -1;
            }
            fclose(fp);
            size_t errCount = 0;
            for(size_t i = 0; i < NHW; i++) {
                if(!(buf[i] == gold[i]))
                    errCount++;
            }
            delete[] gold;
            float errPercent = 100.0f * (float)errCount / (float)count;
            bool isError = errPercent > errPercentLimit;
            printf("%%s: [Percent-Error %%e%%%%] for %%s with %%s\\n",
                isError ? "ERROR" : "OK", errPercent, tensorName.c_str(), argList[1].c_str());
            if(isError) {
                return -1;
            }
        }
        delete[] buf;
""" % (argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput))
        else:
            f.write( \
"""
        if(fileName != "-") {
            FILE * fp = fopen(fileName.c_str(), "wb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            if (data_type == VX_TYPE_FLOAT32)
                fwrite(ptr, sizeof(float), count, fp);
            else
                fwrite(ptr, sizeof(short), count, fp);                
            fclose(fp);
        }
        if(argList.size() >= 2) {
            float rmsErrorLimit = 0, maxErrorLimit = 0;
            if(argList.size() >= 3) {
                rmsErrorLimit = maxErrorLimit = (float)atof(argList[2].c_str());
                if(argList.size() >= 4) {
                    maxErrorLimit = (float)atof(argList[3].c_str());
                }
            }
            float * gold = new float [count];
            FILE * fp = fopen(argList[1].c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << argList[1] << std::endl;
                return -1;
            }
            if(fread(gold, sizeof(float), count, fp) != count) {
                std::cerr << "ERROR: not enought data (" << count << " floats needed) in " << argList[1] << std::endl;
                return -1;
            }
            fclose(fp);
            double sqrError = 0;
            float maxError = 0;
            if(data_type == VX_TYPE_FLOAT32) {
                for(size_t i = 0; i < count; i++) {
                    float err = ((float *)ptr)[i] - gold[i];
                    if(err < 0) err = -err;
                    sqrError += err * err;
                    if(!(err < maxError)) maxError = err;
                }
            }
            else
            {
                for(size_t i = 0; i < count; i++) {
                    float src = _cvtsh_ss(((unsigned short*)ptr)[i]);
                    float err = src - gold[i];
                    if(err < 0) err = -err;
                    sqrError += err * err;
                    if(!(err < maxError)) maxError = err;
                }
            }
            delete[] gold;
            float rmsError = (float)sqrt(sqrError/count);
            bool isError = !(rmsError <= rmsErrorLimit) || !(maxError <= maxErrorLimit);
            printf("%s: [RMS-Error %e] [MAX-Error %e] for %s with %s\\n",
                isError ? "ERROR" : "OK", rmsError, maxError, tensorName.c_str(), argList[1].c_str());
            if(isError) {
                return -1;
            }
        }
""")
        f.write( \
"""
    }
    status = vxUnmapTensorPatch(tensor, map_id);
    if(status) {
        std::cerr << "ERROR: vxUnmapTensorPatch() failed for " << fileName << std::endl;
        return -1;
    }
    return 0;
}

int main(int argc, const char ** argv)
{
    // check command-line usage
    if(argc < 2) {
        printf(
            "\\n"
            "Usage: anntest <weights.bin> [<input-data-file(s)> [<output-data-file(s)>]]]\\n"
            "\\n"
            "   <input-data-file>: is filename to initialize tensor\\n"
""")
        f.write( \
"""#if ENABLE_OPENCV
            "     .jpg or .png: decode and initialize for 3 channel tensors\\n"
            "         (use %%04d in fileName when batch-size > 1: batch index starts from 0)\\n"
#endif
""")
        f.write( \
"""            "     other: initialize tensor with raw data from the file\\n"
            "\\n"
""")
        if type(argmaxOutput) is np.ndarray:
            f.write( \
"""            "   <output-data-file>[,<reference-for-compare>,<percentMismatchLimit>]:\\n"
            "     <referece-to-compare> is raw tensor data of LUT output for comparision\\n"
            "     <percentMismatchLimit> is max mismatches (percent) allowed\\n"
            "     <output-data-file> is filename for saving output tensor data\\n"
#if ENABLE_OPENCV
            "       .png: save LUT output as PNG file(s)\\n"
            "         (use %%04d in fileName to when batch-size > 1: batch index starts from 0)\\n"
#endif
""")
        elif type(argmaxOutput) is str:
            f.write( \
"""            "   <output-data-file>[,<reference-for-compare>,<percentMismatchLimit>]:\\n"
            "     <referece-to-compare> is raw tensor data of argmax output for comparision\\n"
            "     <percentMismatchLimit> is max mismatches (percent) allowed\\n"
            "     <output-data-file> is filename for saving output tensor data\\n"
""")
        else:
            f.write( \
"""            "   <output-data-file>[,<reference-for-compare>,<maxErrorLimit>,<rmsErrorLimit>]:\\n"
            "     <referece-to-compare> is raw tensor data for comparision\\n"
            "     <maxErrorLimit> is max absolute error allowed\\n"
            "     <rmsErrorLimit> is max RMS error allowed\\n"
            "     <output-data-file> is filename for saving output tensor data\\n"
""")
        f.write( \
"""            "       '-' to ignore\\n"
            "       other: save raw tensor into the file\\n"
            "\\n"
        );
        return -1;
    }
    const char * binaryFilename = argv[1];
    argc -= 2;
    argv += 2;

    // create context, input, output, and graph
    vxRegisterLogCallback(NULL, log_callback, vx_false_e);
    vx_context context = vxCreateContext();
    vx_status status = vxGetStatus((vx_reference)context);
    if(status) {
        printf("ERROR: vxCreateContext() failed\\n");
        return -1;
    }
    vxRegisterLogCallback(context, log_callback, vx_false_e);
    vx_graph graph = vxCreateGraph(context);
    status = vxGetStatus((vx_reference)graph);
    if(status) {
        printf("ERROR: vxCreateGraph(...) failed (%d)\\n", status);
        return -1;
    }
""")
        for tensor in graph.inputs:
            f.write( \
"""
    // create and initialize input tensor %s
    vx_size dims_%s[%d] = { %s };
    vx_tensor %s = vxCreateTensor(context, %d, dims_%s, %s, 0);
    if(vxGetStatus((vx_reference)%s)) {
        printf("ERROR: vxCreateTensor() failed for %s\\n");
        return -1;
    }
    if(*argv) {
        if(strcmp(*argv, "-") != 0) {
            if(copyTensor("%s", %s, *argv, VX_WRITE_ONLY) < 0) {
                return -1;
            }
            printf("OK: initialized tensor '%s' from %%s\\n", *argv);
        }
        argv++;
    }
""" % (tensor.name, tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), \
       tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], \
       tensor.name, tensor.name, tensor.name, tensor.name, tensor.name))
        for tensor in graph.outputs:
            f.write( \
"""
    // create output tensor %s
    vx_size dims_%s[%d] = { %s };
    vx_tensor %s = vxCreateTensor(context, %d, dims_%s, %s, 0);
    if(vxGetStatus((vx_reference)%s)) {
        printf("ERROR: vxCreateTensor() failed for %s\\n");
        return -1;
    }
""" % (tensor.name, tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), \
       tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], \
       tensor.name, tensor.name))
        f.write( \
"""
    // build graph using annmodule
    int64_t freq = clockFrequency(), t0, t1;
    t0 = clockCounter();
    status = annAddToGraph(graph, %s, %s, binaryFilename);
    if(status) {
        printf("ERROR: annAddToGraph() failed (%%d)\\n", status);
        return -1;
    }
    status = vxVerifyGraph(graph);
    if(status) {
        printf("ERROR: vxVerifyGraph(...) failed (%%d)\\n", status);
        return -1;
    }
    t1 = clockCounter();
    printf("OK: graph initialization with annAddToGraph() took %%.3f msec\\n", (float)(t1-t0)*1000.0f/(float)freq);

    t0 = clockCounter();
    status = vxProcessGraph(graph);
    t1 = clockCounter();
    if(status != VX_SUCCESS) {
        printf("ERROR: vxProcessGraph() failed (%%d)\\n", status);
        return -1;
    }
    printf("OK: vxProcessGraph() took %%.3f msec (1st iteration)\\n", (float)(t1-t0)*1000.0f/(float)freq);
""" % (', '.join([tensor.name for tensor in graph.inputs]), \
       ', '.join([tensor.name for tensor in graph.outputs])))
        for tensor in graph.outputs:
            f.write( \
"""
    // save tensor %s
    if(*argv) {
        if(strcmp(*argv, "-") != 0) {
            if(copyTensor("%s", %s, *argv, VX_READ_ONLY) < 0) {
                return -1;
            }
            printf("OK: wrote tensor '%s' into %%s\\n", *argv);
        }
        argv++;
    }
""" % (tensor.name, tensor.name, tensor.name, tensor.name))
        f.write( \
"""
    t0 = clockCounter();
    int N = 100;
    for(int i = 0; i < N; i++) {
        status = vxProcessGraph(graph);
        if(status != VX_SUCCESS)
            break;
    }
    t1 = clockCounter();
    printf("OK: vxProcessGraph() took %.3f msec (average over %d iterations)\\n", (float)(t1-t0)*1000.0f/(float)freq/(float)N, N);

    // release resources
    ERROR_CHECK_STATUS(vxReleaseGraph(&graph));
""")
        for tensor in graph.inputs:
            f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" % (tensor.name))
        for tensor in graph.outputs:
            f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" % (tensor.name))
        f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseContext(&context));
    printf("OK: successful\\n");

    return 0;
}
""")

def generateBinary(graph,fileName):
    VARIABLES_FILE_MAGIC = 0xF00DD1E0
    VARIABLES_DATA_MAGIC = 0xF00DD1E1
    VARIABLES_EOFF_MAGIC = 0xF00DD1E2
    print('creating ' + fileName + ' ...')
    with open(fileName, 'wb') as f:
        f.write(struct.pack('I', VARIABLES_FILE_MAGIC))
        for tensor in graph.initializers:
            binary = graph.binaries[tensor.name]
            f.write(struct.pack('II', VARIABLES_DATA_MAGIC, len(binary)))
            f.write(binary)
        f.write(struct.pack('I', VARIABLES_EOFF_MAGIC))

def generateCode(graph,argmaxOutput,outputFolder):
    if not os.path.isdir(outputFolder):
        os.mkdir(outputFolder)
    generateCMakeFiles(graph,outputFolder)
    generateModuleH(graph,outputFolder + '/annmodule.h')
    generateModuleCPP(graph,outputFolder + '/annmodule.cpp')
    generateBinary(graph,outputFolder + '/weights.bin')
    generateTestCPP(graph,argmaxOutput,outputFolder + '/anntest.cpp')
    generatePythonH(graph,outputFolder + '/annpython.h')
    generatePythonCPP(graph,outputFolder + '/annpython.cpp')
    generatePythonScriptSample(graph,outputFolder + '/anntest.py')

def main():
    usage = """
Usage: python nnir2openvx.py [OPTIONS] <nnirInputFolder> <outputFolder>

  OPTIONS:
    --argmax UINT8                    -- argmax at the end with 8-bit output
    --argmax UINT16                   -- argmax at the end with 16-bit output
    --argmax <fileNamePrefix>rgb.txt  -- argmax at the end with RGB color mapping using LUT
    --argmax <fileNamePrefix>rgba.txt -- argmax at the end with RGBA color mapping using LUT
    --help                            -- show this help message

  LUT File Format (RGB): 8-bit R G B values one per each label in text format
    R0 G0 B0
    R1 G1 B1
    ...

  LUT File Format (RGBA): 8-bit R G B A values one per each label in text format
    R0 G0 B0 A0
    R1 G1 B1 A1
    ...

"""
    pos = 1;
    argmaxOutput = None
    while len(sys.argv[pos:]) >= 2 and sys.argv[pos][:2] == '--':
        if sys.argv[pos] == '--argmax':
            argmaxOutput = sys.argv[pos+1]
            if argmaxOutput == 'UINT8':
                argmaxOutput = 'vx_uint8'
            elif argmaxOutput == 'UINT16':
                argmaxOutput = 'vx_uint16'
            else:
                if not os.path.isfile(argmaxOutput):
                    print('ERROR: unable to open: %s' % (argmaxOutput))
                    sys.exit(1)
                with open(argmaxOutput,'r') as f:
                    if argmaxOutput[-8:] == 'rgba.txt':
                        argmaxOutput = np.reshape(np.array([int(v) for v in f.read().split()]), [-1, 4]).transpose()
                    else:
                        argmaxOutput = np.reshape(np.array([int(v) for v in f.read().split()]), [-1, 3]).transpose()
        else:
            if sys.argv[pos] != '--help':
                print('ERROR: invalid option: %s' % (sys.argv[pos]))
            print(usage)
            sys.exit(1)
        pos = pos + 2
    if len(sys.argv[pos:]) < 2:
        print(usage)
        sys.exit(1)
    inputFolder = sys.argv[pos]
    outputFolder = sys.argv[pos+1]
    print('reading IR model from ' + inputFolder + ' ...')
    graph = IrGraph()
    graph.fromFile(inputFolder)
    for tensor in graph.outputs:
        if len(tensor.shape) == 1:
            print('#OUTPUT-TENSOR: %s %d %d %d %d ' %(tensor.name, tensor.shape[0], 1, 1, 1));    
        elif len(tensor.shape) == 2:
            print('#OUTPUT-TENSOR: %s %d %d %d %d ' %(tensor.name, tensor.shape[0], tensor.shape[1], 1, 1));
        elif len(tensor.shape) == 4:
            print('#OUTPUT-TENSOR: %s %d %d %d %d ' %(tensor.name, tensor.shape[0], tensor.shape[1], tensor.shape[2], tensor.shape[3]));
    print('creating C code in ' + outputFolder + ' ...')
    generateCode(graph,argmaxOutput,outputFolder)

if __name__ == '__main__':
    main()
