Skip to content

Commit

Permalink
round 1 cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
shariqm-modal committed Nov 21, 2024
1 parent 1984cf0 commit 5a4e76a
Showing 1 changed file with 95 additions and 51 deletions.
146 changes: 95 additions & 51 deletions 06_gpu_and_ml/protein_fold/protein_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,76 @@
# output-directory: "/tmp/protein-fold"
# ---

# Guide
# - A protein is a sequence of atoms, some but not all of those atoms may belong
# to an Amino Acid (AA).
# - Residue is the same as Amino Acid.
# - A residue ID corresponds to a specific AA, it is not an index in the
# protein. e.g. residue ID 2 has the name "ILE" which is Isoleucine.
#
# Acronyms
## AA - Amino Acid
## SSE - Secondary Structure
# Background
## Atom14 refers to a reduced representation of 14 atoms per Amino Acid. Alpha
## Fold and ESMFold seem to run in this representation space for efficiency.
## Atom37 refers to an expanded representation that is computed from Atom14 and
## more closely resembles PDB structure.

# TODO
# - Do we need PDB & PDBX files? Kind of silly but py3DMol doesn't seem to like
# PDB files.
# - Error handling for bad sequence. E.g. non amino acid letter.

from fastapi import FastAPI
from fastapi.responses import FileResponse
# # Protein fold anything with ESM3, plus some powerful visualizations

# If you haven't heard of AlphaFold or ESM3, you've likely been living under a
# rock, it's time to come out! Protein Folding is a hot topic in the world of
# bioinformatics and machine learning. In this example, we'll show you how to
# use ESM3 to predict the three-dimensional structure of a protein from its raw
# amino acid sequence. We'll also show you how to visualize the results using
# [py3DMol](https://3dmol.org/).

# If you're new to protein folding check out the next section for a brief
# overview, otherwise you can skip to Basic Setup.

# ## What is Protein Folding?

# A protein's three-dimensional shape determines its function in the body, from
# catalyzing biochemical reactions to building cellular structures. To develop
# new drugs and treatments for diseases, researchers need to understand how
# potential drugs bind to target proteins in the human body and thus need to predict
# their threre-dimensional structure from their raw amino acid sequence.

# Historically determining the protein structure of a single sequence could
# take years of wet-lab experiments and millions of dollars, but with the advent
# of deep protein folding models like ESM3, this can be done in seconds for a
# few dollars.

# ## Basic Setup

import logging as L
import modal
from pathlib import Path
import time

MINUTES = 60 # seconds

app_name = "protein_fold"
app = modal.App(app_name)

# We'll use A10G GPUs for inference, which are able to generate 3D structures
# in seconds for a fews cents.

gpu = "A10G"

# ### Create a Volume to store weights, data, and HTML

# On the first launch we'll be downloading the model weights from Hugging
# Face(?), to speed up future cold starts we'll store the weights on a
# [Volume](https://modal.com/docs/guide/volumes) for quick reads. We'll also
# use this volume to cache PDB data from the web and store HTML files for
# visualization.
volume = modal.Volume.from_name(
"example-protein-fold", create_if_missing=True
)
VOLUME_PATH = Path("/vol/data")
PDBS_PATH = VOLUME_PATH / "pdbs"
HTMLS_PATH = VOLUME_PATH / "htmls"

# ### Define dependencies in container images

# The container image for inference is based on Modal's default slim Debian
# Linux image with `esm` for loading and running or model
# for the
esm3_image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"esm==3.0.5",
"torch==2.4.1",
"torch==2.4.1", # Needed? FIXME
)
)

# We'll also define a web app image for extracting PDB data. FIXME
web_app_image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(
Expand All @@ -47,39 +81,46 @@
"pydssp==0.9.0",
"py3Dmol==2.4.0",
"torch==2.4.1",
"fastapi[standard]==0.115.4",
)
)

app = modal.App("protein_fold")

volume = modal.Volume.from_name(
"example-protein-fold", create_if_missing=True
)
volume_path = Path("/vol/data")
PDBS_PATH = volume_path / "pdbs"
HTMLS_PATH = volume_path / "htmls"

# We can also "pre-import" libraries that will be used by the functions we run on Modal in a given image
# using the `with image.imports` context manager.
with esm3_image.imports():
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
import torch

with web_app_image.imports():
from py3DmolWrapper import py3DMolViewWrapper

SECONDS_PER_MINUTE = 60
# ## Defining a `Model` inference class for ESM3

# Next, we map the model's setup and inference code onto Modal.

# 1. We run any setup that can be persisted to disk in methods decorated with `@build`.
# In this example, that includes downloading the model weights.
# 2. We run any additional setup, like moving the model to the GPU, in methods decorated with `@enter`.
# We do our model optimizations in this step. For details, see the section on `torch.compile` below.
# 3. We run the actual inference in methods decorated with `@method`.

# Here we define the

@app.cls(
gpu='A10G',
image=esm3_image,
timeout=20 * SECONDS_PER_MINUTE,
secrets=[modal.Secret.from_name("huggingface-secret")],
gpu=gpu,
timeout=20 * MINUTES,
)
class ModelInference:
@modal.build()
@modal.enter()
def build(self):
from esm.models.esm3 import ESM3
self.model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda")
# Optimizations from:
# XXX Remove:
# https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb#scrollTo=cc0f0186
self.model = self.model.half()
torch.backends.cuda.matmul.allow_tf32 = True
Expand All @@ -88,38 +129,44 @@ def build(self):

@modal.method()
def generate(self, sequence: str) -> bool:
start_time = time.monotonic()
# num_steps cannot be longer than the seq length in ESM.
start_time = time.monotonic() # TODO Remove

num_steps = min(len(sequence), self.max_steps)
structure_generation_config = GenerationConfig(
track="structure",
num_steps=min(len(sequence), self.max_steps)
)
num_steps=num_steps,
)
esm_protein = self.model.generate(
ESMProtein(sequence=sequence),
structure_generation_config
)
)

latency_s = time.monotonic() - start_time
print (f"Latency {latency_s:.2f} seconds.")
print (f"Inference latency: {latency_s:.2f} seconds.")

# Check that esm_protein did not error.
print ("Checking for errors...")
if hasattr(esm_protein, "error_msg"):
raise ValueError(esm_protein.error_msg)

esm_protein.ptm = esm_protein.ptm.to('cpu') # Move off GPU
print ("Moving all data off GPU before returning to caller...")
esm_protein.ptm = esm_protein.ptm.to('cpu')
return esm_protein

assets_path = Path(__file__).parent / "assets"

web_app = FastAPI()
@app.function(
image=web_app_image,
concurrency_limit=1,
allow_concurrent_inputs=1000,
volumes={volume_path: volume},
volumes={VOLUME_PATH: volume},
mounts=[modal.Mount.from_local_dir(assets_path, remote_path="/assets")],
)
@modal.asgi_app()
def protein_fold_fastapi_app():
from fastapi import FastAPI
from fastapi.responses import FileResponse

import gradio as gr
from gradio.routes import mount_gradio_app

Expand Down Expand Up @@ -216,7 +263,6 @@ def build_database_html(sequence, maybe_stale_pdb_id):

def remove_nested_quotes(html):
"""Remove nested quotes in HEADER"""

if html.find("HEADER") == -1:
return html

Expand All @@ -225,13 +271,9 @@ def remove_nested_quotes(html):
# Getting the HEADER double quote end by matching: \n","pdb");
j = html[i:].index('"pdb");') + i - len('",')
header_html = html[i:j]
header_html = header_html.replace("\\'", "")
header_html = header_html.replace('\\"', "")
header_html = header_html.replace("'", "")
header_html = header_html.replace('"', "")
print ("Len before:", len(html))
for delete_me in ("\\'", '\\"', "'", '"'):
header_html = header_html.replace(delete_me, "")
html = html[:i] + header_html + html[j:]
print ("Len after:", len(html))
return html

def postprocess_html(html):
Expand Down Expand Up @@ -295,6 +337,8 @@ def run_esm_and_graph(sequence, maybe_stale_pdb_id):
return [postprocess_html(h)
for h in (esm_sse_html, esm_pLDDT_html, database_sse_html)]

web_app = FastAPI()

# custom styles: an icon, a background, and a theme
@web_app.get("/favicon.ico", include_in_schema=False)
async def favicon():
Expand Down

0 comments on commit 5a4e76a

Please sign in to comment.