Skip to content

Commit

Permalink
A little cleanup on OpenPMD conversion implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Jan 10, 2025
1 parent 77b7eec commit 4edc949
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 65 deletions.
2 changes: 1 addition & 1 deletion cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Aperture(Element):
Physical aperture.
NOTE: The aperture currently only affects beams of type `ParticleBeam` and only has
an effect when the aperture is active.
an effect when the aperture is active.
:param x_max: half size horizontal offset in [m].
:param y_max: half size vertical offset in [m].
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Drift(Element):
"""
Drift section in a particle accelerator.
Note: the transfer map now uses the linear approximation.
NOTE: The transfer map now uses the linear approximation.
Including the R_56 = L / (beta**2 * gamma **2)
:param length: Length in meters.
Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class HorizontalCorrector(Element):
"""
Horizontal corrector magnet in a particle accelerator.
Note: This is modeled as a drift section with
a thin-kick in the horizontal plane.
NOTE: This is modeled as a drift section with a thin-kick in the horizontal plane.
:param length: Length in meters.
:param angle: Particle deflection angle in the horizontal plane in rad.
Expand Down
124 changes: 63 additions & 61 deletions cheetah/particles/particle_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import pmd_beamphysics as openpmd
import torch
from matplotlib import pyplot as plt
from pmd_beamphysics import ParticleGroup
from scipy import constants
from scipy.constants import physical_constants
from scipy.ndimage import gaussian_filter
Expand Down Expand Up @@ -672,33 +672,32 @@ def from_astra(cls, path: str, device=None, dtype=torch.float32) -> "ParticleBea

@classmethod
def from_openpmd_file(
cls, path: str, energy: torch.Tensor, device=None, dtype=torch.float32
cls, path: str, energy: torch.Tensor, device=None, dtype=None
) -> "ParticleBeam":
"""Load an OpenPMD particle distribution file as a Cheetah Beam."""

particle_group = ParticleGroup(path)
"""Load an OpenPMD particle group HDF5 file as a Cheetah `ParticleBeam`."""
particle_group = openpmd.ParticleGroup(path)
return cls.from_openpmd_particlegroup(
particle_group, energy, device=device, dtype=dtype
)

@classmethod
def from_openpmd_particlegroup(
cls,
particle_group: ParticleGroup,
particle_group: openpmd.ParticleGroup,
energy: torch.Tensor,
device=None,
dtype=torch.float32,
dtype=None,
) -> "ParticleBeam":
"""Convert an OpenPMD ParticleGroup to a Cheetah Beam.
"""
Create a Cheetah `ParticleBeam` from an OpenPMD `ParticleGroup` object.
:param particle_group: OpenPMD ParticleGroup object.
:param particle_group: OpenPMD `ParticleGroup` object.
:param energy: Reference energy of the beam in eV.
:param device: Device to move the beam's particle array to. If set to `"auto"` a
CUDA GPU is selected if available. The CPU is used otherwise.
:param dtype: Data type of the generated particles.
"""

# Assume only electron now
# For now, assume an electron beam
p0c = torch.sqrt(energy**2 - electron_mass_eV**2)

x = torch.from_numpy(particle_group.x)
Expand All @@ -721,6 +720,59 @@ def from_openpmd_particlegroup(
dtype=dtype,
)

def save_as_openpmd_h5(self, path: str) -> None:
"""
Save the `ParticleBeam` as an OpenPMD particle group HDF5 file.
:param path: Path to the file where the beam should be saved.
"""
particle_group = self.to_openpmd_particlegroup()
particle_group.write(path)

def to_openpmd_particlegroup(self) -> openpmd.ParticleGroup:
"""
Convert the `ParticleBeam` to an OpenPMD `ParticleGroup` object.
NOTE: OpenPMD uses boolean particle status flags, i.e. alive or dead. Cheetah's
survival probabilities are converted to status flags by thresholding at 0.5.
NOTE: At the moment this method only supports non-batched particles
distributions.
:return: OpenPMD `ParticleGroup` object with the `ParticleBeam`'s particles.
"""
# For now only support non-batched particles
if len(self.particles.shape) != 2:
raise ValueError("Only non-batched particles are supported.")

n_particles = self.num_particles
weights = np.ones(n_particles)
px = self.px * self.p0c
py = self.py * self.p0c
p_total = torch.sqrt(self.energies**2 - electron_mass_eV**2)
pz = torch.sqrt(p_total**2 - px**2 - py**2)
t = self.tau / speed_of_light
weights = self.particle_charges
# To be discussed
status = self.survival_probabilities > 0.5

data = {
"x": self.x.numpy(),
"y": self.y.numpy(),
"z": self.tau.numpy(),
"px": px.numpy(),
"py": py.numpy(),
"pz": pz.numpy(),
"t": t.numpy(),
"weight": weights,
"status": status,
# TODO: Modify when support for other species was added
"species": "electron",
}
particle_group = openpmd.ParticleGroup(data=data)

return particle_group

def transformed_to(
self,
mu_x: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1462,56 +1514,6 @@ def momenta(self) -> torch.Tensor:
"""Momenta of the individual particles."""
return torch.sqrt(self.energies**2 - electron_mass_eV**2)

def save_as_openpmd_h5(self, filename: str) -> None:
"""Save the ParticleBeam as an OpenPMD beamphysics HDF5 file.
:param filename: Path to the file where the beam should be saved.
"""
particle_group = self.to_openpmd_particlegroup()
particle_group.write(filename)

def to_openpmd_particlegroup(self) -> None:
"""
Convert the beam to an OpenPMD-beamphysics ParticleGroup object.
Note: Currently OpenPMD-beamphysics only supports boolean particle status
flags, i.e. alive or dead. The survival_probabilities will be converted to
status flags by thresholding at 0.5.
:return: OpenPMD-beamphysics ParticleGroup object with the beam's particles.
"""
# For now only support none-batched particles
if len(self.particles.shape) != 2:
raise ValueError("Only non-batched particles are supported.")

n_particles = self.num_particles
weights = np.ones(n_particles)
px = self.px * self.p0c
py = self.py * self.p0c
p_total = torch.sqrt(self.energies**2 - electron_mass_eV**2)
pz = torch.sqrt(p_total**2 - px**2 - py**2)
t = self.tau / speed_of_light
weights = self.particle_charges
# To be discussed
status = self.survival_probabilities > 0.5

data = {
"x": self.x.numpy(),
"y": self.y.numpy(),
"z": self.tau.numpy(),
"px": px.numpy(),
"py": py.numpy(),
"pz": pz.numpy(),
"t": t.numpy(),
"weight": weights,
"status": status,
# To be modified after adding support for other species
"species": "electron",
}
particle_group = ParticleGroup(data=data)

return particle_group

def clone(self) -> "ParticleBeam":
return ParticleBeam(
particles=self.particles.clone(),
Expand Down

0 comments on commit 4edc949

Please sign in to comment.