Using RSA for modulo-multiplication leads to error on Java Card

二次信任 提交于 2019-11-30 09:56:58

Below is a very simple unit test with a (hopefully) working variant of your code:

package test.java.so;

import java.math.BigInteger;
import java.util.Random;

import javacard.framework.JCSystem;
import javacard.framework.Util;
import javacard.security.KeyBuilder;
import javacard.security.RSAPublicKey;
import javacardx.crypto.Cipher;

import org.apache.commons.lang3.ArrayUtils;
import org.bouncycastle.util.Arrays;
import org.junit.Assert;
import org.junit.Test;

import sutil.test.AbstractTest;

public class So36966764_Test extends AbstractTest {

    private static final int NUM_BITS = 1024;

    // Dummy
    static class Configuration {
        public static final short LENGTH_MODULUS = NUM_BITS/8;
        public static final short LENGTH_RSAOBJECT_MODULUS = LENGTH_MODULUS;
        public static final short TEMP_OFFSET_MODULUS = 0;
        public static final short TEMP_OFFSET_RESULT = LENGTH_MODULUS;
    }

    private byte[] tempBuffer = JCSystem.makeTransientByteArray((short)(Configuration.TEMP_OFFSET_RESULT+Configuration.LENGTH_MODULUS), JCSystem.CLEAR_ON_DESELECT);
    private byte[] eempromTempBuffer = new byte[Configuration.LENGTH_MODULUS]; // Why EEPROM?
    private RSAPublicKey mRsaPublicKekForSquare = (RSAPublicKey)KeyBuilder.buildKey(KeyBuilder.TYPE_RSA_PUBLIC, (short)NUM_BITS, false);
    private Cipher mRsaCipherForSquaring = Cipher.getInstance(Cipher.ALG_RSA_NOPAD, false);

    // Assuming xLength==yLength==LENGTH_MODULUS
    public byte[] modMultiply(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength, short tempOutoffset) {

        //copy x value to temporary rambuffer
        Util.arrayCopy(x, xOffset, tempBuffer, tempOutoffset, xLength);

        // copy the y value to match th size of rsa_object
        Util.arrayFillNonAtomic(eempromTempBuffer, (short)0, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
        Util.arrayCopy(y,yOffset,eempromTempBuffer,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);

        // x+y
        if(add(x,xOffset,xLength, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
            subtract(x,xOffset,xLength, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }
        while(isGreater(x, xOffset, xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS)>0) {
            subtract(x,xOffset,xLength, tempBuffer,Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }

        //(x+y)2
        mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
        mRsaCipherForSquaring.doFinal(x, xOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, x, xOffset); // OK

        mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK

        if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
                Configuration.LENGTH_MODULUS)) {
            add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                    Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }

        /*WRONG OFFSET mRsaCipherForSquaring.doFinal(eempromTempBuffer, yOffset, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, yOffset); */
        mRsaCipherForSquaring.doFinal(eempromTempBuffer, (short)0, Configuration.LENGTH_RSAOBJECT_MODULUS, eempromTempBuffer, (short)0); //OK

        /*WRONG OFFSET if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, yOffset,*/
        if (subtract(x, xOffset, Configuration.LENGTH_MODULUS, eempromTempBuffer, (short)0,Configuration.LENGTH_MODULUS)) {
            add(x, xOffset, Configuration.LENGTH_MODULUS, tempBuffer,
                    Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        }
        // ((x+y)^2 - x^2 -y^2)/2
        modular_division_by_2(x, xOffset,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        return x;
    }

    public static boolean add(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
        short digit_mask = 0xff;
        short digit_len = 0x08;
        short result = 0;
        short i = (short) (xLength + xOffset - 1);
        short j = (short) (yLength + yOffset - 1);

        for (; i >= xOffset; i--, j--) {
            result = (short) (result + (short) (x[i] & digit_mask) + (short) (y[j] & digit_mask));

            x[i] = (byte) (result & digit_mask);
            result = (short) ((result >> digit_len) & digit_mask);
        }
        while (result > 0 && i >= xOffset) {
            result = (short) (result + (short) (x[i] & digit_mask));
            x[i] = (byte) (result & digit_mask);
            result = (short) ((result >> digit_len) & digit_mask);
            i--;
        }

        return result != 0;
    }

    public static boolean subtract(byte[] x, short xOffset, short xLength, byte[] y, short yOffset, short yLength) {
        short digit_mask = 0xff;
        short i = (short) (xLength + xOffset - 1);
        short j = (short) (yLength + yOffset - 1);
        short carry = 0;
        short subtraction_result = 0;

        for (; i >= xOffset && j >= yOffset; i--, j--) {
            subtraction_result = (short) ((x[i] & digit_mask)
                    - (y[j] & digit_mask) - carry);
            x[i] = (byte) (subtraction_result & digit_mask);
            carry = (short) (subtraction_result < 0 ? 1 : 0);
        }
        for (; i >= xOffset && carry > 0; i--) {
            if (x[i] != 0)
                carry = 0;
            x[i] -= 1;
        }

        return carry > 0;
    }

    public short isGreater(byte[] x,short xOffset,short xLength,byte[] y ,short yOffset,short yLength)
    {
        // Beware: this part is not tested
        while(xLength>yLength) {
            if(x[xOffset++]!=0x00) {
                return 1; // x is greater
            }
            xLength--;
        }
        while(yLength>xLength) {
            if(y[yOffset++]!=0x00) {
                return -1; // y is greater
            }
            yLength--;
        }
        // Beware: this part is not tested END
        for(short i = 0; i < xLength; i++) {
            if (x[xOffset] != y[yOffset]) {
                short srcShort = (short)(x[xOffset]&(short)0xFF);
                short dstShort = (short)(y[yOffset]&(short)0xFF);
                return ( ((srcShort > dstShort) ? (byte)1 : (byte)-1));
            }
            xOffset++;
            yOffset++;
        }
        return 0;
    }

    private void modular_division_by_2(byte[] input, short inOffset, short inLength, byte[] modulus, short modulusOffset, short modulusLength) {
        short carry = 0;
        short digit_mask = 0xff;
        short digit_first_bit_mask = 0x80;
        short lastIndex = (short) (inOffset + inLength - 1);

        short i = inOffset;
        if ((byte) (input[lastIndex] & 0x01) != 0) {
            if (add(input, inOffset, inLength, modulus, modulusOffset,
                    modulusLength)) {
                carry = digit_first_bit_mask;
            }
        }

        for (; i <= lastIndex; i++) {
            if ((input[i] & 0x01) == 0) {
                input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
                carry = 0;
            } else {
                input[i] = (byte) (((input[i] & digit_mask) >> 1) | carry);
                carry = digit_first_bit_mask;
            }
        }
    }

    @Test
    public void testModMultiply() {
        Random r = new Random(12345L);
        for(int iiii=0;iiii<10;iiii++) {
            BigInteger modulus = BigInteger.probablePrime(NUM_BITS, r);
            System.out.println(" M = " + modulus);
            byte[] modulusBytes = normalize(modulus.toByteArray());
            Util.arrayCopyNonAtomic(modulusBytes, (short)0, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);

            mRsaPublicKekForSquare.setModulus(modulusBytes, (short)0, (short)modulusBytes.length);
            mRsaPublicKekForSquare.setExponent(new byte[] {0x02}, (short)0, (short)1);

            for(int iii=0;iii<1000;iii++) {
                BigInteger x = new BigInteger(NUM_BITS, r).mod(modulus);
                System.out.println(" x = " + x);
                BigInteger y = new BigInteger(NUM_BITS, r).mod(modulus);
                System.out.println(" y = " + y);
                BigInteger accResult;
                {
                    byte[] xBytes = normalize(x.toByteArray());
                    byte[] yBytes = normalize(y.toByteArray());
                    byte[] accResultBytes = modMultiply(xBytes, (short)0, (short)xBytes.length, yBytes, (short)0, (short)yBytes.length, Configuration.TEMP_OFFSET_RESULT);
                    accResult = new BigInteger(1, accResultBytes);
                }
                System.out.println(" Qr = " + accResult);
                BigInteger realResult = x.multiply(y).mod(modulus);
                System.out.println(" Rr = " + realResult);
                Assert.assertEquals(realResult, accResult);
            }
        }
    }

    private byte[] normalize(byte[] xBytes) {
        if(xBytes.length<Configuration.LENGTH_MODULUS) {
            xBytes = ArrayUtils.addAll(new byte[Configuration.LENGTH_MODULUS-xBytes.length], xBytes);
        }
        if(xBytes.length>Configuration.LENGTH_MODULUS) {
            Assert.assertEquals(xBytes[0], 0x00);
            xBytes=Arrays.copyOfRange(xBytes, 1, xBytes.length);
        }
        return xBytes;
    }
}

What was (IMHO) wrong:

  1. The isGreater() method -- although it is possible to use subtraction to compare numbers, it is much easier (and faster) to compare corresponding bytes starting from the most significant one and stop on the first mismatch. (In the subtraction case you would need to complete the subtraction and return the sign of the final result -- your original code ends on first "mismatch".)

  2. x+y overflow -- you should have kept the modulus subtraction for the addition overflow case (i.e. when add() returns true) in your last edit.

  3. Offsets into eempromTempBuffer -- on two places you used yOffset and should have used 0 (commented out with a "WRONG OFFSET").

  4. Casting Configuration.LENGTH_RSAOBJECT_MODULUS-1 to byte is not a good idea for larger values of modulus length

Some (random) comments:

  • the test uses already mentioned jcardsim to work

  • the code assumes that lengths of x and y are both LENGTH_MODULUS (as well as LENGTH_RSAOBJECT_MODULUS being equal to LENGTH_MODULUS)

  • it is not a good idea to have eempromTempBuffer in a non-volatile memory

  • your code is VERY similar to this code which is interesting

  • an interesting read regarding this topic is here (section 4.2.3).

Good luck!

Disclaimer: I am not a crypto expert nor mathematician so please do validate my thoughts

Alberto12

I managed to solve the problem by changing the mathematical formula of multiplication.I posted below the updated code.

private byte[] multiply(byte[] x, short xOffset, short xLength, byte[] y,
        short yOffset, short yLength,short tempOutoffset)
{
    normalize();
    //copy x value to temporary rambuffer
    Util.arrayFillNonAtomic(tempBuffer, tempOutoffset,(short) (Configuration.LENGTH_RSAOBJECT_MODULUS+tempOutoffset),(byte)0x00);
    Util.arrayCopy(x, xOffset, tempBuffer, (short)(Configuration.LENGTH_RSAOBJECT_MODULUS - xLength), xLength);

    // copy the y value to match th size of rsa_object
    Util.arrayFillNonAtomic(ram_y, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
    Util.arrayCopy(y,yOffset,ram_y,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);

    Util.arrayFillNonAtomic(ram_y_prime, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
    Util.arrayCopy(y,yOffset,ram_y_prime,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - yLength),yLength);

    Util.arrayFillNonAtomic(ram_x, IConsts.OFFSET_START, (short) (Configuration.LENGTH_RSAOBJECT_MODULUS-1),(byte)0x00);
    Util.arrayCopy(x,xOffset,ram_x,(short)(Configuration.LENGTH_RSAOBJECT_MODULUS - xLength),xLength);

    // if x>y
    if(this.isGreater(ram_x, IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,IConsts.OFFSET_START, Configuration.LENGTH_MODULUS)>0)
    {

        // x <- x-y
        JBigInteger.subtract(ram_x,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,
                IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS);
    }
    else
    {

        // y <- y-x
        JBigInteger.subtract(ram_y_prime,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS, ram_x,
                IConsts.OFFSET_START, Configuration.LENGTH_MODULUS);
         // ramy stores the (y-x) values copy value to ram_x
        Util.arrayCopy(ram_y_prime, IConsts.OFFSET_START,ram_x,IConsts.OFFSET_START,Configuration.LENGTH_RSAOBJECT_MODULUS);

    }

        //|x-y|2
        mRsaCipherForSquaring.init(mRsaPublicKekForSquare, Cipher.MODE_ENCRYPT);
        mRsaCipherForSquaring.doFinal(ram_x, IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_x,
                IConsts.OFFSET_START); // OK

        // x^2
        mRsaCipherForSquaring.doFinal(tempBuffer, tempOutoffset, Configuration.LENGTH_RSAOBJECT_MODULUS, tempBuffer, tempOutoffset); // OK

        // y^2
        mRsaCipherForSquaring.doFinal(ram_y,IConsts.OFFSET_START, Configuration.LENGTH_RSAOBJECT_MODULUS, ram_y,IConsts.OFFSET_START); //OK 



        if (JBigInteger.add(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer, tempOutoffset,
                Configuration.LENGTH_MODULUS)) {
              // y^2 + x^2 
            JBigInteger.subtract(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer,
                    Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
        } 


        //  x^2 + y^2
        if (JBigInteger.subtract(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, ram_x, IConsts.OFFSET_START,
                Configuration.LENGTH_MODULUS)) {

            JBigInteger.add(ram_y, IConsts.OFFSET_START, Configuration.LENGTH_MODULUS, tempBuffer,
                    Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
    }
    // (x^2 + y^2 - (x-y)^2)/2
   JBigInteger.modular_division_by_2(ram_y, IConsts.OFFSET_START,Configuration. LENGTH_MODULUS, tempBuffer, Configuration.TEMP_OFFSET_MODULUS, Configuration.LENGTH_MODULUS);
   return ram_y;
}

The problem was that for some numbers for same numbers a and b on 1024 bits the sum a+b overcome the value p of the modulus.In above code I subtract from a+b the p value in order to make the RSA functioning.But this thing is not mathematically correct because (a+b)^2 mod p is different from ((a+b) mod p)^2 mod p . By changing the formula from ((x+y)^2 -x^2 -y^2)/2 to (x^2 + y^2 - (x-y)^2)/2 I was sure I will never have overflow because a-b is smaller than p. Based on link above I changed the code moving all the operations in RAM.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!