// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// The purpose of this file is determine what bitrate to use for mirroring.
// Ideally this should be as much as possible, without causing any frames to
// arrive late.

// The current algorithm is to measure how much bandwidth we've been using
// recently. We also keep track of how much data has been queued up for sending
// in a virtual "buffer" (this virtual buffer represents all the buffers between
// the sender and the receiver, including retransmissions and so forth.)
// If we estimate that our virtual buffer is mostly empty, we try to use
// more bandwidth than our recent usage, otherwise we use less.

#include "media/cast/sender/congestion_control.h"

#include <algorithm>
#include <deque>

#include "base/logging.h"
#include "base/macros.h"
#include "base/trace_event/trace_event.h"
#include "media/cast/constants.h"

namespace media {
namespace cast {

class AdaptiveCongestionControl : public CongestionControl {
 public:
  AdaptiveCongestionControl(base::TickClock* clock,
                            int max_bitrate_configured,
                            int min_bitrate_configured,
                            double max_frame_rate);

  ~AdaptiveCongestionControl() final;

  // CongestionControl implementation.
  void UpdateRtt(base::TimeDelta rtt) final;
  void UpdateTargetPlayoutDelay(base::TimeDelta delay) final;
  void SendFrameToTransport(FrameId frame_id,
                            size_t frame_size_in_bits,
                            base::TimeTicks when) final;
  void AckFrame(FrameId frame_id, base::TimeTicks when) final;
  void AckLaterFrames(std::vector<FrameId> received_frames,
                      base::TimeTicks when) final;
  int GetBitrate(base::TimeTicks playout_time,
                 base::TimeDelta playout_delay) final;

 private:
  struct FrameStats {
    FrameStats();
    // Time this frame was first enqueued for transport.
    base::TimeTicks enqueue_time;
    // Time this frame was acked.
    base::TimeTicks ack_time;
    // Size of encoded frame in bits.
    size_t frame_size_in_bits;
  };

  // Calculate how much "dead air" (idle time) there is between two frames.
  static base::TimeDelta DeadTime(const FrameStats& a, const FrameStats& b);
  // Get the FrameStats for a given |frame_id|.  Never returns nullptr.
  // Note: Older FrameStats will be removed automatically.
  FrameStats* GetFrameStats(FrameId frame_id);
  // Discard old FrameStats.
  void PruneFrameStats();
  // Calculate a safe bitrate. This is based on how much we've been
  // sending in the past.
  double CalculateSafeBitrate();

  // Estimate when the transport will start sending the data for a given frame.
  // |estimated_bitrate| is the current estimated transmit bitrate in bits per
  // second.
  base::TimeTicks EstimatedSendingTime(FrameId frame_id,
                                       double estimated_bitrate);

  base::TickClock* const clock_;  // Not owned by this class.
  const int max_bitrate_configured_;
  const int min_bitrate_configured_;
  const double max_frame_rate_;
  std::deque<FrameStats> frame_stats_;
  FrameId last_frame_stats_;
  // This is the latest known frame that all previous frames (having smaller
  // |frame_id|) and this frame were acked by receiver.
  FrameId last_checkpoint_frame_;
  // This is the first time that |last_checkpoint_frame_| is marked.
  base::TimeTicks last_checkpoint_time_;
  FrameId last_enqueued_frame_;
  base::TimeDelta rtt_;
  size_t history_size_;
  size_t acked_bits_in_history_;
  base::TimeDelta dead_time_in_history_;

  DISALLOW_COPY_AND_ASSIGN(AdaptiveCongestionControl);
};

class FixedCongestionControl : public CongestionControl {
 public:
  explicit FixedCongestionControl(int bitrate) : bitrate_(bitrate) {}
  ~FixedCongestionControl() final {}

  // CongestionControl implementation.
  void UpdateRtt(base::TimeDelta rtt) final {}
  void UpdateTargetPlayoutDelay(base::TimeDelta delay) final {}
  void SendFrameToTransport(FrameId frame_id,
                            size_t frame_size_in_bits,
                            base::TimeTicks when) final {}
  void AckFrame(FrameId frame_id, base::TimeTicks when) final {}
  void AckLaterFrames(std::vector<FrameId> received_frames,
                      base::TimeTicks when) final {}
  int GetBitrate(base::TimeTicks playout_time,
                 base::TimeDelta playout_delay) final {
    return bitrate_;
  }

 private:
  const int bitrate_;

  DISALLOW_COPY_AND_ASSIGN(FixedCongestionControl);
};


CongestionControl* NewAdaptiveCongestionControl(
    base::TickClock* clock,
    int max_bitrate_configured,
    int min_bitrate_configured,
    double max_frame_rate) {
  return new AdaptiveCongestionControl(clock,
                                       max_bitrate_configured,
                                       min_bitrate_configured,
                                       max_frame_rate);
}

CongestionControl* NewFixedCongestionControl(int bitrate) {
  return new FixedCongestionControl(bitrate);
}

// This means that we *try* to keep our buffer 90% empty.
// If it is less full, we increase the bandwidth, if it is more
// we decrease the bandwidth. Making this smaller makes the
// congestion control more aggressive.
static const double kTargetEmptyBufferFraction = 0.9;

// This is the size of our history in frames. Larger values makes the
// congestion control adapt slower.
static const size_t kHistorySize = 100;

AdaptiveCongestionControl::FrameStats::FrameStats() : frame_size_in_bits(0) {
}

AdaptiveCongestionControl::AdaptiveCongestionControl(base::TickClock* clock,
                                                     int max_bitrate_configured,
                                                     int min_bitrate_configured,
                                                     double max_frame_rate)
    : clock_(clock),
      max_bitrate_configured_(max_bitrate_configured),
      min_bitrate_configured_(min_bitrate_configured),
      max_frame_rate_(max_frame_rate),
      last_frame_stats_(FrameId::first() - 1),
      last_checkpoint_frame_(FrameId::first() - 1),
      last_enqueued_frame_(FrameId::first() - 1),
      history_size_(kHistorySize),
      acked_bits_in_history_(0) {
  DCHECK_GE(max_bitrate_configured, min_bitrate_configured) << "Invalid config";
  DCHECK_GT(min_bitrate_configured, 0);
  frame_stats_.resize(2);
  base::TimeTicks now = clock->NowTicks();
  frame_stats_[0].ack_time = now;
  frame_stats_[0].enqueue_time = now;
  frame_stats_[1].ack_time = now;
  frame_stats_[1].enqueue_time = now;
  last_checkpoint_time_ = now;
  DCHECK(!frame_stats_[0].ack_time.is_null());
}

CongestionControl::~CongestionControl() {}
AdaptiveCongestionControl::~AdaptiveCongestionControl() {}

void AdaptiveCongestionControl::UpdateRtt(base::TimeDelta rtt) {
  rtt_ = (7 * rtt_ + rtt) / 8;
}

void AdaptiveCongestionControl::UpdateTargetPlayoutDelay(
    base::TimeDelta delay) {
  const int max_unacked_frames = std::min<int>(
      kMaxUnackedFrames, 1 + static_cast<int>(delay * max_frame_rate_ /
                                              base::TimeDelta::FromSeconds(1)));
  DCHECK_GT(max_unacked_frames, 0);
  history_size_ = max_unacked_frames + kHistorySize;
  PruneFrameStats();
}

// Calculate how much "dead air" there is between two frames.
base::TimeDelta AdaptiveCongestionControl::DeadTime(const FrameStats& a,
                                                    const FrameStats& b) {
  if (b.enqueue_time > a.ack_time) {
    return b.enqueue_time - a.ack_time;
  } else {
    return base::TimeDelta();
  }
}

double AdaptiveCongestionControl::CalculateSafeBitrate() {
  double transmit_time =
      (GetFrameStats(last_checkpoint_frame_)->ack_time -
       frame_stats_.front().enqueue_time - dead_time_in_history_)
          .InSecondsF();

  if (acked_bits_in_history_ == 0 || transmit_time <= 0.0) {
    return min_bitrate_configured_;
  }
  return acked_bits_in_history_ / std::max(transmit_time, 1E-3);
}

AdaptiveCongestionControl::FrameStats* AdaptiveCongestionControl::GetFrameStats(
    FrameId frame_id) {
  DCHECK_LT(frame_id - last_frame_stats_, static_cast<int64_t>(kHistorySize));
  int offset = frame_id - last_frame_stats_;
  if (offset > 0) {
    frame_stats_.resize(frame_stats_.size() + offset);
    last_frame_stats_ += offset;
    offset = 0;
  }
  PruneFrameStats();
  offset += frame_stats_.size() - 1;
  // TODO(miu): Change the following to DCHECK once crash fix is confirmed.
  // http://crbug.com/517145
  CHECK(offset >= 0 && offset < static_cast<int32_t>(frame_stats_.size()));
  return &frame_stats_[offset];
}

void AdaptiveCongestionControl::PruneFrameStats() {
 while (frame_stats_.size() > history_size_) {
    DCHECK_GT(frame_stats_.size(), 1UL);
    DCHECK(!frame_stats_[0].ack_time.is_null());
    acked_bits_in_history_ -= frame_stats_[0].frame_size_in_bits;
    dead_time_in_history_ -= DeadTime(frame_stats_[0], frame_stats_[1]);
    DCHECK_GE(acked_bits_in_history_, 0UL);
    VLOG(2) << "DT: " << dead_time_in_history_.InSecondsF();
    DCHECK_GE(dead_time_in_history_.InSecondsF(), 0.0);
    frame_stats_.pop_front();
  }
}

void AdaptiveCongestionControl::AckFrame(FrameId frame_id,
                                         base::TimeTicks when) {
  FrameStats* frame_stats = GetFrameStats(last_checkpoint_frame_);
  while (last_checkpoint_frame_ < frame_id) {
    FrameStats* last_frame_stats = frame_stats;
    frame_stats = GetFrameStats(last_checkpoint_frame_ + 1);
    if (frame_stats->enqueue_time.is_null()) {
      // Can't ack a frame that hasn't been sent yet.
      return;
    }
    last_checkpoint_frame_++;
    if (when < frame_stats->enqueue_time)
      when = frame_stats->enqueue_time;
    // Don't overwrite the ack time for those frames that were already acked in
    // previous extended ACKs.
    if (frame_stats->ack_time.is_null())
      frame_stats->ack_time = when;
    DCHECK_GE(when, frame_stats->ack_time);
    acked_bits_in_history_ += frame_stats->frame_size_in_bits;
    dead_time_in_history_ += DeadTime(*last_frame_stats, *frame_stats);
    last_checkpoint_time_ = when;
  }
}

void AdaptiveCongestionControl::AckLaterFrames(
    std::vector<FrameId> received_frames,
    base::TimeTicks when) {
  for (FrameId frame_id : received_frames) {
    if (last_checkpoint_frame_ < frame_id) {
      FrameStats* frame_stats = GetFrameStats(frame_id);
      if (frame_stats->enqueue_time.is_null()) {
        // Can't ack a frame that hasn't been sent yet.
        continue;
      }
      if (when < frame_stats->enqueue_time)
        when = frame_stats->enqueue_time;
      // Don't overwrite the ack time for those frames that were acked before.
      if (frame_stats->ack_time.is_null())
        frame_stats->ack_time = when;
      DCHECK_GE(when, frame_stats->ack_time);
    }
  }
}

void AdaptiveCongestionControl::SendFrameToTransport(FrameId frame_id,
                                                     size_t frame_size_in_bits,
                                                     base::TimeTicks when) {
  last_enqueued_frame_ = frame_id;
  FrameStats* frame_stats = GetFrameStats(frame_id);
  frame_stats->enqueue_time = when;
  frame_stats->frame_size_in_bits = frame_size_in_bits;
}

base::TimeTicks AdaptiveCongestionControl::EstimatedSendingTime(
    FrameId frame_id,
    double estimated_bitrate) {
  const base::TimeTicks now = clock_->NowTicks();

  // Starting with the time of the latest acknowledgement, extrapolate forward
  // to determine an estimated sending time for |frame_id|.
  //
  // |estimated_sending_time| will contain the estimated sending time for each
  // frame after the last ACK'ed frame.  It is possible for multiple frames to
  // be in-flight; and therefore it is common for the |estimated_sending_time|
  // for those frames to be before |now|.  The initial estimate is based on the
  // last ACKed frame and the RTT.
  base::TimeTicks estimated_sending_time = last_checkpoint_time_ - rtt_;
  for (FrameId f = last_checkpoint_frame_ + 1; f < frame_id; ++f) {
    FrameStats* const stats = GetFrameStats(f);

    // |estimated_ack_time| is the local time when the sender receives the ACK,
    // and not the time when the ACK left the receiver.
    base::TimeTicks estimated_ack_time = stats->ack_time;

    // Do not update the estimate if this frame's packets will never again enter
    // the packet send queue; unless there is no estimate yet.
    if (!estimated_ack_time.is_null())
      continue;

    // Model: The |estimated_sending_time| is the time at which the first byte
    // of the encoded frame is transmitted.  Then, assume the transmission of
    // the remaining bytes is paced such that the last byte has just left the
    // sender at |frame_transmit_time| later.  This last byte then takes
    // ~RTT/2 amount of time to travel to the receiver.  Finally, the ACK from
    // the receiver is sent and this takes another ~RTT/2 amount of time to
    // reach the sender.
    const base::TimeDelta frame_transmit_time = base::TimeDelta::FromSecondsD(
        stats->frame_size_in_bits / estimated_bitrate);
    estimated_ack_time = std::max(estimated_sending_time, stats->enqueue_time) +
                         frame_transmit_time + rtt_;

    if (estimated_ack_time < now) {
      // The current frame has not yet been ACK'ed and the yet the computed
      // |estimated_ack_time| is before |now|.  This contradiction must be
      // resolved.
      //
      // The solution below is a little counter-intuitive, but it seems to
      // work.  Basically, when we estimate that the ACK should have already
      // happened, we figure out how long ago it should have happened and
      // guess that the ACK will happen half of that time in the future.  This
      // will cause some over-estimation when acks are late, which is actually
      // the desired behavior.
      estimated_ack_time = now + (now - estimated_ack_time) / 2;
    }

    // Since we [in the common case] do not wait for an ACK before we start
    // sending the next frame, estimate the next frame's sending time as the
    // time just after the last byte of the current frame left the sender (see
    // Model comment above).
    estimated_sending_time =
        std::max(estimated_sending_time, estimated_ack_time - rtt_);
  }

  FrameStats* const frame_stats = GetFrameStats(frame_id);
  if (frame_stats->enqueue_time.is_null()) {
    // The frame has not yet been enqueued for transport.  Since it cannot be
    // enqueued in the past, ensure the result is lower-bounded by |now|.
    estimated_sending_time = std::max(estimated_sending_time, now);
  } else {
    // |frame_stats->enqueue_time| is the time the frame was enqueued for
    // transport.  The frame may not actually start being sent until a
    // point-in-time after that, because the transport is waiting for prior
    // frames to be acknowledged.
    estimated_sending_time =
        std::max(estimated_sending_time, frame_stats->enqueue_time);
  }

  return estimated_sending_time;
}

int AdaptiveCongestionControl::GetBitrate(base::TimeTicks playout_time,
                                          base::TimeDelta playout_delay) {
  double safe_bitrate = CalculateSafeBitrate();
  // Estimate when we might start sending the next frame.
  base::TimeDelta time_to_catch_up =
      playout_time -
      EstimatedSendingTime(last_enqueued_frame_ + 1, safe_bitrate);

  double empty_buffer_fraction =
      time_to_catch_up.InSecondsF() / playout_delay.InSecondsF();
  empty_buffer_fraction = std::min(empty_buffer_fraction, 1.0);
  empty_buffer_fraction = std::max(empty_buffer_fraction, 0.0);

  int bits_per_second = static_cast<int>(
      safe_bitrate * empty_buffer_fraction / kTargetEmptyBufferFraction);
  VLOG(3) << " FBR:" << (bits_per_second / 1E6)
          << " EBF:" << empty_buffer_fraction
          << " SBR:" << (safe_bitrate / 1E6);
  TRACE_COUNTER_ID1("cast.stream", "Empty Buffer Fraction", this,
                    empty_buffer_fraction);
  bits_per_second = std::max(bits_per_second, min_bitrate_configured_);
  bits_per_second = std::min(bits_per_second, max_bitrate_configured_);

  return bits_per_second;
}

}  // namespace cast
}  // namespace media
