Skip to content

Commit

Permalink
chore: Update get_trial function to handle source trials recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
amirnd51 committed Aug 27, 2024
1 parent 3d5c030 commit 73ae791
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
32 changes: 21 additions & 11 deletions python_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,25 +340,30 @@ async def predict(request: PredictRequest):

# experiment_id=create_expriement( cur, conn)

trail= get_trial_by_model_and_input( model_id, inputs)
# print(trail)
trial= get_trial_by_model_and_input( model_id, inputs)
if not experiment_id:
cur,conn=get_db_cur_con()
experiment_id=create_expriement(cur, conn)


# print(trail)
# if trail[2]
# print("trial")
if trail and trail[2] is not None:
if trial: #existing trial
# print(trail[2])
experiment_id = trail[0]

cur,conn=get_db_cur_con()
trial_id = trail[1]
source_trial = trial

model=get_model_by_id(model_id,cur,conn)
if not experiment_id:
experiment_id=create_expriement(cur, conn)
new_trial_id=create_trial( model_id, experiment_id, cur, conn,source_trial)
# if not experiment_id:
# experiment_id=create_expriement(cur, conn)

return {"experimentId": experiment_id, "trialId": trial_id, "model_id": model["name"], "input_url": inputs}
return {"experimentId": experiment_id, "trialId": new_trial_id, "model_id": model["name"], "input_url": inputs}
else:
cur,conn=get_db_cur_con()
if not experiment_id:
experiment_id=create_expriement(cur, conn)

# create a new trial and generate a new uuid experiment
# cur,conn=get_db_cur_con()

Expand Down Expand Up @@ -396,9 +401,11 @@ async def delete_trial(trial_id: str):
@app.get("/trial/{trial_id}")
async def get_trial(trial_id: str):
cur,conn=get_db_cur_con()

cur.execute("""
SELECT t.id AS trial_id,
t.result,
t.source_trial_id as source_trial,
t.completed_at,
ti.url AS input_url,
m.id AS modelId,
Expand Down Expand Up @@ -443,7 +450,10 @@ async def get_trial(trial_id: str):
row = cur.fetchone()

if not row:
raise Exception(f"No trial found with ID {trial_id}")
# raise Exception(f"No trial found with ID {trial_id}")
return None
if row["source_trial"] is not None:
return get_trial(row["source_trial"])
# print(row)
# Prepare the response structure
result = {
Expand Down
6 changes: 3 additions & 3 deletions python_api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def close_db_cur_con(cur, conn):
cur.close()
conn.close()

def create_trial( model_id, experiment_id, cur, conn):
def create_trial( model_id, experiment_id, cur, conn,source_id="",completed_at=None):
trial_id= str(uuid.uuid4())
cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at ,experiment_id) VALUES (%s,%s,%s,%s, %s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() , experiment_id))
cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,completed_at,experiment_id,source_trial_id) VALUES (%s,%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() ,datetime.now() , experiment_id,source_id))
conn.commit()
return trial_id

Expand Down Expand Up @@ -123,7 +123,7 @@ def get_trial_by_model_and_input(model_id, input_urls):

# Assuming the columns are returned in the order you expect
# You may need to adjust this depending on your database schema
return (trial['experiment_id'], trial['trial_id'], trial['completed_at'])
return (trial['trial_id'])


except (Exception, psycopg2.DatabaseError) as error:
Expand Down

0 comments on commit 73ae791

Please sign in to comment.