Skip to content

Commit

Permalink
concurent execution
Browse files Browse the repository at this point in the history
  • Loading branch information
kdziedzic68 committed Nov 19, 2024
1 parent f5f251e commit 9abe2e5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/evaluation/dataset-generator/config/generate.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
input_name: query
input_name: topic
name: synthetic-RAG-data
tasks:
- type: ragbits.evaluate.dataset_generator.tasks.corpus_generation:CorpusGenerationStep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def from_dict_config(cls, dict_config: DictConfig) -> "DatasetGenerationPipeline
type=task_config.type,
llm=LLMConfigForTask(
provider_type=task_config.llm.provider_type,
kwargs=OmegaConf.to_container(task_config.llm.kwargs),
kwargs=OmegaConf.to_container(task_config.llm.kwargs), # type: ignore
),
kwargs=OmegaConf.to_container(task_config.kwargs),
kwargs=OmegaConf.to_container(task_config.kwargs), # type: ignore
filters=getattr(task_config, "filters", None),
)
for task_config in dict_config.tasks
Expand Down Expand Up @@ -125,10 +125,10 @@ def _parse_pipeline_steps(self) -> list[Step]:
task_kwargs.update(task_config.kwargs or {}) # type: ignore
task = get_cls_from_config(task_config.type, module)(**task_kwargs)
tasks.append(task)
if getattr(task_config, "filters", None):
for filter_type in task_config.filters:
filter = get_cls_from_config(filter_type, module)(tasks[-1])
tasks.append(filter)
filter_types = getattr(task_config, "filters", None) or []
for filter_type in filter_types:
filter = get_cls_from_config(filter_type, module)(tasks[-1])
tasks.append(filter)
return tasks

def _instantiate_pipeline(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def inputs(self) -> list[str]:
Returns:
list of input fields
"""
return ["query"]
return ["topic"]

@property
def outputs(self) -> list[str]:
Expand All @@ -54,13 +54,26 @@ def process(self, *inputs: StepInput) -> "StepOutput":
Returns:
a generated corpus
"""
result = []
for topic in inputs[0]:
for _ in range(self._num_per_query):
new_inp = deepcopy(topic)
prompt_inp = self._prompt_class.input_type(**{self.inputs[0]: new_inp[self.inputs[0]]}) # type: ignore
new_inp[self.outputs[0]] = asyncio.get_event_loop().run_until_complete(
self._llm.generate(prompt=self._prompt_class(prompt_inp))
)
result.append(new_inp)
result = asyncio.get_event_loop().run_until_complete(self._process_topics(topics=inputs[0]))
yield result

async def _process_topics(self, topics: list[dict]) -> list[dict]:
"""
Processes a list of topics concurrently, respecting the batch size limit.
Args:
topics (List[dict]): A list of topics to process.
Returns:
List[dict]: A list of processed topics.
"""
tasks = [self._process_topic(topic) for _ in range(self._num_per_query) for topic in topics]
results = await asyncio.gather(*tasks)
return results

async def _process_topic(self, topic: dict) -> dict:
new_inp = deepcopy(topic)
prompt_inp = self._prompt_class.input_type(**{self.inputs[0]: new_inp[self.inputs[0]]}) # type: ignore
new_inp[self.outputs[0]] = await self._llm.generate(prompt=self._prompt_class(prompt_inp))
return new_inp

0 comments on commit 9abe2e5

Please sign in to comment.