Skip to content

Commit

Permalink
fix: Working with non-main threads
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws committed Aug 14, 2024
1 parent 841fa69 commit 58b7ad0
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 35 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ test = [
"pytest-benchmark",
"pytest-cov",
"pytest-rerunfailures",
"pytest-timeout",
"pytest-xdist",
"qiskit-braket-provider",
"qiskit-algorithms",
"sphinx",
"sphinx-rtd-theme",
"sphinxcontrib-apidoc",
Expand Down
4 changes: 4 additions & 0 deletions src/braket/juliapkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
"BraketSimulator": {
"uuid": "76d27892-9a0b-406c-98e4-7c178e9b3dff",
"rev": "ksh/nobraket"
},
"JSON3": {
"uuid": "0f8b85d8-7281-11e9-16c2-39a750bddbf1",
"version": "1.14.0"
}
}
}
99 changes: 66 additions & 33 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,29 @@ def _openqasm_to_jl(self, openqasm_ir: OpenQASMProgram):
jl.convert(jl.String, openqasm_ir.braketSchemaHeader.version),
)
if openqasm_ir.inputs:
jl_inputs = jl.Dict(
[
jl_inputs = jl.Dict[jl.String, jl.Any](
jl.Pair(
jl.convert(jl.String, input_key),
(
jl.convert(jl.String, input_key),
(
jl.convert(jl.String, input_val)
if isinstance(input_val, str)
else jl.convert(jl.Number, input_val)
),
)
for (input_key, input_val) in openqasm_ir.inputs.items()
]
jl.convert(jl.String, input_val)
if isinstance(input_val, str)
else jl.convert(jl.Number, input_val)
),
)
for (input_key, input_val) in openqasm_ir.inputs.items()
)
else:
jl_inputs = jl.nothing
jl_inputs = jl.Dict[jl.String, jl.Float64]()
jl_source = jl.convert(jl.String, openqasm_ir.source)
return jl.BraketSimulator.OpenQasmProgram(
jl_braket_schema_header,
jl_source,
jl_inputs,
)

def _ir_list_to_jl(self, payloads: list[OpenQASMProgram], shots: int):
return [self._openqasm_to_jl(ir) for ir in payloads]

def run_openqasm(
self,
openqasm_ir: OpenQASMProgram,
Expand All @@ -70,29 +71,29 @@ def run_openqasm(
are requested when shots>0.
"""
try:
r = jl.simulate._jl_call_nogil(
self._device, self._openqasm_to_jl(openqasm_ir), shots
)
jl_ir = self._openqasm_to_jl(openqasm_ir)
jl_shots = jl.convert(jl.Int, shots)
jl_result = jl.simulate(self._device, [jl_ir], jl_shots)[0]
except JuliaError as e:
_handle_julia_error(e)
r.additionalMetadata.action = openqasm_ir

result = GateModelTaskResult.parse_raw_schema(jl_result)
result.additionalMetadata.action = openqasm_ir

# attach the result types
if not shots:
r = _result_value_to_ndarray(r)
result = _result_value_to_ndarray(result)
else:
r.resultTypes = [rt.type for rt in r.resultTypes]
return r

def _ir_list_to_jl(self, payloads: list[OpenQASMProgram], shots: int):
return [self._openqasm_to_jl(ir) for ir in payloads]
result.resultTypes = [rt.type for rt in result.resultTypes]
return result

def run_multiple(
self,
programs: Sequence[OpenQASMProgram],
max_parallel: Optional[int] = -1,
shots: Optional[int] = 0,
inputs: Optional[Union[dict, Sequence[dict]]] = None,
) -> list[GateModelTaskResult]:
): # -> list[GateModelTaskResult]:
"""
Run the tasks specified by the given IR programs.
Extra arguments will contain any additional information necessary to run the tasks,
Expand All @@ -107,24 +108,31 @@ def run_multiple(
"""
try:
julia_irs = self._ir_list_to_jl(programs, shots)
results = jl.simulate(
julia_inputs = (
jl.Dict[jl.String, jl.Float64]() if inputs is None else inputs
)
jl_results = jl.simulate(
self._device,
julia_irs,
max_parallel=max_parallel,
shots=shots,
inputs=inputs,
jl.convert(jl.Int, shots),
inputs=julia_inputs,
max_parallel=jl.convert(jl.Int, max_parallel),
)

except JuliaError as e:
_handle_julia_error(e)
results = [
GateModelTaskResult.parse_raw_schema(jl_result) for jl_result in jl_results
]
for p_ix, program in enumerate(programs):
results[p_ix].additionalMetadata.action = program

for r_ix, result in enumerate(results):
results[r_ix].additionalMetadata.action = programs[r_ix]
# attach the result types
if not shots:
results[r_ix] = _result_value_to_ndarray(result)
else:
results[r_ix].resultTypes = [rt.type for rt in result.resultTypes]

return results


Expand All @@ -135,11 +143,36 @@ def _result_value_to_ndarray(
np.ndarray. This must be done because the wrapper Julia simulator results Python lists to comply
with the pydantic specification for ResultTypeValues.
"""

def reconstruct_complex(v):
if isinstance(v, list):
return complex(v[0], v[1])
else:
return v

for result_ind, result_type in enumerate(task_result.resultTypes):
if isinstance(result_type.type, (StateVector, DensityMatrix, Probability)):
task_result.resultTypes[result_ind].value = np.asarray(
task_result.resultTypes[result_ind].value
)
# Amplitude
if isinstance(result_type.value, dict):
val = task_result.resultTypes[result_ind].value
task_result.resultTypes[result_ind].value = {
k: reconstruct_complex(v) for (k, v) in val.items()
}
if isinstance(result_type.type, StateVector):
val = task_result.resultTypes[result_ind].value
# complex are stored as tuples of reals
fixed_val = [reconstruct_complex(v) for v in val]
task_result.resultTypes[result_ind].value = np.asarray(fixed_val)
if isinstance(result_type.type, DensityMatrix):
val = task_result.resultTypes[result_ind].value
# complex are stored as tuples of reals
fixed_val = [
[reconstruct_complex(v) for v in inner_val] for inner_val in val
]
task_result.resultTypes[result_ind].value = np.asarray(fixed_val)
if isinstance(result_type.type, Probability):
val = task_result.resultTypes[result_ind].value
task_result.resultTypes[result_ind].value = np.asarray(val)

return task_result


Expand Down
3 changes: 1 addition & 2 deletions src/braket/simulator_v2/julia_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,5 @@

jl = juliacall.Base.Module()

jl.seval("using PythonCall, BraketSimulator")
jl.seval("using PythonCall: Py, pyconvert")
jl.seval("using JSON3, BraketSimulator")
jlBraketSimulator = jl.BraketSimulator
49 changes: 49 additions & 0 deletions test/unit_tests/braket/simulator_v2/test_qiskit_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from qiskit.circuit.library import TwoLocal

# Import some utilities
from qiskit.primitives import BackendEstimator
from qiskit.quantum_info import SparsePauliOp
from qiskit_algorithms import VQE
from qiskit_algorithms.optimizers import SLSQP
from qiskit_braket_provider import BraketLocalBackend

# For now, simply test that this completes in reasonable
# time and doesn't hang due to Python vs Julia threading
# issues or the GIL being locked


@pytest.fixture
def H2_op():
# Define the Hamiltonian operator for H2 in terms of Pauli spin operators
return SparsePauliOp.from_list(
[
("II", -1.052373245772859),
("IZ", 0.39793742484318045),
("ZI", -0.39793742484318045),
("ZZ", -0.01128010425623538),
("XX", 0.18093119978423156),
]
)


@pytest.fixture
def vqe():
local_simulator = BraketLocalBackend(name="braket_sv_v2")
# Define a `BackendEstimator` with a Braket backend
qi = BackendEstimator(local_simulator, options={"seed_simulator": 42})
qi.set_transpile_options(seed_transpiler=42)

# Specify VQE configuration
ansatz = TwoLocal(rotation_blocks="ry", entanglement_blocks="cz")
slsqp = SLSQP(maxiter=1)
return VQE(estimator=qi, ansatz=ansatz, optimizer=slsqp)


@pytest.mark.timeout(10)
def test_qiskit_vqe(H2_op, vqe):
# Find the ground state
print("Computing VQE", flush=True)
result = vqe.compute_minimum_eigenvalue(H2_op)
print("Done computing VQE", flush=True)
assert result.eigenvalue < 0.0

0 comments on commit 58b7ad0

Please sign in to comment.