//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Transport.h"
#include "Logging.h"
#include "Protocol.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Errno.h"
#include "llvm/Support/Error.h"
#include <system_error>

using namespace mlir;
using namespace mlir::lsp;

//===----------------------------------------------------------------------===//
// Reply
//===----------------------------------------------------------------------===//

namespace {
/// Function object to reply to an LSP call.
/// Each instance must be called exactly once, otherwise:
///  - if there was no reply, an error reply is sent
///  - if there were multiple replies, only the first is sent
class Reply {
public:
  Reply(const llvm::json::Value &id, StringRef method,
        JSONTransport &transport);
  Reply(Reply &&other);
  Reply &operator=(Reply &&) = delete;
  Reply(const Reply &) = delete;
  Reply &operator=(const Reply &) = delete;

  void operator()(llvm::Expected<llvm::json::Value> reply);

private:
  StringRef method;
  std::atomic<bool> replied = {false};
  llvm::json::Value id;
  JSONTransport *transport;
};
} // namespace

Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
             JSONTransport &transport)
    : id(id), transport(&transport) {}

Reply::Reply(Reply &&other)
    : replied(other.replied.load()), id(std::move(other.id)),
      transport(other.transport) {
  other.transport = nullptr;
}

void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
  if (replied.exchange(true)) {
    Logger::error("Replied twice to message {0}({1})", method, id);
    assert(false && "must reply to each call only once!");
    return;
  }
  assert(transport && "expected valid transport to reply to");

  if (reply) {
    Logger::info("--> reply:{0}({1})", method, id);
    transport->reply(std::move(id), std::move(reply));
  } else {
    llvm::Error error = reply.takeError();
    Logger::info("--> reply:{0}({1})", method, id, error);
    transport->reply(std::move(id), std::move(error));
  }
}

//===----------------------------------------------------------------------===//
// MessageHandler
//===----------------------------------------------------------------------===//

bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
  Logger::info("--> {0}", method);

  if (method == "exit")
    return false;
  if (method == "$cancel") {
    // TODO: Add support for cancelling requests.
  } else {
    auto it = notificationHandlers.find(method);
    if (it != notificationHandlers.end())
      it->second(value);
  }
  return true;
}

bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
                            llvm::json::Value id) {
  Logger::info("--> {0}({1})", method, id);

  Reply reply(id, method, transport);

  auto it = methodHandlers.find(method);
  if (it != methodHandlers.end()) {
    it->second(params, std::move(reply));
  } else {
    reply(llvm::make_error<LSPError>("method not found: " + method.str(),
                                     ErrorCode::MethodNotFound));
  }
  return true;
}

bool MessageHandler::onReply(llvm::json::Value id,
                             llvm::Expected<llvm::json::Value> result) {
  // TODO: Add support for reply callbacks when support for outgoing messages is
  // added. For now, we just log an error on any replies received.
  Callback<llvm::json::Value> replyHandler =
      [&id](llvm::Expected<llvm::json::Value> result) {
        Logger::error(
            "received a reply with ID {0}, but there was no such call", id);
        if (!result)
          llvm::consumeError(result.takeError());
      };

  // Log and run the reply handler.
  if (result)
    replyHandler(std::move(result));
  else
    replyHandler(result.takeError());
  return true;
}

//===----------------------------------------------------------------------===//
// JSONTransport
//===----------------------------------------------------------------------===//

/// Encode the given error as a JSON object.
static llvm::json::Object encodeError(llvm::Error error) {
  std::string message;
  ErrorCode code = ErrorCode::UnknownErrorCode;
  auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
    message = lspError.message;
    code = lspError.code;
    return llvm::Error::success();
  };
  if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
    message = llvm::toString(std::move(unhandled));

  return llvm::json::Object{
      {"message", std::move(message)},
      {"code", int64_t(code)},
  };
}

/// Decode the given JSON object into an error.
llvm::Error decodeError(const llvm::json::Object &o) {
  StringRef msg = o.getString("message").getValueOr("Unspecified error");
  if (Optional<int64_t> code = o.getInteger("code"))
    return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
  return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
                                             msg.str());
}

void JSONTransport::notify(StringRef method, llvm::json::Value params) {
  sendMessage(llvm::json::Object{
      {"jsonrpc", "2.0"},
      {"method", method},
      {"params", std::move(params)},
  });
}
void JSONTransport::call(StringRef method, llvm::json::Value params,
                         llvm::json::Value id) {
  sendMessage(llvm::json::Object{
      {"jsonrpc", "2.0"},
      {"id", std::move(id)},
      {"method", method},
      {"params", std::move(params)},
  });
}
void JSONTransport::reply(llvm::json::Value id,
                          llvm::Expected<llvm::json::Value> result) {
  if (result) {
    return sendMessage(llvm::json::Object{
        {"jsonrpc", "2.0"},
        {"id", std::move(id)},
        {"result", std::move(*result)},
    });
  }

  sendMessage(llvm::json::Object{
      {"jsonrpc", "2.0"},
      {"id", std::move(id)},
      {"error", encodeError(result.takeError())},
  });
}

llvm::Error JSONTransport::run(MessageHandler &handler) {
  std::string json;
  while (!feof(in)) {
    if (ferror(in)) {
      return llvm::errorCodeToError(
          std::error_code(errno, std::system_category()));
    }

    if (succeeded(readMessage(json))) {
      if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
        if (!handleMessage(std::move(*doc), handler))
          return llvm::Error::success();
      } else {
        Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
      }
    }
  }
  return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
}

void JSONTransport::sendMessage(llvm::json::Value msg) {
  outputBuffer.clear();
  llvm::raw_svector_ostream os(outputBuffer);
  os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
  out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
      << outputBuffer;
  out.flush();
  Logger::debug(">>> {0}\n", outputBuffer);
}

bool JSONTransport::handleMessage(llvm::json::Value msg,
                                  MessageHandler &handler) {
  // Message must be an object with "jsonrpc":"2.0".
  llvm::json::Object *object = msg.getAsObject();
  if (!object ||
      object->getString("jsonrpc") != llvm::Optional<StringRef>("2.0"))
    return false;

  // `id` may be any JSON value. If absent, this is a notification.
  llvm::Optional<llvm::json::Value> id;
  if (llvm::json::Value *i = object->get("id"))
    id = std::move(*i);
  Optional<StringRef> method = object->getString("method");

  // This is a response.
  if (!method) {
    if (!id)
      return false;
    if (auto *err = object->getObject("error"))
      return handler.onReply(std::move(*id), decodeError(*err));
    // result should be given, use null if not.
    llvm::json::Value result = nullptr;
    if (llvm::json::Value *r = object->get("result"))
      result = std::move(*r);
    return handler.onReply(std::move(*id), std::move(result));
  }

  // Params should be given, use null if not.
  llvm::json::Value params = nullptr;
  if (llvm::json::Value *p = object->get("params"))
    params = std::move(*p);

  if (id)
    return handler.onCall(*method, std::move(params), std::move(*id));
  return handler.onNotify(*method, std::move(params));
}

/// Tries to read a line up to and including \n.
/// If failing, feof(), ferror(), or shutdownRequested() will be set.
LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) {
  // Big enough to hold any reasonable header line. May not fit content lines
  // in delimited mode, but performance doesn't matter for that mode.
  static constexpr int bufSize = 128;
  size_t size = 0;
  out.clear();
  for (;;) {
    out.resize_for_overwrite(size + bufSize);
    if (!std::fgets(&out[size], bufSize, in))
      return failure();

    clearerr(in);

    // If the line contained null bytes, anything after it (including \n) will
    // be ignored. Fortunately this is not a legal header or JSON.
    size_t read = std::strlen(&out[size]);
    if (read > 0 && out[size + read - 1] == '\n') {
      out.resize(size + read);
      return success();
    }
    size += read;
  }
}

// Returns None when:
//  - ferror(), feof(), or shutdownRequested() are set.
//  - Content-Length is missing or empty (protocol error)
LogicalResult JSONTransport::readStandardMessage(std::string &json) {
  // A Language Server Protocol message starts with a set of HTTP headers,
  // delimited  by \r\n, and terminated by an empty line (\r\n).
  unsigned long long contentLength = 0;
  llvm::SmallString<128> line;
  while (true) {
    if (feof(in) || ferror(in) || failed(readLine(in, line)))
      return failure();

    // Content-Length is a mandatory header, and the only one we handle.
    StringRef lineRef = line;
    if (lineRef.consume_front("Content-Length: ")) {
      llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
    } else if (!lineRef.trim().empty()) {
      // It's another header, ignore it.
      continue;
    } else {
      // An empty line indicates the end of headers. Go ahead and read the JSON.
      break;
    }
  }

  // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
  if (contentLength == 0 || contentLength > 1 << 30)
    return failure();

  json.resize(contentLength);
  for (size_t pos = 0, read; pos < contentLength; pos += read) {
    read = std::fread(&json[pos], 1, contentLength - pos, in);
    if (read == 0)
      return failure();

    // If we're done, the error was transient. If we're not done, either it was
    // transient or we'll see it again on retry.
    clearerr(in);
    pos += read;
  }
  return success();
}

/// For lit tests we support a simplified syntax:
/// - messages are delimited by '// -----' on a line by itself
/// - lines starting with // are ignored.
/// This is a testing path, so favor simplicity over performance here.
/// When returning failure: feof(), ferror(), or shutdownRequested() will be
/// set.
LogicalResult JSONTransport::readDelimitedMessage(std::string &json) {
  json.clear();
  llvm::SmallString<128> line;
  while (succeeded(readLine(in, line))) {
    StringRef lineRef = line.str().trim();
    if (lineRef.startswith("//")) {
      // Found a delimiter for the message.
      if (lineRef == "// -----")
        break;
      continue;
    }

    json += line;
  }

  return failure(ferror(in));
}
