# RUN: %PYTHON %s | FileCheck %s

import gc
from mlir.ir import *


def run(f):
  print("\nTEST:", f.__name__)
  f()
  gc.collect()
  assert Context._get_live_count() == 0
  return f


# CHECK-LABEL: TEST: testAffineMapCapsule
@run
def testAffineMapCapsule():
  with Context() as ctx:
    am1 = AffineMap.get_empty(ctx)
  # CHECK: mlir.ir.AffineMap._CAPIPtr
  affine_map_capsule = am1._CAPIPtr
  print(affine_map_capsule)
  am2 = AffineMap._CAPICreate(affine_map_capsule)
  assert am2 == am1
  assert am2.context is ctx


# CHECK-LABEL: TEST: testAffineMapGet
@run
def testAffineMapGet():
  with Context() as ctx:
    d0 = AffineDimExpr.get(0)
    d1 = AffineDimExpr.get(1)
    c2 = AffineConstantExpr.get(2)

    # CHECK: (d0, d1)[s0, s1, s2] -> ()
    map0 = AffineMap.get(2, 3, [])
    print(map0)

    # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
    map1 = AffineMap.get(2, 3, [d1, c2])
    print(map1)

    # CHECK: () -> (2)
    map2 = AffineMap.get(0, 0, [c2])
    print(map2)

    # CHECK: (d0, d1) -> (d0, d1)
    map3 = AffineMap.get(2, 0, [d0, d1])
    print(map3)

    # CHECK: (d0, d1) -> (d1)
    map4 = AffineMap.get(2, 0, [d1])
    print(map4)

    # CHECK: (d0, d1, d2) -> (d2, d0, d1)
    map5 = AffineMap.get_permutation([2, 0, 1])
    print(map5)

    assert map1 == AffineMap.get(2, 3, [d1, c2])
    assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
    assert map2 == AffineMap.get_constant(2)
    assert map3 == AffineMap.get_identity(2)
    assert map4 == AffineMap.get_minor_identity(2, 1)

    try:
      AffineMap.get(1, 1, [1])
    except RuntimeError as e:
      # CHECK: Invalid expression when attempting to create an AffineMap
      print(e)

    try:
      AffineMap.get(1, 1, [None])
    except RuntimeError as e:
      # CHECK: Invalid expression (None?) when attempting to create an AffineMap
      print(e)

    try:
      AffineMap.get_permutation([1, 0, 1])
    except RuntimeError as e:
      # CHECK: Invalid permutation when attempting to create an AffineMap
      print(e)

    try:
      map3.get_submap([42])
    except ValueError as e:
      # CHECK: result position out of bounds
      print(e)

    try:
      map3.get_minor_submap(42)
    except ValueError as e:
      # CHECK: number of results out of bounds
      print(e)

    try:
      map3.get_major_submap(42)
    except ValueError as e:
      # CHECK: number of results out of bounds
      print(e)


# CHECK-LABEL: TEST: testAffineMapDerive
@run
def testAffineMapDerive():
  with Context() as ctx:
    map5 = AffineMap.get_identity(5)

    # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
    map123 = map5.get_submap([1, 2, 3])
    print(map123)

    # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
    map01 = map5.get_major_submap(2)
    print(map01)

    # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
    map34 = map5.get_minor_submap(2)
    print(map34)


# CHECK-LABEL: TEST: testAffineMapProperties
@run
def testAffineMapProperties():
  with Context():
    d0 = AffineDimExpr.get(0)
    d1 = AffineDimExpr.get(1)
    d2 = AffineDimExpr.get(2)
    map1 = AffineMap.get(3, 0, [d2, d0])
    map2 = AffineMap.get(3, 0, [d2, d0, d1])
    map3 = AffineMap.get(3, 1, [d2, d0, d1])
    # CHECK: False
    print(map1.is_permutation)
    # CHECK: True
    print(map1.is_projected_permutation)
    # CHECK: True
    print(map2.is_permutation)
    # CHECK: True
    print(map2.is_projected_permutation)
    # CHECK: False
    print(map3.is_permutation)
    # CHECK: False
    print(map3.is_projected_permutation)


# CHECK-LABEL: TEST: testAffineMapExprs
@run
def testAffineMapExprs():
  with Context():
    d0 = AffineDimExpr.get(0)
    d1 = AffineDimExpr.get(1)
    d2 = AffineDimExpr.get(2)
    map3 = AffineMap.get(3, 1, [d2, d0, d1])

    # CHECK: 3
    print(map3.n_dims)
    # CHECK: 4
    print(map3.n_inputs)
    # CHECK: 1
    print(map3.n_symbols)
    assert map3.n_inputs == map3.n_dims + map3.n_symbols

    # CHECK: 3
    print(len(map3.results))
    for expr in map3.results:
      # CHECK: d2
      # CHECK: d0
      # CHECK: d1
      print(expr)
    for expr in map3.results[-1:-4:-1]:
      # CHECK: d1
      # CHECK: d0
      # CHECK: d2
      print(expr)
    assert list(map3.results) == [d2, d0, d1]


# CHECK-LABEL: TEST: testCompressUnusedSymbols
@run
def testCompressUnusedSymbols():
  with Context() as ctx:
    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
                  AffineDimExpr.get(2))
    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
                  AffineSymbolExpr.get(2))
    maps = [
        AffineMap.get(3, 3, [d2, d0, d1]),
        AffineMap.get(3, 3, [d2, d0 + s2, d1]),
        AffineMap.get(3, 3, [d1, d2, d0])
    ]

    compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)

    #      CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
    print(maps)

    #      CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
    print(compressed_maps)


# CHECK-LABEL: TEST: testReplace
@run
def testReplace():
  with Context() as ctx:
    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
                  AffineDimExpr.get(2))
    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
                  AffineSymbolExpr.get(2))
    map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])

    replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
    replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
    replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)

    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
    print(replace0)

    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
    print(replace1)

    # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
    print(replace3)


# CHECK-LABEL: TEST: testHash
@run
def testHash():
  with Context():
    d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
    m1 = AffineMap.get(2, 0, [d0, d1])
    m2 = AffineMap.get(2, 0, [d1, d0])
    assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))

    dictionary = dict()
    dictionary[m1] = 1
    dictionary[m2] = 2
    assert m1 in dictionary
