// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "content/public/browser/tracing_controller.h"

#include <stdint.h>
#include <utility>

#include "base/files/file_util.h"
#include "base/memory/ref_counted_memory.h"
#include "base/run_loop.h"
#include "base/strings/pattern.h"
#include "base/threading/thread_restrictions.h"
#include "build/build_config.h"
#include "content/browser/tracing/tracing_controller_impl.h"
#include "content/public/browser/browser_thread.h"
#include "content/public/browser/trace_uploader.h"
#include "content/public/test/browser_test_utils.h"
#include "content/public/test/content_browser_test.h"
#include "content/public/test/content_browser_test_utils.h"
#include "content/shell/browser/shell.h"
#include "content/test/test_content_browser_client.h"

using base::trace_event::RECORD_CONTINUOUSLY;
using base::trace_event::RECORD_UNTIL_FULL;
using base::trace_event::TraceConfig;

namespace content {

namespace {

const char* kMetadataWhitelist[] = {
  "cpu-brand",
  "network-type",
  "os-name",
  "user-agent"
};

bool IsMetadataWhitelisted(const std::string& metadata_name) {
  for (auto* key : kMetadataWhitelist) {
    if (base::MatchPattern(metadata_name, key)) {
      return true;
    }
  }
  return false;
}

bool IsTraceEventArgsWhitelisted(
    const char* category_group_name,
    const char* event_name,
    base::trace_event::ArgumentNameFilterPredicate* arg_filter) {
  if (base::MatchPattern(category_group_name, "benchmark") &&
      base::MatchPattern(event_name, "whitelisted")) {
    return true;
  }
  return false;
}

}  // namespace

class TracingControllerTestEndpoint : public TraceDataEndpoint {
 public:
  TracingControllerTestEndpoint(
      base::Callback<void(std::unique_ptr<const base::DictionaryValue>,
                          base::RefCountedString*)> done_callback)
      : done_callback_(done_callback) {}

  void ReceiveTraceChunk(std::unique_ptr<std::string> chunk) override {
    EXPECT_FALSE(chunk->empty());
    trace_ += *chunk;
  }

  void ReceiveTraceFinalContents(
      std::unique_ptr<const base::DictionaryValue> metadata) override {
    scoped_refptr<base::RefCountedString> chunk_ptr =
        base::RefCountedString::TakeString(&trace_);

    BrowserThread::PostTask(
        BrowserThread::UI, FROM_HERE,
        base::Bind(done_callback_, base::Passed(std::move(metadata)),
                   base::RetainedRef(chunk_ptr)));
  }

 protected:
  ~TracingControllerTestEndpoint() override {}

  std::string trace_;
  base::Callback<void(std::unique_ptr<const base::DictionaryValue>,
                      base::RefCountedString*)>
      done_callback_;
};

class TracingTestBrowserClient : public TestContentBrowserClient {
 public:
  TracingDelegate* GetTracingDelegate() override {
    return new TestTracingDelegate();
  };

 private:
  class TestTracingDelegate : public TracingDelegate {
   public:
    std::unique_ptr<TraceUploader> GetTraceUploader(
        net::URLRequestContextGetter* request_context) override {
      return nullptr;
    }
    MetadataFilterPredicate GetMetadataFilterPredicate() override {
      return base::Bind(IsMetadataWhitelisted);
    }
  };
};

class TracingControllerTest : public ContentBrowserTest {
 public:
  TracingControllerTest() {}

  void SetUp() override {
    get_categories_done_callback_count_ = 0;
    enable_recording_done_callback_count_ = 0;
    disable_recording_done_callback_count_ = 0;
    ContentBrowserTest::SetUp();
  }

  void TearDown() override { ContentBrowserTest::TearDown(); }

  void Navigate(Shell* shell) {
    NavigateToURL(shell, GetTestUrl("", "title.html"));
  }

  void GetCategoriesDoneCallbackTest(base::Closure quit_callback,
                                     const std::set<std::string>& categories) {
    get_categories_done_callback_count_++;
    EXPECT_TRUE(categories.size() > 0);
    quit_callback.Run();
  }

  void StartTracingDoneCallbackTest(base::Closure quit_callback) {
    enable_recording_done_callback_count_++;
    quit_callback.Run();
  }

  void StopTracingStringDoneCallbackTest(
      base::Closure quit_callback,
      std::unique_ptr<const base::DictionaryValue> metadata,
      base::RefCountedString* data) {
    disable_recording_done_callback_count_++;
    last_metadata_.reset(metadata.release());
    last_data_ = data->data();
    EXPECT_TRUE(data->size() > 0);
    quit_callback.Run();
  }

  void StopTracingFileDoneCallbackTest(base::Closure quit_callback,
                                            const base::FilePath& file_path) {
    disable_recording_done_callback_count_++;
    {
      base::ThreadRestrictions::ScopedAllowIO allow_io_for_test_verifications;
      EXPECT_TRUE(PathExists(file_path));
      int64_t file_size;
      base::GetFileSize(file_path, &file_size);
      EXPECT_GT(file_size, 0);
    }
    quit_callback.Run();
    last_actual_recording_file_path_ = file_path;
  }

    int get_categories_done_callback_count() const {
    return get_categories_done_callback_count_;
  }

  int enable_recording_done_callback_count() const {
    return enable_recording_done_callback_count_;
  }

  int disable_recording_done_callback_count() const {
    return disable_recording_done_callback_count_;
  }

  base::FilePath last_actual_recording_file_path() const {
    return last_actual_recording_file_path_;
  }

  const base::DictionaryValue* last_metadata() const {
    return last_metadata_.get();
  }

  const std::string& last_data() const {
    return last_data_;
  }

  void TestStartAndStopTracingString() {
    Navigate(shell());

    TracingController* controller = TracingController::GetInstance();

    {
      base::RunLoop run_loop;
      TracingController::StartTracingDoneCallback callback =
          base::Bind(&TracingControllerTest::StartTracingDoneCallbackTest,
                     base::Unretained(this),
                     run_loop.QuitClosure());
      bool result = controller->StartTracing(
          TraceConfig(), callback);
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(enable_recording_done_callback_count(), 1);
    }

    {
      base::RunLoop run_loop;
      base::Callback<void(std::unique_ptr<const base::DictionaryValue>,
                          base::RefCountedString*)>
          callback = base::Bind(
              &TracingControllerTest::StopTracingStringDoneCallbackTest,
              base::Unretained(this), run_loop.QuitClosure());
      bool result = controller->StopTracing(
          TracingController::CreateStringSink(callback));
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(disable_recording_done_callback_count(), 1);
    }
  }

  void TestStartAndStopTracingStringWithFilter() {
    TracingTestBrowserClient client;
    ContentBrowserClient* old_client = SetBrowserClientForTesting(&client);
    Navigate(shell());

    base::trace_event::TraceLog::GetInstance()->SetArgumentFilterPredicate(
        base::Bind(&IsTraceEventArgsWhitelisted));

    TracingController* controller = TracingController::GetInstance();

    {
      base::RunLoop run_loop;
      TracingController::StartTracingDoneCallback callback =
          base::Bind(&TracingControllerTest::StartTracingDoneCallbackTest,
                     base::Unretained(this),
                     run_loop.QuitClosure());

      TraceConfig config = TraceConfig();
      config.EnableArgumentFilter();

      bool result = controller->StartTracing(config, callback);
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(enable_recording_done_callback_count(), 1);
    }

    {
      base::RunLoop run_loop;
      base::Callback<void(std::unique_ptr<const base::DictionaryValue>,
                          base::RefCountedString*)>
          callback = base::Bind(
              &TracingControllerTest::StopTracingStringDoneCallbackTest,
              base::Unretained(this), run_loop.QuitClosure());

      scoped_refptr<TracingController::TraceDataSink> trace_data_sink =
          TracingController::CreateStringSink(callback);

      base::DictionaryValue metadata;
      metadata.SetString("not-whitelisted", "this_not_found");
      controller->AddMetadata(metadata);

      bool result = controller->StopTracing(trace_data_sink);
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(disable_recording_done_callback_count(), 1);
    }
    SetBrowserClientForTesting(old_client);
  }

  void TestStartAndStopTracingCompressed() {
    Navigate(shell());

    TracingController* controller = TracingController::GetInstance();

    {
      base::RunLoop run_loop;
      TracingController::StartTracingDoneCallback callback =
          base::Bind(&TracingControllerTest::StartTracingDoneCallbackTest,
                     base::Unretained(this), run_loop.QuitClosure());
      bool result = controller->StartTracing(TraceConfig(), callback);
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(enable_recording_done_callback_count(), 1);
    }

    {
      base::RunLoop run_loop;
      base::Callback<void(std::unique_ptr<const base::DictionaryValue>,
                          base::RefCountedString*)>
          callback = base::Bind(
              &TracingControllerTest::StopTracingStringDoneCallbackTest,
              base::Unretained(this), run_loop.QuitClosure());
      bool result = controller->StopTracing(
          TracingControllerImpl::CreateCompressedStringSink(
              new TracingControllerTestEndpoint(callback)));
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(disable_recording_done_callback_count(), 1);
    }
  }

  void TestStartAndStopTracingFile(
      const base::FilePath& result_file_path) {
    Navigate(shell());

    TracingController* controller = TracingController::GetInstance();

    {
      base::RunLoop run_loop;
      TracingController::StartTracingDoneCallback callback =
          base::Bind(&TracingControllerTest::StartTracingDoneCallbackTest,
                     base::Unretained(this),
                     run_loop.QuitClosure());
      bool result = controller->StartTracing(TraceConfig(), callback);
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(enable_recording_done_callback_count(), 1);
    }

    {
      base::RunLoop run_loop;
      base::Closure callback = base::Bind(
          &TracingControllerTest::StopTracingFileDoneCallbackTest,
          base::Unretained(this),
          run_loop.QuitClosure(),
          result_file_path);
      bool result = controller->StopTracing(
          TracingController::CreateFileSink(result_file_path, callback));
      ASSERT_TRUE(result);
      run_loop.Run();
      EXPECT_EQ(disable_recording_done_callback_count(), 1);
    }
  }

 private:
  int get_categories_done_callback_count_;
  int enable_recording_done_callback_count_;
  int disable_recording_done_callback_count_;
  base::FilePath last_actual_recording_file_path_;
  std::unique_ptr<const base::DictionaryValue> last_metadata_;
  std::string last_data_;
};

IN_PROC_BROWSER_TEST_F(TracingControllerTest, GetCategories) {
  Navigate(shell());

  TracingController* controller = TracingController::GetInstance();

  base::RunLoop run_loop;
  TracingController::GetCategoriesDoneCallback callback =
      base::Bind(&TracingControllerTest::GetCategoriesDoneCallbackTest,
                 base::Unretained(this),
                 run_loop.QuitClosure());
  ASSERT_TRUE(controller->GetCategories(callback));
  run_loop.Run();
  EXPECT_EQ(get_categories_done_callback_count(), 1);
}

IN_PROC_BROWSER_TEST_F(TracingControllerTest, EnableAndStopTracing) {
  TestStartAndStopTracingString();
}

IN_PROC_BROWSER_TEST_F(TracingControllerTest, DisableRecordingStoresMetadata) {
  TestStartAndStopTracingString();
  // Check that a number of important keys exist in the metadata dictionary. The
  // values are not checked to ensure the test is robust.
  EXPECT_TRUE(last_metadata() != NULL);
  std::string network_type;
  last_metadata()->GetString("network-type", &network_type);
  EXPECT_TRUE(network_type.length() > 0);
  std::string user_agent;
  last_metadata()->GetString("user-agent", &user_agent);
  EXPECT_TRUE(user_agent.length() > 0);
  std::string os_name;
  last_metadata()->GetString("os-name", &os_name);
  EXPECT_TRUE(os_name.length() > 0);
  std::string cpu_brand;
  last_metadata()->GetString("cpu-brand", &cpu_brand);
  EXPECT_TRUE(cpu_brand.length() > 0);
}

// TODO(crbug.com/642991) Disabled for flakiness.
IN_PROC_BROWSER_TEST_F(TracingControllerTest,
                       DISABLED_NotWhitelistedMetadataStripped) {
  TestStartAndStopTracingStringWithFilter();
  // Check that a number of important keys exist in the metadata dictionary.
  EXPECT_TRUE(last_metadata() != NULL);
  std::string cpu_brand;
  last_metadata()->GetString("cpu-brand", &cpu_brand);
  EXPECT_TRUE(cpu_brand.length() > 0);
  EXPECT_TRUE(cpu_brand != "__stripped__");
  std::string network_type;
  last_metadata()->GetString("network-type", &network_type);
  EXPECT_TRUE(network_type.length() > 0);
  EXPECT_TRUE(network_type != "__stripped__");
  std::string os_name;
  last_metadata()->GetString("os-name", &os_name);
  EXPECT_TRUE(os_name.length() > 0);
  EXPECT_TRUE(os_name != "__stripped__");
  std::string user_agent;
  last_metadata()->GetString("user-agent", &user_agent);
  EXPECT_TRUE(user_agent.length() > 0);
  EXPECT_TRUE(user_agent != "__stripped__");

  // Check that the not whitelisted metadata is stripped.
  std::string not_whitelisted;
  last_metadata()->GetString("not-whitelisted", &not_whitelisted);
  EXPECT_TRUE(not_whitelisted.length() > 0);
  EXPECT_TRUE(not_whitelisted == "__stripped__");

  // Also check the string data.
  EXPECT_TRUE(last_data().size() > 0);
  EXPECT_TRUE(last_data().find("cpu-brand") != std::string::npos);
  EXPECT_TRUE(last_data().find("network-type") != std::string::npos);
  EXPECT_TRUE(last_data().find("os-name") != std::string::npos);
  EXPECT_TRUE(last_data().find("user-agent") != std::string::npos);

  EXPECT_TRUE(last_data().find("not-whitelisted") != std::string::npos);
  EXPECT_TRUE(last_data().find("this_not_found") == std::string::npos);
}

IN_PROC_BROWSER_TEST_F(TracingControllerTest,
                       EnableAndStopTracingWithFilePath) {
  base::FilePath file_path;
  {
    base::ThreadRestrictions::ScopedAllowIO allow_io_for_creating_test_file;
    base::CreateTemporaryFile(&file_path);
  }
  TestStartAndStopTracingFile(file_path);
  EXPECT_EQ(file_path.value(), last_actual_recording_file_path().value());
}

IN_PROC_BROWSER_TEST_F(TracingControllerTest,
                       EnableAndStopTracingWithCompression) {
  TestStartAndStopTracingCompressed();
}

IN_PROC_BROWSER_TEST_F(TracingControllerTest,
                       EnableAndStopTracingWithEmptyFileAndNullCallback) {
  Navigate(shell());

  TracingController* controller = TracingController::GetInstance();
  EXPECT_TRUE(controller->StartTracing(
      TraceConfig(),
      TracingController::StartTracingDoneCallback()));
  EXPECT_TRUE(controller->StopTracing(NULL));
  base::RunLoop().RunUntilIdle();
}

}  // namespace content
