//===- ir.c - Simple test of C APIs ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

/* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
 */

#include "mlir-c/IR.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/Dialect/Standard.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Registration.h"
#include "mlir-c/Support.h"

#include <assert.h>
#include <inttypes.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
                      MlirLocation location, MlirBlock funcBody) {
  MlirValue iv = mlirBlockGetArgument(loopBody, 0);
  MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
  MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
  MlirType f32Type =
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));

  MlirOperationState loadLHSState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("memref.load"), location);
  MlirValue loadLHSOperands[] = {funcArg0, iv};
  mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
  mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
  MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
  mlirBlockAppendOwnedOperation(loopBody, loadLHS);

  MlirOperationState loadRHSState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("memref.load"), location);
  MlirValue loadRHSOperands[] = {funcArg1, iv};
  mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
  mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
  MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
  mlirBlockAppendOwnedOperation(loopBody, loadRHS);

  MlirOperationState addState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("arith.addf"), location);
  MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
                             mlirOperationGetResult(loadRHS, 0)};
  mlirOperationStateAddOperands(&addState, 2, addOperands);
  mlirOperationStateAddResults(&addState, 1, &f32Type);
  MlirOperation add = mlirOperationCreate(&addState);
  mlirBlockAppendOwnedOperation(loopBody, add);

  MlirOperationState storeState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("memref.store"), location);
  MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
  mlirOperationStateAddOperands(&storeState, 3, storeOperands);
  MlirOperation store = mlirOperationCreate(&storeState);
  mlirBlockAppendOwnedOperation(loopBody, store);

  MlirOperationState yieldState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("scf.yield"), location);
  MlirOperation yield = mlirOperationCreate(&yieldState);
  mlirBlockAppendOwnedOperation(loopBody, yield);
}

MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
  MlirModule moduleOp = mlirModuleCreateEmpty(location);
  MlirBlock moduleBody = mlirModuleGetBody(moduleOp);

  MlirType memrefType =
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
  MlirType funcBodyArgTypes[] = {memrefType, memrefType};
  MlirRegion funcBodyRegion = mlirRegionCreate();
  MlirBlock funcBody = mlirBlockCreate(
      sizeof(funcBodyArgTypes) / sizeof(MlirType), funcBodyArgTypes);
  mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);

  MlirAttribute funcTypeAttr = mlirAttributeParseGet(
      ctx,
      mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
  MlirAttribute funcNameAttr =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
  MlirNamedAttribute funcAttrs[] = {
      mlirNamedAttributeGet(
          mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("type")),
          funcTypeAttr),
      mlirNamedAttributeGet(
          mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sym_name")),
          funcNameAttr)};
  MlirOperationState funcState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("builtin.func"), location);
  mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
  mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
  MlirOperation func = mlirOperationCreate(&funcState);
  mlirBlockInsertOwnedOperation(moduleBody, 0, func);

  MlirType indexType =
      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
  MlirAttribute indexZeroLiteral =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
      indexZeroLiteral);
  MlirOperationState constZeroState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("arith.constant"), location);
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
  MlirOperation constZero = mlirOperationCreate(&constZeroState);
  mlirBlockAppendOwnedOperation(funcBody, constZero);

  MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
  MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
  MlirValue dimOperands[] = {funcArg0, constZeroValue};
  MlirOperationState dimState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("memref.dim"), location);
  mlirOperationStateAddOperands(&dimState, 2, dimOperands);
  mlirOperationStateAddResults(&dimState, 1, &indexType);
  MlirOperation dim = mlirOperationCreate(&dimState);
  mlirBlockAppendOwnedOperation(funcBody, dim);

  MlirRegion loopBodyRegion = mlirRegionCreate();
  MlirBlock loopBody = mlirBlockCreate(0, NULL);
  mlirBlockAddArgument(loopBody, indexType);
  mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);

  MlirAttribute indexOneLiteral =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
  MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
      indexOneLiteral);
  MlirOperationState constOneState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("arith.constant"), location);
  mlirOperationStateAddResults(&constOneState, 1, &indexType);
  mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
  MlirOperation constOne = mlirOperationCreate(&constOneState);
  mlirBlockAppendOwnedOperation(funcBody, constOne);

  MlirValue dimValue = mlirOperationGetResult(dim, 0);
  MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
  MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
  MlirOperationState loopState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("scf.for"), location);
  mlirOperationStateAddOperands(&loopState, 3, loopOperands);
  mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
  MlirOperation loop = mlirOperationCreate(&loopState);
  mlirBlockAppendOwnedOperation(funcBody, loop);

  populateLoopBody(ctx, loopBody, location, funcBody);

  MlirOperationState retState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("std.return"), location);
  MlirOperation ret = mlirOperationCreate(&retState);
  mlirBlockAppendOwnedOperation(funcBody, ret);

  MlirOperation module = mlirModuleGetOperation(moduleOp);
  mlirOperationDump(module);
  // clang-format off
  // CHECK: module {
  // CHECK:   func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
  // CHECK:     %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
  // CHECK:     %[[C1:.*]] = arith.constant 1 : index
  // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
  // CHECK:       %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32>
  // CHECK:       %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32>
  // CHECK:       %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
  // CHECK:       memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
  // CHECK:     }
  // CHECK:     return
  // CHECK:   }
  // CHECK: }
  // clang-format on

  return moduleOp;
}

struct OpListNode {
  MlirOperation op;
  struct OpListNode *next;
};
typedef struct OpListNode OpListNode;

struct ModuleStats {
  unsigned numOperations;
  unsigned numAttributes;
  unsigned numBlocks;
  unsigned numRegions;
  unsigned numValues;
  unsigned numBlockArguments;
  unsigned numOpResults;
};
typedef struct ModuleStats ModuleStats;

int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
  MlirOperation operation = head->op;
  stats->numOperations += 1;
  stats->numValues += mlirOperationGetNumResults(operation);
  stats->numAttributes += mlirOperationGetNumAttributes(operation);

  unsigned numRegions = mlirOperationGetNumRegions(operation);

  stats->numRegions += numRegions;

  intptr_t numResults = mlirOperationGetNumResults(operation);
  for (intptr_t i = 0; i < numResults; ++i) {
    MlirValue result = mlirOperationGetResult(operation, i);
    if (!mlirValueIsAOpResult(result))
      return 1;
    if (mlirValueIsABlockArgument(result))
      return 2;
    if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
      return 3;
    if (i != mlirOpResultGetResultNumber(result))
      return 4;
    ++stats->numOpResults;
  }

  MlirRegion region = mlirOperationGetFirstRegion(operation);
  while (!mlirRegionIsNull(region)) {
    for (MlirBlock block = mlirRegionGetFirstBlock(region);
         !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
      ++stats->numBlocks;
      intptr_t numArgs = mlirBlockGetNumArguments(block);
      stats->numValues += numArgs;
      for (intptr_t j = 0; j < numArgs; ++j) {
        MlirValue arg = mlirBlockGetArgument(block, j);
        if (!mlirValueIsABlockArgument(arg))
          return 5;
        if (mlirValueIsAOpResult(arg))
          return 6;
        if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
          return 7;
        if (j != mlirBlockArgumentGetArgNumber(arg))
          return 8;
        ++stats->numBlockArguments;
      }

      for (MlirOperation child = mlirBlockGetFirstOperation(block);
           !mlirOperationIsNull(child);
           child = mlirOperationGetNextInBlock(child)) {
        OpListNode *node = malloc(sizeof(OpListNode));
        node->op = child;
        node->next = head->next;
        head->next = node;
      }
    }
    region = mlirRegionGetNextInOperation(region);
  }
  return 0;
}

int collectStats(MlirOperation operation) {
  OpListNode *head = malloc(sizeof(OpListNode));
  head->op = operation;
  head->next = NULL;

  ModuleStats stats;
  stats.numOperations = 0;
  stats.numAttributes = 0;
  stats.numBlocks = 0;
  stats.numRegions = 0;
  stats.numValues = 0;
  stats.numBlockArguments = 0;
  stats.numOpResults = 0;

  do {
    int retval = collectStatsSingle(head, &stats);
    if (retval)
      return retval;
    OpListNode *next = head->next;
    free(head);
    head = next;
  } while (head);

  if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
    return 100;

  fprintf(stderr, "@stats\n");
  fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
  fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
  fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
  fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
  fprintf(stderr, "Number of values: %u\n", stats.numValues);
  fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
  fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
  // clang-format off
  // CHECK-LABEL: @stats
  // CHECK: Number of operations: 12
  // CHECK: Number of attributes: 4
  // CHECK: Number of blocks: 3
  // CHECK: Number of regions: 3
  // CHECK: Number of values: 9
  // CHECK: Number of block arguments: 3
  // CHECK: Number of op results: 6
  // clang-format on
  return 0;
}

static void printToStderr(MlirStringRef str, void *userData) {
  (void)userData;
  fwrite(str.data, 1, str.length, stderr);
}

static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
  // Assuming we are given a module, go to the first operation of the first
  // function.
  MlirRegion region = mlirOperationGetRegion(operation, 0);
  MlirBlock block = mlirRegionGetFirstBlock(region);
  operation = mlirBlockGetFirstOperation(block);
  region = mlirOperationGetRegion(operation, 0);
  MlirOperation parentOperation = operation;
  block = mlirRegionGetFirstBlock(region);
  operation = mlirBlockGetFirstOperation(block);
  assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));

  // Verify that parent operation and block report correctly.
  // CHECK: Parent operation eq: 1
  fprintf(stderr, "Parent operation eq: %d\n",
          mlirOperationEqual(mlirOperationGetParentOperation(operation),
                             parentOperation));
  // CHECK: Block eq: 1
  fprintf(stderr, "Block eq: %d\n",
          mlirBlockEqual(mlirOperationGetBlock(operation), block));
  // CHECK: Block parent operation eq: 1
  fprintf(
      stderr, "Block parent operation eq: %d\n",
      mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation));
  // CHECK: Block parent region eq: 1
  fprintf(stderr, "Block parent region eq: %d\n",
          mlirRegionEqual(mlirBlockGetParentRegion(block), region));

  // In the module we created, the first operation of the first function is
  // an "memref.dim", which has an attribute and a single result that we can
  // use to test the printing mechanism.
  mlirBlockPrint(block, printToStderr, NULL);
  fprintf(stderr, "\n");
  fprintf(stderr, "First operation: ");
  mlirOperationPrint(operation, printToStderr, NULL);
  fprintf(stderr, "\n");
  // clang-format off
  // CHECK:   %[[C0:.*]] = arith.constant 0 : index
  // CHECK:   %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
  // CHECK:   %[[C1:.*]] = arith.constant 1 : index
  // CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
  // CHECK:     %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
  // CHECK:     %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
  // CHECK:     %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
  // CHECK:     memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
  // CHECK:   }
  // CHECK: return
  // CHECK: First operation: {{.*}} = arith.constant 0 : index
  // clang-format on

  // Get the operation name and print it.
  MlirIdentifier ident = mlirOperationGetName(operation);
  MlirStringRef identStr = mlirIdentifierStr(ident);
  fprintf(stderr, "Operation name: '");
  for (size_t i = 0; i < identStr.length; ++i)
    fputc(identStr.data[i], stderr);
  fprintf(stderr, "'\n");
  // CHECK: Operation name: 'arith.constant'

  // Get the identifier again and verify equal.
  MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
  fprintf(stderr, "Identifier equal: %d\n",
          mlirIdentifierEqual(ident, identAgain));
  // CHECK: Identifier equal: 1

  // Get the block terminator and print it.
  MlirOperation terminator = mlirBlockGetTerminator(block);
  fprintf(stderr, "Terminator: ");
  mlirOperationPrint(terminator, printToStderr, NULL);
  fprintf(stderr, "\n");
  // CHECK: Terminator: return

  // Get the attribute by index.
  MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
  fprintf(stderr, "Get attr 0: ");
  mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
  fprintf(stderr, "\n");
  // CHECK: Get attr 0: 0 : index

  // Now re-get the attribute by name.
  MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
      operation, mlirIdentifierStr(namedAttr0.name));
  fprintf(stderr, "Get attr 0 by name: ");
  mlirAttributePrint(attr0ByName, printToStderr, NULL);
  fprintf(stderr, "\n");
  // CHECK: Get attr 0 by name: 0 : index

  // Get a non-existing attribute and assert that it is null (sanity).
  fprintf(stderr, "does_not_exist is null: %d\n",
          mlirAttributeIsNull(mlirOperationGetAttributeByName(
              operation, mlirStringRefCreateFromCString("does_not_exist"))));
  // CHECK: does_not_exist is null: 1

  // Get result 0 and its type.
  MlirValue value = mlirOperationGetResult(operation, 0);
  fprintf(stderr, "Result 0: ");
  mlirValuePrint(value, printToStderr, NULL);
  fprintf(stderr, "\n");
  fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
  // CHECK: Result 0: {{.*}} = arith.constant 0 : index
  // CHECK: Value is null: 0

  MlirType type = mlirValueGetType(value);
  fprintf(stderr, "Result 0 type: ");
  mlirTypePrint(type, printToStderr, NULL);
  fprintf(stderr, "\n");
  // CHECK: Result 0 type: index

  // Set a custom attribute.
  mlirOperationSetAttributeByName(operation,
                                  mlirStringRefCreateFromCString("custom_attr"),
                                  mlirBoolAttrGet(ctx, 1));
  fprintf(stderr, "Op with set attr: ");
  mlirOperationPrint(operation, printToStderr, NULL);
  fprintf(stderr, "\n");
  // CHECK: Op with set attr: {{.*}} {custom_attr = true}

  // Remove the attribute.
  fprintf(stderr, "Remove attr: %d\n",
          mlirOperationRemoveAttributeByName(
              operation, mlirStringRefCreateFromCString("custom_attr")));
  fprintf(stderr, "Remove attr again: %d\n",
          mlirOperationRemoveAttributeByName(
              operation, mlirStringRefCreateFromCString("custom_attr")));
  fprintf(stderr, "Removed attr is null: %d\n",
          mlirAttributeIsNull(mlirOperationGetAttributeByName(
              operation, mlirStringRefCreateFromCString("custom_attr"))));
  // CHECK: Remove attr: 1
  // CHECK: Remove attr again: 0
  // CHECK: Removed attr is null: 1

  // Add a large attribute to verify printing flags.
  int64_t eltsShape[] = {4};
  int32_t eltsData[] = {1, 2, 3, 4};
  mlirOperationSetAttributeByName(
      operation, mlirStringRefCreateFromCString("elts"),
      mlirDenseElementsAttrInt32Get(
          mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
                                  mlirAttributeGetNull()),
          4, eltsData));
  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
  mlirOpPrintingFlagsPrintGenericOpForm(flags);
  mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
  mlirOpPrintingFlagsUseLocalScope(flags);
  fprintf(stderr, "Op print with all flags: ");
  mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
  fprintf(stderr, "\n");
  // clang-format off
  // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() {elts = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
  // clang-format on

  mlirOpPrintingFlagsDestroy(flags);
}

static int constructAndTraverseIr(MlirContext ctx) {
  MlirLocation location = mlirLocationUnknownGet(ctx);

  MlirModule moduleOp = makeAndDumpAdd(ctx, location);
  MlirOperation module = mlirModuleGetOperation(moduleOp);
  assert(!mlirModuleIsNull(mlirModuleFromOperation(module)));

  int errcode = collectStats(module);
  if (errcode)
    return errcode;

  printFirstOfEach(ctx, module);

  mlirModuleDestroy(moduleOp);
  return 0;
}

/// Creates an operation with a region containing multiple blocks with
/// operations and dumps it. The blocks and operations are inserted using
/// block/operation-relative API and their final order is checked.
static void buildWithInsertionsAndPrint(MlirContext ctx) {
  MlirLocation loc = mlirLocationUnknownGet(ctx);
  mlirContextSetAllowUnregisteredDialects(ctx, true);

  MlirRegion owningRegion = mlirRegionCreate();
  MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
  MlirOperationState state = mlirOperationStateGet(
      mlirStringRefCreateFromCString("insertion.order.test"), loc);
  mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
  MlirOperation op = mlirOperationCreate(&state);
  MlirRegion region = mlirOperationGetRegion(op, 0);

  // Use integer types of different bitwidth as block arguments in order to
  // differentiate blocks.
  MlirType i1 = mlirIntegerTypeGet(ctx, 1);
  MlirType i2 = mlirIntegerTypeGet(ctx, 2);
  MlirType i3 = mlirIntegerTypeGet(ctx, 3);
  MlirType i4 = mlirIntegerTypeGet(ctx, 4);
  MlirBlock block1 = mlirBlockCreate(1, &i1);
  MlirBlock block2 = mlirBlockCreate(1, &i2);
  MlirBlock block3 = mlirBlockCreate(1, &i3);
  MlirBlock block4 = mlirBlockCreate(1, &i4);
  // Insert blocks so as to obtain the 1-2-3-4 order,
  mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
  mlirRegionInsertOwnedBlockBefore(region, block3, block2);
  mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
  mlirRegionInsertOwnedBlockAfter(region, block3, block4);

  MlirOperationState op1State =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
  MlirOperationState op2State =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
  MlirOperationState op3State =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
  MlirOperationState op4State =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
  MlirOperationState op5State =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
  MlirOperationState op6State =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
  MlirOperationState op7State =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
  MlirOperation op1 = mlirOperationCreate(&op1State);
  MlirOperation op2 = mlirOperationCreate(&op2State);
  MlirOperation op3 = mlirOperationCreate(&op3State);
  MlirOperation op4 = mlirOperationCreate(&op4State);
  MlirOperation op5 = mlirOperationCreate(&op5State);
  MlirOperation op6 = mlirOperationCreate(&op6State);
  MlirOperation op7 = mlirOperationCreate(&op7State);

  // Insert operations in the first block so as to obtain the 1-2-3-4 order.
  MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
  assert(mlirOperationIsNull(nullOperation));
  mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
  mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
  mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
  mlirBlockInsertOwnedOperationAfter(block1, op3, op4);

  // Append operations to the rest of blocks to make them non-empty and thus
  // printable.
  mlirBlockAppendOwnedOperation(block2, op5);
  mlirBlockAppendOwnedOperation(block3, op6);
  mlirBlockAppendOwnedOperation(block4, op7);

  mlirOperationDump(op);
  mlirOperationDestroy(op);
  mlirContextSetAllowUnregisteredDialects(ctx, false);
  // clang-format off
  // CHECK-LABEL:  "insertion.order.test"
  // CHECK:      ^{{.*}}(%{{.*}}: i1
  // CHECK:        "dummy.op1"
  // CHECK-NEXT:   "dummy.op2"
  // CHECK-NEXT:   "dummy.op3"
  // CHECK-NEXT:   "dummy.op4"
  // CHECK:      ^{{.*}}(%{{.*}}: i2
  // CHECK:        "dummy.op5"
  // CHECK:      ^{{.*}}(%{{.*}}: i3
  // CHECK:        "dummy.op6"
  // CHECK:      ^{{.*}}(%{{.*}}: i4
  // CHECK:        "dummy.op7"
  // clang-format on
}

/// Creates operations with type inference and tests various failure modes.
static int createOperationWithTypeInference(MlirContext ctx) {
  MlirLocation loc = mlirLocationUnknownGet(ctx);
  MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);

  // The shape.const_size op implements result type inference and is only used
  // for that reason.
  MlirOperationState state = mlirOperationStateGet(
      mlirStringRefCreateFromCString("shape.const_size"), loc);
  MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
  mlirOperationStateAddAttributes(&state, 1, &valueAttr);
  mlirOperationStateEnableResultTypeInference(&state);

  // Expect result type inference to succeed.
  MlirOperation op = mlirOperationCreate(&state);
  if (mlirOperationIsNull(op)) {
    fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
    return 1;
  }

  // CHECK: RESULT_TYPE_INFERENCE: !shape.size
  fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
  mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
  fprintf(stderr, "\n");
  mlirOperationDestroy(op);
  return 0;
}

/// Dumps instances of all builtin types to check that C API works correctly.
/// Additionally, performs simple identity checks that a builtin type
/// constructed with C API can be inspected and has the expected type. The
/// latter achieves full coverage of C API for builtin types. Returns 0 on
/// success and a non-zero error code on failure.
static int printBuiltinTypes(MlirContext ctx) {
  // Integer types.
  MlirType i32 = mlirIntegerTypeGet(ctx, 32);
  MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
  MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
  if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
    return 1;
  if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
    return 2;
  if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
    return 3;
  if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
    return 4;
  if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
    return 5;
  fprintf(stderr, "@types\n");
  mlirTypeDump(i32);
  fprintf(stderr, "\n");
  mlirTypeDump(si32);
  fprintf(stderr, "\n");
  mlirTypeDump(ui32);
  fprintf(stderr, "\n");
  // CHECK-LABEL: @types
  // CHECK: i32
  // CHECK: si32
  // CHECK: ui32

  // Index type.
  MlirType index = mlirIndexTypeGet(ctx);
  if (!mlirTypeIsAIndex(index))
    return 6;
  mlirTypeDump(index);
  fprintf(stderr, "\n");
  // CHECK: index

  // Floating-point types.
  MlirType bf16 = mlirBF16TypeGet(ctx);
  MlirType f16 = mlirF16TypeGet(ctx);
  MlirType f32 = mlirF32TypeGet(ctx);
  MlirType f64 = mlirF64TypeGet(ctx);
  if (!mlirTypeIsABF16(bf16))
    return 7;
  if (!mlirTypeIsAF16(f16))
    return 9;
  if (!mlirTypeIsAF32(f32))
    return 10;
  if (!mlirTypeIsAF64(f64))
    return 11;
  mlirTypeDump(bf16);
  fprintf(stderr, "\n");
  mlirTypeDump(f16);
  fprintf(stderr, "\n");
  mlirTypeDump(f32);
  fprintf(stderr, "\n");
  mlirTypeDump(f64);
  fprintf(stderr, "\n");
  // CHECK: bf16
  // CHECK: f16
  // CHECK: f32
  // CHECK: f64

  // None type.
  MlirType none = mlirNoneTypeGet(ctx);
  if (!mlirTypeIsANone(none))
    return 12;
  mlirTypeDump(none);
  fprintf(stderr, "\n");
  // CHECK: none

  // Complex type.
  MlirType cplx = mlirComplexTypeGet(f32);
  if (!mlirTypeIsAComplex(cplx) ||
      !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
    return 13;
  mlirTypeDump(cplx);
  fprintf(stderr, "\n");
  // CHECK: complex<f32>

  // Vector (and Shaped) type. ShapedType is a common base class for vectors,
  // memrefs and tensors, one cannot create instances of this class so it is
  // tested on an instance of vector type.
  int64_t shape[] = {2, 3};
  MlirType vector =
      mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
  if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
    return 14;
  if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
      !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
      mlirShapedTypeGetDimSize(vector, 0) != 2 ||
      mlirShapedTypeIsDynamicDim(vector, 0) ||
      mlirShapedTypeGetDimSize(vector, 1) != 3 ||
      !mlirShapedTypeHasStaticShape(vector))
    return 15;
  mlirTypeDump(vector);
  fprintf(stderr, "\n");
  // CHECK: vector<2x3xf32>

  // Ranked tensor type.
  MlirType rankedTensor = mlirRankedTensorTypeGet(
      sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
  if (!mlirTypeIsATensor(rankedTensor) ||
      !mlirTypeIsARankedTensor(rankedTensor) ||
      !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
    return 16;
  mlirTypeDump(rankedTensor);
  fprintf(stderr, "\n");
  // CHECK: tensor<2x3xf32>

  // Unranked tensor type.
  MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
  if (!mlirTypeIsATensor(unrankedTensor) ||
      !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
      mlirShapedTypeHasRank(unrankedTensor))
    return 17;
  mlirTypeDump(unrankedTensor);
  fprintf(stderr, "\n");
  // CHECK: tensor<*xf32>

  // MemRef type.
  MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2);
  MlirType memRef = mlirMemRefTypeContiguousGet(
      f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
  if (!mlirTypeIsAMemRef(memRef) ||
      !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
    return 18;
  mlirTypeDump(memRef);
  fprintf(stderr, "\n");
  // CHECK: memref<2x3xf32, 2>

  // Unranked MemRef type.
  MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4);
  MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4);
  if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
      mlirTypeIsAMemRef(unrankedMemRef) ||
      !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
                          memSpace4))
    return 19;
  mlirTypeDump(unrankedMemRef);
  fprintf(stderr, "\n");
  // CHECK: memref<*xf32, 4>

  // Tuple type.
  MlirType types[] = {unrankedMemRef, f32};
  MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
  if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
      !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
      !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
    return 20;
  mlirTypeDump(tuple);
  fprintf(stderr, "\n");
  // CHECK: tuple<memref<*xf32, 4>, f32>

  // Function type.
  MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
  MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
                             mlirIntegerTypeGet(ctx, 32),
                             mlirIntegerTypeGet(ctx, 64)};
  MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
  if (mlirFunctionTypeGetNumInputs(funcType) != 2)
    return 21;
  if (mlirFunctionTypeGetNumResults(funcType) != 3)
    return 22;
  if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
      !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
    return 23;
  if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
      !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
      !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
    return 24;
  mlirTypeDump(funcType);
  fprintf(stderr, "\n");
  // CHECK: (index, i1) -> (i16, i32, i64)

  return 0;
}

void callbackSetFixedLengthString(const char *data, intptr_t len,
                                  void *userData) {
  strncpy(userData, data, len);
}

bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
  if (strlen(lhs) != rhs.length) {
    return false;
  }
  return !strncmp(lhs, rhs.data, rhs.length);
}

int printBuiltinAttributes(MlirContext ctx) {
  MlirAttribute floating =
      mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
  if (!mlirAttributeIsAFloat(floating) ||
      fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
    return 1;
  fprintf(stderr, "@attrs\n");
  mlirAttributeDump(floating);
  // CHECK-LABEL: @attrs
  // CHECK: 2.000000e+00 : f64

  // Exercise mlirAttributeGetType() just for the first one.
  MlirType floatingType = mlirAttributeGetType(floating);
  mlirTypeDump(floatingType);
  // CHECK: f64

  MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
  if (!mlirAttributeIsAInteger(integer) ||
      mlirIntegerAttrGetValueInt(integer) != 42)
    return 2;
  mlirAttributeDump(integer);
  // CHECK: 42 : i32

  MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
  if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
    return 3;
  mlirAttributeDump(boolean);
  // CHECK: true

  const char data[] = "abcdefghijklmnopqestuvwxyz";
  MlirAttribute opaque =
      mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("std"), 3, data,
                        mlirNoneTypeGet(ctx));
  if (!mlirAttributeIsAOpaque(opaque) ||
      !stringIsEqual("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
    return 4;

  MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
  if (opaqueData.length != 3 ||
      strncmp(data, opaqueData.data, opaqueData.length))
    return 5;
  mlirAttributeDump(opaque);
  // CHECK: #std.abc

  MlirAttribute string =
      mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
  if (!mlirAttributeIsAString(string))
    return 6;

  MlirStringRef stringValue = mlirStringAttrGetValue(string);
  if (stringValue.length != 2 ||
      strncmp(data + 3, stringValue.data, stringValue.length))
    return 7;
  mlirAttributeDump(string);
  // CHECK: "de"

  MlirAttribute flatSymbolRef =
      mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
  if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
    return 8;

  MlirStringRef flatSymbolRefValue =
      mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
  if (flatSymbolRefValue.length != 3 ||
      strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
    return 9;
  mlirAttributeDump(flatSymbolRef);
  // CHECK: @fgh

  MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
  MlirAttribute symbolRef =
      mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
  if (!mlirAttributeIsASymbolRef(symbolRef) ||
      mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
      !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
                          flatSymbolRef) ||
      !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
                          flatSymbolRef))
    return 10;

  MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
  MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
  if (symbolRefLeaf.length != 3 ||
      strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
      symbolRefRoot.length != 2 ||
      strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
    return 11;
  mlirAttributeDump(symbolRef);
  // CHECK: @ij::@fgh::@fgh

  MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
  if (!mlirAttributeIsAType(type) ||
      !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
    return 12;
  mlirAttributeDump(type);
  // CHECK: f32

  MlirAttribute unit = mlirUnitAttrGet(ctx);
  if (!mlirAttributeIsAUnit(unit))
    return 13;
  mlirAttributeDump(unit);
  // CHECK: unit

  int64_t shape[] = {1, 2};

  int bools[] = {0, 1};
  uint8_t uints8[] = {0u, 1u};
  int8_t ints8[] = {0, 1};
  uint32_t uints32[] = {0u, 1u};
  int32_t ints32[] = {0, 1};
  uint64_t uints64[] = {0u, 1u};
  int64_t ints64[] = {0, 1};
  float floats[] = {0.0f, 1.0f};
  double doubles[] = {0.0, 1.0};
  MlirAttribute encoding = mlirAttributeGetNull();
  MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
      2, bools);
  MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
                              encoding),
      2, uints8);
  MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
      2, ints8);
  MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
                              encoding),
      2, uints32);
  MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
      2, ints32);
  MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
                              encoding),
      2, uints64);
  MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
      2, ints64);
  MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2,
      floats);
  MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
      mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
      doubles);

  if (!mlirAttributeIsADenseElements(boolElements) ||
      !mlirAttributeIsADenseElements(uint8Elements) ||
      !mlirAttributeIsADenseElements(int8Elements) ||
      !mlirAttributeIsADenseElements(uint32Elements) ||
      !mlirAttributeIsADenseElements(int32Elements) ||
      !mlirAttributeIsADenseElements(uint64Elements) ||
      !mlirAttributeIsADenseElements(int64Elements) ||
      !mlirAttributeIsADenseElements(floatElements) ||
      !mlirAttributeIsADenseElements(doubleElements))
    return 14;

  if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
      mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
      mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
      mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
      mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
      mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
      mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
      fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
          1E-6f ||
      fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
    return 15;

  mlirAttributeDump(boolElements);
  mlirAttributeDump(uint8Elements);
  mlirAttributeDump(int8Elements);
  mlirAttributeDump(uint32Elements);
  mlirAttributeDump(int32Elements);
  mlirAttributeDump(uint64Elements);
  mlirAttributeDump(int64Elements);
  mlirAttributeDump(floatElements);
  mlirAttributeDump(doubleElements);
  // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
  // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
  // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>

  MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
      1);
  MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
                              encoding),
      1);
  MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
      1);
  MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
                              encoding),
      1);
  MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
      1);
  MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
                              encoding),
      1);
  MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
      1);
  MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
  MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
      mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0);

  if (!mlirAttributeIsADenseElements(splatBool) ||
      !mlirDenseElementsAttrIsSplat(splatBool) ||
      !mlirAttributeIsADenseElements(splatUInt8) ||
      !mlirDenseElementsAttrIsSplat(splatUInt8) ||
      !mlirAttributeIsADenseElements(splatInt8) ||
      !mlirDenseElementsAttrIsSplat(splatInt8) ||
      !mlirAttributeIsADenseElements(splatUInt32) ||
      !mlirDenseElementsAttrIsSplat(splatUInt32) ||
      !mlirAttributeIsADenseElements(splatInt32) ||
      !mlirDenseElementsAttrIsSplat(splatInt32) ||
      !mlirAttributeIsADenseElements(splatUInt64) ||
      !mlirDenseElementsAttrIsSplat(splatUInt64) ||
      !mlirAttributeIsADenseElements(splatInt64) ||
      !mlirDenseElementsAttrIsSplat(splatInt64) ||
      !mlirAttributeIsADenseElements(splatFloat) ||
      !mlirDenseElementsAttrIsSplat(splatFloat) ||
      !mlirAttributeIsADenseElements(splatDouble) ||
      !mlirDenseElementsAttrIsSplat(splatDouble))
    return 16;

  if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
      mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
      mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
      mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
      mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
      mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
      mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
      fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
          1E-6f ||
      fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
    return 17;

  uint8_t *uint8RawData =
      (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
  int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
  uint32_t *uint32RawData =
      (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
  int32_t *int32RawData =
      (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
  uint64_t *uint64RawData =
      (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
  int64_t *int64RawData =
      (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
  float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
  double *doubleRawData =
      (double *)mlirDenseElementsAttrGetRawData(doubleElements);
  if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
      int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
      int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
      uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
      floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
      doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
    return 18;

  mlirAttributeDump(splatBool);
  mlirAttributeDump(splatUInt8);
  mlirAttributeDump(splatInt8);
  mlirAttributeDump(splatUInt32);
  mlirAttributeDump(splatInt32);
  mlirAttributeDump(splatUInt64);
  mlirAttributeDump(splatInt64);
  mlirAttributeDump(splatFloat);
  mlirAttributeDump(splatDouble);
  // CHECK: dense<true> : tensor<1x2xi1>
  // CHECK: dense<1> : tensor<1x2xui8>
  // CHECK: dense<1> : tensor<1x2xi8>
  // CHECK: dense<1> : tensor<1x2xui32>
  // CHECK: dense<1> : tensor<1x2xi32>
  // CHECK: dense<1> : tensor<1x2xui64>
  // CHECK: dense<1> : tensor<1x2xi64>
  // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
  // CHECK: dense<1.000000e+00> : tensor<1x2xf64>

  mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
  mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
  // CHECK: 1.000000e+00 : f32
  // CHECK: 1.000000e+00 : f64

  int64_t indices[] = {0, 1};
  int64_t one = 1;
  MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
      2, indices);
  MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
      mlirRankedTensorTypeGet(1, &one, mlirF32TypeGet(ctx), encoding), 1,
      floats);
  MlirAttribute sparseAttr = mlirSparseElementsAttribute(
      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
      indicesAttr, valuesAttr);
  mlirAttributeDump(sparseAttr);
  // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32>

  return 0;
}

int printAffineMap(MlirContext ctx) {
  MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
  MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
  MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
  MlirAffineMap multiDimIdentityAffineMap =
      mlirAffineMapMultiDimIdentityGet(ctx, 3);
  MlirAffineMap minorIdentityAffineMap =
      mlirAffineMapMinorIdentityGet(ctx, 3, 2);
  unsigned permutation[] = {1, 2, 0};
  MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
      ctx, sizeof(permutation) / sizeof(unsigned), permutation);

  fprintf(stderr, "@affineMap\n");
  mlirAffineMapDump(emptyAffineMap);
  mlirAffineMapDump(affineMap);
  mlirAffineMapDump(constAffineMap);
  mlirAffineMapDump(multiDimIdentityAffineMap);
  mlirAffineMapDump(minorIdentityAffineMap);
  mlirAffineMapDump(permutationAffineMap);
  // CHECK-LABEL: @affineMap
  // CHECK: () -> ()
  // CHECK: (d0, d1, d2)[s0, s1] -> ()
  // CHECK: () -> (2)
  // CHECK: (d0, d1, d2) -> (d0, d1, d2)
  // CHECK: (d0, d1, d2) -> (d1, d2)
  // CHECK: (d0, d1, d2) -> (d1, d2, d0)

  if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
      mlirAffineMapIsIdentity(affineMap) ||
      mlirAffineMapIsIdentity(constAffineMap) ||
      !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
      mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
      mlirAffineMapIsIdentity(permutationAffineMap))
    return 1;

  if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
      mlirAffineMapIsMinorIdentity(affineMap) ||
      !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
      !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
      mlirAffineMapIsMinorIdentity(permutationAffineMap))
    return 2;

  if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
      mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
      mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
      mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
      mlirAffineMapIsEmpty(permutationAffineMap))
    return 3;

  if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
      mlirAffineMapIsSingleConstant(affineMap) ||
      !mlirAffineMapIsSingleConstant(constAffineMap) ||
      mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
      mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
      mlirAffineMapIsSingleConstant(permutationAffineMap))
    return 4;

  if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
    return 5;

  if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
      mlirAffineMapGetNumDims(affineMap) != 3 ||
      mlirAffineMapGetNumDims(constAffineMap) != 0 ||
      mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
      mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
      mlirAffineMapGetNumDims(permutationAffineMap) != 3)
    return 6;

  if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
      mlirAffineMapGetNumSymbols(affineMap) != 2 ||
      mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
      mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
      mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
      mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
    return 7;

  if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
      mlirAffineMapGetNumResults(affineMap) != 0 ||
      mlirAffineMapGetNumResults(constAffineMap) != 1 ||
      mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
      mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
      mlirAffineMapGetNumResults(permutationAffineMap) != 3)
    return 8;

  if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
      mlirAffineMapGetNumInputs(affineMap) != 5 ||
      mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
      mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
      mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
      mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
    return 9;

  if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
      !mlirAffineMapIsPermutation(emptyAffineMap) ||
      mlirAffineMapIsProjectedPermutation(affineMap) ||
      mlirAffineMapIsPermutation(affineMap) ||
      mlirAffineMapIsProjectedPermutation(constAffineMap) ||
      mlirAffineMapIsPermutation(constAffineMap) ||
      !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
      !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
      !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
      mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
      !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
      !mlirAffineMapIsPermutation(permutationAffineMap))
    return 10;

  intptr_t sub[] = {1};

  MlirAffineMap subMap = mlirAffineMapGetSubMap(
      multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
  MlirAffineMap majorSubMap =
      mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
  MlirAffineMap minorSubMap =
      mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);

  mlirAffineMapDump(subMap);
  mlirAffineMapDump(majorSubMap);
  mlirAffineMapDump(minorSubMap);
  // CHECK: (d0, d1, d2) -> (d1)
  // CHECK: (d0, d1, d2) -> (d0)
  // CHECK: (d0, d1, d2) -> (d2)

  return 0;
}

int printAffineExpr(MlirContext ctx) {
  MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
  MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
  MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
  MlirAffineExpr affineAddExpr =
      mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
  MlirAffineExpr affineMulExpr =
      mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
  MlirAffineExpr affineModExpr =
      mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
  MlirAffineExpr affineFloorDivExpr =
      mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
  MlirAffineExpr affineCeilDivExpr =
      mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);

  // Tests mlirAffineExprDump.
  fprintf(stderr, "@affineExpr\n");
  mlirAffineExprDump(affineDimExpr);
  mlirAffineExprDump(affineSymbolExpr);
  mlirAffineExprDump(affineConstantExpr);
  mlirAffineExprDump(affineAddExpr);
  mlirAffineExprDump(affineMulExpr);
  mlirAffineExprDump(affineModExpr);
  mlirAffineExprDump(affineFloorDivExpr);
  mlirAffineExprDump(affineCeilDivExpr);
  // CHECK-LABEL: @affineExpr
  // CHECK: d5
  // CHECK: s5
  // CHECK: 5
  // CHECK: d5 + s5
  // CHECK: d5 * s5
  // CHECK: d5 mod s5
  // CHECK: d5 floordiv s5
  // CHECK: d5 ceildiv s5

  // Tests methods of affine binary operation expression, takes add expression
  // as an example.
  mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
  mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
  // CHECK: d5
  // CHECK: s5

  // Tests methods of affine dimension expression.
  if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
    return 1;

  // Tests methods of affine symbol expression.
  if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
    return 2;

  // Tests methods of affine constant expression.
  if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
    return 3;

  // Tests methods of affine expression.
  if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
      !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
      !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
      mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
      mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
      mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
      mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
      mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
    return 4;

  if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
      !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
      !mlirAffineExprIsPureAffine(affineConstantExpr) ||
      !mlirAffineExprIsPureAffine(affineAddExpr) ||
      mlirAffineExprIsPureAffine(affineMulExpr) ||
      mlirAffineExprIsPureAffine(affineModExpr) ||
      mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
      mlirAffineExprIsPureAffine(affineCeilDivExpr))
    return 5;

  if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
      mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
      mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
      mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
      mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
      mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
      mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
      mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
    return 6;

  if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
      !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
      !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
      !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
      !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
      !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
      !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
      !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
    return 7;

  if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
      mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
      mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
      !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
      !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
      !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
      !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
      !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
    return 8;

  // Tests 'IsA' methods of affine binary operation expression.
  if (!mlirAffineExprIsAAdd(affineAddExpr))
    return 9;

  if (!mlirAffineExprIsAMul(affineMulExpr))
    return 10;

  if (!mlirAffineExprIsAMod(affineModExpr))
    return 11;

  if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
    return 12;

  if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
    return 13;

  if (!mlirAffineExprIsABinary(affineAddExpr))
    return 14;

  // Test other 'IsA' method on affine expressions.
  if (!mlirAffineExprIsAConstant(affineConstantExpr))
    return 15;

  if (!mlirAffineExprIsADim(affineDimExpr))
    return 16;

  if (!mlirAffineExprIsASymbol(affineSymbolExpr))
    return 17;

  // Test equality and nullity.
  MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
  if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
    return 18;

  if (mlirAffineExprIsNull(affineDimExpr))
    return 19;

  return 0;
}

int affineMapFromExprs(MlirContext ctx) {
  MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
  MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
  MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
  MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);

  // CHECK-LABEL: @affineMapFromExprs
  fprintf(stderr, "@affineMapFromExprs");
  // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
  mlirAffineMapDump(map);

  if (mlirAffineMapGetNumResults(map) != 2)
    return 1;

  if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
    return 2;

  if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
    return 3;

  MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, 1);
  MlirAffineExpr composed = mlirAffineExprCompose(affineDim2Expr, map);
  // CHECK: s1
  mlirAffineExprDump(composed);
  if (!mlirAffineExprEqual(composed, affineSymbolExpr))
    return 4;

  return 0;
}

int printIntegerSet(MlirContext ctx) {
  MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);

  // CHECK-LABEL: @printIntegerSet
  fprintf(stderr, "@printIntegerSet");

  // CHECK: (d0, d1)[s0] : (1 == 0)
  mlirIntegerSetDump(emptySet);

  if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
    return 1;

  MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
  if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
    return 2;

  // Construct a set constrained by:
  //   d0 - s0 == 0,
  //   d1 - 42 >= 0.
  MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
  MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
  MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
  MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
  MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
  MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
  MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
  MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
  MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
  bool flags[] = {true, false};

  MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
  // CHECK: (d0, d1)[s0] : (
  // CHECK-DAG: d0 - s0 == 0
  // CHECK-DAG: d1 - 42 >= 0
  mlirIntegerSetDump(set);

  // Transform d1 into s0.
  MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
  MlirAffineExpr repl[] = {d0, s1};
  MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
  // CHECK: (d0)[s0, s1] : (
  // CHECK-DAG: d0 - s0 == 0
  // CHECK-DAG: s1 - 42 >= 0
  mlirIntegerSetDump(replaced);

  if (mlirIntegerSetGetNumDims(set) != 2)
    return 3;
  if (mlirIntegerSetGetNumDims(replaced) != 1)
    return 4;

  if (mlirIntegerSetGetNumSymbols(set) != 1)
    return 5;
  if (mlirIntegerSetGetNumSymbols(replaced) != 2)
    return 6;

  if (mlirIntegerSetGetNumInputs(set) != 3)
    return 7;

  if (mlirIntegerSetGetNumConstraints(set) != 2)
    return 8;

  if (mlirIntegerSetGetNumEqualities(set) != 1)
    return 9;

  if (mlirIntegerSetGetNumInequalities(set) != 1)
    return 10;

  MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
  MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
  bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
  bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
  if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
    return 11;
  if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
    return 12;

  return 0;
}

int registerOnlyStd() {
  MlirContext ctx = mlirContextCreate();
  // The built-in dialect is always loaded.
  if (mlirContextGetNumLoadedDialects(ctx) != 1)
    return 1;

  MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();

  MlirDialect std = mlirContextGetOrLoadDialect(
      ctx, mlirDialectHandleGetNamespace(stdHandle));
  if (!mlirDialectIsNull(std))
    return 2;

  mlirDialectHandleRegisterDialect(stdHandle, ctx);

  std = mlirContextGetOrLoadDialect(ctx,
                                    mlirDialectHandleGetNamespace(stdHandle));
  if (mlirDialectIsNull(std))
    return 3;

  MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
  if (!mlirDialectEqual(std, alsoStd))
    return 4;

  MlirStringRef stdNs = mlirDialectGetNamespace(std);
  MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
  if (stdNs.length != alsoStdNs.length ||
      strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
    return 5;

  fprintf(stderr, "@registration\n");
  // CHECK-LABEL: @registration

  // CHECK: std.cond_br is_registered: 1
  fprintf(stderr, "std.cond_br is_registered: %d\n",
          mlirContextIsRegisteredOperation(
              ctx, mlirStringRefCreateFromCString("std.cond_br")));

  // CHECK: std.not_existing_op is_registered: 0
  fprintf(stderr, "std.not_existing_op is_registered: %d\n",
          mlirContextIsRegisteredOperation(
              ctx, mlirStringRefCreateFromCString("std.not_existing_op")));

  // CHECK: not_existing_dialect.not_existing_op is_registered: 0
  fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
          mlirContextIsRegisteredOperation(
              ctx, mlirStringRefCreateFromCString(
                       "not_existing_dialect.not_existing_op")));

  mlirContextDestroy(ctx);
  return 0;
}

/// Tests backreference APIs
static int testBackreferences() {
  fprintf(stderr, "@test_backreferences\n");

  MlirContext ctx = mlirContextCreate();
  mlirContextSetAllowUnregisteredDialects(ctx, true);
  MlirLocation loc = mlirLocationUnknownGet(ctx);

  MlirOperationState opState =
      mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc);
  MlirRegion region = mlirRegionCreate();
  MlirBlock block = mlirBlockCreate(0, NULL);
  mlirRegionAppendOwnedBlock(region, block);
  mlirOperationStateAddOwnedRegions(&opState, 1, &region);
  MlirOperation op = mlirOperationCreate(&opState);
  MlirIdentifier ident =
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("identifier"));

  if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) {
    fprintf(stderr, "ERROR: Getting context from operation failed\n");
    return 1;
  }
  if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) {
    fprintf(stderr, "ERROR: Getting parent operation from block failed\n");
    return 2;
  }
  if (!mlirContextEqual(ctx, mlirIdentifierGetContext(ident))) {
    fprintf(stderr, "ERROR: Getting context from identifier failed\n");
    return 3;
  }

  mlirOperationDestroy(op);
  mlirContextDestroy(ctx);

  // CHECK-LABEL: @test_backreferences
  return 0;
}

/// Tests operand APIs.
int testOperands() {
  fprintf(stderr, "@testOperands\n");
  // CHECK-LABEL: @testOperands

  MlirContext ctx = mlirContextCreate();
  mlirRegisterAllDialects(ctx);
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
  MlirLocation loc = mlirLocationUnknownGet(ctx);
  MlirType indexType = mlirIndexTypeGet(ctx);

  // Create some constants to use as operands.
  MlirAttribute indexZeroLiteral =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
      indexZeroLiteral);
  MlirOperationState constZeroState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("arith.constant"), loc);
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
  MlirOperation constZero = mlirOperationCreate(&constZeroState);
  MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);

  MlirAttribute indexOneLiteral =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
  MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
      indexOneLiteral);
  MlirOperationState constOneState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("arith.constant"), loc);
  mlirOperationStateAddResults(&constOneState, 1, &indexType);
  mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
  MlirOperation constOne = mlirOperationCreate(&constOneState);
  MlirValue constOneValue = mlirOperationGetResult(constOne, 0);

  // Create the operation under test.
  mlirContextSetAllowUnregisteredDialects(ctx, true);
  MlirOperationState opState =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
  MlirValue initialOperands[] = {constZeroValue};
  mlirOperationStateAddOperands(&opState, 1, initialOperands);
  MlirOperation op = mlirOperationCreate(&opState);

  // Test operand APIs.
  intptr_t numOperands = mlirOperationGetNumOperands(op);
  fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands);
  // CHECK: Num Operands: 1

  MlirValue opOperand = mlirOperationGetOperand(op, 0);
  fprintf(stderr, "Original operand: ");
  mlirValuePrint(opOperand, printToStderr, NULL);
  // CHECK: Original operand: {{.+}} arith.constant 0 : index

  mlirOperationSetOperand(op, 0, constOneValue);
  opOperand = mlirOperationGetOperand(op, 0);
  fprintf(stderr, "Updated operand: ");
  mlirValuePrint(opOperand, printToStderr, NULL);
  // CHECK: Updated operand: {{.+}} arith.constant 1 : index

  mlirOperationDestroy(op);
  mlirOperationDestroy(constZero);
  mlirOperationDestroy(constOne);
  mlirContextDestroy(ctx);

  return 0;
}

/// Tests clone APIs.
int testClone() {
  fprintf(stderr, "@testClone\n");
  // CHECK-LABEL: @testClone

  MlirContext ctx = mlirContextCreate();
  mlirRegisterAllDialects(ctx);
  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std"));
  MlirLocation loc = mlirLocationUnknownGet(ctx);
  MlirType indexType = mlirIndexTypeGet(ctx);
  MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");

  MlirAttribute indexZeroLiteral =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
  MlirOperationState constZeroState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("arith.constant"), loc);
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
  MlirOperation constZero = mlirOperationCreate(&constZeroState);

  MlirAttribute indexOneLiteral =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
  MlirOperation constOne = mlirOperationClone(constZero);
  mlirOperationSetAttributeByName(constOne, valueStringRef, indexOneLiteral);

  mlirOperationPrint(constZero, printToStderr, NULL);
  mlirOperationPrint(constOne, printToStderr, NULL);
  // CHECK: arith.constant 0 : index
  // CHECK: arith.constant 1 : index

  mlirOperationDestroy(constZero);
  mlirOperationDestroy(constOne);
  mlirContextDestroy(ctx);
  return 0;
}

// Wraps a diagnostic into additional text we can match against.
MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
  fprintf(stderr, "processing diagnostic (userData: %" PRIdPTR ") <<\n",
          (intptr_t)userData);
  mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
  fprintf(stderr, "\n");
  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
  mlirLocationPrint(loc, printToStderr, NULL);
  assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
  fprintf(stderr, "\n>> end of diagnostic (userData: %" PRIdPTR ")\n",
          (intptr_t)userData);
  return mlirLogicalResultSuccess();
}

// Logs when the delete user data callback is called
static void deleteUserData(void *userData) {
  fprintf(stderr, "deleting user data (userData: %" PRIdPTR ")\n",
          (intptr_t)userData);
}

int testTypeID(MlirContext ctx) {
  fprintf(stderr, "@testTypeID\n");

  // Test getting and comparing type and attribute type ids.
  MlirType i32 = mlirIntegerTypeGet(ctx, 32);
  MlirTypeID i32ID = mlirTypeGetTypeID(i32);
  MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
  MlirTypeID ui32ID = mlirTypeGetTypeID(ui32);
  MlirType f32 = mlirF32TypeGet(ctx);
  MlirTypeID f32ID = mlirTypeGetTypeID(f32);
  MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1);
  MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr);

  if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) ||
      mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) {
    fprintf(stderr, "ERROR: Expected type ids to be present\n");
    return 1;
  }

  if (!mlirTypeIDEqual(i32ID, ui32ID) ||
      mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) {
    fprintf(
        stderr,
        "ERROR: Expected different integer types to have the same type id\n");
    return 2;
  }

  if (mlirTypeIDEqual(i32ID, f32ID) ||
      mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(f32ID)) {
    fprintf(stderr,
            "ERROR: Expected integer type id to not equal float type id\n");
    return 3;
  }

  if (mlirTypeIDEqual(i32ID, i32AttrID) ||
      mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(i32AttrID)) {
    fprintf(stderr, "ERROR: Expected integer type id to not equal integer "
                    "attribute type id\n");
    return 4;
  }

  MlirLocation loc = mlirLocationUnknownGet(ctx);
  MlirType indexType = mlirIndexTypeGet(ctx);
  MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");

  // Create a registered operation, which should have a type id.
  MlirAttribute indexZeroLiteral =
      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
  MlirOperationState constZeroState = mlirOperationStateGet(
      mlirStringRefCreateFromCString("arith.constant"), loc);
  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
  MlirOperation constZero = mlirOperationCreate(&constZeroState);

  if (!mlirOperationVerify(constZero)) {
    fprintf(stderr, "ERROR: Expected operation to verify correctly\n");
    return 5;
  }

  if (mlirOperationIsNull(constZero)) {
    fprintf(stderr, "ERROR: Expected registered operation to be present\n");
    return 6;
  }

  MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);

  if (mlirTypeIDIsNull(registeredOpID)) {
    fprintf(stderr,
            "ERROR: Expected registered operation type id to be present\n");
    return 7;
  }

  // Create an unregistered operation, which should not have a type id.
  mlirContextSetAllowUnregisteredDialects(ctx, true);
  MlirOperationState opState =
      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
  MlirOperation unregisteredOp = mlirOperationCreate(&opState);
  if (mlirOperationIsNull(unregisteredOp)) {
    fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
    return 8;
  }

  MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);

  if (!mlirTypeIDIsNull(unregisteredOpID)) {
    fprintf(stderr,
            "ERROR: Expected unregistered operation type id to be null\n");
    return 9;
  }

  mlirOperationDestroy(constZero);
  mlirOperationDestroy(unregisteredOp);

  return 0;
}

int testSymbolTable(MlirContext ctx) {
  fprintf(stderr, "@testSymbolTable\n");

  const char *moduleString = "func private @foo()"
                             "func private @bar()";
  const char *otherModuleString = "func private @qux()"
                                  "func private @foo()";

  MlirModule module =
      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
  MlirModule otherModule = mlirModuleCreateParse(
      ctx, mlirStringRefCreateFromCString(otherModuleString));

  MlirSymbolTable symbolTable =
      mlirSymbolTableCreate(mlirModuleGetOperation(module));

  MlirOperation funcFoo =
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("foo"));
  if (mlirOperationIsNull(funcFoo))
    return 1;

  MlirOperation funcBar =
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
  if (mlirOperationEqual(funcFoo, funcBar))
    return 2;

  MlirOperation missing =
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
  if (!mlirOperationIsNull(missing))
    return 3;

  MlirBlock moduleBody = mlirModuleGetBody(module);
  MlirBlock otherModuleBody = mlirModuleGetBody(otherModule);
  MlirOperation operation = mlirBlockGetFirstOperation(otherModuleBody);
  mlirOperationRemoveFromParent(operation);
  mlirBlockAppendOwnedOperation(moduleBody, operation);

  // At this moment, the operation is still missing from the symbol table.
  MlirOperation stillMissing =
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
  if (!mlirOperationIsNull(stillMissing))
    return 4;

  // After it is added to the symbol table, and not only the operation with
  // which the table is associated, it can be looked up.
  mlirSymbolTableInsert(symbolTable, operation);
  MlirOperation funcQux =
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
  if (!mlirOperationEqual(operation, funcQux))
    return 5;

  // Erasing from the symbol table also removes the operation.
  mlirSymbolTableErase(symbolTable, funcBar);
  MlirOperation nowMissing =
      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
  if (!mlirOperationIsNull(nowMissing))
    return 6;

  // Adding a symbol with the same name to the table should rename.
  MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(otherModuleBody);
  mlirOperationRemoveFromParent(duplicateNameOp);
  mlirBlockAppendOwnedOperation(moduleBody, duplicateNameOp);
  MlirAttribute newName = mlirSymbolTableInsert(symbolTable, duplicateNameOp);
  MlirStringRef newNameStr = mlirStringAttrGetValue(newName);
  if (mlirStringRefEqual(newNameStr, mlirStringRefCreateFromCString("foo")))
    return 7;
  MlirAttribute updatedName = mlirOperationGetAttributeByName(
      duplicateNameOp, mlirSymbolTableGetSymbolAttributeName());
  if (!mlirAttributeEqual(updatedName, newName))
    return 8;

  mlirOperationDump(mlirModuleGetOperation(module));
  mlirOperationDump(mlirModuleGetOperation(otherModule));
  // clang-format off
  // CHECK-LABEL: @testSymbolTable
  // CHECK: module
  // CHECK:   func private @foo
  // CHECK:   func private @qux
  // CHECK:   func private @foo{{.+}}
  // CHECK: module
  // CHECK-NOT: @qux
  // CHECK-NOT: @foo
  // clang-format on

  mlirSymbolTableDestroy(symbolTable);
  mlirModuleDestroy(module);
  mlirModuleDestroy(otherModule);

  return 0;
}

void testDiagnostics() {
  MlirContext ctx = mlirContextCreate();
  MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
      ctx, errorHandler, (void *)42, deleteUserData);
  fprintf(stderr, "@test_diagnostics\n");
  MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
  mlirEmitError(unknownLoc, "test diagnostics");
  MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
      ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
  mlirEmitError(fileLineColLoc, "test diagnostics");
  MlirLocation callSiteLoc = mlirLocationCallSiteGet(
      mlirLocationFileLineColGet(
          ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
      fileLineColLoc);
  mlirEmitError(callSiteLoc, "test diagnostics");
  MlirLocation null = {0};
  MlirLocation nameLoc =
      mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
  mlirEmitError(nameLoc, "test diagnostics");
  MlirLocation locs[2] = {nameLoc, callSiteLoc};
  MlirAttribute nullAttr = {0};
  MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
  mlirEmitError(fusedLoc, "test diagnostics");
  mlirContextDetachDiagnosticHandler(ctx, id);
  mlirEmitError(unknownLoc, "more test diagnostics");
  // CHECK-LABEL: @test_diagnostics
  // CHECK: processing diagnostic (userData: 42) <<
  // CHECK:   test diagnostics
  // CHECK:   loc(unknown)
  // CHECK: >> end of diagnostic (userData: 42)
  // CHECK: processing diagnostic (userData: 42) <<
  // CHECK:   test diagnostics
  // CHECK:   loc("file.c":1:2)
  // CHECK: >> end of diagnostic (userData: 42)
  // CHECK: processing diagnostic (userData: 42) <<
  // CHECK:   test diagnostics
  // CHECK:   loc(callsite("other-file.c":2:3 at "file.c":1:2))
  // CHECK: >> end of diagnostic (userData: 42)
  // CHECK: processing diagnostic (userData: 42) <<
  // CHECK:   test diagnostics
  // CHECK:   loc("named")
  // CHECK: >> end of diagnostic (userData: 42)
  // CHECK: processing diagnostic (userData: 42) <<
  // CHECK:   test diagnostics
  // CHECK:   loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
  // CHECK: deleting user data (userData: 42)
  // CHECK-NOT: processing diagnostic
  // CHECK:     more test diagnostics
  mlirContextDestroy(ctx);
}

int main() {
  MlirContext ctx = mlirContextCreate();
  mlirRegisterAllDialects(ctx);
  if (constructAndTraverseIr(ctx))
    return 1;
  buildWithInsertionsAndPrint(ctx);
  if (createOperationWithTypeInference(ctx))
    return 2;

  if (printBuiltinTypes(ctx))
    return 3;
  if (printBuiltinAttributes(ctx))
    return 4;
  if (printAffineMap(ctx))
    return 5;
  if (printAffineExpr(ctx))
    return 6;
  if (affineMapFromExprs(ctx))
    return 7;
  if (printIntegerSet(ctx))
    return 8;
  if (registerOnlyStd())
    return 9;
  if (testBackreferences())
    return 10;
  if (testOperands())
    return 11;
  if (testClone())
    return 12;
  if (testTypeID(ctx))
    return 13;
  if (testSymbolTable(ctx))
    return 14;

  mlirContextDestroy(ctx);

  testDiagnostics();
  return 0;
}
