diff --git a/services/trainer/poetry.lock b/services/trainer/poetry.lock index bae792b..ef4a820 100644 --- a/services/trainer/poetry.lock +++ b/services/trainer/poetry.lock @@ -75,6 +75,22 @@ category = "main" optional = false python-versions = ">=3.6" +[[package]] +name = "cachecontrol" +version = "0.12.12" +description = "httplib2 caching for requests" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +msgpack = ">=0.5.2" +requests = "*" + +[package.extras] +filecache = ["filelock (>=3.8.0)"] +redis = ["redis (>=2.10.5)"] + [[package]] name = "cachetools" version = "5.2.0" @@ -132,6 +148,25 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "cryptography" +version = "38.0.2" +description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +cffi = ">=1.12" + +[package.extras] +docs = ["sphinx (>=1.6.5,!=1.8.0,!=3.1.0,!=3.1.1)", "sphinx-rtd-theme"] +docstest = ["pyenchant (>=1.6.11)", "twine (>=1.12.0)", "sphinxcontrib-spelling (>=4.0.1)"] +pep8test = ["black", "flake8", "flake8-import-order", "pep8-naming"] +sdist = ["setuptools-rust (>=0.11.4)"] +ssh = ["bcrypt (>=3.1.5)"] +test = ["pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-subtests", "pytest-xdist", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,!=3.79.2)"] + [[package]] name = "datasets" version = "2.4.0" @@ -156,7 +191,7 @@ tqdm = ">=4.62.1" xxhash = "*" [package.extras] -dev = ["conllu", "absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (>=2.0.1)", "boto3 (>=1.19.8)", "botocore (>=1.22.8)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score (<0.0.7)", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"] +dev = ["conllu", "absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (>=2.0.1)", "boto3 (>=1.19.8)", "botocore (>=1.22.8)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score (<0.0.7)", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"] apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa"] benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"] @@ -165,7 +200,7 @@ quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyam s3 = ["fsspec", "boto3", "botocore", "s3fs"] tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"] tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (>=2.0.1)", "boto3 (>=1.19.8)", "botocore (>=1.22.8)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score (<0.0.7)", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "importlib-resources"] +tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (>=2.0.1)", "boto3 (>=1.19.8)", "botocore (>=1.22.8)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score (<0.0.7)", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "importlib-resources"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] @@ -200,6 +235,22 @@ python-versions = ">=3.7" docs = ["furo (>=2022.6.21)", "sphinx (>=5.1.1)", "sphinx-autodoc-typehints (>=1.19.1)"] testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pytest-cov (>=3)", "pytest-timeout (>=2.1)"] +[[package]] +name = "firebase-admin" +version = "6.0.1" +description = "Firebase Admin Python SDK" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +cachecontrol = ">=0.12.6" +google-api-core = {version = ">=1.22.1,<3.0.0dev", extras = ["grpc"], markers = "platform_python_implementation != \"PyPy\""} +google-api-python-client = ">=1.7.8" +google-cloud-firestore = {version = ">=2.1.0", markers = "platform_python_implementation != \"PyPy\""} +google-cloud-storage = ">=1.37.1" +pyjwt = {version = ">=2.5.0", extras = ["crypto"]} + [[package]] name = "flask" version = "2.2.2" @@ -272,6 +323,8 @@ python-versions = ">=3.7" [package.dependencies] google-auth = ">=1.25.0,<3.0dev" googleapis-common-protos = ">=1.56.2,<2.0dev" +grpcio = {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""} +grpcio-status = {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""} protobuf = ">=3.20.1,<5.0.0dev" requests = ">=2.18.0,<3.0.0dev" @@ -280,6 +333,21 @@ grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio-status (>=1.33.2,<2.0dev)"] grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"] grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"] +[[package]] +name = "google-api-python-client" +version = "2.64.0" +description = "Google API Client Library for Python" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.0 || >2.3.0,<3.0.0dev" +google-auth = ">=1.19.0,<3.0.0dev" +google-auth-httplib2 = ">=0.1.0" +httplib2 = ">=0.15.0,<1dev" +uritemplate = ">=3.0.1,<5" + [[package]] name = "google-auth" version = "2.11.0" @@ -300,6 +368,19 @@ enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"] pyopenssl = ["pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] +[[package]] +name = "google-auth-httplib2" +version = "0.1.0" +description = "Google Authentication Library: httplib2 transport" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.15.0" +six = "*" + [[package]] name = "google-cloud-core" version = "2.3.2" @@ -315,6 +396,20 @@ google-auth = ">=1.25.0,<3.0dev" [package.extras] grpc = ["grpcio (>=1.38.0,<2.0dev)"] +[[package]] +name = "google-cloud-firestore" +version = "2.7.1" +description = "Google Cloud Firestore API client library" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +google-api-core = {version = ">=1.32.0,<2.0.0 || >=2.8.0,<3.0.0dev", extras = ["grpc"]} +google-cloud-core = ">=1.4.1,<3.0.0dev" +proto-plus = ">=1.22.0,<2.0.0dev" +protobuf = ">=3.20.2,<5.0.0dev" + [[package]] name = "google-cloud-storage" version = "2.5.0" @@ -373,6 +468,33 @@ protobuf = ">=3.15.0,<5.0.0dev" [package.extras] grpc = ["grpcio (>=1.0.0,<2.0.0dev)"] +[[package]] +name = "grpcio" +version = "1.49.1" +description = "HTTP/2-based RPC framework" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +six = ">=1.5.2" + +[package.extras] +protobuf = ["grpcio-tools (>=1.49.1)"] + +[[package]] +name = "grpcio-status" +version = "1.49.1" +description = "Status proto mapping for gRPC" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.49.1" +protobuf = ">=4.21.3" + [[package]] name = "gunicorn" version = "20.1.0" @@ -387,6 +509,17 @@ gevent = ["gevent (>=1.4.0)"] setproctitle = ["setproctitle"] tornado = ["tornado (>=0.2)"] +[[package]] +name = "httplib2" +version = "0.20.4" +description = "A comprehensive HTTP client library." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "huggingface-hub" version = "0.9.1" @@ -515,6 +648,14 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "msgpack" +version = "1.0.4" +description = "MessagePack serializer" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "multidict" version = "6.0.2" @@ -618,6 +759,20 @@ xxhash = ["xxhash (>=1.4.3)"] sftp = ["paramiko (>=2.7.0)"] progress = ["tqdm (>=4.41.0,<5.0.0)"] +[[package]] +name = "proto-plus" +version = "1.22.1" +description = "Beautiful, Pythonic protocol buffers." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +protobuf = ">=3.19.0,<5.0.0dev" + +[package.extras] +testing = ["google-api-core[grpc] (>=1.31.5)"] + [[package]] name = "protobuf" version = "4.21.5" @@ -680,6 +835,24 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "pyjwt" +version = "2.5.0" +description = "JSON Web Token implementation in Python" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +cryptography = {version = ">=3.3.1", optional = true, markers = "extra == \"crypto\""} +types-cryptography = {version = ">=3.3.21", optional = true, markers = "extra == \"crypto\""} + +[package.extras] +crypto = ["cryptography (>=3.3.1)", "types-cryptography (>=3.3.21)"] +dev = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface", "cryptography (>=3.3.1)", "types-cryptography (>=3.3.21)", "pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)", "pre-commit"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)"] + [[package]] name = "pyparsing" version = "3.0.9" @@ -1007,6 +1180,14 @@ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.1)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0,<1.12)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"] vision = ["pillow"] +[[package]] +name = "types-cryptography" +version = "3.3.23.1" +description = "Typing stubs for cryptography" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "typing-extensions" version = "4.3.0" @@ -1015,6 +1196,14 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +category = "main" +optional = false +python-versions = ">=3.6" + [[package]] name = "urllib3" version = "1.26.12" @@ -1076,7 +1265,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = ">=3.10,<3.12" -content-hash = "7820568c10ef97389c45c095088b376465d67ad5742878b0e1d43073e8280919" +content-hash = "b3568a78e5ca9083d27941b63440abc199782f8297b03ad0b625908544ff61cf" [metadata.files] aiohttp = [] @@ -1086,27 +1275,36 @@ async-timeout = [] atomicwrites = [] attrs = [] audioread = [] +cachecontrol = [] cachetools = [] certifi = [] cffi = [] charset-normalizer = [] click = [] colorama = [] +cryptography = [] datasets = [] decorator = [] dill = [] filelock = [] +firebase-admin = [] flask = [] frozenlist = [] fsspec = [] google-api-core = [] +google-api-python-client = [] google-auth = [] +google-auth-httplib2 = [] google-cloud-core = [] +google-cloud-firestore = [] google-cloud-storage = [] google-crc32c = [] google-resumable-media = [] googleapis-common-protos = [] +grpcio = [] +grpcio-status = [] gunicorn = [] +httplib2 = [] huggingface-hub = [] idna = [] iniconfig = [] @@ -1117,6 +1315,7 @@ librosa = [] llvmlite = [] loguru = [] markupsafe = [] +msgpack = [] multidict = [] multiprocess = [] nodeenv = [] @@ -1126,6 +1325,7 @@ packaging = [] pandas = [] pluggy = [] pooch = [] +proto-plus = [] protobuf = [] py = [] pyarrow = [] @@ -1133,6 +1333,7 @@ pyasn1 = [] pyasn1-modules = [] pycparser = [] pyhumps = [] +pyjwt = [] pyparsing = [] pyright = [] pytest = [] @@ -1155,7 +1356,9 @@ toml = [] torch = [] tqdm = [] transformers = [] +types-cryptography = [] typing-extensions = [] +uritemplate = [] urllib3 = [] werkzeug = [] win32-setctime = [] diff --git a/services/trainer/pyproject.toml b/services/trainer/pyproject.toml index 95144bc..59653e2 100644 --- a/services/trainer/pyproject.toml +++ b/services/trainer/pyproject.toml @@ -18,6 +18,7 @@ pyhumps = "^3.7.3" google-cloud-storage = "^2.5.0" scipy = "^1.9.1" librosa = "^0.9.2" +firebase-admin = "^6.0.1" [tool.poetry.dev-dependencies] pytest = "^6.2.5" diff --git a/services/trainer/trainer/firebase.py b/services/trainer/trainer/firebase.py new file mode 100644 index 0000000..7b13029 --- /dev/null +++ b/services/trainer/trainer/firebase.py @@ -0,0 +1,14 @@ +import firebase_admin +from firebase_admin import credentials, firestore + +PROJECT_ID = "elpiscloud" + + +def get_firestore_client() -> firestore.firestore.Client: + """Returns a firestore client for the elpiscloud project.""" + cred = credentials.ApplicationDefault() + app = firebase_admin.initialize_app( + cred, + {"projectId": PROJECT_ID}, + ) + return firestore.client(app) diff --git a/services/trainer/trainer/main.py b/services/trainer/trainer/main.py index d771fd1..07d3103 100644 --- a/services/trainer/trainer/main.py +++ b/services/trainer/trainer/main.py @@ -3,11 +3,14 @@ import os from http import HTTPStatus from pathlib import Path +from typing import Optional from flask import Flask, Response, request +from google.cloud.firestore import DocumentReference from loguru import logger from trainer.cloud_storage import download_blob, list_blobs_with_prefix, upload_blob -from trainer.model_metadata import ModelMetadata +from trainer.firebase import get_firestore_client +from trainer.model_metadata import ModelMetadata, TrainingStatus from trainer.trainer import train app = Flask(__name__) @@ -74,16 +77,32 @@ def process_training_request(metadata: ModelMetadata): Paramters: metadata: The metadata of the model training job to perform. """ - logger.info("Begin executing training job") + status = get_model_status(metadata) + if status is None: + logger.info("Model was deleted from firestore. Exiting.") + return - dataset_path = DATA_PATH / "datasets" / metadata.user_id / metadata.model_name - download_dataset(metadata=metadata, dataset_path=dataset_path) + logger.info(f"Firestore model status: {status}") + if status == TrainingStatus.TRAINING: + logger.info(f"Already training! Exiting.") + return - model_path = train( - metadata=metadata, data_path=DATA_PATH, dataset_path=dataset_path - ) + set_model_status(metadata, TrainingStatus.TRAINING) + + try: + dataset_path = DATA_PATH / "datasets" / metadata.user_id / metadata.model_name + download_dataset(metadata=metadata, dataset_path=dataset_path) + model_path = train( + metadata=metadata, data_path=DATA_PATH, dataset_path=dataset_path + ) - upload_model(metadata, model_path) + upload_model(metadata, model_path) + set_model_status(metadata, TrainingStatus.FINISHED) + logger.success("Training successful!") + + except: + logger.error(f"Training failed for model: {metadata}") + set_model_status(metadata, TrainingStatus.ERROR) def download_dataset(metadata: ModelMetadata, dataset_path: Path) -> None: @@ -114,6 +133,56 @@ def upload_model(metadata: ModelMetadata, model_path: Path) -> None: upload_blob(MODEL_BUCKET, model_path / file, blob_name) +def get_model_status(metadata: ModelMetadata) -> Optional[TrainingStatus]: + """Gets the current status of the model in firestore, if it exists. + + Parameters: + metadata: Some information about the model to be trained. + + Returns: + The status of the model, or None if the model was deleted. + """ + document = _get_model_document_reference(metadata) + snapshot = document.get() + + data = snapshot.to_dict() + # Checks if dataset was deleted before we start training. + if data is None: + return + + try: + updated_model = ModelMetadata.from_dict(data) + return updated_model.status + except: + return + + +def set_model_status(metadata: ModelMetadata, status: TrainingStatus) -> None: + """Sets the current status of the model in firestore, if it exists. + + Parameters: + metadata: Some information about the model to be trained. + status: The status to set on the firestore model. + """ + document = _get_model_document_reference(metadata) + snapshot = document.get() + + if not snapshot.exists: + return + + document.update({"status": status.value}) + + +def _get_model_document_reference(metadata: ModelMetadata) -> DocumentReference: + db = get_firestore_client() + return ( + db.collection("users") + .document(metadata.user_id) + .collection("models") + .document(metadata.model_name) + ) + + if __name__ == "__main__": PORT = int(os.getenv("PORT", "8080"))