Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for SHA-256 RSA PSS signatures #9

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 227 additions & 1 deletion lib/src/rsa.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,74 @@ use dep::bignum::BigNum;
use dep::bignum::runtime_bignum::BigNumInstance;
use dep::bignum::runtime_bignum::BigNumTrait;
use dep::bignum::runtime_bignum::BigNumInstanceTrait;
use dep::bignum::runtime_bignum::BigNumParamsTrait;
use crate::types::{RSA, BN1024, BN2048, BNInst1024, BNInst2048, RSA1024, RSA2048, BN1964, BNInst1964, RSA1964};

use crate::types::{RSA, BN1024, BN2048, BNInst1024, BNInst2048, RSA1024, RSA2048};
global HASH_LEN: u32 = 32;

fn reverse_array<let N: u32>(array: [u8; N]) -> [u8; N] {
let mut reversed = [0 as u8; N];
for i in 0..N {
reversed[i] = array[N - i - 1];
}
reversed
}

fn get_array_slice<let N: u32, let M: u32>(array: [u8; N], start: u32, end: u32) -> [u8; M] {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
let mut slice = [0 as u8; M];
for i in 0..M {
if i < end - start {
slice[i] = array[start + i];
}
}
slice
}

/**
* @brief Generate a mask from a seed using the MGF1 algorithm with SHA256 as the hash function
**/
fn mgf1_sha256<let SEED_LEN: u32, let MASK_LEN: u32>(seed: [u8; SEED_LEN]) -> [u8; MASK_LEN] {
// MASK_LEN must be less than 2^32 * HASH_LEN
dep::std::field::bn254::assert_lt(MASK_LEN as Field, 0xffffffff * HASH_LEN as Field + 1);

// HASH_LEN bytes are added at each iteration and there is at least 1 iteration
// so if HASH_LEN is not enough to fill MASK_LEN bytes in one iteration,
// another one is required and so on.
let iterations = (MASK_LEN / HASH_LEN) + 1;

let mut mask: [u8; MASK_LEN] = [0; MASK_LEN];
let mut hashed: [u8; HASH_LEN] = [0; HASH_LEN];

for i in 0..iterations {
// Hopefully one day we can use the line below, but for now we'll go with a fixed value
// let mut block: [u8; SEED_LEN + 4] = [0; SEED_LEN + 4];
let mut block: [u8; 256] = [0; 256];

// Copy seed to block
for j in 0..SEED_LEN {
block[j] = seed[j];
}

// Add counter to block
let counter_bytes = (i as Field).to_be_bytes(4);
for j in 0..4 {
block[SEED_LEN + j] = counter_bytes[j];
}

// Hash the block
// First SEED_LEN bytes are the seed, next 4 bytes are the counter
hashed = dep::std::hash::sha256_var(block, SEED_LEN as u64 + 4);

// Copy hashed output to mask
for j in 0..HASH_LEN {
if i * HASH_LEN + j < MASK_LEN {
mask[i * HASH_LEN + j] = hashed[j];
}
}
}

mask
}

/**
* @brief Compare a recovered byte hash from an RSA signature to the original message hash
Expand Down Expand Up @@ -48,6 +114,113 @@ fn compare_signature_sha256<let N: u32>(padded_sha256_hash: [u8; N], msg_hash: [
}
impl<BN, BNInstance, let NumBytes: u32> RSA<BN, BNInstance, NumBytes> where BN: BigNumTrait, BNInstance: BigNumInstanceTrait<BN> {
/**
* @brief Verify an RSA signature generated via the PSS signature scheme.
* @details `key_size` is the size of the RSA modulus in bits and is required to correctly decode the signature.
*
* @note We assume the public key exponent `e` is 65537
**/
pub fn verify_sha256_pss(
_: Self,
jfecher marked this conversation as resolved.
Show resolved Hide resolved
instance: BNInstance,
msg_hash: [u8; 32],
sig: BN,
key_size: u32
) -> bool {
// Exponentiate the signature assuming e = 65537
let mut exponentiated = instance.mul(sig, sig);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
exponentiated = instance.mul(exponentiated, exponentiated);
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
exponentiated = instance.mul(exponentiated, sig);
// Convert the exponentiated signature to a byte array and reverse it to
// get it in big endian order, which is much easier to work with for
// the rest of the verification process
let em:[u8; NumBytes] = reverse_array(exponentiated.to_le_bytes());

// The modulus size in bits minus 1
let em_bits = key_size - 1;
// The actual length of the encoded message without any of the leftmost 0s
let em_len = (em_bits + 7) / 8;
// The length of the modulus in bytes
let key_len = (key_size + 7) / 8;
let h_len = 32;
let s_len = 32;

// Check if emLen < hLen + sLen + 2
assert(em_len >= h_len + s_len + 2);

// Check if eM ends with 0xBC
assert(em[em.len() - 1] == 0xBC);
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved

let db_mask_len = em_len - h_len - 1;
// This offset is necessary for key sizes not divisible by 8
// Will equal 0 for key sizes divisible by 8
let offset = em.len() - key_len;
// 256 - 32 - 1 = 223
// As hash is 32 bytes and we also remove the 0xBC at the end, we have 223 bytes left for DB
let masked_db: [u8; 223] = get_array_slice(em, offset, db_mask_len + offset);
let h = get_array_slice(em, db_mask_len + offset, em.len() - 1);

// Mask the leftmost bits
let bits_to_mask = 8 * em_len - em_bits;
let mask = 0xFF >> bits_to_mask as u8;
let first_byte = masked_db[offset];
// Make sure the 8 * em_len - em_bits leftmost bits are 0
// c.f. https://github.com/RustCrypto/RSA/blob/aeedb5adf5297892fcb9e11f7c0f6c0157005c58/src/algorithms/pss.rs#L205
assert((first_byte & !mask) == 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done more cheaply by just performing a division rather than bitwise operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I switched to using powers of 2 now


// Generate dbMask using MGF1
let db_mask:[u8; 223] = mgf1_sha256(h);

// Compute DB = maskedDB xor dbMask
let mut db = [0 as u8; 223];
for i in 0..db_mask_len {
db[i] = masked_db[i] ^ db_mask[i];
}

// Set leftmost byte of DB to 0
db[0] = 0;

// Check if the leftmost octets of DB are zero
for i in 0..(em_len - h_len - s_len - 2) {
assert(db[i] == 0);
}

// Check if the octet at position emLen - hLen - sLen - 2 is 0x01
assert(db[em_len - h_len - s_len - 2] == 0x01);

// Extract salt
let salt: [u8; 32] = get_array_slice(db, db_mask_len - s_len, db_mask_len);

// Construct M'
// M' = (0x)00 00 00 00 00 00 00 00 || msg_hash || salt
let mut m_prime = [0 as u8; 72]; // 8 + h_len + s_len
for i in 8..40 {
m_prime[i] = msg_hash[i - 8];
}
for i in 40..72 {
m_prime[i] = salt[i - 40];
}

// Compute H'
let h_prime = dep::std::hash::sha256(m_prime);

// Compare H and H'
h == h_prime
}
/**
* @brief Verify an RSA signature generated via the pkcs1v15 signature scheme.
* @details The fourth function parameter is required to define the value of `NBytes`
* when converting a BigNum into a byte array, the number of bytes is required and currently cannot be inferred.
Expand Down Expand Up @@ -124,3 +297,56 @@ fn test_verify_sha256_pkcs1v15_2048() {
let rsa: RSA2048 = RSA {};
assert(rsa.verify_sha256_pkcs1v15(BNInstance, sha256_hash, signature));
}

#[test]
fn test_mgf1_sha256() {
let seed: [u8; 32] = dep::std::hash::sha256("Hello World! This is Noir-RSA".as_bytes());
let expected_mask: [u8; 32] = [
106, 93, 232, 46, 236, 203, 51, 228, 103, 104, 145, 29, 197, 74, 26, 194, 135, 200, 40, 232, 179, 172, 220, 135, 51, 185, 209, 35, 194, 131, 176, 190
];
let mask: [u8; 32] = mgf1_sha256(seed);
assert(mask == expected_mask);
}

#[test]
fn test_verify_sha256_pss_2048() {
let sha256_hash: [u8; 32] = dep::std::hash::sha256("Hello World! This is Noir-RSA".as_bytes());
let BNInstance: BNInst2048 = BigNumInstance::new(
[
0x9c149f9aa49db0ee279be7b1b9f2b1, 0x472c42ef0019a179b36a37313423d5, 0x6796724132fb7b7b14db1078126604, 0xafbc95dc121ab99608b8f4145536b7, 0xef72836154b7feaaa7b4a79e07a3fa, 0x44ef7a3c294a863a2b8820c17a5ce6, 0x7bbb5c1d1160b1136cf310155d04d3, 0x81e11b000f6d1a72cdd0c57225a29c, 0x9d02c16e2643b7d38e082bdde79b0a, 0xd8d1905fde1d7bafc516ef0a544793, 0x41aa69097fc68ca3efb4a20e953f45, 0xc1580306d50a17bdeef26eeace7d9e, 0x69ecdf56a571e2349c615851b28f62, 0x79d7c408c01d0d601a318c3c3fb840, 0x94c5ac3935445baf4de56e50c50b30, 0x4d5c5d278e3a0f7f0617908b86ef71, 0x2ad08f95d923c8c04d9d93d072e13d, 0xaf
],
[
0xabdcccb22151f142cbaa207bf6cc89, 0xeaad15b6065fcff60108c268554cfd, 0xc508f4e0a4559a78166f1739577c8e, 0xb140e6600daaeb1038e38081ed338c, 0x72e332e60e3c44cdedbe157b539714, 0xf60bff05409ef357f90d0087def9b3, 0x9c82957c546f138ee9c7edd6e8fe18, 0xe04e36bd481ece291a557975dab90, 0xfd7c8ea97e11164495d881394d304b, 0xfa5a096283f960b1256700d772e9, 0x61c0057ea007659e58ac8754fec77c, 0x57bc9dd907a5065cd6ba3d6bd6b0e9, 0x517515a3fd6408b5e5766be7a692c5, 0x2b817d4eb0fe5afd47900c0e56f93f, 0xf604e52448f245c64e0e66198018e8, 0x8db41fd0ec788dec05113586c205ea, 0x224597fab441cd104694d7d61e0a57, 0x176
]
);
let signature: BN2048 = BigNum::from_array(
[
0x17b138e0afc2893ae13d2eea1ce349, 0x5d18d222dfc4af50aa6eb31eeb8403, 0x29d2a6d0fc929addd1785228f0f350, 0x162ff4f4bfc98867ff1e9ae7e4dfa2, 0xd38c9fe831a9acb84eb5754fde653f, 0xd9629bcc05a5ba2fd4edbd257c8b22, 0xc130ab593612459d247b0a1430515f, 0x20f1b3ca2724f0f811a8c793130196, 0x507cc26d77029853a1ef7a46c7f3c1, 0xb335060a4a9f58b25173d84e50b600, 0x5836975b3491cfe85a98204fca00f2, 0x17f96851d11d36ab2b93adb7aa3c07, 0x8e36e106516fded0e68a2a1be0d6b6, 0x697f2906a826b9dee8f3f497596db6, 0x38fdb5f7ca447d7b55e3f4ab592050, 0x2511143f27e0ce03e03100532b68a5, 0x01db0955c41de01a6c2e4bb902b621, 0x3d
]
);

let rsa: RSA2048 = RSA {};
assert(rsa.verify_sha256_pss(BNInstance, sha256_hash, signature, 2048));
}

// The following test is using an unusual key size to test an edge case for key size not divisible by 8
#[test]
fn test_verify_sha256_pss_1964() {
let sha256_hash: [u8; 32] = dep::std::hash::sha256("Hello World! This is Noir-RSA".as_bytes());
let BNInstance: BNInst1964 = BigNumInstance::new(
[
0x2a035f94929bb130deaf854028f433, 0xbcfe17113a631a59158c3a81b85d4d, 0x136b271c541b8dbad5672777c48d9a, 0xa7f9a26a6fcebd3299c1604c501818, 0x8eb621862ba1bb8c432bc64c21e0a0, 0x1e0be8f31f728998e1e8783d06271d, 0x70f2a1450579e00f86bb39abe44cd8, 0xf26823e29557176366145bad958caa, 0x5b11eb8d6d8d5f0afb047b8826747a, 0xbfd7ee9d32a932bf274c75b63c8f40, 0x26bf6dc058ec401fd5f34c7ec1755e, 0xa04110ce6ae975c6f9075dc8917565, 0x24b194b23c695ba8eea35a6f336c2a, 0x153f3fec0be28f1f1636069149ba47, 0xca764dc95c1f96a325f9d51254a790, 0x67ea496073a3adcb0d093afec3c2a2, 0xcdd00d115e5
],
[
0xeb858462288906c78264566f2517bf, 0x307764ae5c67f7b179be6178742f97, 0x2eadd630b79ccfbea22e68e50701b9, 0x2683b85d13dffa94d1c986870489f1, 0x6d9f9788a78c52cab1918b4070842b, 0xd93261e5f34c5508a47183b62fb3b4, 0xd4b3a75c20554347708df7583ff81f, 0x63d7878156e7ea6a62f8c3f7e6f15f, 0x6cc531bba9e02310e414a3aa4e1a06, 0x19b50bf06c4444b2a788ec44c41c91, 0xd6bad67102a89597684b9d53f5370f, 0x726cc1262518f1b59aded3184d3ae4, 0xd018338b704c0a2535dee1fe70ce60, 0x9ec435e1f1652d2dde4e56cce77901, 0x41460a18f3fed1112e41afc93a2464, 0xc19e91fbbab21f155e4b71dbe9a7dc, 0x13e6ce9d6a14
]
);
let signature: BN1964 = BigNum::from_array(
[
0x29787913470259aee2641be22694c6, 0xec09bc1e2d77c6585d183298e7821f, 0x44b3474a5a85c68fd5b1c0d58b8a14, 0xb4305577f64847c7365e4c50bd3ce6, 0x6fa468f111116d3b240bc26a6f7aa1, 0xfa4e31aea797f4e13d832798e84698, 0x318e9578a49fa2cfea99f604597688, 0xc8be4cf64d62c6342b7ee14be70cb5, 0xc7a942ab90fbff9e8791d6a24ae841, 0xd75dc97acf08b6c3d6fbe52ed2d025, 0x06708e15a6d7649eac7c37ab06c9d6, 0xc5602556b8d74f4e98983b65eb698b, 0xd1ce86b68f8737ef86c35344f4d55e, 0xff6c2b0f9e8a62758c36bdcfe31ce9, 0xbadab8d8f0a51ff4d8d94600eb7c58, 0xcefdcd5b6b58a024020ee703e313b3, 0x0a1a27008f49
]
);

let rsa: RSA1964 = RSA {};
assert(rsa.verify_sha256_pss(BNInstance, sha256_hash, signature, 1964));
}
11 changes: 11 additions & 0 deletions lib/src/types.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@ use dep::bignum::BigNum;
use dep::bignum::runtime_bignum::BigNumInstance;
use dep::bignum::fields::Params2048;
use dep::bignum::fields::Params1024;
use dep::bignum::runtime_bignum::BigNumParamsTrait;

struct RSA<BN, BNInstance, let NumBytes: u32>{}

struct Params1964 {}
impl BigNumParamsTrait<17> for Params1964 {
fn modulus_bits() -> u32 {
1964
}
}

type BN1024 = BigNum<9, Params1024>;
type BN2048 = BigNum<18, Params2048>;
type BN1964 = BigNum<17, Params1964>;
type BNInst1024 = BigNumInstance<9, Params1024>;
type BNInst2048 = BigNumInstance<18, Params2048>;
type BNInst1964 = BigNumInstance<17, Params1964>;

type RSA1024 = RSA<BN1024, BNInst1024, 128>;
type RSA2048 = RSA<BN2048, BNInst2048, 256>;
type RSA1964 = RSA<BN1964, BNInst1964, 255>;
29 changes: 21 additions & 8 deletions signature_gen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use num_bigint::BigUint;
use rsa::pkcs1v15::Signature;
use rsa::pkcs1v15::VerifyingKey;
use rsa::{RsaPrivateKey, RsaPublicKey};
use signature::Keypair;
use signature::RandomizedSignerMut;
use std::env;
use toml::Value;

Expand Down Expand Up @@ -31,7 +33,7 @@ fn format_limbs_as_toml_value(limbs: &Vec<BigUint>) -> Vec<Value> {
.collect()
}

fn generate_2048_bit_signature_parameters(msg: &str, as_toml: bool) {
fn generate_2048_bit_signature_parameters(msg: &str, as_toml: bool, pss: bool) {
let mut hasher = Sha256::new();
hasher.update(msg.as_bytes());
let hashed_message = hasher.finalize();
Expand All @@ -48,12 +50,16 @@ fn generate_2048_bit_signature_parameters(msg: &str, as_toml: bool) {
RsaPrivateKey::new(&mut rng, bits).expect("failed to generate a key");
let pub_key: RsaPublicKey = priv_key.clone().into();

let signing_key = rsa::pkcs1v15::SigningKey::<Sha256>::new(priv_key);
let sig: Vec<u8> = signing_key.sign(msg.as_bytes()).to_vec();

let sig_bytes = &Signature::try_from(sig.as_slice()).unwrap().to_bytes();
let sig_bytes = if pss {
let mut signing_key = rsa::pss::BlindedSigningKey::<Sha256>::new(priv_key);
let sig = signing_key.sign_with_rng(&mut rng, msg.as_bytes());
sig.to_vec()
} else {
let signing_key = rsa::pkcs1v15::SigningKey::<Sha256>::new(priv_key);
signing_key.sign(msg.as_bytes()).to_vec()
};

let sig_uint: BigUint = BigUint::from_bytes_be(sig_bytes);
let sig_uint: BigUint = BigUint::from_bytes_be(&sig_bytes);

let sig_str = bn_limbs(sig_uint.clone(), 2048);

Expand Down Expand Up @@ -108,12 +114,19 @@ fn main() {
.long("toml")
.help("Print output in TOML format"),
)
.arg(
Arg::with_name("pss")
.short("p")
.long("pss")
.help("Use RSA PSS"),
)
.get_matches();

let msg = matches.value_of("msg").unwrap();
let as_toml = matches.is_present("toml");

generate_2048_bit_signature_parameters(msg, as_toml);
let pss = matches.is_present("pss");

generate_2048_bit_signature_parameters(msg, as_toml, pss);
}

fn test_signature_generation_impl() {
Expand Down
Loading