//======================================================================
// copyright (c) 2002-2020 Sergio Masci
// (further optimisation 2020)
//
// Fast Square Root
//
// this function is taken from the fp_sqrt function of the XCSB
// runtime library
//
// xsqrt is optimised for use on 8 bit micros
// it uses just shift, add, subtract and compare instructions
// it can compute the integer square root of an unsigned 16 bit
// number in just 8 iterations of the main loop
//======================================================================
#include 
#include 
#include 

#if !defined(FALSE)
#define FALSE 0
#define TRUE (!FALSE)
#endif

#if 0
#define	BIGN unsigned int
#define SMALLN unsigned short
#define sizeof_BIGN   32
#define sizeof_SMALLN 16
#else
#define	BIGN unsigned short
#define SMALLN unsigned char
#define sizeof_BIGN   16
#define sizeof_SMALLN  8
#endif

//======================================================================
//======================================================================

// (x+n)^2 = (x+n)(x+n) = x^2 + 2nx + n^2

//======================================================================
//======================================================================
SMALLN xsqrt(BIGN val)
{

// see
//	xsqrt_changes
// for explaination of optimisations to
//	xsqrt_original

	short	cnt;

	BIGN	xp2,
		acc,
		np2;


	if (val == 0)
	{
		return 0;
	}

#if (sizeof_BIGN > 16)

	acc = val;

	for (cnt=0; acc!=0; cnt++)
	{
		acc = acc >> 1;
	}
#else

	SMALLN	xacc;

	if ((val & 0xff00) != 0)
	{
		// this can be optimised to simply copying the
		// high byte
		xacc = (val >> 8) & 0xff;

		cnt = 7;
	}
	else
	{
		// this can be optimised to simply copying the
		// low byte
		xacc = val & 0xff;

		cnt = 0;
	}

	for (; xacc!=0; cnt++)
	{
		xacc = (xacc >> 1) & 0x7f;
	}
#endif

	// aprox sqrt
	// cnt = cnt / 2;
	cnt = cnt >> 1;

	// select integer power of 2 for 'n' so that 'n.x'
	// is equivalent to simply shifting 'x' by log2(n)

	// n^2
	// np2 = 1 << (cnt * 2);
	np2 = 1 << (cnt << 1);

	// x^2
	xp2 = 0;

	while (cnt >= 0)
	{
		acc = xp2 + np2;

		xp2 = xp2 >> 1;

		if (acc <= val)
		{
			val = val - acc;

			xp2 = xp2 + np2;
		}

		cnt--;

		np2 = np2 >> 2;
	}

	return xp2;
}


#if 0

//======================================================================
//======================================================================
SMALLN xsqrt_changes(BIGN val)
{
	short	cnt;

	// the obviation of x and n become apparent
	// once it is realised that "(xp2 << 1) == 0"
	// on entry to the main computation loop
	// and that "xp2 = xp2 << 1" can be moved from the
	// end of the loop. The redundent computation at the
	// last iteration of the loop now becomes a redundent
	// computation at the first iteration of the loop
	// leaving "xp2 == x" after the loop completes whereas
	// previously "xp2 == (x >> 1)"

	// SMALLN	x,
	//		n;

	BIGN	xp2,
		acc,
		np2;


	if (val == 0)
	{
		return 0;
	}

#if (sizeof_BIGN > 16)

	acc = val;

	for (cnt=0; acc!=0; cnt++)
	{
		acc = acc >> 1;
	}
#else

	SMALLN	xacc;

	if ((val & 0xff00) != 0)
	{
		// this can be optimised to simply copying the
		// high byte
		xacc = (val >> 8) & 0xff;

		cnt = 7;
	}
	else
	{
		// this can be optimised to simply copying the
		// low byte
		xacc = val & 0xff;

		cnt = 0;
	}

	for (; xacc!=0; cnt++)
	{
		xacc = (xacc >> 1) & 0x7f;
	}
#endif

	// aprox sqrt
	// cnt = cnt / 2;
	cnt = cnt >> 1;

	// NOTE 0
	// because of NOTE 7
	// x   = 0;
	// is redundent

	// NOTE 0b
	// because of NOTE 7

	// select integer power of 2 for 'n' so that 'n.x'
	// is equivalent to simply shifting 'x' by log2(n)
	// n   = 1 << cnt;
	// is redundent

	// n^2
	// np2 = 1 << (cnt * 2);
	np2 = 1 << (cnt << 1);

	// x^2
	xp2 = 0;

	// NOTE 1
	// xp2 = xp2 << 1;
	// xp2 was previously set to 0 so xp2 does not change

	while (cnt >= 0)
	{
		// NOTE 2
		// because of NOTE 1
		// acc = (xp2 << 1) + np2;
		// becomes:
		acc = xp2 + np2;

		// NOTE 3
		// because of NOTE 2
		xp2 = xp2 >> 1;

		if (acc <= val)
		{
			val = val - acc;

			// NOTE 4
			// because of NOTE 7
			// x = x + n;
			// this is now redundent
			
			xp2 = xp2 + np2;
		}

		cnt--;

		// NOTE 5
		// because of NOTE 4
		// n   = n   >> 1;
		// this is now redundent

		np2 = np2 >> 2;

		// NOTE 6
		// because of NOTE 3
		// xp2 = xp2 >> 1;
		// moved to NOTE 2
	}

	// NOTE 7
	// because of NOTE 6
	// return x;
	// becomes:

	return xp2;
}


//======================================================================
//======================================================================
SMALLN xsqrt_original(BIGN val)
{
	short	cnt;

	SMALLN	x,
		n;

	BIGN	xp2,
		acc,
		np2;


	if (val == 0)
	{
		return 0;
	}

#if (sizeof_BIGN > 16)

	acc = val;

	for (cnt=0; acc!=0; cnt++)
	{
		acc = acc >> 1;
	}
#else

	SMALLN	xacc;

	if ((val & 0xff00) != 0)
	{
		// this can be optimised to simply copying the
		// high byte
		xacc = (val >> 8) & 0xff;

		cnt = 7;
	}
	else
	{
		// this can be optimised to simply copying the
		// low byte
		xacc = val & 0xff;

		cnt = 0;
	}

	for (; xacc!=0; cnt++)
	{
		xacc = (xacc >> 1) & 0x7f;
	}
#endif

	// aprox sqrt
	// cnt = cnt / 2;
	cnt = cnt >> 1;

	x   = 0;

	// select integer power of 2 for 'n' so that 'n.x'
	// is equivalent to simply shifting 'x' by log2(n)
	n   = 1 << cnt;

	// n^2
	// np2 = 1 << (cnt * 2);
	np2 = 1 << (cnt << 1);

	// x^2
	xp2 = 0;

	while (cnt >= 0)
	{
		acc = (xp2 << 1) + np2;

		if (acc <= val)
		{
			val = val - acc;

			x = x + n;

			xp2 = xp2 + np2;
		}

		cnt--;

		n   = n   >> 1;
		np2 = np2 >> 2;

		xp2 = xp2 >> 1;
	}

	return x;
}

#endif


//======================================================================
//======================================================================
int main(int argc, char *argv[])
{
	int	val1, val2;
	int	j;

	printf("compiled with BIGN=%d bits and SMALLN=%d bits\n",
		sizeof(BIGN)*8, sizeof(SMALLN)*8);

	if (sizeof_BIGN   != sizeof(BIGN)*8  ||
	    sizeof_SMALLN != sizeof(SMALLN)*8)
	{
		printf("something wrong with compiler settings, BIGN and SMALLN do not match expected options\n");
		exit(1);
	}


	for (j=0; j<0x10000; j++)
	{
		val1 = xsqrt(j);
		val2 = (int)sqrt(j);

		printf("%c %d %d %d\n", ((val1 != val2) ? '*' : ' '), j, val1, val2);
	}
}