From 38162c8b42ee9eb2b69d0a2f1671c97eae7b9a51 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 20 Nov 2024 13:40:23 +0100 Subject: [PATCH] Combine `verify_device_and_dtype` arguments and clean up naming --- cheetah/accelerator/aperture.py | 2 +- cheetah/accelerator/cavity.py | 2 +- cheetah/accelerator/custom_transfer_map.py | 2 +- cheetah/accelerator/dipole.py | 2 +- cheetah/accelerator/horizontal_corrector.py | 2 +- cheetah/accelerator/quadrupole.py | 2 +- cheetah/accelerator/screen.py | 6 +- cheetah/accelerator/solenoid.py | 2 +- cheetah/accelerator/space_charge_kick.py | 7 +- .../transverse_deflecting_cavity.py | 2 +- cheetah/accelerator/vertical_corrector.py | 2 +- cheetah/particles/parameter_beam.py | 14 ++-- cheetah/particles/particle_beam.py | 22 +++--- cheetah/utils/__init__.py | 4 +- cheetah/utils/argument_verification.py | 67 +++++++++++-------- 15 files changed, 70 insertions(+), 68 deletions(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 7214e49b..6f135e22 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -33,7 +33,7 @@ def __init__( device=None, dtype=None, ) -> None: - device, dtype = verify_device_and_dtype([], [x_max, y_max], device, dtype) + device, dtype = verify_device_and_dtype([x_max, y_max], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index e213a298..b6cb536a 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -43,7 +43,7 @@ def __init__( dtype=None, ) -> None: device, dtype = verify_device_and_dtype( - [length], [voltage, phase, frequency], device, dtype + [length, voltage, phase, frequency], device, dtype ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index abaeae3e..e8158a3a 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -25,7 +25,7 @@ def __init__( device=None, dtype=None, ) -> None: - device, dtype = verify_device_and_dtype([transfer_map], [length], device, dtype) + device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index bd803626..a787eec1 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -65,8 +65,8 @@ def __init__( dtype=None, ): device, dtype = verify_device_and_dtype( - [length], [ + length, angle, k1, e1, diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index a19aaa15..66c48892 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -35,7 +35,7 @@ def __init__( device=None, dtype=None, ) -> None: - device, dtype = verify_device_and_dtype([length], [angle], device, dtype) + device, dtype = verify_device_and_dtype([length, angle], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index b34a9e21..045c75a4 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -45,7 +45,7 @@ def __init__( dtype=None, ) -> None: device, dtype = verify_device_and_dtype( - [length], [k1, misalignment, tilt], device, dtype + [length, k1, misalignment, tilt], device, dtype ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 359a6d88..02a9d515 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -56,11 +56,7 @@ def __init__( dtype=None, ) -> None: device, dtype = verify_device_and_dtype( - [], # No required tensor arguments - # Excludes resolution and binning, since those are integer valued, not float - [pixel_size, misalignment, kde_bandwidth], - device, - dtype, + [pixel_size, misalignment, kde_bandwidth], device, dtype ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index ed7bc6e7..f86a90ff 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -43,7 +43,7 @@ def __init__( dtype=None, ) -> None: device, dtype = verify_device_and_dtype( - [length], [k, misalignment], device, dtype + [length, k, misalignment], device, dtype ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 8fb52e7c..a30f6dc2 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -59,12 +59,7 @@ def __init__( device=None, dtype=None, ) -> None: - device, dtype = verify_device_and_dtype( - [effect_length], - [], # TODO: Add grid_extend_{x,y,tau}, needs torch.Tensor default - device, - dtype, - ) + device, dtype = verify_device_and_dtype([effect_length], device, dtype) self.factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index de5ad9e1..6cdb00c9 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -47,7 +47,7 @@ def __init__( dtype=None, ) -> None: device, dtype = verify_device_and_dtype( - [length], [voltage, phase, frequency, misalignment, tilt], device, dtype + [length, voltage, phase, frequency, misalignment, tilt], device, dtype ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 4010c4f4..babc3f38 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -38,7 +38,7 @@ def __init__( device=None, dtype=None, ) -> None: - device, dtype = verify_device_and_dtype([length], [angle], device, dtype) + device, dtype = verify_device_and_dtype([length, angle], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index d17022dc..cf271ca9 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -6,8 +6,8 @@ from cheetah.particles.beam import Beam from cheetah.particles.particle_beam import ParticleBeam from cheetah.utils import ( - extract_argument_device, - extract_argument_dtype, + are_all_the_same_device, + are_all_the_same_dtype, extract_argument_shape, verify_device_and_dtype, ) @@ -35,7 +35,7 @@ def __init__( dtype=None, ) -> None: device, dtype = verify_device_and_dtype( - [mu, cov, energy], [total_charge], device, dtype + [mu, cov, energy, total_charge], device, dtype ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -97,8 +97,8 @@ def from_parameters( ] # Extract device and dtype from given arguments - device = device if device is not None else extract_argument_device(not_nones) - dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + device = device if device is not None else are_all_the_same_device(not_nones) + dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones) factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature @@ -232,8 +232,8 @@ def from_twiss( # Extract shape, device and dtype from given arguments shape = extract_argument_shape(not_nones) - device = device if device is not None else extract_argument_device(not_nones) - dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + device = device if device is not None else are_all_the_same_device(not_nones) + dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones) factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 3b2196a8..b084fe70 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -7,9 +7,9 @@ from cheetah.particles.beam import Beam from cheetah.utils import ( + are_all_the_same_device, + are_all_the_same_dtype, elementwise_linspace, - extract_argument_device, - extract_argument_dtype, extract_argument_shape, verify_device_and_dtype, ) @@ -42,7 +42,7 @@ def __init__( dtype=None, ) -> None: device, dtype = verify_device_and_dtype( - [particles, energy], [particle_charges], device, dtype + [particles, energy, particle_charges], device, dtype ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -134,8 +134,8 @@ def from_parameters( ] # Extract device and dtype from given arguments - device = device if device is not None else extract_argument_device(not_nones) - dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + device = device if device is not None else are_all_the_same_device(not_nones) + dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones) factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature @@ -279,8 +279,8 @@ def from_twiss( # Extract shape, device and dtype from given arguments shape = extract_argument_shape(not_nones) - device = device if device is not None else extract_argument_device(not_nones) - dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + device = device if device is not None else are_all_the_same_device(not_nones) + dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones) factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature @@ -418,8 +418,8 @@ def uniform_3d_ellipsoid( # Extract shape, device and dtype from given arguments shape = extract_argument_shape(not_nones) - device = device if device is not None else extract_argument_device(not_nones) - dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + device = device if device is not None else are_all_the_same_device(not_nones) + dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones) factory_kwargs = {"device": device, "dtype": dtype} # Expand to vectorised version for beam creation @@ -563,8 +563,8 @@ def make_linspaced( ] # Extract device and dtype from given arguments - device = device if device is not None else extract_argument_device(not_nones) - dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + device = device if device is not None else are_all_the_same_device(not_nones) + dtype = dtype if dtype is not None else are_all_the_same_dtype(not_nones) factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index e5c574a0..ff233a9c 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -1,7 +1,7 @@ from . import bmadx # noqa: F401 from .argument_verification import ( # noqa: F401 - extract_argument_device, - extract_argument_dtype, + are_all_the_same_device, + are_all_the_same_dtype, extract_argument_shape, verify_device_and_dtype, ) diff --git a/cheetah/utils/argument_verification.py b/cheetah/utils/argument_verification.py index 7c67b7a1..89263e70 100644 --- a/cheetah/utils/argument_verification.py +++ b/cheetah/utils/argument_verification.py @@ -3,54 +3,65 @@ import torch -def extract_argument_device(arguments: list[torch.Tensor]) -> torch.device: +def are_all_the_same_device(tensors: list[torch.Tensor]) -> torch.device: """ - Determines whether all arguments are on the same device and returns the default - pytorch device if no argumente are passed. + Determines whether all arguments are on the same device and, if so, returns that + device. If no arguments are passed, global default PyTorch device is returned. """ - if len(arguments) > 1: + if len(tensors) > 1: assert all( - argument.device == arguments[0].device for argument in arguments - ), "Arguments must be on the same device." + argument.device == tensors[0].device for argument in tensors + ), "All tensors must be on the same device." - return arguments[0].device if len(arguments) > 0 else torch.get_default_device() + return tensors[0].device if len(tensors) > 0 else torch.get_default_device() -def extract_argument_dtype(arguments: list[torch.Tensor]) -> torch.dtype: +def are_all_the_same_dtype(tensors: list[torch.Tensor]) -> torch.dtype: """ - Determines whether all arguments have the same dtype and returns the default - pytorch dtype if no argumente are passed. + Determines whether all arguments have the same dtype and, if so, returns that dtype. + If no arguments are passed, global default PyTorch dtype is returned. """ - if len(arguments) > 1: + if len(tensors) > 1: assert all( - argument.dtype == arguments[0].dtype for argument in arguments - ), "Arguments must have the same dtype." + argument.dtype == tensors[0].dtype for argument in tensors + ), "All arguments must have the same dtype." - return arguments[0].dtype if len(arguments) > 0 else torch.get_default_dtype() + return tensors[0].dtype if len(tensors) > 0 else torch.get_default_dtype() -def extract_argument_shape(arguments: list[torch.Tensor]) -> torch.Size: +def extract_argument_shape(tensors: list[torch.Tensor]) -> torch.Size: """Determines whether all arguments have the same shape.""" - if len(arguments) > 1: + if len(tensors) > 1: assert all( - argument.shape == arguments[0].shape for argument in arguments + argument.shape == tensors[0].shape for argument in tensors ), "Arguments must have the same shape." - return arguments[0].shape if len(arguments) > 0 else torch.Size([1]) + return tensors[0].shape if len(tensors) > 0 else torch.Size([1]) def verify_device_and_dtype( - required: list[torch.Tensor], - optionals: list[Optional[torch.Tensor]], - device: torch.device, - dtype: torch.dtype, + tensors: list[Optional[torch.Tensor]], + desired_device: Optional[torch.device], + desired_dtype: Optional[torch.dtype], ) -> tuple[torch.device, torch.dtype]: """ - Verifies that all required & given optional arguments have the same device and - dtype if no defaults are provided. + Verifies that passed tensors (if they are tensors and not `None`) have the same + device and dtype and that that device and dtype are the same as the desired device + and dtype if they are requested. + + If all verifications pass, this function returns the device and dtype shared by all + tensors. """ - not_nones = required + [argument for argument in optionals if argument is not None] + not_nones = [tensor for tensor in tensors if tensor is not None] - device = device if device is not None else extract_argument_device(not_nones) - dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) - return (device, dtype) + chosen_device = ( + desired_device + if desired_device is not None + else are_all_the_same_device(not_nones) + ) + chosen_dtype = ( + desired_dtype + if desired_dtype is not None + else are_all_the_same_dtype(not_nones) + ) + return (chosen_device, chosen_dtype)