diff --git a/.env.sample b/.env.sample index 52c5dbc..ec2104a 100644 --- a/.env.sample +++ b/.env.sample @@ -9,14 +9,26 @@ SECRET_KEY=CHANGE_ME # 【请修改】管理员邮箱,第一次启动会创建此账号,默认密码为 123123,请及时修改密码! # 之后每次程序启动时,会将此账号设置为管理员 ADMIN_EMAIL=admin@moeflow.com +# initial password for auto-created admin user +ADMIN_INITIAL_PASSWORD= + +# Database +# MONGODB_URI="" # takes precedence over other MONGODB_* entries +# MONGODB_DB_NAME=moeflow +# MONGODB_USER=moeflow +# MONGODB_PASS=CHANGE_ME + +# Job Queue +# CELERY_BROKER_URL="" # takes precedence over other RABBITMQ_* entries +# RABBITMQ_USER=moeflow +# RABBITMQ_PASS=CHANGE_ME +# RABBITMQ_VHOST_NAME=moeflow -# APP 专用的 MongoDB 数据库名称 -MONGODB_DB_NAME=moeflow # ----------- # Storage 配置 # ----------- -# 目前支持 LOCAL_STORAGE 和 OSS 和 OPENDAL +# 目前支持 LOCAL_STORAGE 和 OSS STORAGE_TYPE=LOCAL_STORAGE # STORAGE_DOMAIN: 返回给客户端的图片URL前缀 # 1. 如果STORAGE_TYPE为OSS @@ -41,7 +53,7 @@ OSS_PROCESS_COVER_NAME=cover OSS_PROCESS_SAFE_CHECK_NAME=safe-check # ----------- -# CDN 配置 +# CDN 配置 # ----------- # 如果绑定了 CDN 来加速 OSS,且开启了 CDN 的[阿里云 OSS 私有 Bucket 回源]和[URL 鉴权], # 此时需要设置 OSS_VIA_CDN = True,并设置 CDN URL 鉴权主/备 KEY diff --git a/.env.test.sample b/.env.test.sample new file mode 100644 index 0000000..20e1239 --- /dev/null +++ b/.env.test.sample @@ -0,0 +1,77 @@ +# Env variables for tests +TESTING=YES +LOG_LEVEL=DEBUG +SITE_NAME=萌翻TEST +SECRET_KEY=SECRET +ADMIN_EMAIL=admin@moeflow.com + +MONGODB_URI="mongodb://moeflow:CHANGE_ME@127.0.0.1:27017/moeflow_test?authSource=admin" +# MONGODB_DB_NAME=moeflow +# MONGODB_USER=moeflow +# MONGODB_PASS=CHANGE_ME +CELERY_BROKER_URL="amqp://moeflow:CHANGE_ME@127.0.0.1:5672/moeflow" # takes precedence over other RABBITMQ_* entries +# RABBITMQ_USER=moeflow +# RABBITMQ_PASS=CHANGE_ME +# RABBITMQ_VHOST_NAME=moeflow_test + +# Storage +STORAGE_TYPE=LOCAL_STORAGE +# STORAGE_DOMAIN: 返回给客户端的图片URL前缀 +# 1. 如果STORAGE_TYPE为OSS +# - 未设置自定义域名则填写阿里云提供的 OSS 域名,格式如:https://..aliyuncs.com/ +# - 如果绑定了 CDN 来加速 OSS,则填写绑定在 CDN 的域名 +# 2. 如果STORAGE_TYPE为LOCAL_STORAGE +# - 本地储存填写绑定到服务器的域名+"/storage/",格式如:http(s)://.com/storage/, +# 3. 如果STORAGE_TYPE为OPENDAL: 不生效 (图片URL将由OPENDAL_STORAGE_PROVIDER决定) +STORAGE_DOMAIN=http://127.0.0.1:5000/storage/ +# (可不修改) 允许上传文件的最大大小(MB),默认 1GB +MAX_CONTENT_LENGTH_MB=1024 + +## OSS_*: STORAGE_TYPE为OSS时的配置 +OSS_ACCESS_KEY_ID= +OSS_ACCESS_KEY_SECRET= +# OSS Endpoint(地域节点) +# 含协议名,形如 https://oss-cn-shanghai.aliyuncs.com/ +OSS_ENDPOINT= +OSS_BUCKET_NAME= +# (可不修改) OSS 图片处理规则名称 +OSS_PROCESS_COVER_NAME=cover +OSS_PROCESS_SAFE_CHECK_NAME=safe-check + +# ----------- +# CDN 配置 +# ----------- +# 如果绑定了 CDN 来加速 OSS,且开启了 CDN 的[阿里云 OSS 私有 Bucket 回源]和[URL 鉴权], +# 此时需要设置 OSS_VIA_CDN = True,并设置 CDN URL 鉴权主/备 KEY +OSS_VIA_CDN=True +CDN_URL_KEY_A= +CDN_URL_KEY_B= + +# ----------- +# Email 配置 +# ----------- +# 是否发送用户邮件(验证码等) +ENABLE_USER_EMAIL=False +# 是否发送日志邮件 +ENABLE_LOG_EMAIL=False +# SMTP 服务器地址 +EMAIL_SMTP_HOST= +# SMTP 服务器端口 +EMAIL_SMTP_PORT= +# 是否使用 SSL 连接 SMTP 服务器 +EMAIL_USE_SSL=True +# 发件邮箱地址 +EMAIL_ADDRESS= +# SMTP 服务器登陆用户名,通常是邮箱全称 +EMAIL_USERNAME= +# SMTP 服务器登陆密码 +EMAIL_PASSWORD= +# 用户回信邮箱地址 +EMAIL_REPLY_ADDRESS= +# 网站错误报告邮箱地址 +EMAIL_ERROR_ADDRESS= + +# Options for non default features. only enable if you know what you are doing. +# CELERY_BROKER_URL="amqp://moeflow:PLEASE_CHANGE_THIS@moeflow-rabbitmq:5672/moeflow" +# CELERY_BACKEND_URL='mongodb://moeflow:PLEASE_CHANGE_THIS@moeflow-mongodb:27017/moeflow?authSource=admin' +# MIT_STORAGE_ROOT=/app/storage diff --git a/.github/workflows/check-pr.yml b/.github/workflows/check-pr.yml index 45ea40b..0d37025 100644 --- a/.github/workflows/check-pr.yml +++ b/.github/workflows/check-pr.yml @@ -5,7 +5,7 @@ on: workflow_dispatch: jobs: - check-pr: + static-check: runs-on: ubuntu-latest steps: @@ -17,3 +17,31 @@ jobs: - run: pip install -r requirements.txt - run: ruff check . - run: ruff format --diff . + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: docker-compose -f tests/deps.yaml up -d + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + - run: pip install -r requirements.txt + - run: cp -rv .env.test.sample .env.test + - uses: pavelzw/pytest-action@v2 + with: + emoji: false + verbose: true + job-summary: true + - uses: codecov/codecov-action@v4.0.1 + if: always() + with: + token: ${{ secrets.CODECOV_TOKEN }} + # XXX: can't if (SECRET_DEFINED) for this step + - name: save test report + uses: actions/upload-artifact@v4 + if: always() + with: + name: report.html + path: report.html diff --git a/.gitignore b/.gitignore index 71b2e8e..e191704 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports +/report.html htmlcov/ .tox/ .coverage @@ -80,7 +81,8 @@ celerybeat-schedule *.sage.py # dotenv -.env +.env* +!.*.sample # virtualenv .venv @@ -117,4 +119,4 @@ gen /files/tmp/ # 储存文件 -/storage/* \ No newline at end of file +/storage/* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5a3d863 --- /dev/null +++ b/Makefile @@ -0,0 +1,44 @@ +PYTEST_COV_ARGS = + +FORCE: ; + +create-venv: + python3.10 -mvenv venv + +deps: + venv/bin/pip install -r requirements.txt + +remove-venv: FORCE + rm -rf venv + +recreate-venv: remove-venv create-venv + +lint: + venv/bin/ruff check + +lint-fix: + venv/bin/ruff --fix + +format: + venv/bin/ruff format + +requirements.txt: deps-top.txt recreate-venv + venv/bin/pip install -r deps-top.txt + echo '# GENERATED: run make requirements.txt to recreate lock file' > requirements.txt + venv/bin/pipdeptree --freeze >> requirements.txt + +test: test_all + +test_all: + venv/bin/pytest + +test_all_parallel: + # TODO: fix this + venv/bin/pytest -n 8 + +test_single: + venv/bin/pytest tests/api/test_file_api.py + +test_logging: + #--capture=no + venv/bin/pytest --capture=sys --log-cli-level=DEBUG tests/base/test_logging.py diff --git a/README.md b/README.md index 1b87cfb..d30c7a3 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # 萌翻[MoeFlow]后端项目 +[![codecov](https://codecov.io/gh/moeflow-com/moeflow-backend/graph/badge.svg?token=LQJBLB495F)](https://codecov.io/gh/moeflow-com/moeflow-backend) + 由于此版本调整了部分 API 接口, **请配合萌翻前端 Version.1.0.1 版本使用!** 直接使用旧版可能在修改(创建)团队和项目时报错。 此版本需配置 **阿里云 OSS** 作为文件存储。如果需要使用其他文件存储方式,可以选择使用以下的分支版本: diff --git a/app/__init__.py b/app/__init__.py index b0acfd9..0b9116e 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,162 +1,40 @@ import os +import logging -from celery import Celery from flask import Flask, g, request -from flask_apikit import APIKit -from flask_babel import Babel -from app.constants.locale import Locale -from app.core.rbac import AllowApplyType, ApplicationCheckType -from app.services.google_storage import GoogleStorage -from app.services.oss import OSS -from app.utils.logging import configure_logger, logger +from .factory import ( + app_config, + create_celery, + create_flask_app, + init_flask_app, + babel, + oss, + gs_vision, +) -from .apis import register_apis -import app.config as _app_config +from app.constants.locale import Locale +from app.utils.logging import configure_root_logger, configure_extra_logs -app_config = { - k: getattr(_app_config, k) for k in dir(_app_config) if not k.startswith("_") -} +configure_root_logger() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) # 基本路径 APP_PATH = os.path.abspath(os.path.dirname(__file__)) FILE_PATH = os.path.abspath(os.path.join(APP_PATH, "..", "files")) # 一般文件 TMP_PATH = os.path.abspath(os.path.join(FILE_PATH, "tmp")) # 临时文件存放地址 STORAGE_PATH = os.path.abspath(os.path.join(APP_PATH, "..", "storage")) # 储存地址 -# 插件 -babel = Babel() -oss = OSS() -gs_vision = GoogleStorage() -apikit = APIKit() - -config_path_env = "CONFIG_PATH" - - -def create_default_team(admin_user): - from app.models.team import Team, TeamRole - from app.models.site_setting import SiteSetting - logger.info("-" * 50) - if Team.objects().count() == 0: - logger.info("已建立默认团队") - team = Team.create( - name="默认团队", - creator=admin_user, - ) - team.intro = "所有新用户会自动加入此团队,如不需要,站点管理员可以在“站点管理-自动加入的团队 ID”中删除此团队 ID。" - team.allow_apply_type = AllowApplyType.ALL - team.application_check_type = ApplicationCheckType.ADMIN_CHECK - team.default_role = TeamRole.by_system_code("member") - team.save() - site_setting = SiteSetting.get() - site_setting.auto_join_team_ids = [team.id] - site_setting.save() - else: - logger.info("已有团队,跳过建立默认团队") - - -def create_or_override_default_admin(app): - """创建或覆盖默认管理员""" - from app.models.user import User - - admin_user = User.get_by_email(app.config["ADMIN_EMAIL"]) - if admin_user: - if admin_user.admin is False: - admin_user.admin = True - admin_user.save() - logger.info("-" * 50) - logger.info("已将 {} 设置为管理员".format(app.config["ADMIN_EMAIL"])) - else: - admin_user = User.create( - name="Admin", - email=app.config["ADMIN_EMAIL"], - password="123123", - ) - admin_user.admin = True - admin_user.save() - logger.info( - "已创建管理员 {}, 默认密码为 123123,请及时修改!".format(admin_user.email) - ) - return admin_user +# Singletons +flask_app = create_flask_app(Flask(__name__)) +configure_extra_logs(flask_app) +celery = create_celery(flask_app) +init_flask_app(flask_app) def create_app(): - app = Flask(__name__) - app.config.from_mapping(app_config) - configure_logger(app) # 配置日志记录(放在最前,会被下面调用) - - logger.info("-" * 50) - # 连接数据库 - from app.models import connect_db - - connect_db(app.config) - # 注册api蓝本 - register_apis(app) - # 初始化插件 - babel.init_app(app) - apikit.init_app(app) - - logger.info("-" * 50) - logger.info("站点支持语言: " + str([str(i) for i in babel.list_translations()])) - oss.init(app.config) # 文件储存 - - return app - - -def init_db(app: Flask): - # 初始化角色,语言 - from app.models.language import Language - from app.models.project import ProjectRole - from app.models.team import TeamRole - from app.models.site_setting import SiteSetting - - TeamRole.init_system_roles() - ProjectRole.init_system_roles() - Language.init_system_languages() - SiteSetting.init_site_setting() - admin_user = create_or_override_default_admin(app) - create_default_team(admin_user) - - -def create_celery() -> Celery: - # 为celery创建app - app = Flask(__name__) - app.config.from_mapping(app_config) - # 通过app配置创建celery实例 - created = Celery( - app.name, - broker=app.config["CELERY_BROKER_URL"], - backend=app.config["CELERY_BACKEND_URL"], - mongodb_backend_settings=app.config["CELERY_MONGODB_BACKEND_SETTINGS"], - ) - created.conf.update({"app_config": app.config}) - created.autodiscover_tasks( - packages=[ - "app.tasks.email", - "app.tasks.file_parse", - "app.tasks.output_team_projects", - "app.tasks.output_project", - "app.tasks.ocr", - "app.tasks.import_from_labelplus", - "app.tasks.thumbnail", - "app.tasks.mit", # only included for completeness's sake. its impl is in other repo. - ], - related_name=None, - ) - created.conf.task_routes = ( - [ - # TODO 'output' should be named better. - # its original purpose was cpu-intensive jobs that may block light ones. - ("tasks.output_project_task", {"queue": "output"}), - ("tasks.import_from_labelplus_task", {"queue": "output"}), - ("tasks.mit.*", {"queue": "mit"}), - ("*", {"queue": "default"}), # default queue for all other tasks - ], - ) - return created - - -celery = create_celery() + return flask_app @babel.localeselector @@ -179,3 +57,15 @@ def get_locale(): # if current_user: # if current_user.timezone: # return current_user.timezone + +__all__ = [ + "oss", + "gs_vision", + "flask_app", + "app_config", + "celery", + "APP_PATH", + "STORAGE_PATH", + "TMP_PATH", + "FILE_PATH", +] diff --git a/app/apis/__init__.py b/app/apis/__init__.py index d6cc4fc..3d4c77d 100644 --- a/app/apis/__init__.py +++ b/app/apis/__init__.py @@ -2,10 +2,12 @@ 所有的API编写在此 """ -from flask import Blueprint, Flask +import logging -from app.utils.logging import logger +from flask import Blueprint, Flask +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) """ @apiDefine TokenHeader @@ -40,13 +42,12 @@ def register_apis(app: Flask): :param app: :return: """ - logger.info("-" * 50) - logger.info("注册蓝本:") + logger.info("Register route blueprints") # 获取urls中所有蓝本 from . import urls blueprints = [v for k, v in vars(urls).items() if isinstance(v, Blueprint)] for blueprint in blueprints: prefix = "/" if blueprint.url_prefix is None else blueprint.url_prefix - logger.info(" - {}: {}".format(blueprint.name, prefix)) + logger.debug(" - {}: {}".format(blueprint.name, prefix)) app.register_blueprint(blueprint) diff --git a/app/apis/site_setting.py b/app/apis/site_setting.py index b491199..6d4f85c 100644 --- a/app/apis/site_setting.py +++ b/app/apis/site_setting.py @@ -52,8 +52,8 @@ def put(self): site_setting.whitelist_emails = data["whitelist_emails"] site_setting.only_allow_admin_create_team = data["only_allow_admin_create_team"] site_setting.auto_join_team_ids = data["auto_join_team_ids"] - site_setting.homepage_html = data["homepage_html"] - site_setting.homepage_css = data["homepage_css"] + site_setting.homepage_html = data.get("homepage_html", "") + site_setting.homepage_css = data.get("homepage_css", "") site_setting.save() site_setting.reload() return site_setting.to_api() diff --git a/app/config.py b/app/config.py index 54056f3..83399fb 100644 --- a/app/config.py +++ b/app/config.py @@ -3,21 +3,31 @@ # 开发测试配置可放在 configs 文件夹下(已 gitignore)或项目外 # =========== from os import environ as env +import urllib.parse as urlparse # ----------- # 基础设置 # ----------- SITE_NAME = env["SITE_NAME"] -DOMAIN = env["DOMAIN"] SECRET_KEY = env["SECRET_KEY"] # 必填 - 密钥 -DEBUG = False -TESTING = False +LOG_LEVEL = env.get("LOG_LEVEL", "INFO") +# DEPRECATED: please use modern container logging collector +LOG_PATH = env.get("LOG_PATH") MAX_CONTENT_LENGTH = int(env.get("MAX_CONTENT_LENGTH_MB", 1024)) * 1024 * 1024 ADMIN_EMAIL = env["ADMIN_EMAIL"] +ADMIN_INITIAL_PASSWORD = env.get("ADMIN_INITIAL_PASSWORD", "123123") +# TODO reduce code relying on this +TESTING = env.get("TESTING") == "YES" # ----------- # Mongo 数据库 # ----------- -DB_URI = f"mongodb://{env['MONGODB_USER']}:{env['MONGODB_PASS']}@moeflow-mongodb:27017/{env['MONGODB_DB_NAME']}?authSource=admin" +DB_URI = env.get("MONGODB_URI") + +DB_URI = ( + DB_URI + or f"mongodb://{env['MONGODB_USER']}:{env['MONGODB_PASS']}@moeflow-mongodb:27017/{env['MONGODB_DB_NAME']}?authSource=admin" +) + # ----------- # i18n # ----------- @@ -45,7 +55,7 @@ # 未设置自定义域名则填写阿里云提供的 OSS 域名,格式如:https://..aliyuncs.com/ # 如果绑定了 CDN 来加速 OSS,则填写绑定在 CDN 的域名 # 本地储存填写绑定到服务器的域名,需用 nginx 指向 storage 文件夹,格式如:https://.com/storage/ -STORAGE_DOMAIN = env.get("STORAGE_DOMAIN", "http://" + DOMAIN + "/storage/") +STORAGE_DOMAIN = env["STORAGE_DOMAIN"] OSS_ACCESS_KEY_ID = env.get("OSS_ACCESS_KEY_ID", "") OSS_ACCESS_KEY_SECRET = env.get("OSS_ACCESS_KEY_SECRET", "") OSS_ENDPOINT = env.get("OSS_ENDPOINT", "") @@ -112,14 +122,20 @@ # ----------- # Celery # ----------- -CELERY_BROKER_URL = env.get( - "CELERY_BROKER_URL", - f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}", +CELERY_BROKER_URL = env.get("CELERY_BROKER_URL") + +CELERY_BROKER_URL = ( + CELERY_BROKER_URL + or f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}" ) CELERY_BACKEND_URL = env.get("CELERY_BACKEND_URL", DB_URI) -CELERY_MONGODB_BACKEND_SETTINGS = { - "database": env["MONGODB_DB_NAME"], - "taskmeta_collection": "celery_taskmeta", + +_DB_URI_PARSED = urlparse.urlparse(DB_URI) +CELERY_BACKEND_SETTINGS = { + "mongodb_backend_settings": { + "database": _DB_URI_PARSED.path[1:], + "taskmeta_collection": "celery_taskmeta", + } } # ----------- # APIKit diff --git a/app/core/rbac.py b/app/core/rbac.py index 21cff56..e0fe0dc 100644 --- a/app/core/rbac.py +++ b/app/core/rbac.py @@ -4,6 +4,7 @@ """ import datetime +import logging from app.exceptions import UserNotExistError, CreatorCanNotLeaveError from flask_babel import gettext, lazy_gettext @@ -28,10 +29,12 @@ from app.models.invitation import Invitation from app.constants.base import IntType from app.constants.role import RoleType -from app.utils.logging import logger from app.utils.mongo import mongo_order, mongo_slice from typing import List, Any, Type +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARN) + class AllowApplyType(IntType): """ diff --git a/app/core/views.py b/app/core/views.py index fff9470..98c90d5 100644 --- a/app/core/views.py +++ b/app/core/views.py @@ -1,13 +1,11 @@ -from flask import g +from typing import Optional +from flask import g from flask_apikit.views import APIView -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from app.models.user import User +from app.models.user import User class MoeAPIView(APIView): @property - def current_user(self) -> "User": + def current_user(self) -> Optional[User]: return g.get("current_user") diff --git a/app/factory.py b/app/factory.py new file mode 100644 index 0000000..60e081d --- /dev/null +++ b/app/factory.py @@ -0,0 +1,140 @@ +import logging +from celery import Celery +from flask import Flask +from flask_apikit import APIKit +from flask_babel import Babel +from app.core.rbac import AllowApplyType, ApplicationCheckType +from app.services.google_storage import GoogleStorage +import app.config as _app_config +from app.services.oss import OSS +from .apis import register_apis + +from app.models import connect_db + +logger = logging.getLogger(__name__) + +# singleton modules +babel = Babel() +apikit = APIKit() +oss = OSS() +gs_vision = GoogleStorage() + +app_config = { + k: getattr(_app_config, k) for k in dir(_app_config) if not k.startswith("_") +} + + +def create_flask_app(app: Flask) -> Flask: + app.config.from_mapping(app_config) + connect_db(app.config) + # print("WTF", app.logger.level) + # WTF: why is logging so fuking hard in py ecosystem? + # prevent flask from duplicating logs + # app.logger.removeHandler(flask_default_handler) + # app.logger.propagate = False + return app + + +def init_flask_app(app: Flask): + register_apis(app) + babel.init_app(app) + apikit.init_app(app) + logger.info("-" * 50) + logger.info("站点支持语言: " + str([str(i) for i in babel.list_translations()])) + oss.init(app.config) # 文件储存 + + +def create_celery(app: Flask) -> Celery: + # 通过app配置创建celery实例 + created = Celery( + app.name, + broker=app.config["CELERY_BROKER_URL"], + backend=app.config["CELERY_BACKEND_URL"], + **app.config["CELERY_BACKEND_SETTINGS"], + ) + created.conf.update({"app_config": app.config}) + created.autodiscover_tasks( + packages=[ + "app.tasks.email", + "app.tasks.file_parse", + "app.tasks.output_team_projects", + "app.tasks.output_project", + "app.tasks.ocr", + "app.tasks.import_from_labelplus", + "app.tasks.thumbnail", + "app.tasks.mit", # only included for completeness's sake. its impl is in other repo. + ], + related_name=None, + ) + created.conf.task_routes = ( + [ + # TODO 'output' should be named better. + # its original purpose was cpu-intensive jobs that may block light ones. + ("tasks.output_project_task", {"queue": "output"}), + ("tasks.import_from_labelplus_task", {"queue": "output"}), + ("tasks.mit.*", {"queue": "mit"}), + ("*", {"queue": "default"}), # default queue for all other tasks + ], + ) + return created + + +def create_or_override_default_admin(app: Flask): + """创建或覆盖默认管理员""" + from app.models.user import User + + admin_user = User.get_by_email(app.config["ADMIN_EMAIL"]) + if admin_user: + if admin_user.admin is False: + admin_user.admin = True + admin_user.save() + logger.debug("已将 {} 设置为管理员".format(app.config["ADMIN_EMAIL"])) + else: + admin_user = User.create( + name="Admin", + email=app.config["ADMIN_EMAIL"], + password=app.config["ADMIN_INITIAL_PASSWORD"], + ) + admin_user.admin = True + admin_user.save() + logger.debug( + "已创建管理员 {}, 默认密码为 123123,请及时修改!".format(admin_user.email) + ) + return admin_user + + +def create_default_team(admin_user): + from app.models.team import Team, TeamRole + from app.models.site_setting import SiteSetting + + if Team.objects().count() == 0: + logger.debug("已建立默认团队") + team = Team.create( + name="默认团队", + creator=admin_user, + ) + team.intro = "所有新用户会自动加入此团队,如不需要,站点管理员可以在“站点管理-自动加入的团队 ID”中删除此团队 ID。" + team.allow_apply_type = AllowApplyType.ALL + team.application_check_type = ApplicationCheckType.ADMIN_CHECK + team.default_role = TeamRole.by_system_code("member") + team.save() + site_setting = SiteSetting.get() + site_setting.auto_join_team_ids = [team.id] + site_setting.save() + else: + logger.debug("已有团队,跳过建立默认团队") + + +def init_db(app: Flask): + # 初始化角色,语言 + from app.models.language import Language + from app.models.project import ProjectRole + from app.models.team import TeamRole + from app.models.site_setting import SiteSetting + + TeamRole.init_system_roles() + ProjectRole.init_system_roles() + Language.init_system_languages() + SiteSetting.init_site_setting() + admin_user = create_or_override_default_admin(app) + create_default_team(admin_user) diff --git a/app/models/__init__.py b/app/models/__init__.py index ba8eef8..97c19f9 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -2,16 +2,17 @@ 模型 """ +import logging from mongoengine import connect -from app.utils.logging import logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) def connect_db(config): - logger.info("-" * 50) - logger.info("连接 mongodb:") + logger.info("Connect mongodb") uri = config["DB_URI"] - logger.info(" - uri: {}".format(uri)) + logger.debug(" - $DB_URI: {}".format(uri)) return connect(host=uri) diff --git a/app/models/file.py b/app/models/file.py index 02f04ab..31f3c22 100644 --- a/app/models/file.py +++ b/app/models/file.py @@ -1,8 +1,10 @@ -from typing import NoReturn, Union +from io import BufferedReader +from typing import NoReturn, Union, BinaryIO import datetime import math import re import mongoengine +import logging from bson import ObjectId from flask import current_app from flask_babel import gettext @@ -63,10 +65,12 @@ from app.utils import default from app.utils.file import get_file_size from app.utils.hash import get_file_md5 -from app.utils.logging import logger from app.utils.mongo import mongo_order, mongo_slice from app.utils.type import is_number +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + default_translations_order = ["-selected", "-proofread_content", "-edit_time"] @@ -133,7 +137,7 @@ def _get_prefix_and_suffix(self): suffix = prefix_and_suffix[1] return prefix, suffix - def _get_sort_name(self, width): + def _get_sort_name(self, width: int): """返回用于排序的名称,使用前缀排序,将前缀中数字补足一定位数用于排序""" # 将前缀中数字与其他字符拆成列表 # 形如 ['book', '1', '-', '002', '.jpg'] @@ -641,7 +645,9 @@ def download_real_file(self, local_path=None): ) @only_file - def upload_real_file(self, real_file, do_safe_scan=False): + def upload_real_file( + self, real_file: Union[BufferedReader, BinaryIO], do_safe_scan=False + ): """ 上传源文件 """ diff --git a/app/models/language.py b/app/models/language.py index 7af4a79..d5c8b3f 100644 --- a/app/models/language.py +++ b/app/models/language.py @@ -1,10 +1,13 @@ +import logging from flask_babel import gettext from mongoengine import Document, BooleanField, StringField, IntField, QuerySet from app.exceptions.language import LanguageNotExistError -from app.utils.logging import logger from typing import List, Any, Dict +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + class Language(Document): en_name: str = StringField(db_field="e") # 英文名称 @@ -775,9 +778,8 @@ class Language(Document): @classmethod def init_system_languages(cls) -> None: """初始化语言表""" - logger.info("-" * 50) if cls.objects.count() > 0: - logger.info("已存在语言表,跳过初始化") + logger.debug("已存在语言表,跳过初始化") return sort = 0 for lang in cls.SYSTEM_LANGUAGES_DATA: @@ -791,7 +793,7 @@ def init_system_languages(cls) -> None: sort=sort, ).save() sort += 1 - logger.info(f"初始化语言表,共添加{len(cls.SYSTEM_LANGUAGES_DATA)}种语言") + logger.debug(f"初始化语言表,共添加{len(cls.SYSTEM_LANGUAGES_DATA)}种语言") @classmethod def create( diff --git a/app/models/project.py b/app/models/project.py index 95709c0..f57276e 100644 --- a/app/models/project.py +++ b/app/models/project.py @@ -1,5 +1,7 @@ from app.exceptions.project import LabelplusParseFailedError import datetime +import logging +from typing import List, Union, BinaryIO, TYPE_CHECKING from bson import ObjectId from io import BufferedReader @@ -50,6 +52,9 @@ from app.models.language import Language from app.models.target import Target from app.models.term import TermBank + +if TYPE_CHECKING: + from app.models.team import Team from app.models.output import Output from app.tasks.file_parse import find_terms from app.constants.file import ( @@ -63,10 +68,9 @@ ProjectStatus, ) from app.utils.mongo import mongo_order, mongo_slice -from typing import List, TYPE_CHECKING -if TYPE_CHECKING: - from app.models.team import Team +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) class ProjectAllowApplyType(AllowApplyType): @@ -605,6 +609,9 @@ def create_file(self, name: str, parent: File = None) -> File: :param parent: 所属文件夹,顶层则为None :return: """ + logging.debug( + f"Project(id={self.id}).create_file(name={name}, parent={parent})" + ) # 确认父级是文件夹且存在于本项目 if parent is not None: parent = self.get_folder(parent) @@ -651,7 +658,9 @@ def create_file(self, name: str, parent: File = None) -> File: file.inc_cache("file_count", 1, update_self=False) return file - def upload(self, filename: str, real_file: BufferedReader, parent=None): + def upload( + self, filename: str, real_file: Union[BufferedReader, BinaryIO], parent=None + ) -> File: """ 上传文件 @@ -672,13 +681,15 @@ def upload_from_github(self): """从github导入项目""" @classmethod - def by_id(cls, id): - project = cls.objects(id=id).first() + def by_id(cls, id_: str): + project = cls.objects(id=id_).first() if project is None: raise ProjectNotExistError() return project - def get_files(self, name, parent="all", activated=True): + def get_files( + self, name, parent: Union[str, ObjectId, File] = "all", activated=True + ): """通过文件名获取文件或文件夹(大小写不敏感),默认仅获取激活的修订版""" file = File.objects(name__iexact=name, project=self) # 限制文件夹 @@ -691,7 +702,7 @@ def get_files(self, name, parent="all", activated=True): file = file.filter(activated=activated) return file - def get_folder(self, folder): + def get_folder(self, folder: Union[str, ObjectId, File]) -> File: """ 尝试获取本项目下的文件夹,检查文件夹是否存在于本项目 如没有则raise FolderNotExistError diff --git a/app/models/site_setting.py b/app/models/site_setting.py index e74ae15..293dcc9 100644 --- a/app/models/site_setting.py +++ b/app/models/site_setting.py @@ -1,3 +1,4 @@ +import logging from mongoengine import ( Document, ListField, @@ -5,7 +6,9 @@ StringField, ObjectIdField, ) -from app.utils.logging import logger + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) class SiteSetting(Document): @@ -29,11 +32,10 @@ class SiteSetting(Document): @classmethod def init_site_setting(cls): - logger.info("-" * 50) if cls.objects(type="site").count() > 0: - logger.info("已有站点设置,跳过初始化") + logger.debug("已有站点设置,跳过初始化") else: - logger.info("初始化站点设置") + logger.debug("初始化站点设置") cls(type="site").save() @classmethod diff --git a/app/services/oss.py b/app/services/oss.py index d8176dc..b1d3661 100644 --- a/app/services/oss.py +++ b/app/services/oss.py @@ -2,12 +2,14 @@ 对接阿里云OSS储存服务 """ -from io import BufferedReader, BytesIO +from io import BufferedReader, BytesIO, FileIO import os import re import shutil import time import hashlib +import logging +from typing import Union from urllib import parse import oss2 @@ -16,6 +18,8 @@ from app.constants.storage import StorageType +logger = logging.getLogger(__name__) + def md5sum(src): m = hashlib.md5() @@ -77,7 +81,14 @@ def init(self, config): self.oss_domain = config["STORAGE_DOMAIN"] self.STORAGE_PATH = STORAGE_PATH - def upload(self, path, filename, file, headers=None, progress_callback=None): + def upload( + self, + path: str, + filename: str, + file: Union[str, BufferedReader, FileIO], + headers=None, + progress_callback=None, + ): """上传文件""" if self.storage_type == StorageType.OSS: return self.bucket.put_object( @@ -96,9 +107,12 @@ def upload(self, path, filename, file, headers=None, progress_callback=None): with open(os.path.join(folder_path, filename), "w") as saved_file: saved_file.write(file) else: - file.save(os.path.join(folder_path, filename)) + file.save( + os.path.join(folder_path, filename) + ) # XXX: what's the type of file here? + logging.debug("saved file : %s / %s", folder_path, filename) - def download(self, path, filename, /, *, local_path=None): + def download(self, path, filename: str, /, *, local_path=None): """下载文件""" # 如果提供local_path,则下载到本地 if self.storage_type == StorageType.OSS: @@ -141,7 +155,7 @@ def is_exist(self, path, filename, process_name=None): ) ) - def delete(self, path, filename): + def delete(self, path, filename: Union[str, list[str]]): """(批量)删除文件""" if self.storage_type == StorageType.OSS: # 如果给予列表,则批量删除 diff --git a/app/tasks/__init__.py b/app/tasks/__init__.py index 22577de..a3541ae 100644 --- a/app/tasks/__init__.py +++ b/app/tasks/__init__.py @@ -17,6 +17,9 @@ from asgiref.sync import async_to_sync +_FORCE_SYNC_TASK: bool = celery_app.conf["app_config"].get("TESTING", False) + + class SyncResult: """和celery的delay异步返回类似的结果,用于同步、异步切换""" diff --git a/app/tasks/import_from_labelplus.py b/app/tasks/import_from_labelplus.py index 823311f..95d3999 100644 --- a/app/tasks/import_from_labelplus.py +++ b/app/tasks/import_from_labelplus.py @@ -11,7 +11,7 @@ from app import celery from app.models import connect_db -from . import SyncResult +from . import SyncResult, _FORCE_SYNC_TASK from celery.utils.log import get_task_logger from app.utils.labelplus import load_from_labelplus from app.constants.source import SourcePositionType @@ -103,7 +103,7 @@ def import_from_labelplus_task(project_id): def import_from_labelplus(project_id, /, *, run_sync=False) -> SyncResult | AsyncResult: alive_workers = celery.control.ping() - if len(alive_workers) == 0 or run_sync: + if len(alive_workers) == 0 or run_sync or _FORCE_SYNC_TASK: # 同步执行 import_from_labelplus_task(project_id) return SyncResult() diff --git a/app/tasks/output_project.py b/app/tasks/output_project.py index 49ed157..a0f8a28 100644 --- a/app/tasks/output_project.py +++ b/app/tasks/output_project.py @@ -17,7 +17,7 @@ from app import oss from app.models import connect_db from app.regexs import SAFE_FILENAME_REGEX -from . import SyncResult +from . import SyncResult, _FORCE_SYNC_TASK from celery.utils.log import get_task_logger logger = get_task_logger(__name__) @@ -224,7 +224,7 @@ def output_project_task(output_id): def output_project(output_id, /, *, run_sync=False): alive_workers = celery.control.ping() - if len(alive_workers) == 0 or run_sync: + if len(alive_workers) == 0 or run_sync or _FORCE_SYNC_TASK: # 同步执行 output_project_task(output_id) return SyncResult() diff --git a/app/utils/file.py b/app/utils/file.py index c1a11df..b4c793b 100644 --- a/app/utils/file.py +++ b/app/utils/file.py @@ -1,7 +1,7 @@ import os -def get_file_size(file, unit="kb"): +def get_file_size(file, unit="kb") -> int: """获取文件大小,默认返回kb为单位的数值""" file.seek(0, os.SEEK_END) # 移动到文件尾部 size = file.tell() # 获取文件大小,单位是Byte diff --git a/app/utils/logging.py b/app/utils/logging.py index 27f0629..8ea3d0a 100644 --- a/app/utils/logging.py +++ b/app/utils/logging.py @@ -7,8 +7,11 @@ import os import logging from logging.handlers import SMTPHandler +from flask import Flask +from typing import Optional -logger = logging.Logger(__name__) +logger = logging.getLogger(__name__) +root_logger = logging.getLogger("root") class SMTPSSLHandler(SMTPHandler): @@ -47,19 +50,54 @@ def emit(self, record): self.handleError(record) -def configure_logger(app): - """ - 通过app.config自动配置logger +_logger_configured = False - :param app: - :return: - """ - logger.setLevel(logging.DEBUG) - # 各种格式 - stream_formatter = logging.Formatter("[%(asctime)s] (%(levelname)s) %(message)s") + +def configure_root_logger(override: Optional[str] = None): + global _logger_configured + if _logger_configured: + raise AssertionError("configure_root_logger already executed") + _logger_configured = True + logging.debug( + "configuring root logger %s %s", + root_logger.level, + root_logger.getEffectiveLevel(), + ) + level = override or os.environ.get("LOG_LEVEL") + if not level: + return + logging.basicConfig( + format="[%(asctime)s] %(levelname)s %(name)s %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S%z", + force=True, # why the f is this required? + level=getattr(logging, level.upper()), + ) + logging.debug("reset log level %s", level) + + +def configure_extra_logs(app: Flask): + if app.config.get("ENABLE_LOG_EMAIL"): + _enable_email_error_log(app) + if app.config.get("LOG_PATH"): + _enable_file_log(app) + + +def _enable_file_log(app: Flask): file_formatter = logging.Formatter( "[%(asctime)s %(pathname)s:%(lineno)d] (%(levelname)s) %(message)s" ) + log_path = app.config.get("LOG_PATH") + log_folder = os.path.dirname(log_path) + if not os.path.isdir(log_folder): + os.makedirs(log_folder) + file_handler = logging.FileHandler(log_path) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + +def _enable_email_error_log(app: Flask): + # === 邮件输出 === + mail_formatter = logging.Formatter( """ Message type: %(levelname)s @@ -73,60 +111,17 @@ def configure_logger(app): %(message)s """ ) - - if app.config["DEBUG"]: - # 控制台输出 - stream_handler = logging.StreamHandler() - # 如果测试只输出ERROR - if app.config["TESTING"]: - stream_handler.setLevel(logging.ERROR) - else: - stream_handler.setLevel(logging.DEBUG) - stream_handler.setFormatter(stream_formatter) # 格式设置 - # 附加到logger - logger.addHandler(stream_handler) - app.logger.addHandler(stream_handler) - else: - # 设置了LOG_PATH则使用,否则使用默认的logs文件夹 - if app.config.get("LOG_PATH"): - log_path = app.config.get("LOG_PATH") - log_folder = os.path.dirname(log_path) - else: - log_folder = "./logs" - log_file = "log.txt" - log_path = os.path.join(log_folder, log_file) - # 不存在记录文件夹自动创建 - if not os.path.isdir(log_folder): - os.makedirs(log_folder) - - # === 控制台输出 === - stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.INFO) - stream_handler.setFormatter(stream_formatter) - - # === 文件输出 === - file_handler = logging.FileHandler(log_path) - file_handler.setLevel(logging.WARNING) - file_handler.setFormatter(file_formatter) - - # === 邮件输出 === - if app.config["ENABLE_LOG_EMAIL"]: - mail_handler = SMTPSSLHandler( - (app.config["EMAIL_SMTP_HOST"], app.config["EMAIL_SMTP_PORT"]), - app.config["EMAIL_ADDRESS"], - app.config["EMAIL_ERROR_ADDRESS"], - "萌翻站点发生错误", - credentials=( - app.config["EMAIL_ADDRESS"], - app.config["EMAIL_PASSWORD"], - ), - ) - mail_handler.setLevel(logging.ERROR) - mail_handler.setFormatter(mail_formatter) - logger.addHandler(mail_handler) - app.logger.addHandler(mail_handler) - - logger.addHandler(stream_handler) - logger.addHandler(file_handler) - app.logger.addHandler(stream_handler) - app.logger.addHandler(file_handler) + mail_handler = SMTPSSLHandler( + (app.config["EMAIL_SMTP_HOST"], app.config["EMAIL_SMTP_PORT"]), + app.config["EMAIL_ADDRESS"], + app.config["EMAIL_ERROR_ADDRESS"], + "萌翻站点发生错误", + credentials=( + app.config["EMAIL_ADDRESS"], + app.config["EMAIL_PASSWORD"], + ), + ) + mail_handler.setLevel(logging.ERROR) + mail_handler.setFormatter(mail_formatter) + logger.addHandler(mail_handler) + app.logger.addHandler(mail_handler) diff --git a/app/validators/project.py b/app/validators/project.py index f1927d6..5f1a45e 100644 --- a/app/validators/project.py +++ b/app/validators/project.py @@ -46,6 +46,7 @@ def to_model(self, in_data): raise ProjectSetNotExistError in_data["project_set"] = project_set return in_data + return in_data class SearchUserProjectSchema(DefaultSchema): diff --git a/deps-top.txt b/deps-top.txt index 0ef74c9..6d1551d 100644 --- a/deps-top.txt +++ b/deps-top.txt @@ -12,9 +12,13 @@ itsdangerous==2.0.1 # MUST NOT upgrade this, we still use TimedJSO # werkzeug==2.0.2 flask-apikit==0.0.7 gunicorn==20.0.4 # 生产环境服务器 -pytest==6.1.1 # 测试框架 -pytest-cov==2.7.1 # 测试覆盖率 -pytest-xdist==1.29.0 # 并发测试支持 +pytest==8.2.1 # 测试框架 +pytest-cov==5.0.0 # 测试覆盖率 +pytest-xdist==3.6.1 # 并发测试支持 +pytest-dotenv==0.5.2 +pytest-html==4.1.1 +pytest-md==0.2.0 + flask-babel==1.0.0 # i18n mongoengine==0.20.0 # Mongo数据库 mongomock==4.1.2 diff --git a/manage.py b/manage.py index 9d48102..011ed98 100644 --- a/manage.py +++ b/manage.py @@ -3,7 +3,8 @@ import click import logging -from app import create_app, init_db +from app import flask_app +from app.factory import init_db logging.basicConfig( level=logging.INFO, @@ -34,8 +35,7 @@ def migrate(): """ Initialize the database """ - app = create_app() - init_db(app) + init_db(flask_app) @click.command() diff --git a/pytest.ini b/pytest.ini index 2268bb3..abd34de 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,8 +1,9 @@ [pytest] +generate_report_on_test = True ; 覆盖率配置,需要时开启(可配合VSCode的Coverage Gutters使用) -addopts = -s -; addopts = -s --cov=app --cov-report=term --cov-report=xml:cov.xml +addopts = --log-level=DEBUG --capture=no --cov=app --cov-report=term --cov-report=xml:cov.xml --html=report.html --self-contained-html ; 忽略第三方库Warining filterwarnings = ignore:count is deprecated. Use Collection.count_documents instead.:DeprecationWarning - ignore:Using or importing the ABCs .* is deprecated.*:DeprecationWarning \ No newline at end of file + ignore:Using or importing the ABCs .* is deprecated.*:DeprecationWarning +env_files = .env.test diff --git a/requirements.txt b/requirements.txt index 5aeee23..814bdc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,25 @@ +# GENERATED: run make requirements.txt to recreate lock file asgiref==3.7.2 - typing_extensions==4.11.0 -blinker==1.7.0 + typing_extensions==4.12.2 Flask-APIKit==0.0.7 Flask==2.2.5 click==8.1.3 itsdangerous==2.0.1 - Jinja2==3.1.3 + Jinja2==3.1.4 MarkupSafe==2.1.5 - Werkzeug==3.0.2 + Werkzeug==3.0.3 MarkupSafe==2.1.5 marshmallow==3.0.0b20 Flask-Babel==1.0.0 - Babel==2.14.0 + Babel==2.15.0 Flask==2.2.5 click==8.1.3 itsdangerous==2.0.1 - Jinja2==3.1.3 + Jinja2==3.1.4 MarkupSafe==2.1.5 - Werkzeug==3.0.2 + Werkzeug==3.0.3 MarkupSafe==2.1.5 - Jinja2==3.1.3 + Jinja2==3.1.4 MarkupSafe==2.1.5 pytz==2024.1 flower==0.9.5 @@ -32,7 +32,7 @@ flower==0.9.5 click==8.1.3 click-repl==0.3.0 click==8.1.3 - prompt-toolkit==3.0.43 + prompt_toolkit==3.0.46 wcwidth==0.2.13 kombu==5.3.7 amqp==5.2.0 @@ -45,7 +45,7 @@ flower==0.9.5 humanize==4.9.0 prometheus-client==0.8.0 pytz==2024.1 - tornado==6.4 + tornado==6.4.1 google-cloud-storage==1.33.0 google-auth==1.35.0 cachetools==4.2.4 @@ -65,11 +65,11 @@ google-cloud-storage==1.33.0 pyasn1==0.6.0 setuptools==65.5.0 six==1.16.0 - googleapis-common-protos==1.63.0 + googleapis-common-protos==1.63.1 protobuf==4.25.3 protobuf==4.25.3 requests==2.22.0 - certifi==2024.2.2 + certifi==2024.6.2 chardet==3.0.4 idna==2.8 urllib3==1.25.11 @@ -86,7 +86,7 @@ google-cloud-storage==1.33.0 google-crc32c==1.5.0 six==1.16.0 requests==2.22.0 - certifi==2024.2.2 + certifi==2024.6.2 chardet==3.0.4 idna==2.8 urllib3==1.25.11 @@ -101,50 +101,69 @@ oss2==2.7.0 aliyun-python-sdk-core-v3==2.13.3 jmespath==0.10.0 pycryptodome==3.20.0 - aliyun-python-sdk-kms==2.16.2 + aliyun-python-sdk-kms==2.16.3 aliyun-python-sdk-core==2.15.1 - cryptography==42.0.5 + cryptography==42.0.8 cffi==1.16.0 pycparser==2.22 jmespath==0.10.0 crcmod==1.7 pycryptodome==3.20.0 requests==2.22.0 - certifi==2024.2.2 + certifi==2024.6.2 chardet==3.0.4 idna==2.8 urllib3==1.25.11 Pillow==8.0.1 pip==23.0.1 pipdeptree==2.13.2 -pytest-cov==2.7.1 - coverage==7.4.4 - pytest==6.1.1 - attrs==23.2.0 +pytest-cov==5.0.0 + coverage==7.5.3 + pytest==8.2.1 + exceptiongroup==1.2.1 iniconfig==2.0.0 packaging==24.0 - pluggy==0.13.1 - py==1.11.0 - toml==0.10.2 -pytest-xdist==1.29.0 - execnet==2.1.1 - pytest==6.1.1 - attrs==23.2.0 + pluggy==1.5.0 + tomli==2.0.1 +pytest-dotenv==0.5.2 + pytest==8.2.1 + exceptiongroup==1.2.1 iniconfig==2.0.0 packaging==24.0 - pluggy==0.13.1 - py==1.11.0 - toml==0.10.2 - pytest-forked==1.6.0 - py==1.11.0 - pytest==6.1.1 - attrs==23.2.0 + pluggy==1.5.0 + tomli==2.0.1 + python-dotenv==1.0.1 +pytest-html==4.1.1 + Jinja2==3.1.4 + MarkupSafe==2.1.5 + pytest==8.2.1 + exceptiongroup==1.2.1 + iniconfig==2.0.0 + packaging==24.0 + pluggy==1.5.0 + tomli==2.0.1 + pytest-metadata==3.1.1 + pytest==8.2.1 + exceptiongroup==1.2.1 iniconfig==2.0.0 packaging==24.0 - pluggy==0.13.1 - py==1.11.0 - toml==0.10.2 - six==1.16.0 + pluggy==1.5.0 + tomli==2.0.1 +pytest-md==0.2.0 + pytest==8.2.1 + exceptiongroup==1.2.1 + iniconfig==2.0.0 + packaging==24.0 + pluggy==1.5.0 + tomli==2.0.1 +pytest-xdist==3.6.1 + execnet==2.1.1 + pytest==8.2.1 + exceptiongroup==1.2.1 + iniconfig==2.0.0 + packaging==24.0 + pluggy==1.5.0 + tomli==2.0.1 redis==5.0.3 async-timeout==4.0.3 -ruff==0.4.1 +ruff==0.4.8 diff --git a/tests/__init__.py b/tests/__init__.py index a664456..1a9add1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,6 +7,7 @@ from mongoengine import connection from app import create_app, FILE_PATH +from app.factory import init_db from app.models.site_setting import SiteSetting from app.models.user import User from app.models.team import Team @@ -23,14 +24,13 @@ def create_test_app(): # 先创建app连接上数据库 - create_app() + app = create_app() # 不是_test结尾则停止测试,防止数据覆盖 if not connection.get_db().name.endswith("_test"): - raise RuntimeError("Please use *_test database") - # 清库,防止测试交叉影响 + raise AssertionError("Please use *_test database") + # reset the db connection.get_db().client.drop_database(connection.get_db().name) - # 再次创建app,保证角色之类初始化成功 - app = create_app() + init_db(app) return app @@ -148,11 +148,11 @@ def open(self, *args, **kwargs): if kwargs.get("headers") is None: kwargs["headers"] = { # 默认增加跨域参数 - "Origin": "https://example.com" + "Origin": "https://example.com", } # 如果给予data参数,则转换为json,并设置json内容头 json_data = kwargs.pop("json", None) - if json_data: + if isinstance(json_data, dict): kwargs["data"] = json.dumps(json_data) kwargs["headers"]["Content-Type"] = "application/json" token = kwargs.pop("token", None) diff --git a/tests/api/test_file_api.py b/tests/api/test_file_api.py index b1660f4..deb706a 100644 --- a/tests/api/test_file_api.py +++ b/tests/api/test_file_api.py @@ -242,7 +242,7 @@ def test_edit_file_name(self): ) self.assertErrorEqual(data, NoPermissionError) # == 缺少filename == - data = self.put("/v1/files/{}".format(file1.id), token=token) + data = self.put("/v1/files/{}".format(file1.id), json={}, token=token) self.assertErrorEqual(data, ValidateError) # == 错误的名字 == data = self.put( @@ -300,7 +300,7 @@ def test_move_file(self): ) self.assertErrorEqual(data, NoPermissionError) # == 缺少parent_id == - data = self.put("/v1/files/{}".format(file1.id), token=token) + data = self.put("/v1/files/{}".format(file1.id), json={}, token=token) self.assertErrorEqual(data, ValidateError) # == 缺少parent_id,null等同于缺少 == data = self.put( diff --git a/tests/api/test_project_api.py b/tests/api/test_project_api.py index 8c9f995..83a6ddf 100644 --- a/tests/api/test_project_api.py +++ b/tests/api/test_project_api.py @@ -111,13 +111,13 @@ def test_create_project(self): set1 = team1.default_project_set set2 = team2.default_project_set # 未登录,没有权限创建 - data = self.post(f"/v1/teams/{str(team1.id)}/projects") + data = self.post(f"/v1/teams/{str(team1.id)}/projects", json={}) self.assertErrorEqual(data, NeedTokenError) # user2没有权限创建 - data = self.post(f"/v1/teams/{str(team1.id)}/projects", token=token2) + data = self.post(f"/v1/teams/{str(team1.id)}/projects", json={}, token=token2) self.assertErrorEqual(data, NoPermissionError) # 没有参数不能创建 - data = self.post(f"/v1/teams/{str(team1.id)}/projects", token=token1) + data = self.post(f"/v1/teams/{str(team1.id)}/projects", json={}, token=token1) self.assertErrorEqual(data, ValidateError) # 缺少参数不能创建 data = self.post( @@ -377,7 +377,7 @@ def test_edit_project(self): data = self.put(f"/v1/projects/{str(project1.id)}", token=token2) self.assertErrorEqual(data, NoPermissionError) # 没有参数不能修改 - data = self.put(f"/v1/projects/{str(project1.id)}", token=token1) + data = self.put(f"/v1/projects/{str(project1.id)}", json={}, token=token1) self.assertErrorEqual(data, RequestDataEmptyError) # 创建一个自定义角色 role1 = project1.create_role( @@ -1639,7 +1639,9 @@ def test_create_target(self): self.assertErrorEqual(data, NoPermissionError) self.assertEqual(project.targets().count(), 1) # user1,缺少语言 - data = self.post(f"/v1/projects/{str(project.id)}/targets", token=token1) + data = self.post( + f"/v1/projects/{str(project.id)}/targets", json={}, token=token1 + ) self.assertErrorEqual(data, ValidateError) self.assertEqual(project.targets().count(), 1) # user1,语言重复 diff --git a/tests/api/test_source_api.py b/tests/api/test_source_api.py index fb9531e..2692d32 100644 --- a/tests/api/test_source_api.py +++ b/tests/api/test_source_api.py @@ -97,16 +97,22 @@ def test_create_image_sources(self): with self.app.test_request_context(): # === 错误测试 === # 没登录不能获取 - data = self.post("/v1/files/{}/sources".format(image_file.id)) + data = self.post("/v1/files/{}/sources".format(image_file.id), json={}) self.assertErrorEqual(data, NeedTokenError) # 其他用户不能登录 - data = self.post("/v1/files/{}/sources".format(image_file.id), token=token2) + data = self.post( + "/v1/files/{}/sources".format(image_file.id), json={}, token=token2 + ) self.assertErrorEqual(data, NoPermissionError) # 文件类型不能创建原文 - data = self.post("/v1/files/{}/sources".format(text_file.id), token=token) + data = self.post( + "/v1/files/{}/sources".format(text_file.id), json={}, token=token + ) self.assertErrorEqual(data, FileTypeNotSupportError) # === 什么都不携带为空source === - data = self.post("/v1/files/{}/sources".format(image_file.id), token=token) + data = self.post( + "/v1/files/{}/sources".format(image_file.id), json={}, token=token + ) self.assertErrorEqual(data) source = Source.by_id(data.json["id"]) self.assertEqual("", source.content) @@ -184,16 +190,18 @@ def test_edit_image_sources(self): with self.app.test_request_context(): # === 错误测试 === # 没登录不能获取 - data = self.put("/v1/sources/{}".format(source.id)) + data = self.put("/v1/sources/{}".format(source.id), json={}) self.assertErrorEqual(data, NeedTokenError) # 其他用户不能登录 - data = self.put("/v1/sources/{}".format(source.id), token=token2) + data = self.put("/v1/sources/{}".format(source.id), json={}, token=token2) self.assertErrorEqual(data, NoPermissionError) # 文件类型不能创建原文 - data = self.put("/v1/sources/{}".format(text_source.id), token=token) + data = self.put( + "/v1/sources/{}".format(text_source.id), json={}, token=token + ) self.assertErrorEqual(data, FileTypeNotSupportError) # 空json报错 - data = self.put("/v1/sources/{}".format(source.id), token=token) + data = self.put("/v1/sources/{}".format(source.id), json={}, token=token) self.assertErrorEqual(data, ValidateError) # 类型不符报错 data = self.put( @@ -338,7 +346,9 @@ def test_edit_image_source_rank(self): ) self.assertErrorEqual(data, FileTypeNotSupportError) # 空json报错 - data = self.put("/v1/sources/{}/rank".format(source1.id), token=token) + data = self.put( + "/v1/sources/{}/rank".format(source1.id), json={}, token=token + ) self.assertErrorEqual(data, ValidateError) # 类型不符报错 data = self.put( diff --git a/tests/api/test_team_api.py b/tests/api/test_team_api.py index 25b01f9..f887f3d 100644 --- a/tests/api/test_team_api.py +++ b/tests/api/test_team_api.py @@ -164,7 +164,7 @@ def test_create_team(self): data = self.post("/v1/teams", json={"name": "t1"}) self.assertErrorEqual(data, NeedTokenError) # == 缺少名称字段时,不能创建 == - data = self.post("/v1/teams", token=token) + data = self.post("/v1/teams", json={}, token=token) self.assertErrorEqual(data, ValidateError) self.assertIsNotNone(data.json["message"].get("name")) # == 名称长度错误时,不能创建 == @@ -342,7 +342,7 @@ def test_edit_team(self): self.assertErrorEqual(data, ValidateError) self.assertIsNotNone(data.json["message"].get("name")) # == 空json报错 == - data = self.put(f"/v1/teams/{str(team1.id)}", token=token) + data = self.put(f"/v1/teams/{str(team1.id)}", json={}, token=token) self.assertErrorEqual(data, RequestDataEmptyError) data = self.put(f"/v1/teams/{str(team1.id)}", json={}, token=token) self.assertErrorEqual(data, RequestDataEmptyError) diff --git a/tests/api/test_translation_api.py b/tests/api/test_translation_api.py index 0051fe0..186787d 100644 --- a/tests/api/test_translation_api.py +++ b/tests/api/test_translation_api.py @@ -84,7 +84,7 @@ def test_create_translation(self): self.assertErrorEqual(data, NoPermissionError) # 空参数 data = self.post( - "/v1/sources/{}/translations".format(source.id), token=token + "/v1/sources/{}/translations".format(source.id), json={}, token=token ) self.assertErrorEqual(data, ValidateError) # === 完整的参数 === @@ -173,7 +173,9 @@ def test_edit_translation(self): ) self.assertErrorEqual(data, NoPermissionError) # 空json报错 - data = self.put("/v1/translations/{}".format(translation.id), token=token) + data = self.put( + "/v1/translations/{}".format(translation.id), json={}, token=token + ) self.assertErrorEqual(data, ValidateError) # 翻译的初始化状态 translation.reload() diff --git a/tests/base/test_default_admin.py b/tests/base/test_default_admin.py index bf0215b..e53a496 100644 --- a/tests/base/test_default_admin.py +++ b/tests/base/test_default_admin.py @@ -1,4 +1,5 @@ from app import create_app +from app.factory import init_db from app.models.user import User from tests import MoeTestCase @@ -20,7 +21,7 @@ def test_reset_default_admin(self): admin_user.save() admin_user.reload() self.assertEqual(admin_user.admin, False) - create_app() + init_db(create_app()) admin_user.reload() self.assertEqual(admin_user.admin, True) # 测试其他用户权限不受影响 @@ -34,7 +35,7 @@ def test_reset_default_admin_when_true(self): self.assertEqual(admin_user.email, self.app.config["ADMIN_EMAIL"]) self.assertEqual(admin_user.admin, True) self.assertEqual(user.admin, False) - create_app() + init_db(create_app()) admin_user.reload() self.assertEqual(admin_user.admin, True) # 测试其他用户权限不受影响 diff --git a/tests/base/test_logging.py b/tests/base/test_logging.py new file mode 100644 index 0000000..7b5ea7d --- /dev/null +++ b/tests/base/test_logging.py @@ -0,0 +1,29 @@ +from app import flask_app, create_app +import app.utils.logging as app_logging +import logging + +# app_logging.configure_root_logger() +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def test_logging(): + assert 1 is 1 + logger.log(0, "notset?") + logger.debug("debug") + logger.info("info") + logger.warning("warn") + logger.error("local logger %d / %d", logger.level, logger.getEffectiveLevel()) + app_logging.logger.debug("debug to global logger") + app_logging.logger.info("info to global logger") + app_logging.logger.warning("warn to global logger") + app_logging.logger.error( + "global logger %d / %d", + app_logging.logger.level, + app_logging.logger.getEffectiveLevel(), + ) + root_logger = logging.getLogger("root") + logging.error( + "root logger %d / %d", root_logger.level, root_logger.getEffectiveLevel() + ) + create_app() diff --git a/tests/deps.yaml b/tests/deps.yaml new file mode 100644 index 0000000..5585527 --- /dev/null +++ b/tests/deps.yaml @@ -0,0 +1,39 @@ +# genearted from moeflow-deploy repo like +# docker-compose -f docker-compose.yml -f docker-compose.dev.yml config moeflow-mongodb moeflow-rabbitmq +version: '3.3' +# name: moeflow-backend-test-deps +services: + moeflow-mongodb: + environment: + MONGO_INITDB_ROOT_PASSWORD: CHANGE_ME + MONGO_INITDB_ROOT_USERNAME: moeflow + healthcheck: + test: + - CMD + - mongo + - --eval + - db.adminCommand('ping') + timeout: 5s + interval: 15s + start_period: 10s + image: docker.io/mongo:4.4.1 + ports: + - 127.0.0.1:27017:27017 + restart: unless-stopped + moeflow-rabbitmq: + environment: + RABBITMQ_DEFAULT_PASS: CHANGE_ME + RABBITMQ_DEFAULT_USER: moeflow + RABBITMQ_DEFAULT_VHOST: moeflow + healthcheck: + test: + - CMD-SHELL + - rabbitmq-diagnostics -q ping + timeout: 5s + interval: 5s + start_period: 10s + image: docker.io/rabbitmq:3.8.9-management + ports: + - 127.0.0.1:5672:5672 # AMQP + - 127.0.0.1:15672:15672 # management UI + restart: unless-stopped diff --git a/tests/model/test_file_storage_model.py b/tests/model/test_file_storage_model.py index b9bf3bc..848d9f4 100644 --- a/tests/model/test_file_storage_model.py +++ b/tests/model/test_file_storage_model.py @@ -1,8 +1,8 @@ import os -import requests from bson import ObjectId +# FIXME: testee should be parametrized instance, not singleton from app import TMP_PATH, oss from app.constants.storage import StorageType from tests import MoeTestCase @@ -76,16 +76,10 @@ def test_sign_url(self): self.assertFalse(oss.is_exist(self.path, filename1)) oss.upload(self.path, filename1, filename1) self.assertTrue(oss.is_exist(self.path, filename1)) - # 直接用链接访问,报错 - if self.app.config["STORAGE_TYPE"] == StorageType.OSS: - response = requests.get( - self.app.config["STORAGE_DOMAIN"] + self.path + filename1 - ) - self.assertEqual(403, response.status_code) - # 签名后的url可以访问 + # it creates an url for user agent url = oss.sign_url(self.path, filename1) - response = requests.get(url) - self.assertEqual(filename1, response.text) + # FIXME this only works in CI test + self.assertEqual(url, f"http://127.0.0.1:5000/storage/test/{filename1}") # 清理,删除这个文件 oss.delete(self.path, [filename1]) self.assertFalse(oss.is_exist(self.path, filename1))