diff --git a/06_gpu_and_ml/protein_fold/protein_fold.py b/06_gpu_and_ml/protein_fold/protein_fold.py index 664ca4a25..ea478ba3c 100644 --- a/06_gpu_and_ml/protein_fold/protein_fold.py +++ b/06_gpu_and_ml/protein_fold/protein_fold.py @@ -2,13 +2,14 @@ # output-directory: "/tmp/protein-fold" # --- -# # Protein fold anything with ESM3, plus some powerful visualizations +# # Protein Fold sequences with ESM3 -# If you haven't heard of AlphaFold or ESM3, you've likely been living under a -# rock, it's time to come out! Protein Folding is a hot topic in the world of -# bioinformatics and machine learning. In this example, we'll show you how to +# If you haven't heard of [AlphaFold](https://deepmind.google/technologies/alphafold/) +# or [ESM3](https://github.com/facebookresearch/esm), you've likely been living under a +# rock, it's time to come out! Protein folding is all the rage in the world of +# bioinformatics and ML. In this example, we'll show you how to # use ESM3 to predict the three-dimensional structure of a protein from its raw -# amino acid sequence. We'll also show you how to visualize the results using +# amino acid sequence. As a bonus, we'll also show you how to visualize the results using # [py3DMol](https://3dmol.org/). # If you're new to protein folding check out the next section for a brief @@ -16,16 +17,25 @@ # ## What is Protein Folding? -# A protein's three-dimensional shape determines its function in the body, from -# catalyzing biochemical reactions to building cellular structures. To develop -# new drugs and treatments for diseases, researchers need to understand how -# potential drugs bind to target proteins in the human body and thus need to predict -# their threre-dimensional structure from their raw amino acid sequence. +# A protein's three-dimensional shape determines it's function in the body, +# i.e. everything from catalyzing biochemical reactions to building cellular +# structures. To develop new drugs and treatments for diseases, researchers +# need to understand how potential drugs bind to target proteins in the human +# body and thus need to predict their three-dimensional structure from their +# raw amino acid sequence. # Historically determining the protein structure of a single sequence could # take years of wet-lab experiments and millions of dollars, but with the advent -# of deep protein folding models like ESM3, this can be done in seconds for a -# few dollars. +# of deep protein folding models like Meta's ESM3, a quality approximation can be +# done in seconds for a few cents on Modal. + +# The bioinformatics community has created a number of resources for managing +# protein data, including the [Research Collaboratory for Structural Bioinformatics](https://www.rcsb.org/) (RCSB) which stores the known structure of over 200,000 proteins. +# These sets of proteins structures are known as Protein Data Banks (PDB). +# We'll also be using the open source [Biotite](https://www.biotite-python.org/) +# library to help us extract sequences and residues from PDBs. Biotite was +# created by Patrick Kunzmann and members of the Protein Bioinformatics lab +# at Rub University Bochum in Germany. # ## Basic Setup @@ -44,37 +54,43 @@ gpu = "A10G" -# ### Create a Volume to store weights, data, and HTML +# ### Create a Volume to store weights and pdb data -# On the first launch we'll be downloading the model weights from Hugging -# Face(?), to speed up future cold starts we'll store the weights on a -# [Volume](https://modal.com/docs/guide/volumes) for quick reads. We'll also -# use this volume to cache PDB data from the web and store HTML files for -# visualization. -volume = modal.Volume.from_name( - "example-protein-fold", create_if_missing=True -) +# To minimize cold start times we'll store the model weights on a modal +# [Volume](https://modal.com/docs/guide/volumes) so they can be quickly loaded +# onto the GPU. We'll also store pdb data on the volume to minimize hitting +# the RCSB server for the same data multiple times. + + +volume = modal.Volume.from_name("example-protein-fold", create_if_missing=True) VOLUME_PATH = Path("/vol/data") PDBS_PATH = VOLUME_PATH / "pdbs" HTMLS_PATH = VOLUME_PATH / "htmls" + # ### Define dependencies in container images # The container image for inference is based on Modal's default slim Debian -# Linux image with `esm` for loading and running or model -# for the +# Linux image with `esm` for loading and running the model, and `hf_transfer` +# for a higher bandwidth download of the model weights from hugging face. + + esm3_image = ( - modal.Image.debian_slim(python_version="3.12") - .pip_install( + modal.Image.debian_slim(python_version="3.12").pip_install( "esm==3.0.5", - "torch==2.4.1", # Needed? FIXME - ) + "torch==2.4.1", + "huggingface_hub[hf_transfer]==0.26.2", + ).env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) ) -# We'll also define a web app image for extracting PDB data. FIXME + +# We'll also define a web app image as a frontend for the entire system that +# includes `gradio` for building a UI, `biotite` for extracting sequences and +# residues from PDBs, and `py3Dmol` for visualizing the 3D structures. + + web_app_image = ( - modal.Image.debian_slim(python_version="3.11") - .pip_install( + modal.Image.debian_slim(python_version="3.12").pip_install( "esm==3.0.5", "gradio~=4.44.0", "biotite==0.41.2", @@ -86,27 +102,37 @@ ) +# Here we "pre-import" libraries that will be used by the functions we run +# on Modal in a given image using the `with image.imports` context manager. + -# We can also "pre-import" libraries that will be used by the functions we run on Modal in a given image -# using the `with image.imports` context manager. with esm3_image.imports(): + from esm.models.esm3 import ESM3 from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig import torch + with web_app_image.imports(): from py3DmolWrapper import py3DMolViewWrapper + import os + + import biotite.database.rcsb as rcsb + import biotite.structure as b_structure + import biotite.structure.io.pdbx as pdbx + from biotite.structure.io.pdbx import CIFFile + from esm.sdk.api import ESMProtein + # ## Defining a `Model` inference class for ESM3 # Next, we map the model's setup and inference code onto Modal. -# 1. We run any setup that can be persisted to disk in methods decorated with `@build`. -# In this example, that includes downloading the model weights. -# 2. We run any additional setup, like moving the model to the GPU, in methods decorated with `@enter`. -# We do our model optimizations in this step. For details, see the section on `torch.compile` below. -# 3. We run the actual inference in methods decorated with `@method`. +# 1. To ensure the cached image includes the model weights, we download them in a +# method decorated with `@build`, this reduces cold boot times. +# 2. For any additional setup code that involves CPU or GPU RAM we put in a +# method deocrated with `@enter` which runs on container start. +# 3. To run the actual inference, we put it in a method decorated with `@method` -# Here we define the @app.cls( image=esm3_image, @@ -114,45 +140,69 @@ gpu=gpu, timeout=20 * MINUTES, ) -class ModelInference: +class Model: + def setup_model(self): + self.model: ESM3InferenceClient = ESM3.from_pretrained( + "esm3_sm_open_v1" + # How to store on volume FIXME + ) + @modal.build() - @modal.enter() def build(self): - from esm.models.esm3 import ESM3 - self.model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda") - # XXX Remove: - # https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb#scrollTo=cc0f0186 + self.setup_model() + + @modal.enter() + def enter(self): + self.setup_model() + self.model.to("cuda") + + # Enable half precision for faster inference self.model = self.model.half() torch.backends.cuda.matmul.allow_tf32 = True self.max_steps = 250 + L.info(f"Setting max ESM steps to: {self.max_steps}") @modal.method() - def generate(self, sequence: str) -> bool: - start_time = time.monotonic() # TODO Remove + def inference(self, sequence: str) -> bool: + start_time = time.monotonic() # TODO Remove num_steps = min(len(sequence), self.max_steps) structure_generation_config = GenerationConfig( track="structure", num_steps=num_steps, ) + L.info("Running ESM3 inference with num_steps={num_steps}") esm_protein = self.model.generate( - ESMProtein(sequence=sequence), - structure_generation_config + ESMProtein(sequence=sequence), structure_generation_config ) latency_s = time.monotonic() - start_time - print (f"Inference latency: {latency_s:.2f} seconds.") + L.info(f"Inference latency: {latency_s:.2f} seconds.") - # Check that esm_protein did not error. - print ("Checking for errors...") + L.info("Checking for errors...") if hasattr(esm_protein, "error_msg"): raise ValueError(esm_protein.error_msg) - print ("Moving all data off GPU before returning to caller...") - esm_protein.ptm = esm_protein.ptm.to('cpu') + L.info("Moving all data off GPU before returning to caller...") + esm_protein.ptm = esm_protein.ptm.to("cpu") return esm_protein + +# ### Serving a Gradio UI `web_endpoint` with an `asgi_app` + +# The `ModelInference` class above is available for use +# from any other Python environment with the right Modal credentials +# and the `modal` package installed -- just use [`lookup`](https://modal.com/docs/reference/modal.Cls#lookup). + +# But we can also expose it via a web app using the `@asgi_app` decorator. Here +# we will specifically create a Gradio web app that allows us to visualize +# the 3D structure output of the ESM3 model as well as any known structures from +# the RCSB Protein Data Bank. + +# You should see the URL for this UI in the output of `modal deploy` +# or on your [Modal app dashboard](https://modal.com/apps) for this app. + assets_path = Path(__file__).parent / "assets" @app.function( @@ -170,172 +220,53 @@ def protein_fold_fastapi_app(): import gradio as gr from gradio.routes import mount_gradio_app - import biotite.database.rcsb as rcsb - import biotite.structure as b_structure - import biotite.structure.io.pdbx as pdbx - from biotite.structure.io.pdbx import CIFFile - import biotite.application.blast as blast - from esm.sdk.api import ESMProtein - - import os - width, height = 400, 400 - def fetch_pdb_if_necessary(pdb_id, pdb_type): - # Only fetch from server if it's not on the volume already. - assert pdb_type in ("pdb", "pdbx") - file_path = PDBS_PATH / f"{pdb_id}.{pdb_type}" - if not os.path.exists(file_path): - print (f"Loading PDB {file_path} from server...") - rcsb.fetch(pdb_id, pdb_type, str(PDBS_PATH)) - return file_path - - def get_sequence(pdb_id): - try: - pdb_id = pdb_id.strip() # Remove whitespace - pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx") - - structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1) - - # TODO Excluding atoms that are not in a Amino Acid? - (amino_sequences, _) = ( - b_structure.to_sequence( - structure[b_structure.filter_amino_acids(structure)])) - # Note: len(amino_sequences) == # of chains - - # TODO Should I put something between chains? - sequence = "".join([str(s) for s in amino_sequences]) - return sequence - - except Exception as e: - return f"Error: {e}" - - def find_pdb_id_from_sequence(query_sequence, maybe_stale_pdb_id): - pdb_id_sequence = get_sequence(maybe_stale_pdb_id) - - if pdb_id_sequence == query_sequence: - print ("Skipped BLAST run, PDB ID already known.") - not_stale_pdb_id = maybe_stale_pdb_id - return not_stale_pdb_id - else: - # TODO Use Blast or some kind of Sequence Database - # Skipping Blast which is rate limited for now. - return None - - print ("Running blast to find PDB ID of sequence...") - blast_app = blast.BlastWebApp( - "blastp", query_sequence, database="pdb") - blast_app.start() - blast_app.join() - - alignments = blast_app.get_alignments() - if len(alignments) == 0: - return None - return alignments[0].hit_id - - def build_database_html(sequence, maybe_stale_pdb_id): - pdb_id = find_pdb_id_from_sequence(sequence, maybe_stale_pdb_id) - if pdb_id is None: - return "

Folding Structure of Sequence not found.

" - - # Remove chain information if present. - if pdb_id.find("_") != -1: # "1CRN_A" -> "1CRN" - pdb_id = pdb_id[:pdb_id.index("_")] - - # Extract secondary structure from PDBX file. - pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx") - structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1) - # TODO Should I only use AA atoms? - atoms = structure[ - b_structure.filter_amino_acids(structure)] - residue_secondary_structures = b_structure.annotate_sse(atoms) - residue_ids = b_structure.get_residues(atoms)[0] - residue_id_to_sse = {} - for (residue_id, sse) in zip(residue_ids, residue_secondary_structures): - residue_id_to_sse[int(residue_id)] = sse - - # Extract PDB string from PDB file. - pdb_file_path = fetch_pdb_if_necessary(pdb_id, "pdb") - pdb_string = Path(pdb_file_path).read_text() - - return py3DMolViewWrapper().build_html_with_secondary_structure( - width, height, pdb_string, residue_id_to_sse) - - def remove_nested_quotes(html): - """Remove nested quotes in HEADER""" - if html.find("HEADER") == -1: - return html - - # Need this because srcdoc adds another quote and triple is tricky. - i = html.index("HEADER") - # Getting the HEADER double quote end by matching: \n","pdb"); - j = html[i:].index('"pdb");') + i - len('",') - header_html = html[i:j] - for delete_me in ("\\'", '\\"', "'", '"'): - header_html = header_html.replace(delete_me, "") - html = html[:i] + header_html + html[j:] - return html - - def postprocess_html(html): - html = remove_nested_quotes(html) - html = html.replace("'", '"') - - html_wrapped = f"""{html}""" - iframe_html = (f"""""") - return iframe_html - # return html_wrapped - - def run_esm_and_graph_debug(sequence, maybe_stale_pdb_id): - # Data cleaning - sequence = sequence.strip() # Remove whitespace - maybe_stale_pdb_id = maybe_stale_pdb_id.strip() # Remove whitespace - - database_sse_html = build_database_html(sequence, maybe_stale_pdb_id) - - # Save HTMLs to volume for debugging issues. - HTMLS_PATH.mkdir(parents=True, exist_ok=True) - with open(str(HTMLS_PATH / "db.html"), "w") as f: - f.write(postprocess_html(database_sse_html)) - - return [postprocess_html(h) - # for h in (esm_sse_html, esm_pLDDT_html, database_sse_html)] - for h in [database_sse_html] * 3] - - def run_esm_and_graph(sequence, maybe_stale_pdb_id): - # Data cleaning - sequence = sequence.strip() # Remove whitespace - maybe_stale_pdb_id = maybe_stale_pdb_id.strip() # Remove whitespace - - # Run ESM - esm_protein: ESMProtein = ModelInference().generate.remote(sequence) + def extract_data(esm_protein): residue_pLDDTs = (100 * esm_protein.plddt).tolist() atoms = esm_protein.to_protein_chain().atom_array - residue_secondary_structures = b_structure.annotate_sse(atoms) - residue_ids = b_structure.get_residues(atoms)[0] - residue_id_to_sse = {} - for (residue_id, sse) in zip(residue_ids, residue_secondary_structures): - residue_id_to_sse[int(residue_id)] = sse + residue_id_to_sse = extract_residues(atoms) + return residue_pLDDTs, residue_id_to_sse - esm_sse_html = ( - py3DMolViewWrapper().build_html_with_secondary_structure( - width, height, esm_protein.to_pdb_string(), - residue_id_to_sse)) + def run_esm(sequence, maybe_stale_pdb_id): + L.info("Removing whitespace from text input") + sequence = sequence.strip() + maybe_stale_pdb_id = maybe_stale_pdb_id.strip() - esm_pLDDT_html = ( - py3DMolViewWrapper().build_html_with_pLDDTs( - width, height, esm_protein.to_pdb_string(), residue_pLDDTs)) + L.info("Running ESM") + esm_protein: ESMProtein = Model().inference.remote(sequence) + residue_pLDDTs, residue_id_to_sse = extract_data(esm_protein) - database_sse_html = build_database_html(sequence, maybe_stale_pdb_id) + L.info("Constructing HTML for ESM3 prediction with residues.") + esm_sse_html = py3DMolViewWrapper().build_html_with_secondary_structure( + width, height, esm_protein.to_pdb_string(), residue_id_to_sse + ) - # Save HTMLs to volume for debugging issues. - HTMLS_PATH.mkdir(parents=True, exist_ok=True) - with open(str(HTMLS_PATH / "db.html"), "w") as f: - f.write(postprocess_html(database_sse_html)) + L.info("Constructing HTML for ESM3 prediction with confidence.") + esm_pLDDT_html = py3DMolViewWrapper().build_html_with_pLDDTs( + width, height, esm_protein.to_pdb_string(), residue_pLDDTs + ) - return [postprocess_html(h) - for h in (esm_sse_html, esm_pLDDT_html, database_sse_html)] + pdb_sequence = get_sequence(maybe_stale_pdb_id) + if pdb_sequence == sequence: + pdb_id = maybe_stale_pdb_id + # if pdb_id.find("_") != -1: # "1CRN_A" -> "1CRN" + # pdb_id = pdb_id[:pdb_id.index("_")] # WHy was this here. + L.info(f"Constructing HTML for RCSB entry for {pdb_id}") + pdb_string, residue_id_to_sse = extract_pdb_and_residues(pdb_id) + rcsb_sse_html = ( + py3DMolViewWrapper().build_html_with_secondary_structure( + width, height, pdb_string, residue_id_to_sse + ) + ) + else: + L.info("Sequence structure is unknown, generating HTML as such.") + rcsb_sse_html = "

Folding Structure of Sequence not found.

" + + return [ + postprocess_html(h) + for h in (esm_sse_html, esm_pLDDT_html, rcsb_sse_html) + ] web_app = FastAPI() @@ -365,23 +296,25 @@ async def background(): "Actin [1ATN]", "Ribonuclease A [5RSA]", ] + + # Number the examples. example_pdbs = [f"{i+1}) {x}" for i, x in enumerate(example_pdbs)] with gr.Blocks( theme=theme, css=css, title="ESM3 Protein Folding" ) as interface: - # Title gr.Markdown("# Fold Proteins using ESM3 Fold") with gr.Row(): - # PDB ID text box + Get sequence from it button with gr.Column(): gr.Markdown("## Custom PDB ID") pdb_id_box = gr.Textbox( label="Enter PDB ID or select one on the right", - placeholder="e.g. '5JQ4', '1MBO', '1TPO', etc.") - get_sequence_button = ( - gr.Button("Retrieve Sequence from PDB ID", variant="primary")) + placeholder="e.g. '5JQ4', '1MBO', '1TPO', etc.", + ) + get_sequence_button = gr.Button( + "Retrieve Sequence from PDB ID", variant="primary" + ) pdb_link_button = gr.Button(value="Open PDB page for ID") rcsb_link = "https://www.rcsb.org/structure/" @@ -389,50 +322,53 @@ async def background(): fn=None, inputs=pdb_id_box, js=f"""(pdb_id) => {{ window.open("{rcsb_link}" + pdb_id) }}""", - ) + ) with gr.Column(): - def extract_pdb_id(idx): - pdb = example_pdbs[idx] - s = pdb.index("[") + 1 - e = pdb.index("]") - return pdb[s:e] - h = int(len(example_pdbs) / 2) + def extract_pdb_id(example_idx): + pdb = example_pdbs[example_idx] + return pdb[pdb.index("[") + 1 : pdb.index("]")] + + half_len = int(len(example_pdbs) / 2) gr.Markdown("## Example PDB IDs") with gr.Row(): with gr.Column(): - # add in a few examples to inspire users - for i, pdb in enumerate(example_pdbs[:h]): + for i, pdb in enumerate(example_pdbs[:half_len]): btn = gr.Button(pdb, variant="secondary") btn.click( - fn=lambda j=i: extract_pdb_id(j), outputs=pdb_id_box) + fn=lambda j=i: extract_pdb_id(j), + outputs=pdb_id_box, + ) with gr.Column(): - # more examples - for i, pdb in enumerate(example_pdbs[h:]): + for i, pdb in enumerate(example_pdbs[half_len:]): btn = gr.Button(pdb, variant="secondary") btn.click( - fn=lambda j=i+4: extract_pdb_id(j), outputs=pdb_id_box) + fn=lambda j=i + half_len: extract_pdb_id(j), + outputs=pdb_id_box, + ) - # Sequence text box + Run ESM from it button gr.Markdown("## Sequence") sequence_box = gr.Textbox( label="Enter a sequence or retrieve it from a PDB ID", - placeholder="e.g. 'MVTRLE...', 'GKQEG...', etc.") - run_esm_and_graph_button = ( - gr.Button("Run ESM3 Fold on Sequence", variant="primary")) + placeholder="e.g. 'MVTRLE...', 'GKQEG...', etc.", + ) + run_esm_button = gr.Button( + "Run ESM3 Fold on Sequence", variant="primary" + ) - # 3 Columns of Protein Folding htmls = [] + legend_height = 100 with gr.Row(): - # Column 1: ESM Prediction with SSE coloring. with gr.Column(): gr.Markdown("## ESM3 Prediction - Secondary Structs") gr.Image( # output image component - height=100, width=400, + height=legend_height, + width=width, value="/assets/secondaryStructureLegend.png", - show_download_button=False, show_label=False, + show_download_button=False, + show_label=False, show_fullscreen_button=False, ) htmls.append(gr.HTML()) @@ -441,27 +377,34 @@ def extract_pdb_id(idx): with gr.Column(): gr.Markdown("## ESM3 Prediction - PLTT Confidence") gr.Image( # output image component - height=100, width=400, value="/assets/plddtLegend2.png", - show_download_button=False, show_label=False, + height=legend_height, + width=width, + value="/assets/plddtLegend2.png", + show_download_button=False, + show_label=False, show_fullscreen_button=False, ) htmls.append(gr.HTML()) # Column 3: Database showing Crystalized form of Protein if avail. with gr.Column(): - gr.Markdown("## Crystalized Structure from PDB") + gr.Markdown("## Crystalized Structure from RCSB's PDB") gr.Image( # output image component - height=100, width=400, + height=legend_height, + width=width, value="/assets/secondaryStructureLegend.png", - show_download_button=False, show_label=False, + show_download_button=False, + show_label=False, show_fullscreen_button=False, ) htmls.append(gr.HTML()) get_sequence_button.click( - fn=get_sequence, inputs=[pdb_id_box], outputs=[sequence_box]) - run_esm_and_graph_button.click(fn=run_esm_and_graph, - inputs=[sequence_box, pdb_id_box], outputs=htmls) + fn=get_sequence, inputs=[pdb_id_box], outputs=[sequence_box] + ) + run_esm_button.click( + fn=run_esm, inputs=[sequence_box, pdb_id_box], outputs=htmls + ) # mount for execution on Modal return mount_gradio_app( @@ -470,18 +413,106 @@ def extract_pdb_id(idx): path="/", ) + @app.local_entrypoint() def main(): - pdb_id = "1ZNI" - # pdb_id = "1MBO" - test = retrieve_pdb.remote(pdb_id) - - # # YFP? - # sequence = ( - # "AMFSKVNNQKMLEDCFYIRKKVFVEEQGIPEESEIDEYESESIHLIGYDNGQPVATARIRPINETTVKIERVAVMKSHRGQGMGRMLMQAVESLAKDEGFYVATMNAQCHAIPFYESLNFKMRGNIFLEEGIEHIEMTKKLT") - # for i in range(10): - # ModelInference().generate.remote(sequence) - # x = 1 - -if __name__ == "__main__": - main() + sequence = ( + "AMFSKVNNQKMLEDCFYIRKKVFVEEQGIPEESEIDEYESESIHLIGYDNGQPVATARIRPINETTVKI" + "ERVAVMKSHRGQGMGRMLMQAVESLAKDEGFYVATMNAQCHAIPFYESLNFKMRGNIFLEEGIEHIEMT" + "KKLT" + ) + esm_protein = Model().inference.remote(sequence) + L.info(esm_protein) + + +# ## Addenda + +# The remainder of this code is boilerplate. + +# ### Extracting Sequences and Residues from PDBs + +# A fair amount of code is required to extract sequences and residue +# information from pdb strings, we build several helper functions for it below. + + +def fetch_pdb_if_necessary(pdb_id, pdb_type): + """Fetch PDB from RCSB if not already cached.""" + + assert pdb_type in ("pdb", "pdbx") + + file_path = PDBS_PATH / f"{pdb_id}.{pdb_type}" + if not os.path.exists(file_path): + L.info(f"Loading PDB {file_path} from server...") + rcsb.fetch(pdb_id, pdb_type, str(PDBS_PATH)) + return file_path + + +def get_sequence(pdb_id): + try: + pdb_id = pdb_id.strip() + pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx") + + structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1) + + amino_sequences, _ = b_structure.to_sequence( + structure[b_structure.filter_amino_acids(structure)] + ) + + sequence = "".join([str(s) for s in amino_sequences]) + return sequence + + except Exception as e: + return f"Error: {e}" + + +def extract_residues(atoms): + residue_secondary_structures = b_structure.annotate_sse(atoms) + residue_ids = b_structure.get_residues(atoms)[0] + residue_id_to_sse = {} + for residue_id, sse in zip(residue_ids, residue_secondary_structures): + residue_id_to_sse[int(residue_id)] = sse + return residue_id_to_sse + + +def extract_pdb_and_residues(pdb_id): + pdb_file_path = fetch_pdb_if_necessary(pdb_id, "pdb") + pdb_string = Path(pdb_file_path).read_text() + + pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx") + structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1) + atoms = structure[b_structure.filter_amino_acids(structure)] + residue_id_to_sse = extract_residues(atoms) + + return pdb_string, residue_id_to_sse + +# ### Miscellaneous +# The remaining code includes small helper functions for cleaning up HTML +# generated by py3DMol for viewing on a webpage. + +def remove_nested_quotes(html): + """Remove triple nested quotes in HEADER""" + if html.find("HEADER") == -1: + return html + + i = html.index("HEADER") + j = html[i:].index('"pdb");') + i - len('",') + header_html = html[i:j] + for delete_me in ("\\'", '\\"', "'", '"'): + header_html = header_html.replace(delete_me, "") + html = html[:i] + header_html + html[j:] + return html + + +def postprocess_html(html): + html = remove_nested_quotes(html) + html = html.replace("'", '"') + + L.info("Wrapping py3DMol HTML in iframe so we can view multiple HTMLs.") + html_wrapped = f"""{html}""" + iframe_html = ( + f"""""" + ) + return iframe_html # FIXME? + # return html_wrapped diff --git a/06_gpu_and_ml/protein_fold/py3DmolWrapper.py b/06_gpu_and_ml/protein_fold/py3DmolWrapper.py index 91c59c67e..30512e542 100644 --- a/06_gpu_and_ml/protein_fold/py3DmolWrapper.py +++ b/06_gpu_and_ml/protein_fold/py3DmolWrapper.py @@ -3,10 +3,12 @@ # For PDB file formatting and conventions refer to: # https://www.biostat.jhsph.edu/~iruczins/teaching/260.655/links/pdbformat.pdf -import py3Dmol + +from dataclasses import dataclass import logging as L import numpy as np -from dataclasses import dataclass +import py3Dmol + @dataclass class pLDDTBands: @@ -15,7 +17,7 @@ class pLDDTBands: name: str color: str -class py3DMolViewWrapper(object): +class py3DMolViewWrapper: def __init__(self): # Color & Ranges AlphaFold colab self.pLDDT_bands = [ @@ -34,7 +36,7 @@ def __init__(self): 'a' : magenta_hex, # Alpha Helix 'b': gold_hex, # Beta Sheet 'c': white_hex, # Coil - '': black_hex} # Loop? + '': black_hex} # Loop ################################# ### Secondary Structures Plot ### @@ -64,13 +66,8 @@ def build_html_with_secondary_structure(self, view = py3Dmol.view(width=width, height=height) view.addModel(pdb_string, "pdb") - lowest_residue_id = self.setup_secondary_structures_plot( - pdb_string, residue_id_to_sse) - color_map = {rid : self.secondary_structure_to_color[sse] for rid, sse in residue_id_to_sse.items()} - print (lowest_residue_id) - print (color_map) view.setStyle( {"cartoon": @@ -78,7 +75,6 @@ def build_html_with_secondary_structure(self, {"prop": "resi", # refers to residual index in pdb_string "map": color_map}}}) view.zoomTo() - # TODO: pdb_string sometimes creates invalid HTML (missing ')') return view._make_html() ############################### @@ -134,7 +130,6 @@ def setup_pLDDTs_plot( assert num_residues == len(residue_pLDDTs), ( f"{num_residues} != {len(residue_pLDDTs)}") - new_pdb_string = '\n'.join(new_lines) return new_pdb_string @@ -155,5 +150,4 @@ def build_html_with_pLDDTs(self, width, height, pdb_string, residue_pLDDTs): 'map': color_map}}}) view.zoomTo() - # TODO: pdb_string sometimes creates invalid HTML (missing ')') return view._make_html()