diff --git a/estimator/lwe_dual.py b/estimator/lwe_dual.py index 6ce6957..0344f7f 100644 --- a/estimator/lwe_dual.py +++ b/estimator/lwe_dual.py @@ -60,15 +60,12 @@ def dual_reduce( ) # Compute new secret distribution - if params.Xs.is_sparse: + if type(params.Xs) is SparseTernary: h = params.Xs.hamming_weight if not 0 <= h1 <= h: raise OutOfBoundsError(f"Splitting weight {h1} must be between 0 and h={h}.") - # assuming the non-zero entries are uniform - p = h1 / 2 - red_Xs = SparseTernary(params.n - zeta, h / 2 - p) - slv_Xs = SparseTernary(zeta, p) - + # split the +1 and -1 entries in a balanced way. + slv_Xs, red_Xs = params.Xs.split_balanced(zeta, h1) if h1 == h: # no reason to do lattice reduction if we assume # that the hw on the reduction part is 0 @@ -176,7 +173,7 @@ def cost( Logging.log("dual", log_level, f"{repr(cost)}") rep = 1 - if params.Xs.is_sparse: + if type(params.Xs) is SparseTernary: h = params.Xs.hamming_weight probability = RR(prob_drop(params.n, h, zeta, h1)) rep = prob_amplify(success_probability, probability) @@ -313,7 +310,7 @@ def f(beta): beta = cost["beta"] cost["zeta"] = zeta - if params.Xs.is_sparse: + if type(params.Xs) is SparseTernary: cost["h1"] = h1 return cost @@ -428,7 +425,7 @@ def __call__( params = params.normalize() - if params.Xs.is_sparse: + if type(params.Xs) is SparseTernary: Cost.register_impermanent(h1=False) def _optimize_blocksize( diff --git a/estimator/nd.py b/estimator/nd.py index 3f332c5..ab195b3 100644 --- a/estimator/nd.py +++ b/estimator/nd.py @@ -423,6 +423,27 @@ def resize(self, new_n): """ return SparseTernary(new_n, self.p, self.m) + def split_balanced(self, new_n, new_hw=None): + """ + Split the +1 and -1 entries in a balanced way, and return 2 SparseTernary distributions: + one of dimension `new_n` and the other of dimension `n - new_n`. + + :param new_n: dimension of the first noise distribution + :param new_hw: hamming weight of the first noise distribution. If none, we take the most likely weight. + :return: tuple of (SparseTernary, SparseTernary) + """ + n, hw = len(self), self.hamming_weight + if new_hw is None: + # Most likely split has same density: new_hw / new_n = hw / n. + new_hw = int(round(hw * new_n / n)) + + new_p = int(round((new_hw * self.p) / hw)) + new_m = new_hw - new_p + return ( + SparseTernary(new_n, new_p, new_m), + SparseTernary(n - new_n, self.p - new_p, self.m - new_m) + ) + @property def hamming_weight(self): return self.p + self.m