// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/win/wait_chain.h"

#include <memory>
#include <string>

#include "base/bind.h"
#include "base/command_line.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
#include "base/test/multiprocess_test.h"
#include "base/threading/simple_thread.h"
#include "base/win/win_util.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/multiprocess_func_list.h"

namespace base {
namespace win {

namespace {

// Appends |handle| as a command line switch.
void AppendSwitchHandle(CommandLine* command_line,
                        StringPiece switch_name,
                        HANDLE handle) {
  command_line->AppendSwitchASCII(switch_name.as_string(),
                                  UintToString(HandleToUint32(handle)));
}

// Retrieves the |handle| associated to |switch_name| from the command line.
ScopedHandle GetSwitchValueHandle(CommandLine* command_line,
                                  StringPiece switch_name) {
  std::string switch_string =
      command_line->GetSwitchValueASCII(switch_name.as_string());
  unsigned int switch_uint = 0;
  if (switch_string.empty() || !StringToUint(switch_string, &switch_uint)) {
    DLOG(ERROR) << "Missing or invalid " << switch_name << " argument.";
    return ScopedHandle();
  }
  return ScopedHandle(reinterpret_cast<HANDLE>(switch_uint));
}

// Helper function to create a mutex.
ScopedHandle CreateMutex(bool inheritable) {
  SECURITY_ATTRIBUTES security_attributes = {sizeof(SECURITY_ATTRIBUTES),
                                             nullptr, inheritable};
  return ScopedHandle(::CreateMutex(&security_attributes, FALSE, NULL));
}

// Helper function to create an event.
ScopedHandle CreateEvent(bool inheritable) {
  SECURITY_ATTRIBUTES security_attributes = {sizeof(SECURITY_ATTRIBUTES),
                                             nullptr, inheritable};
  return ScopedHandle(
      ::CreateEvent(&security_attributes, FALSE, FALSE, nullptr));
}

// Helper thread class that runs the callback then stops.
class SingleTaskThread : public SimpleThread {
 public:
  explicit SingleTaskThread(const Closure& task)
      : SimpleThread("WaitChainTest SingleTaskThread"), task_(task) {}

  void Run() override { task_.Run(); }

 private:
  Closure task_;

  DISALLOW_COPY_AND_ASSIGN(SingleTaskThread);
};

// Helper thread to cause a deadlock by acquiring 2 mutexes in a given order.
class DeadlockThread : public SimpleThread {
 public:
  DeadlockThread(HANDLE mutex_1, HANDLE mutex_2)
      : SimpleThread("WaitChainTest DeadlockThread"),
        wait_event_(CreateEvent(false)),
        mutex_acquired_event_(CreateEvent(false)),
        mutex_1_(mutex_1),
        mutex_2_(mutex_2) {}

  void Run() override {
    // Acquire the mutex then signal the main thread.
    EXPECT_EQ(WAIT_OBJECT_0, ::WaitForSingleObject(mutex_1_, INFINITE));
    EXPECT_TRUE(::SetEvent(mutex_acquired_event_.Get()));

    // Wait until both threads are holding their mutex before trying to acquire
    // the other one.
    EXPECT_EQ(WAIT_OBJECT_0,
              ::WaitForSingleObject(wait_event_.Get(), INFINITE));

    // To unblock the deadlock, one of the threads will get terminated (via
    // TerminateThread()) without releasing the mutex. This causes the other
    // thread to wake up with WAIT_ABANDONED.
    EXPECT_EQ(WAIT_ABANDONED, ::WaitForSingleObject(mutex_2_, INFINITE));
  }

  // Blocks until a mutex is acquired.
  void WaitForMutexAcquired() {
    EXPECT_EQ(WAIT_OBJECT_0,
              ::WaitForSingleObject(mutex_acquired_event_.Get(), INFINITE));
  }

  // Signal the thread to acquire the second mutex.
  void SignalToAcquireMutex() { EXPECT_TRUE(::SetEvent(wait_event_.Get())); }

  // Terminates the thread.
  bool Terminate() {
    ScopedHandle thread_handle(::OpenThread(THREAD_TERMINATE, FALSE, tid()));
    return ::TerminateThread(thread_handle.Get(), 0);
  }

 private:
  ScopedHandle wait_event_;
  ScopedHandle mutex_acquired_event_;

  // The 2 mutex to acquire.
  HANDLE mutex_1_;
  HANDLE mutex_2_;

  DISALLOW_COPY_AND_ASSIGN(DeadlockThread);
};

// Creates a thread that joins |thread_to_join| and then terminates when it
// finishes execution.
std::unique_ptr<SingleTaskThread> CreateJoiningThread(
    SimpleThread* thread_to_join) {
  std::unique_ptr<SingleTaskThread> thread(new SingleTaskThread(
      Bind(&SimpleThread::Join, Unretained(thread_to_join))));
  thread->Start();

  return thread;
}

// Creates a thread that calls WaitForSingleObject() on the handle and then
// terminates when it unblocks.
std::unique_ptr<SingleTaskThread> CreateWaitingThread(HANDLE handle) {
  std::unique_ptr<SingleTaskThread> thread(new SingleTaskThread(
      Bind(IgnoreResult(&::WaitForSingleObject), handle, INFINITE)));
  thread->Start();

  return thread;
}

// Creates a thread that blocks on |mutex_2| after acquiring |mutex_1|.
std::unique_ptr<DeadlockThread> CreateDeadlockThread(HANDLE mutex_1,
                                                     HANDLE mutex_2) {
  std::unique_ptr<DeadlockThread> thread(new DeadlockThread(mutex_1, mutex_2));
  thread->Start();

  // Wait until the first mutex is acquired before returning.
  thread->WaitForMutexAcquired();

  return thread;
}

// Child process to test the cross-process capability of the WCT api.
// This process will simulate a hang while holding a mutex that the parent
// process is waiting on.
MULTIPROCESS_TEST_MAIN(WaitChainTestProc) {
  CommandLine* command_line = CommandLine::ForCurrentProcess();

  ScopedHandle mutex = GetSwitchValueHandle(command_line, "mutex");
  CHECK(mutex.IsValid());

  ScopedHandle sync_event(GetSwitchValueHandle(command_line, "sync_event"));
  CHECK(sync_event.IsValid());

  // Acquire mutex.
  CHECK(::WaitForSingleObject(mutex.Get(), INFINITE) == WAIT_OBJECT_0);

  // Signal back to the parent process that the mutex is hold.
  CHECK(::SetEvent(sync_event.Get()));

  // Wait on a signal from the parent process before terminating.
  CHECK(::WaitForSingleObject(sync_event.Get(), INFINITE) == WAIT_OBJECT_0);

  return 0;
}

// Start a child process and passes the |mutex| and the |sync_event| to the
// command line.
Process StartChildProcess(HANDLE mutex, HANDLE sync_event) {
  CommandLine command_line = GetMultiProcessTestChildBaseCommandLine();

  AppendSwitchHandle(&command_line, "mutex", mutex);
  AppendSwitchHandle(&command_line, "sync_event", sync_event);

  LaunchOptions options;
  HandlesToInheritVector handle_vector;
  handle_vector.push_back(mutex);
  handle_vector.push_back(sync_event);
  options.handles_to_inherit = &handle_vector;
  return SpawnMultiProcessTestChild("WaitChainTestProc", command_line, options);
}

// Returns true if the |wait_chain| is an alternating sequence of thread objects
// and synchronization objects.
bool WaitChainStructureIsCorrect(const WaitChainNodeVector& wait_chain) {
  // Checks thread objects.
  for (size_t i = 0; i < wait_chain.size(); i += 2) {
    if (wait_chain[i].ObjectType != WctThreadType)
      return false;
  }

  // Check synchronization objects.
  for (size_t i = 1; i < wait_chain.size(); i += 2) {
    if (wait_chain[i].ObjectType == WctThreadType)
      return false;
  }
  return true;
}

// Returns true if the |wait_chain| goes through more than 1 process.
bool WaitChainIsCrossProcess(const WaitChainNodeVector& wait_chain) {
  if (wait_chain.size() == 0)
    return false;

  // Just check that the process id changes somewhere in the chain.
  // Note: ThreadObjects are every 2 nodes.
  DWORD first_process = wait_chain[0].ThreadObject.ProcessId;
  for (size_t i = 2; i < wait_chain.size(); i += 2) {
    if (first_process != wait_chain[i].ThreadObject.ProcessId)
      return true;
  }
  return false;
}

}  // namespace

// Creates 2 threads that acquire their designated mutex and then try to
// acquire each others' mutex to cause a deadlock.
TEST(WaitChainTest, Deadlock) {
  // 2 mutexes are needed to get a deadlock.
  ScopedHandle mutex_1 = CreateMutex(false);
  ASSERT_TRUE(mutex_1.IsValid());
  ScopedHandle mutex_2 = CreateMutex(false);
  ASSERT_TRUE(mutex_2.IsValid());

  std::unique_ptr<DeadlockThread> deadlock_thread_1 =
      CreateDeadlockThread(mutex_1.Get(), mutex_2.Get());
  std::unique_ptr<DeadlockThread> deadlock_thread_2 =
      CreateDeadlockThread(mutex_2.Get(), mutex_1.Get());

  // Signal the threads to try to acquire the other mutex.
  deadlock_thread_1->SignalToAcquireMutex();
  deadlock_thread_2->SignalToAcquireMutex();
  // Sleep to make sure the 2 threads got a chance to execute.
  Sleep(10);

  // Create a few waiting threads to get a longer wait chain.
  std::unique_ptr<SingleTaskThread> waiting_thread_1 =
      CreateJoiningThread(deadlock_thread_1.get());
  std::unique_ptr<SingleTaskThread> waiting_thread_2 =
      CreateJoiningThread(waiting_thread_1.get());

  WaitChainNodeVector wait_chain;
  bool is_deadlock;
  ASSERT_TRUE(GetThreadWaitChain(waiting_thread_2->tid(), &wait_chain,
                                 &is_deadlock, nullptr, nullptr));

  EXPECT_EQ(9U, wait_chain.size());
  EXPECT_TRUE(is_deadlock);
  EXPECT_TRUE(WaitChainStructureIsCorrect(wait_chain));
  EXPECT_FALSE(WaitChainIsCrossProcess(wait_chain));

  ASSERT_TRUE(deadlock_thread_1->Terminate());

  // The SimpleThread API expect Join() to be called before destruction.
  deadlock_thread_2->Join();
  waiting_thread_2->Join();
}

// Creates a child process that acquires a mutex and then blocks. A chain of
// threads then blocks on that mutex.
TEST(WaitChainTest, CrossProcess) {
  ScopedHandle mutex = CreateMutex(true);
  ASSERT_TRUE(mutex.IsValid());
  ScopedHandle sync_event = CreateEvent(true);
  ASSERT_TRUE(sync_event.IsValid());

  Process child_process = StartChildProcess(mutex.Get(), sync_event.Get());
  ASSERT_TRUE(child_process.IsValid());

  // Wait for the child process to signal when it's holding the mutex.
  EXPECT_EQ(WAIT_OBJECT_0, ::WaitForSingleObject(sync_event.Get(), INFINITE));

  // Create a few waiting threads to get a longer wait chain.
  std::unique_ptr<SingleTaskThread> waiting_thread_1 =
      CreateWaitingThread(mutex.Get());
  std::unique_ptr<SingleTaskThread> waiting_thread_2 =
      CreateJoiningThread(waiting_thread_1.get());
  std::unique_ptr<SingleTaskThread> waiting_thread_3 =
      CreateJoiningThread(waiting_thread_2.get());

  WaitChainNodeVector wait_chain;
  bool is_deadlock;
  ASSERT_TRUE(GetThreadWaitChain(waiting_thread_3->tid(), &wait_chain,
                                 &is_deadlock, nullptr, nullptr));

  EXPECT_EQ(7U, wait_chain.size());
  EXPECT_FALSE(is_deadlock);
  EXPECT_TRUE(WaitChainStructureIsCorrect(wait_chain));
  EXPECT_TRUE(WaitChainIsCrossProcess(wait_chain));

  // Unblock child process and wait for it to terminate.
  ASSERT_TRUE(::SetEvent(sync_event.Get()));
  ASSERT_TRUE(child_process.WaitForExit(nullptr));

  // The SimpleThread API expect Join() to be called before destruction.
  waiting_thread_3->Join();
}

}  // namespace win
}  // namespace base
