Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/content to embedding model #18

Merged
merged 19 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
pull_request:

env:
DJANGO_SECRET_KEY: "ci-test-insecure-django-secret-key"
COMPOSE_FILE: docker-compose.yml:gh-docker-compose.yml

jobs:
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12
3.10
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import requests
from langchain.embeddings.base import Embeddings
from utils import EmbeddingModelType

from chatbotcore.utils import EmbeddingModelType


@dataclass
Expand Down
1 change: 0 additions & 1 deletion chatbot-core/database.py → chatbotcore/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def store_data(self, data: list) -> None:
point_vectors = [
{"id": str(uuid.uuid4()), "vector": v_representation, "payload": metadata} for v_representation, metadata in data
]

response = self.db_client.upsert(collection_name=self.collection_name, points=point_vectors)
return response

Expand Down
2 changes: 1 addition & 1 deletion chatbot-core/doc_loaders.py → chatbotcore/doc_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class DocumentLoader:
Base Class for Document Loaders
"""

chunk_size: int = 200
chunk_size: int = 100
chunk_overlap: int = 20

def _get_split_documents(self, documents: List[Document]):
Expand Down
17 changes: 9 additions & 8 deletions chatbot-core/llm.py → chatbotcore/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass, field
from typing import Any, Optional

from custom_embeddings import CustomEmbeddingsWrapper
from django.conf import settings
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
Expand All @@ -16,6 +15,8 @@
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient

from chatbotcore.custom_embeddings import CustomEmbeddingsWrapper

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -93,7 +94,7 @@ def get_prompt_template_for_response(self):
)
return llm_response_prompt

def get_db_retriever(self, collection_name: str, top_k_items: int = 5, score_threshold: float = 0.5):
def get_db_retriever(self, collection_name: str, top_k_items: int = 5, score_threshold: float = 0.7):
"""Get the database retriever"""
db_retriever = QdrantVectorStore(
client=self.qdrant_client, collection_name=collection_name, embedding=self.embedding_model
Expand All @@ -120,21 +121,21 @@ def create_chain(self, db_collection_name: str):
rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain)
return rag_chain

def execute_chain(self, user_id: str, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
async def execute_chain(self, user_id: str, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
"""
Executes the chain
"""
if not self.rag_chain:
self.rag_chain = self.create_chain(db_collection_name=db_collection_name)

if "user_id" not in self.user_memory_mapping:
if user_id not in self.user_memory_mapping:
self.user_memory_mapping[user_id] = ConversationBufferWindowMemory(
k=self.conversation_max_window, memory_key=self.mem_key, return_messages=True
)

memory = self.user_memory_mapping[user_id]

response = self.rag_chain.invoke(
response = await self.rag_chain.ainvoke(
{"input": query, "chat_history": self.get_message_history(user_id=user_id)["chat_history"]}
)
response_text = response["answer"] if "answer" in response else "I don't know the answer."
Expand All @@ -147,13 +148,13 @@ def get_message_history(self, user_id: str):
"""
Returns the historical conversational data
"""
if "user_id" in self.user_memory_mapping:
if user_id in self.user_memory_mapping:
return self.user_memory_mapping[user_id].load_memory_variables({})
return {}
return {"chat_history": []}

def delete_message_history_by_user(self, user_id: str) -> bool:
"""Deletes the message history based on user id"""
if "user_id" in self.user_memory_mapping:
if user_id in self.user_memory_mapping:
del self.user_memory_mapping[user_id]
logger.info(f"Successfully delete the {user_id} conversational history.")
return True
Expand Down
File renamed without changes.
Empty file added common/utils.py
Empty file.
2 changes: 1 addition & 1 deletion content/migrations/0002_alter_content_document_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 5.1.1 on 2024-09-18 05:28
# Generated by Django 5.1.1 on 2024-09-18 08:20

from django.db import migrations, models

Expand Down
10 changes: 7 additions & 3 deletions content/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.utils.translation import gettext_lazy as _

from common.models import UserResource
from content.tasks import create_embedding_for_content_task


class Tag(models.Model):
Expand Down Expand Up @@ -43,7 +44,10 @@ def __str__(self):

def save(self, *args, **kwargs):
"""Save the content to the database."""
if self.pk is None:
if self.document_type == self.DocumentType.TEXT:
self.extracted_file = self.document_file
if self.pk is None and self.document_type == self.DocumentType.TEXT:
self.extracted_file = self.document_file
self.document_status = self.DocumentStatus.TEXT_EXTRACTED

super().save(*args, **kwargs)
if self.document_status == self.DocumentStatus.TEXT_EXTRACTED:
create_embedding_for_content_task(self.id)
6 changes: 6 additions & 0 deletions content/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from rest_framework import serializers


class UserQuerySerializer(serializers.Serializer):
query = serializers.CharField(required=True, allow_null=False, allow_blank=False)
user_id = serializers.UUIDField(required=True)
38 changes: 38 additions & 0 deletions content/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import requests
from celery import shared_task
from django.conf import settings

from chatbotcore.database import QdrantDatabase
from chatbotcore.doc_loaders import LoaderFromText


@shared_task(blind=True)
def create_embedding_for_content_task(content_id):
from content.models import Content

content = Content.objects.get(id=content_id)
headers = {"Content-Type": "application/json"}
data = content.extracted_file.read()
loader = LoaderFromText(text=data)
split_docs = loader.create_document_chunks()

payload = {
"type_model": settings.EMBEDDING_MODEL_TYPE,
"name_model": settings.EMBEDDING_MODEL_NAME,
"texts": [split_docs[i].page_content for i in range(len(split_docs))],
}
response = requests.post(settings.EMBEDDING_MODEL_URL, headers=headers, json=payload)
metadata = [
{"source": "plain-text", "page_content": split_docs[i].page_content, "uuid": content.content_id}
for i in range(len(split_docs))
]
if response.status_code == 200:
db = QdrantDatabase(
host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT, collection_name=settings.QDRANT_DB_COLLECTION_NAME
)
db.set_collection()
db.store_data(zip(response.json(), metadata))
content.document_status = Content.DocumentStatus.ADDED_TO_VECTOR
else:
content.document_status = Content.DocumentStatus.FAILURE
thenav56 marked this conversation as resolved.
Show resolved Hide resolved
content.save()
18 changes: 18 additions & 0 deletions content/views.py
Original file line number Diff line number Diff line change
@@ -1 +1,19 @@
# Create your views here.
import asyncio

from rest_framework.generics import GenericAPIView
from rest_framework.response import Response

from chatbotcore.llm import OllamaHandler
from content.serializers import UserQuerySerializer


class UserQuery(GenericAPIView):
llm = OllamaHandler()

def post(self, request, *arg, **kwargs):
serializer = UserQuerySerializer(data=request.data)
if serializer.is_valid():
result = asyncio.run(self.llm.execute_chain(request.data["user_id"], request.data["query"]))
return Response(result)
return Response(serializer.errors, 422)
97 changes: 69 additions & 28 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,66 @@

name: chatbot

x-server: &base_server_setup
build:
context: .
# Used for python debugging.
stdin_open: true
tty: true
env_file:
- .env
environment:
APP_ENVIRONMENT: ${APP_ENVIRONMENT:-development}
APP_TYPE: web
DJANGO_DEBUG: ${DJANGO_DEBUG:-true}
DJANGO_SECRET_KEY: ${DJANGO_SECRET_KEY}
thenav56 marked this conversation as resolved.
Show resolved Hide resolved
DJANGO_TIME_ZONE: ${DJANGO_TIME_ZONE:-Asia/Kathmandu}
# -- Domain configurations
DJANGO_ALLOWED_HOSTS: ${DJANGO_ALLOWED_HOSTS:-*}
APP_DOMAIN: localhost:8000
APP_HTTP_PROTOCOL: ${APP_HTTP_PROTOCOL:-http}
SESSION_COOKIE_DOMAIN: ${SESSION_COOKIE_DOMAIN:-localhost}
# Database config
DATABASE_NAME: ${DATABASE_NAME:-postgres}
DATABASE_USER: ${DATABASE_USER:-postgres}
DATABASE_PASSWORD: ${DATABASE_PASSWORD:-postgres}
DATABASE_PORT: ${DATABASE_PORT:-5432}
DATABASE_HOST: ${DATABASE_HOST:-db}
# # Redis config
CELERY_REDIS_URL: ${CELERY_REDIS_URL:-redis://redis:6379/0}
DJANGO_CACHE_REDIS_URL: ${DJANGO_CACHE_REDIS_URL:-redis://redis:6379/1}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://redis:6379/0}
CELERY_RESULT_BACKEND: ${CELERY_RESULT_BACKEND:-redis://redis:6379/0}
CELERY_ACCEPT_CONTENT: ${CELERY_ACCEPT_CONTENT},
CELERY_TASK_SERIALIZER: ${CELERY_TASK_SERIALIZER:-json}
CELERY_RESULT_SERIALIZER: ${CELERY_RESULT_SERIALIZER:-json}
CELERY_TIMEZONE: ${CELERY_TIMEZONE:-UTC}

# EMBEDDING MODEL
QDRANT_DB_HOST: ${QDRANT_DB_HOST:-0.0.0.0}
QDRANT_DB_PORT: ${QDRANT_DB_PORT:-6333}
QDRANT_DB_COLLECTION_NAME: ${QDRANT_DB_COLLECTION_NAME:-test}
EMBEDDING_MODEL_URL: ${EMBEDDING_MODEL_URL:-localhost}
EMBEDDING_MODEL_NAME: ${EMBEDDING_MODEL_NAME:-embedding_model}
EMBEDDING_MODEL_VECTOR_SIZE: ${EMBEDDING_MODEL_VECTOR_SIZE:-384}
EMBEDDING_MODEL_TYPE: ${EMBEDDING_MODEL_TYPE:-1}
OLLAMA_EMBEDDING_MODEL_BASE_URL: ${OLLAMA_EMBEDDING_MODEL_BASE_URL:-model}
LLM_TYPE: ${LLM_TYPE:-1}
LLM_MODEL_NAME: ${LLM_MODEL_NAME:-"mistral:latest"}
LLM_OLLAMA_BASE_URL: ${LLM_OLLAMA_BASE_URL:-localhost:9000}
OPENAI_API_KEY: ${OPENAI_API_KEY:-test_key}

volumes:
- .:/code
depends_on:
- db
- redis
logging:
driver: "json-file"
options:
max-size: "100m"
max-file: "5"

services:
db:
image: postgres:16-alpine
Expand Down Expand Up @@ -25,43 +88,21 @@ services:
# https://github.com/qdrant/qdrant/blob/master/config/config.yaml
QDRANT__SERVICE__HOST: 0.0.0.0
QDRANT__SERVICE__HTTP_PORT: 6333

web:
build: .
env_file:
- .env
environment:
DJANGO_DEBUG: ${DJANGO_DEBUG:-True}
DJANGO_ALLOWED_HOST: ${DJANGO_ALLOWED_HOST:-localhost}
DJNAGO_SECRET_KEY: ${DJANGO_SECRET_KEY}
DATABASE_NAME: ${DATABASE_NAME:-postgres}
DATABASE_USER: ${DATABASE_USER:-postgres}
DATABASE_PASSWORD: ${DATABASE_PASSWORD:-postgres}
DATABASE_PORT: ${DATABASE_PORT:-5432}
DATABASE_HOST: ${DATABASE_HOST:-db}
QDRANT_DB_HOST: ${QDRANT_DB_HOST:-0.0.0.0}
QDRANT_DB_PORT: ${QDRANT_DB_PORT:-6333}
QDRANT_DB_COLLECTION_NAME: ${QDRANT_DB_COLLECTION_NAME:-test}
EMBEDDING_MODEL_URL: ${EMBEDDING_MODEL_URL:-localhost}
EMBEDDING_MODEL_NAME: ${EMBEDDING_MODEL_NAME:-embedding_model}
EMBEDDING_MODEL_VECTOR_SIZE: ${EMBEDDING_MODEL_VECTOR_SIZE:-384}
EMBEDDING_MODEL_TYPE: ${EMBEDDING_MODEL_TYPE:-1}
OLLAMA_EMBEDDING_MODEL_BASE_URL: ${OLLAMA_EMBEDDING_MODEL_BASE_URL:-model}
LLM_TYPE: ${LLM_TYPE:-1}
LLM_MODEL_NAME: ${LLM_MODEL_NAME:-"mistral:latest"}
LLM_OLLAMA_BASE_URL: ${LLM_OLLAMA_BASE_URL:-localhost}
OPENAI_API_KEY: ${OPENAI_API_KEY:-test_key}
<<: *base_server_setup
command: bash -c 'python manage.py runserver 0.0.0.0:8001'
volumes:
- .:/code
ports:
- 127.0.0.1:8001:8001
- 8001:8001

depends_on:
- db
- redis
- qdrant

worker:
<<: *base_server_setup
command: celery -A main worker --loglevel=info

volumes:
postgres-data16:
Expand Down
7 changes: 4 additions & 3 deletions gh-docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
services:
web:
image: $DOCKER_IMAGE_BACKEND
build: !reset null
env_file: !reset null
environment:
CI: "true"
DJANGO_SECRET_KEY: "test"
APP_ENVIRONMENT: CI
APP_TYPE: web


volumes:
- ./coverage/:/code/coverage/

worker: !reset null

3 changes: 3 additions & 0 deletions main/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .celery import app as celery_app

__all__ = ("celery_app",)
13 changes: 13 additions & 0 deletions main/celery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os

from celery import Celery

# Set the default Django settings module for the 'celery' program.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "main.settings")

app = Celery("main")

app.config_from_object("django.conf:settings", namespace="CELERY")

# Load task modules from all registered Django app configs.
app.autodiscover_tasks()
Loading
Loading