Skip to content

Commit

Permalink
Hard code 3-element sort in numba functions for more speed.
Browse files Browse the repository at this point in the history
Signed-off by: David Rowenhorst <[email protected]>
  • Loading branch information
drowenhorst-nrl committed Apr 29, 2024
1 parent 8fc8f6c commit 99af42c
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions pyebsdindex/tripletvote.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@
import numpy as np
import numba

#keep this around for profiling numba functions
#import ctypes
#import time

# # Access the _PyTime_AsSecondsDouble and _PyTime_GetSystemClock functions from pythonapi
# get_system_clock = ctypes.pythonapi._PyTime_GetSystemClock
# as_seconds_double = ctypes.pythonapi._PyTime_AsSecondsDouble
#
# # Set the argument types and return types of the functions
# get_system_clock.argtypes = []
# get_system_clock.restype = ctypes.c_int64
#
# as_seconds_double.argtypes = [ctypes.c_int64]
# as_seconds_double.restype = ctypes.c_double
# @numba.jit(nopython=True, cache=True,fastmath=True,parallel=False)
# def ntime()-> np.float64:
# return np.float64(as_seconds_double(get_system_clock()))
### END of numba timer ####

from pyebsdindex import crystal_sym, rotlib, crystallometry


Expand Down Expand Up @@ -953,6 +972,8 @@ def _orientation_quest_nb(polescart, bandnorms, weights):
@staticmethod
@numba.jit(nopython=True, cache=True,fastmath=True,parallel=False)
def _tripvote_numba(bandnorms, band_intensity, LUT, angTol, tripAngles, tripID, nfam):
timing1 = 0.0
timing2 = 0.0
npats = bandnorms.shape[0]
n_bands = bandnorms.shape[1]
LUTTemp = np.asarray(LUT).copy()
Expand Down Expand Up @@ -984,16 +1005,38 @@ def _tripvote_numba(bandnorms, band_intensity, LUT, angTol, tripAngles, tripID,
for i in range(n_bands):
for j in range(i + 1,n_bands):
for k in range(j + 1,n_bands):
# tic = ntime()
angtri = np.array([bandangs[i,j],bandangs[i,k],bandangs[j,k]], dtype=np.float32)
srt = angtri.argsort(kind='quicksort') #np.array(np.argsort(angtri), dtype=numba.int64)
#srt = np.array(np.argsort(angtri), dtype=numba.int64)
# I am doing the above, but is MUCH faster for just the three numbers to hard code
srt = np.array([0,1,2], dtype=np.uint64)
if angtri[srt[0]] > angtri[srt[2]]:
srt[2], srt[0] = srt[0], srt[2]
if angtri[srt[0]] > angtri[srt[1]]:
srt[1], srt[0] = srt[0], srt[1]
if angtri[srt[1]] > angtri[srt[2]]:
srt[2], srt[1] = srt[1], srt[2]
##### end hard code argsrt ######

srt2 = np.asarray(LUTTemp[:,srt[0],srt[1],srt[2]], dtype=np.int64).copy()
unsrtFID = np.argsort(srt2,kind='quicksort').astype(np.int64)
angtriSRT = np.asarray(angtri[srt])
#unsrtFID = np.argsort(srt2,kind='quicksort').astype(np.int64)
#again, hard coding in the above for speed.
unsrtFID = np.array([0,1,2], dtype=np.uint64)
if srt2[unsrtFID[0]] > srt2[unsrtFID[2]]:
unsrtFID[2], unsrtFID[0] = unsrtFID[0], unsrtFID[2]
if srt2[unsrtFID[0]] > srt2[unsrtFID[1]]:
unsrtFID[1], unsrtFID[0] = unsrtFID[0], unsrtFID[1]
if srt2[unsrtFID[1]] > srt2[unsrtFID[2]]:
unsrtFID[2], unsrtFID[1] = unsrtFID[1], unsrtFID[2]
##### end hard code argsrt ######
angtriSRT = np.asarray(angtri[srt], dtype = np.float32)

#angTest0 = (np.abs(tripAngles - angtriSRT)).astype(np.float32)
#print(angTest0.shape)
#angTest = (angTest0 <= angTol)#.astype(np.int)

# toc = ntime()
# timing1 += toc - tic
# toc = ntime()
for q in range(ntrip):
#print('____')
#print(tripAngles[q,:], angtriSRT)
Expand Down Expand Up @@ -1076,7 +1119,7 @@ def _tripvote_numba(bandnorms, band_intensity, LUT, angTol, tripAngles, tripID,
accumulator[p,f[1], k] += 1
accumulator[p,f[2], i] += 1


# timing2 += ntime() - toc

for q in range(n_bands):
mxvote[p,q] = np.amax(accumulatorW[p,:,q])
Expand All @@ -1090,7 +1133,7 @@ def _tripvote_numba(bandnorms, band_intensity, LUT, angTol, tripAngles, tripID,
#for q in range(n_bands):
bandFam[p,q] = np.argmax(accumulatorW[p,:,q])
bandRank[p,:] = (n_bands - np.arange(n_bands)) / n_bands * band_cm[p,:] * mxvote[p,:]

# print(timing1, timing2)
return accumulatorW, bandFam, bandRank, band_cm, accumulator

@staticmethod
Expand Down

0 comments on commit 99af42c

Please sign in to comment.