diff --git a/examples/braket_dwave_sampler_factoring.py b/examples/braket_dwave_sampler_factoring.py index f5d7163..84a66f3 100644 --- a/examples/braket_dwave_sampler_factoring.py +++ b/examples/braket_dwave_sampler_factoring.py @@ -26,10 +26,8 @@ # Getting the SampleSet from the quantum task object -# Declare folder to save S3 results -s3_destination_folder = ("your-s3-bucket", "your-folder") # Declare sampler -sampler = BraketDWaveSampler(s3_destination_folder) +sampler = BraketDWaveSampler() integer_to_factor = 15 diff --git a/examples/braket_dwave_sampler_min_vertex.py b/examples/braket_dwave_sampler_min_vertex.py index d7b21db..495fa93 100644 --- a/examples/braket_dwave_sampler_min_vertex.py +++ b/examples/braket_dwave_sampler_min_vertex.py @@ -17,8 +17,7 @@ from braket.ocean_plugin import BraketDWaveSampler -s3_destination_folder = ("your-s3-bucket", "your-folder") -sampler = BraketDWaveSampler(s3_destination_folder) +sampler = BraketDWaveSampler() star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes diff --git a/examples/braket_sampler_min_vertex.py b/examples/braket_sampler_min_vertex.py index e9c8330..72c6811 100644 --- a/examples/braket_sampler_min_vertex.py +++ b/examples/braket_sampler_min_vertex.py @@ -13,18 +13,13 @@ import dwave_networkx as dnx import networkx as nx -from braket.aws import AwsDevice from dwave.system.composites import EmbeddingComposite from braket.ocean_plugin import BraketSampler -s3_destination_folder = ("your-s3-bucket", "your-folder") - -# Get an online D-Wave device ARN -device_arn = AwsDevice.get_devices(provider_names=["D-Wave Systems"], statuses=["ONLINE"])[0].arn -print("Using device ARN", device_arn) - -sampler = BraketSampler(s3_destination_folder, device_arn) +# Use a default online D-Wave device ARN +sampler = BraketSampler() +print("Using device ARN", sampler.solver.arn) star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes diff --git a/examples/debug_braket_dwave_sampler_min_vertex.py b/examples/debug_braket_dwave_sampler_min_vertex.py index cf987f4..1cf9116 100644 --- a/examples/debug_braket_dwave_sampler_min_vertex.py +++ b/examples/debug_braket_dwave_sampler_min_vertex.py @@ -25,10 +25,8 @@ ) # configure to log to file logger.setLevel(logging.DEBUG) # log to file all log messages with level DEBUG or above -s3_destination_folder = ("your-s3-bucket", "your-folder") - # Pass in logger to BraketDWaveSampler -sampler = BraketDWaveSampler(s3_destination_folder, logger=logger) +sampler = BraketDWaveSampler(logger=logger) star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes diff --git a/examples/debug_braket_sampler_min_vertex.py b/examples/debug_braket_sampler_min_vertex.py index a4e2c7c..0b2baa3 100644 --- a/examples/debug_braket_sampler_min_vertex.py +++ b/examples/debug_braket_sampler_min_vertex.py @@ -16,7 +16,6 @@ import dwave_networkx as dnx import networkx as nx -from braket.aws import AwsDevice from dwave.system.composites import EmbeddingComposite from braket.ocean_plugin import BraketSampler @@ -25,14 +24,8 @@ logger.addHandler(logging.StreamHandler(stream=sys.stdout)) # configure to print to sys.stdout logger.setLevel(logging.DEBUG) # print to sys.stdout all log messages with level DEBUG or above -s3_destination_folder = ("your-s3-bucket", "your-folder") - -# Get an online D-Wave device ARN -device_arn = AwsDevice.get_devices(provider_names=["D-Wave Systems"], statuses=["ONLINE"])[0].arn -print("Using device ARN", device_arn) - # Pass in logger to BraketSampler -sampler = BraketSampler(s3_destination_folder, device_arn, logger=logger) +sampler = BraketSampler(logger=logger) star_graph = nx.star_graph(4) # star graph where node 0 is connected to 4 other nodes diff --git a/setup.py b/setup.py index 3de64d6..236be29 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ packages=find_namespace_packages(where="src", exclude=("test",)), package_dir={"": "src"}, install_requires=[ - "amazon-braket-sdk", + "amazon-braket-sdk>=1.10.0", "boto3>=1.18.13", "boltons>=20.0.0", "colorama>=0.4.3", diff --git a/src/braket/ocean_plugin/braket_dwave_sampler.py b/src/braket/ocean_plugin/braket_dwave_sampler.py index 2c649c8..a244f90 100644 --- a/src/braket/ocean_plugin/braket_dwave_sampler.py +++ b/src/braket/ocean_plugin/braket_dwave_sampler.py @@ -20,7 +20,7 @@ import jsonref from boltons.dictutils import FrozenDict -from braket.aws import AwsDevice, AwsSession +from braket.aws import AwsSession from braket.tasks import QuantumTask from dimod import SampleSet @@ -48,18 +48,11 @@ class BraketDWaveSampler(BraketSampler): def __init__( self, - s3_destination_folder: AwsSession.S3DestinationFolder, + s3_destination_folder: AwsSession.S3DestinationFolder = None, device_arn: str = None, aws_session: AwsSession = None, logger: Logger = getLogger(__name__), ): - if not device_arn: - try: - device_arn = AwsDevice.get_devices( - provider_names=["D-Wave Systems"], statuses=["ONLINE"] - )[0].arn - except IndexError: - raise RuntimeError("No D-Wave devices online") super().__init__(s3_destination_folder, device_arn, aws_session, logger) @property @@ -138,7 +131,7 @@ def sample_ising( >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> h = {0: -1, 1: 1} >>> sampleset = sampler.sample_ising(h, {}, answer_mode="HISTOGRAM") >>> for sample in sampleset.samples(): @@ -151,7 +144,7 @@ def sample_ising( >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> h = {30: -1, 31: 1} >>> sampleset = sampler.sample_ising(h, {}, answer_mode="HISTOGRAM") >>> for sample in sampleset.samples(): @@ -191,7 +184,7 @@ def sample_ising_quantum_task( >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> Q = {0: 1, 1: 1} >>> task = sampler.sample_ising_quantum_task(Q, {}, answer_mode="HISTOGRAM") >>> sampleset = BraketDWaveSampler.get_task_sample_set(task) @@ -206,7 +199,7 @@ def sample_ising_quantum_task( >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> Q = {30: 1, 31: 1} >>> task = sampler.sample_ising_quantum_task(Q, {}, answer_mode="HISTOGRAM") >>> sampleset = BraketDWaveSampler.get_task_sample_set(task) @@ -236,7 +229,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet: >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2} >>> sampleset = sampler.sample_qubo(Q, postprocess="SAMPLING", num_reads=100) >>> for sample in sampleset.samples(): @@ -249,7 +242,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet: 30 and 31 on a sampler on the D-Wave Advantage4 device. >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2} >>> sampleset = sampler.sample_qubo(Q, num_reads=100) >>> for sample in sampleset.samples(): @@ -280,7 +273,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) -> >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2} >>> task = sampler.sample_qubo_quantum_task(Q, answer_mode="HISTOGRAM", num_reads=100) >>> sampleset = BraketDWaveSampler.get_task_sample_set(task) @@ -295,7 +288,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) -> >>> from braket.ocean_plugin import BraketDWaveSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketDWaveSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketDWaveSampler(device_arn_1) >>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2} >>> task = sampler.sample_qubo_quantum_task(Q, answer_mode="HISTOGRAM", num_reads=100) >>> sampleset = BraketDWaveSampler.get_task_sample_set(task) diff --git a/src/braket/ocean_plugin/braket_sampler.py b/src/braket/ocean_plugin/braket_sampler.py index c50ec0f..d3aab7e 100644 --- a/src/braket/ocean_plugin/braket_sampler.py +++ b/src/braket/ocean_plugin/braket_sampler.py @@ -51,11 +51,19 @@ class BraketSampler(Sampler, Structured): def __init__( self, - s3_destination_folder: AwsSession.S3DestinationFolder, - device_arn: str, + s3_destination_folder: AwsSession.S3DestinationFolder = None, + device_arn: str = None, aws_session: AwsSession = None, logger: Logger = getLogger(__name__), ): + if not device_arn: + try: + device_arn = AwsDevice.get_devices( + provider_names=["D-Wave Systems"], statuses=["ONLINE"] + )[0].arn + except IndexError: + raise RuntimeError("No D-Wave devices online") + self._s3_destination_folder = s3_destination_folder self._device_arn = device_arn self._logger = logger @@ -179,7 +187,7 @@ def sample_ising( >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> h = {0: -1, 1: 1} >>> sampleset = sampler.sample_ising(h, {}, resultFormat="HISTOGRAM") >>> for sample in sampleset.samples(): @@ -192,7 +200,7 @@ def sample_ising( >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> h = {30: -1, 31: 1} >>> sampleset = sampler.sample_ising(h, {}, resultFormat="HISTOGRAM") >>> for sample in sampleset.samples(): @@ -243,7 +251,7 @@ def sample_ising_quantum_task( >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> Q = {0: 1, 1: 1} >>> task = sampler.sample_ising_quantum_task(Q, {}, resultFormat="HISTOGRAM") >>> sampleset = BraketSampler.get_task_sample_set(task) @@ -257,7 +265,7 @@ def sample_ising_quantum_task( >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> Q = {30: 1, 31: 1} >>> task = sampler.sample_ising_quantum_task(Q, {}, resultFormat="HISTOGRAM") >>> sampleset = BraketSampler.get_task_sample_set(task) @@ -308,7 +316,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet: >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2} >>> sampleset = sampler.sample_qubo(Q, postprocessingType="SAMPLING", shots=100) >>> for sample in sampleset.samples(): @@ -321,7 +329,7 @@ def sample_qubo(self, Q: Dict[Tuple[int, int], float], **kwargs) -> SampleSet: 30 and 31 on a sampler on the D-Wave Advantage4 device. >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2} >>> sampleset = sampler.sample_qubo(Q, shots=100) >>> for sample in sampleset.samples(): @@ -359,7 +367,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) -> >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> Q = {(0, 0): -1, (4, 4): -1, (0, 4): 2} >>> task = sampler.sample_qubo_quantum_task(Q, resultFormat="HISTOGRAM", shots=100) >>> sampleset = BraketSampler.get_task_sample_set(task) @@ -374,7 +382,7 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) -> >>> from braket.ocean_plugin import BraketSampler >>> device_arn_1 = "arn:aws:braket:::device/qpu/d-wave/Advantage_system4" - >>> sampler = BraketSampler(s3_destination_folder, device_arn_1) + >>> sampler = BraketSampler(device_arn_1) >>> Q = {(30, 30): -1, (31, 31): -1, (30, 31): 2} >>> task = sampler.sample_qubo_quantum_task(Q, resultFormat="HISTOGRAM", shots=100) >>> sampleset = BraketSampler.get_task_sample_set(task) @@ -387,7 +395,6 @@ def sample_qubo_quantum_task(self, Q: Dict[Tuple[int, int], float], **kwargs) -> solver_kwargs = self._process_solver_kwargs(**kwargs) sorted_edges = frozenset((u, v) if u < v else (v, u) for u, v in Q) - print(self._access_optimized_edgelist()) for u, v in sorted_edges: if u not in self._access_optimized_nodelist(): raise BinaryQuadraticModelStructureError( diff --git a/test/integ_tests/test_braket_dwave_sampler_running.py b/test/integ_tests/test_braket_dwave_sampler_running.py index 7c237bc..0cb31e3 100644 --- a/test/integ_tests/test_braket_dwave_sampler_running.py +++ b/test/integ_tests/test_braket_dwave_sampler_running.py @@ -21,7 +21,7 @@ def test_factoring_embedded_composite( dwave_arn, aws_session, s3_destination_folder, factoring_bqm, integer_to_factor ): sampler = BraketDWaveSampler( - s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session + s3_destination_folder=s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session ) embedding_sampler = EmbeddingComposite(sampler) response = embedding_sampler.sample( diff --git a/test/integ_tests/test_braket_sampler_running.py b/test/integ_tests/test_braket_sampler_running.py index 8902f99..9f79154 100644 --- a/test/integ_tests/test_braket_sampler_running.py +++ b/test/integ_tests/test_braket_sampler_running.py @@ -21,7 +21,9 @@ def test_factoring_minorminer( dwave_arn, aws_session, s3_destination_folder, factoring_bqm, integer_to_factor ): - sampler = BraketSampler(s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session) + sampler = BraketSampler( + s3_destination_folder=s3_destination_folder, device_arn=dwave_arn, aws_session=aws_session + ) _, target_edgelist, target_adjacency = sampler.structure embedding = minorminer.find_embedding(factoring_bqm.quadratic, target_edgelist) bqm_embedded = embed_bqm(factoring_bqm, embedding, target_adjacency, 3.0) diff --git a/test/unit_tests/braket/ocean_plugin/test_braket_dwave_sampler.py b/test/unit_tests/braket/ocean_plugin/test_braket_dwave_sampler.py index 8d4a553..77069b4 100644 --- a/test/unit_tests/braket/ocean_plugin/test_braket_dwave_sampler.py +++ b/test/unit_tests/braket/ocean_plugin/test_braket_dwave_sampler.py @@ -50,6 +50,7 @@ def device_parameters_2(): return {BraketSolverMetadata.DWAVE["device_parameters_key_name"]: {}} +# Removed s3_destination_folder fixture parameter. @pytest.fixture @patch("braket.ocean_plugin.braket_sampler.AwsDevice") def braket_dwave_sampler( @@ -62,9 +63,7 @@ def braket_dwave_sampler( @patch("braket.ocean_plugin.braket_sampler.AwsDevice") -@patch("braket.ocean_plugin.braket_dwave_sampler.AwsDevice") def test_default_device_arn( - dwave_sampler_mock_qpu, sampler_mock_qpu, braket_sampler_properties, s3_destination_folder, @@ -73,7 +72,7 @@ def test_default_device_arn( ): mock_device = Mock() mock_device.arn = dwave_arn - dwave_sampler_mock_qpu.get_devices.return_value = [mock_device] + sampler_mock_qpu.get_devices.return_value = [mock_device] sampler_mock_qpu.return_value.properties = braket_sampler_properties sampler = BraketDWaveSampler(s3_destination_folder, None, Mock(), logger) assert isinstance(sampler, BraketSampler) @@ -81,11 +80,11 @@ def test_default_device_arn( @pytest.mark.xfail(raises=RuntimeError) -@patch("braket.ocean_plugin.braket_dwave_sampler.AwsDevice") -def test_default_device_arn_error(dwave_sampler_mock_qpu, s3_destination_folder, logger, dwave_arn): +@patch("braket.ocean_plugin.braket_sampler.AwsDevice") +def test_default_device_arn_error(sampler_mock_qpu, s3_destination_folder, logger, dwave_arn): mock_device = Mock() mock_device.arn = dwave_arn - dwave_sampler_mock_qpu.get_devices.return_value = [] + sampler_mock_qpu.get_devices.return_value = [] BraketDWaveSampler(s3_destination_folder, None, Mock(), logger) diff --git a/test/unit_tests/braket/ocean_plugin/test_braket_sampler.py b/test/unit_tests/braket/ocean_plugin/test_braket_sampler.py index d3ae196..b8ca129 100644 --- a/test/unit_tests/braket/ocean_plugin/test_braket_sampler.py +++ b/test/unit_tests/braket/ocean_plugin/test_braket_sampler.py @@ -98,6 +98,23 @@ def test_nodelist(braket_sampler): assert braket_sampler.nodelist == (0, 1, 2) +@patch("braket.ocean_plugin.braket_sampler.AwsDevice") +def test_default_device_arn( + sampler_mock_qpu, + braket_sampler_properties, + s3_destination_folder, + logger, + dwave_arn, +): + mock_device = Mock() + mock_device.arn = dwave_arn + sampler_mock_qpu.get_devices.return_value = [mock_device] + sampler_mock_qpu.return_value.properties = braket_sampler_properties + sampler = BraketSampler(s3_destination_folder, None, Mock(), logger) + assert isinstance(sampler, BraketSampler) + assert sampler._device_arn == dwave_arn + + @pytest.mark.xfail(raises=BinaryQuadraticModelStructureError) @pytest.mark.parametrize( "h, J", [({0: -1, 500: 1}, {}), ({0: -1, 1: 1}, {(0, 1): 3}), ({0: -1, 2: 1}, {(3, 500): 3})]