I\'m looking for a fast method to efficiently compute (a
⋅b
) modulo n
(in the mathematical sense of that) for
Ok, how about this (not tested)
modmul:
; rcx = a
; rdx = b
; r8 = n
mov rax, rdx
mul rcx
div r8
mov rax, rdx
ret
The precondition is that a * b / n <= ~0ULL
, otherwise there will be a divide error. That's a slightly less strict condition than a < n && m < n
, one of them can be bigger than n
as long as the other is small enough.
Unfortunately it has to be assembled and linked in separately, because MSVC doesn't support inline asm for 64bit targets.
It's also still slow, the real problem is that 64bit div
, which can take nearly a hundred cycles (seriously, up to 90 cycles on Nehalem for example).
You could do it the old-fashioned way with shift/add/subtract. The below code assumes a
< n
and
n
< 263 (so things don't overflow):
uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
uint64_t rv = 0;
while (b) {
if (b&1)
if ((rv += a) >= n) rv -= n;
if ((a += a) >= n) a -= n;
b >>= 1; }
return rv;
}
You could use while (a && b)
for the loop instead to short-circuit things if it's likely that a
will be a factor of n
. Will be slightly slower (more comparisons and likely correctly predicted branches) if a
is not a factor of n
.
If you really, absolutely, need that last bit (allowing n
up to 264-1), you can use:
uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
uint64_t rv = 0;
while (b) {
if (b&1) {
rv += a;
if (rv < a || rv >= n) rv -= n; }
uint64_t t = a;
a += a;
if (a < t || a >= n) a -= n;
b >>= 1; }
return rv;
}
Alternately, just use GCC instrinsics to access the underlying x64 instructions:
inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
uint64_t rv;
asm ("mul %3" : "=d"(rv), "=a"(a) : "1"(a), "r"(b));
asm ("div %4" : "=d"(rv), "=a"(a) : "0"(rv), "1"(a), "r"(n));
return rv;
}
The 64-bit div instruction is really slow, however, so the loop might actually be faster. You'd need to profile to be sure.
This intrinsic is named __mul128.
typedef unsigned long long BIG;
// handles only the "hard" case when high bit of n is set
BIG shl_mod( BIG v, BIG n, int by )
{
if (v > n) v -= n;
while (by--) {
if (v > (n-v))
v -= n-v;
else
v <<= 1;
}
return v;
}
Now you can use shl_mod(B, n, 64)
Having no inline assembly kind of sucks. Anyway, the function call overhead is actually extremely small. Parameters are passed in volatile registers and no cleanup is needed.
I don't have an assembler, and x64 targets don't support __asm, so I had no choice but to "assemble" my function from opcodes myself.
Obviously it depends on . I'm using mpir (gmp) as a reference to show the function produces correct results.
#include "stdafx.h"
// mulmod64(a, b, m) == (a * b) % m
typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m);
uint8_t mulmod64_opcodes[] = {
0x48, 0x89, 0xC8, // mov rax, rcx
0x48, 0xF7, 0xE2, // mul rdx
0x4C, 0x89, 0xC1, // mov rcx, r8
0x48, 0xF7, 0xF1, // div rcx
0x48, 0x89, 0xD0, // mov rax,rdx
0xC3 // ret
};
mulmod64_fnptr_t mulmod64_fnptr;
void init() {
DWORD dwOldProtect;
VirtualProtect(
&mulmod64_opcodes,
sizeof(mulmod64_opcodes),
PAGE_EXECUTE_READWRITE,
&dwOldProtect);
// NOTE: reinterpret byte array as a function pointer
mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes;
}
int main() {
init();
uint64_t a64 = 2139018971924123ull;
uint64_t b64 = 1239485798578921ull;
uint64_t m64 = 8975489368910167ull;
// reference code
mpz_t a, b, c, m, r;
mpz_inits(a, b, c, m, r, NULL);
mpz_set_ui(a, a64);
mpz_set_ui(b, b64);
mpz_set_ui(m, m64);
mpz_mul(c, a, b);
mpz_mod(r, c, m);
gmp_printf("(%Zd * %Zd) mod %Zd = %Zd\n", a, b, m, r);
// using mulmod64
uint64_t r64 = mulmod64_fnptr(a64, b64, m64);
printf("(%llu * %llu) mod %llu = %llu\n", a64, b64, m64, r64);
return 0;
}
7 years later, I got a solution working in Visual Studio 2019
#include <stdint.h>
#include <intrin.h>
#pragma intrinsic(_umul128)
#pragma intrinsic(_udiv128)
// compute (a*b)%n with 128-bit intermediary result
// assumes n>0 and a*b < n * 2**64 (always the case when a<=n || b<=n )
inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
uint64_t r, s = _umul128(a, b, &r);
(void)_udiv128(r, s, n, &r);
return r;
}
// compute (a*b)%n with 128-bit intermediary result
// assumes n>0, works including if a*b >= n * 2**64
inline uint64_t mulmod1(uint64_t a, uint64_t b, uint64_t n) {
uint64_t r, s = _umul128(a % n, b, &r);
(void)_udiv128(r, s, n, &r);
return r;
}