-
Notifications
You must be signed in to change notification settings - Fork 0
/
py_module.c
77 lines (63 loc) · 2.29 KB
/
py_module.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <Python.h>
#include <numpy/arrayobject.h>
//#include "/usr/local/lib/python3.7/site-packages/numpy/core/include/numpy/arrayobject.h"
#include "roc.h"
static char module_docstring[] = "This module provides an interface for calculating mean ROC-AUC";
static char mroc_docstring[] = "Calculate mean of ROC-AUC's aggregated by label.";
static PyObject *mroc_mean_roc_auc(PyObject *self, PyObject *args);
static PyMethodDef module_methods[] = {
{"mean_roc_auc", mroc_mean_roc_auc, METH_VARARGS, mroc_docstring},
{NULL, NULL, 0, NULL}
};
PyMODINIT_FUNC PyInit__mroc(void) {
PyObject *module;
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_mroc",
module_docstring,
-1,
module_methods,
NULL,
NULL,
NULL,
NULL
};
module = PyModule_Create(&moduledef);
if (!module) {
return NULL;
}
// Load `numpy` functionality.
import_array();
return module;
}
static PyObject *mroc_mean_roc_auc(PyObject *self, PyObject *args) {
PyObject *labels_obj, *actuals_obj, *preds_obj;
// Parse the input tuple
if (!PyArg_ParseTuple(args, "OOO", &labels_obj, &actuals_obj, &preds_obj)) {
return NULL;
}
// Interpret the input objects as numpy arrays.
PyObject *labels_array = PyArray_FROM_OTF(labels_obj, NPY_INT, NPY_IN_ARRAY);
PyObject *actuals_array = PyArray_FROM_OTF(actuals_obj, NPY_INT, NPY_IN_ARRAY);
PyObject *preds_array = PyArray_FROM_OTF(preds_obj, NPY_DOUBLE, NPY_IN_ARRAY);
// If that didn't work, throw an exception.
if (labels_array == NULL || actuals_array == NULL || preds_array == NULL) {
Py_XDECREF(labels_array);
Py_XDECREF(actuals_array);
Py_XDECREF(preds_array);
return NULL;
}
size_t n = (size_t) PyArray_DIM(labels_array, 0);
int* labels = (int*) PyArray_DATA(labels_array);
int* actuals = (int*) PyArray_DATA(actuals_array);
double* preds = (double*) PyArray_DATA(preds_array);
double value = mean_roc_auc(labels, actuals, preds, n);
// Clean up.
Py_DECREF(labels_array);
Py_DECREF(actuals_array);
Py_DECREF(preds_array);
// Build the output tuple
PyObject *ret = Py_BuildValue("d", value);
return ret;
}
// Article: https://dfm.io/posts/python-c-extensions/