Skip to content

Commit

Permalink
fix writing to wrong graph
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Oct 13, 2024
1 parent 8f4628b commit 29a2e05
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 35 deletions.
33 changes: 18 additions & 15 deletions ell-studio/src/components/depgraph/graphUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,24 @@ export const getInitialGraph = (lmps, traces, evals) => {
};
});

// Create LMP nodes, excluding those that are part of evaluations
const lmpNodes = lmps.filter(Boolean).filter(lmp => !evalLmpIds.has(lmp.lmp_id)).map(lmp => {
const dimensions = calculateNodeDimensions('lmp', lmp);
return {
id: `${lmp.lmp_id}`,
type: "lmp",
data: {
label: lmp.name,
lmp,
isEvalLabeler: evalLmpIds.has(lmp.lmp_id),
...dimensions
},
position: { x: 0, y: 0 },
};
});
// Create LMP nodes, excluding those that are part of evaluations and those of type "metric"
const lmpNodes = lmps.filter(Boolean)
.filter(lmp => !evalLmpIds.has(lmp.lmp_id) && lmp.lmp_type !== "LABELER")
.map(lmp => {
const dimensions = calculateNodeDimensions('lmp', lmp);
console.log(lmp);
return {
id: `${lmp.lmp_id}`,
type: "lmp",
data: {
label: lmp.name,
lmp,
isEvalLabeler: evalLmpIds.has(lmp.lmp_id),
...dimensions
},
position: { x: 0, y: 0 },
};
});

const deadNodes = lmps.flatMap(lmp =>
(lmp.uses || [])
Expand Down
1 change: 1 addition & 0 deletions ell-studio/src/components/evaluations/MetricGraphGrid.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => {
if (summary) {
const { mean, std, min, max } = summary.data;
const count = summary.count;
console.log(count)

// Calculate Standard Error of the Mean (SEM)
const sem = std / Math.sqrt(count);
Expand Down
5 changes: 3 additions & 2 deletions examples/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ def is_correct(datapoint, output):
return float(output.lower() == label.lower())

eval = ell.evaluation.Evaluation(
name="capital_prediction", dataset=dataset, metrics={"score": is_correct, "length": lambda _, output: len(output)}
name="capital_prediction", dataset=dataset, metrics={"score": is_correct, "length": lambda _, output: len(output)}, samples_per_datapoint=5

)

# ell.init(verbose=True, store='./logdir')
@ell.simple(model="gpt-4o")
def predict_capital(question: str):
"""
Answer only with the capital of the country. If hotdog land, answer Banana.
Answer only with the capital of the country. If hotdog land, answer hotdog land.
"""
# print(question[0])
return f"Answer the following question. {question}"
Expand Down
28 changes: 13 additions & 15 deletions src/ell/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from functools import partial, wraps
import itertools
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast
from concurrent.futures import ThreadPoolExecutor, as_completed
import uuid

from ell.types.studio import LMPType
import openai

import numpy as np
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from sqlmodel._compat import SQLModelConfig
from ell.lmp._track import _track
from ell.types.message import LMP
Expand Down Expand Up @@ -196,7 +196,6 @@ def _write(self, evaluation_id: str):
config.store.write_evaluation_run_labeler_summaries(summaries)

class Evaluation(BaseModel):
"""Simple evaluation for prompt engineering rigorously"""

model_config = ConfigDict(arbitrary_types_allowed=True)
name: str
Expand All @@ -219,16 +218,15 @@ class Evaluation(BaseModel):
serialized: Optional[SerializedEvaluation] = Field(default=None)

id: Optional[str] = Field(default=None)

def __init__(self, *args, **kwargs):
assert (
"dataset" in kwargs or "n_evals" in kwargs
), "Either dataset or n_evals must be set"
assert not (
"dataset" in kwargs and "n_evals" in kwargs
), "Either dataset or samples_per_datapoint must be set, not both"

super().__init__(*args, **kwargs)

@model_validator(mode='before')
@classmethod
def validate_dataset_or_n_evals(cls, values):
if 'dataset' not in values and 'n_evals' not in values:
raise ValueError("Either dataset or n_evals must be set")
if 'dataset' in values and 'n_evals' in values:
raise ValueError("Either dataset or n_evals must be set, not both")
return values

@field_validator("metrics", "annotations", "criterion", mode="before")
def wrap_callables_in_lmp_function(cls, value):
Expand All @@ -237,7 +235,7 @@ def wrap_callables_in_lmp_function(cls, value):
if isinstance(value, dict):
return {
k: (
function()(v)
function(type=LMPType.LABELR)(v)
if callable(v) and not hasattr(v, "__ell_track__")
else v
)
Expand Down
1 change: 0 additions & 1 deletion src/ell/stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Opti
return None
except sqlalchemy.exc.IntegrityError as e:
session.rollback()
print("race condition")
return None

def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
Expand Down
3 changes: 1 addition & 2 deletions src/ell/types/studio/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def UTCTimestampField(index: bool = False, **kwargs: Any):
class LMPType(str, enum.Enum):
LM = "LM"
TOOL = "TOOL"
MULTIMODAL = "MULTIMODAL"
METRIC = "METRIC"
LABELR = "LABELR"
FUNCTION = "FUNCTION"
OTHER = "OTHER"

Expand Down

0 comments on commit 29a2e05

Please sign in to comment.