Skip to content

Commit

Permalink
added numba fast matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
alexlib committed Jun 19, 2024
1 parent d0e9458 commit 14549eb
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion openptv_python/ray_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Tuple

import numpy as np
from numba import njit
from numba import float64, int64, njit, prange

from .calibration import Calibration
from .parameters import MultimediaPar
Expand Down Expand Up @@ -306,3 +306,36 @@ def fast_ray_tracing(
# out = vec_add(tmp1, tmp2)

# return X, out

# import numpy as np
# import time

@njit(float64[:, :](float64[:, :], float64[:, :], float64[:, :], int64, int64, int64), parallel=True)
def matmul_numba_optimized(a, b, c, m, n, k):
for i in prange(m):
for j in range(k):
temp = 0.0
for ll in range(n):
temp += b[i, ll] * c[ll, j]
a[i, j] = temp
return a

# # Define the same inputs as in the C test
# b = np.array([
# [1.0, 2.0, 3.0],
# [4.0, 5.0, 6.0]
# ], dtype=np.float64)
# c = np.array([
# [1.0, 0.0],
# [0.0, 1.0],
# [1.0, 1.0]
# ], dtype=np.float64)
# a = np.zeros((2, 2), dtype=np.float64)

# start_time = time.time()
# matmul_numba_optimized(a, b, c, 2, 3, 2)
# end_time = time.time()

# print("Optimized Python Numba Time:", end_time - start_time, "seconds")
# print("Result from Python function:")
# print(a)

0 comments on commit 14549eb

Please sign in to comment.