Skip to content

Commit

Permalink
Revert "quantize_block C->C++, use std::thread everywhere (#1024)"
Browse files Browse the repository at this point in the history
This reverts commit 332530b.
  • Loading branch information
akx committed Mar 7, 2024
1 parent 87e029b commit 16541ff
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 27 deletions.
24 changes: 14 additions & 10 deletions csrc/common.cpp
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
#include <common.h>
#include <float.h>

void quantize_block(const quantize_block_args& args) {
void *quantize_block(void *arguments) {
// 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
// 4. check minimal distance
// 5. store index

struct quantize_block_args *args = (quantize_block_args *) arguments;

// 1. find absmax in block
float absmax_block = -FLT_MAX;
for (long long i = args.block_idx; i < args.block_end; i++)
absmax_block = fmax(absmax_block, fabs(args.A[i]));
for (long long i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));

args.absmax[args.block_idx / args.blocksize] = absmax_block;
args->absmax[args->block_idx / args->blocksize] = absmax_block;

for (long long i = args.block_idx; i < args.block_end; i++) {
for (long long i = args->block_idx; i < args->block_end; i++) {
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args.A[i] / absmax_block;
long long idx = args.bin_searcher->scalar(normed_value);
float normed_value = args->A[i] / absmax_block;
long long idx = args->bin_searcher->scalar(normed_value);

// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
if (idx < 255) {
float dist_left = fabs(normed_value - (args.code[idx]));
float dist_right = fabs(normed_value - (args.code[idx + 1]));
float dist_left = fabs(normed_value - (args->code[idx]));
float dist_right = fabs(normed_value - (args->code[idx + 1]));
if (dist_right < dist_left) { idx += 1; }
}

// 5. store index
args.out[i] = (unsigned char) idx;
args->out[i] = (unsigned char) idx;
}

return NULL;
}
2 changes: 1 addition & 1 deletion csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ struct quantize_block_args {
};


void quantize_block(const quantize_block_args& args);
void *quantize_block(void *arguments);

#endif
59 changes: 43 additions & 16 deletions csrc/cpu_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <BinSearch.h>
#include <common.h>
#ifdef _WIN32
#include <thread>
#else
#include <pthread.h>
#endif
#include <common.h>

using namespace BinSearch;

Expand All @@ -26,38 +30,61 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);

int thread_wave_size = 256;
// we chunk the threads into waves of 256 since the max limit is
// we chunk the thresds into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
std::vector<std::thread> threads(valid_chunks);
std::vector<quantize_block_args> args(valid_chunks);
#ifdef _WIN32
std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks);
#else
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
#endif

struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));

for(long long i = 0; i < valid_chunks; i++)
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));

int chunks_processed = 0;
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
{
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;

struct quantize_block_args& arg = args[chunks_processed];
arg.bin_searcher = &bin_searcher;
arg.code = code;
arg.A = A;
arg.absmax = absmax;
arg.out = out;
arg.block_end = block_end;
arg.block_idx = block_idx;
arg.threadidx = block_idx / blocksize;
arg.blocksize = blocksize;

threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
struct quantize_block_args *arg = args[chunks_processed];
arg->bin_searcher = &bin_searcher;
arg->code = code;
arg->A = A;
arg->absmax = absmax;
arg->out = out;
arg->block_end = block_end;
arg->block_idx = block_idx;
arg->threadidx = block_idx / blocksize;
arg->blocksize = blocksize;

#ifdef _WIN32
new (&threads[chunks_processed]) std::thread(quantize_block, arg);
#else
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
#endif
chunks_processed += 1;
if(chunks_processed == valid_chunks){ break; }
}

for (int i = 0; i < valid_chunks; i++)
{
#ifdef _WIN32
threads[i].join();
#else
int err = pthread_join(threads[i], NULL);
#endif
}
free(threads);
for (int i = 0; i < valid_chunks; i++)
free(args[i]);
free(args);

}

}

0 comments on commit 16541ff

Please sign in to comment.