Skip to content

Commit

Permalink
Fix dtype checking and sigma broadcasting in ParticleBeam
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Nov 20, 2024
1 parent 2cbe869 commit 8b7e654
Showing 1 changed file with 71 additions and 108 deletions.
179 changes: 71 additions & 108 deletions cheetah/particles/particle_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@
from torch.distributions import MultivariateNormal

from cheetah.particles.beam import Beam
from cheetah.utils import (
are_all_the_same_device,
are_all_the_same_dtype,
elementwise_linspace,
extract_argument_shape,
verify_device_and_dtype,
)
from cheetah.utils import elementwise_linspace, verify_device_and_dtype

speed_of_light = torch.tensor(constants.speed_of_light) # In m/s
electron_mass = torch.tensor(constants.electron_mass) # In kg
Expand Down Expand Up @@ -65,7 +59,7 @@ def __init__(
@classmethod
def from_parameters(
cls,
num_particles: Optional[int] = None,
num_particles: int = 100_000,
mu_x: Optional[torch.Tensor] = None,
mu_y: Optional[torch.Tensor] = None,
mu_px: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -110,10 +104,9 @@ def from_parameters(
:param device: Device to move the beam's particle array to. If set to `"auto"` a
CUDA GPU is selected if available. The CPU is used otherwise.
"""
# Figure out if arguments were passed
not_nones = [
argument
for argument in [
# Extract device and dtype from given arguments
device, dtype = verify_device_and_dtype(
[
mu_x,
mu_px,
mu_y,
Expand All @@ -129,19 +122,13 @@ def from_parameters(
cor_tau,
energy,
total_charge,
]
if argument is not None
]

# Extract device and dtype from given arguments
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
],
device,
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
num_particles = (
num_particles if num_particles is not None else torch.tensor(100_000)
)
mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs)
mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs)
mu_y = mu_y if mu_y is not None else torch.tensor(0.0, **factory_kwargs)
Expand Down Expand Up @@ -222,15 +209,18 @@ def from_parameters(
cov[..., 5, 4] = cor_tau
cov[..., 5, 5] = sigma_p**2

particles = torch.ones((*mean.shape[:-1], num_particles, 7), **factory_kwargs)
vector_shape = torch.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
mean = mean.expand(*vector_shape, 6)
cov = cov.expand(*vector_shape, 6, 6)
particles = torch.ones((*vector_shape, num_particles, 7), **factory_kwargs)
distributions = [
MultivariateNormal(sample_mean, covariance_matrix=sample_cov)
for sample_mean, sample_cov in zip(mean.view(-1, 6), cov.view(-1, 6, 6))
]
particles[..., :6] = torch.stack(
[distribution.sample((num_particles,)) for distribution in distributions],
dim=0,
).view(*particles.shape[:-2], num_particles, 6)
).view(*vector_shape, num_particles, 6)

return cls(
particles,
Expand All @@ -243,7 +233,7 @@ def from_parameters(
@classmethod
def from_twiss(
cls,
num_particles: Optional[int] = None,
num_particles: int = 100_000,
beta_x: Optional[torch.Tensor] = None,
alpha_x: Optional[torch.Tensor] = None,
emittance_x: Optional[torch.Tensor] = None,
Expand All @@ -258,10 +248,9 @@ def from_twiss(
device=None,
dtype=None,
) -> "ParticleBeam":
# Figure out if arguments were passed
not_nones = [
argument
for argument in [
# Extract device and dtype from given arguments
device, dtype = verify_device_and_dtype(
[
beta_x,
alpha_x,
emittance_x,
Expand All @@ -273,60 +262,45 @@ def from_twiss(
sigma_p,
cor_tau,
total_charge,
]
if argument is not None
]

# Extract shape, device and dtype from given arguments
shape = extract_argument_shape(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
],
device,
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
num_particles = num_particles if num_particles is not None else 1_000_000
beta_x = (
beta_x if beta_x is not None else torch.full(shape, 0.0, **factory_kwargs)
)
beta_x = beta_x if beta_x is not None else torch.tensor(0.0, **factory_kwargs)
alpha_x = (
alpha_x if alpha_x is not None else torch.full(shape, 0.0, **factory_kwargs)
alpha_x if alpha_x is not None else torch.tensor(0.0, **factory_kwargs)
)
emittance_x = (
emittance_x
if emittance_x is not None
else torch.full(shape, 7.1971891e-13, **factory_kwargs)
)
beta_y = (
beta_y if beta_y is not None else torch.full(shape, 0.0, **factory_kwargs)
else torch.tensor(7.1971891e-13, **factory_kwargs)
)
beta_y = beta_y if beta_y is not None else torch.tensor(0.0, **factory_kwargs)
alpha_y = (
alpha_y if alpha_y is not None else torch.full(shape, 0.0, **factory_kwargs)
alpha_y if alpha_y is not None else torch.tensor(0.0, **factory_kwargs)
)
emittance_y = (
emittance_y
if emittance_y is not None
else torch.full(shape, 7.1971891e-13, **factory_kwargs)
)
energy = (
energy if energy is not None else torch.full(shape, 1e8, **factory_kwargs)
else torch.tensor(7.1971891e-13, **factory_kwargs)
)
energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs)
sigma_tau = (
sigma_tau
if sigma_tau is not None
else torch.full(shape, 1e-6, **factory_kwargs)
sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs)
)
sigma_p = (
sigma_p
if sigma_p is not None
else torch.full(shape, 1e-6, **factory_kwargs)
sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs)
)
cor_tau = (
cor_tau if cor_tau is not None else torch.full(shape, 0.0, **factory_kwargs)
cor_tau if cor_tau is not None else torch.tensor(0.0, **factory_kwargs)
)
total_charge = (
total_charge
if total_charge is not None
else torch.full(shape, 0.0, **factory_kwargs)
else torch.tensor(0.0, **factory_kwargs)
)

sigma_x = torch.sqrt(beta_x * emittance_x)
Expand All @@ -338,10 +312,10 @@ def from_twiss(

return cls.from_parameters(
num_particles=num_particles,
mu_x=torch.full(shape, 0.0, **factory_kwargs),
mu_px=torch.full(shape, 0.0, **factory_kwargs),
mu_y=torch.full(shape, 0.0, **factory_kwargs),
mu_py=torch.full(shape, 0.0, **factory_kwargs),
mu_x=torch.tensor(0.0, **factory_kwargs),
mu_px=torch.tensor(0.0, **factory_kwargs),
mu_y=torch.tensor(0.0, **factory_kwargs),
mu_py=torch.tensor(0.0, **factory_kwargs),
sigma_x=sigma_x,
sigma_px=sigma_px,
sigma_y=sigma_y,
Expand All @@ -360,7 +334,7 @@ def from_twiss(
@classmethod
def uniform_3d_ellipsoid(
cls,
num_particles: Optional[int] = None,
num_particles: int = 100_000,
radius_x: Optional[torch.Tensor] = None,
radius_y: Optional[torch.Tensor] = None,
radius_tau: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -399,11 +373,9 @@ def uniform_3d_ellipsoid(
:return: ParticleBeam with uniformly distributed particles inside an ellipsoid.
"""

# Figure out if arguments were passed
not_nones = [
argument
for argument in [
# Extract device and dtype from given arguments
device, dtype = verify_device_and_dtype(
[
radius_x,
radius_y,
radius_tau,
Expand All @@ -412,42 +384,34 @@ def uniform_3d_ellipsoid(
sigma_p,
energy,
total_charge,
]
if argument is not None
]

# Extract shape, device and dtype from given arguments
shape = extract_argument_shape(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
],
device,
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}

# Expand to vectorised version for beam creation
vector_shape = shape if len(shape) > 0 else torch.Size([1])

# Set default values without function call in function signature
# NOTE that this does not need to be done for values that are passed to the
# Gaussian beam generation.
num_particles = (
num_particles if num_particles is not None else torch.tensor(1_000_000)
)
radius_x = (
radius_x.expand(vector_shape)
if radius_x is not None
else torch.full(vector_shape, 1e-3, **factory_kwargs)
radius_x if radius_x is not None else torch.tensor(1e-3, **factory_kwargs)
)
radius_y = (
radius_y.expand(vector_shape)
if radius_y is not None
else torch.full(vector_shape, 1e-3, **factory_kwargs)
radius_y if radius_y is not None else torch.tensor(1e-3, **factory_kwargs)
)
radius_tau = (
radius_tau.expand(vector_shape)
radius_tau
if radius_tau is not None
else torch.full(vector_shape, 1e-3, **factory_kwargs)
else torch.tensor(1e-3, **factory_kwargs)
)

# Generate x, y and ss within the ellipsoid
# Generate x, y and tau within the ellipsoid
# Broadcasting with (1,) is a hack to make the loop work. Interestingly it
# this does not break the assigments into the non-vectorised particle tensor of
# the beam object.
vector_shape = torch.broadcast_shapes(
radius_x.shape, radius_y.shape, radius_tau.shape, (1,)
)
flattened_x = torch.empty(
*vector_shape, num_particles, **factory_kwargs
).flatten(end_dim=-2)
Expand Down Expand Up @@ -484,10 +448,13 @@ def uniform_3d_ellipsoid(
# Generate an uncorrelated Gaussian beam
beam = cls.from_parameters(
num_particles=num_particles,
mu_px=torch.full(shape, 0.0, **factory_kwargs),
mu_py=torch.full(shape, 0.0, **factory_kwargs),
mu_px=torch.tensor(0.0, **factory_kwargs),
mu_py=torch.tensor(0.0, **factory_kwargs),
sigma_x=radius_x, # Only a placeholder, will be overwritten
sigma_px=sigma_px,
sigma_y=radius_y, # Only a placeholder, will be overwritten
sigma_py=sigma_py,
sigma_tau=radius_tau, # Only a placeholder, will be overwritten
sigma_p=sigma_p,
energy=energy,
total_charge=total_charge,
Expand All @@ -496,9 +463,9 @@ def uniform_3d_ellipsoid(
)

# Replace the spatial coordinates with the generated ones
beam.x = flattened_x.view(*shape, num_particles)
beam.y = flattened_y.view(*shape, num_particles)
beam.tau = flattened_tau.view(*shape, num_particles)
beam.x = flattened_x.view(*vector_shape, num_particles)
beam.y = flattened_y.view(*vector_shape, num_particles)
beam.tau = flattened_tau.view(*vector_shape, num_particles)

return beam

Expand Down Expand Up @@ -542,10 +509,9 @@ def make_linspaced(
:param device: Device to move the beam's particle array to. If set to `"auto"` a
CUDA GPU is selected if available. The CPU is used otherwise.
"""
# Figure out if arguments were passed
not_nones = [
argument
for argument in [
# Extract device and dtype from given arguments
device, dtype = verify_device_and_dtype(
[
mu_x,
mu_px,
mu_y,
Expand All @@ -558,13 +524,10 @@ def make_linspaced(
sigma_p,
energy,
total_charge,
]
if argument is not None
]

# Extract device and dtype from given arguments
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
],
device,
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
Expand Down

0 comments on commit 8b7e654

Please sign in to comment.