From e065d70f81a9aa787f523a1668b0f6aacbb0cbd8 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Tue, 17 Dec 2024 19:45:51 +0100 Subject: [PATCH] feat: Scoping rules and utilities for symbols, links and variables (#1754) This PR introduces scoping for symbols, links and variables. It comes with utilities that can be used to resolve names appropriately. Moreover the model data structures are changed so that they always use direct references by indices instead of names in order to streamline the serialisation format. --- hugr-core/src/export.rs | 358 ++++++++++----- hugr-core/src/import.rs | 327 +++++-------- .../tests/snapshots/model__roundtrip_add.snap | 6 +- .../snapshots/model__roundtrip_alias.snap | 4 +- .../snapshots/model__roundtrip_call.snap | 4 + .../tests/snapshots/model__roundtrip_cfg.snap | 8 +- .../snapshots/model__roundtrip_cond.snap | 6 +- .../model__roundtrip_constraints.snap | 2 + hugr-model/capnp/hugr-v0.capnp | 58 +-- hugr-model/src/v0/binary/read.rs | 84 ++-- hugr-model/src/v0/binary/write.rs | 66 ++- hugr-model/src/v0/mod.rs | 165 ++++--- hugr-model/src/v0/scope/link.rs | 125 +++++ hugr-model/src/v0/scope/mod.rs | 8 + hugr-model/src/v0/scope/symbol.rs | 198 ++++++++ hugr-model/src/v0/scope/vars.rs | 151 ++++++ hugr-model/src/v0/text/hugr.pest | 2 + hugr-model/src/v0/text/parse.rs | 429 ++++++++++++------ hugr-model/src/v0/text/print.rs | 86 ++-- 19 files changed, 1367 insertions(+), 720 deletions(-) create mode 100644 hugr-model/src/v0/scope/link.rs create mode 100644 hugr-model/src/v0/scope/mod.rs create mode 100644 hugr-model/src/v0/scope/symbol.rs create mode 100644 hugr-model/src/v0/scope/vars.rs diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 0892822e9..224fc47f1 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -14,11 +14,8 @@ use crate::{ use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; use fxhash::FxHashMap; use hugr_model::v0::{self as model}; -use indexmap::IndexSet; use std::fmt::Write; -type FxIndexSet = IndexSet; - pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; const TERM_PARAM_TUPLE: &str = "param.tuple"; const TERM_JSON: &str = "prelude.json"; @@ -37,9 +34,6 @@ struct Context<'a> { hugr: &'a Hugr, /// The module that is being built. module: model::Module<'a>, - /// Mapping from ports to link indices. - /// This only includes the minimum port among groups of linked ports. - links: FxIndexSet<(Node, Port)>, /// The arena in which the model is allocated. bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. @@ -61,6 +55,23 @@ struct Context<'a> { /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, + + /// Table that is used to track which ports are connected. + /// + /// Each group of ports that is connected together is represented by a + /// single link. When traversing the [`Hugr`] graph we assign a link to each + /// port by finding the smallest node/port pair among all the linked ports + /// and looking up the link for that pair in this table. + links: model::scope::LinkTable<(Node, Port)>, + + /// The symbol table tracking symbols that are currently in scope. + symbols: model::scope::SymbolTable<'a>, + + /// Mapping from implicit imports to their node ids. + implicit_imports: FxHashMap<&'a str, model::NodeId>, + + /// Map from node ids in the [`Hugr`] to the corresponding node ids in the model. + node_indices: FxHashMap, } impl<'a> Context<'a> { @@ -72,49 +83,68 @@ impl<'a> Context<'a> { hugr, module, bump, - links: IndexSet::default(), term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), local_constraints: Vec::new(), + symbols: model::scope::SymbolTable::default(), + implicit_imports: FxHashMap::default(), + node_indices: FxHashMap::default(), + links: model::scope::LinkTable::default(), } } /// Exports the root module of the HUGR graph. pub fn export_root(&mut self) { + self.module.root = self.module.insert_region(model::Region::default()); + self.symbols.enter(self.module.root); + self.links.enter(self.module.root); + let hugr_children = self.hugr.children(self.hugr.root()); let mut children = Vec::with_capacity(hugr_children.size_hint().0); - for child in self.hugr.children(self.hugr.root()) { - children.push(self.export_node(child)); + for child in hugr_children.clone() { + children.push(self.export_node_shallow(child)); + } + + for (child, child_node_id) in hugr_children.zip(children.iter().copied()) { + self.export_node_deep(child, child_node_id); } - children.extend(self.decl_operations.values().copied()); + let mut all_children = BumpVec::with_capacity_in( + children.len() + self.decl_operations.len() + self.implicit_imports.len(), + self.bump, + ); - let root = self.module.insert_region(model::Region { + all_children.extend(self.implicit_imports.drain().map(|(_, id)| id)); + all_children.extend(self.decl_operations.values().copied()); + all_children.extend(children); + + let (links, ports) = self.links.exit(); + self.symbols.exit(); + + self.module.regions[self.module.root.index()] = model::Region { kind: model::RegionKind::Module, sources: &[], targets: &[], - children: self.bump.alloc_slice_copy(&children), + children: all_children.into_bump_slice(), meta: &[], // TODO: Export metadata signature: None, - }); - - self.module.root = root; + scope: Some(model::RegionScope { links, ports }), + }; } /// Returns the edge id for a given port, creating a new edge if necessary. /// /// Any two ports that are linked will be represented by the same link. - fn get_link_id(&mut self, node: Node, port: impl Into) -> model::LinkId { + fn get_link_index(&mut self, node: Node, port: impl Into) -> model::LinkIndex { // To ensure that linked ports are represented by the same edge, we take the minimum port // among all the linked ports, including the one we started with. let port = port.into(); let linked_ports = self.hugr.linked_ports(node, port); let all_ports = std::iter::once((node, port)).chain(linked_ports); let repr = all_ports.min().unwrap(); - let edge = self.links.insert_full(repr).0 as _; - model::LinkId(edge) + self.links.use_link(repr) } pub fn make_ports( @@ -122,12 +152,12 @@ impl<'a> Context<'a> { node: Node, direction: Direction, num_ports: usize, - ) -> &'a [model::LinkRef<'a>] { + ) -> &'a [model::LinkIndex] { let ports = self.hugr.node_ports(node, direction); let mut links = BumpVec::with_capacity_in(ports.size_hint().0, self.bump); for port in ports.take(num_ports) { - links.push(model::LinkRef::Id(self.get_link_id(node, port))); + links.push(self.get_link_index(node, port)); } links.into_bump_slice() @@ -160,8 +190,9 @@ impl<'a> Context<'a> { &mut self, extension: &IdentList, name: impl AsRef, - ) -> model::GlobalRef<'a> { - model::GlobalRef::Named(self.make_qualified_name(extension, name)) + ) -> model::NodeId { + let symbol = self.make_qualified_name(extension, name); + self.resolve_symbol(symbol) } /// Get the node that declares or defines the function associated with the given @@ -195,12 +226,31 @@ impl<'a> Context<'a> { result } - pub fn export_node(&mut self, node: Node) -> model::NodeId { + fn export_node_shallow(&mut self, node: Node) -> model::NodeId { + let node_id = self.module.insert_node(model::Node::default()); + self.node_indices.insert(node, node_id); + + let symbol = match self.hugr.get_optype(node) { + OpType::FuncDefn(func_defn) => Some(func_defn.name.as_str()), + OpType::FuncDecl(func_decl) => Some(func_decl.name.as_str()), + OpType::AliasDecl(alias_decl) => Some(alias_decl.name.as_str()), + OpType::AliasDefn(alias_defn) => Some(alias_defn.name.as_str()), + _ => None, + }; + + if let Some(symbol) = symbol { + self.symbols + .insert(symbol, node_id) + .expect("duplicate symbol"); + } + + node_id + } + + fn export_node_deep(&mut self, node: Node, node_id: model::NodeId) { // We insert a dummy node with the invalid operation at this point to reserve // the node id. This is necessary to establish the correct node id for the // local scope introduced by some operations. We will overwrite this node later. - let node_id = self.module.insert_node(model::Node::default()); - let mut params: &[_] = &[]; let mut regions: &[_] = &[]; @@ -219,14 +269,18 @@ impl<'a> Context<'a> { OpType::DFG(dfg) => { let extensions = self.export_ext_set(&dfg.signature.runtime_reqs); - regions = self - .bump - .alloc_slice_copy(&[self.export_dfg(node, extensions)]); + regions = self.bump.alloc_slice_copy(&[self.export_dfg( + node, + extensions, + model::ScopeClosure::Open, + )]); model::Operation::Dfg } OpType::CFG(_) => { - regions = self.bump.alloc_slice_copy(&[self.export_cfg(node)]); + regions = self + .bump + .alloc_slice_copy(&[self.export_cfg(node, model::ScopeClosure::Open)]); model::Operation::Cfg } @@ -240,9 +294,11 @@ impl<'a> Context<'a> { OpType::DataflowBlock(block) => { let extensions = self.export_ext_set(&block.extension_delta); - regions = self - .bump - .alloc_slice_copy(&[self.export_dfg(node, extensions)]); + regions = self.bump.alloc_slice_copy(&[self.export_dfg( + node, + extensions, + model::ScopeClosure::Open, + )]); model::Operation::Block } @@ -256,9 +312,11 @@ impl<'a> Context<'a> { signature, }); let extensions = this.export_ext_set(&func.signature.body().runtime_reqs); - regions = this - .bump - .alloc_slice_copy(&[this.export_dfg(node, extensions)]); + regions = this.bump.alloc_slice_copy(&[this.export_dfg( + node, + extensions, + model::ScopeClosure::Closed, + )]); model::Operation::DefineFunc { decl } }), @@ -300,26 +358,25 @@ impl<'a> Context<'a> { OpType::Call(call) => { // TODO: If the node is not connected to a function, we should do better than panic. let node = self.connected_function(node).unwrap(); - let name = model::GlobalRef::Named(self.get_func_name(node).unwrap()); - + let symbol = self.node_indices[&node]; let mut args = BumpVec::new_in(self.bump); args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); - let func = self.make_term(model::Term::ApplyFull { global: name, args }); + let func = self.make_term(model::Term::ApplyFull { symbol, args }); model::Operation::CallFunc { func } } OpType::LoadFunction(load) => { // TODO: If the node is not connected to a function, we should do better than panic. let node = self.connected_function(node).unwrap(); - let name = model::GlobalRef::Named(self.get_func_name(node).unwrap()); + let symbol = self.node_indices[&node]; let mut args = BumpVec::new_in(self.bump); args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); - let func = self.make_term(model::Term::ApplyFull { global: name, args }); + let func = self.make_term(model::Term::ApplyFull { symbol, args }); model::Operation::LoadFunc { func } } @@ -327,16 +384,18 @@ impl<'a> Context<'a> { OpType::LoadConstant(_) => todo!("Export load constant?"), OpType::CallIndirect(_) => model::Operation::CustomFull { - operation: model::GlobalRef::Named(OP_FUNC_CALL_INDIRECT), + operation: self.resolve_symbol(OP_FUNC_CALL_INDIRECT), }, OpType::Tag(tag) => model::Operation::Tag { tag: tag.tag as _ }, OpType::TailLoop(tail_loop) => { let extensions = self.export_ext_set(&tail_loop.extension_delta); - regions = self - .bump - .alloc_slice_copy(&[self.export_dfg(node, extensions)]); + regions = self.bump.alloc_slice_copy(&[self.export_dfg( + node, + extensions, + model::ScopeClosure::Open, + )]); model::Operation::TailLoop } @@ -361,7 +420,9 @@ impl<'a> Context<'a> { // as that of the node. This might change in the future. let extensions = self.export_ext_set(&op.extension_delta()); - if let Some(region) = self.export_dfg_if_present(node, extensions) { + if let Some(region) = + self.export_dfg_if_present(node, extensions, model::ScopeClosure::Closed) + { regions = self.bump.alloc_slice_copy(&[region]); } @@ -381,7 +442,9 @@ impl<'a> Context<'a> { // as that of the node. This might change in the future. let extensions = self.export_ext_set(&op.extension_delta()); - if let Some(region) = self.export_dfg_if_present(node, extensions) { + if let Some(region) = + self.export_dfg_if_present(node, extensions, model::ScopeClosure::Closed) + { regions = self.bump.alloc_slice_copy(&[region]); } @@ -417,8 +480,7 @@ impl<'a> Context<'a> { None => &[], }; - // Replace the placeholder node with the actual node. - *self.module.get_node_mut(node_id).unwrap() = model::Node { + self.module.nodes[node_id.index()] = model::Node { operation, inputs, outputs, @@ -427,8 +489,6 @@ impl<'a> Context<'a> { meta, signature, }; - - node_id } /// Export an `OpDef` as an operation declaration. @@ -438,7 +498,7 @@ impl<'a> Context<'a> { /// of the operation. The node is added to the `decl_operations` map so that /// at the end of the export, the operation declaration nodes can be added /// to the module as children of the module region. - pub fn export_opdef(&mut self, opdef: &OpDef) -> model::GlobalRef<'a> { + pub fn export_opdef(&mut self, opdef: &OpDef) -> model::NodeId { use std::collections::hash_map::Entry; let poly_func_type = match opdef.signature_func() { @@ -450,9 +510,7 @@ impl<'a> Context<'a> { let entry = self.decl_operations.entry(key); let node = match entry { - Entry::Occupied(occupied_entry) => { - return model::GlobalRef::Direct(*occupied_entry.get()) - } + Entry::Occupied(occupied_entry) => return *occupied_entry.get(), Entry::Vacant(vacant_entry) => { *vacant_entry.insert(self.module.insert_node(model::Node { operation: model::Operation::Invalid, @@ -502,7 +560,7 @@ impl<'a> Context<'a> { node_data.operation = model::Operation::DeclareOperation { decl }; node_data.meta = meta; - model::GlobalRef::Direct(node) + node } /// Export the signature of a `DataflowBlock`. Here we can't use `OpType::dataflow_signature` @@ -545,18 +603,45 @@ impl<'a> Context<'a> { &mut self, node: Node, extensions: model::TermId, + closure: model::ScopeClosure, ) -> Option { if self.hugr.children(node).next().is_none() { None } else { - Some(self.export_dfg(node, extensions)) + Some(self.export_dfg(node, extensions, closure)) } } /// Creates a data flow region from the given node's children. /// /// `Input` and `Output` nodes are used to determine the source and target ports of the region. - pub fn export_dfg(&mut self, node: Node, extensions: model::TermId) -> model::RegionId { + pub fn export_dfg( + &mut self, + node: Node, + extensions: model::TermId, + closure: model::ScopeClosure, + ) -> model::RegionId { + let region = self.module.insert_region(model::Region::default()); + + self.symbols.enter(region); + if closure == model::ScopeClosure::Closed { + self.links.enter(region); + } + + let region_children = { + let children = self.hugr.children(node); + + // We skip the first two children, which are the `Input` and `Output` nodes. + // These nodes are not exported as model nodes themselves, but are used to determine + // the region's sources and targets. + let mut region_children = + BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump); + for child in children.skip(2) { + region_children.push(self.export_node_shallow(child)); + } + region_children.into_bump_slice() + }; + let mut children = self.hugr.children(node); // The first child is an `Input` node, which we use to determine the region's sources. @@ -574,10 +659,8 @@ impl<'a> Context<'a> { let targets = self.make_ports(output_node, Direction::Incoming, output_op.types.len()); // Export the remaining children of the node. - let mut region_children = BumpVec::with_capacity_in(children.size_hint().0, self.bump); - - for child in children { - region_children.push(self.export_node(child)); + for (child, child_node_id) in children.zip(region_children.iter().copied()) { + self.export_node_deep(child, child_node_id); } let signature = { @@ -591,62 +674,114 @@ impl<'a> Context<'a> { })) }; - self.module.insert_region(model::Region { + let scope = match closure { + model::ScopeClosure::Closed => { + let (links, ports) = self.links.exit(); + Some(model::RegionScope { links, ports }) + } + model::ScopeClosure::Open => None, + }; + self.symbols.exit(); + + self.module.regions[region.index()] = model::Region { kind: model::RegionKind::DataFlow, sources, targets, - children: region_children.into_bump_slice(), + children: region_children, meta: &[], // TODO: Export metadata signature, - }) + scope, + }; + + region } /// Creates a control flow region from the given node's children. - pub fn export_cfg(&mut self, node: Node) -> model::RegionId { - let mut children = self.hugr.children(node); - let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 + 1, self.bump); + pub fn export_cfg(&mut self, node: Node, closure: model::ScopeClosure) -> model::RegionId { + let region = self.module.insert_region(model::Region::default()); + self.symbols.enter(region); + + if closure == model::ScopeClosure::Closed { + self.links.enter(region); + } + + let region_children = { + let children = self.hugr.children(node); + let mut region_children = + BumpVec::with_capacity_in(children.size_hint().0 - 1, self.bump); + + // First export the children shallowly to allocate their IDs and register symbols. + for (i, child) in children.enumerate() { + // The second node is the exit block, which is not exported as a node itself. + if i == 1 { + continue; + } + + region_children.push(self.export_node_shallow(child)); + } + + region_children.into_bump_slice() + }; + + let mut children_iter = self.hugr.children(node); + let mut region_children_iter = region_children.iter().copied(); // The first child is the entry block. // We create a source port on the control flow region and connect it to the // first input port of the exported entry block. - let entry_block = children.next().unwrap(); + let source = { + let entry_block = children_iter.next().unwrap(); + let entry_node_id = region_children_iter.next().unwrap(); - let OpType::DataflowBlock(_) = self.hugr.get_optype(entry_block) else { - panic!("expected a `DataflowBlock` node as the first child node"); - }; + let OpType::DataflowBlock(_) = self.hugr.get_optype(entry_block) else { + panic!("expected a `DataflowBlock` node as the first child node"); + }; - let source = model::LinkRef::Id(self.get_link_id(entry_block, IncomingPort::from(0))); - region_children.push(self.export_node(entry_block)); + self.export_node_deep(entry_block, entry_node_id); + self.get_link_index(entry_block, IncomingPort::from(0)) + }; - // The last child is the exit block. + // The second child is the exit block. // Contrary to the entry block, the exit block does not have a dataflow subgraph. // We therefore do not export the block itself, but simply use its output ports // as the target ports of the control flow region. - let exit_block = children.next_back().unwrap(); - - // Export the remaining children of the node, except for the last one. - for child in children { - region_children.push(self.export_node(child)); - } + let exit_block = children_iter.next_back().unwrap(); let OpType::ExitBlock(_) = self.hugr.get_optype(exit_block) else { - panic!("expected an `ExitBlock` node as the last child node"); + panic!("expected an `ExitBlock` node as the second child node"); }; + // Export the remaining children of the node, except for the last one. + for (child, child_node_id) in children_iter.zip(region_children_iter) { + self.export_node_deep(child, child_node_id); + } + let targets = self.make_ports(exit_block, Direction::Incoming, 1); // Get the signature of the control flow region. // This is the same as the signature of the parent node. let signature = Some(self.export_func_type(&self.hugr.signature(node).unwrap())); - self.module.insert_region(model::Region { + let scope = match closure { + model::ScopeClosure::Closed => { + let (links, ports) = self.links.exit(); + Some(model::RegionScope { links, ports }) + } + model::ScopeClosure::Open => None, + }; + self.symbols.exit(); + + self.module.regions[region.index()] = model::Region { kind: model::RegionKind::ControlFlow, sources: self.bump.alloc_slice_copy(&[source]), targets, - children: region_children.into_bump_slice(), + children: region_children, meta: &[], // TODO: Export metadata signature, - }) + scope, + }; + + region } /// Export the `Case` node children of a `Conditional` node as data flow regions. @@ -660,7 +795,7 @@ impl<'a> Context<'a> { }; let extensions = self.export_ext_set(&case_op.signature.runtime_reqs); - regions.push(self.export_dfg(child, extensions)); + regions.push(self.export_dfg(child, extensions, model::ScopeClosure::Open)); } regions.into_bump_slice() @@ -683,7 +818,7 @@ impl<'a> Context<'a> { for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); + let r#type = self.export_type_param(param, Some((scope, i as _))); let param = model::Param { name, r#type, @@ -706,14 +841,17 @@ impl<'a> Context<'a> { match t { TypeEnum::Extension(ext) => self.export_custom_type(ext), TypeEnum::Alias(alias) => { - let name = model::GlobalRef::Named(self.bump.alloc_str(alias.name())); + let global = self.resolve_symbol(self.bump.alloc_str(alias.name())); let args = &[]; - self.make_term(model::Term::ApplyFull { global: name, args }) + self.make_term(model::Term::ApplyFull { + symbol: global, + args, + }) } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) + self.make_term(model::Term::Var(model::VarId(node, *index as _))) } TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), TypeEnum::Sum(sum) => self.export_sum_type(sum), @@ -732,12 +870,12 @@ impl<'a> Context<'a> { } pub fn export_custom_type(&mut self, t: &CustomType) -> model::TermId { - let global = self.make_named_global_ref(t.extension(), t.name()); + let symbol = self.make_named_global_ref(t.extension(), t.name()); let args = self .bump .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); - let term = model::Term::ApplyFull { global, args }; + let term = model::Term::ApplyFull { symbol, args }; self.make_term(term) } @@ -762,15 +900,12 @@ impl<'a> Context<'a> { pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> model::TermId { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var(model::LocalRef::Index( - node, - var.index() as _, - ))) + self.make_term(model::Term::Var(model::VarId(node, var.index() as _))) } pub fn export_row_var(&mut self, t: &RowVariable) -> model::TermId { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var(model::LocalRef::Index(node, t.0 as _))) + self.make_term(model::Term::Var(model::VarId(node, t.0 as _))) } pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { @@ -834,12 +969,12 @@ impl<'a> Context<'a> { pub fn export_type_param( &mut self, t: &TypeParam, - var: Option>, + var: Option<(model::NodeId, model::VarIndex)>, ) -> model::TermId { match t { TypeParam::Type { b } => { - if let (Some(var), TypeBound::Copyable) = (var, b) { - let term = self.make_term(model::Term::Var(var)); + if let (Some((node, index)), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(model::VarId(node, index))); let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); self.local_constraints.push(non_linear); } @@ -860,8 +995,9 @@ impl<'a> Context<'a> { .map(|param| model::ListPart::Item(self.export_type_param(param, None))), ); let types = self.make_term(model::Term::List { parts }); + let symbol = self.resolve_symbol(TERM_PARAM_TUPLE); self.make_term(model::Term::ApplyFull { - global: model::GlobalRef::Named(TERM_PARAM_TUPLE), + symbol, args: self.bump.alloc_slice_copy(&[types]), }) } @@ -879,10 +1015,9 @@ impl<'a> Context<'a> { for ext in ext_set.iter() { // `ExtensionSet`s represent variables by extension names that parse to integers. match ext.parse::() { - Ok(var) => { + Ok(index) => { let node = self.local_scope.expect("local variable out of scope"); - let local_ref = model::LocalRef::Index(node, var); - let term = self.make_term(model::Term::Var(local_ref)); + let term = self.make_term(model::Term::Var(model::VarId(node, index))); parts.push(model::ExtSetPart::Splice(term)); } Err(_) => parts.push(model::ExtSetPart::Extension(self.bump.alloc_str(ext))), @@ -913,11 +1048,26 @@ impl<'a> Context<'a> { let value = serde_json::to_string(value).expect("json values are always serializable"); let value = self.make_term(model::Term::Str(self.bump.alloc_str(&value))); let value = self.bump.alloc_slice_copy(&[value]); + let symbol = self.resolve_symbol(TERM_JSON); self.make_term(model::Term::ApplyFull { - global: model::GlobalRef::Named(TERM_JSON), + symbol, args: value, }) } + + fn resolve_symbol(&mut self, name: &'a str) -> model::NodeId { + let result = self.symbols.resolve(name); + + match result { + Ok(node) => node, + Err(_) => *self.implicit_imports.entry(name).or_insert_with(|| { + self.module.insert_node(model::Node { + operation: model::Operation::Import { name }, + ..model::Node::default() + }) + }), + } + } } #[cfg(test)] diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 532802903..72fe8601a 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -22,16 +22,13 @@ use crate::{ Direction, Hugr, HugrView, Node, Port, }; use fxhash::FxHashMap; -use hugr_model::v0::{self as model, GlobalRef}; -use indexmap::IndexMap; +use hugr_model::v0::{self as model}; use itertools::Either; use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; const TERM_JSON: &str = "prelude.json"; -type FxIndexMap = IndexMap; - /// Error during import. #[derive(Debug, Clone, Error)] pub enum ImportError { @@ -76,22 +73,18 @@ pub fn import_hugr( module: &model::Module, extensions: &ExtensionRegistry, ) -> Result { - let names = Names::new(module)?; - // TODO: Module should know about the number of edges, so that we can use a vector here. // For now we use a hashmap, which will be slower. - let edge_ports = FxHashMap::default(); - let mut ctx = Context { module, - names, hugr: Hugr::new(OpType::Module(Module {})), - link_ports: edge_ports, + link_ports: FxHashMap::default(), static_edges: Vec::new(), extensions, nodes: FxHashMap::default(), - local_variables: IndexMap::default(), + local_vars: FxHashMap::default(), custom_name_cache: FxHashMap::default(), + region_scope: model::RegionId::default(), }; ctx.import_root()?; @@ -105,31 +98,28 @@ struct Context<'a> { /// The module being imported. module: &'a model::Module<'a>, - names: Names<'a>, - /// The HUGR graph being constructed. hugr: Hugr, /// The ports that are part of each link. This is used to connect the ports at the end of the /// import process. - link_ports: FxHashMap, Vec<(Node, Port)>>, + link_ports: FxHashMap<(model::RegionId, model::LinkIndex), Vec<(Node, Port)>>, /// Pairs of nodes that should be connected by a static edge. /// These are collected during the import process and connected at the end. static_edges: Vec<(model::NodeId, model::NodeId)>, - // /// The `(Node, Port)` pairs for each `PortId` in the module. - // imported_ports: Vec>, /// The ambient extension registry to use for importing. extensions: &'a ExtensionRegistry, /// A map from `NodeId` to the imported `Node`. nodes: FxHashMap, - /// The local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, LocalVar>, + local_vars: FxHashMap, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, + + region_scope: model::RegionId, } impl<'a> Context<'a> { @@ -166,25 +156,6 @@ impl<'a> Context<'a> { .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) } - /// Looks up a [`LocalRef`] within the current scope. - fn resolve_local_ref( - &self, - local_ref: &model::LocalRef, - ) -> Result<(usize, LocalVar), ImportError> { - let term = match local_ref { - model::LocalRef::Index(_, index) => self - .local_variables - .get_index(*index as usize) - .map(|(_, v)| (*index as usize, *v)), - model::LocalRef::Named(name) => self - .local_variables - .get_full(name) - .map(|(index, _, v)| (index, *v)), - }; - - term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) - } - fn make_node( &mut self, node_id: model::NodeId, @@ -209,13 +180,16 @@ impl<'a> Context<'a> { } /// Associate links with the ports of the given node in the given direction. - fn record_links(&mut self, node: Node, direction: Direction, links: &'a [model::LinkRef<'a>]) { + fn record_links(&mut self, node: Node, direction: Direction, links: &'a [model::LinkIndex]) { let optype = self.hugr.get_optype(node); // NOTE: `OpType::port_count` copies the signature, which significantly slows down the import. debug_assert!(links.len() <= optype.port_count(direction)); for (link, port) in links.iter().zip(self.hugr.node_ports(node, direction)) { - self.link_ports.entry(*link).or_default().push((node, port)); + self.link_ports + .entry((self.region_scope, *link)) + .or_default() + .push((node, port)); } } @@ -242,8 +216,9 @@ impl<'a> Context<'a> { if inputs.is_empty() || outputs.is_empty() { return Err(error_unsupported!( - "link {} is missing either an input or an output port", - link_id + "link {}#{} is missing either an input or an output port", + link_id.0, + link_id.1 )); } @@ -279,60 +254,13 @@ impl<'a> Context<'a> { Ok(()) } - fn with_local_socpe( - &mut self, - f: impl FnOnce(&mut Self) -> Result, - ) -> Result { - let previous = std::mem::take(&mut self.local_variables); - let result = f(self); - self.local_variables = previous; - result - } - - fn resolve_global_ref( - &self, - global_ref: &model::GlobalRef, - ) -> Result { - match global_ref { - model::GlobalRef::Direct(node_id) => Ok(*node_id), - model::GlobalRef::Named(name) => { - let item = self - .names - .items - .get(name) - .ok_or_else(|| model::ModelError::InvalidGlobal(global_ref.to_string()))?; - - match item { - NamedItem::FuncDecl(node) => Ok(*node), - NamedItem::FuncDefn(node) => Ok(*node), - NamedItem::CtrDecl(node) => Ok(*node), - NamedItem::OperationDecl(node) => Ok(*node), - } - } - } - } - - fn get_global_name(&self, global_ref: model::GlobalRef<'a>) -> Result<&'a str, ImportError> { - match global_ref { - model::GlobalRef::Direct(node_id) => { - let node_data = self.get_node(node_id)?; - - let name = match node_data.operation { - model::Operation::DefineFunc { decl } => decl.name, - model::Operation::DeclareFunc { decl } => decl.name, - model::Operation::DefineAlias { decl, .. } => decl.name, - model::Operation::DeclareAlias { decl } => decl.name, - model::Operation::DeclareConstructor { decl } => decl.name, - model::Operation::DeclareOperation { decl } => decl.name, - _ => { - return Err(model::ModelError::InvalidGlobal(global_ref.to_string()).into()); - } - }; - - Ok(name) - } - model::GlobalRef::Named(name) => Ok(name), - } + fn get_symbol_name(&self, node_id: model::NodeId) -> Result<&'a str, ImportError> { + let node_data = self.get_node(node_id)?; + let name = node_data + .operation + .symbol() + .ok_or(model::ModelError::InvalidSymbol(node_id))?; + Ok(name) } fn get_func_signature( @@ -345,11 +273,12 @@ impl<'a> Context<'a> { _ => return Err(model::ModelError::UnexpectedOperation(func_node).into()), }; - self.import_poly_func_type(*decl, |_, signature| Ok(signature)) + self.import_poly_func_type(func_node, *decl, |_, signature| Ok(signature)) } /// Import the root region of the module. fn import_root(&mut self) -> Result<(), ImportError> { + self.region_scope = self.module.root; let region_data = self.get_region(self.module.root)?; for node in region_data.children { @@ -400,7 +329,7 @@ impl<'a> Context<'a> { } model::Operation::DefineFunc { decl } => { - self.import_poly_func_type(*decl, |ctx, signature| { + self.import_poly_func_type(node_id, *decl, |ctx, signature| { let optype = OpType::FuncDefn(FuncDefn { name: decl.name.to_string(), signature, @@ -419,7 +348,7 @@ impl<'a> Context<'a> { } model::Operation::DeclareFunc { decl } => { - self.import_poly_func_type(*decl, |ctx, signature| { + self.import_poly_func_type(node_id, *decl, |ctx, signature| { let optype = OpType::FuncDecl(FuncDecl { name: decl.name.to_string(), signature, @@ -432,19 +361,18 @@ impl<'a> Context<'a> { } model::Operation::CallFunc { func } => { - let model::Term::ApplyFull { global: name, args } = self.get_term(func)? else { + let model::Term::ApplyFull { symbol, args } = self.get_term(func)? else { return Err(model::ModelError::TypeError(func).into()); }; - let func_node = self.resolve_global_ref(name)?; - let func_sig = self.get_func_signature(func_node)?; + let func_sig = self.get_func_signature(*symbol)?; let type_args = args .iter() .map(|term| self.import_type_arg(*term)) .collect::, _>>()?; - self.static_edges.push((func_node, node_id)); + self.static_edges.push((*symbol, node_id)); let optype = OpType::Call(Call::try_new(func_sig, type_args)?); let node = self.make_node(node_id, optype, parent)?; @@ -452,19 +380,18 @@ impl<'a> Context<'a> { } model::Operation::LoadFunc { func } => { - let model::Term::ApplyFull { global: name, args } = self.get_term(func)? else { + let model::Term::ApplyFull { symbol, args } = self.get_term(func)? else { return Err(model::ModelError::TypeError(func).into()); }; - let func_node = self.resolve_global_ref(name)?; - let func_sig = self.get_func_signature(func_node)?; + let func_sig = self.get_func_signature(*symbol)?; let type_args = args .iter() .map(|term| self.import_type_arg(*term)) .collect::, _>>()?; - self.static_edges.push((func_node, node_id)); + self.static_edges.push((*symbol, node_id)); let optype = OpType::LoadFunction(LoadFunction::try_new(func_sig, type_args)?); @@ -481,16 +408,16 @@ impl<'a> Context<'a> { Ok(Some(node)) } - model::Operation::CustomFull { - operation: GlobalRef::Named(name), - } if name == OP_FUNC_CALL_INDIRECT => { - let signature = self.get_node_signature(node_id)?; - let optype = OpType::CallIndirect(CallIndirect { signature }); - let node = self.make_node(node_id, optype, parent)?; - Ok(Some(node)) - } - model::Operation::CustomFull { operation } => { + let name = self.get_symbol_name(operation)?; + + if name == OP_FUNC_CALL_INDIRECT { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::CallIndirect(CallIndirect { signature }); + let node = self.make_node(node_id, optype, parent)?; + return Ok(Some(node)); + } + let signature = self.get_node_signature(node_id)?; let args = node_data .params @@ -498,7 +425,6 @@ impl<'a> Context<'a> { .map(|param| self.import_type_arg(*param)) .collect::, _>>()?; - let name = self.get_global_name(operation)?; let (extension, name) = self.import_custom_name(name)?; // TODO: Currently we do not have the description or any other metadata for @@ -529,7 +455,7 @@ impl<'a> Context<'a> { "custom operation with implicit parameters" )), - model::Operation::DefineAlias { decl, value } => self.with_local_socpe(|ctx| { + model::Operation::DefineAlias { decl, value } => { if !decl.params.is_empty() { return Err(error_unsupported!( "parameters or constraints in alias definition" @@ -538,14 +464,14 @@ impl<'a> Context<'a> { let optype = OpType::AliasDefn(AliasDefn { name: decl.name.to_smolstr(), - definition: ctx.import_type(value)?, + definition: self.import_type(value)?, }); - let node = ctx.make_node(node_id, optype, parent)?; + let node = self.make_node(node_id, optype, parent)?; Ok(Some(node)) - }), + } - model::Operation::DeclareAlias { decl } => self.with_local_socpe(|ctx| { + model::Operation::DeclareAlias { decl } => { if !decl.params.is_empty() { return Err(error_unsupported!( "parameters or constraints in alias declaration" @@ -557,9 +483,9 @@ impl<'a> Context<'a> { bound: TypeBound::Copyable, }); - let node = ctx.make_node(node_id, optype, parent)?; + let node = self.make_node(node_id, optype, parent)?; Ok(Some(node)) - }), + } model::Operation::Tag { tag } => { let signature = node_data @@ -578,6 +504,8 @@ impl<'a> Context<'a> { Ok(Some(node)) } + model::Operation::Import { .. } => Ok(None), + model::Operation::DeclareConstructor { .. } => Ok(None), model::Operation::DeclareOperation { .. } => Ok(None), } @@ -591,6 +519,11 @@ impl<'a> Context<'a> { ) -> Result<(), ImportError> { let region_data = self.get_region(region)?; + let prev_region = self.region_scope; + if region_data.scope.is_some() { + self.region_scope = region; + } + if region_data.kind != model::RegionKind::DataFlow { return Err(model::ModelError::InvalidRegions(node_id).into()); } @@ -623,6 +556,8 @@ impl<'a> Context<'a> { self.import_node(*child, node)?; } + self.region_scope = prev_region; + Ok(()) } @@ -755,6 +690,11 @@ impl<'a> Context<'a> { return Err(model::ModelError::InvalidRegions(node_id).into()); } + let prev_region = self.region_scope; + if region_data.scope.is_some() { + self.region_scope = region; + } + let (region_source, region_targets, _) = self.get_func_type( region_data .signature @@ -861,6 +801,8 @@ impl<'a> Context<'a> { self.record_links(exit, Direction::Incoming, region_data.targets); } + self.region_scope = prev_region; + Ok(()) } @@ -899,43 +841,43 @@ impl<'a> Context<'a> { fn import_poly_func_type( &mut self, + node: model::NodeId, decl: model::FuncDecl<'a>, in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, ) -> Result { - self.with_local_socpe(|ctx| { - let mut imported_params = Vec::with_capacity(decl.params.len()); + let mut imported_params = Vec::with_capacity(decl.params.len()); - ctx.local_variables.extend( - decl.params - .iter() - .map(|param| (param.name, LocalVar::new(param.r#type))), - ); + for (index, param) in decl.params.iter().enumerate() { + self.local_vars + .insert(model::VarId(node, index as _), LocalVar::new(param.r#type)); + } - for constraint in decl.constraints { - match ctx.get_term(*constraint)? { - model::Term::NonLinearConstraint { term } => { - let model::Term::Var(var) = ctx.get_term(*term)? else { - return Err(error_unsupported!( - "constraint on term that is not a variable" - )); - }; - - let var = ctx.resolve_local_ref(var)?.0; - ctx.local_variables[var].bound = TypeBound::Copyable; - } - _ => return Err(error_unsupported!("constraint other than copy or discard")), + for constraint in decl.constraints { + match self.get_term(*constraint)? { + model::Term::NonLinearConstraint { term } => { + let model::Term::Var(var) = self.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + self.local_vars + .get_mut(var) + .ok_or(model::ModelError::InvalidVar(*var))? + .bound = TypeBound::Copyable; } + _ => return Err(error_unsupported!("constraint other than copy or discard")), } + } - for (index, param) in decl.params.iter().enumerate() { - // NOTE: `PolyFuncType` only has explicit type parameters at present. - let bound = ctx.local_variables[index].bound; - imported_params.push(ctx.import_type_param(param.r#type, bound)?); - } + for (index, param) in decl.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = self.local_vars[&model::VarId(node, index as _)].bound; + imported_params.push(self.import_type_param(param.r#type, bound)?); + } - let body = ctx.import_func_type::(decl.signature)?; - in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) - }) + let body = self.import_func_type::(decl.signature)?; + in_scope(self, PolyFuncTypeBase::new(imported_params, body)) } /// Import a [`TypeParam`] from a term that represents a static type. @@ -951,7 +893,7 @@ impl<'a> Context<'a> { model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), - model::Term::Var(_) => Err(error_unsupported!("type variable as `TypeParam`")), + model::Term::Var { .. } => Err(error_unsupported!("type variable as `TypeParam`")), model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")), model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")), @@ -995,9 +937,12 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var.r#type, var.bound)?; - Ok(TypeArg::new_var_use(index, decl)) + let var_info = self + .local_vars + .get(var) + .ok_or(model::ModelError::InvalidVar(*var))?; + let decl = self.import_type_param(var_info.r#type, var_info.bound)?; + Ok(TypeArg::new_var_use(var.1 as _, decl)) } model::Term::List { .. } => { @@ -1052,9 +997,8 @@ impl<'a> Context<'a> { match self.get_term(term_id)? { model::Term::Wildcard => return Err(error_uninferred!("wildcard")), - model::Term::Var(var) => { - let (index, _) = self.resolve_local_ref(var)?; - es.insert_type_var(index); + model::Term::Var(model::VarId(_, index)) => { + es.insert_type_var(*index as _); } model::Term::ExtSet { parts } => { @@ -1091,13 +1035,13 @@ impl<'a> Context<'a> { Err(error_uninferred!("application with implicit parameters")) } - model::Term::ApplyFull { global: name, args } => { + model::Term::ApplyFull { symbol, args } => { let args = args .iter() .map(|arg| self.import_type_arg(*arg)) .collect::, _>>()?; - let name = self.get_global_name(*name)?; + let name = self.get_symbol_name(*symbol)?; let (extension, id) = self.import_custom_name(name)?; let extension_ref = @@ -1119,10 +1063,8 @@ impl<'a> Context<'a> { ))) } - model::Term::Var(var) => { - // We pretend that all `TypeBound`s are copyable. - let (index, _) = self.resolve_local_ref(var)?; - Ok(TypeBase::new_var_use(index, TypeBound::Copyable)) + model::Term::Var(model::VarId(_, index)) => { + Ok(TypeBase::new_var_use(*index as _, TypeBound::Copyable)) } model::Term::FuncType { .. } => { @@ -1255,9 +1197,8 @@ impl<'a> Context<'a> { } } } - model::Term::Var(var) => { - let (index, _) = ctx.resolve_local_ref(var)?; - let var = RV::try_from_rv(RowVariable(index, TypeBound::Any)) + model::Term::Var(model::VarId(_, index)) => { + let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Any)) .map_err(|_| model::ModelError::TypeError(term_id))?; types.push(TypeBase::new(TypeEnum::RowVar(var))); } @@ -1298,13 +1239,14 @@ impl<'a> Context<'a> { term_id: model::TermId, ) -> Result { let (global, args) = match self.get_term(term_id)? { - model::Term::Apply { global, args } | model::Term::ApplyFull { global, args } => { - (global, args) + model::Term::Apply { symbol, args } | model::Term::ApplyFull { symbol, args } => { + (symbol, args) } _ => return Err(model::ModelError::TypeError(term_id).into()), }; - if global != &GlobalRef::Named(TERM_JSON) { + let global = self.get_symbol_name(*global)?; + if global != TERM_JSON { return Err(model::ModelError::TypeError(term_id).into()); } @@ -1323,51 +1265,6 @@ impl<'a> Context<'a> { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum NamedItem { - FuncDecl(model::NodeId), - FuncDefn(model::NodeId), - CtrDecl(model::NodeId), - OperationDecl(model::NodeId), -} - -struct Names<'a> { - items: FxHashMap<&'a str, NamedItem>, -} - -impl<'a> Names<'a> { - pub fn new(module: &model::Module<'a>) -> Result { - let mut items = FxHashMap::default(); - - for (node_id, node_data) in module.nodes.iter().enumerate() { - let node_id = model::NodeId(node_id as _); - - let item = match node_data.operation { - model::Operation::DefineFunc { decl } => { - Some((decl.name, NamedItem::FuncDecl(node_id))) - } - model::Operation::DeclareFunc { decl } => { - Some((decl.name, NamedItem::FuncDefn(node_id))) - } - model::Operation::DeclareConstructor { decl } => { - Some((decl.name, NamedItem::CtrDecl(node_id))) - } - model::Operation::DeclareOperation { decl } => { - Some((decl.name, NamedItem::OperationDecl(node_id))) - } - _ => None, - }; - - if let Some((name, item)) = item { - // TODO: Deal with duplicates - items.insert(name, item); - } - } - - Ok(Self { items }) - } -} - /// Information about a local variable. #[derive(Debug, Clone, Copy)] struct LocalVar { diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap index b7de139fe..7ffec5ef9 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_add.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -1,9 +1,13 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-add.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add.edn\"))" --- (hugr 0) +(import arithmetic.int.iadd) + +(import arithmetic.int.types.int) + (define-func example.add [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap index c279c5d6a..27fdd4740 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_alias.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -1,9 +1,11 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-alias.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alias.edn\"))" --- (hugr 0) +(import arithmetic.int.types.int) + (declare-alias local.float type) (define-alias local.int type (@ arithmetic.int.types.int)) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 5ddc4eb32..2b37b5a20 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -4,6 +4,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call --- (hugr 0) +(import prelude.json) + +(import arithmetic.int.types.int) + (declare-func example.callee (forall ?0 ext-set) [(@ arithmetic.int.types.int)] diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index 41a8f0d62..e39f0d37d 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -13,14 +13,14 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (cfg [%0] [%1] (signature (fn [?0] [?0] (ext))) (cfg - [%2] [%8] + [%4] [%8] (signature (fn [?0] [?0] (ext))) - (block [%2] [%5] + (block [%4] [%5] (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg - [%3] [%4] + [%2] [%3] (signature (fn [?0] [(adt [[?0]])] (ext))) - (tag 0 [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext)))))) + (tag 0 [%2] [%3] (signature (fn [?0] [(adt [[?0]])] (ext)))))) (block [%5] [%8] (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap index fe55e965f..92ab0cb4d 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cond.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -1,9 +1,13 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-cond.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cond.edn\"))" --- (hugr 0) +(import arithmetic.int.types.int) + +(import arithmetic.int.ineg) + (define-func example.cond [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap index 291c2de48..d7cb2bf01 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -4,6 +4,8 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons --- (hugr 0) +(import prelude.Array) + (declare-func array.replicate (forall ?0 type) (forall ?1 nat) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 366de92eb..b3bb1f0f2 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -15,6 +15,9 @@ using NodeId = UInt32; # The id of a `Link`. using LinkId = UInt32; +# The index of a `Link`. +using LinkIndex = UInt32; + struct Module { root @0 :RegionId; nodes @1 :List(Node); @@ -24,8 +27,8 @@ struct Module { struct Node { operation @0 :Operation; - inputs @1 :List(LinkRef); - outputs @2 :List(LinkRef); + inputs @1 :List(LinkIndex); + outputs @2 :List(LinkIndex); params @3 :List(TermId); regions @4 :List(RegionId); meta @5 :List(MetaItem); @@ -42,8 +45,8 @@ struct Operation { funcDecl @5 :FuncDecl; aliasDefn @6 :AliasDefn; aliasDecl @7 :AliasDecl; - custom @8 :GlobalRef; - customFull @9 :GlobalRef; + custom @8 :NodeId; + customFull @9 :NodeId; tag @10 :UInt16; tailLoop @11 :Void; conditional @12 :Void; @@ -51,6 +54,7 @@ struct Operation { loadFunc @14 :TermId; constructorDecl @15 :ConstructorDecl; operationDecl @16 :OperationDecl; + import @17 :Text; } struct FuncDefn { @@ -97,13 +101,22 @@ struct Operation { struct Region { kind @0 :RegionKind; - sources @1 :List(LinkRef); - targets @2 :List(LinkRef); + sources @1 :List(LinkIndex); + targets @2 :List(LinkIndex); children @3 :List(NodeId); meta @4 :List(MetaItem); signature @5 :OptionalTermId; + scope @6 :RegionScope; +} + +struct RegionScope { + links @0 :UInt32; + ports @1 :UInt32; } +# Either `0` for an open scope, or the number of links in the closed scope incremented by `1`. +using LinkScope = UInt32; + enum RegionKind { dataFlow @0; controlFlow @1; @@ -115,37 +128,16 @@ struct MetaItem { value @1 :UInt32; } -struct LinkRef { - union { - id @0 :LinkId; - named @1 :Text; - } -} - -struct GlobalRef { - union { - node @0 :NodeId; - named @1 :Text; - } -} - -struct LocalRef { - union { - direct :group { - index @0 :UInt16; - node @1 :NodeId; - } - named @2 :Text; - } -} - struct Term { union { wildcard @0 :Void; runtimeType @1 :Void; staticType @2 :Void; constraint @3 :Void; - variable @4 :LocalRef; + variable :group { + variableNode @4 :NodeId; + variableIndex @21 :UInt16; + } apply @5 :Apply; applyFull @6 :ApplyFull; quote @7 :TermId; @@ -165,12 +157,12 @@ struct Term { } struct Apply { - global @0 :GlobalRef; + symbol @0 :NodeId; args @1 :List(TermId); } struct ApplyFull { - global @0 :GlobalRef; + symbol @0 :NodeId; args @1 :List(TermId); } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 2dfe67efc..b14ca4482 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -71,8 +71,8 @@ fn read_module<'a>( fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult> { let operation = read_operation(bump, reader.get_operation()?)?; - let inputs = read_list!(bump, reader, get_inputs, read_link_ref); - let outputs = read_list!(bump, reader, get_outputs, read_link_ref); + let inputs = read_scalar_list!(bump, reader, get_inputs, model::LinkIndex); + let outputs = read_scalar_list!(bump, reader, get_outputs, model::LinkIndex); let params = read_scalar_list!(bump, reader, get_params, model::TermId); let regions = read_scalar_list!(bump, reader, get_regions, model::RegionId); let meta = read_list!(bump, reader, get_meta, read_meta_item); @@ -89,43 +89,6 @@ fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult }) } -fn read_local_ref<'a>( - bump: &'a Bump, - reader: hugr_capnp::local_ref::Reader, -) -> ReadResult> { - use hugr_capnp::local_ref::Which; - Ok(match reader.which()? { - Which::Direct(reader) => { - let index = reader.get_index(); - let node = model::NodeId(reader.get_node()); - model::LocalRef::Index(node, index) - } - Which::Named(name) => model::LocalRef::Named(bump.alloc_str(name?.to_str()?)), - }) -} - -fn read_global_ref<'a>( - bump: &'a Bump, - reader: hugr_capnp::global_ref::Reader, -) -> ReadResult> { - use hugr_capnp::global_ref::Which; - Ok(match reader.which()? { - Which::Node(node) => model::GlobalRef::Direct(model::NodeId(node)), - Which::Named(name) => model::GlobalRef::Named(bump.alloc_str(name?.to_str()?)), - }) -} - -fn read_link_ref<'a>( - bump: &'a Bump, - reader: hugr_capnp::link_ref::Reader, -) -> ReadResult> { - use hugr_capnp::link_ref::Which; - Ok(match reader.which()? { - Which::Id(id) => model::LinkRef::Id(model::LinkId(id)), - Which::Named(name) => model::LinkRef::Named(bump.alloc_str(name?.to_str()?)), - }) -} - fn read_operation<'a>( bump: &'a Bump, reader: hugr_capnp::operation::Reader, @@ -217,11 +180,11 @@ fn read_operation<'a>( }); model::Operation::DeclareOperation { decl } } - Which::Custom(name) => model::Operation::Custom { - operation: read_global_ref(bump, name?)?, + Which::Custom(operation) => model::Operation::Custom { + operation: model::NodeId(operation), }, - Which::CustomFull(name) => model::Operation::CustomFull { - operation: read_global_ref(bump, name?)?, + Which::CustomFull(operation) => model::Operation::CustomFull { + operation: model::NodeId(operation), }, Which::Tag(tag) => model::Operation::Tag { tag }, Which::TailLoop(()) => model::Operation::TailLoop, @@ -232,6 +195,9 @@ fn read_operation<'a>( Which::LoadFunc(func) => model::Operation::LoadFunc { func: model::TermId(func), }, + Which::Import(name) => model::Operation::Import { + name: bump.alloc_str(name?.to_str()?), + }, }) } @@ -245,12 +211,18 @@ fn read_region<'a>( hugr_capnp::RegionKind::Module => model::RegionKind::Module, }; - let sources = read_list!(bump, reader, get_sources, read_link_ref); - let targets = read_list!(bump, reader, get_targets, read_link_ref); + let sources = read_scalar_list!(bump, reader, get_sources, model::LinkIndex); + let targets = read_scalar_list!(bump, reader, get_targets, model::LinkIndex); let children = read_scalar_list!(bump, reader, get_children, model::NodeId); let meta = read_list!(bump, reader, get_meta, read_meta_item); let signature = reader.get_signature().checked_sub(1).map(model::TermId); + let scope = if reader.has_scope() { + Some(read_region_scope(reader.get_scope()?)?) + } else { + None + }; + Ok(model::Region { kind, sources, @@ -258,9 +230,16 @@ fn read_region<'a>( children, meta, signature, + scope, }) } +fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult { + let links = reader.get_links(); + let ports = reader.get_ports(); + Ok(model::RegionScope { links, ports }) +} + fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult> { use hugr_capnp::term::Which; Ok(match reader.which()? { @@ -274,20 +253,25 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::NatType(()) => model::Term::NatType, Which::ExtSetType(()) => model::Term::ExtSetType, Which::ControlType(()) => model::Term::ControlType, - Which::Variable(local_ref) => model::Term::Var(read_local_ref(bump, local_ref?)?), + + Which::Variable(reader) => { + let node = model::NodeId(reader.get_variable_node()); + let index = reader.get_variable_index(); + model::Term::Var(model::VarId(node, index)) + } Which::Apply(reader) => { let reader = reader?; - let global = read_global_ref(bump, reader.get_global()?)?; + let symbol = model::NodeId(reader.get_symbol()); let args = read_scalar_list!(bump, reader, get_args, model::TermId); - model::Term::Apply { global, args } + model::Term::Apply { symbol, args } } Which::ApplyFull(reader) => { let reader = reader?; - let global = read_global_ref(bump, reader.get_global()?)?; + let symbol = model::NodeId(reader.get_symbol()); let args = read_scalar_list!(bump, reader, get_args, model::TermId); - model::Term::ApplyFull { global, args } + model::Term::ApplyFull { symbol, args } } Which::Quote(r#type) => model::Term::Quote { diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index aa377e2ec..ea495db54 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -31,8 +31,8 @@ fn write_module(mut builder: hugr_capnp::module::Builder, module: &model::Module fn write_node(mut builder: hugr_capnp::node::Builder, node: &model::Node) { write_operation(builder.reborrow().init_operation(), &node.operation); - write_list!(builder, init_inputs, write_link_ref, node.inputs); - write_list!(builder, init_outputs, write_link_ref, node.outputs); + let _ = builder.set_inputs(model::LinkIndex::unwrap_slice(node.inputs)); + let _ = builder.set_outputs(model::LinkIndex::unwrap_slice(node.outputs)); write_list!(builder, init_meta, write_meta_item, node.meta); let _ = builder.set_params(model::TermId::unwrap_slice(node.params)); let _ = builder.set_regions(model::RegionId::unwrap_slice(node.regions)); @@ -47,11 +47,9 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode model::Operation::TailLoop => builder.set_tail_loop(()), model::Operation::Conditional => builder.set_conditional(()), model::Operation::Tag { tag } => builder.set_tag(*tag), - model::Operation::Custom { operation } => { - write_global_ref(builder.init_custom(), operation) - } + model::Operation::Custom { operation } => builder.set_custom(operation.0), model::Operation::CustomFull { operation } => { - write_global_ref(builder.init_custom_full(), operation) + builder.set_custom_full(operation.0); } model::Operation::CallFunc { func } => builder.set_call_func(func.0), model::Operation::LoadFunc { func } => builder.set_load_func(func.0), @@ -100,6 +98,10 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode builder.set_type(decl.r#type.0); } + model::Operation::Import { name } => { + builder.set_import(*name); + } + model::Operation::Invalid => builder.set_invalid(()), } } @@ -113,31 +115,6 @@ fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { }); } -fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { - match global_ref { - model::GlobalRef::Direct(node) => builder.set_node(node.0), - model::GlobalRef::Named(name) => builder.set_named(name), - } -} - -fn write_link_ref(mut builder: hugr_capnp::link_ref::Builder, link_ref: &model::LinkRef) { - match link_ref { - model::LinkRef::Id(id) => builder.set_id(id.0), - model::LinkRef::Named(name) => builder.set_named(name), - } -} - -fn write_local_ref(mut builder: hugr_capnp::local_ref::Builder, local_ref: &model::LocalRef) { - match local_ref { - model::LocalRef::Index(node, index) => { - let mut builder = builder.init_direct(); - builder.set_node(node.0); - builder.set_index(*index); - } - model::LocalRef::Named(name) => builder.set_named(name), - } -} - fn write_meta_item(mut builder: hugr_capnp::meta_item::Builder, meta_item: &model::MetaItem) { builder.set_name(meta_item.name); builder.set_value(meta_item.value.0) @@ -150,11 +127,20 @@ fn write_region(mut builder: hugr_capnp::region::Builder, region: &model::Region model::RegionKind::Module => hugr_capnp::RegionKind::Module, }); - write_list!(builder, init_sources, write_link_ref, region.sources); - write_list!(builder, init_targets, write_link_ref, region.targets); + let _ = builder.set_sources(model::LinkIndex::unwrap_slice(region.sources)); + let _ = builder.set_targets(model::LinkIndex::unwrap_slice(region.targets)); let _ = builder.set_children(model::NodeId::unwrap_slice(region.children)); write_list!(builder, init_meta, write_meta_item, region.meta); builder.set_signature(region.signature.map_or(0, |t| t.0 + 1)); + + if let Some(scope) = ®ion.scope { + write_region_scope(builder.init_scope(), scope); + } +} + +fn write_region_scope(mut builder: hugr_capnp::region_scope::Builder, scope: &model::RegionScope) { + builder.set_links(scope.links); + builder.set_ports(scope.ports); } fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { @@ -163,7 +149,11 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { model::Term::Type => builder.set_runtime_type(()), model::Term::StaticType => builder.set_static_type(()), model::Term::Constraint => builder.set_constraint(()), - model::Term::Var(local_ref) => write_local_ref(builder.init_variable(), local_ref), + model::Term::Var(model::VarId(node, index)) => { + let mut builder = builder.init_variable(); + builder.set_variable_node(node.0); + builder.set_variable_index(*index); + } model::Term::ListType { item_type } => builder.set_list_type(item_type.0), model::Term::Str(value) => builder.set_string(value), model::Term::StrType => builder.set_string_type(()), @@ -175,15 +165,15 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { model::Term::Control { values } => builder.set_control(values.0), model::Term::ControlType => builder.set_control_type(()), - model::Term::Apply { global, args } => { + model::Term::Apply { symbol, args } => { let mut builder = builder.init_apply(); - write_global_ref(builder.reborrow().init_global(), global); + builder.set_symbol(symbol.0); let _ = builder.set_args(model::TermId::unwrap_slice(args)); } - model::Term::ApplyFull { global, args } => { + model::Term::ApplyFull { symbol, args } => { let mut builder = builder.init_apply_full(); - write_global_ref(builder.reborrow().init_global(), global); + builder.set_symbol(symbol.0); let _ = builder.set_args(model::TermId::unwrap_slice(args)); } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 2b0dc1eaf..ad3733079 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -91,6 +91,7 @@ use smol_str::SmolStr; use thiserror::Error; pub mod binary; +pub mod scope; pub mod text; macro_rules! define_index { @@ -132,7 +133,7 @@ macro_rules! define_index { } define_index! { - /// Index of a node in a hugr graph. + /// Id of a node in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct NodeId(pub u32); } @@ -140,21 +141,31 @@ define_index! { define_index! { /// Index of a link in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] - pub struct LinkId(pub u32); + pub struct LinkIndex(pub u32); } define_index! { - /// Index of a region in a hugr graph. + /// Id of a region in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct RegionId(pub u32); } define_index! { - /// Index of a term in a hugr graph. + /// Id of a term in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct TermId(pub u32); } +/// The id of a link consisting of its region and the link index. +#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[display("{_0}#{_1}")] +pub struct LinkId(pub RegionId, pub LinkIndex); + +/// The id of a variable consisting of its node and the variable index. +#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[display("{_0}#{_1}")] +pub struct VarId(pub NodeId, pub VarIndex); + /// A module consisting of a hugr graph together with terms. #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] pub struct Module<'a> { @@ -233,9 +244,9 @@ pub struct Node<'a> { /// The operation that the node performs. pub operation: Operation<'a>, /// The input ports of the node. - pub inputs: &'a [LinkRef<'a>], + pub inputs: &'a [LinkIndex], /// The output ports of the node. - pub outputs: &'a [LinkRef<'a>], + pub outputs: &'a [LinkIndex], /// The parameters of the node. pub params: &'a [TermId], /// The regions of the node. @@ -290,8 +301,8 @@ pub enum Operation<'a> { /// becomes known by resolving the reference, the node can be transformed into a [`Operation::CustomFull`] /// by inferring terms for the implicit parameters or at least filling them in with a wildcard term. Custom { - /// The name of the custom operation. - operation: GlobalRef<'a>, + /// The symbol of the custom operation. + operation: NodeId, }, /// Custom operation with full parameters. /// @@ -299,8 +310,8 @@ pub enum Operation<'a> { /// Since this can be tedious to write, the [`Operation::Custom`] variant can be used to indicate that /// the implicit parameters should be inferred. CustomFull { - /// The name of the custom operation. - operation: GlobalRef<'a>, + /// The symbol of the custom operation. + operation: NodeId, }, /// Alias definitions. DefineAlias { @@ -358,17 +369,39 @@ pub enum Operation<'a> { /// The declaration of the operation. decl: &'a OperationDecl<'a>, }, + + /// Import a symbol. + Import { + /// The name of the symbol to be imported. + name: &'a str, + }, +} + +impl<'a> Operation<'a> { + /// Returns the symbol introduced by the operation, if any. + pub fn symbol(&self) -> Option<&'a str> { + match self { + Operation::DefineFunc { decl } => Some(decl.name), + Operation::DeclareFunc { decl } => Some(decl.name), + Operation::DefineAlias { decl, .. } => Some(decl.name), + Operation::DeclareAlias { decl } => Some(decl.name), + Operation::DeclareConstructor { decl } => Some(decl.name), + Operation::DeclareOperation { decl } => Some(decl.name), + Operation::Import { name } => Some(name), + _ => None, + } + } } /// A region in the hugr. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub struct Region<'a> { /// The kind of the region. See [`RegionKind`] for details. pub kind: RegionKind, /// The source ports of the region. - pub sources: &'a [LinkRef<'a>], + pub sources: &'a [LinkIndex], /// The target ports of the region. - pub targets: &'a [LinkRef<'a>], + pub targets: &'a [LinkIndex], /// The nodes in the region. The order of the nodes is not significant. pub children: &'a [NodeId], /// The metadata attached to the region. @@ -377,12 +410,34 @@ pub struct Region<'a> { /// /// Can be `None` to indicate that the region signature should be inferred. pub signature: Option, + /// Information about the scope defined by this region, if the region is closed. + pub scope: Option, +} + +/// Information about the scope defined by a closed region. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RegionScope { + /// The number of links in the scope. + pub links: u32, + /// The number of ports in the scope. + pub ports: u32, +} + +/// Type to indicate whether scopes are open or closed. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] +pub enum ScopeClosure { + /// A scope that is open and therefore not isolated from its parent scope. + #[default] + Open, + /// A scope that is closed and therefore isolated from its parent scope. + Closed, } /// The kind of a region. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub enum RegionKind { /// Data flow region. + #[default] DataFlow = 0, /// Control flow region. ControlFlow = 1, @@ -449,63 +504,8 @@ pub struct MetaItem<'a> { pub value: TermId, } -/// A reference to a global variable. -/// -/// Global variables are defined in nodes. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum GlobalRef<'a> { - /// Reference to the global that is defined by the given node. - Direct(NodeId), - /// Reference to the global with the given name. - Named(&'a str), -} - -impl std::fmt::Display for GlobalRef<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - GlobalRef::Direct(id) => write!(f, ":{}", id.index()), - GlobalRef::Named(name) => write!(f, "{}", name), - } - } -} - -/// A reference to a local variable. -/// -/// Local variables are defined as parameters to nodes. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum LocalRef<'a> { - /// Reference to the local variable by its parameter index and its defining node. - Index(NodeId, u16), - /// Reference to the local variable by its name. - Named(&'a str), -} - -impl std::fmt::Display for LocalRef<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - LocalRef::Index(node, index) => write!(f, "?:{}:{}", node.index(), index), - LocalRef::Named(name) => write!(f, "?{}", name), - } - } -} - -/// A reference to a link. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum LinkRef<'a> { - /// Reference to the link by its id. - Id(LinkId), - /// Reference to the link by its name. - Named(&'a str), -} - -impl std::fmt::Display for LinkRef<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - LinkRef::Id(id) => write!(f, "%:{})", id.index()), - LinkRef::Named(name) => write!(f, "%{}", name), - } - } -} +/// An index of a variable within a node's parameter list. +pub type VarIndex = u16; /// A term in the compile time meta language. #[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] @@ -530,7 +530,7 @@ pub enum Term<'a> { Constraint, /// A local variable. - Var(LocalRef<'a>), + Var(VarId), /// A symbolic function application. /// @@ -540,8 +540,8 @@ pub enum Term<'a> { /// /// `(GLOBAL ARG-0 ... ARG-n)` Apply { - /// Reference to the global declaration to apply. - global: GlobalRef<'a>, + /// Reference to the symbol to apply. + symbol: NodeId, /// Arguments to the function, covering only the explicit parameters. args: &'a [TermId], }, @@ -553,8 +553,8 @@ pub enum Term<'a> { /// /// `(@GLOBAL ARG-0 ... ARG-n)` ApplyFull { - /// Reference to the global declaration to apply. - global: GlobalRef<'a>, + /// Reference to the symbol to apply. + symbol: NodeId, /// Arguments to the function, covering both implicit and explicit parameters. args: &'a [TermId], }, @@ -718,13 +718,12 @@ pub enum ModelError { /// There is a reference to a region that does not exist. #[error("region not found: {0}")] RegionNotFound(RegionId), - /// There is a local reference that does not resolve. - #[error("local variable invalid: {0}")] - InvalidLocal(String), - /// There is a global reference that does not resolve to a node - /// that defines a global variable. - #[error("global variable invalid: {0}")] - InvalidGlobal(String), + /// Invalid variable reference. + #[error("variable {0} invalid")] + InvalidVar(VarId), + /// Invalid symbol reference. + #[error("symbol reference {0} invalid")] + InvalidSymbol(NodeId), /// The model contains an operation in a place where it is not allowed. #[error("unexpected operation on node: {0}")] UnexpectedOperation(NodeId), diff --git a/hugr-model/src/v0/scope/link.rs b/hugr-model/src/v0/scope/link.rs new file mode 100644 index 000000000..e45bbb345 --- /dev/null +++ b/hugr-model/src/v0/scope/link.rs @@ -0,0 +1,125 @@ +use std::hash::{BuildHasherDefault, Hash}; + +use fxhash::FxHasher; +use indexmap::IndexSet; + +use crate::v0::{LinkIndex, RegionId}; + +type FxIndexSet = IndexSet>; + +/// Table for tracking links between ports. +/// +/// Two ports are connected when they share the same link. Links are named and +/// scoped via closed regions. Links from one closed region are not visible +/// in another. Open regions are considered to form the same scope as their +/// parent region. Links do not have a unique point of declaration. +/// +/// The link table keeps track of an association between a key of type `K` and +/// the link indices within each closed region. When resolving links from a text format, +/// `K` is the name of the link as a string slice. However the link table might +/// is also useful in other contexts where the key is not a string when constructing +/// a module from a different representation. +/// +/// # Examples +/// +/// ``` +/// # pub use hugr_model::v0::RegionId; +/// # pub use hugr_model::v0::scope::LinkTable; +/// let mut links = LinkTable::new(); +/// links.enter(RegionId(0)); +/// let foo_0 = links.use_link("foo"); +/// let bar_0 = links.use_link("bar"); +/// assert_eq!(foo_0, links.use_link("foo")); +/// assert_eq!(bar_0, links.use_link("bar")); +/// let (num_links, num_ports) = links.exit(); +/// assert_eq!(num_links, 2); +/// assert_eq!(num_ports, 4); +/// ``` +#[derive(Debug, Clone)] +pub struct LinkTable { + /// The set of links in the currently active region and all parent regions. + /// + /// The order in this index set is the order in which links were added to the table. + /// This is used to efficiently remove all links from the current region when exiting a scope. + links: FxIndexSet<(RegionId, K)>, + + /// The stack of scopes that are currently open. + scopes: Vec, +} + +impl LinkTable +where + K: Copy + Eq + Hash, +{ + /// Create a new empty link table. + pub fn new() -> Self { + Self { + links: FxIndexSet::default(), + scopes: Vec::new(), + } + } + + /// Enter a new scope for the given closed region. + pub fn enter(&mut self, region: RegionId) { + self.scopes.push(LinkScope { + link_stack: self.links.len(), + link_count: 0, + port_count: 0, + region, + }); + } + + /// Exit a previously entered scope, returning the number of links and ports in the scope. + pub fn exit(&mut self) -> (u32, u32) { + let scope = self.scopes.pop().unwrap(); + self.links.drain(scope.link_stack..); + debug_assert_eq!(self.links.len(), scope.link_stack); + (scope.link_count, scope.port_count) + } + + /// Resolve a link key to a link index, adding one more port to the current scope. + /// + /// If the key has not been used in the current scope before, it will be added to the link table. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn use_link(&mut self, key: K) -> LinkIndex { + let scope = self.scopes.last_mut().unwrap(); + let (map_index, inserted) = self.links.insert_full((scope.region, key)); + + if inserted { + scope.link_count += 1; + } + + scope.port_count += 1; + LinkIndex::new(map_index - scope.link_stack) + } + + /// Reset the link table to an empty state while preserving allocated memory. + pub fn clear(&mut self) { + self.links.clear(); + self.scopes.clear(); + } +} + +impl Default for LinkTable +where + K: Copy + Eq + Hash, +{ + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +struct LinkScope { + /// The length of `LinkTable::links` when the scope was opened. + link_stack: usize, + /// The number of links in this scope. + link_count: u32, + /// The number of ports in this scope. + port_count: u32, + /// The region that introduces this scope. + region: RegionId, +} diff --git a/hugr-model/src/v0/scope/mod.rs b/hugr-model/src/v0/scope/mod.rs new file mode 100644 index 000000000..546d97d61 --- /dev/null +++ b/hugr-model/src/v0/scope/mod.rs @@ -0,0 +1,8 @@ +//! Utilities for working with scoped symbols, variables and links. +mod link; +mod symbol; +mod vars; + +pub use link::LinkTable; +pub use symbol::{DuplicateSymbolError, SymbolTable, UnknownSymbolError}; +pub use vars::{DuplicateVarError, UnknownVarError, VarTable}; diff --git a/hugr-model/src/v0/scope/symbol.rs b/hugr-model/src/v0/scope/symbol.rs new file mode 100644 index 000000000..863a25751 --- /dev/null +++ b/hugr-model/src/v0/scope/symbol.rs @@ -0,0 +1,198 @@ +use std::{borrow::Cow, hash::BuildHasherDefault}; + +use fxhash::FxHasher; +use indexmap::IndexMap; +use thiserror::Error; + +use crate::v0::{NodeId, RegionId}; + +type FxIndexMap = IndexMap>; + +/// Symbol binding table that keeps track of symbol resolution and scoping. +/// +/// Nodes may introduce a symbol so that other parts of the IR can refer to the +/// node. Symbols have an associated name and are scoped via regions. A symbol +/// can shadow another symbol with the same name from an outer region, but +/// within any single region each symbol name must be unique. +/// +/// When a symbol is referred to directly by the id of the node, the symbol must +/// be in scope at the point of reference as if the reference was by name. This +/// guarantees that transformations between directly indexed and named formats +/// are always valid. +/// +/// # Examples +/// +/// ``` +/// # pub use hugr_model::v0::{NodeId, RegionId}; +/// # pub use hugr_model::v0::scope::SymbolTable; +/// let mut symbols = SymbolTable::new(); +/// symbols.enter(RegionId(0)); +/// symbols.insert("foo", NodeId(0)).unwrap(); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0)); +/// symbols.enter(RegionId(1)); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0)); +/// symbols.insert("foo", NodeId(1)).unwrap(); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(1)); +/// assert!(!symbols.is_visible(NodeId(0))); +/// symbols.exit(); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0)); +/// assert!(symbols.is_visible(NodeId(0))); +/// assert!(!symbols.is_visible(NodeId(1))); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct SymbolTable<'a> { + symbols: FxIndexMap<&'a str, BindingIndex>, + bindings: FxIndexMap, + scopes: FxIndexMap, +} + +impl<'a> SymbolTable<'a> { + /// Create a new symbol table. + pub fn new() -> Self { + Self::default() + } + + /// Enter a new scope for the given region. + pub fn enter(&mut self, region: RegionId) { + self.scopes.insert( + region, + Scope { + binding_stack: self.bindings.len(), + }, + ); + } + + /// Exit a previously entered scope. + /// + /// # Panics + /// + /// Panics if there are no remaining open scopes. + pub fn exit(&mut self) { + let (_, scope) = self.scopes.pop().unwrap(); + + for _ in scope.binding_stack..self.bindings.len() { + let (_, binding) = self.bindings.pop().unwrap(); + + if let Some(shadows) = binding.shadows { + self.symbols[binding.symbol_index] = shadows; + } else { + let last = self.symbols.pop(); + debug_assert_eq!(last.unwrap().1, self.bindings.len()); + } + } + } + + /// Insert a new symbol into the current scope. + /// + /// # Errors + /// + /// Returns an error if the symbol is already defined in the current scope. + /// In the case of an error the table remains unchanged. + /// + /// # Panics + /// + /// Panics if there is no current scope. + pub fn insert(&mut self, name: &'a str, node: NodeId) -> Result<(), DuplicateSymbolError> { + let scope_depth = self.scopes.len() as u16 - 1; + let (symbol_index, shadowed) = self.symbols.insert_full(name, self.bindings.len()); + + if let Some(shadowed) = shadowed { + let (shadowed_node, shadowed_binding) = self.bindings.get_index(shadowed).unwrap(); + if shadowed_binding.scope_depth == scope_depth { + self.symbols.insert(name, shadowed); + return Err(DuplicateSymbolError(name.into(), node, *shadowed_node)); + } + } + + self.bindings.insert( + node, + Binding { + scope_depth, + shadows: shadowed, + symbol_index, + }, + ); + + Ok(()) + } + + /// Check whether a symbol is currently visible in the current scope. + pub fn is_visible(&self, node: NodeId) -> bool { + let Some(binding) = self.bindings.get(&node) else { + return false; + }; + + // Check that the symbol has not been shadowed at this point. + self.symbols[binding.symbol_index] == binding.symbol_index + } + + /// Tries to resolve a symbol name in the current scope. + pub fn resolve(&self, name: &'a str) -> Result { + let index = *self + .symbols + .get(name) + .ok_or(UnknownSymbolError(name.into()))?; + + // NOTE: The unwrap is safe because the `symbols` map + // points to valid indices in the `bindings` map. + let (node, _) = self.bindings.get_index(index).unwrap(); + Ok(*node) + } + + /// Returns the depth of the given region, if it corresponds to a currently open scope. + pub fn region_to_depth(&self, region: RegionId) -> Option { + Some(self.scopes.get_index_of(®ion)? as _) + } + + /// Returns the region corresponding to the scope at the given depth. + pub fn depth_to_region(&self, depth: ScopeDepth) -> Option { + let (region, _) = self.scopes.get_index(depth as _)?; + Some(*region) + } + + /// Resets the symbol table to its initial state while maintaining its + /// allocated memory. + pub fn clear(&mut self) { + self.symbols.clear(); + self.bindings.clear(); + self.scopes.clear(); + } +} + +#[derive(Debug, Clone, Copy)] +struct Binding { + /// The depth of the scope in which this binding is defined. + scope_depth: ScopeDepth, + + /// The index of the binding that is shadowed by this one, if any. + shadows: Option, + + /// The index of this binding's symbol in the symbol table. + /// + /// The symbol table always points to the currently visible binding for a + /// symbol. Therefore this index is only valid if this binding is not shadowed. + /// In particular, we detect shadowing by checking if the entry in the symbol + /// table at this index does indeed point to this binding. + symbol_index: SymbolIndex, +} + +#[derive(Debug, Clone, Copy)] +struct Scope { + /// The length of the `bindings` stack when this scope was entered. + binding_stack: usize, +} + +type BindingIndex = usize; +type SymbolIndex = usize; + +pub type ScopeDepth = u16; + +/// Error that occurs when trying to resolve an unknown symbol. +#[derive(Debug, Clone, Error)] +#[error("symbol name `{0}` not found in this scope")] +pub struct UnknownSymbolError<'a>(pub Cow<'a, str>); + +/// Error that occurs when trying to introduce a symbol that is already defined in the current scope. +#[derive(Debug, Clone, Error)] +#[error("symbol `{0}` is already defined in this scope")] +pub struct DuplicateSymbolError<'a>(pub Cow<'a, str>, pub NodeId, pub NodeId); diff --git a/hugr-model/src/v0/scope/vars.rs b/hugr-model/src/v0/scope/vars.rs new file mode 100644 index 000000000..596e46809 --- /dev/null +++ b/hugr-model/src/v0/scope/vars.rs @@ -0,0 +1,151 @@ +use fxhash::FxHasher; +use indexmap::IndexSet; +use std::hash::BuildHasherDefault; +use thiserror::Error; + +use crate::v0::{NodeId, VarId}; + +type FxIndexSet = IndexSet>; + +/// Table for keeping track of node parameters. +/// +/// Variables refer to the parameters of a node which introduces a symbol. +/// Variables have an associated name and are scoped via nodes. The types of +/// parameters of a node may only refer to earlier parameters in the same node +/// in the order they are defined. A variable name must be unique within a +/// single node. Each node that introduces a symbol introduces a new isolated +/// scope for variables. +/// +/// # Examples +/// +/// ``` +/// # pub use hugr_model::v0::{NodeId, VarId}; +/// # pub use hugr_model::v0::scope::VarTable; +/// let mut vars = VarTable::new(); +/// vars.enter(NodeId(0)); +/// vars.insert("foo").unwrap(); +/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0)); +/// assert!(!vars.is_visible(VarId(NodeId(0), 1))); +/// vars.insert("bar").unwrap(); +/// assert!(vars.is_visible(VarId(NodeId(0), 1))); +/// assert_eq!(vars.resolve("bar").unwrap(), VarId(NodeId(0), 1)); +/// vars.enter(NodeId(1)); +/// assert!(vars.resolve("foo").is_err()); +/// assert!(!vars.is_visible(VarId(NodeId(0), 0))); +/// vars.exit(); +/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0)); +/// assert!(vars.is_visible(VarId(NodeId(0), 0))); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct VarTable<'a> { + /// The set of variables in the currently active node and all its parent nodes. + /// + /// The order in this index set is the order in which variables were added to the table. + /// This is used to efficiently remove all variables from the current node when exiting a scope. + vars: FxIndexSet<(NodeId, &'a str)>, + /// The stack of scopes that are currently open. + scopes: Vec, +} + +impl<'a> VarTable<'a> { + /// Create a new empty variable table. + pub fn new() -> Self { + Self::default() + } + + /// Enter a new scope for the given node. + pub fn enter(&mut self, node: NodeId) { + self.scopes.push(VarScope { + node, + var_count: 0, + var_stack: self.vars.len(), + }) + } + + /// Exit a previously entered scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn exit(&mut self) { + let scope = self.scopes.pop().unwrap(); + self.vars.drain(scope.var_stack..); + } + + /// Resolve a variable name to a node and variable index. + /// + /// # Errors + /// + /// Returns an error if the variable is not defined in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn resolve(&self, name: &'a str) -> Result> { + let scope = self.scopes.last().unwrap(); + let set_index = self + .vars + .get_index_of(&(scope.node, name)) + .ok_or(UnknownVarError(scope.node, name))?; + let var_index = (set_index - scope.var_stack) as u16; + Ok(VarId(scope.node, var_index)) + } + + /// Check if a variable is visible in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn is_visible(&self, var: VarId) -> bool { + let scope = self.scopes.last().unwrap(); + scope.node == var.0 && var.1 < scope.var_count + } + + /// Insert a new variable into the current scope. + /// + /// # Errors + /// + /// Returns an error if the variable is already defined in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn insert(&mut self, name: &'a str) -> Result> { + let scope = self.scopes.last_mut().unwrap(); + let inserted = self.vars.insert((scope.node, name)); + + if !inserted { + return Err(DuplicateVarError(scope.node, name)); + } + + let var_index = scope.var_count; + scope.var_count += 1; + Ok(VarId(scope.node, var_index)) + } + + /// Reset the variable table to an empty state while preserving the allocations. + pub fn clear(&mut self) { + self.vars.clear(); + self.scopes.clear(); + } +} + +#[derive(Debug, Clone)] +struct VarScope { + /// The node that introduces this scope. + node: NodeId, + /// The number of variables in this scope. + var_count: u16, + /// The length of `VarTable::vars` when the scope was opened. + var_stack: usize, +} + +/// Error that occurs when a node defines two parameters with the same name. +#[derive(Debug, Clone, Error)] +#[error("node {0} already has a variable named `{1}`")] +pub struct DuplicateVarError<'a>(NodeId, &'a str); + +/// Error that occurs when a variable is not defined in the current scope. +#[derive(Debug, Clone, Error)] +#[error("can not resolve variable `{1}` in node {0}")] +pub struct UnknownVarError<'a>(NodeId, &'a str); diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index fc52b8271..4fd34f223 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -34,6 +34,7 @@ node = { | node_tail_loop | node_cond | node_tag + | node_import | node_custom } @@ -51,6 +52,7 @@ node_declare_operation = { "(" ~ "declare-operation" ~ operation_header ~ meta* node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_import = { "(" ~ "import" ~ symbol ~ meta* ~ ")" } node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 8527f1a00..4ad77d914 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -1,4 +1,5 @@ use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; +use fxhash::FxHashMap; use pest::{ iterators::{Pair, Pairs}, Parser, RuleType, @@ -6,9 +7,10 @@ use pest::{ use thiserror::Error; use crate::v0::{ - AliasDecl, ConstructorDecl, ExtSetPart, FuncDecl, GlobalRef, LinkRef, ListPart, LocalRef, - MetaItem, Module, Node, NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, - RegionKind, Term, TermId, + scope::{LinkTable, SymbolTable, UnknownSymbolError, VarTable}, + AliasDecl, ConstructorDecl, ExtSetPart, FuncDecl, LinkIndex, ListPart, MetaItem, Module, Node, + NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, RegionScope, + ScopeClosure, Term, TermId, }; mod pest_parser { @@ -50,12 +52,20 @@ pub fn parse<'a>(input: &'a str, bump: &'a Bump) -> Result, Par struct ParseContext<'a> { module: Module<'a>, bump: &'a Bump, + vars: VarTable<'a>, + links: LinkTable<&'a str>, + symbols: SymbolTable<'a>, + implicit_imports: FxHashMap<&'a str, NodeId>, } impl<'a> ParseContext<'a> { fn new(bump: &'a Bump) -> Self { Self { module: Module::default(), + symbols: SymbolTable::default(), + links: LinkTable::default(), + vars: VarTable::default(), + implicit_imports: FxHashMap::default(), bump, } } @@ -63,20 +73,38 @@ impl<'a> ParseContext<'a> { fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> { debug_assert_eq!(pair.as_rule(), Rule::module); let mut inner = pair.into_inner(); + + self.module.root = self.module.insert_region(Region::default()); + self.symbols.enter(self.module.root); + self.links.enter(self.module.root); + + // TODO: What scope does the metadata live in? let meta = self.parse_meta(&mut inner)?; + let explicit_children = self.parse_nodes(&mut inner)?; - let children = self.parse_nodes(&mut inner)?; + let mut children = BumpVec::with_capacity_in( + explicit_children.len() + self.implicit_imports.len(), + self.bump, + ); + children.extend(explicit_children); + children.extend(self.implicit_imports.drain().map(|(_, node)| node)); + let children = children.into_bump_slice(); + + let (link_count, port_count) = self.links.exit(); + self.symbols.exit(); - let root_region = self.module.insert_region(Region { + self.module.regions[self.module.root.index()] = Region { kind: RegionKind::Module, sources: &[], targets: &[], children, meta, signature: None, - }); - - self.module.root = root_region; + scope: Some(RegionScope { + links: link_count, + ports: port_count, + }), + }; Ok(()) } @@ -87,143 +115,195 @@ impl<'a> ParseContext<'a> { let rule = pair.as_rule(); let mut inner = pair.into_inner(); - let term = match rule { - Rule::term_wildcard => Term::Wildcard, - Rule::term_type => Term::Type, - Rule::term_static => Term::StaticType, - Rule::term_constraint => Term::Constraint, - Rule::term_str_type => Term::StrType, - Rule::term_nat_type => Term::NatType, - Rule::term_ctrl_type => Term::ControlType, - Rule::term_ext_set_type => Term::ExtSetType, - - Rule::term_var => { - let name_token = inner.next().unwrap(); - let name = name_token.as_str(); - Term::Var(LocalRef::Named(name)) - } - - Rule::term_apply => { - let name = GlobalRef::Named(self.parse_symbol(&mut inner)?); - let mut args = Vec::new(); - - for token in inner { - args.push(self.parse_term(token)?); - } + let term = + match rule { + Rule::term_wildcard => Term::Wildcard, + Rule::term_type => Term::Type, + Rule::term_static => Term::StaticType, + Rule::term_constraint => Term::Constraint, + Rule::term_str_type => Term::StrType, + Rule::term_nat_type => Term::NatType, + Rule::term_ctrl_type => Term::ControlType, + Rule::term_ext_set_type => Term::ExtSetType, + + Rule::term_var => { + let name_token = inner.next().unwrap(); + let name = name_token.as_str(); + + let var = self.vars.resolve(name).map_err(|err| { + ParseError::custom(&err.to_string(), name_token.as_span()) + })?; - Term::Apply { - global: name, - args: self.bump.alloc_slice_copy(&args), + Term::Var(var) } - } - Rule::term_apply_full => { - let name = GlobalRef::Named(self.parse_symbol(&mut inner)?); - let mut args = Vec::new(); + Rule::term_apply => { + let symbol = self.parse_symbol_use(&mut inner)?; + let mut args = Vec::new(); - for token in inner { - args.push(self.parse_term(token)?); - } + for token in inner { + args.push(self.parse_term(token)?); + } - Term::ApplyFull { - global: name, - args: self.bump.alloc_slice_copy(&args), + Term::Apply { + symbol, + args: self.bump.alloc_slice_copy(&args), + } } - } - Rule::term_quote => { - let r#type = self.parse_term(inner.next().unwrap())?; - Term::Quote { r#type } - } + Rule::term_apply_full => { + let symbol = self.parse_symbol_use(&mut inner)?; + let mut args = Vec::new(); - Rule::term_list => { - let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + for token in inner { + args.push(self.parse_term(token)?); + } - for token in inner { - match token.as_rule() { - Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), - Rule::spliced_term => { - let term_token = token.into_inner().next().unwrap(); - parts.push(ListPart::Splice(self.parse_term(term_token)?)) - } - _ => unreachable!(), + Term::ApplyFull { + symbol, + args: self.bump.alloc_slice_copy(&args), } } - Term::List { - parts: parts.into_bump_slice(), + Rule::term_quote => { + let r#type = self.parse_term(inner.next().unwrap())?; + Term::Quote { r#type } } - } - Rule::term_list_type => { - let item_type = self.parse_term(inner.next().unwrap())?; - Term::ListType { item_type } - } + Rule::term_list => { + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + + for token in inner { + match token.as_rule() { + Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ListPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), + } + } - Rule::term_str => { - let value = self.parse_string(inner.next().unwrap())?; - Term::Str(value) - } + Term::List { + parts: parts.into_bump_slice(), + } + } - Rule::term_nat => { - let value = inner.next().unwrap().as_str().parse().unwrap(); - Term::Nat(value) - } + Rule::term_list_type => { + let item_type = self.parse_term(inner.next().unwrap())?; + Term::ListType { item_type } + } - Rule::term_ext_set => { - let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + Rule::term_str => { + let value = self.parse_string(inner.next().unwrap())?; + Term::Str(value) + } - for token in inner { - match token.as_rule() { - Rule::ext_name => { - parts.push(ExtSetPart::Extension(self.bump.alloc_str(token.as_str()))) - } - Rule::spliced_term => { - let term_token = token.into_inner().next().unwrap(); - parts.push(ExtSetPart::Splice(self.parse_term(term_token)?)) + Rule::term_nat => { + let value = inner.next().unwrap().as_str().parse().unwrap(); + Term::Nat(value) + } + + Rule::term_ext_set => { + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + + for token in inner { + match token.as_rule() { + Rule::ext_name => parts + .push(ExtSetPart::Extension(self.bump.alloc_str(token.as_str()))), + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ExtSetPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), } - _ => unreachable!(), + } + + Term::ExtSet { + parts: parts.into_bump_slice(), } } - Term::ExtSet { - parts: parts.into_bump_slice(), + Rule::term_adt => { + let variants = self.parse_term(inner.next().unwrap())?; + Term::Adt { variants } } - } - Rule::term_adt => { - let variants = self.parse_term(inner.next().unwrap())?; - Term::Adt { variants } - } + Rule::term_func_type => { + let inputs = self.parse_term(inner.next().unwrap())?; + let outputs = self.parse_term(inner.next().unwrap())?; + let extensions = self.parse_term(inner.next().unwrap())?; + Term::FuncType { + inputs, + outputs, + extensions, + } + } - Rule::term_func_type => { - let inputs = self.parse_term(inner.next().unwrap())?; - let outputs = self.parse_term(inner.next().unwrap())?; - let extensions = self.parse_term(inner.next().unwrap())?; - Term::FuncType { - inputs, - outputs, - extensions, + Rule::term_ctrl => { + let values = self.parse_term(inner.next().unwrap())?; + Term::Control { values } } - } - Rule::term_ctrl => { - let values = self.parse_term(inner.next().unwrap())?; - Term::Control { values } - } + Rule::term_non_linear => { + let term = self.parse_term(inner.next().unwrap())?; + Term::NonLinearConstraint { term } + } - Rule::term_non_linear => { - let term = self.parse_term(inner.next().unwrap())?; - Term::NonLinearConstraint { term } - } + r => unreachable!("term: {:?}", r), + }; + + Ok(self.module.insert_term(term)) + } - r => unreachable!("term: {:?}", r), + fn parse_node_shallow(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + debug_assert_eq!(pair.as_rule(), Rule::node); + let pair = pair.into_inner().next().unwrap(); + let span = pair.as_span(); + let rule = pair.as_rule(); + let mut inner = pair.into_inner(); + + let symbol = match rule { + Rule::node_define_func => { + let mut func_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut func_header)?) + } + Rule::node_declare_func => { + let mut func_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut func_header)?) + } + Rule::node_define_alias => { + let mut alias_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut alias_header)?) + } + Rule::node_declare_alias => { + let mut alias_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut alias_header)?) + } + Rule::node_declare_ctr => { + let mut ctr_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut ctr_header)?) + } + Rule::node_declare_operation => { + let mut op_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut op_header)?) + } + Rule::node_import => Some(self.parse_symbol(&mut inner)?), + _ => None, }; - Ok(self.module.insert_term(term)) + let node = self.module.insert_node(Node::default()); + + if let Some(symbol) = symbol { + self.symbols + .insert(symbol, node) + .map_err(|err| ParseError::custom(&err.to_string(), span))?; + } + + Ok(node) } - fn parse_node(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + fn parse_node_deep(&mut self, pair: Pair<'a, Rule>, node: NodeId) -> ParseResult> { debug_assert_eq!(pair.as_rule(), Rule::node); let pair = pair.into_inner().next().unwrap(); let rule = pair.as_rule(); @@ -236,7 +316,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, ScopeClosure::Open)?; Node { operation: Operation::Dfg, inputs, @@ -253,7 +333,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, ScopeClosure::Open)?; Node { operation: Operation::Cfg, inputs, @@ -270,7 +350,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, ScopeClosure::Open)?; Node { operation: Operation::Block, inputs, @@ -283,9 +363,11 @@ impl<'a> ParseContext<'a> { } Rule::node_define_func => { + self.vars.enter(node); let decl = self.parse_func_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, ScopeClosure::Closed)?; + self.vars.exit(); Node { operation: Operation::DefineFunc { decl }, inputs: &[], @@ -298,8 +380,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_func => { + self.vars.enter(node); let decl = self.parse_func_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareFunc { decl }, inputs: &[], @@ -346,9 +430,11 @@ impl<'a> ParseContext<'a> { } Rule::node_define_alias => { + self.vars.enter(node); let decl = self.parse_alias_header(inner.next().unwrap())?; let value = self.parse_term(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DefineAlias { decl, value }, inputs: &[], @@ -361,8 +447,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_alias => { + self.vars.enter(node); let decl = self.parse_alias_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareAlias { decl }, inputs: &[], @@ -383,7 +471,7 @@ impl<'a> ParseContext<'a> { let op_rule = op.as_rule(); let mut op_inner = op.into_inner(); - let name = GlobalRef::Named(self.parse_symbol(&mut op_inner)?); + let operation = self.parse_symbol_use(&mut op_inner)?; let mut params = Vec::new(); @@ -392,8 +480,8 @@ impl<'a> ParseContext<'a> { } let operation = match op_rule { - Rule::term_apply_full => Operation::CustomFull { operation: name }, - Rule::term_apply => Operation::Custom { operation: name }, + Rule::term_apply_full => Operation::CustomFull { operation }, + Rule::term_apply => Operation::Custom { operation }, _ => unreachable!(), }; @@ -401,7 +489,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, ScopeClosure::Closed)?; Node { operation, inputs, @@ -418,7 +506,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, ScopeClosure::Open)?; Node { operation: Operation::TailLoop, inputs, @@ -435,7 +523,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, ScopeClosure::Open)?; Node { operation: Operation::Conditional, inputs, @@ -465,8 +553,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_ctr => { + self.vars.enter(node); let decl = self.parse_ctr_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareConstructor { decl }, inputs: &[], @@ -479,8 +569,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_operation => { + self.vars.enter(node); let decl = self.parse_op_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareOperation { decl }, inputs: &[], @@ -495,25 +587,38 @@ impl<'a> ParseContext<'a> { _ => unreachable!(), }; - let node_id = self.module.insert_node(node); - - Ok(node_id) + Ok(node) } - fn parse_regions(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [RegionId]> { + fn parse_regions( + &mut self, + pairs: &mut Pairs<'a, Rule>, + closure: ScopeClosure, + ) -> ParseResult<&'a [RegionId]> { let mut regions = Vec::new(); for pair in filter_rule(pairs, Rule::region) { - regions.push(self.parse_region(pair)?); + regions.push(self.parse_region(pair, closure)?); } Ok(self.bump.alloc_slice_copy(®ions)) } - fn parse_region(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + fn parse_region( + &mut self, + pair: Pair<'a, Rule>, + closure: ScopeClosure, + ) -> ParseResult { debug_assert_eq!(pair.as_rule(), Rule::region); let pair = pair.into_inner().next().unwrap(); let rule = pair.as_rule(); let mut inner = pair.into_inner(); + let region = self.module.insert_region(Region::default()); + self.symbols.enter(region); + + if closure == ScopeClosure::Closed { + self.links.enter(region); + } + let kind = match rule { Rule::region_cfg => RegionKind::ControlFlow, Rule::region_dfg => RegionKind::DataFlow, @@ -526,24 +631,48 @@ impl<'a> ParseContext<'a> { let meta = self.parse_meta(&mut inner)?; let children = self.parse_nodes(&mut inner)?; - Ok(self.module.insert_region(Region { + let scope = match closure { + ScopeClosure::Closed => { + let (links, ports) = self.links.exit(); + Some(RegionScope { links, ports }) + } + ScopeClosure::Open => None, + }; + + self.symbols.exit(); + + self.module.regions[region.index()] = Region { kind, sources, targets, children, meta, signature, - })) + scope, + }; + + Ok(region) } fn parse_nodes(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [NodeId]> { - let mut nodes = Vec::new(); + let nodes = { + let mut pairs = pairs.clone(); + let mut nodes = BumpVec::with_capacity_in(pairs.len(), self.bump); + + for pair in filter_rule(&mut pairs, Rule::node) { + nodes.push(self.parse_node_shallow(pair)?); + } - for pair in filter_rule(pairs, Rule::node) { - nodes.push(self.parse_node(pair)?); + nodes.into_bump_slice() + }; + + for (i, pair) in filter_rule(pairs, Rule::node).enumerate() { + let node = nodes[i]; + let node_data = self.parse_node_deep(pair, node)?; + self.module.nodes[node.index()] = node_data; } - Ok(self.bump.alloc_slice_copy(&nodes)) + Ok(nodes) } fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a FuncDecl<'a>> { @@ -627,6 +756,7 @@ impl<'a> ParseContext<'a> { for pair in filter_rule(pairs, Rule::param) { let param = pair.into_inner().next().unwrap(); + let param_span = param.as_span(); let param = match param.as_rule() { Rule::param_implicit => { @@ -652,6 +782,10 @@ impl<'a> ParseContext<'a> { _ => unreachable!(), }; + self.vars + .insert(param.name) + .map_err(|err| ParseError::custom(&err.to_string(), param_span))?; + params.push(param); } @@ -679,27 +813,27 @@ impl<'a> ParseContext<'a> { Ok(Some(signature)) } - fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [LinkRef<'a>]> { + fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [LinkIndex]> { let Some(Rule::port_list) = pairs.peek().map(|p| p.as_rule()) else { return Ok(&[]); }; let pair = pairs.next().unwrap(); let inner = pair.into_inner(); - let mut links = Vec::new(); + let mut links = BumpVec::with_capacity_in(inner.len(), self.bump); for token in inner { links.push(self.parse_port(token)?); } - Ok(self.bump.alloc_slice_copy(&links)) + Ok(links.into_bump_slice()) } - fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult> { + fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult { debug_assert_eq!(pair.as_rule(), Rule::port); let mut inner = pair.into_inner(); - let link = LinkRef::Named(&inner.next().unwrap().as_str()[1..]); - Ok(link) + let name = &inner.next().unwrap().as_str()[1..]; + Ok(self.links.use_link(name)) } fn parse_meta(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [MetaItem<'a>]> { @@ -715,6 +849,21 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(&items)) } + fn parse_symbol_use(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult { + let name = self.parse_symbol(pairs)?; + let resolved = self.symbols.resolve(name); + + Ok(match resolved { + Ok(node) => node, + Err(UnknownSymbolError(_)) => *self.implicit_imports.entry(name).or_insert_with(|| { + self.module.insert_node(Node { + operation: Operation::Import { name }, + ..Node::default() + }) + }), + }) + } + fn parse_symbol(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a str> { let pair = pairs.next().unwrap(); if let Rule::symbol = pair.as_rule() { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 4132338d9..ba7874e45 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - ExtSetPart, GlobalRef, LinkRef, ListPart, LocalRef, MetaItem, ModelError, Module, NodeId, - Operation, Param, ParamSort, RegionId, RegionKind, Term, TermId, + ExtSetPart, LinkIndex, ListPart, MetaItem, ModelError, Module, NodeId, Operation, Param, + ParamSort, RegionId, RegionKind, Term, TermId, VarId, }; type PrintError = ModelError; @@ -247,10 +247,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Operation::Custom { operation } => { this.print_group(|this| { if node_data.params.is_empty() { - this.print_global_ref(*operation)?; + this.print_symbol(*operation)?; } else { this.print_parens(|this| { - this.print_global_ref(*operation)?; + this.print_symbol(*operation)?; for param in node_data.params { this.print_term(*param)?; @@ -271,7 +271,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_group(|this| { this.print_parens(|this| { this.print_text("@"); - this.print_global_ref(*operation)?; + this.print_symbol(*operation)?; for param in node_data.params { this.print_term(*param)?; @@ -364,6 +364,12 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_signature(node_data.signature)?; this.print_meta(node_data.meta) } + + Operation::Import { name } => { + this.print_text("import"); + this.print_text(*name); + this.print_meta(node_data.meta) + } }) } @@ -413,8 +419,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { fn print_port_lists( &mut self, - first: &'a [LinkRef<'a>], - second: &'a [LinkRef<'a>], + first: &'a [LinkIndex], + second: &'a [LinkIndex], ) -> PrintResult<()> { if !first.is_empty() && !second.is_empty() { self.print_group(|this| { @@ -426,20 +432,17 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } - fn print_port_list(&mut self, links: &'a [LinkRef<'a>]) -> PrintResult<()> { + fn print_port_list(&mut self, links: &'a [LinkIndex]) -> PrintResult<()> { self.print_brackets(|this| { for link in links { - this.print_link_ref(*link); + this.print_link_index(*link); } Ok(()) }) } - fn print_link_ref(&mut self, link_ref: LinkRef<'a>) { - match link_ref { - LinkRef::Id(link_id) => self.print_text(format!("%{}", link_id.0)), - LinkRef::Named(name) => self.print_text(format!("%{}", name)), - } + fn print_link_index(&mut self, link_index: LinkIndex) { + self.print_text(format!("%{}", link_index.0)); } fn print_params(&mut self, params: &'a [Param<'a>]) -> PrintResult<()> { @@ -492,13 +495,13 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("constraint"); Ok(()) } - Term::Var(local_ref) => self.print_local_ref(*local_ref), - Term::Apply { global: name, args } => { + Term::Var(var) => self.print_var(*var), + Term::Apply { symbol, args } => { if args.is_empty() { - self.print_global_ref(*name)?; + self.print_symbol(*symbol)?; } else { self.print_parens(|this| { - this.print_global_ref(*name)?; + this.print_symbol(*symbol)?; for arg in args.iter() { this.print_term(*arg)?; } @@ -508,9 +511,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - Term::ApplyFull { global: name, args } => self.print_parens(|this| { + Term::ApplyFull { symbol, args } => self.print_parens(|this| { this.print_text("@"); - this.print_global_ref(*name)?; + this.print_symbol(*symbol)?; for arg in args.iter() { this.print_term(*arg)?; } @@ -628,44 +631,27 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - fn print_local_ref(&mut self, local_ref: LocalRef<'a>) -> PrintResult<()> { - let name = match local_ref { - LocalRef::Index(_, i) => { - let Some(name) = self.locals.get(i as usize) else { - return Err(PrintError::InvalidLocal(local_ref.to_string())); - }; - - name - } - LocalRef::Named(name) => name, + fn print_var(&mut self, var: VarId) -> PrintResult<()> { + let Some(name) = self.locals.get(var.1 as usize) else { + return Err(PrintError::InvalidVar(var)); }; self.print_text(format!("?{}", name)); Ok(()) } - fn print_global_ref(&mut self, global_ref: GlobalRef<'a>) -> PrintResult<()> { - match global_ref { - GlobalRef::Direct(node_id) => { - let node_data = self - .module - .get_node(node_id) - .ok_or(PrintError::NodeNotFound(node_id))?; - - let name = match &node_data.operation { - Operation::DefineFunc { decl } => decl.name, - Operation::DeclareFunc { decl } => decl.name, - Operation::DefineAlias { decl, .. } => decl.name, - Operation::DeclareAlias { decl } => decl.name, - _ => return Err(PrintError::UnexpectedOperation(node_id)), - }; - - self.print_text(name) - } + fn print_symbol(&mut self, node_id: NodeId) -> PrintResult<()> { + let node_data = self + .module + .get_node(node_id) + .ok_or(PrintError::NodeNotFound(node_id))?; - GlobalRef::Named(symbol) => self.print_text(symbol), - } + let name = node_data + .operation + .symbol() + .ok_or(PrintError::UnexpectedOperation(node_id))?; + self.print_text(name); Ok(()) }