diff --git a/06_gpu_and_ml/protein_fold/protein_fold.py b/06_gpu_and_ml/protein_fold/protein_fold.py index ea478ba3c..02b4a366b 100644 --- a/06_gpu_and_ml/protein_fold/protein_fold.py +++ b/06_gpu_and_ml/protein_fold/protein_fold.py @@ -54,18 +54,16 @@ gpu = "A10G" -# ### Create a Volume to store weights and pdb data +# ### Create a Volume to store cached pdb data -# To minimize cold start times we'll store the model weights on a modal -# [Volume](https://modal.com/docs/guide/volumes) so they can be quickly loaded -# onto the GPU. We'll also store pdb data on the volume to minimize hitting -# the RCSB server for the same data multiple times. +# To minimize htting the RCSB server for pdb data we'll cache pdb data on a modal +# [Volume](https://modal.com/docs/guide/volumes) so we never have to read the +# same data twice. volume = modal.Volume.from_name("example-protein-fold", create_if_missing=True) VOLUME_PATH = Path("/vol/data") PDBS_PATH = VOLUME_PATH / "pdbs" -HTMLS_PATH = VOLUME_PATH / "htmls" # ### Define dependencies in container images @@ -144,7 +142,6 @@ class Model: def setup_model(self): self.model: ESM3InferenceClient = ESM3.from_pretrained( "esm3_sm_open_v1" - # How to store on volume FIXME ) @modal.build() @@ -165,8 +162,6 @@ def enter(self): @modal.method() def inference(self, sequence: str) -> bool: - start_time = time.monotonic() # TODO Remove - num_steps = min(len(sequence), self.max_steps) structure_generation_config = GenerationConfig( track="structure", @@ -177,9 +172,6 @@ def inference(self, sequence: str) -> bool: ESMProtein(sequence=sequence), structure_generation_config ) - latency_s = time.monotonic() - start_time - L.info(f"Inference latency: {latency_s:.2f} seconds.") - L.info("Checking for errors...") if hasattr(esm_protein, "error_msg"): raise ValueError(esm_protein.error_msg) @@ -189,7 +181,7 @@ def inference(self, sequence: str) -> bool: return esm_protein -# ### Serving a Gradio UI `web_endpoint` with an `asgi_app` +# ### Serving a Gradio UI with an `asgi_app` # The `ModelInference` class above is available for use # from any other Python environment with the right Modal credentials @@ -198,7 +190,10 @@ def inference(self, sequence: str) -> bool: # But we can also expose it via a web app using the `@asgi_app` decorator. Here # we will specifically create a Gradio web app that allows us to visualize # the 3D structure output of the ESM3 model as well as any known structures from -# the RCSB Protein Data Bank. +# the RCSB Protein Data Bank. In addition, to understand the confidence level +# of the ESM3 output, we'll include a visualization of the [pLDDT](https://www.ebi.ac.uk/training/online/courses/alphafold/inputs-and-outputs/evaluating-alphafolds-predicted-structures-using-confidence-scores/plddt-understanding-local-confidence/) (predicted Local Distance Difference Test) scores +# for each residue (amino acid) in the folded protein structure. +# # You should see the URL for this UI in the output of `modal deploy` # or on your [Modal app dashboard](https://modal.com/apps) for this app. @@ -231,7 +226,6 @@ def extract_data(esm_protein): def run_esm(sequence, maybe_stale_pdb_id): L.info("Removing whitespace from text input") sequence = sequence.strip() - maybe_stale_pdb_id = maybe_stale_pdb_id.strip() L.info("Running ESM") esm_protein: ESMProtein = Model().inference.remote(sequence) @@ -247,11 +241,10 @@ def run_esm(sequence, maybe_stale_pdb_id): width, height, esm_protein.to_pdb_string(), residue_pLDDTs ) + maybe_stale_pdb_id = maybe_stale_pdb_id.strip() pdb_sequence = get_sequence(maybe_stale_pdb_id) if pdb_sequence == sequence: pdb_id = maybe_stale_pdb_id - # if pdb_id.find("_") != -1: # "1CRN_A" -> "1CRN" - # pdb_id = pdb_id[:pdb_id.index("_")] # WHy was this here. L.info(f"Constructing HTML for RCSB entry for {pdb_id}") pdb_string, residue_id_to_sse = extract_pdb_and_residues(pdb_id) rcsb_sse_html = ( @@ -386,7 +379,7 @@ def extract_pdb_id(example_idx): ) htmls.append(gr.HTML()) - # Column 3: Database showing Crystalized form of Protein if avail. + # Column 3: Crystalized form of Protein if available. with gr.Column(): gr.Markdown("## Crystalized Structure from RCSB's PDB") gr.Image( # output image component @@ -413,7 +406,6 @@ def extract_pdb_id(example_idx): path="/", ) - @app.local_entrypoint() def main(): sequence = ( @@ -429,11 +421,15 @@ def main(): # The remainder of this code is boilerplate. + # ### Extracting Sequences and Residues from PDBs # A fair amount of code is required to extract sequences and residue # information from pdb strings, we build several helper functions for it below. +# There is more complex code for running py3Dmol which you can find in +# the `py3DmolWrapper.py` file. + def fetch_pdb_if_necessary(pdb_id, pdb_type): """Fetch PDB from RCSB if not already cached.""" @@ -514,5 +510,4 @@ def postprocess_html(html): f"""allow="midi; display-capture;" frameborder="0" """ f"""srcdoc='{html_wrapped}'>""" ) - return iframe_html # FIXME? - # return html_wrapped + return iframe_html