Skip to content

Commit

Permalink
Add user defined functions and sorts and refactor term sorts
Browse files Browse the repository at this point in the history
  • Loading branch information
oeb25 committed Jul 31, 2024
1 parent 7217e5a commit 7a75c69
Show file tree
Hide file tree
Showing 18 changed files with 642 additions and 139 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions lowlevel/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1246,10 +1246,10 @@ impl SmtlibParse for DatatypeDec {
Err(p.stuck("datatype_dec"))
}
}
/// `(<symbol> (<sorted_var>*) <sort>)`
/// `(<symbol> (<sort>*) <sort>)`
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FunctionDec(pub Symbol, pub Vec<SortedVar>, pub Sort);
pub struct FunctionDec(pub Symbol, pub Vec<Sort>, pub Sort);
impl std::fmt::Display for FunctionDec {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "({} ({}) {})", self.0, self.1.iter().format(" "), self.2)
Expand All @@ -1268,7 +1268,7 @@ impl SmtlibParse for FunctionDec {
p.expect(Token::LParen)?;
let m0 = <Symbol as SmtlibParse>::parse(p)?;
p.expect(Token::LParen)?;
let m1 = p.any::<SortedVar>()?;
let m1 = p.any::<Sort>()?;
p.expect(Token::RParen)?;
let m2 = <Sort as SmtlibParse>::parse(p)?;
p.expect(Token::RParen)?;
Expand Down
91 changes: 81 additions & 10 deletions lowlevel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,71 @@ pub enum Error {
IO(#[from] std::io::Error),
}

#[derive(Debug)]
pub trait Logger: 'static {
fn exec(&self, cmd: &ast::Command);
fn response(&self, cmd: &ast::Command, res: &str);
}

impl<F, G> Logger for (F, G)
where
F: Fn(&ast::Command) + 'static,
G: Fn(&ast::Command, &str) + 'static,
{
fn exec(&self, cmd: &ast::Command) {
(self.0)(cmd)
}

fn response(&self, cmd: &ast::Command, res: &str) {
(self.1)(cmd, res)
}
}

pub struct Driver<B> {
backend: B,
logger: Option<Box<dyn Logger>>,
}

impl<B: std::fmt::Debug> std::fmt::Debug for Driver<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Driver")
.field("backend", &self.backend)
.field(
"logger",
if self.logger.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.finish()
}
}

impl<B> Driver<B>
where
B: Backend,
{
pub fn new(backend: B) -> Result<Self, Error> {
let mut driver = Self { backend };
let mut driver = Self {
backend,
logger: None,
};

driver.exec(&Command::SetOption(ast::Option::PrintSuccess(true)))?;

Ok(driver)
}
pub fn set_logger(&mut self, logger: impl Logger) {
self.logger = Some(Box::new(logger))
}
pub fn exec(&mut self, cmd: &Command) -> Result<GeneralResponse, Error> {
// println!("> {cmd}");
if let Some(logger) = &self.logger {
logger.exec(cmd);
}
let res = self.backend.exec(cmd)?;
if let Some(logger) = &self.logger {
logger.response(cmd, &res);
}
let res = if let Some(res) = cmd.parse_response(&res)? {
GeneralResponse::SpecificSuccessResponse(res)
} else {
Expand All @@ -67,30 +113,57 @@ pub mod tokio {
use crate::{
ast::{self, Command, GeneralResponse},
backend::tokio::TokioBackend,
Error,
Error, Logger,
};

#[derive(Debug)]
pub struct TokioDriver<B> {
backend: B,
logger: Option<Box<dyn Logger>>,
}

impl<B: std::fmt::Debug> std::fmt::Debug for TokioDriver<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokioDriver")
.field("backend", &self.backend)
.field(
"logger",
if self.logger.is_some() {
&"Some(...)"
} else {
&"None"
},
)
.finish()
}
}

impl<B> TokioDriver<B>
where
B: TokioBackend,
{
pub async fn new(backend: B) -> Result<Self, Error> {
let mut driver = Self { backend };
let mut driver = Self {
backend,
logger: None,
};

driver
.exec(&Command::SetOption(ast::Option::PrintSuccess(true)))
.await?;

Ok(driver)
}
pub fn set_logger(&mut self, logger: impl Logger) {
self.logger = Some(Box::new(logger))
}
pub async fn exec(&mut self, cmd: &Command) -> Result<GeneralResponse, Error> {
// println!("> {cmd}");
if let Some(logger) = &self.logger {
logger.exec(cmd);
}
let res = self.backend.exec(cmd).await?;
if let Some(logger) = &self.logger {
logger.response(cmd, &res);
}
let res = if let Some(res) = cmd.parse_response(&res)? {
GeneralResponse::SpecificSuccessResponse(res)
} else {
Expand All @@ -107,9 +180,7 @@ impl Term {
match self {
Term::SpecConstant(_) => HashSet::new(),
Term::Identifier(q) => std::iter::once(q).collect(),
Term::Application(q, args) => std::iter::once(q)
.chain(args.iter().flat_map(|arg| arg.all_consts()))
.collect(),
Term::Application(_, args) => args.iter().flat_map(|arg| arg.all_consts()).collect(),
Term::Let(_, _) => todo!(),
// TODO
Term::Forall(_, _) => HashSet::new(),
Expand Down
1 change: 1 addition & 0 deletions smtlib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ miette.workspace = true
smtlib-lowlevel.workspace = true
serde = { workspace = true, optional = true }
thiserror.workspace = true
indexmap = "2.2.6"

[dev-dependencies]
futures = "0.3.29"
Expand Down
20 changes: 10 additions & 10 deletions smtlib/examples/queens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ use smtlib::{
and,
backend::{cvc5_binary::Cvc5Binary, z3_binary::Z3Binary, Backend},
distinct, or,
terms::Sort,
prelude::*,
Int, Logic, SatResultWithModel, Solver,
};

fn queens<B: Backend>(backend: B) -> miette::Result<()> {
let x0 = Int::from_name("x0");
let x1 = Int::from_name("x1");
let x2 = Int::from_name("x2");
let x3 = Int::from_name("x3");
let x4 = Int::from_name("x4");
let x5 = Int::from_name("x5");
let x6 = Int::from_name("x6");
let x7 = Int::from_name("x7");
let x0 = Int::new_const("x0");
let x1 = Int::new_const("x1");
let x2 = Int::new_const("x2");
let x3 = Int::new_const("x3");
let x4 = Int::new_const("x4");
let x5 = Int::new_const("x5");
let x6 = Int::new_const("x6");
let x7 = Int::new_const("x7");
let xs = [x0, x1, x2, x3, x4, x5, x6, x7];

let n = Int::from_name("N");
let n = Int::new_const("N");

let mut solver = Solver::new(backend)?;

Expand Down
18 changes: 9 additions & 9 deletions smtlib/examples/queens_bv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ use smtlib::{
and,
backend::{cvc5_binary::Cvc5Binary, z3_binary::Z3Binary, Backend},
distinct, or,
terms::Sort,
prelude::*,
BitVec, SatResultWithModel, Solver,
};

fn queens<B: Backend>(backend: B) -> miette::Result<()> {
let x0 = BitVec::<4>::from_name("x0");
let x1 = BitVec::<4>::from_name("x1");
let x2 = BitVec::<4>::from_name("x2");
let x3 = BitVec::<4>::from_name("x3");
let x4 = BitVec::<4>::from_name("x4");
let x5 = BitVec::<4>::from_name("x5");
let x6 = BitVec::<4>::from_name("x6");
let x7 = BitVec::<4>::from_name("x7");
let x0 = BitVec::<4>::new_const("x0");
let x1 = BitVec::<4>::new_const("x1");
let x2 = BitVec::<4>::new_const("x2");
let x3 = BitVec::<4>::new_const("x3");
let x4 = BitVec::<4>::new_const("x4");
let x5 = BitVec::<4>::new_const("x5");
let x6 = BitVec::<4>::new_const("x6");
let x7 = BitVec::<4>::new_const("x7");
let xs = [x0, x1, x2, x3, x4, x5, x6, x7];

let mut solver = Solver::new(backend)?;
Expand Down
18 changes: 9 additions & 9 deletions smtlib/examples/queens_bv2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ use smtlib::{
and,
backend::{cvc5_binary::Cvc5Binary, z3_binary::Z3Binary, Backend},
distinct, or,
terms::Sort,
prelude::*,
BitVec, Logic, SatResultWithModel, Solver,
};

fn queens<B: Backend>(backend: B) -> miette::Result<()> {
let x0 = BitVec::<8>::from_name("x0");
let x1 = BitVec::<8>::from_name("x1");
let x2 = BitVec::<8>::from_name("x2");
let x3 = BitVec::<8>::from_name("x3");
let x4 = BitVec::<8>::from_name("x4");
let x5 = BitVec::<8>::from_name("x5");
let x6 = BitVec::<8>::from_name("x6");
let x7 = BitVec::<8>::from_name("x7");
let x0 = BitVec::<8>::new_const("x0");
let x1 = BitVec::<8>::new_const("x1");
let x2 = BitVec::<8>::new_const("x2");
let x3 = BitVec::<8>::new_const("x3");
let x4 = BitVec::<8>::new_const("x4");
let x5 = BitVec::<8>::new_const("x5");
let x6 = BitVec::<8>::new_const("x6");
let x7 = BitVec::<8>::new_const("x7");
let xs = [x0, x1, x2, x3, x4, x5, x6, x7];

let mut solver = Solver::new(backend)?;
Expand Down
4 changes: 2 additions & 2 deletions smtlib/examples/simplify.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use miette::IntoDiagnostic;
use smtlib::Sort;
use smtlib::prelude::*;

#[derive(Debug, Clone)]
enum Expr {
Expand Down Expand Up @@ -98,7 +98,7 @@ fn expr_to_smt_bool(expr: &Expr) -> smtlib::Bool {
fn expr_to_smt(expr: &Expr) -> smtlib::terms::Dynamic {
match expr {
Expr::Num(n) => smtlib::Int::from(*n).into(),
Expr::Var(v) => smtlib::Int::from_name(v).into(),
Expr::Var(v) => smtlib::Int::new_const(v).into(),
Expr::Bool(b) => smtlib::Bool::from(*b).into(),
Expr::Add(l, r) => (expr_to_smt_int(l) + expr_to_smt_int(r)).into(),
Expr::Sub(l, r) => (expr_to_smt_int(l) - expr_to_smt_int(r)).into(),
Expand Down
51 changes: 51 additions & 0 deletions smtlib/src/funs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use smtlib_lowlevel::{ast, lexicon::Symbol};

use crate::{
sorts::Sort,
terms::{qual_ident, Dynamic},
};

#[derive(Debug)]
pub struct Fun {
pub name: String,
pub vars: Vec<Sort>,
pub return_sort: Sort,
}

impl Fun {
pub fn new(name: impl Into<String>, vars: Vec<Sort>, return_ty: Sort) -> Self {
Self {
name: name.into(),
vars,
return_sort: return_ty,
}
}

pub fn call(&self, args: &[Dynamic]) -> Result<Dynamic, crate::Error> {
if self.vars.len() != args.len() {
todo!()
}
for (expected, given) in self.vars.iter().zip(args) {
if expected != given.sort() {
todo!("expected {expected:?} given {:?}", given.sort())
}
}
let term = if args.is_empty() {
ast::Term::Identifier(qual_ident(self.name.clone(), None))
} else {
ast::Term::Application(
qual_ident(self.name.clone(), None),
args.iter().map(|arg| (*arg).into()).collect(),
)
};
Ok(Dynamic::from_term_sort(term, self.return_sort.clone()))
}

pub fn ast(&self) -> ast::FunctionDec {
ast::FunctionDec(
Symbol(self.name.to_string()),
self.vars.iter().map(|sort| sort.ast()).collect(),
self.return_sort.ast(),
)
}
}
Loading

0 comments on commit 7a75c69

Please sign in to comment.