Skip to content

Commit

Permalink
Fix the bug in topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Jan 8, 2025
1 parent bdbb053 commit 6bc67f9
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 49 deletions.
4 changes: 2 additions & 2 deletions examples/modular-transformers/configuration_my_new_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MyNewModelConfig(PretrainedConfig):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
MyNewModel 2 up to 4096, CodeMyNewModel up to 16384.
MyNewModel 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
Expand Down Expand Up @@ -110,7 +110,7 @@ class MyNewModelConfig(PretrainedConfig):
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_heads
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
```python
>>> from transformers import MyNewModelModel, MyNewModelConfig
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand Down
65 changes: 21 additions & 44 deletions utils/create_dependency_mapping.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,30 @@
import ast
from collections import defaultdict


# Function to perform topological sorting
def topological_sort(dependencies):
new_dependencies = {}
graph = defaultdict(list)
def topological_sort(dependencies: dict):
# Nodes are the name of the models to convert (we only add those to the graph)
nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies.keys()}
# This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency)
graph = {}
name_mapping = {}
for node, deps in dependencies.items():
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
for dep in deps:
dep_name = dep.split(".")[-2]
if dep_name == node_name:
# Skip self dependencies for topological sort as they create cycles
continue
if "auto" not in dep and node_name not in graph[dep_name]:
graph[dep_name].append(node_name)
new_dependencies[node_name] = node

# Create a graph and in-degree count for each node
def filter_one_by_one(filtered_list, reverse):
if len(reverse) == 0:
return filtered_list

graph = defaultdict(list)
# Build the graph
for node, deps in reverse.items():
for dep in deps:
graph[dep].append(node)

base_modules = set(reverse.keys()) - set(graph.keys())
if base_modules == reverse.keys():
# we are at the end
return filtered_list + list(graph.keys())
to_add = []
for k in graph.keys():
if len(graph[k]) == 1 and graph[k][0] in base_modules:
if graph[k][0] in reverse:
del reverse[graph[k][0]]
if k not in filtered_list:
to_add += [k]
for k in base_modules:
if k not in filtered_list:
to_add += [k]
filtered_list += list(to_add)
return filter_one_by_one(filtered_list, reverse)

final_order = filter_one_by_one([], graph)

return [new_dependencies.get(k) for k in final_order if k in new_dependencies]
dep_names = {dep.split(".")[-2] for dep in deps}
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
graph[node_name] = dependencies
name_mapping[node_name] = node

sorting_list = []
while len(graph) > 0:
# Find the nodes with 0 out-degree
leaf_nodes = {node for node in graph if len(graph[node]) == 0}
# Add them to the list
sorting_list += list(leaf_nodes)
# Remove the leafs from the graph (and from the deps of other nodes)
graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes}

return [name_mapping[x] for x in sorting_list]


# Function to extract class and import info from a file
Expand Down
8 changes: 6 additions & 2 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,9 +1691,13 @@ def save_modeling_file(modular_file, converted_file):
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True)
if args.files_to_parse == ["examples"]:
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)

for file_name in find_priority_list(args.files_to_parse):
priority_list = find_priority_list(args.files_to_parse)
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"

for file_name in priority_list:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converted_files = convert_modular_file(file_name)
Expand Down

0 comments on commit 6bc67f9

Please sign in to comment.