// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/cryptauth/secure_channel.h"

#include <memory>

#include "base/bind.h"
#include "base/memory/ptr_util.h"
#include "chromeos/components/proximity_auth/logging/logging.h"
#include "components/cryptauth/secure_message_delegate_impl.h"
#include "components/cryptauth/wire_message.h"

namespace cryptauth {

// static
SecureChannel::Factory* SecureChannel::Factory::factory_instance_ = nullptr;

// static
std::unique_ptr<SecureChannel> SecureChannel::Factory::NewInstance(
    std::unique_ptr<Connection> connection) {
  if (!factory_instance_) {
    factory_instance_ = new Factory();
  }
  return factory_instance_->BuildInstance(std::move(connection));
}

// static
void SecureChannel::Factory::SetInstanceForTesting(Factory* factory) {
  factory_instance_ = factory;
}

std::unique_ptr<SecureChannel> SecureChannel::Factory::BuildInstance(
    std::unique_ptr<Connection> connection) {
  return base::WrapUnique(new SecureChannel(std::move(connection)));
}

// static
std::string SecureChannel::StatusToString(const Status& status) {
  switch (status) {
    case Status::DISCONNECTED:
      return "[disconnected]";
    case Status::CONNECTING:
      return "[connecting]";
    case Status::CONNECTED:
      return "[connected]";
    case Status::AUTHENTICATING:
      return "[authenticating]";
    case Status::AUTHENTICATED:
      return "[authenticated]";
    case Status::DISCONNECTING:
      return "[disconnecting]";
    default:
      return "[unknown status]";
  }
}

SecureChannel::PendingMessage::PendingMessage(const std::string& feature,
                                              const std::string& payload,
                                              int sequence_number)
    : feature(feature), payload(payload), sequence_number(sequence_number) {}

SecureChannel::PendingMessage::~PendingMessage() {}

SecureChannel::SecureChannel(std::unique_ptr<Connection> connection)
    : status_(Status::DISCONNECTED),
      connection_(std::move(connection)),
      weak_ptr_factory_(this) {
  connection_->AddObserver(this);
}

SecureChannel::~SecureChannel() {
  connection_->RemoveObserver(this);
}

void SecureChannel::Initialize() {
  DCHECK(status_ == Status::DISCONNECTED);
  connection_->Connect();
  TransitionToStatus(Status::CONNECTING);
}

int SecureChannel::SendMessage(const std::string& feature,
                               const std::string& payload) {
  DCHECK(status_ == Status::AUTHENTICATED);

  int sequence_number = next_sequence_number_;
  next_sequence_number_++;

  queued_messages_.emplace(
      std::make_unique<PendingMessage>(feature, payload, sequence_number));
  ProcessMessageQueue();

  return sequence_number;
}

void SecureChannel::Disconnect() {
  if (connection_->IsConnected()) {
    TransitionToStatus(Status::DISCONNECTING);

    // If |connection_| is active, calling Disconnect() will eventually cause
    // its status to transition to DISCONNECTED, which will in turn cause this
    // class to transition to DISCONNECTED.
    connection_->Disconnect();
    return;
  }

  TransitionToStatus(Status::DISCONNECTED);
}

void SecureChannel::AddObserver(Observer* observer) {
  observer_list_.AddObserver(observer);
}

void SecureChannel::RemoveObserver(Observer* observer) {
  observer_list_.RemoveObserver(observer);
}

void SecureChannel::GetConnectionRssi(
    base::OnceCallback<void(base::Optional<int32_t>)> callback) {
  if (!connection_) {
    std::move(callback).Run(base::nullopt);
    return;
  }

  connection_->GetConnectionRssi(std::move(callback));
}

base::Optional<std::string> SecureChannel::GetChannelBindingData() {
  if (secure_context_)
    return secure_context_->GetChannelBindingData();

  return base::nullopt;
}

void SecureChannel::OnConnectionStatusChanged(
    Connection* connection,
    Connection::Status old_status,
    Connection::Status new_status) {
  DCHECK(connection == connection_.get());

  if (new_status == Connection::Status::CONNECTED) {
    TransitionToStatus(Status::CONNECTED);

    // Once the connection has succeeded, authenticate the connection by
    // initiating the security handshake.
    Authenticate();
    return;
  }

  if (new_status == Connection::Status::DISCONNECTED) {
    // If the connection is no longer active, disconnect.
    Disconnect();
    return;
  }
}

void SecureChannel::OnMessageReceived(const Connection& connection,
                                      const WireMessage& wire_message) {
  DCHECK(&connection == const_cast<const Connection*>(connection_.get()));
  if (wire_message.feature() == Authenticator::kAuthenticationFeature) {
    // If the message received was part of the authentication handshake, it
    // is a low-level message and should not be forwarded to observers.
    return;
  }

  if (!secure_context_) {
    PA_LOG(WARNING) << "Received unexpected message before authentication "
                    << "was complete. Feature: " << wire_message.feature()
                    << ", Payload: " << wire_message.payload();
    return;
  }

  secure_context_->Decode(wire_message.payload(),
                          base::Bind(&SecureChannel::OnMessageDecoded,
                                     weak_ptr_factory_.GetWeakPtr(),
                                     wire_message.feature()));
}

void SecureChannel::OnSendCompleted(const cryptauth::Connection& connection,
                                    const cryptauth::WireMessage& wire_message,
                                    bool success) {
  if (wire_message.feature() == Authenticator::kAuthenticationFeature) {
    // No need to process authentication messages; these are handled by
    // |authenticator_|.
    return;
  }

  if (!pending_message_) {
    PA_LOG(ERROR) << "OnSendCompleted(), but a send was not expected to be in "
                  << "progress. Disconnecting from "
                  << connection_->GetDeviceAddress();
    Disconnect();
    return;
  }

  if (success && status_ != Status::DISCONNECTED) {
    pending_message_.reset();

    // Create a WeakPtr to |this| before invoking observer callbacks. It is
    // possible that an Observer will respond to the OnMessageSent() call by
    // destroying the connection (e.g., if the client only wanted to send one
    // message and destroyed the connection after the message was sent).
    base::WeakPtr<SecureChannel> weak_this = weak_ptr_factory_.GetWeakPtr();

    if (wire_message.sequence_number() != -1) {
      for (auto& observer : observer_list_)
        observer.OnMessageSent(this, wire_message.sequence_number());
    }

    // Process the next message if possible. Note that if the SecureChannel was
    // deleted by the OnMessageSent() callback, this will be a no-op since
    // |weak_this| will have been invalidated in that case.
    if (weak_this.get())
      weak_this->ProcessMessageQueue();

    return;
  }

  PA_LOG(ERROR) << "Could not send message: {"
                << "payload: \"" << pending_message_->payload << "\", "
                << "feature: \"" << pending_message_->feature << "\""
                << "}";
  pending_message_.reset();

  // The connection automatically retries failed messges, so if |success| is
  // |false| here, a fatal error has occurred. Thus, there is no need to retry
  // the message; instead, disconnect.
  Disconnect();
}

void SecureChannel::TransitionToStatus(const Status& new_status) {
  if (new_status == status_) {
    // Only report changes to state.
    return;
  }

  Status old_status = status_;
  status_ = new_status;

  for (auto& observer : observer_list_)
    observer.OnSecureChannelStatusChanged(this, old_status, status_);
}

void SecureChannel::Authenticate() {
  DCHECK(status_ == Status::CONNECTED);
  DCHECK(!authenticator_);

  authenticator_ = DeviceToDeviceAuthenticator::Factory::NewInstance(
      connection_.get(), connection_->remote_device().user_id(),
      cryptauth::SecureMessageDelegateImpl::Factory::NewInstance());
  authenticator_->Authenticate(
      base::Bind(&SecureChannel::OnAuthenticationResult,
                 weak_ptr_factory_.GetWeakPtr()));

  TransitionToStatus(Status::AUTHENTICATING);
}

void SecureChannel::ProcessMessageQueue() {
  if (pending_message_ || queued_messages_.empty())  {
    return;
  }

  DCHECK(!connection_->is_sending_message());

  pending_message_ = std::move(queued_messages_.front());
  queued_messages_.pop();

  PA_LOG(INFO) << "Sending message to " << connection_->GetDeviceAddress()
               << ": {"
               << "feature: \"" << pending_message_->feature << "\", "
               << "payload: \"" << pending_message_->payload << "\""
               << "}";

  secure_context_->Encode(
      pending_message_->payload,
      base::Bind(&SecureChannel::OnMessageEncoded,
                 weak_ptr_factory_.GetWeakPtr(), pending_message_->feature,
                 pending_message_->sequence_number));
}

void SecureChannel::OnMessageEncoded(const std::string& feature,
                                     int sequence_number,
                                     const std::string& encoded_message) {
  connection_->SendMessage(std::make_unique<cryptauth::WireMessage>(
      encoded_message, feature, sequence_number));
}

void SecureChannel::OnMessageDecoded(
    const std::string& feature, const std::string& decoded_message) {
  PA_LOG(INFO) << "Received message from " << connection_->GetDeviceAddress()
               << ": {"
               << "feature: \"" << feature << "\", "
               << "payload: \"" << decoded_message << "\""
               << "}";

  for (auto& observer : observer_list_)
    observer.OnMessageReceived(this, feature, decoded_message);
}

void SecureChannel::OnAuthenticationResult(
    Authenticator::Result result,
    std::unique_ptr<SecureContext> secure_context) {
  DCHECK(status_ == Status::AUTHENTICATING);

  // The authenticator is no longer needed, so release it.
  authenticator_.reset();

  if (result != Authenticator::Result::SUCCESS) {
    PA_LOG(WARNING) << "Failed to authenticate connection to device with ID "
        << connection_->remote_device().GetTruncatedDeviceIdForLogs();
    Disconnect();
    return;
  }

  secure_context_ = std::move(secure_context);
  TransitionToStatus(Status::AUTHENTICATED);
}

}  // namespace cryptauth
