Hello I'm working on a project on Java Card which implies a lot of modulo-multiplication. I managed to implement an modulo-multiplication on this platform using RSA cryptosystem but it seems to work for certain numbers.
public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y,
short yOffset, short yLength, short tempOutoffset) {
//copy x value to temporary rambuffer
Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);
// copy the y value to match th size of rsa_object
Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (byte) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
// x+y
if (JBigInteger.add(x,xOffset,xLength, eempromTempBuffer,
(short)0,Configuration.LENGTH_MODULUS)) ;
if(this.isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0)
{
JBigInteger.subtract(x,xOffset,xLength, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
//(x+y)2
mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x,
xOffset); // OK
mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK
if (JBigInteger.subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
Configuration.LENGTH_MODULUS)) {
JBigInteger.add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); //OK
if (JBigInteger.subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,
Configuration.LENGTH_MODULUS)) {
JBigInteger.add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
// ((x+y)^2 - x^2 -y^2)/2
JBigInteger.modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
return x;
}
public static boolean add(byte[] x, short xOffset, short xLength, byte[] y,
short yOffset, short yLength) {
short digit_mask = 0xff;
short digit_len = 0x08;
short result = 0;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
for (; i >= xOffset; i--, j--) {
result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
}
while (result > 0 && i >= xOffset) {
result = (short) (result + (short) (x[i] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
i--;
}
return result != 0;
}
public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y,
short yOffset, short yLength) {
short digit_mask = 0xff;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
short carry = 0;
short subtraction_result = 0;
for (; i >= xOffset && j >= yOffset; i--, j--) {
subtraction_result = (short) ((x[i] & digit_mask)
- (y[j] & digit_mask) - carry);
x[i] = (byte) (subtraction_result & digit_mask);
carry = (short) (subtraction_result < 0 ? 1 : 0);
}
for (; i >= xOffset && carry > 0; i--) {
if (x[i] != 0)
carry = 0;
x[i] -= 1;
}
return carry > 0;
}
public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
{
if(xLength > yLength)
return (short)1;
if(xLength < yLength)
return (short)(-1);
short digit_mask = 0xff;
short digit_len = 0x08;
short result = 0;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
for (; i >= xOffset; i--, j--) {
result = (short) (result + (short) (x[i] & digit_mask) - (short) (y[j] & digit_mask));
if(result > 0)
return (short)1;
if(result < 0)
return (short)-1;
}
return 0;
}
The code works well for little number but fails on bigger one
Below is a very simple unit test with a (hopefully) working variant of your code:
package test.java.so;
import java.math.BigInteger;
import java.util.Random;
import javacard.framework.JCSystem;
import javacard.framework.Util;
import javacard.security.KeyBuilder;
import javacard.security.RSAPublicKey;
import javacardx.crypto.Cipher;
import org.apache.commons.lang3.ArrayUtils;
import org.bouncycastle.util.Arrays;
import org.junit.Assert;
import org.junit.Test;
import sutil.test.AbstractTest;
public class So36966764_Test extends AbstractTest {
private static final int NUM_BITS = 1024;
// Dummy
static class Configuration {
public static final short LENGTH_MODULUS = NUM_BITS/8;
public static final short LENGTH_RSAOBJECT_MODULUS = LENGTH_MODULUS;
public static final short TEMP_OFFSET_MODULUS = 0;
public static final short TEMP_OFFSET_RESULT = LENGTH_MODULUS;
}
private byte[] tempBuffer = JCSystem.makeTransientByteArray((short)(Configuration.TEMP_OFFSET_RESULT+Configuration.LENGTH_MODULUS), JCSystem.CLEAR_ON_DESELECT);
private byte[] eempromTempBuffer = new byte[Configuration.LENGTH_MODULUS]; // Why EEPROM?
private RSAPublicKey mRsaPublicKekForSquare = (RSAPublicKey)KeyBuilder.buildKey(KeyBuilder.TYPE_RSA_PUBLIC, (short)NUM_BITS, false);
private Cipher mRsaCipherForSquaring = Cipher.getInstance(Cipher.ALG_RSA_NOPAD, false);
// Assuming xLength==yLength==LENGTH_MODULUS
public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength, short tempOutoffset) {
//copy x value to temporary rambuffer
Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);
// copy the y value to match th size of rsa_object
Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
// x+y
if(add(x,xOffset,xLength, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
subtract(x,xOffset,xLength, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
while(isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0) {
subtract(x,xOffset,xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
//(x+y)2
mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x, xOffset); // OK
mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK
if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
Configuration.LENGTH_MODULUS)) {
add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
/*WRONG OFFSET mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); */
mRsaCipherForSquaring.doFinal(eempromTempBuffer, (short)0, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, (short)0); //OK
/*WRONG OFFSET if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,*/
if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
// ((x+y)^2 - x^2 -y^2)/2
modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
return x;
}
public static boolean add(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
short digit_mask = 0xff;
short digit_len = 0x08;
short result = 0;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
for (; i >= xOffset; i--, j--) {
result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
}
while (result > 0 && i >= xOffset) {
result = (short) (result + (short) (x[i] & digit_mask));
x[i] = (byte) (result & digit_mask);
result = (short) ((result >> digit_len) & digit_mask);
i--;
}
return result != 0;
}
public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
short digit_mask = 0xff;
short i = (short) (xLength + xOffset - 1);
short j = (short) (yLength + yOffset - 1);
short carry = 0;
short subtraction_result = 0;
for (; i >= xOffset && j >= yOffset; i--, j--) {
subtraction_result = (short) ((x[i] & digit_mask)
- (y[j] & digit_mask) - carry);
x[i] = (byte) (subtraction_result & digit_mask);
carry = (short) (subtraction_result < 0 ? 1 : 0);
}
for (; i >= xOffset && carry > 0; i--) {
if (x[i] != 0)
carry = 0;
x[i] -= 1;
}
return carry > 0;
}
public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
{
// Beware: this part is not tested
while(xLength>yLength) {
if(x[xOffset++]!=0x00) {
return 1; // x is greater
}
xLength--;
}
while(yLength>xLength) {
if(y[yOffset++]!=0x00) {
return -1; // y is greater
}
yLength--;
}
// Beware: this part is not tested END
for(short i = 0; i < xLength; i++) {
if (x[xOffset] != y[yOffset]) {
short srcShort = (short)(x[xOffset]&(short)0xFF);
short dstShort = (short)(y[yOffset]&(short)0xFF);
return ( ((srcShort > dstShort) ? (byte)1 : (byte)-1));
}
xOffset++;
yOffset++;
}
return 0;
}
private void modular_division_by_2(byte[] input, short inOffset, short inLength, byte[] modulus, short modulusOffset, short modulusLength) {
short carry = 0;
short digit_mask = 0xff;
short digit_first_bit_mask = 0x80;
short lastIndex = (short) (inOffset + inLength - 1);
short i = inOffset;
if ((byte) (input[lastIndex] & 0x01) != 0) {
if (add(input, inOffset, inLength, modulus, modulusOffset,
modulusLength)) {
carry = digit_first_bit_mask;
}
}
for (; i <= lastIndex; i++) {
if ((input[i] & 0x01) == 0) {
input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
carry = 0;
} else {
input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
carry = digit_first_bit_mask;
}
}
}
@Test
public void testModMultiply() {
Random r = new Random(12345L);
for(int iiii=0;iiii<10;iiii++) {
BigInteger modulus = BigInteger.probablePrime(NUM_BITS, r);
System.out.println(" M = " + modulus);
byte[] modulusBytes = normalize(modulus.toByteArray());
Util.arrayCopyNonAtomic(modulusBytes, (short)0, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
mRsaPublicKekForSquare.setModulus(modulusBytes, (short)0, (short)modulusBytes.length);
mRsaPublicKekForSquare.setExponent(new byte[] {0x02}, (short)0, (short)1);
for(int iii=0;iii<1000;iii++) {
BigInteger x = new BigInteger(NUM_BITS, r).mod(modulus);
System.out.println(" x = " + x);
BigInteger y = new BigInteger(NUM_BITS, r).mod(modulus);
System.out.println(" y = " + y);
BigInteger accResult;
{
byte[] xBytes = normalize(x.toByteArray());
byte[] yBytes = normalize(y.toByteArray());
byte[] accResultBytes = modMultiply(xBytes, (short)0, (short)xBytes.length, yBytes, (short)0, (short)yBytes.length, Configuration.TEMP_OFFSET_RESULT);
accResult = new BigInteger(1, accResultBytes);
}
System.out.println(" Qr = " + accResult);
BigInteger realResult = x.multiply(y).mod(modulus);
System.out.println(" Rr = " + realResult);
Assert.assertEquals(realResult, accResult);
}
}
}
private byte[] normalize(byte[] xBytes) {
if(xBytes.length<Configuration.LENGTH_MODULUS) {
xBytes = ArrayUtils.addAll(new byte[Configuration.LENGTH_MODULUS-xBytes.length], xBytes);
}
if(xBytes.length>Configuration.LENGTH_MODULUS) {
Assert.assertEquals(xBytes[0], 0x00);
xBytes=Arrays.copyOfRange(xBytes, 1, xBytes.length);
}
return xBytes;
}
}
What was (IMHO) wrong:
The
isGreater()
method -- although it is possible to use subtraction to compare numbers, it is much easier (and faster) to compare corresponding bytes starting from the most significant one and stop on the first mismatch. (In the subtraction case you would need to complete the subtraction and return the sign of the final result -- your original code ends on first "mismatch".)x+y
overflow -- you should have kept the modulus subtraction for the addition overflow case (i.e. whenadd()
returns true) in your last edit.Offsets into
eempromTempBuffer
-- on two places you usedyOffset
and should have used0
(commented out with a "WRONG OFFSET").Casting
Configuration.LENGTH_RSAOBJECT_MODULUS-1
tobyte
is not a good idea for larger values of modulus length
Some (random) comments:
the test uses already mentioned jcardsim to work
the code assumes that lengths of
x
andy
are bothLENGTH_MODULUS
(as well asLENGTH_RSAOBJECT_MODULUS
being equal toLENGTH_MODULUS
)it is not a good idea to have
eempromTempBuffer
in a non-volatile memoryyour code is VERY similar to this code which is interesting
an interesting read regarding this topic is here (section 4.2.3).
Good luck!
Disclaimer: I am not a crypto expert nor mathematician so please do validate my thoughts
I managed to solve the problem by changing the mathematical formula of multiplication.I posted below the updated code.
private byte[] multiply(byte[] x, short xOffset, short xLength, byte[] y,
short yOffset, short yLength,short tempOutoffset)
{
normalize();
//copy x value to temporary rambuffer
Util.arrayFillNonAtomic(tempBuffer, tempOutoffset,(short) (Configuration.LENGTH_RSAOBJECT_MODULUS+tempOutoffset),(byte)0x00);
Util.arrayCopy(x, xOffset, tempBuffer, (short)(Configuration.LENGTH_RSAOBJECT_MODULUS - xLength), xLength);
// copy the y value to match th size of rsa_object
Util.arrayFillNonAtomic(ram_y, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
Util.arrayCopy(y,yOffset,ram_y,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
Util.arrayFillNonAtomic(ram_y_prime, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
Util.arrayCopy(y,yOffset,ram_y_prime,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
Util.arrayFillNonAtomic(ram_x, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
Util.arrayCopy(x,xOffset,ram_x,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - xLength),xLength);
// if x>y
if(this.isGreater(ram_x, IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,IConsts.OFFSET_START, Configuration.LENGTH_MODULUS)>0)
{
// x <- x-y
JBigInteger.subtract(ram_x,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,
IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS);
}
else
{
// y <- y-x
JBigInteger.subtract(ram_y_prime,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS, ram_x,
IConsts.OFFSET_START, Configuration.LENGTH_MODULUS);
// ramy stores the (y-x) values copy value to ram_x
Util.arrayCopy(ram_y_prime, IConsts.OFFSET_START,ram_x,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS);
}
//|x-y|2
mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
mRsaCipherForSquaring.doFinal(ram_x, IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_x,
IConsts.OFFSET_START); // OK
// x^2
mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK
// y^2
mRsaCipherForSquaring.doFinal(ram_y,IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,IConsts.OFFSET_START); //OK
if (JBigInteger.add(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
Configuration.LENGTH_MODULUS)) {
// y^2 + x^2
JBigInteger.subtract(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
// x^2 + y^2
if (JBigInteger.subtract(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, ram_x, IConsts.OFFSET_START,
Configuration.LENGTH_MODULUS)) {
JBigInteger.add(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer,
Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
}
// (x^2 + y^2 - (x-y)^2)/2
JBigInteger.modular_division_by_2(ram_y, IConsts.OFFSET_START,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
return ram_y;
}
The problem was that for some numbers for same numbers a
and b
on 1024 bits the sum a+b
overcome the value p
of the modulus.In above code I subtract from a+b
the p
value in order to make the RSA functioning.But this thing is not mathematically correct because (a+b)^2 mod p
is different from ((a+b) mod p)^2 mod p
. By changing the formula from ((x+y)^2 -x^2 -y^2)/2
to (x^2 + y^2 - (x-y)^2)/2
I was sure I will never have overflow because a-b
is smaller than p
. Based on link above I changed the code moving all the operations in RAM.
来源:https://stackoverflow.com/questions/36966764/using-rsa-for-modulo-multiplication-leads-to-error-on-java-card