Skip to content

Commit

Permalink
make reference_context optional (run-llama#9266)
Browse files Browse the repository at this point in the history
* make reference_context optional

* lint

* make entry to chlog
  • Loading branch information
nerdai authored Dec 2, 2023
1 parent 1f9ba34 commit a89a4c7
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features

- Make `reference_contexts` optional in `LabelledRagDataset` (#9266)
- Re-organize `download` module (#9253)
- Added document management to ingestion pipeline (#9135)
- Add docs for `LabelledRagDataset` (#9228)

Expand Down
20 changes: 17 additions & 3 deletions llama_index/llama_dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from abc import abstractmethod
from enum import Enum
from typing import List, Optional, Type
from typing import List, Optional, Type, Union

import tqdm
from pandas import DataFrame as PandasDataFrame
Expand Down Expand Up @@ -58,10 +58,17 @@ def class_name(self) -> str:

class BaseLlamaPredictionDataset(BaseModel):
_prediction_type: Type[BaseLlamaExamplePrediction] = BaseLlamaExamplePrediction # type: ignore[misc]
predictions: Optional[List[BaseLlamaExamplePrediction]] = Field(
default=None, description="Predictions on train_examples."
predictions: List[BaseLlamaExamplePrediction] = Field(
default=list, description="Predictions on train_examples."
)

def __getitem__(self, val: Union[slice, int]) -> List[BaseLlamaExamplePrediction]:
"""Enable slicing and indexing.
Returns the desired slice on `predictions`.
"""
return self.predictions[val]

@abstractmethod
def to_pandas(self) -> PandasDataFrame:
"""Create pandas dataframe."""
Expand Down Expand Up @@ -99,6 +106,13 @@ class BaseLlamaDataset(BaseModel):
default=[], description="Data examples of this dataset."
)

def __getitem__(self, val: Union[slice, int]) -> List[BaseLlamaDataExample]:
"""Enable slicing and indexing.
Returns the desired slice on `examples`.
"""
return self.examples[val]

@abstractmethod
def to_pandas(self) -> PandasDataFrame:
"""Create pandas dataframe."""
Expand Down
18 changes: 10 additions & 8 deletions llama_index/llama_dataset/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ class RagExamplePrediction(BaseLlamaExamplePrediction):
"""RAG example prediction class.
Args:
response: str
contexts: List[str]
response (str): The response generated by the LLM.
contexts (Optional[List[str]]): The retrieved context (text) for generating
response.
"""

response: str = Field(
default_factory=str,
description="The generated (predicted) response that can be compared to a reference (ground-truth) answer.",
)
contexts: List[str] = Field(
default_factory=List,
contexts: Optional[List[str]] = Field(
default_factory=None,
description="The contexts in raw text form used to generate the response.",
)

Expand All @@ -45,10 +46,11 @@ class LabelledRagDataExample(BaseLlamaDataExample):
Args:
query (str): The user query
kind (LlamaRagDataExampleKind): The example is generated by human or ai
reference_contexts (List[str] or List[TextNode]): The contexts used for response
query_by (CreatedBy): Query generated by human or ai (model-name)
reference_contexts (Optional[List[str]]): The contexts used for response
reference_answer ([str]): Reference answer to the query. An answer
that would receive full marks upon evaluation.
reference_answer_by: The reference answer generated by human or ai (model-name).
"""

query: str = Field(
Expand All @@ -57,8 +59,8 @@ class LabelledRagDataExample(BaseLlamaDataExample):
query_by: Optional[CreatedBy] = Field(
default=None, description="What generated the query."
)
reference_contexts: List[str] = Field(
default_factory=List,
reference_contexts: Optional[List[str]] = Field(
default_factory=None,
description="The contexts used to generate the reference answer.",
)
reference_answer: str = Field(
Expand Down

0 comments on commit a89a4c7

Please sign in to comment.