Skip to content

Commit

Permalink
Fix default batch size, add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Feb 16, 2024
1 parent 889f78f commit bc912c4
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
1 change: 0 additions & 1 deletion adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class Agent(BaseModel, ABC):
>>> agent = Agent(skills=LinearSkillSet(skills=[TransformSkill()]), environment=StaticEnvironment())
>>> agent.learn() # starts the learning process
>>> predictions = agent.run() # runs the agent and returns the predictions
"""

environment: Optional[Union[Environment, AsyncEnvironment]] = None
Expand Down
27 changes: 27 additions & 0 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@


class AsyncKafkaEnvironment(AsyncEnvironment):
"""
Represents an asynchronous Kafka environment:
- agent can retrieve data batch by batch from the input topic
- agent can return its predictions to the output topic
Attributes:
kafka_bootstrap_servers (Union[str, List[str]]): The Kafka bootstrap servers.
kafka_input_topic (str): The Kafka input topic.
kafka_output_topic (str): The Kafka output topic.
"""

kafka_bootstrap_servers: Union[str, List[str]]
kafka_input_topic: str
Expand Down Expand Up @@ -78,6 +88,15 @@ async def set_predictions(self, predictions: InternalDataFrame):


class FileStreamAsyncKafkaEnvironment(AsyncKafkaEnvironment):
"""
Represents an asynchronous Kafka environment with file stream:
- agent can retrieve data batch by batch from the input topic
- agent can return its predictions to the output topic
- input data is read from `input_file`
- output data is stored to the `output_file`
- errors are saved to the `error_file`
"""

input_file: str
output_file: str
error_file: str
Expand Down Expand Up @@ -107,6 +126,10 @@ def _iter_csv_s3(self, s3_uri):
yield row

async def initialize(self):
"""
Initialize the environment: read data from the input file and push it to the kafka topic.
"""

# TODO: Add support for other file types except CSV, and also for other cloud storage services
if self.input_file.startswith("s3://"):
csv_reader = self._iter_csv_s3(self.input_file)
Expand All @@ -121,6 +144,10 @@ async def initialize(self):
await self.message_sender(producer, csv_reader, self.kafka_input_topic)

async def finalize(self):
"""
Finalize the environment: read data from the output kafka topic and write it to the output file.
"""

consumer = AIOKafkaConsumer(
self.kafka_output_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
Expand Down
2 changes: 1 addition & 1 deletion adala/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def batch_to_batch(

async def get_next_batch(self, data_iterator, batch_size: Optional[int]) -> InternalDataFrame:
if batch_size is None:
batch_size = self.optimal_batch_size
batch_size = self.batch_size
batch = []
try:
for _ in range(batch_size):
Expand Down
74 changes: 73 additions & 1 deletion adala/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


app = fastapi.FastAPI()

# TODO: add a correct middleware policy to handle CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand All @@ -33,10 +35,39 @@ def model_dump(self, *args, **kwargs) -> Dict[str, Any]: # type: ignore


class JobCreated(BaseModel):
"""
Response model for a job created.
"""
job_id: str


class SubmitRequest(BaseModel):
"""
Request model for submitting a job.
Attributes:
agent (Agent): The agent to be used for the task. Example of serialized agent:
{
"skills": [{
"type": "classification",
"name": "text_classifier",
"instructions": "Classify the text.",
"input_template": "Text: {text}",
"output_template": "Classification result: {label}",
"labels": {
"label": ['label1', 'label2', 'label3']
}
}],
"runtimes": {
"default": {
"type": "openai-chat",
"model": "gpt-3.5-turbo",
"api_key": "..."
}
}
}
task_name (str): The name of the task to be executed by the agent.
"""
agent: Agent
task_name: str = "process_file"

Expand All @@ -49,7 +80,13 @@ def get_index():
@app.post("/submit", response_model=Response[JobCreated])
async def submit(request: SubmitRequest):
"""
Submit a request to execute task in celery.
Submit a request to execute task `request.task_name` in celery.
Args:
request (SubmitRequest): The request model for submitting a job.
Returns:
Response[JobCreated]: The response model for a job created.
"""
# TODO: get task by name, e.g. request.task_name
task = process_file
Expand All @@ -60,30 +97,65 @@ async def submit(request: SubmitRequest):


class JobStatusRequest(BaseModel):
"""
Request model for getting the status of a job.
"""
job_id: str


class JobStatusResponse(BaseModel):
"""
Response model for getting the status of a job.
Attributes:
status (str): The status of the job.
processed_total (List[int]): The total number of processed records and the total number of records in job.
Example: [10, 100] means 10% of the completeness.
"""
status: str
processed_total: List[int] = Annotated[List[int], AfterValidator(lambda x: len(x) == 2)]


@app.get('/get-status')
def get_status(request: JobStatusRequest):
"""
Get the status of a job.
Args:
request (JobStatusRequest): The request model for getting the status of a job.
Returns:
JobStatusResponse: The response model for getting the status of a job.
"""
job = process_file.AsyncResult(request.job_id)
return Response[JobStatusResponse](data=JobStatusResponse(status=job.status))


class JobCancelRequest(BaseModel):
"""
Request model for cancelling a job.
"""
job_id: str


class JobCancelResponse(BaseModel):
"""
Response model for cancelling a job.
"""
status: str


@app.post('/cancel')
def cancel_job(request: JobCancelRequest):
"""
Cancel a job.
Args:
request (JobCancelRequest): The request model for cancelling a job.
Returns:
JobCancelResponse: The response model for cancelling a job.
"""
job = process_file.AsyncResult(request.job_id)
job.revoke()
return Response[JobCancelResponse](data=JobCancelResponse(status='cancelled'))

0 comments on commit bc912c4

Please sign in to comment.