Skip to content

Commit

Permalink
feat: CircuitBuilder::append_with_output_arr (#871)
Browse files Browse the repository at this point in the history
- Fleshes out `utils::collect_array`, and uses it in all the methods
that return a fixed size array.
- Adds `CircuitBuilder::append_with_output_arr`.

Follow up to
#867 (comment)
  • Loading branch information
aborgna-q authored Mar 8, 2024
1 parent 441c887 commit 87b9aab
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 30 deletions.
6 changes: 2 additions & 4 deletions quantinuum-hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTag, OpTrait, OpType};
use crate::utils::collect_array;
use crate::{IncomingPort, Node, OutgoingPort};

use std::iter;
Expand Down Expand Up @@ -275,10 +276,7 @@ pub trait Dataflow: Container {
///
/// Panics if the number of input Wires does not match the size of the array.
fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
self.input_wires()
.collect_vec()
.try_into()
.expect(&format!("Incorrect number of wires: {N}")[..])
collect_array(self.input_wires())
}

/// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph.
Expand Down
50 changes: 31 additions & 19 deletions quantinuum-hugr/src/builder/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::mem;

use itertools::Itertools;
use thiserror::Error;

use crate::ops::{OpName, OpType};
use crate::utils::collect_array;

use super::{BuildError, Dataflow};
use crate::{CircuitUnit, Wire};
Expand Down Expand Up @@ -74,17 +74,13 @@ impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
}

/// Returns an array with the tracked linear units.
///
/// # Panics
///
/// If the number of outputs does not match `N`.
#[must_use]
pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
self.tracked_units()
.collect_vec()
.try_into()
.unwrap_or_else(|ws: Vec<usize>| {
panic!(
"Incorrect number of linear units: Expected {N} but got {}",
ws.len()
)
})
collect_array(self.tracked_units())
}

#[inline]
Expand Down Expand Up @@ -121,7 +117,7 @@ impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
///
/// # Errors
///
/// This function will return an error if an index is invalid.
/// Returns an error on an invalid input unit.
pub fn append_with_outputs<A: Into<CircuitUnit>>(
&mut self,
op: impl Into<OpType>,
Expand Down Expand Up @@ -181,6 +177,28 @@ impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
Ok(nonlinear_outputs)
}

/// Append an `op` with some inputs being the stored wires.
/// Any inputs of the form [`CircuitUnit::Linear`] are used to index the
/// stored wires.
/// The outputs at those indices are used to replace the stored wire.
/// The remaining outputs are returned as an array.
///
/// # Errors
///
/// Returns an error on an invalid input unit.
///
/// # Panics
///
/// If the number of outputs does not match `N`.
pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
&mut self,
op: impl Into<OpType>,
inputs: impl IntoIterator<Item = A>,
) -> Result<[Wire; N], BuildError> {
let outputs = self.append_with_outputs(op, inputs)?;
Ok(collect_array(outputs))
}

/// Add a wire to the list of tracked wires.
///
/// Returns the new unit index.
Expand Down Expand Up @@ -302,21 +320,15 @@ mod test {
assert_eq!(circ.n_wires(), 1);

let [q0] = circ.tracked_units_arr();
let [ancilla] = circ
.append_with_outputs::<CircuitUnit>(q_alloc(), [])?
.try_into()
.expect("Expected a single ancilla wire");
let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
let ancilla = circ.track_wire(ancilla);

assert_ne!(ancilla, 0);
assert_eq!(circ.n_wires(), 2);
assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);

circ.append(cx_gate(), [q0, ancilla])?;
let [_bit] = circ
.append_with_outputs(measure(), [q0])?
.try_into()
.unwrap();
let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;

let q0 = circ.untrack_wire(q0)?;

Expand Down
19 changes: 14 additions & 5 deletions quantinuum-hugr/src/builder/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
//!
use crate::ops::handle::{BasicBlockID, CaseID, DfgID, FuncID, NodeHandle, TailLoopID};
use crate::ops::OpTag;
use crate::utils::collect_array;
use crate::{Node, OutgoingPort, Wire};

use itertools::Itertools;
use std::iter::FusedIterator;

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
Expand Down Expand Up @@ -49,10 +49,7 @@ impl<T: NodeHandle> BuildHandle<T> {

/// Attempt to cast outputs in to array of Wires.
pub fn outputs_arr<const N: usize>(&self) -> [Wire; N] {
self.outputs()
.collect_vec()
.try_into()
.expect(&format!("Incorrect number of wires: {}", N)[..])
self.outputs().to_array()
}

#[inline]
Expand Down Expand Up @@ -113,6 +110,18 @@ pub struct Outputs {
range: std::ops::Range<usize>,
}

impl Outputs {
#[inline]
/// Returns the output wires as an array.
///
/// # Panics
///
/// If the length of the slice is not equal to `N`.
pub fn to_array<const N: usize>(self) -> [Wire; N] {
collect_array(self)
}
}

impl Iterator for Outputs {
type Item = Wire;

Expand Down
42 changes: 40 additions & 2 deletions quantinuum-hugr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,46 @@ where
}

/// Collect a vector into an array.
pub fn collect_array<const N: usize, T: Debug>(arr: &[T]) -> [&T; N] {
arr.iter().collect_vec().try_into().unwrap()
///
/// This is useful for deconstructing a vectors content.
///
/// # Example
///
/// ```ignore
/// let iter = 0..3;
/// let [a, b, c] = crate::utils::collect_array(iter);
/// assert_eq!(b, 1);
/// ```
///
/// # Panics
///
/// If the length of the slice is not equal to `N`.
///
/// See also [`try_collect_array`] for a non-panicking version.
#[inline]
pub fn collect_array<const N: usize, T: Debug>(arr: impl IntoIterator<Item = T>) -> [T; N] {
try_collect_array(arr).unwrap_or_else(|v| panic!("Expected {} elements, got {:?}", N, v))
}

/// Collect a vector into an array.
///
/// This is useful for deconstructing a vectors content.
///
/// # Example
///
/// ```ignore
/// let iter = 0..3;
/// let [a, b, c] = crate::utils::try_collect_array(iter)
/// .unwrap_or_else(|v| panic!("Expected 3 elements, got {:?}", v));
/// assert_eq!(b, 1);
/// ```
///
/// See also [`collect_array`].
#[inline]
pub fn try_collect_array<const N: usize, T>(
arr: impl IntoIterator<Item = T>,
) -> Result<[T; N], Vec<T>> {
arr.into_iter().collect_vec().try_into()
}

/// Helper method to skip serialization of default values in serde.
Expand Down

0 comments on commit 87b9aab

Please sign in to comment.