/*
 * Copyright © 2010 Intel Corporation
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice (including the next
 * paragraph) shall be included in all copies or substantial portions of the
 * Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 * DEALINGS IN THE SOFTWARE.
 */

/**
 * \file brw_wm_channel_expressions.cpp
 *
 * Breaks vector operations down into operations on each component.
 *
 * The 965 fragment shader receives 8 or 16 pixels at a time, so each
 * channel of a vector is laid out as 1 or 2 8-float registers.  Each
 * ALU operation operates on one of those channel registers.  As a
 * result, there is no value to the 965 fragment shader in tracking
 * "vector" expressions in the sense of GLSL fragment shaders, when
 * doing a channel at a time may help in constant folding, algebraic
 * simplification, and reducing the liveness of channel registers.
 *
 * The exception to the desire to break everything down to floats is
 * texturing.  The texture sampler returns a writemasked masked
 * 4/8-register sequence containing the texture values.  We don't want
 * to dispatch to the sampler separately for each channel we need, so
 * we do retain the vector types in that case.
 */

extern "C" {
#include "main/core.h"
#include "brw_wm.h"
}
#include "glsl/ir.h"
#include "glsl/ir_expression_flattening.h"
#include "glsl/glsl_types.h"

class ir_channel_expressions_visitor : public ir_hierarchical_visitor {
public:
   ir_channel_expressions_visitor()
   {
      this->progress = false;
      this->mem_ctx = NULL;
   }

   ir_visitor_status visit_leave(ir_assignment *);

   ir_rvalue *get_element(ir_variable *var, unsigned int element);
   void assign(ir_assignment *ir, int elem, ir_rvalue *val);

   bool progress;
   void *mem_ctx;
};

static bool
channel_expressions_predicate(ir_instruction *ir)
{
   ir_expression *expr = ir->as_expression();
   unsigned int i;

   if (!expr)
      return false;

   switch (expr->operation) {
      /* these opcodes need to act on the whole vector,
       * just like texturing.
       */
      case ir_unop_interpolate_at_centroid:
      case ir_binop_interpolate_at_offset:
      case ir_binop_interpolate_at_sample:
         return false;
      default:
         break;
   }

   for (i = 0; i < expr->get_num_operands(); i++) {
      if (expr->operands[i]->type->is_vector())
	 return true;
   }

   return false;
}

bool
brw_do_channel_expressions(exec_list *instructions)
{
   ir_channel_expressions_visitor v;

   /* Pull out any matrix expression to a separate assignment to a
    * temp.  This will make our handling of the breakdown to
    * operations on the matrix's vector components much easier.
    */
   do_expression_flattening(instructions, channel_expressions_predicate);

   visit_list_elements(&v, instructions);

   return v.progress;
}

ir_rvalue *
ir_channel_expressions_visitor::get_element(ir_variable *var, unsigned int elem)
{
   ir_dereference *deref;

   if (var->type->is_scalar())
      return new(mem_ctx) ir_dereference_variable(var);

   assert(elem < var->type->components());
   deref = new(mem_ctx) ir_dereference_variable(var);
   return new(mem_ctx) ir_swizzle(deref, elem, 0, 0, 0, 1);
}

void
ir_channel_expressions_visitor::assign(ir_assignment *ir, int elem, ir_rvalue *val)
{
   ir_dereference *lhs = ir->lhs->clone(mem_ctx, NULL);
   ir_assignment *assign;

   /* This assign-of-expression should have been generated by the
    * expression flattening visitor (since we never short circit to
    * not flatten, even for plain assignments of variables), so the
    * writemask is always full.
    */
   assert(ir->write_mask == (1 << ir->lhs->type->components()) - 1);

   assign = new(mem_ctx) ir_assignment(lhs, val, NULL, (1 << elem));
   ir->insert_before(assign);
}

ir_visitor_status
ir_channel_expressions_visitor::visit_leave(ir_assignment *ir)
{
   ir_expression *expr = ir->rhs->as_expression();
   bool found_vector = false;
   unsigned int i, vector_elements = 1;
   ir_variable *op_var[3];

   if (!expr)
      return visit_continue;

   if (!this->mem_ctx)
      this->mem_ctx = ralloc_parent(ir);

   for (i = 0; i < expr->get_num_operands(); i++) {
      if (expr->operands[i]->type->is_vector()) {
	 found_vector = true;
	 vector_elements = expr->operands[i]->type->vector_elements;
	 break;
      }
   }
   if (!found_vector)
      return visit_continue;

   switch (expr->operation) {
      case ir_unop_interpolate_at_centroid:
      case ir_binop_interpolate_at_offset:
      case ir_binop_interpolate_at_sample:
         return visit_continue;

      default:
         break;
   }

   /* Store the expression operands in temps so we can use them
    * multiple times.
    */
   for (i = 0; i < expr->get_num_operands(); i++) {
      ir_assignment *assign;
      ir_dereference *deref;

      assert(!expr->operands[i]->type->is_matrix());

      op_var[i] = new(mem_ctx) ir_variable(expr->operands[i]->type,
					   "channel_expressions",
					   ir_var_temporary);
      ir->insert_before(op_var[i]);

      deref = new(mem_ctx) ir_dereference_variable(op_var[i]);
      assign = new(mem_ctx) ir_assignment(deref,
					  expr->operands[i],
					  NULL);
      ir->insert_before(assign);
   }

   const glsl_type *element_type = glsl_type::get_instance(ir->lhs->type->base_type,
							   1, 1);

   /* OK, time to break down this vector operation. */
   switch (expr->operation) {
   case ir_unop_bit_not:
   case ir_unop_logic_not:
   case ir_unop_neg:
   case ir_unop_abs:
   case ir_unop_sign:
   case ir_unop_rcp:
   case ir_unop_rsq:
   case ir_unop_sqrt:
   case ir_unop_exp:
   case ir_unop_log:
   case ir_unop_exp2:
   case ir_unop_log2:
   case ir_unop_bitcast_i2f:
   case ir_unop_bitcast_f2i:
   case ir_unop_bitcast_f2u:
   case ir_unop_bitcast_u2f:
   case ir_unop_i2u:
   case ir_unop_u2i:
   case ir_unop_f2i:
   case ir_unop_f2u:
   case ir_unop_i2f:
   case ir_unop_f2b:
   case ir_unop_b2f:
   case ir_unop_i2b:
   case ir_unop_b2i:
   case ir_unop_u2f:
   case ir_unop_trunc:
   case ir_unop_ceil:
   case ir_unop_floor:
   case ir_unop_fract:
   case ir_unop_round_even:
   case ir_unop_sin:
   case ir_unop_cos:
   case ir_unop_sin_reduced:
   case ir_unop_cos_reduced:
   case ir_unop_dFdx:
   case ir_unop_dFdy:
   case ir_unop_bitfield_reverse:
   case ir_unop_bit_count:
   case ir_unop_find_msb:
   case ir_unop_find_lsb:
      for (i = 0; i < vector_elements; i++) {
	 ir_rvalue *op0 = get_element(op_var[0], i);

	 assign(ir, i, new(mem_ctx) ir_expression(expr->operation,
						  element_type,
						  op0,
						  NULL));
      }
      break;

   case ir_binop_add:
   case ir_binop_sub:
   case ir_binop_mul:
   case ir_binop_imul_high:
   case ir_binop_div:
   case ir_binop_carry:
   case ir_binop_borrow:
   case ir_binop_mod:
   case ir_binop_min:
   case ir_binop_max:
   case ir_binop_pow:
   case ir_binop_lshift:
   case ir_binop_rshift:
   case ir_binop_bit_and:
   case ir_binop_bit_xor:
   case ir_binop_bit_or:
   case ir_binop_less:
   case ir_binop_greater:
   case ir_binop_lequal:
   case ir_binop_gequal:
   case ir_binop_equal:
   case ir_binop_nequal:
      for (i = 0; i < vector_elements; i++) {
	 ir_rvalue *op0 = get_element(op_var[0], i);
	 ir_rvalue *op1 = get_element(op_var[1], i);

	 assign(ir, i, new(mem_ctx) ir_expression(expr->operation,
						  element_type,
						  op0,
						  op1));
      }
      break;

   case ir_unop_any: {
      ir_expression *temp;
      temp = new(mem_ctx) ir_expression(ir_binop_logic_or,
					element_type,
					get_element(op_var[0], 0),
					get_element(op_var[0], 1));

      for (i = 2; i < vector_elements; i++) {
	 temp = new(mem_ctx) ir_expression(ir_binop_logic_or,
					   element_type,
					   get_element(op_var[0], i),
					   temp);
      }
      assign(ir, 0, temp);
      break;
   }

   case ir_binop_dot: {
      ir_expression *last = NULL;
      for (i = 0; i < vector_elements; i++) {
	 ir_rvalue *op0 = get_element(op_var[0], i);
	 ir_rvalue *op1 = get_element(op_var[1], i);
	 ir_expression *temp;

	 temp = new(mem_ctx) ir_expression(ir_binop_mul,
					   element_type,
					   op0,
					   op1);
	 if (last) {
	    last = new(mem_ctx) ir_expression(ir_binop_add,
					      element_type,
					      temp,
					      last);
	 } else {
	    last = temp;
	 }
      }
      assign(ir, 0, last);
      break;
   }

   case ir_binop_logic_and:
   case ir_binop_logic_xor:
   case ir_binop_logic_or:
      ir->fprint(stderr);
      fprintf(stderr, "\n");
      unreachable("not reached: expression operates on scalars only");
   case ir_binop_all_equal:
   case ir_binop_any_nequal: {
      ir_expression *last = NULL;
      for (i = 0; i < vector_elements; i++) {
	 ir_rvalue *op0 = get_element(op_var[0], i);
	 ir_rvalue *op1 = get_element(op_var[1], i);
	 ir_expression *temp;
	 ir_expression_operation join;

	 if (expr->operation == ir_binop_all_equal)
	    join = ir_binop_logic_and;
	 else
	    join = ir_binop_logic_or;

	 temp = new(mem_ctx) ir_expression(expr->operation,
					   element_type,
					   op0,
					   op1);
	 if (last) {
	    last = new(mem_ctx) ir_expression(join,
					      element_type,
					      temp,
					      last);
	 } else {
	    last = temp;
	 }
      }
      assign(ir, 0, last);
      break;
   }
   case ir_unop_noise:
      unreachable("noise should have been broken down to function call");

   case ir_binop_bfm: {
      /* Does not need to be scalarized, since its result will be identical
       * for all channels.
       */
      ir_rvalue *op0 = get_element(op_var[0], 0);
      ir_rvalue *op1 = get_element(op_var[1], 0);

      assign(ir, 0, new(mem_ctx) ir_expression(expr->operation,
                                               element_type,
                                               op0,
                                               op1));
      break;
   }

   case ir_binop_ubo_load:
      unreachable("not yet supported");

   case ir_triop_fma:
   case ir_triop_lrp:
   case ir_triop_csel:
   case ir_triop_bitfield_extract:
      for (i = 0; i < vector_elements; i++) {
	 ir_rvalue *op0 = get_element(op_var[0], i);
	 ir_rvalue *op1 = get_element(op_var[1], i);
	 ir_rvalue *op2 = get_element(op_var[2], i);

	 assign(ir, i, new(mem_ctx) ir_expression(expr->operation,
						  element_type,
						  op0,
						  op1,
						  op2));
      }
      break;

   case ir_triop_bfi: {
      /* Only a single BFM is needed for multiple BFIs. */
      ir_rvalue *op0 = get_element(op_var[0], 0);

      for (i = 0; i < vector_elements; i++) {
         ir_rvalue *op1 = get_element(op_var[1], i);
         ir_rvalue *op2 = get_element(op_var[2], i);

         assign(ir, i, new(mem_ctx) ir_expression(expr->operation,
                                                  element_type,
                                                  op0->clone(mem_ctx, NULL),
                                                  op1,
                                                  op2));
      }
      break;
   }

   case ir_unop_pack_snorm_2x16:
   case ir_unop_pack_snorm_4x8:
   case ir_unop_pack_unorm_2x16:
   case ir_unop_pack_unorm_4x8:
   case ir_unop_pack_half_2x16:
   case ir_unop_unpack_snorm_2x16:
   case ir_unop_unpack_snorm_4x8:
   case ir_unop_unpack_unorm_2x16:
   case ir_unop_unpack_unorm_4x8:
   case ir_unop_unpack_half_2x16:
   case ir_binop_ldexp:
   case ir_binop_vector_extract:
   case ir_triop_vector_insert:
   case ir_quadop_bitfield_insert:
   case ir_quadop_vector:
      unreachable("should have been lowered");

   case ir_unop_unpack_half_2x16_split_x:
   case ir_unop_unpack_half_2x16_split_y:
   case ir_binop_pack_half_2x16_split:
   case ir_unop_interpolate_at_centroid:
   case ir_binop_interpolate_at_offset:
   case ir_binop_interpolate_at_sample:
      unreachable("not reached: expression operates on scalars only");
   }

   ir->remove();
   this->progress = true;

   return visit_continue;
}
