From 9e000eda2716aa627e562d0ba6a8fb6113fb22cf Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 8 Jan 2025 12:06:51 +0100 Subject: [PATCH 1/9] look-ahead negation --- utils/modular_model_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 2f7512639f9..7f0677de119 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -58,7 +58,7 @@ def get_module_source_from_name(module_name: str) -> str: def preserve_case_replace(text, patterns: dict, default_name: str): # Create a regex pattern to match all variations regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) - compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL) + compiled_regex = re.compile(f"(? Date: Wed, 8 Jan 2025 13:01:00 +0100 Subject: [PATCH 2/9] re add examples by default --- examples/modular-transformers/modeling_dummy.py | 2 +- examples/modular-transformers/modeling_multimodal1.py | 2 +- examples/modular-transformers/modeling_my_new_model2.py | 2 +- examples/modular-transformers/modeling_super.py | 2 +- utils/create_dependency_mapping.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 3e0aa6e9b2a..382b87bd384 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -597,7 +597,7 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index c4f90a5cbad..df23a83b341 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -597,7 +597,7 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index b8d5b5eb910..9288b1a2930 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -602,7 +602,7 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 42d8108ee72..1f5aa55c469 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -519,7 +519,7 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/utils/create_dependency_mapping.py b/utils/create_dependency_mapping.py index 4a9955c976f..2cece89f8a9 100644 --- a/utils/create_dependency_mapping.py +++ b/utils/create_dependency_mapping.py @@ -7,13 +7,13 @@ def topological_sort(dependencies): new_dependencies = {} graph = defaultdict(list) for node, deps in dependencies.items(): - node_name = node.split("/")[-2] + node_name = node.rsplit("modular_", 1)[1].replace(".py", "") for dep in deps: dep_name = dep.split(".")[-2] if dep_name == node_name: # Skip self dependencies for topological sort as they create cycles continue - if "example" not in node and "auto" not in dep and node_name not in graph[dep_name]: + if "auto" not in dep and node_name not in graph[dep_name]: graph[dep_name].append(node_name) new_dependencies[node_name] = node From f9763db12e4390cde163ac43cab8c6c09e0d280f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 8 Jan 2025 17:34:48 +0100 Subject: [PATCH 3/9] Fix the bug in topological sort --- .../configuration_my_new_model.py | 4 +- .../models/diffllama/modeling_diffllama.py | 2 +- utils/create_dependency_mapping.py | 65 ++++++------------- utils/modular_model_converter.py | 8 ++- 4 files changed, 30 insertions(+), 49 deletions(-) diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py index 7042c586cbb..59637e02d3f 100644 --- a/examples/modular-transformers/configuration_my_new_model.py +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -43,7 +43,7 @@ class MyNewModelConfig(PretrainedConfig): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens, - MyNewModel 2 up to 4096, CodeMyNewModel up to 16384. + MyNewModel 2 up to 4096, CodeLlama up to 16384. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): @@ -110,7 +110,7 @@ class MyNewModelConfig(PretrainedConfig): mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. head_dim (`int`, *optional*): - The attention head dimension. If None, it will default to hidden_size // num_heads + The attention head dimension. If None, it will default to hidden_size // num_attention_heads ```python >>> from transformers import MyNewModelModel, MyNewModelConfig diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index ac2be71e5fd..725d3c31024 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -898,7 +898,7 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/utils/create_dependency_mapping.py b/utils/create_dependency_mapping.py index 2cece89f8a9..0c29f47d884 100644 --- a/utils/create_dependency_mapping.py +++ b/utils/create_dependency_mapping.py @@ -1,53 +1,30 @@ import ast from collections import defaultdict - # Function to perform topological sorting -def topological_sort(dependencies): - new_dependencies = {} - graph = defaultdict(list) +def topological_sort(dependencies: dict): + # Nodes are the name of the models to convert (we only add those to the graph) + nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies.keys()} + # This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency) + graph = {} + name_mapping = {} for node, deps in dependencies.items(): node_name = node.rsplit("modular_", 1)[1].replace(".py", "") - for dep in deps: - dep_name = dep.split(".")[-2] - if dep_name == node_name: - # Skip self dependencies for topological sort as they create cycles - continue - if "auto" not in dep and node_name not in graph[dep_name]: - graph[dep_name].append(node_name) - new_dependencies[node_name] = node - - # Create a graph and in-degree count for each node - def filter_one_by_one(filtered_list, reverse): - if len(reverse) == 0: - return filtered_list - - graph = defaultdict(list) - # Build the graph - for node, deps in reverse.items(): - for dep in deps: - graph[dep].append(node) - - base_modules = set(reverse.keys()) - set(graph.keys()) - if base_modules == reverse.keys(): - # we are at the end - return filtered_list + list(graph.keys()) - to_add = [] - for k in graph.keys(): - if len(graph[k]) == 1 and graph[k][0] in base_modules: - if graph[k][0] in reverse: - del reverse[graph[k][0]] - if k not in filtered_list: - to_add += [k] - for k in base_modules: - if k not in filtered_list: - to_add += [k] - filtered_list += list(to_add) - return filter_one_by_one(filtered_list, reverse) - - final_order = filter_one_by_one([], graph) - - return [new_dependencies.get(k) for k in final_order if k in new_dependencies] + dep_names = {dep.split(".")[-2] for dep in deps} + dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name} + graph[node_name] = dependencies + name_mapping[node_name] = node + + sorting_list = [] + while len(graph) > 0: + # Find the nodes with 0 out-degree + leaf_nodes = {node for node in graph if len(graph[node]) == 0} + # Add them to the list + sorting_list += list(leaf_nodes) + # Remove the leafs from the graph (and from the deps of other nodes) + graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes} + + return [name_mapping[x] for x in sorting_list] # Function to extract class and import info from a file diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 7f0677de119..7cba82f6df1 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1691,9 +1691,13 @@ def save_modeling_file(modular_file, converted_file): args = parser.parse_args() if args.files_to_parse == ["all"]: args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) - args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True) + if args.files_to_parse == ["examples"]: + args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True) - for file_name in find_priority_list(args.files_to_parse): + priority_list = find_priority_list(args.files_to_parse) + assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted" + + for file_name in priority_list: print(f"Converting {file_name} to a single model single file format") module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "") converted_files = convert_modular_file(file_name) From a8ae2811c2b772349815433e693df1277125ac55 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 8 Jan 2025 17:35:52 +0100 Subject: [PATCH 4/9] Update create_dependency_mapping.py --- utils/create_dependency_mapping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/create_dependency_mapping.py b/utils/create_dependency_mapping.py index 0c29f47d884..5cf38cdd1f8 100644 --- a/utils/create_dependency_mapping.py +++ b/utils/create_dependency_mapping.py @@ -1,6 +1,7 @@ import ast from collections import defaultdict + # Function to perform topological sorting def topological_sort(dependencies: dict): # Nodes are the name of the models to convert (we only add those to the graph) From ed4cbc199dcb88defd2f0ddf5ef88a28e842bd32 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 9 Jan 2025 10:53:36 +0100 Subject: [PATCH 5/9] start adding test --- .../modular/test_conversion_order.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/repo_utils/modular/test_conversion_order.py diff --git a/tests/repo_utils/modular/test_conversion_order.py b/tests/repo_utils/modular/test_conversion_order.py new file mode 100644 index 00000000000..9d7ff06ee65 --- /dev/null +++ b/tests/repo_utils/modular/test_conversion_order.py @@ -0,0 +1,43 @@ +import os +import sys + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(os.path.join(ROOT_DIR, "utils")) + +import create_dependency_mapping # noqa: E402 + +# This is equivalent to `all` in the current library state (as of 09/01/2025) +model_root = os.path.join("src", "transformers", "models") +files_to_parse = [ + os.path.join(model_root, "starcoder2", "modular_starcoder2.py"), + os.path.join(model_root, "gemma", "modular_gemma.py"), + os.path.join(model_root, "olmo2", "modular_olmo2.py"), + os.path.join(model_root, "diffllama", "modular_diffllama.py"), + os.path.join(model_root, "granite", "modular_granite.py"), + os.path.join(model_root, "gemma2", "modular_gemma2.py"), + os.path.join(model_root, "mixtral", "modular_mixtral.py"), + os.path.join(model_root, "olmo", "modular_olmo.py"), + os.path.join(model_root, "rt_detr", "modular_rt_detr.py"), + os.path.join(model_root, "qwen2", "modular_qwen2.py"), + os.path.join(model_root, "llava_next_video", "modular_llava_next_video.py"), + os.path.join(model_root, "cohere2", "modular_cohere2.py"), + os.path.join(model_root, "modernbert", "modular_modernbert.py"), + os.path.join(model_root, "colpali", "modular_colpali.py"), + os.path.join(model_root, "deformable_detr", "modular_deformable_detr.py"), + os.path.join(model_root, "aria", "modular_aria.py"), + os.path.join(model_root, "ijepa", "modular_ijepa.py"), + os.path.join(model_root, "bamba", "modular_bamba.py"), + os.path.join(model_root, "dinov2_with_registers", "modular_dinov2_with_registers.py"), + os.path.join(model_root, "instructblipvideo", "modular_instructblipvideo.py"), + os.path.join(model_root, "glm", "modular_glm.py"), + os.path.join(model_root, "phi", "modular_phi.py"), + os.path.join(model_root, "mistral", "modular_mistral.py"), +] + +# Find the order +priority_list = create_dependency_mapping.find_priority_list(files_to_parse) +# Extract just the model names +model_priority_list = [] + +def appear_after(model1: str, file2: str) -> bool: + pass From 76ebe751690bb743fa4817096074317093a23587 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 9 Jan 2025 16:43:52 +0100 Subject: [PATCH 6/9] finalize test --- .../modular/test_conversion_order.py | 77 +++++++++++-------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/tests/repo_utils/modular/test_conversion_order.py b/tests/repo_utils/modular/test_conversion_order.py index 9d7ff06ee65..101655020da 100644 --- a/tests/repo_utils/modular/test_conversion_order.py +++ b/tests/repo_utils/modular/test_conversion_order.py @@ -1,5 +1,6 @@ import os import sys +import unittest ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) sys.path.append(os.path.join(ROOT_DIR, "utils")) @@ -7,37 +8,51 @@ import create_dependency_mapping # noqa: E402 # This is equivalent to `all` in the current library state (as of 09/01/2025) -model_root = os.path.join("src", "transformers", "models") -files_to_parse = [ - os.path.join(model_root, "starcoder2", "modular_starcoder2.py"), - os.path.join(model_root, "gemma", "modular_gemma.py"), - os.path.join(model_root, "olmo2", "modular_olmo2.py"), - os.path.join(model_root, "diffllama", "modular_diffllama.py"), - os.path.join(model_root, "granite", "modular_granite.py"), - os.path.join(model_root, "gemma2", "modular_gemma2.py"), - os.path.join(model_root, "mixtral", "modular_mixtral.py"), - os.path.join(model_root, "olmo", "modular_olmo.py"), - os.path.join(model_root, "rt_detr", "modular_rt_detr.py"), - os.path.join(model_root, "qwen2", "modular_qwen2.py"), - os.path.join(model_root, "llava_next_video", "modular_llava_next_video.py"), - os.path.join(model_root, "cohere2", "modular_cohere2.py"), - os.path.join(model_root, "modernbert", "modular_modernbert.py"), - os.path.join(model_root, "colpali", "modular_colpali.py"), - os.path.join(model_root, "deformable_detr", "modular_deformable_detr.py"), - os.path.join(model_root, "aria", "modular_aria.py"), - os.path.join(model_root, "ijepa", "modular_ijepa.py"), - os.path.join(model_root, "bamba", "modular_bamba.py"), - os.path.join(model_root, "dinov2_with_registers", "modular_dinov2_with_registers.py"), - os.path.join(model_root, "instructblipvideo", "modular_instructblipvideo.py"), - os.path.join(model_root, "glm", "modular_glm.py"), - os.path.join(model_root, "phi", "modular_phi.py"), - os.path.join(model_root, "mistral", "modular_mistral.py"), +MODEL_ROOT = os.path.join("src", "transformers", "models") +FILES_TO_PARSE = [ + os.path.join(MODEL_ROOT, "starcoder2", "modular_starcoder2.py"), + os.path.join(MODEL_ROOT, "gemma", "modular_gemma.py"), + os.path.join(MODEL_ROOT, "olmo2", "modular_olmo2.py"), + os.path.join(MODEL_ROOT, "diffllama", "modular_diffllama.py"), + os.path.join(MODEL_ROOT, "granite", "modular_granite.py"), + os.path.join(MODEL_ROOT, "gemma2", "modular_gemma2.py"), + os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"), + os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"), + os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"), + os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"), + os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"), + os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"), + os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"), + os.path.join(MODEL_ROOT, "colpali", "modular_colpali.py"), + os.path.join(MODEL_ROOT, "deformable_detr", "modular_deformable_detr.py"), + os.path.join(MODEL_ROOT, "aria", "modular_aria.py"), + os.path.join(MODEL_ROOT, "ijepa", "modular_ijepa.py"), + os.path.join(MODEL_ROOT, "bamba", "modular_bamba.py"), + os.path.join(MODEL_ROOT, "dinov2_with_registers", "modular_dinov2_with_registers.py"), + os.path.join(MODEL_ROOT, "instructblipvideo", "modular_instructblipvideo.py"), + os.path.join(MODEL_ROOT, "glm", "modular_glm.py"), + os.path.join(MODEL_ROOT, "phi", "modular_phi.py"), + os.path.join(MODEL_ROOT, "mistral", "modular_mistral.py"), ] -# Find the order -priority_list = create_dependency_mapping.find_priority_list(files_to_parse) -# Extract just the model names -model_priority_list = [] -def appear_after(model1: str, file2: str) -> bool: - pass +def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool: + """Return True if `model1` appear after `model2` in `priority_list`.""" + return priority_list.index(model1) > priority_list.index(model2) + + +class ConversionOrderTest(unittest.TestCase): + + def test_conversion_order(self): + # Find the order + priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE) + # Extract just the model names + model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list] + + # These are based on what the current library order should be (as of 09/01/2025) + self.assertTrue(appear_after("mixtral", "mistral", model_priority_list)) + self.assertTrue(appear_after("gemma2", "gemma", model_priority_list)) + self.assertTrue(appear_after("starcoder2", "mistral", model_priority_list)) + self.assertTrue(appear_after("olmo2", "olmo", model_priority_list)) + self.assertTrue(appear_after("diffllama", "mistral", model_priority_list)) + self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list)) \ No newline at end of file From 2916d3dff926333a8b4674dbfeb30877c9e656ea Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 9 Jan 2025 16:47:26 +0100 Subject: [PATCH 7/9] more tests --- tests/repo_utils/modular/test_conversion_order.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/repo_utils/modular/test_conversion_order.py b/tests/repo_utils/modular/test_conversion_order.py index 101655020da..a1df9b12509 100644 --- a/tests/repo_utils/modular/test_conversion_order.py +++ b/tests/repo_utils/modular/test_conversion_order.py @@ -33,6 +33,8 @@ os.path.join(MODEL_ROOT, "glm", "modular_glm.py"), os.path.join(MODEL_ROOT, "phi", "modular_phi.py"), os.path.join(MODEL_ROOT, "mistral", "modular_mistral.py"), + os.path.join(MODEL_ROOT, "phi3", "modular_phi3.py"), + os.path.join(MODEL_ROOT, "cohere", "modular_cohere.py"), ] @@ -55,4 +57,6 @@ def test_conversion_order(self): self.assertTrue(appear_after("starcoder2", "mistral", model_priority_list)) self.assertTrue(appear_after("olmo2", "olmo", model_priority_list)) self.assertTrue(appear_after("diffllama", "mistral", model_priority_list)) - self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list)) \ No newline at end of file + self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list)) + self.assertTrue(appear_after("cohere2", "cohere", model_priority_list)) + self.assertTrue(appear_after("phi3", "mistral", model_priority_list)) \ No newline at end of file From 36824d34b749e94806cb4b4d22bff384bc0e418b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 9 Jan 2025 16:47:46 +0100 Subject: [PATCH 8/9] style --- tests/repo_utils/modular/test_conversion_order.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/repo_utils/modular/test_conversion_order.py b/tests/repo_utils/modular/test_conversion_order.py index a1df9b12509..f5e133ce1fe 100644 --- a/tests/repo_utils/modular/test_conversion_order.py +++ b/tests/repo_utils/modular/test_conversion_order.py @@ -2,11 +2,13 @@ import sys import unittest + ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) sys.path.append(os.path.join(ROOT_DIR, "utils")) import create_dependency_mapping # noqa: E402 + # This is equivalent to `all` in the current library state (as of 09/01/2025) MODEL_ROOT = os.path.join("src", "transformers", "models") FILES_TO_PARSE = [ @@ -44,7 +46,6 @@ def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool: class ConversionOrderTest(unittest.TestCase): - def test_conversion_order(self): # Find the order priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE) @@ -59,4 +60,4 @@ def test_conversion_order(self): self.assertTrue(appear_after("diffllama", "mistral", model_priority_list)) self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list)) self.assertTrue(appear_after("cohere2", "cohere", model_priority_list)) - self.assertTrue(appear_after("phi3", "mistral", model_priority_list)) \ No newline at end of file + self.assertTrue(appear_after("phi3", "mistral", model_priority_list)) From d3810cf2af048efac6b9bd96b4bc8d70ce5110ae Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 9 Jan 2025 17:01:41 +0100 Subject: [PATCH 9/9] style --- src/transformers/models/diffllama/modeling_diffllama.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 725d3c31024..7d6d6af3824 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -612,11 +612,7 @@ def _init_weights(self, module): class DiffLlamaRotaryEmbedding(nn.Module): - def __init__( - self, - config: DiffLlamaConfig, - device=None, - ): + def __init__(self, config: DiffLlamaConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: