快速平方平方计算

为了加速我的数字divisons,我需要加速对bigint的操作y = x ^ 2,表示为无符号DWORD的动态数组。 要清楚:

DWORD x[n+1] = { LSW, ......, MSW }; 
  • 其中n + 1是使用的DWORD的数量
  • 所以数值x = x [0] + x [1] << 32 + … x [N] << 32 *(n)

问题是: 如何在没有精确度损失的情况下尽可能快地计算y = x ^ 2? – 使用C ++和整数算术(32位与进位)。

我目前的做法是应用乘法,y = x * x,并避免多重乘法。

例如:

 x = x[0] + x[1]<<32 + ... x[n]<<32*(n) 

为了简单起见,让我重写一下:

 x = x0+ x1 + x2 + ... + xn 

其中索引代表数组内的地址,所以:

 y = x*x y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn) y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn) y0 = x0*x0 y1 = x1*x0 + x0*x1 y2 = x2*x0 + x1*x1 + x0*x2 y3 = x3*x0 + x2*x1 + x1*x2 ... y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2) y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1) y(2n-1) = xn(n )*x(n ) 

经过仔细观察,很明显几乎所有xi xj都出现了两次(不是第一次也是最后一次),这意味着N N次乘法可以用(N + 1)*(N / 2)次乘法代替。 PS 32位* 32位= 64位,所以每个多路+加操作的结果都是64 + 1位。

有没有更好的方法来计算这个快速? 所有我在搜索过程中发现sqrt算法,而不是sqr …

快速平方米

! 请注意,我的代码中的所有数字都是MSW,而不是上面的测试(为简化方程式,LSW首先是LSW,否则就是索引混乱)。

当前的功能fsqr实现

 void arbnum::sqr(const arbnum &x) { // O((N+1)*N/2) arbnum c; DWORD h, l; int N, nx, nc, i, i0, i1, k; c._alloc(x.siz + x.siz + 1); nx = x.siz - 1; nc = c.siz - 1; N = nx + nx; for (i=0; i<=nc; i++) c.dat[i]=0; for (i=1; i<N; i++) for (i0=0; (i0<=nx) && (i0<=i); i0++) { i1 = i - i0; if (i0 >= i1) break; if (i1 > nx) continue; h = x.dat[nx-i0]; if (!h) continue; l = x.dat[nx-i1]; if (!l) continue; alu.mul(h, l, h, l); k = nc - i; if (k >= 0) alu.add(c.dat[k], c.dat[k], l); k--; if (k>=0) alu.adc(c.dat[k], c.dat[k],h); k--; for (; (alu.cy) && (k>=0); k--) alu.inc(c.dat[k]); } c.shl(1); for (i = 0; i <= N; i += 2) { i0 = i>>1; h = x.dat[nx-i0]; if (!h) continue; alu.mul(h, l, h, h); k = nc - i; if (k >= 0) alu.add(c.dat[k], c.dat[k],l); k--; if (k>=0) alu.adc(c.dat[k], c.dat[k], h); k--; for (; (alu.cy) && (k >= 0); k--) alu.inc(c.dat[k]); } c.bits = c.siz<<5; c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1; c.sig = sig; *this = c; } 

使用Karatsuba乘法

(感谢Calpis)

我实现了Karatsuba乘法,但是结果比使用简单的O(N ^ 2)乘法慢得多,可能是因为那种可怕的递归,我看不到任何方法可以避免。 这是一个折衷必须在非常大的数字(大于几百位数)…但即使如此,有很多的内存传输。 有没有一种方法来避免递归调用(非递归变体,…几乎所有的递归算法可以这样做)。 不过,我会试着调整一下,看看会发生什么(避免正常化等,也可能是代码中的一些愚蠢的错误)。 无论如何,在解决Karatsuba案例x * x之后,没有太多的性能增益。

优化的Karatsuba乘法

y = x ^ 2循环1000倍,0.9 <x <1〜32 * 98位的性能测试:

 x = 0.98765588997654321000000009876... | 98*32 bits sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr mul1[ 363.472 ms ] ... O(N^2) classic multiplication mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication x = 0.98765588997654321000... | 195*32 bits sqr [ 883.01 ms ] mul1[ 1427.02 ms ] mul2[ 1089.84 ms ] x = 0.98765588997654321000... | 389*32 bits sqr [ 3189.19 ms ] mul1[ 5553.23 ms ] mul2[ 3159.07 ms ] 

在对Karatsuba进行优化之后,代码比以前快得多了。 尽管如此,对于较小的数字来说,它稍微比我的O(N ^ 2)乘法的一半速度慢。 对于更大的数字,用Booth乘法的复杂度给出的比率更快。 乘法门限大约为32 * 98位,而sqr大约为32 * 389位,所以如果输入位的总和超过了这个门限值,那么Karatsuba乘法将被用于加速乘法,对于sqr也是如此。

顺便说一句,优化包括:

  • 通过太大的递归参数来最小化堆垃圾
  • 避免使用携带的任何bignum aritmetics(+, – )32位ALU代替。
  • 忽略0 * y或x * 0或0 * 0的情况
  • 重新格式化输入x,y的数字大小为2的幂,以避免重新分配
  • 对z1 =(x0 + x1)*(y0 + y1)执行模乘法以使递归最小化

修改Schönhage-Strassen乘法到sqr实现

我已经测试过使用FTT和NTT转换来加速sqr计算。 结果是这些:

  1. FTT

    • 失去准确性,因此需要高精度的复数
    • 这实际上大大减缓了事情,所以没有加速。
    • 结果不准确(可能是错误的四舍五入)
    • FTT是无法使用的
  2. NTT

    • NTT是有限域DFT,所以不会发生精度损失。
    • 需要对无符号整数进行模块化运算:modpow,modmul,modadd和modsub
    • 我使用DWORD(32位无符号整数)。
    • 由于溢出问题,NTT输入/输出矢量大小受限制! 对于32位模块化运算,N被限制为(2 ^ 32)/(max(input [])^ 2),所以bigint必须被分成更小的块(我使用BYTES,所以bigint处理的最大尺寸是^ 32)/((2 ^ 8)^ 2)= 2 ^ 16个字节= 2 ^ 14个DWORD = 16384个DWORD)。
    • sqr仅使用1xNTT + 1xINTT而不是2xNTT + 1xINTT来进行乘法运算。
    • NTT的使用速度太慢,并且在我的实现(对于mul和sqr)中实际使用的阈值数量太大,甚至可能超过溢出限制,所以应该使用64位模块化算术,这会减慢速度下来更多。
    • NTT是为我的目的也是无法使用的

一些测量:

 a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.177 ms ] fast sqr sqr2[ 720.419 ms ] NTT sqr mul1[ 5.588 ms ] simpe mul mul2[ 3.172 ms ] karatsuba mul mul3[ 1053.382 ms ] NTT mul 

我的实现:

 void arbnum::sqr_NTT(const arbnum &x) { // O(N*log(N)*(log(log(N)))) - 1x NTT // Schönhage-Strassen sqr // To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!! int i, j, k, n; int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2; i = x.siz; for (n = 1; n < i; n<<=1) ; if (n + n > 0x3000) { _error(_arbnum_error_TooBigNumber); zero(); return; } n <<= 3; DWORD *xx, *yy, q, qq; xx = new DWORD[n+n]; #ifdef _mmap_h if (xx) mmap_new(xx, (n+n) << 2); #endif if (xx==NULL) { _error(_arbnum_error_NotEnoughMemory); zero(); return; } yy = xx + n; // Zero padding (and split DWORDs to BYTEs) for (i--, k=0; i >= 0; i--) { q = x.dat[i]; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; } for (;k<n;k++) xx[k] = 0; //NTT fourier_NTT ntt; ntt.NTT(yy,xx,n); // init NTT for n // Convolution for (i=0; i<n; i++) yy[i] = modmul(yy[i], yy[i], ntt.p); //INTT ntt.INTT(xx, yy); //suma q=0; for (i = 0, j = 0; i<n; i++) { qq = xx[i]; q += qq&0xFF; yy[ni-1] = q&0xFF; q>>=8; qq>>=8; q+=qq; } // Merge WORDs to DWORDs and copy them to result _alloc(n>>2); for (i = 0, j = 0; i<siz; i++) { q =(yy[j]<<24)&0xFF000000; j++; q |=(yy[j]<<16)&0x00FF0000; j++; q |=(yy[j]<< 8)&0x0000FF00; j++; q |=(yy[j] )&0x000000FF; j++; dat[i] = q; } #ifdef _mmap_h if (xx) mmap_del(xx); #endif delete xx; bits = siz<<5; sig = s; exp = exp0 + (siz<<5) - 1; // _normalize(); } 

结论

对于较小的数字,这是我的快速sqr方法的最佳选择,并在阈值karatsuba乘法后更好。 但我仍然认为应该有一些我们忽视的微不足道的东西。 有没有其他的想法?

NTT优化

经过大规模优化(主要是NTT)之后:堆栈溢出问题模块化算法和NTT(有限域DFT)优化

一些值已经改变:

 a = 0.98765588997654321000 | 1553*32bits looped 10x times mul2[ 28.585 ms ] Karatsuba mul mul3[ 26.311 ms ] NTT mul 

所以现在NTT乘法在大约1500 * 32位的门限之后终于比Karatsuba快。

一些测量和bug发现

 a = 0.99991970486 | 1553*32 bits looped: 10x sqr1[ 58.656 ms ] fast sqr sqr2[ 13.447 ms ] NTT sqr mul1[ 102.563 ms ] simpe mul mul2[ 28.916 ms ] Karatsuba mul Error mul3[ 19.470 ms ] NTT mul 

我发现我的Karatsuba(over / under)流过了每个双字节的LSB。 当我研究了,我会更新代码…

而且,在NTT优化之后,阈值改变了,所以NTT sqr是310×32位= 9920位操作数NTT多位是1396×32位= 44672位 结果 (操作数位的总和)。

由@greybeard修复Karatsuba代码

 //--------------------------------------------------------------------------- void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n) { // Recursion for Karatsuba // z[2n] = x[n]*y[n]; // n=2^m int i; for (i=0; i<n; i++) if (x[i]) { i=-1; break; } // x==0 ? if (i < 0) for (i = 0; i<n; i++) if (y[i]) { i = -1; break; } // y==0 ? if (i >= 0) { for (i = 0; i < n + n; i++) z[i]=0; return; } // 0.? = 0 if (n == 1) { alu.mul(z[0], z[1], x[0], y[0]); return; } if (n< 1) return; int n2 = n>>1; _mul_karatsuba(z+n, x+n2, y+n2, n2); // z0 = x0.y0 _mul_karatsuba(z , x , y , n2); // z2 = x1.y1 DWORD *q = new DWORD[n<<1], *q0, *q1, *qq; BYTE cx,cy; if (q == NULL) { _error(_arbnum_error_NotEnoughMemory); return; } #define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0] #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0] qq = q; q0 = x + n2; q1 = x; i = n2 - 1; _add; cx = alu.cy; // =x0+x1 qq = q + n2; q0 = y + n2; q1 = y; i = n2 - 1; _add; cy = alu.cy; // =y0+y1 _mul_karatsuba(q + n, q + n2, q, n2); // =(x0+x1)(y0+y1) mod ((2^N)-1) if (cx) { qq = q + n; q0 = qq; q1 = q + n2; i = n2 - 1; _add; cx = alu.cy; }// += cx*(y0 + y1) << n2 if (cy) { qq = q + n; q0 = qq; q1 = q; i = n2 -1; _add; cy = alu.cy; }// +=cy*(x0+x1)<<n2 qq = q + n; q0 = qq; q1 = z + n; i = n - 1; _sub; // -=z0 qq = q + n; q0 = qq; q1 = z; i = n - 1; _sub; // -=z2 qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add; // z1=(x0+x1)(y0+y1)-z0-z2 DWORD ccc=0; if (alu.cy) ccc++; // Handle carry from last operation if (cx || cy) ccc++; // Handle carry from before last operation if (ccc) { i = n2 - 1; alu.add(z[i], z[i], ccc); for (i--; i>=0; i--) if (alu.cy) alu.inc(z[i]); else break; } delete[] q; #undef _add #undef _sub } //--------------------------------------------------------------------------- void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y) { // O(3*(N)^log2(3)) ~ O(3*(N^1.585)) // Karatsuba multiplication // int s = x.sig*y.sig; arbnum a, b; a = x; b = y; a.sig = +1; b.sig = +1; int i, n; for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1) ; a._realloc(n); b._realloc(n); _alloc(n + n); for (i=0; i < siz; i++) dat[i]=0; _mul_karatsuba(dat, a.dat, b.dat, n); bits = siz << 5; sig = s; exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1; // _normalize(); } //--------------------------------------------------------------------------- 

我的arbnum号码表示:

 // dat is MSDW first ... LSDW last DWORD *dat; int siz,exp,sig,bits; 
  • dat[siz]是mantisa。 LSDW表示最不重要的DWORD。
  • expdat[0]的MSB的指数,
  • 第一个非零位在尾数中!

     // |-----|---------------------------|---------------|------| // | sig | MSB mantisa LSB | exponent | bits | // |-----|---------------------------|---------------|------| // | +1 | 0.(0 ... 0) | 2^0 | 0 | +zero // | -1 | 0.(0 ... 0) | 2^0 | 0 | -zero // |-----|---------------------------|---------------|------| // | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | +number // | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | -number // |-----|---------------------------|---------------|------| // | +1 | 1.0 | 2^+0x7FFFFFFE | 1 | +infinity // | -1 | 1.0 | 2^+0x7FFFFFFE | 1 | -infinity // |-----|---------------------------|---------------|------| 

如果我正确地理解了你的算法,看起来O(n^2)其中n是数字的位数。

你看过Karatsuba算法吗? 它使用分而治之的方法来加速乘法。 这可能值得一看。

如果你正在寻找一个新的更好的指数,你可能需要编写它。 这是来自golang的代码。

https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s