diff --git a/Cargo.lock b/Cargo.lock index e133f3bd64..1facda817a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1968,6 +1968,7 @@ dependencies = [ "prost-types", "rand", "serde", + "serde_json", "sha2", "tokio", "tokio-stream", @@ -2803,9 +2804,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.127" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "indexmap 2.5.0", "itoa", diff --git a/nativelink-store/src/filesystem_store.rs b/nativelink-store/src/filesystem_store.rs index 101c4a78bd..9aedba07d8 100644 --- a/nativelink-store/src/filesystem_store.rs +++ b/nativelink-store/src/filesystem_store.rs @@ -307,7 +307,7 @@ fn make_temp_digest(digest: &mut DigestInfo) { .fetch_add(1, Ordering::Relaxed) .to_le_bytes(), ); - digest.set_packed_hash(hash); + digest.set_packed_hash(*hash); } impl LenEntry for FileEntryImpl { @@ -595,13 +595,10 @@ impl FilesystemStore { } pub async fn get_file_entry_for_digest(&self, digest: &DigestInfo) -> Result, Error> { - self.evicting_map.get(digest).await.ok_or_else(|| { - make_err!( - Code::NotFound, - "{} not found in filesystem store", - digest.hash_str() - ) - }) + self.evicting_map + .get(digest) + .await + .ok_or_else(|| make_err!(Code::NotFound, "{digest} not found in filesystem store")) } async fn update_file<'a>( @@ -852,13 +849,10 @@ impl StoreDriver for FilesystemStore { return Ok(()); } - let entry = self.evicting_map.get(&digest).await.ok_or_else(|| { - make_err!( - Code::NotFound, - "{} not found in filesystem store", - digest.hash_str() - ) - })?; + let entry = + self.evicting_map.get(&digest).await.ok_or_else(|| { + make_err!(Code::NotFound, "{digest} not found in filesystem store") + })?; let read_limit = length.unwrap_or(usize::MAX) as u64; let mut resumeable_temp_file = entry.read_file_part(offset as u64, read_limit).await?; diff --git a/nativelink-store/src/grpc_store.rs b/nativelink-store/src/grpc_store.rs index 8366d7638d..cbbe1a0795 100644 --- a/nativelink-store/src/grpc_store.rs +++ b/nativelink-store/src/grpc_store.rs @@ -588,7 +588,7 @@ impl StoreDriver for GrpcStore { "{}/uploads/{}/blobs/{}/{}", &self.instance_name, Uuid::new_v4().hyphenated().encode_lower(&mut buf), - digest.hash_str(), + digest.packed_hash(), digest.size_bytes(), ); @@ -673,7 +673,7 @@ impl StoreDriver for GrpcStore { let resource_name = format!( "{}/blobs/{}/{}", &self.instance_name, - digest.hash_str(), + digest.packed_hash(), digest.size_bytes(), ); diff --git a/nativelink-store/src/verify_store.rs b/nativelink-store/src/verify_store.rs index 8d9c87358c..9ee668b939 100644 --- a/nativelink-store/src/verify_store.rs +++ b/nativelink-store/src/verify_store.rs @@ -21,6 +21,7 @@ use nativelink_metric::MetricsComponent; use nativelink_util::buf_channel::{ make_buf_channel_pair, DropCloserReadHalf, DropCloserWriteHalf, }; +use nativelink_util::common::PackedHash; use nativelink_util::digest_hasher::{ default_digest_hasher_func, DigestHasher, ACTIVE_HASHER_FUNC, }; @@ -61,7 +62,7 @@ impl VerifyStore { mut tx: DropCloserWriteHalf, mut rx: DropCloserReadHalf, maybe_expected_digest_size: Option, - original_hash: &[u8; 32], + original_hash: &PackedHash, mut maybe_hasher: Option<&mut D>, ) -> Result<(), Error> { let mut sum_size: u64 = 0; @@ -119,9 +120,7 @@ impl VerifyStore { if original_hash != hash_result { self.hash_verification_failures.inc(); return Err(make_input_err!( - "Hashes do not match, got: {} but digest hash was {}", - hex::encode(original_hash), - hex::encode(hash_result), + "Hashes do not match, got: {original_hash} but digest hash was {hash_result}", )); } } diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index b5598b8724..0a3ec302ca 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -79,6 +79,7 @@ rust_test_suite( srcs = [ "tests/buf_channel_test.rs", "tests/channel_body_for_tests_test.rs", + "tests/common_test.rs", "tests/evicting_map_test.rs", "tests/fastcdc_test.rs", "tests/fs_test.rs", @@ -109,6 +110,7 @@ rust_test_suite( "@crates//:parking_lot", "@crates//:pretty_assertions", "@crates//:rand", + "@crates//:serde_json", "@crates//:sha2", "@crates//:tokio", "@crates//:tokio-stream", diff --git a/nativelink-util/Cargo.toml b/nativelink-util/Cargo.toml index d77113e001..b7887d03ec 100644 --- a/nativelink-util/Cargo.toml +++ b/nativelink-util/Cargo.toml @@ -23,7 +23,7 @@ pin-project = "1.1.5" # Release PR: https://github.com/tokio-rs/console/pull/576 console-subscriber = { git = "https://github.com/tokio-rs/console", rev = "5f6faa2" , default-features = false } futures = { version = "0.3.30", default-features = false } -hex = { version = "0.4.3", default-features = false } +hex = { version = "0.4.3", default-features = false, features = ["std"] } hyper = "1.4.1" hyper-util = "0.1.6" lru = { version = "0.12.3", default-features = false } @@ -49,3 +49,4 @@ nativelink-macro = { path = "../nativelink-macro" } http-body-util = "0.1.2" pretty_assertions = { version = "1.4.0", features = ["std"] } rand = { version = "0.8.5", default-features = false } +serde_json = { version = "1.0.128", default-features = false } diff --git a/nativelink-util/src/common.rs b/nativelink-util/src/common.rs index 272d8b5403..1556da0c89 100644 --- a/nativelink-util/src/common.rs +++ b/nativelink-util/src/common.rs @@ -15,8 +15,9 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::fmt; -use std::fmt::Display; use std::hash::Hash; +use std::io::{Cursor, Write}; +use std::ops::{Deref, DerefMut}; use bytes::{BufMut, Bytes, BytesMut}; use nativelink_error::{make_input_err, Error, ResultExt}; @@ -25,12 +26,14 @@ use nativelink_metric::{ }; use nativelink_proto::build::bazel::remote::execution::v2::Digest; use prost::Message; -use serde::{Deserialize, Serialize}; +use serde::de::Visitor; +use serde::ser::Error as _; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tracing::{event, Level}; pub use crate::fs; -#[derive(Serialize, Deserialize, Default, Clone, Copy, Eq, PartialEq, Hash)] +#[derive(Default, Clone, Copy, Eq, PartialEq, Hash)] #[repr(C)] pub struct DigestInfo { /// Raw hash in packed form. @@ -67,16 +70,21 @@ impl DigestInfo { let size_bytes = size_bytes .try_into() .map_err(|_| make_input_err!("Could not convert {} into u64", size_bytes))?; + // The proto `Digest` takes an i64, so to keep compatibility + // we only allow sizes that can fit into an i64. + if size_bytes > i64::MAX as u64 { + return Err(make_input_err!( + "Size bytes is too large: {} - max: {}", + size_bytes, + i64::MAX + )); + } Ok(DigestInfo { size_bytes, packed_hash, }) } - pub fn hash_str(&self) -> String { - format!("{}", self.packed_hash) - } - pub const fn zero_digest() -> DigestInfo { DigestInfo { size_bytes: 0, @@ -84,8 +92,8 @@ impl DigestInfo { } } - pub const fn packed_hash(&self) -> &[u8; 32] { - &self.packed_hash.0 + pub const fn packed_hash(&self) -> &PackedHash { + &self.packed_hash } pub fn set_packed_hash(&mut self, packed_hash: [u8; 32]) { @@ -95,19 +103,157 @@ impl DigestInfo { pub const fn size_bytes(&self) -> u64 { self.size_bytes } + + /// Returns a struct that can turn the `DigestInfo` into a string. + const fn stringifier(&self) -> DigestStackStringifier<'_> { + DigestStackStringifier::new(self) + } +} + +/// Counts the number of digits a number needs if it were to be +/// converted to a string. +const fn count_digits(mut num: u64) -> usize { + let mut count = 0; + while num != 0 { + count += 1; + num /= 10; + } + count +} + +/// An optimized version of a function that can convert a `DigestInfo` +/// into a str on the stack. +struct DigestStackStringifier<'a> { + digest: &'a DigestInfo, + /// Buffer that can hold the string representation of the `DigestInfo`. + /// - Hex is '2 * sizeof(PackedHash)'. + /// - Digits can be at most `count_digits(u64::MAX)`. + /// - We also have a hyphen separator. + buf: [u8; std::mem::size_of::() * 2 + count_digits(u64::MAX) + 1], } -impl Display for DigestInfo { +impl<'a> DigestStackStringifier<'a> { + const fn new(digest: &'a DigestInfo) -> Self { + DigestStackStringifier { + digest, + buf: [0u8; std::mem::size_of::() * 2 + count_digits(u64::MAX) + 1], + } + } + + fn as_str(&mut self) -> Result<&str, Error> { + // Populate the buffer and return the amount of bytes written + // to the buffer. + let len = { + let mut cursor = Cursor::new(&mut self.buf[..]); + cursor + .write_fmt(format_args!( + "{}-{}", + self.digest.packed_hash(), + self.digest.size_bytes() + )) + .err_tip(|| { + format!( + "Could not serialize DigestInfo into string - {}", + self.digest, + ) + })?; + cursor.position() as usize + }; + // Convert the buffer into utf8 string. + std::str::from_utf8(&self.buf[..len]).map_err(|e| { + make_input_err!( + "Could not convert [u8] to string - {} - {:?} - {:?}", + self.digest, + self.buf, + e, + ) + }) + } +} + +/// Custom serializer for `DigestInfo` because the default Serializer +/// would try to encode the data as a byte array, but we use {hex}-{size}. +impl Serialize for DigestInfo { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut stringifier = self.stringifier(); + serializer.serialize_str( + stringifier + .as_str() + .err_tip(|| "During serialization of DigestInfo") + .map_err(S::Error::custom)?, + ) + } +} + +/// Custom deserializer for `DigestInfo` becaues the default Deserializer +/// would try to decode the data as a byte array, but we use {hex}-{size}. +impl<'de> Deserialize<'de> for DigestInfo { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct DigestInfoVisitor; + impl<'a> Visitor<'a> for DigestInfoVisitor { + type Value = DigestInfo; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a string representing a DigestInfo") + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + let Some((hash, size)) = s.split_once('-') else { + return Err(E::custom( + "Invalid DigestInfo format, expected '-' separator", + )); + }; + let size_bytes = size + .parse::() + .map_err(|e| E::custom(format!("Could not parse size_bytes: {e:?}")))?; + DigestInfo::try_new(hash, size_bytes) + .map_err(|e| E::custom(format!("Could not create DigestInfo: {e:?}"))) + } + } + deserializer.deserialize_str(DigestInfoVisitor) + } +} + +impl fmt::Display for DigestInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}-{}", self.packed_hash, self.size_bytes) + let mut stringifier = self.stringifier(); + f.write_str( + stringifier + .as_str() + .err_tip(|| "During serialization of DigestInfo") + .map_err(|e| { + event!( + Level::ERROR, + "Could not convert DigestInfo to string - {e:?}" + ); + fmt::Error + })?, + ) } } impl fmt::Debug for DigestInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("DigestInfo") - .field(&format!("{}-{}", self.packed_hash, self.size_bytes)) - .finish() + let mut stringifier = self.stringifier(); + match stringifier.as_str() { + Ok(s) => f.debug_tuple("DigestInfo").field(&s).finish(), + Err(e) => { + event!( + Level::ERROR, + "Could not convert DigestInfo to string - {e:?}" + ); + Err(fmt::Error) + } + } } } @@ -162,7 +308,7 @@ impl TryFrom<&Digest> for DigestInfo { impl From for Digest { fn from(val: DigestInfo) -> Self { Digest { - hash: val.hash_str(), + hash: val.packed_hash.to_string(), size_bytes: val.size_bytes.try_into().unwrap_or_else(|e| { event!( Level::ERROR, @@ -180,7 +326,7 @@ impl From for Digest { impl From<&DigestInfo> for Digest { fn from(val: &DigestInfo) -> Self { Digest { - hash: val.hash_str(), + hash: val.packed_hash.to_string(), size_bytes: val.size_bytes.try_into().unwrap_or_else(|e| { event!( Level::ERROR, @@ -196,10 +342,10 @@ impl From<&DigestInfo> for Digest { } #[derive(Serialize, Deserialize, Default, Clone, Copy, Eq, PartialEq, Hash, PartialOrd, Ord)] -struct PackedHash([u8; 32]); +pub struct PackedHash([u8; 32]); impl PackedHash { - pub const fn new() -> Self { + const fn new() -> Self { PackedHash([0; 32]) } @@ -225,9 +371,17 @@ impl fmt::Display for PackedHash { } } -impl fmt::Debug for PackedHash { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_fmt(format_args!("{self}")) +impl Deref for PackedHash { + type Target = [u8; 32]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for PackedHash { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } diff --git a/nativelink-util/tests/common_test.rs b/nativelink-util/tests/common_test.rs new file mode 100644 index 0000000000..d102bd03d7 --- /dev/null +++ b/nativelink-util/tests/common_test.rs @@ -0,0 +1,121 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use nativelink_error::{make_input_err, Error}; +use nativelink_macro::nativelink_test; +use nativelink_util::common::DigestInfo; +use pretty_assertions::assert_eq; + +const MIN_DIGEST: &str = "0000000000000000000000000000000000000000000000000000000000000000-0"; +const MAX_SAFE_DIGEST: &str = + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff-9223372036854775807"; +const MAX_UNSAFE_DIGEST: &str = + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff-18446744073709551615"; + +#[nativelink_test] +async fn digest_info_min_max_test() -> Result<(), Error> { + { + const DIGEST: DigestInfo = DigestInfo::new([0u8; 32], u64::MIN); + assert_eq!(&format!("{DIGEST}"), MIN_DIGEST); + } + { + const DIGEST: DigestInfo = DigestInfo::new([255u8; 32], u64::MAX); + assert_eq!(&format!("{DIGEST}"), MAX_UNSAFE_DIGEST); + } + Ok(()) +} + +#[nativelink_test] +async fn digest_info_try_new_min_max_test() -> Result<(), Error> { + { + let digest_parts: (&str, u64) = MIN_DIGEST + .split_once('-') + .map(|(s, n)| (s, n.parse().unwrap())) + .unwrap(); + let digest = DigestInfo::try_new(digest_parts.0, digest_parts.1).unwrap(); + assert_eq!(&format!("{digest}"), MIN_DIGEST); + } + { + let digest_parts: (&str, u64) = MAX_SAFE_DIGEST + .split_once('-') + .map(|(s, n)| (s, n.parse().unwrap())) + .unwrap(); + let digest = DigestInfo::try_new(digest_parts.0, digest_parts.1).unwrap(); + assert_eq!(&format!("{digest}"), MAX_SAFE_DIGEST); + } + { + let digest_parts: (&str, u64) = MAX_UNSAFE_DIGEST + .split_once('-') + .map(|(s, n)| (s, n.parse().unwrap())) + .unwrap(); + let digest_res = DigestInfo::try_new(digest_parts.0, digest_parts.1); + assert_eq!( + digest_res, + Err(make_input_err!( + "Size bytes is too large: 18446744073709551615 - max: 9223372036854775807" + )) + ); + } + Ok(()) +} + +#[nativelink_test] +async fn digest_info_serialize_test() -> Result<(), Error> { + { + const DIGEST: DigestInfo = DigestInfo::new([0u8; 32], u64::MIN); + assert_eq!( + serde_json::to_string(&DIGEST).unwrap(), + format!("\"{MIN_DIGEST}\"") + ); + } + { + const DIGEST: DigestInfo = DigestInfo::new([255u8; 32], i64::MAX as u64); + assert_eq!( + serde_json::to_string(&DIGEST).unwrap(), + format!("\"{MAX_SAFE_DIGEST}\"") + ); + } + { + const DIGEST: DigestInfo = DigestInfo::new([255u8; 32], u64::MAX); + assert_eq!( + serde_json::to_string(&DIGEST).unwrap(), + format!("\"{MAX_UNSAFE_DIGEST}\"") + ); + } + Ok(()) +} + +#[nativelink_test] +async fn digest_info_deserialize_test() -> Result<(), Error> { + { + assert_eq!( + serde_json::from_str::(&format!("\"{MIN_DIGEST}\"")).unwrap(), + DigestInfo::new([0u8; 32], u64::MIN) + ); + } + { + assert_eq!( + serde_json::from_str::(&format!("\"{MAX_SAFE_DIGEST}\"")).unwrap(), + DigestInfo::new([255u8; 32], i64::MAX as u64) + ); + } + { + let digest_res = serde_json::from_str::(&format!("\"{MAX_UNSAFE_DIGEST}\"")); + assert_eq!( + format!("{}", digest_res.err().unwrap()), + "Could not create DigestInfo: Error { code: InvalidArgument, messages: [\"Size bytes is too large: 18446744073709551615 - max: 9223372036854775807\"] } at line 1 column 87", + ); + } + Ok(()) +} diff --git a/nativelink-util/tests/proto_stream_utils_test.rs b/nativelink-util/tests/proto_stream_utils_test.rs index b1e19104ec..2d20d7e062 100644 --- a/nativelink-util/tests/proto_stream_utils_test.rs +++ b/nativelink-util/tests/proto_stream_utils_test.rs @@ -39,7 +39,7 @@ async fn ensure_no_errors_if_only_first_message_has_resource_name_set() -> Resul let message1 = WriteRequest { resource_name: format!( "{INSTANCE_NAME}/uploads/some-uuid/blobs/{}/{}", - DIGEST.hash_str(), + DIGEST.packed_hash(), DIGEST.size_bytes() ), write_offset: 0, diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index d9393f8851..570ea22fe0 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -1472,10 +1472,16 @@ impl UploadActionResults { "digest_function", hasher.proto_digest_func().as_str_name().to_lowercase(), ); - template_str.replace("action_digest_hash", action_digest_info.hash_str()); + template_str.replace( + "action_digest_hash", + action_digest_info.packed_hash().to_string(), + ); template_str.replace("action_digest_size", action_digest_info.size_bytes()); if let Some(historical_digest_info) = maybe_historical_digest_info { - template_str.replace("historical_results_hash", historical_digest_info.hash_str()); + template_str.replace( + "historical_results_hash", + format!("{}", historical_digest_info.packed_hash()), + ); template_str.replace( "historical_results_size", historical_digest_info.size_bytes(),