Using RSA for modulo-multiplication leads to error on Java Card

后端 未结 2 1358
别那么骄傲
别那么骄傲 2021-01-01 03:13

Hello I\'m working on a project on Java Card which implies a lot of modulo-multiplication. I managed to implement an modulo-multiplication on this platform using RSA cryptos

相关标签:
2条回答
  • 2021-01-01 04:04

    I managed to solve the problem by changing the mathematical formula of multiplication.I posted below the updated code.

    private byte[] multiply(byte[] x, short xOffset, short xLength, byte[] y,
            short yOffset, short yLength,short tempOutoffset)
    {
        normalize();
        //copy x value to temporary rambuffer
        Util.arrayFillNonAtomic(tempBuffer, tempOutoffset,(short) (Configuration.LENGTH_RSAOBJECT_MODULUS+tempOutoffset),(byte)0x00);
        Util.arrayCopy(x, xOffset, tempBuffer, (short)(Configuration.LENGTH_RSAOBJECT_MODULUS - xLength), xLength);
    
        // copy the y value to match th size of rsa_object
        Util.arrayFillNonAtomic(ram_y, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
        Util.arrayCopy(y,yOffset,ram_y,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
    
        Util.arrayFillNonAtomic(ram_y_prime, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
        Util.arrayCopy(y,yOffset,ram_y_prime,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
    
        Util.arrayFillNonAtomic(ram_x, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
        Util.arrayCopy(x,xOffset,ram_x,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - xLength),xLength);
    
        // if x>y
        if(this.isGreater(ram_x, IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,IConsts.OFFSET_START, Configuration.LENGTH_MODULUS)>0)
        {
    
            // x <- x-y
            JBigInteger.subtract(ram_x,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,
                    IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS);
        }
        else
        {
    
            // y <- y-x
            JBigInteger.subtract(ram_y_prime,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS, ram_x,
                    IConsts.OFFSET_START, Configuration.LENGTH_MODULUS);
             // ramy stores the (y-x) values copy value to ram_x
            Util.arrayCopy(ram_y_prime, IConsts.OFFSET_START,ram_x,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS);
    
        }
    
            //|x-y|2
            mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
            mRsaCipherForSquaring.doFinal(ram_x, IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_x,
                    IConsts.OFFSET_START); // OK
    
            // x^2
            mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK
    
            // y^2
            mRsaCipherForSquaring.doFinal(ram_y,IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,IConsts.OFFSET_START); //OK 
    
    
    
            if (JBigInteger.add(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
                    Configuration.LENGTH_MODULUS)) {
                  // y^2 + x^2 
                JBigInteger.subtract(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer,
                        Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
            } 
    
    
            //  x^2 + y^2
            if (JBigInteger.subtract(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, ram_x, IConsts.OFFSET_START,
                    Configuration.LENGTH_MODULUS)) {
    
                JBigInteger.add(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer,
                        Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }
        // (x^2 + y^2 - (x-y)^2)/2
       JBigInteger.modular_division_by_2(ram_y, IConsts.OFFSET_START,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
       return ram_y;
    }
    

    The problem was that for some numbers for same numbers a and b on 1024 bits the sum a+b overcome the value p of the modulus.In above code I subtract from a+b the p value in order to make the RSA functioning.But this thing is not mathematically correct because (a+b)^2 mod p is different from ((a+b) mod p)^2 mod p . By changing the formula from ((x+y)^2 -x^2 -y^2)/2 to (x^2 + y^2 - (x-y)^2)/2 I was sure I will never have overflow because a-b is smaller than p. Based on link above I changed the code moving all the operations in RAM.

    0 讨论(0)
  • 2021-01-01 04:08

    Below is a very simple unit test with a (hopefully) working variant of your code:

    package test.java.so;
    
    import java.math.BigInteger;
    import java.util.Random;
    
    import javacard.framework.JCSystem;
    import javacard.framework.Util;
    import javacard.security.KeyBuilder;
    import javacard.security.RSAPublicKey;
    import javacardx.crypto.Cipher;
    
    import org.apache.commons.lang3.ArrayUtils;
    import org.bouncycastle.util.Arrays;
    import org.junit.Assert;
    import org.junit.Test;
    
    import sutil.test.AbstractTest;
    
    public class So36966764_Test extends AbstractTest {
    
        private static final int NUM_BITS = 1024;
    
        // Dummy
        static class Configuration {
            public static final short LENGTH_MODULUS = NUM_BITS/8;
            public static final short LENGTH_RSAOBJECT_MODULUS = LENGTH_MODULUS;
            public static final short TEMP_OFFSET_MODULUS = 0;
            public static final short TEMP_OFFSET_RESULT = LENGTH_MODULUS;
        }
    
        private byte[] tempBuffer = JCSystem.makeTransientByteArray((short)(Configuration.TEMP_OFFSET_RESULT+Configuration.LENGTH_MODULUS), JCSystem.CLEAR_ON_DESELECT);
        private byte[] eempromTempBuffer = new byte[Configuration.LENGTH_MODULUS]; // Why EEPROM?
        private RSAPublicKey mRsaPublicKekForSquare = (RSAPublicKey)KeyBuilder.buildKey(KeyBuilder.TYPE_RSA_PUBLIC, (short)NUM_BITS, false);
        private Cipher mRsaCipherForSquaring = Cipher.getInstance(Cipher.ALG_RSA_NOPAD, false);
    
        // Assuming xLength==yLength==LENGTH_MODULUS
        public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength, short tempOutoffset) {
    
            //copy x value to temporary rambuffer
            Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);
    
            // copy the y value to match th size of rsa_object
            Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
            Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);
    
            // x+y
            if(add(x,xOffset,xLength, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
                subtract(x,xOffset,xLength, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
            }
            while(isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0) {
                subtract(x,xOffset,xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
            }
    
            //(x+y)2
            mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
            mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x, xOffset); // OK
    
            mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK
    
            if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
                    Configuration.LENGTH_MODULUS)) {
                add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                        Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
            }
    
            /*WRONG OFFSET mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); */
            mRsaCipherForSquaring.doFinal(eempromTempBuffer, (short)0, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, (short)0); //OK
    
            /*WRONG OFFSET if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,*/
            if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
                add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                        Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
            }
            // ((x+y)^2 - x^2 -y^2)/2
            modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
            return x;
        }
    
        public static boolean add(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
            short digit_mask = 0xff;
            short digit_len = 0x08;
            short result = 0;
            short i = (short) (xLength + xOffset - 1);
            short j = (short) (yLength + yOffset - 1);
    
            for (; i >= xOffset; i--, j--) {
                result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));
    
                x[i] = (byte) (result & digit_mask);
                result = (short) ((result >> digit_len) & digit_mask);
            }
            while (result > 0 && i >= xOffset) {
                result = (short) (result + (short) (x[i] & digit_mask));
                x[i] = (byte) (result & digit_mask);
                result = (short) ((result >> digit_len) & digit_mask);
                i--;
            }
    
            return result != 0;
        }
    
        public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
            short digit_mask = 0xff;
            short i = (short) (xLength + xOffset - 1);
            short j = (short) (yLength + yOffset - 1);
            short carry = 0;
            short subtraction_result = 0;
    
            for (; i >= xOffset && j >= yOffset; i--, j--) {
                subtraction_result = (short) ((x[i] & digit_mask)
                        - (y[j] & digit_mask) - carry);
                x[i] = (byte) (subtraction_result & digit_mask);
                carry = (short) (subtraction_result < 0 ? 1 : 0);
            }
            for (; i >= xOffset && carry > 0; i--) {
                if (x[i] != 0)
                    carry = 0;
                x[i] -= 1;
            }
    
            return carry > 0;
        }
    
        public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
        {
            // Beware: this part is not tested
            while(xLength>yLength) {
                if(x[xOffset++]!=0x00) {
                    return 1; // x is greater
                }
                xLength--;
            }
            while(yLength>xLength) {
                if(y[yOffset++]!=0x00) {
                    return -1; // y is greater
                }
                yLength--;
            }
            // Beware: this part is not tested END
            for(short i = 0; i < xLength; i++) {
                if (x[xOffset] != y[yOffset]) {
                    short srcShort = (short)(x[xOffset]&(short)0xFF);
                    short dstShort = (short)(y[yOffset]&(short)0xFF);
                    return ( ((srcShort > dstShort) ? (byte)1 : (byte)-1));
                }
                xOffset++;
                yOffset++;
            }
            return 0;
        }
    
        private void modular_division_by_2(byte[] input, short inOffset, short inLength, byte[] modulus, short modulusOffset, short modulusLength) {
            short carry = 0;
            short digit_mask = 0xff;
            short digit_first_bit_mask = 0x80;
            short lastIndex = (short) (inOffset + inLength - 1);
    
            short i = inOffset;
            if ((byte) (input[lastIndex] & 0x01) != 0) {
                if (add(input, inOffset, inLength, modulus, modulusOffset,
                        modulusLength)) {
                    carry = digit_first_bit_mask;
                }
            }
    
            for (; i <= lastIndex; i++) {
                if ((input[i] & 0x01) == 0) {
                    input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
                    carry = 0;
                } else {
                    input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
                    carry = digit_first_bit_mask;
                }
            }
        }
    
        @Test
        public void testModMultiply() {
            Random r = new Random(12345L);
            for(int iiii=0;iiii<10;iiii++) {
                BigInteger modulus = BigInteger.probablePrime(NUM_BITS, r);
                System.out.println(" M = " + modulus);
                byte[] modulusBytes = normalize(modulus.toByteArray());
                Util.arrayCopyNonAtomic(modulusBytes, (short)0, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
    
                mRsaPublicKekForSquare.setModulus(modulusBytes, (short)0, (short)modulusBytes.length);
                mRsaPublicKekForSquare.setExponent(new byte[] {0x02}, (short)0, (short)1);
    
                for(int iii=0;iii<1000;iii++) {
                    BigInteger x = new BigInteger(NUM_BITS, r).mod(modulus);
                    System.out.println(" x = " + x);
                    BigInteger y = new BigInteger(NUM_BITS, r).mod(modulus);
                    System.out.println(" y = " + y);
                    BigInteger accResult;
                    {
                        byte[] xBytes = normalize(x.toByteArray());
                        byte[] yBytes = normalize(y.toByteArray());
                        byte[] accResultBytes = modMultiply(xBytes, (short)0, (short)xBytes.length, yBytes, (short)0, (short)yBytes.length, Configuration.TEMP_OFFSET_RESULT);
                        accResult = new BigInteger(1, accResultBytes);
                    }
                    System.out.println(" Qr = " + accResult);
                    BigInteger realResult = x.multiply(y).mod(modulus);
                    System.out.println(" Rr = " + realResult);
                    Assert.assertEquals(realResult, accResult);
                }
            }
        }
    
        private byte[] normalize(byte[] xBytes) {
            if(xBytes.length<Configuration.LENGTH_MODULUS) {
                xBytes = ArrayUtils.addAll(new byte[Configuration.LENGTH_MODULUS-xBytes.length], xBytes);
            }
            if(xBytes.length>Configuration.LENGTH_MODULUS) {
                Assert.assertEquals(xBytes[0], 0x00);
                xBytes=Arrays.copyOfRange(xBytes, 1, xBytes.length);
            }
            return xBytes;
        }
    }
    

    What was (IMHO) wrong:

    1. The isGreater() method -- although it is possible to use subtraction to compare numbers, it is much easier (and faster) to compare corresponding bytes starting from the most significant one and stop on the first mismatch. (In the subtraction case you would need to complete the subtraction and return the sign of the final result -- your original code ends on first "mismatch".)

    2. x+y overflow -- you should have kept the modulus subtraction for the addition overflow case (i.e. when add() returns true) in your last edit.

    3. Offsets into eempromTempBuffer -- on two places you used yOffset and should have used 0 (commented out with a "WRONG OFFSET").

    4. Casting Configuration.LENGTH_RSAOBJECT_MODULUS-1 to byte is not a good idea for larger values of modulus length

    Some (random) comments:

    • the test uses already mentioned jcardsim to work

    • the code assumes that lengths of x and y are both LENGTH_MODULUS (as well as LENGTH_RSAOBJECT_MODULUS being equal to LENGTH_MODULUS)

    • it is not a good idea to have eempromTempBuffer in a non-volatile memory

    • your code is VERY similar to this code which is interesting

    • an interesting read regarding this topic is here (section 4.2.3).

    Good luck!

    Disclaimer: I am not a crypto expert nor mathematician so please do validate my thoughts

    0 讨论(0)
提交回复
热议问题