#!/usr/bin/python
# Copyright 2016 The ANGLE Project Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# angle_format.py:
#  Utils for ANGLE formats.

import json
import os
import re

def reject_duplicate_keys(pairs):
    found_keys = {}
    for key, value in pairs:
        if key in found_keys:
           raise ValueError("duplicate key: %r" % (key,))
        else:
           found_keys[key] = value
    return found_keys

def load_json(path):
    with open(path) as map_file:
        file_data = map_file.read()
        map_file.close()
        return json.loads(file_data, object_pairs_hook=reject_duplicate_keys)

def load_forward_table(path):
    pairs = load_json(path)
    reject_duplicate_keys(pairs)
    return { gl: angle for gl, angle in pairs }

def load_inverse_table(path):
    pairs = load_json(path)
    reject_duplicate_keys(pairs)
    return { angle: gl for gl, angle in pairs }

def load_without_override():
    map_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'angle_format_map.json')
    return load_forward_table(map_path)

def load_with_override(override_path):
    results = load_without_override()
    overrides = load_json(override_path)

    for k, v in overrides.iteritems():
        results[k] = v

    return results

def get_component_type(format_id):
    if "SNORM" in format_id:
        return "snorm"
    elif "UNORM" in format_id:
        return "unorm"
    elif "FLOAT" in format_id:
        return "float"
    elif "UINT" in format_id:
        return "uint"
    elif "SINT" in format_id:
        return "int"
    elif format_id == "NONE":
        return "none"
    elif "SRGB" in format_id:
        return "unorm"
    elif format_id == "R9G9B9E5_SHAREDEXP":
        return "float"
    else:
        raise ValueError("Unknown component type for " + format_id)

def get_channel_tokens(format_id):
    r = re.compile(r'([ABDGLRS][\d]+)')
    return filter(r.match, r.split(format_id))

def get_channels(format_id):
    channels = ''
    tokens = get_channel_tokens(format_id)
    if len(tokens) == 0:
        return None
    for token in tokens:
        channels += token[0].lower()

    return channels

def get_bits(format_id):
    bits = {}
    tokens = get_channel_tokens(format_id)
    if len(tokens) == 0:
        return None
    for token in tokens:
        bits[token[0]] = int(token[1:])
    return bits

def get_format_info(format_id):
    return get_component_type(format_id), get_bits(format_id), get_channels(format_id)

# TODO(oetuaho): Expand this code so that it could generate the gl format info tables as well.
def gl_format_channels(internal_format):
    if internal_format == 'GL_BGR5_A1_ANGLEX':
        return 'bgra'
    if internal_format == 'GL_R11F_G11F_B10F':
        return 'rgb'
    if internal_format == 'GL_RGB5_A1':
        return 'rgba'
    if internal_format.find('GL_RGB10_A2') == 0:
        return 'rgba'

    channels_pattern = re.compile('GL_(COMPRESSED_)?(SIGNED_)?(ETC\d_)?([A-Z]+)')
    match = re.search(channels_pattern, internal_format)
    channels_string = match.group(4)

    if channels_string == 'ALPHA':
        return 'a'
    if channels_string == 'LUMINANCE':
        if (internal_format.find('ALPHA') >= 0):
            return 'la'
        return 'l'
    if channels_string == 'SRGB':
        if (internal_format.find('ALPHA') >= 0):
            return 'rgba'
        return 'rgb'
    if channels_string == 'DEPTH':
        if (internal_format.find('STENCIL') >= 0):
            return 'ds'
        return 'd'
    if channels_string == 'STENCIL':
        return 's'
    return channels_string.lower()

def get_internal_format_initializer(internal_format, format_id):
    gl_channels = gl_format_channels(internal_format)
    gl_format_no_alpha = gl_channels == 'rgb' or gl_channels == 'l'
    component_type, bits, channels = get_format_info(format_id)

    if not gl_format_no_alpha or channels != 'rgba':
        return 'nullptr'

    elif 'BC1_' in format_id:
        # BC1 is a special case since the texture data determines whether each block has an alpha channel or not.
        # This if statement is hit by COMPRESSED_RGB_S3TC_DXT1, which is a bit of a mess.
        # TODO(oetuaho): Look into whether COMPRESSED_RGB_S3TC_DXT1 works right in general.
        # Reference: https://www.opengl.org/registry/specs/EXT/texture_compression_s3tc.txt
        return 'nullptr'

    elif component_type == 'uint' and bits['R'] == 8:
        return 'Initialize4ComponentData<GLubyte, 0x00, 0x00, 0x00, 0x01>'
    elif component_type == 'unorm' and bits['R'] == 8:
        return 'Initialize4ComponentData<GLubyte, 0x00, 0x00, 0x00, 0xFF>'
    elif component_type == 'unorm' and bits['R'] == 16:
        return 'Initialize4ComponentData<GLubyte, 0x0000, 0x0000, 0x0000, 0xFFFF>'
    elif component_type == 'int' and bits['R'] == 8:
        return 'Initialize4ComponentData<GLbyte, 0x00, 0x00, 0x00, 0x01>'
    elif component_type == 'snorm' and bits['R'] == 8:
        return 'Initialize4ComponentData<GLbyte, 0x00, 0x00, 0x00, 0x7F>'
    elif component_type == 'snorm' and bits['R'] == 16:
        return 'Initialize4ComponentData<GLushort, 0x0000, 0x0000, 0x0000, 0x7FFF>'
    elif component_type == 'float' and bits['R'] == 16:
        return 'Initialize4ComponentData<GLhalf, 0x0000, 0x0000, 0x0000, gl::Float16One>'
    elif component_type == 'uint' and bits['R'] == 16:
        return 'Initialize4ComponentData<GLushort, 0x0000, 0x0000, 0x0000, 0x0001>'
    elif component_type == 'int' and bits['R'] == 16:
        return 'Initialize4ComponentData<GLshort, 0x0000, 0x0000, 0x0000, 0x0001>'
    elif component_type == 'float' and bits['R'] == 32:
        return 'Initialize4ComponentData<GLfloat, 0x00000000, 0x00000000, 0x00000000, gl::Float32One>'
    elif component_type == 'int' and bits['R'] == 32:
        return 'Initialize4ComponentData<GLint, 0x00000000, 0x00000000, 0x00000000, 0x00000001>'
    elif component_type == 'uint' and bits['R'] == 32:
        return 'Initialize4ComponentData<GLuint, 0x00000000, 0x00000000, 0x00000000, 0x00000001>'
    else:
        raise ValueError('warning: internal format initializer could not be generated and may be needed for ' + internal_format)
