// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_CAST_CHANNEL_CAST_SOCKET_H_
#define COMPONENTS_CAST_CHANNEL_CAST_SOCKET_H_

#include <stdint.h>

#include <queue>
#include <string>

#include "base/cancelable_callback.h"
#include "base/gtest_prod_util.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "base/observer_list.h"
#include "base/threading/thread_checker.h"
#include "base/timer/timer.h"
#include "components/cast_channel/cast_auth_util.h"
#include "components/cast_channel/cast_channel_enum.h"
#include "components/cast_channel/cast_socket.h"
#include "components/cast_channel/cast_transport.h"
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/log/net_log_source.h"

namespace net {
class CertVerifier;
class CTPolicyEnforcer;
class CTVerifier;
class NetLog;
class SSLClientSocket;
class StreamSocket;
class TransportClientSocket;
class TransportSecurityState;
class X509Certificate;
}

namespace cast_channel {
class CastMessage;
class Logger;
struct LastError;

// Cast device capabilities.
enum CastDeviceCapability : int {
  NONE = 0,
  VIDEO_OUT = 1 << 0,
  VIDEO_IN = 1 << 1,
  AUDIO_OUT = 1 << 2,
  AUDIO_IN = 1 << 3,
  DEV_MODE = 1 << 4,
  MULTIZONE_GROUP = 1 << 5
};

// Public interface of the CastSocket class.
class CastSocket {
 public:
  // Invoked when CastSocket opens.
  // |socket|: raw pointer of opened socket (this pointer). Guaranteed to be
  // valid in callback function. Do not pass |socket| around.
  using OnOpenCallback = base::OnceCallback<void(CastSocket* socket)>;

  class Observer {
   public:
    virtual ~Observer() {}

    // Invoked when an error occurs on |socket|.
    virtual void OnError(const CastSocket& socket,
                         ChannelError error_state) = 0;

    // Invoked when |socket| receives a message.
    virtual void OnMessage(const CastSocket& socket,
                           const CastMessage& message) = 0;
  };

  virtual ~CastSocket() {}

  // Used by BrowserContextKeyedAPIFactory.
  static const char* service_name() { return "CastSocketImplManager"; }

  // Connects the channel to the peer. If successful, the channel will be in
  // READY_STATE_OPEN.  DO NOT delete the CastSocket object in |callback|.
  // Instead use Close().
  // |callback| will be invoked with any ChannelError that occurred, or
  // CHANNEL_ERROR_NONE if successful.
  // If the CastSocket is destroyed while the connection is pending, |callback|
  // will be invoked with CHANNEL_ERROR_UNKNOWN. In this case, invoking
  // |callback| must not result in any re-entrancy behavior.
  virtual void Connect(OnOpenCallback callback) = 0;

  // Closes the channel if not already closed. On completion, the channel will
  // be in READY_STATE_CLOSED.
  //
  // It is fine to delete this object in |callback|.
  virtual void Close(const net::CompletionCallback& callback) = 0;

  // The IP endpoint for the destination of the channel.
  virtual const net::IPEndPoint& ip_endpoint() const = 0;

  // Channel id generated by the CastChannelService.
  virtual int id() const = 0;

  // Sets the channel id generated by CastChannelService.
  virtual void set_id(int id) = 0;

  // The ready state of the channel.
  virtual ReadyState ready_state() const = 0;

  // Returns the last error that occurred on this channel, or
  // CHANNEL_ERROR_NONE if no error has occurred.
  virtual ChannelError error_state() const = 0;

  // True when keep-alive signaling is handled for this socket.
  virtual bool keep_alive() const = 0;

  // Whether the channel is audio only as identified by the device
  // certificate during channel authentication.
  virtual bool audio_only() const = 0;

  // Marks a socket as invalid due to an error, and sends an OnError
  // event to |delegate_|.
  // The OnError event receipient is responsible for closing the socket in the
  // event of an error.
  // Setting the error state does not close the socket if it is open.
  virtual void SetErrorState(ChannelError error_state) = 0;

  // Returns a pointer to the socket's message transport layer. Can be used to
  // send and receive CastMessages over the socket.
  virtual CastTransport* transport() const = 0;

  // Registers |observer| with the socket to receive messages and error events.
  virtual void AddObserver(Observer* observer) = 0;

  // Unregisters |observer|.
  virtual void RemoveObserver(Observer* observer) = 0;
};

// Holds parameters necessary to open a Cast channel (CastSocket) to a Cast
// device.
struct CastSocketOpenParams {
  // IP endpoint of the Cast device.
  net::IPEndPoint ip_endpoint;

  // Log of socket events.
  net::NetLog* net_log;

  // Connection timeout interval. If this value is not set, Cast socket will not
  // report CONNECT_TIMEOUT error and may hang when connecting to a Cast device.
  base::TimeDelta connect_timeout;

  // Amount of idle time to wait before disconnecting. Cast socket will ping
  // Cast device periodically at |ping_interval| to check liveness. If it does
  // not receive response in |liveness_timeout|, it reports PING_TIMEOUT error.
  // |liveness_timeout| should always be larger than or equal to
  // |ping_interval|.
  // If this value is not set, there is not periodic ping and Cast socket is
  // always assumed alive.
  base::TimeDelta liveness_timeout;

  // Amount of idle time to wait before pinging the Cast device. See comments
  // for |liveness_timeout|.
  base::TimeDelta ping_interval;

  // A bit vector representing the capabilities of the sink. The values are
  // defined in components/cast_channel/cast_socket.h.
  uint64_t device_capabilities;

  CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
                       net::NetLog* net_log,
                       base::TimeDelta connect_timeout);
  CastSocketOpenParams(const net::IPEndPoint& ip_endpoint,
                       net::NetLog* net_log,
                       base::TimeDelta connect_timeout,
                       base::TimeDelta liveness_timeout,
                       base::TimeDelta ping_interval,
                       uint64_t device_capabilities);
};

// This class implements a channel between Chrome and a Cast device using a TCP
// socket with SSL.  The channel may authenticate that the receiver is a genuine
// Cast device.  All CastSocketImpl objects must be used only on the IO thread.
//
// NOTE: Not called "CastChannel" to reduce confusion with the generated API
// code.
class CastSocketImpl : public CastSocket {
 public:
  CastSocketImpl(const CastSocketOpenParams& open_params,
                 const scoped_refptr<Logger>& logger);

  CastSocketImpl(const CastSocketOpenParams& open_params,
                 const scoped_refptr<Logger>& logger,
                 const AuthContext& auth_context);

  // Ensures that the socket is closed.
  ~CastSocketImpl() override;

  // CastSocket interface.
  void Connect(OnOpenCallback callback) override;
  CastTransport* transport() const override;
  void Close(const net::CompletionCallback& callback) override;
  const net::IPEndPoint& ip_endpoint() const override;
  int id() const override;
  void set_id(int channel_id) override;
  ReadyState ready_state() const override;
  ChannelError error_state() const override;
  bool keep_alive() const override;
  bool audio_only() const override;
  void AddObserver(Observer* observer) override;
  void RemoveObserver(Observer* observer) override;

 protected:
  // CastTransport::Delegate methods for receiving handshake messages.
  class AuthTransportDelegate : public CastTransport::Delegate {
   public:
    explicit AuthTransportDelegate(CastSocketImpl* socket);

    // Gets the error state of the channel.
    // Returns CHANNEL_ERROR_NONE if no errors are present.
    ChannelError error_state() const;

    // Gets recorded error details.
    LastError last_error() const;

    // CastTransport::Delegate interface.
    void OnError(ChannelError error_state) override;
    void OnMessage(const CastMessage& message) override;
    void Start() override;

   private:
    CastSocketImpl* socket_;
    ChannelError error_state_;
    LastError last_error_;
  };

  // CastTransport::Delegate methods to receive normal messages and errors.
  class CastSocketMessageDelegate : public CastTransport::Delegate {
   public:
    CastSocketMessageDelegate(CastSocketImpl* socket);
    ~CastSocketMessageDelegate() override;

    // CastTransport::Delegate implementation.
    void OnError(ChannelError error_state) override;
    void OnMessage(const CastMessage& message) override;
    void Start() override;

   private:
    CastSocketImpl* const socket_;
    DISALLOW_COPY_AND_ASSIGN(CastSocketMessageDelegate);
  };

  // Replaces the internally-constructed transport object with one provided
  // by the caller (e.g. a mock).
  void SetTransportForTesting(std::unique_ptr<CastTransport> transport);

  // Verifies whether the socket complies with cast channel policy.
  // Audio only channel policy mandates that a device declaring a video out
  // capability must not have a certificate with audio only policy.
  bool VerifyChannelPolicy(const AuthResult& result);

  void Connect();

 private:
  FRIEND_TEST_ALL_PREFIXES(MockCastSocketTest, TestObservers);
  friend class AuthTransportDelegate;

  void SetErrorState(ChannelError error_state) override;

  // Frees resources and cancels pending callbacks.  |ready_state_| will be set
  // READY_STATE_CLOSED on completion.  A no-op if |ready_state_| is already
  // READY_STATE_CLOSED.
  void CloseInternal();

  // Creates an instance of TransportClientSocket.
  virtual std::unique_ptr<net::TransportClientSocket> CreateTcpSocket();
  // Creates an instance of SSLClientSocket with the given underlying |socket|.
  virtual std::unique_ptr<net::SSLClientSocket> CreateSslSocket(
      std::unique_ptr<net::StreamSocket> socket);
  // Extracts peer certificate from SSLClientSocket instance when the socket
  // is in cert error state.
  // Returns null if the certificate could not be extracted.
  // TODO(kmarshall): Use MockSSLClientSocket for tests instead of overriding
  // this function.
  virtual scoped_refptr<net::X509Certificate> ExtractPeerCert();
  // Verifies whether the challenge reply received from the peer is valid:
  // 1. Signature in the reply is valid.
  // 2. Certificate is rooted to a trusted CA.
  virtual bool VerifyChallengeReply();

  // Invoked by a cancelable closure when connection setup time
  // exceeds the interval specified at |connect_timeout|.
  void OnConnectTimeout();

  /////////////////////////////////////////////////////////////////////////////
  // Following methods work together to implement the following flow:
  // 1. Create a new TCP socket and connect to it
  // 2. Create a new SSL socket and try connecting to it
  // 3. If connection fails due to invalid cert authority, then extract the
  //    peer certificate from the error.
  // 4. Whitelist the peer certificate and try #1 and #2 again.
  // 5. If SSL socket is connected successfully, and if protocol is casts://
  //    then issue an auth challenge request.
  // 6. Validate the auth challenge response.
  //
  // Main method that performs connection state transitions.
  void DoConnectLoop(int result);
  // Each of the below Do* method is executed in the corresponding
  // connection state. For example when connection state is TCP_CONNECT
  // DoTcpConnect is called, and so on.
  int DoTcpConnect();
  int DoTcpConnectComplete(int result);
  int DoSslConnect();
  int DoSslConnectComplete(int result);
  int DoAuthChallengeSend();
  int DoAuthChallengeSendComplete(int result);
  int DoAuthChallengeReplyComplete(int result);
  /////////////////////////////////////////////////////////////////////////////

  // Resets the cancellable callback used for async invocations of
  // DoConnectLoop.
  void ResetConnectLoopCallback();

  // Posts a task to invoke |connect_loop_callback_| with |result| on the
  // current message loop.
  void PostTaskToStartConnectLoop(int result);

  // Runs the external connection callback and resets it.
  void DoConnectCallback();

  virtual base::OneShotTimer* GetTimer();

  void SetConnectState(ConnectionState connect_state);
  void SetReadyState(ReadyState ready_state);

  THREAD_CHECKER(thread_checker_);

  // The id of the channel.
  int channel_id_;

  // The NetLog source for this service.
  net::NetLogSource net_log_source_;

  // Cast socket related settings.
  CastSocketOpenParams open_params_;

  // Shared logging object, used to log CastSocket events for diagnostics.
  scoped_refptr<Logger> logger_;

  // CertVerifier is owned by us but should be deleted AFTER SSLClientSocket
  // since in some cases the destructor of SSLClientSocket may call a method
  // to cancel a cert verification request.
  std::unique_ptr<net::CertVerifier> cert_verifier_;
  std::unique_ptr<net::TransportSecurityState> transport_security_state_;
  std::unique_ptr<net::CTVerifier> cert_transparency_verifier_;
  std::unique_ptr<net::CTPolicyEnforcer> ct_policy_enforcer_;

  // Owned ptr to the underlying TCP socket.
  std::unique_ptr<net::TransportClientSocket> tcp_socket_;

  // Owned ptr to the underlying SSL socket.
  std::unique_ptr<net::SSLClientSocket> socket_;

  // Certificate of the peer. This field may be empty if the peer
  // certificate is not yet fetched.
  scoped_refptr<net::X509Certificate> peer_cert_;

  // The challenge context for the current connection.
  const AuthContext auth_context_;

  // Reply received from the receiver to a challenge request.
  std::unique_ptr<CastMessage> challenge_reply_;

  // Callbacks invoked when the socket is connected or fails to connect.
  std::vector<OnOpenCallback> connect_callbacks_;

  // Callback invoked by |connect_timeout_timer_| to cancel the connection.
  base::CancelableClosure connect_timeout_callback_;

  // Timer invoked when the connection has timed out.
  std::unique_ptr<base::OneShotTimer> connect_timeout_timer_;

  // Set when a timeout is triggered and the connection process has
  // canceled.
  bool is_canceled_;

  // Whether the channel is audio only as identified by the device
  // certificate during channel authentication.
  bool audio_only_;

  // Connection flow state machine state.
  ConnectionState connect_state_;

  // Write flow state machine state.
  WriteState write_state_;

  // Read flow state machine state.
  ReadState read_state_;

  // The last error encountered by the channel.
  ChannelError error_state_;

  // The current status of the channel.
  ReadyState ready_state_;

  // Callback which, when invoked, will re-enter the connection state machine.
  // Oustanding callbacks will be cancelled when |this| is destroyed.
  // The callback signature is based on net::CompletionCallback, which passes
  // operation result codes as byte counts in the success case, or as
  // net::Error enum values for error cases.
  base::CancelableCallback<void(int)> connect_loop_callback_;

  // Cast message formatting and parsing layer.
  std::unique_ptr<CastTransport> transport_;

  // Caller's message read and error handling delegate.
  std::unique_ptr<CastTransport::Delegate> delegate_;

  // Raw pointer to the auth handshake delegate. Used to get detailed error
  // information.
  AuthTransportDelegate* auth_delegate_;

  // List of socket observers.
  base::ObserverList<Observer> observers_;

  DISALLOW_COPY_AND_ASSIGN(CastSocketImpl);
};
}  // namespace cast_channel

#endif  // COMPONENTS_CAST_CHANNEL_CAST_SOCKET_H_
