From 7a75c696a6b634364b51de6ce0324e7522fe6274 Mon Sep 17 00:00:00 2001 From: oeb25 Date: Wed, 31 Jul 2024 14:23:24 +0200 Subject: [PATCH] Add user defined functions and sorts and refactor term sorts --- Cargo.lock | 5 +- lowlevel/src/ast.rs | 6 +- lowlevel/src/lib.rs | 91 ++++++++-- smtlib/Cargo.toml | 1 + smtlib/examples/queens.rs | 20 +-- smtlib/examples/queens_bv.rs | 18 +- smtlib/examples/queens_bv2.rs | 18 +- smtlib/examples/simplify.rs | 4 +- smtlib/src/funs.rs | 51 ++++++ smtlib/src/lib.rs | 30 +++- smtlib/src/solver.rs | 159 ++++++++++++++++-- smtlib/src/sorts.rs | 120 +++++++++++++ smtlib/src/terms.rs | 148 ++++++++++++---- smtlib/src/theories/core.rs | 41 +++-- smtlib/src/theories/fixed_size_bit_vectors.rs | 21 +-- smtlib/src/theories/ints.rs | 27 ++- smtlib/src/theories/reals.rs | 19 ++- xtask/src/spec.toml | 2 +- 18 files changed, 642 insertions(+), 139 deletions(-) create mode 100644 smtlib/src/funs.rs create mode 100644 smtlib/src/sorts.rs diff --git a/Cargo.lock b/Cargo.lock index b1756b9..3c31ea5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,9 +382,9 @@ checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" [[package]] name = "indexmap" -version = "2.1.0" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", "hashbrown", @@ -795,6 +795,7 @@ name = "smtlib" version = "0.2.0" dependencies = [ "futures", + "indexmap", "insta", "itertools", "miette", diff --git a/lowlevel/src/ast.rs b/lowlevel/src/ast.rs index 402627b..86b353c 100644 --- a/lowlevel/src/ast.rs +++ b/lowlevel/src/ast.rs @@ -1246,10 +1246,10 @@ impl SmtlibParse for DatatypeDec { Err(p.stuck("datatype_dec")) } } -/// `( (*) )` +/// `( (*) )` #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct FunctionDec(pub Symbol, pub Vec, pub Sort); +pub struct FunctionDec(pub Symbol, pub Vec, pub Sort); impl std::fmt::Display for FunctionDec { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "({} ({}) {})", self.0, self.1.iter().format(" "), self.2) @@ -1268,7 +1268,7 @@ impl SmtlibParse for FunctionDec { p.expect(Token::LParen)?; let m0 = ::parse(p)?; p.expect(Token::LParen)?; - let m1 = p.any::()?; + let m1 = p.any::()?; p.expect(Token::RParen)?; let m2 = ::parse(p)?; p.expect(Token::RParen)?; diff --git a/lowlevel/src/lib.rs b/lowlevel/src/lib.rs index 6bc966e..302f429 100644 --- a/lowlevel/src/lib.rs +++ b/lowlevel/src/lib.rs @@ -34,9 +34,44 @@ pub enum Error { IO(#[from] std::io::Error), } -#[derive(Debug)] +pub trait Logger: 'static { + fn exec(&self, cmd: &ast::Command); + fn response(&self, cmd: &ast::Command, res: &str); +} + +impl Logger for (F, G) +where + F: Fn(&ast::Command) + 'static, + G: Fn(&ast::Command, &str) + 'static, +{ + fn exec(&self, cmd: &ast::Command) { + (self.0)(cmd) + } + + fn response(&self, cmd: &ast::Command, res: &str) { + (self.1)(cmd, res) + } +} + pub struct Driver { backend: B, + logger: Option>, +} + +impl std::fmt::Debug for Driver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Driver") + .field("backend", &self.backend) + .field( + "logger", + if self.logger.is_some() { + &"Some(...)" + } else { + &"None" + }, + ) + .finish() + } } impl Driver @@ -44,15 +79,26 @@ where B: Backend, { pub fn new(backend: B) -> Result { - let mut driver = Self { backend }; + let mut driver = Self { + backend, + logger: None, + }; driver.exec(&Command::SetOption(ast::Option::PrintSuccess(true)))?; Ok(driver) } + pub fn set_logger(&mut self, logger: impl Logger) { + self.logger = Some(Box::new(logger)) + } pub fn exec(&mut self, cmd: &Command) -> Result { - // println!("> {cmd}"); + if let Some(logger) = &self.logger { + logger.exec(cmd); + } let res = self.backend.exec(cmd)?; + if let Some(logger) = &self.logger { + logger.response(cmd, &res); + } let res = if let Some(res) = cmd.parse_response(&res)? { GeneralResponse::SpecificSuccessResponse(res) } else { @@ -67,12 +113,28 @@ pub mod tokio { use crate::{ ast::{self, Command, GeneralResponse}, backend::tokio::TokioBackend, - Error, + Error, Logger, }; - #[derive(Debug)] pub struct TokioDriver { backend: B, + logger: Option>, + } + + impl std::fmt::Debug for TokioDriver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TokioDriver") + .field("backend", &self.backend) + .field( + "logger", + if self.logger.is_some() { + &"Some(...)" + } else { + &"None" + }, + ) + .finish() + } } impl TokioDriver @@ -80,7 +142,10 @@ pub mod tokio { B: TokioBackend, { pub async fn new(backend: B) -> Result { - let mut driver = Self { backend }; + let mut driver = Self { + backend, + logger: None, + }; driver .exec(&Command::SetOption(ast::Option::PrintSuccess(true))) @@ -88,9 +153,17 @@ pub mod tokio { Ok(driver) } + pub fn set_logger(&mut self, logger: impl Logger) { + self.logger = Some(Box::new(logger)) + } pub async fn exec(&mut self, cmd: &Command) -> Result { - // println!("> {cmd}"); + if let Some(logger) = &self.logger { + logger.exec(cmd); + } let res = self.backend.exec(cmd).await?; + if let Some(logger) = &self.logger { + logger.response(cmd, &res); + } let res = if let Some(res) = cmd.parse_response(&res)? { GeneralResponse::SpecificSuccessResponse(res) } else { @@ -107,9 +180,7 @@ impl Term { match self { Term::SpecConstant(_) => HashSet::new(), Term::Identifier(q) => std::iter::once(q).collect(), - Term::Application(q, args) => std::iter::once(q) - .chain(args.iter().flat_map(|arg| arg.all_consts())) - .collect(), + Term::Application(_, args) => args.iter().flat_map(|arg| arg.all_consts()).collect(), Term::Let(_, _) => todo!(), // TODO Term::Forall(_, _) => HashSet::new(), diff --git a/smtlib/Cargo.toml b/smtlib/Cargo.toml index b1d8f15..d25ea62 100644 --- a/smtlib/Cargo.toml +++ b/smtlib/Cargo.toml @@ -23,6 +23,7 @@ miette.workspace = true smtlib-lowlevel.workspace = true serde = { workspace = true, optional = true } thiserror.workspace = true +indexmap = "2.2.6" [dev-dependencies] futures = "0.3.29" diff --git a/smtlib/examples/queens.rs b/smtlib/examples/queens.rs index 610ff0e..9559033 100644 --- a/smtlib/examples/queens.rs +++ b/smtlib/examples/queens.rs @@ -6,22 +6,22 @@ use smtlib::{ and, backend::{cvc5_binary::Cvc5Binary, z3_binary::Z3Binary, Backend}, distinct, or, - terms::Sort, + prelude::*, Int, Logic, SatResultWithModel, Solver, }; fn queens(backend: B) -> miette::Result<()> { - let x0 = Int::from_name("x0"); - let x1 = Int::from_name("x1"); - let x2 = Int::from_name("x2"); - let x3 = Int::from_name("x3"); - let x4 = Int::from_name("x4"); - let x5 = Int::from_name("x5"); - let x6 = Int::from_name("x6"); - let x7 = Int::from_name("x7"); + let x0 = Int::new_const("x0"); + let x1 = Int::new_const("x1"); + let x2 = Int::new_const("x2"); + let x3 = Int::new_const("x3"); + let x4 = Int::new_const("x4"); + let x5 = Int::new_const("x5"); + let x6 = Int::new_const("x6"); + let x7 = Int::new_const("x7"); let xs = [x0, x1, x2, x3, x4, x5, x6, x7]; - let n = Int::from_name("N"); + let n = Int::new_const("N"); let mut solver = Solver::new(backend)?; diff --git a/smtlib/examples/queens_bv.rs b/smtlib/examples/queens_bv.rs index ea6f1cc..6198856 100644 --- a/smtlib/examples/queens_bv.rs +++ b/smtlib/examples/queens_bv.rs @@ -6,19 +6,19 @@ use smtlib::{ and, backend::{cvc5_binary::Cvc5Binary, z3_binary::Z3Binary, Backend}, distinct, or, - terms::Sort, + prelude::*, BitVec, SatResultWithModel, Solver, }; fn queens(backend: B) -> miette::Result<()> { - let x0 = BitVec::<4>::from_name("x0"); - let x1 = BitVec::<4>::from_name("x1"); - let x2 = BitVec::<4>::from_name("x2"); - let x3 = BitVec::<4>::from_name("x3"); - let x4 = BitVec::<4>::from_name("x4"); - let x5 = BitVec::<4>::from_name("x5"); - let x6 = BitVec::<4>::from_name("x6"); - let x7 = BitVec::<4>::from_name("x7"); + let x0 = BitVec::<4>::new_const("x0"); + let x1 = BitVec::<4>::new_const("x1"); + let x2 = BitVec::<4>::new_const("x2"); + let x3 = BitVec::<4>::new_const("x3"); + let x4 = BitVec::<4>::new_const("x4"); + let x5 = BitVec::<4>::new_const("x5"); + let x6 = BitVec::<4>::new_const("x6"); + let x7 = BitVec::<4>::new_const("x7"); let xs = [x0, x1, x2, x3, x4, x5, x6, x7]; let mut solver = Solver::new(backend)?; diff --git a/smtlib/examples/queens_bv2.rs b/smtlib/examples/queens_bv2.rs index dd16bb6..911256f 100644 --- a/smtlib/examples/queens_bv2.rs +++ b/smtlib/examples/queens_bv2.rs @@ -6,19 +6,19 @@ use smtlib::{ and, backend::{cvc5_binary::Cvc5Binary, z3_binary::Z3Binary, Backend}, distinct, or, - terms::Sort, + prelude::*, BitVec, Logic, SatResultWithModel, Solver, }; fn queens(backend: B) -> miette::Result<()> { - let x0 = BitVec::<8>::from_name("x0"); - let x1 = BitVec::<8>::from_name("x1"); - let x2 = BitVec::<8>::from_name("x2"); - let x3 = BitVec::<8>::from_name("x3"); - let x4 = BitVec::<8>::from_name("x4"); - let x5 = BitVec::<8>::from_name("x5"); - let x6 = BitVec::<8>::from_name("x6"); - let x7 = BitVec::<8>::from_name("x7"); + let x0 = BitVec::<8>::new_const("x0"); + let x1 = BitVec::<8>::new_const("x1"); + let x2 = BitVec::<8>::new_const("x2"); + let x3 = BitVec::<8>::new_const("x3"); + let x4 = BitVec::<8>::new_const("x4"); + let x5 = BitVec::<8>::new_const("x5"); + let x6 = BitVec::<8>::new_const("x6"); + let x7 = BitVec::<8>::new_const("x7"); let xs = [x0, x1, x2, x3, x4, x5, x6, x7]; let mut solver = Solver::new(backend)?; diff --git a/smtlib/examples/simplify.rs b/smtlib/examples/simplify.rs index f103c99..1fbfbba 100644 --- a/smtlib/examples/simplify.rs +++ b/smtlib/examples/simplify.rs @@ -1,5 +1,5 @@ use miette::IntoDiagnostic; -use smtlib::Sort; +use smtlib::prelude::*; #[derive(Debug, Clone)] enum Expr { @@ -98,7 +98,7 @@ fn expr_to_smt_bool(expr: &Expr) -> smtlib::Bool { fn expr_to_smt(expr: &Expr) -> smtlib::terms::Dynamic { match expr { Expr::Num(n) => smtlib::Int::from(*n).into(), - Expr::Var(v) => smtlib::Int::from_name(v).into(), + Expr::Var(v) => smtlib::Int::new_const(v).into(), Expr::Bool(b) => smtlib::Bool::from(*b).into(), Expr::Add(l, r) => (expr_to_smt_int(l) + expr_to_smt_int(r)).into(), Expr::Sub(l, r) => (expr_to_smt_int(l) - expr_to_smt_int(r)).into(), diff --git a/smtlib/src/funs.rs b/smtlib/src/funs.rs new file mode 100644 index 0000000..01e6997 --- /dev/null +++ b/smtlib/src/funs.rs @@ -0,0 +1,51 @@ +use smtlib_lowlevel::{ast, lexicon::Symbol}; + +use crate::{ + sorts::Sort, + terms::{qual_ident, Dynamic}, +}; + +#[derive(Debug)] +pub struct Fun { + pub name: String, + pub vars: Vec, + pub return_sort: Sort, +} + +impl Fun { + pub fn new(name: impl Into, vars: Vec, return_ty: Sort) -> Self { + Self { + name: name.into(), + vars, + return_sort: return_ty, + } + } + + pub fn call(&self, args: &[Dynamic]) -> Result { + if self.vars.len() != args.len() { + todo!() + } + for (expected, given) in self.vars.iter().zip(args) { + if expected != given.sort() { + todo!("expected {expected:?} given {:?}", given.sort()) + } + } + let term = if args.is_empty() { + ast::Term::Identifier(qual_ident(self.name.clone(), None)) + } else { + ast::Term::Application( + qual_ident(self.name.clone(), None), + args.iter().map(|arg| (*arg).into()).collect(), + ) + }; + Ok(Dynamic::from_term_sort(term, self.return_sort.clone())) + } + + pub fn ast(&self) -> ast::FunctionDec { + ast::FunctionDec( + Symbol(self.name.to_string()), + self.vars.iter().map(|sort| sort.ast()).collect(), + self.return_sort.ast(), + ) + } +} diff --git a/smtlib/src/lib.rs b/smtlib/src/lib.rs index e872126..f89ce56 100644 --- a/smtlib/src/lib.rs +++ b/smtlib/src/lib.rs @@ -8,17 +8,19 @@ use std::collections::HashMap; use itertools::Itertools; use smtlib_lowlevel::ast; use terms::Const; -pub use terms::Sort; +pub use terms::Sorted; pub use backend::Backend; pub use logics::Logic; -pub use smtlib_lowlevel::{self as lowlevel, backend}; +pub use smtlib_lowlevel::{self as lowlevel, backend, Logger}; #[cfg(feature = "tokio")] mod tokio_solver; #[rustfmt::skip] mod logics; +pub mod funs; mod solver; +pub mod sorts; pub mod terms; pub mod theories; @@ -27,6 +29,10 @@ pub use theories::{core::*, fixed_size_bit_vectors::*, ints::*, reals::*}; #[cfg(feature = "tokio")] pub use tokio_solver::TokioSolver; +pub mod prelude { + pub use crate::terms::{Sorted, StaticSorted}; +} + /// The satisfiability result produced by a solver #[derive(Debug)] pub enum SatResult { @@ -80,6 +86,7 @@ impl SatResultWithModel { /// An error that occurred during any stage of using `smtlib`. #[derive(Debug, thiserror::Error, miette::Diagnostic)] +#[non_exhaustive] pub enum Error { #[error(transparent)] #[diagnostic(transparent)] @@ -105,6 +112,11 @@ pub enum Error { /// The actual sat result actual: SatResult, }, + #[error("tried to cast a dynamic of sort {expected} to {actual}")] + DynamicCastSortMismatch { + expected: sorts::Sort, + actual: sorts::Sort, + }, } /// A [`Model`] contains the values of all named constants returned through @@ -168,7 +180,7 @@ impl Model { /// # Ok(()) /// # } /// ``` - pub fn eval(&self, x: Const) -> Option + pub fn eval(&self, x: Const) -> Option where T::Inner: From, { @@ -178,14 +190,16 @@ impl Model { #[cfg(test)] mod tests { - use crate::terms::{forall, Sort}; + use terms::StaticSorted; + + use crate::terms::{forall, Sorted}; use super::*; #[test] fn int_math() { - let x = Int::from_name("x"); - let y = Int::from_name("hello"); + let x = Int::new_const("x"); + let y = Int::new_const("hello"); // let x_named = x.labeled(); let mut z = 12 + y * 4; z += 3; @@ -195,8 +209,8 @@ mod tests { #[test] fn quantifiers() { - let x = Int::from_name("x"); - let y = Int::from_name("y"); + let x = Int::new_const("x"); + let y = Int::new_const("y"); let res = forall((x, y), (x + 2)._eq(y)); println!("{}", ast::Term::from(res)); diff --git a/smtlib/src/solver.rs b/smtlib/src/solver.rs index 4177432..48a433c 100644 --- a/smtlib/src/solver.rs +++ b/smtlib/src/solver.rs @@ -1,13 +1,16 @@ -use std::collections::{hash_map::Entry, HashMap}; - +use indexmap::{map::Entry, IndexMap, IndexSet}; use smtlib_lowlevel::{ ast::{self, Identifier, QualIdentifier}, backend, - lexicon::Symbol, - Driver, + lexicon::{Numeral, Symbol}, + Driver, Logger, }; -use crate::{terms::Dynamic, Bool, Error, Logic, Model, SatResult, SatResultWithModel}; +use crate::{ + funs, sorts, + terms::{qual_ident, Dynamic}, + Bool, Error, Logic, Model, SatResult, SatResultWithModel, +}; /// The [`Solver`] type is the primary entrypoint to interaction with the /// solver. Checking for validity of a set of assertions requires: @@ -38,7 +41,15 @@ use crate::{terms::Dynamic, Bool, Error, Logic, Model, SatResult, SatResultWithM #[derive(Debug)] pub struct Solver { driver: Driver, - decls: HashMap, + push_pop_stack: Vec, + decls: IndexMap, + declared_sorts: IndexSet, +} + +#[derive(Debug)] +struct StackSizes { + decls: usize, + declared_sorts: usize, } impl Solver @@ -52,9 +63,25 @@ where pub fn new(backend: B) -> Result { Ok(Self { driver: Driver::new(backend)?, + push_pop_stack: Vec::new(), decls: Default::default(), + declared_sorts: Default::default(), }) } + pub fn set_logger(&mut self, logger: impl Logger) { + self.driver.set_logger(logger) + } + pub fn set_timeout(&mut self, ms: usize) -> Result<(), Error> { + let cmd = ast::Command::SetOption(ast::Option::Attribute(ast::Attribute::WithValue( + smtlib_lowlevel::lexicon::Keyword(":timeout".to_string()), + ast::AttributeValue::SpecConstant(ast::SpecConstant::Numeral(Numeral(ms.to_string()))), + ))); + match self.driver.exec(&cmd)? { + ast::GeneralResponse::Success => Ok(()), + ast::GeneralResponse::Error(e) => Err(Error::Smt(e, cmd.to_string())), + _ => todo!(), + } + } /// Explicitly sets the logic for the solver. For some backends this is not /// required, as they will infer what ever logic fits the current program. /// @@ -132,6 +159,27 @@ where res => todo!("{res:?}"), } } + pub fn declare_fun(&mut self, fun: &funs::Fun) -> Result<(), Error> { + for var in &fun.vars { + self.declare_sort(&var.ast())?; + } + self.declare_sort(&fun.return_sort.ast())?; + + if fun.vars.is_empty() { + return self.declare_const(&qual_ident(fun.name.clone(), Some(fun.return_sort.ast()))); + } + + let cmd = ast::Command::DeclareFun( + Symbol(fun.name.clone()), + fun.vars.iter().map(|s| s.ast()).collect(), + fun.return_sort.ast(), + ); + match self.driver.exec(&cmd)? { + ast::GeneralResponse::Success => Ok(()), + ast::GeneralResponse::Error(e) => Err(Error::Smt(e, cmd.to_string())), + _ => todo!(), + } + } /// Simplifies the given term pub fn simplify(&mut self, t: Dynamic) -> Result { self.declare_all_consts(&t.into())?; @@ -146,11 +194,58 @@ where } } + pub fn scope( + &mut self, + f: impl FnOnce(&mut Solver) -> Result, + ) -> Result { + self.push(1)?; + let res = f(self)?; + self.pop(1)?; + Ok(res) + } + + fn push(&mut self, levels: usize) -> Result<(), Error> { + self.push_pop_stack.push(StackSizes { + decls: self.decls.len(), + declared_sorts: self.declared_sorts.len(), + }); + + let cmd = ast::Command::Push(Numeral(levels.to_string())); + Ok(match self.driver.exec(&cmd)? { + ast::GeneralResponse::Success => {} + ast::GeneralResponse::Error(e) => return Err(Error::Smt(e, cmd.to_string())), + _ => todo!(), + }) + } + + fn pop(&mut self, levels: usize) -> Result<(), Error> { + if let Some(sizes) = self.push_pop_stack.pop() { + self.decls.truncate(sizes.decls); + self.declared_sorts.truncate(sizes.declared_sorts); + } + + let cmd = ast::Command::Pop(Numeral(levels.to_string())); + Ok(match self.driver.exec(&cmd)? { + ast::GeneralResponse::Success => {} + ast::GeneralResponse::Error(e) => return Err(Error::Smt(e, cmd.to_string())), + _ => todo!(), + }) + } + fn declare_all_consts(&mut self, t: &ast::Term) -> Result<(), Error> { for q in t.all_consts() { - match q { - QualIdentifier::Identifier(_) => {} - QualIdentifier::Sorted(i, s) => match self.decls.entry(i.clone()) { + self.declare_const(q)?; + } + Ok(()) + } + + fn declare_const(&mut self, q: &QualIdentifier) -> Result<(), Error> { + Ok(match q { + QualIdentifier::Identifier(_) => {} + QualIdentifier::Sorted(i, s) => { + self.declare_sort(s)?; + + match self.decls.entry(i.clone()) { Entry::Occupied(stored) => assert_eq!(s, stored.get()), Entry::Vacant(v) => { v.insert(s.clone()); @@ -162,9 +257,51 @@ where Identifier::Indexed(_, _) => todo!(), } } - }, + } } + }) + } + + fn declare_sort(&mut self, s: &ast::Sort) -> Result<(), Error> { + if self.declared_sorts.contains(s) { + return Ok(()); + } + self.declared_sorts.insert(s.clone()); + + let cmd = match s { + ast::Sort::Sort(ident) => { + let sym = match ident { + Identifier::Simple(sym) => sym, + Identifier::Indexed(_, _) => { + // TODO: is it correct that only sorts from theores can + // be indexed, and thus does not need to be declared? + return Ok(()); + } + }; + if sorts::is_built_in_sort(&sym.0) { + return Ok(()); + } + ast::Command::DeclareSort(sym.clone(), Numeral("0".to_string())) + } + ast::Sort::Parametric(ident, params) => { + let sym = match ident { + Identifier::Simple(sym) => sym, + Identifier::Indexed(_, _) => { + // TODO: is it correct that only sorts from theores can + // be indexed, and thus does not need to be declared? + return Ok(()); + } + }; + if sorts::is_built_in_sort(&sym.0) { + return Ok(()); + } + ast::Command::DeclareSort(sym.clone(), Numeral(params.len().to_string())) + } + }; + match self.driver.exec(&cmd)? { + ast::GeneralResponse::Success => Ok(()), + ast::GeneralResponse::Error(e) => return Err(Error::Smt(e, cmd.to_string())), + _ => todo!(), } - Ok(()) } } diff --git a/smtlib/src/sorts.rs b/smtlib/src/sorts.rs new file mode 100644 index 0000000..1d16b46 --- /dev/null +++ b/smtlib/src/sorts.rs @@ -0,0 +1,120 @@ +use smtlib_lowlevel::{ + ast::{self, Identifier}, + lexicon::{Numeral, Symbol}, +}; + +use crate::terms::{self, qual_ident}; + +pub struct SortTemplate { + pub name: String, + pub index: Vec, + pub arity: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Sort { + pub name: String, + pub index: Vec, + pub parameters: Vec, +} + +impl std::fmt::Display for Sort { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.ast().fmt(f) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Index { + Numeral(usize), + Symbol(String), +} + +impl SortTemplate { + pub fn instantiate(&self, parameters: Vec) -> Result { + if self.arity != parameters.len() { + return Err(todo!()); + } + + Ok(Sort { + name: self.name.clone(), + index: self.index.clone(), + parameters, + }) + } +} + +impl Index { + fn ast(&self) -> ast::Index { + match self { + Index::Numeral(n) => ast::Index::Numeral(Numeral(n.to_string())), + Index::Symbol(s) => ast::Index::Symbol(Symbol(s.to_string())), + } + } +} + +pub(crate) fn is_built_in_sort(name: &str) -> bool { + match name { + "Int" | "Bool" => true, + _ => false, + } +} + +impl Sort { + pub fn new(name: impl Into) -> Self { + let mut name = name.into(); + if !is_built_in_sort(&name) { + // HACK: how should we handle this? or should we event handle it? + name += "_xxx"; + } + Self { + name, + index: Vec::new(), + parameters: Vec::new(), + } + } + pub fn new_parametric(name: impl Into, parameters: Vec) -> Self { + Self { + name: name.into(), + index: Vec::new(), + parameters, + } + } + pub fn new_indexed(name: impl Into, index: Vec) -> Self { + Self { + name: name.into(), + index, + parameters: Vec::new(), + } + } + + pub fn ast(&self) -> ast::Sort { + let ident = if self.index.is_empty() { + Identifier::Simple(Symbol(self.name.to_string())) + } else { + Identifier::Indexed( + Symbol(self.name.to_string()), + self.index.iter().map(|idx| idx.ast()).collect(), + ) + }; + if self.parameters.is_empty() { + ast::Sort::Sort(ident) + } else { + ast::Sort::Parametric( + ident, + self.parameters.iter().map(|param| param.ast()).collect(), + ) + } + } + + pub fn new_const(&self, name: impl Into) -> terms::Const { + let name: &'static str = String::leak(name.into()); + terms::Const( + name, + terms::Dynamic::from_term_sort( + ast::Term::Identifier(qual_ident(name.into(), Some(self.ast()))), + self.clone(), + ), + ) + } +} diff --git a/smtlib/src/terms.rs b/smtlib/src/terms.rs index b9da88e..b67b20a 100644 --- a/smtlib/src/terms.rs +++ b/smtlib/src/terms.rs @@ -11,7 +11,7 @@ use smtlib_lowlevel::{ lexicon::{Keyword, Symbol}, }; -use crate::Bool; +use crate::{sorts::Sort, Bool}; pub(crate) fn fun(name: &str, args: Vec) -> Term { Term::Application(qual_ident(name.to_string(), None), args) @@ -50,42 +50,71 @@ impl std::ops::Deref for Const { /// This type wraps terms loosing all static type information. It is particular /// useful when constructing terms dynamically. #[derive(Debug, Clone, Copy)] -pub struct Dynamic(&'static Term); +pub struct Dynamic(&'static Term, &'static Sort); impl std::fmt::Display for Dynamic { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Term::from(*self).fmt(f) } } +pub trait StaticSorted: Into + From { + type Inner: StaticSorted; + fn static_sort() -> Sort; + fn new_const(name: impl Into) -> Const { + let name = name.into(); + let bv = Term::Identifier(qual_ident(name.clone(), Some(Self::static_sort().ast()))).into(); + let name = String::leak(name); + Const(name, bv) + } +} + /// An trait for statically typing STM-LIB terms. /// /// This trait indicates that a type can construct a [`Term`] which is the /// low-level primitive that is used to define expressions for the SMT solvers /// to evaluate. -pub trait Sort: Into { +pub trait Sorted: Into { /// The inner type of the term. This is used for [`Const`](Const) where the inner type is `T`. - type Inner: Sort; + type Inner: Sorted; /// The sort of the term - fn sort() -> ast::Sort; - /// Construct a constant of this sort. See the documentation of [`Const`] - /// for more information about constants. - fn from_name(name: impl Into) -> Const + fn sort(&self) -> Sort; + /// The sort of the term + fn is_sort(sort: &Sort) -> bool; + // /// Construct a constant of this sort. See the documentation of [`Const`] + // /// for more information about constants. + // fn from_name(name: impl Into) -> Const + // where + // Self: From, + // { + // // TODO: Only add |_| if necessary + // let name = format!("|{}|", name.into()); + // Const( + // Box::leak(name.clone().into_boxed_str()), + // Term::Identifier(qual_ident(name, Some(Self::sort().ast()))).into(), + // ) + // } + /// Casts a dynamically typed term into a concrete type + fn from_dynamic(d: Dynamic) -> Self where - Self: From, + Self: From<(Term, Sort)>, { - // TODO: Only add |_| if necessary - let name = format!("|{}|", name.into()); - Const( - Box::leak(name.clone().into_boxed_str()), - Term::Identifier(qual_ident(name, Some(Self::sort()))).into(), - ) + (d.0.clone(), d.1.clone()).into() } - /// Casts a dynamically typed term into a concrete type - fn from_dynamic(d: Dynamic) -> Self + /// Casts a dynamically typed term into a concrete type iff the dynamic sort + /// matches + fn try_from_dynamic(d: Dynamic) -> Option where - Self: From, + Self: From<(Term, Sort)>, { - d.0.clone().into() + if Self::is_sort(d.sort()) { + Some((d.0.clone(), d.1.clone()).into()) + } else { + None + } + } + fn into_dynamic(self) -> Dynamic { + let sort = self.sort(); + Dynamic::from_term_sort(self.into(), sort) } /// Construct the term representing `(= self other)` fn _eq(self, other: impl Into) -> Bool { @@ -122,10 +151,25 @@ impl> From> for Term { c.1.into() } } -impl Sort for Const { +impl Sorted for Const { type Inner = T; - fn sort() -> ast::Sort { - T::sort() + fn sort(&self) -> Sort { + T::sort(self) + } + fn is_sort(sort: &Sort) -> bool { + T::is_sort(sort) + } +} + +impl Sorted for T { + type Inner = T::Inner; + + fn sort(&self) -> Sort { + Self::static_sort() + } + + fn is_sort(sort: &Sort) -> bool { + sort == &Self::static_sort() } } @@ -162,15 +206,45 @@ impl From for Term { d.0.clone() } } -impl From for Dynamic { - fn from(t: Term) -> Self { - Dynamic(Box::leak(Box::new(t))) +impl From<(Term, Sort)> for Dynamic { + fn from((t, sort): (Term, Sort)) -> Self { + Dynamic::from_term_sort(t, sort) + } +} +impl Dynamic { + pub fn from_term_sort(t: Term, sort: Sort) -> Self { + Dynamic(Box::leak(Box::new(t)), Box::leak(Box::new(sort))) + } + + pub fn sort(&self) -> &Sort { + &self.1 + } + + pub fn as_int(&self) -> Result { + crate::Int::try_from_dynamic(self.clone()).ok_or_else(|| { + crate::Error::DynamicCastSortMismatch { + expected: crate::Int::static_sort(), + actual: self.1.clone(), + } + }) + } + + pub fn as_bool(&self) -> Result { + crate::Bool::try_from_dynamic(self.clone()).ok_or_else(|| { + crate::Error::DynamicCastSortMismatch { + expected: crate::Bool::static_sort(), + actual: self.1.clone(), + } + }) } } -impl Sort for Dynamic { +impl Sorted for Dynamic { type Inner = Self; - fn sort() -> ast::Sort { - ast::Sort::Sort(Identifier::Simple(Symbol("dynamic".into()))) + fn sort(&self) -> Sort { + self.1.clone() + } + fn is_sort(_sort: &Sort) -> bool { + true } } @@ -228,21 +302,24 @@ pub trait QuantifierVars { impl QuantifierVars for Const where - A: Sort, + A: StaticSorted, { fn into_vars(self) -> Vec { - vec![SortedVar(Symbol(self.0.to_string()), A::sort())] + vec![SortedVar( + Symbol(self.0.to_string()), + A::static_sort().ast(), + )] } } macro_rules! impl_quantifiers { ($($x:ident $n:tt),+ $(,)?) => { impl<$($x,)+> QuantifierVars for ($(Const<$x>),+) where - $($x: Sort),+ + $($x: StaticSorted),+ { fn into_vars(self) -> Vec { vec![ - $(SortedVar(Symbol((self.$n).0.into()), $x::sort())),+ + $(SortedVar(Symbol((self.$n).0.into()), $x::static_sort().ast())),+ ] } } @@ -260,6 +337,13 @@ impl_quantifiers!(A 0, B 1, C 2, D 3, E 4); // .collect() // } // } +impl QuantifierVars for Vec> { + fn into_vars(self) -> Vec { + self.into_iter() + .map(|v| SortedVar(Symbol(v.0.into()), v.1 .1.ast())) + .collect() + } +} impl QuantifierVars for Vec { fn into_vars(self) -> Vec { self diff --git a/smtlib/src/theories/core.rs b/smtlib/src/theories/core.rs index c42029a..14bfdc5 100644 --- a/smtlib/src/theories/core.rs +++ b/smtlib/src/theories/core.rs @@ -1,13 +1,11 @@ #![doc = concat!("```ignore\n", include_str!("./Core.smt2"), "```")] -use smtlib_lowlevel::{ - ast::{self, Identifier, Term}, - lexicon::Symbol, -}; +use smtlib_lowlevel::ast::{self, Term}; use crate::{ impl_op, - terms::{fun, qual_ident, Const, Dynamic, Sort}, + sorts::Sort, + terms::{fun, qual_ident, Const, Dynamic, Sorted, StaticSorted}, }; /// A [`Bool`] is a term containing a @@ -16,6 +14,12 @@ use crate::{ #[derive(Clone, Copy)] pub struct Bool(BoolImpl); +impl Bool { + pub fn new(value: bool) -> Bool { + value.into() + } +} + impl std::fmt::Debug for Bool { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) @@ -40,7 +44,7 @@ impl std::fmt::Display for Bool { } impl From for Dynamic { fn from(b: Bool) -> Self { - Term::from(b).into() + b.into_dynamic() } } impl From for Bool { @@ -66,13 +70,25 @@ impl From for Bool { Bool(BoolImpl::Term(Box::leak(Box::new(t)))) } } -impl Sort for Bool { +impl From<(Term, Sort)> for Bool { + fn from((t, _): (Term, Sort)) -> Self { + t.into() + } +} +impl StaticSorted for Bool { type Inner = Self; - fn sort() -> ast::Sort { - ast::Sort::Sort(Identifier::Simple(Symbol("Bool".into()))) + fn static_sort() -> Sort { + Sort::new("Bool") + } + fn new_const(name: impl Into) -> Const { + let name = String::leak(name.into()); + Const(name, Bool(BoolImpl::Const(name))) } } impl Bool { + pub fn sort() -> Sort { + Self::static_sort() + } fn binop(self, op: &str, other: Bool) -> Self { fun(op, vec![self.into(), other.into()]).into() } @@ -91,8 +107,11 @@ impl Bool { /// and an if statement: /// - **C-style notation:** `self ? then : otherwise` /// - **Rust notation:** `if self { then } else { otherwise }` - pub fn ite(self, then: Bool, otherwise: Bool) -> Bool { - fun("ite", vec![self.into(), then.into(), otherwise.into()]).into() + pub fn ite>(self, then: T, otherwise: T) -> T { + let sort = then.sort(); + let term = fun("ite", vec![self.into(), then.into(), otherwise.into()]); + let dyn_term = Dynamic::from_term_sort(term, sort); + T::from_dynamic(dyn_term) } } diff --git a/smtlib/src/theories/fixed_size_bit_vectors.rs b/smtlib/src/theories/fixed_size_bit_vectors.rs index 78a4bd3..1c8e6d5 100644 --- a/smtlib/src/theories/fixed_size_bit_vectors.rs +++ b/smtlib/src/theories/fixed_size_bit_vectors.rs @@ -1,13 +1,11 @@ #![doc = concat!("```ignore\n", include_str!("./FixedSizeBitVectors.smt2"), "```")] use itertools::Itertools; -use smtlib_lowlevel::{ - ast::{self, Identifier, Index, Term}, - lexicon::{Numeral, Symbol}, -}; +use smtlib_lowlevel::ast::{self, Term}; use crate::{ - terms::{fun, qual_ident, Const, Dynamic, Sort}, + sorts::{Index, Sort}, + terms::{fun, qual_ident, Const, Dynamic, Sorted, StaticSorted}, Bool, }; @@ -30,7 +28,7 @@ impl std::fmt::Display for BitVec { impl From> for Dynamic { fn from(i: BitVec) -> Self { - Term::from(i).into() + i.into_dynamic() } } @@ -78,13 +76,10 @@ impl TryFrom> for [bool; M] { } } -impl Sort for BitVec { +impl StaticSorted for BitVec { type Inner = Self; - fn sort() -> ast::Sort { - ast::Sort::Sort(Identifier::Indexed( - Symbol("BitVec".to_string()), - vec![Index::Numeral(Numeral(M.to_string()))], - )) + fn static_sort() -> Sort { + Sort::new_indexed("BitVec", vec![Index::Numeral(M)]) } } impl From<[bool; M]> for BitVec { @@ -279,7 +274,7 @@ impl_op!(BitVec, [bool; M], Shl, shl, bvshl, ShlAssign, shl_assign, <<); mod tests { use smtlib_lowlevel::backend::Z3Binary; - use crate::{terms::Sort, Solver}; + use crate::{terms::Sorted, Solver}; use super::BitVec; diff --git a/smtlib/src/theories/ints.rs b/smtlib/src/theories/ints.rs index 787a335..5891fe7 100644 --- a/smtlib/src/theories/ints.rs +++ b/smtlib/src/theories/ints.rs @@ -1,13 +1,11 @@ #![doc = concat!("```ignore\n", include_str!("./Ints.smt2"), "```")] -use smtlib_lowlevel::{ - ast::{self, Identifier, Term}, - lexicon::Symbol, -}; +use smtlib_lowlevel::ast::Term; use crate::{ impl_op, - terms::{fun, qual_ident, Const, Dynamic, Sort}, + sorts::Sort, + terms::{fun, qual_ident, Const, Dynamic, Sorted, StaticSorted}, Bool, }; @@ -29,7 +27,7 @@ impl std::fmt::Display for Int { impl From for Dynamic { fn from(i: Int) -> Self { - Term::from(i).into() + i.into_dynamic() } } @@ -43,10 +41,15 @@ impl From for Int { Int(Box::leak(Box::new(t))) } } -impl Sort for Int { +impl From<(Term, Sort)> for Int { + fn from((t, _): (Term, Sort)) -> Self { + t.into() + } +} +impl StaticSorted for Int { type Inner = Self; - fn sort() -> ast::Sort { - ast::Sort::Sort(Identifier::Simple(Symbol("Int".into()))) + fn static_sort() -> Sort { + Sort::new("Int") } } impl From for Int { @@ -55,6 +58,9 @@ impl From for Int { } } impl Int { + pub fn sort() -> Sort { + Self::static_sort() + } fn binop>(self, op: &str, other: Int) -> T { fun(op, vec![self.into(), other.into()]).into() } @@ -92,3 +98,6 @@ impl_op!(Int, i64, Add, add, "+", AddAssign, add_assign, +); impl_op!(Int, i64, Sub, sub, "-", SubAssign, sub_assign, -); impl_op!(Int, i64, Mul, mul, "*", MulAssign, mul_assign, *); impl_op!(Int, i64, Div, div, "div", DivAssign, div_assign, /); +impl_op!(Int, i64, Rem, rem, "rem", RemAssign, rem_assign, %); +impl_op!(Int, i64, Shl, shl, "hsl", ShlAssign, shl_assign, <<); +impl_op!(Int, i64, Shr, shr, "hsr", ShrAssign, shr_assign, >>); diff --git a/smtlib/src/theories/reals.rs b/smtlib/src/theories/reals.rs index ad27efd..6bd77db 100644 --- a/smtlib/src/theories/reals.rs +++ b/smtlib/src/theories/reals.rs @@ -1,13 +1,11 @@ #![doc = concat!("```ignore\n", include_str!("./Reals.smt2"), "```")] -use smtlib_lowlevel::{ - ast::{self, Identifier, Term}, - lexicon::Symbol, -}; +use smtlib_lowlevel::ast::Term; use crate::{ impl_op, - terms::{fun, qual_ident, Const, Dynamic, Sort}, + sorts::Sort, + terms::{fun, qual_ident, Const, Dynamic, Sorted, StaticSorted}, Bool, }; @@ -29,7 +27,7 @@ impl std::fmt::Display for Real { impl From for Dynamic { fn from(i: Real) -> Self { - Term::from(i).into() + i.into_dynamic() } } @@ -43,10 +41,10 @@ impl From for Real { Real(Box::leak(Box::new(t))) } } -impl Sort for Real { +impl StaticSorted for Real { type Inner = Self; - fn sort() -> ast::Sort { - ast::Sort::Sort(Identifier::Simple(Symbol("Real".into()))) + fn static_sort() -> Sort { + Sort::new("Real") } } impl From for Real { @@ -65,6 +63,9 @@ impl From for Real { } } impl Real { + pub fn sort() -> Sort { + Self::static_sort() + } fn binop>(self, op: &str, other: Real) -> T { fun(op, vec![self.into(), other.into()]).into() } diff --git a/xtask/src/spec.toml b/xtask/src/spec.toml index 128081c..5774fde 100644 --- a/xtask/src/spec.toml +++ b/xtask/src/spec.toml @@ -143,7 +143,7 @@ syntax = "( + )" syntax = "( par ( + ) ( + ) )" [function_dec] -syntax = "( ( * ) )" +syntax = "( ( * ) )" [function_def] syntax = " ( * ) "