/* HIT_START
 * BUILD: %t %s ../../test_common.cpp NVCC_OPTIONS -std=c++11
 * TEST_NAMED: %t hipMultiThreadDevice-serial --tests 0x1
 * TEST_NAMED: %t hipMultiThreadDevice-pyramid --tests 0x4
 * TEST_NAMED: %t hipMultiThreadDevice-nearzero --tests 0x10
 * HIT_END
 */

#include "hip/hip_runtime_api.h"
#include "test_common.h"

#ifdef _WIN32
#define MAX_BURST_SIZE   40
#else
#define MAX_BURST_SIZE   100
#endif

// Create a lot of streams and then destroy 'em.
void createThenDestroyStreams(int iterations, int burstSize) {
    hipStream_t* streams = new hipStream_t[burstSize];

    for (int i = 0; i < iterations; i++) {
        if (p_verbose & 0x1) {
            printf("%s iter=%d, create %d then destroy %d\n", __func__, i, burstSize, burstSize);
        }
        for (int j = 0; j < burstSize; j++) {
            if (p_verbose & 0x2) {
                printf("  %d.%d streamCreate\n", i, j);
            }
            HIPCHECK(hipStreamCreate(&streams[j]));
        }
        for (int j = 0; j < burstSize; j++) {
            if (p_verbose & 0x2) {
                printf("  %d.%d streamDestroy\n", i, j);
            }
            HIPCHECK(hipStreamDestroy(streams[j]));
        }
    }

    delete[] streams;
}


void waitStreams(int iterations) {
    // Repeatedly sync and wait for all streams to complete.
    // TO make this interesting, the test has other threads repeatedly adding and removing streams
    // to the device.
    for (int i = 0; i < iterations; i++) {
        HIPCHECK(hipDeviceSynchronize());
    }
}


// Create 3 streams, all creating and destroying streams on the same device.
// Some create many queue, some not many.
//
void multiThread_pyramid(bool serialize, int iters) {
    printf("%s creating %d streams\n", __func__, iters * MAX_BURST_SIZE);
    std::thread t1(createThenDestroyStreams, iters * 1, MAX_BURST_SIZE);
    if (serialize) {
        t1.join();
        printf("t1 done\n");
    }

    std::thread t2(createThenDestroyStreams, iters * 10, 10);
    if (serialize) {
        t2.join();
        printf("t2 done\n");
    }

    std::thread t3(createThenDestroyStreams, iters * 100, 1);
    if (serialize) {
        t3.join();
        printf("t3 done\n");
    }

    if (!serialize) {
        t1.join();
        t2.join();
        t3.join();
    }
}


// Create 3 streams, all creating and destroying streams on the same device.
// Try to keep number of streams near zero, to cause problems.
void multiThread_nearzero(bool serialize, int iters) {
    printf("%s creating %d streams x 3 threads\n", __func__, iters);
    std::thread t1(createThenDestroyStreams, iters, 1);
    if (serialize) {
        t1.join();
        printf("t1 done\n");
    }

    std::thread t2(createThenDestroyStreams, iters, 1);
    if (serialize) {
        t2.join();
        printf("t2 done\n");
    }

    std::thread t3(waitStreams, iters * 50);
    if (serialize) {
        t3.join();
        printf("t3 done\n");
    }

    if (!serialize) {
        t1.join();
        printf("t1 done\n");
        t2.join();
        printf("t2 done\n");
        t3.join();
        printf("t3 done\n");
    }
}

int main(int argc, char* argv[]) {
    HipTest::parseStandardArguments(argc, argv, true);

    // Serial version, just call once:
    if (p_tests & 0x1) {
        printf("\ntest 0x1 : serial createThenDestroyStreams(10) \n");
        createThenDestroyStreams(10, 10);
    };

    /*disable, this takess a while and if the next one works then no need to run serial*/
    if (1 && (p_tests & 0x2)) {
        printf("\ntest 0x2 : serialized multiThread_pyramid(1) \n");
        multiThread_pyramid(true, 3);
    }

    if (p_tests & 0x4) {
        printf("\ntest 0x4 : parallel multiThread_pyramid(1) \n");
        multiThread_pyramid(false, 3);
    }

    // if (p_tests & 0x8) {
    //    printf ("test 0x8 : multiThread_pyramid(100) \n");
    //    multiThread_pyramid(false, 100);
    // }

    if (p_tests & 0x10) {
        printf("\ntest 0x10 : parallel multiThread_nearzero(1000) \n");
        multiThread_nearzero(false, 1000);
    }

    passed();
}
