Skip to content

Commit

Permalink
enable parallel writes
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Oct 13, 2024
1 parent 49379c4 commit 8f4628b
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 24 deletions.
21 changes: 19 additions & 2 deletions ell-studio/src/components/evaluations/MetricGraphGrid.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Card } from '../common/Card';
import Graph from '../graphing/Graph';
import { GraphProvider } from '../graphing/GraphSystem';
import MetricDisplay from './MetricDisplay';
import { Link } from 'react-router-dom';

const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => {
const [activeIndex, setActiveIndex] = useState(null);
Expand Down Expand Up @@ -40,7 +41,7 @@ const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => {
}, { means: [], stdDevs: [], errors: [], confidenceIntervals: [] });
};

const xData = Array.from({ length: getHistoricalData(evaluation.labelers?.[0]).means.length}, (_, i) => `Run ${i + 1}`);
const xData = Array.from({ length: getHistoricalData(evaluation.labelers?.[0]).means.length}, (_, i) => `${i + 1}`);

const handleHover = useCallback((index) => {
setActiveIndex(index);
Expand All @@ -53,7 +54,6 @@ const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => {
}, [onActiveIndexChange]);

const hasMultipleValues = getHistoricalData(evaluation.labelers[0]).means.length > 1;

return (
<GraphProvider
xData={xData}
Expand All @@ -69,6 +69,21 @@ const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => {
position: 'average',
// TODO: Make the label custom so when we click it takes us to that run id.
},
},
scales: {
x: {
display: true,
title: {
display: true,
text: 'Run Number'
}
},
y: {
display: true,
title: {
display: false
}
}
}
}
}}
Expand All @@ -86,6 +101,7 @@ const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => {
return (
<Card key={labeler.id}>
<div className={`flex justify-between items-center p-2 ${hasMultipleValues ? 'border-b border-gray-800' : ''}`}>
<Link to={`/lmp/${labeler.labeling_lmp.name}/${labeler.labeling_lmp.lmp_id}`}>
<LMPCardTitle
lmp={labeler.labeling_lmp}
nameOverridePrint={labeler.name}
Expand All @@ -95,6 +111,7 @@ const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => {
paddingClassOverride="p-0"
shortVersion={true}
/>
</Link>
<MetricDisplay
currentValue={currentValue}
previousValue={previousValue}
Expand Down
4 changes: 3 additions & 1 deletion ell-studio/src/components/graphing/GraphSystem.js
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ export const GraphRenderer = ({ graphId }) => {
dataset.errorBars && dataset.errorBars.some(error => error > 0 || (error.low - error.high > 0))
);

let yAxisScale = {};
let yAxisScale = {
...sharedConfig.options.scales.y,};
if (hasNonZeroErrorBars || true) {
// Calculate min and max values including error bars
const minMaxValues = data.datasets.reduce((acc, dataset) => {
Expand All @@ -150,6 +151,7 @@ export const GraphRenderer = ({ graphId }) => {

yAxisScale = {
y: {
...sharedConfig.options.scales.y,
beginAtZero: false,
min: Math.max(0, minMaxValues.min - yAxisPadding),
max: minMaxValues.max + yAxisPadding,
Expand Down
2 changes: 1 addition & 1 deletion ell-studio/src/pages/Evaluation.js
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function Evaluation() {

<div className="mb-6">
<div className="flex border-b border-border">
{['Runs', 'Metrics', 'Version History'].map((tab) => (
{['Runs', 'Version History'].map((tab) => (
<button
key={tab}
className={`px-4 py-2 focus:outline-none ${
Expand Down
8 changes: 4 additions & 4 deletions examples/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def is_correct(datapoint, output):
return float(output.lower() == label.lower())

eval = ell.evaluation.Evaluation(
name="test", dataset=dataset, metrics={"score": is_correct, "length": lambda _, output: len(output)}
name="capital_prediction", dataset=dataset, metrics={"score": is_correct, "length": lambda _, output: len(output)}
)

# ell.init(verbose=True, store='./logdir')
Expand Down Expand Up @@ -265,6 +265,6 @@ def score(datapoint, output):


if __name__ == "__main__":
test_poem_eval()
# ell.init(verbose=True, store="./logdir")
# test_predictor_evaluation()
# test_poem_eval()
ell.init(verbose=True, store="./logdir")
test_predictor_evaluation()
38 changes: 22 additions & 16 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from typing import Any, Optional, Dict, List, Set, Union
from pydantic import BaseModel
import sqlalchemy
from sqlmodel import Session, SQLModel, create_engine, select
import ell.store
import cattrs
Expand Down Expand Up @@ -31,22 +32,27 @@ def __init__(self, db_uri: str, blob_store: Optional[ell.store.BlobStore] = None

def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]:
with Session(self.engine) as session:
# Bind the serialized_lmp to the session
lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == serialized_lmp.lmp_id)).first()

if lmp:
# Already added to the DB.
return lmp
else:
session.add(serialized_lmp)

for use_id in uses:
used_lmp = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == use_id)).first()
if used_lmp:
serialized_lmp.uses.append(used_lmp)

session.commit()
return None
try:
# Bind the serialized_lmp to the session
lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == serialized_lmp.lmp_id)).first()

if lmp:
# Already added to the DB.
return lmp
else:
session.add(serialized_lmp)

for use_id in uses:
used_lmp = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == use_id)).first()
if used_lmp:
serialized_lmp.uses.append(used_lmp)

session.commit()
return None
except sqlalchemy.exc.IntegrityError as e:
session.rollback()
print("race condition")
return None

def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
with Session(self.engine) as session:
Expand Down

0 comments on commit 8f4628b

Please sign in to comment.