Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leaks in DynamicsBackend.run #358

Open
xyzdxf opened this issue Aug 15, 2024 · 2 comments
Open

Memory leaks in DynamicsBackend.run #358

xyzdxf opened this issue Aug 15, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@xyzdxf
Copy link

xyzdxf commented Aug 15, 2024

Informations

  • Qiskit Dynamics version: 0.5.1
  • Python version: 3.9.18
  • Operating system: Darwin

What is the current behavior?

When running jobs with DynamicsBackend, the memory usage keeps increasing.
mem

Steps to reproduce the problem

Create a file named pulse_memory.py

# a parallelism warning raised by JAX is being raised due to somethign outside of Dynamics
import warnings
warnings.filterwarnings("ignore", message="os.fork")

# Configure JAX
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

from qiskit import QuantumCircuit
from qiskit import pulse
from qiskit_ibm_provider import ibm_provider
from qiskit.qobj.utils import MeasLevel
from qiskit_dynamics import DynamicsBackend
import gc

@profile
def run(circuit, backend):
    result = backend.run(circuit, meas_level=MeasLevel.KERNELED)
    del result
    gc.collect()

if __name__ == '__main__':
    # initialize the backend
    provider = ibm_provider.IBMProvider()
    kyoto = provider.get_backend('ibm_kyoto')
    sim_backend = DynamicsBackend.from_backend(kyoto,subsystem_list=[0], array_library="jax", rotating_frame="auto")
    dt = sim_backend.dt
    solver_options = {"method": "jax_odeint", "atol": 1e-6, "rtol": 1e-8, "hmax": dt}
    sim_backend.options.solver_options = solver_options

    # circuit to run
    qc = QuantumCircuit(1,1)
    qc.h(0)
    qc.measure([0],[0])

    with pulse.build() as h_q0:
        pulse.play(
            pulse.library.Gaussian(duration=256, amp=0.2, sigma=50, name="custom"),
            pulse.DriveChannel(0)
        )
    qc.add_calibration("h", qubits=[0], schedule=h_q0)


    # Repeat the experiment
    for _ in range(500):
        run(qc, sim_backend)
mprof run --python pulse_memory.py
mprof plot

What is the expected behavior?

The memory usage should not keep increasing ...

Suggested solutions

@xyzdxf xyzdxf added the bug Something isn't working label Aug 15, 2024
@DanPuzzuoli
Copy link
Collaborator

Hi @xyzdxf

Thanks for sharing this. My memory is vague but I think something like this has come up before, and it may have had something to do with JAX's storage of compiled functions.

I think if you put: jax.clear_caches() within the loop after each call it could solve this issue. Obviously you don't always want to do this if you're genuinely re-using compiled functions, but in this case some compiling is being done behind the scenes and isn't even being re-used anyway.

@xyzdxf
Copy link
Author

xyzdxf commented Aug 20, 2024

Hi @DanPuzzuoli

Thanks. After putting jax.clear_caches() within the loop after each call, the memory rise is reduced by a factor of ~2.
Figure_1

I have checked the messages in Slack, and it appears that the issue persists. The memory leak is from JAX, not sure how to fix it...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants