Skip to content

Commit

Permalink
fix mypy and precommit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
picciama committed Nov 9, 2023
1 parent fd01c24 commit a0f8a15
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 28 deletions.
32 changes: 20 additions & 12 deletions oktoberfest/predict/koina.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import time
from functools import partial
from typing import Dict, List, Optional, Union
from typing import Dict, Generator, KeysView, List, Optional, Union

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tritonclient.grpc import InferenceServerClient, InferenceServerException, InferInput, InferRequestedOutput
from tritonclient.grpc import (
InferenceServerClient,
InferenceServerException,
InferInput,
InferRequestedOutput,
InferResult,
)


class Koina:
"""A class for interacting with Koina models for inference."""

model_inputs: Dict[str, np.ndarray]
model_inputs: Dict[str, str]
model_outputs: Dict[str, np.ndarray]
batch_size: int

Expand Down Expand Up @@ -149,7 +155,7 @@ def __get_batchsize(self):
self.batchsize = self.client.get_model_config(self.model_name).config.max_batch_size

@staticmethod
def __get_batch_outputs(names: List[str]) -> List[InferRequestedOutput]:
def __get_batch_outputs(names: KeysView[str]) -> List[InferRequestedOutput]:
"""
Create InferRequestedOutput objects for the given output names.
Expand Down Expand Up @@ -181,7 +187,7 @@ def __get_batch_inputs(self, data: Dict[str, np.ndarray]) -> List[InferInput]:
batch_inputs[-1].set_data_from_numpy(data[iname].reshape(-1, 1).astype(self.type_convert[idtype]))
return batch_inputs

def __extract_predictions(self, infer_result: InferRequestedOutput) -> Dict[str, np.ndarray]:
def __extract_predictions(self, infer_result: InferResult) -> Dict[str, np.ndarray]:
"""
Extract the predictions from an inference result.
Expand Down Expand Up @@ -216,7 +222,9 @@ def __predict_batch(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:

return self.__extract_predictions(infer_result)

Check warning on line 223 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L223

Added line #L223 was not covered by tests

def __predict_sequential(self, data: Dict[str, np.ndarray], disable_progress_bar: bool) -> Dict[str, np.ndarray]:
def __predict_sequential(
self, data: Dict[str, np.ndarray], disable_progress_bar: bool = False
) -> Dict[str, np.ndarray]:
"""
Perform sequential inference and return the predictions.
Expand All @@ -242,7 +250,7 @@ def __predict_sequential(self, data: Dict[str, np.ndarray], disable_progress_bar
return predictions

Check warning on line 250 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L249-L250

Added lines #L249 - L250 were not covered by tests

@staticmethod
def __slice_dict(data: Dict[str, np.ndarray], batchsize: int) -> Dict[str, np.ndarray]:
def __slice_dict(data: Dict[str, np.ndarray], batchsize: int) -> Generator[Dict[str, np.ndarray], None, None]:
"""
Slice the input data into batches of a specified batch size.
Expand Down Expand Up @@ -324,7 +332,7 @@ def __merge_list_dict_array(dict_list: List[Dict[str, np.ndarray]]) -> Dict[str,
out[k] = np.concatenate([x[k] for x in dict_list])
return out

def __async_callback(self, infer_results: List[InferRequestedOutput], result: np.ndarray, error):
def __async_callback(self, infer_results: List[InferResult], result: InferResult, error):
"""
Callback function for asynchronous inference.
Expand All @@ -344,7 +352,7 @@ def __async_callback(self, infer_results: List[InferRequestedOutput], result: np
infer_results.append(result)

def __async_predict_batch(
self, data: Dict[str, np.ndarray], infer_results: List[InferRequestedOutput], request_id: int, timeout: int = 10
self, data: Dict[str, np.ndarray], infer_results: List[InferResult], request_id: int, timeout: int = 10
):
"""
Perform asynchronous batch inference on the given data using the Koina model.
Expand Down Expand Up @@ -426,7 +434,7 @@ def __predict_async(self, data: Dict[str, np.ndarray], disable_progress_bar: boo
:return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays
representing the model's output.
"""
infer_results: List[Dict[str, np.ndarray]] = []
infer_results: List[InferResult] = []
for i, data_batch in enumerate(self.__slice_dict(data, self.batchsize)):
self.__async_predict_batch(data_batch, infer_results, request_id=i)

Expand All @@ -439,9 +447,9 @@ def __predict_async(self, data: Dict[str, np.ndarray], disable_progress_bar: boo
pbar.refresh()

# sort according to request id
infer_results = [
infer_results_to_return = [
self.__extract_predictions(infer_results[i])
for i in np.argsort(np.array([int(y.get_response("id")["id"]) for y in infer_results]))
]

return self.__merge_list_dict_array(infer_results)
return self.__merge_list_dict_array(infer_results_to_return)
8 changes: 4 additions & 4 deletions tests/unit_tests/data/predictions/library_input.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
,"MODIFIED_SEQUENCE","COLLISION_ENERGY","PRECURSOR_CHARGE","FRAGMENTATION"
0,"[UNIMOD:737]-PEPTIDEK[UNIMOD:737]",30,2,"HCD"
1,"[UNIMOD:737]-PEPTIDE",30,2,"HCD"
2,"[UNIMOD:737]-M[UNIMOD:35]EC[UNIMOD:4]TIDEK[UNIMOD:737]",35,1,"CID"
"MODIFIED_SEQUENCE","COLLISION_ENERGY","PRECURSOR_CHARGE","FRAGMENTATION"
"[UNIMOD:737]-PEPTIDEK[UNIMOD:737]",30,2,"HCD"
"[UNIMOD:737]-PEPTIDE",30,2,"HCD"
"[UNIMOD:737]-M[UNIMOD:35]EC[UNIMOD:4]TIDEK[UNIMOD:737]",35,1,"CID"
Loading

0 comments on commit a0f8a15

Please sign in to comment.