diff --git a/nativelink-config/Cargo.toml b/nativelink-config/Cargo.toml index 94428da49..c4881854b 100644 --- a/nativelink-config/Cargo.toml +++ b/nativelink-config/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] byte-unit = { version = "5.1.4", default-features = false, features = ["byte"] } humantime = "2.1.0" -serde = { version = "1.0.210", default-features = false } +serde = { version = "1.0.210", default-features = false, features = ["derive"] } serde_json5 = "0.1.0" shellexpand = { version = "3.1.0", default-features = false, features = ["base-0"] } diff --git a/nativelink-config/src/serde_utils.rs b/nativelink-config/src/serde_utils.rs index 99299be90..330e65846 100644 --- a/nativelink-config/src/serde_utils.rs +++ b/nativelink-config/src/serde_utils.rs @@ -15,98 +15,124 @@ use std::borrow::Cow; use std::fmt; use std::marker::PhantomData; -use std::str::FromStr; use byte_unit::Byte; use humantime::parse_duration; +use serde::de::Visitor; use serde::{de, Deserialize, Deserializer}; /// Helper for serde macro so you can use shellexpand variables in the json configuration /// files when the number is a numeric type. -pub fn convert_numeric_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result +pub fn convert_numeric_with_shellexpand<'de, D, T>(deserializer: D) -> Result where D: Deserializer<'de>, - E: fmt::Display, - T: TryFrom + FromStr, + T: TryFrom, >::Error: fmt::Display, { - // define a visitor that deserializes - // `ActualData` encoded as json within a string - struct USizeVisitor>(PhantomData); + struct NumericVisitor>(PhantomData); - impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor + impl<'de, T> Visitor<'de> for NumericVisitor where - FromStrErr: fmt::Display, - T: TryFrom + FromStr, + T: TryFrom, >::Error: fmt::Display, { type Value = T; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string containing json data") + formatter.write_str("an integer or a plain number string") } fn visit_i64(self, v: i64) -> Result { - v.try_into().map_err(de::Error::custom) + T::try_from(v).map_err(de::Error::custom) + } + + fn visit_u64(self, v: u64) -> Result { + let v_i64 = i64::try_from(v).map_err(de::Error::custom)?; + T::try_from(v_i64).map_err(de::Error::custom) } fn visit_str(self, v: &str) -> Result { - (*shellexpand::env(v).map_err(de::Error::custom)?) - .parse::() - .map_err(de::Error::custom) + let expanded = shellexpand::env(v).map_err(de::Error::custom)?; + let s = expanded.as_ref().trim(); + let parsed = s.parse::().map_err(de::Error::custom)?; + T::try_from(parsed).map_err(de::Error::custom) } } - deserializer.deserialize_any(USizeVisitor::(PhantomData:: {})) + deserializer.deserialize_any(NumericVisitor::(PhantomData)) } /// Same as convert_numeric_with_shellexpand, but supports `Option`. -pub fn convert_optional_numeric_with_shellexpand<'de, D, T, E>( +pub fn convert_optional_numeric_with_shellexpand<'de, D, T>( deserializer: D, ) -> Result, D::Error> where D: Deserializer<'de>, - E: fmt::Display, - T: TryFrom + FromStr, + T: TryFrom, >::Error: fmt::Display, { - // define a visitor that deserializes - // `ActualData` encoded as json within a string - struct USizeVisitor>(PhantomData); + struct OptionalNumericVisitor>(PhantomData); - impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor + impl<'de, T> Visitor<'de> for OptionalNumericVisitor where - FromStrErr: fmt::Display, - T: TryFrom + FromStr, + T: TryFrom, >::Error: fmt::Display, { type Value = Option; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string containing json data") + formatter.write_str("an optional integer or a plain number string") + } + + fn visit_none(self) -> Result { + Ok(None) + } + + fn visit_unit(self) -> Result { + Ok(None) + } + + fn visit_some>( + self, + deserializer: D2, + ) -> Result { + deserializer.deserialize_any(self) } fn visit_i64(self, v: i64) -> Result { - Ok(Some(v.try_into().map_err(de::Error::custom)?)) + T::try_from(v).map(Some).map_err(de::Error::custom) + } + + fn visit_u64(self, v: u64) -> Result { + let v_i64 = i64::try_from(v).map_err(de::Error::custom)?; + T::try_from(v_i64).map(Some).map_err(de::Error::custom) } fn visit_str(self, v: &str) -> Result { if v.is_empty() { + return Err(de::Error::custom("empty string is not a valid number")); + } + if v.trim().is_empty() { return Ok(None); } - Ok(Some( - (*shellexpand::env(v).map_err(de::Error::custom)?) - .parse::() - .map_err(de::Error::custom)?, - )) + let expanded = shellexpand::env(v).map_err(de::Error::custom)?; + let s = expanded.as_ref().trim(); + let parsed = s.parse::().map_err(de::Error::custom)?; + T::try_from(parsed).map(Some).map_err(de::Error::custom) } } - deserializer.deserialize_any(USizeVisitor::(PhantomData:: {})) + deserializer.deserialize_option(OptionalNumericVisitor::(PhantomData)) } -/// Helper for serde macro so you can use shellexpand variables in the json configuration -/// files when the number is a numeric type. +/// Helper for serde macro so you can use shellexpand variables in the json +/// configuration files when the input is a string. +/// +/// Handles YAML/JSON values according to the YAML 1.2 specification: +/// - Empty string (`""`) remains an empty string +/// - `null` becomes `None` +/// - Missing field becomes `None` +/// - Whitespace is preserved pub fn convert_string_with_shellexpand<'de, D: Deserializer<'de>>( deserializer: D, ) -> Result { @@ -133,89 +159,119 @@ pub fn convert_optional_string_with_shellexpand<'de, D: Deserializer<'de>>( deserializer: D, ) -> Result, D::Error> { let value = Option::::deserialize(deserializer)?; - if let Some(value) = value { - Ok(Some( - (*(shellexpand::env(&value).map_err(de::Error::custom)?)).to_string(), - )) - } else { - Ok(None) + match value { + Some(v) if v.is_empty() => Ok(Some(String::new())), // Keep empty string as empty string + Some(v) => Ok(Some( + (*(shellexpand::env(&v).map_err(de::Error::custom)?)).to_string(), + )), + None => Ok(None), // Handle both null and field not present } } -pub fn convert_data_size_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result +pub fn convert_data_size_with_shellexpand<'de, D, T>(deserializer: D) -> Result where D: Deserializer<'de>, - E: fmt::Display, - T: TryFrom + FromStr, - >::Error: fmt::Display, + T: TryFrom, + >::Error: fmt::Display, { - // define a visitor that deserializes - // `ActualData` encoded as json within a string - struct USizeVisitor>(PhantomData); + struct DataSizeVisitor>(PhantomData); - impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor + impl<'de, T> Visitor<'de> for DataSizeVisitor where - FromStrErr: fmt::Display, - T: TryFrom + FromStr, - >::Error: fmt::Display, + T: TryFrom, + >::Error: fmt::Display, { type Value = T; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string containing json data") + formatter.write_str("either a number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")") + } + + fn visit_u64(self, v: u64) -> Result { + T::try_from(u128::from(v)).map_err(de::Error::custom) } fn visit_i64(self, v: i64) -> Result { - v.try_into().map_err(de::Error::custom) + if v < 0 { + return Err(de::Error::custom("Negative data size is not allowed")); + } + T::try_from(v as u128).map_err(de::Error::custom) + } + + fn visit_u128(self, v: u128) -> Result { + T::try_from(v).map_err(de::Error::custom) + } + + fn visit_i128(self, v: i128) -> Result { + if v < 0 { + return Err(de::Error::custom("Negative data size is not allowed")); + } + T::try_from(v as u128).map_err(de::Error::custom) } fn visit_str(self, v: &str) -> Result { - let expanded = (*shellexpand::env(v).map_err(de::Error::custom)?).to_string(); - let byte_size = Byte::parse_str(expanded, true).map_err(de::Error::custom)?; - let byte_size_u128 = byte_size.as_u128(); - T::try_from(byte_size_u128.try_into().map_err(de::Error::custom)?) - .map_err(de::Error::custom) + let expanded = shellexpand::env(v).map_err(de::Error::custom)?; + let s = expanded.as_ref().trim(); + let byte_size = Byte::parse_str(s, true).map_err(de::Error::custom)?; + let bytes = byte_size.as_u128(); + T::try_from(bytes).map_err(de::Error::custom) } } - deserializer.deserialize_any(USizeVisitor::(PhantomData:: {})) + deserializer.deserialize_any(DataSizeVisitor::(PhantomData)) } -pub fn convert_duration_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result +pub fn convert_duration_with_shellexpand<'de, D, T>(deserializer: D) -> Result where D: Deserializer<'de>, - E: fmt::Display, - T: TryFrom + FromStr, - >::Error: fmt::Display, + T: TryFrom, + >::Error: fmt::Display, { - // define a visitor that deserializes - // `ActualData` encoded as json within a string - struct USizeVisitor>(PhantomData); + struct DurationVisitor>(PhantomData); - impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor + impl<'de, T> Visitor<'de> for DurationVisitor where - FromStrErr: fmt::Display, - T: TryFrom + FromStr, - >::Error: fmt::Display, + T: TryFrom, + >::Error: fmt::Display, { type Value = T; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string containing json data") + formatter.write_str("either a number of seconds as an integer, or a string with a duration format (e.g., \"1h2m3s\", \"30m\", \"1d\")") + } + + fn visit_u64(self, v: u64) -> Result { + T::try_from(v).map_err(de::Error::custom) } fn visit_i64(self, v: i64) -> Result { - v.try_into().map_err(de::Error::custom) + if v < 0 { + return Err(de::Error::custom("Negative duration is not allowed")); + } + T::try_from(v as u64).map_err(de::Error::custom) + } + + fn visit_u128(self, v: u128) -> Result { + let v_u64 = u64::try_from(v).map_err(de::Error::custom)?; + T::try_from(v_u64).map_err(de::Error::custom) + } + + fn visit_i128(self, v: i128) -> Result { + if v < 0 { + return Err(de::Error::custom("Negative duration is not allowed")); + } + let v_u64 = u64::try_from(v).map_err(de::Error::custom)?; + T::try_from(v_u64).map_err(de::Error::custom) } fn visit_str(self, v: &str) -> Result { - let expanded = (*shellexpand::env(v).map_err(de::Error::custom)?).to_string(); - let duration = parse_duration(&expanded).map_err(de::Error::custom)?; - let duration_secs = duration.as_secs(); - T::try_from(duration_secs.try_into().map_err(de::Error::custom)?) - .map_err(de::Error::custom) + let expanded = shellexpand::env(v).map_err(de::Error::custom)?; + let s = expanded.as_ref().trim(); + let duration = parse_duration(s).map_err(de::Error::custom)?; + let secs = duration.as_secs(); + T::try_from(secs).map_err(de::Error::custom) } } - deserializer.deserialize_any(USizeVisitor::(PhantomData:: {})) + deserializer.deserialize_any(DurationVisitor::(PhantomData)) } diff --git a/nativelink-config/tests/deserialization_test.rs b/nativelink-config/tests/deserialization_test.rs index 872600893..3af4961f2 100644 --- a/nativelink-config/tests/deserialization_test.rs +++ b/nativelink-config/tests/deserialization_test.rs @@ -14,54 +14,356 @@ use nativelink_config::serde_utils::{ convert_data_size_with_shellexpand, convert_duration_with_shellexpand, + convert_optional_numeric_with_shellexpand, convert_optional_string_with_shellexpand, }; -use pretty_assertions::assert_eq; use serde::Deserialize; -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct DurationEntity { #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] duration: usize, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct DataSizeEntity { #[serde(default, deserialize_with = "convert_data_size_with_shellexpand")] data_size: usize, } -#[test] -fn test_duration_human_readable_deserialize() { - let example = r#" - {"duration": "1m 10s"} - "#; - let deserialized: DurationEntity = serde_json5::from_str(example).unwrap(); - assert_eq!(deserialized.duration, 70); +#[derive(Deserialize, Debug)] +struct OptionalNumericEntity { + #[serde( + default, + deserialize_with = "convert_optional_numeric_with_shellexpand" + )] + value: Option, } -#[test] -fn test_duration_usize_deserialize() { - let example = r#" - {"duration": 10} - "#; - let deserialized: DurationEntity = serde_json5::from_str(example).unwrap(); - assert_eq!(deserialized.duration, 10); +#[derive(Deserialize, Debug)] +struct OptionalStringEntity { + #[serde(default, deserialize_with = "convert_optional_string_with_shellexpand")] + value: Option, } -#[test] -fn test_data_size_unit_deserialize() { - let example = r#" - {"data_size": "1KiB"} - "#; - let deserialized: DataSizeEntity = serde_json5::from_str(example).unwrap(); - assert_eq!(deserialized.data_size, 1024); +mod duration_tests { + use super::*; + + #[test] + fn test_duration_parsing() { + let examples = [ + // Basic duration tests + (r#"{"duration": "1m 10s"}"#, 70), + (r#"{"duration": 10}"#, 10), + (r#"{"duration": " 1m 10s "}"#, 70), + // Complex duration formats + (r#"{"duration": "1y3w4d5h6m7s"}"#, 33_735_967), + (r#"{"duration": "0s"}"#, 0), + (r#"{"duration": "1ns"}"#, 0), // Sub-second rounds to 0 + (r#"{"duration": "999h"}"#, 3_596_400), + // Large numbers + (r#"{"duration": 0}"#, 0), + (r#"{"duration": 1000}"#, 1000), + // u32::MAX + (r#"{"duration": 4294967295}"#, 4_294_967_295), + ]; + + for (input, expected) in examples { + let deserialized: DurationEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.duration, expected); + } + } + + #[test] + fn test_duration_negative_rejected() { + let example = r#"{"duration": -10}"#; + let result: Result = serde_json5::from_str(example); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Negative duration is not allowed")); + } + + #[test] + fn test_duration_errors() { + let examples = [ + ( + r#"{"duration": true}"#, + "expected either a number of seconds as an integer, or a string with a duration format (e.g., \"1h2m3s\", \"30m\", \"1d\")", + ), + ( + r#"{"duration": "invalid"}"#, + "expected number at 0", + ), + ( + r#"{"duration": "999999999999999999999s"}"#, + "number is too large", + ), + ]; + + for (input, expected_error) in examples { + let error = serde_json5::from_str::(input) + .unwrap_err() + .to_string(); + assert!(error.contains(expected_error)); + } + } + + #[test] + fn test_duration_whitespace_handling() { + let example = r#"{"duration": " 1m 10s "}"#; + let deserialized: DurationEntity = serde_json5::from_str(example).unwrap(); + assert_eq!(deserialized.duration, 70); + } + + #[test] + fn test_large_duration_numbers() { + let examples = [ + // u32::MAX + (r#"{"duration": 4294967295}"#, 4_294_967_295), + // u64::MAX - this will fail to parse as usize on 64-bit systems + // (r#"{"duration": 18446744073709551615}"#, 18_446_744_073_709_551_615), + ]; + + for (input, expected) in examples { + let deserialized: DurationEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.duration, expected); + } + } +} + +mod data_size_tests { + use super::*; + + #[test] + fn test_data_size_parsing() { + let examples = [ + // Basic size tests + (r#"{"data_size": "1KiB"}"#, 1024), + (r#"{"data_size": "1MiB"}"#, 1_048_576), + (r#"{"data_size": "1MB"}"#, 1_000_000), + (r#"{"data_size": "1M"}"#, 1_000_000), + (r#"{"data_size": "1Mi"}"#, 1_048_576), + // Large sizes + (r#"{"data_size": "9EiB"}"#, 10_376_293_541_461_622_784), + (r#"{"data_size": 10}"#, 10), + // Edge cases + (r#"{"data_size": "1B"}"#, 1), + (r#"{"data_size": "1.5GB"}"#, 1_500_000_000), + (r#"{"data_size": "1.5GiB"}"#, 1_610_612_736), + (r#"{"data_size": "0B"}"#, 0), + // Whitespace handling + (r#"{"data_size": " 1KiB "}"#, 1024), + ]; + + for (input, expected) in examples { + let deserialized: DataSizeEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.data_size, expected); + } + } + + #[test] + fn test_data_size_negative_rejected() { + let example = r#"{"data_size": -1024}"#; + let result: Result = serde_json5::from_str(example); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Negative data size is not allowed")); + } + + #[test] + fn test_data_size_case_insensitivity() { + let examples = [ + r#"{"data_size": "1kb"}"#, + r#"{"data_size": "1KB"}"#, + r#"{"data_size": "1Kb"}"#, + r#"{"data_size": "1kB"}"#, + ]; + + for input in examples { + let deserialized: DataSizeEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.data_size, 1000); // All should be 1 kilobyte + } + } + + #[test] + fn test_data_size_errors() { + let examples = [ + ( + r#"{"data_size": true}"#, + "expected either a number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")", + ), + ( + r#"{"data_size": "invalid"}"#, + "the character 'i' is not a number", + ), + ( + r#"{"data_size": "999999999999999999999B"}"#, + "the value 999999999999999999999 exceeds the valid range", + ), + ]; + + for (input, expected_error) in examples { + let error = serde_json5::from_str::(input) + .unwrap_err() + .to_string(); + assert!(error.contains(expected_error)); + } + } +} + +mod optional_values_tests { + use super::*; + + #[test] + fn test_optional_numeric_values() { + let examples = [ + (r#"{"value": null}"#, None), + (r#"{"value": 42}"#, Some(42)), + (r#"{}"#, None), // Missing field + ]; + + for (input, expected) in examples { + let deserialized: OptionalNumericEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.value, expected); + } + } + + #[test] + fn test_optional_numeric_large_numbers() { + // Test i64::MAX for optional numeric + let input = r#"{"value": "9223372036854775807"}"#; + let result: OptionalNumericEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(result.value, Some(9_223_372_036_854_775_807)); + } + + #[test] + fn test_optional_numeric_errors() { + let examples = [ + ( + r#"{"value": {}}"#, + "expected an optional integer or a plain number string", + ), + ( + r#"{"value": "not_a_number"}"#, + "invalid digit found in string", + ), + ( + r#"{"value": "999999999999999999999"}"#, + "number too large to fit in target type", + ), + ]; + + for (input, expected_error) in examples { + let error = serde_json5::from_str::(input) + .unwrap_err() + .to_string(); + assert!(error.contains(expected_error)); + } + } + + #[test] + fn test_optional_string_values() { + let examples = [ + (r#"{"value": ""}"#, Some(String::new())), + (r#"{"value": null}"#, None), + (r#"{}"#, None), + (r#"{"value": " "}"#, Some(" ".to_string())), + ]; + + for (input, expected) in examples { + let deserialized: OptionalStringEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.value, expected); + } + } + + #[test] + fn test_mixed_optional_values() { + #[derive(Deserialize)] + struct MixedOptionals { + #[serde( + default, + deserialize_with = "convert_optional_numeric_with_shellexpand" + )] + number: Option, + #[serde(default, deserialize_with = "convert_optional_string_with_shellexpand")] + string: Option, + } + + let examples = [ + ( + r#"{"number": null, "string": "hello"}"#, + None, + Some("hello".to_string()), + ), + (r#"{"number": 42, "string": null}"#, Some(42), None), + (r#"{"number": null, "string": null}"#, None, None), + (r#"{}"#, None, None), + ( + r#"{"number": null, "string": ""}"#, + None, + Some(String::new()), + ), + ( + r#"{"number": null, "string": " "}"#, + None, + Some(" ".to_string()), + ), + ]; + + for (input, expected_number, expected_string) in examples { + let deserialized: MixedOptionals = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.number, expected_number); + assert_eq!(deserialized.string, expected_string); + } + } } -#[test] -fn test_data_size_usize_deserialize() { - let example = r#" - {"data_size": 10} - "#; - let deserialized: DataSizeEntity = serde_json5::from_str(example).unwrap(); - assert_eq!(deserialized.data_size, 10); +mod shellexpand_tests { + use super::*; + + #[test] + fn test_shellexpand_functionality() { + std::env::set_var("TEST_DURATION", "5m"); + std::env::set_var("TEST_SIZE", "1GB"); + std::env::set_var("TEST_NUMBER", "42"); + std::env::set_var("TEST_VAR", "test_value"); + std::env::set_var("EMPTY_VAR", ""); + + // Test duration with environment variable + let duration_result = + serde_json5::from_str::(r#"{"duration": "${TEST_DURATION}"}"#).unwrap(); + assert_eq!(duration_result.duration, 300); + + // Test data size with environment variable + let size_result = + serde_json5::from_str::(r#"{"data_size": "${TEST_SIZE}"}"#).unwrap(); + assert_eq!(size_result.data_size, 1_000_000_000); + + // Test optional numeric with environment variable + let numeric_result = + serde_json5::from_str::(r#"{"value": "${TEST_NUMBER}"}"#) + .unwrap(); + assert_eq!(numeric_result.value, Some(42)); + + // Test optional string with environment variable + let string_result = + serde_json5::from_str::(r#"{"value": "${TEST_VAR}"}"#).unwrap(); + assert_eq!(string_result.value, Some("test_value".to_string())); + + // Test optional string with empty environment variable + let empty_string_result = + serde_json5::from_str::(r#"{"value": "${EMPTY_VAR}"}"#).unwrap(); + assert_eq!(empty_string_result.value, Some(String::new())); + + // Test undefined environment variable + let undefined_result = + serde_json5::from_str::(r#"{"value": "${UNDEFINED_VAR}"}"#); + assert!(undefined_result + .unwrap_err() + .to_string() + .contains("environment variable not found")); + } }