Skip to content

Commit

Permalink
Merge pull request #18 from desy-ml/pytorch-histogramdd
Browse files Browse the repository at this point in the history
Replace custom histogramdd() with torch.histogramdd()
  • Loading branch information
jank324 authored Feb 5, 2023
2 parents bf326a2 + 928b12b commit a5901fa
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 295 deletions.
86 changes: 58 additions & 28 deletions benchmark/cheetah/cheetah.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def reading(self):
image = dist.pdf(pos)
image = np.flipud(image.T)
elif isinstance(self.read_beam, ParticleBeam):
image, _ = utils.histogramdd(
image, _ = torch.histogramdd(
torch.stack((self.read_beam.xs, self.read_beam.ys)),
bins=self.pixel_bin_edges,
)
Expand Down
127 changes: 2 additions & 125 deletions cheetah/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def ocelot2cheetah(element, warnings=True):
elif isinstance(element, oc.Monitor) and "BSC" in element.id:
if warnings:
print(
"WARNING: Diagnostic screen was converted with default screen properties."
"WARNING: Diagnostic screen was converted with default screen"
" properties."
)
return acc.Screen((2448, 2040), (3.5488e-6, 2.5003e-6), name=element.id)
elif isinstance(element, oc.Monitor) and "BPM" in element.id:
Expand Down Expand Up @@ -141,127 +142,3 @@ def subcell_of_ocelot(cell, start, end):
break

return subcell


_range = range


def histogramdd(sample, bins=None, range=None, weights=None, remove_overflow=True):
"""
Pytorch version of n-dimensional histogram.
Taken from https://github.com/miranov25/RootInteractive/blob/b54446e09072e90e17f3da72d5244a20c8fdd209/RootInteractive/Tools/Histograms/histogramdd.py
"""
edges = None
device = None
custom_edges = False
D, N = sample.shape
if device == None:
device = sample.device
if bins == None:
if edges == None:
bins = 10
custom_edges = False
else:
try:
bins = edges.size(1) - 1
except AttributeError:
bins = torch.empty(D)
for i in _range(len(edges)):
bins[i] = edges[i].size(0) - 1
bins = bins.to(device)
custom_edges = True
try:
M = bins.size(0)
if M != D:
raise ValueError(
"The dimension of bins must be equal to the dimension of sample x."
)
except AttributeError:
# bins is either an integer or a list
if type(bins) == int:
bins = torch.full([D], bins, dtype=torch.long, device=device)
elif torch.is_tensor(bins[0]):
custom_edges = True
edges = bins
bins = torch.empty(D, dtype=torch.long)
for i in _range(len(edges)):
bins[i] = edges[i].size(0) - 1
bins = bins.to(device)
else:
bins = torch.as_tensor(bins)
if bins.dim() == 2:
custom_edges = True
edges = bins
bins = torch.full([D], bins.size(1) - 1, dtype=torch.long, device=device)
if custom_edges:
use_old_edges = False
if not torch.is_tensor(edges):
use_old_edges = True
edges_old = edges
m = max(i.size(0) for i in edges)
tmp = torch.empty([D, m], device=edges[0].device)
for i in _range(D):
s = edges[i].size(0)
tmp[i, :] = edges[i][-1]
tmp[i, :s] = edges[i][:]
edges = tmp.to(device)
k = torch.searchsorted(edges, sample)
k = torch.min(k, (bins + 1).reshape(-1, 1))
if use_old_edges:
edges = edges_old
else:
edges = torch.unbind(edges)
else:
if range == None: # range is not defined
range = torch.empty(2, D, device=device)
if N == 0: # Empty histogram
range[0, :] = 0
range[1, :] = 1
else:
range[0, :] = torch.min(sample, 1)[0]
range[1, :] = torch.max(sample, 1)[0]
elif not torch.is_tensor(range): # range is a tuple
r = torch.empty(2, D)
for i in _range(D):
if range[i] is not None:
r[:, i] = torch.as_tensor(range[i])
else:
if N == 0: # Edge case: empty histogram
r[0, i] = 0
r[1, i] = 1
r[0, i] = torch.min(sample[:, i])[0]
r[1, i] = torch.max(sample[:, i])[0]
range = r.to(device=device, dtype=sample.dtype)
singular_range = torch.eq(
range[0], range[1]
) # If the range consists of only one point, pad it up.
range[0, singular_range] -= 0.5
range[1, singular_range] += 0.5
edges = [
torch.linspace(range[0, i], range[1, i], bins[i] + 1)
for i in _range(len(bins))
]
tranges = torch.empty_like(range)
tranges[1, :] = bins / (range[1, :] - range[0, :])
tranges[0, :] = 1 - range[0, :] * tranges[1, :]
k = torch.addcmul(
tranges[0, :].reshape(-1, 1), sample, tranges[1, :].reshape(-1, 1)
).long() # Get the right index
k = torch.max(
k, torch.zeros([], device=device, dtype=torch.long)
) # Underflow bin
k = torch.min(k, (bins + 1).reshape(-1, 1))

multiindex = torch.ones_like(bins)
multiindex[1:] = torch.cumprod(torch.flip(bins[1:], [0]) + 2, -1).long()
multiindex = torch.flip(multiindex, [0])
l = torch.sum(k * multiindex.reshape(-1, 1), 0)
hist = torch.bincount(
l, minlength=(multiindex[0] * (bins[0] + 2)).item(), weights=weights
)
hist = hist.reshape(tuple(bins + 2))
if remove_overflow:
core = D * (slice(1, -1),)
hist = hist[core]
return hist, edges
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="cheetah-accelerator",
version="0.5.16",
version="0.5.17",
author="Jan Kaiser & Oliver Stein",
author_email="[email protected]",
url="https://github.com/desy-ml/cheetah",
Expand Down
107 changes: 69 additions & 38 deletions test/intro.ipynb

Large diffs are not rendered by default.

59 changes: 28 additions & 31 deletions test/ocelot_vs_joss.ipynb

Large diffs are not rendered by default.

67 changes: 32 additions & 35 deletions test/olivers_test.ipynb

Large diffs are not rendered by default.

93 changes: 67 additions & 26 deletions test/testmore.ipynb

Large diffs are not rendered by default.

21 changes: 11 additions & 10 deletions test/testtime.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"output_type": "stream",
"text": [
"[INFO ] : : \u001b[0mbeam.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
"[INFO ] : : \u001b[0mhigh_order.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
"[INFO ] : : : : : : : : : : : \u001b[0mradiation_py.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
"[INFO ] : : : : : : : : \u001b[0mhigh_order.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
"[INFO ] \u001b[0mradiation_py.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
"[INFO ] \u001b[0mradiation_py.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
"[INFO ] \u001b[0mcsr.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
"[INFO ] \u001b[0mcsr.py: module PYFFTW is not installed. Install it to speed up calculation.\u001b[0m\n",
"[INFO ] \u001b[0mcsr.py: module NUMEXPR is not installed. Install it to speed up calculation\u001b[0m\n",
Expand Down Expand Up @@ -48,8 +49,8 @@
"metadata": {},
"outputs": [],
"source": [
"beam1 = cheetah.ParameterBeam.from_astra(\"../distributions/ACHIP_EA1_2021.1351.001\")\n",
"beam2 = cheetah.ParticleBeam.from_astra(\"../distributions/ACHIP_EA1_2021.1351.001\")"
"beam1 = cheetah.ParameterBeam.from_astra(\"../benchmark/cheetah/ACHIP_EA1_2021.1351.001\")\n",
"beam2 = cheetah.ParticleBeam.from_astra(\"../benchmark/cheetah/ACHIP_EA1_2021.1351.001\")"
]
},
{
Expand Down Expand Up @@ -77,7 +78,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"91.7 µs ± 221 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"86.6 µs ± 816 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
Expand All @@ -95,7 +96,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"902 µs ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
"955 µs ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
Expand Down Expand Up @@ -127,7 +128,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"z = 42.34949999999999 / 42.34949999999999 : applied: 3.36 ms ± 1.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"z = 42.34949999999999 / 42.34949999999999. Applied: 2.82 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -146,7 +147,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('cheetah-test')",
"display_name": "rl39",
"language": "python",
"name": "python3"
},
Expand All @@ -160,12 +161,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "947d9ee3b458d99f0eb80dac10135f6d6a35887bd0ce9fb941e727a4631c373a"
"hash": "343fe3b89e2d7877d61a0509fd880204236e5c07449e4c121f53f2530ef83fc9"
}
}
},
Expand Down

0 comments on commit a5901fa

Please sign in to comment.