//===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "MLIRServer.h"
#include "../lsp-server-support/Logging.h"
#include "../lsp-server-support/Protocol.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/Parser/AsmParserState.h"
#include "mlir/Parser/Parser.h"
#include "llvm/Support/SourceMgr.h"

using namespace mlir;

/// Returns a language server location from the given MLIR file location.
static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
  llvm::Expected<lsp::URIForFile> sourceURI =
      lsp::URIForFile::fromFile(loc.getFilename());
  if (!sourceURI) {
    lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
                       loc.getFilename(),
                       llvm::toString(sourceURI.takeError()));
    return llvm::None;
  }

  lsp::Position position;
  position.line = loc.getLine() - 1;
  position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
  return lsp::Location{*sourceURI, lsp::Range(position)};
}

/// Returns a language server location from the given MLIR location, or None if
/// one couldn't be created. `uri` is an optional additional filter that, when
/// present, is used to filter sub locations that do not share the same uri.
static Optional<lsp::Location>
getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
                   const lsp::URIForFile *uri = nullptr) {
  Optional<lsp::Location> location;
  loc->walk([&](Location nestedLoc) {
    FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
    if (!fileLoc)
      return WalkResult::advance();

    Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
    if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
      location = *sourceLoc;
      SMLoc loc = sourceMgr.FindLocForLineAndColumn(
          sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());

      // Use range of potential identifier starting at location, else length 1
      // range.
      location->range.end.character += 1;
      if (Optional<SMRange> range = AsmParserState::convertIdLocToRange(loc)) {
        auto lineCol = sourceMgr.getLineAndColumn(range->End);
        location->range.end.character =
            std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
      }
      return WalkResult::interrupt();
    }
    return WalkResult::advance();
  });
  return location;
}

/// Collect all of the locations from the given MLIR location that are not
/// contained within the given URI.
static void collectLocationsFromLoc(Location loc,
                                    std::vector<lsp::Location> &locations,
                                    const lsp::URIForFile &uri) {
  SetVector<Location> visitedLocs;
  loc->walk([&](Location nestedLoc) {
    FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
    if (!fileLoc || !visitedLocs.insert(nestedLoc))
      return WalkResult::advance();

    Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
    if (sourceLoc && sourceLoc->uri != uri)
      locations.push_back(*sourceLoc);
    return WalkResult::advance();
  });
}

/// Returns true if the given range contains the given source location. Note
/// that this has slightly different behavior than SMRange because it is
/// inclusive of the end location.
static bool contains(SMRange range, SMLoc loc) {
  return range.Start.getPointer() <= loc.getPointer() &&
         loc.getPointer() <= range.End.getPointer();
}

/// Returns true if the given location is contained by the definition or one of
/// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
/// the range within `def` that the provided `loc` overlapped with.
static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
                       SMRange *overlappedRange = nullptr) {
  // Check the main definition.
  if (contains(def.loc, loc)) {
    if (overlappedRange)
      *overlappedRange = def.loc;
    return true;
  }

  // Check the uses.
  const auto *useIt = llvm::find_if(
      def.uses, [&](const SMRange &range) { return contains(range, loc); });
  if (useIt != def.uses.end()) {
    if (overlappedRange)
      *overlappedRange = *useIt;
    return true;
  }
  return false;
}

/// Given a location pointing to a result, return the result number it refers
/// to or None if it refers to all of the results.
static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
  // Skip all of the identifier characters.
  auto isIdentifierChar = [](char c) {
    return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
           c == '-';
  };
  const char *curPtr = loc.getPointer();
  while (isIdentifierChar(*curPtr))
    ++curPtr;

  // Check to see if this location indexes into the result group, via `#`. If it
  // doesn't, we can't extract a sub result number.
  if (*curPtr != '#')
    return llvm::None;

  // Compute the sub result number from the remaining portion of the string.
  const char *numberStart = ++curPtr;
  while (llvm::isDigit(*curPtr))
    ++curPtr;
  StringRef numberStr(numberStart, curPtr - numberStart);
  unsigned resultNumber = 0;
  return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
                                                    : resultNumber;
}

/// Given a source location range, return the text covered by the given range.
/// If the range is invalid, returns None.
static Optional<StringRef> getTextFromRange(SMRange range) {
  if (!range.isValid())
    return None;
  const char *startPtr = range.Start.getPointer();
  return StringRef(startPtr, range.End.getPointer() - startPtr);
}

/// Given a block, return its position in its parent region.
static unsigned getBlockNumber(Block *block) {
  return std::distance(block->getParent()->begin(), block->getIterator());
}

/// Given a block and source location, print the source name of the block to the
/// given output stream.
static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
  // Try to extract a name from the source location.
  Optional<StringRef> text = getTextFromRange(loc);
  if (text && text->startswith("^")) {
    os << *text;
    return;
  }

  // Otherwise, we don't have a name so print the block number.
  os << "<Block #" << getBlockNumber(block) << ">";
}
static void printDefBlockName(raw_ostream &os,
                              const AsmParserState::BlockDefinition &def) {
  printDefBlockName(os, def.block, def.definition.loc);
}

/// Convert the given MLIR diagnostic to the LSP form.
static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
                                               Diagnostic &diag,
                                               const lsp::URIForFile &uri) {
  lsp::Diagnostic lspDiag;
  lspDiag.source = "mlir";

  // Note: Right now all of the diagnostics are treated as parser issues, but
  // some are parser and some are verifier.
  lspDiag.category = "Parse Error";

  // Try to grab a file location for this diagnostic.
  // TODO: For simplicity, we just grab the first one. It may be likely that we
  // will need a more interesting heuristic here.'
  Optional<lsp::Location> lspLocation =
      getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
  if (lspLocation)
    lspDiag.range = lspLocation->range;

  // Convert the severity for the diagnostic.
  switch (diag.getSeverity()) {
  case DiagnosticSeverity::Note:
    llvm_unreachable("expected notes to be handled separately");
  case DiagnosticSeverity::Warning:
    lspDiag.severity = lsp::DiagnosticSeverity::Warning;
    break;
  case DiagnosticSeverity::Error:
    lspDiag.severity = lsp::DiagnosticSeverity::Error;
    break;
  case DiagnosticSeverity::Remark:
    lspDiag.severity = lsp::DiagnosticSeverity::Information;
    break;
  }
  lspDiag.message = diag.str();

  // Attach any notes to the main diagnostic as related information.
  std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
  for (Diagnostic &note : diag.getNotes()) {
    lsp::Location noteLoc;
    if (Optional<lsp::Location> loc =
            getLocationFromLoc(sourceMgr, note.getLocation()))
      noteLoc = *loc;
    else
      noteLoc.uri = uri;
    relatedDiags.emplace_back(noteLoc, note.str());
  }
  if (!relatedDiags.empty())
    lspDiag.relatedInformation = std::move(relatedDiags);

  return lspDiag;
}

//===----------------------------------------------------------------------===//
// MLIRDocument
//===----------------------------------------------------------------------===//

namespace {
/// This class represents all of the information pertaining to a specific MLIR
/// document.
struct MLIRDocument {
  MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
               StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
  MLIRDocument(const MLIRDocument &) = delete;
  MLIRDocument &operator=(const MLIRDocument &) = delete;

  //===--------------------------------------------------------------------===//
  // Definitions and References
  //===--------------------------------------------------------------------===//

  void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
                      std::vector<lsp::Location> &locations);
  void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
                        std::vector<lsp::Location> &references);

  //===--------------------------------------------------------------------===//
  // Hover
  //===--------------------------------------------------------------------===//

  Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
                                 const lsp::Position &hoverPos);
  Optional<lsp::Hover>
  buildHoverForOperation(SMRange hoverRange,
                         const AsmParserState::OperationDefinition &op);
  lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
                                          unsigned resultStart,
                                          unsigned resultEnd, SMLoc posLoc);
  lsp::Hover buildHoverForBlock(SMRange hoverRange,
                                const AsmParserState::BlockDefinition &block);
  lsp::Hover
  buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
                             const AsmParserState::BlockDefinition &block);

  //===--------------------------------------------------------------------===//
  // Document Symbols
  //===--------------------------------------------------------------------===//

  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
  void findDocumentSymbols(Operation *op,
                           std::vector<lsp::DocumentSymbol> &symbols);

  //===--------------------------------------------------------------------===//
  // Fields
  //===--------------------------------------------------------------------===//

  /// The high level parser state used to find definitions and references within
  /// the source file.
  AsmParserState asmState;

  /// The container for the IR parsed from the input file.
  Block parsedIR;

  /// The source manager containing the contents of the input file.
  llvm::SourceMgr sourceMgr;
};
} // namespace

MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
                           StringRef contents,
                           std::vector<lsp::Diagnostic> &diagnostics) {
  ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
    diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
  });

  // Try to parsed the given IR string.
  auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
  if (!memBuffer) {
    lsp::Logger::error("Failed to create memory buffer for file", uri.file());
    return;
  }

  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
  if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr,
                             &asmState))) {
    // If parsing failed, clear out any of the current state.
    parsedIR.clear();
    asmState = AsmParserState();
    return;
  }
}

//===----------------------------------------------------------------------===//
// MLIRDocument: Definitions and References
//===----------------------------------------------------------------------===//

void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
                                  const lsp::Position &defPos,
                                  std::vector<lsp::Location> &locations) {
  SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);

  // Functor used to check if an SM definition contains the position.
  auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
    if (!isDefOrUse(def, posLoc))
      return false;
    locations.emplace_back(uri, sourceMgr, def.loc);
    return true;
  };

  // Check all definitions related to operations.
  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
    if (contains(op.loc, posLoc))
      return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
    for (const auto &result : op.resultGroups)
      if (containsPosition(result.definition))
        return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
    for (const auto &symUse : op.symbolUses) {
      if (contains(symUse, posLoc)) {
        locations.emplace_back(uri, sourceMgr, op.loc);
        return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
      }
    }
  }

  // Check all definitions related to blocks.
  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
    if (containsPosition(block.definition))
      return;
    for (const AsmParserState::SMDefinition &arg : block.arguments)
      if (containsPosition(arg))
        return;
  }
}

void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
                                    const lsp::Position &pos,
                                    std::vector<lsp::Location> &references) {
  // Functor used to append all of the definitions/uses of the given SM
  // definition to the reference list.
  auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
    references.emplace_back(uri, sourceMgr, def.loc);
    for (const SMRange &use : def.uses)
      references.emplace_back(uri, sourceMgr, use);
  };

  SMLoc posLoc = pos.getAsSMLoc(sourceMgr);

  // Check all definitions related to operations.
  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
    if (contains(op.loc, posLoc)) {
      for (const auto &result : op.resultGroups)
        appendSMDef(result.definition);
      for (const auto &symUse : op.symbolUses)
        if (contains(symUse, posLoc))
          references.emplace_back(uri, sourceMgr, symUse);
      return;
    }
    for (const auto &result : op.resultGroups)
      if (isDefOrUse(result.definition, posLoc))
        return appendSMDef(result.definition);
    for (const auto &symUse : op.symbolUses) {
      if (!contains(symUse, posLoc))
        continue;
      for (const auto &symUse : op.symbolUses)
        references.emplace_back(uri, sourceMgr, symUse);
      return;
    }
  }

  // Check all definitions related to blocks.
  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
    if (isDefOrUse(block.definition, posLoc))
      return appendSMDef(block.definition);

    for (const AsmParserState::SMDefinition &arg : block.arguments)
      if (isDefOrUse(arg, posLoc))
        return appendSMDef(arg);
  }
}

//===----------------------------------------------------------------------===//
// MLIRDocument: Hover
//===----------------------------------------------------------------------===//

Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
                                             const lsp::Position &hoverPos) {
  SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
  SMRange hoverRange;

  // Check for Hovers on operations and results.
  for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
    // Check if the position points at this operation.
    if (contains(op.loc, posLoc))
      return buildHoverForOperation(op.loc, op);

    // Check if the position points at the symbol name.
    for (auto &use : op.symbolUses)
      if (contains(use, posLoc))
        return buildHoverForOperation(use, op);

    // Check if the position points at a result group.
    for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
      const auto &result = op.resultGroups[i];
      if (!isDefOrUse(result.definition, posLoc, &hoverRange))
        continue;

      // Get the range of results covered by the over position.
      unsigned resultStart = result.startIndex;
      unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
                                        : op.resultGroups[i + 1].startIndex;
      return buildHoverForOperationResult(hoverRange, op.op, resultStart,
                                          resultEnd, posLoc);
    }
  }

  // Check to see if the hover is over a block argument.
  for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
    if (isDefOrUse(block.definition, posLoc, &hoverRange))
      return buildHoverForBlock(hoverRange, block);

    for (const auto &arg : llvm::enumerate(block.arguments)) {
      if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
        continue;

      return buildHoverForBlockArgument(
          hoverRange, block.block->getArgument(arg.index()), block);
    }
  }
  return llvm::None;
}

Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
    SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
  llvm::raw_string_ostream os(hover.contents.value);

  // Add the operation name to the hover.
  os << "\"" << op.op->getName() << "\"";
  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
    os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
  os << "\n\n";

  os << "Generic Form:\n\n```mlir\n";

  // Temporary drop the regions of this operation so that they don't get
  // printed in the output. This helps keeps the size of the output hover
  // small.
  SmallVector<std::unique_ptr<Region>> regions;
  for (Region &region : op.op->getRegions()) {
    regions.emplace_back(std::make_unique<Region>());
    regions.back()->takeBody(region);
  }

  op.op->print(
      os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
  os << "\n```\n";

  // Move the regions back to the current operation.
  for (Region &region : op.op->getRegions())
    region.takeBody(*regions.back());

  return hover;
}

lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
                                                      Operation *op,
                                                      unsigned resultStart,
                                                      unsigned resultEnd,
                                                      SMLoc posLoc) {
  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
  llvm::raw_string_ostream os(hover.contents.value);

  // Add the parent operation name to the hover.
  os << "Operation: \"" << op->getName() << "\"\n\n";

  // Check to see if the location points to a specific result within the
  // group.
  if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
    if ((resultStart + *resultNumber) < resultEnd) {
      resultStart += *resultNumber;
      resultEnd = resultStart + 1;
    }
  }

  // Add the range of results and their types to the hover info.
  if ((resultStart + 1) == resultEnd) {
    os << "Result #" << resultStart << "\n\n"
       << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
  } else {
    os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
       << "Types: ";
    llvm::interleaveComma(
        op->getResults().slice(resultStart, resultEnd), os,
        [&](Value result) { os << "`" << result.getType() << "`"; });
  }

  return hover;
}

lsp::Hover
MLIRDocument::buildHoverForBlock(SMRange hoverRange,
                                 const AsmParserState::BlockDefinition &block) {
  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
  llvm::raw_string_ostream os(hover.contents.value);

  // Print the given block to the hover output stream.
  auto printBlockToHover = [&](Block *newBlock) {
    if (const auto *def = asmState.getBlockDef(newBlock))
      printDefBlockName(os, *def);
    else
      printDefBlockName(os, newBlock);
  };

  // Display the parent operation, block number, predecessors, and successors.
  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
     << "Block #" << getBlockNumber(block.block) << "\n\n";
  if (!block.block->hasNoPredecessors()) {
    os << "Predecessors: ";
    llvm::interleaveComma(block.block->getPredecessors(), os,
                          printBlockToHover);
    os << "\n\n";
  }
  if (!block.block->hasNoSuccessors()) {
    os << "Successors: ";
    llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
    os << "\n\n";
  }

  return hover;
}

lsp::Hover MLIRDocument::buildHoverForBlockArgument(
    SMRange hoverRange, BlockArgument arg,
    const AsmParserState::BlockDefinition &block) {
  lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
  llvm::raw_string_ostream os(hover.contents.value);

  // Display the parent operation, block, the argument number, and the type.
  os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
     << "Block: ";
  printDefBlockName(os, block);
  os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
     << "Type: `" << arg.getType() << "`\n\n";

  return hover;
}

//===----------------------------------------------------------------------===//
// MLIRDocument: Document Symbols
//===----------------------------------------------------------------------===//

void MLIRDocument::findDocumentSymbols(
    std::vector<lsp::DocumentSymbol> &symbols) {
  for (Operation &op : parsedIR)
    findDocumentSymbols(&op, symbols);
}

void MLIRDocument::findDocumentSymbols(
    Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
  std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;

  // Check for the source information of this operation.
  if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
    // If this operation defines a symbol, record it.
    if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
      symbols.emplace_back(symbol.getName(),
                           isa<FunctionOpInterface>(op)
                               ? lsp::SymbolKind::Function
                               : lsp::SymbolKind::Class,
                           lsp::Range(sourceMgr, def->scopeLoc),
                           lsp::Range(sourceMgr, def->loc));
      childSymbols = &symbols.back().children;

    } else if (op->hasTrait<OpTrait::SymbolTable>()) {
      // Otherwise, if this is a symbol table push an anonymous document symbol.
      symbols.emplace_back("<" + op->getName().getStringRef() + ">",
                           lsp::SymbolKind::Namespace,
                           lsp::Range(sourceMgr, def->scopeLoc),
                           lsp::Range(sourceMgr, def->loc));
      childSymbols = &symbols.back().children;
    }
  }

  // Recurse into the regions of this operation.
  if (!op->getNumRegions())
    return;
  for (Region &region : op->getRegions())
    for (Operation &childOp : region.getOps())
      findDocumentSymbols(&childOp, *childSymbols);
}

//===----------------------------------------------------------------------===//
// MLIRTextFileChunk
//===----------------------------------------------------------------------===//

namespace {
/// This class represents a single chunk of an MLIR text file.
struct MLIRTextFileChunk {
  MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
                    const lsp::URIForFile &uri, StringRef contents,
                    std::vector<lsp::Diagnostic> &diagnostics)
      : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}

  /// Adjust the line number of the given range to anchor at the beginning of
  /// the file, instead of the beginning of this chunk.
  void adjustLocForChunkOffset(lsp::Range &range) {
    adjustLocForChunkOffset(range.start);
    adjustLocForChunkOffset(range.end);
  }
  /// Adjust the line number of the given position to anchor at the beginning of
  /// the file, instead of the beginning of this chunk.
  void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }

  /// The line offset of this chunk from the beginning of the file.
  uint64_t lineOffset;
  /// The document referred to by this chunk.
  MLIRDocument document;
};
} // namespace

//===----------------------------------------------------------------------===//
// MLIRTextFile
//===----------------------------------------------------------------------===//

namespace {
/// This class represents a text file containing one or more MLIR documents.
class MLIRTextFile {
public:
  MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
               int64_t version, DialectRegistry &registry,
               std::vector<lsp::Diagnostic> &diagnostics);

  /// Return the current version of this text file.
  int64_t getVersion() const { return version; }

  //===--------------------------------------------------------------------===//
  // LSP Queries
  //===--------------------------------------------------------------------===//

  void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
                      std::vector<lsp::Location> &locations);
  void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
                        std::vector<lsp::Location> &references);
  Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
                                 lsp::Position hoverPos);
  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);

private:
  /// Find the MLIR document that contains the given position, and update the
  /// position to be anchored at the start of the found chunk instead of the
  /// beginning of the file.
  MLIRTextFileChunk &getChunkFor(lsp::Position &pos);

  /// The context used to hold the state contained by the parsed document.
  MLIRContext context;

  /// The full string contents of the file.
  std::string contents;

  /// The version of this file.
  int64_t version;

  /// The number of lines in the file.
  int64_t totalNumLines = 0;

  /// The chunks of this file. The order of these chunks is the order in which
  /// they appear in the text file.
  std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
};
} // namespace

MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
                           int64_t version, DialectRegistry &registry,
                           std::vector<lsp::Diagnostic> &diagnostics)
    : context(registry, MLIRContext::Threading::DISABLED),
      contents(fileContents.str()), version(version) {
  context.allowUnregisteredDialects();

  // Split the file into separate MLIR documents.
  // TODO: Find a way to share the split file marker with other tools. We don't
  // want to use `splitAndProcessBuffer` here, but we do want to make sure this
  // marker doesn't go out of sync.
  SmallVector<StringRef, 8> subContents;
  StringRef(contents).split(subContents, "// -----");
  chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
      context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));

  uint64_t lineOffset = subContents.front().count('\n');
  for (StringRef docContents : llvm::drop_begin(subContents)) {
    unsigned currentNumDiags = diagnostics.size();
    auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
                                                     docContents, diagnostics);
    lineOffset += docContents.count('\n');

    // Adjust locations used in diagnostics to account for the offset from the
    // beginning of the file.
    for (lsp::Diagnostic &diag :
         llvm::drop_begin(diagnostics, currentNumDiags)) {
      chunk->adjustLocForChunkOffset(diag.range);

      if (!diag.relatedInformation)
        continue;
      for (auto &it : *diag.relatedInformation)
        if (it.location.uri == uri)
          chunk->adjustLocForChunkOffset(it.location.range);
    }
    chunks.emplace_back(std::move(chunk));
  }
  totalNumLines = lineOffset;
}

void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
                                  lsp::Position defPos,
                                  std::vector<lsp::Location> &locations) {
  MLIRTextFileChunk &chunk = getChunkFor(defPos);
  chunk.document.getLocationsOf(uri, defPos, locations);

  // Adjust any locations within this file for the offset of this chunk.
  if (chunk.lineOffset == 0)
    return;
  for (lsp::Location &loc : locations)
    if (loc.uri == uri)
      chunk.adjustLocForChunkOffset(loc.range);
}

void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
                                    lsp::Position pos,
                                    std::vector<lsp::Location> &references) {
  MLIRTextFileChunk &chunk = getChunkFor(pos);
  chunk.document.findReferencesOf(uri, pos, references);

  // Adjust any locations within this file for the offset of this chunk.
  if (chunk.lineOffset == 0)
    return;
  for (lsp::Location &loc : references)
    if (loc.uri == uri)
      chunk.adjustLocForChunkOffset(loc.range);
}

Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
                                             lsp::Position hoverPos) {
  MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
  Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);

  // Adjust any locations within this file for the offset of this chunk.
  if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
    chunk.adjustLocForChunkOffset(*hoverInfo->range);
  return hoverInfo;
}

void MLIRTextFile::findDocumentSymbols(
    std::vector<lsp::DocumentSymbol> &symbols) {
  if (chunks.size() == 1)
    return chunks.front()->document.findDocumentSymbols(symbols);

  // If there are multiple chunks in this file, we create top-level symbols for
  // each chunk.
  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
    MLIRTextFileChunk &chunk = *chunks[i];
    lsp::Position startPos(chunk.lineOffset);
    lsp::Position endPos((i == e - 1) ? totalNumLines - 1
                                      : chunks[i + 1]->lineOffset);
    lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
                               lsp::SymbolKind::Namespace,
                               /*range=*/lsp::Range(startPos, endPos),
                               /*selectionRange=*/lsp::Range(startPos));
    chunk.document.findDocumentSymbols(symbol.children);

    // Fixup the locations of document symbols within this chunk.
    if (i != 0) {
      SmallVector<lsp::DocumentSymbol *> symbolsToFix;
      for (lsp::DocumentSymbol &childSymbol : symbol.children)
        symbolsToFix.push_back(&childSymbol);

      while (!symbolsToFix.empty()) {
        lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
        chunk.adjustLocForChunkOffset(symbol->range);
        chunk.adjustLocForChunkOffset(symbol->selectionRange);

        for (lsp::DocumentSymbol &childSymbol : symbol->children)
          symbolsToFix.push_back(&childSymbol);
      }
    }

    // Push the symbol for this chunk.
    symbols.emplace_back(std::move(symbol));
  }
}

MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
  if (chunks.size() == 1)
    return *chunks.front();

  // Search for the first chunk with a greater line offset, the previous chunk
  // is the one that contains `pos`.
  auto it = llvm::upper_bound(
      chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
        return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
      });
  MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
  pos.line -= chunk.lineOffset;
  return chunk;
}

//===----------------------------------------------------------------------===//
// MLIRServer::Impl
//===----------------------------------------------------------------------===//

struct lsp::MLIRServer::Impl {
  Impl(DialectRegistry &registry) : registry(registry) {}

  /// The registry containing dialects that can be recognized in parsed .mlir
  /// files.
  DialectRegistry &registry;

  /// The files held by the server, mapped by their URI file name.
  llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
};

//===----------------------------------------------------------------------===//
// MLIRServer
//===----------------------------------------------------------------------===//

lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
    : impl(std::make_unique<Impl>(registry)) {}
lsp::MLIRServer::~MLIRServer() = default;

void lsp::MLIRServer::addOrUpdateDocument(
    const URIForFile &uri, StringRef contents, int64_t version,
    std::vector<Diagnostic> &diagnostics) {
  impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
      uri, contents, version, impl->registry, diagnostics);
}

Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
  auto it = impl->files.find(uri.file());
  if (it == impl->files.end())
    return llvm::None;

  int64_t version = it->second->getVersion();
  impl->files.erase(it);
  return version;
}

void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
                                     const Position &defPos,
                                     std::vector<Location> &locations) {
  auto fileIt = impl->files.find(uri.file());
  if (fileIt != impl->files.end())
    fileIt->second->getLocationsOf(uri, defPos, locations);
}

void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
                                       const Position &pos,
                                       std::vector<Location> &references) {
  auto fileIt = impl->files.find(uri.file());
  if (fileIt != impl->files.end())
    fileIt->second->findReferencesOf(uri, pos, references);
}

Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
                                                const Position &hoverPos) {
  auto fileIt = impl->files.find(uri.file());
  if (fileIt != impl->files.end())
    return fileIt->second->findHover(uri, hoverPos);
  return llvm::None;
}

void lsp::MLIRServer::findDocumentSymbols(
    const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
  auto fileIt = impl->files.find(uri.file());
  if (fileIt != impl->files.end())
    fileIt->second->findDocumentSymbols(symbols);
}
