Skip to content

Commit

Permalink
Combine verify_device_and_dtype arguments and clean up naming
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Nov 20, 2024
1 parent f58dcce commit 38162c8
Show file tree
Hide file tree
Showing 15 changed files with 70 additions and 68 deletions.
2 changes: 1 addition & 1 deletion cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
device=None,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([], [x_max, y_max], device, dtype)
device, dtype = verify_device_and_dtype([x_max, y_max], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [voltage, phase, frequency], device, dtype
[length, voltage, phase, frequency], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
device=None,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([transfer_map], [length], device, dtype)
device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(
dtype=None,
):
device, dtype = verify_device_and_dtype(
[length],
[
length,
angle,
k1,
e1,
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
device=None,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([length], [angle], device, dtype)
device, dtype = verify_device_and_dtype([length, angle], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [k1, misalignment, tilt], device, dtype
[length, k1, misalignment, tilt], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
6 changes: 1 addition & 5 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def __init__(
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[], # No required tensor arguments
# Excludes resolution and binning, since those are integer valued, not float
[pixel_size, misalignment, kde_bandwidth],
device,
dtype,
[pixel_size, misalignment, kde_bandwidth], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/solenoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [k, misalignment], device, dtype
[length, k, misalignment], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
7 changes: 1 addition & 6 deletions cheetah/accelerator/space_charge_kick.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,7 @@ def __init__(
device=None,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[effect_length],
[], # TODO: Add grid_extend_{x,y,tau}, needs torch.Tensor default
device,
dtype,
)
device, dtype = verify_device_and_dtype([effect_length], device, dtype)
self.factory_kwargs = {"device": device, "dtype": dtype}

super().__init__(name=name)
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [voltage, phase, frequency, misalignment, tilt], device, dtype
[length, voltage, phase, frequency, misalignment, tilt], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/vertical_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
device=None,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([length], [angle], device, dtype)
device, dtype = verify_device_and_dtype([length, angle], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
14 changes: 7 additions & 7 deletions cheetah/particles/parameter_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from cheetah.particles.beam import Beam
from cheetah.particles.particle_beam import ParticleBeam
from cheetah.utils import (
extract_argument_device,
extract_argument_dtype,
are_all_the_same_device,
are_all_the_same_dtype,
extract_argument_shape,
verify_device_and_dtype,
)
Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[mu, cov, energy], [total_charge], device, dtype
[mu, cov, energy, total_charge], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -97,8 +97,8 @@ def from_parameters(
]

# Extract device and dtype from given arguments
device = device if device is not None else extract_argument_device(not_nones)
dtype = dtype if dtype is not None else extract_argument_dtype(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
Expand Down Expand Up @@ -232,8 +232,8 @@ def from_twiss(

# Extract shape, device and dtype from given arguments
shape = extract_argument_shape(not_nones)
device = device if device is not None else extract_argument_device(not_nones)
dtype = dtype if dtype is not None else extract_argument_dtype(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
Expand Down
22 changes: 11 additions & 11 deletions cheetah/particles/particle_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from cheetah.particles.beam import Beam
from cheetah.utils import (
are_all_the_same_device,
are_all_the_same_dtype,
elementwise_linspace,
extract_argument_device,
extract_argument_dtype,
extract_argument_shape,
verify_device_and_dtype,
)
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[particles, energy], [particle_charges], device, dtype
[particles, energy, particle_charges], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -134,8 +134,8 @@ def from_parameters(
]

# Extract device and dtype from given arguments
device = device if device is not None else extract_argument_device(not_nones)
dtype = dtype if dtype is not None else extract_argument_dtype(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
Expand Down Expand Up @@ -279,8 +279,8 @@ def from_twiss(

# Extract shape, device and dtype from given arguments
shape = extract_argument_shape(not_nones)
device = device if device is not None else extract_argument_device(not_nones)
dtype = dtype if dtype is not None else extract_argument_dtype(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
Expand Down Expand Up @@ -418,8 +418,8 @@ def uniform_3d_ellipsoid(

# Extract shape, device and dtype from given arguments
shape = extract_argument_shape(not_nones)
device = device if device is not None else extract_argument_device(not_nones)
dtype = dtype if dtype is not None else extract_argument_dtype(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
factory_kwargs = {"device": device, "dtype": dtype}

# Expand to vectorised version for beam creation
Expand Down Expand Up @@ -563,8 +563,8 @@ def make_linspaced(
]

# Extract device and dtype from given arguments
device = device if device is not None else extract_argument_device(not_nones)
dtype = dtype if dtype is not None else extract_argument_dtype(not_nones)
device = device if device is not None else are_all_the_same_device(not_nones)
dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones)
factory_kwargs = {"device": device, "dtype": dtype}

# Set default values without function call in function signature
Expand Down
4 changes: 2 additions & 2 deletions cheetah/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import bmadx # noqa: F401
from .argument_verification import ( # noqa: F401
extract_argument_device,
extract_argument_dtype,
are_all_the_same_device,
are_all_the_same_dtype,
extract_argument_shape,
verify_device_and_dtype,
)
Expand Down
67 changes: 39 additions & 28 deletions cheetah/utils/argument_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,65 @@
import torch


def extract_argument_device(arguments: list[torch.Tensor]) -> torch.device:
def are_all_the_same_device(tensors: list[torch.Tensor]) -> torch.device:
"""
Determines whether all arguments are on the same device and returns the default
pytorch device if no argumente are passed.
Determines whether all arguments are on the same device and, if so, returns that
device. If no arguments are passed, global default PyTorch device is returned.
"""
if len(arguments) > 1:
if len(tensors) > 1:
assert all(
argument.device == arguments[0].device for argument in arguments
), "Arguments must be on the same device."
argument.device == tensors[0].device for argument in tensors
), "All tensors must be on the same device."

return arguments[0].device if len(arguments) > 0 else torch.get_default_device()
return tensors[0].device if len(tensors) > 0 else torch.get_default_device()


def extract_argument_dtype(arguments: list[torch.Tensor]) -> torch.dtype:
def are_all_the_same_dtype(tensors: list[torch.Tensor]) -> torch.dtype:
"""
Determines whether all arguments have the same dtype and returns the default
pytorch dtype if no argumente are passed.
Determines whether all arguments have the same dtype and, if so, returns that dtype.
If no arguments are passed, global default PyTorch dtype is returned.
"""
if len(arguments) > 1:
if len(tensors) > 1:
assert all(
argument.dtype == arguments[0].dtype for argument in arguments
), "Arguments must have the same dtype."
argument.dtype == tensors[0].dtype for argument in tensors
), "All arguments must have the same dtype."

return arguments[0].dtype if len(arguments) > 0 else torch.get_default_dtype()
return tensors[0].dtype if len(tensors) > 0 else torch.get_default_dtype()


def extract_argument_shape(arguments: list[torch.Tensor]) -> torch.Size:
def extract_argument_shape(tensors: list[torch.Tensor]) -> torch.Size:
"""Determines whether all arguments have the same shape."""
if len(arguments) > 1:
if len(tensors) > 1:
assert all(
argument.shape == arguments[0].shape for argument in arguments
argument.shape == tensors[0].shape for argument in tensors
), "Arguments must have the same shape."

return arguments[0].shape if len(arguments) > 0 else torch.Size([1])
return tensors[0].shape if len(tensors) > 0 else torch.Size([1])


def verify_device_and_dtype(
required: list[torch.Tensor],
optionals: list[Optional[torch.Tensor]],
device: torch.device,
dtype: torch.dtype,
tensors: list[Optional[torch.Tensor]],
desired_device: Optional[torch.device],
desired_dtype: Optional[torch.dtype],
) -> tuple[torch.device, torch.dtype]:
"""
Verifies that all required & given optional arguments have the same device and
dtype if no defaults are provided.
Verifies that passed tensors (if they are tensors and not `None`) have the same
device and dtype and that that device and dtype are the same as the desired device
and dtype if they are requested.
If all verifications pass, this function returns the device and dtype shared by all
tensors.
"""
not_nones = required + [argument for argument in optionals if argument is not None]
not_nones = [tensor for tensor in tensors if tensor is not None]

device = device if device is not None else extract_argument_device(not_nones)
dtype = dtype if dtype is not None else extract_argument_dtype(not_nones)
return (device, dtype)
chosen_device = (
desired_device
if desired_device is not None
else are_all_the_same_device(not_nones)
)
chosen_dtype = (
desired_dtype
if desired_dtype is not None
else are_all_the_same_dtype(not_nones)
)
return (chosen_device, chosen_dtype)

0 comments on commit 38162c8

Please sign in to comment.