Skip to content

Commit

Permalink
Merge pull request #77 from robertknight/propagate-recognition-errors
Browse files Browse the repository at this point in the history
Propagate errors running recognition model, instead of panicking
  • Loading branch information
robertknight authored May 23, 2024
2 parents edad43a + fb544ea commit 190baec
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
25 changes: 25 additions & 0 deletions ocrs/src/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::error::Error;
use std::fmt;

/// The error type returned when running a machine learning model fails.
#[derive(Debug)]
pub enum ModelRunError {
/// Model execution failed.
RunFailed(Box<dyn Error + Send + Sync>),

/// The model output had a different data type or shape than expected.
WrongOutput,
}

impl fmt::Display for ModelRunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
ModelRunError::RunFailed(err) => write!(f, "model run failed: {}", err),
ModelRunError::WrongOutput => {
write!(f, "model output had unexpected type or shape")
}
}
}
}

impl Error for ModelRunError {}
1 change: 1 addition & 0 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rten_tensor::prelude::*;
use rten_tensor::NdTensor;

mod detection;
mod errors;
mod geom_util;
mod layout_analysis;
mod log;
Expand Down
29 changes: 18 additions & 11 deletions ocrs/src/recognition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use rten_imageproc::{
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Tensor};

use crate::errors::ModelRunError;
use crate::geom_util::{downwards_line, leftmost_edge, rightmost_edge};
use crate::preprocess::BLACK_VALUE;
use crate::text_items::{TextChar, TextLine};
Expand Down Expand Up @@ -359,12 +360,14 @@ impl TextRecognizer {

/// Run text recognition on an NCHW batch of text line images, and return
/// a `[batch, seq, label]` tensor of class probabilities.
fn run(&self, input: NdTensor<f32, 4>) -> anyhow::Result<NdTensor<f32, 3>> {
fn run(&self, input: NdTensor<f32, 4>) -> Result<NdTensor<f32, 3>, ModelRunError> {
let input: Tensor<f32> = input.into();
let [output] =
self.model
.run_n(&[(self.input_id, (&input).into())], [self.output_id], None)?;
let mut rec_sequence: NdTensor<f32, 3> = output.try_into()?;
let [output] = self
.model
.run_n(&[(self.input_id, (&input).into())], [self.output_id], None)
.map_err(|err| ModelRunError::RunFailed(err.into()))?;
let mut rec_sequence: NdTensor<f32, 3> =
output.try_into().map_err(|_| ModelRunError::WrongOutput)?;

// Transpose from [seq, batch, class] => [batch, seq, class]
rec_sequence.permute([1, 0, 2]);
Expand Down Expand Up @@ -470,9 +473,9 @@ impl TextRecognizer {
.collect();

// Run text recognition on batches of lines.
let mut line_rec_results: Vec<LineRecResult> = line_groups
let batch_rec_results: Result<Vec<Vec<LineRecResult>>, ModelRunError> = line_groups
.into_par_iter()
.flat_map(|(group_width, lines)| {
.map(|(group_width, lines)| {
if debug {
println!(
"Processing group of {} lines of width {}",
Expand All @@ -489,12 +492,11 @@ impl TextRecognizer {
group_width as usize,
);

// TODO - Propagate errors from recognition model to caller.
let rec_output = self.run(rec_input).expect("recognition failed");
let rec_output = self.run(rec_input)?;
let ctc_input_len = rec_output.shape()[1];

// Apply CTC decoding to get the label sequence for each line.
lines
let line_rec_results = lines
.into_iter()
.enumerate()
.map(|(group_line_index, line)| {
Expand All @@ -513,10 +515,15 @@ impl TextRecognizer {
ctc_output,
}
})
.collect::<Vec<_>>()
.collect::<Vec<_>>();

Ok(line_rec_results)
})
.collect();

let mut line_rec_results: Vec<LineRecResult> =
batch_rec_results?.into_iter().flatten().collect();

// The recognition outputs are in a different order than the inputs due to
// batching and parallel processing. Re-sort them into input order.
line_rec_results.sort_by_key(|result| result.line.index);
Expand Down

0 comments on commit 190baec

Please sign in to comment.