Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: zeroize secret stuff #21

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@ version = "0.1.0"
edition = "2021"

[dependencies]
bitvec = { version = "1.0.1", default-features = false, features = ["alloc"] }
sha2 = { version = "0.10.6", default-features = false }

thiserror = { version = "1", optional = true }
zeroize = {version = "1.8.1", features = ["derive"]}

[dev-dependencies]
hex = "0.4.3"

[features]
default = ["std", "sufficient-memory"]
std = ["thiserror"]
std = []
sufficient-memory = []

[lib]
Expand Down
38 changes: 29 additions & 9 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#[cfg(feature = "std")]
use std::fmt::{Debug, Display, Formatter, Result};
use std::string::String;

#[cfg(not(feature = "std"))]
use alloc::string::String;

#[cfg(feature = "std")]
use thiserror::Error;
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};

#[cfg(not(feature = "std"))]
use core::fmt::{Debug, Display, Formatter, Result as FmtResult};

#[derive(Debug)]
#[cfg_attr(feature = "std", derive(Error))]
pub enum ErrorWordList {
pub enum ErrorMnemonic {
DamagedWord,
InvalidChecksum,
InvalidEntropy,
Expand All @@ -14,10 +20,24 @@ pub enum ErrorWordList {
WordsNumber,
}

// TODO: provide actual error descriptions.
#[cfg(feature = "std")]
impl Display for ErrorWordList {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
<Self as Debug>::fmt(self, f)
impl ErrorMnemonic {
fn error_text(&self) -> String {
match &self {
ErrorMnemonic::DamagedWord => String::from("Unable to extract a word from the word list."),
ErrorMnemonic::InvalidChecksum => String::from("Invalid text mnemonic: the checksum does not match."),
ErrorMnemonic::InvalidEntropy => String::from("Unable to calculate the mnemonic from entropy. Invalid entropy length."),
ErrorMnemonic::InvalidWordNumber => String::from("Ordinal number for word requested is higher than total number of words in the word list."),
ErrorMnemonic::NoWord => String::from("Requested word in not in the word list."),
ErrorMnemonic::WordsNumber => String::from("Invalid text mnemonic: unexpected number of words."),
}
}
}

impl Display for ErrorMnemonic {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "{}", self.error_text())
}
}

#[cfg(feature = "std")]
impl std::error::Error for ErrorMnemonic {}
131 changes: 98 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#[cfg(not(feature = "std"))]
extern crate alloc;

#[cfg(feature = "std")]
#[macro_use]
extern crate std;

#[cfg(not(feature = "std"))]
Expand All @@ -12,8 +14,8 @@ use alloc::{string::String, vec::Vec};
#[cfg(feature = "std")]
use std::{string::String, vec::Vec};

use bitvec::prelude::{BitSlice, BitVec, Msb0};
use sha2::{Digest, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};

pub mod error;

Expand All @@ -26,43 +28,44 @@ mod tests;
#[cfg(any(feature = "sufficient-memory", test))]
pub mod wordlist;

use crate::error::ErrorWordList;
use crate::error::ErrorMnemonic;

pub const TOTAL_WORDS: usize = 2048;
pub const WORD_MAX_LEN: usize = 8;
pub const SEPARATOR_LEN: usize = 1;

pub const MAX_SEED_LEN: usize = 24;

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Zeroize)]
pub struct Bits11(u16);

impl Bits11 {
pub fn bits(self) -> u16 {
self.0
}
pub fn from(i: u16) -> Result<Self, ErrorWordList> {
pub fn from(i: u16) -> Result<Self, ErrorMnemonic> {
if (i as usize) < TOTAL_WORDS {
Ok(Self(i))
} else {
Err(ErrorWordList::InvalidWordNumber)
Err(ErrorMnemonic::InvalidWordNumber)
}
}
}

#[derive(Clone, Debug)]
pub struct WordListElement<L: AsWordList + ?Sized> {
pub word: L::Word,
pub bits11: Bits11,
}

pub trait AsWordList {
type Word: AsRef<str>;
fn get_word(&self, bits: Bits11) -> Result<Self::Word, ErrorWordList>;
fn get_word(&self, bits: Bits11) -> Result<Self::Word, ErrorMnemonic>;
fn get_words_by_prefix(
&self,
prefix: &str,
) -> Result<Vec<WordListElement<Self>>, ErrorWordList>;
fn bits11_for_word(&self, word: &str) -> Result<Bits11, ErrorWordList>;
) -> Result<Vec<WordListElement<Self>>, ErrorMnemonic>;
fn bits11_for_word(&self, word: &str) -> Result<Bits11, ErrorMnemonic>;
}

#[derive(Debug, Copy, Clone)]
Expand All @@ -75,14 +78,14 @@ pub enum MnemonicType {
}

impl MnemonicType {
fn from(len: usize) -> Result<Self, ErrorWordList> {
fn from(len: usize) -> Result<Self, ErrorMnemonic> {
match len {
12 => Ok(Self::Words12),
15 => Ok(Self::Words15),
18 => Ok(Self::Words18),
21 => Ok(Self::Words21),
24 => Ok(Self::Words24),
_ => Err(ErrorWordList::WordsNumber),
_ => Err(ErrorMnemonic::WordsNumber),
}
}
fn checksum_bits(&self) -> u8 {
Expand All @@ -108,27 +111,67 @@ impl MnemonicType {
}
}

#[derive(Clone, Debug, ZeroizeOnDrop)]
struct BitsHelper {
bits: Vec<bool>,
}

impl BitsHelper {
fn with_capacity(cap: usize) -> Self {
Self {
bits: Vec::with_capacity(cap),
}
}

fn extend_from_byte(&mut self, byte: u8) {
for i in (0..BITS_IN_BYTE).rev() {
let bit = (byte & (1 << i)) != 0;
self.bits.push(bit);
}
}

fn extend_from_bits11(&mut self, bits11: &Bits11) {
let two_bytes = bits11.0.to_be_bytes();

// last 3 bits of first byte - others are always zero
for i in (0..BITS_IN_U11 % BITS_IN_BYTE).rev() {
let bit = (two_bytes[0] & (1 << i)) != 0;
self.bits.push(bit);
}

// all bits of second byte
self.extend_from_byte(two_bytes[1])
}
}

pub const BITS_IN_BYTE: usize = 8;
pub const BITS_IN_U11: usize = 11;

#[derive(Clone, Debug, ZeroizeOnDrop)]
pub struct WordSet {
pub bits11_set: Vec<Bits11>,
}

impl WordSet {
pub fn from_entropy(entropy: &[u8]) -> Result<Self, ErrorWordList> {
pub fn from_entropy(entropy: &[u8]) -> Result<Self, ErrorMnemonic> {
if entropy.len() < 16 || entropy.len() > 32 || entropy.len() % 4 != 0 {
return Err(ErrorWordList::InvalidEntropy);
return Err(ErrorMnemonic::InvalidEntropy);
}

let checksum_byte = sha256_first_byte(entropy);
let mut entropy_bits: BitVec<u8, Msb0> = BitVec::with_capacity((entropy.len() + 1) * 8);
entropy_bits.extend_from_bitslice(&BitVec::<u8, Msb0>::from_slice(entropy));
entropy_bits.extend_from_bitslice(&BitVec::<u8, Msb0>::from_element(checksum_byte));

let mut bits11_set: Vec<Bits11> = Vec::new();
for chunk in entropy_bits.chunks_exact(11usize) {
let mut entropy_bits = BitsHelper::with_capacity((entropy.len() + 1) * BITS_IN_BYTE);
for byte in entropy {
entropy_bits.extend_from_byte(*byte);
}
entropy_bits.extend_from_byte(checksum_byte);

let mut bits11_set: Vec<Bits11> = Vec::with_capacity(MAX_SEED_LEN);
for chunk in entropy_bits.bits.chunks_exact(BITS_IN_U11) {
let mut bits11: u16 = 0;
for (i, bit) in chunk.into_iter().enumerate() {
for (i, bit) in chunk.iter().rev().enumerate() {
if *bit {
bits11 |= 1 << (10 - i)
bits11 |= 1 << i
}
}
bits11_set.push(Bits11(bits11));
Expand All @@ -146,7 +189,7 @@ impl WordSet {
&mut self,
word: &str,
wordlist: &L,
) -> Result<(), ErrorWordList> {
) -> Result<(), ErrorMnemonic> {
if self.bits11_set.len() < MAX_SEED_LEN {
let bits11 = wordlist.bits11_for_word(word)?;
self.bits11_set.push(bits11);
Expand All @@ -158,19 +201,40 @@ impl WordSet {
MnemonicType::from(self.bits11_set.len()).is_ok()
}

pub fn to_entropy(&self) -> Result<Vec<u8>, ErrorWordList> {
pub fn to_entropy(&self) -> Result<Vec<u8>, ErrorMnemonic> {
let mnemonic_type = MnemonicType::from(self.bits11_set.len())?;

let mut entropy_bits: BitVec<u8, Msb0> = BitVec::with_capacity(mnemonic_type.total_bits());
let mut entropy_bits = BitsHelper::with_capacity(mnemonic_type.total_bits());

for bits11 in self.bits11_set.iter() {
entropy_bits.extend_from_bitslice(
&BitSlice::<u8, Msb0>::from_slice(&bits11.bits().to_be_bytes())[5..16],
)
entropy_bits.extend_from_bits11(bits11);
}

let mut entropy = entropy_bits.into_vec();
let entropy_len = mnemonic_type.entropy_bits() / 8;
let mut entropy: Vec<u8> = Vec::with_capacity(mnemonic_type.total_bits() / BITS_IN_BYTE);

let chunks_exact = entropy_bits.bits.chunks_exact(BITS_IN_BYTE);
let remainder = chunks_exact.remainder();

for chunk in chunks_exact {
let mut byte: u8 = 0;
for (i, bit) in chunk.iter().rev().enumerate() {
if *bit {
byte |= 1 << i
}
}
entropy.push(byte);
}

let mut last_byte: u8 = 0;
for (i, bit) in remainder.iter().rev().enumerate() {
if *bit {
last_byte |= 1 << (BITS_IN_BYTE - remainder.len() + i)
}
}

entropy.push(last_byte);

let entropy_len = mnemonic_type.entropy_bits().div_ceil(BITS_IN_BYTE);

let actual_checksum = checksum(entropy[entropy_len], mnemonic_type.checksum_bits());

Expand All @@ -181,15 +245,16 @@ impl WordSet {
let expected_checksum = checksum(checksum_byte, mnemonic_type.checksum_bits());

if actual_checksum != expected_checksum {
Err(ErrorWordList::InvalidChecksum)
Err(ErrorMnemonic::InvalidChecksum)
} else {
Ok(entropy)
}
}

pub fn to_phrase<L: AsWordList>(&self, wordlist: &L) -> Result<String, ErrorWordList> {
let mut phrase =
String::with_capacity(self.bits11_set.len() * (WORD_MAX_LEN + SEPARATOR_LEN) - 1);
pub fn to_phrase<L: AsWordList>(&self, wordlist: &L) -> Result<String, ErrorMnemonic> {
let mut phrase = String::with_capacity(
self.bits11_set.len() * (WORD_MAX_LEN + SEPARATOR_LEN) - SEPARATOR_LEN,
);
for bits11 in self.bits11_set.iter() {
if !phrase.is_empty() {
phrase.push(' ')
Expand All @@ -202,8 +267,8 @@ impl WordSet {
}

fn checksum(source: u8, bits: u8) -> u8 {
assert!(bits <= 8);
source >> (8 - bits)
assert!(bits <= BITS_IN_BYTE as u8);
source >> (BITS_IN_BYTE as u8 - bits)
}

fn sha256_first_byte(input: &[u8]) -> u8 {
Expand Down
10 changes: 5 additions & 5 deletions src/regular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::vec::Vec;

use crate::error::ErrorWordList;
use crate::error::ErrorMnemonic;
use crate::wordlist::WORDLIST_ENGLISH;
use crate::{AsWordList, Bits11, WordListElement};

Expand All @@ -13,15 +13,15 @@ pub struct InternalWordList;
impl AsWordList for InternalWordList {
type Word = &'static str;

fn get_word(&self, bits: Bits11) -> Result<Self::Word, ErrorWordList> {
fn get_word(&self, bits: Bits11) -> Result<Self::Word, ErrorMnemonic> {
let word_order = bits.bits() as usize;
Ok(WORDLIST_ENGLISH[word_order])
}

fn get_words_by_prefix(
&self,
prefix: &str,
) -> Result<Vec<WordListElement<Self>>, ErrorWordList> {
) -> Result<Vec<WordListElement<Self>>, ErrorMnemonic> {
let mut out: Vec<WordListElement<Self>> = Vec::new();
for (i, word) in WORDLIST_ENGLISH.iter().enumerate() {
if word.starts_with(prefix) {
Expand All @@ -34,12 +34,12 @@ impl AsWordList for InternalWordList {
Ok(out)
}

fn bits11_for_word(&self, word: &str) -> Result<Bits11, ErrorWordList> {
fn bits11_for_word(&self, word: &str) -> Result<Bits11, ErrorMnemonic> {
for (i, element) in WORDLIST_ENGLISH.iter().enumerate() {
if element == &word {
return Bits11::from(i as u16);
}
}
Err(ErrorWordList::NoWord)
Err(ErrorMnemonic::NoWord)
}
}
Loading
Loading