Skip to content
This repository has been archived by the owner on Jul 15, 2024. It is now read-only.

Commit

Permalink
Merge pull request caikit#675 from mynhardtburger/input_token_count
Browse files Browse the repository at this point in the history
Add input token count to embedding, reranker, sentence similarity
  • Loading branch information
evaline-ju authored Mar 6, 2024
2 parents 203b272 + 30f32a7 commit 75e65a8
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
7 changes: 5 additions & 2 deletions caikit/interfaces/nlp/data_model/embedding_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data structures for embedding vector representations
"""
"""Data structures for embedding vector representations"""

# Standard
from dataclasses import dataclass
from typing import Optional

# First Party
from py_to_proto.dataclass_to_proto import Annotated, FieldNumber
Expand All @@ -36,6 +37,7 @@ class EmbeddingResult(DataObjectBase):

result: Annotated[Vector1D, FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
input_token_count: Annotated[Optional[int], FieldNumber(3)]


@dataobject(package="caikit_data_model.caikit_nlp")
Expand All @@ -45,3 +47,4 @@ class EmbeddingResults(DataObjectBase):

results: Annotated[ListOfVector1D, FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
input_token_count: Annotated[Optional[int], FieldNumber(3)]
2 changes: 2 additions & 0 deletions caikit/interfaces/nlp/data_model/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class RerankResult(DataObjectBase):

result: Annotated[RerankScores, FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
input_token_count: Annotated[Optional[int], FieldNumber(3)]


@dataobject(package="caikit_data_model.caikit_nlp")
Expand All @@ -64,3 +65,4 @@ class RerankResults(DataObjectBase):

results: Annotated[List[RerankScores], FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
input_token_count: Annotated[Optional[int], FieldNumber(3)]
8 changes: 5 additions & 3 deletions caikit/interfaces/nlp/data_model/sentence_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data structures for embedding vector representations
"""
"""Data structures for embedding vector representations"""

# Standard
from typing import List
from typing import List, Optional

# First Party
from py_to_proto.dataclass_to_proto import Annotated, FieldNumber
Expand Down Expand Up @@ -42,6 +42,7 @@ class SentenceSimilarityResult(DataObjectBase):

result: Annotated[SentenceSimilarityScores, FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
input_token_count: Annotated[Optional[int], FieldNumber(3)]


@dataobject(package="caikit_data_model.caikit_nlp")
Expand All @@ -50,3 +51,4 @@ class SentenceSimilarityResults(DataObjectBase):

results: Annotated[List[SentenceSimilarityScores], FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
input_token_count: Annotated[Optional[int], FieldNumber(3)]
25 changes: 17 additions & 8 deletions tests/interfaces/nlp/test_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for reranker
"""
"""Test for reranker"""

# Standard
import random
Expand Down Expand Up @@ -100,12 +99,18 @@ def input_scores2(input_random_score, input_random_score_3):

@pytest.fixture
def input_result_1(input_scores):
return {"result": dm.RerankScores(query="foo", scores=input_scores)}
return {
"result": dm.RerankScores(query="foo", scores=input_scores),
"input_token_count": 0,
}


@pytest.fixture
def input_result_2(input_scores2):
return {"result": dm.RerankScores(query="bar", scores=input_scores2)}
return {
"result": dm.RerankScores(query="bar", scores=input_scores2),
"input_token_count": 0,
}


@pytest.fixture
Expand All @@ -114,7 +119,8 @@ def input_results(input_scores, input_scores2):
"results": [
dm.RerankScores(query="foo", scores=input_scores),
dm.RerankScores(query="bar", scores=input_scores2),
]
],
"input_token_count": 0,
}


Expand All @@ -125,7 +131,10 @@ def input_sentence_similarity_scores_1():

@pytest.fixture
def input_sentence_similarity_result(input_sentence_similarity_scores_1):
return {"result": dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1)}
return {
"result": dm.SentenceSimilarityScores(**input_sentence_similarity_scores_1),
"input_token_count": 0,
}


@pytest.fixture
Expand All @@ -145,7 +154,7 @@ def input_sentence_similarities_scores(

@pytest.fixture
def input_sentence_similarity_results(input_sentence_similarities_scores):
return {"results": input_sentence_similarities_scores}
return {"results": input_sentence_similarities_scores, "input_token_count": 0}


## Tests ########################################################################
Expand All @@ -162,7 +171,7 @@ def input_sentence_similarity_results(input_sentence_similarities_scores):
(dm.SentenceSimilarityResults, "input_sentence_similarity_results"),
],
)
def test_data_object(data_object, inputs, request):
def test_data_object(data_object, inputs, request: pytest.FixtureRequest):
# Init data object
fixture_values = request.getfixturevalue(inputs)
new_do_from_init = data_object(**fixture_values)
Expand Down

0 comments on commit 75e65a8

Please sign in to comment.