Skip to content

Commit

Permalink
Pythia v1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
tschuelia committed Oct 30, 2024
1 parent 7914717 commit 6d23222
Show file tree
Hide file tree
Showing 17 changed files with 39,905 additions and 18,943 deletions.
88 changes: 36 additions & 52 deletions src/corax/difficulty/prediction.c
Original file line number Diff line number Diff line change
@@ -1,69 +1,53 @@

#include "prediction.h"

;

static const int32_t num_class[] = { 1, };

size_t get_num_class(void) {
return 1;
int32_t get_num_target(void) {
return N_TARGET;
}

size_t get_num_feature(void) {
return 10;
}

const char* get_pred_transform(void) {
return "identity";
}

float get_sigmoid_alpha(void) {
return 1.0;
}

float get_ratio_c(void) {
return 1.0;
void get_num_class(int32_t* out) {
for (int i = 0; i < N_TARGET; ++i) {
out[i] = num_class[i];
}
}

float get_global_bias(void) {
return 0.0;
int32_t get_num_feature(void) {
return 10;
}

const char* get_threshold_type(void) {
return "float64";
}

const char* get_leaf_output_type(void) {
return "float64";
}


static inline double pred_transform(double margin) {
return margin;
}
double predict(union Entry* data, int pred_margin) {
double sum = (double)0;
void predict(union Entry* data, int pred_margin, double* result) {
unsigned int tmp;
int nid, cond, fid; /* used for folded subtrees */
sum += predict_margin_unit0(data);
sum += predict_margin_unit1(data);
sum += predict_margin_unit2(data);
sum += predict_margin_unit3(data);
sum += predict_margin_unit4(data);
sum += predict_margin_unit5(data);
sum += predict_margin_unit6(data);
sum += predict_margin_unit7(data);
sum += predict_margin_unit8(data);
sum += predict_margin_unit9(data);
sum += predict_margin_unit10(data);
sum += predict_margin_unit11(data);
sum += predict_margin_unit12(data);
sum += predict_margin_unit13(data);
sum += predict_margin_unit14(data);

sum = sum + (double)(0);
if (!pred_margin) {
return pred_transform(sum);
} else {
return sum;
}
predict_unit0(data, result);
predict_unit1(data, result);
predict_unit2(data, result);
predict_unit3(data, result);
predict_unit4(data, result);
predict_unit5(data, result);
predict_unit6(data, result);
predict_unit7(data, result);
predict_unit8(data, result);
predict_unit9(data, result);
predict_unit10(data, result);
predict_unit11(data, result);
predict_unit12(data, result);
predict_unit13(data, result);
predict_unit14(data, result);

// Apply base_scores
result[0] += 0;

// Apply postprocessor
if (!pred_margin) { postprocess(result); }
}

void postprocess(double* result) {
// Do nothing
}

56 changes: 23 additions & 33 deletions src/corax/difficulty/prediction.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,35 @@
#define UNLIKELY(x) (x)
#endif

#define N_TARGET 1
#define MAX_N_CLASS 1

union Entry {
int missing;
double fvalue;
int qvalue;
};

struct Node {
uint8_t default_left;
unsigned int split_index;
double threshold;
int left_child;
int right_child;
};

extern const unsigned char is_categorical[];


size_t get_num_class(void);
size_t get_num_feature(void);
const char* get_pred_transform(void);
float get_sigmoid_alpha(void);
float get_ratio_c(void);
float get_global_bias(void);
int32_t get_num_target(void);
void get_num_class(int32_t* out);
int32_t get_num_feature(void);
const char* get_threshold_type(void);
const char* get_leaf_output_type(void);
void predict(union Entry* data, int pred_margin, double* result);
void postprocess(double* result);
void predict_unit0(union Entry* data, double* result);
void predict_unit1(union Entry* data, double* result);
void predict_unit2(union Entry* data, double* result);
void predict_unit3(union Entry* data, double* result);
void predict_unit4(union Entry* data, double* result);
void predict_unit5(union Entry* data, double* result);
void predict_unit6(union Entry* data, double* result);
void predict_unit7(union Entry* data, double* result);
void predict_unit8(union Entry* data, double* result);
void predict_unit9(union Entry* data, double* result);
void predict_unit10(union Entry* data, double* result);
void predict_unit11(union Entry* data, double* result);
void predict_unit12(union Entry* data, double* result);
void predict_unit13(union Entry* data, double* result);
void predict_unit14(union Entry* data, double* result);

double predict(union Entry* data, int pred_margin);
double predict_margin_unit0(union Entry* data);
double predict_margin_unit1(union Entry* data);
double predict_margin_unit2(union Entry* data);
double predict_margin_unit3(union Entry* data);
double predict_margin_unit4(union Entry* data);
double predict_margin_unit5(union Entry* data);
double predict_margin_unit6(union Entry* data);
double predict_margin_unit7(union Entry* data);
double predict_margin_unit8(union Entry* data);
double predict_margin_unit9(union Entry* data);
double predict_margin_unit10(union Entry* data);
double predict_margin_unit11(union Entry* data);
double predict_margin_unit12(union Entry* data);
double predict_margin_unit13(union Entry* data);
double predict_margin_unit14(union Entry* data);
Loading

0 comments on commit 6d23222

Please sign in to comment.