Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
shariqm-modal committed Dec 4, 2024
1 parent b0a8e69 commit eec8f6c
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions 06_gpu_and_ml/protein_fold/protein_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,16 @@

gpu = "A10G"

# ### Create a Volume to store weights and pdb data
# ### Create a Volume to store cached pdb data

# To minimize cold start times we'll store the model weights on a modal
# [Volume](https://modal.com/docs/guide/volumes) so they can be quickly loaded
# onto the GPU. We'll also store pdb data on the volume to minimize hitting
# the RCSB server for the same data multiple times.
# To minimize htting the RCSB server for pdb data we'll cache pdb data on a modal
# [Volume](https://modal.com/docs/guide/volumes) so we never have to read the
# same data twice.


volume = modal.Volume.from_name("example-protein-fold", create_if_missing=True)
VOLUME_PATH = Path("/vol/data")
PDBS_PATH = VOLUME_PATH / "pdbs"
HTMLS_PATH = VOLUME_PATH / "htmls"


# ### Define dependencies in container images
Expand Down Expand Up @@ -144,7 +142,6 @@ class Model:
def setup_model(self):
self.model: ESM3InferenceClient = ESM3.from_pretrained(
"esm3_sm_open_v1"
# How to store on volume FIXME
)

@modal.build()
Expand All @@ -165,8 +162,6 @@ def enter(self):

@modal.method()
def inference(self, sequence: str) -> bool:
start_time = time.monotonic() # TODO Remove

num_steps = min(len(sequence), self.max_steps)
structure_generation_config = GenerationConfig(
track="structure",
Expand All @@ -177,9 +172,6 @@ def inference(self, sequence: str) -> bool:
ESMProtein(sequence=sequence), structure_generation_config
)

latency_s = time.monotonic() - start_time
L.info(f"Inference latency: {latency_s:.2f} seconds.")

L.info("Checking for errors...")
if hasattr(esm_protein, "error_msg"):
raise ValueError(esm_protein.error_msg)
Expand All @@ -189,7 +181,7 @@ def inference(self, sequence: str) -> bool:
return esm_protein


# ### Serving a Gradio UI `web_endpoint` with an `asgi_app`
# ### Serving a Gradio UI with an `asgi_app`

# The `ModelInference` class above is available for use
# from any other Python environment with the right Modal credentials
Expand All @@ -198,7 +190,10 @@ def inference(self, sequence: str) -> bool:
# But we can also expose it via a web app using the `@asgi_app` decorator. Here
# we will specifically create a Gradio web app that allows us to visualize
# the 3D structure output of the ESM3 model as well as any known structures from
# the RCSB Protein Data Bank.
# the RCSB Protein Data Bank. In addition, to understand the confidence level
# of the ESM3 output, we'll include a visualization of the [pLDDT](https://www.ebi.ac.uk/training/online/courses/alphafold/inputs-and-outputs/evaluating-alphafolds-predicted-structures-using-confidence-scores/plddt-understanding-local-confidence/) (predicted Local Distance Difference Test) scores
# for each residue (amino acid) in the folded protein structure.
#

# You should see the URL for this UI in the output of `modal deploy`
# or on your [Modal app dashboard](https://modal.com/apps) for this app.
Expand Down Expand Up @@ -231,7 +226,6 @@ def extract_data(esm_protein):
def run_esm(sequence, maybe_stale_pdb_id):
L.info("Removing whitespace from text input")
sequence = sequence.strip()
maybe_stale_pdb_id = maybe_stale_pdb_id.strip()

L.info("Running ESM")
esm_protein: ESMProtein = Model().inference.remote(sequence)
Expand All @@ -247,11 +241,10 @@ def run_esm(sequence, maybe_stale_pdb_id):
width, height, esm_protein.to_pdb_string(), residue_pLDDTs
)

maybe_stale_pdb_id = maybe_stale_pdb_id.strip()
pdb_sequence = get_sequence(maybe_stale_pdb_id)
if pdb_sequence == sequence:
pdb_id = maybe_stale_pdb_id
# if pdb_id.find("_") != -1: # "1CRN_A" -> "1CRN"
# pdb_id = pdb_id[:pdb_id.index("_")] # WHy was this here.
L.info(f"Constructing HTML for RCSB entry for {pdb_id}")
pdb_string, residue_id_to_sse = extract_pdb_and_residues(pdb_id)
rcsb_sse_html = (
Expand Down Expand Up @@ -386,7 +379,7 @@ def extract_pdb_id(example_idx):
)
htmls.append(gr.HTML())

# Column 3: Database showing Crystalized form of Protein if avail.
# Column 3: Crystalized form of Protein if available.
with gr.Column():
gr.Markdown("## Crystalized Structure from RCSB's PDB")
gr.Image( # output image component
Expand All @@ -413,7 +406,6 @@ def extract_pdb_id(example_idx):
path="/",
)


@app.local_entrypoint()
def main():
sequence = (
Expand All @@ -429,11 +421,15 @@ def main():

# The remainder of this code is boilerplate.


# ### Extracting Sequences and Residues from PDBs

# A fair amount of code is required to extract sequences and residue
# information from pdb strings, we build several helper functions for it below.

# There is more complex code for running py3Dmol which you can find in
# the `py3DmolWrapper.py` file.


def fetch_pdb_if_necessary(pdb_id, pdb_type):
"""Fetch PDB from RCSB if not already cached."""
Expand Down Expand Up @@ -514,5 +510,4 @@ def postprocess_html(html):
f"""allow="midi; display-capture;" frameborder="0" """
f"""srcdoc='{html_wrapped}'></iframe>"""
)
return iframe_html # FIXME?
# return html_wrapped
return iframe_html

0 comments on commit eec8f6c

Please sign in to comment.