-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
92927d9
commit 6753f31
Showing
26 changed files
with
768 additions
and
339 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import AsyncGenerator | ||
from fastapi import Depends, HTTPException, status | ||
from fastapi.security.api_key import APIKeyHeader, APIKey | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
from app.core.config import settings | ||
from app.database.session import AsyncSessionFactory | ||
from app.services.embedding_service import EmbeddingService, OpenAIEmbeddingService | ||
from app.services.extraction_service import ExtractionService | ||
from app.services.similarity_service import SimilarityService | ||
from app.services.text_process_service import TextProcessService | ||
|
||
api_key_header = APIKeyHeader(name="access_token") | ||
|
||
|
||
async def get_db() -> AsyncGenerator[AsyncSession, None]: | ||
async with AsyncSessionFactory() as session: | ||
try: | ||
yield session | ||
finally: | ||
await session.close() | ||
|
||
|
||
async def get_api_key(api_key_header: str = Depends(api_key_header)) -> APIKey: | ||
if api_key_header == settings.API_KEY: | ||
return api_key_header | ||
else: | ||
raise HTTPException( | ||
status_code=status.HTTP_403_FORBIDDEN, | ||
detail="Could not validate credentials", | ||
) | ||
|
||
|
||
def get_embedding_service() -> EmbeddingService: | ||
return OpenAIEmbeddingService(api_key=settings.OPENAI_API_KEY) | ||
|
||
|
||
def get_extraction_service() -> ExtractionService: | ||
return ExtractionService(access_token=settings.GITHUB_ACCESS_TOKEN) | ||
|
||
|
||
def get_similarity_service() -> SimilarityService: | ||
return SimilarityService() | ||
|
||
|
||
def get_text_process_service() -> TextProcessService: | ||
return TextProcessService() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,14 @@ | ||
from fastapi import APIRouter | ||
from .endpoints import score | ||
from .endpoints import data_ingestion, similarity_score | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(score.router, tags=["score"]) | ||
api_router.include_router(score.router, tags=["start-etl-process"]) | ||
|
||
# Including router for data ingestion | ||
api_router.include_router( | ||
data_ingestion.router, prefix="/data_ingestion", tags=["Data Ingestion"] | ||
) | ||
|
||
# Including router for similarity score calculations | ||
api_router.include_router( | ||
similarity_score.router, prefix="/similarity_score", tags=["Similarity Score"] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from sqlalchemy import select | ||
from fastapi import APIRouter, Depends, HTTPException | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
from fastapi import Query | ||
from typing import List | ||
from app.api.dependencies import ( | ||
get_db, | ||
get_embedding_service, | ||
get_text_process_service, | ||
get_extraction_service, | ||
) | ||
from app.models.github_model import GitRepository | ||
from app.crud.github_crud import GitHubCRUD | ||
from app.services.text_process_service import TextProcessService | ||
from app.services.embedding_service import EmbeddingService | ||
from app.services.extraction_service import ExtractionService | ||
import asyncio | ||
from loguru import logger | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/data_ingestion") | ||
async def data_ingestion( | ||
location: str = Query( | ||
"Milan", description="The location to filter GitHub users by." | ||
), | ||
max_users: int = Query(2, description="The maximum number of users to fetch."), | ||
max_repos_per_user: int = Query( | ||
1, description="The maximum number of repositories per user." | ||
), | ||
db: AsyncSession = Depends(get_db), | ||
text_process_service: TextProcessService = Depends(get_text_process_service), | ||
embedding_service: EmbeddingService = Depends(get_embedding_service), | ||
extraction_service: ExtractionService = Depends(get_extraction_service), | ||
): | ||
logger.info("------------ Starting data ingestion process") | ||
try: | ||
logger.info(" ------------ Extracting data") | ||
all_data = await extraction_service.extract_data( | ||
location, max_users, max_repos_per_user | ||
) | ||
|
||
logger.info(" ------------ Filtering new READMEs") | ||
existing_repos = await db.execute(select(GitRepository.repo_name)) | ||
logger.info(f"#### {existing_repos.scalars().all()}") | ||
existing_names = {repo.repo_name for repo in existing_repos.scalars().all()} | ||
logger.info(f"#### {existing_names}") | ||
new_data = [ | ||
data for data in all_data if data["repo_name"] not in existing_names | ||
] | ||
|
||
logger.info(" ------------ Processing READMEs") | ||
processed_readmes = [ | ||
text_process_service.process_text(data["readme"]) for data in new_data | ||
] | ||
|
||
logger.info(" ------------ Generating embeddings") | ||
# embeddings = [ | ||
# embedding_service.generate_embeddings(readme) | ||
# for readme in processed_readmes | ||
# ] | ||
|
||
# embeddings = await asyncio.gather( | ||
# *( | ||
# embedding_service.generate_embeddings(readme) | ||
# for readme in processed_readmes | ||
# ) | ||
# ) | ||
embeddings = [] | ||
for readme in processed_readmes: | ||
embedding = embedding_service.generate_embeddings(readme) | ||
embeddings.append(embedding) | ||
|
||
logger.info("Loading new data into database") | ||
github_crud = GitHubCRUD() | ||
logger.info("Creating new users") | ||
for data, embedding in zip(new_data, embeddings): | ||
|
||
user = await github_crud.create_git_user( | ||
db, | ||
{ | ||
"username": data["username"], | ||
"location": location, | ||
}, | ||
) | ||
logger.info(f"Creating repository: {data['repo_name']}") | ||
await github_crud.create_git_repository( | ||
db, | ||
{ | ||
"username": user.username, | ||
"repo_name": data["repo_name"], | ||
"readme_raw": data["readme"], | ||
"readme_cleaned": processed_readmes[new_data.index(data)], | ||
"readme_embedding": embedding, | ||
}, | ||
) | ||
logger.info("Data ingestion completed successfully") | ||
return {"status": "Data ingestion successful"} | ||
except Exception as e: | ||
logger.error(f"Data ingestion process failed: {e}") | ||
raise HTTPException(status_code=500, detail=str(e)) |
Empty file.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.