Skip to content

Commit

Permalink
Speed up coded_bkw computation
Browse files Browse the repository at this point in the history
Avoids symbolic computation at all cost for faster computation,
especially since on line 62 there's a `sum` which computes the sum of
(I believe) hundreds of expressions. Benchmarks on my laptop:

Kyber512: 277.6s -> 12.25s
Kyber1024: Long time -> 17.40s
TFHE630: 539.4s -> 20.15s

Look at that, 20x speed up!
  • Loading branch information
grhkm21 committed Jul 30, 2023
1 parent 7d5b7fe commit f5b8886
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions estimator/lwe_bkw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@
"""
See :ref:`Coded-BKW for LWE` for what is available.
"""
from sage.all import ZZ, ceil, log, floor, sqrt, var, find_root, erf, oo, cached_function
from sage.rings.all import QQ
from sage.calculus.var import var
from sage.functions.all import log
from sage.functions.error import erf
from sage.functions.other import floor, ceil
from sage.misc.cachefunc import cached_function
from sage.misc.functional import sqrt
from sage.numerical.optimize import find_root
from sage.rings.infinity import infinity as oo
from sage.rings.integer_ring import ZZ
from sage.rings.real_mpfr import RR

from .lwe_parameters import LWEParameters
from .util import local_minimum
from .cost import Cost
Expand Down Expand Up @@ -44,18 +55,21 @@ def ntest(n, ell, t1, t2, b, q):
return 0

# solve for ntest by aiming for ntop == 0
ntest = var("ntest")
sigma_set = sqrt(q ** (2 * (1 - ell / ntest)) / 12)
ncod = sum(CodedBKW.N(i, sigma_set, b, q) for i in range(1, t2 + 1))
ntop = n - ncod - ntest - t1 * b
def ntop(ntest):
# Patch so that `find_root` (which uses float) doesn't error
ntest = RR(ntest)
sigma_set = sqrt(q ** (2 * (1 - ell / ntest)) / 12)
ncod = sum(CodedBKW.N(i, sigma_set, b, q) for i in range(1, t2 + 1))
res = n - ncod - ntest - t1 * b
return res

try:
start = max(int(round(find_root(ntop, 2, n - t1 * b + 1, rtol=0.1))) - 1, 2)
except RuntimeError:
start = 2
ntest_min = 1
for ntest in range(start, n - t1 * b + 1):
if abs(ntop(ntest=ntest).n()) >= abs(ntop(ntest=ntest_min).n()):
if abs(ntop(ntest).n()) >= abs(ntop(ntest_min).n()):
break
ntest_min = ntest
return int(ntest_min)
Expand Down

0 comments on commit f5b8886

Please sign in to comment.