Skip to content
This repository has been archived by the owner on Oct 11, 2023. It is now read-only.

Commit

Permalink
change: make s3_folder optional (#66)
Browse files Browse the repository at this point in the history
* change: make s3_folder optional, make device_arn optional for BraketSampler

Co-authored-by: Aaron Berdy <[email protected]>
  • Loading branch information
virajvchaudhari and ajberdy authored Feb 17, 2022
1 parent bde2922 commit b7d3ded
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 61 deletions.
4 changes: 1 addition & 3 deletions examples/braket_dwave_sampler_factoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
# Getting the SampleSet from the quantum task object


# Declare folder to save S3 results
s3_destination_folder = ("your-s3-bucket", "your-folder")
# Declare sampler
sampler = BraketDWaveSampler(s3_destination_folder)
sampler = BraketDWaveSampler()

integer_to_factor = 15

Expand Down
3 changes: 1 addition & 2 deletions examples/braket_dwave_sampler_min_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

from braket.ocean_plugin import BraketDWaveSampler

s3_destination_folder = ("your-s3-bucket", "your-folder")
sampler = BraketDWaveSampler(s3_destination_folder)
sampler = BraketDWaveSampler()

star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes

Expand Down
11 changes: 3 additions & 8 deletions examples/braket_sampler_min_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,13 @@

import dwave_networkx as dnx
import networkx as nx
from braket.aws import AwsDevice
from dwave.system.composites import EmbeddingComposite

from braket.ocean_plugin import BraketSampler

s3_destination_folder = ("your-s3-bucket", "your-folder")

# Get an online D-Wave device ARN
device_arn = AwsDevice.get_devices(provider_names=["D-Wave Systems"], statuses=["ONLINE"])[0].arn
print("Using device ARN", device_arn)

sampler = BraketSampler(s3_destination_folder, device_arn)
# Use a default online D-Wave device ARN
sampler = BraketSampler()
print("Using device ARN", sampler.solver.arn)

star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes

Expand Down
4 changes: 1 addition & 3 deletions examples/debug_braket_dwave_sampler_min_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@
) # configure to log to file
logger.setLevel(logging.DEBUG) # log to file all log messages with level DEBUG or above

s3_destination_folder = ("your-s3-bucket", "your-folder")

# Pass in logger to BraketDWaveSampler
sampler = BraketDWaveSampler(s3_destination_folder, logger=logger)
sampler = BraketDWaveSampler(logger=logger)

star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes

Expand Down
9 changes: 1 addition & 8 deletions examples/debug_braket_sampler_min_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import dwave_networkx as dnx
import networkx as nx
from braket.aws import AwsDevice
from dwave.system.composites import EmbeddingComposite

from braket.ocean_plugin import BraketSampler
Expand All @@ -25,14 +24,8 @@
logger.addHandler(logging.StreamHandler(stream=sys.stdout)) # configure to print to sys.stdout
logger.setLevel(logging.DEBUG) # print to sys.stdout all log messages with level DEBUG or above

s3_destination_folder = ("your-s3-bucket", "your-folder")

# Get an online D-Wave device ARN
device_arn = AwsDevice.get_devices(provider_names=["D-Wave Systems"], statuses=["ONLINE"])[0].arn
print("Using device ARN", device_arn)

# Pass in logger to BraketSampler
sampler = BraketSampler(s3_destination_folder, device_arn, logger=logger)
sampler = BraketSampler(logger=logger)

star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
packages=find_namespace_packages(where="src", exclude=("test",)),
package_dir={"": "src"},
install_requires=[
"amazon-braket-sdk",
"amazon-braket-sdk>=1.10.0",
"boto3>=1.18.13",
"boltons>=20.0.0",
"colorama>=0.4.3",
Expand Down
27 changes: 10 additions & 17 deletions src/braket/ocean_plugin/braket_dwave_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import jsonref
from boltons.dictutils import FrozenDict
from braket.aws import AwsDevice, AwsSession
from braket.aws import AwsSession
from braket.tasks import QuantumTask
from dimod import SampleSet

Expand Down Expand Up @@ -48,18 +48,11 @@ class BraketDWaveSampler(BraketSampler):

def __init__(
self,
s3_destination_folder: AwsSession.S3DestinationFolder,
s3_destination_folder: AwsSession.S3DestinationFolder = None,
device_arn: str = None,
aws_session: AwsSession = None,
logger: Logger = getLogger(__name__),
):
if not device_arn:
try:
device_arn = AwsDevice.get_devices(
provider_names=["D-Wave Systems"], statuses=["ONLINE"]
)[0].arn
except IndexError:
raise RuntimeError("No D-Wave devices online")
super().__init__(s3_destination_folder, device_arn, aws_session, logger)

@property
Expand Down Expand Up @@ -138,7 +131,7 @@ def sample_ising(
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> h = {0: -1, 1: 1}
>>> sampleset = sampler.sample_ising(h, {}, answer_mode="HISTOGRAM")
>>> for sample in sampleset.samples():
Expand All @@ -151,7 +144,7 @@ def sample_ising(
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> h = {30: -1, 31: 1}
>>> sampleset = sampler.sample_ising(h, {}, answer_mode="HISTOGRAM")
>>> for sample in sampleset.samples():
Expand Down Expand Up @@ -191,7 +184,7 @@ def sample_ising_quantum_task(
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> Q = {0: 1, 1: 1}
>>> task = sampler.sample_ising_quantum_task(Q, {}, answer_mode="HISTOGRAM")
>>> sampleset = BraketDWaveSampler.get_task_sample_set(task)
Expand All @@ -206,7 +199,7 @@ def sample_ising_quantum_task(
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> Q = {30: 1, 31: 1}
>>> task = sampler.sample_ising_quantum_task(Q, {}, answer_mode="HISTOGRAM")
>>> sampleset = BraketDWaveSampler.get_task_sample_set(task)
Expand Down Expand Up @@ -236,7 +229,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet:
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2}
>>> sampleset = sampler.sample_qubo(Q, postprocess="SAMPLING", num_reads=100)
>>> for sample in sampleset.samples():
Expand All @@ -249,7 +242,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet:
30 and 31 on a sampler on the D-Wave Advantage4 device.
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2}
>>> sampleset = sampler.sample_qubo(Q, num_reads=100)
>>> for sample in sampleset.samples():
Expand Down Expand Up @@ -280,7 +273,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) ->
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2}
>>> task = sampler.sample_qubo_quantum_task(Q, answer_mode="HISTOGRAM", num_reads=100)
>>> sampleset = BraketDWaveSampler.get_task_sample_set(task)
Expand All @@ -295,7 +288,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) ->
>>> from braket.ocean_plugin import BraketDWaveSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketDWaveSampler(device_arn_1)
>>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2}
>>> task = sampler.sample_qubo_quantum_task(Q, answer_mode="HISTOGRAM", num_reads=100)
>>> sampleset = BraketDWaveSampler.get_task_sample_set(task)
Expand Down
29 changes: 18 additions & 11 deletions src/braket/ocean_plugin/braket_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ class BraketSampler(Sampler, Structured):

def __init__(
self,
s3_destination_folder: AwsSession.S3DestinationFolder,
device_arn: str,
s3_destination_folder: AwsSession.S3DestinationFolder = None,
device_arn: str = None,
aws_session: AwsSession = None,
logger: Logger = getLogger(__name__),
):
if not device_arn:
try:
device_arn = AwsDevice.get_devices(
provider_names=["D-Wave Systems"], statuses=["ONLINE"]
)[0].arn
except IndexError:
raise RuntimeError("No D-Wave devices online")

self._s3_destination_folder = s3_destination_folder
self._device_arn = device_arn
self._logger = logger
Expand Down Expand Up @@ -179,7 +187,7 @@ def sample_ising(
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> h = {0: -1, 1: 1}
>>> sampleset = sampler.sample_ising(h, {}, resultFormat="HISTOGRAM")
>>> for sample in sampleset.samples():
Expand All @@ -192,7 +200,7 @@ def sample_ising(
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> h = {30: -1, 31: 1}
>>> sampleset = sampler.sample_ising(h, {}, resultFormat="HISTOGRAM")
>>> for sample in sampleset.samples():
Expand Down Expand Up @@ -243,7 +251,7 @@ def sample_ising_quantum_task(
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> Q = {0: 1, 1: 1}
>>> task = sampler.sample_ising_quantum_task(Q, {}, resultFormat="HISTOGRAM")
>>> sampleset = BraketSampler.get_task_sample_set(task)
Expand All @@ -257,7 +265,7 @@ def sample_ising_quantum_task(
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> Q = {30: 1, 31: 1}
>>> task = sampler.sample_ising_quantum_task(Q, {}, resultFormat="HISTOGRAM")
>>> sampleset = BraketSampler.get_task_sample_set(task)
Expand Down Expand Up @@ -308,7 +316,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet:
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2}
>>> sampleset = sampler.sample_qubo(Q, postprocessingType="SAMPLING", shots=100)
>>> for sample in sampleset.samples():
Expand All @@ -321,7 +329,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet:
30 and 31 on a sampler on the D-Wave Advantage4 device.
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2}
>>> sampleset = sampler.sample_qubo(Q, shots=100)
>>> for sample in sampleset.samples():
Expand Down Expand Up @@ -359,7 +367,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) ->
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2}
>>> task = sampler.sample_qubo_quantum_task(Q, resultFormat="HISTOGRAM", shots=100)
>>> sampleset = BraketSampler.get_task_sample_set(task)
Expand All @@ -374,7 +382,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) ->
>>> from braket.ocean_plugin import BraketSampler
>>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4"
>>> sampler = BraketSampler(s3_destination_folder, device_arn_1)
>>> sampler = BraketSampler(device_arn_1)
>>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2}
>>> task = sampler.sample_qubo_quantum_task(Q, resultFormat="HISTOGRAM", shots=100)
>>> sampleset = BraketSampler.get_task_sample_set(task)
Expand All @@ -387,7 +395,6 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) ->
solver_kwargs = self._process_solver_kwargs(**kwargs)

sorted_edges = frozenset((u, v) if u < v else (v, u) for u, v in Q)
print(self._access_optimized_edgelist())
for u, v in sorted_edges:
if u not in self._access_optimized_nodelist():
raise BinaryQuadraticModelStructureError(
Expand Down
2 changes: 1 addition & 1 deletion test/integ_tests/test_braket_dwave_sampler_running.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_factoring_embedded_composite(
dwave_arn, aws_session, s3_destination_folder, factoring_bqm, integer_to_factor
):
sampler = BraketDWaveSampler(
s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session
s3_destination_folder=s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session
)
embedding_sampler = EmbeddingComposite(sampler)
response = embedding_sampler.sample(
Expand Down
4 changes: 3 additions & 1 deletion test/integ_tests/test_braket_sampler_running.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
def test_factoring_minorminer(
dwave_arn, aws_session, s3_destination_folder, factoring_bqm, integer_to_factor
):
sampler = BraketSampler(s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session)
sampler = BraketSampler(
s3_destination_folder=s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session
)
_, target_edgelist, target_adjacency = sampler.structure
embedding = minorminer.find_embedding(factoring_bqm.quadratic, target_edgelist)
bqm_embedded = embed_bqm(factoring_bqm, embedding, target_adjacency, 3.0)
Expand Down
11 changes: 5 additions & 6 deletions test/unit_tests/braket/ocean_plugin/test_braket_dwave_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def device_parameters_2():
return {BraketSolverMetadata.DWAVE["device_parameters_key_name"]: {}}


# Removed s3_destination_folder fixture parameter.
@pytest.fixture
@patch("braket.ocean_plugin.braket_sampler.AwsDevice")
def braket_dwave_sampler(
Expand All @@ -62,9 +63,7 @@ def braket_dwave_sampler(


@patch("braket.ocean_plugin.braket_sampler.AwsDevice")
@patch("braket.ocean_plugin.braket_dwave_sampler.AwsDevice")
def test_default_device_arn(
dwave_sampler_mock_qpu,
sampler_mock_qpu,
braket_sampler_properties,
s3_destination_folder,
Expand All @@ -73,19 +72,19 @@ def test_default_device_arn(
):
mock_device = Mock()
mock_device.arn = dwave_arn
dwave_sampler_mock_qpu.get_devices.return_value = [mock_device]
sampler_mock_qpu.get_devices.return_value = [mock_device]
sampler_mock_qpu.return_value.properties = braket_sampler_properties
sampler = BraketDWaveSampler(s3_destination_folder, None, Mock(), logger)
assert isinstance(sampler, BraketSampler)
assert sampler._device_arn == dwave_arn


@pytest.mark.xfail(raises=RuntimeError)
@patch("braket.ocean_plugin.braket_dwave_sampler.AwsDevice")
def test_default_device_arn_error(dwave_sampler_mock_qpu, s3_destination_folder, logger, dwave_arn):
@patch("braket.ocean_plugin.braket_sampler.AwsDevice")
def test_default_device_arn_error(sampler_mock_qpu, s3_destination_folder, logger, dwave_arn):
mock_device = Mock()
mock_device.arn = dwave_arn
dwave_sampler_mock_qpu.get_devices.return_value = []
sampler_mock_qpu.get_devices.return_value = []
BraketDWaveSampler(s3_destination_folder, None, Mock(), logger)


Expand Down
17 changes: 17 additions & 0 deletions test/unit_tests/braket/ocean_plugin/test_braket_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def test_nodelist(braket_sampler):
assert braket_sampler.nodelist == (0, 1, 2)


@patch("braket.ocean_plugin.braket_sampler.AwsDevice")
def test_default_device_arn(
sampler_mock_qpu,
braket_sampler_properties,
s3_destination_folder,
logger,
dwave_arn,
):
mock_device = Mock()
mock_device.arn = dwave_arn
sampler_mock_qpu.get_devices.return_value = [mock_device]
sampler_mock_qpu.return_value.properties = braket_sampler_properties
sampler = BraketSampler(s3_destination_folder, None, Mock(), logger)
assert isinstance(sampler, BraketSampler)
assert sampler._device_arn == dwave_arn


@pytest.mark.xfail(raises=BinaryQuadraticModelStructureError)
@pytest.mark.parametrize(
"h, J", [({0: -1, 500: 1}, {}), ({0: -1, 1: 1}, {(0, 1): 3}), ({0: -1, 2: 1}, {(3, 500): 3})]
Expand Down

0 comments on commit b7d3ded

Please sign in to comment.