#include "globals.h"
#include "sha.h"
#include "salsa.h"
#include "shittify.h"
#include "scrypt.h"

//__attribute__((vec_type_hint(uint)))

__attribute__((reqd_work_group_size(WORKSIZE, 1, 1)))
__kernel void search(const __global uint4 *restrict input,
volatile __global uint *restrict output,
#if (CLSIZE == 64)
__global uint16 *restrict padcache,
#else
__global uint8 *restrict padcache,
#endif
const uint4 midstate0, const uint4 midstate16, const uint target){

#if (CLSIZE == 64)
	uint16 X[2];
#else
	uint8 X[4];
#endif
	uint in0 = input[4].s0;
	uint in1 = input[4].s1;
	uint in2 = input[4].s2;
//	uint in3 = input[4].s3;

#ifdef GOFFSET
	uint gid = get_global_id(0);
#else
	uint gid = input[4].s3 + get_global_id(0);
#endif

	//const uint4 SKa = {0x80000000U, 0x00000280U, 0x5C5C5C5CU, 0x36363636U};
	//uint8 SKb = {zero, zero, zero, zero, zero, zero, zero, SK04};
	//uint8 SKc = {SK00, zero, zero, zero, zero, zero, zero, SK05};
//	const uint8 ShaK = {0x6a09e667U, 0xBB67AE85U, 0x3C6EF372U, 0xA54FF53AU, 0x510e527fU, 0x9b05688cU, 0x1F83D9ABU, 0x5BE0CD19U};

//	const uint8 Inid = {input[4].s0, input[4].s1, input[4].s2, gid, SK00, zero, zero, zero};

	uint8 ostate = {0x6a09e667U, 0xBB67AE85U, 0x3C6EF372U, 0xA54FF53AU, 0x510e527fU, 0x9b05688cU, 0x1F83D9ABU, 0x5BE0CD19U};
	//uint8 tstate = {0x6a09e667U, 0xBB67AE85U, 0x3C6EF372U, 0xA54FF53AU, 0x510e527fU, 0x9b05688cU, 0x1F83D9ABU, 0x5BE0CD19U};
	uint8 tstate = ostate;

	uint8 tmpa = {in0, in1, in2, gid, SK00, zero, zero, zero};
	uint8 tmpb = {zero, zero, zero, zero, zero, zero, zero, SK01}; //= (uint8)(zero, zero, zero, zero, zero, zero, zero, SK03);
	//uint8 tmpc = {midstate0, midstate16};
	uint8 tmpc = {midstate0.s0, midstate0.s1, midstate0.s2, midstate0.s3, midstate16.s0, midstate16.s1, midstate16.s2, midstate16.s3};
	//uint8 tmpc = {midstate0[0], midstate0[1], midstate0[2], midstate0[3], midstate16[0], midstate16[1], midstate16[2], midstate16[3]};
	//uint8 tstatebak;
	uint tsb0, tsb1, tsb2, tsb3, tsb4, tsb5, tsb6, tsb7;

	uint lnum = 1;

/*
volatile uint8 fixedW[8] = {(uint8)(0x428a2f99U,0xf1374491U,0xb5c0fbcfU,0xe9b5dba5U,0x3956c25bU,0x59f111f1U,0x923f82a4U,0xab1c5ed5U),
                                (uint8)(0xd807aa98U,0x12835b01U,0x243185beU,0x550c7dc3U,0x72be5d74U,0x80deb1feU,0x9bdc06a7U,0xc19bf794U),
                                (uint8)(0xf59b89c2U,0x73924787U,0x23c6886eU,0xa42ca65cU,0x15ed3627U,0x4d6edcbfU,0xe28217fcU,0xef02488fU),
                                (uint8)(0xb707775cU,0x0468c23fU,0xe7e72b4cU,0x49e1f1a2U,0x4b99c816U,0x926d1570U,0xaa0fc072U,0xadb36e2cU),
                                (uint8)(0xad87a3eaU,0xbcb1d3a3U,0x7b993186U,0x562b9420U,0xbff3ca0cU,0xda4b0c23U,0x6cd8711aU,0x8f337caaU),
                                (uint8)(0xc91b1417U,0xc359dce1U,0xa83253a7U,0x3b13c12dU,0x9d3d725dU,0xd9031a84U,0xb1a03340U,0x16f58012U),
                                (uint8)(0xe64fb6a2U,0xe84d923aU,0xe93a5730U,0x09837686U,0x078ff753U,0x29833341U,0xd5de0b7eU,0x6948ccf4U),
                                (uint8)(0xe0a1adbeU,0x7c728e11U,0x511c78e4U,0x315b45bdU,0xfca71413U,0xea28f96aU,0x79703128U,0x4e1ef848U)};

*/

	SHA256(&tmpc, &tmpa, &tmpb, one);
	tmpb = tmpc^SK02;
	tmpa = SK02;
	SHA256(&ostate, &tmpb, &tmpa, zero);
	tmpb = tmpc^SK03;
	tmpa = SK03;
	SHA256(&tstate, &tmpb, &tmpa, zero);

//	tmpc += SHA256(tmpc, tmpa, tmpb, one);
//	tmpb = tmpc^SK02;
//	ostate += SHA256(ostate, tmpb, SK02, zero);
//	tmpb = tmpc^SK03;
//	tstate += SHA256(tstate, tmpb, SK03, zero);



	//backup tstate
	//tstatebak = tstate;
	tsb0 = tstate.s0;
	tsb1 = tstate.s1;
	tsb2 = tstate.s2;
	tsb3 = tstate.s3;
	tsb4 = tstate.s4;
	tsb5 = tstate.s5;
	tsb6 = tstate.s6;
	tsb7 = tstate.s7;


	tmpb.lo = input[0];
	tmpb.hi = input[1];
	tmpc.lo = input[2];
	tmpc.hi = input[3];


	SHA256(&tstate, &tmpb, &tmpc, one);
//	tstate += SHA256(tstate, tmpb, tmpc, one);


//	tmpa.s0 = input[4].s0;
//	tmpa.s1 = input[4].s1;
//	tmpa.s2 = input[4].s2;
//	tmpa.s3 = gid;
//	tmpa.s4 = one;
//	tmpa.s5 = SK00;
//	tmpa.s6 = zero;
//	tmpa.s7 = zero;


//	tmpa.s4 = 1;
//	tmpa.s5 = SK00;
	tmpb = zero;
#if (CLSIZE == 64)
	for(uint i=0; i<4; i++){

//    uint8 SKb = {zero, zero, zero, zero, zero, zero, zero, SK04};
//    uint8 SKc = {SK00, zero, zero, zero, zero, zero, zero, SK05};
//reset a
//set b
		//reset a
		tmpa.s0 = in0;
		tmpa.s1 = in1;
		tmpa.s2 = in2;
		tmpa.s3 = gid;
		tmpa.s4 = lnum++;
		tmpa.s5 = SK00;
		tmpa.s6 = zero;
		tmpa.s7 = zero;
		//set b
		tmpb.s0 = zero;
		tmpb.s7 = SK04;
		//set c
		tmpc = tstate;
		SHA256(&tmpc, &tmpa, &tmpb, one);
		//lnum++;
		//set a
		tmpa = ostate;
		//set b
		tmpb.s0 = SK00;
		tmpb.s7 = SK05;
		SHA256(&tmpa, &tmpc, &tmpb, one);

/*
		if(i==0){
			X[0].lo = tmpa;
		}else if(i==1){
			X[0].hi = tmpa;
		}else if(i==2){
			X[1].lo = tmpa;
		}else if(i==3){
			X[1].hi = tmpa;
		}
*/
		//(i&one) ? X[i>>1].hi = tmpa : X[i>>1].lo = tmpa;


		if(i&one)
			X[i>>1].hi = tmpa;
		else
			X[i>>1].lo = tmpa;

		//lnum++;
/*
		//reset a
		tmpa.s0 = in0;
		tmpa.s1 = in1;
		tmpa.s2 = in2;
		tmpa.s3 = gid;
		tmpa.s4 = lnum;
		tmpa.s5 = SK00;
		tmpa.s6 = zero;
		tmpa.s7 = zero;
		//set b
		tmpb.s0 = zero;
		tmpb.s7 = SK04;
		//set c
		tmpc = tstate;
		SHA256(&tmpc, &tmpa, &tmpb, one);
		lnum++;
		//set a
		tmpa = ostate;
		//set b
		tmpb.s0 = SK00;
		tmpb.s7 = SK05;
		SHA256(&tmpa, &tmpc, &tmpb, one);
		X[i].hi = tmpa;
*/

//		tmpc = tstate;
//		SHA256(&tmpc, &tmpa, &SKb, one);
//		tmpa.s4++;
//		tmpb = ostate;
//		SHA256(&tmpb, &tmpc, &SKc, one);
//		X[i].lo = tmpb;

//		tmpc += SHA256(tmpc, tmpa, SKb, one);
//		tmpa.s4++;
//		X[i].lo = ostate;
//		X[i].lo += SHA256(ostate, tmpc, SKc, one);

//		tmpc = tstate;
//		SHA256(&tmpc, &tmpa, &SKb, one);
//		tmpa.s4++;
//		tmpb = ostate;
//		SHA256(&tmpb, &tmpc, &SKc, one);
//		X[i].hi = tmpb;

//		tmpc = tstate;
//		tmpc += SHA256(tstate, tmpa, SKb, one);
//		tmpa.s4++;
//		X[i].hi = ostate;
//		X[i].hi += SHA256(ostate, tmpc, SKc, one);
	}
#else
	for(uint i=0; i<4; i++){
		//reset a
		tmpa.s0 = in0;
		tmpa.s1 = in1;
		tmpa.s2 = in2;
		tmpa.s3 = gid;
		tmpa.s4 = lnum++;
		tmpa.s5 = SK00;
		tmpa.s6 = zero;
		tmpa.s7 = zero;
		//set b
		tmpb.s0 = zero;
		tmpb.s7 = SK04;
		//set c
		tmpc = tstate;
		SHA256(&tmpc, &tmpa, &tmpb, one);
		//lnum++;
		//set a
		tmpa = ostate;
		//set b
		tmpb.s0 = SK00;
		tmpb.s7 = SK05;
		SHA256(&tmpa, &tmpc, &tmpb, one);
		X[i] = tmpa;

//		tmpc = tstate;
//		SHA256(&tmpc, &tmpa, &SKb, one);

//		tmpa.s4++;
//		tmpb = ostate;
//		SHA256(&tmpb, &tmpc, &SKc, one);
//		X[i] = tmpb;

//		tmpc = SHA256(tstate, tmpa, SKb, one) + tstate;
//		tmpa.s4++;
//		X[i] = SHA256(ostate, tmpc, SKc, one) + ostate;
	}
#endif

#if (CLSIZE == 64)
//	X[0] = EndianSwapa(X[0].s49e38d27c16b05af);
//	X[1] = EndianSwapa(X[1].s49e38d27c16b05af);
	scrypt_core(X, padcache);
//	X[0] = EndianSwapa(X[0].sc9630da741eb852f);
//	X[1] = EndianSwapa(X[1].sc9630da741eb852f);
#else
	shittify(X);
	scrypt_core(X, padcache);
	unshittify(X);
#endif

	tstate.s0 = tsb0;
	tstate.s1 = tsb1;
	tstate.s2 = tsb2;
	tstate.s3 = tsb3;
	tstate.s4 = tsb4;
	tstate.s5 = tsb5;
	tstate.s6 = tsb6;
	tstate.s7 = tsb7;

#if (CLSIZE == 64)
	for(uint i=0; i<2; i++){
		tmpa = X[i].lo;
		tmpc = X[i].hi;
		SHA256(&tstate, &tmpa, &tmpc, one);
	}
	//tmpa = X[1].lo;
	//tmpc = X[1].hi;
	//SHA256(&tstatebak, &tmpa, &tmpc, one);

//	tstatebak += SHA256(tstatebak, X[0].lo, X[0].hi, one);
//	tstatebak += SHA256(tstatebak, X[1].lo, X[1].hi, one);
#else
	for(uint i=0; i<4; i+=2)
		SHA256(&tstate, X+i, X+i+1, one);

//		SHA256(&tstate, X, X+1, one);
//		SHA256(&tstate, X+2, X+3, one);

//		tstatebak += SHA256(tstatebak, X[0], X[1], one);
//		tstatebak += SHA256(tstatebak, X[2], X[3], one);
#endif



	tsb0 = tstate.s0;
	tsb1 = tstate.s1;
	tsb2 = tstate.s2;
	tsb3 = tstate.s3;
	tsb4 = tstate.s4;
	tsb5 = tstate.s5;
	tsb6 = tstate.s6;
	tsb7 = tstate.s7;
	//tstate = tstatebak;
	//SHA256_fixed(&tstatebak);

	for(uint i=0; i<8; i++){
#define A tsb0
#define B tsb1
#define C tsb2
#define D tsb3
#define E tsb4
#define F tsb5
#define G tsb6
#define H tsb7
		//tmpa = fixedW[i];
		RND(A,B,C,D,E,F,G,H, fixedWa[i]);
		RND(H,A,B,C,D,E,F,G, fixedWb[i]);
		RND(G,H,A,B,C,D,E,F, fixedWc[i]);
		RND(F,G,H,A,B,C,D,E, fixedWd[i]);
		RND(E,F,G,H,A,B,C,D, fixedWe[i]);
		RND(D,E,F,G,H,A,B,C, fixedWf[i]);
		RND(C,D,E,F,G,H,A,B, fixedWg[i]);
		RND(B,C,D,E,F,G,H,A, fixedWh[i]);
#undef A
#undef B
#undef C
#undef D
#undef E
#undef F
#undef G
#undef H
	}

	//tstate += tstatebak;

	tstate.s0 += tsb0;
	tstate.s1 += tsb1;
	tstate.s2 += tsb2;
	tstate.s3 += tsb3;
	tstate.s4 += tsb4;
	tstate.s5 += tsb5;
	tstate.s6 += tsb6;
	tstate.s7 += tsb7;
	//tstatebak += SHA256_fixed(tstatebak);

	//tmpb wasn't modified
	//tmpb.s0 = SK00;
	//tmpb.s7 = SK05;
	SHA256(&ostate, &tstate, &tmpb, one);

//	ostate += SHA256(ostate, tstatebak, SKc, one);

	bool found = (EndianSwapa((ostate.s7)) <= target);
	if(found)
		SETFOUND(gid);
}
