Skip to content

Commit

Permalink
change add probability to nnom_predic()
Browse files Browse the repository at this point in the history
change mnist-simple example
  • Loading branch information
majianjia committed Mar 30, 2019
1 parent 1f4c136 commit d4ec7b4
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
8 changes: 5 additions & 3 deletions docs/api_evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,21 @@ Total Memory cost (Network and NNoM): 32876
---
## nnom_predic_one()
## nnom_predic()
~~~C
int32_t nnom_predic_one(nnom_model_t *m);
int32_t nnom_predic(nnom_model_t *m, uint32_t *label, float *prob);
~~~

To predict one set of input data. This is the standalone API which does not require `printf()` to print results but only return the predicted label.
A standalone evaluation method, run single prodiction, return probability and top-1 label.

This method is basicly `model_run()` + `index(top-1)`

**Arguments**

- **m:** the model to run prediction (evaluation).
- **label:** the variable to store top-1 label.
- **prob:** the variable to store probability. Range from 0~1.

**Return**

Expand Down
3 changes: 3 additions & 0 deletions docs/example_mnist_simple_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ prediction start..
Time: 62 tick
Truth label: 8
Predicted label: 8
Probability: 100%
~~~

额,如果恶心到你了,那我道歉...
Expand All @@ -134,6 +136,7 @@ Predicted label: 8
- 此次预测的时间,这里用了 `62 tick`,我这是相当于 62ms
- 这张图片的真实数字是 `8`
- 网络计算的这张照片的数字 `8`
- 可能性是100%

赶快去试试,其他的 9 张图片吧。

Expand Down
5 changes: 3 additions & 2 deletions examples/mnist-simple/mcu/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ void print_img(int8_t * buf)
void mnist(int argc, char** argv)
{
uint32_t tick, time;
int32_t result;
uint32_t predic_label;
uint32_t prob;
int32_t index = atoi(argv[1]);

if(index >= TOTAL_IMAGE || argc != 2)
Expand All @@ -67,7 +68,7 @@ void mnist(int argc, char** argv)

// copy data and do prediction
memcpy(nnom_input_data, (int8_t*)&img[index][0], 784);
result = nnom_predic_one(model);
nnom_predic_one(model, predic_label, prob);
time = rt_tick_get() - tick;

//print original image to console
Expand Down
2 changes: 1 addition & 1 deletion inc/nnom_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void prediction_summary(nnom_predic_t *pre);
// this api test one set of data, return the prediction
// return the predicted label
// return NN_ARGUMENT_ERROR if parameter error
int32_t nnom_predic_one(nnom_model_t *m);
nnom_status_t nnom_predic(nnom_model_t *m, uint32_t *label, float *prob);

void model_stat(nnom_model_t *m);

Expand Down
12 changes: 9 additions & 3 deletions src/nnom_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ void prediction_summary(nnom_predic_t *pre)

// stand alone prediction API
// this api test one set of data, return the prediction
int32_t nnom_predic_one(nnom_model_t *m)
nnom_status_t nnom_predic(nnom_model_t *m, uint32_t *label, float *prob)
{
int32_t max_val, max_index;
int32_t max_val, max_index, sum;
int8_t *output;

if (!m)
Expand All @@ -239,15 +239,21 @@ int32_t nnom_predic_one(nnom_model_t *m)
// Top 1
max_val = output[0];
max_index = 0;
sum = max_val;
for (uint32_t i = 1; i < shape_size(&m->tail->out->shape); i++)
{
if (output[i] > max_val)
{
max_val = output[i];
max_index = i;
}
sum += output[i];
}
return max_index;
// send results
*label = max_index;
*prob = (float)max_val/(float)sum;

return NN_SUCCESS;
}

static void layer_stat(nnom_layer_t *layer)
Expand Down

0 comments on commit d4ec7b4

Please sign in to comment.