#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <pthread.h>
#include <time.h>
#include <sys/time.h>
#include <string.h>

struct timeval starttime,endtime;

//int block_size;

struct mem_chunk{
	int array_num;
	int block_size;
	long *a;
	long *b;	
}*mem_table;

void *MemWork(void *mem_arg);

void usage(void){
        printf("Usage: bench [threads] [max mem MB]\n");
        printf("        -h: this help menu\n");
        printf("\n");
}

int main(int argc, char **argv){
	int thread_count;
	int memtotal;
	int memportion;
	int c, num1, runs;
	int mem_arg, long_size;
	long asize;
	double te0;

        if(argc != 3){
                usage();
                exit(1);
        }else{
                for(c=0; c<argc; c++){
                        if( !strcmp(argv[c], "-h") || !strcmp(argv[c], "-H") ){
                                usage();
                                exit(0);
                        }
                }
        
                thread_count = atoi(argv[1]);
                memtotal = atoi(argv[2]);
                if(memtotal == 0 || thread_count == 0){
                        usage();
                        exit(1);
                }
        
                printf("\nRunning %d threads\n", thread_count);
                printf("Using %dMB of mem\n", memtotal);
	}

	pthread_t *threads = (pthread_t *)malloc(sizeof(pthread_t) * thread_count);
	memportion = floor(memtotal / thread_count / 2);

	mem_table = (struct mem_chunk *)malloc( sizeof(struct mem_chunk) * thread_count);
        long_size=sizeof(long);
        asize = (long)(memportion * 1024 * 1024 / long_size);
	//block_size = floor(memtotal * 1024 * 1024 / long_size);

	for(num1=0; num1<thread_count; num1++){
			mem_table[num1].a = calloc(asize, long_size);
			mem_table[num1].b = calloc(asize, long_size);
			mem_table[num1].array_num = num1;
			mem_table[num1].block_size = (int)(memportion * 1024 * 1024 );
	}

	printf("SIZE: %lu\n", sizeof(mem_table[0])*thread_count);

	gettimeofday(&starttime, NULL);
        for(runs=0; runs<4; runs++)
        {  

		for(num1=0; num1<thread_count; num1++){
			if (pthread_create(&threads[num1], NULL, MemWork, &mem_table[num1].array_num) != 0)
				perror("pthread_create"), exit(1);
		}
	

		for(num1=0; num1<thread_count; num1++){
			if (pthread_join(threads[num1], NULL) != 0)
				perror("pthread_join"),exit(1);
		}

	}
	gettimeofday(&endtime, NULL);
	te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	printf("\tMemory bandwidth\n");
	// done 4 times
	printf("\t\tMB/s: %lf\n\n", (double)memtotal*4/te0);
	

return 0;
}

void *MemWork(void *mem_arg)
{
/*
	long memtotal = floor( *(double *)mem_arg / 2 );
	int long_size=sizeof(long);
	long asize = (long)(memtotal * 1024 * 1024 / long_size);
        long block_size;
        long *a;
        long *b;



        a=calloc(asize, long_size);
        b=calloc(asize, long_size);


        block_size = floor(memtotal * 1024 * 1024 / long_size);
*/

	//printf("thread: %d\n", *(int *)mem_arg);
	//printf("blocksize: %d\n",   block_size );

	(void)memcpy( mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].block_size);

	
	//printf("memtotal: %lu block_s: %lu\n", memtotal, block_size);
	pthread_exit(0);

}
