Skip to content

Commit

Permalink
fix ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
kdziedzic68 committed Nov 12, 2024
1 parent 6d82936 commit 56f3611
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _parse_pipeline_steps(self) -> list[Step]:
llm = get_cls_from_config(llm_config.provider_type, module)(**llm_kwargs)
task_kwargs: dict[Any, Any] = {"llm": llm}
if getattr(task_config, "kwargs", None):
task_kwargs.update(OmegaConf.to_container(task_config.kwargs)) #type: ignore
task_kwargs.update(OmegaConf.to_container(task_config.kwargs)) # type: ignore
task = get_cls_from_config(task_config.type, module)(**task_kwargs)
tasks.append(task)
if getattr(task_config, "filters", None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def process(self, *inputs: StepInput) -> "StepOutput":
for inp in inputs[0]:
for _ in range(self._num_per_query):
new_inp = deepcopy(inp)
prompt_inp = self._prompt_class.input_type(**{self.inputs[0]: new_inp[self.inputs[0]]}) #type: ignore
prompt_inp = self._prompt_class.input_type(**{self.inputs[0]: new_inp[self.inputs[0]]}) # type: ignore
new_inp[self.outputs[0]] = asyncio.get_event_loop().run_until_complete(
self._llm.generate(prompt=self._prompt_class(prompt_inp))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def format_input(self, input: dict[str, Any]) -> ChatFormat:
Returns:
The formatted chat object containing the input for query generation.
"""
chat = self._prompt_class(self._prompt_class.input_type(**input)).chat #type: ignore
chat = self._prompt_class(self._prompt_class.input_type(**input)).chat # type: ignore
return chat

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def format_output(self, output: str, input: dict[str, Any] | None = None) -> dic
Returns:
A dictionary containing "chunk" and "question".
"""
return {"chunk": input["chunk"], "question": output} #type: ignore
return {"chunk": input["chunk"], "question": output} # type: ignore


class PassagesGenTask(BaseDistilabelTask):
Expand Down Expand Up @@ -63,15 +63,15 @@ def format_output(self, output: str, input: dict[str, Any] | None = None) -> dic
matched_passages: list[str] = []

for passage in passages:
if passage in input["chunk"]: #type: ignore
if passage in input["chunk"]: # type: ignore
matched_passages.append(passage)
else:
matched_passage = get_closest_substring(input["chunk"], passage) #type: ignore
matched_passage = get_closest_substring(input["chunk"], passage) # type: ignore
matched_passages.append(matched_passage)

return {"chunk": input["chunk"], "question": input["question"], "passages": matched_passages} #type: ignore
return {"chunk": input["chunk"], "question": input["question"], "passages": matched_passages} # type: ignore

return {"chunk": input["chunk"], "question": input["question"], "passages": passages} #type: ignore
return {"chunk": input["chunk"], "question": input["question"], "passages": passages} # type: ignore


class AnswerGenTask(BaseDistilabelTask):
Expand Down

0 comments on commit 56f3611

Please sign in to comment.