Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add queries for adding new tables and routes for molecule data download #53

Merged
merged 7 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions add_data_tables.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
CREATE TABLE ml_data (
molecule_id INTEGER,
property TEXT,
max DOUBLE PRECISION,
min DOUBLE PRECISION,
delta DOUBLE PRECISION,
vburminconf DOUBLE PRECISION,
boltzmann_average DOUBLE PRECISION
);

ALTER TABLE ml_data
ADD CONSTRAINT fk_molecule_id
FOREIGN KEY (molecule_id) REFERENCES molecule(molecule_id);

CREATE INDEX idx_ml_data_molecule_id ON ml_data(molecule_id);

\COPY ml_data FROM 'ml_data_json_table.csv' DELIMITER ',' CSV HEADER;

CREATE TABLE dft_data (
molecule_id INTEGER,
property TEXT,
max DOUBLE PRECISION,
min DOUBLE PRECISION,
delta DOUBLE PRECISION,
vburminconf DOUBLE PRECISION,
boltzmann_average DOUBLE PRECISION
);

ALTER TABLE dft_data
ADD CONSTRAINT fk_molecule_id
FOREIGN KEY (molecule_id) REFERENCES molecule(molecule_id);

CREATE INDEX idx_dft_data_molecule_id ON dft_data(molecule_id);

\COPY dft_data FROM 'dft_data_json_table.csv' DELIMITER ',' CSV HEADER;

CREATE TABLE xtb_data (
molecule_id INTEGER,
property TEXT,
max DOUBLE PRECISION,
min DOUBLE PRECISION,
boltzmann_average DOUBLE PRECISION
);

ALTER TABLE xtb_data
ADD CONSTRAINT fk_molecule_id
FOREIGN KEY (molecule_id) REFERENCES molecule(molecule_id);

CREATE INDEX idx_xtb_data_molecule_id ON xtb_data(molecule_id);

\COPY xtb_data FROM 'xtb_data_json_table.csv' DELIMITER ',' CSV HEADER;


CREATE TABLE xtb_ni_data (
molecule_id INTEGER,
property TEXT,
boltzmann_average DOUBLE PRECISION,
max DOUBLE PRECISION,
min DOUBLE PRECISION
);

ALTER TABLE xtb_ni_data
ADD CONSTRAINT fk_molecule_id
FOREIGN KEY (molecule_id) REFERENCES molecule(molecule_id);

CREATE INDEX idx_xtb_ni_data_molecule_id ON xtb_ni_data(molecule_id);

\COPY xtb_ni_data FROM 'xtb_ni_data_json_table.csv' DELIMITER ',' CSV HEADER;
187 changes: 164 additions & 23 deletions backend/app/app/api/v2/endpoints/molecule.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,70 @@
from multiprocessing.sharedctypes import Value
from re import sub
"""
API endpoints for molecules.
Prefixed with /molecules
"""

import io
from typing import List, Optional, Any

from app import schemas
from app.api import deps
from app.db.session import models
import pandas as pd

from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from rdkit import Chem
from sqlalchemy import exc, text
from sqlalchemy.orm import Session

from app import schemas
from app.api import deps
from app.db.session import models

router = APIRouter()

def _pandas_long_to_wide(df):
"""
Internal function for reshaping from long to wide format for CSV export.
"""
# Reshape the data into wide format
df_wide = df.pivot(index=["molecule_id", "smiles"], columns="property")

# Flatten multi-level columns and reset the index
df_wide.columns = ['_'.join(col[::-1]).strip() for col in df_wide.columns.values]

df_wide.reset_index(inplace=True)

df_wide.dropna(axis=1, inplace=True)

return df_wide

def _pandas_to_buffer(df):
"""Internal function for converting dataframe to buffer"""

# Create a buffer to hold the csv file.
buffer = io.StringIO()

# Write the dataframe to the buffer.
df.to_csv(buffer, index=False)

# Set the buffer to the beginning of the file.
buffer.seek(0)

return buffer

def _valid_molecule_id(molecule_id, db):

# Generalized - get max molecule id.
query = text(f"SELECT MAX(molecule_id) FROM molecule;")
max_molecule_id = db.execute(query).fetchall()[0][0]

# Check to see if the molecule_id is within range.
if molecule_id > max_molecule_id:
raise HTTPException(status_code=404, detail=f"Molecule with ID supplied not found, the maximum ID is {max_molecule_id}")

# Check to see if the molecule_id is within range.
if molecule_id <= 0:
raise HTTPException(status_code=500)

return

def valid_smiles(smiles):
"""Check to see if a smile string is valid to represent a molecule.
Expand Down Expand Up @@ -47,6 +100,110 @@ def valid_smiles(smiles):

return smiles

@router.get("/data/export/{molecule_id}")
async def get_molecule_data(molecule_id: int,
data_type: str="ml",
return_type: str="csv",
db: Session = Depends(deps.get_db)):

# Check to see if the molecule_id is valid.
_valid_molecule_id(molecule_id, db)

# Check for valid data type.
if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]:
raise HTTPException(status_code=400, detail="Invalid data type.")

if return_type.lower() not in ["csv", "json"]:
raise HTTPException(status_code=400, detail="Invalid return type.")

# Use pandas.rea`` d_sql_query to get the data.
table_name = f"{data_type}_data"
query = text(f"""
SELECT t.*, m.SMILES
FROM {table_name} t
JOIN molecule m ON t.molecule_id = m.molecule_id
WHERE t.molecule_id = :molecule_id
""")

stmt = query.bindparams(molecule_id=molecule_id)

df = pd.read_sql_query(stmt, db.bind)

df_wide = _pandas_long_to_wide(df)

if return_type.lower() == "json":
json_data = df_wide.to_dict(orient="records")[0]
return json_data
else:
buffer = _pandas_to_buffer(df_wide)

# Return the buffer as a streaming response.
response = StreamingResponse(buffer, media_type="text/csv")
response.headers["Content-Disposition"] = f"attachment; filename={molecule_id}_{data_type}.csv"
return response

@router.get("/data/export")
async def get_molecules_data(molecule_ids: str,
data_type: str="ml",
return_type: str="csv",
context: Optional[str]=None,
db: Session = Depends(deps.get_db)):


# Sanitize molecule ids
int_check = [x.strip().isdigit() for x in molecule_ids.split(",")]

if not all(int_check):
raise HTTPException(status_code=400, detail="Invalid molecule ids.")

molecule_ids_list = [int(x) for x in molecule_ids.split(",")]
first_molecule_id = molecule_ids_list[0]
num_molecules = len(molecule_ids_list)

if context:
if context.lower() not in ["substructure", "pca_neighbors", "umap_neighbors"]:
raise HTTPException(status_code=400, detail="Invalid context.")

# Check to see if all molecule ids are valid.
[ _valid_molecule_id(int(x), db) for x in molecule_ids.split(",") ]

# Check for valid data type.
if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]:
raise HTTPException(status_code=400, detail="Invalid data type.")

if return_type.lower() not in ["csv", "json"]:
raise HTTPException(status_code=400, detail="Invalid return type.")

# Use pandas.read_sql_query to get the data.
table_name = f"{data_type}_data"

query = text(f"""
SELECT t.*, m.SMILES
FROM {table_name} t
JOIN molecule m ON t.molecule_id = m.molecule_id
WHERE t.molecule_id IN ({molecule_ids})
""")

df = pd.read_sql_query(query, db.bind)

df_wide = _pandas_long_to_wide(df)

if return_type.lower() == "json":
json_data = df_wide.to_dict(orient="records")
return json_data
else:
buffer = _pandas_to_buffer(df_wide)

# Return the buffer as a streaming response.
filename = f"{data_type}_{first_molecule_id}_{num_molecules}"
if context:
filename += f"_{context}"
filename += ".csv"
response = StreamingResponse(buffer, media_type="text/csv")
response.headers["Content-Disposition"] = f"attachment; filename={filename}"

return response


@router.get("/umap", response_model=List[schemas.MoleculeSimple])
def get_molecule_umap(
Expand Down Expand Up @@ -100,13 +257,7 @@ def get_molecule_umap(
@router.get("/{molecule_id}", response_model=schemas.Molecule)
def get_a_single_molecule(molecule_id: int, db: Session = Depends(deps.get_db)):

# Generalized - get max molecule id.
query = text(f"SELECT MAX(molecule_id) FROM molecule;")
max_molecule_id = db.execute(query).fetchall()[0][0]

# Check to see if the molecule_id is within range.
if molecule_id > max_molecule_id:
raise HTTPException(status_code=404, detail=f"Molecule with ID supplied not found, the maximum ID is {max_molecule_id}")
_valid_molecule_id(molecule_id, db)

molecule = (
db.query(models.molecule)
Expand Down Expand Up @@ -175,17 +326,7 @@ def search_neighbors(

type = type.lower()

# Generalized - get max molecule id.
query = text(f"SELECT MAX(molecule_id) FROM molecule;")
max_molecule_id = db.execute(query).fetchall()[0][0]

# Check to see if the molecule_id is within range.
if molecule_id > max_molecule_id:
raise HTTPException(status_code=404, detail=f"Molecule with ID supplied not found, the maximum ID is {max_molecule_id}")

# Check to see if the molecule_id is within range.
if molecule_id <= 0:
raise HTTPException(status_code=500)
_valid_molecule_id(molecule_id, db)

# Check for valid neighbor type.
if type not in ["pca", "umap"]:
Expand Down
1 change: 1 addition & 0 deletions backend/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- alembic
- psycopg2-binary
- sqlalchemy
- pandas
- tenacity
- uvicorn
- curl
Expand Down
Loading