Skip to content

Commit

Permalink
feat(eval): rewrite log_completions to save completions to directory (A…
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww authored Oct 25, 2024
1 parent c3da25f commit 7340b78
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 20 deletions.
11 changes: 10 additions & 1 deletion evaluation/integration_tests/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

def get_config(
metadata: EvalMetadata,
instance_id: str,
) -> AppConfig:
config = AppConfig(
default_agent=metadata.agent_class,
Expand All @@ -49,6 +50,14 @@ def get_config(
workspace_base=None,
workspace_mount_path=None,
)
if metadata.llm_config.log_completions:
metadata.llm_config.log_completions_folder = os.path.join(
metadata.eval_output_dir, 'llm_completions', instance_id
)
logger.info(
f'Logging LLM completions for instance {instance_id} to '
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
return config

Expand All @@ -58,7 +67,7 @@ def process_instance(
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
config = get_config(metadata)
config = get_config(metadata, instance.instance_id)

# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:
Expand Down
9 changes: 8 additions & 1 deletion evaluation/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def get_config(
workspace_base=None,
workspace_mount_path=None,
)
if metadata.llm_config.log_completions:
metadata.llm_config.log_completions_folder = os.path.join(
metadata.eval_output_dir, 'llm_completions', instance['instance_id']
)
logger.info(
f'Logging LLM completions for instance {instance["instance_id"]} to '
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
return config

Expand Down Expand Up @@ -432,7 +440,6 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
llm_completions=state.extra_data.get('llm_completions', []),
error=state.last_error if state and state.last_error else None,
)
return output
Expand Down
1 change: 0 additions & 1 deletion evaluation/utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class EvalOutput(BaseModel):
history: (
list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
) = None
llm_completions: list[dict[str, Any]] | None = None
metrics: dict[str, Any] | None = None
error: str | None = None

Expand Down
4 changes: 0 additions & 4 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@ def update_state_before_step(self):
async def update_state_after_step(self):
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
if 'llm_completions' not in self.state.extra_data:
self.state.extra_data['llm_completions'] = []
self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
self.agent.llm.llm_completions.clear()

async def report_error(self, message: str, exception: Exception | None = None):
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.
Expand Down
2 changes: 2 additions & 0 deletions openhands/core/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class LLMConfig:
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider.
log_completions: Whether to log LLM completions to the state.
log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
"""

Expand Down Expand Up @@ -73,6 +74,7 @@ class LLMConfig:
disable_vision: bool | None = None
caching_prompt: bool = True
log_completions: bool = False
log_completions_folder: str | None = None
draft_editor: Optional['LLMConfig'] = None

def defaults_to_dict(self) -> dict:
Expand Down
39 changes: 26 additions & 13 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import json
import os
import time
import warnings
from functools import partial
Expand Down Expand Up @@ -77,11 +79,6 @@ def __init__(
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)

# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages
# - 'response': response from the LLM
self.llm_completions: list[dict[str, Any]] = []

# litellm actually uses base Exception here for unknown model
self.model_info: ModelInfo | None = None
try:
Expand All @@ -95,6 +92,13 @@ def __init__(
except Exception as e:
logger.warning(f'Could not get model info for {config.model}:\n{e}')

if self.config.log_completions:
if self.config.log_completions_folder is None:
raise RuntimeError(
'log_completions_folder is required when log_completions is enabled'
)
os.makedirs(self.config.log_completions_folder, exist_ok=True)

# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
Expand Down Expand Up @@ -194,14 +198,24 @@ def wrapper(*args, **kwargs):

# log for evals or other scripts that need the raw completion
if self.config.log_completions:
self.llm_completions.append(
{
'messages': messages,
'response': resp,
'timestamp': time.time(),
'cost': self._completion_cost(resp),
}
assert self.config.log_completions_folder is not None
log_file = os.path.join(
self.config.log_completions_folder,
# use the metric model name (for draft editor)
f'{self.metrics.model_name}-{time.time()}.json',
)
with open(log_file, 'w') as f:
json.dump(
{
'messages': messages,
'response': resp,
'args': args,
'kwargs': kwargs,
'timestamp': time.time(),
'cost': self._completion_cost(resp),
},
f,
)

message_back: str = resp['choices'][0]['message']['content']

Expand Down Expand Up @@ -400,7 +414,6 @@ def __repr__(self):

def reset(self):
self.metrics.reset()
self.llm_completions = []

def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
if isinstance(messages, Message):
Expand Down

0 comments on commit 7340b78

Please sign in to comment.