Skip to content

Commit

Permalink
feature: Add Measure operator (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws authored May 7, 2024
1 parent e4b3c3b commit 9430d59
Show file tree
Hide file tree
Showing 14 changed files with 221 additions and 49 deletions.
31 changes: 21 additions & 10 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,30 @@ jobs:
docs:
name: Documentation
runs-on: ubuntu-latest
if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }}
needs: [os-test, version-test]
permissions:
actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created
contents: write
statuses: write
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: '1'
- run: |
julia --project=docs -e '
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
include("docs/make.jl")'
- uses: julia-actions/cache@v1
- name: Configure doc environment
shell: julia --project=docs --color=yes {0}
run: |
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
- name: Run doctests
shell: julia --project=docs --color=yes {0}
run: |
using Documenter: DocMeta, doctest
using Braket
DocMeta.setdocmeta!(Braket, :DocTestSetup, :(using Braket); recursive=true)
doctest(Braket)
2 changes: 2 additions & 0 deletions docs/src/circuits.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ FreeParameter
depth
qubits
qubit_count
measure
Measure
```

## Output to IR
Expand Down
2 changes: 1 addition & 1 deletion src/Braket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export provider_name, properties, type
export apply_gate_noise!, apply
export logs, log_metric, metrics, @hybrid_job
export depth, qubit_count, qubits, ir, IRType, OpenQASMSerializationProperties
export OpenQasmProgram
export OpenQasmProgram, Measure, measure
export simulate
export QueueDepthInfo, QueueType, Normal, Priority, queue_depth, queue_position

Expand Down
86 changes: 77 additions & 9 deletions src/circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ mutable struct Circuit
parameters::Set{FreeParameter}
observables_simultaneously_measureable::Bool
has_compiler_directives::Bool
measure_targets::Vector{Int}

@doc """
Circuit()
Construct an empty `Circuit`.
"""
Circuit() = new(Moments(), [], [], [], Dict(), Dict(), Set{Int}(), Set(), true, false)
Circuit() = new(Moments(), [], [], [], Dict(), Dict(), Set{Int}(), Set(), true, false, Int[])
end
"""
Circuit(m::Moments, ixs::Vector, rts::Vector{Result}, bri::Vector)
Expand Down Expand Up @@ -175,6 +176,7 @@ end
(c::Circuit)(::Type{T}, args...) where {T<:Gate} = apply_gate!(T, c, args...)
(c::Circuit)(::Type{T}, args...) where {T<:Noise} = apply_noise!(T, c, args...)
(c::Circuit)(::Type{T}) where {T<:CompilerDirective} = add_instruction!(c, Instruction{T}(T()))
(c::Circuit)(::Type{Measure}, args...) = foreach(target->add_instruction!(c, Instruction{Measure}(Measure(), target)), args)
(c::Circuit)(g::QO, args...) where {QO<:QuantumOperator} = add_instruction!(c, Instruction(g, args...))
(c::Circuit)(g::CD) where {CD<:CompilerDirective} = add_instruction!(c, Instruction{CD}(g))
(c::Circuit)(v::AbstractVector) = foreach(vi->c(vi...), v)
Expand All @@ -201,15 +203,15 @@ julia> θ = FreeParameter(:theta);
julia> circ = Circuit();
julia> circ = H(circ, 0)
julia> circ = H(circ, 0);
julia> circ = Rx(circ, 1, α)
julia> circ = Rx(circ, 1, α);
julia> circ = Ry(circ, 0, θ)
julia> circ = Ry(circ, 0, θ);
julia> circ = Probability(circ)
julia> circ = Probability(circ);
julia> new_circ = circ(theta=2.0, alpha=1.0)
julia> new_circ = circ(theta=2.0, alpha=1.0);
```
"""
function (c::Circuit)(arg::Number; kwargs...)
Expand Down Expand Up @@ -261,7 +263,9 @@ julia> H(c, 0);
julia> CNot(c, 0, 1);
julia> qubits(c)
QubitSet(0, 1)
QubitSet with 2 elements:
0
1
```
"""
qubits(c::Circuit) = (qs = union!(copy(c.moments._qubits), c.qubit_observable_set); QubitSet(qs))
Expand Down Expand Up @@ -297,12 +301,65 @@ Base.convert(::Type{Program}, c::Circuit) = (basis_rotation_instructions!(c); re
Circuit(p::Program) = convert(Circuit, p)
Program(c::Circuit) = convert(Program, c)

function _add_measure!(c::Circuit, target_qubits::QubitSet)
for (idx, target) in enumerate(target_qubits)
num_qubits_measured = !isempty(c.measure_targets) && length(target_qubits) == 1 ? length(c.measure_targets) : 0
add_instruction!(c, Instruction(Measure(idx - 1 + num_qubits_measured), target))
push!(c.measure_targets, target)
end
return c
end
_add_measure!(c::Circuit, target_qubits::Vector{IntOrQubit}) = _add_measure!(c, QubitSet(target_qubits))
_add_measure!(c::Circuit, target_qubit::IntOrQubit) = _add_measure!(c, QubitSet(target_qubit))

"""
measure(c::Circuit, target_qubits) -> Circuit
Add a [`Measure`](@ref) operator to `c`, ensuring only the targeted qubits are measured.
A `Measure` operation can **only** be applied if the circuit does **not** contain any result types.
If `c` has no qubits defined, or `target_qubits` are not within the qubit range of `c`,
an `ArgumentError` is raised. If the circuit `c` contains any result types, or any of the
target qubits are already measured, an `ErrorException` is raised.
# Examples
```jldoctest
julia> circ = Circuit([(H, 0), (CNot, 0, 1)]);
julia> circ = measure(circ, 0);
julia> circ.instructions
3-element Vector{Braket.Instruction}:
Braket.Instruction{H}(H(), QubitSet(0))
Braket.Instruction{CNot}(CNot(), QubitSet(0, 1))
Braket.Instruction{Measure}(Measure(0), QubitSet(0))
julia> circ = Circuit([(H, 0), (CNot, 0, 1), (StateVector,)]);
julia> circ = measure(circ, 0);
ERROR: a circuit cannot contain both measure instructions and result types.
[...]
julia> circ = Circuit([(H, 0), (CNot, 0, 1)]);
julia> circ = measure(circ, [0, 1, 0]);
ERROR: cannot repeat qubit(s) in the same measurement.
[...]
```
"""
function measure(c::Circuit, target_qubits)
isempty(c.result_types) || error("a circuit cannot contain both measure instructions and result types.")
# Check if there are repeated qubits in the same measurement
allunique(target_qubits) || error("cannot repeat qubit(s) in the same measurement.")
return _add_measure!(c, target_qubits)
end

function openqasm_header(c::Circuit, sps::SerializationProperties=OpenQASMSerializationProperties())
ir_instructions = ["OPENQASM 3.0;"]
for p in sort(string.(c.parameters))
push!(ir_instructions, "input float $p;")
end
isempty(c.result_types) && push!(ir_instructions, "bit[$(qubit_count(c))] b;")
bit_count = isempty(c.measure_targets) ? qubit_count(c) : length(c.measure_targets)
isempty(c.result_types) && push!(ir_instructions, "bit[$bit_count] b;")
if sps.qubit_reference_type == VIRTUAL
total_qubits = real(maximum(qubits(c))) + 1
push!(ir_instructions, "qubit[$total_qubits] q;")
Expand All @@ -326,7 +383,7 @@ function ir(c::Circuit, ::Val{:OpenQASM}; serialization_properties::Serializatio
ixs = map(ix->ir(ix, Val(:OpenQASM); serialization_properties=serialization_properties), c.instructions)
if !isempty(c.result_types)
rts = map(rt->ir(rt, Val(:OpenQASM); serialization_properties=serialization_properties), c.result_types)
else
elseif isempty(c.measure_targets) # measure all qubits if not explicitly specified
rts = ["b[$(idx-1)] = measure $(format(Int(qubit), serialization_properties));" for (idx, qubit) in enumerate(qubits(c))]
end
return OpenQasmProgram(header_dict[OpenQasmProgram], join(vcat(header, ixs, rts), "\n"), nothing)
Expand Down Expand Up @@ -488,6 +545,7 @@ add_to_qubit_observable_set!(c::Circuit, rt::AdjointGradient) = union!(c.qubit_
add_to_qubit_observable_set!(c::Circuit, rt::Result) = c.qubit_observable_set

function add_result_type!(c::Circuit, rt::Result)
isempty(c.measure_targets) || error("cannot add a result type to a circuit which already contains a measure instruction.")
rt_to_add = rt
if rt_to_add c.result_types
if rt_to_add isa AdjointGradient && any(rt_ isa AdjointGradient for rt_ in c.result_types)
Expand All @@ -505,7 +563,16 @@ end
add_result_type!(c::Circuit, rt::Result, target) = add_result_type!(c, remap(rt, target))
add_result_type!(c::Circuit, rt::Result, target_mapping::Dict{<:Integer, <:Integer}) = add_result_type!(c, remap(rt, target_mapping))

function _check_if_qubit_measured(c::Circuit, qubit::Int)
isempty(c.measure_targets) && return
# check if there is a measure instruction on the targeted qubit(s)
isempty(intersect(c.measure_targets, qubit)) || error("cannot apply instruction to measured qubits.")
end
_check_if_qubit_measured(c::Circuit, qubit::Qubit) = _check_if_qubit_measured(c, Int(qubit))
_check_if_qubit_measured(c::Circuit, qubits) = foreach(q->_check_if_qubit_measured(c, q), qubits)

function add_instruction!(c::Circuit, ix::Instruction{O}) where {O<:Operator}
_check_if_qubit_measured(c, ix.target)
to_add = [ix]
if ix.operator isa QuantumOperator && Parametrizable(ix.operator) == Parametrized()
for param in parameters(ix.operator)
Expand Down Expand Up @@ -536,6 +603,7 @@ end

function add_verbatim_box!(c::Circuit, verbatim_circuit::Circuit, target_mapping::Dict{<:Integer, <:Integer}=Dict{Int, Int}())
!isempty(verbatim_circuit.result_types) && throw(ErrorException("verbatim subcircuit is not measured and cannot have result types."))
isempty(verbatim_circuit.measure_targets) || error("cannot measure a subcircuit inside a verbatim box.")
isempty(verbatim_circuit.instructions) && return c
c = add_instruction!(c, Instruction(StartVerbatimBox()))
for ix in verbatim_circuit.instructions
Expand Down
2 changes: 1 addition & 1 deletion src/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ generate its ARN when passed to the appropriate function.
# Examples
```jldoctest
julia> d = Braket.SV1()
julia> d = Braket.SV1();
julia> arn(d)
"arn:aws:braket:::device/quantum-simulator/amazon/sv1"
Expand Down
3 changes: 0 additions & 3 deletions src/gates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,6 @@ function ir(g::Unitary, target::QubitSet, ::Val{:JAQCD}; kwargs...)
return IR.Unitary(t_c, mat, "unitary")
end
StructTypes.StructType(::Type{<:Gate}) = StructTypes.Struct()
abstract type Parametrizable end
struct Parametrized end
struct NonParametrized end

Parametrizable(g::AngledGate) = Parametrized()
Parametrizable(g::Gate) = NonParametrized()
Expand Down
12 changes: 6 additions & 6 deletions src/observables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Struct representing an observable of an arbitrary complex Hermitian matrix.
# Examples
```jldoctest
julia> ho = Braket.Observables.HermitianObservable([0 1; 1 0])
Braket.Observables.HermitianObservable(Complex{Int64}[0 + 0im 1 + 0im; 1 + 0im 0 + 0im])
HermitianObservable((2, 2))
```
"""
struct HermitianObservable <: NonCompositeObservable
Expand Down Expand Up @@ -110,12 +110,12 @@ Struct representing a tensor product of smaller observables.
# Examples
```jldoctest
julia> Braket.Observables.TensorProduct(["x", "h"])
Braket.Observables.TensorProduct(Braket.Observables.Observable[Braket.Observables.X(), Braket.Observables.H()])
X @ H
julia> ho = Braket.Observables.HermitianObservable([0 1; 1 0]);
julia> Braket.Observables.TensorProduct([ho, Braket.Observables.Z()])
Braket.Observables.TensorProduct(Braket.Observables.Observable[Braket.Observables.HermitianObservable(Complex{Int64}[0 + 0im 1 + 0im; 1 + 0im 0 + 0im]), Braket.Observables.Z()])
HermitianObservable((2, 2)) @ Z
```
"""
struct TensorProduct{O} <: Observable where {O<:Observable}
Expand Down Expand Up @@ -207,12 +207,12 @@ Struct representing the sum of observables.
# Examples
```jldoctest
julia> o1 = 2.0 * Observables.I() @ Observables.Z();
julia> o1 = 2.0 * Braket.Observables.I() * Braket.Observables.Z();
julia> o2 = 3.0 * Observables.X() @ Observables.X();
julia> o2 = 3.0 * Braket.Observables.X() * Braket.Observables.X();
julia> o = o1 + o2
Braket.Observables.Sum()
Sum(2.0 * I @ Z, 3.0 * X @ X)
```
"""
struct Sum <: Observable
Expand Down
28 changes: 28 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ ir(o::Operator, t::QubitSet; kwargs...) = ir(o, t, Val(IRType[]); kwargs...)
ir(o::Operator, t::IntOrQubit; kwargs...) = ir(o, QubitSet(t), Val(IRType[]); kwargs...)
ir(o::Operator, t::IntOrQubit, args...; kwargs...) = ir(o, QubitSet(t), args...; kwargs...)

abstract type Parametrizable end
struct Parametrized end
struct NonParametrized end

struct PauliEigenvalues{N}
coeff::Float64
PauliEigenvalues{N}(coeff::Float64=1.0) where {N} = new(coeff)
Expand Down Expand Up @@ -55,3 +59,27 @@ function Base.getindex(p::PauliEigenvalues{N}, i::Int)::Float64 where N
end
end
Base.getindex(p::PauliEigenvalues{N}, ix::Vector{Int}) where {N} = [p[i] for i in ix]

"""
Measure(index) <: QuantumOperator
Represents a measurement operation on targeted qubit, stored in the classical register at `index`.
"""
struct Measure <: QuantumOperator
index::Int
end
Measure() = Measure(-1)
Parametrizable(m::Measure) = NonParametrized()
chars(::Type{Measure}) = ("M",)
chars(m::Measure) = ("M",)
qubit_count(::Type{Measure}) = 1
ir(m::Measure, target::QubitSet, ::Val{:JAQCD}; kwargs...) = error("measure instructions are not supported with JAQCD.")
function ir(m::Measure, target::QubitSet, ::Val{:OpenQASM}; serialization_properties=OpenQASMSerializationProperties())
instructions = Vector{String}(undef, length(target))
for (idx, qubit) in enumerate(target)
bit_index = m.index > 0 && length(targets) == 1 ? m.index : idx - 1
t = format_qubits(qubit, serialization_properties)
instructions[idx] = "b[$bit_index] = measure $t;"
end
return join(instructions, "\n")
end
12 changes: 9 additions & 3 deletions src/qubit_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,22 @@ Elements may be `Int`s or [`Qubit`](@ref)s.
# Examples
```jldoctest
julia> QubitSet(1, Qubit(0))
QubitSet(1, Qubit(0))
QubitSet with 2 elements:
1
Qubit(0)
julia> QubitSet([2, 1])
QubitSet(2, 1)
QubitSet with 2 elements:
2
1
julia> QubitSet()
QubitSet()
julia> QubitSet(QubitSet(5, 1))
QubitSet(5, 1)
QubitSet with 2 elements:
5
1
```
"""
struct QubitSet <: AbstractSet{Int}
Expand Down
12 changes: 6 additions & 6 deletions src/schemas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ and a `target` set of qubits to which the `operator` is applied.
# Examples
```jldoctest
julia> Instruction(H(), 1)
Braket.Instruction(H(), QubitSet(1))
julia> Braket.Instruction(H(), 1)
Braket.Instruction{H}(H(), QubitSet(1))
julia> Instruction(CNot(), [1, Qubit(4)])
Braket.Instruction(CNot(), QubitSet(1, Qubit(4)))
julia> Braket.Instruction(CNot(), [1, Qubit(4)])
Braket.Instruction{CNot}(CNot(), QubitSet(1, Qubit(4)))
julia> Instruction(StartVerbatimBox(), QubitSet())
Braket.Instruction(StartVerbatimBox(), QubitSet())
julia> Braket.Instruction(StartVerbatimBox(), QubitSet())
Braket.Instruction{StartVerbatimBox}(StartVerbatimBox(), QubitSet())
```
"""
struct Instruction{O<:Operator}
Expand Down
29 changes: 29 additions & 0 deletions test/integ_tests/measure.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using AWS, Braket, Test

SHOTS = 8000

IONQ_ARN = "arn:aws:braket:us-east-1::device/qpu/ionq/Harmony"
SIMULATOR_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
OQC_ARN = "arn:aws:braket:eu-west-2::device/qpu/oqc/Lucy"

@testset "Measure operator" begin
@testset "Unsupported devices" begin
@testset "Arn $arn" for arn in (IONQ_ARN, SIMULATOR_ARN)
device = AwsDevice(arn)
status(device) == "OFFLINE" && continue
circ = Circuit([(H, 0), (CNot, 0, 1), (H, 2), (Measure, 0, 1)])
# TODO check error message
@test_throws AWS.AWSExceptions.AWSException device(circ, shots=1000)
end
end
@testset "Supported devices" begin
@testset "Arn $arn" for arn in (OQC_ARN,)
device = AwsDevice(arn)
status(device) == "OFFLINE" && continue
circ = Circuit([(H, 0), (CNot, 0, 1), (Measure, 0)])
res = result(device(circ, shots=SHOTS))
@test all(m->length(m) == 1, res.measurements)
@test res.measured_qubits == [0]
end
end
end
Loading

0 comments on commit 9430d59

Please sign in to comment.