Skip to content

Commit

Permalink
classifier: fix memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalEgn committed Aug 2, 2024
1 parent 86e3221 commit e7f67c5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
27 changes: 15 additions & 12 deletions inspire_classifier/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,27 @@ def train():
train_and_save_classifier()


def predict_coreness(title, abstract):
def initialize_classifier():
"""
Predicts class-wise probabilities given the title and abstract.
Initializes the classifier.
"""
text = title + " <ENDTITLE> " + abstract
categories = ["rejected", "non_core", "core"]
try:
classifier = Classifier(
cuda_device_id=current_app.config["CLASSIFIER_CUDA_DEVICE_ID"]
)
except IOError as error:
raise IOError("Data ITOS not found.") from error

classifier = Classifier(
cuda_device_id=current_app.config["CLASSIFIER_CUDA_DEVICE_ID"]
)
try:
classifier.load_trained_classifier_weights(path_for("trained_classifier"))
except IOError as error:
raise IOError("Could not load the trained classifier weights.") from error
raise IOError("Could not load the trained classifier weights.", path_for("trained_classifier")) from error

return classifier


def predict_coreness(classifier, title, abstract):
"""
Predicts class-wise probabilities given the title and abstract.
"""
text = title + " <ENDTITLE> " + abstract
categories = ["rejected", "non_core", "core"]
class_probabilities = classifier.predict(
text, temperature=current_app.config["CLASSIFIER_SOFTMAX_TEMPERATUR"]
)
Expand Down
7 changes: 5 additions & 2 deletions inspire_classifier/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from prometheus_flask_exporter.multiprocess import GunicornInternalPrometheusMetrics
from webargs.flaskparser import use_args

from inspire_classifier.api import predict_coreness
from inspire_classifier.api import predict_coreness, initialize_classifier

from . import serializers

Expand All @@ -55,6 +55,8 @@ def create_app():
app.config["CLASSIFIER_BASE_PATH"] = app.instance_path
app.config.from_object("inspire_classifier.config")
app.config.from_pyfile("classifier.cfg", silent=True)
with app.app_context():
classifier = initialize_classifier()

@app.route("/api/health")
def date():
Expand All @@ -69,7 +71,7 @@ def date():
)
def core_classifier(args):
"""Endpoint for the CORE classifier."""
prediction = predict_coreness(args["title"], args["abstract"])
prediction = predict_coreness(classifier, args["title"], args["abstract"])
response = coreness_schema.dump(prediction)
return response

Expand All @@ -89,3 +91,4 @@ def page_not_found(e):

if __name__ == "__main__":
app.run(host="0.0.0.0")

0 comments on commit e7f67c5

Please sign in to comment.