From bc912c42d6138dd501e56bcffe6fae1d703d344f Mon Sep 17 00:00:00 2001 From: nik Date: Fri, 16 Feb 2024 12:47:19 +0000 Subject: [PATCH] Fix default batch size, add docstrings --- adala/agents/base.py | 1 - adala/environments/kafka.py | 27 ++++++++++++++ adala/runtimes/base.py | 2 +- adala/server/app.py | 74 ++++++++++++++++++++++++++++++++++++- 4 files changed, 101 insertions(+), 3 deletions(-) diff --git a/adala/agents/base.py b/adala/agents/base.py index f638f53..f4d8841 100644 --- a/adala/agents/base.py +++ b/adala/agents/base.py @@ -42,7 +42,6 @@ class Agent(BaseModel, ABC): >>> agent = Agent(skills=LinearSkillSet(skills=[TransformSkill()]), environment=StaticEnvironment()) >>> agent.learn() # starts the learning process >>> predictions = agent.run() # runs the agent and returns the predictions - """ environment: Optional[Union[Environment, AsyncEnvironment]] = None diff --git a/adala/environments/kafka.py b/adala/environments/kafka.py index 8dd4a5c..3435a43 100644 --- a/adala/environments/kafka.py +++ b/adala/environments/kafka.py @@ -15,6 +15,16 @@ class AsyncKafkaEnvironment(AsyncEnvironment): + """ + Represents an asynchronous Kafka environment: + - agent can retrieve data batch by batch from the input topic + - agent can return its predictions to the output topic + + Attributes: + kafka_bootstrap_servers (Union[str, List[str]]): The Kafka bootstrap servers. + kafka_input_topic (str): The Kafka input topic. + kafka_output_topic (str): The Kafka output topic. + """ kafka_bootstrap_servers: Union[str, List[str]] kafka_input_topic: str @@ -78,6 +88,15 @@ async def set_predictions(self, predictions: InternalDataFrame): class FileStreamAsyncKafkaEnvironment(AsyncKafkaEnvironment): + """ + Represents an asynchronous Kafka environment with file stream: + - agent can retrieve data batch by batch from the input topic + - agent can return its predictions to the output topic + - input data is read from `input_file` + - output data is stored to the `output_file` + - errors are saved to the `error_file` + """ + input_file: str output_file: str error_file: str @@ -107,6 +126,10 @@ def _iter_csv_s3(self, s3_uri): yield row async def initialize(self): + """ + Initialize the environment: read data from the input file and push it to the kafka topic. + """ + # TODO: Add support for other file types except CSV, and also for other cloud storage services if self.input_file.startswith("s3://"): csv_reader = self._iter_csv_s3(self.input_file) @@ -121,6 +144,10 @@ async def initialize(self): await self.message_sender(producer, csv_reader, self.kafka_input_topic) async def finalize(self): + """ + Finalize the environment: read data from the output kafka topic and write it to the output file. + """ + consumer = AIOKafkaConsumer( self.kafka_output_topic, bootstrap_servers=self.kafka_bootstrap_servers, diff --git a/adala/runtimes/base.py b/adala/runtimes/base.py index 6ddc008..f5841bf 100644 --- a/adala/runtimes/base.py +++ b/adala/runtimes/base.py @@ -212,7 +212,7 @@ async def batch_to_batch( async def get_next_batch(self, data_iterator, batch_size: Optional[int]) -> InternalDataFrame: if batch_size is None: - batch_size = self.optimal_batch_size + batch_size = self.batch_size batch = [] try: for _ in range(batch_size): diff --git a/adala/server/app.py b/adala/server/app.py index 7f211a7..4db1649 100644 --- a/adala/server/app.py +++ b/adala/server/app.py @@ -9,6 +9,8 @@ app = fastapi.FastAPI() + +# TODO: add a correct middleware policy to handle CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -33,10 +35,39 @@ def model_dump(self, *args, **kwargs) -> Dict[str, Any]: # type: ignore class JobCreated(BaseModel): + """ + Response model for a job created. + """ job_id: str class SubmitRequest(BaseModel): + """ + Request model for submitting a job. + + Attributes: + agent (Agent): The agent to be used for the task. Example of serialized agent: + { + "skills": [{ + "type": "classification", + "name": "text_classifier", + "instructions": "Classify the text.", + "input_template": "Text: {text}", + "output_template": "Classification result: {label}", + "labels": { + "label": ['label1', 'label2', 'label3'] + } + }], + "runtimes": { + "default": { + "type": "openai-chat", + "model": "gpt-3.5-turbo", + "api_key": "..." + } + } + } + task_name (str): The name of the task to be executed by the agent. + """ agent: Agent task_name: str = "process_file" @@ -49,7 +80,13 @@ def get_index(): @app.post("/submit", response_model=Response[JobCreated]) async def submit(request: SubmitRequest): """ - Submit a request to execute task in celery. + Submit a request to execute task `request.task_name` in celery. + + Args: + request (SubmitRequest): The request model for submitting a job. + + Returns: + Response[JobCreated]: The response model for a job created. """ # TODO: get task by name, e.g. request.task_name task = process_file @@ -60,30 +97,65 @@ async def submit(request: SubmitRequest): class JobStatusRequest(BaseModel): + """ + Request model for getting the status of a job. + """ job_id: str class JobStatusResponse(BaseModel): + """ + Response model for getting the status of a job. + + Attributes: + status (str): The status of the job. + processed_total (List[int]): The total number of processed records and the total number of records in job. + Example: [10, 100] means 10% of the completeness. + """ status: str + processed_total: List[int] = Annotated[List[int], AfterValidator(lambda x: len(x) == 2)] @app.get('/get-status') def get_status(request: JobStatusRequest): + """ + Get the status of a job. + + Args: + request (JobStatusRequest): The request model for getting the status of a job. + Returns: + JobStatusResponse: The response model for getting the status of a job. + """ job = process_file.AsyncResult(request.job_id) return Response[JobStatusResponse](data=JobStatusResponse(status=job.status)) class JobCancelRequest(BaseModel): + """ + Request model for cancelling a job. + """ job_id: str class JobCancelResponse(BaseModel): + """ + Response model for cancelling a job. + """ status: str @app.post('/cancel') def cancel_job(request: JobCancelRequest): + """ + Cancel a job. + + Args: + request (JobCancelRequest): The request model for cancelling a job. + + Returns: + JobCancelResponse: The response model for cancelling a job. + """ job = process_file.AsyncResult(request.job_id) job.revoke() return Response[JobCancelResponse](data=JobCancelResponse(status='cancelled'))