#include "globals.h"
#include "sha.h"
//#include "salsa.h"
//#include "shittify.h"
#include "scrypt.h"

__attribute__((reqd_work_group_size(WORKSIZE, 1, 1)))
__kernel void search(volatile __global uint4 *restrict input,
volatile __global uint *restrict output,
#if (CLSIZE == 64)
__global uint16 *restrict padcache,
#else
__global uint8 *restrict padcache,
#endif
volatile uint4 midstate0, volatile uint4 midstate16, const uint target){


#if (CLSIZE == 64)
	uint16 X[2];


#else
	uint8 X[4];
#endif

	uint in0 = input[4].s0;
	uint in1 = input[4].s1;
	uint in2 = input[4].s2;

#ifdef GOFFSET
	uint gid = get_global_id(0);
#else
	uint gid = input[4].s3 + get_global_id(0);
#endif

	uint8 ostate = {0x6a09e667U, 0xBB67AE85U, 0x3C6EF372U, 0xA54FF53AU, 0x510e527fU, 0x9b05688cU, 0x1F83D9ABU, 0x5BE0CD19U};
	uint8 tstate = ostate;

//	uint8 ostate, tstate;
	uint8 tmpa = {in0, in1, in2, gid, SK00, zero, zero, zero};
	uint8 tmpb = {zero, zero, zero, zero, zero, zero, zero, SK01};
	uint8 tmpc = {midstate0.s0, midstate0.s1, midstate0.s2, midstate0.s3, midstate16.s0, midstate16.s1, midstate16.s2, midstate16.s3};
	uint8 tmpd;
	//uint tsb0, tsb1, tsb2, tsb3, tsb4, tsb5, tsb6, tsb7;
	uint lnum = 1;

	SHA256(&tmpc, &tmpa, &tmpb);
	tmpb = tmpc^SK02;
	tmpa = SK02;

	SHA256(&ostate, &tmpb, &tmpa);

	tmpb = tmpc^SK03;
	tmpa = SK03;
	SHA256(&tstate, &tmpb, &tmpa);
	//backup tstate
	tmpd = tstate;

/*
	//backup tstate
	tsb0 = tstate.s0;
	tsb1 = tstate.s1;
	tsb2 = tstate.s2;
	tsb3 = tstate.s3;
	tsb4 = tstate.s4;
	tsb5 = tstate.s5;
	tsb6 = tstate.s6;
	tsb7 = tstate.s7;
*/
	tmpb.lo = input[0];
	tmpb.hi = input[1];
	tmpc.lo = input[2];
	tmpc.hi = input[3];

	//use padcache to backup tstate
	//uint x = (get_global_id(0)%CONCURRENT_THREADS);
   //const uint xSIZE = CONCURRENT_THREADS;
//#if (CLSIZE == 64)
//padcache[((gid%CONCURRENT_THREADS)<<1)].lo = tstate;
//#else
//padcache[((gid%CONCURRENT_THREADS)<<1)] = tstate;
//#endif
//(x<<1)+1

	//use input to backup tstate
//	input[0] = tstate.lo;
//	input[1] = tstate.hi;

	SHA256(&tstate, &tmpb, &tmpc);

	tmpb = zero;
#if (CLSIZE == 64)
	for(uint i=0; i<4; i++){
		//reset a
		tmpa.s0 = in0;
		tmpa.s1 = in1;
		tmpa.s2 = in2;
		tmpa.s3 = gid;
		tmpa.s4 = lnum++;
		tmpa.s5 = SK00;
		tmpa.s6 = zero;
		tmpa.s7 = zero;
		//set b
		tmpb.s0 = zero;
		tmpb.s7 = SK04;
		//set c
		tmpc = tstate;
		SHA256(&tmpc, &tmpa, &tmpb);
		//set a
		tmpa = ostate;
		//set b
		tmpb.s0 = SK00;
		tmpb.s7 = SK05;
		SHA256(&tmpa, &tmpc, &tmpb);

		if(i&one)
			X[i>>1].hi = tmpa;
		else
			X[i>>1].lo = tmpa;
	}
#else
	for(uint i=0; i<4; i++){
		//reset a
		tmpa.s0 = in0;
		tmpa.s1 = in1;
		tmpa.s2 = in2;
		tmpa.s3 = gid;
		tmpa.s4 = lnum++;
		tmpa.s5 = SK00;
		tmpa.s6 = zero;
		tmpa.s7 = zero;
		//set b
		tmpb.s0 = zero;
		tmpb.s7 = SK04;
		//set c
		tmpc = tstate;
		SHA256(&tmpc, &tmpa, &tmpb);
		//set a
		tmpa = ostate;
		//set b
		tmpb.s0 = SK00;
		tmpb.s7 = SK05;
		SHA256(&tmpa, &tmpc, &tmpb);
		X[i] = tmpa;
	}
#endif


	//restore tstate
//#if (CLSIZE == 64)
//	tstate = padcache[((gid%CONCURRENT_THREADS)<<1)].lo;
//#else
//	tstate = padcache[((gid%CONCURRENT_THREADS)<<1)];
//#endif

	tstate = tmpd;

#if (CLSIZE == 64)
	scrypt_core(X, padcache);
#else
	scrypt_core(X, padcache);
#endif

/*
	tstate.s0 = tsb0;
	tstate.s1 = tsb1;
	tstate.s2 = tsb2;
	tstate.s3 = tsb3;
	tstate.s4 = tsb4;
	tstate.s5 = tsb5;
	tstate.s6 = tsb6;
	tstate.s7 = tsb7;
*/

	//restore tstate
//	tstate.lo = input[0];
//	tstate.hi = input[1];

#if (CLSIZE == 64)
	for(uint i=0; i<2; i++){
		tmpa = X[i].lo;
		tmpc = X[i].hi;
		SHA256(&tstate, &tmpa, &tmpc);
	}
#else
	for(uint i=0; i<4; i+=2)
		SHA256(&tstate, X+i, X+i+1);
#endif

tmpa = zero;
tmpa.s0 = 0x00000001U;
tmpa.s1 = 0x80000000U;
tmpc = zero;
tmpc.s7 = 0x00000620U;

SHA256(&tstate, &tmpa, &tmpc);

//	SHA256_fixed(&tstate);

/*
	tsb0 = tstate.s0;
	tsb1 = tstate.s1;
	tsb2 = tstate.s2;
	tsb3 = tstate.s3;
	tsb4 = tstate.s4;
	tsb5 = tstate.s5;
	tsb6 = tstate.s6;
	tsb7 = tstate.s7;
//0x00000001, 0x80000000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x00000620
//0x00000001, 0x80000000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x00000620

#define A tsb0
#define B tsb1
#define C tsb2
#define D tsb3
#define E tsb4
#define F tsb5
#define G tsb6
#define H tsb7
	for(uint i=0; i<4; i++){
FXRD(A,B,C,D,E,F,G,H, Kc[i].s0);
FXRD(H,A,B,C,D,E,F,G, Kc[i].s1);
FXRD(G,H,A,B,C,D,E,F, Kc[i].s2);
FXRD(F,G,H,A,B,C,D,E, Kc[i].s3);
FXRD(E,F,G,H,A,B,C,D, Kc[i].s4);
FXRD(D,E,F,G,H,A,B,C, Kc[i].s5);
FXRD(C,D,E,F,G,H,A,B, Kc[i].s6);
FXRD(B,C,D,E,F,G,H,A, Kc[i].s7);
FXRD(A,B,C,D,E,F,G,H, Kc[i].s8);
FXRD(H,A,B,C,D,E,F,G, Kc[i].s9);
FXRD(G,H,A,B,C,D,E,F, Kc[i].sa);
FXRD(F,G,H,A,B,C,D,E, Kc[i].sb);
FXRD(E,F,G,H,A,B,C,D, Kc[i].sc);
FXRD(D,E,F,G,H,A,B,C, Kc[i].sd);
FXRD(C,D,E,F,G,H,A,B, Kc[i].se);
FXRD(B,C,D,E,F,G,H,A, Kc[i].sf);

B00 += Wr1(B14) + B09 + Wr2(B01) + Kc[i].s0;
B01 += Wr1(B15) + B10 + Wr2(B02) + Kc[i].s1;
B02 += Wr1(B00) + B11 + Wr2(B03) + Kc[i].s2;
B03 += Wr1(B01) + B12 + Wr2(B04) + Kc[i].s3;
B04 += Wr1(B02) + B13 + Wr2(B05) + Kc[i].s4;
B05 += Wr1(B03) + B14 + Wr2(B06) + Kc[i].s5;
B06 += Wr1(B04) + B15 + Wr2(B07) + Kc[i].s6;
B07 += Wr1(B05) + B00 + Wr2(B08) + Kc[i].s7;
B08 += Wr1(B06) + B01 + Wr2(B09) + Kc[i].s8;
B09 += Wr1(B07) + B02 + Wr2(B10) + Kc[i].s9;
B10 += Wr1(B08) + B03 + Wr2(B11) + Kc[i].sa;
B11 += Wr1(B09) + B04 + Wr2(B12) + Kc[i].sb;
B12 += Wr1(B10) + B05 + Wr2(B13) + Kc[i].sc;
B13 += Wr1(B11) + B06 + Wr2(B14) + Kc[i].sd;
B14 += Wr1(B12) + B07 + Wr2(B15) + Kc[i].se;
B15 += Wr1(B13) + B08 + Wr2(B00) + Kc[i].sf;


	}
#undef A
#undef B
#undef C
#undef D
#undef E
#undef F
#undef G
#undef H

	tstate.s0 += tsb0;
	tstate.s1 += tsb1;
	tstate.s2 += tsb2;
	tstate.s3 += tsb3;
	tstate.s4 += tsb4;
	tstate.s5 += tsb5;
	tstate.s6 += tsb6;
	tstate.s7 += tsb7;

*/

	SHA256(&ostate, &tstate, &tmpb);

	bool found = (EndianSwapa((ostate.s7)) <= target);
	if(found)
		SETFOUND(gid);
}