# Copyright (c) 2018 - 2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from future import standard_library
standard_library.install_aliases()
from builtins import *
from builtins import str
from builtins import range
import os, sys, struct
import datetime, pytz
from nnir import *

tensor_type_nnir2openvx = {
    'F032' : 'VX_TYPE_FLOAT32',
    'F016' : 'VX_TYPE_FLOAT16',
    'U016' : 'VX_TYPE_UINT16',
    'I016' : 'VX_TYPE_INT16',
    'U008' : 'VX_TYPE_UINT8',
    'I064' : 'VX_TYPE_INT64',
    'I032' : 'VX_TYPE_INT32',
}

def generateLicenseForCPP(f):
        f.write( \
"""/*
MIT License

Copyright (c) 2018 - 2022 Advanced Micro Devices, Inc. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/

/* This file is generated by nnir2openvx.py on %s */
""" % (datetime.datetime.now(tz=pytz.timezone('America/Los_Angeles')).isoformat()))

def generateLicenseForScript(f):
        f.write( \
"""################################################################################
#
# MIT License
#
# Copyright (c) 2018 - 2022 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################

# This file is generated by nnir2openvx.py on %s
""" % (datetime.datetime.now(tz=pytz.timezone('America/Los_Angeles')).isoformat()))

def generateCMakeFiles(graph,outputFolder):
    fileName = outputFolder + '/CMakeLists.txt'
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForScript(f)
        f.write( \
"""
cmake_minimum_required (VERSION 3.0)

project (annmodule)

set(CMAKE_CXX_STANDARD 11)

set(ROCM_PATH /opt/rocm CACHE PATH "ROCm Installation Path")

list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)

#find the OPENVX backend type
set(OPENVX_BACKEND_OPENCL_FOUND 0)
set(OPENVX_BACKEND_HIP_FOUND 0)
if(EXISTS ${ROCM_PATH}/mivisionx/include/openvx_backend.h)
    file(READ ${ROCM_PATH}/mivisionx/include/openvx_backend.h OPENVX_BACKEND_FILE)
    string(REGEX MATCH "ENABLE_OPENCL ([0-9]*)" _ ${OPENVX_BACKEND_FILE})
    set(OPENVX_BACKEND_OPENCL_FOUND ${CMAKE_MATCH_1})
    string(REGEX MATCH "ENABLE_HIP ([0-9]*)" _ ${OPENVX_BACKEND_FILE})
    set(OPENVX_BACKEND_HIP_FOUND ${CMAKE_MATCH_1})
else()
    message("-- ${Red}WARNING: ${ROCM_PATH}/mivisionx/include/openvx_backend.h file Not Found. please install the latest mivisionx! ${ColourReset}")
endif()

find_package(OpenCV QUIET)

if (OPENVX_BACKEND_OPENCL_FOUND)
    find_package(OpenCL QUIET)

    if(OpenCL_FOUND)
        message("-- Using OpenCL Library -- ${OpenCL_LIBRARIES}")
    else()
        message(FATAL_ERROR "OpenCL Required for NN Flow")
    endif()

    include_directories(${OpenCL_INCLUDE_DIRS} ${OpenCL_INCLUDE_DIRS}/Headers )
endif()

include_directories(${ROCM_PATH}/mivisionx/include)

link_directories(${ROCM_PATH}/mivisionx/lib)

list(APPEND SOURCES annmodule.cpp)
add_library(${PROJECT_NAME} SHARED ${SOURCES})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -mf16c -std=gnu++14")

if (OPENVX_BACKEND_OPENCL_FOUND)
    target_link_libraries(${PROJECT_NAME} openvx vx_nn pthread ${OpenCL_LIBRARIES})
else()
    target_link_libraries(${PROJECT_NAME} openvx vx_nn pthread)
endif()

add_executable(anntest anntest.cpp)
if(OpenCV_FOUND)
    if(${OpenCV_VERSION_MAJOR} EQUAL 3 OR ${OpenCV_VERSION_MAJOR} EQUAL 4)
        target_compile_definitions(anntest PUBLIC ENABLE_OPENCV=1)
        include_directories(${OpenCV_INCLUDE_DIRS})
        target_link_libraries(anntest ${OpenCV_LIBRARIES})
        if(${OpenCV_VERSION_MAJOR} EQUAL 4)
	        target_compile_definitions(anntest PUBLIC USE_OPENCV_4=1)
        else()
	        target_compile_definitions(anntest PUBLIC USE_OPENCV_4=0)
        endif()
    else()
        target_compile_definitions(anntest PUBLIC ENABLE_OPENCV=0)
        message("-- NOTE: anntest -- OpenCV Version-${OpenCV_VERSION_MAJOR}.${OpenCV_VERSION_MINOR}.X Not Supported")
    endif()
else(OpenCV_FOUND)
  target_compile_definitions(anntest PUBLIC ENABLE_OPENCV=0)
endif(OpenCV_FOUND)

if (OPENVX_BACKEND_OPENCL_FOUND)
    target_link_libraries(anntest ${PROJECT_NAME} openvx vx_nn pthread ${OpenCL_LIBRARIES})
else()
    target_link_libraries(anntest ${PROJECT_NAME} openvx vx_nn pthread)
endif()

add_library(annpython SHARED annpython.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -mf16c -std=gnu++14")

if (OPENVX_BACKEND_OPENCL_FOUND)
    target_link_libraries(annpython ${PROJECT_NAME} openvx vx_nn pthread ${OpenCL_LIBRARIES})
else()
    target_link_libraries(annpython ${PROJECT_NAME} openvx vx_nn pthread)
endif()
""")
    if not os.path.isdir(outputFolder + '/cmake'):
        os.mkdir(outputFolder + '/cmake')
    fileName = outputFolder + '/cmake/FindOpenCL.cmake'
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForScript(f)
        f.write( \
"""
include( FindPackageHandleStandardArgs )
find_package_handle_standard_args(
    OpenCL
    FOUND_VAR OpenCL_FOUND
    REQUIRED_VARS
        OpenCL_LIBRARIES
        OpenCL_INCLUDE_DIRS
        CL_TARGET_OpenCL_VERSION
    VERSION_VAR OpenCL_VERSION
)

if(OpenCL_LIBRARIES AND OpenCL_INCLUDE_DIRS)
    set(OpenCL_FOUND TRUE)
    add_definitions(-DCL_TARGET_OPENCL_VERSION=${CL_TARGET_OpenCL_VERSION})
else()
    find_path(OPENCL_INCLUDE_DIRS
        NAMES OpenCL/cl.h CL/cl.h
        HINTS
        ${OPENCL_ROOT}/include
        $ENV{AMDAPPSDKROOT}/include
        $ENV{CUDA_PATH}/include
        PATHS
        ${ROCM_PATH}/opencl/include
        /usr/include
        /usr/local/include
        /usr/local/cuda/include
        /opt/cuda/include
        DOC "OpenCL header file path"
    )
    mark_as_advanced( OPENCL_INCLUDE_DIRS )

    if("${CMAKE_SIZEOF_VOID_P}" EQUAL "8")
        find_library( OPENCL_LIBRARIES
            NAMES OpenCL
            HINTS
            ${OPENCL_ROOT}/lib
            $ENV{AMDAPPSDKROOT}/lib
            $ENV{CUDA_PATH}/lib
            DOC "OpenCL dynamic library path"
            PATH_SUFFIXES x86_64 x64 x86_64/sdk
            PATHS
            ${ROCM_PATH}/opencl/lib/
            /usr/lib
            /usr/local/cuda/lib
            /opt/cuda/lib
        )
    else( )
        find_library( OPENCL_LIBRARIES
            NAMES OpenCL
            HINTS
            ${OPENCL_ROOT}/lib
            $ENV{AMDAPPSDKROOT}/lib
            $ENV{CUDA_PATH}/lib
            DOC "OpenCL dynamic library path"
            PATH_SUFFIXES x86 Win32
            PATHS
            ${ROCM_PATH}/opencl/lib/
            /usr/lib
            /usr/local/cuda/lib
            /opt/cuda/lib
        )
    endif( )
    mark_as_advanced( OPENCL_LIBRARIES )

    if(OPENCL_LIBRARIES AND OPENCL_INCLUDE_DIRS)
        set(OPENCL_FOUND TRUE)
    endif( )

    set(OpenCL_FOUND ${OPENCL_FOUND} CACHE INTERNAL "")
    set(OpenCL_LIBRARIES ${OPENCL_LIBRARIES} CACHE INTERNAL "")
    set(OpenCL_INCLUDE_DIRS ${OPENCL_INCLUDE_DIRS} CACHE INTERNAL "")

    if(EXISTS "${ROCM_PATH}/opencl/lib/libOpenCL.so")
        if(NOT "${OPENCL_LIBRARIES}" STREQUAL "${ROCM_PATH}/opencl/lib/libOpenCL.so")
            message("-- OpenCL Found - ${OPENCL_LIBRARIES}")
            message("-- ROCm OpenCL Found - Force OpenCL_LIBRARIES & OpenCL_INCLUDE_DIRS to use ROCm OpenCL")
            set(OpenCL_LIBRARIES ${ROCM_PATH}/opencl/lib/libOpenCL.so CACHE INTERNAL "")
            set(OpenCL_INCLUDE_DIRS ${ROCM_PATH}/opencl/include CACHE INTERNAL "")
        endif()
    else()
        message("-- ROCm OpenCL Not Found}")
    endif()

    if(OpenCL_FOUND)
        execute_process(
            COMMAND bash -c "nm -gDC ${OpenCL_LIBRARIES} | grep OPENCL_2.2"
            OUTPUT_VARIABLE outVar
        )
        if(NOT ${outVar} STREQUAL "")
            set(CL_TARGET_OpenCL_VERSION 220 CACHE INTERNAL "")
        else()
            message( "-- FindOpenCL failed to find: OpenCL 2.2" )
            set(CL_TARGET_OpenCL_VERSION 120 CACHE INTERNAL "")
        endif()
        add_definitions(-DCL_TARGET_OPENCL_VERSION=${CL_TARGET_OpenCL_VERSION})
        message("-- OpenCL - Setting CL_TARGET_OPENCL_VERSION=${CL_TARGET_OpenCL_VERSION}")
    endif()

    if( NOT OpenCL_FOUND )
        message( "-- FindOpenCL failed to find: OpenCL" )
    endif()
endif()
""")

def generateModuleH(graph,fileName,virtual_tensor_flag):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#ifndef included_file_annmodule_h
#define included_file_annmodule_h

#include <VX/vx.h>
#include <map>

////
// initialize graph neural network for inference
""")
        for tensor in graph.inputs:
            f.write( \
"""//   %s -- dims[] = { %s } (input)
""" % (tensor.name, ', '.join([str(v) for v in reversed(tensor.shape)])))
        for tensor in graph.outputs:
            f.write( \
"""//   %s -- dims[] = { %s, } (output)
""" % (tensor.name, ', '.join([str(v) for v in reversed(tensor.shape)])))
        if virtual_tensor_flag == 0:
            f.write( \
"""//
extern "C" VX_API_ENTRY vx_status VX_API_CALL annAddToGraph(vx_graph graph, %s, %s, std::map<std::string, vx_tensor> &tensorMap, const char * binaryFilename);

#endif
""" % (', '.join(['vx_tensor ' + tensor.name for tensor in graph.inputs]), \
       ', '.join(['vx_tensor ' + tensor.name for tensor in graph.outputs])))
        else:
            f.write( \
"""//
extern "C" VX_API_ENTRY vx_status VX_API_CALL annAddToGraph(vx_graph graph, %s, %s, const char * binaryFilename);

#endif
""" % (', '.join(['vx_tensor ' + tensor.name for tensor in graph.inputs]), \
       ', '.join(['vx_tensor ' + tensor.name for tensor in graph.outputs])))

def generateModuleCPP(graph,fileName,virtual_tensor_flag):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#include "annmodule.h"
#include <VX/vx_khr_nn.h>
#include <VX/vx_compatibility.h>
#include <vx_amd_nn.h>
#include <vx_ext_amd.h>
#include <stdio.h>

#define ERROR_CHECK_OBJECT(obj) { vx_status status = vxGetStatus((vx_reference)(obj)); if(status != VX_SUCCESS) { vxAddLogEntry((vx_reference)context, status     , "ERROR: failed with status = (%d) at " __FILE__ "#%d\\n", status, __LINE__); return status; } }
#define ERROR_CHECK_STATUS(call) { vx_status status = (call); if(status != VX_SUCCESS) { vxAddLogEntry((vx_reference)context, status, "ERROR: failed with status = (%d) at " __FILE__ "#%d\\n", status, __LINE__); return status; } }

static vx_status initializeTensor(vx_context context, vx_tensor tensor, FILE * fp, const char * binaryFilename)
{
    vx_enum data_type = VX_TYPE_FLOAT32;
    vx_size num_of_dims = 4, dims[4] = { 1, 1, 1, 1 }, stride[4];
    ERROR_CHECK_STATUS(vxQueryTensor(tensor, VX_TENSOR_DATA_TYPE, &data_type, sizeof(vx_enum)));
    ERROR_CHECK_STATUS(vxQueryTensor(tensor, VX_TENSOR_NUMBER_OF_DIMS, &num_of_dims, sizeof(vx_size)));
    ERROR_CHECK_STATUS(vxQueryTensor(tensor, VX_TENSOR_DIMS, &dims, num_of_dims * sizeof(vx_size)));
    vx_size itemsize = sizeof(float);
    if(data_type == VX_TYPE_UINT8 || data_type == VX_TYPE_INT8) {
        itemsize = sizeof(vx_uint8);
    }
    else if(data_type == VX_TYPE_UINT16 || data_type == VX_TYPE_INT16 || data_type == VX_TYPE_FLOAT16) {
        itemsize = sizeof(vx_uint16);
    }
    else if(data_type == VX_TYPE_INT64) {
        itemsize = sizeof(vx_int64);
    }
    vx_size count = dims[0] * dims[1] * dims[2] * dims[3];

    vx_uint32 h[2] = { 0 };
    fread(h, 1, sizeof(h), fp);
    if(h[0] != 0xf00dd1e1 || (vx_size)h[1] != (count*itemsize)) {
      vxAddLogEntry((vx_reference)tensor, VX_FAILURE, "ERROR: invalid data (magic,size)=(0x%x,%d) in %%s at byte position %d -- expected size is %ld\\n", h[0], h[1], binaryFilename, ftell(fp)-sizeof(h), count*itemsize);
      return VX_FAILURE;
    }

    vx_map_id map_id;
    void * ptr;
    ERROR_CHECK_STATUS(vxMapTensorPatch(tensor, num_of_dims, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST));
    vx_size n = fread(ptr, itemsize, count, fp);
    if(n != count) {
        vxAddLogEntry((vx_reference)tensor, VX_FAILURE, "ERROR: expected char[%ld], but got char[%ld] in %s\\n", count*itemsize, n*itemsize, binaryFilename);
        return VX_FAILURE;
    }
    ERROR_CHECK_STATUS(vxUnmapTensorPatch(tensor, map_id));

    return VX_SUCCESS;
}
""" )
        if virtual_tensor_flag == 0:
            f.write( \
"""VX_API_ENTRY vx_status VX_API_CALL annAddToGraph(vx_graph graph, %s, %s, std::map<std::string, vx_tensor> &tensorMap, const char * binaryFilename)
{
    vx_context context = vxGetContext((vx_reference)graph);
    ERROR_CHECK_OBJECT(context);
    ERROR_CHECK_STATUS(vxLoadKernels(context, "vx_nn"));

    // create variables
""" % (', '.join(['vx_tensor ' + tensor.name for tensor in graph.inputs]), \
       ', '.join(['vx_tensor ' + tensor.name for tensor in graph.outputs])))
        else:
            f.write( \
"""VX_API_ENTRY vx_status VX_API_CALL annAddToGraph(vx_graph graph, %s, %s, const char * binaryFilename)
{
    vx_context context = vxGetContext((vx_reference)graph);
    ERROR_CHECK_OBJECT(context);
    ERROR_CHECK_STATUS(vxLoadKernels(context, "vx_nn"));

    // create variables
""" % (', '.join(['vx_tensor ' + tensor.name for tensor in graph.inputs]), \
       ', '.join(['vx_tensor ' + tensor.name for tensor in graph.outputs])))
        for tensor in graph.initializers:
            f.write( \
"""    vx_size dims_%s[%d] = { %s };
    vx_tensor %s = vxCreateTensor(context, %d, dims_%s, %s, 0);
    ERROR_CHECK_OBJECT(%s);
""" %(tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), \
      tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], tensor.name))
        f.write( \
"""
    // initialize variables
    FILE * fp__variables = fopen(binaryFilename, "rb");
    if(!fp__variables) {
        vxAddLogEntry((vx_reference)context, VX_FAILURE, "ERROR: unable to open: %s\\n", binaryFilename);
        return VX_FAILURE;
    }
    { vx_uint32 magic = 0;
      fread(&magic, 1, sizeof(magic), fp__variables);
      if(magic != 0xf00dd1e0) {
        vxAddLogEntry((vx_reference)context, VX_FAILURE, "ERROR: invalid file magic in %s\\n", binaryFilename);
        return VX_FAILURE;
      }
    }
""")
        for tensor in graph.initializers:
            f.write( \
"""    ERROR_CHECK_STATUS(initializeTensor(context, %s, fp__variables, binaryFilename));
""" %(tensor.name))
        f.write( \
"""    { vx_uint32 magic = 0;
      fread(&magic, 1, sizeof(magic), fp__variables);
      if(magic != 0xf00dd1e2) {
        vxAddLogEntry((vx_reference)context, VX_FAILURE, "ERROR: invalid eoff magic in %s\\n", binaryFilename);
        return VX_FAILURE;
      }
      fclose(fp__variables);
    }

    // create local tensors used in graph
""")
        localList = []
        for tensor in graph.locals:
            localList.append(tensor.name)
        outputList = []
        for tensor in graph.outputs:
            outputList.append(tensor.name)
        for idx, tensor in enumerate(graph.locals):
            if (not tensor.name in outputList) and (not tensor.name in localList[:idx]):
                tensor.shape = [int(v) for v in tensor.shape]
                f.write( \
"""    vx_size dims_%s[%d] = { %s };
"""%(tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)])))
                if virtual_tensor_flag == 0:
                    f.write( \
"""    vx_tensor %s = vxCreateTensor(context, %d, dims_%s, %s, 0);
    tensorMap.insert(std::pair<std::string, vx_tensor>("%s", %s));        
"""%(tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], tensor.name, tensor.name))
                else:
                    f.write( \
"""    vx_tensor %s = vxCreateVirtualTensor(graph, %d, dims_%s, %s, 0);
"""%(tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type]))
                f.write( \
"""    ERROR_CHECK_OBJECT(%s);
"""%(tensor.name))
      
        f.write( \
"""
    // create nodes in graph
""")
        for node in graph.nodes:
            if node.type == 'conv':
                pads = node.attr.get('pads')
                dilations = node.attr.get('dilations')
                alpha = node.attr.get('alpha')
                f.write( \
"""
    { vx_nn_convolution_params_t conv_params = { 0 };
      conv_params.padding_x = %d;
      conv_params.padding_y = %d;
      conv_params.overflow_policy = VX_CONVERT_POLICY_SATURATE;
      conv_params.rounding_policy = VX_ROUND_POLICY_TO_NEAREST_EVEN;
      conv_params.down_scale_size_rounding = VX_NN_DS_SIZE_ROUNDING_FLOOR;
      conv_params.dilation_x = %d;
      conv_params.dilation_y = %d;
      vx_node node = vxConvolutionLayer(graph, %s, %s, %s, &conv_params, sizeof(conv_params), %s);
      ERROR_CHECK_OBJECT(node);
""" % (pads[0], pads[1], dilations[0] - 1, dilations[1] - 1, \
      node.inputs[0], node.inputs[1], node.inputs[2] if len(node.inputs) == 3 else 'NULL', node.outputs[0]))
                if (node.attr.get('mode') != 0):
                    f.write( \
"""      vx_float32 alpha = %f;
      vx_scalar s_alpha = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &alpha, sizeof(alpha));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 5, (vx_reference) s_alpha));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_alpha));
""" %(alpha))
                if (node.attr.get('group') > 1):
                    group = node.attr.get('group');
                    f.write( \
"""      vx_int32 groupCount = %d;
      vx_scalar s_groupCount = vxCreateScalarWithSize(context, VX_TYPE_INT32, &groupCount, sizeof(groupCount));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 6, (vx_reference) s_groupCount));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_groupCount));
""" % (group))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'conv_transpose':
                pads = node.attr.get('pads')
                dilations = node.attr.get('dilations')
                kernel_shape = node.attr.get('kernel_shape')
                output_pads = [(dilations[0] - 1) * (kernel_shape[0] - 1), \
                                (dilations[1] - 1) * (kernel_shape[1] - 1)]
                f.write( \
"""
    { vx_nn_deconvolution_params_t conv_params = { 0 };
      conv_params.padding_x = %d;
      conv_params.padding_y = %d;
      conv_params.overflow_policy = VX_CONVERT_POLICY_SATURATE;
      conv_params.rounding_policy = VX_ROUND_POLICY_TO_NEAREST_EVEN;
      conv_params.a_x = %d;
      conv_params.a_y = %d;
      vx_node node = vxDeconvolutionLayer(graph, %s, %s, %s, &conv_params, sizeof(conv_params), %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (pads[0], pads[1], output_pads[0] , output_pads[1] , \
      node.inputs[0], node.inputs[1], node.inputs[2] if len(node.inputs) == 3 else 'NULL', node.outputs[0]))
            elif node.type == 'gemm':
                alpha = node.attr.get('alpha')
                beta = node.attr.get('beta')
                transA = node.attr.get('transA')
                transB = node.attr.get('transB')
                hasBias = False
                if beta == 1.0 and len(node.inputs) == 3 and len(graph.tensor_shapes[node.inputs[2]]) <= 2:
                    hasBias = True
                if alpha == 1.0 and transA == 0 and (beta == 0.0 or hasBias):
                    f.write( \
"""
    { vx_node node = vxFullyConnectedLayer(graph, %s, %s, %s, VX_CONVERT_POLICY_SATURATE, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % ( \
        node.inputs[0], node.inputs[1], node.inputs[2] if hasBias else 'NULL', node.outputs[0]))
                else:
                    raise ValueError("Unsupported gemm configuration by OpenVX: alpha={} beta={} transA={} transB={}".format(alpha, beta, transA, transB))
            elif node.type == 'matmul':
                alpha = node.attr.get('alpha')
                beta = node.attr.get('beta')
                transA = node.attr.get('transA')
                transB = node.attr.get('transB')
                f.write( \
"""
    { _vx_tensor_matrix_multiply_params_t matrix_mul_params = { 0 };
      matrix_mul_params.transpose_input1 = %d;
      matrix_mul_params.transpose_input2 = %d;
      matrix_mul_params.transpose_input3 = %d;
      vx_node node = vxTensorMatrixMultiplyNode(graph, %s, %s, %s, &matrix_mul_params, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % ( \
        1 if transA else 0, 1 if transB else 0, 0, node.inputs[0], node.inputs[1], node.inputs[2] if beta else 'NULL', node.outputs[0]))
            elif node.type == 'max_pool' or node.type == 'avg_pool':
                f.write( \
"""
    { vx_node node = vxPoolingLayer(graph, %s, %s, %d, %d, %d, %d, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_OBJECT(node);
      vx_enum border_mode = %d;
      vx_scalar s_border_mode = vxCreateScalarWithSize(context, VX_TYPE_ENUM, &border_mode, sizeof(border_mode));
      ERROR_CHECK_OBJECT(s_border_mode);
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 8, (vx_reference) s_border_mode));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_border_mode));
""" % (node.inputs[0], 'VX_NN_POOLING_AVG' if node.type == 'avg_pool' else 'VX_NN_POOLING_MAX', \
       node.attr.get('kernel_shape')[0], node.attr.get('kernel_shape')[1], \
       node.attr.get('pads')[0], node.attr.get('pads')[1], node.outputs[0], \
       (1 if node.attr.get('border_mode') == 'discard' else 0)))
                if (node.attr.get('mode') != 0):
                    f.write( \
"""      vx_int32 mode = %s;
      vx_scalar s_mode = vxCreateScalarWithSize(context, VX_TYPE_INT32, &mode, sizeof(mode));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 9, (vx_reference) s_mode));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_mode));
""" % (node.attr.get('mode')))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'global_avg_pool':
                f.write( \
"""
    { vx_node node = vxPoolingLayer(graph, %s, VX_NN_POOLING_AVG, %d, %d, %d, %d, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_OBJECT(node);
""" % (node.inputs[0], graph.tensor_shapes[node.inputs[0]][2], graph.tensor_shapes[node.inputs[0]][3], \
       node.attr.get('pads')[0], node.attr.get('pads')[1], node.outputs[0]))
                if (node.attr.get('mode') != 0):
                    f.write( \
"""      vx_int32 mode = %s;
      vx_scalar s_mode = vxCreateScalarWithSize(context, VX_TYPE_INT32, &mode, sizeof(mode));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 9, (vx_reference) s_mode));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_mode));
""" % (node.attr.get('mode')))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'relu':
                f.write( \
"""
    { vx_node node = vxActivationLayer(graph, %s, VX_NN_ACTIVATION_RELU, 0.0f, 0.0f, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'leaky_relu':
                f.write( \
"""
    {  vx_node node = vxActivationLayer(graph, %s, VX_NN_ACTIVATION_LEAKY_RELU, %f, 0.0f, %s);
       ERROR_CHECK_OBJECT(node);
       ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.attr.get('alpha'), node.outputs[0]))
            elif node.type == 'sigmoid':
                f.write( \
"""
    { vx_node node = vxActivationLayer(graph, %s, VX_NN_ACTIVATION_LOGISTIC, 0.0f, 0.0f, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'add' or node.type == 'sum':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_node node = vxTensorAddNode(graph, %s, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'sub':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_node node = vxTensorSubtractNode(graph, %s, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'mul':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_float32 value = 1.0f;
      vx_scalar scale = vxCreateScalar(context, VX_TYPE_FLOAT32, &value);
      ERROR_CHECK_OBJECT(scale);
      vx_node node = vxTensorMultiplyNode(graph, %s, %s, scale, VX_CONVERT_POLICY_SATURATE, VX_ROUND_POLICY_TO_NEAREST_EVEN, %s);
      ERROR_CHECK_STATUS(vxReleaseScalar(&scale));
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'muladd':
                tensor = graph.tensor_dict[node.inputs[0]]
                f.write( \
"""
    { vx_float32 value = 1.0f;
      vx_scalar scale = vxCreateScalar(context, VX_TYPE_FLOAT32, &value);
      ERROR_CHECK_OBJECT(scale);
      vx_size dims[%d] = { %s };
      vx_tensor tmp__tensor = vxCreateVirtualTensor(graph, %d, dims, %s, 0);
      ERROR_CHECK_OBJECT(tmp__tensor);
      vx_node node = vxTensorMultiplyNode(graph, %s, %s, scale, VX_CONVERT_POLICY_SATURATE, VX_ROUND_POLICY_TO_NEAREST_EVEN, tmp__tensor);
      ERROR_CHECK_STATUS(vxReleaseScalar(&scale));
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
      node = vxTensorAddNode(graph, tmp__tensor, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), len(tensor.shape), \
       tensor_type_nnir2openvx[tensor.type], node.inputs[0], node.inputs[1], node.inputs[2], node.outputs[0]))
            elif node.type == 'min':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_node node = vxTensorMinNode(graph, %s, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'max':
                if len(node.inputs) == 2:
                    f.write( \
"""
    { vx_node node = vxTensorMaxNode(graph, %s, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[1], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'clamp':
                if len(node.inputs) == 3:
                    tensor = graph.tensor_dict[node.inputs[0]]
                    f.write( \
"""
    { vx_size dims[%d] = { %s };
      vx_tensor tmp__tensor = vxCreateVirtualTensor(graph, %d, dims, %s, 0);
      ERROR_CHECK_OBJECT(tmp__tensor);
      vx_node node = vxTensorMaxNode(graph, %s, %s, VX_CONVERT_POLICY_SATURATE, tmp__tensor);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
      node = vxTensorMinNode(graph, tmp__tensor, %s, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), len(tensor.shape), \
       tensor_type_nnir2openvx[tensor.type], node.inputs[0], node.inputs[1], node.inputs[2], node.outputs[0]))
                elif len(node.inputs) == 1:
                    tensor = graph.tensor_dict[node.inputs[0]]
                    max_value = node.attr.get('max')
                    min_value = node.attr.get('min')
                    f.write( \
"""
    { vx_size tensor_dims[%d] = { %s };
      vx_tensor max_tensor = vxCreateTensor(context, 4, tensor_dims, VX_TYPE_FLOAT32, 0);
      ERROR_CHECK_OBJECT(max_tensor);
      vx_tensor min_tensor = vxCreateTensor(context, 4, tensor_dims, VX_TYPE_FLOAT32, 0);
      ERROR_CHECK_OBJECT(min_tensor);
      float *max_data = new float[tensor_dims[0]*tensor_dims[1]*tensor_dims[2]*tensor_dims[3]];
      float *min_data = new float[tensor_dims[0]*tensor_dims[1]*tensor_dims[2]*tensor_dims[3]];
      for(int i = 0; i < tensor_dims[0]*tensor_dims[1]*tensor_dims[2]*tensor_dims[3]; i++)
      {
      	max_data[i] = %f;
      	min_data[i] = %f;
      }
      vx_size stride[4] = { sizeof(float), sizeof(float)*tensor_dims[0] , sizeof(float)*tensor_dims[0]*tensor_dims[1], sizeof(float)*tensor_dims[0]*tensor_dims[1]*tensor_dims[2]  };
      ERROR_CHECK_STATUS(vxCopyTensorPatch(max_tensor, 4, nullptr, nullptr, stride, max_data, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST));
      ERROR_CHECK_STATUS(vxCopyTensorPatch(min_tensor, 4, nullptr, nullptr, stride, min_data, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST));
      vx_tensor tmp__tensor = vxCreateVirtualTensor(graph, %d, tensor_dims, %s, 0);
      ERROR_CHECK_OBJECT(tmp__tensor);
      vx_node node = vxTensorMinNode(graph, %s, max_tensor, VX_CONVERT_POLICY_SATURATE, tmp__tensor);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
      node = vxTensorMaxNode(graph, tmp__tensor, min_tensor, VX_CONVERT_POLICY_SATURATE, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
      
      delete [] max_data;
      delete [] min_data;
    }
""" % (len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), max_value, min_value, len(tensor.shape), \
       tensor_type_nnir2openvx[tensor.type], node.inputs[0], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'exp':
                if len(node.inputs) == 1:
                    f.write( \
"""
    { vx_node node = vxTensorExpNode(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'log':
                if len(node.inputs) == 1:
                    f.write( \
"""
    { vx_node node = vxTensorLogNode(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
                else:
                    raise ValueError("Unsupported number of input arguments by OpenVX: {}".format(node.type))
            elif node.type == 'batch_norm':
                f.write( \
"""
    { vx_node node = vxBatchNormalizationLayer(graph, %s, %s, %s, %s, %s, %ef, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.inputs[3], node.inputs[4], node.inputs[1], node.inputs[2], node.attr.get('epsilon'), node.outputs[0]))
            elif node.type == 'lrn':
                f.write( \
"""
    { vx_node node = vxNormalizationLayer(graph, %s, %s , %d, %ef, %ef, %s);
""" % (node.inputs[0], "VX_NN_NORMALIZATION_SAME_MAP" if node.attr.get('mode') == 0 else "VX_NN_NORMALIZATION_ACROSS_MAPS" , \
       node.attr.get('size'), node.attr.get('alpha'), node.attr.get('beta'), node.outputs[0]))
                if (node.attr.get('bias') != 1.0):
                    f.write( \
"""   vx_float32 bias = %s;
      vx_scalar s_bias = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &bias, sizeof(bias));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 6, (vx_reference) s_bias));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_bias));
""" % (node.attr.get('bias')))
                f.write( \
"""   ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")                        
            elif node.type == 'slice':
                f.write( \
"""
    { vx_node node = vxSliceLayer(graph, %s, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], ', '.join([name for name in node.outputs]), \
       ', '.join(['NULL' for i in range(8 - len(node.outputs))])))
            elif node.type == 'concat':
                f.write( \
"""
    { vx_node node = vxConcatLayer(graph, %s, %s, %s, %d);
      ERROR_CHECK_OBJECT(node);ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.outputs[0], ', '.join([name for name in node.inputs]), \
       ', '.join(['NULL' for i in range(8 - len(node.inputs))]), node.attr.get('axis')))                    
            elif node.type == 'softmax':
                f.write( \
"""
    { vx_node node = vxSoftmaxLayer(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
""" % (node.inputs[0], node.outputs[0]))
                if (node.attr.get('axis') > 1):
                    axis = node.attr.get('axis');
                    f.write( \
"""      vx_int32 axis = %d;
      vx_scalar s_axis = vxCreateScalarWithSize(context, VX_TYPE_INT32, &axis, sizeof(axis));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 2, (vx_reference) s_axis));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_axis));
""" % (axis))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'reshape' or node.type == 'flatten':
                f.write( \
"""
    { vx_node node = vxReshapeLayer(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'copy':
                f.write( \
"""
    { vx_node node = vxCopyNode(graph, (vx_reference)%s, (vx_reference)%s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], node.outputs[0]))
            elif node.type == 'transpose' or node.type == 'permute': 
                if node.type == 'transpose':
                    order_list = node.attr.get('axes')
                elif node.type == 'permute':
                    order_list = node.attr.get('order')
                f.write( \
"""
    { 
      int order_value[4] = {%d,%d,%d,%d}; 
      vx_array order =  vxCreateArray(context, VX_TYPE_INT32, 4);
      ERROR_CHECK_STATUS(vxTruncateArray(order,0));
      int *order_ptr = &order_value[0];
      ERROR_CHECK_STATUS(vxAddArrayItems(order, 4, order_ptr, sizeof(int)));
      vx_node node = vxPermuteLayer(graph, %s, order, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (order_list[0],order_list[1],order_list[2],order_list[3],node.inputs[0], node.outputs[0]))
            elif node.type == 'prior_box':
                aspect_ratio = node.attr.get('aspect_ratio')
                aspect_ratio_len = len(aspect_ratio)
                variance = node.attr.get('variance')
                max_size = node.attr.get('max_size')
                f.write( \
"""
    { 
""")
                if(aspect_ratio_len == 2):
                  f.write( \
"""     
      float aspect_ratio_value[2] = {%f,%f};
      vx_array aspect_ratio =  vxCreateArray(context, VX_TYPE_FLOAT32, 2);
      ERROR_CHECK_STATUS(vxTruncateArray(aspect_ratio,0));
      float *aspect_ratio_ptr = &aspect_ratio_value[0];
      ERROR_CHECK_STATUS(vxAddArrayItems(aspect_ratio, 2, aspect_ratio_ptr, sizeof(float)));
""" %(aspect_ratio[0], aspect_ratio[1])) 
                elif aspect_ratio_len == 1:
                  f.write( \
"""     
      float aspect_ratio_value[1] = {%f};
      vx_array aspect_ratio =  vxCreateArray(context, VX_TYPE_FLOAT32, 1);
      ERROR_CHECK_STATUS(vxTruncateArray(aspect_ratio,0));
      float *aspect_ratio_ptr = &aspect_ratio_value[0];
      ERROR_CHECK_STATUS(vxAddArrayItems(aspect_ratio, 1, aspect_ratio_ptr, sizeof(float)));
""" %(aspect_ratio[0])) 
                f.write( \
"""
      float variance_value[4] = {%f,%f,%f,%f}; 
      vx_array variance =  vxCreateArray(context, VX_TYPE_FLOAT32, 4); 
      ERROR_CHECK_STATUS(vxTruncateArray(variance,0));
      float *variance_ptr = &variance_value[0];
      ERROR_CHECK_STATUS(vxAddArrayItems(variance, 4, variance_ptr, sizeof(float)));
      vx_node node = vxPriorBoxLayer(graph, %s, %s, %f, aspect_ratio , %d, %d, %f, %s, variance, %f);
      ERROR_CHECK_OBJECT(node); 
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (variance[0],variance[1],variance[2],variance[3], node.inputs[0], node.inputs[1], node.attr.get('min_size'), node.attr.get('flip'),\
        node.attr.get('clip'), node.attr.get('prior_offset'), node.outputs[0], max_size))
            elif node.type == 'upsample':
                zoom_factor = node.attr.get('zoom_factor')
                if zoom_factor == 2:
                    f.write( \
"""
    { vx_node node = vxUpsampleNearestLayer(graph, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" % (node.inputs[0], node.outputs[0]))
                else:
                    raise ValueError("Unsupported scaling factor: {}".format(factor))
            elif node.type == 'crop':
                offset = node.attr.get('offset')
                f.write( \
"""
    { 
      vx_int32 axis = %d;
      vx_int32 offset1 = %d;
      vx_int32 offset2 = %d;
      vx_int32 offset3 = %d;
      vx_int32 offset4 = %d;
      vx_scalar s_axis = vxCreateScalarWithSize(context, VX_TYPE_INT32, &axis, sizeof(axis));      
      vx_scalar s_offset1 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset1, sizeof(offset1));
      vx_scalar s_offset2 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset2, sizeof(offset2));
      vx_scalar s_offset3 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset3, sizeof(offset3));
      vx_scalar s_offset4 = vxCreateScalarWithSize(context, VX_TYPE_INT32, &offset4, sizeof(offset4));
      vx_node node = vxCropLayer(graph, %s, %s, %s, s_axis, s_offset1, s_offset2, s_offset3, s_offset4);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.attr.get('axis'), offset[0], offset[1], offset[2], offset[3], node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'crop_and_resize':
                f.write( \
"""
    { 
      vx_node node = vxCropAndResizeLayer(graph, %s, %s, %d, %d, %d, %d, %d, %d);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.outputs[0], node.attr.get('coord')[0], node.attr.get('coord')[1], node.attr.get('shape')[0], node.attr.get('shape')[1], node.attr.get('scale'), node.attr.get('mode')))
            elif node.type == 'cast':
                to = node.attr.get('to')
                f.write( \
"""
    { 
      vx_node node = vxCastLayer(graph, %s, %d, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" % (node.inputs[0], to, node.outputs[0]))
            elif node.type == 'argmax':
                f.write( \
"""
    { 
      vx_node node = vxArgmaxLayer(graph, %s, (vx_reference)%s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
"""  
    % (node.inputs[0], node.outputs[0]))
            elif node.type == 'topk':
                f.write( \
"""
    { 
      vx_node node = vxTopkLayer(graph, %s, %s, %d, %d, %d, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
"""  
    % (node.inputs[0], node.inputs[1], node.attr.get('axis'), node.attr.get('largest'), node.attr.get('sorted'), node.outputs[0], node.outputs[1]))
            elif node.type == 'nms':
                f.write( \
"""
    { 
      vx_node node = vxNMSLayer(graph, %s, %s, %d, %s, %s, %s, %s);
      ERROR_CHECK_OBJECT(node);     
        
"""
    % (node.inputs[0], node.inputs[1], node.attr.get('center_point_box'), node.outputs[0], node.inputs[2], node.inputs[3], node.inputs[4]))
            elif node.type == 'detection_output':
                f.write( \
"""
    { vx_node node = vxDetectionOutputLayer(graph, %s, %s, %s, %d, %d, %d, %f, %s, %d, %d, %s);
      ERROR_CHECK_OBJECT(node);     
        
"""
    % (node.inputs[0], node.inputs[1], node.inputs[2], node.attr.get('num_classes'), node.attr.get('share_location'), node.attr.get('background_label_id'), \
         node.attr.get('nms_threshold'), node.attr.get('code_type'), node.attr.get('keep_top_k'), node.attr.get('variance_encoded_in_target'), node.outputs[0]))
                if (node.attr.get('top_k') > -1):
                    top_k = node.attr.get('top_k');
                    f.write( \
"""      vx_int32 top_k = %d;
      vx_scalar s_topK = vxCreateScalarWithSize(context, VX_TYPE_INT32, &top_k, sizeof(top_k));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 12, (vx_reference) s_topK));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_topK));
""" % (top_k))
                if (node.attr.get('confidence_threshold') > -sys.float_info.max):
                    confidence_threshold = node.attr.get('confidence_threshold');
                    f.write( \
"""      vx_float32 confidence_threshold = %f;
      vx_scalar s_confidence_threshold = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &confidence_threshold, sizeof(confidence_threshold));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 13, (vx_reference) s_confidence_threshold));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_confidence_threshold));
""" % (confidence_threshold))
                if (node.attr.get('eta') > 0.0 and node.attr.get('eta') <= 1.0):
                    eta = node.attr.get('eta');
                    f.write( \
"""      vx_float32 eta = %f;
      vx_scalar s_eta = vxCreateScalarWithSize(context, VX_TYPE_FLOAT32, &eta, sizeof(eta));
      ERROR_CHECK_STATUS(vxSetParameterByIndex(node, 11, (vx_reference) s_eta));
      ERROR_CHECK_STATUS(vxReleaseScalar(&s_eta));
""" % (eta))
                f.write( \
"""      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""")
            elif node.type == 'gather':
                f.write( \
"""
    { 
      vx_int32 axis = %d;
      vx_scalar s_axis = vxCreateScalarWithSize(context, VX_TYPE_INT32, &axis, sizeof(axis));      
      vx_node node = vxGatherLayer(graph, %s, %s, %s, s_axis);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.attr.get('axis'), node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'tile':
                f.write( \
"""
    { 
      vx_node node = vxTileLayer(graph, %s, %s, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'less':
                f.write( \
"""
    { 
      vx_node node = vxTensorCompareLayer(graph, %s, %s, %s, 0);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'greater':
                f.write( \
"""
    { 
      vx_node node = vxTensorCompareLayer(graph, %s, %s, %s, 1);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'less_equal':
                f.write( \
"""
    { 
      vx_node node = vxTensorCompareLayer(graph, %s, %s, %s, 2);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'greater_equal':
                f.write( \
"""
    { 
      vx_node node = vxTensorCompareLayer(graph, %s, %s, %s, 3);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'equal':
                f.write( \
"""
    { 
      vx_node node = vxTensorCompareLayer(graph, %s, %s, %s, 4);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'not_equal':
                f.write( \
"""
    { 
      vx_node node = vxTensorCompareLayer(graph, %s, %s, %s, 5);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }    
""" 
    % (node.inputs[0], node.inputs[1], node.outputs[0]))
            elif node.type == 'reduce_min':
                axes = node.attr.get('axes')
                axes_len = -1
                if axes is None:
                    axes = 4 #since cpp doesn't recognize None. And axes values can range between [-4,3]
                    axes_len = 0
                else:
                    axes_len = len(axes)
                if axes_len == 0:
                    f.write(\
"""
    {
      vx_node node = vxReduceMinLayer(graph, %s, %d, %d, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (node.inputs[0], axes, node.attr.get('keepdims'), node.outputs[0]))
                else:
                    if axes_len == 1:
                        f.write( \
"""
    { 
      int axes_list[1] = {%d};
    }
""" % (axes[0]))
                    elif axes_len == 2:
                        f.write( \
"""
    { 
      int axes_list[2] = {%d, %d};
    }
""" % (axes[0], axes[1]))
                    elif axes_len == 3:
                        f.write( \
"""
    { 
      int axes_list[3] = {%d, %d, %d};
    }
""" % (axes[0], axes[1], axes[2]))
                    elif axes_len == 4:
                        f.write( \
"""
    { 
      int axes_list[4] = {%d, %d, %d, %d};
    }
""" % (axes[0], axes[1], axes[2], axes[3]))
                    f.write( \
"""
    { 
      vx_array axes =  vxCreateArray(context, VX_TYPE_INT32, %d);
      ERROR_CHECK_STATUS(vxTruncateArray(axes,0));
      int *axes_ptr = &axes_list[0];
      ERROR_CHECK_STATUS(vxAddArrayItems(axes, %d, axes_ptr, sizeof(int)));
      vx_node node = vxReduceMinLayer(graph, %s, axes, %d, %s);
      ERROR_CHECK_OBJECT(node);
      ERROR_CHECK_STATUS(vxReleaseNode(&node));
    }
""" % (axes_len, axes_len, node.inputs[0], node.attr.get('keepdims'), node.outputs[0]))
            else:
                raise ValueError("Unsupported node by OpenVX: {}".format(node.type))
        f.write( \
"""
    // release local tensors
""")
        for idx, tensor in enumerate(graph.locals):
            if (not tensor.name in outputList) and (not tensor.name in localList[:idx]):
                f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" %(tensor.name))
        f.write( \
"""
    // release initializer tensors
""")
        for tensor in graph.initializers:
            f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" %(tensor.name))
        f.write( \
"""
    return VX_SUCCESS;
}
""")

def generatePythonH(graph, fileName, virtual_tensor_flag):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#ifndef included_file_annpython_h
#define included_file_annpython_h

#include <VX/vx.h>
#include <map>
#include <string>
#include <half.hpp>
using half_float::half;

////
// python interface handle: upto 8 outputs
//
typedef struct pyif_ann_handle_t {
    vx_context  context;
    vx_graph    graph;
    vx_tensor   input;
    vx_tensor   output[8];
    int         num_output;
""")
        if virtual_tensor_flag == 0:
            f.write( \
"""    std::map<std::string, vx_tensor> tensorMap;
""")
        f.write( \
"""} * pyif_ann_handle;

////
// python interface functions
//
extern "C" VX_API_ENTRY const char *    VX_API_CALL annQueryInference();
""")
        if virtual_tensor_flag == 0:
        	f.write( \
"""extern "C" VX_API_ENTRY const char *    VX_API_CALL annQueryLocals();
""")
        f.write( \
"""extern "C" VX_API_ENTRY pyif_ann_handle VX_API_CALL annCreateInference(const char * binaryFilename);
extern "C" VX_API_ENTRY int             VX_API_CALL annReleaseInference(pyif_ann_handle handle);
extern "C" VX_API_ENTRY int             VX_API_CALL annCopyToInferenceInput(pyif_ann_handle handle, float * inp_ptr, size_t inp_size, bool is_nhwc);
extern "C" VX_API_ENTRY int             VX_API_CALL annCopyFromInferenceOutput(pyif_ann_handle handle, float * out_ptr, size_t out_size);
""")
        for i in range(1, len(graph.outputs)):
            f.write( \
"""extern "C" VX_API_ENTRY int             VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size);
""" % (i))
        if virtual_tensor_flag == 0:
        	f.write( \
"""extern "C" VX_API_ENTRY int             VX_API_CALL annCopyFromInferenceLocal(pyif_ann_handle handle, const char *tensorName, float * out_ptr, size_t out_size);
""")
        f.write( \
"""extern "C" VX_API_ENTRY int             VX_API_CALL annRunInference(pyif_ann_handle handle, int num_iterations);

#endif
""")

def generatePythonCPP(graph,fileName, virtual_tensor_flag):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        if len(graph.inputs) != 1: #or len(graph.outputs) != 1:
            f.write( \
"""
#include "annpython.h"

VX_API_ENTRY const char * VX_API_CALL annQueryInference()
{
    return "Unsupported";
}

VX_API_ENTRY pyif_ann_handle VX_API_CALL annCreateInference(const char * binaryFilename)
{
    return NULL;
}

VX_API_ENTRY int VX_API_CALL annReleaseInference(pyif_ann_handle handle)
{
    return -1;
}

VX_API_ENTRY int VX_API_CALL annRunInference(pyif_ann_handle handle, float * inp_ptr, size_t inp_size, float * out_ptr, size_t out_size)
{
    return -1;
}
""")
        else:
            input_shape = graph.inputs[0].shape
            input_data_type = graph.inputs[0].type
            if input_data_type == "F032":
                input_buf_size = eval('*'.join([str(v) for v in input_shape])) * 4
            elif input_data_type == "F016":
                input_buf_size = eval('*'.join([str(v) for v in input_shape])) * 2
            output_shape = []
            output_buf_size = []
            output_str = []
            config = 'input,' + graph.inputs[0].name + ',' + ','.join(str(v) for v in input_shape) + ';'
            for i in range(len(graph.outputs)):
                output_shape.append(graph.outputs[i].shape)
                if input_data_type == "F032":
                    output_buf_size.append(eval('*'.join([str(v) for v in output_shape[i]])) * 4)
                if input_data_type == "F016":
                    output_buf_size.append(eval('*'.join([str(v) for v in output_shape[i]])) * 2)
                config += 'output' + str(i) + ',' +graph.outputs[i].name + ',' + ','.join(str(v) for v in output_shape[i])+';'
                output_str.append('handle->output[' + str(i) + ']')
            local_shape = []
            local_buf_size = []
            configLocals = ''
            for i in range(len(graph.locals)):
                local_shape.append(graph.locals[i].shape)
                local_buf_size.append(eval('*'.join([str(v) for v in local_shape[i]])) * 4)
                configLocals += 'local' + str(i) + ',' +graph.locals[i].name + ',' + ','.join(str(v) for v in local_shape[i])+';'
            f.write( \
"""
#include "annpython.h"
#include "annmodule.h"
#include <stdio.h>
#include <string.h>
#include <string>
#include <vx_ext_amd.h>

static void VX_CALLBACK log_callback(vx_context context, vx_reference ref, vx_status status, const vx_char string[])
{
    size_t len = strlen(string);
    if (len > 0) {
        printf("%%s", string);
        if (string[len - 1] != '\\n')
            printf("\\n");
        fflush(stdout);
    }
}

VX_API_ENTRY const char * VX_API_CALL annQueryInference()
{
    return "%s";
}
""" % (config))
            if virtual_tensor_flag == 0:
                f.write( \
"""VX_API_ENTRY const char * VX_API_CALL annQueryLocals()
{
    return "%s";
}
""" % (configLocals));

            f.write( \
"""
VX_API_ENTRY pyif_ann_handle VX_API_CALL annCreateInference(const char * binaryFilename)
{
    bool successful = false;

    pyif_ann_handle handle = new pyif_ann_handle_t();
    if(!handle) {
        printf("ERROR: new pyif_ann_handle: failed (nullptr)\\n");
    }
    else {
        vx_status status;
        vxRegisterLogCallback(NULL, log_callback, vx_false_e);
        handle->context = vxCreateContext();
        if((status = vxGetStatus((vx_reference)handle->context)) != VX_SUCCESS) {
            printf("ERROR: vxCreateContext: failed (%%d)\\n", status);
        }
        else {
            handle->graph = vxCreateGraph(handle->context);
            if((status = vxGetStatus((vx_reference)handle->graph)) != VX_SUCCESS) {
                printf("ERROR: vxCreateGraph: failed (%%d)\\n", status);
            }
            else {
                vx_size inp_dim[4] = { %s };
""" %(', '.join([str(v) for v in reversed(input_shape)])))
            if input_data_type == "F032":
                f.write( \
"""                handle->input = vxCreateTensor(handle->context, 4, inp_dim, VX_TYPE_FLOAT32, 0); 
""" )                 
            elif input_data_type == "F016":
                f.write( \
"""                handle->input = vxCreateTensor(handle->context, 4, inp_dim, VX_TYPE_FLOAT16, 0);
""")
            f.write( \
"""                if((status = vxGetStatus((vx_reference)handle->input)) != VX_SUCCESS) {
                    printf("ERROR: vxCreateTensor(input:[%s]): failed (%%d)\\n", status);
                }
                else {
                    handle->num_output = %d;
""" % ('x'.join([str(v) for v in input_shape]),len(graph.outputs)))
            for i in range(len(graph.outputs)):
                f.write( \
"""                    vx_size out_dim_%d[%d] = { %s };
""" %(i, len(output_shape[i]), ', '.join([str(v) for v in reversed(output_shape[i])])))
                if input_data_type == "F032":
                    f.write( \
"""                    handle->output[%d] = vxCreateTensor(handle->context, %d, out_dim_%d, VX_TYPE_FLOAT32, 0);
"""% (i, len(output_shape[i]), i))
                elif input_data_type =="F016":
                    f.write( \
"""                    handle->output[%d] = vxCreateTensor(handle->context, %d, out_dim_%d, VX_TYPE_FLOAT16, 0);
"""% (i, len(output_shape[i]), i))
                f.write( \
"""                    if((status = vxGetStatus((vx_reference)handle->output[%d])) != VX_SUCCESS) {
                        printf("ERROR: vxCreateTensor(output:[%s]): failed (%%d)\\n", status);
                    }
"""% (i, 'x'.join([str(v) for v in output_shape[i]])))
            if virtual_tensor_flag == 0:
                f.write( \
"""					else if((status = annAddToGraph(handle->graph, handle->input, %s, handle->tensorMap, binaryFilename)) != VX_SUCCESS) {
                        printf("ERROR: annAddToGraph: failed (%%d)\\n", status);
                    }
"""  %(', '.join(output_str)))
            else:
                f.write( \
"""                    else if((status = annAddToGraph(handle->graph, handle->input, %s, binaryFilename)) != VX_SUCCESS) {
                        printf("ERROR: annAddToGraph: failed (%%d)\\n", status);
                    }
""" %(', '.join(output_str)))
            f.write( \
"""                    else if((status = vxVerifyGraph(handle->graph)) != VX_SUCCESS) {
                        printf("ERROR: vxVerifyGraph: failed (%d)\\n", status);
                    }
                    else {
                        printf("OK: annCreateInference: successful\\n");
                        successful = true;
                    }
""" )
            f.write( \
"""                }
            }
        }
    }
    if(!successful) {
        if(handle) {
            if(handle->graph)
                vxReleaseGraph(&handle->graph);
            if(handle->input)
                vxReleaseTensor(&handle->input);
""" )
            for i in range(len(graph.outputs)):
                f.write( \
"""            if(handle->output[%d])
                vxReleaseTensor(&handle->output[%d]);
""" % ( i, i))             
            f.write( \
"""            if(handle->context)
                vxReleaseContext(&handle->context);
            delete handle;
            handle = nullptr;
        }
    }
    return handle;
}

VX_API_ENTRY int VX_API_CALL annReleaseInference(pyif_ann_handle handle)
{
    vx_status status = VX_SUCCESS;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annReleaseInference: invalid handle\\n");
    }
    else if(handle->graph && (status = vxReleaseGraph(&handle->graph)) != VX_SUCCESS) {
        printf("ERROR: annReleaseInference: vxReleaseGraph: failed (%d)\\n", status);
    }
    else if(handle->input && (status = vxReleaseTensor(&handle->input)) != VX_SUCCESS) {
        printf("ERROR: annReleaseInference: vxReleaseTensor(input): failed (%d)\\n", status);
    }
    else {
        for (int i=0; i<handle->num_output; i++) {
            if(handle->output[i] && (status = vxReleaseTensor(&handle->output[i])) != VX_SUCCESS) {
                printf("ERROR: annReleaseInference: vxReleaseTensor(output<%d>): failed (%d)\\n", i,status);
            }
        }
""" )
            f.write( \
"""    }   
	if(handle->context && (status = vxReleaseContext(&handle->context)) != VX_SUCCESS) {
        printf("ERROR: annReleaseInference: vxReleaseContext: failed (%d)\\n", status);
    }
    else {
        delete handle;
    }
    return status;
}
""")
            if input_data_type == "F032":
                f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyToInferenceInput(pyif_ann_handle handle, float * inp_ptr, size_t inp_size, bool is_nhwc)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    vx_map_id map_id;
    float * ptr = nullptr;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyToInferenceInput: invalid handle\\n");
    }
    else if(inp_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyToInferenceInput: invalid input buffer size (must be %d) -- got %%d\\n", (int)inp_size);
    }
    else if(handle->input == nullptr) {
        printf("ERROR: annCopyToInferenceInput: input is not valid\\n");
    }
    else if(!is_nhwc) {
        if((status = vxCopyTensorPatch(handle->input, 4, nullptr, nullptr, stride, inp_ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
            printf("ERROR: annCopyToInferenceInput: vxCopyTensorPatch: failed (%%d)\\n", status);
        }
    }
    else if((status = vxMapTensorPatch(handle->input, 4, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyToInferenceInput: vxMapTensorPatch: failed (%%d)\\n", status);
    }
    else {
        size_t N = %d, C = %d, H = %d, W = %d;
        for(size_t n = 0; n < N; n++) {
            for(size_t c = 0; c < C; c++) {
                for(size_t y = 0; y < H; y++) {
                    size_t tpos = n * C*H*W + c * H*W + y * W;
                    size_t ipos = n * H*W*C + y * W*C + c;
                    for(size_t x = 0; x < W; x++, tpos++, ipos += C) {
                        ptr[tpos] = inp_ptr[ipos];
                    }
                }
            }
        }
        if ((status = vxUnmapTensorPatch(handle->input, map_id)) != VX_SUCCESS) {
            printf("ERROR: annCopyToInferenceInput: vxUnmapTensorPatch: failed (%%d)\\n", status);
        }
    }
    return status;
}
""" % (input_shape[3]*4, input_shape[2]*input_shape[3]*4, input_shape[1]*input_shape[2]*input_shape[3]*4, \
       input_buf_size, input_buf_size, input_shape[0], input_shape[1], input_shape[2], input_shape[3]))
            elif input_data_type == "F016":
                f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyToInferenceInput(pyif_ann_handle handle, float * inp_ptr, size_t inp_size, bool is_nhwc)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 2, %d, %d, %d };
    vx_map_id map_id;
    half * ptr = nullptr;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyToInferenceInput: invalid handle\\n");
    }
    else if(inp_size/2 != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyToInferenceInput: invalid input buffer size (must be %d) -- got %%d\\n", (int)inp_size);
    }
    else if(handle->input == nullptr) {
        printf("ERROR: annCopyToInferenceInput: input is not valid\\n");
    }
    else if(!is_nhwc) {
        if((status = vxCopyTensorPatch(handle->input, 4, nullptr, nullptr, stride, inp_ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
            printf("ERROR: annCopyToInferenceInput: vxCopyTensorPatch: failed (%%d)\\n", status);
        }
    }
    else if((status = vxMapTensorPatch(handle->input, 4, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyToInferenceInput: vxMapTensorPatch: failed (%%d)\\n", status);
    }
    else {
        size_t N = %d, C = %d, H = %d, W = %d;
        for(size_t n = 0; n < N; n++) {
            for(size_t c = 0; c < C; c++) {
                for(size_t y = 0; y < H; y++) {
                    size_t tpos = n * C*H*W + c * H*W + y * W;
                    size_t ipos = n * H*W*C + y * W*C + c;
                    for(size_t x = 0; x < W; x++, tpos++, ipos += C) {
                        ptr[tpos] = static_cast<half>(inp_ptr[ipos]);
                    }
                }
            }
        }
        if ((status = vxUnmapTensorPatch(handle->input, map_id)) != VX_SUCCESS) {
            printf("ERROR: annCopyToInferenceInput: vxUnmapTensorPatch: failed (%%d)\\n", status);
        }
    }
    return status;
}
""" % (input_shape[3]*2, input_shape[2]*input_shape[3]*2, input_shape[1]*input_shape[2]*input_shape[3]*2, \
       input_buf_size, input_buf_size, input_shape[0], input_shape[1], input_shape[2], input_shape[3]))           
            tshape = []
            for i in range(len(graph.outputs)):
                if len(output_shape[i]) == 4:
                    tshape.append([output_shape[i][0], output_shape[i][1], output_shape[i][2], output_shape[i][3]])
                else:
                    tshape.append([1, 1, output_shape[i][0], output_shape[i][1]])
            if input_data_type == "F032":
                f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output[0] && (status = vxCopyTensorPatch(handle->output[0], %d, nullptr, nullptr, stride, out_ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyFromInferenceOutput: vxCopyTensorPatch: failed (%%d)\\n", status);
    }
    return status;
}
"""   % (tshape[0][3]*4, tshape[0][2]*tshape[0][3]*4, tshape[0][1]*tshape[0][2]*tshape[0][3]*4, output_buf_size[0], output_buf_size[0], len(output_shape[0])))
            elif input_data_type == "F016":
                f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 2, %d, %d, %d };
    vx_map_id map_id;
    half* ptr;
    int writeSize = %d * %d * %d;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size/2 != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output[0])
    {
        if((status = vxMapTensorPatch(handle->output[0], %d, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
            printf("ERROR: annCopyFromInferenceOutput: vxMapTensorPatch: failed (%%d)\\n", status);
        }
        for(int i = 0; i < writeSize; i++) {
            out_ptr[i] = static_cast<float>(ptr[i]);
        }
        if ((status = vxUnmapTensorPatch(handle->output[0], map_id)) != VX_SUCCESS) {
            printf("ERROR: annCopyFromInferenceOutput: vxUnmapTensorPatch: failed (%%d)\\n", status);
        }
    }
    return status;
}
""" % (tshape[0][3]*2, tshape[0][2]*tshape[0][3]*2, tshape[0][1]*tshape[0][2]*tshape[0][3]*2, tshape[i][1], tshape[i][2], tshape[i][3], output_buf_size[0], output_buf_size[0], len(output_shape[0])))
            if (len(graph.outputs) > 1):
                if input_data_type == "F032":
                    f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output && (status = vxCopyTensorPatch(handle->output[1], %d, nullptr, nullptr, stride, out_ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyFromInferenceOutput: vxCopyTensorPatch: failed (%%d)\\n", status);
    }
    return status;
}
"""   % (1, tshape[1][3]*4, tshape[1][2]*tshape[1][3]*4, tshape[1][1]*tshape[1][2]*tshape[1][3]*4, output_buf_size[1], output_buf_size[1], len(output_shape[1])))
                elif input_data_type == "F016":
                    f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 2, %d, %d, %d };
    vx_map_id map_id;
    half* ptr;
    int writeSize = %d*%d*%d;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size/2 != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output)
    {
        if((status = vxMapTensorPatch(handle->output[1], %d, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS)
        {
            printf("ERROR: annCopyFromInferenceOutput: vxMapTensorPatch: failed (%%d)\\n", status);
        }
        for(int i = 0; i < writeSize; i++)
        {
            out_ptr[i] = static_cast<float>(ptr[i]);
        }
        if ((status = vxUnmapTensorPatch(handle->output[1], map_id)) != VX_SUCCESS) {
            printf("ERROR: annCopyFromInferenceOutput: vxUnmapTensorPatch: failed (%%d)\\n", status);
        }
    }
    return status;
}
"""   % (1, tshape[1][3]*2, tshape[1][2]*tshape[1][3]*2, tshape[1][1]*tshape[1][2]*tshape[1][3]*2, tshape[i][1],tshape[i][2],tshape[i][3],output_buf_size[1], output_buf_size[1], len(output_shape[1])))
            if (len(graph.outputs) > 2):
                if input_data_type == "F032":
                    f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 4, %d, %d, %d };
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output && (status = vxCopyTensorPatch(handle->output[1], %d, nullptr, nullptr, stride, out_ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyFromInferenceOutput: vxCopyTensorPatch: failed (%%d)\\n", status);
    }
    return status;
}
""" % (2, tshape[2][3]*4, tshape[2][2]*tshape[2][3]*4, tshape[2][1]*tshape[2][2]*tshape[2][3]*4, output_buf_size[2], output_buf_size[2], len(output_shape[2])))
                elif input_data_type == "F016":
                    f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceOutput_%d(pyif_ann_handle handle, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    vx_size stride[4] = { 2, %d, %d, %d };
    vx_map_id map_id;
    half* ptr;
    int writeSize = %d*%d*%d;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid handle\\n");
    }
    else if(out_size/2 != %d) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceOutput: invalid output buffer size (must be %d) -- got %%d\\n", (int)out_size);
    }
    else if(handle->output)
    {
        if((status = vxMapTensorPatch(handle->output[2], %d, nullptr, nullptr, &map_id, stride, (void **)&ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS)
        {
            printf("ERROR: annCopyFromInferenceOutput: vxMapTensorPatch: failed (%%d)\\n", status);
        }
        for(int i = 0; i < writeSize; i++)
        {
            out_ptr[i] = static_cast<float>(ptr[i]);
        }
        if ((status = vxUnmapTensorPatch(handle->output[2], map_id)) != VX_SUCCESS) {
            printf("ERROR: annCopyFromInferenceOutput: vxUnmapTensorPatch: failed (%%d)\\n", status);
        }
    }
    return status;
}
"""   % (2, tshape[2][3]*2, tshape[2][2]*tshape[2][3]*2, tshape[2][1]*tshape[2][2]*tshape[2][3]*2, tshape[i][1],tshape[i][2],tshape[i][3],output_buf_size[2], output_buf_size[2], len(output_shape[2])))
            if virtual_tensor_flag == 0:
                f.write( \
"""
VX_API_ENTRY int VX_API_CALL annCopyFromInferenceLocal(pyif_ann_handle handle, const char *tensorName, float * out_ptr, size_t out_size)
{
    vx_status status = VX_SUCCESS;
    std::string tensorName_str = tensorName;
    auto it = handle->tensorMap.find(tensorName_str);
    vx_size dims[4];
    status = vxQueryTensor((vx_tensor)it->second, VX_TENSOR_DIMS, dims, sizeof(dims));
    if(status != VX_SUCCESS){
        printf("ERROR: annCopyFromInferenceLocal: vxQueryTensor: failed (%d)\\n", status);
    }
""")
                if input_data_type == "F032":
                    f.write( \
"""    vx_size stride[4] = { 4, dims[0]*4, dims[0]*dims[1]*4, dims[0]*dims[1]*dims[2]*4 };
""")
                elif input_data_type == "F016":
                    f.write( \
"""    vx_size stride[4] = { 2, dims[0]*2, dims[0]*dims[1]*2, dims[0]*dims[1]*dims[2]*2 };
""")
                f.write( \
"""    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceLocal: invalid handle\\n");
    }
""" )
                if input_data_type == "F032":
                    f.write(\
"""    else if(out_size/%d != stride[3]) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceLocal: invalid output buffer size (must be %%d) -- got %%d\\n", (int)stride[3],(int)out_size);
    }
""" % (input_shape[0]))
                elif input_data_type == "F016":
	                f.write (\
"""     else if(out_size/(2*%d) != stride[3]) {
        status = VX_FAILURE;
        printf("ERROR: annCopyFromInferenceLocal: invalid output buffer size (must be %%d) -- got %%d\\n", (int)stride[3],(int)out_size);
    }
""" % (input_shape[0]))
                f.write (\
"""    else if((vx_tensor)it->second && (status = vxCopyTensorPatch((vx_tensor)it->second, %d, nullptr, nullptr, stride, out_ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)) != VX_SUCCESS) {
        printf("ERROR: annCopyFromInferenceLocal: vxCopyTensorPatch: failed (%%d)\\n", status);
    }
    return status;
}
"""   % (len(local_shape[i])))
            f.write( \
"""

VX_API_ENTRY int VX_API_CALL annRunInference(pyif_ann_handle handle, int num_iterations)
{
    vx_status status = VX_SUCCESS;
    if(!handle) {
        status = VX_FAILURE;
        printf("ERROR: annRunInference: invalid handle\\n");
    }
    else {
        for(int i = 0; i < num_iterations; i++) {
            if((status = vxProcessGraph(handle->graph)) != VX_SUCCESS)
                break;
        }
    }
    return status;
}
""")

def generatePythonScriptSample(graph,fileName, virtual_tensor_flag):
    print('creating ' + fileName + ' ...')
    input_data_type = graph.inputs[0].type
    with open(fileName, 'w') as f:
        generateLicenseForScript(f)
        f.write( \
"""
import sys,os,ctypes 
import numpy as np
from numpy.ctypeslib import ndpointer

class AnnAPI:
    def __init__(self,library):
        self.lib = ctypes.cdll.LoadLibrary(library)
        self.annQueryInference = self.lib.annQueryInference
        self.annQueryInference.restype = ctypes.c_char_p
        self.annQueryInference.argtypes = []
""")
        if virtual_tensor_flag == 0:
        	f.write( \
"""        self.annQueryLocals = self.lib.annQueryLocals
        self.annQueryLocals.restype = ctypes.c_char_p
        self.annQueryLocals.argtypes = []
""")
        f.write( \
"""        self.annCreateInference = self.lib.annCreateInference
        self.annCreateInference.restype = ctypes.c_void_p
        self.annCreateInference.argtypes = [ctypes.c_char_p]
        self.annReleaseInference = self.lib.annReleaseInference
        self.annReleaseInference.restype = ctypes.c_int
        self.annReleaseInference.argtypes = [ctypes.c_void_p]
        self.annCopyToInferenceInput = self.lib.annCopyToInferenceInput
        self.annCopyToInferenceInput.restype = ctypes.c_int
        self.annCopyToInferenceInput.argtypes = [ctypes.c_void_p, ndpointer(ctypes.c_float, flags="C_CONTIGUOUS"), ctypes.c_size_t, ctypes.c_bool]
        self.annCopyFromInferenceOutput = self.lib.annCopyFromInferenceOutput
        self.annCopyFromInferenceOutput.restype = ctypes.c_int
        self.annCopyFromInferenceOutput.argtypes = [ctypes.c_void_p, ndpointer(ctypes.c_float, flags="C_CONTIGUOUS"), ctypes.c_size_t]
""")
        if virtual_tensor_flag == 0:
        	f.write( \
"""        self.annCopyFromInferenceLocal = self.lib.annCopyFromInferenceLocal
        self.annCopyFromInferenceLocal.restype = ctypes.c_int
        self.annCopyFromInferenceLocal.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ndpointer(ctypes.c_float, flags="C_CONTIGUOUS"), ctypes.c_size_t]
""")
        f.write( \
"""        self.annRunInference = self.lib.annRunInference
        self.annRunInference.restype = ctypes.c_int
        self.annRunInference.argtypes = [ctypes.c_void_p, ctypes.c_int]        
        print('OK: AnnAPI found "' + self.annQueryInference().decode("utf-8") + '" as configuration in ' + library)

if __name__ == '__main__':
    if len(sys.argv) < 4:
        print ("Usage : python anntest.py <libannpython> <weightsfile> <input_tensor_file> <output_tensor_file>")
        sys.exit(1)
    annlibPythonName = sys.argv[1]
    weightsFile = sys.argv[2]
    inputTensorFile = sys.argv[3]
    api = AnnAPI(annlibPythonName)
    if not os.path.exists("dumpBuffers"):
        os.makedirs("dumpBuffers")
    tensorOutputFile = sys.argv[4]
    for each in filter(None,api.annQueryInference().decode("utf-8").split(';')):
        types,name,n,c,h,w = each.split(',')
        if types[0:5] == "input":
            hdl = api.annCreateInference(weightsFile)
            im = np.fromfile(inputTensorFile, dtype=np.float32)
            inp_size = int(n)*int(c)*int(h)*int(w)*4
            status = api.annCopyToInferenceInput(hdl, np.ascontiguousarray(im, dtype=np.float32), inp_size, 0)
            print('INFO: annCopyToInferenceInput status %d'  %(status))
            status = api.annRunInference(hdl, 1)
            print('INFO: annRunInference status %d ' %(status))
        elif types[0:6] == "output":
            out_size = int(n)*int(c)*int(h)*int(w)*4
            out_buf = bytearray(out_size)
            out = np.frombuffer(out_buf, dtype=np.float32)
            status = api.annCopyFromInferenceOutput(hdl, np.ascontiguousarray(out, dtype=np.float32), out_size)
            print('INFO: annCopyFromInferenceOutput status %d' %(status))
""" )
        if input_data_type == "F016":
            f.write( \
"""            out = out.astype(np.float16)
""")
        f.write( \
"""            fid = open('%s' %tensorOutputFile, 'wb+') 
            fid.write(out.tobytes())
            fid.close()
""")
        if virtual_tensor_flag == 0:
            f.write( \
"""    for each in filter(None,api.annQueryLocals().decode("utf-8").split(';')):
        types,name,n,c,h,w = each.split(',')
        local_size = int(n)*int(c)*int(h)*int(w)*4
        local_buf = bytearray(local_size)
        local = np.frombuffer(local_buf, dtype=np.float32)
        status = api.annCopyFromInferenceLocal(hdl, name, np.ascontiguousarray(local, dtype=np.float32), local_size)
        print('INFO: annCopyFromInferenceLocal status %d' %(status))
""" )
            if input_data_type == "F016":
                f.write( \
"""        local = local.astype(np.float16)
""" )
            f.write( \
"""        fid = open('dumpBuffers/%s.bin' %name, 'wb+') 
        fid.write(local.tobytes())
        fid.close()
    status = api.annReleaseInference(hdl)
    print('INFO: annReleaseInference status %d' %(status))
""")

def generateTestCPP(graph,argmaxOutput,fileName,virtual_tensor_flag):
    print('creating ' + fileName + ' ...')
    with open(fileName, 'w') as f:
        generateLicenseForCPP(f)
        f.write( \
"""
#include "annmodule.h"
#include <vx_ext_amd.h>
#include <vx_amd_nn.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <stdio.h>
#include <string.h>
#include <string>
#include <inttypes.h>
#include <chrono>
#include <unistd.h>
#include <math.h>
#include <half.hpp>
#include <immintrin.h>
#include <map>
using half_float::half;

#if ENABLE_OPENCV
#include <opencv2/opencv.hpp>
using namespace cv;
#if USE_OPENCV_4
#define CV_LOAD_IMAGE_COLOR IMREAD_COLOR
#endif
#endif

#define ERROR_CHECK_OBJECT(obj) { vx_status status = vxGetStatus((vx_reference)(obj)); if(status != VX_SUCCESS) { vxAddLogEntry((vx_reference)context, status     , "ERROR: failed with status = (%d) at " __FILE__ "#%d\\n", status, __LINE__); return status; } }
#define ERROR_CHECK_STATUS(call) { vx_status status = (call); if(status != VX_SUCCESS) { printf("ERROR: failed with status = (%d) at " __FILE__ "#%d\\n", status, __LINE__); return -1; } }
""")
        if type(argmaxOutput) is np.ndarray:
            f.write( \
"""
// LUT (R)
vx_uint8 g_lut__r[] = {
    %s
};
// LUT (G)
vx_uint8 g_lut__g[] = {
    %s
};
// LUT (B)
vx_uint8 g_lut__b[] = {
    %s
};
""" % (','.join([str(v) for v in argmaxOutput[0]]), ','.join([str(v) for v in argmaxOutput[1]]), ','.join([str(v) for v in argmaxOutput[2]])))
            if len(argmaxOutput) == 4:
                f.write( \
"""// LUT (A)
vx_uint8 g_lut__a[] = {
    %s
};
""" % (','.join([str(v) for v in argmaxOutput[3]])))
        f.write( \
"""
static void VX_CALLBACK log_callback(vx_context context, vx_reference ref, vx_status status, const vx_char string[])
{
    size_t len = strlen(string);
    if (len > 0) {
        printf("%s", string);
        if (string[len - 1] != '\\n')
            printf("\\n");
        fflush(stdout);
    }
}

inline int64_t clockCounter()
{
    return std::chrono::high_resolution_clock::now().time_since_epoch().count();
}

inline int64_t clockFrequency()
{
    return std::chrono::high_resolution_clock::period::den / std::chrono::high_resolution_clock::period::num;
}

static vx_status copyTensor(std::string tensorName, vx_tensor tensor, std::string args, vx_enum usage = VX_WRITE_ONLY, std::string add = "0,0,0", std::string multiply = "1,1,1")
{
    std::vector<float> addVec, mulVec;
    std::stringstream sa(add), sm(multiply);
    float i, j;
    while (sa >> i && sm >> j)
    {
        addVec.push_back(i);
        mulVec.push_back(j);
        if (sa.peek() == ',')
            sa.ignore();
        if (sm.peek() == ',')
            sm.ignore();
    }
    
    // split the args into fileName and other parameters
    std::vector<std::string> argList;
    std::istringstream sf(args);
    for(std::string s; std::getline(sf, s, ','); ) {
        argList.push_back(s);
    }
    std::cout << "Preprocessing add: " << addVec[0] << " " << addVec[1] << " " << addVec[2] << " " << std::endl;
    std::cout << "Preprocessing multiply: " << mulVec[0] << " " << mulVec[1] << " " << mulVec[2] << " " << std::endl;
    std::string fileName = argList[0];
    // access the tensor object
    vx_enum data_type = VX_TYPE_FLOAT32;
    vx_size num_of_dims = 4, dims[4] = { 1, 1, 1, 1 }, stride[4];
    vxQueryTensor(tensor, VX_TENSOR_DATA_TYPE, &data_type, sizeof(data_type));
    vxQueryTensor(tensor, VX_TENSOR_NUMBER_OF_DIMS, &num_of_dims, sizeof(num_of_dims));
    vxQueryTensor(tensor, VX_TENSOR_DIMS, &dims, sizeof(dims[0])*num_of_dims);
    if((data_type != VX_TYPE_FLOAT32) && (data_type != VX_TYPE_FLOAT16) && (data_type != VX_TYPE_INT64) && (data_type != VX_TYPE_INT32)) {
        std::cerr << "ERROR: copyTensor() supports only VX_TYPE_FLOAT32 or VX_TYPE_FLOAT16 or VX_TYPE_INT64 or VX_TYPE_INT32: invalid for " << fileName << std::endl;
        return -1;
    }
    vx_size count = dims[0] * dims[1] * dims[2] * dims[3];
    vx_map_id map_id;
    void * ptr;
    vx_status status = vxMapTensorPatch(tensor, num_of_dims, nullptr, nullptr, &map_id, stride, (void **)&ptr, usage, VX_MEMORY_TYPE_HOST);
    if(status) {
        std::cerr << "ERROR: vxMapTensorPatch() failed for " << fileName << std::endl;
        return -1;
    }
    if(usage == VX_WRITE_ONLY) {
#if ENABLE_OPENCV
        if(dims[2] == 3 && fileName.size() > 4 && (fileName.substr(fileName.size()-4, 4) == ".png" || fileName.substr(fileName.size()-4, 4) == ".jpg"))
        {
            for(size_t n = 0; n < dims[3]; n++) {
                char imgFileName[1024];
                sprintf(imgFileName, fileName.c_str(), (int)n);
                Mat img;
                img = imread(imgFileName, CV_LOAD_IMAGE_COLOR);
                if(!img.data || img.rows != dims[1] || img.cols != dims[0]) {
                    printf("ERROR: invalid image or dimensions: %s\\n", imgFileName);
                    return -1;
                }
                for(vx_size y = 0; y < dims[1]; y++) {
                    unsigned char * src = img.data + y*dims[0]*dims[2];
                    if(data_type == VX_TYPE_FLOAT32) {
                        float * dstR = (float *)ptr + ((n * stride[3] + y * stride[1]) >> 2);
                        float * dstG = dstR + (stride[2] >> 2);
                        float * dstB = dstG + (stride[2] >> 2);
                        for(vx_size x = 0; x < dims[0]; x++, src += 3) {
                            *dstR++ = (src[2] * mulVec[0]) + addVec[0];
                            *dstG++ = (src[1] * mulVec[1]) + addVec[1];
                            *dstB++ = (src[0] * mulVec[2]) + addVec[2];
                        }
                    } else if(data_type == VX_TYPE_FLOAT16) {
                        short * dstR = (short *)ptr + ((n * stride[3] + y * stride[1]) >> 1);
                        short * dstG = dstR + (stride[2] >> 2);
                        short * dstB = dstG + (stride[2] >> 2);                    
                        for(vx_size x = 0; x < dims[0]; x++, src += 3) {
                            *dstR++ = (src[2] * mulVec[0]) + addVec[0];
                            *dstG++ = (src[1] * mulVec[1]) + addVec[1];
                            *dstB++ = (src[0] * mulVec[2]) + addVec[2];
                        }
                    } else if(data_type == VX_TYPE_INT64) {
                        long int* dstR = (long int*)ptr + ((n * stride[3] + y * stride[1]) >> 3);
                        long int* dstG = dstR + (stride[2] >> 2);
                        long int* dstB = dstG + (stride[2] >> 2);                    
                        for(vx_size x = 0; x < dims[0]; x++, src += 3) {
                            *dstR++ = (src[2] * mulVec[0]) + addVec[0];
                            *dstG++ = (src[1] * mulVec[1]) + addVec[1];
                            *dstB++ = (src[0] * mulVec[2]) + addVec[2];
                        }
                    } else if(data_type == VX_TYPE_INT32) {
                        int* dstR = (int*)ptr + ((n * stride[3] + y * stride[1]) >> 2);
                        int* dstG = dstR + (stride[2] >> 2);
                        int* dstB = dstG + (stride[2] >> 2);                    
                        for(vx_size x = 0; x < dims[0]; x++, src += 3) {
                            *dstR++ = (src[2] * mulVec[0]) + addVec[0];
                            *dstG++ = (src[1] * mulVec[1]) + addVec[1];
                            *dstB++ = (src[0] * mulVec[2]) + addVec[2];
                        }
                    }
                }
            }
        }
        else
#endif
        {
            FILE * fp = fopen(fileName.c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            for(size_t n = 0; n < dims[3]; n++) {
                for(size_t c = 0; c < dims[2]; c++) {
                    for(size_t y = 0; y < dims[1]; y++) {
                        if(data_type == VX_TYPE_FLOAT32) {
                            float * ptrY = (float *)ptr + ((n * stride[3] + c * stride[2] + y * stride[1]) >> 2);
                            vx_size n = fread(ptrY, sizeof(float), dims[0], fp);
                            if(n != dims[0]) {
                                std::cerr << "ERROR: expected char[" << count*sizeof(float) << "], but got less in " << fileName << std::endl;
                                return -1;
                            }
                            for(size_t x = 0; x < dims[0]; x++) {
                                *(ptrY+x) = *(ptrY+x) * mulVec[c] + addVec[c];
                            }
                        } else if(data_type == VX_TYPE_FLOAT16){
                            short * ptrY = (short *)ptr + ((n * stride[3] + c * stride[2] + y * stride[1]) >> 1);
                            vx_size n = fread(ptrY, sizeof(short), dims[0], fp);
                            if(n != dims[0]) {
                                std::cerr << "ERROR: expected char[" << count*sizeof(short) << "], but got less in " << fileName << std::endl;
                                return -1;
                            }
                            for(size_t x = 0; x < dims[0]; x++) {
                                *(ptrY+x) = *(ptrY+x) * mulVec[c] + addVec[c];
                            }
                        } else if(data_type == VX_TYPE_INT64){
                            long int * ptrY = (long int *)ptr + ((n * stride[3] + c * stride[2] + y * stride[1]) >> 3);
                            vx_size n = fread(ptrY, sizeof(long int), dims[0], fp);
                            if(n != dims[0]) {
                                std::cerr << "ERROR: expected char[" << count*sizeof(long int) << "], but got less in " << fileName << std::endl;
                                return -1;
                            }
                            for(size_t x = 0; x < dims[0]; x++) {
                                *(ptrY+x) = *(ptrY+x) * mulVec[c] + addVec[c];
                            }
                        } else if(data_type == VX_TYPE_INT32){
                            int * ptrY = (int *)ptr + ((n * stride[3] + c * stride[2] + y * stride[1]) >> 3);
                            vx_size n = fread(ptrY, sizeof(int), dims[0], fp);
                            if(n != dims[0]) {
                                std::cerr << "ERROR: expected char[" << count*sizeof(int) << "], but got less in " << fileName << std::endl;
                                return -1;
                            }
                            for(size_t x = 0; x < dims[0]; x++) {
                                *(ptrY+x) = *(ptrY+x) * mulVec[c] + addVec[c];
                            }
                        }
                    }
                }
            }
            fclose(fp);
        }
    }
    else {""")
        if type(argmaxOutput) is np.ndarray:
            f.write( \
"""
        vx_size W = 1, H =1, C = 1, N = 1;
        if (num_of_dims == 2) {
            C = dims[0], N = dims[1];
        }
        else if(num_of_dims == 4) {
            W = dims[0], H = dims[1], C = dims[2], N = dims[3];
        }
        vx_size HW = H * W;
        vx_size NHW = N * HW;
        vx_size CHW = C * HW;
        if(C < sizeof(g_lut__r) || C < sizeof(g_lut__g) || C < sizeof(g_lut__b)) {
            std::cerr << "ERROR: LUT doesn't have enough entries for all channels with " << fileName << std::endl;
            return -1;
        }
        vx_uint8 * buf = new vx_uint8[NHW*%d], * pb = buf;
        for(vx_size n = 0; n < N; n++) {
            for(vx_size y = 0; y < H; y++) {
                for(vx_size x = 0; x < W; x++) {
                    vx_size best_c = 0;
                    float * pc = (float *)ptr + n * CHW + y * W + x;
                    if (data_type == VX_TYPE_FLOAT32) {
                        float best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    else if (data_type == VX_TYPE_FLOAT16) {
                        half * pc = (half *)((short *)ptr + n * CHW + y * W + x);
                        half best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    else if (data_type == VX_TYPE_INT64) {
                        long int * pc = (long int *)((long int *)ptr + n * CHW + y * W + x);
                        long int best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    else if (data_type == VX_TYPE_INT32) {
                        int * pc = (int *)((int *)ptr + n * CHW + y * W + x);
                        int best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    *pb++ = g_lut__r[best_c];
                    *pb++ = g_lut__g[best_c];
                    *pb++ = g_lut__b[best_c];
                    %s
                }
            }
        }
#if ENABLE_OPENCV
        if(fileName.size() > 4 && (fileName.substr(fileName.size()-4, 4) == ".png"))
        {
            for(size_t n = 0; n < N; n++) {
                char imgFileName[1024];
                sprintf(imgFileName, fileName.c_str(), (int)n);
                Mat img(dims[1], dims[0], %s);
                if(!img.data || img.rows != H || img.cols != W) {
                    printf("ERROR: invalid image or dimensions: %%s\\n", imgFileName);
                    return -1;
                }
                for(vx_size y = 0; y < H; y++) {
                    size_t outC = %d;
                    vx_uint8 * dst = img.data + y * W * outC;
                    vx_uint8 * src = (vx_uint8 *)(buf + (n * H + y) * W * outC);
                    for(vx_size x = 0; x < W; x++, src += outC, dst += outC) {
                        dst[2] = src[0];
                        dst[1] = src[1];
                        dst[0] = src[2];
                        %s
                    }
                }
                std::vector<int> compression_params;
                compression_params.push_back(IMWRITE_PNG_COMPRESSION);
                compression_params.push_back(9);
                try {
                    imwrite(imgFileName, img, compression_params);
                }
                catch (cv::Exception& ex) {
                    printf("ERROR: exception converting image to PNG format: %%s\\n", ex.what());
                    return -1;
                }
            }
        }
        else
#endif
        if(fileName != "-") {
            FILE * fp = fopen(fileName.c_str(), "wb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            fwrite(buf, %d, NHW, fp);
            fclose(fp);
        }
        if(argList.size() >= 2) {
            size_t outC = %d;
            float errPercentLimit = 0;
            if(argList.size() >= 3) {
                errPercentLimit = (float)atof(argList[2].c_str());
            }
            vx_uint8 * gold = new vx_uint8 [NHW * outC];
            FILE * fp = fopen(argList[1].c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << argList[1] << std::endl;
                return -1;
            }
            if(fread(gold, outC, NHW, fp) != NHW) {
                std::cerr << "ERROR: not enought data (" << NHW * outC << " bytes) in " << argList[1] << std::endl;
                return -1;
            }
            fclose(fp);
            size_t errCount = 0;
            for(size_t i = 0; i < NHW * outC; i += outC) {
                bool isErr = false;
                for(size_t j = 0; j < outC; j++) {
                    if(!(buf[i+j] == gold[i+j]))
                        isErr = true;
                }
                if(isErr)
                    errCount++;
            }
            delete[] gold;
            float errPercent = 100.0f * (float)errCount / (float)count;
            bool isError = errPercent > errPercentLimit;
            printf("%%s: [Percent-Error %%e%%%%] for %%s with %%s\\n",
                isError ? "ERROR" : "OK", errPercent, tensorName.c_str(), argList[1].c_str());
            if(isError) {
                return -1;
            }
        }
        delete[] buf;
""" % (len(argmaxOutput), \
            '*pb++ = g_lut__a[best_c];' if len(argmaxOutput) == 4 else '', \
            'CV_8UC4' if len(argmaxOutput) == 4 else 'CV_8UC3', len(argmaxOutput), \
            'dst[3] = src[3];' if len(argmaxOutput) == 4 else '', len(argmaxOutput), \
            len(argmaxOutput)))
        elif type(argmaxOutput) is str:
            f.write( \
"""
        vx_size W = 1, H = 1, C = 1, N = 1;
        if(num_of_dims == 2) {
            C = dims[0], N = dims[1];
        }
        else if(num_of_dims == 4) {
            W = dims[0], H = dims[1], C = dims[2], N = dims[3];
        }
        vx_size HW = H * W;
        vx_size NHW = N * HW;
        vx_size CHW = C * HW;
        %s * buf = new %s[NHW], * pb = buf;
        for(vx_size n = 0; n < N; n++) {
            for(vx_size y = 0; y < H; y++) {
                for(vx_size x = 0; x < W; x++) {
                    vx_size best_c = 0;
                    if (data_type == VX_TYPE_FLOAT32) {
                        float * pc = (float *)ptr + n * CHW + y * W + x;
                        float best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }else {
                        half * pc = (half *)ptr + n * CHW + y * W + x;
                        half best_v = *pc;
                        for(vx_size c = 1; c < C; c++, pc += HW) {
                            if(*pc > best_v) {
                                best_v = *pc;
                                best_c = c;
                            }
                        }
                    }
                    *pb++ = (%s)best_c;
                }
            }
        }
        if(fileName != "-") {
            FILE * fp = fopen(fileName.c_str(), "wb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            fwrite(buf, sizeof(%s), NHW, fp);
            fclose(fp);
        }
        if(argList.size() >= 2) {
            float errPercentLimit = 0;
            if(argList.size() >= 3) {
                errPercentLimit = (float)atof(argList[2].c_str());
            }
            %s * gold = new %s [NHW];
            FILE * fp = fopen(argList[1].c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << argList[1] << std::endl;
                return -1;
            }
            if(fread(gold, sizeof(%s), NHW, fp) != NHW) {
                std::cerr << "ERROR: not enought data (" << count << " %s needed) in " << argList[1] << std::endl;
                return -1;
            }
            fclose(fp);
            size_t errCount = 0;
            for(size_t i = 0; i < NHW; i++) {
                if(!(buf[i] == gold[i]))
                    errCount++;
            }
            delete[] gold;
            float errPercent = 100.0f * (float)errCount / (float)count;
            bool isError = errPercent > errPercentLimit;
            printf("%%s: [Percent-Error %%e%%%%] for %%s with %%s\\n",
                isError ? "ERROR" : "OK", errPercent, tensorName.c_str(), argList[1].c_str());
            if(isError) {
                return -1;
            }
        }
        delete[] buf;
""" % (argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput, argmaxOutput))
        else:
            f.write( \
"""
        if(fileName != "-") {
            FILE * fp = fopen(fileName.c_str(), "wb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << fileName << std::endl;
                return -1;
            }
            if (data_type == VX_TYPE_FLOAT32)
                fwrite(ptr, sizeof(float), count, fp);
            else if (data_type == VX_TYPE_FLOAT16)
                fwrite(ptr, sizeof(short), count, fp);
            else if (data_type == VX_TYPE_INT64)
                fwrite(ptr, sizeof(long int), count, fp); 
            else if (data_type == VX_TYPE_INT32)
                fwrite(ptr, sizeof(int), count, fp); 
            fclose(fp);
        }
        if(argList.size() >= 2) {
            float rmsErrorLimit = 0, maxErrorLimit = 0;
            if(argList.size() >= 3) {
                rmsErrorLimit = maxErrorLimit = (float)atof(argList[2].c_str());
                if(argList.size() >= 4) {
                    maxErrorLimit = (float)atof(argList[3].c_str());
                }
            }
            float * gold = new float [count];
            FILE * fp = fopen(argList[1].c_str(), "rb");
            if(!fp) {
                std::cerr << "ERROR: unable to open: " << argList[1] << std::endl;
                return -1;
            }
            if(fread(gold, sizeof(float), count, fp) != count) {
                std::cerr << "ERROR: not enought data (" << count << " floats needed) in " << argList[1] << std::endl;
                return -1;
            }
            fclose(fp);
            double sqrError = 0;
            float maxError = 0;
            if(data_type == VX_TYPE_FLOAT32) {
                for(size_t i = 0; i < count; i++) {
                    float err = ((float *)ptr)[i] - gold[i];
                    if(err < 0) err = -err;
                    sqrError += err * err;
                    if(!(err < maxError)) maxError = err;
                }
            }
            else if(data_type == VX_TYPE_FLOAT16)
            {
                for(size_t i = 0; i < count; i++) {
                    float src = _cvtsh_ss(((unsigned short*)ptr)[i]);
                    float err = src - gold[i];
                    if(err < 0) err = -err;
                    sqrError += err * err;
                    if(!(err < maxError)) maxError = err;
                }
            }
            delete[] gold;
            float rmsError = (float)sqrt(sqrError/count);
            bool isError = !(rmsError <= rmsErrorLimit) || !(maxError <= maxErrorLimit);
            printf("%s: [RMS-Error %e] [MAX-Error %e] for %s with %s\\n",
                isError ? "ERROR" : "OK", rmsError, maxError, tensorName.c_str(), argList[1].c_str());
            if(isError) {
                return -1;
            }
        }
""")
        f.write( \
"""
    }
    status = vxUnmapTensorPatch(tensor, map_id);
    if(status) {
        std::cerr << "ERROR: vxUnmapTensorPatch() failed for " << fileName << std::endl;
        return -1;
    }
    return 0;
}

int main(int argc, const char ** argv)
{
    // check command-line usage
    if(argc < 2) {
        printf(
            "\\n"
            "Usage: anntest <weights.bin> [<input-data-file(s)> [<output-data-file(s)>]] <--add ADD> <--multiply MULTIPLY>]\\n"
            "\\n"
            "   <weights.bin>: is a filename of the weights file to be used for the inference\\n"
            "   <input-data-file>: is a filename to initialize input tensor\\n"
            "   <output-data-file>: is a filename to initialize output tensor\\n"
""")
        f.write( \
"""#if ENABLE_OPENCV
            "     .jpg or .png: decode and initialize for 3 channel tensors\\n"
            "         (use %%04d in fileName when batch-size > 1: batch index starts from 0)\\n"
#endif
""")
        f.write( \
"""            "     other: initialize tensor with raw data from the file\\n"
            "\\n"
""")
        if type(argmaxOutput) is np.ndarray:
            f.write( \
"""            "   <output-data-file>[,<reference-for-compare>,<percentMismatchLimit>]:\\n"
            "     <referece-to-compare> is raw tensor data of LUT output for comparision\\n"
            "     <percentMismatchLimit> is max mismatches (percent) allowed\\n"
            "     <output-data-file> is filename for saving output tensor data\\n"
#if ENABLE_OPENCV
            "       .png: save LUT output as PNG file(s)\\n"
            "         (use %%04d in fileName to when batch-size > 1: batch index starts from 0)\\n"
#endif
""")
        elif type(argmaxOutput) is str:
            f.write( \
"""            "   <output-data-file>[,<reference-for-compare>,<percentMismatchLimit>]:\\n"
            "     <referece-to-compare> is raw tensor data of argmax output for comparision\\n"
            "     <percentMismatchLimit> is max mismatches (percent) allowed\\n"
            "     <output-data-file> is filename for saving output tensor data\\n"
""")
        else:
            f.write( \
"""            "   <output-data-file>[,<reference-for-compare>,<maxErrorLimit>,<rmsErrorLimit>]:\\n"
            "     <referece-to-compare> is raw tensor data for comparision\\n"
            "     <maxErrorLimit> is max absolute error allowed\\n"
            "     <rmsErrorLimit> is max RMS error allowed\\n"
            "     <output-data-file> is filename for saving output tensor data\\n"
""")
        f.write( \
"""            "       '-' to ignore\\n"
            "       other: save raw tensor into the file\\n"
            "   <add>: input preprocessing factor [optional - default:[0,0,0]]\\n"
            "   <multiply>: input preprocessing factor [optional - default:[1,1,1]]\\n"
            "\\n"
        );
        return -1;
    }
    const char * binaryFilename = argv[1];
    argc -= 2;
    argv += 2;

    std::string add = "0,0,0", multiply = "1,1,1";
    for (int i = 0; i < argc; i++) {
        if (strcasecmp(argv[i], "--add") == 0) {
            add = argv[i+1];
        }
        if (strcasecmp(argv[i], "--multiply") == 0) {
            multiply = argv[i+1];
        }
    }

    // create context, input, output, and graph
    vxRegisterLogCallback(NULL, log_callback, vx_false_e);
    vx_context context = vxCreateContext();
    vx_status status = vxGetStatus((vx_reference)context);
    if(status) {
        printf("ERROR: vxCreateContext() failed\\n");
        return -1;
    }
    vxRegisterLogCallback(context, log_callback, vx_false_e);
    vx_graph graph = vxCreateGraph(context);
    status = vxGetStatus((vx_reference)graph);
    if(status) {
        printf("ERROR: vxCreateGraph(...) failed (%d)\\n", status);
        return -1;
    }
""")
        for tensor in graph.inputs:
            f.write( \
"""
    // create and initialize input tensor %s
    vx_size dims_%s[%d] = { %s };
    vx_tensor %s = vxCreateTensor(context, %d, dims_%s, %s, 0);
    if(vxGetStatus((vx_reference)%s)) {
        printf("ERROR: vxCreateTensor() failed for %s\\n");
        return -1;
    }
    if(*argv) {
        if(strcmp(*argv, "-") != 0) {
            if(copyTensor("%s", %s, *argv, VX_WRITE_ONLY, add, multiply) < 0) {
                return -1;
            }
            printf("OK: initialized tensor '%s' from %%s\\n", *argv);
        }
        argv++;
    }
""" % (tensor.name, tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), \
       tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], \
       tensor.name, tensor.name, tensor.name, tensor.name, tensor.name))
        for tensor in graph.outputs:
            f.write( \
"""
    // create output tensor %s
    vx_size dims_%s[%d] = { %s };
    vx_tensor %s = vxCreateTensor(context, %d, dims_%s, %s, 0);
    if(vxGetStatus((vx_reference)%s)) {
        printf("ERROR: vxCreateTensor() failed for %s\\n");
        return -1;
    }
""" % (tensor.name, tensor.name, len(tensor.shape), ', '.join([str(v) for v in reversed(tensor.shape)]), \
       tensor.name, len(tensor.shape), tensor.name, tensor_type_nnir2openvx[tensor.type], \
       tensor.name, tensor.name))
        f.write( \
"""
    // build graph using annmodule
    int64_t freq = clockFrequency(), t0, t1;
    t0 = clockCounter();
""")
        if virtual_tensor_flag == 0:
            f.write( \
"""    std::map<std::string, vx_tensor> tensorMap;
    status = annAddToGraph(graph, %s, %s, tensorMap, binaryFilename);
""" % (', '.join([tensor.name for tensor in graph.inputs]), \
       ', '.join([tensor.name for tensor in graph.outputs])))
        else:
            f.write( \
"""    status = annAddToGraph(graph, %s, %s, binaryFilename);
""" % (', '.join([tensor.name for tensor in graph.inputs]), \
       ', '.join([tensor.name for tensor in graph.outputs])))
        f.write( \
"""    if(status) {
        printf("ERROR: annAddToGraph() failed (%d)\\n", status);
        return -1;
    }
    status = vxVerifyGraph(graph);
    if(status) {
        printf("ERROR: vxVerifyGraph(...) failed (%d)\\n", status);
        return -1;
    }
    t1 = clockCounter();
    printf("OK: graph initialization with annAddToGraph() took %.3f msec\\n", (float)(t1-t0)*1000.0f/(float)freq);

    t0 = clockCounter();
    status = vxProcessGraph(graph);
    t1 = clockCounter();
    if(status != VX_SUCCESS) {
        printf("ERROR: vxProcessGraph() failed (%d)\\n", status);
        return -1;
    }
    printf("OK: vxProcessGraph() took %.3f msec (1st iteration)\\n", (float)(t1-t0)*1000.0f/(float)freq);
""" )
        for tensor in graph.outputs:
            f.write( \
"""
    // save tensor %s
    if(*argv) {
        if(strcmp(*argv, "-") != 0) {
            if(copyTensor("%s", %s, *argv, VX_READ_ONLY, add, multiply) < 0) {
                return -1;
            }
            printf("OK: wrote tensor '%s' into %%s\\n", *argv);
        }
        argv++;
    }
""" % (tensor.name, tensor.name, tensor.name, tensor.name))
        if virtual_tensor_flag == 0:
            for tensor in graph.locals:
            	f.write( \
"""
    // save tensor %s
    auto it_%s = tensorMap.find("%s");
   	if(copyTensor("%s", it_%s->second, "%s", VX_READ_ONLY, add, multiply) < 0) {
        return -1;
    }
    printf("OK: wrote tensor '%s' into %s\\n");
""" % (tensor.name, tensor.name, tensor.name, tensor.name, tensor.name, tensor.name, tensor.name,tensor.name))
        f.write( \
"""
    t0 = clockCounter();
    int N = 100;
    for(int i = 0; i < N; i++) {
        status = vxProcessGraph(graph);
        if(status != VX_SUCCESS)
            break;
    }
    t1 = clockCounter();
    printf("OK: vxProcessGraph() took %.3f msec (average over %d iterations)\\n", (float)(t1-t0)*1000.0f/(float)freq/(float)N, N);

    // release resources
    ERROR_CHECK_STATUS(vxReleaseGraph(&graph));
""")
        for tensor in graph.inputs:
            f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" % (tensor.name))
        for tensor in graph.outputs:
            f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseTensor(&%s));
""" % (tensor.name))
        f.write( \
"""    ERROR_CHECK_STATUS(vxReleaseContext(&context));
    printf("OK: successful\\n");

    return 0;
}
""")

def generateBinary(graph,fileName):
    VARIABLES_FILE_MAGIC = 0xF00DD1E0
    VARIABLES_DATA_MAGIC = 0xF00DD1E1
    VARIABLES_EOFF_MAGIC = 0xF00DD1E2
    print('creating ' + fileName + ' ...')
    with open(fileName, 'wb') as f:
        f.write(struct.pack('I', VARIABLES_FILE_MAGIC))
        for tensor in graph.initializers:
            binary = graph.binaries[tensor.name]
            f.write(struct.pack('II', VARIABLES_DATA_MAGIC, len(binary)))
            f.write(binary)
        f.write(struct.pack('I', VARIABLES_EOFF_MAGIC))

def generateCode(graph,argmaxOutput,outputFolder,virtual_tensor_flag):
    if not os.path.isdir(outputFolder):
        os.mkdir(outputFolder)
    generateCMakeFiles(graph,outputFolder)
    generateModuleH(graph,outputFolder + '/annmodule.h', virtual_tensor_flag)
    generateModuleCPP(graph,outputFolder + '/annmodule.cpp',virtual_tensor_flag)
    generateBinary(graph,outputFolder + '/weights.bin')
    generateTestCPP(graph,argmaxOutput,outputFolder + '/anntest.cpp', virtual_tensor_flag)
    generatePythonH(graph,outputFolder + '/annpython.h', virtual_tensor_flag)
    generatePythonCPP(graph,outputFolder + '/annpython.cpp', virtual_tensor_flag)
    generatePythonScriptSample(graph,outputFolder + '/anntest.py', virtual_tensor_flag)

def main():
    usage = """
Usage: python nnir_to_openvx.py [OPTIONS] <nnirInputFolder> <outputFolder>

  OPTIONS:
    --virtual_tensor 1                -- to make tensors non-virtual  (default: 1)               
    --argmax UINT8                    -- argmax at the end with 8-bit output
    --argmax UINT16                   -- argmax at the end with 16-bit output
    --argmax <fileNamePrefix>rgb.txt  -- argmax at the end with RGB color mapping using LUT
    --argmax <fileNamePrefix>rgba.txt -- argmax at the end with RGBA color mapping using LUT
    --help                            -- show this help message

  LUT File Format (RGB): 8-bit R G B values one per each label in text format
    R0 G0 B0
    R1 G1 B1
    ...

  LUT File Format (RGBA): 8-bit R G B A values one per each label in text format
    R0 G0 B0 A0
    R1 G1 B1 A1
    ...

"""
    pos = 1;
    argmaxOutput = None
    virtual_tensor_flag = 1;
    while len(sys.argv[pos:]) >= 2 and sys.argv[pos][:2] == '--':
        if sys.argv[pos] == '--argmax':
            argmaxOutput = sys.argv[pos+1]
            if argmaxOutput == 'UINT8':
                argmaxOutput = 'vx_uint8'
            elif argmaxOutput == 'UINT16':
                argmaxOutput = 'vx_uint16'
            else:
                if not os.path.isfile(argmaxOutput):
                    print('ERROR: unable to open: %s' % (argmaxOutput))
                    sys.exit(1)
                with open(argmaxOutput,'r') as f:
                    if argmaxOutput[-8:] == 'rgba.txt':
                        argmaxOutput = np.reshape(np.array([int(v) for v in f.read().split()]), [-1, 4]).transpose()
                    else:
                        argmaxOutput = np.reshape(np.array([int(v) for v in f.read().split()]), [-1, 3]).transpose()
        elif sys.argv[pos] == '--virtual_tensor':
            virtual_tensor_flag = int(sys.argv[pos+1])
        else:
            if sys.argv[pos] != '--help':
                print('ERROR: invalid option: %s' % (sys.argv[pos]))
            print(usage)
            sys.exit(1)
        pos = pos + 2
    if len(sys.argv[pos:]) < 2:
        print(usage)
        sys.exit(1)
    inputFolder = sys.argv[pos]
    outputFolder = sys.argv[pos+1]
    print('reading IR model from ' + inputFolder + ' ...')
    graph = IrGraph(True)
    graph.fromFile(inputFolder)
    for tensor in graph.outputs:
        if len(tensor.shape) == 1:
            print('#OUTPUT-TENSOR: %s %d %d %d %d ' %(tensor.name, tensor.shape[0], 1, 1, 1));    
        elif len(tensor.shape) == 2:
            print('#OUTPUT-TENSOR: %s %d %d %d %d ' %(tensor.name, tensor.shape[0], tensor.shape[1], 1, 1));
        elif len(tensor.shape) == 4:
            print('#OUTPUT-TENSOR: %s %d %d %d %d ' %(tensor.name, tensor.shape[0], tensor.shape[1], tensor.shape[2], tensor.shape[3]));
    print('creating C code in ' + outputFolder + ' ...')
    generateCode(graph,argmaxOutput,outputFolder,virtual_tensor_flag)

if __name__ == '__main__':
    main()
