diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c1b3bfe7..dd94b574 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -103,3 +103,26 @@ def test_reference_particle_plot_vectorized_2d(): # Run the plotting to see if it raises an exception segment.plot_overview(incoming=incoming, resolution=0.1, vector_idx=(0, 2)) + + +def test_plotting_with_nonleave_tensors(): + """ + Test that the plotting routines can handle elements with non-leave tensors. + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), + cheetah.BPM(is_active=True), + ] + ) + + incoming = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + # Prepopulate the segment + segment.track(incoming) + + # Test that plotting does not raise an exception + segment.plot_overview(incoming=incoming) + segment.plot_twiss(incoming=incoming)