diff --git a/caikit/interfaces/nlp/data_model/embedding_vectors.py b/caikit/interfaces/nlp/data_model/embedding_vectors.py index 30163f142..ec9479e17 100644 --- a/caikit/interfaces/nlp/data_model/embedding_vectors.py +++ b/caikit/interfaces/nlp/data_model/embedding_vectors.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Data structures for embedding vector representations -""" +"""Data structures for embedding vector representations""" + # Standard from dataclasses import dataclass +from typing import Optional # First Party from py_to_proto.dataclass_to_proto import Annotated, FieldNumber @@ -36,6 +37,7 @@ class EmbeddingResult(DataObjectBase): result: Annotated[Vector1D, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] @dataobject(package="caikit_data_model.caikit_nlp") @@ -45,3 +47,4 @@ class EmbeddingResults(DataObjectBase): results: Annotated[ListOfVector1D, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] diff --git a/caikit/interfaces/nlp/data_model/reranker.py b/caikit/interfaces/nlp/data_model/reranker.py index be700cafe..5844169dc 100644 --- a/caikit/interfaces/nlp/data_model/reranker.py +++ b/caikit/interfaces/nlp/data_model/reranker.py @@ -54,6 +54,7 @@ class RerankResult(DataObjectBase): result: Annotated[RerankScores, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] @dataobject(package="caikit_data_model.caikit_nlp") @@ -64,3 +65,4 @@ class RerankResults(DataObjectBase): results: Annotated[List[RerankScores], FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] diff --git a/caikit/interfaces/nlp/data_model/sentence_similarity.py b/caikit/interfaces/nlp/data_model/sentence_similarity.py index 5c908a1dc..20fabd0d9 100644 --- a/caikit/interfaces/nlp/data_model/sentence_similarity.py +++ b/caikit/interfaces/nlp/data_model/sentence_similarity.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Data structures for embedding vector representations -""" +"""Data structures for embedding vector representations""" + # Standard -from typing import List +from typing import List, Optional # First Party from py_to_proto.dataclass_to_proto import Annotated, FieldNumber @@ -42,6 +42,7 @@ class SentenceSimilarityResult(DataObjectBase): result: Annotated[SentenceSimilarityScores, FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] @dataobject(package="caikit_data_model.caikit_nlp") @@ -50,3 +51,4 @@ class SentenceSimilarityResults(DataObjectBase): results: Annotated[List[SentenceSimilarityScores], FieldNumber(1)] producer_id: Annotated[ProducerId, FieldNumber(2)] + input_token_count: Annotated[Optional[int], FieldNumber(3)] diff --git a/tests/interfaces/nlp/test_reranker.py b/tests/interfaces/nlp/test_reranker.py index 16840ef25..8d916c6fc 100644 --- a/tests/interfaces/nlp/test_reranker.py +++ b/tests/interfaces/nlp/test_reranker.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test for reranker -""" +"""Test for reranker""" # Standard import random @@ -100,12 +99,18 @@ def input_scores2(input_random_score, input_random_score_3): @pytest.fixture def input_result_1(input_scores): - return {"result": dm.RerankScores(query="foo", scores=input_scores)} + return { + "result": dm.RerankScores(query="foo", scores=input_scores), + "input_token_count": 0, + } @pytest.fixture def input_result_2(input_scores2): - return {"result": dm.RerankScores(query="bar", scores=input_scores2)} + return { + "result": dm.RerankScores(query="bar", scores=input_scores2), + "input_token_count": 0, + } @pytest.fixture @@ -114,7 +119,8 @@ def input_results(input_scores, input_scores2): "results": [ dm.RerankScores(query="foo", scores=input_scores), dm.RerankScores(query="bar", scores=input_scores2), - ] + ], + "input_token_count": 0, } @@ -125,7 +131,10 @@ def input_sentence_similarity_scores_1(): @pytest.fixture def input_sentence_similarity_result(input_sentence_similarity_scores_1): - return {"result": dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1)} + return { + "result": dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1), + "input_token_count": 0, + } @pytest.fixture @@ -145,7 +154,7 @@ def input_sentence_similarities_scores( @pytest.fixture def input_sentence_similarity_results(input_sentence_similarities_scores): - return {"results": input_sentence_similarities_scores} + return {"results": input_sentence_similarities_scores, "input_token_count": 0} ## Tests ######################################################################## @@ -162,7 +171,7 @@ def input_sentence_similarity_results(input_sentence_similarities_scores): (dm.SentenceSimilarityResults, "input_sentence_similarity_results"), ], ) -def test_data_object(data_object, inputs, request): +def test_data_object(data_object, inputs, request: pytest.FixtureRequest): # Init data object fixture_values = request.getfixturevalue(inputs) new_do_from_init = data_object(**fixture_values)