From 736ecdf92c73164bd2ff089b2ad6ba9456c78794 Mon Sep 17 00:00:00 2001 From: mreid-tt <943378+mreid-tt@users.noreply.github.com> Date: Fri, 28 Apr 2023 21:36:33 -0400 Subject: [PATCH 1/2] Add threading to create operations --- spkrepo/views/api.py | 212 +++++++++++++++++++++++-------------------- 1 file changed, 116 insertions(+), 96 deletions(-) diff --git a/spkrepo/views/api.py b/spkrepo/views/api.py index 639c9e5..875506c 100644 --- a/spkrepo/views/api.py +++ b/spkrepo/views/api.py @@ -3,6 +3,7 @@ import os import re import shutil +import threading from functools import wraps from flask import Blueprint, _request_ctx_stack, current_app, request @@ -33,6 +34,9 @@ firmware_re = re.compile(r"^(?P\d\.\d)-(?P\d{3,6})$") version_re = re.compile(r"^(?P.*)-(?P\d+)$") +# Create two locks +package_lock = threading.Lock() +version_lock = threading.Lock() def api_auth_required(f): @wraps(f) @@ -127,93 +131,121 @@ def post(self): if firmware is None: abort(422, message="Unknown firmware") + # path to save files + data_path = current_app.config["DATA_PATH"] + # Package - create_package = False - package = Package.find(spk.info["package"]) - if package is None: - if not current_user.has_role("package_admin"): - abort(403, message="Insufficient permissions to create new packages") - create_package = True - package = Package(name=spk.info["package"], author=current_user) - elif ( - not current_user.has_role("package_admin") - and current_user not in package.maintainers - ): - abort(403, message="Insufficient permissions on this package") + with package_lock: + create_package = False + package = Package.find(spk.info["package"]) + if package is None: + if not current_user.has_role("package_admin"): + abort(403, message="Insufficient permissions to create new packages") + create_package = True + package = Package(name=spk.info["package"], author=current_user) + elif ( + not current_user.has_role("package_admin") + and current_user not in package.maintainers + ): + abort(403, message="Insufficient permissions on this package") + + if create_package: + try: + os.mkdir(os.path.join(data_path, package.name)) + except Exception as e: # pragma: no cover + shutil.rmtree(os.path.join(data_path, package.name), ignore_errors=True) + abort(500, message="Failed to create directory", details=str(e)) + # Add package to database + db.session.add(package) + db.session.commit() # Version - create_version = False - match = version_re.match(spk.info["version"]) - if not match: - abort(422, message="Invalid version") - # TODO: check discrepencies with what's in the database - version = {v.version: v for v in package.versions}.get( - int(match.group("version")) - ) - if version is None: - create_version = True - version_startable = None - if spk.info.get("startable") is False or spk.info.get("ctl_stop") is False: - version_startable = False - elif spk.info.get("startable") is True or spk.info.get("ctl_stop") is True: - version_startable = True - version = Version( - package=package, - upstream_version=match.group("upstream_version"), - version=int(match.group("version")), - changelog=spk.info.get("changelog"), - report_url=spk.info.get("report_url"), - distributor=spk.info.get("distributor"), - distributor_url=spk.info.get("distributor_url"), - maintainer=spk.info.get("maintainer"), - maintainer_url=spk.info.get("maintainer_url"), - dependencies=spk.info.get("install_dep_packages"), - conf_dependencies=spk.conf_dependencies, - conflicts=spk.info.get("install_conflict_packages"), - conf_conflicts=spk.conf_conflicts, - conf_privilege=spk.conf_privilege, - conf_resource=spk.conf_resource, - install_wizard="install" in spk.wizards, - upgrade_wizard="upgrade" in spk.wizards, - startable=version_startable, - license=spk.license, + with version_lock: + create_version = False + match = version_re.match(spk.info["version"]) + if not match: + abort(422, message="Invalid version") + # TODO: check discrepencies with what's in the database + version = {v.version: v for v in package.versions}.get( + int(match.group("version")) ) + if version is None: + create_version = True + version_startable = None + if spk.info.get("startable") is False or spk.info.get("ctl_stop") is False: + version_startable = False + elif spk.info.get("startable") is True or spk.info.get("ctl_stop") is True: + version_startable = True + version = Version( + package=package, + upstream_version=match.group("upstream_version"), + version=int(match.group("version")), + changelog=spk.info.get("changelog"), + report_url=spk.info.get("report_url"), + distributor=spk.info.get("distributor"), + distributor_url=spk.info.get("distributor_url"), + maintainer=spk.info.get("maintainer"), + maintainer_url=spk.info.get("maintainer_url"), + dependencies=spk.info.get("install_dep_packages"), + conf_dependencies=spk.conf_dependencies, + conflicts=spk.info.get("install_conflict_packages"), + conf_conflicts=spk.conf_conflicts, + conf_privilege=spk.conf_privilege, + conf_resource=spk.conf_resource, + install_wizard="install" in spk.wizards, + upgrade_wizard="upgrade" in spk.wizards, + startable=version_startable, + license=spk.license, + ) - for key, value in spk.info.items(): - if key == "install_dep_services": - for service_name in value.split(): - version.service_dependencies.append(Service.find(service_name)) - elif key == "displayname": - version.displaynames["enu"] = DisplayName( - language=Language.find("enu"), displayname=value - ) - elif key.startswith("displayname_"): - language = Language.find(key.split("_", 1)[1]) - if not language: - abort(422, message="Unknown INFO displayname language") - version.displaynames[language.code] = DisplayName( - language=language, displayname=value - ) - elif key == "description": - version.descriptions["enu"] = Description( - description=value, language=Language.find("enu") - ) - elif key.startswith("description_"): - language = Language.find(key.split("_", 1)[1]) - if not language: - abort(422, message="Unknown INFO description language") - version.descriptions[language.code] = Description( - language=language, description=value + for key, value in spk.info.items(): + if key == "install_dep_services": + for service_name in value.split(): + version.service_dependencies.append(Service.find(service_name)) + elif key == "displayname": + version.displaynames["enu"] = DisplayName( + language=Language.find("enu"), displayname=value + ) + elif key.startswith("displayname_"): + language = Language.find(key.split("_", 1)[1]) + if not language: + abort(422, message="Unknown INFO displayname language") + version.displaynames[language.code] = DisplayName( + language=language, displayname=value + ) + elif key == "description": + version.descriptions["enu"] = Description( + description=value, language=Language.find("enu") + ) + elif key.startswith("description_"): + language = Language.find(key.split("_", 1)[1]) + if not language: + abort(422, message="Unknown INFO description language") + version.descriptions[language.code] = Description( + language=language, description=value + ) + + # Icon + for size, icon in spk.icons.items(): + version.icons[size] = Icon( + path=os.path.join( + package.name, str(version.version), "icon_%s.png" % size + ), + size=size, ) - # Icon - for size, icon in spk.icons.items(): - version.icons[size] = Icon( - path=os.path.join( - package.name, str(version.version), "icon_%s.png" % size - ), - size=size, - ) + if create_version: + try: + os.mkdir(os.path.join(data_path, package.name, str(version.version))) + except Exception as e: # pragma: no cover + shutil.rmtree( + os.path.join(data_path, package.name, str(version.version)), + ignore_errors=True, + ) + abort(500, message="Failed to create directory", details=str(e)) + # Add version to database + db.session.add(version) + db.session.commit() # Build if version.id: @@ -254,27 +286,15 @@ def post(self): # save files try: - data_path = current_app.config["DATA_PATH"] - if create_package: - os.mkdir(os.path.join(data_path, package.name)) if create_version: - os.mkdir(os.path.join(data_path, package.name, str(version.version))) for size, icon in build.version.icons.items(): icon.save(spk.icons[size]) build.save(spk.stream) except Exception as e: # pragma: no cover - if create_package: - shutil.rmtree(os.path.join(data_path, package.name), ignore_errors=True) - elif create_version: - shutil.rmtree( - os.path.join(data_path, package.name, str(version.version)), - ignore_errors=True, - ) - else: - try: - os.remove(os.path.join(data_path, build.path)) - except OSError: - pass + try: + os.remove(os.path.join(data_path, build.path)) + except OSError: + pass abort(500, message="Failed to save files", details=str(e)) # insert the package into database From 1240c5e843881c5fc8608eaaa23cf46ddc311649 Mon Sep 17 00:00:00 2001 From: mreid-tt <943378+mreid-tt@users.noreply.github.com> Date: Sat, 29 Apr 2023 17:16:48 -0400 Subject: [PATCH 2/2] Use distributed lock for multiple worker processes --- spkrepo/views/api.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/spkrepo/views/api.py b/spkrepo/views/api.py index 875506c..a3fe5b1 100644 --- a/spkrepo/views/api.py +++ b/spkrepo/views/api.py @@ -3,7 +3,7 @@ import os import re import shutil -import threading +import redis from functools import wraps from flask import Blueprint, _request_ctx_stack, current_app, request @@ -29,14 +29,13 @@ from ..utils import SPK api = Blueprint("api", __name__) +redis_client = redis.Redis(host='localhost', port=6379, db=0) +redis_lock = redis_client.lock('my_lock') # regexes firmware_re = re.compile(r"^(?P\d\.\d)-(?P\d{3,6})$") version_re = re.compile(r"^(?P.*)-(?P\d+)$") -# Create two locks -package_lock = threading.Lock() -version_lock = threading.Lock() def api_auth_required(f): @wraps(f) @@ -135,7 +134,10 @@ def post(self): data_path = current_app.config["DATA_PATH"] # Package - with package_lock: + try: + # Acquire the Redis lock + redis_lock.acquire() + create_package = False package = Package.find(spk.info["package"]) if package is None: @@ -158,9 +160,16 @@ def post(self): # Add package to database db.session.add(package) db.session.commit() + + finally: + # Release the Redis lock + redis_lock.release() # Version - with version_lock: + try: + # Acquire the Redis lock + redis_lock.acquire() + create_version = False match = version_re.match(spk.info["version"]) if not match: @@ -247,6 +256,10 @@ def post(self): db.session.add(version) db.session.commit() + finally: + # Release the Redis lock + redis_lock.release() + # Build if version.id: # check for conflicts