diff --git a/chiya/bot.py b/chiya/bot.py index 6ae25c85..55934d77 100644 --- a/chiya/bot.py +++ b/chiya/bot.py @@ -23,9 +23,7 @@ @bot.event async def on_ready() -> None: - """ - Called when the client is done preparing the data received from Discord. - """ + "Called when the client is done preparing the data received from Discord." log.info(f"Logged in as: {str(bot.user)}") await bot.tree.sync(guild=discord.Object(config.guild_id)) @@ -43,9 +41,7 @@ async def setup_logger(): log.remove() class InterceptHandler(logging.Handler): - """ - Setup up an Interceptor class to redirect all logs from the standard logging library to loguru. - """ + "Setup up an Interceptor class to redirect all logs from the standard logging library to loguru." def emit(self, record: logging.LogRecord) -> None: # Get corresponding Loguru level if it exists. diff --git a/chiya/cogs/commands/reminder.py b/chiya/cogs/commands/reminder.py index fea5116d..debb1e16 100644 --- a/chiya/cogs/commands/reminder.py +++ b/chiya/cogs/commands/reminder.py @@ -1,5 +1,3 @@ -import asyncio - import discord from discord.ext import commands from discord import app_commands diff --git a/chiya/config.py b/chiya/config.py index d18f4ae1..efa3ff35 100644 --- a/chiya/config.py +++ b/chiya/config.py @@ -75,8 +75,7 @@ class ChiyaConfig(ParentModel): workspace = Path(__file__).parent.parent config_file = workspace / "config.toml" - if not config_file.is_file(): - raise FileNotFoundError("Unable to load config.yml, exiting...") + raise FileNotFoundError("Unable to load config.toml, exiting...") config = ChiyaConfig.model_validate(tomllib.load(config_file.open("rb"))) diff --git a/chiya/database.py b/chiya/database.py index d554a8ee..fde2b9ad 100644 --- a/chiya/database.py +++ b/chiya/database.py @@ -1,108 +1,82 @@ -import dataset -from loguru import logger as log -from sqlalchemy import create_engine -from sqlalchemy_utils import database_exists, create_database +from sqlalchemy import create_engine, BigInteger, Boolean, Column, Integer, Text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker from chiya.config import config -class Database: - def __init__(self) -> None: - host = config.database.host - database = config.database.database - user = config.database.user - password = config.database.password - - if not all([host, database, user, password]): - log.error("One or more database connection variables are missing, exiting...") - raise SystemExit - - # self.url = f"mysql://{user}:{password}@{host}/{database}?charset=utf8mb4" - self.url = config.database.url - - def get(self) -> dataset.Database: - """Returns the dataset database object.""" - return dataset.connect(url=self.url) - - def setup(self) -> None: - """Sets up the tables needed for Chiya.""" - engine = create_engine(self.url) - if not database_exists(engine.url): - create_database(engine.url) - - db = self.get() - - if "mod_logs" not in db: - mod_logs = db.create_table("mod_logs") - mod_logs.create_column("user_id", db.types.bigint) - mod_logs.create_column("mod_id", db.types.bigint) - mod_logs.create_column("timestamp", db.types.bigint) - mod_logs.create_column("reason", db.types.text) - mod_logs.create_column("duration", db.types.text) - mod_logs.create_column("type", db.types.text) - log.info("Created missing table: mod_logs") - - if "remind_me" not in db: - remind_me = db.create_table("remind_me") - remind_me.create_column("reminder_location", db.types.bigint) - remind_me.create_column("author_id", db.types.bigint) - remind_me.create_column("date_to_remind", db.types.bigint) - remind_me.create_column("message", db.types.text) - remind_me.create_column("sent", db.types.boolean, default=False) - log.info("Created missing table: remind_me") - - if "timed_mod_actions" not in db: - timed_mod_actions = db.create_table("timed_mod_actions") - timed_mod_actions.create_column("user_id", db.types.bigint) - timed_mod_actions.create_column("mod_id", db.types.bigint) - timed_mod_actions.create_column("action_type", db.types.text) - timed_mod_actions.create_column("start_time", db.types.bigint) - timed_mod_actions.create_column("end_time", db.types.bigint) - timed_mod_actions.create_column("is_done", db.types.boolean, default=False) - timed_mod_actions.create_column("reason", db.types.text) - log.info("Created missing table: timed_mod_actions") - - if "tickets" not in db: - tickets = db.create_table("tickets") - tickets.create_column("user_id", db.types.bigint) - tickets.create_column("guild", db.types.bigint) - tickets.create_column("timestamp", db.types.bigint) - tickets.create_column("ticket_subject", db.types.text) - tickets.create_column("ticket_message", db.types.text) - tickets.create_column("log_url", db.types.text) - tickets.create_column("status", db.types.boolean) - log.info("Created missing table: tickets") - - if "starboard" not in db: - starboard = db.create_table("starboard") - starboard.create_column("channel_id", db.types.bigint) - starboard.create_column("message_id", db.types.bigint) - starboard.create_column("star_embed_id", db.types.bigint) - log.info("Created missing table: starboard") - - if "joyboard" not in db: - joyboard = db.create_table("joyboard") - joyboard.create_column("channel_id", db.types.bigint) - joyboard.create_column("message_id", db.types.bigint) - joyboard.create_column("joy_embed_id", db.types.bigint) - log.info("Created missing table: joyboard") - - if "highlights" not in db: - highlights = db.create_table("highlights") - highlights.create_column("term", db.types.text) - highlights.create_column("users", db.types.text) - log.info("Created missing table: highlights") - - # utf8mb4_unicode_ci is required to support emojis and other unicode. - # dataset does not expose collation in any capacity so rather than - # checking an object property, we have to do this hacky way of checking - # the charset via queries and updating it where necessary. - # for table in db.tables: - # charset = next(db.query(f"SHOW TABLE STATUS WHERE NAME = '{table}';"))["Collation"] - # if charset == "utf8mb4_unicode_ci": - # continue - # db.query(f"ALTER TABLE {table} CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;") - # log.info(f"Converted table to utf8mb4_unicode_ci: {table}") - - db.commit() - db.close() +Base = declarative_base() +engine = create_engine(config.database.url, connect_args={"check_same_thread": False}) +session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +class BaseModel(Base): + __abstract__ = True + + def save(self): + session.add(self) + session.commit() + return self + + def delete(self): + session.delete(self) + session.commit() + return self + + def flush(self): + session.add(self) + session.flush() + return self + + +class ModLog(Base): + __tablename__ = "mod_logs" + + id = Column(Integer, primar_key=True) + user_id = Column(BigInteger, nullable=False) + mod_id = Column(BigInteger, nullable=False) + timestamp = Column(BigInteger, nullable=False) + reason = Column(Text, nullable=False) + duration = Column(Text, nullable=False) + type = Column(Text, nullable=False) + + +class RemindMe(Base): + __tablename__ = "remind_me" + + id = Column(Integer, primar_key=True) + reminder_location = Column(BigInteger, nullable=False) + author_id = Column(BigInteger, nullable=False) + date_to_remind = Column(BigInteger, nullable=False) + message = Column(Text, nullable=False) + sent = Column(Boolean, nullable=False, default=False) + + +class Ticket(Base): + __tablename__ = "tickets" + + id = Column(Integer, primary_key=True) + user_id = Column(BigInteger, nullable=False) + guild = Column(BigInteger, nullable=False) + timestamp = Column(BigInteger, nullable=False) + ticket_subject = Column(Text, nullable=False) + ticket_message = Column(Text, nullable=False) + log_url = Column(Text, nullable=False) + status = Column(Boolean) + + +class Joyboard(Base): + __tablename__ = "joyboard" + + id = Column(Integer, primary_key=True) + channel_id = Column(BigInteger, nullable=False) + message_id = Column(BigInteger, nullable=False) + joy_embed_id = Column(BigInteger, nullable=False) + + +class Highlight(Base): + __tablename__ = "highlights" + + id = Column(Integer, primary_key=True) + term = Column(Text, nullable=False) + users = Column(Text, nullable=False) diff --git a/pyproject.toml b/pyproject.toml index 142df41a..10479616 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,8 @@ description = "A moderation-heavy general purpose Discord bot" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "alembic>=1.14.0", + "arrow>=1.3.0", "dataset==1.6.2", "discord-ext-menus>=1.1", "discord-py[speed]==2.4.0", @@ -14,6 +16,7 @@ dependencies = [ "privatebinapi==1.0.0", "pydantic>=2.10.4", "requests>=2.32.3", + "sqlalchemy>=1.4.54", "sqlalchemy-utils>=0.41.2", ] @@ -24,4 +27,4 @@ dev = [ ] [tool.ruff] -line-length = 120 \ No newline at end of file +line-length = 120 diff --git a/uv.lock b/uv.lock index 4cb1ca66..6f182be4 100644 --- a/uv.lock +++ b/uv.lock @@ -133,6 +133,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/7a/4daaf3b6c08ad7ceffea4634ec206faeff697526421c20f07628c7372156/anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352", size = 93052 }, ] +[[package]] +name = "arrow" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "types-python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/00/0f6e8fcdb23ea632c866620cc872729ff43ed91d284c866b515c6342b173/arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85", size = 131960 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/ed/e97229a566617f2ae958a6b13e7cc0f585470eac730a73e9e82c32a3cdd2/arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80", size = 66419 }, +] + [[package]] name = "attrs" version = "24.3.0" @@ -327,6 +340,8 @@ name = "chiya" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "alembic" }, + { name = "arrow" }, { name = "dataset" }, { name = "discord-ext-menus" }, { name = "discord-py", extra = ["speed"] }, @@ -336,6 +351,7 @@ dependencies = [ { name = "privatebinapi" }, { name = "pydantic" }, { name = "requests" }, + { name = "sqlalchemy" }, { name = "sqlalchemy-utils" }, ] @@ -347,6 +363,8 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", specifier = ">=1.14.0" }, + { name = "arrow", specifier = ">=1.3.0" }, { name = "dataset", specifier = "==1.6.2" }, { name = "discord-ext-menus", specifier = ">=1.1" }, { name = "discord-py", extras = ["speed"], specifier = "==2.4.0" }, @@ -356,6 +374,7 @@ requires-dist = [ { name = "privatebinapi", specifier = "==1.0.0" }, { name = "pydantic", specifier = ">=2.10.4" }, { name = "requests", specifier = ">=2.32.3" }, + { name = "sqlalchemy", specifier = ">=1.4.54" }, { name = "sqlalchemy-utils", specifier = ">=0.41.2" }, ] @@ -1023,6 +1042,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/b2/b2b50d5ecf21acf870190ae5d093602d95f66c9c31f9d5de6062eb329ad1/pydantic_core-2.27.2-cp313-cp313-win_arm64.whl", hash = "sha256:ac4dbfd1691affb8f48c2c13241a2e3b60ff23247cbcf981759c768b6633cf8b", size = 1885186 }, ] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, +] + [[package]] name = "requests" version = "2.32.3" @@ -1063,6 +1094,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/9f/026e18ca7d7766783d779dae5e9c656746c6ede36ef73c6d934aaf4a6dec/ruff-0.8.4-py3-none-win_arm64.whl", hash = "sha256:9183dd615d8df50defa8b1d9a074053891ba39025cf5ae88e8bcb52edcc4bf08", size = 9074500 }, ] +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, +] + [[package]] name = "sjcl" version = "0.2.1" @@ -1114,6 +1154,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/f0/dc4757b83ac1ab853cf222df8535ed73973e0c203d983982ba7b8bc60508/SQLAlchemy_Utils-0.41.2-py3-none-any.whl", hash = "sha256:85cf3842da2bf060760f955f8467b87983fb2e30f1764fd0e24a48307dc8ec6e", size = 93083 }, ] +[[package]] +name = "types-python-dateutil" +version = "2.9.0.20241206" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/60/47d92293d9bc521cd2301e423a358abfac0ad409b3a1606d8fbae1321961/types_python_dateutil-2.9.0.20241206.tar.gz", hash = "sha256:18f493414c26ffba692a72369fea7a154c502646301ebfe3d56a04b3767284cb", size = 13802 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/b3/ca41df24db5eb99b00d97f89d7674a90cb6b3134c52fb8121b6d8d30f15c/types_python_dateutil-2.9.0.20241206-py3-none-any.whl", hash = "sha256:e248a4bc70a486d3e3ec84d0dc30eec3a5f979d6e7ee4123ae043eedbb987f53", size = 14384 }, +] + [[package]] name = "typing-extensions" version = "4.12.2"