diff --git a/tests/test_particle_beam.py b/tests/test_particle_beam.py index 49003d8d..3cde97af 100644 --- a/tests/test_particle_beam.py +++ b/tests/test_particle_beam.py @@ -161,31 +161,76 @@ def test_only_sigma_vectorized(): assert beam.particles.shape == (2, 10_000, 7) -def test_indexing(): - # test batching with beamline parameters +def test_indexing_with_vectorized_beamline(): + """ + Test that indexing into a vectorised outgoing beam works when the vectorisation + originates in the beamline. + """ quadrupole = cheetah.Quadrupole( length=torch.tensor(0.2).unsqueeze(0), k1=torch.rand((5, 2)) ) - incoming = cheetah.ParticleBeam.from_parameters(sigma_x=torch.tensor(1e-5)) + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=1_000, sigma_x=torch.tensor(1e-5) + ) outgoing = quadrupole.track(incoming) + sub_beam = outgoing[:3] + + assert sub_beam.particles.shape == torch.Size([3, 2, 1_000, 7]) + assert sub_beam.energy.shape == torch.Size([3, 2]) + assert sub_beam.particle_charges.shape == torch.Size([3, 2, 1_000]) + assert sub_beam.surival_probabilities.shape == torch.Size([3, 2, 1_000]) + + assert torch.all(sub_beam.particles == outgoing.particles[:3]) + assert torch.all(sub_beam.energy == outgoing.energy[:3]) + assert torch.all(sub_beam.particle_charges == outgoing.particle_charges[:3]) + assert torch.all( + sub_beam.surival_probabilities == outgoing.surival_probabilities[:3] + ) - sub_beam = outgoing[:2] - assert sub_beam.beta_x.shape == torch.Size([2, 2]) - assert torch.equal(sub_beam.particle_charges, incoming.particle_charges) - assert torch.equal(sub_beam.energy, incoming.energy) - # test batching with energy - incoming = cheetah.ParticleBeam.from_parameters(sigma_x=torch.tensor(1e-5)) - incoming.energy.data = torch.rand((5, 2)) +def test_indexing_with_vectorized_incoming_beam(): + """ + Test that indexing into a vectorised outgoing beam works when the vectorisation + originates in the incoming beam. + """ + quadrupole = cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(0.1)) + incoming = cheetah.ParticleBeam.from_parameters( + num_particles=1_000, sigma_x=torch.tensor(1e-5), energy=torch.rand((5, 2)) + ) outgoing = quadrupole.track(incoming) - sub_beam = outgoing[:2] + sub_beam = outgoing[:3] + + assert sub_beam.particles.shape == torch.Size([3, 2, 1_000, 7]) + assert sub_beam.energy.shape == torch.Size([3, 2]) + assert sub_beam.particle_charges.shape == torch.Size([3, 2, 1_000]) + assert sub_beam.surival_probabilities.shape == torch.Size([3, 2, 1_000]) + + assert torch.all(sub_beam.particles == outgoing.particles[:3]) + assert torch.all(sub_beam.energy == outgoing.energy[:3]) + assert torch.all(sub_beam.particle_charges == outgoing.particle_charges[:3]) + assert torch.all( + sub_beam.surival_probabilities == outgoing.surival_probabilities[:3] + ) - assert sub_beam.beta_x.shape == torch.Size([2, 2]) - assert torch.equal(sub_beam.particle_charges, incoming.particle_charges) - assert torch.equal(sub_beam.energy, incoming.energy[:2]) + +def test_indexing_fails_for_inconsitent_vectorization(): + """ + Test that indexing into a vectorised beam fails when the vectorisation is + inconsistent, i.e. not broadcastable. + """ + beam = cheetah.ParticleBeam.from_parameters( + sigma_x=torch.rand((5, 2)), energy=torch.rand((4, 2)) + ) with pytest.raises(RuntimeError): - outgoing.energy.data = torch.rand((4, 2)) - outgoing[:2] + _ = beam[:3] + + +def test_indexing_fails_for_invalid_index(): + """Test that indexing into a vectorised beam fails when the index is invalid.""" + beam = cheetah.ParticleBeam.from_parameters(energy=torch.rand((5, 2))) + + with pytest.raises(IndexError): + _ = beam[6]