From e7f67c59b7852513a4b5eb46743d4879bf416766 Mon Sep 17 00:00:00 2001 From: PascalEgn Date: Fri, 2 Aug 2024 14:58:09 +0200 Subject: [PATCH] classifier: fix memory usage --- inspire_classifier/api.py | 27 +++++++++++++++------------ inspire_classifier/app.py | 7 +++++-- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/inspire_classifier/api.py b/inspire_classifier/api.py index f22c7a5..ac6c515 100644 --- a/inspire_classifier/api.py +++ b/inspire_classifier/api.py @@ -151,24 +151,27 @@ def train(): train_and_save_classifier() -def predict_coreness(title, abstract): +def initialize_classifier(): """ - Predicts class-wise probabilities given the title and abstract. + Initializes the classifier. """ - text = title + " " + abstract - categories = ["rejected", "non_core", "core"] - try: - classifier = Classifier( - cuda_device_id=current_app.config["CLASSIFIER_CUDA_DEVICE_ID"] - ) - except IOError as error: - raise IOError("Data ITOS not found.") from error - + classifier = Classifier( + cuda_device_id=current_app.config["CLASSIFIER_CUDA_DEVICE_ID"] + ) try: classifier.load_trained_classifier_weights(path_for("trained_classifier")) except IOError as error: - raise IOError("Could not load the trained classifier weights.") from error + raise IOError("Could not load the trained classifier weights.", path_for("trained_classifier")) from error + + return classifier + +def predict_coreness(classifier, title, abstract): + """ + Predicts class-wise probabilities given the title and abstract. + """ + text = title + " " + abstract + categories = ["rejected", "non_core", "core"] class_probabilities = classifier.predict( text, temperature=current_app.config["CLASSIFIER_SOFTMAX_TEMPERATUR"] ) diff --git a/inspire_classifier/app.py b/inspire_classifier/app.py index eefadb1..c4cb346 100644 --- a/inspire_classifier/app.py +++ b/inspire_classifier/app.py @@ -28,7 +28,7 @@ from prometheus_flask_exporter.multiprocess import GunicornInternalPrometheusMetrics from webargs.flaskparser import use_args -from inspire_classifier.api import predict_coreness +from inspire_classifier.api import predict_coreness, initialize_classifier from . import serializers @@ -55,6 +55,8 @@ def create_app(): app.config["CLASSIFIER_BASE_PATH"] = app.instance_path app.config.from_object("inspire_classifier.config") app.config.from_pyfile("classifier.cfg", silent=True) + with app.app_context(): + classifier = initialize_classifier() @app.route("/api/health") def date(): @@ -69,7 +71,7 @@ def date(): ) def core_classifier(args): """Endpoint for the CORE classifier.""" - prediction = predict_coreness(args["title"], args["abstract"]) + prediction = predict_coreness(classifier, args["title"], args["abstract"]) response = coreness_schema.dump(prediction) return response @@ -89,3 +91,4 @@ def page_not_found(e): if __name__ == "__main__": app.run(host="0.0.0.0") +