diff --git a/cli/src/template_builder.rs b/cli/src/template_builder.rs index f453ce09e2b..d7fdd085636 100644 --- a/cli/src/template_builder.rs +++ b/cli/src/template_builder.rs @@ -488,6 +488,7 @@ fn build_keyword<'a, L: TemplateLanguage<'a> + ?Sized>( name, name_span, args: vec![], + keyword_args: vec![], args_span: name_span.end_pos().span(&name_span.end_pos()), }; language diff --git a/cli/src/template_parser.rs b/cli/src/template_parser.rs index 4b5cfc8fd9b..08405b20c0f 100644 --- a/cli/src/template_parser.rs +++ b/cli/src/template_parser.rs @@ -418,6 +418,7 @@ fn parse_function_call_node(pair: Pair) -> TemplateParseResult( #[cfg(test)] mod tests { use assert_matches::assert_matches; + use jj_lib::dsl_util::KeywordArgument; use super::*; @@ -749,6 +751,15 @@ mod tests { name: function.name, name_span: empty_span(), args: normalize_list(function.args), + keyword_args: function + .keyword_args + .into_iter() + .map(|arg| KeywordArgument { + name: arg.name, + name_span: empty_span(), + value: normalize_tree(arg.value), + }) + .collect(), args_span: empty_span(), } } diff --git a/lib/src/dsl_util.rs b/lib/src/dsl_util.rs index 8e9f783f718..f67ea8f551b 100644 --- a/lib/src/dsl_util.rs +++ b/lib/src/dsl_util.rs @@ -15,7 +15,7 @@ //! Domain-specific language helpers. use std::collections::HashMap; -use std::fmt; +use std::{array, fmt}; use itertools::Itertools as _; use pest::iterators::Pairs; @@ -46,11 +46,23 @@ pub struct FunctionCallNode<'i, T> { pub name_span: pest::Span<'i>, /// List of positional arguments. pub args: Vec>, - // TODO: revset supports keyword args + /// List of keyword arguments. + pub keyword_args: Vec>, /// Span of the arguments list. pub args_span: pest::Span<'i>, } +/// Keyword argument pair in AST. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct KeywordArgument<'i, T> { + /// Parameter name. + pub name: &'i str, + /// Span of the parameter name. + pub name_span: pest::Span<'i>, + /// Value expression. + pub value: ExpressionNode<'i, T>, +} + impl<'i, T> FunctionCallNode<'i, T> { /// Ensures that no arguments passed. pub fn expect_no_arguments(&self) -> Result<(), InvalidArguments<'i>> { @@ -71,6 +83,7 @@ impl<'i, T> FunctionCallNode<'i, T> { pub fn expect_some_arguments( &self, ) -> Result<(&[ExpressionNode<'i, T>; N], &[ExpressionNode<'i, T>]), InvalidArguments<'i>> { + self.ensure_no_keyword_arguments()?; if self.args.len() >= N { let (required, rest) = self.args.split_at(N); Ok((required.try_into().unwrap(), rest)) @@ -90,6 +103,7 @@ impl<'i, T> FunctionCallNode<'i, T> { ), InvalidArguments<'i>, > { + self.ensure_no_keyword_arguments()?; let count_range = N..=(N + M); if count_range.contains(&self.args.len()) { let (required, rest) = self.args.split_at(N); @@ -105,11 +119,94 @@ impl<'i, T> FunctionCallNode<'i, T> { } } - fn invalid_arguments(&self, message: String) -> InvalidArguments<'i> { + /// Extracts N required arguments and M optional arguments. Some of them can + /// be specified as keyword arguments. + /// + /// `names` is a list of parameter names. Unnamed positional arguments + /// should be padded with `""`. + #[allow(clippy::type_complexity)] + pub fn expect_named_arguments( + &self, + names: &[&str], + ) -> Result< + ( + [&ExpressionNode<'i, T>; N], + [Option<&ExpressionNode<'i, T>>; M], + ), + InvalidArguments<'i>, + > { + if self.keyword_args.is_empty() { + let (required, optional) = self.expect_arguments::()?; + // TODO: use .each_ref() if MSRV is bumped to 1.77.0 + Ok((array::from_fn(|i| &required[i]), optional)) + } else { + let (required, optional) = self.expect_named_arguments_vec(names, N, N + M)?; + Ok(( + required.try_into().ok().unwrap(), + optional.try_into().ok().unwrap(), + )) + } + } + + #[allow(clippy::type_complexity)] + fn expect_named_arguments_vec( + &self, + names: &[&str], + min: usize, + max: usize, + ) -> Result< + ( + Vec<&ExpressionNode<'i, T>>, + Vec>>, + ), + InvalidArguments<'i>, + > { + assert!(names.len() <= max); + + if self.args.len() > max { + return Err(self.invalid_arguments_count(min, Some(max))); + } + let mut extracted = Vec::with_capacity(max); + extracted.extend(self.args.iter().map(Some)); + extracted.resize(max, None); + + for arg in &self.keyword_args { + let name = arg.name; + let span = arg.name_span.start_pos().span(&arg.value.span.end_pos()); + let pos = names.iter().position(|&n| n == name).ok_or_else(|| { + self.invalid_arguments(format!(r#"Unexpected keyword argument "{name}""#), span) + })?; + if extracted[pos].is_some() { + return Err(self.invalid_arguments( + format!(r#"Got multiple values for keyword "{name}""#), + span, + )); + } + extracted[pos] = Some(&arg.value); + } + + let optional = extracted.split_off(min); + let required = extracted.into_iter().flatten().collect_vec(); + if required.len() != min { + return Err(self.invalid_arguments_count(min, Some(max))); + } + Ok((required, optional)) + } + + fn ensure_no_keyword_arguments(&self) -> Result<(), InvalidArguments<'i>> { + if let (Some(first), Some(last)) = (self.keyword_args.first(), self.keyword_args.last()) { + let span = first.name_span.start_pos().span(&last.value.span.end_pos()); + Err(self.invalid_arguments("Unexpected keyword arguments".to_owned(), span)) + } else { + Ok(()) + } + } + + fn invalid_arguments(&self, message: String, span: pest::Span<'i>) -> InvalidArguments<'i> { InvalidArguments { name: self.name, message, - span: self.args_span, + span, } } @@ -119,7 +216,7 @@ impl<'i, T> FunctionCallNode<'i, T> { (min, Some(max)) => format!("Expected {min} to {max} arguments"), (min, None) => format!("Expected at least {min} arguments"), }; - self.invalid_arguments(message) + self.invalid_arguments(message, self.args_span) } } @@ -200,6 +297,17 @@ where name: function.name, name_span: function.name_span, args: fold_expression_nodes(folder, function.args)?, + keyword_args: function + .keyword_args + .into_iter() + .map(|arg| { + Ok(KeywordArgument { + name: arg.name, + name_span: arg.name_span, + value: folder.fold_expression(arg.value)?, + }) + }) + .try_collect()?, args_span: function.args_span, }) } @@ -451,7 +559,12 @@ where span: pest::Span<'i>, ) -> Result { if let Some((id, params, defn)) = self.aliases_map.get_function(function.name) { + // TODO: add support for keyword arguments and arity-based + // overloading (#2966)? let arity = params.len(); + function + .ensure_no_keyword_arguments() + .map_err(E::invalid_arguments)?; if function.args.len() != arity { return Err(E::invalid_arguments( function.invalid_arguments_count(arity, Some(arity)), @@ -503,3 +616,110 @@ where .dedup() .collect() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_expect_arguments() { + fn empty_span() -> pest::Span<'static> { + pest::Span::new("", 0, 0).unwrap() + } + + fn function( + name: &'static str, + args: impl Into>>, + keyword_args: impl Into>>, + ) -> FunctionCallNode<'static, u32> { + FunctionCallNode { + name, + name_span: empty_span(), + args: args.into(), + keyword_args: keyword_args.into(), + args_span: empty_span(), + } + } + + fn value(v: u32) -> ExpressionNode<'static, u32> { + ExpressionNode::new(v, empty_span()) + } + + fn keyword(name: &'static str, v: u32) -> KeywordArgument<'static, u32> { + KeywordArgument { + name, + name_span: empty_span(), + value: value(v), + } + } + + let f = function("foo", [], []); + assert!(f.expect_no_arguments().is_ok()); + assert!(f.expect_some_arguments::<0>().is_ok()); + assert!(f.expect_arguments::<0, 0>().is_ok()); + assert!(f.expect_named_arguments::<0, 0>(&[]).is_ok()); + + let f = function("foo", [value(0)], []); + assert!(f.expect_no_arguments().is_err()); + assert_eq!( + f.expect_some_arguments::<0>().unwrap(), + (&[], [value(0)].as_slice()) + ); + assert_eq!( + f.expect_some_arguments::<1>().unwrap(), + (&[value(0)], [].as_slice()) + ); + assert!(f.expect_arguments::<0, 0>().is_err()); + assert_eq!( + f.expect_arguments::<0, 1>().unwrap(), + (&[], [Some(&value(0))]) + ); + assert_eq!(f.expect_arguments::<1, 1>().unwrap(), (&[value(0)], [None])); + assert!(f.expect_named_arguments::<0, 0>(&[]).is_err()); + assert_eq!( + f.expect_named_arguments::<0, 1>(&["a"]).unwrap(), + ([], [Some(&value(0))]) + ); + assert_eq!( + f.expect_named_arguments::<1, 0>(&["a"]).unwrap(), + ([&value(0)], []) + ); + + let f = function("foo", [], [keyword("a", 0)]); + assert!(f.expect_no_arguments().is_err()); + assert!(f.expect_some_arguments::<1>().is_err()); + assert!(f.expect_arguments::<0, 1>().is_err()); + assert!(f.expect_arguments::<1, 0>().is_err()); + assert!(f.expect_named_arguments::<0, 0>(&[]).is_err()); + assert!(f.expect_named_arguments::<0, 1>(&[]).is_err()); + assert!(f.expect_named_arguments::<1, 0>(&[]).is_err()); + assert_eq!( + f.expect_named_arguments::<1, 0>(&["a"]).unwrap(), + ([&value(0)], []) + ); + assert_eq!( + f.expect_named_arguments::<1, 1>(&["a", "b"]).unwrap(), + ([&value(0)], [None]) + ); + assert!(f.expect_named_arguments::<1, 1>(&["b", "a"]).is_err()); + + let f = function("foo", [value(0)], [keyword("a", 1), keyword("b", 2)]); + assert!(f.expect_named_arguments::<0, 0>(&[]).is_err()); + assert!(f.expect_named_arguments::<1, 1>(&["a", "b"]).is_err()); + assert_eq!( + f.expect_named_arguments::<1, 2>(&["c", "a", "b"]).unwrap(), + ([&value(0)], [Some(&value(1)), Some(&value(2))]) + ); + assert_eq!( + f.expect_named_arguments::<2, 1>(&["c", "b", "a"]).unwrap(), + ([&value(0), &value(2)], [Some(&value(1))]) + ); + assert_eq!( + f.expect_named_arguments::<0, 3>(&["c", "b", "a"]).unwrap(), + ([], [Some(&value(0)), Some(&value(2)), Some(&value(1))]) + ); + + let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]); + assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err()); + } +} diff --git a/lib/src/fileset_parser.rs b/lib/src/fileset_parser.rs index 6650df480e7..4ffe51764a2 100644 --- a/lib/src/fileset_parser.rs +++ b/lib/src/fileset_parser.rs @@ -205,6 +205,7 @@ fn parse_function_call_node(pair: Pair) -> FilesetParseResult Result { parse_program(text) @@ -361,6 +363,15 @@ mod tests { name: function.name, name_span: empty_span(), args: normalize_list(function.args), + keyword_args: function + .keyword_args + .into_iter() + .map(|arg| KeywordArgument { + name: arg.name, + name_span: empty_span(), + value: normalize_tree(arg.value), + }) + .collect(), args_span: empty_span(), } }