diff --git a/python_api/api.py b/python_api/api.py index e806aeb..4816458 100644 --- a/python_api/api.py +++ b/python_api/api.py @@ -340,25 +340,30 @@ async def predict(request: PredictRequest): # experiment_id=create_expriement( cur, conn) - trail= get_trial_by_model_and_input( model_id, inputs) - # print(trail) + trial= get_trial_by_model_and_input( model_id, inputs) + if not experiment_id: + cur,conn=get_db_cur_con() + experiment_id=create_expriement(cur, conn) + + # print(trail) + # if trail[2] # print("trial") - if trail and trail[2] is not None: + if trial: #existing trial # print(trail[2]) - experiment_id = trail[0] + cur,conn=get_db_cur_con() - trial_id = trail[1] + source_trial = trial model=get_model_by_id(model_id,cur,conn) - if not experiment_id: - experiment_id=create_expriement(cur, conn) + new_trial_id=create_trial( model_id, experiment_id, cur, conn,source_trial) + # if not experiment_id: + # experiment_id=create_expriement(cur, conn) - return {"experimentId": experiment_id, "trialId": trial_id, "model_id": model["name"], "input_url": inputs} + return {"experimentId": experiment_id, "trialId": new_trial_id, "model_id": model["name"], "input_url": inputs} else: cur,conn=get_db_cur_con() - if not experiment_id: - experiment_id=create_expriement(cur, conn) + # create a new trial and generate a new uuid experiment # cur,conn=get_db_cur_con() @@ -396,9 +401,11 @@ async def delete_trial(trial_id: str): @app.get("/trial/{trial_id}") async def get_trial(trial_id: str): cur,conn=get_db_cur_con() + cur.execute(""" SELECT t.id AS trial_id, t.result, + t.source_trial_id as source_trial, t.completed_at, ti.url AS input_url, m.id AS modelId, @@ -443,7 +450,10 @@ async def get_trial(trial_id: str): row = cur.fetchone() if not row: - raise Exception(f"No trial found with ID {trial_id}") + # raise Exception(f"No trial found with ID {trial_id}") + return None + if row["source_trial"] is not None: + return get_trial(row["source_trial"]) # print(row) # Prepare the response structure result = { diff --git a/python_api/db.py b/python_api/db.py index cf33538..a39cfbb 100644 --- a/python_api/db.py +++ b/python_api/db.py @@ -30,9 +30,9 @@ def close_db_cur_con(cur, conn): cur.close() conn.close() -def create_trial( model_id, experiment_id, cur, conn): +def create_trial( model_id, experiment_id, cur, conn,source_id="",completed_at=None): trial_id= str(uuid.uuid4()) - cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at ,experiment_id) VALUES (%s,%s,%s,%s, %s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() , experiment_id)) + cur.execute("INSERT INTO trials (id,model_id,created_at,updated_at,completed_at,experiment_id,source_trial_id) VALUES (%s,%s,%s,%s,%s,%s,%s) RETURNING id", (trial_id,model_id, datetime.now(), datetime.now() ,datetime.now() , experiment_id,source_id)) conn.commit() return trial_id @@ -123,7 +123,7 @@ def get_trial_by_model_and_input(model_id, input_urls): # Assuming the columns are returned in the order you expect # You may need to adjust this depending on your database schema - return (trial['experiment_id'], trial['trial_id'], trial['completed_at']) + return (trial['trial_id']) except (Exception, psycopg2.DatabaseError) as error: