diff --git a/docs/API.rst b/docs/API.rst index 7d4987e3..b510d72f 100644 --- a/docs/API.rst +++ b/docs/API.rst @@ -65,8 +65,7 @@ Koina interface .. autosummary:: :toctree: api/pr - pr.grpc_predict - pr.infer_predictions + pr.predict Postprocessing koina response ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -74,7 +73,6 @@ Postprocessing koina response .. autosummary:: :toctree: api/pr - pr.parse_fragment_labels diff --git a/oktoberfest/data/spectra.py b/oktoberfest/data/spectra.py index 9ec6b7ae..13b64924 100644 --- a/oktoberfest/data/spectra.py +++ b/oktoberfest/data/spectra.py @@ -30,7 +30,6 @@ class Spectra: INTENSITY_COLUMN_PREFIX = "INTENSITY_RAW" INTENSITY_PRED_PREFIX = "INTENSITY_PRED" MZ_COLUMN_PREFIX = "MZ_RAW" - EPSILON = 1e-7 COLUMNS_FRAGMENT_ION = ["Y1+", "Y1++", "Y1+++", "B1+", "B1++", "B1+++"] spectra_data: pd.DataFrame @@ -129,7 +128,7 @@ def add_matrix(self, intensity_data: pd.Series, fragment_type: FragmentType) -> # Change zeros to epislon to keep the info of invalid values # change the -1 values to 0 (for better performance when converted to sparse representation) - intensity_array[intensity_array == 0] = Spectra.EPSILON + intensity_array[intensity_array == 0] = c.EPSILON intensity_array[intensity_array == -1] = 0.0 # generate column names and build dataframe from sparse matrix diff --git a/oktoberfest/predict/__init__.py b/oktoberfest/predict/__init__.py index 8102e23c..12fff39e 100644 --- a/oktoberfest/predict/__init__.py +++ b/oktoberfest/predict/__init__.py @@ -1,2 +1,3 @@ """Init predict.""" +from .koina import Koina from .predict import * diff --git a/oktoberfest/predict/koina.py b/oktoberfest/predict/koina.py new file mode 100644 index 00000000..9dc49744 --- /dev/null +++ b/oktoberfest/predict/koina.py @@ -0,0 +1,455 @@ +import time +from functools import partial +from typing import Dict, Generator, KeysView, List, Optional, Union + +import numpy as np +import pandas as pd +from tqdm.auto import tqdm +from tritonclient.grpc import ( + InferenceServerClient, + InferenceServerException, + InferInput, + InferRequestedOutput, + InferResult, +) + + +class Koina: + """A class for interacting with Koina models for inference.""" + + model_inputs: Dict[str, str] + model_outputs: Dict[str, np.ndarray] + batch_size: int + + def __init__( + self, + model_name: str, + server_url: str = "koina.proteomicsdb.org:443", + ssl: bool = True, + targets: Optional[List[str]] = None, + ): + """ + Initialize a KoinaModel instance with the specified parameters. + + This constructor initializes the KoinaModel instance, connecting it to the specified Inference Server. + It checks the availability of the server, the specified model, retrieves input and output information, + and determines the maximum batch size supported by the model's configuration. + Note: To use this class, ensure that the inference server is properly configured and running, + and that the specified model is available on the server. + + :param model_name: The name of the Koina model to be used for inference. + :param server_url: The URL of the inference server. Defaults to "koina.proteomicsdb.org:443". + :param ssl: Indicates whether to use SSL for communication with the server. Defaults to True. + :param targets: An optional list of targets to predict. If this is None, all model targets are + predicted and received. + """ + self.model_inputs = {} + self.model_outputs = {} + # self.batchsize = No + + self.model_name = model_name + self.url = server_url + self.ssl = ssl + self.client = InferenceServerClient(url=server_url, ssl=ssl) + + self.type_convert = { + "FP32": np.dtype("float32"), + "BYTES": np.dtype("O"), + "INT16": np.dtype("int16"), + "INT32": np.dtype("int32"), + "INT64": np.dtype("int64"), + } + + self._is_server_ready() + self._is_model_ready() + + self.__get_inputs() + self.__get_outputs(targets) + self.__get_batchsize() + + def _is_server_ready(self): + """ + Check if the inference server is live and accessible. + + This method checks the availability of the inference server and raises an exception if it is not live or + accessible. It ensures that the server is properly running and can be used for inference with the Koina + model. Note: This method is primarily for internal use and typically called during model initialization. + + :raises ValueError: If the server is not live or accessible. + """ + try: + if not self.client.is_server_live(): + raise ValueError("Server not yet started.") + except InferenceServerException as e: + if self.url == "koina.proteomicsdb.org:443": + if self.ssl: + raise ValueError( + "The public koina network seems to be inaccessible at the moment. " + "Please notify ludwig.lautenbacher@tum.de." + ) from e + else: + raise ValueError("To use the public koina network you need to set `ssl=True`.") from e + raise + + def _is_model_ready(self): + """ + Check if the specified model is available on the server. + + This method checks if the specified Koina model is available on the inference server. If the model is not + available, it raises an exception indicating that the model is not accessible at the provided server URL. + Note: This method is primarily for internal use and typically called during model initialization. + + :raises ValueError: If the specified model is not available at the server. + """ + if not self.client.is_model_ready(self.model_name): + raise ValueError(f"The model {self.model_name} is not available at {self.url}") + + def __get_inputs(self): + """ + Retrieve the input names and datatypes for the model. + + This method fetches the names and data types of the input tensors for the Koina model and stores them in + the 'model_inputs' attribute. Note: This method is for internal use and is typically called during model + initialization. + """ + for i in self.client.get_model_metadata(self.model_name).inputs: + self.model_inputs[i.name] = i.datatype + + def __get_outputs(self, targets: Optional[List] = None): + """ + Retrieve the output names and datatypes for the model. + + This method fetches the names and data types of the output tensors for the Koina model and stores them in + the 'model_outputs' attribute. If a list of target names is supplied, the tensors are filtered for those. + In case that the targets contain a name that is not a valid output of the requested model, a ValueError is + raised. Note: This method is for internal use and is typically called during model initialization. + + :param targets: An optional list of target names to filter the predictions for. If this is None, all targets + are added to list of output tensors to predict. + :raises ValueError: If a target supplied is not a valid output name of the requested model. + """ + model_outputs = self.client.get_model_metadata(self.model_name).outputs + model_targets = [out.name for out in model_outputs] + + if targets is None: + targets = model_targets + else: + for target in targets: + if target not in model_targets: + raise ValueError( + f"The supplied target {target} is not a valid output target of the model. " + f"Valid targets are {model_targets}." + ) + for i in model_outputs: + if i.name in targets: + self.model_outputs[i.name] = i.datatype + + def __get_batchsize(self): + """ + Get the maximum batch size supported by the model's configuration. + + This method determines the maximum batch size supported by the Koina model's configuration and stores it + in the 'batchsize' attribute. Note: This method is for internal use and is typically called during model + initialization. + """ + self.batchsize = self.client.get_model_config(self.model_name).config.max_batch_size + + @staticmethod + def __get_batch_outputs(names: KeysView[str]) -> List[InferRequestedOutput]: + """ + Create InferRequestedOutput objects for the given output names. + + This method generates InferRequestedOutput objects for the specified output names. InferRequestedOutput objects + are used to request specific outputs when performing inference. Note: This method is for internal use and is + typically called during inference. + + :param names: A list of output names for which InferRequestedOutput objects should be created. + + :return: A list of InferRequestedOutput objects. + """ + return [InferRequestedOutput(name) for name in names] + + def __get_batch_inputs(self, data: Dict[str, np.ndarray]) -> List[InferInput]: + """ + Prepare a list of InferInput objects for the input data. + + This method prepares a list of InferInput objects for the provided input data. InferInput objects are used to + specify the input tensors and their data when performing inference. Note: This method is for internal use and + is typically called during inference. + + :param data: A dictionary containing input data for inference. Keys are input names, and values are numpy arrays. + + :return: A list of InferInput objects for the input data. + """ + batch_inputs = [] + for iname, idtype in self.model_inputs.items(): + batch_inputs.append(InferInput(iname, (len(data[next(iter(data))]), 1), idtype)) + batch_inputs[-1].set_data_from_numpy(data[iname].reshape(-1, 1).astype(self.type_convert[idtype])) + return batch_inputs + + def __extract_predictions(self, infer_result: InferResult) -> Dict[str, np.ndarray]: + """ + Extract the predictions from an inference result. + + This method extracts the predictions from an inference result and organizes them in a dictionary with output + names as keys and corresponding arrays as values. Note: This method is for internal use and is typically called + during inference. + + :param infer_result: The result of an inference operation. + + :return: A dictionary containing the extracted predictions. Keys are output names, and values are numpy arrays. + """ + predictions = {} + for oname in self.model_outputs.keys(): + predictions[oname] = infer_result.as_numpy(oname) + return predictions + + def __predict_batch(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Perform batch inference and return the predictions. + + This method performs batch inference on the provided input data using the configured Koina model and returns the + predictions. Note: This method is for internal use and is typically called during inference. + + :param data: A dictionary containing input data for batch inference. Keys are input names, and values are numpy arrays. + + :return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays + representing the model's output. + """ + batch_outputs = self.__get_batch_outputs(self.model_outputs.keys()) + batch_inputs = self.__get_batch_inputs(data) + infer_result = self.client.infer(self.model_name, inputs=batch_inputs, outputs=batch_outputs) + + return self.__extract_predictions(infer_result) + + def __predict_sequential( + self, data: Dict[str, np.ndarray], disable_progress_bar: bool = False + ) -> Dict[str, np.ndarray]: + """ + Perform sequential inference and return the predictions. + + This method performs sequential inference on the provided input data using the configured Koina model. It processes + the input data batch by batch and returns the predictions. You can choose to disable the progress bar during inference + using the 'disable_progress_bar' parameter. Note: This method is for internal use and is typically called during inference. + + :param data: A dictionary containing input data for inference. Keys are input names, and values are numpy arrays. + :param disable_progress_bar: If True, disable the progress bar during inference. + + :return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays representing + the model's output. + """ + predictions: Dict[str, np.ndarray] = {} + for data_batch in tqdm( + self.__slice_dict(data, self.batchsize), desc="Getting predictions", disable=disable_progress_bar + ): + pred_batch = self.__predict_batch(data_batch) + if predictions: + predictions = self.__merge_array_dict(predictions, pred_batch) + else: + predictions = pred_batch # Only first iteration to initialize dict keys + return predictions + + @staticmethod + def __slice_dict(data: Dict[str, np.ndarray], batchsize: int) -> Generator[Dict[str, np.ndarray], None, None]: + """ + Slice the input data into batches of a specified batch size. + + This method takes the input data and divides it into smaller batches, each containing 'batchsize' elements. It yields + these batches one at a time, allowing for batched processing of input data. Note: This method is for internal use and + is typically called during batched inference. + + :param data: A dictionary containing input data for batch inference. Keys are input names, and values are numpy arrays. + :param batchsize: The desired batch size for slicing the data. + + :yield: A dictionary containing a batch of input data with keys and values corresponding to the input names and + batched arrays. + """ + len_inputs = list(data.values())[0].shape[0] + for i in range(0, len_inputs, batchsize): + dict_slice = {} + for k, v in data.items(): + dict_slice[k] = v[i : i + batchsize] + yield dict_slice + + @staticmethod + def __merge_array_dict(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Merge two dictionaries of arrays. + + This method takes two dictionaries, 'd1' and 'd2', each containing arrays with identical keys. It merges the + arrays from both dictionaries, creating a new dictionary with the same keys and combined arrays. Note: This + method is for internal use and is typically called during batched inference. + + :param d1: A dictionary containing arrays. + :param d2: Another dictionary containing arrays with the same keys as d1. + + :raises NotImplementedError: If the keys in 'd1' and 'd2' do not match. + :return: A dictionary containing merged arrays with the same keys as d1 and d2. + + Example: + ``` + dict1 = {"output1": np.array([1.0, 2.0, 3.0]), "output2": np.array([4.0, 5.0, 6.0])} + dict2 = {"output1": np.array([7.0, 8.0, 9.0]), "output2": np.array([10.0, 11.0, 12.0])} + merged_dict = model.__merge_array_dict(dict1, dict2) + print(merged_dict) + ``` + """ + if d1.keys() != d2.keys(): + raise NotImplementedError(f"Keys in dictionary need to be equal {d1.keys(), d2.keys()}") + out = {} + for k in d1.keys(): + out[k] = np.concatenate([d1[k], d2[k]]) + return out + + @staticmethod + def __merge_list_dict_array(dict_list: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: + """ + Merge a list of dictionaries of arrays. + + This method takes a list of dictionaries, where each dictionary contains arrays with identical keys. It merges + the arrays from all dictionaries in the list, creating a new dictionary with the same keys and combined arrays. + Note: This method is for internal use and is typically called during batched inference. + + :param dict_list: A list of dictionaries, each containing arrays with the same keys. + :raises NotImplementedError: If the keys of all dictionaries in the list do not match. + + :return: A dictionary containing merged arrays with the same keys as the dictionaries in the list. + + Example:: + dict_list = [ + {"output1": np.array([1.0, 2.0, 3.0]), "output2": np.array([4.0, 5.0, 6.0])}, + {"output1": np.array([7.0, 8.0, 9.0]), "output2": np.array([10.0, 11.0, 12.0])}, + {"output1": np.array([13.0, 14.0, 15.0]), "output2": np.array([16.0, 17.0, 18.0])}, + ] + merged_dict = model.__merge_list_dict_array(dict_list) + print(merged_dict) + """ + tmp = [x.keys() for x in dict_list] + if not np.all([tmp[0] == x for x in tmp]): + raise NotImplementedError(f"Keys of all dictionaries in the list need to be equal {tmp}") + out = {} + for k in tmp[0]: + out[k] = np.concatenate([x[k] for x in dict_list]) + return out + + def __async_callback(self, infer_results: List[InferResult], result: InferResult, error): + """ + Callback function for asynchronous inference. + + This method serves as a callback function for asynchronous inference. It is invoked when an asynchronous + inference task is completed. The result of the task is appended to the 'infer_results' list, and any + encountered error is checked and handled appropriately. Note: This method is for internal use and is typically + called during asynchronous inference. + + :param infer_results: A list to which the results of asynchronous inference will be appended. + :param result: The result of an asynchronous inference operation. + :param error: An error, if any, encountered during asynchronous inference. + :raises error: if any exception was encountered during asynchronous inference. + """ + if error: + raise error + else: + infer_results.append(result) + + def __async_predict_batch( + self, data: Dict[str, np.ndarray], infer_results: List[InferResult], request_id: int, timeout: int = 10 + ): + """ + Perform asynchronous batch inference on the given data using the Koina model. + + This method initiates asynchronous batch inference on the provided input data using the configured Koina model. + Results will be appended to the 'infer_results' list as they become available. The 'id' parameter is used to + identify and order the results. The method will return when the inference request is completed or when the + 'timeout' is reached. + + :param data: A dictionary containing input data for batch inference. Keys are input names, and values are numpy arrays. + :param infer_results: A list to which the results of asynchronous inference will be appended. + :param request_id: An identifier for the inference request, used to track the order of completion. + :param timeout: The maximum time (in seconds) to wait for the inference to complete. Defaults to 10 seconds. + """ + batch_outputs = self.__get_batch_outputs(self.model_outputs.keys()) + batch_inputs = self.__get_batch_inputs(data) + + self.client.async_infer( + model_name=self.model_name, + request_id=str(request_id), + inputs=batch_inputs, + callback=partial(self.__async_callback, infer_results), + outputs=batch_outputs, + client_timeout=timeout, + ) + + def predict( + self, data: Union[Dict[str, np.ndarray], pd.DataFrame], disable_progress_bar: bool = False, _async: bool = True + ) -> Dict[str, np.ndarray]: + """ + Perform inference on the given data using the Koina model. + + This method allows you to perform inference on the provided input data using the configured Koina model. You can + choose to perform inference asynchronously (in parallel) or sequentially, depending on the value of the '_async' + parameter. If asynchronous inference is selected, the method will return when all inference tasks are complete. + Note: Ensure that the model and server are properly configured and that the input data matches the model's + nput requirements. + + :param data: A dictionary or dataframe containing input data for inference. For the dictionary, keys are input names, + and values are numpy arrays. In case of a dataframe, the input fields for the requested model must be present + in the column names. + :param disable_progress_bar: If True, disable the progress bar during inference. Defaults to False. + :param _async: If True, perform asynchronous inference; if False, perform sequential inference. Defaults to True. + + :return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays + representing the model's output. + + Example:: + model = KoinaModel("Prosit_2019_intensity") + input_data = { + "peptide_sequences": np.array(["PEPTIDEK" for _ in range(size)]), + "precursor_charges": np.array([2 for _ in range(size)]), + "collision_energies": np.array([20 for _ in range(size)]), + "fragmentation_types": np.array(["HCD" for _ in range(size)]), + "instrument_types": np.array(["QE" for _ in range(size)]) + } + predictions = model.predict(input_data) + """ + if isinstance(data, pd.DataFrame): + data = {input_field: data[input_field].to_numpy() for input_field in self.model_inputs.keys()} + if _async: + pred_func = self.__predict_async + else: + pred_func = self.__predict_sequential + return pred_func(data, disable_progress_bar=disable_progress_bar) + + def __predict_async(self, data: Dict[str, np.ndarray], disable_progress_bar: bool = False) -> Dict[str, np.ndarray]: + """ + Perform asynchronous inference on the given data using the Koina model. + + This method performs asynchronous inference on the provided input data using the configured Koina model. + Asynchronous inference allows for parallel processing of input data, potentially leading to faster results. + The method will return when all asynchronous inference tasks are complete. Note: Ensure that the model and server + are properly configured and that the input data matches the model's input requirements. + + :param data: A dictionary containing input data for inference. Keys are input names, and values are numpy arrays. + :param disable_progress_bar: If True, disable the progress bar during asynchronous inference. Defaults to False. + + :return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays + representing the model's output. + """ + infer_results: List[InferResult] = [] + for i, data_batch in enumerate(self.__slice_dict(data, self.batchsize)): + self.__async_predict_batch(data_batch, infer_results, request_id=i) + + with tqdm(total=i + 1, desc="Getting predictions", disable=disable_progress_bar) as pbar: + while len(infer_results) != i + 1: + pbar.n = len(infer_results) + pbar.refresh() + time.sleep(1) + pbar.n = len(infer_results) + pbar.refresh() + + # sort according to request id + infer_results_to_return = [ + self.__extract_predictions(infer_results[i]) + for i in np.argsort(np.array([int(y.get_response("id")["id"]) for y in infer_results])) + ] + + return self.__merge_list_dict_array(infer_results_to_return) diff --git a/oktoberfest/predict/predict.py b/oktoberfest/predict/predict.py index 5a849bbc..9faf6b7a 100644 --- a/oktoberfest/predict/predict.py +++ b/oktoberfest/predict/predict.py @@ -1,164 +1,61 @@ import logging import re -from math import ceil -from multiprocessing import current_process -from typing import Dict, List, Tuple +from typing import Dict import numpy as np import pandas as pd from spectrum_fundamentals.metrics.similarity import SimilarityMetrics -from tqdm.auto import tqdm -from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput from ..data.spectra import FragmentType, Spectra +from .koina import Koina logger = logging.getLogger(__name__) -def grpc_predict( - library: Spectra, - url: str, - intensity_model: str, - irt_model: str, - ssl: bool = True, - alignment: bool = False, - job_type: str = "", -): +def predict(data: pd.DataFrame, *args, **kwargs) -> Dict[str, np.ndarray]: """ - Use grpc to predict library and add predictions to library. + Retrieve predictions from koina. - :param library: Spectra object with the library - :param url: Url including the port of the prediction server - :param intensity_model: the name of the intensity model on the server - :param irt_model: the name of the irt model on the server - :param ssl: whether or not the server requires an ssl encrypted transportation, default = True - :param alignment: True if alignment present - :param job_type: TODO - :return: grpc predictions if we are trying to generate spectral library - """ - triton_client = InferenceServerClient(url=url, ssl=ssl) - batch_size = 1000 - - intensity_outputs = ["intensities", "mz", "annotation"] - intensity_input_data = { - "peptide_sequences": ( - library.spectra_data["MODIFIED_SEQUENCE"].to_numpy().reshape(-1, 1).astype(np.object_), - "BYTES", - ), - "collision_energies": ( - library.spectra_data["COLLISION_ENERGY"].to_numpy().reshape(-1, 1).astype(np.float32), - "FP32", - ), - "precursor_charges": ( - library.spectra_data["PRECURSOR_CHARGE"].to_numpy().reshape(-1, 1).astype(np.int32), - "INT32", - ), - } - if "tmt" in intensity_model.lower() or "ptm" in intensity_model.lower(): - intensity_input_data["fragmentation_types"] = ( - library.spectra_data["FRAGMENTATION"].to_numpy().reshape(-1, 1).astype(np.object_), - "BYTES", - ) + This function takes a dataframe containing information about PSMS and predicts peptide + properties using a koina server. The configuration of koina is set using the kwargs. + See the koina predict function for details. TODO, link this properly. - intensity_predictions = infer_predictions( - triton_client, - model=intensity_model, - input_data=intensity_input_data, - outputs=intensity_outputs, - batch_size=batch_size, - ) - intensity_predictions["intensities"][np.where(intensity_predictions["intensities"] < 1e-7)] = 0.0 + :param data: Dataframe containing the data for the prediction. + :param args: Additional positional arguments forwarded to Koina::predict + :param kwargs: Additional keyword arguments forwarded to Koina::predict - irt_input_data = {"peptide_sequences": intensity_input_data["peptide_sequences"]} - irt_outputs = ["irt"] - irt_predictions = infer_predictions( - triton_client, - model=irt_model, - input_data=irt_input_data, - outputs=irt_outputs, - batch_size=batch_size, + :return: a dictionary with targets (keys) and predictions (values) + """ + predictor = Koina(*args, **kwargs) + + data.rename( + columns={ + "MODIFIED_SEQUENCE": "peptide_sequences", + "PRECURSOR_CHARGE": "precursor_charges", + "COLLISION_ENERGY": "collision_energies", + "FRAGMENTATION": "fragmentation_types", + }, + inplace=True, ) - if job_type == "SpectralLibraryGeneration": - intensity_prediction_dict = { - "intensity": intensity_predictions["intensities"], - "fragmentmz": intensity_predictions["mz"], - "annotation": parse_fragment_labels( - intensity_predictions["annotation"], - library.spectra_data["PRECURSOR_CHARGE"].to_numpy()[:, None], - library.spectra_data["PEPTIDE_LENGTH"].to_numpy()[:, None], - ), - } - output_dict = {intensity_model: intensity_prediction_dict, irt_model: irt_predictions["irt"]} - return output_dict - - intensities_pred = pd.DataFrame() - intensities_pred["intensity"] = intensity_predictions["intensities"].tolist() - library.add_matrix(intensities_pred["intensity"], FragmentType.PRED) + results = predictor.predict(data) - if alignment: - return + data.rename( + columns={ + "peptide_sequences": "MODIFIED_SEQUENCE", + "precursor_charges": "PRECURSOR_CHARGE", + "collision_energies": "COLLISION_ENERGY", + "fragmentation_types": "FRAGMENTATION", + }, + inplace=True, + ) - library.add_column(irt_predictions["irt"], name="PREDICTED_IRT") + return results -def infer_predictions( - triton_client: InferenceServerClient, - model: str, - input_data: Dict[str, Tuple[np.ndarray, str]], - outputs: List[str], - batch_size: int, +def parse_fragment_labels( + spectra_labels: np.ndarray, precursor_charges: np.ndarray, seq_lengths: np.ndarray ) -> Dict[str, np.ndarray]: - """ - Infer predictions from a triton client. - - :param triton_client: An inference client using grpc - :param model: a model that is recognized by the server specified in the triton client - :param input_data: a dictionary that contains the input names (key) for the specific model - and a tuple of the input_data as a numpy array of shape [:, 1] and the dtype recognized - by the triton client (value). - :param outputs: a list of output names for the specific model - :param batch_size: the number of elements from the input_data that should be provided to the - triton client at once - :return: a dictionary containing the predictions (values) for the given outputs (keys) - """ - num_spec = len(input_data[list(input_data)[0]][0]) - predictions: Dict[str, List[np.ndarray]] = {output: [] for output in outputs} - - n_batches = ceil(num_spec / batch_size) - process_identity = current_process()._identity - if len(process_identity) > 0: - position = process_identity[0] - else: - position = 0 - - with tqdm( - total=n_batches, - position=position, - desc=f"Inferring predictions for {num_spec} spectra with batch site {batch_size}", - leave=True, - ) as progress: - for i in range(0, n_batches): - progress.update(1) - # logger.info(f"Predicting batch {i+1}/{n_batches}.") - infer_inputs = [] - for input_key, (data, dtype) in input_data.items(): - batch_data = data[i * batch_size : (i + 1) * batch_size] - infer_input = InferInput(input_key, batch_data.shape, dtype) - infer_input.set_data_from_numpy(batch_data) - infer_inputs.append(infer_input) - - infer_outputs = [InferRequestedOutput(output) for output in outputs] - - prediction = triton_client.infer(model, inputs=infer_inputs, outputs=infer_outputs) - - for output in outputs: - predictions[output].append(prediction.as_numpy(output)) - - return {key: np.vstack(value) for key, value in predictions.items()} - - -def parse_fragment_labels(spectra_labels: np.ndarray, precursor_charges: np.ndarray, seq_lengths: np.ndarray): """Uses regex to parse labels.""" pattern = rb"([y|b])([0-9]{1,2})\+([1-3])" fragment_types = [] @@ -231,11 +128,12 @@ def ce_calibration(library: Spectra, **server_kwargs) -> pd.Series: between predicted and observed intensities before returning the alignment library. :param library: spectral library to perform CE calibration on - :param server_kwargs: Additional parameters that are forwarded to grpc_predict + :param server_kwargs: Additional parameters that are forwarded to the prediction method :return: pandas series containing the spectral angle for all tested collision energies """ alignment_library = _prepare_alignment_df(library) - grpc_predict(alignment_library, alignment=True, **server_kwargs) + intensities = predict(alignment_library.spectra_data, **server_kwargs) + alignment_library.add_matrix(pd.Series(intensities["intensities"].tolist(), name="intensities"), FragmentType.PRED) _alignment(alignment_library) return alignment_library diff --git a/oktoberfest/runner.py b/oktoberfest/runner.py index 9c914264..0650aa21 100644 --- a/oktoberfest/runner.py +++ b/oktoberfest/runner.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import List, Type, Union +import pandas as pd from spectrum_io.spectral_library import MSP, DLib, SpectralLibrary, Spectronaut from oktoberfest import __copyright__, __version__ @@ -12,7 +13,7 @@ from oktoberfest import preprocessing as pp from oktoberfest import rescore as re -from .data.spectra import Spectra +from .data.spectra import FragmentType, Spectra from .utils import Config, JobPool, ProcessStep logger = logging.getLogger(__name__) @@ -90,10 +91,9 @@ def _get_best_ce(library: Spectra, spectra_file: Path, config: Config) -> int: results_dir.mkdir(exist_ok=True) if (library.spectra_data["FRAGMENTATION"] == "HCD").any(): server_kwargs = { - "url": config.prediction_server, + "server_url": config.prediction_server, "ssl": config.ssl, - "intensity_model": config.models["intensity"], - "irt_model": config.models["irt"], + "model_name": config.models["intensity"], } alignment_library = pr.ce_calibration(library, **server_kwargs) ce_alignment = alignment_library.spectra_data.groupby(by=["COLLISION_ENERGY"])["SPECTRAL_ANGLE"].mean() @@ -154,11 +154,8 @@ def generate_spectral_lib(config_path: Union[str, Path]): no_of_sections = no_of_spectra // 7000 server_kwargs = { - "url": config.prediction_server, + "server_url": config.prediction_server, "ssl": config.ssl, - "intensity_model": config.models["intensity"], - "irt_model": config.models["irt"], - "job_type": "SpectralLibraryGeneration", } spectral_library: Type[SpectralLibrary] @@ -192,9 +189,21 @@ def generate_spectral_lib(config_path: Union[str, Path]): else: break - grpc_output_sec = pr.grpc_predict(spectra_div, **server_kwargs) + pred_intensities = pr.predict(spectra_div.spectra_data, model_name=config.models["intensity"], **server_kwargs) + pred_irts = pr.predict(spectra_div.spectra_data, model_name=config.models["irt"], **server_kwargs) + + intensity_prediction_dict = { + "intensity": pred_intensities["intensities"], + "fragmentmz": pred_intensities["mz"], + "annotation": pr.parse_fragment_labels( + pred_intensities["annotation"], + spectra_div.spectra_data["PRECURSOR_CHARGE"].to_numpy()[:, None], + spectra_div.spectra_data["PEPTIDE_LENGTH"].to_numpy()[:, None], + ), + } + output_dict = {config.models["intensity"]: intensity_prediction_dict, config.models["irt"]: pred_irts["irt"]} - out_lib = spectral_library(spectra_div.spectra_data, grpc_output_sec, out_file) + out_lib = spectral_library(spectra_div.spectra_data, output_dict, out_file) out_lib.prepare_spectrum() out_lib.write() @@ -255,13 +264,21 @@ def _calculate_features(spectra_file: Path, config: Config): return server_kwargs = { - "url": config.prediction_server, + "server_url": config.prediction_server, "ssl": config.ssl, - "intensity_model": config.models["intensity"], - "irt_model": config.models["irt"], } - pr.grpc_predict(library, **server_kwargs) + pred_intensities = pr.predict( + library.spectra_data, + model_name=config.models["intensity"], + targets=["intensities", "annotation"], + **server_kwargs, + ) + pred_irts = pr.predict(library.spectra_data, model_name=config.models["irt"], **server_kwargs) + + library.add_matrix(pd.Series(pred_intensities["intensities"].tolist(), name="intensities"), FragmentType.PRED) + library.add_column(pred_irts["irt"], name="PREDICTED_IRT") + library.write_pred_as_hdf5(config.output / "data" / spectra_file.with_suffix(".mzml.pred.hdf5").name) # produce percolator tab files diff --git a/tests/unit_tests/data/predictions/library_input.csv b/tests/unit_tests/data/predictions/library_input.csv index f8b6886e..b8dad5fd 100644 --- a/tests/unit_tests/data/predictions/library_input.csv +++ b/tests/unit_tests/data/predictions/library_input.csv @@ -1,4 +1,4 @@ -,"MODIFIED_SEQUENCE","COLLISION_ENERGY","PRECURSOR_CHARGE","FRAGMENTATION" -0,"[UNIMOD:737]-PEPTIDEK[UNIMOD:737]",30,2,"HCD" -1,"[UNIMOD:737]-PEPTIDE",30,2,"HCD" -2,"[UNIMOD:737]-M[UNIMOD:35]EC[UNIMOD:4]TIDEK[UNIMOD:737]",35,1,"CID" +"MODIFIED_SEQUENCE","COLLISION_ENERGY","PRECURSOR_CHARGE","FRAGMENTATION" +"[UNIMOD:737]-PEPTIDEK[UNIMOD:737]",30,2,"HCD" +"[UNIMOD:737]-PEPTIDE",30,2,"HCD" +"[UNIMOD:737]-M[UNIMOD:35]EC[UNIMOD:4]TIDEK[UNIMOD:737]",35,1,"CID" diff --git a/tests/unit_tests/data/predictions/library_output.csv b/tests/unit_tests/data/predictions/library_output.csv index 2dc9c742..ce89bef2 100644 --- a/tests/unit_tests/data/predictions/library_output.csv +++ b/tests/unit_tests/data/predictions/library_output.csv @@ -1,4 +1,4 @@ -,MODIFIED_SEQUENCE,COLLISION_ENERGY,PRECURSOR_CHARGE,FRAGMENTATION,INTENSITY_PRED_Y1+,INTENSITY_PRED_Y1++,INTENSITY_PRED_Y1+++,INTENSITY_PRED_B1+,INTENSITY_PRED_B1++,INTENSITY_PRED_B1+++,INTENSITY_PRED_Y2+,INTENSITY_PRED_Y2++,INTENSITY_PRED_Y2+++,INTENSITY_PRED_B2+,INTENSITY_PRED_B2++,INTENSITY_PRED_B2+++,INTENSITY_PRED_Y3+,INTENSITY_PRED_Y3++,INTENSITY_PRED_Y3+++,INTENSITY_PRED_B3+,INTENSITY_PRED_B3++,INTENSITY_PRED_B3+++,INTENSITY_PRED_Y4+,INTENSITY_PRED_Y4++,INTENSITY_PRED_Y4+++,INTENSITY_PRED_B4+,INTENSITY_PRED_B4++,INTENSITY_PRED_B4+++,INTENSITY_PRED_Y5+,INTENSITY_PRED_Y5++,INTENSITY_PRED_Y5+++,INTENSITY_PRED_B5+,INTENSITY_PRED_B5++,INTENSITY_PRED_B5+++,INTENSITY_PRED_Y6+,INTENSITY_PRED_Y6++,INTENSITY_PRED_Y6+++,INTENSITY_PRED_B6+,INTENSITY_PRED_B6++,INTENSITY_PRED_B6+++,INTENSITY_PRED_Y7+,INTENSITY_PRED_Y7++,INTENSITY_PRED_Y7+++,INTENSITY_PRED_B7+,INTENSITY_PRED_B7++,INTENSITY_PRED_B7+++,INTENSITY_PRED_Y8+,INTENSITY_PRED_Y8++,INTENSITY_PRED_Y8+++,INTENSITY_PRED_B8+,INTENSITY_PRED_B8++,INTENSITY_PRED_B8+++,INTENSITY_PRED_Y9+,INTENSITY_PRED_Y9++,INTENSITY_PRED_Y9+++,INTENSITY_PRED_B9+,INTENSITY_PRED_B9++,INTENSITY_PRED_B9+++,INTENSITY_PRED_Y10+,INTENSITY_PRED_Y10++,INTENSITY_PRED_Y10+++,INTENSITY_PRED_B10+,INTENSITY_PRED_B10++,INTENSITY_PRED_B10+++,INTENSITY_PRED_Y11+,INTENSITY_PRED_Y11++,INTENSITY_PRED_Y11+++,INTENSITY_PRED_B11+,INTENSITY_PRED_B11++,INTENSITY_PRED_B11+++,INTENSITY_PRED_Y12+,INTENSITY_PRED_Y12++,INTENSITY_PRED_Y12+++,INTENSITY_PRED_B12+,INTENSITY_PRED_B12++,INTENSITY_PRED_B12+++,INTENSITY_PRED_Y13+,INTENSITY_PRED_Y13++,INTENSITY_PRED_Y13+++,INTENSITY_PRED_B13+,INTENSITY_PRED_B13++,INTENSITY_PRED_B13+++,INTENSITY_PRED_Y14+,INTENSITY_PRED_Y14++,INTENSITY_PRED_Y14+++,INTENSITY_PRED_B14+,INTENSITY_PRED_B14++,INTENSITY_PRED_B14+++,INTENSITY_PRED_Y15+,INTENSITY_PRED_Y15++,INTENSITY_PRED_Y15+++,INTENSITY_PRED_B15+,INTENSITY_PRED_B15++,INTENSITY_PRED_B15+++,INTENSITY_PRED_Y16+,INTENSITY_PRED_Y16++,INTENSITY_PRED_Y16+++,INTENSITY_PRED_B16+,INTENSITY_PRED_B16++,INTENSITY_PRED_B16+++,INTENSITY_PRED_Y17+,INTENSITY_PRED_Y17++,INTENSITY_PRED_Y17+++,INTENSITY_PRED_B17+,INTENSITY_PRED_B17++,INTENSITY_PRED_B17+++,INTENSITY_PRED_Y18+,INTENSITY_PRED_Y18++,INTENSITY_PRED_Y18+++,INTENSITY_PRED_B18+,INTENSITY_PRED_B18++,INTENSITY_PRED_B18+++,INTENSITY_PRED_Y19+,INTENSITY_PRED_Y19++,INTENSITY_PRED_Y19+++,INTENSITY_PRED_B19+,INTENSITY_PRED_B19++,INTENSITY_PRED_B19+++,INTENSITY_PRED_Y20+,INTENSITY_PRED_Y20++,INTENSITY_PRED_Y20+++,INTENSITY_PRED_B20+,INTENSITY_PRED_B20++,INTENSITY_PRED_B20+++,INTENSITY_PRED_Y21+,INTENSITY_PRED_Y21++,INTENSITY_PRED_Y21+++,INTENSITY_PRED_B21+,INTENSITY_PRED_B21++,INTENSITY_PRED_B21+++,INTENSITY_PRED_Y22+,INTENSITY_PRED_Y22++,INTENSITY_PRED_Y22+++,INTENSITY_PRED_B22+,INTENSITY_PRED_B22++,INTENSITY_PRED_B22+++,INTENSITY_PRED_Y23+,INTENSITY_PRED_Y23++,INTENSITY_PRED_Y23+++,INTENSITY_PRED_B23+,INTENSITY_PRED_B23++,INTENSITY_PRED_B23+++,INTENSITY_PRED_Y24+,INTENSITY_PRED_Y24++,INTENSITY_PRED_Y24+++,INTENSITY_PRED_B24+,INTENSITY_PRED_B24++,INTENSITY_PRED_B24+++,INTENSITY_PRED_Y25+,INTENSITY_PRED_Y25++,INTENSITY_PRED_Y25+++,INTENSITY_PRED_B25+,INTENSITY_PRED_B25++,INTENSITY_PRED_B25+++,INTENSITY_PRED_Y26+,INTENSITY_PRED_Y26++,INTENSITY_PRED_Y26+++,INTENSITY_PRED_B26+,INTENSITY_PRED_B26++,INTENSITY_PRED_B26+++,INTENSITY_PRED_Y27+,INTENSITY_PRED_Y27++,INTENSITY_PRED_Y27+++,INTENSITY_PRED_B27+,INTENSITY_PRED_B27++,INTENSITY_PRED_B27+++,INTENSITY_PRED_Y28+,INTENSITY_PRED_Y28++,INTENSITY_PRED_Y28+++,INTENSITY_PRED_B28+,INTENSITY_PRED_B28++,INTENSITY_PRED_B28+++,INTENSITY_PRED_Y29+,INTENSITY_PRED_Y29++,INTENSITY_PRED_Y29+++,INTENSITY_PRED_B29+,INTENSITY_PRED_B29++,INTENSITY_PRED_B29+++,PREDICTED_IRT -0,[UNIMOD:737]-PEPTIDEK[UNIMOD:737],30,2,HCD,0.23809576034545898,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,0.8115681409835815,1.0000000116860974e-07,1.0000000116860974e-07,1.0,1.0000000116860974e-07,1.0000000116860974e-07,0.16860607266426086,1.0000000116860974e-07,1.0000000116860974e-07,0.009546236135065556,1.0000000116860974e-07,1.0000000116860974e-07,0.05823937803506851,1.0000000116860974e-07,1.0000000116860974e-07,0.0676455870270729,1.0000000116860974e-07,1.0000000116860974e-07,0.039169538766145706,1.0000000116860974e-07,1.0000000116860974e-07,0.07002319395542145,1.0000000116860974e-07,1.0000000116860974e-07,0.6715999841690063,0.01274949125945568,1.0000000116860974e-07,0.5392391085624695,1.0000000116860974e-07,1.0000000116860974e-07,0.02867981791496277,0.00021903926972299814,1.0000000116860974e-07,0.07469962537288666,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,39.420433 -1,[UNIMOD:737]-PEPTIDE,30,2,HCD,0.522221565246582,1.0271855899190996e-07,1.0000000116860974e-07,1.0271855899190996e-07,1.0271855899190996e-07,1.0000000116860974e-07,0.6478652358055115,1.0271855899190996e-07,1.0000000116860974e-07,0.6932786107063293,1.0271855899190996e-07,1.0000000116860974e-07,1.0271855899190996e-07,1.0271855899190996e-07,1.0000000116860974e-07,0.03741396963596344,0.04295269027352333,1.0000000116860974e-07,1.0271855899190996e-07,1.0271855899190996e-07,1.0000000116860974e-07,1.0,0.016776476055383682,1.0000000116860974e-07,0.024784639477729797,1.0271855899190996e-07,1.0000000116860974e-07,0.47810620069503784,1.0271855899190996e-07,1.0000000116860974e-07,1.0271855899190996e-07,1.0271855899190996e-07,1.0000000116860974e-07,0.6293804049491882,0.0024885176680982113,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,43.523045 -2,[UNIMOD:737]-M[UNIMOD:35]EC[UNIMOD:4]TIDEK[UNIMOD:737],35,1,CID,0.2831476628780365,1.0000000116860974e-07,1.0000000116860974e-07,1.3209350413490029e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0,1.0000000116860974e-07,1.0000000116860974e-07,1.3209350413490029e-07,1.0000000116860974e-07,1.0000000116860974e-07,0.38974782824516296,1.0000000116860974e-07,1.0000000116860974e-07,0.018663309514522552,1.0000000116860974e-07,1.0000000116860974e-07,0.16417913138866425,1.0000000116860974e-07,1.0000000116860974e-07,0.019735727459192276,1.0000000116860974e-07,1.0000000116860974e-07,0.27998510003089905,1.0000000116860974e-07,1.0000000116860974e-07,0.009887314401566982,1.0000000116860974e-07,1.0000000116860974e-07,0.48476091027259827,1.0000000116860974e-07,1.0000000116860974e-07,0.009575176984071732,1.0000000116860974e-07,1.0000000116860974e-07,0.26375487446784973,1.0000000116860974e-07,1.0000000116860974e-07,0.019234927371144295,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,1.0000000116860974e-07,24.778168 +MODIFIED_SEQUENCE,COLLISION_ENERGY,PRECURSOR_CHARGE,FRAGMENTATION,INTENSITY_PRED_Y1+,INTENSITY_PRED_Y1++,INTENSITY_PRED_Y1+++,INTENSITY_PRED_B1+,INTENSITY_PRED_B1++,INTENSITY_PRED_B1+++,INTENSITY_PRED_Y2+,INTENSITY_PRED_Y2++,INTENSITY_PRED_Y2+++,INTENSITY_PRED_B2+,INTENSITY_PRED_B2++,INTENSITY_PRED_B2+++,INTENSITY_PRED_Y3+,INTENSITY_PRED_Y3++,INTENSITY_PRED_Y3+++,INTENSITY_PRED_B3+,INTENSITY_PRED_B3++,INTENSITY_PRED_B3+++,INTENSITY_PRED_Y4+,INTENSITY_PRED_Y4++,INTENSITY_PRED_Y4+++,INTENSITY_PRED_B4+,INTENSITY_PRED_B4++,INTENSITY_PRED_B4+++,INTENSITY_PRED_Y5+,INTENSITY_PRED_Y5++,INTENSITY_PRED_Y5+++,INTENSITY_PRED_B5+,INTENSITY_PRED_B5++,INTENSITY_PRED_B5+++,INTENSITY_PRED_Y6+,INTENSITY_PRED_Y6++,INTENSITY_PRED_Y6+++,INTENSITY_PRED_B6+,INTENSITY_PRED_B6++,INTENSITY_PRED_B6+++,INTENSITY_PRED_Y7+,INTENSITY_PRED_Y7++,INTENSITY_PRED_Y7+++,INTENSITY_PRED_B7+,INTENSITY_PRED_B7++,INTENSITY_PRED_B7+++,INTENSITY_PRED_Y8+,INTENSITY_PRED_Y8++,INTENSITY_PRED_Y8+++,INTENSITY_PRED_B8+,INTENSITY_PRED_B8++,INTENSITY_PRED_B8+++,INTENSITY_PRED_Y9+,INTENSITY_PRED_Y9++,INTENSITY_PRED_Y9+++,INTENSITY_PRED_B9+,INTENSITY_PRED_B9++,INTENSITY_PRED_B9+++,INTENSITY_PRED_Y10+,INTENSITY_PRED_Y10++,INTENSITY_PRED_Y10+++,INTENSITY_PRED_B10+,INTENSITY_PRED_B10++,INTENSITY_PRED_B10+++,INTENSITY_PRED_Y11+,INTENSITY_PRED_Y11++,INTENSITY_PRED_Y11+++,INTENSITY_PRED_B11+,INTENSITY_PRED_B11++,INTENSITY_PRED_B11+++,INTENSITY_PRED_Y12+,INTENSITY_PRED_Y12++,INTENSITY_PRED_Y12+++,INTENSITY_PRED_B12+,INTENSITY_PRED_B12++,INTENSITY_PRED_B12+++,INTENSITY_PRED_Y13+,INTENSITY_PRED_Y13++,INTENSITY_PRED_Y13+++,INTENSITY_PRED_B13+,INTENSITY_PRED_B13++,INTENSITY_PRED_B13+++,INTENSITY_PRED_Y14+,INTENSITY_PRED_Y14++,INTENSITY_PRED_Y14+++,INTENSITY_PRED_B14+,INTENSITY_PRED_B14++,INTENSITY_PRED_B14+++,INTENSITY_PRED_Y15+,INTENSITY_PRED_Y15++,INTENSITY_PRED_Y15+++,INTENSITY_PRED_B15+,INTENSITY_PRED_B15++,INTENSITY_PRED_B15+++,INTENSITY_PRED_Y16+,INTENSITY_PRED_Y16++,INTENSITY_PRED_Y16+++,INTENSITY_PRED_B16+,INTENSITY_PRED_B16++,INTENSITY_PRED_B16+++,INTENSITY_PRED_Y17+,INTENSITY_PRED_Y17++,INTENSITY_PRED_Y17+++,INTENSITY_PRED_B17+,INTENSITY_PRED_B17++,INTENSITY_PRED_B17+++,INTENSITY_PRED_Y18+,INTENSITY_PRED_Y18++,INTENSITY_PRED_Y18+++,INTENSITY_PRED_B18+,INTENSITY_PRED_B18++,INTENSITY_PRED_B18+++,INTENSITY_PRED_Y19+,INTENSITY_PRED_Y19++,INTENSITY_PRED_Y19+++,INTENSITY_PRED_B19+,INTENSITY_PRED_B19++,INTENSITY_PRED_B19+++,INTENSITY_PRED_Y20+,INTENSITY_PRED_Y20++,INTENSITY_PRED_Y20+++,INTENSITY_PRED_B20+,INTENSITY_PRED_B20++,INTENSITY_PRED_B20+++,INTENSITY_PRED_Y21+,INTENSITY_PRED_Y21++,INTENSITY_PRED_Y21+++,INTENSITY_PRED_B21+,INTENSITY_PRED_B21++,INTENSITY_PRED_B21+++,INTENSITY_PRED_Y22+,INTENSITY_PRED_Y22++,INTENSITY_PRED_Y22+++,INTENSITY_PRED_B22+,INTENSITY_PRED_B22++,INTENSITY_PRED_B22+++,INTENSITY_PRED_Y23+,INTENSITY_PRED_Y23++,INTENSITY_PRED_Y23+++,INTENSITY_PRED_B23+,INTENSITY_PRED_B23++,INTENSITY_PRED_B23+++,INTENSITY_PRED_Y24+,INTENSITY_PRED_Y24++,INTENSITY_PRED_Y24+++,INTENSITY_PRED_B24+,INTENSITY_PRED_B24++,INTENSITY_PRED_B24+++,INTENSITY_PRED_Y25+,INTENSITY_PRED_Y25++,INTENSITY_PRED_Y25+++,INTENSITY_PRED_B25+,INTENSITY_PRED_B25++,INTENSITY_PRED_B25+++,INTENSITY_PRED_Y26+,INTENSITY_PRED_Y26++,INTENSITY_PRED_Y26+++,INTENSITY_PRED_B26+,INTENSITY_PRED_B26++,INTENSITY_PRED_B26+++,INTENSITY_PRED_Y27+,INTENSITY_PRED_Y27++,INTENSITY_PRED_Y27+++,INTENSITY_PRED_B27+,INTENSITY_PRED_B27++,INTENSITY_PRED_B27+++,INTENSITY_PRED_Y28+,INTENSITY_PRED_Y28++,INTENSITY_PRED_Y28+++,INTENSITY_PRED_B28+,INTENSITY_PRED_B28++,INTENSITY_PRED_B28+++,INTENSITY_PRED_Y29+,INTENSITY_PRED_Y29++,INTENSITY_PRED_Y29+++,INTENSITY_PRED_B29+,INTENSITY_PRED_B29++,INTENSITY_PRED_B29+++,PREDICTED_IRT +[UNIMOD:737]-PEPTIDEK[UNIMOD:737],30,2,HCD,0.23809576034545898,8.493546488352877e-08,0,8.493546488352877e-08,8.493546488352877e-08,0,0.8115681409835815,8.493546488352877e-08,0,1.0,8.493546488352877e-08,0,0.16860607266426086,8.493546488352877e-08,0,0.009546236135065556,8.493546488352877e-08,0,0.05823937803506851,8.493546488352877e-08,0,0.0676455870270729,8.493546488352877e-08,0,0.039169538766145706,8.493546488352877e-08,0,0.07002319395542145,8.493546488352877e-08,0,0.6715999841690063,0.01274949125945568,0,0.5392391085624695,8.493546488352877e-08,0,0.02867981791496277,0.00021903926972299814,0,0.07469962537288666,8.493546488352877e-08,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,39.420433 +[UNIMOD:737]-PEPTIDE,30,2,HCD,0.522221565246582,1.0271855899190996e-07,0,1.0271855899190996e-07,1.0271855899190996e-07,0,0.6478652358055115,1.0271855899190996e-07,0,0.6932786107063293,1.0271855899190996e-07,0,1.0271855899190996e-07,1.0271855899190996e-07,0,0.03741396963596344,0.04295269027352333,0,1.0271855899190996e-07,1.0271855899190996e-07,0,1.0,0.016776476055383682,0,0.024784639477729797,1.0271855899190996e-07,0,0.47810620069503784,1.0271855899190996e-07,0,1.0271855899190996e-07,1.0271855899190996e-07,0,0.6293804049491882,0.0024885176680982113,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,43.523045 +[UNIMOD:737]-M[UNIMOD:35]EC[UNIMOD:4]TIDEK[UNIMOD:737],35,1,CID,0.2831476628780365,0,0,1.3209350413490029e-07,0,0,1.0,0,0,1.3209350413490029e-07,0,0,0.38974782824516296,0,0,0.018663309514522552,0,0,0.16417913138866425,0,0,0.019735727459192276,0,0,0.27998510003089905,0,0,0.009887314401566982,0,0,0.48476091027259827,0,0,0.009575176984071732,0,0,0.26375487446784973,0,0,0.019234927371144295,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24.778168 diff --git a/tests/unit_tests/test_predictions.py b/tests/unit_tests/test_predictions.py index 4d03f1f9..d329b9d1 100644 --- a/tests/unit_tests/test_predictions.py +++ b/tests/unit_tests/test_predictions.py @@ -4,7 +4,8 @@ import pandas as pd from oktoberfest.data import Spectra -from oktoberfest.pr import grpc_predict +from oktoberfest.data.spectra import FragmentType +from oktoberfest.pr import predict class TestTMTProsit(unittest.TestCase): @@ -13,16 +14,22 @@ class TestTMTProsit(unittest.TestCase): def test_prosit_tmt(self): """Test retrieval of predictions from prosit tmt models via koina.""" library = Spectra.from_csv(Path(__file__).parent / "data" / "predictions" / "library_input.csv") - grpc_predict( - library=library, - url="koina.proteomicsdb.org:443", - intensity_model="Prosit_2020_intensity_TMT", - irt_model="Prosit_2020_irt_TMT", + input_data = library.spectra_data + + pred_intensities = predict( + input_data, + model_name="Prosit_2020_intensity_TMT", + server_url="koina.proteomicsdb.org:443", ssl=True, - alignment=False, - job_type="", + targets=["intensities", "annotation"], + ) + pred_irt = predict( + input_data, model_name="Prosit_2020_irt_TMT", server_url="koina.proteomicsdb.org:443", ssl=True ) + library.add_matrix(pd.Series(pred_intensities["intensities"].tolist(), name="intensities"), FragmentType.PRED) + library.add_column(pred_irt["irt"], name="PREDICTED_IRT") + expected_df = pd.read_csv(Path(__file__).parent / "data" / "predictions" / "library_output.csv") sparse_cols = [col for col in library.spectra_data.columns if col.startswith("INTENSITY_PRED")] for sparse_col in sparse_cols: