Skip to content

Commit

Permalink
feat(s3s-policy/pattern): add PatternSet
Browse files Browse the repository at this point in the history
  • Loading branch information
Nugine committed Oct 10, 2024
1 parent e4e2eb0 commit fa09f77
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 0 deletions.
1 change: 1 addition & 0 deletions crates/s3s-policy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ license.workspace = true
indexmap = { version = "2.5.0", features = ["serde"] }
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.128"
thiserror = "1.0.64"
4 changes: 4 additions & 0 deletions crates/s3s-policy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
#![warn(
clippy::dbg_macro, //
)]
#![allow(
clippy::module_name_repetitions, //
)]

pub mod model;
pub mod pattern;

#[cfg(test)]
mod tests;
124 changes: 124 additions & 0 deletions crates/s3s-policy/src/pattern.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
pub struct PatternSet {
// TODO: rewrite the naive implementation with something like Aho-Corasick
patterns: Vec<Pattern>,
}

#[derive(Debug, thiserror::Error)]
pub enum PatternError {
#[error("Invalid pattern")]
InvalidPattern,
}

#[derive(Debug)]
struct Pattern {
bytes: Vec<u8>,
}

impl PatternSet {
/// Create a new matcher from a list of patterns.
///
/// Patterns can contain
/// + `*` to match any sequence of characters (including empty sequence)
/// + `?` to match any single character
/// + any other character to match itself
///
/// # Errors
/// Returns an error if any pattern is invalid.
pub fn new<'a>(patterns: impl IntoIterator<Item = &'a str>) -> Result<PatternSet, PatternError> {
let patterns = patterns.into_iter().map(Self::parse_pattern).collect::<Result<_, _>>()?;
Ok(PatternSet { patterns })
}

fn parse_pattern(pattern: &str) -> Result<Pattern, PatternError> {
if pattern.is_empty() {
return Err(PatternError::InvalidPattern);
}
Ok(Pattern {
bytes: pattern.as_bytes().to_owned(),
})
}

/// Check if the input matches any of the patterns.
#[must_use]
pub fn is_match(&self, input: &str) -> bool {
for pattern in &self.patterns {
if Self::match_pattern(&pattern.bytes, input.as_bytes()) {
return true;
}
}
false
}

/// <https://leetcode.com/problems/wildcard-matching/>
fn match_pattern(pattern: &[u8], input: &[u8]) -> bool {
let mut p_idx = 0;
let mut s_idx = 0;

let mut p_back = usize::MAX - 1;
let mut s_back = usize::MAX - 1;

loop {
if p_idx < pattern.len() {
let p = pattern[p_idx];
if p == b'*' {
p_idx += 1;
p_back = p_idx;
s_back = s_idx;
continue;
}

if s_idx < input.len() {
let c = input[s_idx];
if p == c || p == b'?' {
p_idx += 1;
s_idx += 1;
continue;
}
}
} else if s_idx == input.len() {
return true;
}

if p_back == pattern.len() {
return true;
}

if s_back + 1 < input.len() {
s_back += 1;
p_idx = p_back;
s_idx = s_back;
continue;
}

return false;
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_match() {
let cases = &[
("*", "", true),
("**", "", true),
("***", "abc", true),
("a", "aa", false),
("***a", "aaaa", true),
("*abc???def", "abcdefabc123def", true),
("a*c?b", "acdcb", false),
("*a*b*c*", "abc", true),
("a*b*c*", "abc", true),
("*a*b*c", "abc", true),
("a*b*c", "abc", true),
];

for &(pattern, input, expected) in cases {
let pattern = PatternSet::parse_pattern(pattern).unwrap();
let ans = PatternSet::match_pattern(&pattern.bytes, input.as_bytes());
assert_eq!(ans, expected, "pattern: {pattern:?}, input: {input:?}");
}
}
}

0 comments on commit fa09f77

Please sign in to comment.