From df6fa314ccbb1b23bb5d2d727784b74f6184651e Mon Sep 17 00:00:00 2001 From: Matthew Harrigan Date: Fri, 6 Oct 2023 10:45:19 -0700 Subject: [PATCH] Enable decomposition for IntEffect too --- qualtran/bloqs/basic_gates/z_basis.py | 29 +++++++++++++++------- qualtran/bloqs/basic_gates/z_basis_test.py | 4 +++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/qualtran/bloqs/basic_gates/z_basis.py b/qualtran/bloqs/basic_gates/z_basis.py index 86d082129..1d6fd4a56 100644 --- a/qualtran/bloqs/basic_gates/z_basis.py +++ b/qualtran/bloqs/basic_gates/z_basis.py @@ -230,15 +230,26 @@ def signature(self) -> Signature: side = Side.RIGHT if self.state else Side.LEFT return Signature([Register('val', bitsize=self.bitsize, side=side)]) - def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']: - states = [ZeroState(), OneState()] - xs = [] - for bit in ints_to_bits(np.array([self.val]), w=self.bitsize)[0]: - x = bb.add(states[bit]) - xs.append(x) - xs = np.array(xs) - - return {'val': bb.join(xs)} + def build_composite_bloq(self, bb: 'BloqBuilder', **val) -> Dict[str, 'SoquetT']: + bits = ints_to_bits(np.array([self.val]), w=self.bitsize)[0] + + if self.state: + assert not val + states = [ZeroState(), OneState()] + xs = [] + for bit in bits: + x = bb.add(states[bit]) + xs.append(x) + xs = np.array(xs) + + return {'val': bb.join(xs)} + + val = val['val'] + xs = bb.split(val) + effects = [ZeroEffect(), OneEffect()] + for i, bit in enumerate(bits): + bb.add(effects[bit], q=xs[i]) + return {} def add_my_tensors( self, diff --git a/qualtran/bloqs/basic_gates/z_basis_test.py b/qualtran/bloqs/basic_gates/z_basis_test.py index 12bbb0e9b..9798a798f 100644 --- a/qualtran/bloqs/basic_gates/z_basis_test.py +++ b/qualtran/bloqs/basic_gates/z_basis_test.py @@ -15,6 +15,7 @@ import numpy as np import pytest +import qualtran.testing as qlt_testing from qualtran import BloqBuilder from qualtran.bloqs.basic_gates import ( IntEffect, @@ -147,6 +148,9 @@ def test_int_effect(): with pytest.raises(AssertionError): k.call_classically(val=245) + qlt_testing.assert_valid_bloq_decomposition(k) + np.testing.assert_allclose(k.tensor_contract(), k.decompose_bloq().tensor_contract()) + def test_to_cirq(): bb = BloqBuilder()