// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/nqe/throughput_analyzer.h"

#include <stdint.h>

#include <deque>
#include <memory>

#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/run_loop.h"
#include "base/single_thread_task_runner.h"
#include "base/threading/thread_task_runner_handle.h"
#include "net/base/url_util.h"
#include "net/url_request/url_request.h"
#include "net/url_request/url_request_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace net {

namespace nqe {

namespace {

class TestThroughputAnalyzer : public internal::ThroughputAnalyzer {
 public:
  TestThroughputAnalyzer()
      : internal::ThroughputAnalyzer(
            base::ThreadTaskRunnerHandle::Get(),
            base::Bind(
                &TestThroughputAnalyzer::OnNewThroughputObservationAvailable,
                base::Unretained(this)),
            false,
            false),
        throughput_observations_received_(0),
        bits_received_(0) {}

  ~TestThroughputAnalyzer() override {}

  int32_t throughput_observations_received() const {
    return throughput_observations_received_;
  }

  void OnNewThroughputObservationAvailable(int32_t downstream_kbps) {
    throughput_observations_received_++;
  }

  int64_t GetBitsReceived() const override { return bits_received_; }

  void IncrementBitsReceived(int64_t additional_bits_received) {
    bits_received_ += additional_bits_received;
  }

  using internal::ThroughputAnalyzer::disable_throughput_measurements;

 private:
  int throughput_observations_received_;

  int64_t bits_received_;

  DISALLOW_COPY_AND_ASSIGN(TestThroughputAnalyzer);
};

TEST(ThroughputAnalyzerTest, MaximumRequests) {
  const struct {
    bool use_local_requests;
  } tests[] = {{
                   false,
               },
               {
                   true,
               }};

  for (const auto& test : tests) {
    TestThroughputAnalyzer throughput_analyzer;

    TestDelegate test_delegate;
    TestURLRequestContext context;

    ASSERT_FALSE(throughput_analyzer.disable_throughput_measurements());
    std::deque<std::unique_ptr<URLRequest>> requests;

    // Start more requests than the maximum number of requests that can be held
    // in the memory.
    const std::string url = test.use_local_requests
                                ? "http://127.0.0.1/test.html"
                                : "http://example.com/test.html";
    for (size_t i = 0; i < 1000; ++i) {
      std::unique_ptr<URLRequest> request(
          context.CreateRequest(GURL(url), DEFAULT_PRIORITY, &test_delegate));
      ASSERT_EQ(test.use_local_requests, IsLocalhost(request->url().host()));

      throughput_analyzer.NotifyStartTransaction(*(request.get()));
      requests.push_back(std::move(request));
    }
    // Too many local requests should cause the |throughput_analyzer| to disable
    // throughput measurements.
    EXPECT_EQ(test.use_local_requests,
              throughput_analyzer.disable_throughput_measurements());
  }
}

// Tests if the throughput observation is taken correctly when local and network
// requests overlap.
TEST(ThroughputAnalyzerTest, TestThroughputWithMultipleRequestsOverlap) {
  static const struct {
    bool start_local_request;
    bool local_request_completes_first;
    bool expect_throughput_observation;
  } tests[] = {
      {
          false, false, true,
      },
      {
          true, false, false,
      },
      {
          true, true, true,
      },
  };

  for (const auto& test : tests) {
    // Localhost requests are not allowed for estimation purposes.
    TestThroughputAnalyzer throughput_analyzer;

    TestDelegate test_delegate;
    TestURLRequestContext context;

    std::unique_ptr<URLRequest> request_local;

    std::unique_ptr<URLRequest> request_not_local(
        context.CreateRequest(GURL("http://example.com/echo.html"),
                              DEFAULT_PRIORITY, &test_delegate));
    request_not_local->Start();

    if (test.start_local_request) {
      request_local = context.CreateRequest(GURL("http://localhost/echo.html"),
                                            DEFAULT_PRIORITY, &test_delegate);
      request_local->Start();
    }

    base::RunLoop().Run();

    EXPECT_EQ(0, throughput_analyzer.throughput_observations_received());

    // If |test.start_local_request| is true, then |request_local| starts
    // before |request_not_local|, and ends after |request_not_local|. Thus,
    // network quality estimator should not get a chance to record throughput
    // observation from |request_not_local| because of ongoing local request
    // at all times.
    if (test.start_local_request)
      throughput_analyzer.NotifyStartTransaction(*request_local);
    throughput_analyzer.NotifyStartTransaction(*request_not_local);

    if (test.local_request_completes_first) {
      ASSERT_TRUE(test.start_local_request);
      throughput_analyzer.NotifyRequestCompleted(*request_local);
    }

    // Increment the bytes received count to emulate the bytes received for
    // |request_local| and |request_not_local|.
    throughput_analyzer.IncrementBitsReceived(100 * 1000 * 8);

    throughput_analyzer.NotifyRequestCompleted(*request_not_local);
    if (test.start_local_request && !test.local_request_completes_first)
      throughput_analyzer.NotifyRequestCompleted(*request_local);

    base::RunLoop().RunUntilIdle();

    int expected_throughput_observations =
        test.expect_throughput_observation ? 1 : 0;
    EXPECT_EQ(expected_throughput_observations,
              throughput_analyzer.throughput_observations_received());
  }
}

// Tests if the throughput observation is taken correctly when two network
// requests overlap.
TEST(ThroughputAnalyzerTest, TestThroughputWithNetworkRequestsOverlap) {
  static const struct {
    int64_t increment_bits;
    bool expect_throughput_observation;
  } tests[] = {
      {
          100 * 1000 * 8, true,
      },
      {
          1, false,
      },
  };

  for (const auto& test : tests) {
    // Localhost requests are not allowed for estimation purposes.
    TestThroughputAnalyzer throughput_analyzer;
    TestDelegate test_delegate;
    TestURLRequestContext context;

    EXPECT_EQ(0, throughput_analyzer.throughput_observations_received());

    std::unique_ptr<URLRequest> request_network_1 = context.CreateRequest(
        GURL("http://example.com/echo.html"), DEFAULT_PRIORITY, &test_delegate);
    std::unique_ptr<URLRequest> request_network_2 = context.CreateRequest(
        GURL("http://example.com/echo.html"), DEFAULT_PRIORITY, &test_delegate);
    request_network_1->Start();
    request_network_2->Start();

    base::RunLoop().Run();

    EXPECT_LE(0, throughput_analyzer.throughput_observations_received());

    throughput_analyzer.NotifyStartTransaction(*request_network_1);
    throughput_analyzer.NotifyStartTransaction(*request_network_2);

    // Increment the bytes received count to emulate the bytes received for
    // |request_network_1| and |request_network_2|.
    throughput_analyzer.IncrementBitsReceived(test.increment_bits);

    throughput_analyzer.NotifyRequestCompleted(*request_network_1);
    throughput_analyzer.NotifyRequestCompleted(*request_network_2);
    base::RunLoop().RunUntilIdle();

    // Only one observation should be taken since two requests overlap.
    if (test.expect_throughput_observation) {
      EXPECT_EQ(1, throughput_analyzer.throughput_observations_received());
    } else {
      EXPECT_EQ(0, throughput_analyzer.throughput_observations_received());
    }
  }
}

}  // namespace

}  // namespace nqe

}  // namespace net
