Skip to content

Commit e9c711a

Browse files
committed
Implement faster multiplication using Number Theoretic Transform
Performs ntt with three primes (29<<27|1, 26<<27|1, 24<<27|1)
1 parent 99cc2d5 commit e9c711a

File tree

4 files changed

+253
-0
lines changed

4 files changed

+253
-0
lines changed

bigdecimal.gemspec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Gem::Specification.new do |s|
4646
ext/bigdecimal/feature.h
4747
ext/bigdecimal/missing.c
4848
ext/bigdecimal/missing.h
49+
ext/bigdecimal/ntt.h
4950
ext/bigdecimal/missing/dtoa.c
5051
ext/bigdecimal/static_assert.h
5152
]

ext/bigdecimal/bigdecimal.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
#include "bits.h"
3232
#include "static_assert.h"
3333

34+
#if SIZEOF_DECDIG == 4
35+
#define USE_NTT_MULTIPLICATION 1
36+
#include "ntt.h"
37+
#define NTT_MULTIPLICATION_THRESHOLD 100
38+
#endif
39+
3440
#define BIGDECIMAL_VERSION "3.2.2"
3541

3642
/* #define ENABLE_NUMERIC_STRING */
@@ -3251,6 +3257,25 @@ BigDecimal_vpmult(VALUE self, VALUE v) {
32513257
RB_GC_GUARD(b.bigdecimal);
32523258
return c.bigdecimal;
32533259
}
3260+
3261+
#if SIZEOF_DECDIG == 4
3262+
VALUE
3263+
BigDecimal_nttmult(VALUE self, VALUE v) {
3264+
BDVALUE a,b,c;
3265+
a = GetBDValueMust(self);
3266+
b = GetBDValueMust(v);
3267+
c = NewZeroWrap(1, VPMULT_RESULT_PREC(a.real, b.real) * BASE_FIG);
3268+
ntt_multiply(a.real->Prec, b.real->Prec, a.real->frac, b.real->frac, c.real->frac);
3269+
VpSetSign(c.real, a.real->sign * b.real->sign);
3270+
c.real->exponent = a.real->exponent + b.real->exponent;
3271+
c.real->Prec = a.real->Prec + b.real->Prec;
3272+
VpNmlz(c.real);
3273+
RB_GC_GUARD(a.bigdecimal);
3274+
RB_GC_GUARD(b.bigdecimal);
3275+
return c.bigdecimal;
3276+
}
3277+
#endif
3278+
32543279
#endif /* BIGDECIMAL_USE_VP_TEST_METHODS */
32553280

32563281
/* Document-class: BigDecimal
@@ -3623,6 +3648,9 @@ Init_bigdecimal(void)
36233648
#ifdef BIGDECIMAL_USE_VP_TEST_METHODS
36243649
rb_define_method(rb_cBigDecimal, "vpdivd", BigDecimal_vpdivd, 2);
36253650
rb_define_method(rb_cBigDecimal, "vpmult", BigDecimal_vpmult, 1);
3651+
#ifdef USE_NTT_MULTIPLICATION
3652+
rb_define_method(rb_cBigDecimal, "nttmult", BigDecimal_nttmult, 1);
3653+
#endif
36263654
#endif /* BIGDECIMAL_USE_VP_TEST_METHODS */
36273655

36283656
#define ROUNDING_MODE(i, name, value) \
@@ -4926,6 +4954,15 @@ VpMult(Real *c, Real *a, Real *b)
49264954
if (w) rbd_free_struct(c);
49274955
return 0;
49284956
}
4957+
4958+
#ifdef USE_NTT_MULTIPLICATION
4959+
if (b->Prec >= NTT_MULTIPLICATION_THRESHOLD) {
4960+
ntt_multiply((uint32_t)a->Prec, (uint32_t)b->Prec, a->frac, b->frac, c->frac);
4961+
c->Prec = a->Prec + b->Prec;
4962+
goto Cleanup;
4963+
}
4964+
#endif
4965+
49294966
carry = 0;
49304967
nc = ind_c = MxIndAB;
49314968
memset(c->frac, 0, (nc + 1) * sizeof(DECDIG)); /* Initialize c */
@@ -4972,6 +5009,8 @@ VpMult(Real *c, Real *a, Real *b)
49725009
}
49735010
}
49745011
}
5012+
5013+
Cleanup:
49755014
VpNmlz(c);
49765015
if (w != NULL) { /* free work variable */
49775016
VpAsgn(w, c, 10);

ext/bigdecimal/ntt.h

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// NTT (Number Theoretic Transform) implementation for BigDecimal multiplication
2+
3+
#define NTT_PRIMITIVE_ROOT 17
4+
#define NTT_PRIME_BASE1 24
5+
#define NTT_PRIME_BASE2 26
6+
#define NTT_PRIME_BASE3 29
7+
#define NTT_PRIME_SHIFT 27
8+
#define NTT_PRIME1 (((uint32_t)NTT_PRIME_BASE1 << NTT_PRIME_SHIFT) | 1)
9+
#define NTT_PRIME2 (((uint32_t)NTT_PRIME_BASE2 << NTT_PRIME_SHIFT) | 1)
10+
#define NTT_PRIME3 (((uint32_t)NTT_PRIME_BASE3 << NTT_PRIME_SHIFT) | 1)
11+
#define MAX_NTT32_BITS 27
12+
#define NTT_DECDIG_BASE 1000000000
13+
14+
// Calculates base**ex % mod
15+
static uint32_t
16+
mod_pow(uint32_t base, uint32_t ex, uint32_t mod) {
17+
uint32_t res = 1;
18+
uint32_t bit = 1;
19+
while (true) {
20+
if (ex & bit) {
21+
ex ^= bit;
22+
res = ((uint64_t)res * base) % mod;
23+
}
24+
if (!ex) break;
25+
base = ((uint64_t)base * base) % mod;
26+
bit <<= 1;
27+
}
28+
return res;
29+
}
30+
31+
// Recursively performs butterfly operations of NTT
32+
static void
33+
ntt_recursive(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int depth, uint32_t r, uint32_t prime) {
34+
if (depth > 0) {
35+
ntt_recursive(size_bits, input, tmp, output, depth - 1, ((uint64_t)r * r) % prime, prime);
36+
} else {
37+
tmp = input;
38+
}
39+
uint32_t size_half = (uint32_t)1 << (size_bits - 1);
40+
uint32_t stride = (uint32_t)1 << (size_bits - depth - 1);
41+
uint32_t n = size_half / stride;
42+
uint32_t rn = 1, rm = prime - 1;
43+
uint32_t idx = 0;
44+
for (uint32_t i = 0; i < n; i++) {
45+
uint32_t j = i * 2 * stride;
46+
for (uint32_t k = 0; k < stride; k++, j++, idx++) {
47+
uint32_t a = tmp[j], b = tmp[j + stride];
48+
output[idx] = (a + (uint64_t)rn * b) % prime;
49+
output[idx + size_half] = (a + (uint64_t)rm * b) % prime;
50+
}
51+
rn = ((uint64_t)rn * r) % prime;
52+
rm = ((uint64_t)rm * r) % prime;
53+
}
54+
}
55+
56+
/* Perform NTT on input array.
57+
* base, shift: Represent the prime number as (base << shift | 1)
58+
* r_base: Primitive root of unity modulo prime
59+
* size_bits: log2 of the size of the input array. Should be less or equal to shift
60+
* input: input array of size 1 << size_bits
61+
*/
62+
static void
63+
ntt(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int r_base, int base, int shift, int dir) {
64+
uint32_t size = (uint32_t)1 << size_bits;
65+
uint32_t prime = ((uint32_t)base << shift) | 1;
66+
67+
// rmax**(1 << shift) % prime == 1
68+
// r**size % prime == 1
69+
uint32_t rmax = mod_pow(r_base, base, prime);
70+
uint32_t r = mod_pow(rmax, (uint32_t)1 << (shift - size_bits), prime);
71+
72+
if (dir < 0) r = mod_pow(r, prime - 2, prime);
73+
ntt_recursive(size_bits, input, output, tmp, size_bits - 1, r, prime);
74+
if (dir < 0) {
75+
uint32_t n_inv = mod_pow((uint32_t)size, prime - 2, prime);
76+
for (uint32_t i = 0; i < size; i++) {
77+
output[i] = ((uint64_t)output[i] * n_inv) % prime;
78+
}
79+
}
80+
}
81+
82+
/* Calculate c that satisfies: c % PRIME1 == mod1 && c % PRIME2 == mod2 && c % PRIME3 == mod3
83+
* c = (mod1 * 35002755423056150739595925972 + mod2 * 14584479687667766215746868453 + mod3 * 37919651490985126265126719818) % (PRIME1 * PRIME2 * PRIME3)
84+
*/
85+
static inline void
86+
mod_restore_prime_24_26_29_shift_27(uint32_t mod1, uint32_t mod2, uint32_t mod3, uint32_t *digits) {
87+
// Use mixed radix notation to eliminate modulo by PRIME1 * PRIME2 * PRIME3
88+
// [DIG0, DIG1, DIG2] = DIG0 + DIG1 * PRIME1 + DIG2 * PRIME1 * PRIME2
89+
// DIG0: 0...PRIME1, DIG1: 0...PRIME2, DIG2: 0...PRIME3
90+
// 35002755423056150739595925972 = [1, 3489660916, 3113851359]
91+
// 14584479687667766215746868453 = [0, 13, 1297437912]
92+
// 37919651490985126265126719818 = [0, 0, 3373338954]
93+
uint64_t c0 = mod1;
94+
uint64_t c1 = (uint64_t)mod2 * 13 + (uint64_t)mod1 * 3489660916;
95+
uint64_t c2 = (uint64_t)mod3 * 3373338954 % NTT_PRIME3 + (uint64_t)mod2 * 1297437912 % NTT_PRIME3 + (uint64_t)mod1 * 3113851359 % NTT_PRIME3;
96+
c2 += c1 / NTT_PRIME2;
97+
c1 %= NTT_PRIME2;
98+
c2 %= NTT_PRIME3;
99+
// Base conversion
100+
c1 += c2 % NTT_DECDIG_BASE * NTT_PRIME2;
101+
c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1;
102+
c1 /= NTT_DECDIG_BASE;
103+
digits[0] = c0 % NTT_DECDIG_BASE;
104+
c0 /= NTT_DECDIG_BASE;
105+
c1 += c2 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME2;
106+
c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1;
107+
c1 /= NTT_DECDIG_BASE;
108+
digits[1] = c0 % NTT_DECDIG_BASE;
109+
c0 = c0 / NTT_DECDIG_BASE + c1 % NTT_DECDIG_BASE * NTT_PRIME1;
110+
digits[2] = c0 % NTT_DECDIG_BASE;
111+
digits[3] = (c0 / NTT_DECDIG_BASE + c1 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME1) % NTT_DECDIG_BASE;
112+
}
113+
114+
/*
115+
* NTT multiplication
116+
* Uses three NTTs with mod (24 << 27 | 1), (26 << 27 | 1), and (29 << 27 | 1)
117+
*/
118+
static void
119+
ntt_multiply(size_t a_size, size_t b_size, uint32_t *a, uint32_t *b, uint32_t *c) {
120+
if (a_size < b_size) {
121+
ntt_multiply(b_size, a_size, b, a, c);
122+
return;
123+
}
124+
125+
int b_bits = 0;
126+
while (((uint32_t)1 << b_bits) < (uint32_t)b_size) b_bits++;
127+
int ntt_size_bits = b_bits + 1;
128+
if (ntt_size_bits > MAX_NTT32_BITS) {
129+
rb_raise(rb_eArgError, "Multiply size too large");
130+
}
131+
132+
// To calculate large_a * small_b faster, split into several batches.
133+
uint32_t ntt_size = (uint32_t)1 << ntt_size_bits;
134+
uint32_t batch_size = ntt_size - (uint32_t)b_size;
135+
uint32_t batch_count = (uint32_t)((a_size + batch_size - 1) / batch_size);
136+
137+
uint32_t *ntt1 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
138+
uint32_t *ntt2 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
139+
uint32_t *ntt3 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
140+
uint32_t *tmp1 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
141+
uint32_t *tmp2 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
142+
uint32_t *tmp3 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
143+
uint32_t *conv1 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
144+
uint32_t *conv2 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
145+
uint32_t *conv3 = ruby_xcalloc(sizeof(uint32_t), ntt_size);
146+
147+
// Calculate NTT for b in three primes. Result is reused for each batch of a.
148+
memcpy(tmp1, b, b_size * sizeof(uint32_t));
149+
memset(tmp1 + b_size, 0, (ntt_size - b_size) * sizeof(uint32_t));
150+
ntt(ntt_size_bits, tmp1, ntt1, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1);
151+
ntt(ntt_size_bits, tmp1, ntt2, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1);
152+
ntt(ntt_size_bits, tmp1, ntt3, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1);
153+
154+
memset(c, 0, (a_size + b_size) * sizeof(uint32_t));
155+
for (uint32_t idx = 0; idx < batch_count; idx++) {
156+
if (idx == batch_count - 1) {
157+
uint32_t len = (uint32_t)a_size - idx * batch_size;
158+
memcpy(tmp1, a + idx * batch_size, len * sizeof(uint32_t));
159+
memset(tmp1 + len, 0, (ntt_size - len) * sizeof(uint32_t));
160+
} else {
161+
memcpy(tmp1, a + idx * batch_size, batch_size * sizeof(uint32_t));
162+
}
163+
// Calculate convolution for this batch in three primes
164+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1);
165+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt1[i]) % NTT_PRIME1;
166+
ntt(ntt_size_bits, tmp2, conv1, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, -1);
167+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1);
168+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt2[i]) % NTT_PRIME2;
169+
ntt(ntt_size_bits, tmp2, conv2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, -1);
170+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1);
171+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt3[i]) % NTT_PRIME3;
172+
ntt(ntt_size_bits, tmp2, conv3, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, -1);
173+
174+
// Restore the original convolution value from three convolutions calculated in three primes
175+
for (uint32_t i = 0; i < ntt_size; i++) {
176+
uint32_t dig[4];
177+
mod_restore_prime_24_26_29_shift_27(conv1[i], conv2[i], conv3[i], dig);
178+
for (int j = 0; j < 4; j++) {
179+
// Maximum overlap(4) * maximum_value(999999999) does not overflow 32-bit integer.
180+
// Index check: if dig[j] is non-zero, assign index is within valid range.
181+
if (dig[j]) c[idx * batch_size + i + 1 - j] += dig[j];
182+
}
183+
}
184+
}
185+
uint32_t carry = 0;
186+
for (int32_t i = (uint32_t)(a_size + b_size - 1); i >= 0; i--) {
187+
uint32_t v = c[i] + carry;
188+
c[i] = v % NTT_DECDIG_BASE;
189+
carry = v / NTT_DECDIG_BASE;
190+
}
191+
ruby_xfree(ntt1);
192+
ruby_xfree(ntt2);
193+
ruby_xfree(ntt3);
194+
ruby_xfree(tmp1);
195+
ruby_xfree(tmp2);
196+
ruby_xfree(tmp3);
197+
ruby_xfree(conv1);
198+
ruby_xfree(conv2);
199+
ruby_xfree(conv3);
200+
}

test/bigdecimal/test_vp_operation.rb

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ def setup
1313
end
1414
end
1515

16+
def ntt_mult_available?
17+
BASE_FIG == 9
18+
end
19+
1620
def test_vpmult
1721
assert_equal(BigDecimal('121932631112635269'), BigDecimal('123456789').vpmult(BigDecimal('987654321')))
1822
assert_equal(BigDecimal('12193263.1112635269'), BigDecimal('123.456789').vpmult(BigDecimal('98765.4321')))
@@ -21,6 +25,15 @@ def test_vpmult
2125
assert_equal(BigDecimal("#{x * y}e-300"), BigDecimal("#{x}e-100").vpmult(BigDecimal("#{y}e-200")))
2226
end
2327

28+
def test_nttmult
29+
omit 'NTT multiplication is only available for 32-bit DECDIG' unless ntt_mult_available?
30+
[*1..32].repeated_permutation(2) do |a, b|
31+
x = BigDecimal(10 ** (BASE_FIG * a) / 7)
32+
y = BigDecimal(10 ** (BASE_FIG * b) / 13)
33+
assert_equal(x.to_i * y.to_i, x.nttmult(y))
34+
end
35+
end
36+
2437
def test_vpdivd
2538
# a[0] > b[0]
2639
# XXXX_YYYY_ZZZZ / 1111 #=> 000X_000Y_000Z

0 commit comments

Comments
 (0)