Skip to content

Commit

Permalink
create evaluate_sklearn()
Browse files Browse the repository at this point in the history
  • Loading branch information
tzamalisp committed Aug 20, 2020
1 parent 5c568ae commit 5367478
Showing 1 changed file with 8 additions and 25 deletions.
33 changes: 8 additions & 25 deletions dataset_eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,8 @@ def evaluate_dataset(eval_job, dataset_dir, storage_dir):
with open(groundtruth_path, "w") as f:
yaml.dump(create_groundtruth_dict(snapshot["data"]["name"], train), f)

# logging.info("Training GAIA model...")
# evaluate_gaia(eval_location, groundtruth_path, filelist_path, storage_dir, eval_job)

# Passing more user preferences to train the model.
logging.info("Training model...")
results = gaia_wrapper.train_model(
project_dir=eval_location,
groundtruth_file=groundtruth_path,
filelist_file=filelist_path,
c_values=eval_job["options"].get("c_values", []),
gamma_values=eval_job["options"].get("gamma_values", []),
preprocessing_values=eval_job["options"].get("preprocessing_values", []),
)
logging.info("Saving results...")
save_history_file(storage_dir, results["history_path"], eval_job["id"])
db.dataset_eval.set_job_result(eval_job["id"], json.dumps({
"project_path": eval_location,
"parameters": results["parameters"],
"accuracy": results["accuracy"],
"confusion_matrix": results["confusion_matrix"],
"history_path": results["history_path"],
}))
logging.info("Training GAIA model...")
evaluate_gaia(eval_location, groundtruth_path, filelist_path, storage_dir, eval_job)

db.dataset_eval.set_job_status(eval_job["id"], db.dataset_eval.STATUS_DONE)
logging.info("Evaluation job %s has been completed." % eval_job["id"])
Expand All @@ -104,6 +84,7 @@ def evaluate_dataset(eval_job, dataset_dir, storage_dir):


def evaluate_gaia(eval_location, groundtruth_path, filelist_path, storage_dir, eval_job):
# Passing more user preferences to train the model.
results = gaia_wrapper.train_model(
project_dir=eval_location,
groundtruth_file=groundtruth_path,
Expand All @@ -123,9 +104,11 @@ def evaluate_gaia(eval_location, groundtruth_path, filelist_path, storage_dir, e
}))


def evaluate_sklearn(eval_location, groundtruth_path, filelist_path, storage_dir, eval_job):
# create_classification_project(ground_truth_directory=groundtruth_path)
pass
def evaluate_sklearn(eval_location, dataset_dir, storage_dir, eval_job):
create_classification_project(ground_truth_directory=dataset_dir,
project_file=eval_job["id"],
exports_directory=eval_job["id"],
exports_path=eval_location)


def create_groundtruth_dict(name, datadict):
Expand Down

0 comments on commit 5367478

Please sign in to comment.