Skip to content

Commit

Permalink
New quantization, some fixes in evaluation scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
canfirtina committed Mar 19, 2024
1 parent 4aa5a72 commit 5ebddcf
Show file tree
Hide file tree
Showing 27 changed files with 696 additions and 543 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ RawHash performs real-time mapping of nanopore raw signals. When the prefix of r

# Recent changes

* We came up with a better and more accurate quantization mechanism in RawHash2. The new quantization mechanism dynamically arranges the bucket sizes that each signal value is quantized depending on the normalized distribution of the signal values. **This provides significant improvements in both accuracy and performance.**

* We have integrated the signal alignment functionality with DTW as proposed in RawAlign (see the citation below). The parameters may still not be highly optimized as this is still in experimental stage. Use it with caution.

* Offline overlapping functionality is integrated.
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,4 @@ hit.o: rmap.h kalloc.h khash.h
rmap.o: rindex.h rsig.h kthread.h rh_kvec.h rutils.h rsketch.h revent.h sequence_until.h dtw.h
revent.o: roptions.h kalloc.h
rindex.o: roptions.h rutils.h rsketch.h rsig.h bseq.h khash.h rh_kvec.h kthread.h
main:o rawhash.h ketopt.h rutils.h
main:o rawhash.h ketopt.h rutils.h
186 changes: 99 additions & 87 deletions src/main.cpp

Large diffs are not rendered by default.

259 changes: 179 additions & 80 deletions src/revent.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <assert.h>
#include <float.h>
#include <math.h>
#include "rutils.h"

//Some of the functions here are adopted from the Sigmap implementation (https://github.com/haowenz/sigmap/tree/c9a40483264c9514587a36555b5af48d3f054f6f). We have optimized the Sigmap implementation to work with the hash tables efficiently.

Expand All @@ -19,8 +20,11 @@ typedef struct ri_detect_s {
int valid_peak;
}ri_detect_t;

static inline void comp_prefix_prefixsq(const float *sig, uint32_t s_len, float* prefix_sum, float* prefix_sum_square) {

static inline void comp_prefix_prefixsq(const float *sig,
const uint32_t s_len,
float* prefix_sum,
float* prefix_sum_square)
{
assert(s_len > 0);

prefix_sum[0] = 0.0f;
Expand All @@ -31,22 +35,17 @@ static inline void comp_prefix_prefixsq(const float *sig, uint32_t s_len, float*
}
}

static inline float* comp_tstat(void *km, const float *prefix_sum, const float *prefix_sum_square, uint32_t s_len, uint32_t w_len) {

static inline float* comp_tstat(void *km,
const float *prefix_sum,
const float *prefix_sum_square,
const uint32_t s_len,
const uint32_t w_len)
{
const float eta = FLT_MIN;

// rh_kvec_t(float) tstat = {0,0,0};
// rh_kv_resize(float, 0, tstat, s_len+1);
// rh_kv_pushp(float, 0, tstat, &s);

float* tstat = (float*)ri_kcalloc(km, s_len+1, sizeof(float));
// Quick return:
// t-test not defined for number of points less than 2
// need at least as many points as twice the window length
if (s_len < 2*w_len || w_len < 2) return tstat;
// fudge boundaries
memset(tstat, 0, w_len*sizeof(float));
// get to work on the rest

for (uint32_t i = w_len; i <= s_len - w_len; ++i) {
float sum1 = prefix_sum[i];
float sumsq1 = prefix_sum_square[i];
Expand All @@ -58,46 +57,59 @@ static inline float* comp_tstat(void *km, const float *prefix_sum, const float *
float sumsq2 = prefix_sum_square[i + w_len] - prefix_sum_square[i];
float mean1 = sum1 / w_len;
float mean2 = sum2 / w_len;
float combined_var = sumsq1 / w_len - mean1 * mean1 + sumsq2 / w_len - mean2 * mean2;
float combined_var = (sumsq1/w_len - mean1*mean1 + sumsq2/w_len - mean2*mean2)/w_len;
// Prevent problem due to very small variances
combined_var = fmaxf(combined_var, eta);
// t-stat
// Formula is a simplified version of Student's t-statistic for the
// special case where there are two samples of equal size with
// differing variance
const float delta_mean = mean2 - mean1;
tstat[i] = fabs(delta_mean) / sqrt(combined_var / w_len);
tstat[i] = fabs(delta_mean) / sqrt(combined_var);
}
// fudge boundaries
memset(tstat+s_len-w_len+1, 0, (w_len)*sizeof(float));

return tstat;
}

static inline uint32_t gen_peaks(ri_detect_t *short_detector, ri_detect_t *long_detector, const float peak_height, uint32_t* peaks) {

assert(short_detector->s_len == long_detector->s_len);
// static inline float calculate_adaptive_peak_height(const float *prefix_sum, const float *prefix_sum_square, uint32_t current_index, uint32_t window_length, float base_peak_height) {
// // Ensure we don't go beyond signal bounds
// uint32_t start_index = current_index > window_length ? current_index - window_length : 0;
// uint32_t end_index = current_index + window_length;

uint32_t curInd = 0;
// float sum = prefix_sum[end_index] - prefix_sum[start_index];
// float sumsq = prefix_sum_square[end_index] - prefix_sum_square[start_index];
// float mean = sum / (end_index - start_index);
// float variance = (sumsq / (end_index - start_index)) - (mean * mean);
// float stddev = sqrtf(variance);

// // Example adaptive strategy: Increase peak height in high-variance regions
// return base_peak_height * (1 + stddev);
// }

static inline uint32_t gen_peaks(ri_detect_t **detectors,
const uint32_t n_detectors,
const float peak_height,
const float *prefix_sum,
const float *prefix_sum_square,
uint32_t* peaks) {

uint32_t ndetector = 2;
ri_detect_t *detectors[ndetector]; // = {short_detector, long_detector};
detectors[0] = short_detector;
detectors[1] = long_detector;
for (uint32_t i = 0; i < short_detector->s_len; i++) {
for (uint32_t k = 0; k < ndetector; k++) {
uint32_t curInd = 0;
for (uint32_t i = 0; i < detectors[0]->s_len; i++) {
for (uint32_t k = 0; k < n_detectors; k++) {
ri_detect_t *detector = detectors[k];
// Carry on if we've been masked out
if (detector->masked_to >= i) continue;

float current_value = detector->sig[i];
// float adaptive_peak_height = calculate_adaptive_peak_height(prefix_sum, prefix_sum_square, i, detector->window_length, peak_height);

if (detector->peak_pos == detector->DEF_PEAK_POS) {
// CASE 1: We've not yet recorded a maximum
if (current_value < detector->peak_value) {
// Either record a deeper minimum...
// CASE 1: We've not yet recorded any maximum
if (current_value < detector->peak_value) { // A deeper minimum:
detector->peak_value = current_value;
} else if (current_value - detector->peak_value > peak_height) { // TODO(Haowen): this might cause overflow, need to fix this
// ...or we've seen a qualifying maximum
} else if (current_value - detector->peak_value > peak_height) {
// ...or a qualifying maximum:
detector->peak_value = current_value;
detector->peak_pos = i;
// otherwise, wait to rise high enough to be considered a peak
Expand All @@ -109,22 +121,22 @@ static inline uint32_t gen_peaks(ri_detect_t *short_detector, ri_detect_t *long_
detector->peak_value = current_value;
detector->peak_pos = i;
}
// Dominate other tstat signals if we're going to fire at some point
if (detector == short_detector) {
if (detector->peak_value > detector->threshold) {
long_detector->masked_to = detector->peak_pos + detector->window_length;
long_detector->peak_pos = long_detector->DEF_PEAK_POS;
long_detector->peak_value = long_detector->DEF_PEAK_VAL;
long_detector->valid_peak = 0;
// Tell other detectors no need to check for a peak until a certain point
if (detector->peak_value > detector->threshold) {
for(int n_d = k+1; n_d < n_detectors; n_d++){
detectors[n_d]->masked_to = detector->peak_pos + detectors[0]->window_length;
detectors[n_d]->peak_pos = detectors[n_d]->DEF_PEAK_POS;
detectors[n_d]->peak_value = detectors[n_d]->DEF_PEAK_VAL;
detectors[n_d]->valid_peak = 0;
}
}
// Have we convinced ourselves we've seen a peak
if (detector->peak_value - current_value > peak_height && detector->peak_value > detector->threshold) {
// There is a good peak
if (detector->peak_value - current_value > peak_height &&
detector->peak_value > detector->threshold) {
detector->valid_peak = 1;
}
// Finally, check the distance if this is a good peak
// Check if we are now further away from the current peak
if (detector->valid_peak && (i - detector->peak_pos) > detector->window_length / 2) {
// Emit the boundary and reset
peaks[curInd++] = detector->peak_pos;
detector->peak_pos = detector->DEF_PEAK_POS;
detector->peak_value = current_value;
Expand All @@ -137,78 +149,165 @@ static inline uint32_t gen_peaks(ri_detect_t *short_detector, ri_detect_t *long_
return curInd;
}

int compare_floats(const void* a, const void* b) {
const float* da = (const float*) a;
const float* db = (const float*) b;
return (*da > *db) - (*da < *db);
}

float calculate_mean_of_filtered_segment(float* segment,
const uint32_t segment_length)
{
// Calculate median and IQR
qsort(segment, segment_length, sizeof(float), compare_floats); // Assuming compare_floats is already defined
float q1 = segment[segment_length / 4];
float q3 = segment[3 * segment_length / 4];
float iqr = q3 - q1;
float lower_bound = q1 - iqr;
float upper_bound = q3 + iqr;

float sum = 0.0;
uint32_t count = 0;
for (uint32_t i = 0; i < segment_length; i++) {
if (segment[i] >= lower_bound && segment[i] <= upper_bound) {
sum += segment[i];
++count;
}
}

// Return the mean of the filtered segment
return count > 0 ? sum / count : 0; // Ensure we don't divide by zero
}

/**
* @brief Generates events from peaks, prefix sums and s_len.
*
* @param km Pointer to memory manager.
* @param peaks Array of peak positions.
* @param peak_size Size of peaks array.
* @param prefix_sum Array of prefix sums.
* @param prefix_sum_square Array of prefix sums squared.
* @param s_len Length of the signal.
* @param n_events Pointer to the number of events generated.
* @return float* Pointer to the array of generated events.
*/
static inline float* gen_events(void *km,
float* sig,
const uint32_t *peaks,
const uint32_t peak_size,
const float *prefix_sum,
const float *prefix_sum_square,
const uint32_t s_len,
uint32_t* n_events)
{
uint32_t n_ev = 1;
double mean = 0, std_dev = 0, sum = 0, sum2 = 0;
uint32_t n_ev = 0;

for (uint32_t i = 1; i < peak_size; ++i)
if (peaks[i] > 0 && peaks[i] < s_len)
n_ev++;
for (uint32_t pi = 0; pi < peak_size; ++pi)
if (peaks[pi] > 0 && peaks[pi] < s_len) n_ev++;

float* events = (float*)ri_kmalloc(km, n_ev*sizeof(float));
float l_prefixsum = 0, l_peak = 0;

for (uint32_t pi = 0; pi < n_ev - 1; pi++){
events[pi] = (prefix_sum[peaks[pi]] - l_prefixsum)/(peaks[pi]-l_peak);
sum += events[pi];
sum2 += events[pi]*events[pi];
l_prefixsum = prefix_sum[peaks[pi]];
l_peak = peaks[pi];
}

events[n_ev-1] = (prefix_sum[s_len] - l_prefixsum)/(s_len-l_peak);
sum += events[n_ev-1];
sum2 += events[n_ev-1]*events[n_ev-1];
uint32_t start_idx = 0, segment_length = 0;

//normalization
mean = sum/n_ev;
std_dev = sqrt(sum2/n_ev - (mean)*(mean));
for (uint32_t pi = 0, i = 0; pi < peak_size && i < n_ev; pi++){
if (!(peaks[pi] > 0 && peaks[pi] < s_len)) continue;

for(uint32_t i = 0; i < n_ev; ++i){
events[i] = (events[i]-mean)/std_dev;
segment_length = peaks[pi] - start_idx;
events[i++] = calculate_mean_of_filtered_segment(sig + start_idx, segment_length);
start_idx = peaks[pi];
}

(*n_events) = n_ev;
return events;
}

float* detect_events(void *km, uint32_t s_len, const float* sig, uint32_t window_length1, uint32_t window_length2, float threshold1, float threshold2, float peak_height, uint32_t *n) // kt_for() callback
static inline float* normalize_signal(void *km,
const float* sig,
const uint32_t s_len,
double* mean_sum,
double* std_dev_sum,
uint32_t* n_events_sum,
uint32_t* n_sig)
{
double sum = (*mean_sum), sum2 = (*std_dev_sum);
double mean = 0, std_dev = 0;
float* events = (float*)ri_kcalloc(km, s_len, sizeof(float));

for (uint32_t i = 0; i < s_len; ++i) {
sum += sig[i];
sum2 += sig[i]*sig[i];
}

(*n_events_sum) += s_len;
(*mean_sum) = sum;
(*std_dev_sum) = sum2;

mean = sum/(*n_events_sum);
std_dev = sqrt(sum2/(*n_events_sum) - (mean)*(mean));

float norm_val = 0;
int k = 0;
for(uint32_t i = 0; i < s_len; ++i){
norm_val = (sig[i]-mean)/std_dev;
if(norm_val < 3 && norm_val > -3) events[k++] = norm_val;
}

(*n_sig) = k;

return events;
}

float* detect_events(void *km,
const uint32_t s_len,
const float* sig,
const uint32_t window_length1,
const uint32_t window_length2,
const float threshold1,
const float threshold2,
const float peak_height,
double* mean_sum,
double* std_dev_sum,
uint32_t* n_events_sum,
uint32_t *n_events)
{
float* prefix_sum = (float*)ri_kcalloc(km, s_len+1, sizeof(float));
float* prefix_sum_square = (float*)ri_kcalloc(km, s_len+1, sizeof(float));

comp_prefix_prefixsq(sig, s_len, prefix_sum, prefix_sum_square);
float* tstat1 = comp_tstat(km, prefix_sum, prefix_sum_square, s_len, window_length1);
float* tstat2 = comp_tstat(km, prefix_sum, prefix_sum_square, s_len, window_length2);
ri_detect_t short_detector = {.DEF_PEAK_POS = -1, .DEF_PEAK_VAL = FLT_MAX, .sig = tstat1, .s_len = s_len, .threshold = threshold1,
.window_length = window_length1, .masked_to = 0, .peak_pos = -1, .peak_value = FLT_MAX, .valid_peak = 0};
//Normalize the signal
uint32_t n_signals = 0;
float* norm_signals = normalize_signal(km, sig, s_len, mean_sum, std_dev_sum, n_events_sum, &n_signals);
if(n_signals == 0) return 0;
comp_prefix_prefixsq(norm_signals, n_signals, prefix_sum, prefix_sum_square);

float* tstat1 = comp_tstat(km, prefix_sum, prefix_sum_square, n_signals, window_length1);
float* tstat2 = comp_tstat(km, prefix_sum, prefix_sum_square, n_signals, window_length2);
ri_detect_t short_detector = {.DEF_PEAK_POS = -1,
.DEF_PEAK_VAL = FLT_MAX,
.sig = tstat1,
.s_len = n_signals,
.threshold = threshold1,
.window_length = window_length1,
.masked_to = 0,
.peak_pos = -1,
.peak_value = FLT_MAX,
.valid_peak = 0};

ri_detect_t long_detector = {.DEF_PEAK_POS = -1,
.DEF_PEAK_VAL = FLT_MAX,
.sig = tstat2,
.s_len = n_signals,
.threshold = threshold2,
.window_length = window_length2,
.masked_to = 0,
.peak_pos = -1,
.peak_value = FLT_MAX,
.valid_peak = 0};

uint32_t* peaks = (uint32_t*)ri_kmalloc(km, n_signals * sizeof(uint32_t));
ri_detect_t *detectors[2] = {&short_detector, &long_detector};
uint32_t n_peaks = gen_peaks(detectors, 2, peak_height, prefix_sum, prefix_sum_square, peaks);
ri_kfree(km, tstat1); ri_kfree(km, tstat2); ri_kfree(km, prefix_sum); ri_kfree(km, prefix_sum_square);

ri_detect_t long_detector = {.DEF_PEAK_POS = -1, .DEF_PEAK_VAL = FLT_MAX, .sig = tstat2, .s_len = s_len, .threshold = threshold2,
.window_length = window_length2, .masked_to = 0, .peak_pos = -1, .peak_value = FLT_MAX, .valid_peak = 0};
uint32_t* peaks = (uint32_t*)ri_kmalloc(km, s_len * sizeof(uint32_t));
uint32_t n_peaks = gen_peaks(&short_detector, &long_detector, peak_height, peaks);
float* events = 0;
if(n_peaks > 0) events = gen_events(km, peaks, n_peaks, prefix_sum, prefix_sum_square, s_len, n);
ri_kfree(km, tstat1); ri_kfree(km, tstat2); ri_kfree(km, prefix_sum); ri_kfree(km, prefix_sum_square); ri_kfree(km, peaks);
if(n_peaks > 0) events = gen_events(km, norm_signals, peaks, n_peaks, n_signals, n_events);
ri_kfree(km, norm_signals); ri_kfree(km, peaks);

return events;
}
Loading

0 comments on commit 5ebddcf

Please sign in to comment.