Skip to content

Commit

Permalink
Fix and split tests; add invalid index test
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Dec 3, 2024
1 parent 900f557 commit 7f277da
Showing 1 changed file with 61 additions and 16 deletions.
77 changes: 61 additions & 16 deletions tests/test_particle_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,31 +161,76 @@ def test_only_sigma_vectorized():
assert beam.particles.shape == (2, 10_000, 7)


def test_indexing():
# test batching with beamline parameters
def test_indexing_with_vectorized_beamline():
"""
Test that indexing into a vectorised outgoing beam works when the vectorisation
originates in the beamline.
"""
quadrupole = cheetah.Quadrupole(
length=torch.tensor(0.2).unsqueeze(0), k1=torch.rand((5, 2))
)
incoming = cheetah.ParticleBeam.from_parameters(sigma_x=torch.tensor(1e-5))
incoming = cheetah.ParticleBeam.from_parameters(
num_particles=1_000, sigma_x=torch.tensor(1e-5)
)

outgoing = quadrupole.track(incoming)
sub_beam = outgoing[:3]

assert sub_beam.particles.shape == torch.Size([3, 2, 1_000, 7])
assert sub_beam.energy.shape == torch.Size([3, 2])
assert sub_beam.particle_charges.shape == torch.Size([3, 2, 1_000])
assert sub_beam.surival_probabilities.shape == torch.Size([3, 2, 1_000])

assert torch.all(sub_beam.particles == outgoing.particles[:3])
assert torch.all(sub_beam.energy == outgoing.energy[:3])
assert torch.all(sub_beam.particle_charges == outgoing.particle_charges[:3])
assert torch.all(
sub_beam.surival_probabilities == outgoing.surival_probabilities[:3]
)

sub_beam = outgoing[:2]
assert sub_beam.beta_x.shape == torch.Size([2, 2])
assert torch.equal(sub_beam.particle_charges, incoming.particle_charges)
assert torch.equal(sub_beam.energy, incoming.energy)

# test batching with energy
incoming = cheetah.ParticleBeam.from_parameters(sigma_x=torch.tensor(1e-5))
incoming.energy.data = torch.rand((5, 2))
def test_indexing_with_vectorized_incoming_beam():
"""
Test that indexing into a vectorised outgoing beam works when the vectorisation
originates in the incoming beam.
"""
quadrupole = cheetah.Quadrupole(length=torch.tensor(0.2), k1=torch.tensor(0.1))
incoming = cheetah.ParticleBeam.from_parameters(
num_particles=1_000, sigma_x=torch.tensor(1e-5), energy=torch.rand((5, 2))
)

outgoing = quadrupole.track(incoming)
sub_beam = outgoing[:2]
sub_beam = outgoing[:3]

assert sub_beam.particles.shape == torch.Size([3, 2, 1_000, 7])
assert sub_beam.energy.shape == torch.Size([3, 2])
assert sub_beam.particle_charges.shape == torch.Size([3, 2, 1_000])
assert sub_beam.surival_probabilities.shape == torch.Size([3, 2, 1_000])

assert torch.all(sub_beam.particles == outgoing.particles[:3])
assert torch.all(sub_beam.energy == outgoing.energy[:3])
assert torch.all(sub_beam.particle_charges == outgoing.particle_charges[:3])
assert torch.all(
sub_beam.surival_probabilities == outgoing.surival_probabilities[:3]
)

assert sub_beam.beta_x.shape == torch.Size([2, 2])
assert torch.equal(sub_beam.particle_charges, incoming.particle_charges)
assert torch.equal(sub_beam.energy, incoming.energy[:2])

def test_indexing_fails_for_inconsitent_vectorization():
"""
Test that indexing into a vectorised beam fails when the vectorisation is
inconsistent, i.e. not broadcastable.
"""
beam = cheetah.ParticleBeam.from_parameters(
sigma_x=torch.rand((5, 2)), energy=torch.rand((4, 2))
)

with pytest.raises(RuntimeError):
outgoing.energy.data = torch.rand((4, 2))
outgoing[:2]
_ = beam[:3]


def test_indexing_fails_for_invalid_index():
"""Test that indexing into a vectorised beam fails when the index is invalid."""
beam = cheetah.ParticleBeam.from_parameters(energy=torch.rand((5, 2)))

with pytest.raises(IndexError):
_ = beam[6]

0 comments on commit 7f277da

Please sign in to comment.