|
| 1 | +// NTT (Number Theoretic Transform) implementation for BigDecimal multiplication |
| 2 | + |
| 3 | +#define NTT_PRIMITIVE_ROOT 17 |
| 4 | +#define NTT_PRIME_BASE1 24 |
| 5 | +#define NTT_PRIME_BASE2 26 |
| 6 | +#define NTT_PRIME_BASE3 29 |
| 7 | +#define NTT_PRIME_SHIFT 27 |
| 8 | +#define NTT_PRIME1 (((uint32_t)NTT_PRIME_BASE1 << NTT_PRIME_SHIFT) | 1) |
| 9 | +#define NTT_PRIME2 (((uint32_t)NTT_PRIME_BASE2 << NTT_PRIME_SHIFT) | 1) |
| 10 | +#define NTT_PRIME3 (((uint32_t)NTT_PRIME_BASE3 << NTT_PRIME_SHIFT) | 1) |
| 11 | +#define MAX_NTT32_BITS 27 |
| 12 | +#define NTT_DECDIG_BASE 1000000000 |
| 13 | + |
| 14 | +// Calculates base**ex % mod |
| 15 | +static uint32_t |
| 16 | +mod_pow(uint32_t base, uint32_t ex, uint32_t mod) { |
| 17 | + uint32_t res = 1; |
| 18 | + uint32_t bit = 1; |
| 19 | + while (true) { |
| 20 | + if (ex & bit) { |
| 21 | + ex ^= bit; |
| 22 | + res = ((uint64_t)res * base) % mod; |
| 23 | + } |
| 24 | + if (!ex) break; |
| 25 | + base = ((uint64_t)base * base) % mod; |
| 26 | + bit <<= 1; |
| 27 | + } |
| 28 | + return res; |
| 29 | +} |
| 30 | + |
| 31 | +// Recursively performs butterfly operations of NTT |
| 32 | +static void |
| 33 | +ntt_recursive(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int depth, uint32_t r, uint32_t prime) { |
| 34 | + if (depth > 0) { |
| 35 | + ntt_recursive(size_bits, input, tmp, output, depth - 1, ((uint64_t)r * r) % prime, prime); |
| 36 | + } else { |
| 37 | + tmp = input; |
| 38 | + } |
| 39 | + uint32_t size_half = (uint32_t)1 << (size_bits - 1); |
| 40 | + uint32_t stride = (uint32_t)1 << (size_bits - depth - 1); |
| 41 | + uint32_t n = size_half / stride; |
| 42 | + uint32_t rn = 1, rm = prime - 1; |
| 43 | + uint32_t idx = 0; |
| 44 | + for (uint32_t i = 0; i < n; i++) { |
| 45 | + uint32_t j = i * 2 * stride; |
| 46 | + for (uint32_t k = 0; k < stride; k++, j++, idx++) { |
| 47 | + uint32_t a = tmp[j], b = tmp[j + stride]; |
| 48 | + output[idx] = (a + (uint64_t)rn * b) % prime; |
| 49 | + output[idx + size_half] = (a + (uint64_t)rm * b) % prime; |
| 50 | + } |
| 51 | + rn = ((uint64_t)rn * r) % prime; |
| 52 | + rm = ((uint64_t)rm * r) % prime; |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +/* Perform NTT on input array. |
| 57 | + * base, shift: Represent the prime number as (base << shift | 1) |
| 58 | + * r_base: Primitive root of unity modulo prime |
| 59 | + * size_bits: log2 of the size of the input array. Should be less or equal to shift |
| 60 | + * input: input array of size 1 << size_bits |
| 61 | + */ |
| 62 | +static void |
| 63 | +ntt(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int r_base, int base, int shift, int dir) { |
| 64 | + uint32_t size = (uint32_t)1 << size_bits; |
| 65 | + uint32_t prime = ((uint32_t)base << shift) | 1; |
| 66 | + |
| 67 | + // rmax**(1 << shift) % prime == 1 |
| 68 | + // r**size % prime == 1 |
| 69 | + uint32_t rmax = mod_pow(r_base, base, prime); |
| 70 | + uint32_t r = mod_pow(rmax, (uint32_t)1 << (shift - size_bits), prime); |
| 71 | + |
| 72 | + if (dir < 0) r = mod_pow(r, prime - 2, prime); |
| 73 | + ntt_recursive(size_bits, input, output, tmp, size_bits - 1, r, prime); |
| 74 | + if (dir < 0) { |
| 75 | + uint32_t n_inv = mod_pow((uint32_t)size, prime - 2, prime); |
| 76 | + for (uint32_t i = 0; i < size; i++) { |
| 77 | + output[i] = ((uint64_t)output[i] * n_inv) % prime; |
| 78 | + } |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +/* Calculate c that satisfies: c % PRIME1 == mod1 && c % PRIME2 == mod2 && c % PRIME3 == mod3 |
| 83 | + * c = (mod1 * 35002755423056150739595925972 + mod2 * 14584479687667766215746868453 + mod3 * 37919651490985126265126719818) % (PRIME1 * PRIME2 * PRIME3) |
| 84 | + */ |
| 85 | +static inline void |
| 86 | +mod_restore_prime_24_26_29_shift_27(uint32_t mod1, uint32_t mod2, uint32_t mod3, uint32_t *digits) { |
| 87 | + // Use mixed radix notation to eliminate modulo by PRIME1 * PRIME2 * PRIME3 |
| 88 | + // [DIG0, DIG1, DIG2] = DIG0 + DIG1 * PRIME1 + DIG2 * PRIME1 * PRIME2 |
| 89 | + // DIG0: 0...PRIME1, DIG1: 0...PRIME2, DIG2: 0...PRIME3 |
| 90 | + // 35002755423056150739595925972 = [1, 3489660916, 3113851359] |
| 91 | + // 14584479687667766215746868453 = [0, 13, 1297437912] |
| 92 | + // 37919651490985126265126719818 = [0, 0, 3373338954] |
| 93 | + uint64_t c0 = mod1; |
| 94 | + uint64_t c1 = (uint64_t)mod2 * 13 + (uint64_t)mod1 * 3489660916; |
| 95 | + uint64_t c2 = (uint64_t)mod3 * 3373338954 % NTT_PRIME3 + (uint64_t)mod2 * 1297437912 % NTT_PRIME3 + (uint64_t)mod1 * 3113851359 % NTT_PRIME3; |
| 96 | + c2 += c1 / NTT_PRIME2; |
| 97 | + c1 %= NTT_PRIME2; |
| 98 | + c2 %= NTT_PRIME3; |
| 99 | + // Base conversion |
| 100 | + c1 += c2 % NTT_DECDIG_BASE * NTT_PRIME2; |
| 101 | + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; |
| 102 | + c1 /= NTT_DECDIG_BASE; |
| 103 | + digits[0] = c0 % NTT_DECDIG_BASE; |
| 104 | + c0 /= NTT_DECDIG_BASE; |
| 105 | + c1 += c2 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME2; |
| 106 | + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; |
| 107 | + c1 /= NTT_DECDIG_BASE; |
| 108 | + digits[1] = c0 % NTT_DECDIG_BASE; |
| 109 | + c0 = c0 / NTT_DECDIG_BASE + c1 % NTT_DECDIG_BASE * NTT_PRIME1; |
| 110 | + digits[2] = c0 % NTT_DECDIG_BASE; |
| 111 | + digits[3] = (c0 / NTT_DECDIG_BASE + c1 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME1) % NTT_DECDIG_BASE; |
| 112 | +} |
| 113 | + |
| 114 | +/* |
| 115 | + * NTT multiplication |
| 116 | + * Uses three NTTs with mod (24 << 27 | 1), (26 << 27 | 1), and (29 << 27 | 1) |
| 117 | + */ |
| 118 | +static void |
| 119 | +ntt_multiply(size_t a_size, size_t b_size, uint32_t *a, uint32_t *b, uint32_t *c) { |
| 120 | + if (a_size < b_size) { |
| 121 | + ntt_multiply(b_size, a_size, b, a, c); |
| 122 | + return; |
| 123 | + } |
| 124 | + |
| 125 | + int b_bits = 0; |
| 126 | + while (((uint32_t)1 << b_bits) < (uint32_t)b_size) b_bits++; |
| 127 | + int ntt_size_bits = b_bits + 1; |
| 128 | + if (ntt_size_bits > MAX_NTT32_BITS) { |
| 129 | + rb_raise(rb_eArgError, "Multiply size too large"); |
| 130 | + } |
| 131 | + |
| 132 | + // To calculate large_a * small_b faster, split into several batches. |
| 133 | + uint32_t ntt_size = (uint32_t)1 << ntt_size_bits; |
| 134 | + uint32_t batch_size = ntt_size - (uint32_t)b_size; |
| 135 | + uint32_t batch_count = (uint32_t)((a_size + batch_size - 1) / batch_size); |
| 136 | + |
| 137 | + uint32_t *ntt1 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 138 | + uint32_t *ntt2 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 139 | + uint32_t *ntt3 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 140 | + uint32_t *tmp1 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 141 | + uint32_t *tmp2 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 142 | + uint32_t *tmp3 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 143 | + uint32_t *conv1 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 144 | + uint32_t *conv2 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 145 | + uint32_t *conv3 = ruby_xcalloc(sizeof(uint32_t), ntt_size); |
| 146 | + |
| 147 | + // Calculate NTT for b in three primes. Result is reused for each batch of a. |
| 148 | + memcpy(tmp1, b, b_size * sizeof(uint32_t)); |
| 149 | + memset(tmp1 + b_size, 0, (ntt_size - b_size) * sizeof(uint32_t)); |
| 150 | + ntt(ntt_size_bits, tmp1, ntt1, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); |
| 151 | + ntt(ntt_size_bits, tmp1, ntt2, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); |
| 152 | + ntt(ntt_size_bits, tmp1, ntt3, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); |
| 153 | + |
| 154 | + memset(c, 0, (a_size + b_size) * sizeof(uint32_t)); |
| 155 | + for (uint32_t idx = 0; idx < batch_count; idx++) { |
| 156 | + if (idx == batch_count - 1) { |
| 157 | + uint32_t len = (uint32_t)a_size - idx * batch_size; |
| 158 | + memcpy(tmp1, a + idx * batch_size, len * sizeof(uint32_t)); |
| 159 | + memset(tmp1 + len, 0, (ntt_size - len) * sizeof(uint32_t)); |
| 160 | + } else { |
| 161 | + memcpy(tmp1, a + idx * batch_size, batch_size * sizeof(uint32_t)); |
| 162 | + } |
| 163 | + // Calculate convolution for this batch in three primes |
| 164 | + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); |
| 165 | + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt1[i]) % NTT_PRIME1; |
| 166 | + ntt(ntt_size_bits, tmp2, conv1, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, -1); |
| 167 | + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); |
| 168 | + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt2[i]) % NTT_PRIME2; |
| 169 | + ntt(ntt_size_bits, tmp2, conv2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, -1); |
| 170 | + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); |
| 171 | + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt3[i]) % NTT_PRIME3; |
| 172 | + ntt(ntt_size_bits, tmp2, conv3, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, -1); |
| 173 | + |
| 174 | + // Restore the original convolution value from three convolutions calculated in three primes |
| 175 | + for (uint32_t i = 0; i < ntt_size; i++) { |
| 176 | + uint32_t dig[4]; |
| 177 | + mod_restore_prime_24_26_29_shift_27(conv1[i], conv2[i], conv3[i], dig); |
| 178 | + for (int j = 0; j < 4; j++) { |
| 179 | + // Maximum overlap(4) * maximum_value(999999999) does not overflow 32-bit integer. |
| 180 | + // Index check: if dig[j] is non-zero, assign index is within valid range. |
| 181 | + if (dig[j]) c[idx * batch_size + i + 1 - j] += dig[j]; |
| 182 | + } |
| 183 | + } |
| 184 | + } |
| 185 | + uint32_t carry = 0; |
| 186 | + for (int32_t i = (uint32_t)(a_size + b_size - 1); i >= 0; i--) { |
| 187 | + uint32_t v = c[i] + carry; |
| 188 | + c[i] = v % NTT_DECDIG_BASE; |
| 189 | + carry = v / NTT_DECDIG_BASE; |
| 190 | + } |
| 191 | + ruby_xfree(ntt1); |
| 192 | + ruby_xfree(ntt2); |
| 193 | + ruby_xfree(ntt3); |
| 194 | + ruby_xfree(tmp1); |
| 195 | + ruby_xfree(tmp2); |
| 196 | + ruby_xfree(tmp3); |
| 197 | + ruby_xfree(conv1); |
| 198 | + ruby_xfree(conv2); |
| 199 | + ruby_xfree(conv3); |
| 200 | +} |
0 commit comments