Skip to content

Commit

Permalink
Make shellexpand fields more robust (TraceMachina#1471)
Browse files Browse the repository at this point in the history
These turned out to be much buggier than anticipated. The new
implementation behaves like the old one but no longer requires `FromStr`
and has consistent behavior with what we'd expect from yaml-to-json
conversions.
  • Loading branch information
aaronmondal authored Nov 12, 2024
1 parent 545793c commit b6cf659
Show file tree
Hide file tree
Showing 3 changed files with 472 additions and 114 deletions.
2 changes: 1 addition & 1 deletion nativelink-config/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
byte-unit = { version = "5.1.4", default-features = false, features = ["byte"] }
humantime = "2.1.0"
serde = { version = "1.0.210", default-features = false }
serde = { version = "1.0.210", default-features = false, features = ["derive"] }
serde_json5 = "0.1.0"
shellexpand = { version = "3.1.0", default-features = false, features = ["base-0"] }

Expand Down
220 changes: 138 additions & 82 deletions nativelink-config/src/serde_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,106 +15,132 @@
use std::borrow::Cow;
use std::fmt;
use std::marker::PhantomData;
use std::str::FromStr;

use byte_unit::Byte;
use humantime::parse_duration;
use serde::de::Visitor;
use serde::{de, Deserialize, Deserializer};

/// Helper for serde macro so you can use shellexpand variables in the json configuration
/// files when the number is a numeric type.
pub fn convert_numeric_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result<T, D::Error>
pub fn convert_numeric_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct NumericVisitor<T: TryFrom<i64>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for NumericVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("an integer or a plain number string")
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
v.try_into().map_err(de::Error::custom)
T::try_from(v).map_err(de::Error::custom)
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_i64).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
(*shellexpand::env(v).map_err(de::Error::custom)?)
.parse::<T>()
.map_err(de::Error::custom)
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let s = expanded.as_ref().trim();
let parsed = s.parse::<i64>().map_err(de::Error::custom)?;
T::try_from(parsed).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_any(NumericVisitor::<T>(PhantomData))
}

/// Same as convert_numeric_with_shellexpand, but supports `Option<T>`.
pub fn convert_optional_numeric_with_shellexpand<'de, D, T, E>(
/// Same as `convert_numeric_with_shellexpand`, but supports `Option<T>`.
pub fn convert_optional_numeric_with_shellexpand<'de, D, T>(
deserializer: D,
) -> Result<Option<T>, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct OptionalNumericVisitor<T: TryFrom<i64>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for OptionalNumericVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
type Value = Option<T>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("an optional integer or a plain number string")
}

fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}

fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}

fn visit_some<D2: Deserializer<'de>>(
self,
deserializer: D2,
) -> Result<Self::Value, D2::Error> {
deserializer.deserialize_any(self)
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
Ok(Some(v.try_into().map_err(de::Error::custom)?))
T::try_from(v).map(Some).map_err(de::Error::custom)
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_i64).map(Some).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
if v.is_empty() {
return Err(de::Error::custom("empty string is not a valid number"));
}
if v.trim().is_empty() {
return Ok(None);
}
Ok(Some(
(*shellexpand::env(v).map_err(de::Error::custom)?)
.parse::<T>()
.map_err(de::Error::custom)?,
))
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let s = expanded.as_ref().trim();
let parsed = s.parse::<i64>().map_err(de::Error::custom)?;
T::try_from(parsed).map(Some).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_option(OptionalNumericVisitor::<T>(PhantomData))
}

/// Helper for serde macro so you can use shellexpand variables in the json configuration
/// files when the number is a numeric type.
/// Helper for serde macro so you can use shellexpand variables in the json
/// configuration files when the input is a string.
///
/// Handles YAML/JSON values according to the YAML 1.2 specification:
/// - Empty string (`""`) remains an empty string
/// - `null` becomes `None`
/// - Missing field becomes `None`
/// - Whitespace is preserved
pub fn convert_string_with_shellexpand<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<String, D::Error> {
let value = String::deserialize(deserializer)?;
Ok((*(shellexpand::env(&value).map_err(de::Error::custom)?)).to_string())
}

/// Same as convert_string_with_shellexpand, but supports `Vec<String>`.
/// Same as `convert_string_with_shellexpand`, but supports `Vec<String>`.
pub fn convert_vec_string_with_shellexpand<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Vec<String>, D::Error> {
Expand All @@ -128,94 +154,124 @@ pub fn convert_vec_string_with_shellexpand<'de, D: Deserializer<'de>>(
.collect()
}

/// Same as convert_string_with_shellexpand, but supports `Option<String>`.
/// Same as `convert_string_with_shellexpand`, but supports `Option<String>`.
pub fn convert_optional_string_with_shellexpand<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Option<String>, D::Error> {
let value = Option::<String>::deserialize(deserializer)?;
if let Some(value) = value {
Ok(Some(
(*(shellexpand::env(&value).map_err(de::Error::custom)?)).to_string(),
))
} else {
Ok(None)
match value {
Some(v) if v.is_empty() => Ok(Some(String::new())), // Keep empty string as empty string
Some(v) => Ok(Some(
(*(shellexpand::env(&v).map_err(de::Error::custom)?)).to_string(),
)),
None => Ok(None), // Handle both null and field not present
}
}

pub fn convert_data_size_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result<T, D::Error>
pub fn convert_data_size_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u128>,
<T as TryFrom<u128>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct DataSizeVisitor<T: TryFrom<u128>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for DataSizeVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u128>,
<T as TryFrom<u128>>::Error: fmt::Display,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("either a number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")")
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
T::try_from(u128::from(v)).map_err(de::Error::custom)
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
v.try_into().map_err(de::Error::custom)
if v < 0 {
return Err(de::Error::custom("Negative data size is not allowed"));
}
T::try_from(v as u128).map_err(de::Error::custom)
}

fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
T::try_from(v).map_err(de::Error::custom)
}

fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
if v < 0 {
return Err(de::Error::custom("Negative data size is not allowed"));
}
T::try_from(v as u128).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
let expanded = (*shellexpand::env(v).map_err(de::Error::custom)?).to_string();
let byte_size = Byte::parse_str(expanded, true).map_err(de::Error::custom)?;
let byte_size_u128 = byte_size.as_u128();
T::try_from(byte_size_u128.try_into().map_err(de::Error::custom)?)
.map_err(de::Error::custom)
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let s = expanded.as_ref().trim();
let byte_size = Byte::parse_str(s, true).map_err(de::Error::custom)?;
let bytes = byte_size.as_u128();
T::try_from(bytes).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_any(DataSizeVisitor::<T>(PhantomData))
}

pub fn convert_duration_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result<T, D::Error>
pub fn convert_duration_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u64>,
<T as TryFrom<u64>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct DurationVisitor<T: TryFrom<u64>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for DurationVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u64>,
<T as TryFrom<u64>>::Error: fmt::Display,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("either a number of seconds as an integer, or a string with a duration format (e.g., \"1h2m3s\", \"30m\", \"1d\")")
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
T::try_from(v).map_err(de::Error::custom)
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
v.try_into().map_err(de::Error::custom)
if v < 0 {
return Err(de::Error::custom("Negative duration is not allowed"));
}
T::try_from(v as u64).map_err(de::Error::custom)
}

fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_u64).map_err(de::Error::custom)
}

fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
if v < 0 {
return Err(de::Error::custom("Negative duration is not allowed"));
}
let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_u64).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
let expanded = (*shellexpand::env(v).map_err(de::Error::custom)?).to_string();
let duration = parse_duration(&expanded).map_err(de::Error::custom)?;
let duration_secs = duration.as_secs();
T::try_from(duration_secs.try_into().map_err(de::Error::custom)?)
.map_err(de::Error::custom)
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let expanded = expanded.as_ref().trim();
let duration = parse_duration(expanded).map_err(de::Error::custom)?;
let secs = duration.as_secs();
T::try_from(secs).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_any(DurationVisitor::<T>(PhantomData))
}
Loading

0 comments on commit b6cf659

Please sign in to comment.