diff --git a/pennylane_qiskit/qiskit_device.py b/pennylane_qiskit/qiskit_device.py index 1dea51405..10057fabc 100644 --- a/pennylane_qiskit/qiskit_device.py +++ b/pennylane_qiskit/qiskit_device.py @@ -114,8 +114,8 @@ class QiskitDevice(QubitDevice, abc.ABC): _eigs = {} - def __init__(self, wires, provider, backend, shots=1024, **kwargs): - super().__init__(wires=wires, shots=shots) + def __init__(self, wires, provider, backend, shots=1024, cache=0, **kwargs): + super().__init__(wires=wires, shots=shots, cache=cache) # Keep track if the user specified analytic to be True if shots is None and backend not in self._state_backends: @@ -379,6 +379,20 @@ def analytic_probability(self, wires=None): def batch_execute(self, circuits): + # Check if any of the circuit results was cached + if self._cache: + non_cached_circuits = [] + cached_results = {} + for idx, circuit in enumerate(circuits): + circuit_hash = circuit.graph.hash + if circuit_hash in self._cache_execute: + cached_results[idx] = self._cache_execute[circuit_hash] + else: + non_cached_circuits.append(circuit) + + # Only keep the non-cached circuits + circuits = non_cached_circuits + compiled_circuits = [] # Compile each circuit object @@ -393,9 +407,11 @@ def batch_execute(self, circuits): compiled_circ.name = f"circ{len(compiled_circuits)}" compiled_circuits.append(compiled_circ) - # Send the batch of circuit objects using backend.run - self._current_job = self.backend.run(compiled_circuits, shots=self.shots, **self.run_args) - result = self._current_job.result() + if compiled_circuits: + + # Send the batch of circuit objects using backend.run + self._current_job = self.backend.run(compiled_circuits, shots=self.shots, **self.run_args) + result = self._current_job.result() # Compute statistics using the state and/or samples results = [] @@ -412,6 +428,21 @@ def batch_execute(self, circuits): res = np.asarray(res) results.append(res) + if self._cache: + + # Store the computed results if applicable + for res, circ in zip(results, circuits): + circuit_hash = circ.graph.hash + if circuit_hash not in self._cache_execute: + self._cache_execute[circuit_hash] = res + + if len(self._cache_execute) > self._cache: + self._cache_execute.popitem(last=False) + + # Insert the cached results obtained at the start + for k, v in cached_results.items(): + results.insert(k, v) + if self.tracker.active: self.tracker.update(batches=1, batch_len=len(circuits)) self.tracker.record()