Skip to content

Commit

Permalink
UPDATE: start from sync
Browse files Browse the repository at this point in the history
  • Loading branch information
mazzasaverio committed Feb 23, 2024
1 parent 6753f31 commit 61f1a1e
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 153 deletions.
37 changes: 13 additions & 24 deletions backend/app/app/api/v1/endpoints/data_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,36 @@ async def data_ingestion(
1, description="The maximum number of repositories per user."
),
db: AsyncSession = Depends(get_db),
extraction_service: ExtractionService = Depends(get_extraction_service),
text_process_service: TextProcessService = Depends(get_text_process_service),
embedding_service: EmbeddingService = Depends(get_embedding_service),
extraction_service: ExtractionService = Depends(get_extraction_service),
):
logger.info("------------ Starting data ingestion process")

try:
logger.info(" ------------ Extracting data")
all_data = await extraction_service.extract_data(
logger.info("1) ##### Extracting data")
all_data = extraction_service.extract_data(
location, max_users, max_repos_per_user
)

logger.info(" ------------ Filtering new READMEs")
logger.info("2) ##### Filtering new READMEs")
existing_repos = await db.execute(select(GitRepository.repo_name))
logger.info(f"#### {existing_repos.scalars().all()}")

existing_names = {repo.repo_name for repo in existing_repos.scalars().all()}
logger.info(f"#### {existing_names}")

new_data = [
data for data in all_data if data["repo_name"] not in existing_names
]

logger.info(" ------------ Processing READMEs")
logger.info("3) ##### Processing READMEs")
processed_readmes = [
text_process_service.process_text(data["readme"]) for data in new_data
]

logger.info(" ------------ Generating embeddings")
# embeddings = [
# embedding_service.generate_embeddings(readme)
# for readme in processed_readmes
# ]

# embeddings = await asyncio.gather(
# *(
# embedding_service.generate_embeddings(readme)
# for readme in processed_readmes
# )
# )
embeddings = []
for readme in processed_readmes:
embedding = embedding_service.generate_embeddings(readme)
embeddings.append(embedding)
logger.info("4) ##### Generating embeddings")
embeddings = [
embedding_service.generate_embeddings(readme)
for readme in processed_readmes
]

logger.info("Loading new data into database")
github_crud = GitHubCRUD()
Expand Down
6 changes: 5 additions & 1 deletion backend/app/app/api/v1/endpoints/similarity_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from app.api.dependencies import get_db, get_embedding_service, get_similarity_service
from app.services.embedding_service import EmbeddingService
from app.services.similarity_service import SimilarityService
from loguru import logger

router = APIRouter()

Expand All @@ -15,7 +16,7 @@ async def calculate_similarity(
similarity_service: SimilarityService = Depends(get_similarity_service),
):
"""
Calculate similarity scores between the input text and embeddings in the `git_repositories_n` table.
Calculate similarity scores between the input text and embeddings in the `git_repositories_n2` table.
Args:
input_text (str): The input text from the user.
Expand All @@ -28,10 +29,13 @@ async def calculate_similarity(
"""
# Generate embedding for the input text
input_embedding = embedding_service.generate_embeddings(input_text)
logger.info(f"Input embedding: {input_embedding}")

# Calculate similarity scores
similarities = await similarity_service.calculate_similarity(db, input_embedding)

logger.info(f"Similarities: {similarities}")

# Format and return the results
return {
"similarities": [
Expand Down
2 changes: 1 addition & 1 deletion backend/app/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Settings(BaseSettings):

@property
def DATABASE_URI(self) -> str:
return f"postgresql://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"

@property
def ASYNC_DATABASE_URI(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions backend/app/app/models/github_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@


class GitUser(Base):
__tablename__ = "git_users_n"
__tablename__ = "git_users_n2"
username = Column(String, primary_key=True)
location = Column(String)
updated_at = Column(DateTime, default=datetime.utcnow)


class GitRepository(Base):
__tablename__ = "git_repositories_n"
__tablename__ = "git_repositories_n2"
repo_id = Column(Integer, primary_key=True)
repo_name = Column(String)
username = Column(String, ForeignKey("git_users_n.username"))
username = Column(String, ForeignKey("git_users_n2.username"))
readme_raw = Column(String)
readme_cleaned = Column(String)
readme_embedding = Column(Vector)
Expand Down
74 changes: 5 additions & 69 deletions backend/app/app/services/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@
from typing import List
import numpy as np
from loguru import logger

# Import OpenAI GPT-3
from openai import OpenAI

from app.core.config import settings

# # Import Sentence Transformers
# from sentence_transformers import SentenceTransformer


class EmbeddingGenerator(ABC):
@abstractmethod
Expand All @@ -29,86 +23,28 @@ def __init__(self, model: str = "text-embedding-ada-002", api_key: str = None):
def generate_embeddings(self, readme_text, model="text-embedding-3-small"):
# Split the input text into chunks that do not exceed the model's token limit
max_length = 8000 # Adjust based on the model's limitations
chunks = [
readme_text[i : i + max_length]
for i in range(0, len(readme_text), max_length)
]

logger.debug("Generating embeddings for cleaned text")

all_embeddings = []
for chunk in chunks:
# Generate embeddings for each chunk
chunk_embedding = (
self.client.embeddings.create(input=[chunk], model=model)
.data[0]
.embedding
)
all_embeddings.extend(chunk_embedding)

logger.debug("Embeddings generated")

# Optionally, you can average the embeddings from all chunks if needed
# This step depends on how you plan to use the embeddings
# Example: averaged_embedding = np.mean(np.array(all_embeddings), axis=0)
text = readme_text[:max_length]

return all_embeddings
embedding = (
self.client.embeddings.create(input=[text], model=model).data[0].embedding
)

# async def generate_embeddings(self, readme_text, model="text-embedding-3-small"):
# logger.info("Starting to generate embeddings")
# max_length = 8000 # Adjust based on the model's limitations
# chunks = [
# readme_text[i : i + max_length]
# for i in range(0, len(readme_text), max_length)
# ]

# logger.debug("Generating embeddings for cleaned text")

# all_embeddings = []
# for chunk in chunks:
# logger.debug(f"Generating embeddings for chunk with length {len(chunk)}")
# chunk_embedding = (
# await self.client.embeddings.create(input=[chunk], model=model)
# .data[0]
# .embedding
# )
# all_embeddings.extend(chunk_embedding)

# logger.debug("Embeddings generated")

# return all_embeddings


# class SentenceTransformerEmbeddingService(EmbeddingGenerator):
# def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
# logger.info(f"Initializing SentenceTransformerEmbeddingService with model: {model_name}")
# self.model = SentenceTransformer(model_name)

# def generate_embeddings(self, text: str) -> List[float]:
# logger.info("Starting to generate embeddings using SentenceTransformer")
# embedding = self.model.encode(text)
# logger.debug("Embeddings generated")
# return embedding.tolist()
return embedding


class EmbeddingService:
def __init__(self, strategy: str = "openai", api_key: str = None):
logger.info(f"Initializing EmbeddingService with strategy: {strategy}")
if strategy == "openai" and api_key:
self.strategy = OpenAIEmbeddingService(api_key=api_key)
# elif strategy == "sentence_transformer":
# self.strategy = SentenceTransformerEmbeddingService()
else:
logger.error("Invalid embedding strategy or missing API key for OpenAI.")
raise ValueError(
"Invalid embedding strategy or missing API key for OpenAI."
)

def generate_embeddings(self, text: str) -> List[float]:
logger.info("Generating embeddings for given text")

embeddings = self.strategy.generate_embeddings(text)

logger.debug(f"Embeddings generated: {embeddings}")

return embeddings
Loading

0 comments on commit 61f1a1e

Please sign in to comment.