//-----------------------------------------------------------------------------
//
//  Copyright 2002 Philips Semiconductors Limited
//
//  Philips Semiconductors - Millbrook Industrial Estate
//  Southampton - SO15 0DJ - UK
//
//  All rights are reserved. Reproduction in whole or part is prohibited
//  without the written prior consent of the copyright owner.
//
//  Company Confidential
//
//  Filename: mont.o1.c
//  Project: Centaurus2 / Tamperproof
//
//  Rev		Date		Author		Comments
//  -------------------------------------------------------------------
//  001		05/12/2002	A.Badey		Original version for Centaurus2
//
//  Additionnal information:
//  This code is based on the mont.c code by Bruce Murray. Refer to this code
//  for more information on the revisions made.
//
//  Function:
//  This module is a generic implementation of the RSA algorithm with a key
//  of variable length.
//
//  Warning:
//  This code is covered by the Export Control Laws on Cryptographic Material.
//  Exportation of this software should be made in accordance with the European
//  Union and United Kingdom Community General Export Authorisation legislation.
//  For more information see
//  http://pww.export-control.corp.philips.com/man_ec/index.htm
//
//----------------------------------------------------------------------------- 


//--------------------
// INCLUDE FILES
//--------------------
#include "bn.h"
#include "mont.h"

//--------------------
// LOCAL MACROS 
//--------------------

//--------------------
// EXPORTED DATA
//--------------------

//--------------------
// LOCAL TYPEDEFS
//--------------------

//--------------------
// STATIC DATA
//--------------------

//--------------------
// FUNCTION PROTOTYPES
//--------------------



//----------------------------------------------------
//
// Name: rsa_decode
//       This function performs the encryption of inputblock as specified by
//       the RSA algorithm.
//       This function is a combination of the InitExp and ModExp functions
//		 so that everything is done here. (see Bruce Murray's code)
//		 It redefines the MontInfo variable inside of the function and does
//		 not use the global variable usually defined in mont.h .
//		 There is NO algorithmic difference between the code in this function
//		 and the functionnality of the previous InitExp and ModExp
// Returns: 0 if computed result is valid, 1 otherwise
//
// Parameter	Flow	Description
// ----------------------------------------------
// OutputBlock	out		OutputBlock (Result of exponentiation)
// InputBlock	in		InputBlock  (Number to exponentiate)
// p			in		Modulus
// d			in		Exponent
//
// Additionnal information:
//
//----------------------------------------------------

int rsa_decode_256(Bignum_t *OutputBlock, Bignum_t *InputBlock, Bignum_t *p, Bignum_t *d, Bignum_t *workarea)
{

	//allocate 800 bytes in DEM for the MontInfo variable
	MONTGOMERY_CONSTS *MontInfo = (MONTGOMERY_CONSTS*)workarea;
	
	int n;
	int NumBitsInExponent;
	unsigned long * src_ptr;
	unsigned long * dest_ptr;

	#ifdef __DEBUG__
	printf (">> rsa_decode: beginning.....\n\r");
	#endif

	// Here we use the output block as an accumulator
	
	// Initialise the modulus value in the MontInfo structure
#ifdef EMULATOR_TEST
	 debug_write(char_R,0,(unsigned long)&MontInfo->global_p,(unsigned long)p);
#endif

#ifdef REMOVE_MEMCPY
     BnCpy((unsigned long *)&MontInfo->global_p,(unsigned long *)p);
#else
/*     dest_ptr = (unsigned long *) &MontInfo->global_p;
	 src_ptr = (unsigned long *) p;
     n = 0;
     debug_write(0xD, (unsigned long) dest_ptr, (unsigned long) src_ptr, n);

	 while (n < (BYTES_PER_BIGNUM>>2))
	 {
	    dest_ptr[n] = src_ptr[n];
		debug_write(10, dest_ptr[n], src_ptr[n], n);
		n=n+1;
	 }
*/

 #ifdef ENDIANNESS_VARIABLE
	if(u32s_per_bignum == U32s_PER_BIGNUM)
		memcpy (&MontInfo->global_p, p, BYTES_PER_BIGNUM);
	else
		memcpy (&MontInfo->global_p, p, BYTES_PER_1KBITS);
 #else
		memcpy (&MontInfo->global_p, p, BYTES_PER_BIGNUM);
 #endif

#endif

#ifdef EMULATOR_TEST_DEBUG
	 debug_write(char_r,1,(unsigned long)&MontInfo->global_p,(unsigned long)p);
#endif

	// Initialise the number of digits (slices) in the modulus
	MontInfo->Modulus_Digits = BnNumDigits (p); 
#ifdef EMULATOR_TEST_DEBUG
	 debug_write(char_m,2,(unsigned long)&MontInfo->Modulus_Digits,(unsigned long)p);
#endif
   
	// Initialise the value of R in MontInfo: this is the number of bits in the modulus
	// saved as a power of 2
	MontInfo->R_power = (MontInfo->Modulus_Digits * 32);
 
   	// Initialise the value of ModHash in MontInfo: this is m' = -p^(-1) mod 2^32
	MontInfo->Mod_dash = (0 - MnInvModTwoToTheThirtyTwo (p->slice[0]));

	// Initialise the Omura correction factor in MontInfo
	Omura (&MontInfo->Omura_corr, (Bignum_t *) p);  

#ifdef EMULATOR_TEST_DEBUG
	 debug_write(char_o,3,(unsigned long)&MontInfo->Omura_corr,(unsigned long)p);
#endif

	// Convert the input to Montgomery space and store its value in MontInfo
	MnConvToMont (&MontInfo->InputBlock, InputBlock, MontInfo);

#ifdef EMULATOR_TEST_DEBUG
	 debug_write(char_c,4,(unsigned long)InputBlock,(unsigned long)MontInfo);
#endif
	
	// Get the number of significant bits in the exponent ie we know the MSB equals 1
	// during the computation
	NumBitsInExponent = BnNumBitsIn (d);

#ifdef EMULATOR_TEST_DEBUG
	 debug_write(0xE,5,(unsigned long)NumBitsInExponent,(unsigned long)d);
#endif

#ifdef EMULATOR_TEST_DEBUG
	 debug_write(0xEC,0,(unsigned long)OutputBlock,(unsigned long)&MontInfo->InputBlock);
#endif

#ifdef REMOVE_MEMCPY
     BnCpy((unsigned long *)OutputBlock, (unsigned long *)&MontInfo->InputBlock);
#else
/*     dest_ptr = (unsigned long *) OutputBlock;
	 src_ptr = (unsigned long *) &MontInfo->InputBlock;
     n = 0;
     debug_write(0xEC, (unsigned long) dest_ptr, (unsigned long) src_ptr, n);

	 while (n < (BYTES_PER_BIGNUM>>2))
	 {
	    dest_ptr[n] = src_ptr[n];
		debug_write(0xEE, dest_ptr[n], src_ptr[n], n);
		n=n+1;
	 }
*/
	// Initialise the value of the accumulator with the value of the Input Block
	if(u32s_per_bignum == U32s_PER_BIGNUM)
		memcpy (OutputBlock, &MontInfo->InputBlock, BYTES_PER_BIGNUM);
	else
		memcpy (OutputBlock, &MontInfo->InputBlock, BYTES_PER_1KBITS);

#endif

	#ifdef __DEBUG__
	printf (">> rsa_decode: running Mont exponantiation main routine\n\r");
	#endif

#ifdef EMULATOR_TEST_DEBUG
	 debug_write(char_d,4,(unsigned long)OutputBlock,(unsigned long)&MontInfo->InputBlock);
#endif

	// Now run the Montgomery exponantiation main routine
	for (n = (NumBitsInExponent - 2); n >= 0; n--)
	{
		//debug_write(char_m,1,(unsigned long)d,n);
		// Square the Accumulator
		MnMontMul (OutputBlock, OutputBlock, OutputBlock, MontInfo);
   
		// Multiply Acc by the (Mont) input value if bit n of exponent is set
		if (BnTestBit (d, n))
		{
			MnMontMul (OutputBlock, OutputBlock, &MontInfo->InputBlock, MontInfo);
		}
	}

	// Convert result from Montgomery space back to normal space
	// We use the property x = MonPro(x_mon,1)
	// We need to define a bignum variable initialised to 1 => we use MontInfo->InputBlock
	//  as we don't need its value anymore
	
	#ifdef __DEBUG__
	printf (">> rsa_decode: convert result back to normal space\n\r");
	#endif
	
	BnMake    (&MontInfo->InputBlock, 1);

	#ifdef __DEBUG__
	printf (">> rsa_decode: running MnMontMul\n\r");
	#endif

	MnMontMul (OutputBlock, OutputBlock, &MontInfo->InputBlock, MontInfo);

	#ifdef __DEBUG__
	printf (">> rsa_decode: returning 0 (pass)\n\r");
	#endif

	return (0);

}


//----------------------------------------------------
//
// Name: DualMacWithShift
//       performs the function: Result = ((Result + Vector1*Scalar1 + Vector2*Scalar2) >> 32)
// Returns: void
//
// Parameter	Flow	Description
// ----------------------------------------------
// Result		in/out	
// Vec1			in		
// Scalar1		in		
// Vec2			in		
// Scalar2		in		
//
// Additionnal information:
//'-> Result' must be one ULONG larger than the other input vectors
// to be able to hold the possible final 1 bit overflow.
// -> This function is coded twice: once in pure C (portable between
// platforms) and once with MIPS assembly code optimisation
//
//----------------------------------------------------


static void DualMacWithShift (unsigned int *Result, unsigned int *Vec1, unsigned int Scalar1, unsigned int *Vec2, unsigned int Scalar2)
{
   int i;
   
   unsigned long long Product;
   unsigned int       Temp;
   unsigned int       Overflow;
   unsigned int       CarryWord1;
   unsigned int       CarryWord2;
   
   Product    = ((unsigned long long) (*Vec1++) * Scalar1) + *Result;
   CarryWord1 = (Product >> 32);
   Temp       = (Product >>  0);

   Product    = ((unsigned long long) (*Vec2++) * Scalar2) + Temp;
   CarryWord2 = (Product >> 32);

   for (i = (u32s_per_bignum - 1); i; i--)
   {
      Product    = ((unsigned long long) (*Vec1++) * Scalar1) + CarryWord1 + *(Result + 1);
      CarryWord1 = (Product >> 32);
      Temp       = (Product >>  0);

      Product    = ((unsigned long long) (*Vec2++) * Scalar2) + CarryWord2 + Temp;
      CarryWord2 = (Product >> 32);
      Temp       = (Product >>  0);

      *Result++  = Temp;
   }
      
   Temp = CarryWord1 + *(Result + 1);
   Overflow = ((Temp < CarryWord1) ? 1 : 0);
   
   Temp += CarryWord2;
   if (Temp < CarryWord2) Overflow++;
   
   *Result++ = Temp;
   *Result   = Overflow;
}

// This function has been optimised for the MIPS by implementing some of the main
// routines im MIPS assembly code. Also some variables have been register assigned
// so accesses to external memories is maintained as low as possible.
// Some questions still remains (that will have to be solved when the code is simulated
// on a MIPS simulation platform...): does the compiler save and restore correctly all
// the registers (seems the case for some of them but not all)? Are the assembly coded
// routines functionnally correct (should be yes but cannot be proved for the moment)?
/*
static void DualMacWithShift (unsigned int *Result, unsigned int *Vec1, unsigned int Scalar1, unsigned int *Vec2, unsigned int Scalar2)
{
	int i;
   
	register unsigned long result_v asm ("$t0");
	register unsigned long vec1_v asm ("$t2");
	register unsigned long vec2_v asm ("$t3");
	register unsigned long CarryWord1 asm ("$s0");
	register unsigned long CarryWord2 asm ("$s1");
	register unsigned long temp asm ("$s2");
	
	vec1_v = *Vec1;
    Vec1++;

	vec2_v = *Vec2;
	Vec2++;
	
	result_v = *Result;
	
	// The following routine is equivalent to:
	// (mfhi,mflo) = vec1_v * scalar1
	// temp = mflo + result0
	// carryword1 = 1 if overflow else 0 (based on result < operand if overflow)
	// carryword1 = mfhi + overflow (with overflow being carryword1)
	asm("multu %0, %1;
	     mfhi $k0;
	     mflo $k1;
		 addu $s2, $k1, %2;
		 sltu $s0, $s2, $k1;
		 addu $s0, $s0, $k0" : : "d" (vec1_v), "d" (Scalar1), "d" (result_v) );
	
	// The following routine is equivalent to:
	// (mfhi,mflo) = vec2_v * scalar2
	// temp = mflo + temp
	// carryword2 = 1 if overflow else 0 (based on result < operand if overflow)
	// carryword2 = mfhi + overflow (with overflow being carryword2)
	asm("multu %0, %1;
         mfhi $k0;
	     mflo $k1;
		 addu $s2, $k1, $s2;
		 sltu $s1, $k1, $s2;
		 addu $s1, $s1, $k0" : : "d" (vec2_v), "d" (Scalar2) );
	
	for (i = (u32s_per_bignum - 1); i; i--)
	{
	
		result_v = *(Result + 1);
		
		vec1_v = *Vec1;
		Vec1++;
		
		vec2_v = *Vec2;
		Vec2++;
		
		// The following routine is equivalent to:
		// (mfhi,mflo) = vec1 * scalar1
		// temp = carryword1 + result1
		// carryword1 = 1 if overflow else 0 (based on result < operand if overflow)
		// k0 = mfhi + overflow
		// temp = temp + mflo (ie temp = mflo + carryword1 + result)
		// carryword1 = 1 if overflow else 0
		// carryword1 = a0 + overflow
		asm("multu %0,%1;
		     mfhi $k0;
		     mflo $k1;
			 addu $s2, $s0, %2;
			 sltu $s0, $s2, $s0;
			 addu $k0, $k0, $s0;
			 addu $s2, $s2, $k1;
			 sltu $s0, $s2, $k1;
			 addu $s0, $s0, $k0" : : "d" (vec1_v), "d" (Scalar1), "d" (result_v) );
		
		// The following routine is equivalent to:
		// (mfhi,mflo) = vec2 * scalar2
		// temp = carryword2 + temp
		// carryword2 = 1 if overflow else 0 (based on result < operand if overflow)
		// k0 = mfhi + overflow
		// temp = temp + mflo
		// carryword2 = 1 if overflow else 0
		// carryword2 = a0 + overflow
		asm("multu %0,%1;
		     mfhi $k0;
		     mflo $k1;
			 addu $s2, $s1, $s2;
			 sltu $s1, $s2, $s1;
			 addu $k0, $k0, $s1;
			 addu $s2, $s2, $k1;
			 sltu $s1, $s2, $k1;
			 addu $s1, $s1, $k0" : : "d" (vec2_v), "d" (Scalar2) );
		
		*Result = temp;
		Result++;
	}
	
	result_v = *(Result + 1);
	
	// The following routine is equivalent to:
	// temp = carryword1 + *(result + 1)
	// carry1 = 1 if overflow else 0
	// temp = temp + carryword2
	// carry2 = 1 if overflow else 0
	// overflow = carry1 + carry2 (with overflow = carryword1)
	asm("addu $s2, $s0, %0;
	     sltu $s0, $s2, $s0;
		 addu $s2, $s2, $s1;
		 sltu $s1, $s2, $s1;
		 addu $s0, $s0, $s1" : : "d" (result_v) );
	
	*Result = temp;
	Result++;
	*Result = CarryWord1;

}
*/


//----------------------------------------------------
//
// Name: MnMontMul
//       Multiplies two numbers together in Montgomery space
// Returns: void
//
// Parameter	Flow	Description
// ----------------------------------------------
// z			in		answer        (ptr to Bignum_t)
// x			out		multiplier    (ptr to Bignum_t)
// y			in		multiplicand  (ptr to Bignum_t)
// ModInfo		in		Derived (from modulus) constants * 
//
// Additionnal information:
// Modulus         IN      The Modulus - global p
// Modulus_Digits  IN      The number of base b digits in the modulus
//  
//
//   The logic is (remembering n is the number of base b digits in the
//                 modulus):
//  
//     Acc=0
//
//     For( i = 0 to (n-1))
//     {
//        U[i] = ((x[i]*y[0] + Acc[0]) * m')mod b
// 
//        Acc = (Acc + (x[i] * y)+ (U[i]* Modulus))/b
//     }
//     if Acc > Modulus (Can't be >= 2* Modulus)
//        Acc -= Modulus
//     return Acc
//
//----------------------------------------------------

void MnMontMul (Bignum_t *z, Bignum_t *x, Bignum_t *y, MONTGOMERY_CONSTS *ModInfo)
{
	int i;
	unsigned int Acc[u32s_per_bignum + 1];


	// memset implementation
	for (i = 0; i < u32s_per_bignum + 1; i++)
	{
		Acc[i] = 0;
	}


	for (i = 0; i < ModInfo->Modulus_Digits; i++)
	{
      
		// u   = ((x[i] * y[0] + Acc[0]) * m') mod b  (where b == 2^32)
		// Acc = ((Acc + (y*x_i) + (modulus*U_i)) >> 32);

		unsigned int u = (((x->slice[i] * y->slice[0]) + Acc[0]) * ModInfo->Mod_dash);
      
		DualMacWithShift ((unsigned int *) &Acc, (unsigned int *) y, x->slice[i], (unsigned int *) &ModInfo->global_p, u);
	}


	// Acc is less than (2 * Modulus)
	// It is (at most) a (BITS_IN_Bignum_t+1) bit number
	// If it's too big for a Bignum_t then do Omura correction

	if (Acc[u32s_per_bignum] != 0)		// (0 and 1 are the only possibilities)
	{
		if (BnAcc ((Bignum_t *) &Acc, &ModInfo->Omura_corr))
		{
			//Check overflow from first correction, and correct for a second time if required.
			//(It may not be necessary to check for this as it may never be required... ?)
      
			BnAcc ((Bignum_t *) &Acc, &ModInfo->Omura_corr);
      }
   }

	// Now we can treat Acc as a BIGNUM.
      
	// To reduce out of range values mod p, we can either subtract the modulus,
	// or add the omura correction value and drop the resulting overflow.
      
	// (Note: the 'drop the resulting overflow' part happens automatically even
	// it the modulus is smaller than the size of a BIGNUM... because of the way
	// that the omura correction value is calculated - ie with leading 1's - the
	// overflow propagates off the end !! Cunning ehh ?).
      
	// We use an Add because the BnSub function is unlikely to be in the cache
	// at this point.

   if (BnCmp ((Bignum_t *)&Acc, &ModInfo->global_p) >= 0)
   {
      BnAcc ((Bignum_t *) &Acc, &ModInfo->Omura_corr);
   }
 #ifdef REMOVE_MEMCPY
     BnCpy((unsigned long*)z, (unsigned long*)&Acc);
 #else

	if(u32s_per_bignum == U32s_PER_BIGNUM)
   	memcpy (z, &Acc, BYTES_PER_BIGNUM);
	else
   	memcpy (z, &Acc, BYTES_PER_1KBITS);
 #endif
}


//----------------------------------------------------
//
// Name: MnConvToMont
//       Converts a number into Mont. space
// Returns: void
//
// Parameter	Flow	Description
// ----------------------------------------------
// out			out		The converted number
// in			in		The number to be converted (reduced Mod Modulus)
// ModInfo		in		Derived (from modulus) constants
//
// Additionnal information:
// Mod Info contains (among other things) 
// Modulus			The Modulus (any length) - global p
// R_power			(R=2^R_power)
// Omura_corr 		Omura correction factor.
//
// out = in * R mod Modulus
//
// Can be used to do conversions in place (i.e. out = in)
//
// This is really a dedicated one-bit multiplier. If the modulus
// is not a maximum number of bits number then we reduce every time we double.
// If modulus  is a max number of bits number we can get overflows so we do
// Omura corrections. 
//
//----------------------------------------------------

void MnConvToMont (Bignum_t *out, Bignum_t *in, MONTGOMERY_CONSTS *ModInfo)
{
	int shifts = ModInfo->R_power;

 #ifdef REMOVE_MEMCPY
     BnCpy((unsigned long *)out, (unsigned long *)in);
 #else

	if(u32s_per_bignum == U32s_PER_BIGNUM)
		memcpy (out, in, BYTES_PER_BIGNUM); 
	else
		memcpy (out, in, BYTES_PER_1KBITS); 
 #endif
  
	while (shifts > 0)
	{
		// Double the number
		int Carry = BnASL (out, out);

		// Do Omura Corrections....
		// if we double a number which is mod m and the answer 
		// is BITS_IN_Bignum_t+1 bits then the modulus is
		// BITS_IN_Bignum_t long so Omura will work...

		if (Carry != 0)
		{
			do
			{
				Carry = BnAcc (out, &ModInfo->Omura_corr);
			}
			while (Carry != 0);
		}

		shifts--;
      
		// Reduce our number mod m if the modulus is too small for Omura corrections
		// This is done EVERY time we go around the loop

		if (BnCmp (out, &ModInfo->global_p) >= 0)
		{
			BnDec (out, &ModInfo->global_p);
		}
	}
}
 
 
 
//----------------------------------------------------
//
// Name: MnInvModTwoToTheThirtyTwo
//       Specialist algorithm that returns the inverse of a number mod 2^32
//       (a ULONG is 32 bits long)
// Returns: the inverse
//
// Parameter	Flow	Description
// ----------------------------------------------
// x			in		The number to invert (mod 2^32)
//
// Additionnal information:
// This algorithm is based on the one shown on page 60 of the 
// "High Speed RSA Implementation Document" obtainable at:
// ftp://ftp.rsa.com/pub/pdfs/tr201.pdf
//
// It is based on the fact that for any number T:
// T mod 2^i = T mod 2^(i-1) OR 
//			(T mod 2^(i-1)) + 2^(i-1)
//
// Therefore T^-1 mod 2^i = T^-1 mod 2^(i-1) OR
//						 ((T^-1) mod 2^(i-1)) + 2^(i-1)
// (Test one and if it's not that one it must be the other!)
//
//----------------------------------------------------

unsigned int MnInvModTwoToTheThirtyTwo (unsigned int x)
{
	unsigned int y = 1;
	unsigned int previous_power;	// == 2^(i-1)
   
	for (previous_power = 2; previous_power; previous_power <<= 1)
	{
		if (((x * y) & ((previous_power << 1) - 1)) != 1)
		{
			y += previous_power;
		}
	}
   
	return (y);
}


