Skip to content

Commit

Permalink
Update pairvote function ... it's fast but less accurate. There for c…
Browse files Browse the repository at this point in the history
…ompleteness.

Signed-off by: David Rowenhorst <[email protected]>
  • Loading branch information
drowenhorst-nrl committed Apr 30, 2024
1 parent d86568e commit 345d93b
Showing 1 changed file with 54 additions and 41 deletions.
95 changes: 54 additions & 41 deletions pyebsdindex/tripletvote.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ def bandindex(self, band_norms, band_intensity = None, band_widths=None, verbose

accumulator, bandFam, bandRank, band_cm, accumulator_nw \
= self._tripvote_numba(bandnorms, band_intensity, self.lut, self.angTol, tripangs, tripid, nfam)
#accumulator, bandFam, bandRank, band_cm = self._pairvote_numba(bandangs, self.angTol, pairangs, pairfam,
# nfam, n_bands)
#accumulator, bandFam, bandRank, band_cm, accumulator_nw \
# = self._pairvote_numba(bandnorms, band_intensity, self.angTol, pairangs, pairfam, nfam)

bandRank_arg = np.argsort(bandRank, axis=1).astype(np.int64)

Expand Down Expand Up @@ -1141,53 +1141,66 @@ def _tripvote_numba(bandnorms, band_intensity, LUT, angTol, tripAngles, tripID,

@staticmethod
@numba.jit(nopython=True, cache=True, fastmath=True, parallel=False)
def _pairvote_numba(bandangs, angTol, pairAngs, pairID, nfam, n_bands):
def _pairvote_numba(bandnorms,band_intensity, angTol, pairAngs, pairID, nfam):

npats = bandnorms.shape[0]
n_bands = bandnorms.shape[1]

accumulator = np.zeros((npats, nfam, n_bands), dtype=np.float32)
accumulatorW = np.zeros((npats, nfam, n_bands), dtype=np.float32)
mxvote = np.zeros((npats, n_bands), dtype=np.int32)
tvotes = np.zeros((npats, n_bands), dtype=np.int32)
band_cm = np.zeros((npats, n_bands), dtype=np.float32)
bandRank = np.zeros((npats, n_bands), dtype=np.float32)
bandFam = np.zeros((npats, n_bands), dtype=np.int32)

accumulator = np.zeros((nfam, n_bands), dtype=np.float32)
pairshape = np.shape(pairAngs)
npair = int(pairshape[0])
count = 0.0
#count = 0.0
# angTest2 = np.zeros(ntrip, dtype=numba.boolean)
# angTest2 = np.empty(ntrip,dtype=numba.boolean)
for i in range(n_bands):
for j in range(i + 1, n_bands):
bandangpair = bandangs[i, j]
angTest = (np.abs(pairAngs - bandangpair)).astype(np.float32)
# print(angTest0.shape)


for q in range(npair):
if angTest[q] <= angTol:
w1 = (angTol - angTest[q])

# print(w1, w2, w3)
accumulator[pairID[q,0], i] += w1
accumulator[pairID[q,1], i] += w1
accumulator[pairID[q,0], j] += w1
accumulator[pairID[q,1], j] += w1


mxvote = np.zeros(n_bands, dtype=np.int32)
tvotes = np.zeros(n_bands, dtype=np.int32)
band_cm = np.zeros(n_bands, dtype=np.float32)
for q in range(n_bands):
mxvote[q] = np.amax(accumulator[:, q])
tvotes[q] = np.sum(accumulator[:, q])
for p in range(npats):
bandangs = np.abs(bandnorms[p, ...].dot(bandnorms[p, ...].T))
bandangs = np.clip(bandangs, -1.0, 1.0)
bandangs = np.arccos(bandangs) * RADEG
for i in range(n_bands):
if band_intensity[p,i] < 1e-6: # invalid band
bandangs[i,:] = 10000.0
bandangs[:, i] = 10000.0
for i in range(n_bands):
for j in range(i + 1, n_bands):
bandangpair = bandangs[i, j]
angTest = (np.abs(pairAngs - bandangpair)).astype(np.float32)
# print(angTest0.shape)
for q in range(npair):
if angTest[q] <= angTol:
w1 = (angTol - angTest[q])

# print(w1, w2, w3)
accumulator[p, pairID[q, 0], i] += 1
accumulator[p, pairID[q, 1], i] += 1
accumulator[p, pairID[q, 0], j] += 1
accumulator[p, pairID[q, 1], j] += 1

accumulatorW[p, pairID[q,0], i] += w1
accumulatorW[p, pairID[q,1], i] += w1
accumulatorW[p, pairID[q,0], j] += w1
accumulatorW[p, pairID[q,1], j] += w1

for i in range(n_bands):
if tvotes[i] < 1:
band_cm[i] = 0.0
else:
srt = np.argsort(accumulator[:, i])
band_cm[i] = (accumulator[srt[-1], i] - accumulator[srt[-2], i]) / tvotes[i]
for q in range(n_bands):
mxvote[p, q] = np.amax(accumulatorW[p, :, q])
tvotes[p, q] = np.sum(accumulatorW[p, :, q])
# for i in range(n_bands):
if tvotes[p, q] < 1:
band_cm[p, q] = 0.0
else:
srt = np.argsort(accumulatorW[p, :, q])
band_cm[p, q] = (accumulatorW[p, srt[-1], q] - accumulatorW[p, srt[-2], q]) / (tvotes[p, q])

bandRank = np.zeros(n_bands, dtype=np.float32)
bandFam = np.zeros(n_bands, dtype=np.int32)
for q in range(n_bands):
bandFam[q] = np.argmax(accumulator[:, q])
bandRank = (n_bands - np.arange(n_bands)) / n_bands * band_cm * mxvote
bandFam[p, q] = np.argmax(accumulatorW[p, :, q])
bandRank[p, :] = (n_bands - np.arange(n_bands)) / n_bands * band_cm[p, :] * mxvote[p, :]

return accumulator, bandFam, bandRank, band_cm
return accumulatorW, bandFam, bandRank, band_cm, accumulator

@staticmethod
@numba.jit(nopython=True, cache=True, fastmath=True,parallel=False)
Expand Down

0 comments on commit 345d93b

Please sign in to comment.