Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make shellexpand fields more robust #1471

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nativelink-config/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
byte-unit = { version = "5.1.4", default-features = false, features = ["byte"] }
humantime = "2.1.0"
serde = { version = "1.0.210", default-features = false }
serde = { version = "1.0.210", default-features = false, features = ["derive"] }
serde_json5 = "0.1.0"
shellexpand = { version = "3.1.0", default-features = false, features = ["base-0"] }

Expand Down
220 changes: 138 additions & 82 deletions nativelink-config/src/serde_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,106 +15,132 @@
use std::borrow::Cow;
use std::fmt;
use std::marker::PhantomData;
use std::str::FromStr;

use byte_unit::Byte;
use humantime::parse_duration;
use serde::de::Visitor;
use serde::{de, Deserialize, Deserializer};

/// Helper for serde macro so you can use shellexpand variables in the json configuration
/// files when the number is a numeric type.
pub fn convert_numeric_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result<T, D::Error>
pub fn convert_numeric_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct NumericVisitor<T: TryFrom<i64>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for NumericVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("an integer or a plain number string")
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
v.try_into().map_err(de::Error::custom)
T::try_from(v).map_err(de::Error::custom)
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_i64).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
(*shellexpand::env(v).map_err(de::Error::custom)?)
.parse::<T>()
.map_err(de::Error::custom)
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let s = expanded.as_ref().trim();
let parsed = s.parse::<i64>().map_err(de::Error::custom)?;
T::try_from(parsed).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_any(NumericVisitor::<T>(PhantomData))
}

/// Same as convert_numeric_with_shellexpand, but supports `Option<T>`.
pub fn convert_optional_numeric_with_shellexpand<'de, D, T, E>(
/// Same as `convert_numeric_with_shellexpand`, but supports `Option<T>`.
pub fn convert_optional_numeric_with_shellexpand<'de, D, T>(
deserializer: D,
) -> Result<Option<T>, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct OptionalNumericVisitor<T: TryFrom<i64>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for OptionalNumericVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: fmt::Display,
{
type Value = Option<T>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("an optional integer or a plain number string")
}

fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}

fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}

fn visit_some<D2: Deserializer<'de>>(
self,
deserializer: D2,
) -> Result<Self::Value, D2::Error> {
deserializer.deserialize_any(self)
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
Ok(Some(v.try_into().map_err(de::Error::custom)?))
T::try_from(v).map(Some).map_err(de::Error::custom)
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_i64).map(Some).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
if v.is_empty() {
return Err(de::Error::custom("empty string is not a valid number"));
}
if v.trim().is_empty() {
return Ok(None);
}
Ok(Some(
(*shellexpand::env(v).map_err(de::Error::custom)?)
.parse::<T>()
.map_err(de::Error::custom)?,
))
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let s = expanded.as_ref().trim();
let parsed = s.parse::<i64>().map_err(de::Error::custom)?;
T::try_from(parsed).map(Some).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_option(OptionalNumericVisitor::<T>(PhantomData))
}

/// Helper for serde macro so you can use shellexpand variables in the json configuration
/// files when the number is a numeric type.
/// Helper for serde macro so you can use shellexpand variables in the json
/// configuration files when the input is a string.
///
/// Handles YAML/JSON values according to the YAML 1.2 specification:
/// - Empty string (`""`) remains an empty string
/// - `null` becomes `None`
/// - Missing field becomes `None`
/// - Whitespace is preserved
pub fn convert_string_with_shellexpand<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<String, D::Error> {
let value = String::deserialize(deserializer)?;
Ok((*(shellexpand::env(&value).map_err(de::Error::custom)?)).to_string())
}

/// Same as convert_string_with_shellexpand, but supports `Vec<String>`.
/// Same as `convert_string_with_shellexpand`, but supports `Vec<String>`.
pub fn convert_vec_string_with_shellexpand<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Vec<String>, D::Error> {
Expand All @@ -128,94 +154,124 @@ pub fn convert_vec_string_with_shellexpand<'de, D: Deserializer<'de>>(
.collect()
}

/// Same as convert_string_with_shellexpand, but supports `Option<String>`.
/// Same as `convert_string_with_shellexpand`, but supports `Option<String>`.
pub fn convert_optional_string_with_shellexpand<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Option<String>, D::Error> {
let value = Option::<String>::deserialize(deserializer)?;
if let Some(value) = value {
Ok(Some(
(*(shellexpand::env(&value).map_err(de::Error::custom)?)).to_string(),
))
} else {
Ok(None)
match value {
Some(v) if v.is_empty() => Ok(Some(String::new())), // Keep empty string as empty string
Some(v) => Ok(Some(
(*(shellexpand::env(&v).map_err(de::Error::custom)?)).to_string(),
)),
None => Ok(None), // Handle both null and field not present
}
}

pub fn convert_data_size_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result<T, D::Error>
pub fn convert_data_size_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u128>,
<T as TryFrom<u128>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct DataSizeVisitor<T: TryFrom<u128>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for DataSizeVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u128>,
<T as TryFrom<u128>>::Error: fmt::Display,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("either a number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")")
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
T::try_from(u128::from(v)).map_err(de::Error::custom)
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
v.try_into().map_err(de::Error::custom)
if v < 0 {
return Err(de::Error::custom("Negative data size is not allowed"));
}
T::try_from(v as u128).map_err(de::Error::custom)
}

fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
T::try_from(v).map_err(de::Error::custom)
}

fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
if v < 0 {
return Err(de::Error::custom("Negative data size is not allowed"));
}
T::try_from(v as u128).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
let expanded = (*shellexpand::env(v).map_err(de::Error::custom)?).to_string();
let byte_size = Byte::parse_str(expanded, true).map_err(de::Error::custom)?;
let byte_size_u128 = byte_size.as_u128();
T::try_from(byte_size_u128.try_into().map_err(de::Error::custom)?)
.map_err(de::Error::custom)
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let s = expanded.as_ref().trim();
let byte_size = Byte::parse_str(s, true).map_err(de::Error::custom)?;
let bytes = byte_size.as_u128();
T::try_from(bytes).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_any(DataSizeVisitor::<T>(PhantomData))
}

pub fn convert_duration_with_shellexpand<'de, D, T, E>(deserializer: D) -> Result<T, D::Error>
pub fn convert_duration_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
E: fmt::Display,
T: TryFrom<i64> + FromStr<Err = E>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u64>,
<T as TryFrom<u64>>::Error: fmt::Display,
{
// define a visitor that deserializes
// `ActualData` encoded as json within a string
struct USizeVisitor<T: TryFrom<i64>>(PhantomData<T>);
struct DurationVisitor<T: TryFrom<u64>>(PhantomData<T>);

impl<'de, T, FromStrErr> de::Visitor<'de> for USizeVisitor<T>
impl<'de, T> Visitor<'de> for DurationVisitor<T>
where
FromStrErr: fmt::Display,
T: TryFrom<i64> + FromStr<Err = FromStrErr>,
<T as TryFrom<i64>>::Error: fmt::Display,
T: TryFrom<u64>,
<T as TryFrom<u64>>::Error: fmt::Display,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string containing json data")
formatter.write_str("either a number of seconds as an integer, or a string with a duration format (e.g., \"1h2m3s\", \"30m\", \"1d\")")
}

fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
T::try_from(v).map_err(de::Error::custom)
}

fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
v.try_into().map_err(de::Error::custom)
if v < 0 {
return Err(de::Error::custom("Negative duration is not allowed"));
}
T::try_from(v as u64).map_err(de::Error::custom)
}

fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_u64).map_err(de::Error::custom)
}

fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
if v < 0 {
return Err(de::Error::custom("Negative duration is not allowed"));
}
let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
T::try_from(v_u64).map_err(de::Error::custom)
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
let expanded = (*shellexpand::env(v).map_err(de::Error::custom)?).to_string();
let duration = parse_duration(&expanded).map_err(de::Error::custom)?;
let duration_secs = duration.as_secs();
T::try_from(duration_secs.try_into().map_err(de::Error::custom)?)
.map_err(de::Error::custom)
let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
let expanded = expanded.as_ref().trim();
let duration = parse_duration(expanded).map_err(de::Error::custom)?;
let secs = duration.as_secs();
T::try_from(secs).map_err(de::Error::custom)
}
}

deserializer.deserialize_any(USizeVisitor::<T>(PhantomData::<T> {}))
deserializer.deserialize_any(DurationVisitor::<T>(PhantomData))
}
Loading