Skip to content

Commit

Permalink
Merge pull request #59 from hamishun/refactor
Browse files Browse the repository at this point in the history
Some refactoring
  • Loading branch information
malb authored Nov 29, 2022
2 parents 9687562 + 8d8ec87 commit 63fa248
Show file tree
Hide file tree
Showing 16 changed files with 179 additions and 252 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ At present, this estimator is maintained by Martin Albrecht. Contributors are:
- Cedric Lefebvre
- Fernando Virdia
- Florian Göpfert
- Hamish Hunt
- James Owen
- Léo Ducas
- Markus Schmidt
Expand Down
4 changes: 2 additions & 2 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ docker exec lattice-estimator-tests sage -sh -c pytest

Note that due to [this open
ticket](https://trac.sagemath.org/ticket/34242#comment:20) on Sage, the
published sage container is using an OEL'ed version of Ubunutu, and so we have
published sage container is using an OEL'ed version of Ubuntu, and so we have
to wait for the sagemath image to be updated in order to use `Dockerfile`
standlone (`git clone`ing from inside the container).
standalone (`git clone` from inside the container).

2 changes: 1 addition & 1 deletion docs/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ We use Python classes as namespaces, something like this::
We explain what is going on above:

1. LWE objects know how to normalize themselves by calling ``params.normalize()``. We assume that high-level functions (such as ``__call___`` above) call ``params.normalize()``.
1. LWE objects know how to normalize themselves by calling ``params.normalize()``. We assume that high-level functions (such as ``__call__`` above) call ``params.normalize()``.

2. Often optimizing parameters means finding the optimimum in some range. We provide some syntactical sugar to make this easy/readable.

Expand Down
101 changes: 38 additions & 63 deletions estimator/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@ class Cost:
"problem": False,
}

@staticmethod
def _update_without_overwrite(dst, src):
keys_intersect = set(dst.keys()) & set(src.keys())
attempts = [
f"{k}: {dst[k]} with {src[k]}" for k in keys_intersect if dst[k] != src[k]
]
if len(attempts) > 0:
s = ", ".join(attempts)
raise ValueError(f"Attempting to overwrite {s}")
dst.update(src)

@classmethod
def register_impermanent(cls, data=None, **kwds):
if data is not None:
for k, v in data.items():
if cls.impermanents.get(k, v) != v:
raise ValueError(f"Attempting to overwrite {k}:{cls.impermanents[k]} with {v}")
cls.impermanents[k] = v

for k, v in kwds.items():
if cls.impermanents.get(k, v) != v:
raise ValueError(f"Attempting to overwrite {k}:{cls.impermanents[k]} with {v}")
cls.impermanents[k] = v
cls._update_without_overwrite(cls.impermanents, data)
cls._update_without_overwrite(cls.impermanents, kwds)

key_map = {
"delta": "δ",
Expand All @@ -56,7 +60,7 @@ def __init__(self, **kwds):
for k, v in kwds.items():
setattr(self, k, v)

def str(self, keyword_width=None, newline=None, round_bound=2048, compact=False): # noqa C901
def str(self, keyword_width=0, newline=False, round_bound=2048, compact=False): # noqa C901
"""
:param keyword_width: keys are printed with this width
Expand All @@ -73,19 +77,9 @@ def str(self, keyword_width=None, newline=None, round_bound=2048, compact=False)
"""

def wfmtf(k):
if keyword_width:
fmt = "%%%ss" % keyword_width
else:
fmt = "%s"
return fmt % k

d = self.__dict__
s = []
for k, v in d.items():
if k == "problem": # we store the problem instance in a cost object for reference
continue
kk = wfmtf(self.key_map.get(k, k))
def value_str(k, v):
kstr = self.key_map.get(k, k)
kk = f"{kstr:>{keyword_width}}"
try:
if (1 / round_bound < abs(v) < round_bound) or (not v) or (k in self.val_map):
if abs(v % 1) < 0.0000001:
Expand All @@ -96,19 +90,19 @@ def wfmtf(k):
vv = "%7s" % ("≈2^%.1f" % log(v, 2))
except TypeError: # strings and such
vv = "%8s" % v
if compact:
if compact is True:
kk = kk.strip()
vv = vv.strip()
s.append(f"{kk}: {vv}")
return f"{kk}: {vv}"

if not newline:
return ", ".join(s)
else:
return "\n".join(s)
# we store the problem instance in a cost object for reference
s = [value_str(k, v) for k, v in self.__dict__.items() if k != "problem"]
delimiter = "\n" if newline is True else ", "
return delimiter.join(s)

def reorder(self, *args):
"""
Return a new ordered dict from the key:value pairs in dictinonary but reordered such that the
Return a new ordered dict from the key:value pairs in dictionary but reordered such that the
keys given to this function come first.
:param args: keys which should come first (in order)
Expand All @@ -123,25 +117,18 @@ def reorder(self, *args):
b: 2, c: 3, a: 1
"""
keys = list(self.__dict__.keys())
for key in args:
keys.pop(keys.index(key))
keys = list(args) + keys
r = dict()
for key in keys:
r[key] = self.__dict__[key]
return Cost(**r)
reord = {k: self.__dict__[k] for k in args if k in self.__dict__}
reord.update(self.__dict__)
return Cost(**reord)

def filter(self, **keys):
"""
Return new ordered dictinonary from dictionary restricted to the keys.
Return new ordered dictionary from dictionary restricted to the keys.
:param dictionary: input dictionary
:param keys: keys which should be copied (ordered)
"""
r = dict()
for key in keys:
r[key] = self.__dict__[key]
r = {k: self.__dict__[k] for k in keys if k in self.__dict__}
return Cost(**r)

def repeat(self, times, select=None):
Expand Down Expand Up @@ -170,20 +157,14 @@ def repeat(self, times, select=None):
impermanents = dict(self.impermanents)

if select is not None:
for key in select:
impermanents[key] = select[key]
impermanents.update(select)

ret = dict()
for key in self.__dict__:
try:
if impermanents[key]:
ret[key] = times * self.__dict__[key]
else:
ret[key] = self.__dict__[key]
except KeyError:
raise NotImplementedError(
f"You found a bug, this function does not know about '{key}' but should."
)
try:
ret = {k: times * v if impermanents[k] else v for k, v in self.__dict__.items()}
except KeyError as error:
raise NotImplementedError(
f"You found a bug, this function does not know about about a key but should: {error}"
)
ret["repetitions"] = times * ret.get("repetitions", 1)
return Cost(**ret)

Expand All @@ -209,14 +190,8 @@ def combine(self, right, base=None):
c: 3, a: 1, b: 2
"""
if base is None:
cost = dict()
else:
cost = base.__dict__
for key in self.__dict__:
cost[key] = self.__dict__[key]
for key in right:
cost[key] = right.__dict__[key]
base_dict = {} if base is None else base.__dict__
cost = {**base_dict, **self.__dict__, **right.__dict__}
return Cost(**cost)

def __bool__(self):
Expand Down
63 changes: 31 additions & 32 deletions estimator/gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,10 @@ def gb_cost(n, D, omega=2, prec=None):

for dreg in range(prec):
if s[dreg] < 0:
retval["dreg"] = dreg
retval["rop"] = binomial(n + dreg, dreg) ** omega
retval["mem"] = binomial(n + dreg, dreg) ** 2
break
else:
return retval

retval["dreg"] = dreg
retval["rop"] = binomial(n + dreg, dreg) ** omega
retval["mem"] = binomial(n + dreg, dreg) ** 2

return retval

Expand Down Expand Up @@ -115,20 +112,26 @@ def cost_Gaussian_like(cls, params, success_probability=0.99, omega=2, log_level
dn = cls.equations_for_secret(params)

best, stuck = None, 0
for t in range(ceil(params.Xe.stddev), params.n):
d = 2 * t + 1

def t_and_m_can(t):
C = RR(t / params.Xe.stddev)
assert C >= 1 # if C is too small, we ignore it
# Pr[success]^m = Pr[overall success]
single_prob = AroraGB.ps_single(C)
if single_prob == 1:
m_can = 2**31 # some arbitrary max
else:
m_can = log(success_probability, 2) / log(single_prob, 2)
m_can = floor(m_can)
# log(success_probability, single_prob)
# == log(success_probability, 2) / log(single_prob, 2)
m_can = floor(log(success_probability, single_prob))

return t, m_can

for t, m_can in map(t_and_m_can, range(ceil(params.Xe.stddev), params.n)):
if m_can > params.m:
break

d = 2 * t + 1
current = gb_cost(params.n, [(d, m_can)] + dn, omega)

if current["dreg"] == oo:
Expand All @@ -143,18 +146,15 @@ def cost_Gaussian_like(cls, params, success_probability=0.99, omega=2, log_level

if best is None:
best = current
elif best > current:
best = current
stuck = 0
else:
if best > current:
best = current
stuck = 0
else:
stuck += 1
if stuck >= 5:
break

if best is None:
best = Cost(rop=oo, dreg=oo)
return best
stuck += 1
if stuck >= 5:
break

return best if best is not None else Cost(rop=oo, dreg=oo)

@classmethod
def equations_for_secret(cls, params):
Expand All @@ -164,18 +164,17 @@ def equations_for_secret(cls, params):
:param params: LWE parameters.
"""
if params.Xs <= params.Xe:
a, b = params.Xs.bounds
if b - a < oo:
d = b - a + 1
elif params.Xs.is_Gaussian_like:
d = 2 * ceil(3 * params.Xs.stddev) + 1
else:
raise NotImplementedError(f"Do not know how to handle {params.Xs}.")
dn = [(d, params.n)]
if params.Xs > params.Xe:
return []

a, b = params.Xs.bounds
if b - a < oo:
d = b - a + 1
elif params.Xs.is_Gaussian_like:
d = 2 * ceil(3 * params.Xs.stddev) + 1
else:
dn = []
return dn
raise NotImplementedError(f"Do not know how to handle {params.Xs}.")
return [(d, params.n)]

def __call__(
self, params: LWEParameters, success_probability=0.99, omega=2, log_level=1, **kwds
Expand Down
2 changes: 1 addition & 1 deletion estimator/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ def set_level(lvl, loggers=None):
def log(cls, logger, level, msg, *args, **kwds):
level = int(level)
return logging.getLogger(logger).log(
cls.INFO - 2 * level, f"{{{level}}} " + msg, *args, **kwds
cls.INFO - 2 * level, f"{{{level}}} {msg}", *args, **kwds
)
34 changes: 15 additions & 19 deletions estimator/lwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
High-level LWE interface
"""

from functools import partial
from .lwe_primal import primal_usvp, primal_bdd, primal_hybrid
from .lwe_bkw import coded_bkw
from .lwe_guess import exhaustive_search, mitm, distinguish # noqa
Expand Down Expand Up @@ -42,7 +43,6 @@ def rough(cls, params, jobs=1, catch_exceptions=True):
"""
# NOTE: Don't import these at the top-level to avoid circular imports
from functools import partial
from .reduction import RC
from .util import batch_estimate, f_name

Expand All @@ -58,8 +58,6 @@ def rough(cls, params, jobs=1, catch_exceptions=True):
algorithms["hybrid"] = partial(
primal_hybrid, red_cost_model=RC.ADPS16, red_shape_model="gsa"
)

if params.Xs.is_sparse:
algorithms["dual_mitm_hybrid"] = partial(
dual_hybrid, red_cost_model=RC.ADPS16, mitm_optimization=True
)
Expand All @@ -78,18 +76,18 @@ def rough(cls, params, jobs=1, catch_exceptions=True):
params, algorithms.values(), log_level=1, jobs=jobs, catch_exceptions=catch_exceptions
)
res_raw = res_raw[params]
res = {}
for algorithm in algorithms:
for k, v in res_raw.items():
if f_name(algorithms[algorithm]) == k:
res[algorithm] = v
res = {
algorithm: v for algorithm, attack in algorithms.items()
for k, v in res_raw.items()
if f_name(attack) == k
}

for algorithm in algorithms:
for k, v in res.items():
if algorithm == k:
if v["rop"] == oo:
continue
print(f"{algorithm:20s} :: {repr(v)}")
print(f"{algorithm:20s} :: {v!r}")
return res

def __call__(
Expand Down Expand Up @@ -173,20 +171,18 @@ def __call__(
dual_hybrid, red_cost_model=red_cost_model, mitm_optimization=True
)

for k in deny_list:
del algorithms[k]
for k, v in add_list:
algorithms[k] = v
algorithms = {k: v for k, v in algorithms.items() if k not in deny_list}
algorithms.update(add_list)

res_raw = batch_estimate(
params, algorithms.values(), log_level=1, jobs=jobs, catch_exceptions=catch_exceptions
)
res_raw = res_raw[params]
res = {}
for algorithm in algorithms:
for k, v in res_raw.items():
if f_name(algorithms[algorithm]) == k:
res[algorithm] = v
res = {
algorithm: v for algorithm, attack in algorithms.items()
for k, v in res_raw.items()
if f_name(attack) == k
}

for algorithm in algorithms:
for k, v in res.items():
Expand All @@ -197,7 +193,7 @@ def __call__(
continue
if k == "dual_mitm_hybrid" and res["dual_hybrid"]["rop"] < v["rop"]:
continue
print(f"{algorithm:20s} :: {repr(v)}")
print(f"{algorithm:20s} :: {v!r}")
return res


Expand Down
Loading

0 comments on commit 63fa248

Please sign in to comment.