Skip to content

Commit

Permalink
Add assistant cog
Browse files Browse the repository at this point in the history
  • Loading branch information
jotonedev committed Mar 20, 2024
1 parent e369cd0 commit dae5ab5
Show file tree
Hide file tree
Showing 8 changed files with 817 additions and 611 deletions.
11 changes: 5 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ RUN pip wheel --no-cache-dir --no-deps --wheel-dir /build/wheels -r /build/wheel

FROM python:3.12-slim-bookworm

# Accept secrets as arguments
ARG TOKEN="discord_token"
ARG GUILD_ID="0"
ENV TOKEN $TOKEN
ENV GUILD_ID $GUILD_ID

ENV PYTHONFAULTHANDLER=1 \
PYTHONUNBUFFERED=1 \
PYTHONHASHSEED=random \
Expand All @@ -41,6 +35,11 @@ ENV PYTHONFAULTHANDLER=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
PIP_DEFAULT_TIMEOUT=100

ENV DS_TOKEN "YOUR_DISCORD_TOKEN" # discord token from the developer portal
ENV DS_GUILD_ID "YOUR_GUILD_ID" # the guild id where the bot will be used
ENV CF_CLIENT_ID "YOUR_CLIENT_ID" # cloudflare client id
ENV CF_TOKEN "YOUR_CLOUDFLARE_TOKEN" # cloudflare token

# Change workdir
WORKDIR /bot

Expand Down
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@

## Description

This is a very simple discord bot using slash commands to play music using lavalink on your guild channels.
This is a very simple discord bot using slash commands to play music using lavalink on your guild channels.

## Setup

### Docker

The preferred method to run this is with a docker container. To launch it, run the following command:

```bash
docker run -d -e TOKEN="DISCORD_TOKEN" -e GUILD_ID="YOUR_GUILD_ID" -v $(pwd)/lavalink.json:/bot/config/lavalink.json ghcr.io/jotonedev/dsmusic:v0.3.2
```
The file lavalink.json must be created using the [template](config/lavalink.example.json) in the repository.
If you haven't already set up a lavalink node, you can check the lavalink repository [here](https://github.com/lavalink-devs/Lavalink) on how to set up one.
After that you need to add its ip address, port and password in the lavalink.json file. You can add how many nodes you want, but only one is required.

The file lavalink.json must be created using the [template](config/lavalink.example.json) in the repository.
If you haven't already set up a lavalink node, you can check the lavalink
repository [here](https://github.com/lavalink-devs/Lavalink) on how to set up one.
After that you need to add its ip address, port and password in the lavalink.json file. You can add how many nodes you
want, but only one is required.

### Console

You can also launch the bot manually using the following commands (just remember to edit the lavalink.json appropriately):
You can also launch the bot manually using the following commands (just remember to edit the lavalink.json
appropriately):

```bash
# Add environment variables
Expand Down
37 changes: 31 additions & 6 deletions dsmusic/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,48 @@
else:
uvloop.install()

intents = discord.Intents(387)
intents = discord.Intents(
guilds=True,
members=True,
messages=True,
voice_states=True,
presences=True,
message_content=True
)

permissions = discord.Permissions(
send_messages=True,
read_messages=True,

connect=True,
speak=True,
use_voice_activation=True,
use_soundboard=True,

manage_threads=True,
send_messages_in_threads=True,

attach_files=True,
embed_links=True,
)

client = Client(
intents=intents,
command_prefix="!",
max_messages=None,
assume_unsync_clock=False,
activity=discord.CustomActivity(name="Get Joshed."),
activity=discord.CustomActivity(name="Gressinbon"),
status=discord.Status.online,
mentions=discord.AllowedMentions.none(),
help_command=None
)

oauth_url = discord.utils.oauth_url(839827510761488404, guild=discord.Object(os.getenv("GUILD_ID")))
oauth_url = discord.utils.oauth_url(
client_id=839827510761488404,
guild=discord.Object(os.getenv("DS_GUILD_ID")),
permissions=permissions
)
print(f"Bot URL: {oauth_url}")

token = os.getenv("TOKEN")
token = os.getenv("DS_TOKEN")

if token:
client.run(token=token)
Expand Down
Empty file added dsmusic/assistant/__init__.py
Empty file.
173 changes: 173 additions & 0 deletions dsmusic/assistant/cog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from dataclasses import dataclass, field, asdict
from typing import Any, Literal

import aiohttp
import discord
from discord import app_commands
from discord.ext import commands
from yarl import URL


@dataclass
class UnscopedPrompt:
prompt: str
raw: bool = field(default=False, init=False)
stream: bool = field(default=False, init=False)
max_tokens: int = field(default=256, init=False)


@dataclass
class Message:
content: str
role: Literal["user", "system", "assistant"] = "assistant"


@dataclass
class Response:
response: str


@app_commands.guild_only()
class Assistant(commands.Cog):
def __init__(
self,
bot: discord.Client,
cf_account_id: str,
cf_api_token: str,
session: aiohttp.ClientSession | None = None,
*,
rest_url: str | URL | None = None,
):
self.bot = bot

if session is None:
from orjson import dumps
self.session = aiohttp.ClientSession(
json_serialize=lambda obj: dumps(obj).decode("utf-8", errors="ignore")
)
else:
self.session = session

self._cf_account_id = cf_account_id
self._cf_api_token = cf_api_token

if rest_url is None:
self.rest_url = URL(f"https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/"),
else:
if isinstance(rest_url, URL):
self.rest_url = rest_url
else:
self.rest_url = URL(rest_url)

self.headers = {
"Authorization": f"Bearer {cf_api_token}",
"Content-Type": "application/json",
"DNT": "1",
"Accept-Encoding": "gzip, deflate, br",
"User-Agent": "ds-bot"
}
self.session.headers.update(self.headers)

async def _request(
self,
url: URL,
method: str,
headers: dict[str, Any] | None = None,
payload: dict[str, Any] | None = None,
query: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
Make a request using aiohttp
:param url: the url to make the request to
:param method: method to use
:param headers: headers to use
:param payload: payload to send (used for POST requests)
:param query: query to append to the url
:return: the response as a dict (json
"""
async with self.session.request(
method,
url,
headers=headers,
json=payload,
params=query
) as resp:
return await resp.json()

async def unscoped_prompt(
self,
prompt: str,
raw: bool = False,
max_tokens: int = 256,
*, model: str = "@cf/openchat/openchat-3.5-0106"
) -> Response:
"""
Send an unscoped prompt to the llm
:param prompt: the prompt to send
:param raw: whether the prompt uses raw parameters in the prompt
:param max_tokens: the maximum number of tokens to generate
:param model: the model to use (list of models: https://developers.cloudflare.com/workers-ai/models/)
:return: the response from the llm
"""
data = {
"prompt": prompt,
"raw": raw,
"stream": False,
"max_tokens": max_tokens
}

url = self.rest_url / model
response = await self._request(url, "POST", payload=data)
return Response(**response)

async def scoped_prompt(
self,
messages: list[Message],
max_tokens: int = 256,
*, model: str = "@cf/openchat/openchat-3.5-0106"
):
"""
Send a scoped prompt to the llm.
This allows to have a conversation with the llm using the previous messages as knowledge.
:param messages: a list of messages to send
:param max_tokens: the maximum number of tokens to generate
:param model: the model to use (list of models: https://developers.cloudflare.com/workers-ai/models/)
:return: the response from the llm
"""
if messages[0].role != "system":
messages.insert(0, Message(
content="You are an assistant that can speak any language the user requires. You respond to any "
"question in a way that is helpful and correct.",
role="system"
))

payload = {
"messages": [asdict(message) for message in messages],
"stream": False,
"max_tokens": max_tokens
}

url = self.rest_url / model
response = await self._request(url, "POST", payload=payload)
return Response(**response)

@app_commands.command(name="ask", description="Ask a question to the assistant")
@app_commands.checks.cooldown(3, 10, key=lambda i: (i.guild_id, i.user.id))
@app_commands.describe(query="The question you want to ask", model="The model to use (use openchat by default)")
async def unscoped_prompt_command(self, interaction: discord.Interaction, prompt: str, model: str | None = None):
"""Send a prompt to the assistant"""
# noinspection PyTypeChecker
resp: discord.InteractionResponse = interaction.response

await resp.defer(thinking=True)
llm_response = await self.unscoped_prompt(prompt, model=model)

return await interaction.followup.send(llm_response.response)


async def setup(bot: commands.Bot) -> None:
from os import getenv
client_id = getenv("CF_ACCOUNT_ID")
api_token = getenv("CF_API_TOKEN")

await bot.add_cog(Assistant(bot, client_id, api_token, ))
12 changes: 6 additions & 6 deletions dsmusic/client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import asyncio
import json
import logging
import os
from os import getenv

import discord
import mafic
from mafic import NodeAlreadyConnected
from discord import app_commands
from discord.ext import commands
from mafic import NodeAlreadyConnected

__all__ = [
"Client"
]


logger = logging.getLogger('discord.dsbot')


Expand All @@ -26,17 +24,19 @@ def __init__(self, *args, **kwargs):
self.pool = mafic.NodePool(self)

# App commands
self.guild_id = discord.Object(id=getenv("GUILD_ID", 0))
self.guild_id = discord.Object(id=getenv("DS_GUILD_ID", 0))
self.tree.on_error = self.on_tree_error

async def setup_hook(self):
logger.info("Loading extensions")
await self.load_extension("dsmusic.tracker.cog")
await self.load_extension("dsmusic.music.cog")
await self.load_extension("dsmusic.assistant.cog")

logger.info("Extensions loaded")

# Add lavalink nodes
self.loop.create_task(self.add_nodes())
await self.loop.create_task(self.add_nodes())

# This copies the global commands over to your guild.
logger.info("Syncing command tree")
Expand Down Expand Up @@ -81,7 +81,6 @@ async def add_nodes(self):

if len(self.pool.nodes) == 0:
logger.error("No nodes connected")
return
else:
logger.info(f"{len(self.pool.nodes)} nodes connected")

Expand All @@ -95,3 +94,4 @@ async def on_tree_error(interaction: discord.Interaction, error: app_commands.Ap
return await interaction.response.send_message(f"You are not authorized to use that", ephemeral=True)
else:
logger.error(error)
return await interaction.response.send_message("An error occurred", ephemeral=True)
10 changes: 7 additions & 3 deletions dsmusic/music/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,23 @@ async def play(self, interaction: discord.Interaction, query: str):
# noinspection PyTypeChecker
resp: discord.InteractionResponse = interaction.response

await resp.defer()
await resp.defer(thinking=True)

if interaction.guild.voice_client is None:
vc: LavalinkPlayer = await interaction.user.voice.channel.connect(self_deaf=True, cls=LavalinkPlayer)
vc.is_connected()
else:
if interaction.guild.voice_client.channel != interaction.user.voice.channel:
return await resp.send_message("⚠️ Already on a different channel", ephemeral=True)
return await interaction.followup.send("⚠️ Already on a different channel", ephemeral=True)
else:
# noinspection PyTypeChecker
vc: LavalinkPlayer = interaction.guild.voice_client

tracks = await vc.fetch_tracks(query)
try:
with asyncio.timeout(10):
tracks = await vc.fetch_tracks(query)
except asyncio.TimeoutError:
return await interaction.followup.send("⚠️ Timed out (please, report to the bot owner)", ephemeral=True)

if tracks is None:
return await interaction.followup.send("⚠️ No song found", ephemeral=True)
Expand Down
Loading

0 comments on commit dae5ab5

Please sign in to comment.