Skip to content

Commit

Permalink
Fix issue with CustomTransferMap saving and cloning
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Dec 4, 2024
1 parent 21023b9 commit cd5652a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 21 deletions.
29 changes: 13 additions & 16 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,31 @@ class CustomTransferMap(Element):

def __init__(
self,
transfer_map: torch.Tensor,
predefined_transfer_map: torch.Tensor,
length: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype)
device, dtype = verify_device_and_dtype(
[predefined_transfer_map, length], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

assert isinstance(transfer_map, torch.Tensor)
assert transfer_map.shape[-2:] == (7, 7)
assert isinstance(predefined_transfer_map, torch.Tensor)
assert predefined_transfer_map.shape[-2:] == (7, 7)

self.register_buffer(
"_transfer_map", torch.as_tensor(transfer_map, **factory_kwargs)
"predefined_transfer_map",
torch.as_tensor(predefined_transfer_map, **factory_kwargs),
)
self.register_buffer(
"length",
(
torch.as_tensor(length, **factory_kwargs)
if length is not None
else torch.zeros(transfer_map.shape[:-2], **factory_kwargs)
else torch.zeros(predefined_transfer_map.shape[:-2], **factory_kwargs)
),
)

Expand Down Expand Up @@ -83,33 +86,27 @@ def from_merging_elements(
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return self._transfer_map
return self.predefined_transfer_map

@property
def is_skippable(self) -> bool:
return True

def __repr__(self):
return (
f"{self.__class__.__name__}(transfer_map={repr(self._transfer_map)}, "
f"{self.__class__.__name__}("
+ f"predefined_transfer_map={repr(self.predefined_transfer_map)}, "
+ f"length={repr(self.length)}, "
+ f"name={repr(self.name)})"
)

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["transfer_map"]
return super().defining_features + ["length", "predefined_transfer_map"]

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]

def clone(self) -> "CustomTransferMap":
return CustomTransferMap(
transfer_map=self._transfer_map.clone(),
length=self.length.clone(),
name=self.name,
)

def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
plot_s = s[vector_idx] if s.dim() > 0 else s
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length
Expand Down
2 changes: 1 addition & 1 deletion cheetah/converters/elegant.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def convert_element(

return cheetah.CustomTransferMap(
length=torch.tensor(parsed["l"]),
transfer_map=R,
predefined_transfer_map=R,
name=name,
device=device,
dtype=dtype,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_elegant_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_custom_transfer_map_import():
]
)

assert torch.allclose(converted.c1e._transfer_map, correct_transfer_map)
assert torch.allclose(converted.c1e.predefined_transfer_map, correct_transfer_map)


@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_lattice_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_save_and_reload_custom_transfer_map(tmp_path):
never appears in the ARES lattice and must therefore be tested separately.
"""
custom_transfer_map_element = cheetah.CustomTransferMap(
transfer_map=torch.eye(7, 7),
predefined_transfer_map=torch.eye(7, 7),
length=torch.tensor(1.0),
name="my_custom_transfer_map_element",
)
Expand All @@ -61,8 +61,8 @@ def test_save_and_reload_custom_transfer_map(tmp_path):
reloaded_custom_transfer_map_element = reloaded_segment.elements[0]

assert torch.allclose(
custom_transfer_map_element._transfer_map,
reloaded_custom_transfer_map_element._transfer_map,
custom_transfer_map_element.predefined_transfer_map,
reloaded_custom_transfer_map_element.predefined_transfer_map,
)
assert torch.allclose(
custom_transfer_map_element.length, reloaded_custom_transfer_map_element.length
Expand Down

0 comments on commit cd5652a

Please sign in to comment.