Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix modular edge case + modular sorting order #35562

Merged
merged 9 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/modular-transformers/configuration_my_new_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MyNewModelConfig(PretrainedConfig):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
MyNewModel 2 up to 4096, CodeMyNewModel up to 16384.
MyNewModel 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
Expand Down Expand Up @@ -110,7 +110,7 @@ class MyNewModelConfig(PretrainedConfig):
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_heads
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
```python
>>> from transformers import MyNewModelModel, MyNewModelConfig
Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_multimodal1.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,7 @@ def _init_weights(self, module):


class DiffLlamaRotaryEmbedding(nn.Module):
def __init__(
self,
config: DiffLlamaConfig,
device=None,
):
def __init__(self, config: DiffLlamaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
Expand Down Expand Up @@ -898,7 +894,7 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand Down
63 changes: 63 additions & 0 deletions tests/repo_utils/modular/test_conversion_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import sys
import unittest


ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(ROOT_DIR, "utils"))

import create_dependency_mapping # noqa: E402


# This is equivalent to `all` in the current library state (as of 09/01/2025)
MODEL_ROOT = os.path.join("src", "transformers", "models")
FILES_TO_PARSE = [
os.path.join(MODEL_ROOT, "starcoder2", "modular_starcoder2.py"),
os.path.join(MODEL_ROOT, "gemma", "modular_gemma.py"),
os.path.join(MODEL_ROOT, "olmo2", "modular_olmo2.py"),
os.path.join(MODEL_ROOT, "diffllama", "modular_diffllama.py"),
os.path.join(MODEL_ROOT, "granite", "modular_granite.py"),
os.path.join(MODEL_ROOT, "gemma2", "modular_gemma2.py"),
os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"),
os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"),
os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"),
os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"),
os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"),
os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"),
os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"),
os.path.join(MODEL_ROOT, "colpali", "modular_colpali.py"),
os.path.join(MODEL_ROOT, "deformable_detr", "modular_deformable_detr.py"),
os.path.join(MODEL_ROOT, "aria", "modular_aria.py"),
os.path.join(MODEL_ROOT, "ijepa", "modular_ijepa.py"),
os.path.join(MODEL_ROOT, "bamba", "modular_bamba.py"),
os.path.join(MODEL_ROOT, "dinov2_with_registers", "modular_dinov2_with_registers.py"),
os.path.join(MODEL_ROOT, "instructblipvideo", "modular_instructblipvideo.py"),
os.path.join(MODEL_ROOT, "glm", "modular_glm.py"),
os.path.join(MODEL_ROOT, "phi", "modular_phi.py"),
os.path.join(MODEL_ROOT, "mistral", "modular_mistral.py"),
os.path.join(MODEL_ROOT, "phi3", "modular_phi3.py"),
os.path.join(MODEL_ROOT, "cohere", "modular_cohere.py"),
]


def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool:
"""Return True if `model1` appear after `model2` in `priority_list`."""
return priority_list.index(model1) > priority_list.index(model2)


class ConversionOrderTest(unittest.TestCase):
def test_conversion_order(self):
# Find the order
priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
# Extract just the model names
model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list]

# These are based on what the current library order should be (as of 09/01/2025)
self.assertTrue(appear_after("mixtral", "mistral", model_priority_list))
self.assertTrue(appear_after("gemma2", "gemma", model_priority_list))
self.assertTrue(appear_after("starcoder2", "mistral", model_priority_list))
self.assertTrue(appear_after("olmo2", "olmo", model_priority_list))
self.assertTrue(appear_after("diffllama", "mistral", model_priority_list))
self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list))
self.assertTrue(appear_after("cohere2", "cohere", model_priority_list))
self.assertTrue(appear_after("phi3", "mistral", model_priority_list))
66 changes: 22 additions & 44 deletions utils/create_dependency_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,29 @@


# Function to perform topological sorting
def topological_sort(dependencies):
new_dependencies = {}
graph = defaultdict(list)
def topological_sort(dependencies: dict):
# Nodes are the name of the models to convert (we only add those to the graph)
nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies.keys()}
# This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency)
graph = {}
name_mapping = {}
for node, deps in dependencies.items():
node_name = node.split("/")[-2]
for dep in deps:
dep_name = dep.split(".")[-2]
if dep_name == node_name:
# Skip self dependencies for topological sort as they create cycles
continue
if "example" not in node and "auto" not in dep and node_name not in graph[dep_name]:
graph[dep_name].append(node_name)
new_dependencies[node_name] = node

# Create a graph and in-degree count for each node
def filter_one_by_one(filtered_list, reverse):
if len(reverse) == 0:
return filtered_list

graph = defaultdict(list)
# Build the graph
for node, deps in reverse.items():
for dep in deps:
graph[dep].append(node)

base_modules = set(reverse.keys()) - set(graph.keys())
if base_modules == reverse.keys():
# we are at the end
return filtered_list + list(graph.keys())
to_add = []
for k in graph.keys():
if len(graph[k]) == 1 and graph[k][0] in base_modules:
if graph[k][0] in reverse:
del reverse[graph[k][0]]
if k not in filtered_list:
to_add += [k]
for k in base_modules:
if k not in filtered_list:
to_add += [k]
filtered_list += list(to_add)
return filter_one_by_one(filtered_list, reverse)

final_order = filter_one_by_one([], graph)

return [new_dependencies.get(k) for k in final_order if k in new_dependencies]
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
dep_names = {dep.split(".")[-2] for dep in deps}
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
graph[node_name] = dependencies
name_mapping[node_name] = node

sorting_list = []
while len(graph) > 0:
# Find the nodes with 0 out-degree
leaf_nodes = {node for node in graph if len(graph[node]) == 0}
# Add them to the list
sorting_list += list(leaf_nodes)
# Remove the leafs from the graph (and from the deps of other nodes)
graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes}

return [name_mapping[x] for x in sorting_list]
Comment on lines +6 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats, much easier to read and understand from an (almost) outsider perspective. maybe naive notes/questions:

  • there can't be dependency cycles, right? because they'd break this, the loop would not end
  • I'd be wary of sets since they don't preserve ordering, we might want it later

Copy link
Member Author

@Cyrilvallez Cyrilvallez Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @molbap, thanks for looking into it! You raise some good points, I should have provided a bit more details:

  1. Indeed in our case we can never have cycles as models are added chronologically, ensuring it (if at time 't', i.e. now, we dont' have cycles, we will never have any by adding more models: adding a model is equivalent to adding a new non existing node with 0 in-degree and as much out-degree as modeling imported). For our past model refactors, we might need to be slightly wary of this; it is our responsibility to ensure consistency here and not add cycles. But I'm extremely confident this case will never appear in practice, as we usually rely on the old "copied from" to know from which model we can refactor, which has the same notion of chronology. Worst case scenario, the person refactoring an old model and adding a cycle by mistake will immediately notice that there is a bug, and will be able to switch the "origin model" for the refactor.
    Thus, the topology of our current (and future) graph (directed acyclic graph) ensures that the algorithm will not get stuck 👌🏻

  2. At each step, the order of the nodes with 0 out-degree does not matter, so it is easier to use sets here as they have nice property and we can easily take the difference with the "-" operator!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. well put, thanks! understood, yes I was thinking people might get caught up in an error during development but it will never reach main. Dag rules
  2. if we're sure order is useless then yes, sets are the best choice indeed :)



# Function to extract class and import info from a file
Expand Down
10 changes: 7 additions & 3 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_module_source_from_name(module_name: str) -> str:
def preserve_case_replace(text, patterns: dict, default_name: str):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
compiled_regex = re.compile(f"(?<![a-z0-9])({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)

def replace(match):
matched_pattern = match.group(1)
Expand Down Expand Up @@ -1691,9 +1691,13 @@ def save_modeling_file(modular_file, converted_file):
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True)
if args.files_to_parse == ["examples"]:
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)

for file_name in find_priority_list(args.files_to_parse):
priority_list = find_priority_list(args.files_to_parse)
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"

for file_name in priority_list:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converted_files = convert_modular_file(file_name)
Expand Down
Loading