Skip to content

Commit

Permalink
Fix errors with server runs
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Feb 19, 2024
1 parent e17a2a1 commit 88c84f4
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 58 deletions.
20 changes: 13 additions & 7 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator
import logging
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator, SerializeAsAny
from abc import ABC, abstractmethod
from typing import Any, Optional, List, Dict, Union, Tuple
from rich import print
Expand All @@ -21,6 +22,8 @@
)
from adala.utils.internal_data import InternalDataFrame

logger = logging.getLogger(__name__)


class Agent(BaseModel, ABC):
"""
Expand All @@ -45,11 +48,11 @@ class Agent(BaseModel, ABC):
>>> predictions = agent.run() # runs the agent and returns the predictions
"""

environment: Optional[Union[Environment, AsyncEnvironment]] = None
skills: SkillSet
environment: Optional[SerializeAsAny[Union[Environment, AsyncEnvironment]]] = None
skills: Union[Skill, SkillSet]

memory: Memory = Field(default=None)
runtimes: Dict[str, Union[Runtime, AsyncRuntime]] = Field(
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
default_factory=lambda: {
"default": GuidanceRuntime()
# 'openai': OpenAIChatRuntime(model='gpt-3.5-turbo'),
Expand All @@ -62,7 +65,7 @@ class Agent(BaseModel, ABC):
# )
}
)
teacher_runtimes: Dict[str, Runtime] = Field(
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
default_factory=lambda: {
"default": OpenAIChatRuntime(model="gpt-3.5-turbo"),
# 'openai-gpt4': OpenAIChatRuntime(model='gpt-4')
Expand Down Expand Up @@ -100,6 +103,7 @@ def environment_validator(cls, v) -> Environment:
Validates and possibly transforms the environment attribute:
if the environment is an InternalDataFrame, it is transformed into a StaticEnvironment.
"""
logger.debug(f"Validating environment attribute: {v}")
if isinstance(v, InternalDataFrame):
v = StaticEnvironment(df=v)
elif isinstance(v, dict) and "type" in v:
Expand All @@ -118,7 +122,7 @@ def skills_validator(cls, v) -> SkillSet:
elif isinstance(v, list):
return LinearSkillSet(skills=v)
else:
raise ValueError(f"skills must be of type SkillSet or Skill, not {type(v)}")
raise ValueError(f"skills must be of type SkillSet or Skill, but received type {type(v)}")

@field_validator('runtimes', mode='before')
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
Expand All @@ -132,7 +136,8 @@ def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
raise ValueError(
f"Runtime {runtime_name} must have a 'type' field to specify the runtime type."
)
runtime_value = Runtime.create_from_registry(runtime_value.pop('type'), **runtime_value)
type_name = runtime_value.pop("type")
runtime_value = Runtime.create_from_registry(type=type_name, **runtime_value)
out[runtime_name] = runtime_value
return out

Expand Down Expand Up @@ -262,6 +267,7 @@ async def arun(
try:
data_batch = await self.environment.get_data_batch(batch_size=runtime.batch_size)
if data_batch.empty:
print_text("No more data in the environment. Exiting.")
break
except Exception as e:
# TODO: environment should raise a specific exception + log error
Expand Down
18 changes: 18 additions & 0 deletions adala/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ class Environment(BaseModelInRegistry):
dataset conversion, and state persistence.
"""

@abstractmethod
def initialize(self):
"""
Initialize the environment, e.g by connecting to a database, reading file to memory or starting a stream.
Raises:
NotImplementedError: This method is not implemented for BasicEnvironment.
"""

@abstractmethod
def finalize(self):
"""
Finalize the environment, e.g by closing a database connection, writing memory to file or stopping a stream.
Raises:
NotImplementedError: This method is not implemented for BasicEnvironment.
"""

@abstractmethod
def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
"""
Expand Down
10 changes: 6 additions & 4 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import abc
import boto3
import json
Expand Down Expand Up @@ -70,7 +71,8 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode('utf-8')),
auto_offset_reset='earliest'
auto_offset_reset='earliest',
group_id='adala-consumer-group' # TODO: make it configurable based on the environment
)

data_stream = self.message_receiver(consumer)
Expand All @@ -82,8 +84,8 @@ async def set_predictions(self, predictions: InternalDataFrame):
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)

await self.message_sender(producer, predictions, self.kafka_output_topic)
predictions_iter = (r.to_dict() for _, r in predictions.iterrows())
await self.message_sender(producer, predictions_iter, self.kafka_output_topic)


class FileStreamAsyncKafkaEnvironment(AsyncKafkaEnvironment):
Expand Down Expand Up @@ -167,7 +169,7 @@ async def _write_to_csv_fileobj(self, fileobj, data_stream, column_names):
while True:
try:
record = await anext(data_stream)
csv_writer.writerow(record)
csv_writer.writerow({k: record.get(k, '') for k in column_names})
except StopAsyncIteration:
break

Expand Down
1 change: 1 addition & 0 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ async def record_to_record(
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
) -> Dict[str, str]:

raise NotImplementedError("record_to_record is not implemented")


Expand Down
14 changes: 0 additions & 14 deletions adala/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,3 @@ async def batch_to_batch(
instructions_first=instructions_first,
)
return output

async def get_next_batch(self, data_iterator, batch_size: Optional[int]) -> InternalDataFrame:
if batch_size is None:
batch_size = self.batch_size
batch = []
try:
for _ in range(batch_size):
data = await anext(data_iterator, None)
if data is None: # This checks if the iterator is exhausted
break
batch.append(data)
except StopAsyncIteration:
pass
return InternalDataFrame(batch)
9 changes: 7 additions & 2 deletions adala/server/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import fastapi
import logging
import pickle
from fastapi.middleware.cors import CORSMiddleware
from typing import Generic, TypeVar, Optional, List, Dict, Any
from typing_extensions import Annotated
from pydantic import BaseModel
from pydantic.functional_validators import AfterValidator
from adala.agents import Agent
from adala.server.tasks.process_file import process_file
from log_middleware import LogMiddleware

logger = logging.getLogger(__name__)


app = fastapi.FastAPI()
Expand All @@ -18,6 +23,7 @@
allow_headers=["*"],
allow_credentials=True
)
app.add_middleware(LogMiddleware)

ResponseData = TypeVar("ResponseData")

Expand Down Expand Up @@ -90,8 +96,7 @@ async def submit(request: SubmitRequest):
"""
# TODO: get task by name, e.g. request.task_name
task = process_file

serialized_agent = request.agent.model_dump_json()
serialized_agent = pickle.dumps(request.agent)
result = task.delay(serialized_agent=serialized_agent)
return Response[JobCreated](data=JobCreated(job_id=result.id))

Expand Down
2 changes: 1 addition & 1 deletion adala/server/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ services:
redis:
image: redis:alpine
ports:
- "6379"
- "6379:6379"
healthcheck:
test: [ "CMD", "redis-cli", "ping" ]
42 changes: 42 additions & 0 deletions adala/server/log_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
import logging
from logging import Formatter
from starlette.middleware.base import BaseHTTPMiddleware


class JsonFormatter(Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()

def format(self, record):
json_record = {}
json_record["message"] = record.getMessage()
if "url" in record.__dict__:
json_record["url"] = record.__dict__["url"]
if "method" in record.__dict__:
json_record["method"] = record.__dict__["method"]
if "status_code" in record.__dict__:
json_record["status_code"] = record.__dict__["status_code"]
return json.dumps(json_record)


logger = logging.root
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
logger.handlers = [handler]
logger.setLevel(logging.DEBUG)
logging.getLogger("uvicorn.access").disabled = True


class LogMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
logger.info(
"Request",
extra={
"method": request.method,
"url": str(request.url),
"status_code": response.status_code
},
)
return response
38 changes: 12 additions & 26 deletions adala/server/tasks/process_file.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import json
import pickle
import os
import logging
from celery import Celery
from typing import List
from adala.agents import Agent
from adala.environments.kafka import FileStreamAsyncKafkaEnvironment

logger = logging.getLogger(__name__)

REDIS_URL = os.getenv('REDIS_URL', 'redis://localhost:6379/0')
KAFKA_BOOTSTRAP_SERVERS = os.getenv('KAFKA_BOOTSTRAP_SERVERS', 'localhost:9092')
Expand All @@ -14,29 +14,15 @@


@app.task(name='process_file')
def process_file(
input_file: str,
serialized_agent: str,
output_file: str,
error_file: str,
output_columns: List[str]
):
agent = json.loads(serialized_agent)
env = FileStreamAsyncKafkaEnvironment(
kafka_bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
kafka_input_topic=KAFKA_INPUT_TOPIC,
kafka_output_topic=KAFKA_OUTPUT_TOPIC
)

# Define an agent
agent = Agent(**json.loads(serialized_agent))
agent.environment = env

# Read data from a file and send it to the Kafka input topic
asyncio.run(env.read_from_file(input_file))
def process_file(serialized_agent: bytes):
# Load the agent
agent = pickle.loads(serialized_agent)
# # Read data from a file and send it to the Kafka input topic
asyncio.run(agent.environment.initialize())

# run the agent
asyncio.run(agent.arun())

#
# dump the output to a file
asyncio.run(env.write_to_file(output_file, output_columns))
asyncio.run(agent.environment.finalize())

4 changes: 2 additions & 2 deletions adala/skills/skillset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SkillSet(BaseModel, ABC):
skills (Dict[str, Skill]): A dictionary of skills in the skill set.
"""

skills: Dict[str, Skill]
skills: Union[List, Dict[str, Skill]]

@field_validator("skills", mode="before")
def skills_validator(cls, v: Union[List, Dict]) -> Dict[str, Skill]:
Expand Down Expand Up @@ -63,7 +63,7 @@ def skills_validator(cls, v: Union[List, Dict]) -> Dict[str, Skill]:
elif isinstance(v, dict):
skills = v
else:
raise ValueError(f"skills must be a list or dictionary, not {type(skills)}")
raise ValueError(f"skills must be a list or dictionary, but received type {type(v)}")
return skills

@abstractmethod
Expand Down
14 changes: 12 additions & 2 deletions adala/utils/registry.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import logging
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, field_serializer
from abc import ABC

logger = logging.getLogger(__name__)

_registry = {}


class BaseModelInRegistry(BaseModel, ABC):

type: Optional[str] = None # TODO: this is a workaround for the `type` being represented in OpenAPI schema. If you have a better idea, feel free to fix it

@field_serializer('type')
def serialize_type(self, v: str) -> str:
if v is None:
v = self.__class__.__name__
return v

def __init_subclass__(cls, **kwargs):
global _registry

Expand All @@ -27,4 +36,5 @@ def create_from_registry(cls, type, **kwargs):
if type not in _registry:
raise ValueError(f"Class type '{type}' is not registered. "
f"Available types: {list(_registry.keys())}")
return _registry[type](type=type, **kwargs)
obj = _registry[type](type=type, **kwargs)
return obj

0 comments on commit 88c84f4

Please sign in to comment.