-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Benchmarks for axes order #91
Labels
benchmark
Related to benchmarking something
Comments
lkct
added
enhancement
New feature or request
benchmark
Related to benchmarking something
labels
Jul 3, 2023
Book-keeping -- concatenation and indexing1-D indexing for simple folding, 2-D indexing for mixing (include a dimension for input) codeimport itertools
import subprocess
import sys
import time
from typing import List
import torch
from torch import Tensor
from benchmark.utils.gpu_benchmark import timer
device = "cuda"
B = 128
K = 512
F = 784
f = 288
sizes = {
"B": B,
"K": K,
"F": F,
"f": f,
"2": 2,
}
N = 200
def bench(inputs: List[str], idx_shape: str, dim: int) -> float:
input_data = [torch.rand([sizes[x] for x in y], device=device) for y in inputs]
idx_data = torch.randint(f, [sizes[x] for x in idx_shape], device=device)
idx = [idx_data if i == dim else slice(None) for i in range(len(inputs[0]))]
def proc() -> Tensor:
result = torch.cat(input_data, dim=dim)
result = result.__getitem__(idx)
return result
_, t = timer(proc)
return t
def run(inputs: str, idx_shape: str) -> None:
if "F" in inputs:
dim = inputs.index("F")
args = ([inputs], idx_shape, dim)
else:
dim = inputs.index("f")
args = ([inputs, inputs.replace("f", "F")], idx_shape, dim)
for _ in range(100): # long warm-up is essential?
bench(*args)
s = 0.0
for _ in range(N):
s += bench(*args)
print(f"{s / N * 1000:9.3f}", end="")
def main() -> None:
ans = []
for inputs in itertools.permutations("BKF"):
for idx_shape in ("f",):
inputs = "".join(inputs)
result = subprocess.run(
["python", "bench.py", inputs, idx_shape],
capture_output=True,
check=False,
text=True,
)
assert not result.returncode, result.stderr
ans.append(["i", inputs, "idx", idx_shape, result.stdout])
time.sleep(1)
print(ans[-1], file=sys.stderr)
ans.sort(key=lambda x: x[-1])
for item in ans:
print(*item)
if __name__ == "__main__":
if len(sys.argv) == 1:
main()
else:
run(*sys.argv[1:]) 1-concat, 1-D index
2-concat, 1-D indexcode differ by for inputs in itertools.permutations("BKf"):
0-concat, 2-D indexcode differ by def proc() -> Tensor:
result = input_data[0]
result = result.__getitem__(idx)
return result
...
for idx_shape in ("2f", "f2"):
|
CP-like layers -- 2-operand einsum, 3-D input, 2/3-D param, 3-D output2-D param for param shares among folding, 3-D for not codeimport itertools
import subprocess
import sys
import time
import torch
from torch import Tensor
from benchmark.utils.gpu_benchmark import timer
device = "cuda"
B = 128
K = 512
F = 288
R = 1
R = 128
R = 512
sizes = {
"B": B,
"K": K,
"F": F,
"R": R,
}
N = 200
def bench(inputs: str, params: str, outputs: str, order: str) -> float:
input_data = torch.rand([sizes[x] for x in inputs], device=device)
param_data = torch.rand([sizes[x] for x in params], device=device)
equation_ipo = f"{inputs},{params}->{outputs}"
equation_pio = f"{params},{inputs}->{outputs}"
def proc_ipo() -> Tensor:
result = torch.einsum(equation_ipo, input_data, param_data)
return result.contiguous() # penalize uncontiguous
def proc_pio() -> Tensor:
result = torch.einsum(equation_pio, param_data, input_data)
return result.contiguous() # penalize uncontiguous
proc = proc_ipo if order == "ipo" else proc_pio
_, t = timer(proc)
return t
def run(inputs: str, params: str, outputs: str, order: str) -> None:
args = (inputs, params, outputs, order)
for _ in range(100): # long warm-up is essential?
bench(*args)
s = 0.0
for _ in range(N):
s += bench(*args)
print(order, f"{s / N * 1000:9.3f}", sep=",", end=",")
def main() -> None:
ans = []
for inputs in itertools.permutations("BKF"):
for params in itertools.permutations("KRF"): # "KR" for w/o F
inputs = "".join(inputs)
params = "".join(params)
outputs = inputs.replace("K", "R")
ans.append(["i", inputs, "p", params, "o", outputs])
for order in ("ipo", "pio"):
result = subprocess.run(
["python", "bench.py", inputs, params, outputs, order],
capture_output=True,
check=True,
text=True,
)
ans[-1].extend(result.stdout.split(",")[:-1])
time.sleep(1)
print(ans[-1], file=sys.stderr)
ans.sort(key=lambda x: min(x[-1], x[-3]))
for item in ans:
print(*item)
if __name__ == "__main__":
if len(sys.argv) == 1:
main()
else:
run(*sys.argv[1:]) param w/
|
input | param | output | ipo | pio |
---|---|---|---|---|
FBK | KRF | FBR | 113.788 | 119.096 |
FBK | KFR | FBR | 113.789 | 119.216 |
FBK | RKF | FBR | 113.912 | 119.255 |
FBK | FRK | FBR | 114.691 | 114.610 |
FBK | RFK | FBR | 114.623 | 114.810 |
FBK | FKR | FBR | 114.643 | 114.854 |
FKB | FKR | FRB | 117.210 | 117.104 |
FKB | FRK | FRB | 117.171 | 117.119 |
FKB | RFK | FRB | 117.231 | 117.146 |
FKB | RKF | FRB | 117.993 | 120.959 |
FKB | KRF | FRB | 118.009 | 120.986 |
FKB | KFR | FRB | 118.018 | 121.041 |
KFB | RFK | RFB | 118.541 | 118.592 |
KFB | FKR | RFB | 118.595 | 118.558 |
KFB | FRK | RFB | 118.566 | 118.568 |
KFB | KFR | RFB | 119.701 | 122.484 |
KFB | KRF | RFB | 119.773 | 122.480 |
KFB | RKF | RFB | 119.831 | 122.524 |
BFK | FKR | BFR | 119.897 | 119.930 |
BFK | RFK | BFR | 119.958 | 121.007 |
BFK | FRK | BFR | 119.960 | 120.003 |
BFK | RKF | BFR | 127.243 | 124.769 |
BFK | KFR | BFR | 127.366 | 124.811 |
BFK | KRF | BFR | 127.302 | 124.972 |
BKF | KFR | BRF | 1801.846 | 1957.580 |
KBF | RFK | RBF | 1955.285 | 1803.298 |
BKF | RFK | BRF | 1803.699 | 1947.905 |
KBF | FRK | RBF | 1955.196 | 1804.161 |
BKF | RKF | BRF | 1804.927 | 1959.471 |
BKF | FRK | BRF | 1805.174 | 1954.135 |
BKF | KRF | BRF | 1805.285 | 1951.911 |
BKF | FKR | BRF | 1806.474 | 1946.075 |
KBF | KFR | RBF | 1951.512 | 1809.683 |
KBF | RKF | RBF | 1953.504 | 1810.281 |
KBF | FKR | RBF | 1952.605 | 1811.003 |
KBF | KRF | RBF | 1953.056 | 1831.348 |
param w/ F
-- R
=128
input | param | output | ipo | pio |
---|---|---|---|---|
FKB | FKR | FRB | 332.943 | 271.832 |
FKB | KFR | FRB | 334.048 | 273.649 |
FKB | FRK | FRB | 350.555 | 290.169 |
FBK | FKR | FBR | 290.291 | 350.455 |
FBK | KFR | FBR | 291.597 | 353.179 |
FKB | RFK | FRB | 354.897 | 294.839 |
KFB | KFR | RFB | 367.332 | 329.672 |
KFB | FKR | RFB | 367.710 | 329.767 |
KFB | FRK | RFB | 385.651 | 347.770 |
BFK | FKR | BFR | 349.022 | 388.739 |
KFB | RFK | RFB | 388.376 | 349.378 |
BFK | KFR | BFR | 351.018 | 388.244 |
FBK | FRK | FBR | 403.333 | 464.001 |
FBK | RFK | FBR | 425.401 | 486.442 |
BFK | FRK | BFR | 476.782 | 515.319 |
BFK | RFK | BFR | 530.967 | 568.218 |
FBK | KRF | FBR | 1985.983 | 2309.411 |
FKB | RKF | FRB | 2168.719 | 1988.884 |
FKB | KRF | FRB | 2009.858 | 2123.959 |
KFB | RKF | RFB | 2200.865 | 2034.394 |
KBF | KFR | RBF | 2265.354 | 2035.632 |
KBF | FKR | RBF | 2265.921 | 2037.660 |
BFK | KRF | BFR | 2049.831 | 2367.700 |
BKF | KFR | BRF | 2051.550 | 2241.513 |
KBF | FRK | RBF | 2378.948 | 2053.770 |
BKF | FKR | BRF | 2054.935 | 2239.223 |
KBF | RFK | RBF | 2391.599 | 2058.657 |
KFB | KRF | RFB | 2065.316 | 2191.561 |
FBK | RKF | FBR | 2121.255 | 2163.598 |
BKF | FRK | BRF | 2184.023 | 2255.856 |
BFK | RKF | BFR | 2188.250 | 2199.440 |
BKF | RFK | BRF | 2204.185 | 2259.588 |
KBF | RKF | RBF | 4100.396 | 3748.783 |
BKF | KRF | BRF | 3752.156 | 4108.266 |
BKF | RKF | BRF | 3912.224 | 3961.133 |
KBF | KRF | RBF | 3948.498 | 3917.926 |
param w/ F
-- R
=512
input | param | output | ipo | pio |
---|---|---|---|---|
FKB | KFR | FRB | 1187.450 | 943.467 |
FKB | FKR | FRB | 1180.317 | 948.735 |
FBK | FKR | FBR | 964.116 | 1225.254 |
FBK | KFR | FBR | 964.148 | 1228.849 |
FKB | FRK | FRB | 1247.073 | 1004.016 |
FKB | RFK | FRB | 1278.828 | 1028.722 |
FBK | FRK | FBR | 1054.420 | 1308.914 |
FBK | RFK | FBR | 1093.861 | 1352.954 |
KFB | FKR | RFB | 1532.598 | 1152.051 |
KFB | KFR | RFB | 1533.872 | 1156.363 |
BFK | KFR | BFR | 1168.978 | 2314.418 |
BFK | FKR | BFR | 1172.014 | 2332.021 |
KFB | FRK | RFB | 1565.030 | 1215.263 |
KFB | RFK | RFB | 1594.081 | 1252.846 |
BFK | FRK | BFR | 1264.582 | 2425.236 |
BFK | RFK | BFR | 1719.074 | 2866.844 |
BKF | FKR | BRF | 3362.928 | 5286.904 |
BKF | KFR | BRF | 3364.260 | 5370.066 |
KBF | KFR | RBF | 3595.469 | 3372.718 |
KBF | FKR | RBF | 3601.643 | 3379.303 |
KBF | FRK | RBF | 3699.549 | 3445.736 |
KBF | RFK | RBF | 3749.442 | 3460.732 |
BKF | FRK | BRF | 3473.244 | 5420.293 |
BKF | RFK | BRF | 3512.626 | 5408.833 |
FBK | KRF | FBR | 11434.301 | 12477.709 |
FKB | RKF | FRB | 12378.257 | 11467.728 |
BFK | KRF | BFR | 11633.416 | 13637.055 |
FKB | KRF | FRB | 11635.588 | 12212.634 |
KFB | RKF | RFB | 12772.965 | 11691.638 |
FBK | RKF | FBR | 12164.019 | 11777.333 |
KFB | KRF | RFB | 12122.000 | 12442.438 |
BFK | RKF | BFR | 12386.609 | 12970.214 |
BKF | KRF | BRF | 13848.587 | 16703.520 |
KBF | RKF | RBF | 14832.411 | 14118.670 |
KBF | KRF | RBF | 14172.575 | 14668.185 |
BKF | RKF | BRF | 14619.742 | 16115.491 |
param w/o F
-- R
=1
input | param | output | ipo | pio |
---|---|---|---|---|
KFB | KR | RFB | 116.458 | 116.501 |
KBF | KR | RBF | 119.778 | 116.487 |
KBF | RK | RBF | 116.583 | 116.493 |
KFB | RK | RFB | 116.495 | 116.550 |
FBK | RK | FBR | 117.029 | 117.115 |
BFK | KR | BFR | 117.116 | 117.057 |
BFK | RK | BFR | 117.084 | 117.155 |
FBK | KR | FBR | 117.092 | 117.174 |
FKB | RK | FRB | 358.544 | 339.967 |
FKB | KR | FRB | 358.303 | 340.043 |
BKF | RK | BRF | 365.138 | 350.999 |
BKF | KR | BRF | 365.350 | 351.294 |
param w/o F
-- R
=128
input | param | output | ipo | pio |
---|---|---|---|---|
KFB | KR | RFB | 336.120 | 221.956 |
KBF | KR | RBF | 335.887 | 227.228 |
KBF | RK | RBF | 341.490 | 229.775 |
KFB | RK | RFB | 342.979 | 230.091 |
BFK | KR | BFR | 248.398 | 344.448 |
FBK | KR | FBR | 248.870 | 344.915 |
FBK | RK | FBR | 253.334 | 349.901 |
BFK | RK | BFR | 255.264 | 351.808 |
FKB | KR | FRB | 559.920 | 496.469 |
FKB | RK | FRB | 568.355 | 502.368 |
BKF | KR | BRF | 585.876 | 507.658 |
BKF | RK | BRF | 592.483 | 513.791 |
param w/o F
-- R
=512
input | param | output | ipo | pio |
---|---|---|---|---|
KFB | KR | RFB | 1418.012 | 811.970 |
KBF | KR | RBF | 1446.548 | 814.111 |
BFK | KR | BFR | 829.528 | 1316.020 |
FBK | KR | FBR | 831.009 | 1311.129 |
KFB | RK | RFB | 1512.291 | 839.328 |
KBF | RK | RBF | 1511.988 | 840.207 |
BFK | RK | BFR | 1001.170 | 1482.165 |
FBK | RK | FBR | 1010.110 | 1475.463 |
FKB | KR | FRB | 1316.934 | 1228.216 |
BKF | KR | BRF | 1359.575 | 1237.370 |
FKB | RK | FRB | 1505.308 | 1259.962 |
BKF | RK | BRF | 1530.362 | 1270.497 |
Mixing (sum) layers -- 2-operand einsum, 4-D input, 3-D param, 3-D outputcodediffer from above by B = 128
K = 512
C = 2
F = 32
F = 288
sizes = {
"B": B,
"K": K,
"F": F,
"C": C,
}
...
for inputs in itertools.permutations("BKFC"):
for params in itertools.permutations("KFC"):
inputs = "".join(inputs)
params = "".join(params)
outputs = inputs.replace("C", "") result --
|
input | param | output | ipo | pio |
---|---|---|---|---|
FKCB | FKC | FKB | 367.531 | 367.391 |
KFCB | KFC | KFB | 367.505 | 367.414 |
FKCB | CFK | FKB | 371.657 | 368.825 |
CFKB | FKC | FKB | 370.326 | 369.524 |
FKCB | KFC | FKB | 371.624 | 369.978 |
KFCB | KCF | KFB | 370.246 | 370.936 |
KFCB | CKF | KFB | 370.284 | 373.721 |
KFCB | FKC | KFB | 371.914 | 370.394 |
CKFB | KFC | KFB | 370.408 | 370.402 |
FKCB | FCK | FKB | 370.740 | 370.582 |
CFKB | FCK | FKB | 373.479 | 371.812 |
KFCB | FCK | KFB | 373.505 | 372.520 |
FKCB | CKF | FKB | 372.763 | 373.647 |
KFCB | CFK | KFB | 373.185 | 373.923 |
FKCB | KCF | FKB | 373.819 | 373.542 |
CKFB | KCF | KFB | 373.818 | 373.547 |
CKFB | CKF | KFB | 373.556 | 373.685 |
CFKB | CFK | FKB | 376.040 | 373.648 |
CFKB | KFC | FKB | 373.743 | 375.004 |
CKFB | FKC | KFB | 375.028 | 378.774 |
CFKB | KCF | FKB | 376.018 | 375.233 |
CFKB | CKF | FKB | 376.886 | 375.875 |
CKFB | CFK | KFB | 377.047 | 376.714 |
CKFB | FCK | KFB | 377.676 | 376.757 |
KCFB | KFC | KFB | 1540.021 | 810.576 |
FCKB | FKC | FKB | 1540.502 | 811.118 |
KCFB | FKC | KFB | 1547.769 | 813.003 |
FCKB | KFC | FKB | 1546.890 | 813.744 |
KCFB | CKF | KFB | 1579.278 | 813.753 |
FCKB | CFK | FKB | 1586.879 | 814.755 |
KCFB | CFK | KFB | 1567.768 | 815.677 |
FCKB | KCF | FKB | 1556.495 | 815.865 |
KCFB | FCK | KFB | 1545.086 | 815.905 |
KCFB | KCF | KFB | 1543.966 | 815.987 |
FCKB | CKF | FKB | 1550.423 | 817.128 |
FCKB | FCK | FKB | 1554.190 | 817.297 |
FCBK | FKC | FBK | 1827.928 | 1088.175 |
FCBK | FCK | FBK | 1831.495 | 1090.527 |
FCBK | CFK | FBK | 1889.879 | 1090.808 |
FCBK | KFC | FBK | 1834.945 | 1091.186 |
FCBK | CKF | FBK | 1837.739 | 1091.674 |
FCBK | KCF | FBK | 1838.717 | 1091.918 |
FKBC | FKC | FKB | 1103.236 | 1094.623 |
CFBK | FKC | FBK | 1844.333 | 1094.778 |
CFBK | FCK | FBK | 1843.633 | 1097.725 |
CFBK | KFC | FBK | 1843.886 | 1098.803 |
KFBC | FKC | KFB | 1099.041 | 1099.132 |
CFBK | CFK | FBK | 1877.116 | 1099.701 |
KFBC | CKF | KFB | 1151.676 | 1099.901 |
KFBC | KFC | KFB | 1115.185 | 1100.159 |
KFBC | KCF | KFB | 1108.606 | 1100.315 |
FKBC | KCF | FKB | 1124.793 | 1100.353 |
FKBC | CKF | FKB | 1114.925 | 1100.431 |
CFBK | KCF | FBK | 1855.247 | 1100.536 |
FKBC | KFC | FKB | 1100.970 | 1112.996 |
CFBK | CKF | FBK | 1841.919 | 1101.455 |
FKBC | FCK | FKB | 1102.943 | 1113.840 |
KFBC | FCK | KFB | 1104.292 | 1122.477 |
KFBC | CFK | KFB | 1123.885 | 1115.799 |
KBFC | CKF | KBF | 1846.746 | 1118.036 |
KBFC | CFK | KBF | 1813.970 | 1119.177 |
FKBC | CFK | FKB | 1135.750 | 1119.635 |
KBFC | KFC | KBF | 1809.347 | 1119.766 |
KCBF | KFC | KBF | 1830.465 | 1122.838 |
KBFC | KCF | KBF | 1814.562 | 1123.999 |
KBFC | FKC | KBF | 1813.773 | 1124.754 |
KCBF | KCF | KBF | 1836.964 | 1125.530 |
KCBF | FKC | KBF | 1837.573 | 1126.583 |
KBFC | FCK | KBF | 1815.755 | 1127.580 |
KCBF | CKF | KBF | 1870.705 | 1127.629 |
CKBF | CFK | KBF | 1847.998 | 1128.662 |
KCBF | FCK | KBF | 1836.367 | 1128.818 |
KCBF | CFK | KBF | 1840.175 | 1128.863 |
CKBF | KFC | KBF | 1844.957 | 1130.018 |
CKBF | CKF | KBF | 1874.794 | 1130.837 |
CKBF | KCF | KBF | 1843.871 | 1132.791 |
CKBF | FCK | KBF | 1843.299 | 1133.625 |
CKBF | FKC | KBF | 1848.352 | 1135.878 |
FBKC | FCK | FBK | 1810.536 | 1137.477 |
FBKC | FKC | FBK | 1805.791 | 1137.899 |
FBKC | KFC | FBK | 1856.439 | 1139.148 |
FBKC | KCF | FBK | 1812.757 | 1141.217 |
FBKC | CFK | FBK | 1846.279 | 1142.776 |
FBKC | CKF | FBK | 1817.862 | 1144.043 |
KBCF | KFC | KBF | 1835.764 | 1156.568 |
KBCF | KCF | KBF | 1841.768 | 1157.567 |
KBCF | FKC | KBF | 1854.994 | 1160.787 |
KBCF | CFK | KBF | 1837.696 | 1162.510 |
KBCF | CKF | KBF | 1866.511 | 1162.532 |
KBCF | FCK | KBF | 1836.835 | 1162.555 |
FBCK | FKC | FBK | 1830.111 | 1178.649 |
FBCK | FCK | FBK | 1836.005 | 1180.750 |
FBCK | KFC | FBK | 1853.291 | 1181.822 |
FBCK | KCF | FBK | 1837.737 | 1181.836 |
FBCK | CFK | FBK | 1868.068 | 1182.880 |
FBCK | CKF | FBK | 1840.923 | 1183.692 |
BKFC | KFC | BKF | 2691.465 | 2688.253 |
BFKC | FKC | BFK | 2694.193 | 2695.982 |
BKFC | KCF | BKF | 2695.214 | 2697.435 |
BFKC | CFK | BFK | 2741.626 | 2696.012 |
BFKC | KFC | BFK | 2704.749 | 2700.642 |
BKFC | FKC | BKF | 2701.903 | 2704.405 |
BFKC | CKF | BFK | 2702.773 | 2704.576 |
BFKC | FCK | BFK | 2703.092 | 2703.769 |
BFKC | KCF | BFK | 2704.245 | 2705.473 |
BKFC | FCK | BKF | 2704.247 | 2707.089 |
BKFC | CKF | BKF | 2730.353 | 2708.528 |
BKFC | CFK | BKF | 2719.952 | 2709.968 |
CBKF | KFC | BKF | 4263.530 | 3587.400 |
CBFK | FKC | BFK | 4238.395 | 3591.081 |
CBKF | FKC | BKF | 4269.617 | 3592.677 |
CBKF | KCF | BKF | 4258.191 | 3593.320 |
CBFK | KFC | BFK | 4248.524 | 3594.745 |
CBKF | CFK | BKF | 4258.850 | 3594.814 |
CBKF | CKF | BKF | 4283.232 | 3595.522 |
CBFK | CKF | BFK | 4249.298 | 3596.531 |
CBFK | CFK | BFK | 4304.656 | 3596.606 |
CBFK | FCK | BFK | 4257.009 | 3597.092 |
CBKF | FCK | BKF | 4251.767 | 3597.939 |
CBFK | KCF | BFK | 4248.681 | 3599.644 |
BKCF | FKC | BKF | 3736.378 | 3949.921 |
BKCF | FCK | BKF | 3738.488 | 3961.672 |
BKCF | CFK | BKF | 3740.531 | 3956.702 |
BFCK | KFC | BFK | 3747.073 | 3967.060 |
BFCK | KCF | BFK | 3749.707 | 3968.432 |
BFCK | FCK | BFK | 3752.576 | 3970.594 |
BFCK | FKC | BFK | 3753.589 | 3955.989 |
BKCF | KCF | BKF | 3753.634 | 3955.264 |
BKCF | KFC | BKF | 3757.144 | 3971.041 |
BFCK | CKF | BFK | 3762.168 | 3967.541 |
BKCF | CKF | BKF | 3772.368 | 3957.740 |
BFCK | CFK | BFK | 3793.712 | 3971.665 |
BCFK | FCK | BFK | 4335.131 | 3994.408 |
BCKF | KFC | BKF | 4330.439 | 3994.670 |
BCKF | FKC | BKF | 4336.629 | 3995.968 |
BCFK | KFC | BFK | 4329.605 | 3998.567 |
BCFK | KCF | BFK | 4329.171 | 3999.583 |
BCFK | CFK | BFK | 4368.991 | 4003.686 |
BCKF | FCK | BKF | 4352.477 | 4005.444 |
BCFK | FKC | BFK | 4334.350 | 4009.904 |
BCKF | CFK | BKF | 4333.976 | 4010.229 |
BCKF | CKF | BKF | 4371.505 | 4011.900 |
BCKF | KCF | BKF | 4331.937 | 4013.636 |
BCFK | CKF | BFK | 4335.028 | 4013.674 |
result -- F
=32
input | param | output | ipo | pio |
---|---|---|---|---|
CFKB | FKC | FKB | 79.598 | 79.153 |
CKFB | KFC | KFB | 79.566 | 81.817 |
FKCB | FKC | FKB | 79.955 | 97.287 |
CKFB | CKF | KFB | 80.210 | 92.526 |
KFCB | KFC | KFB | 80.294 | 80.422 |
CFKB | CFK | FKB | 80.497 | 99.437 |
FKCB | CFK | FKB | 80.926 | 91.666 |
KFCB | CKF | KFB | 82.341 | 90.974 |
FKCB | FCK | FKB | 92.289 | 92.344 |
CKFB | FCK | KFB | 92.920 | 92.574 |
KFCB | FCK | KFB | 93.374 | 92.578 |
CKFB | FKC | KFB | 92.658 | 94.965 |
KFCB | CFK | KFB | 93.099 | 92.690 |
CKFB | KCF | KFB | 92.756 | 110.558 |
CKFB | CFK | KFB | 92.970 | 92.800 |
KFCB | FKC | KFB | 93.085 | 92.824 |
KFCB | KCF | KFB | 92.919 | 92.848 |
CFKB | KCF | FKB | 94.191 | 92.948 |
CFKB | KFC | FKB | 93.094 | 93.425 |
CFKB | CKF | FKB | 94.684 | 93.358 |
CFKB | FCK | FKB | 93.367 | 93.559 |
FKCB | KFC | FKB | 93.462 | 93.947 |
FKCB | CKF | FKB | 94.052 | 94.461 |
FKCB | KCF | FKB | 94.733 | 94.473 |
KCFB | KFC | KFB | 211.477 | 130.044 |
FCKB | FKC | FKB | 210.620 | 130.102 |
KCFB | CKF | KFB | 215.384 | 132.373 |
FCKB | CFK | FKB | 215.085 | 133.753 |
KCFB | KCF | KFB | 213.093 | 137.435 |
KCFB | CFK | KFB | 214.065 | 137.480 |
FCKB | KFC | FKB | 213.780 | 137.704 |
KCFB | FCK | KFB | 213.443 | 137.870 |
KCFB | FKC | KFB | 213.155 | 137.957 |
FCKB | FCK | FKB | 214.328 | 138.155 |
FCKB | KCF | FKB | 214.054 | 138.580 |
FCKB | CKF | FKB | 213.860 | 138.975 |
KBFC | KFC | KBF | 238.672 | 157.704 |
KBCF | KFC | KBF | 240.246 | 158.713 |
KCBF | KFC | KBF | 255.124 | 158.941 |
CKBF | KFC | KBF | 241.658 | 160.352 |
FKBC | FKC | FKB | 161.263 | 160.596 |
KFBC | KFC | KFB | 160.904 | 161.044 |
KBFC | CKF | KBF | 243.652 | 161.269 |
KBCF | CKF | KBF | 246.309 | 162.309 |
FCBK | FKC | FBK | 243.916 | 162.420 |
FKBC | CFK | FKB | 162.771 | 172.927 |
CKBF | CKF | KBF | 247.095 | 162.829 |
KCBF | CKF | KBF | 246.178 | 163.045 |
KFBC | CKF | KFB | 163.281 | 173.811 |
CFBK | FKC | FBK | 246.422 | 164.112 |
FBKC | FKC | FBK | 242.724 | 165.993 |
CFBK | CFK | FBK | 250.655 | 166.025 |
FCBK | CFK | FBK | 249.122 | 166.101 |
KCBF | FKC | KBF | 244.671 | 166.421 |
KBCF | CFK | KBF | 243.866 | 166.584 |
KBFC | FKC | KBF | 242.311 | 166.680 |
KCBF | CFK | KBF | 244.461 | 166.827 |
KCBF | KCF | KBF | 244.174 | 166.907 |
KCBF | FCK | KBF | 244.131 | 166.992 |
KBFC | KCF | KBF | 241.301 | 167.124 |
KBCF | FCK | KBF | 243.406 | 167.519 |
CKBF | KCF | KBF | 244.784 | 167.618 |
KBCF | KCF | KBF | 243.740 | 167.685 |
KBFC | FCK | KBF | 241.152 | 167.970 |
KBFC | CFK | KBF | 241.688 | 168.077 |
CKBF | FCK | KBF | 244.305 | 168.107 |
FBCK | FKC | FBK | 244.686 | 168.707 |
KBCF | FKC | KBF | 244.034 | 169.067 |
CKBF | CFK | KBF | 245.370 | 169.176 |
FBKC | CFK | FBK | 247.511 | 169.392 |
BKCF | KFC | BKF | 248.457 | 169.842 |
BFCK | FKC | BFK | 248.430 | 169.880 |
CKBF | FKC | KBF | 244.315 | 169.925 |
FCBK | KCF | FBK | 247.931 | 170.809 |
FCBK | CKF | FBK | 256.509 | 170.834 |
CFBK | KCF | FBK | 249.278 | 171.005 |
CFBK | KFC | FBK | 249.086 | 171.034 |
FBCK | CFK | FBK | 249.342 | 171.165 |
CFBK | FCK | FBK | 249.204 | 171.501 |
CFBK | CKF | FBK | 248.849 | 172.072 |
FCBK | KFC | FBK | 248.645 | 172.564 |
FBKC | KCF | FBK | 268.819 | 172.959 |
FKBC | FCK | FKB | 173.119 | 174.640 |
FBKC | KFC | FBK | 245.996 | 173.209 |
FBKC | FCK | FBK | 245.642 | 173.410 |
FBKC | CKF | FBK | 246.213 | 173.414 |
KFBC | FKC | KFB | 173.680 | 175.566 |
FKBC | KCF | FKB | 173.894 | 174.260 |
BKCF | CKF | BKF | 251.848 | 173.899 |
KFBC | KCF | KFB | 173.975 | 176.086 |
FKBC | KFC | FKB | 173.998 | 174.837 |
KFBC | FCK | KFB | 174.170 | 174.683 |
KFBC | CFK | KFB | 174.195 | 175.851 |
FKBC | CKF | FKB | 174.294 | 175.056 |
CBFK | FKC | BFK | 258.601 | 174.520 |
BFCK | CFK | BFK | 253.164 | 174.660 |
BCFK | FKC | BFK | 259.629 | 174.748 |
CBKF | KFC | BKF | 258.871 | 175.114 |
BCKF | KFC | BKF | 259.257 | 175.316 |
BCFK | CFK | BFK | 263.828 | 176.710 |
FBCK | CKF | FBK | 247.714 | 176.801 |
FBCK | KCF | FBK | 247.170 | 177.100 |
BCKF | CKF | BKF | 265.588 | 177.627 |
FBCK | FCK | FBK | 247.046 | 178.085 |
FBCK | KFC | FBK | 247.356 | 178.092 |
CBFK | CFK | BFK | 264.375 | 178.169 |
BKCF | CFK | BKF | 250.636 | 178.263 |
BKCF | FCK | BKF | 250.638 | 178.286 |
CBKF | CKF | BKF | 263.181 | 178.454 |
BFCK | KCF | BFK | 252.320 | 178.705 |
BKCF | KCF | BKF | 250.549 | 178.750 |
BFCK | KFC | BFK | 251.316 | 179.074 |
BFCK | CKF | BFK | 251.315 | 179.160 |
BKCF | FKC | BKF | 250.134 | 180.515 |
CBFK | KFC | BFK | 266.976 | 182.966 |
FCBK | FCK | FBK | 247.765 | 183.024 |
BCFK | KFC | BFK | 268.268 | 183.291 |
BCKF | KCF | BKF | 267.738 | 183.309 |
CBFK | KCF | BFK | 267.570 | 183.435 |
BCKF | FKC | BKF | 269.241 | 183.463 |
BCKF | CFK | BKF | 268.594 | 183.560 |
CBKF | CFK | BKF | 267.082 | 183.668 |
CBKF | FKC | BKF | 267.374 | 183.678 |
CBKF | FCK | BKF | 266.321 | 183.830 |
BCFK | CKF | BFK | 269.935 | 183.875 |
CBFK | CKF | BFK | 275.909 | 184.310 |
CBKF | KCF | BKF | 267.492 | 184.363 |
BCFK | KCF | BFK | 267.446 | 184.509 |
BCFK | FCK | BFK | 268.312 | 184.614 |
CBFK | FCK | BFK | 266.665 | 184.909 |
BCKF | FCK | BKF | 267.597 | 185.701 |
BFCK | FCK | BFK | 250.912 | 189.132 |
BFKC | FKC | BFK | 207.639 | 207.256 |
BKFC | KFC | BKF | 208.248 | 275.361 |
BFKC | CFK | BFK | 209.477 | 217.516 |
BKFC | CKF | BKF | 209.764 | 219.431 |
BFKC | FCK | BFK | 218.109 | 219.913 |
BKFC | FKC | BKF | 218.851 | 219.514 |
BFKC | CKF | BFK | 219.123 | 218.918 |
BFKC | KCF | BFK | 218.981 | 219.248 |
BKFC | CFK | BKF | 220.066 | 219.263 |
BFKC | KFC | BFK | 220.841 | 219.490 |
BKFC | KCF | BKF | 220.788 | 219.631 |
BKFC | FCK | BKF | 220.577 | 220.304 |
Tucker (EinNet) layers -- 3-operand einsum, 2x 3-D input, 4-D param, 3-D output
codediffer from CP by K = 64
sizes = {
"B": B,
"K": K,
"F": F,
"I": K,
"J": K,
}
...
def bench(inputs: str, inputsj: str, params: str, outputs: str, order: str) -> float:
input_data = torch.rand([sizes[x] for x in inputs], device=device)
inputj_data = torch.rand([sizes[x] for x in inputsj], device=device)
param_data = torch.rand([sizes[x] for x in params], device=device)
equation_ijpo = f"{inputs},{inputsj},{params}->{outputs}" # jipo should be the same
equation_ipjo = f"{inputs},{params},{inputsj}->{outputs}" # jpio should be the same
equation_pijo = f"{params},{inputs},{inputsj}->{outputs}" # pjio should be the same
def proc_ijpo() -> Tensor:
result = torch.einsum(equation_ijpo, input_data, inputj_data, param_data)
return result.contiguous() # penalize uncontiguous
def proc_ipjo() -> Tensor:
result = torch.einsum(equation_ipjo, input_data, param_data, inputj_data)
return result.contiguous() # penalize uncontiguous
def proc_pijo() -> Tensor:
result = torch.einsum(equation_pijo, param_data, input_data, inputj_data)
return result.contiguous() # penalize uncontiguous
proc = proc_ijpo if order == "ijpo" else proc_ipjo if order == "ipjo" else proc_pijo
_, t = timer(proc)
return t
...
for inputs in itertools.permutations("BIF"):
for params in itertools.permutations("IJKF"):
inputs = "".join(inputs)
inputsj = inputs.replace("I", "J")
params = "".join(params)
outputs = inputs.replace("I", "K")
ans.append(["i", inputs, inputsj, "p", params, "o", outputs])
for order in ("ijpo", "ipjo", "pijo"):
...
ans.sort(key=lambda x: min(x[-1], x[-3], x[-5])) result
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This issue is the tracker for benchmarks on pytorch ops with different axes order.
All time counted in$\mu s$ .
Notations:
B
: batchK
(alsoI
,J
): unitF
: foldingR
: rankC
: componentConclusions:
torch.bmm
(possibly with aTensor.expand
to manually broadcast) andtorch.matmul
is the same astorch.einsum
in underlying calculations (all reduced tobmm
->baddbmm
-> bgemm in BLAS).The text was updated successfully, but these errors were encountered: