Skip to content

Commit

Permalink
updata
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Tars committed Dec 27, 2024
1 parent dd595fa commit 385cef5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
1 change: 0 additions & 1 deletion examples/BuddyLlama/import-llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@

# Retrieve the LLaMA model path from environment variables.
model_path = os.environ.get("LLAMA_MODEL_PATH")
print(model_path)
if model_path is None:
raise EnvironmentError(
"The environment variable 'LLAMA_MODEL_PATH' is not set or is invalid."
Expand Down
26 changes: 26 additions & 0 deletions examples/BuddyStableDiffusion/import-stable-diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,32 @@
driver_unet = GraphDriver(graphs_unet[0])
driver_vae = GraphDriver(graphs_vae[0])

driver_text_encoder._subgraphs[
"subgraph0_text_encoder"
] = driver_text_encoder._subgraphs.pop("subgraph0")
driver_text_encoder._subgraphs_inputs[
"subgraph0_text_encoder"
] = driver_text_encoder._subgraphs_inputs.pop("subgraph0")
driver_text_encoder._subgraphs_outputs[
"subgraph0_text_encoder"
] = driver_text_encoder._subgraphs_outputs.pop("subgraph0")
driver_unet._subgraphs["subgraph0_unet"] = driver_unet._subgraphs.pop(
"subgraph0"
)
driver_unet._subgraphs_inputs[
"subgraph0_unet"
] = driver_unet._subgraphs_inputs.pop("subgraph0")
driver_unet._subgraphs_outputs[
"subgraph0_unet"
] = driver_unet._subgraphs_outputs.pop("subgraph0")
driver_vae._subgraphs["subgraph0_vae"] = driver_vae._subgraphs.pop("subgraph0")
driver_vae._subgraphs_inputs[
"subgraph0_vae"
] = driver_vae._subgraphs_inputs.pop("subgraph0")
driver_vae._subgraphs_outputs[
"subgraph0_vae"
] = driver_vae._subgraphs_outputs.pop("subgraph0")

driver_text_encoder.subgraphs[0]._func_name = "subgraph0_text_encoder"
driver_unet.subgraphs[0]._func_name = "subgraph0_unet"
driver_vae.subgraphs[0]._func_name = "subgraph0_vae"
Expand Down

0 comments on commit 385cef5

Please sign in to comment.