#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <pthread.h>
#include <time.h>
#include <sys/time.h>
#include <string.h>
#include <xmmintrin.h>

struct timeval starttime,endtime;

float *a;
float *b;
//float *c;
//float *d;
int blk_size;

float z0[4];
float z1[4];

struct mem_chunk{
	int thread_id;
	long block_size;
	//long asize;
	float *a;
	float *b;
}*mem_table;

void *FloatWork(void *null);
//void AsmWrite(float *a, float *b, int *b_size);
void *MemWork1(void *mem_arg);
void *MemWork2(void *mem_arg);
void *MemWork3(void *mem_arg);
void *MemWork4(void *null);
//void *MemWork5(void *mem_arg);


void usage(void){
        printf("Usage: lsbench [threads] [max mem MB]\n");
        printf("        -h: this help menu\n");
        printf("\n");
}

int main(int argc, char **argv){
	int thread_count;
	int memtotal;
	int memportion;
	int c, num1;
	int float_size;
	long asize;
	long block_size;
	double te0;

        if(argc != 3){
                usage();
                exit(1);
        }else{
                for(c=0; c<argc; c++){
                        if( !strcmp(argv[c], "-h") || !strcmp(argv[c], "-H") ){
                                usage();
                                exit(0);
                        }
                }
        
                thread_count = atoi(argv[1]);
                memtotal = atoi(argv[2]);
                if(memtotal == 0 || thread_count == 0){
                        usage();
                        exit(1);
                }
        
                printf("\nRunning %d threads\n", thread_count);
                printf("Using %dMB of mem\n", memtotal);
	}

	pthread_t *threads = (pthread_t *)malloc(sizeof(pthread_t) * thread_count);
	memportion = (int)(memtotal / 2 / thread_count );

	mem_table = (struct mem_chunk *)malloc( sizeof(struct mem_chunk) * thread_count);
        float_size=sizeof(float);
        asize = (long)(memtotal / 2 * 1024 * 1024 / float_size);
	block_size = (long)(memportion * 1024 * 1024);

	a = malloc(sizeof(float)*asize);
	b = malloc(sizeof(float)*asize);
	//put some random stuff in them
	srand ( time(NULL) );
	for(num1=0; num1<asize; num1++){
		a[num1] = (float)rand();
		b[num1] = (float)rand();
	}
	asize = (long)(memportion * 1024 * 1024 / float_size);
	for(num1=0; num1<thread_count; num1++){	
			//mem_table[num1].asize = asize;
			mem_table[num1].thread_id = num1;
			mem_table[num1].block_size = block_size;
			mem_table[num1].a = &a[num1*asize];
			mem_table[num1].b = &b[num1*asize];
	}


	gettimeofday(&starttime, NULL);
		for(num1=0; num1<thread_count; num1++){
			//mem_table.thread_id[num1]=num1;
			if (pthread_create(&threads[num1], NULL, MemWork1, &mem_table[num1].thread_id) != 0)
				perror("pthread_create"), exit(1);
		}
	

		for(num1=0; num1<thread_count; num1++){
			if (pthread_join(threads[num1], NULL) != 0)
				perror("pthread_join"),exit(1);
		}
	gettimeofday(&endtime, NULL);
	te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	printf("\tMemory bandwidth\n");
	// 6 times read+write
	printf("\t\tmemcpy (read+write):\n");
	printf("\t\tMB/s: %lf\n", (double)memtotal*6/te0);


	gettimeofday(&starttime, NULL);
	for(num1=0; num1<thread_count; num1++){
		if (pthread_create(&threads[num1], NULL, MemWork2, &mem_table[num1].thread_id) != 0)
			perror("pthread_create"), exit(1);
	}
	for(num1=0; num1<thread_count; num1++){
		if (pthread_join(threads[num1], NULL) != 0)
			perror("pthread_join"),exit(1);
	}
	gettimeofday(&endtime, NULL);
	te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	printf("\t\tmemset (write):\n");	
	printf("\t\tMB/s: %lf\n", (double)memtotal*6/te0);


	gettimeofday(&starttime, NULL);
	for(num1=0; num1<thread_count; num1++){
		if (pthread_create(&threads[num1], NULL, MemWork3, &mem_table[num1].thread_id) != 0)
			perror("pthread_create"), exit(1);
	}
	for(num1=0; num1<thread_count; num1++){
		if (pthread_join(threads[num1], NULL) != 0)
			perror("pthread_join"),exit(1);
	}
	gettimeofday(&endtime, NULL);
	te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	printf("\t\tmemmove (read+write):\n");
	printf("\t\tMB/s: %lf\n", (double)memtotal*6/te0);

	float_size = 32;
	asize = (long)(memtotal / 2 * 1024 * 1024 / float_size);
	blk_size = asize;
	//c=&
	//d=&
	gettimeofday(&starttime, NULL);
	for(num1=0; num1<2; num1++){
		if (pthread_create(&threads[num1], NULL, MemWork4, NULL) != 0)
			perror("pthread_create"), exit(1);
	}
	for(num1=0; num1<2; num1++){
		if (pthread_join(threads[num1], NULL) != 0)
			perror("pthread_join"),exit(1);
	}
	gettimeofday(&endtime, NULL);
	te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	printf("\t\tloop assignment (read 2 threads):\n");
	printf("\t\tMB/s: %lf\n", (double)memtotal*6/te0);

/*
	gettimeofday(&starttime, NULL);
	for(num1=0; num1<thread_count; num1++){
		if (pthread_create(&threads[num1], NULL, MemWork4, &mem_table[num1].thread_id) != 0)
			perror("pthread_create"), exit(1);
	}
	for(num1=0; num1<thread_count; num1++){
		if (pthread_join(threads[num1], NULL) != 0)
			perror("pthread_join"),exit(1);
	}
	gettimeofday(&endtime, NULL);
	te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	printf("\t\tloop assignment (write):\n");
	printf("\t\tMB/s: %lf\n\n", (double)memtotal*6/te0);
*/
	free(a);
	free(b);

	printf("Running Floating Point Benchmarks\n\n");
                
	printf("\tSingle Threaded Performance:\n");


        for(num1=0; num1<4; num1++){
                srand (time(NULL));
                z0[num1]=rand()+1;
                z1[num1]=rand()-1;
                //z1[num1]=rand()+1;
                //z0[num1]=rand()-1;
        }


	gettimeofday(&starttime, NULL);

	if (pthread_create(&threads[0], NULL, FloatWork, NULL) != 0)
		perror("pthread_create"), exit(1);
	if (pthread_join(threads[0], NULL) != 0)
		perror("pthread_join"),exit(1);
	

	gettimeofday(&endtime, NULL);
	te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	te0 = 64/te0;
	printf("\t\tGFLOPS(estimated): %lf\n\n", te0);

	printf("\tMulti-Threaded Performance:\n");

	/*
        for(num1=0; num1<8; num1++){
                srand ( num1 );
		for(num2=0; num2<thread_count; num2++){
                	z3[num2][num1]=rand();
                	while(z3[num2][num1] == 0){
                        	z3[num2][num1]=rand();
                	}
		}
        }
	*/

        for(num1=0; num1<4; num1++){
                srand (time(NULL));
                z0[num1]=rand()+1;
                z1[num1]=rand()-1;
                //z1[num1]=rand()+1;
                //z0[num1]=rand()-1;
        }


	gettimeofday(&starttime, NULL);
        for(num1=0; num1<thread_count; num1++){
		//ft[num1].thread_id=num1;
                if (pthread_create(&threads[num1], NULL, FloatWork, NULL) != 0)
                        perror("pthread_create"), exit(1);
        }
        for(num1=0; num1<thread_count; num1++){
                if (pthread_join(threads[num1], NULL) != 0)
                        perror("pthread_join"),exit(1);
        }
        gettimeofday(&endtime, NULL);
        te0=((double)(endtime.tv_sec*1000000-starttime.tv_sec*1000000+endtime.tv_usec-starttime.tv_usec))/1000000;
	te0 = (float)64*thread_count/te0;
	printf("\t\tGFLOPS(estimated): %lf\n\n", te0);



return 0;
}

void *FloatWork(void *null)
{

#ifdef OSX
        asm{
                movups xmm0 z0
                movups xmm1 z1
		movups xmm2 xmm1
		movups xmm3 xmm0
		movups xmm4 xmm1
		movups xmm5 xmm0
		movups xmm6 xmm1
		movups xmm7 xmm0
                mov ecx 4000000000
                LOOP1:
                	addps xmm0 xmm4
                	mulps xmm1 xmm5
                	addps xmm2 xmm6
			mulps xmm3 xmm7
		dec ecx
		jnz LOOP1
        }


#else
		asm(".intel_syntax noprefix\n");
			asm("movups xmm0, [z0]\n");
			asm("movups xmm1, [z1]\n");
			asm("movups xmm2, xmm1\n");
			asm("movups xmm3, xmm0\n");
			asm("movups xmm4, xmm1\n");
			asm("movups xmm5, xmm0\n");
			asm("movups xmm6, xmm1\n");
			asm("movups xmm7, xmm0\n");
			asm("mov ecx, 4000000000\n");
			asm("LOOP1:\n");
				asm("addps xmm0, xmm4\n");
				asm("mulps xmm1, xmm5\n");
                		asm("addps xmm2, xmm6\n");
                		asm("mulps xmm3, xmm7\n");
      			asm("dec ecx\n");
			asm("jnz LOOP1\n");
#endif

	pthread_exit(0);
}

void *MemWork1(void *mem_arg)
{
	//printf("thread: %d\n", *(int *)mem_arg);
	//printf("blocksize: %d\n",   block_size );

	//printf("tid: %d copying %lu bytes\n", *(int *)mem_arg, mem_table[*(int *)mem_arg].block_size );

	(void)memcpy( mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].block_size);
	(void)memcpy( mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].block_size);
	(void)memcpy( mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].block_size);
	(void)memcpy( mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].block_size);
	(void)memcpy( mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].block_size);
	(void)memcpy( mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].block_size);

	//printf("memtotal: %lu block_s: %lu\n", memtotal, block_size);
	pthread_exit(0);

}

void *MemWork2(void *mem_arg)
{
	(void)memset (mem_table[*(int *)mem_arg].b, 124, mem_table[*(int *)mem_arg].block_size);
	(void)memset (mem_table[*(int *)mem_arg].a, 122, mem_table[*(int *)mem_arg].block_size);

        (void)memset (mem_table[*(int *)mem_arg].b, 12, mem_table[*(int *)mem_arg].block_size);
        (void)memset (mem_table[*(int *)mem_arg].a, 18, mem_table[*(int *)mem_arg].block_size);

        (void)memset (mem_table[*(int *)mem_arg].b, 24, mem_table[*(int *)mem_arg].block_size);
        (void)memset (mem_table[*(int *)mem_arg].a, 22, mem_table[*(int *)mem_arg].block_size);

        (void)memset (mem_table[*(int *)mem_arg].b, 2, mem_table[*(int *)mem_arg].block_size);
        (void)memset (mem_table[*(int *)mem_arg].a, 4, mem_table[*(int *)mem_arg].block_size);

        (void)memset (mem_table[*(int *)mem_arg].b, 36, mem_table[*(int *)mem_arg].block_size);
        (void)memset (mem_table[*(int *)mem_arg].a, 48, mem_table[*(int *)mem_arg].block_size);

        (void)memset (mem_table[*(int *)mem_arg].b, 96, mem_table[*(int *)mem_arg].block_size);
        (void)memset (mem_table[*(int *)mem_arg].a, 87, mem_table[*(int *)mem_arg].block_size);


	pthread_exit(0);
}

void *MemWork3(void *mem_arg)
{
	(void)memmove( mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].block_size);
	(void)memmove( mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].block_size);
	(void)memmove( mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].block_size);
	(void)memmove( mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].block_size);
	(void)memmove( mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].block_size);
	(void)memmove( mem_table[*(int *)mem_arg].a, mem_table[*(int *)mem_arg].b, mem_table[*(int *)mem_arg].block_size);

	pthread_exit(0);
}





void *MemWork4(void *null)
{       
	//int index=*(int *)mem_arg;
        //float *a = (float *)&mem_table[index].a[0];
        //float *b = (float *)&mem_table[index].b[0];
	//floats in array
	//int b_size=mem_table[index].block_size/32;
	//b_size-=32;
	//mem_table[index].block_size=b_size;

printf("here 1: %d\n", blk_size);
	//64 bytes of register love
#ifdef OSX
        asm{    
                movups xmm0 z0
                movups xmm1 z1
                movups xmm2 xmm1
                movups xmm3 xmm0

		mov esi a
		mov edi b
                mov ecx blk_size
                LOOP2:
 			movaps   [edi] xmm0
			movaps   [edi+0x10] xmm1
			movaps   [esi] xmm2
			movaps   [esi+0x10] xmm3
                dec ecx
                jnz LOOP2
        }


#else           
/*                asm(".intel_syntax noprefix\n");
                        asm("movups xmm0, [z0]\n");
                        asm("movups xmm1, [z1]\n");
                        asm("movups xmm2, xmm1\n");
                        asm("movups xmm3, xmm0\n");
printf("here 3\n");
			asm("mov esi, a\n");
			asm("mov edi, b\n");
                        asm("mov ecx, [blk_size]\n");
*/
printf("here 4\n");

int num1;

//for(num1=0; num1<blk_size; num1+=32){ 
		//asm(".intel_syntax noprefix\n");
			//asm("mov ecx, [blk_size]\n");
__m128 buff0; //= _mm_load_ps(z0);
__m128 buff1; //= _mm_load_ps(z1);
__m128 buff2;
__m128 buff3;
__m128 buff4;
__m128 buff5;
//__m128 buff0 = _mm_load_ps(z0);
//__m128 buff1 _mm_load_ps(z1);
int num2=0;

for(num1=0; num1<blk_size; num1++){
                        //asm("LOOP3:\n");

	buff0 = _mm_load_ps(&a[num2]);
	buff1 = _mm_load_ps(&b[num2]);
	buff4 = _mm_add_ps(buff0 , buff1 );
        buff2 = _mm_load_ps(&a[num2+4]);
        buff3 = _mm_load_ps(&b[num2+4]);
	buff5 = _mm_add_ps(buff2 , buff3 );

	num2+=8;

				//_mm_store_ps(buff1, &a[num1] );
				//_mm_store_ps(buff0, &b[num1] );
				//_mm_store_ps(buff1, &a[num1] );
				//_mm_store_ps(buff0, &b[num1] );
                                //asm("movups   [edi], xmm0\n");
				//asm("add edi, 0x10\n");
                                //asm("movups   [edi], xmm1\n");
				//asm("add edi, 0x10\n");
                                //asm("movups   [esi], xmm2\n");
				//asm("add esi, 0x10\n");
                                //asm("movups   [esi], xmm3\n");
				//asm("add edi, 0x20\n");
				//asm("add esi, 0x20\n");
                        //asm("dec ecx\n");
                        //asm("jnz LOOP3\n");
}

#endif  

printf("here 2\n");
	pthread_exit(0);
} 

      

/*
void *MemWork5(void *mem_arg)
{
        //pray for register, no asm skills
        char c1[]={22, 44, 11, 25, 36, 44, 69, 71};
                        
        int num1, num2;
        int index=*(int *)mem_arg;
        int b_size=mem_table[index].block_size-8;
        char *a = (char *)&mem_table[index].a[0];
        char *b = (char *)&mem_table[index].b[0];
                        
        for(num2=0; num2<6; num2++){
                for(num1=0; num1<b_size; num1+=8){
                        
                        (void)memcpy( &a, &c1[0], 8);  
                        (void)memcpy( &b, &c1[0], 8);  
                        a+=8;
                        b+=8;
                }
        }
        pthread_exit(0);
}
*/
