Skip to content

Commit

Permalink
feat: Scoping rules and utilities for symbols, links and variables (#…
Browse files Browse the repository at this point in the history
…1754)

This PR introduces scoping for symbols, links and variables. It comes
with utilities that can be used to resolve names appropriately. Moreover
the model data structures are changed so that they always use direct
references by indices instead of names in order to streamline the
serialisation format.
  • Loading branch information
zrho authored Dec 17, 2024
1 parent e40b6c7 commit e065d70
Show file tree
Hide file tree
Showing 19 changed files with 1,367 additions and 720 deletions.
358 changes: 254 additions & 104 deletions hugr-core/src/export.rs

Large diffs are not rendered by default.

327 changes: 112 additions & 215 deletions hugr-core/src/import.rs

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion hugr-core/tests/snapshots/model__roundtrip_add.snap
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"fixtures/model-add.edn\"))"
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add.edn\"))"
---
(hugr 0)

(import arithmetic.int.iadd)

(import arithmetic.int.types.int)

(define-func example.add
[(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/tests/snapshots/model__roundtrip_alias.snap
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"fixtures/model-alias.edn\"))"
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alias.edn\"))"
---
(hugr 0)

(import arithmetic.int.types.int)

(declare-alias local.float type)

(define-alias local.int type (@ arithmetic.int.types.int))
Expand Down
4 changes: 4 additions & 0 deletions hugr-core/tests/snapshots/model__roundtrip_call.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call
---
(hugr 0)

(import prelude.json)

(import arithmetic.int.types.int)

(declare-func example.callee
(forall ?0 ext-set)
[(@ arithmetic.int.types.int)]
Expand Down
8 changes: 4 additions & 4 deletions hugr-core/tests/snapshots/model__roundtrip_cfg.snap
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.
(cfg [%0] [%1]
(signature (fn [?0] [?0] (ext)))
(cfg
[%2] [%8]
[%4] [%8]
(signature (fn [?0] [?0] (ext)))
(block [%2] [%5]
(block [%4] [%5]
(signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext)))
(dfg
[%3] [%4]
[%2] [%3]
(signature (fn [?0] [(adt [[?0]])] (ext)))
(tag 0 [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext))))))
(tag 0 [%2] [%3] (signature (fn [?0] [(adt [[?0]])] (ext))))))
(block [%5] [%8]
(signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext)))
(dfg
Expand Down
6 changes: 5 additions & 1 deletion hugr-core/tests/snapshots/model__roundtrip_cond.snap
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"fixtures/model-cond.edn\"))"
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cond.edn\"))"
---
(hugr 0)

(import arithmetic.int.types.int)

(import arithmetic.int.ineg)

(define-func example.cond
[(adt [[] []]) (@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
Expand Down
2 changes: 2 additions & 0 deletions hugr-core/tests/snapshots/model__roundtrip_constraints.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
---
(hugr 0)

(import prelude.Array)

(declare-func array.replicate
(forall ?0 type)
(forall ?1 nat)
Expand Down
58 changes: 25 additions & 33 deletions hugr-model/capnp/hugr-v0.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ using NodeId = UInt32;
# The id of a `Link`.
using LinkId = UInt32;

# The index of a `Link`.
using LinkIndex = UInt32;

struct Module {
root @0 :RegionId;
nodes @1 :List(Node);
Expand All @@ -24,8 +27,8 @@ struct Module {

struct Node {
operation @0 :Operation;
inputs @1 :List(LinkRef);
outputs @2 :List(LinkRef);
inputs @1 :List(LinkIndex);
outputs @2 :List(LinkIndex);
params @3 :List(TermId);
regions @4 :List(RegionId);
meta @5 :List(MetaItem);
Expand All @@ -42,15 +45,16 @@ struct Operation {
funcDecl @5 :FuncDecl;
aliasDefn @6 :AliasDefn;
aliasDecl @7 :AliasDecl;
custom @8 :GlobalRef;
customFull @9 :GlobalRef;
custom @8 :NodeId;
customFull @9 :NodeId;
tag @10 :UInt16;
tailLoop @11 :Void;
conditional @12 :Void;
callFunc @13 :TermId;
loadFunc @14 :TermId;
constructorDecl @15 :ConstructorDecl;
operationDecl @16 :OperationDecl;
import @17 :Text;
}

struct FuncDefn {
Expand Down Expand Up @@ -97,13 +101,22 @@ struct Operation {

struct Region {
kind @0 :RegionKind;
sources @1 :List(LinkRef);
targets @2 :List(LinkRef);
sources @1 :List(LinkIndex);
targets @2 :List(LinkIndex);
children @3 :List(NodeId);
meta @4 :List(MetaItem);
signature @5 :OptionalTermId;
scope @6 :RegionScope;
}

struct RegionScope {
links @0 :UInt32;
ports @1 :UInt32;
}

# Either `0` for an open scope, or the number of links in the closed scope incremented by `1`.
using LinkScope = UInt32;

enum RegionKind {
dataFlow @0;
controlFlow @1;
Expand All @@ -115,37 +128,16 @@ struct MetaItem {
value @1 :UInt32;
}

struct LinkRef {
union {
id @0 :LinkId;
named @1 :Text;
}
}

struct GlobalRef {
union {
node @0 :NodeId;
named @1 :Text;
}
}

struct LocalRef {
union {
direct :group {
index @0 :UInt16;
node @1 :NodeId;
}
named @2 :Text;
}
}

struct Term {
union {
wildcard @0 :Void;
runtimeType @1 :Void;
staticType @2 :Void;
constraint @3 :Void;
variable @4 :LocalRef;
variable :group {
variableNode @4 :NodeId;
variableIndex @21 :UInt16;
}
apply @5 :Apply;
applyFull @6 :ApplyFull;
quote @7 :TermId;
Expand All @@ -165,12 +157,12 @@ struct Term {
}

struct Apply {
global @0 :GlobalRef;
symbol @0 :NodeId;
args @1 :List(TermId);
}

struct ApplyFull {
global @0 :GlobalRef;
symbol @0 :NodeId;
args @1 :List(TermId);
}

Expand Down
84 changes: 34 additions & 50 deletions hugr-model/src/v0/binary/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ fn read_module<'a>(

fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult<model::Node<'a>> {
let operation = read_operation(bump, reader.get_operation()?)?;
let inputs = read_list!(bump, reader, get_inputs, read_link_ref);
let outputs = read_list!(bump, reader, get_outputs, read_link_ref);
let inputs = read_scalar_list!(bump, reader, get_inputs, model::LinkIndex);
let outputs = read_scalar_list!(bump, reader, get_outputs, model::LinkIndex);
let params = read_scalar_list!(bump, reader, get_params, model::TermId);
let regions = read_scalar_list!(bump, reader, get_regions, model::RegionId);
let meta = read_list!(bump, reader, get_meta, read_meta_item);
Expand All @@ -89,43 +89,6 @@ fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult
})
}

fn read_local_ref<'a>(
bump: &'a Bump,
reader: hugr_capnp::local_ref::Reader,
) -> ReadResult<model::LocalRef<'a>> {
use hugr_capnp::local_ref::Which;
Ok(match reader.which()? {
Which::Direct(reader) => {
let index = reader.get_index();
let node = model::NodeId(reader.get_node());
model::LocalRef::Index(node, index)
}
Which::Named(name) => model::LocalRef::Named(bump.alloc_str(name?.to_str()?)),
})
}

fn read_global_ref<'a>(
bump: &'a Bump,
reader: hugr_capnp::global_ref::Reader,
) -> ReadResult<model::GlobalRef<'a>> {
use hugr_capnp::global_ref::Which;
Ok(match reader.which()? {
Which::Node(node) => model::GlobalRef::Direct(model::NodeId(node)),
Which::Named(name) => model::GlobalRef::Named(bump.alloc_str(name?.to_str()?)),
})
}

fn read_link_ref<'a>(
bump: &'a Bump,
reader: hugr_capnp::link_ref::Reader,
) -> ReadResult<model::LinkRef<'a>> {
use hugr_capnp::link_ref::Which;
Ok(match reader.which()? {
Which::Id(id) => model::LinkRef::Id(model::LinkId(id)),
Which::Named(name) => model::LinkRef::Named(bump.alloc_str(name?.to_str()?)),
})
}

fn read_operation<'a>(
bump: &'a Bump,
reader: hugr_capnp::operation::Reader,
Expand Down Expand Up @@ -217,11 +180,11 @@ fn read_operation<'a>(
});
model::Operation::DeclareOperation { decl }
}
Which::Custom(name) => model::Operation::Custom {
operation: read_global_ref(bump, name?)?,
Which::Custom(operation) => model::Operation::Custom {
operation: model::NodeId(operation),
},
Which::CustomFull(name) => model::Operation::CustomFull {
operation: read_global_ref(bump, name?)?,
Which::CustomFull(operation) => model::Operation::CustomFull {
operation: model::NodeId(operation),
},
Which::Tag(tag) => model::Operation::Tag { tag },
Which::TailLoop(()) => model::Operation::TailLoop,
Expand All @@ -232,6 +195,9 @@ fn read_operation<'a>(
Which::LoadFunc(func) => model::Operation::LoadFunc {
func: model::TermId(func),
},
Which::Import(name) => model::Operation::Import {
name: bump.alloc_str(name?.to_str()?),
},
})
}

Expand All @@ -245,22 +211,35 @@ fn read_region<'a>(
hugr_capnp::RegionKind::Module => model::RegionKind::Module,
};

let sources = read_list!(bump, reader, get_sources, read_link_ref);
let targets = read_list!(bump, reader, get_targets, read_link_ref);
let sources = read_scalar_list!(bump, reader, get_sources, model::LinkIndex);
let targets = read_scalar_list!(bump, reader, get_targets, model::LinkIndex);
let children = read_scalar_list!(bump, reader, get_children, model::NodeId);
let meta = read_list!(bump, reader, get_meta, read_meta_item);
let signature = reader.get_signature().checked_sub(1).map(model::TermId);

let scope = if reader.has_scope() {
Some(read_region_scope(reader.get_scope()?)?)
} else {
None
};

Ok(model::Region {
kind,
sources,
targets,
children,
meta,
signature,
scope,
})
}

fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult<model::RegionScope> {
let links = reader.get_links();
let ports = reader.get_ports();
Ok(model::RegionScope { links, ports })
}

fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult<model::Term<'a>> {
use hugr_capnp::term::Which;
Ok(match reader.which()? {
Expand All @@ -274,20 +253,25 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult
Which::NatType(()) => model::Term::NatType,
Which::ExtSetType(()) => model::Term::ExtSetType,
Which::ControlType(()) => model::Term::ControlType,
Which::Variable(local_ref) => model::Term::Var(read_local_ref(bump, local_ref?)?),

Which::Variable(reader) => {
let node = model::NodeId(reader.get_variable_node());
let index = reader.get_variable_index();
model::Term::Var(model::VarId(node, index))
}

Which::Apply(reader) => {
let reader = reader?;
let global = read_global_ref(bump, reader.get_global()?)?;
let symbol = model::NodeId(reader.get_symbol());
let args = read_scalar_list!(bump, reader, get_args, model::TermId);
model::Term::Apply { global, args }
model::Term::Apply { symbol, args }
}

Which::ApplyFull(reader) => {
let reader = reader?;
let global = read_global_ref(bump, reader.get_global()?)?;
let symbol = model::NodeId(reader.get_symbol());
let args = read_scalar_list!(bump, reader, get_args, model::TermId);
model::Term::ApplyFull { global, args }
model::Term::ApplyFull { symbol, args }
}

Which::Quote(r#type) => model::Term::Quote {
Expand Down
Loading

0 comments on commit e065d70

Please sign in to comment.