Skip to content

Commit

Permalink
More cleanup related to unwrap and arc
Browse files Browse the repository at this point in the history
  • Loading branch information
tlepoint committed Sep 6, 2023
1 parent 8cb13e1 commit 6234350
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 38 deletions.
5 changes: 5 additions & 0 deletions crates/fhe-math/src/rq/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ impl Context {
}
}

/// Creates a context in an `Arc`.
pub fn new_arc(moduli: &[u64], degree: usize) -> Result<Arc<Self>> {
Self::new(moduli, degree).map(Arc::new)
}

/// Returns the modulus as a BigUint.
pub fn modulus(&self) -> &BigUint {
self.rns.modulus()
Expand Down
15 changes: 7 additions & 8 deletions crates/fhe/examples/mulpir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fn main() -> Result<(), Box<dyn Error>> {
Some("Invalid `--database_size` command".to_string()),
)
} else {
database_size = a[0].parse::<usize>().unwrap()
database_size = a[0].parse::<usize>()?
}
} else if arg.starts_with("--element_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
Expand All @@ -89,7 +89,7 @@ fn main() -> Result<(), Box<dyn Error>> {
Some("Invalid `--element_size` command".to_string()),
)
} else {
elements_size = a[0].parse::<usize>().unwrap()
elements_size = a[0].parse::<usize>()?
}
} else {
print_notice_and_exit(
Expand Down Expand Up @@ -128,8 +128,7 @@ fn main() -> Result<(), Box<dyn Error>> {
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build_arc()
.unwrap()
.build_arc()?
);

// Proprocess the database on the server side: the database will be reshaped
Expand Down Expand Up @@ -233,7 +232,7 @@ fn main() -> Result<(), Box<dyn Error>> {
out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci)
}
rk.relinearizes(&mut out)?;
out.mod_switch_to_last_level();
out.mod_switch_to_last_level()?;
out.to_bytes()
});
println!("📄 Response: {}", HumanBytes(response.len() as u64));
Expand All @@ -243,10 +242,10 @@ fn main() -> Result<(), Box<dyn Error>> {
// (remember the database was reshaped to maximize how many elements) were
// embedded in a single ciphertext.
let answer = timeit!("Client answer", {
let response = bfv::Ciphertext::from_bytes(&response, &params).unwrap();
let response = bfv::Ciphertext::from_bytes(&response, &params)?;

let pt = sk.try_decrypt(&response).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
let pt = sk.try_decrypt(&response)?;
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2))?;
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
let offset = index
% number_elements_per_plaintext(
Expand Down
18 changes: 8 additions & 10 deletions crates/fhe/examples/sealpir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ fn main() -> Result<(), Box<dyn Error>> {
Some("Invalid `--database_size` command".to_string()),
)
} else {
database_size = a[0].parse::<usize>().unwrap()
database_size = a[0].parse::<usize>()?
}
} else if arg.starts_with("--element_size") {
let a: Vec<&str> = arg.rsplit('=').collect();
Expand All @@ -91,7 +91,7 @@ fn main() -> Result<(), Box<dyn Error>> {
Some("Invalid `--element_size` command".to_string()),
)
} else {
elements_size = a[0].parse::<usize>().unwrap()
elements_size = a[0].parse::<usize>()?
}
} else {
print_notice_and_exit(
Expand Down Expand Up @@ -129,8 +129,7 @@ fn main() -> Result<(), Box<dyn Error>> {
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build_arc()
.unwrap()
.build_arc()?
);

// Proprocess the database on the server side: the database will be reshaped
Expand Down Expand Up @@ -209,16 +208,15 @@ fn main() -> Result<(), Box<dyn Error>> {
// The operation is done `5` times to compute an average response time.
let responses: Vec<Vec<u8>> = timeit_n!("Server response", 5, {
let start = std::time::Instant::now();
let query = bfv::Ciphertext::from_bytes(&query, &params);
let query = query.unwrap();
let query = bfv::Ciphertext::from_bytes(&query, &params)?;
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
println!("Expand: {}", DisplayDuration(start.elapsed()));

let query_vec = &expanded_query[..dim1];
let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| {
let column = database.iter().skip(i).step_by(dim2);
let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?;
c.mod_switch_to_last_level();
c.mod_switch_to_last_level()?;
Ok(c)
};

Expand Down Expand Up @@ -259,7 +257,7 @@ fn main() -> Result<(), Box<dyn Error>> {
expanded_query[dim1..].iter(),
fold.iter().map(|pts| pts.get(i).unwrap()),
)?;
outi.mod_switch_to_last_level();
outi.mod_switch_to_last_level()?;
Ok(outi.to_bytes())
})
.collect::<fhe::Result<Vec<Vec<u8>>>>()?
Expand All @@ -281,7 +279,7 @@ fn main() -> Result<(), Box<dyn Error>> {
.collect_vec();
let decrypted_pt = responses
.iter()
.map(|r| sk.try_decrypt(r).unwrap())
.flat_map(|r| sk.try_decrypt(r))
.collect_vec();
let decrypted_vec = decrypted_pt
.iter()
Expand Down Expand Up @@ -317,7 +315,7 @@ fn main() -> Result<(), Box<dyn Error>> {
)?;

let pt = sk.try_decrypt(&ct).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2)).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2))?;
let plaintext = transcode_to_bytes(&pt, ilog2(plaintext_modulus));
let offset = index
% number_elements_per_plaintext(
Expand Down
24 changes: 13 additions & 11 deletions crates/fhe/src/bfv/ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,32 @@ pub struct Ciphertext {

impl Ciphertext {
/// Modulo switch the ciphertext to the last level.
pub fn mod_switch_to_last_level(&mut self) {
pub fn mod_switch_to_last_level(&mut self) -> Result<()> {
self.level = self.par.max_level();
let last_ctx = self.par.ctx_at_level(self.level).unwrap();
let last_ctx = self.par.ctx_at_level(self.level)?;
self.seed = None;
self.c.iter_mut().for_each(|ci| {
for ci in self.c.iter_mut() {
if ci.ctx() != last_ctx {
ci.change_representation(Representation::PowerBasis);
assert!(ci.mod_switch_down_to(last_ctx).is_ok());
ci.mod_switch_down_to(last_ctx)?;
ci.change_representation(Representation::Ntt);
}
});
}
Ok(())
}

/// Modulo switch the ciphertext to the next level.
pub fn mod_switch_to_next_level(&mut self) {
pub fn mod_switch_to_next_level(&mut self) -> Result<()> {
if self.level < self.par.max_level() {
self.seed = None;
self.c.iter_mut().for_each(|ci| {
for ci in self.c.iter_mut() {
ci.change_representation(Representation::PowerBasis);
assert!(ci.mod_switch_down_next().is_ok());
ci.mod_switch_down_next()?;
ci.change_representation(Representation::Ntt);
});
}
self.level += 1
}
Ok(())
}

/// Create a ciphertext from a vector of polynomials.
Expand Down Expand Up @@ -262,7 +264,7 @@ mod tests {
);
assert_eq!(ct3.level, 0);

ct3.mod_switch_to_last_level();
ct3.mod_switch_to_last_level()?;

let c0 = ct3.get(0).unwrap();
let c1 = ct3.get(1).unwrap();
Expand Down Expand Up @@ -290,7 +292,7 @@ mod tests {
let mut ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?;

assert_eq!(ct.level, 0);
ct.mod_switch_to_last_level();
ct.mod_switch_to_last_level()?;
assert_eq!(ct.level, params.max_level());

let decrypted = sk.try_decrypt(&ct)?;
Expand Down
2 changes: 1 addition & 1 deletion crates/fhe/src/bfv/keys/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl FheEncrypter<Plaintext, Ciphertext> for PublicKey {
) -> Result<Ciphertext> {
let mut ct = self.c.clone();
while ct.level != pt.level {
ct.mod_switch_to_next_level();
ct.mod_switch_to_next_level()?;
}

let ctx = self.par.ctx_at_level(ct.level)?;
Expand Down
2 changes: 1 addition & 1 deletion crates/fhe/src/bfv/ops/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl Multiplicator {
};

if self.mod_switch {
c.mod_switch_to_next_level();
c.mod_switch_to_next_level()?;
} else {
c.c.iter_mut()
.for_each(|p| p.change_representation(Representation::Ntt));
Expand Down
8 changes: 4 additions & 4 deletions crates/fhe/src/bfv/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ impl BfvParametersBuilder {

let op = NttOperator::new(&plaintext_modulus, self.degree);

let plaintext_ctx = Arc::new(Context::new(&moduli[..1], self.degree)?);
let plaintext_ctx = Context::new_arc(&moduli[..1], self.degree)?;

let mut delta_rests = vec![];
for m in &moduli {
Expand All @@ -379,8 +379,8 @@ impl BfvParametersBuilder {
let mut scalers = Vec::with_capacity(moduli.len());
let mut mul_params = Vec::with_capacity(moduli.len());
for i in 0..moduli.len() {
let rns = RnsContext::new(&moduli[..moduli.len() - i]).unwrap();
let ctx_i = Arc::new(Context::new(&moduli[..moduli.len() - i], self.degree).unwrap());
let rns = RnsContext::new(&moduli[..moduli.len() - i])?;
let ctx_i = Context::new_arc(&moduli[..moduli.len() - i], self.degree)?;
let mut p = Poly::try_convert_from(
&[rns.lift((&delta_rests).into())],
&ctx_i,
Expand Down Expand Up @@ -409,7 +409,7 @@ impl BfvParametersBuilder {
let mut mul_1_moduli = vec![];
mul_1_moduli.append(&mut moduli[..moduli_sizes.len() - i].to_vec());
mul_1_moduli.append(&mut extended_basis[..n_moduli].to_vec());
let mul_1_ctx = Arc::new(Context::new(&mul_1_moduli, self.degree)?);
let mul_1_ctx = Context::new_arc(&mul_1_moduli, self.degree)?;
mul_params.push(MultiplicationParameters::new(
&ctx_i,
&mul_1_ctx,
Expand Down
6 changes: 3 additions & 3 deletions crates/fhe/src/bfv/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,18 @@ mod tests {

let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let b = Vec::<u64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
let b = Vec::<u64>::try_decode(&plaintext?, Encoding::simd())?;
assert_eq!(b, a);

let a = unsafe { params.plaintext.center_vec_vt(&a) };
let plaintext = Plaintext::try_encode(&a, Encoding::poly(), &params);
assert!(plaintext.is_ok());
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::poly())?;
let b = Vec::<i64>::try_decode(&plaintext?, Encoding::poly())?;
assert_eq!(b, a);

let plaintext = Plaintext::try_encode(&a, Encoding::simd(), &params);
assert!(plaintext.is_ok());
let b = Vec::<i64>::try_decode(&plaintext.unwrap(), Encoding::simd())?;
let b = Vec::<i64>::try_decode(&plaintext?, Encoding::simd())?;
assert_eq!(b, a);

Ok(())
Expand Down

0 comments on commit 6234350

Please sign in to comment.