Skip to content

Commit

Permalink
Merge pull request #308 from Hespe/fix-clone-warnings
Browse files Browse the repository at this point in the history
Fix warnings in test suite
  • Loading branch information
jank324 authored Dec 13, 2024
2 parents b92e3c9 + 286d09b commit 2713510
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- Fix plotting for segments that contain tensors with `require_grad=True` (see #288) (@hespe)
- Fix bug where `Element.length` could not be set as a `torch.nn.Parameter` (see #301) (@jank324, @hespe)
- Fix registration of `torch.nn.Parameter` at initilization for elements and beams (see #303) (@hespe)
- Fix warnings about NumPy deprecations and unintentional tensor clones (see #308) (@hespe)

### 🐆 Other

Expand Down
7 changes: 3 additions & 4 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
Expand Down Expand Up @@ -104,14 +103,14 @@ def __init__(
if fringe_integral is not None:
self.fringe_integral = torch.as_tensor(fringe_integral, **factory_kwargs)
self.fringe_integral_exit = (
torch.tensor(fringe_integral_exit, **factory_kwargs)
torch.as_tensor(fringe_integral_exit, **factory_kwargs)
if fringe_integral_exit is not None
else self.fringe_integral
)
if gap is not None:
self.gap = torch.as_tensor(gap, **factory_kwargs)
self.gap_exit = (
torch.tensor(gap_exit, **factory_kwargs)
torch.as_tensor(gap_exit, **factory_kwargs)
if gap_exit is not None
else self.gap
)
Expand Down Expand Up @@ -493,7 +492,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle

alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
height = 0.8 * (torch.sign(plot_angle) if self.is_active else 1)

patch = Rectangle(
(plot_s, 0), plot_length, height, color="tab:green", alpha=alpha, zorder=2
Expand Down
3 changes: 1 addition & 2 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle

Expand Down Expand Up @@ -88,7 +87,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle

alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
height = 0.8 * (torch.sign(plot_angle) if self.is_active else 1)

patch = Rectangle(
(plot_s, 0), plot_length, height, color="tab:blue", alpha=alpha, zorder=2
Expand Down
3 changes: 1 addition & 2 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
Expand Down Expand Up @@ -215,7 +214,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length

alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(plot_k1) if self.is_active else 1)
height = 0.8 * (torch.sign(plot_k1) if self.is_active else 1)
patch = Rectangle(
(plot_s, 0), plot_length, height, color="tab:red", alpha=alpha, zorder=2
)
Expand Down
3 changes: 1 addition & 2 deletions cheetah/accelerator/vertical_corrector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
Expand Down Expand Up @@ -91,7 +90,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle

alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
height = 0.8 * (torch.sign(plot_angle) if self.is_active else 1)

patch = Rectangle(
(plot_s, 0), plot_length, height, color="tab:cyan", alpha=alpha, zorder=2
Expand Down

0 comments on commit 2713510

Please sign in to comment.