From 2c8b87cd451dc9f7969bef3f39ca3084baaa9523 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 28 Oct 2024 14:24:26 +0100 Subject: [PATCH 01/40] rework converter --- utils/modular_model_converter.py | 1168 ++++++++++++++++-------------- 1 file changed, 619 insertions(+), 549 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index bda143c2577..966054cbe22 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -19,6 +19,7 @@ import re from collections import defaultdict, deque from typing import Dict, List, Optional, Set +from abc import ABC, abstractmethod import libcst as cst from check_copies import run_ruff @@ -34,13 +35,6 @@ logger = logging.get_logger(__name__) -# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the -# value from the dependency is used, then mapped to current name convention, resulting in wrong value. -# The corresponding mapped value is used to define the file target for the assignment -ASSIGNMENTS_TO_KEEP = { - "_CHECKPOINT_FOR_DOC": "modeling", -} - AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from {relative_path}. # Do NOT edit this file manually as any edits will be overwritten by the generation of @@ -61,8 +55,141 @@ def get_module_source_from_name(module_name: str) -> str: return source_code -class ClassFinder(CSTVisitor): - """A visitor class which analyses a module, creating a mapping of dependencies between classes and functions. +def find_all_dependencies( + dependency_mapping: Dict[str, set], + start_entity: str | None = None, + initial_dependencies: set | None = None, + initial_checked_dependencies: set | None = None, + return_parent: bool = False, +) -> list | set: + """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of + BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. + + Args: + dependency_mapping (`Dict[str, set]`): + A mapping from entities (usually function names), to immediate dependencies. That is, for function names, + a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called + in `foo`'s definition. + start_entity (str | None, *optional*): + A key of `dependency_mapping`, indicating from which entity to start the search. + initial_dependencies (set | None, *optional*): + If `start_entity` is not provided, this can be used as an alternative. In this case, `initial_dependencies` + the search will continue from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. + initial_checked_dependencies (set | None, *optional*): + If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. + return_parent (bool, *optional*): + If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note + that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. + Returns: + A set of all the dependencies, or a list containing parents as well if `return_parent=True`. + + Example: + Given the following structure in the `modular_xxx.py` file: + ``` + def foo1(): + pass + + def foo2(): + pass + + def bar(): + foo1() + + def foobar(): + bar() + foo2() + + class MyLayer(SomeOtherModelLayer): + def forward(...): + foobar() + ``` + and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: + ``` + dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} + find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) + >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] + ``` + That is, all the functions needed (and potentially their immediate parent) so that the function to be added + in MyLayer (`foobar`) can work correctly. + """ + if initial_dependencies is None and start_entity is not None: + initial_dependencies = dependency_mapping[start_entity] + if initial_checked_dependencies is None: + initial_checked_dependencies = set() + + dependency_queue = deque(initial_dependencies) + all_dependencies = set() + all_dependencies_with_parent = [] + checked_dependencies = set(initial_checked_dependencies) + parents = {initial_dep: start_entity for initial_dep in initial_dependencies} + while len(dependency_queue) > 0: + # Pick element to visit + current = dependency_queue.popleft() + if current not in checked_dependencies: + # Add the dependencies + all_dependencies.add(current) + all_dependencies_with_parent += [(current, parents[current])] + if current in dependency_mapping.keys(): + # Update dependency queue + dependency_queue.extend(dependency_mapping[current]) + parents.update({dep: current for dep in dependency_mapping[current]}) + # add visited node to the list + checked_dependencies.add(current) + + if not return_parent: + return all_dependencies + # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) + return all_dependencies_with_parent + + +# These top-level variables will always use the value in the `modular_xxx.py` file +ASSIGNMENTS_TO_KEEP = { + "_CHECKPOINT_FOR_DOC", +} + +class ClassDependencyMapper(CSTVisitor): + """A visitor which analyzes classes to get their dependencies. If `global_names` is passed, only dependencies + present in the `global_names` will be added. + This class is used through the 3 convenient class methods allowing to get the dependencies of a given class node. + The `ClassFinder` uses it to get dependencies of classes. + """ + + def __init__(self, class_name: str, global_names: set | None): + super().__init__() + self.class_name = class_name + self.dependencies = set() + self.global_names = global_names + + def visit_Name(self, node): + if node.value != self.class_name and node.value in self.global_names: + self.dependencies.add(node.value) + + @classmethod + def dependencies_for_node(cls, node: cst.ClassDef, global_names: set) -> set: + """Create dependencies for a node in the `ModuleMapper`.""" + temp_module = cst.Module(body=[node]) + visitor = cls(node.name.value, global_names) + temp_module.visit(visitor) + return visitor.dependencies + + @classmethod + def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMapper") -> set: + """Create dependencies for a node in the `ModularFileMapper` (which may have been changed by + `replace_call_to_super`). + """ + temp_module = cst.Module(body=[updated_node]) + visitor = cls(updated_node.name.value, set(mapper.global_nodes.keys())) + temp_module.visit(visitor) + return mapper.augment_dependencies_with_functions(visitor.dependencies) + + +class ModuleMapper(CSTVisitor, ABC): + """A visitor class which analyses a module, creating a mapping of dependencies for classes and functions. + The `full_class_dependency_mapping` created contains 1st-level classes and assignments dependencies, as well + as all (recursively) functions dependencies. + The `function_call_recursive_dependecy_mapping` created contains all function definitions, and all their (recursively) + dependencies. + For example if the visited code has ```python3 def init_value(): return 1 @@ -72,15 +199,10 @@ def __init__(self): super().__init__(self) self.value = init_value() ``` - then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]} + then the `class_dependency_mapping` should be: `{"LlamaModel": {"PreTrainedModel", "init_value"}} - The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by - checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the - dependence parent -> child. - - When visiting such nodes, we update the dependency of the parent node, to take into account the visited node. - - All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX. + The dependency mapping is updated via the `ClassDependencyMapper`, then augmented with `augment_dependencies_with_functions` + to get all functions dependencies. """ METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) @@ -88,32 +210,15 @@ def __init__(self): def __init__(self, python_module: cst.Module): # fmt: off self.python_module: cst.Module = python_module # original cst.Module being visited - self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node + self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes self.imports = {} # stores all import statements - self.function_def = {} # stores global scope function definition - self.assignments = {} # LLAMA_DOCSTRING - self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] - self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] + self.functions = {} # mapping of global scope function names to Nodes + self.assignments = {} # mapping of global assignments names to Nodes + self.class_dependency_mapping = {} # mapping of function name to immediate dependencies + self.current_function = None + self.function_call_dependency_mapping = defaultdict(set) # fmt: on - def _update_class_dependency(self, name, value): - """Update the dependency mapping for `name` with `value` by appending the previous - dependencies to the new `value`. - """ - dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value}) - self.first_lvl_dependency_mapping[name] = dep - - dep = set(self.class_dependency_mapping.get(value, set())) - dep |= set(self.class_dependency_mapping.get(name, {})) | set({value}) - self.class_dependency_mapping[name] = dep - - def visit_ClassDef(self, node: ClassDef) -> None: - """We don't have non global scope class defs in transformers. Here we add the inheritance dependencies""" - self.classes[node.name.value] = node - for k in node.bases: # deal with inheritance - base_name = self.python_module.code_for_node(k) - self._update_class_dependency(node.name.value, base_name) - def visit_SimpleStatementLine(self, node): """ Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements @@ -124,74 +229,202 @@ def visit_SimpleStatementLine(self, node): ): left_hand_side = node.body[0].targets[0].target if hasattr(left_hand_side, "value"): - if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys(): - self.assignments[left_hand_side.value] = node + self.assignments[left_hand_side.value] = node else: for idx, target in enumerate(list(left_hand_side.elements)): - if target.value.value not in ASSIGNMENTS_TO_KEEP.keys(): - self.assignments[target.value.value] = node.body[0].value.elements[idx].value + self.assignments[target.value.value] = node.body[0].value.elements[idx].value if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): self.imports[node.body[0].names] = node def visit_FunctionDef(self, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) if m.matches(parent_node, m.Module()): - self.function_def[node.name.value] = node + self.current_function = node.name.value + self.functions[node.name.value] = node + + def leave_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_function = None def leave_If(self, node): for stmt in node.body.body: if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): self.imports[stmt.body[0].names] = node - def leave_Name(self, node): - if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys(): - parent = self.get_metadata(cst.metadata.ScopeProvider, node) - if not isinstance(parent, cst.metadata.scope_provider.GlobalScope): - self._update_class_dependency(parent._name_prefix.split(".")[0], node.value) - - def leave_Arg(self, node): - if m.matches(node.value, m.Name()): - parent = self.get_metadata(ParentNodeProvider, node) - if m.matches(parent, m.ClassDef()) and parent.bases: - self._update_class_dependency(parent.name.value, node.value.value) - - def leave_Dict(self, node): - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent, m.Assign(targets=[m.AssignTarget()])): - name = parent.targets[0].target.value - if name in self.assignments: - for k in node.elements: - dep_name = k.value.value - if dep_name in self.classes: - self._update_class_dependency(name, dep_name) - - def leave_Decorator(self, node): - if hasattr(node.decorator, "args"): - for k in node.decorator.args: - if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value: - if k.value.func.value.value not in self.assignments: - raise ValueError( - f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}" - ) - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - scope = self.get_metadata(cst.metadata.ScopeProvider, node) - name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value - self._update_class_dependency(name, k.value.func.value.value) - elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments: - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - scope = self.get_metadata(cst.metadata.ScopeProvider, node) - name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value - self._update_class_dependency(name, k.value.value) + def visit_ClassDef(self, node: ClassDef) -> None: + """Record class nodes to create their dependencies at the end.""" + self.classes[node.name.value] = node + + def visit_Call(self, node: cst.Call): + """This is used to create a mapping from top-level functions to functions called inside them. + Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, + add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible. + """ + if self.current_function is not None: + # Simple function calls such as foo() + if m.matches(node.func, m.Name()): + self.function_call_dependency_mapping[self.current_function].add(node.func.value) def leave_Module(self, node): - """When leaving the module, we store the position of each global scoped node (Assigns, function def and class def) - to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this. + """When leaving the module, we finally create the `function_call_recursive_dependency_mapping`, then we + compute the dependencies for all recorded classes based on all the nodes we visited. + We also store the position of each global scoped node to allow sorting the dependencies based on their + position in the code later. We use the PositionProvider metadata wrapper for this. """ - self.global_nodes = {**self.assignments, **self.classes, **self.function_def} + # assign all nodes + self.global_nodes = {**self.assignments, **self.classes, **self.functions} # now sort the class dependency_mapping based on the position of the nodes - self.class_start_line = {} + self.start_lines = {} for id, node in self.global_nodes.items(): - self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + + def _compute_recursive_function_dependencies(self) -> dict[str, set]: + """Based on the 1st level function dependency mapping, create the recursive dependency mapping.""" + recursive_dependencies = {} + for function_name in self.function_call_dependency_mapping.keys(): + # We need to check if they are present in self.functions to avoid built-in functions + all_dependencies = { + dep + for dep in find_all_dependencies(self.function_call_dependency_mapping, start_entity=function_name) + if dep in self.functions.keys() + } + recursive_dependencies[function_name] = all_dependencies + return recursive_dependencies + + def augment_dependencies_with_functions(self, dependencies: set) -> set: + """For a set of `dependencies`, augment them by adding all potential functions which are dependencies of + the functions present in the `dependencies`. + """ + new_dependencies = dependencies.copy() + # Go through the set of dependencies + for dep in tuple(dependencies): + if dep in self.function_call_recursive_dependency_mapping.keys(): + new_dependencies.update(self.function_call_recursive_dependency_mapping[dep]) + return new_dependencies + + def compute_class_dependencies(self): + """For each visited class, find its dependencies based on visited the current file + potential merged dependencies. + Note: This function takes care of updating `global_nodes` and `function_call_recursive_dependency_mapping` as well after the + merge with other files dependencies. + """ + # Correctly re-set the global nodes at this point + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # Create the global mapping of recursive dependencies for functions + self.function_call_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + + for class_name, class_node in self.classes.items(): + dependencies = ClassDependencyMapper.dependencies_for_node(class_node, set(self.global_nodes.keys())) + # Corretcly augment class dependencies with all needed functions + self.class_dependency_mapping[class_name] = self.augment_dependencies_with_functions(dependencies) + + @abstractmethod + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + pass + + +class ModelFileMapper(ModuleMapper): + + def __init__(self, python_module: cst.Module): + super().__init__(python_module) + + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + """Compute the relative order that the `missing_dependencies` should have between themselves in the output file. + """ + relative_order = {} + idx = 0 + classes = sorted([dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_line[x]) + # This is because for merged dependencies, we only have relative order in the other visited file, so we need + # to track dependency order relative to a given class + if len(classes) > 0 and not hasattr(self, "full_class_dependency_mapping"): + raise ValueError("Cannot correctly find the relative order of the dependencies.") + + remaining_dependencies = missing_dependencies.copy() + + # Start by tracking relative order class by class + for class_name in classes: + class_dependencies = tuple(self.full_class_dependency_mapping[class_name] & remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + # We need to differentiate between nodes that were already present (we can get relative order globally) and + # nodes that were merged (we can get relative order only relative to the class the dependencies relate to) + for class_dep in class_dependencies: + if class_dep in self.modular_file_start_lines: + merged_dependencies.append(class_dep) + else: + original_dependencies.append(class_dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + remaining_dependencies.remove(dep) + relative_order[dep] = idx + idx += 1 + # Add the class itself + remaining_dependencies.remove(class_name) + relative_order[class_name] = idx + idx += 1 + + # Now add what still remains + remaining_dependencies = tuple(remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + for dep in remaining_dependencies: + if dep in self.modular_file_start_lines: + merged_dependencies.append(dep) + else: + original_dependencies.append(dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + relative_order[dep] = idx + idx += 1 + + return relative_order + + def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapping: dict[str, set]): + """Update the global nodes and function dependency mapping with those from the modular file. + + Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies + instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). + """ + # Add/overwrite all needed function nodes and dependencies + self.functions.update(functions) + self.function_call_dependency_mapping.update(function_call_mapping) + + def _merge_assignments(self, assignments: dict[str, cst.CSTNode]): + """Update the global nodes with the assignment from the modular file. + + Merging rule: if any assignment with the same name was redefined in the modular, we use it ONLY if it is + in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value. This rule was chosen to avoid having to rewrite the + big docstrings. + """ + for assignment, node in assignments.items(): + if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments: + self.assignments[assignment] = node + + def merge_modular_dependencies(self, functions, function_mapping, assignments, start_lines): + """Merge both functions and assignments from the modular definitions into the current module file, + then compute the relative order of all nodes.""" + self._merge_functions(functions, function_mapping) + self._merge_assignments(assignments) + self.modular_file_start_lines = start_lines + + @classmethod + def visit_and_merge_dependencies(cls, module: cst.Module, functions, function_mapping, assignments, start_lines) -> "ModelFileMapper": + wrapper = MetadataWrapper(module) + mapper = cls(module) + wrapper.visit(mapper) + # Merge dependencies + mapper.merge_modular_dependencies(functions, function_mapping, assignments, start_lines) + # Create the class dependencies graph + mapper.compute_class_dependencies() + return mapper class ReplaceNameTransformer(m.MatcherDecoratableTransformer): @@ -210,8 +443,6 @@ def __init__( new_name, given_old_name=None, given_new_name=None, - old_class_name: str = None, - new_class_name: str = None, ): super().__init__() self.old_name = old_name @@ -232,14 +463,6 @@ def __init__( self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "") if self.default_old_name.isupper(): self.default_old_name = self.default_old_name.capitalize() - if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns: - # In last recourse, when the suffix of the new class is not the same as the old class, - # and if the old and new classes start with the default name, we keep the default class name - # and replace the old suffix with the new one. - # Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration` - # where a model extends another model, but is used for a different task. - if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name): - self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :] def preserve_case_replace(self, text): # Create a regex pattern to match all variations @@ -271,33 +494,6 @@ def leave_ClassDef(self, original_node, updated_node): return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) -def find_classes_in_file( - module: cst.Module, - old_id="llama", - new_id="gemma", - given_old_name=None, - given_new_name=None, - old_class_name=None, - new_class_name=None, -): - """Helper function to rename and then parse a source file using the ClassFinder""" - transformer = ReplaceNameTransformer( - old_id, - new_id, - given_old_name=given_old_name, - given_new_name=given_new_name, - old_class_name=old_class_name, - new_class_name=new_class_name, - ) - new_module = module.visit(transformer) - - wrapper = MetadataWrapper(new_module) - - class_finder = ClassFinder(new_module) - wrapper.visit(class_finder) - return class_finder - - DOCSTRING_NODE = m.SimpleStatementLine( body=[ m.Expr( @@ -437,7 +633,6 @@ def update_body(self, existing_body, new_statements): if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): target = self.python_module.code_for_node(node.body[0].target) self.deleted_targets[target] = node - continue for stmt in existing_body: if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): @@ -447,6 +642,9 @@ def update_body(self, existing_body, new_statements): continue if target in self.all_assign_target: stmt = self.all_assign_target[target] + # Skip the docstring (will be added later on, at the beginning) + elif m.matches(stmt, DOCSTRING_NODE): + continue comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() deduplicated_new_body.append(stmt) @@ -456,17 +654,41 @@ def update_body(self, existing_body, new_statements): code = self.python_module.code_for_node(node) comment_less_code = re.sub(r"#.*", "", code).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if ( - node not in deduplicated_new_body - and "super().__init__" not in comment_less_code - and comment_less_code not in existing_nodes - ): + if node not in deduplicated_new_body and comment_less_code not in existing_nodes: if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): - # HACK here to fix the pos_init() that has to be last we kinda do this. - deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:] + deduplicated_new_body.append(node) existing_nodes.add(comment_less_code) + + # Fix the post_init() that has to be last + for i, node in enumerate(deduplicated_new_body): + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "self.post_init(" in comment_less_code and i < len(deduplicated_new_body) - 1: + # Remove it and add it again at the end + deduplicated_new_body.pop(i) + deduplicated_new_body.append(node) + break + return deduplicated_new_body + def _fix_init_location(self, new_body): + """Fix the location of the super()__init__ in the new body, if we had new statements before it.""" + start_index = 0 + for i, node in enumerate(new_body): + if m.matches(node, DOCSTRING_NODE) and i == start_index: + start_index += 1 + continue + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "super().__init__" in comment_less_code and i > start_index: + # Remove it and add it again at the top after the docstrings + node = new_body.pop(i) + new_body = new_body[:start_index] + [node] + new_body[start_index:] + break + return new_body + def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: """Updates the body of the input `node`'s `func_name` function by replacing calls to super().func_name() with the source code of the parent class' `func_name`. @@ -479,10 +701,11 @@ def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CS new_body = [] has_super_call = False - for expr in node.body: + for i, expr in enumerate(node.body): if is_call_to_super(expr, func_name): has_super_call = True - new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body)) + new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) + new_body = self._fix_init_location(new_body) else: expr = expr.visit(self.transformer) if m.matches(expr, DOCSTRING_NODE): @@ -524,8 +747,8 @@ def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> c return updated_node -def replace_call_to_super( - class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str] +def replace_class_node( + mapper: ModelFileMapper, updated_node: cst.ClassDef, class_name: str, all_bases: List[str] ): """ Given the `class_name`, the `updated_node`'s call to super are unpacked. @@ -547,13 +770,13 @@ def replace_call_to_super( | self.post_init() | ``` """ - original_node = class_finder.classes[class_name] + original_node = mapper.classes[class_name] original_methods = { - f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f + f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in original_node.body.body } updated_methods = { - f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f + f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in updated_node.body.body } end_meth = [] @@ -562,7 +785,7 @@ def replace_call_to_super( docstring_node = [] # Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict for func in original_node.body.body: - name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) + name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None: new_params = updated_methods[name].params # Replace the method in the replacement class, preserving decorators @@ -573,19 +796,23 @@ def replace_call_to_super( new_params = new_params.with_changes( params=list(parent_params.values()), star_kwarg=func.params.star_kwarg ) + # Keep decorators in `modular_xxx.py` if any, else original decorators + new_decorators = ( + updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators + ) if not re.match( r"\ndef .*\(.*\):\n raise.*Error\(.*", - class_finder.python_module.code_for_node(updated_methods[name]), + mapper.python_module.code_for_node(updated_methods[name]), ): - func = func.with_changes(body=updated_methods[name].body, params=new_params) + func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators) else: continue if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): - target = class_finder.python_module.code_for_node(func.body[0].targets[0]) + target = mapper.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = class_finder.python_module.code_for_node(func.body[0].target) + target = mapper.python_module.code_for_node(func.body[0].target) assign_targets[target] = func elif m.matches(func, DOCSTRING_NODE): docstring_node = [func] @@ -594,7 +821,7 @@ def replace_call_to_super( # Port new methods that are defined only in modular-file and append at the end for func in updated_node.body.body: - name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) + name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class! # Extract the original docstring updated_docstring = func.body[0].value.value @@ -608,10 +835,10 @@ def replace_call_to_super( end_meth.append(func) if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): # TODO we only use single assign might cause issues - target = class_finder.python_module.code_for_node(func.body[0].targets[0]) + target = mapper.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = class_finder.python_module.code_for_node(func.body[0].target) + target = mapper.python_module.code_for_node(func.body[0].target) assign_targets[target] = func end_meth = docstring_node + list(assign_targets.values()) + end_meth @@ -623,7 +850,13 @@ def replace_call_to_super( ) new_replacement_body = new_replacement_class.body[0].body # get the indented block - return original_node.with_changes(body=new_replacement_body) + # Use decorators redefined in `modular_xxx.py` if any + new_decorators = updated_node.decorators if len(updated_node.decorators) > 0 else original_node.decorators + + # Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`) + name = updated_node.name + + return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name) TYPE_TO_FILE_TYPE = { @@ -660,119 +893,26 @@ def get_new_part(class_name, base_class): return snake_case -def find_all_dependencies(function: str, dependency_mapping: Dict[str, set]): - """Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file: - ``` - def foo1(): - pass - - def foo2(): - pass - - def bar(): - foo1() - - def foobar(): - bar() - foo2() - - class MyLayer(SomeOtherModelLayer): - def forward(...): - foobar() - ``` - and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: - ``` - dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} - find_all_dependencies('foobar', dependency_mapping) - >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] - ``` - That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can - work correctly. - """ - all_dependencies = deque(dependency_mapping[function]) - all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]] - checked_dependencies = set(function) - while len(all_dependencies) > 0: - # Pick element to visit - parent = all_dependencies.popleft() - if parent not in checked_dependencies: - # Update dependencies - all_dependencies.extend(dependency_mapping[parent]) - all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]] - # add visited node to the list - checked_dependencies.add(parent) +# These top-level variables will always appear the very beginning of the file, in the order they are defined in +# this list (this is to avoid having variables at weird places, even if they are not used before) +VARIABLES_AT_THE_BEGINNING = [ + "logger", + "_CHECKPOINT_FOR_DOC", + "_CONFIG_FOR_DOC", +] - # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) - return all_dependencies_with_parent - - -class PostModularConverterCleaner(CSTTransformer): - """Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due - to dependency mapping, even if code parts with those functions/classes were overwritten)""" - - METADATA_DEPENDENCIES = (ParentNodeProvider,) - - def __init__(self, added_dependencies: set): - super().__init__() - self.top_level_functions_or_classes = {} - self.all_used_functions_or_classes = set() - self.added_dependencies = added_dependencies - - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.top_level_functions_or_classes[node.name.value] = node - - def visit_ClassDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.top_level_functions_or_classes[node.name.value] = node - - def visit_Name(self, node: cst.Name): - """This is used to find any mention of a top-level function or class except its own definition. - It will contain other names as well, but those will not be used. This is the most general way to do it - since mentions may appear in a lot of different contexts (apart from simple Call to the function/class). - e.g. Attention classes are only mentionned by their name in a dict assignment. - """ - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - - if not ( - (m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value) - or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value) - ): - self.all_used_functions_or_classes.add(node.value) - - def leave_Module(self, original_node: cst.Module, node): - # Find any class/function that was mistakenly added as part of the dependencies and remove it - unused = self.added_dependencies - self.all_used_functions_or_classes - nodes_to_remove = [ - self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes - ] - new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove] - # Return a new module with the updated body - return node.with_changes(body=new_body) - - -class ModularConverterTransformer(CSTTransformer): +class ModularFileMapper(ModuleMapper): METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None): - super().__init__() - self.model_name = ( - new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3` - ) + super().__init__(python_module) + self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3` self.given_old_name = given_old_name self.given_new_name = given_new_name - # fmt: off - self.python_module = python_module # we store the original module to use `code_for_node` - self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module - self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"} - self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" - self.inserted_deps = [] # nodes inserted via super dependency - self.all_imports = [] # just stores all of the imports - self.all_safe_imports = [] # stores the import under simple statements - self.global_scope_index = 0 - # fmt: on + + self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} + self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} + self.files = { # mapping for different component bodies "modeling": {}, "configuration": {}, @@ -782,15 +922,8 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= "feature_extractor": {}, } self.match_patterns = "|".join(self.files.keys()) - self.all_definitions = {} - self.class_to_file_type = {} - self.current_class = None # keep track of current top-level class during visit - self.current_top_level_function = None # keep track of current top-level function during visit - # Mapping from top-level functions to classes using them - self.function_call_class_mapping = defaultdict(lambda: set()) - # Mapping from top-level functions to other top-level functions dependencies - self.function_call_dependency_mapping = defaultdict(lambda: set()) - self.added_dependencies = set() + self.all_imports = [] + self.all_all_to_add = {} def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from `transformers.models.xxx` we need to: @@ -799,6 +932,8 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: 3. Add this import to `self.transformers_imports` as visited to not parse it twice """ import_statement = self.python_module.code_for_node(node.module) + if "auto.modeling_auto" in import_statement: + return if m.matches(node.module, m.Attribute()): for imported_ in node.names: _import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement) @@ -808,74 +943,167 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: raise ValueError( f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" ) - if import_statement not in self.transformers_imports: + if import_statement not in self.model_specific_modules: if "models" not in import_statement: import_statement = "models." + import_statement if "transformers" not in import_statement: import_statement = "transformers." + import_statement source_code = get_module_source_from_name(import_statement) tree = cst.parse_module(source_code) - self.transformers_imports[import_statement] = tree - imported_class = self.python_module.code_for_node(imported_.name) - self.imported_mapping[imported_class] = import_statement + self.model_specific_modules[import_statement] = tree + imported_object = self.python_module.code_for_node(imported_.name) + self.model_specific_imported_objects[imported_object] = import_statement if m.matches(node.module, m.Name()): if "transformers" == import_statement: raise ValueError( f"You are importing from {import_statement} directly using global imports. Import from the correct local path" ) - def leave_SimpleStatementLine(self, original_node, updated_node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + def visit_SimpleStatementLine(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + simple_top_level_assign_structure = m.SimpleStatementLine( + body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] + ) if m.matches(parent_node, m.Module()): - if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) - return updated_node - elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])): - full_statement = self.python_module.code_for_node(updated_node.body[0].module) - if re.search( - rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement - ): # OR MATCH ..llama.modeling_llama - return cst.RemoveFromParent() - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) - return updated_node - elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): - if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys(): - file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value] - self.files[file_][original_node.body[0].targets[0].target.value] = { - "node": original_node, - "insert_idx": self.global_scope_index, - } - self.global_scope_index += 100 - return updated_node - - def visit_ClassDef(self, node: cst.ClassDef): - """Used to keep track of current class""" - self.current_class = node.name.value + if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): + if node not in self.all_imports: + self.all_imports.append(node) + elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): + full_statement = self.python_module.code_for_node(node.body[0].module) + if ( + # OR MATCH ..llama.modeling_llama + re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement) + and "auto.modeling_auto" not in full_statement + ): + return + if node not in self.all_imports: + self.all_imports.append(node) + elif m.matches(node, simple_top_level_assign_structure): + assigned_variable = node.body[0].targets[0].target.value + # __all__ is treated differently and not added to general assignments + if assigned_variable != "__all__": + self.assignments[assigned_variable] = node + else: + assign_node = node.body[0] + if isinstance(assign_node.value, cst.List): + # Extract the elements from the list + all_all_to_add = defaultdict(list) + for element in assign_node.value.elements: + if isinstance(element.value, cst.SimpleString): + # Remove quotes and add the string to the elements list + class_name = element.value.value + file = self.find_file_type(class_name) + all_all_to_add[file] += [class_name] + for file, new_alls in all_all_to_add.items(): + new_node = assign_node.with_changes( + value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) + ) + self.all_all_to_add[file] = node.with_changes(body=[new_node]) - def leave_ClassDef(self, original_node, updated_node): + def leave_Module(self, node): + """When leaving the module, we finally create the `function_call_recursive_dependency_mapping`, then we + compute the dependencies for all recorded classes based on all the nodes we visited. + We also store the position of each global scoped node to allow sorting the dependencies based on their + position in the code later. We use the PositionProvider metadata wrapper for this. """ - 1. Filter the `base` classes of this class - If they are from `transformers.models.xx` then: - - take the AST tree of the module it comes from and parse it with a `ClassFinder`. - - rename all every instance of `old_name` (llama) to `new_name` (gemma) - 2. We insert the modules which the inherited base depends on. This has to be done in - the order of the dependencies. If on is already in the new_body (because it's defined in the diff file) - then we remove it from the new body to add it again in the correct order. - 3. Replace the calls to `super().xxxx` merging parent code + super().leave_Module(node) + + # Now, visit every model-specific files found in the imports, and merge their dependencies + self.visited_modules = {} + for file, module in self.model_specific_modules.items(): + file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0] + renamer = ReplaceNameTransformer(file_model_name, self.model_name, self.given_old_name, self.given_new_name) + renamed_module = module.visit(renamer) + self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(renamed_module, self.functions, self.function_call_dependency_mapping, + self.assignments, self.start_lines) + + # In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # definitions found in the visited files + self.merge_model_specific_imports(self.visited_modules) + + # Re-assign all nodes + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + + def merge_model_specific_imports(self, visited_modules): + # In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # definitions found in the visited files + self.start_lines_file_mapping = {} + self.added_objects_file_mapping = {} + for object_name, file in self.model_specific_imported_objects.items(): + visited_module = visited_modules[file] + self.start_lines_file_mapping[file] = visited_module.start_lines + # Add functions and their dependencies + if object_name in visited_module.functions and object_name not in self.functions: + self.functions[object_name] = visited_module.functions[object_name] + self.added_objects_file_mapping[object_name] = file + dependencies = visited_module.function_call_recursive_dependency_mapping.get(object_name, None) + if dependencies is not None: + self.function_call_recursive_dependency_mapping[object_name] = dependencies + for dep in dependencies: + self.added_objects_file_mapping[dep] = file + self.functions[dep] = visited_module.global_nodes[dep] + + # Add assignments + elif object_name in visited_module.assignments and object_name not in self.assignments: + self.added_objects_file_mapping[object_name] = file + self.assignments[object_name] = visited_module.assignments[object_name] + + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + """Compute the relative order that the `missing_dependencies` should have between themselves in the output file. """ - class_name = original_node.name.value - bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping] - all_bases = [k.value.value for k in original_node.bases] - self.global_scope_index += 100 - for super_class in bases: - if super_class not in self.imported_mapping: - raise ImportError( - f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}" - ) + relative_order = {} + idx = 0 + + original_dependencies = [] + other_files_dependencies = defaultdict(list) + for dep in tuple(missing_dependencies): + if dep in self.added_objects_file_mapping: + file = self.added_objects_file_mapping[dep] + other_files_dependencies[file].append(dep) + else: + original_dependencies.append(dep) + # Sort all lists according to the order in their respective file + all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + for file, dependencies in other_files_dependencies.items(): + sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) + all_dependencies += sorted_dependencies + + # Add all original node first, then merged ones (one file at a time) + for dep in all_dependencies: + relative_order[dep] = idx + idx += 1 + + return relative_order + + + def find_file_type(self, class_name: str) -> str: + match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) + match = re.search(rf"({match_pattern})$", class_name) + if match: + file_type = TYPE_TO_FILE_TYPE[match.group(1)] + else: + file_type = "modeling" + return file_type + + + def add_class_node(self, class_name: str, node: cst.CSTNode, files: dict[str, dict]): + """Add a single class node (and its dependencies), to the `files`.""" + + bases = [k.value.value for k in node.bases if k.value.value in self.model_specific_imported_objects] + if len(bases) > 1: + raise ValueError( + f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." + ) + all_bases = [k.value.value for k in node.bases] + + file_type = self.find_file_type(class_name) + file_to_update = files[file_type] + + # We need to replace the class node with the super class node + if len(bases) == 1: + super_class = bases[0] + super_file_name = self.model_specific_imported_objects[super_class] - super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", super_file_name) if model_name: model_name = model_name.groups()[0] @@ -883,247 +1111,89 @@ def leave_ClassDef(self, original_node, updated_node): raise ValueError( f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name" ) - file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - visited_module = self.visited_module - if super_file_name not in visited_module: # only extract classes once - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - self.model_name, - self.given_old_name, - self.given_new_name, - ) - visited_module[super_file_name] = class_finder - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - else: # we are re-using the previously parsed data - class_finder = visited_module[super_file_name] - - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - # so, maybe standard renaming did not work (the class name is different) - # we try with another renaming pattern - potential_given_name = get_new_part(class_name, super_class) - del visited_module[super_file_name] - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - potential_given_name, - self.model_name, - potential_given_name, - ) - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - # last recourse, if the suffix of the new class is different from the one of the super class - # e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection - # we try with another renaming pattern - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - self.model_name, - self.given_old_name, - self.given_new_name, - super_class, - class_name, - ) - visited_module[super_file_name] = class_finder - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - raise ValueError( - f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})" - f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}." - f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`" - ) - list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) - start_insert_idx = self.global_scope_index - file_to_update = self.files[file_type] - is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n" - for dependency, _ in list_dependencies: - # we can write to the correct body, using the source of the parent class - node = class_finder.global_nodes.get(dependency, None) - if node is not None: - if dependency not in file_to_update: - node = self.all_definitions.pop(dependency, node) - start_insert_idx -= 1 - file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} - self.added_dependencies.add(dependency) - elif dependency not in self.inserted_deps: - # make sure the node is written after its dependencies - start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 - if ( - dependency in file_to_update.keys() - and dependency in class_finder.first_lvl_dependency_mapping[class_name] - ): - # If dependency is defined, but not used, raise error - calls = m.findall(original_node, m.Call(func=m.Name(dependency))) - if not calls and not is_empty_node and dependency not in all_bases: - raise ValueError( - f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used - when you define `{class_name}`, as it is one of it's direct dependencies. Make sure - you use it in the `__init__` function.""" - ) - self.inserted_deps.append(dependency) - - if len(list_dependencies) > 0: - updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases) - - # Now, if a class was defined without parents, we look for the name - match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) - match = re.search(rf"({match_pattern})$", class_name) - if match: - key = TYPE_TO_FILE_TYPE[match.group(1)] - self.class_to_file_type[class_name] = key - self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} - else: - self.class_to_file_type[class_name] = "modeling" - self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} + # Get the mapper corresponding to the inherited class + mapper = self.visited_modules[super_file_name] - self.current_class = None - return updated_node + # Create the new class node + updated_node = replace_class_node(mapper, node, class_name, all_bases) - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.current_top_level_function = node.name.value - - def leave_FunctionDef(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - self.all_definitions[node.name.value] = node - return node - - def visit_Assign(self, node: cst.Assign) -> None: - # Check if the assignment target is '__all__' - if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__": - if isinstance(node.value, cst.List): - # Extract the elements from the list - all_all_to_add = defaultdict(list) - for elt in node.value.elements: - if isinstance(elt.value, cst.SimpleString): - # Remove quotes and add the string to the elements list - class_name = elt.value.value - file = self.class_to_file_type[ - elt.value.evaluated_value - ] # evaluated value give the content of the string - all_all_to_add[file] += [class_name] - for f_type, new_alls in all_all_to_add.items(): - updated_node = node.with_changes( - value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) - ) - self.files[f_type][class_name] = { - "insert_idx": self.global_scope_index + 100, - "node": updated_node, - } - - def leave_If(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - full_statement = self.python_module.code_for_node(original_node.test) - if re.search(r"[\s\S]*is_.*available", full_statement): - self.all_safe_imports.append(node) - elif full_statement not in self.all_imports: - logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") - return node - - def visit_Call(self, node: cst.Call): - """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them. - Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, - add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible.""" - # Only map function calls if we're inside a class (i.e., current_class is set) - if self.current_class is not None: - # Simple function calls such as foo() - if isinstance(node.func, cst.Name): - self.function_call_class_mapping[node.func.value].add(self.current_class) - elif self.current_top_level_function is not None: - # Simple function calls such as foo() - if isinstance(node.func, cst.Name): - self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value) + # The node was modified -> look for all dependencies (recursively) of the new node + new_node_dependencies = ClassDependencyMapper.dependencies_for_new_node(updated_node, mapper) + all_dependencies_to_add = find_all_dependencies( + dependency_mapping=mapper.class_dependency_mapping, + initial_dependencies=new_node_dependencies, + initial_checked_dependencies=set(file_to_update.keys()), + ) - def _maybe_add_function_to_body( - self, - top_level_function: str, - body: dict, - function_node: cst.FunctionDef, - matching_callers: Optional[set] = None, - parent: Optional[str] = None, - ) -> bool: - """Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers` - is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return - `True`. Return `False` otherwise. - """ - if matching_callers is None and parent is None: - raise ValueError("Cannot add function if both the parent and the matching callers are None.") - if matching_callers is None: - matching_callers = {parent} - if len(matching_callers) > 0 and top_level_function not in body.keys(): - # Add the function just before the first class using it - new_idx = min([body[element]["insert_idx"] for element in matching_callers]) - # Reorder the elements - for element in body.keys(): - if body[element]["insert_idx"] >= new_idx: - body[element]["insert_idx"] += 1 - # Assign new element to body (after changing the count to avoid messing it) - body[top_level_function] = {"insert_idx": new_idx, "node": function_node} - return True - return False - - def _recursively_add_all_new_needed_functions_in_files(self): - """For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in - the different files, and add them to the file if it is the case (also recursively adding all other functions that - may be needed in that function body).""" - # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py` - for top_level_function, function_node in self.all_definitions.items(): - calling_entities = self.function_call_class_mapping[top_level_function] - # The function may be needed in different files, we need to iterate on them - for file, body in self.files.items(): - file_elements = set(body.keys()) - # If the intersection is not null, top_level_func must be added to file - matching_callers = calling_entities & file_elements - added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers) - # If the function was added, we need to recursively add all its dependencies - if added: - for dependency, parent in find_all_dependencies( - top_level_function, self.function_call_dependency_mapping - ): - self._maybe_add_function_to_body( - dependency, body, self.all_definitions[dependency], parent=parent - ) + relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) + for dep in all_dependencies_to_add + } - def leave_Module(self, original_node: cst.Module, node): + # No super class, just check functions and assignments dependency in the imports from other model-specific files + else: + updated_node = node + # The node was NOT modified -> no need to look for recursive dependencies + all_dependencies_to_add = ClassDependencyMapper.dependencies_for_node(updated_node, self.global_nodes) + + relative_dependency_order = self.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], self.global_nodes[dep]) + for dep in all_dependencies_to_add + } + + # Add the class node itself to the nodes to add + class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 + nodes_to_add[class_name] = (class_idx, updated_node) + + return nodes_to_add, file_type + + + def create_files(self) -> dict[str, cst.Module]: + + files = defaultdict(dict) + current_file_indices = defaultdict(lambda: 0) + + # For each class defined in modular, potentially replace the node and add it with its dependencies + for class_name, node in self.classes.items(): + nodes_to_add, file_type = self.add_class_node(class_name, node, files) + # Sort the nodes according to their relative order + nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) + # Write all nodes to file + for dependency, (_, node) in nodes_to_add: + # This is used to keep certain variables at the beginning of the file + try: + # The -1000 is arbitrary -> just keep it bigger than the list + idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) + except ValueError: + idx = current_file_indices[file_type] + current_file_indices[file_type] += 1 + files[file_type][dependency] = {"insert_idx": idx, "node": node} + + # Add the __all__ statement to files + for file_type, node in self.all_all_to_add.items(): + idx = current_file_indices[file_type] + files[file_type]["__all__"] = {"insert_idx": idx, "node": node} + + # Merge imports + # TODO: use scope solution instead imports = {self.python_module.code_for_node(k): k for k in self.all_imports} - dependency_imports = {file_type: imports.copy() for file_type in self.files} - for super_file_name, visiter in self.visited_module.items(): + dependency_imports = {file_type: imports.copy() for file_type in files} + for super_file_name, visiter in self.visited_modules.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] dependency_imports[file_type].update( {self.python_module.code_for_node(k): k for k in visiter.imports.values()} ) - # Check if any new top-level function from the `modular_xxx.py` should be added to the different files - # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file). - self._recursively_add_all_new_needed_functions_in_files() - - for file, body in self.files.items(): + for file, body in files.items(): new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] - if len(new_body) > 0: - if file in dependency_imports.keys(): - new_body = list(dependency_imports[file].values()) + new_body - new_module = cst.Module(body=[*new_body], header=node.header) - # Final cleanup - new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies)) - self.files[file] = new_module - return node + new_body = list(dependency_imports[file].values()) + new_body + new_module = cst.Module(body=[*new_body], header=self.python_module.header) + files[file] = new_module + + return files def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None): @@ -1137,9 +1207,9 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, module = cst.parse_module(code) wrapper = MetadataWrapper(module) if cst_transformers is None: - cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name) + cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name) wrapper.visit(cst_transformers) - for file, node in cst_transformers.files.items(): + for file, node in cst_transformers.create_files().items(): if node != {}: # Get relative path starting from src/transformers/ relative_path = re.search( @@ -1180,7 +1250,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/roberta/modular_roberta.py"], + default=["src/transformers/models/glm/modular_glm.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) From 31809d10a7c536fe7078a6cdb1bdfb10f1efedb4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 28 Oct 2024 14:32:12 +0100 Subject: [PATCH 02/40] Update modular_model_converter.py --- utils/modular_model_converter.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 966054cbe22..350315656a3 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -185,7 +185,7 @@ def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMa class ModuleMapper(CSTVisitor, ABC): """A visitor class which analyses a module, creating a mapping of dependencies for classes and functions. - The `full_class_dependency_mapping` created contains 1st-level classes and assignments dependencies, as well + The `class_dependency_mapping` created contains 1st-level classes and assignments dependencies, as well as all (recursively) functions dependencies. The `function_call_recursive_dependecy_mapping` created contains all function definitions, and all their (recursively) dependencies. @@ -333,17 +333,17 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: """ relative_order = {} idx = 0 - classes = sorted([dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_line[x]) + classes = sorted([dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x]) # This is because for merged dependencies, we only have relative order in the other visited file, so we need # to track dependency order relative to a given class - if len(classes) > 0 and not hasattr(self, "full_class_dependency_mapping"): + if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): raise ValueError("Cannot correctly find the relative order of the dependencies.") remaining_dependencies = missing_dependencies.copy() # Start by tracking relative order class by class for class_name in classes: - class_dependencies = tuple(self.full_class_dependency_mapping[class_name] & remaining_dependencies) + class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) original_dependencies = [] merged_dependencies = [] # We need to differentiate between nodes that were already present (we can get relative order globally) and @@ -1008,6 +1008,8 @@ def leave_Module(self, node): """ super().leave_Module(node) + self.function_call_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + # Now, visit every model-specific files found in the imports, and merge their dependencies self.visited_modules = {} for file, module in self.model_specific_modules.items(): From b7acc3544af177a1d06293ff5695c9f90ce1b643 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 28 Oct 2024 15:22:00 +0100 Subject: [PATCH 03/40] Update modular_model_converter.py --- utils/modular_model_converter.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 350315656a3..ecb2dbab7a8 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -154,6 +154,8 @@ class ClassDependencyMapper(CSTVisitor): The `ClassFinder` uses it to get dependencies of classes. """ + METADATA_DEPENDENCIES = (ParentNodeProvider,) + def __init__(self, class_name: str, global_names: set | None): super().__init__() self.class_name = class_name @@ -162,14 +164,18 @@ def __init__(self, class_name: str, global_names: set | None): def visit_Name(self, node): if node.value != self.class_name and node.value in self.global_names: - self.dependencies.add(node.value) + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + # If it is only an annotation, do not add dependency + if not m.matches(parent_node, m.Annotation()): + self.dependencies.add(node.value) @classmethod def dependencies_for_node(cls, node: cst.ClassDef, global_names: set) -> set: """Create dependencies for a node in the `ModuleMapper`.""" temp_module = cst.Module(body=[node]) + wrapper = MetadataWrapper(temp_module) visitor = cls(node.name.value, global_names) - temp_module.visit(visitor) + wrapper.visit(visitor) return visitor.dependencies @classmethod @@ -178,8 +184,9 @@ def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMa `replace_call_to_super`). """ temp_module = cst.Module(body=[updated_node]) + wrapper = MetadataWrapper(temp_module) visitor = cls(updated_node.name.value, set(mapper.global_nodes.keys())) - temp_module.visit(visitor) + wrapper.visit(visitor) return mapper.augment_dependencies_with_functions(visitor.dependencies) @@ -1143,7 +1150,7 @@ def add_class_node(self, class_name: str, node: cst.CSTNode, files: dict[str, di relative_dependency_order = self.compute_relative_order(all_dependencies_to_add) nodes_to_add = { dep: (relative_dependency_order[dep], self.global_nodes[dep]) - for dep in all_dependencies_to_add + for dep in all_dependencies_to_add if dep not in file_to_update.keys() } # Add the class node itself to the nodes to add @@ -1252,7 +1259,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/glm/modular_glm.py"], + default=["src/transformers/models/gemma/modular_gemma.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) From 2ca25c224171054804bf1ee98ded6c8a78ea9043 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 28 Oct 2024 15:32:57 +0100 Subject: [PATCH 04/40] Update modular_model_converter.py --- utils/modular_model_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index ecb2dbab7a8..3cea396628e 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -999,7 +999,7 @@ def visit_SimpleStatementLine(self, node): if isinstance(element.value, cst.SimpleString): # Remove quotes and add the string to the elements list class_name = element.value.value - file = self.find_file_type(class_name) + file = self.find_file_type(element.value.evaluated_value) all_all_to_add[file] += [class_name] for file, new_alls in all_all_to_add.items(): new_node = assign_node.with_changes( From aaee9ae787918ab73618ea863b1e3f4eefadc9fa Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 28 Oct 2024 15:55:37 +0100 Subject: [PATCH 05/40] Update modular_model_converter.py --- utils/modular_model_converter.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 3cea396628e..5defa0b5f22 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -920,15 +920,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} - self.files = { # mapping for different component bodies - "modeling": {}, - "configuration": {}, - "tokenization": {}, - "processing": {}, - "image_processing": {}, - "feature_extractor": {}, - } - self.match_patterns = "|".join(self.files.keys()) + self.match_patterns = "|".join(list(TYPE_TO_FILE_TYPE.values()).append("modeling")) self.all_imports = [] self.all_all_to_add = {} From 2d26196e35149f12fcc47556b9469a02e2c6e465 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 28 Oct 2024 17:03:47 +0100 Subject: [PATCH 06/40] cleaning --- utils/modular_model_converter.py | 137 ++++++++++++++----------------- 1 file changed, 61 insertions(+), 76 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 5defa0b5f22..78c5dc3ad0d 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -148,12 +148,9 @@ def forward(...): } class ClassDependencyMapper(CSTVisitor): - """A visitor which analyzes classes to get their dependencies. If `global_names` is passed, only dependencies - present in the `global_names` will be added. - This class is used through the 3 convenient class methods allowing to get the dependencies of a given class node. - The `ClassFinder` uses it to get dependencies of classes. + """A visitor which is designed to analyze a single class node to get all its dependencies that are mutual with `global_names`. + This class is used through the 2 convenient class methods. """ - METADATA_DEPENDENCIES = (ParentNodeProvider,) def __init__(self, class_name: str, global_names: set | None): @@ -191,39 +188,20 @@ def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMa class ModuleMapper(CSTVisitor, ABC): - """A visitor class which analyses a module, creating a mapping of dependencies for classes and functions. - The `class_dependency_mapping` created contains 1st-level classes and assignments dependencies, as well - as all (recursively) functions dependencies. - The `function_call_recursive_dependecy_mapping` created contains all function definitions, and all their (recursively) - dependencies. - - For example if the visited code has - ```python3 - def init_value(): return 1 - - class LlamaModel(PreTrainedModel): - def __init__(self): - super().__init__(self) - self.value = init_value() - ``` - then the `class_dependency_mapping` should be: `{"LlamaModel": {"PreTrainedModel", "init_value"}} - - The dependency mapping is updated via the `ClassDependencyMapper`, then augmented with `augment_dependencies_with_functions` - to get all functions dependencies. + """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes and functions. + It defines common visiting patterns between the modular file and the model-specific modules that are imported in the modular file. """ - - METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) + METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) def __init__(self, python_module: cst.Module): # fmt: off - self.python_module: cst.Module = python_module # original cst.Module being visited - self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes - self.imports = {} # stores all import statements - self.functions = {} # mapping of global scope function names to Nodes - self.assignments = {} # mapping of global assignments names to Nodes - self.class_dependency_mapping = {} # mapping of function name to immediate dependencies + self.python_module: cst.Module = python_module # original cst.Module being visited + self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes + self.imports = [] # stores all import statements + self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes + self.function_call_dependency_mapping = defaultdict(set) # 1st-level function dependency mapping + self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes self.current_function = None - self.function_call_dependency_mapping = defaultdict(set) # fmt: on def visit_SimpleStatementLine(self, node): @@ -241,7 +219,7 @@ def visit_SimpleStatementLine(self, node): for idx, target in enumerate(list(left_hand_side.elements)): self.assignments[target.value.value] = node.body[0].value.elements[idx].value if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): - self.imports[node.body[0].names] = node + self.imports.append(node) def visit_FunctionDef(self, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) @@ -257,7 +235,7 @@ def leave_FunctionDef(self, node): def leave_If(self, node): for stmt in node.body.body: if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): - self.imports[stmt.body[0].names] = node + self.imports.append(node) def visit_ClassDef(self, node: ClassDef) -> None: """Record class nodes to create their dependencies at the end.""" @@ -320,6 +298,7 @@ def compute_class_dependencies(self): # Create the global mapping of recursive dependencies for functions self.function_call_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + self.class_dependency_mapping = {} for class_name, class_node in self.classes.items(): dependencies = ClassDependencyMapper.dependencies_for_node(class_node, set(self.global_nodes.keys())) # Corretcly augment class dependencies with all needed functions @@ -331,6 +310,10 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: class ModelFileMapper(ModuleMapper): + """A mapper designed for model-specific files (i.e. a `transformers.models.xxx` file). When encountering such a file + in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. + For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes + care of correctly merging dependencies, then finalizes all dependency graph computations.""" def __init__(self, python_module: cst.Module): super().__init__(python_module) @@ -900,6 +883,16 @@ def get_new_part(class_name, base_class): return snake_case +def find_file_type(class_name: str) -> str: + """Based on a class name, find the file type corresponding to the class.""" + match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) + match = re.search(rf"({match_pattern})$", class_name) + if match: + file_type = TYPE_TO_FILE_TYPE[match.group(1)] + else: + file_type = "modeling" + return file_type + # These top-level variables will always appear the very beginning of the file, in the order they are defined in # this list (this is to avoid having variables at weird places, even if they are not used before) VARIABLES_AT_THE_BEGINNING = [ @@ -909,8 +902,10 @@ def get_new_part(class_name, base_class): ] class ModularFileMapper(ModuleMapper): - METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) - + """This is a Mapper for a modular file. It visits the whole file, recording dependency, then visits all model-specific + files that should be visited, and manages their mutual dependencies. + Calling the method `create_modules()` after visit will create all modules based on this modular file. + """ def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None): super().__init__(python_module) self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3` @@ -921,14 +916,11 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} self.match_patterns = "|".join(list(TYPE_TO_FILE_TYPE.values()).append("modeling")) - self.all_imports = [] self.all_all_to_add = {} def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - """When visiting imports from `transformers.models.xxx` we need to: - 1. Get the original source code - 2. Parse it into an AST Tree - 3. Add this import to `self.transformers_imports` as visited to not parse it twice + """When visiting imports from model-specific files (i.e. `transformers.models.xxx`) we get the code, parse it, + and record it in `self.model_specific_modules`. The imported objects are recorded in `self.model_specific_imported_objects`. """ import_statement = self.python_module.code_for_node(node.module) if "auto.modeling_auto" in import_statement: @@ -959,14 +951,17 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: ) def visit_SimpleStatementLine(self, node): + """If we visit an import statement not previously visited, record it. If we visit a top-level assignment, + simply record it or, if it is `__all__`, split it between files where we should dispatch it. + """ parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) simple_top_level_assign_structure = m.SimpleStatementLine( body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] ) if m.matches(parent_node, m.Module()): if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): - if node not in self.all_imports: - self.all_imports.append(node) + if node not in self.imports: + self.imports.append(node) elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): full_statement = self.python_module.code_for_node(node.body[0].module) if ( @@ -975,8 +970,8 @@ def visit_SimpleStatementLine(self, node): and "auto.modeling_auto" not in full_statement ): return - if node not in self.all_imports: - self.all_imports.append(node) + if node not in self.imports: + self.imports.append(node) elif m.matches(node, simple_top_level_assign_structure): assigned_variable = node.body[0].targets[0].target.value # __all__ is treated differently and not added to general assignments @@ -991,7 +986,7 @@ def visit_SimpleStatementLine(self, node): if isinstance(element.value, cst.SimpleString): # Remove quotes and add the string to the elements list class_name = element.value.value - file = self.find_file_type(element.value.evaluated_value) + file = find_file_type(element.value.evaluated_value) all_all_to_add[file] += [class_name] for file, new_alls in all_all_to_add.items(): new_node = assign_node.with_changes( @@ -1000,10 +995,11 @@ def visit_SimpleStatementLine(self, node): self.all_all_to_add[file] = node.with_changes(body=[new_node]) def leave_Module(self, node): - """When leaving the module, we finally create the `function_call_recursive_dependency_mapping`, then we - compute the dependencies for all recorded classes based on all the nodes we visited. - We also store the position of each global scoped node to allow sorting the dependencies based on their - position in the code later. We use the PositionProvider metadata wrapper for this. + """When we leave the modular file, we do the following in order: + - compute recursive function dependencies + - for each model-specific file found in the imports, rename it with the new model name, visit it, and update + its dependency graph with the new function and assignment definitions found in the modular + - update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) """ super().leave_Module(node) @@ -1026,8 +1022,8 @@ def leave_Module(self, node): self.global_nodes = {**self.assignments, **self.classes, **self.functions} def merge_model_specific_imports(self, visited_modules): - # In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the - # definitions found in the visited files + """Merge the model-specific imported functions and assignments to the modular nodes and dependency graph, + based on the visited files.""" self.start_lines_file_mapping = {} self.added_objects_file_mapping = {} for object_name, file in self.model_specific_imported_objects.items(): @@ -1075,21 +1071,11 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: idx += 1 return relative_order - - - def find_file_type(self, class_name: str) -> str: - match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) - match = re.search(rf"({match_pattern})$", class_name) - if match: - file_type = TYPE_TO_FILE_TYPE[match.group(1)] - else: - file_type = "modeling" - return file_type - - - def add_class_node(self, class_name: str, node: cst.CSTNode, files: dict[str, dict]): - """Add a single class node (and its dependencies), to the `files`.""" + def add_class_node(self, class_name: str, node: cst.CSTNode, files: dict[str, dict]) -> tuple[dict, str]: + """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new + class node based on the inherited classes if needed. + """ bases = [k.value.value for k in node.bases if k.value.value in self.model_specific_imported_objects] if len(bases) > 1: raise ValueError( @@ -1097,7 +1083,7 @@ def add_class_node(self, class_name: str, node: cst.CSTNode, files: dict[str, di ) all_bases = [k.value.value for k in node.bases] - file_type = self.find_file_type(class_name) + file_type = find_file_type(class_name) file_to_update = files[file_type] # We need to replace the class node with the super class node @@ -1151,9 +1137,8 @@ def add_class_node(self, class_name: str, node: cst.CSTNode, files: dict[str, di return nodes_to_add, file_type - - def create_files(self) -> dict[str, cst.Module]: - + def create_modules(self) -> dict[str, cst.Module]: + """Create all the new modules based on visiting the modular file. It replaces all classes as necesary.""" files = defaultdict(dict) current_file_indices = defaultdict(lambda: 0) @@ -1173,14 +1158,14 @@ def create_files(self) -> dict[str, cst.Module]: current_file_indices[file_type] += 1 files[file_type][dependency] = {"insert_idx": idx, "node": node} - # Add the __all__ statement to files + # Add the __all__ statement to files at the end for file_type, node in self.all_all_to_add.items(): idx = current_file_indices[file_type] files[file_type]["__all__"] = {"insert_idx": idx, "node": node} # Merge imports # TODO: use scope solution instead - imports = {self.python_module.code_for_node(k): k for k in self.all_imports} + imports = {self.python_module.code_for_node(k): k for k in self.imports} dependency_imports = {file_type: imports.copy() for file_type in files} for super_file_name, visiter in self.visited_modules.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] @@ -1210,8 +1195,8 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, if cst_transformers is None: cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name) wrapper.visit(cst_transformers) - for file, node in cst_transformers.create_files().items(): - if node != {}: + for file, module in cst_transformers.create_modules().items(): + if module != {}: # Get relative path starting from src/transformers/ relative_path = re.search( r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/") @@ -1220,7 +1205,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, header = AUTO_GENERATED_MESSAGE.format( relative_path=relative_path, short_name=os.path.basename(relative_path) ) - ruffed_code = run_ruff(header + node.code, True) + ruffed_code = run_ruff(header + module.code, True) formatted_code = run_ruff(ruffed_code, False) output[file] = [formatted_code, ruffed_code] return output From 2c675f251cf3bd56d656c5db5a1e74a77abe47b0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 28 Oct 2024 17:26:39 +0100 Subject: [PATCH 07/40] cleaning --- utils/modular_model_converter.py | 1227 +++++++++++++++--------------- 1 file changed, 615 insertions(+), 612 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 78c5dc3ad0d..c66247d4105 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -54,694 +54,695 @@ def get_module_source_from_name(module_name: str) -> str: source_code = file.read() return source_code +class ReplaceNameTransformer(m.MatcherDecoratableTransformer): + """A transformer that replaces `old_name` with `new_name` in comments, string and any references. + It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. + Supported renaming patterns: + - llama -> my_new_model and my_new_model -> llama + - Llama -> MyNewModel and MyNewModel -> Llama + - LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA + - LLaMa -> MyNewModel abd MyNewModel -> Llama + """ -def find_all_dependencies( - dependency_mapping: Dict[str, set], - start_entity: str | None = None, - initial_dependencies: set | None = None, - initial_checked_dependencies: set | None = None, - return_parent: bool = False, -) -> list | set: - """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of - BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. + def __init__( + self, + old_name, + new_name, + given_old_name=None, + given_new_name=None, + ): + super().__init__() + self.old_name = old_name + self.new_name = new_name + self.default_name = "".join(x.title() for x in new_name.split("_")) + if self.new_name in CONFIG_MAPPING_NAMES: + self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace( + "Config", "" + ) # the best source of truth for class names. Could also just use the ones de + self.patterns = { + old_name: new_name, + old_name.upper(): new_name.upper(), + "".join(x.title() for x in old_name.split("_")): self.default_name, + } + if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns: + self.patterns[given_old_name] = given_new_name + if self.old_name in CONFIG_MAPPING_NAMES: + self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "") + if self.default_old_name.isupper(): + self.default_old_name = self.default_old_name.capitalize() - Args: - dependency_mapping (`Dict[str, set]`): - A mapping from entities (usually function names), to immediate dependencies. That is, for function names, - a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called - in `foo`'s definition. - start_entity (str | None, *optional*): - A key of `dependency_mapping`, indicating from which entity to start the search. - initial_dependencies (set | None, *optional*): - If `start_entity` is not provided, this can be used as an alternative. In this case, `initial_dependencies` - the search will continue from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. - initial_checked_dependencies (set | None, *optional*): - If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. - return_parent (bool, *optional*): - If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note - that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. - Returns: - A set of all the dependencies, or a list containing parents as well if `return_parent=True`. + def preserve_case_replace(self, text): + # Create a regex pattern to match all variations + regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys()) + compiled_regex = re.compile(regex_pattern, re.IGNORECASE) - Example: - Given the following structure in the `modular_xxx.py` file: - ``` - def foo1(): - pass + def replace(match): + word = match.group(0) + result = self.patterns.get(word, self.default_name) + return result - def foo2(): - pass + return compiled_regex.sub(replace, text) - def bar(): - foo1() + def convert_to_camelcase(self, text): + # Regex pattern to match consecutive uppercase letters and lowercase the first set + result = re.sub( + rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1 + ) + return result - def foobar(): - bar() - foo2() + @m.leave(m.Name() | m.SimpleString() | m.Comment()) + def replace_name(self, original_node, updated_node): + if re.findall(r"# Copied from", updated_node.value): + return cst.RemoveFromParent() + update = self.preserve_case_replace(updated_node.value) + return updated_node.with_changes(value=update) - class MyLayer(SomeOtherModelLayer): - def forward(...): - foobar() - ``` - and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: - ``` - dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} - find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) - >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] - ``` - That is, all the functions needed (and potentially their immediate parent) so that the function to be added - in MyLayer (`foobar`) can work correctly. - """ - if initial_dependencies is None and start_entity is not None: - initial_dependencies = dependency_mapping[start_entity] - if initial_checked_dependencies is None: - initial_checked_dependencies = set() + def leave_ClassDef(self, original_node, updated_node): + return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) - dependency_queue = deque(initial_dependencies) - all_dependencies = set() - all_dependencies_with_parent = [] - checked_dependencies = set(initial_checked_dependencies) - parents = {initial_dep: start_entity for initial_dep in initial_dependencies} - while len(dependency_queue) > 0: - # Pick element to visit - current = dependency_queue.popleft() - if current not in checked_dependencies: - # Add the dependencies - all_dependencies.add(current) - all_dependencies_with_parent += [(current, parents[current])] - if current in dependency_mapping.keys(): - # Update dependency queue - dependency_queue.extend(dependency_mapping[current]) - parents.update({dep: current for dep in dependency_mapping[current]}) - # add visited node to the list - checked_dependencies.add(current) - if not return_parent: - return all_dependencies - # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) - return all_dependencies_with_parent +DOCSTRING_NODE = m.SimpleStatementLine( + body=[ + m.Expr( + value=m.SimpleString( + # match anything between """ """ + value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None) + ) + ) + ] +) -# These top-level variables will always use the value in the `modular_xxx.py` file -ASSIGNMENTS_TO_KEEP = { - "_CHECKPOINT_FOR_DOC", -} +def SUPER_CALL_NODE(func_name): + return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) -class ClassDependencyMapper(CSTVisitor): - """A visitor which is designed to analyze a single class node to get all its dependencies that are mutual with `global_names`. - This class is used through the 2 convenient class methods. - """ - METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, class_name: str, global_names: set | None): - super().__init__() - self.class_name = class_name - self.dependencies = set() - self.global_names = global_names +def is_call_to_super(node, func_name): + return m.matches( + node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]) + ) - def visit_Name(self, node): - if node.value != self.class_name and node.value in self.global_names: - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - # If it is only an annotation, do not add dependency - if not m.matches(parent_node, m.Annotation()): - self.dependencies.add(node.value) - @classmethod - def dependencies_for_node(cls, node: cst.ClassDef, global_names: set) -> set: - """Create dependencies for a node in the `ModuleMapper`.""" - temp_module = cst.Module(body=[node]) - wrapper = MetadataWrapper(temp_module) - visitor = cls(node.name.value, global_names) - wrapper.visit(visitor) - return visitor.dependencies +# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method +class ReplaceMethodCallTransformer(cst.CSTTransformer): + def __init__(self, all_bases: Set[str]): + self.all_bases = all_bases - @classmethod - def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMapper") -> set: - """Create dependencies for a node in the `ModularFileMapper` (which may have been changed by - `replace_call_to_super`). - """ - temp_module = cst.Module(body=[updated_node]) - wrapper = MetadataWrapper(temp_module) - visitor = cls(updated_node.name.value, set(mapper.global_nodes.keys())) - wrapper.visit(visitor) - return mapper.augment_dependencies_with_functions(visitor.dependencies) + def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: + # Handle ClassB.call_to_method + if ( + isinstance(original_node.value, cst.Name) + and original_node.value.value in self.all_bases + and isinstance(original_node.attr, cst.Name) + ): + # Replace with super().call_to_method + return updated_node.with_changes( + value=cst.Call(cst.Name("super")), + ) + # Handle ClassB().call_to_method + elif ( + isinstance(original_node.value, cst.Call) + and isinstance(original_node.value.func, cst.Name) + and original_node.value.func.value in self.all_bases + and isinstance(original_node.attr, cst.Name) + ): + # Replace with super().call_to_method + return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super")))) + return updated_node + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: + # Check if the function being called is of the form ClassB().func_a or ClassB.func_a + if isinstance(original_node.func, cst.Attribute) and ( + # Match ClassB().func_a(...) + ( + isinstance(original_node.func.value, cst.Call) + and isinstance(original_node.func.value.func, cst.Name) + and original_node.func.value.func.value in self.all_bases + ) + or + # Match ClassB.func_a(...) + (isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases) + ): + # Check if the first argument is 'self', and remove it + if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): + # Create the new argument list without 'self' + new_args = updated_node.args[1:] + else: + new_args = updated_node.args -class ModuleMapper(CSTVisitor, ABC): - """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes and functions. - It defines common visiting patterns between the modular file and the model-specific modules that are imported in the modular file. - """ - METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) + return updated_node.with_changes(args=new_args) + return updated_node - def __init__(self, python_module: cst.Module): - # fmt: off - self.python_module: cst.Module = python_module # original cst.Module being visited - self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes - self.imports = [] # stores all import statements - self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes - self.function_call_dependency_mapping = defaultdict(set) # 1st-level function dependency mapping - self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes - self.current_function = None - # fmt: on - def visit_SimpleStatementLine(self, node): - """ - Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements - are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. - """ - if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( - self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() - ): - left_hand_side = node.body[0].targets[0].target - if hasattr(left_hand_side, "value"): - self.assignments[left_hand_side.value] = node - else: - for idx, target in enumerate(list(left_hand_side.elements)): - self.assignments[target.value.value] = node.body[0].value.elements[idx].value - if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): - self.imports.append(node) +def get_docstring_indent(docstring): + # Match the first line after the opening triple quotes + match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring) + if match: + # Return the indentation spaces captured + return len(match.group(1)) + return 0 - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.current_function = node.name.value - self.functions[node.name.value] = node - def leave_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.current_function = None +def merge_docstrings(original_docstring, updated_docstring): + # indent_level = get_docstring_indent(updated_docstring) + original_level = get_docstring_indent(original_docstring) + if not re.findall(r"\n\s*Args:\n", updated_docstring): + # Split the docstring at the example section, assuming `"""` is used to define the docstring + parts = original_docstring.split("```") + if "```" in updated_docstring and len(parts) > 1: + updated_docstring = updated_docstring.lstrip('r"') + new_parts = updated_docstring.split("```") + if len(new_parts) != 3: + raise ValueError("There should only be one example, and it should have opening and closing '```'") + parts[1] = new_parts[1] + updated_docstring = "".join( + [ + parts[0].rstrip(" \n") + new_parts[0], + f"\n{original_level*' '}```", + parts[1], + "```", + parts[2], + ] + ) + elif updated_docstring not in original_docstring: + # add tabulation if we are at the lowest level. + if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring): + updated_docstring = updated_docstring.replace("\n ", "\n ") + updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n') + return updated_docstring - def leave_If(self, node): - for stmt in node.body.body: - if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): - self.imports.append(node) - def visit_ClassDef(self, node: ClassDef) -> None: - """Record class nodes to create their dependencies at the end.""" - self.classes[node.name.value] = node +class SuperTransformer(cst.CSTTransformer): + METADATA_DEPENDENCIES = (ParentNodeProvider,) - def visit_Call(self, node: cst.Call): - """This is used to create a mapping from top-level functions to functions called inside them. - Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, - add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible. - """ - if self.current_function is not None: - # Simple function calls such as foo() - if m.matches(node.func, m.Name()): - self.function_call_dependency_mapping[self.current_function].add(node.func.value) + def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None): + self.python_module = python_module + self.original_methods = original_methods + self.updated_methods = updated_methods + self.all_assign_target = {} + self.deleted_targets = {} # child node can delete some arguments + self.class_name = class_name + self.all_bases = all_bases or [] + self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) - def leave_Module(self, node): - """When leaving the module, we finally create the `function_call_recursive_dependency_mapping`, then we - compute the dependencies for all recorded classes based on all the nodes we visited. - We also store the position of each global scoped node to allow sorting the dependencies based on their - position in the code later. We use the PositionProvider metadata wrapper for this. + def update_body(self, existing_body, new_statements): """ - # assign all nodes - self.global_nodes = {**self.assignments, **self.classes, **self.functions} - # now sort the class dependency_mapping based on the position of the nodes - self.start_lines = {} - for id, node in self.global_nodes.items(): - self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + Helper method to update the body by removing duplicates before adding new statements. + `existing_body` is the body of the original method, the parent class + `new_statements` are the additional statements + """ + deduplicated_new_body = [] + existing_nodes = set() + for node in new_statements: + if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): + target = self.python_module.code_for_node(node.body[0].targets[0].target) + self.all_assign_target[target] = node + if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): + target = self.python_module.code_for_node(node.body[0].target) + self.deleted_targets[target] = node - def _compute_recursive_function_dependencies(self) -> dict[str, set]: - """Based on the 1st level function dependency mapping, create the recursive dependency mapping.""" - recursive_dependencies = {} - for function_name in self.function_call_dependency_mapping.keys(): - # We need to check if they are present in self.functions to avoid built-in functions - all_dependencies = { - dep - for dep in find_all_dependencies(self.function_call_dependency_mapping, start_entity=function_name) - if dep in self.functions.keys() - } - recursive_dependencies[function_name] = all_dependencies - return recursive_dependencies + for stmt in existing_body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): + target = self.python_module.code_for_node(stmt.body[0].targets[0].target) + if target in self.deleted_targets: + logger.warning(f"Deleted the assign for {target}") + continue + if target in self.all_assign_target: + stmt = self.all_assign_target[target] + # Skip the docstring (will be added later on, at the beginning) + elif m.matches(stmt, DOCSTRING_NODE): + continue + comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + deduplicated_new_body.append(stmt) + existing_nodes.add(comment_less_code) - def augment_dependencies_with_functions(self, dependencies: set) -> set: - """For a set of `dependencies`, augment them by adding all potential functions which are dependencies of - the functions present in the `dependencies`. - """ - new_dependencies = dependencies.copy() - # Go through the set of dependencies - for dep in tuple(dependencies): - if dep in self.function_call_recursive_dependency_mapping.keys(): - new_dependencies.update(self.function_call_recursive_dependency_mapping[dep]) - return new_dependencies - - def compute_class_dependencies(self): - """For each visited class, find its dependencies based on visited the current file + potential merged dependencies. - Note: This function takes care of updating `global_nodes` and `function_call_recursive_dependency_mapping` as well after the - merge with other files dependencies. - """ - # Correctly re-set the global nodes at this point - self.global_nodes = {**self.assignments, **self.classes, **self.functions} - # Create the global mapping of recursive dependencies for functions - self.function_call_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + for node in new_statements: + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if node not in deduplicated_new_body and comment_less_code not in existing_nodes: + if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): + deduplicated_new_body.append(node) + existing_nodes.add(comment_less_code) - self.class_dependency_mapping = {} - for class_name, class_node in self.classes.items(): - dependencies = ClassDependencyMapper.dependencies_for_node(class_node, set(self.global_nodes.keys())) - # Corretcly augment class dependencies with all needed functions - self.class_dependency_mapping[class_name] = self.augment_dependencies_with_functions(dependencies) + # Fix the post_init() that has to be last + for i, node in enumerate(deduplicated_new_body): + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "self.post_init(" in comment_less_code and i < len(deduplicated_new_body) - 1: + # Remove it and add it again at the end + deduplicated_new_body.pop(i) + deduplicated_new_body.append(node) + break - @abstractmethod - def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - pass - + return deduplicated_new_body -class ModelFileMapper(ModuleMapper): - """A mapper designed for model-specific files (i.e. a `transformers.models.xxx` file). When encountering such a file - in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. - For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes - care of correctly merging dependencies, then finalizes all dependency graph computations.""" + def _fix_init_location(self, new_body): + """Fix the location of the super()__init__ in the new body, if we had new statements before it.""" + start_index = 0 + for i, node in enumerate(new_body): + if m.matches(node, DOCSTRING_NODE) and i == start_index: + start_index += 1 + continue + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "super().__init__" in comment_less_code and i > start_index: + # Remove it and add it again at the top after the docstrings + node = new_body.pop(i) + new_body = new_body[:start_index] + [node] + new_body[start_index:] + break + return new_body - def __init__(self, python_module: cst.Module): - super().__init__(python_module) - - def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - """Compute the relative order that the `missing_dependencies` should have between themselves in the output file. + def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: + """Updates the body of the input `node`'s `func_name` function by replacing calls + to super().func_name() with the source code of the parent class' `func_name`. + It keeps everything that is defined before `super().func_name()`. """ - relative_order = {} - idx = 0 - classes = sorted([dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x]) - # This is because for merged dependencies, we only have relative order in the other visited file, so we need - # to track dependency order relative to a given class - if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): - raise ValueError("Cannot correctly find the relative order of the dependencies.") - - remaining_dependencies = missing_dependencies.copy() - - # Start by tracking relative order class by class - for class_name in classes: - class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) - original_dependencies = [] - merged_dependencies = [] - # We need to differentiate between nodes that were already present (we can get relative order globally) and - # nodes that were merged (we can get relative order only relative to the class the dependencies relate to) - for class_dep in class_dependencies: - if class_dep in self.modular_file_start_lines: - merged_dependencies.append(class_dep) - else: - original_dependencies.append(class_dep) - # Sort both list according to the order in their respective file - original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) - merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) - - # Add all original node first, then merged ones - for dep in original_dependencies + merged_dependencies: - remaining_dependencies.remove(dep) - relative_order[dep] = idx - idx += 1 - # Add the class itself - remaining_dependencies.remove(class_name) - relative_order[class_name] = idx - idx += 1 + self.has_docstring = False + parent_has_docstring = False + if func_name in self.original_methods: + parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE) + new_body = [] + has_super_call = False - # Now add what still remains - remaining_dependencies = tuple(remaining_dependencies) - original_dependencies = [] - merged_dependencies = [] - for dep in remaining_dependencies: - if dep in self.modular_file_start_lines: - merged_dependencies.append(dep) + for i, expr in enumerate(node.body): + if is_call_to_super(expr, func_name): + has_super_call = True + new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) + new_body = self._fix_init_location(new_body) else: - original_dependencies.append(dep) - # Sort both list according to the order in their respective file - original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) - merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) - - # Add all original node first, then merged ones - for dep in original_dependencies + merged_dependencies: - relative_order[dep] = idx - idx += 1 - - return relative_order - - def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapping: dict[str, set]): - """Update the global nodes and function dependency mapping with those from the modular file. - - Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies - instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). - """ - # Add/overwrite all needed function nodes and dependencies - self.functions.update(functions) - self.function_call_dependency_mapping.update(function_call_mapping) + expr = expr.visit(self.transformer) + if m.matches(expr, DOCSTRING_NODE): + self.has_docstring = True + if parent_has_docstring: # actually here we ought to de-duplicate? + original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value + updated_docstring = expr.body[0].value.value + merged_doc = merge_docstrings(original_docstring, updated_docstring) + new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])] + else: + new_node = [expr] + new_body.extend(new_node) + elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call: + new_body.append(expr) + if not self.has_docstring and parent_has_docstring: + new_body = [self.original_methods[func_name].body.body[0]] + new_body + return node.with_changes(body=new_body) - def _merge_assignments(self, assignments: dict[str, cst.CSTNode]): - """Update the global nodes with the assignment from the modular file. + def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: + if updated_node.name.value in self.updated_methods: + name = updated_node.name.value + new_body = self.replace_super_calls(updated_node.body, name) + return updated_node.with_changes(body=new_body, params=updated_node.params) + return updated_node - Merging rule: if any assignment with the same name was redefined in the modular, we use it ONLY if it is - in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value. This rule was chosen to avoid having to rewrite the - big docstrings. - """ - for assignment, node in assignments.items(): - if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments: - self.assignments[assignment] = node + def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode: + """ "When a return statement is reached, it is replaced with the unrolled super code""" + if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))): + func_def = self.get_metadata(ParentNodeProvider, original_node) + if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods: + updated_return_value = updated_node.value.with_changes( + args=[ + cst.Arg( + value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))]) + ) + ] + ) + return updated_node.with_changes(value=updated_return_value) + return updated_node - def merge_modular_dependencies(self, functions, function_mapping, assignments, start_lines): - """Merge both functions and assignments from the modular definitions into the current module file, - then compute the relative order of all nodes.""" - self._merge_functions(functions, function_mapping) - self._merge_assignments(assignments) - self.modular_file_start_lines = start_lines - @classmethod - def visit_and_merge_dependencies(cls, module: cst.Module, functions, function_mapping, assignments, start_lines) -> "ModelFileMapper": - wrapper = MetadataWrapper(module) - mapper = cls(module) - wrapper.visit(mapper) - # Merge dependencies - mapper.merge_modular_dependencies(functions, function_mapping, assignments, start_lines) - # Create the class dependencies graph - mapper.compute_class_dependencies() - return mapper +def find_all_dependencies( + dependency_mapping: Dict[str, set], + start_entity: str | None = None, + initial_dependencies: set | None = None, + initial_checked_dependencies: set | None = None, + return_parent: bool = False, +) -> list | set: + """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of + BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. + Args: + dependency_mapping (`Dict[str, set]`): + A mapping from entities (usually function names), to immediate dependencies. That is, for function names, + a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called + in `foo`'s definition. + start_entity (str | None, *optional*): + A key of `dependency_mapping`, indicating from which entity to start the search. + initial_dependencies (set | None, *optional*): + If `start_entity` is not provided, this can be used as an alternative. In this case, `initial_dependencies` + the search will continue from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. + initial_checked_dependencies (set | None, *optional*): + If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. + return_parent (bool, *optional*): + If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note + that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. + Returns: + A set of all the dependencies, or a list containing parents as well if `return_parent=True`. -class ReplaceNameTransformer(m.MatcherDecoratableTransformer): - """A transformer that replaces `old_name` with `new_name` in comments, string and any references. - It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. - Supported renaming patterns: - - llama -> my_new_model and my_new_model -> llama - - Llama -> MyNewModel and MyNewModel -> Llama - - LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA - - LLaMa -> MyNewModel abd MyNewModel -> Llama - """ + Example: + Given the following structure in the `modular_xxx.py` file: + ``` + def foo1(): + pass - def __init__( - self, - old_name, - new_name, - given_old_name=None, - given_new_name=None, - ): - super().__init__() - self.old_name = old_name - self.new_name = new_name - self.default_name = "".join(x.title() for x in new_name.split("_")) - if self.new_name in CONFIG_MAPPING_NAMES: - self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace( - "Config", "" - ) # the best source of truth for class names. Could also just use the ones de - self.patterns = { - old_name: new_name, - old_name.upper(): new_name.upper(), - "".join(x.title() for x in old_name.split("_")): self.default_name, - } - if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns: - self.patterns[given_old_name] = given_new_name - if self.old_name in CONFIG_MAPPING_NAMES: - self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "") - if self.default_old_name.isupper(): - self.default_old_name = self.default_old_name.capitalize() + def foo2(): + pass - def preserve_case_replace(self, text): - # Create a regex pattern to match all variations - regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys()) - compiled_regex = re.compile(regex_pattern, re.IGNORECASE) + def bar(): + foo1() - def replace(match): - word = match.group(0) - result = self.patterns.get(word, self.default_name) - return result + def foobar(): + bar() + foo2() - return compiled_regex.sub(replace, text) + class MyLayer(SomeOtherModelLayer): + def forward(...): + foobar() + ``` + and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: + ``` + dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} + find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) + >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] + ``` + That is, all the functions needed (and potentially their immediate parent) so that the function to be added + in MyLayer (`foobar`) can work correctly. + """ + if initial_dependencies is None and start_entity is not None: + initial_dependencies = dependency_mapping[start_entity] + if initial_checked_dependencies is None: + initial_checked_dependencies = set() - def convert_to_camelcase(self, text): - # Regex pattern to match consecutive uppercase letters and lowercase the first set - result = re.sub( - rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1 - ) - return result + dependency_queue = deque(initial_dependencies) + all_dependencies = set() + all_dependencies_with_parent = [] + checked_dependencies = set(initial_checked_dependencies) + parents = {initial_dep: start_entity for initial_dep in initial_dependencies} + while len(dependency_queue) > 0: + # Pick element to visit + current = dependency_queue.popleft() + if current not in checked_dependencies: + # Add the dependencies + all_dependencies.add(current) + all_dependencies_with_parent += [(current, parents[current])] + if current in dependency_mapping.keys(): + # Update dependency queue + dependency_queue.extend(dependency_mapping[current]) + parents.update({dep: current for dep in dependency_mapping[current]}) + # add visited node to the list + checked_dependencies.add(current) - @m.leave(m.Name() | m.SimpleString() | m.Comment()) - def replace_name(self, original_node, updated_node): - if re.findall(r"# Copied from", updated_node.value): - return cst.RemoveFromParent() - update = self.preserve_case_replace(updated_node.value) - return updated_node.with_changes(value=update) + if not return_parent: + return all_dependencies + # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) + return all_dependencies_with_parent - def leave_ClassDef(self, original_node, updated_node): - return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) +# These top-level variables will always use the value in the `modular_xxx.py` file +ASSIGNMENTS_TO_KEEP = { + "_CHECKPOINT_FOR_DOC", +} -DOCSTRING_NODE = m.SimpleStatementLine( - body=[ - m.Expr( - value=m.SimpleString( - # match anything between """ """ - value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None) - ) - ) - ] -) +class ClassDependencyMapper(CSTVisitor): + """A visitor which is designed to analyze a single class node to get all its dependencies that are mutual with `global_names`. + This class is used through the 2 convenient class methods. + """ + METADATA_DEPENDENCIES = (ParentNodeProvider,) + def __init__(self, class_name: str, global_names: set | None): + super().__init__() + self.class_name = class_name + self.dependencies = set() + self.global_names = global_names -def SUPER_CALL_NODE(func_name): - return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) + def visit_Name(self, node): + if node.value != self.class_name and node.value in self.global_names: + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + # If it is only an annotation, do not add dependency + if not m.matches(parent_node, m.Annotation()): + self.dependencies.add(node.value) + @classmethod + def dependencies_for_node(cls, node: cst.ClassDef, global_names: set) -> set: + """Create dependencies for a node in the `ModuleMapper`.""" + temp_module = cst.Module(body=[node]) + wrapper = MetadataWrapper(temp_module) + visitor = cls(node.name.value, global_names) + wrapper.visit(visitor) + return visitor.dependencies -def is_call_to_super(node, func_name): - return m.matches( - node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]) - ) + @classmethod + def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMapper") -> set: + """Create dependencies for a node in the `ModularFileMapper` (which may have been changed by + `replace_call_to_super`). + """ + temp_module = cst.Module(body=[updated_node]) + wrapper = MetadataWrapper(temp_module) + visitor = cls(updated_node.name.value, set(mapper.global_nodes.keys())) + wrapper.visit(visitor) + return mapper.augment_dependencies_with_functions(visitor.dependencies) -# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method -class ReplaceMethodCallTransformer(cst.CSTTransformer): - def __init__(self, all_bases: Set[str]): - self.all_bases = all_bases +class ModuleMapper(CSTVisitor, ABC): + """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes and functions. + It defines common visiting patterns between the modular file and the model-specific module files that will be visited. + """ + METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) - def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: - # Handle ClassB.call_to_method - if ( - isinstance(original_node.value, cst.Name) - and original_node.value.value in self.all_bases - and isinstance(original_node.attr, cst.Name) - ): - # Replace with super().call_to_method - return updated_node.with_changes( - value=cst.Call(cst.Name("super")), - ) - # Handle ClassB().call_to_method - elif ( - isinstance(original_node.value, cst.Call) - and isinstance(original_node.value.func, cst.Name) - and original_node.value.func.value in self.all_bases - and isinstance(original_node.attr, cst.Name) - ): - # Replace with super().call_to_method - return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super")))) - return updated_node + def __init__(self, python_module: cst.Module): + # fmt: off + self.python_module: cst.Module = python_module # original cst.Module being visited + self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes + self.imports = [] # stores all import statements + self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes + self.function_call_dependency_mapping = defaultdict(set) # 1st-level function dependency mapping + self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes + self.current_function = None + # fmt: on - def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: - # Check if the function being called is of the form ClassB().func_a or ClassB.func_a - if isinstance(original_node.func, cst.Attribute) and ( - # Match ClassB().func_a(...) - ( - isinstance(original_node.func.value, cst.Call) - and isinstance(original_node.func.value.func, cst.Name) - and original_node.func.value.func.value in self.all_bases - ) - or - # Match ClassB.func_a(...) - (isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases) + def visit_SimpleStatementLine(self, node): + """ + Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements + are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. + """ + if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( + self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() ): - # Check if the first argument is 'self', and remove it - if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): - # Create the new argument list without 'self' - new_args = updated_node.args[1:] + left_hand_side = node.body[0].targets[0].target + if hasattr(left_hand_side, "value"): + self.assignments[left_hand_side.value] = node else: - new_args = updated_node.args + for idx, target in enumerate(list(left_hand_side.elements)): + self.assignments[target.value.value] = node.body[0].value.elements[idx].value + if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + self.imports.append(node) - return updated_node.with_changes(args=new_args) - return updated_node + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_function = node.name.value + self.functions[node.name.value] = node + def leave_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_function = None -def get_docstring_indent(docstring): - # Match the first line after the opening triple quotes - match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring) - if match: - # Return the indentation spaces captured - return len(match.group(1)) - return 0 + def leave_If(self, node): + for stmt in node.body.body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): + self.imports.append(node) + def visit_ClassDef(self, node: ClassDef) -> None: + """Record class nodes to create their dependencies at the end.""" + self.classes[node.name.value] = node -def merge_docstrings(original_docstring, updated_docstring): - # indent_level = get_docstring_indent(updated_docstring) - original_level = get_docstring_indent(original_docstring) - if not re.findall(r"\n\s*Args:\n", updated_docstring): - # Split the docstring at the example section, assuming `"""` is used to define the docstring - parts = original_docstring.split("```") - if "```" in updated_docstring and len(parts) > 1: - updated_docstring = updated_docstring.lstrip('r"') - new_parts = updated_docstring.split("```") - if len(new_parts) != 3: - raise ValueError("There should only be one example, and it should have opening and closing '```'") - parts[1] = new_parts[1] - updated_docstring = "".join( - [ - parts[0].rstrip(" \n") + new_parts[0], - f"\n{original_level*' '}```", - parts[1], - "```", - parts[2], - ] - ) - elif updated_docstring not in original_docstring: - # add tabulation if we are at the lowest level. - if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring): - updated_docstring = updated_docstring.replace("\n ", "\n ") - updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n') - return updated_docstring + def visit_Call(self, node: cst.Call): + """This is used to create a mapping from top-level functions to functions called inside them. + Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, + add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible. + """ + if self.current_function is not None: + # Simple function calls such as foo() + if m.matches(node.func, m.Name()): + self.function_call_dependency_mapping[self.current_function].add(node.func.value) + + def leave_Module(self, node): + """When leaving the module, we finally create the `function_call_recursive_dependency_mapping`, then we + compute the dependencies for all recorded classes based on all the nodes we visited. + We also store the position of each global scoped node to allow sorting the dependencies based on their + position in the code later. We use the PositionProvider metadata wrapper for this. + """ + # assign all nodes + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # now sort the class dependency_mapping based on the position of the nodes + self.start_lines = {} + for id, node in self.global_nodes.items(): + self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + + def _compute_recursive_function_dependencies(self) -> dict[str, set]: + """Based on the 1st level function dependency mapping, create the recursive dependency mapping.""" + recursive_dependencies = {} + for function_name in self.function_call_dependency_mapping.keys(): + # We need to check if they are present in self.functions to avoid built-in functions + all_dependencies = { + dep + for dep in find_all_dependencies(self.function_call_dependency_mapping, start_entity=function_name) + if dep in self.functions.keys() + } + recursive_dependencies[function_name] = all_dependencies + return recursive_dependencies + def augment_dependencies_with_functions(self, dependencies: set) -> set: + """For a set of `dependencies`, augment them by adding all potential functions which are dependencies of + the functions present in the `dependencies`. + """ + new_dependencies = dependencies.copy() + # Go through the set of dependencies + for dep in tuple(dependencies): + if dep in self.function_call_recursive_dependency_mapping.keys(): + new_dependencies.update(self.function_call_recursive_dependency_mapping[dep]) + return new_dependencies + + def compute_class_dependencies(self): + """For each visited class, find its dependencies based on visited the current file + potential merged dependencies. + Note: This function takes care of updating `global_nodes` and `function_call_recursive_dependency_mapping` as well after the + merge with other files dependencies. + """ + # Correctly re-set the global nodes at this point + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # Create the global mapping of recursive dependencies for functions + self.function_call_recursive_dependency_mapping = self._compute_recursive_function_dependencies() -class SuperTransformer(cst.CSTTransformer): - METADATA_DEPENDENCIES = (ParentNodeProvider,) + self.class_dependency_mapping = {} + for class_name, class_node in self.classes.items(): + dependencies = ClassDependencyMapper.dependencies_for_node(class_node, set(self.global_nodes.keys())) + # Corretcly augment class dependencies with all needed functions + self.class_dependency_mapping[class_name] = self.augment_dependencies_with_functions(dependencies) - def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None): - self.python_module = python_module - self.original_methods = original_methods - self.updated_methods = updated_methods - self.all_assign_target = {} - self.deleted_targets = {} # child node can delete some arguments - self.class_name = class_name - self.all_bases = all_bases or [] - self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) + @abstractmethod + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + pass + - def update_body(self, existing_body, new_statements): - """ - Helper method to update the body by removing duplicates before adding new statements. - `existing_body` is the body of the original method, the parent class - `new_statements` are the additional statements - """ - deduplicated_new_body = [] - existing_nodes = set() - for node in new_statements: - if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): - target = self.python_module.code_for_node(node.body[0].targets[0].target) - self.all_assign_target[target] = node - if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): - target = self.python_module.code_for_node(node.body[0].target) - self.deleted_targets[target] = node +class ModelFileMapper(ModuleMapper): + """A mapper designed for model-specific files (i.e. a `transformers.models.xxx` file). When encountering such a file + in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. + For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes + care of correctly merging dependencies, then finalizes all dependency graph computations.""" - for stmt in existing_body: - if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): - target = self.python_module.code_for_node(stmt.body[0].targets[0].target) - if target in self.deleted_targets: - logger.warning(f"Deleted the assign for {target}") - continue - if target in self.all_assign_target: - stmt = self.all_assign_target[target] - # Skip the docstring (will be added later on, at the beginning) - elif m.matches(stmt, DOCSTRING_NODE): - continue - comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() - comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - deduplicated_new_body.append(stmt) - existing_nodes.add(comment_less_code) + def __init__(self, python_module: cst.Module): + super().__init__(python_module) + + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + """Compute the relative order that the `missing_dependencies` should have between themselves in the output file. + """ + relative_order = {} + idx = 0 + classes = sorted([dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x]) + # This is because for merged dependencies, we only have relative order in the other visited file, so we need + # to track dependency order relative to a given class + if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): + raise ValueError("Cannot correctly find the relative order of the dependencies.") + + remaining_dependencies = missing_dependencies.copy() - for node in new_statements: - code = self.python_module.code_for_node(node) - comment_less_code = re.sub(r"#.*", "", code).strip() - comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if node not in deduplicated_new_body and comment_less_code not in existing_nodes: - if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): - deduplicated_new_body.append(node) - existing_nodes.add(comment_less_code) + # Start by tracking relative order class by class + for class_name in classes: + class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + # We need to differentiate between nodes that were already present (we can get relative order globally) and + # nodes that were merged (we can get relative order only relative to the class the dependencies relate to) + for class_dep in class_dependencies: + if class_dep in self.modular_file_start_lines: + merged_dependencies.append(class_dep) + else: + original_dependencies.append(class_dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + remaining_dependencies.remove(dep) + relative_order[dep] = idx + idx += 1 + # Add the class itself + remaining_dependencies.remove(class_name) + relative_order[class_name] = idx + idx += 1 - # Fix the post_init() that has to be last - for i, node in enumerate(deduplicated_new_body): - code = self.python_module.code_for_node(node) - comment_less_code = re.sub(r"#.*", "", code).strip() - comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if "self.post_init(" in comment_less_code and i < len(deduplicated_new_body) - 1: - # Remove it and add it again at the end - deduplicated_new_body.pop(i) - deduplicated_new_body.append(node) - break + # Now add what still remains + remaining_dependencies = tuple(remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + for dep in remaining_dependencies: + if dep in self.modular_file_start_lines: + merged_dependencies.append(dep) + else: + original_dependencies.append(dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + relative_order[dep] = idx + idx += 1 - return deduplicated_new_body + return relative_order - def _fix_init_location(self, new_body): - """Fix the location of the super()__init__ in the new body, if we had new statements before it.""" - start_index = 0 - for i, node in enumerate(new_body): - if m.matches(node, DOCSTRING_NODE) and i == start_index: - start_index += 1 - continue - code = self.python_module.code_for_node(node) - comment_less_code = re.sub(r"#.*", "", code).strip() - comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if "super().__init__" in comment_less_code and i > start_index: - # Remove it and add it again at the top after the docstrings - node = new_body.pop(i) - new_body = new_body[:start_index] + [node] + new_body[start_index:] - break - return new_body + def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapping: dict[str, set]): + """Update the global nodes and function dependency mapping with those from the modular file. - def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: - """Updates the body of the input `node`'s `func_name` function by replacing calls - to super().func_name() with the source code of the parent class' `func_name`. - It keeps everything that is defined before `super().func_name()`. + Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies + instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). """ - self.has_docstring = False - parent_has_docstring = False - if func_name in self.original_methods: - parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE) - new_body = [] - has_super_call = False + # Add/overwrite all needed function nodes and dependencies + self.functions.update(functions) + self.function_call_dependency_mapping.update(function_call_mapping) - for i, expr in enumerate(node.body): - if is_call_to_super(expr, func_name): - has_super_call = True - new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) - new_body = self._fix_init_location(new_body) - else: - expr = expr.visit(self.transformer) - if m.matches(expr, DOCSTRING_NODE): - self.has_docstring = True - if parent_has_docstring: # actually here we ought to de-duplicate? - original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value - updated_docstring = expr.body[0].value.value - merged_doc = merge_docstrings(original_docstring, updated_docstring) - new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])] - else: - new_node = [expr] - new_body.extend(new_node) - elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call: - new_body.append(expr) - if not self.has_docstring and parent_has_docstring: - new_body = [self.original_methods[func_name].body.body[0]] + new_body - return node.with_changes(body=new_body) + def _merge_assignments(self, assignments: dict[str, cst.CSTNode]): + """Update the global nodes with the assignment from the modular file. - def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: - if updated_node.name.value in self.updated_methods: - name = updated_node.name.value - new_body = self.replace_super_calls(updated_node.body, name) - return updated_node.with_changes(body=new_body, params=updated_node.params) - return updated_node + Merging rule: if any assignment with the same name was redefined in the modular, we use it ONLY if it is + in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value. This rule was chosen to avoid having to rewrite the + big docstrings. + """ + for assignment, node in assignments.items(): + if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments: + self.assignments[assignment] = node - def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode: - """ "When a return statement is reached, it is replaced with the unrolled super code""" - if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))): - func_def = self.get_metadata(ParentNodeProvider, original_node) - if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods: - updated_return_value = updated_node.value.with_changes( - args=[ - cst.Arg( - value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))]) - ) - ] - ) - return updated_node.with_changes(value=updated_return_value) - return updated_node + def merge_modular_dependencies(self, functions, function_mapping, assignments, start_lines): + """Merge both functions and assignments from the modular definitions into the current module file, + then compute the relative order of all nodes.""" + self._merge_functions(functions, function_mapping) + self._merge_assignments(assignments) + self.modular_file_start_lines = start_lines + @classmethod + def visit_and_merge_dependencies(cls, module: cst.Module, functions, function_mapping, assignments, start_lines) -> "ModelFileMapper": + wrapper = MetadataWrapper(module) + mapper = cls(module) + wrapper.visit(mapper) + # Merge dependencies + mapper.merge_modular_dependencies(functions, function_mapping, assignments, start_lines) + # Create the class dependencies graph + mapper.compute_class_dependencies() + return mapper + -def replace_class_node( - mapper: ModelFileMapper, updated_node: cst.ClassDef, class_name: str, all_bases: List[str] -): +def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef): """ - Given the `class_name`, the `updated_node`'s call to super are unpacked. + Replace a class node which inherits from an imported model-class. This function works in the following way: + - start from the class node of the inherited class + - replace all methods with the same name with the ones defined in the modular + - append all new methods defined in the modular + - replace all calls to super() with the unravelled code | ```python | | ```python | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): @@ -760,6 +761,9 @@ def replace_class_node( | self.post_init() | ``` """ + all_bases = [k.value.value for k in class_node.bases] + class_name = class_node.name.value + original_node = mapper.classes[class_name] original_methods = { f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f @@ -767,7 +771,7 @@ def replace_class_node( } updated_methods = { f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f - for f in updated_node.body.body + for f in class_node.body.body } end_meth = [] @@ -810,7 +814,7 @@ def replace_class_node( end_meth.append(func) # Port new methods that are defined only in modular-file and append at the end - for func in updated_node.body.body: + for func in class_node.body.body: name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class! # Extract the original docstring @@ -832,6 +836,7 @@ def replace_class_node( assign_targets[target] = func end_meth = docstring_node + list(assign_targets.values()) + end_meth + # Replace the calls to `super()` with the unrolled code result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) @@ -841,10 +846,9 @@ def replace_class_node( new_replacement_body = new_replacement_class.body[0].body # get the indented block # Use decorators redefined in `modular_xxx.py` if any - new_decorators = updated_node.decorators if len(updated_node.decorators) > 0 else original_node.decorators - + new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators # Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`) - name = updated_node.name + name = class_node.name return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name) @@ -1081,7 +1085,6 @@ class node based on the inherited classes if needed. raise ValueError( f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." ) - all_bases = [k.value.value for k in node.bases] file_type = find_file_type(class_name) file_to_update = files[file_type] @@ -1103,7 +1106,7 @@ class node based on the inherited classes if needed. mapper = self.visited_modules[super_file_name] # Create the new class node - updated_node = replace_class_node(mapper, node, class_name, all_bases) + updated_node = replace_class_node(mapper, node) # The node was modified -> look for all dependencies (recursively) of the new node new_node_dependencies = ClassDependencyMapper.dependencies_for_new_node(updated_node, mapper) From 8f3b764ce1c6fc154e31e4108cba53682099085b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 10:19:21 +0100 Subject: [PATCH 08/40] finalize imports --- utils/modular_model_converter.py | 214 +++++++++++++++++++++---------- 1 file changed, 144 insertions(+), 70 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index c66247d4105..53d04e88e2d 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -17,14 +17,14 @@ import importlib import os import re -from collections import defaultdict, deque -from typing import Dict, List, Optional, Set from abc import ABC, abstractmethod +from collections import defaultdict, deque +from typing import Dict, Set import libcst as cst from check_copies import run_ruff from create_dependency_mapping import find_priority_list -from libcst import ClassDef, CSTTransformer, CSTVisitor +from libcst import ClassDef, CSTVisitor from libcst import matchers as m from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider @@ -54,6 +54,7 @@ def get_module_source_from_name(module_name: str) -> str: source_code = file.read() return source_code + class ReplaceNameTransformer(m.MatcherDecoratableTransformer): """A transformer that replaces `old_name` with `new_name` in comments, string and any references. It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. @@ -466,10 +467,12 @@ def forward(...): "_CHECKPOINT_FOR_DOC", } + class ClassDependencyMapper(CSTVisitor): """A visitor which is designed to analyze a single class node to get all its dependencies that are mutual with `global_names`. This class is used through the 2 convenient class methods. """ + METADATA_DEPENDENCIES = (ParentNodeProvider,) def __init__(self, class_name: str, global_names: set | None): @@ -510,6 +513,7 @@ class ModuleMapper(CSTVisitor, ABC): """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes and functions. It defines common visiting patterns between the modular file and the model-specific module files that will be visited. """ + METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) def __init__(self, python_module: cst.Module): @@ -528,17 +532,17 @@ def visit_SimpleStatementLine(self, node): Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. """ - if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( - self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() - ): - left_hand_side = node.body[0].targets[0].target - if hasattr(left_hand_side, "value"): - self.assignments[left_hand_side.value] = node - else: - for idx, target in enumerate(list(left_hand_side.elements)): - self.assignments[target.value.value] = node.body[0].value.elements[idx].value - if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): - self.imports.append(node) + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): + left_hand_side = node.body[0].targets[0].target + if hasattr(left_hand_side, "value"): + self.assignments[left_hand_side.value] = node + else: + for idx, target in enumerate(list(left_hand_side.elements)): + self.assignments[target.value.value] = node.body[0].value.elements[idx].value + elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + self.imports.append(node) def visit_FunctionDef(self, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) @@ -551,7 +555,7 @@ def leave_FunctionDef(self, node): if m.matches(parent_node, m.Module()): self.current_function = None - def leave_If(self, node): + def visit_If(self, node): for stmt in node.body.body: if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): self.imports.append(node) @@ -606,7 +610,7 @@ def augment_dependencies_with_functions(self, dependencies: set) -> set: if dep in self.function_call_recursive_dependency_mapping.keys(): new_dependencies.update(self.function_call_recursive_dependency_mapping[dep]) return new_dependencies - + def compute_class_dependencies(self): """For each visited class, find its dependencies based on visited the current file + potential merged dependencies. Note: This function takes care of updating `global_nodes` and `function_call_recursive_dependency_mapping` as well after the @@ -626,7 +630,7 @@ def compute_class_dependencies(self): @abstractmethod def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: pass - + class ModelFileMapper(ModuleMapper): """A mapper designed for model-specific files (i.e. a `transformers.models.xxx` file). When encountering such a file @@ -636,18 +640,19 @@ class ModelFileMapper(ModuleMapper): def __init__(self, python_module: cst.Module): super().__init__(python_module) - + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - """Compute the relative order that the `missing_dependencies` should have between themselves in the output file. - """ + """Compute the relative order that the `missing_dependencies` should have between themselves in the output file.""" relative_order = {} idx = 0 - classes = sorted([dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x]) + classes = sorted( + [dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x] + ) # This is because for merged dependencies, we only have relative order in the other visited file, so we need # to track dependency order relative to a given class if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): raise ValueError("Cannot correctly find the relative order of the dependencies.") - + remaining_dependencies = missing_dependencies.copy() # Start by tracking relative order class by class @@ -665,7 +670,7 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: # Sort both list according to the order in their respective file original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) - + # Add all original node first, then merged ones for dep in original_dependencies + merged_dependencies: remaining_dependencies.remove(dep) @@ -688,7 +693,7 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: # Sort both list according to the order in their respective file original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) - + # Add all original node first, then merged ones for dep in original_dependencies + merged_dependencies: relative_order[dep] = idx @@ -699,7 +704,7 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapping: dict[str, set]): """Update the global nodes and function dependency mapping with those from the modular file. - Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies + Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). """ # Add/overwrite all needed function nodes and dependencies @@ -709,7 +714,7 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapp def _merge_assignments(self, assignments: dict[str, cst.CSTNode]): """Update the global nodes with the assignment from the modular file. - Merging rule: if any assignment with the same name was redefined in the modular, we use it ONLY if it is + Merging rule: if any assignment with the same name was redefined in the modular, we use it ONLY if it is in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value. This rule was chosen to avoid having to rewrite the big docstrings. """ @@ -725,7 +730,9 @@ def merge_modular_dependencies(self, functions, function_mapping, assignments, s self.modular_file_start_lines = start_lines @classmethod - def visit_and_merge_dependencies(cls, module: cst.Module, functions, function_mapping, assignments, start_lines) -> "ModelFileMapper": + def visit_and_merge_dependencies( + cls, module: cst.Module, functions, function_mapping, assignments, start_lines + ) -> "ModelFileMapper": wrapper = MetadataWrapper(module) mapper = cls(module) wrapper.visit(mapper) @@ -734,7 +741,7 @@ def visit_and_merge_dependencies(cls, module: cst.Module, functions, function_ma # Create the class dependencies graph mapper.compute_class_dependencies() return mapper - + def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef): """ @@ -770,8 +777,7 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef): for f in original_node.body.body } updated_methods = { - f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f - for f in class_node.body.body + f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body } end_meth = [] @@ -897,6 +903,7 @@ def find_file_type(class_name: str) -> str: file_type = "modeling" return file_type + # These top-level variables will always appear the very beginning of the file, in the order they are defined in # this list (this is to avoid having variables at weird places, even if they are not used before) VARIABLES_AT_THE_BEGINNING = [ @@ -905,22 +912,92 @@ def find_file_type(class_name: str) -> str: "_CONFIG_FOR_DOC", ] +def get_module_name(node: cst.ImportFrom) -> str: + """Recursively get the fully dotted name of a module in a cst.ImportFrom.""" + if m.matches(node, m.Name()): + return node.value + elif m.matches(node, m.Attribute()): + # Recursively get the full name for attributes + return f"{get_module_name(node.value)}.{node.attr.value}" + return "" + +def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: dict[str, cst.CSTNode], current_idx: int): + """Insert the new `node` to the dict of `imports_to_keep` in-place, if it is not part of the `unused_imports`. + This function takes cares of aggregating similar ImportFrom, i.e. if we ever saw a statement such as + `from typing import Any`, and later another one `from typing import List`, we will aggregate as + `from typing import Any, List` in a single statement. + """ + import_node = node.body[0] + if m.matches(import_node, m.ImportFrom()): + module_name = get_module_name(import_node.module) + else: + module_name = current_idx + + # If we have a new import from with the same module name, write new names to the same import statement + names_to_keep = [name for name in imports_to_keep[module_name].body[0].names] if module_name in imports_to_keep else [] + for name in import_node.names: + name_value = name.evaluated_name + if name_value not in unused_imports: + names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) + if len(names_to_keep) > 0: + imports_to_keep[module_name] = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) + + +def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: + """Get all the imports needed in the `body`, from the list of `all_imports`. + Note: we need to use `isinstance` on assignements, m.matches apparently does not work here yet! + """ + new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] + wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) + scopes = set(wrapper.resolve(ScopeProvider).values()) + unused_imports = set() + import_ref_count = {} + for scope in scopes: + for assignment in scope.assignments: + node = assignment.node + if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): + ref_count = len(assignment.references) + name = assignment.name + # Similar imports may be redefined, and only used between their 1st and 2nd definition + # so if we already have a ref count > 0, the imports is not unused + if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys(): + unused_imports.add(name) + import_ref_count[name] = ref_count + + # Note that dicts implicitly keep the order of insertion + imports_to_keep = {} + for idx, node in enumerate(all_imports): + if m.matches(node, m.If()): + new_statements = {} + for second_idx, stmt_node in enumerate(node.body.body): + append_new_import_node(stmt_node, unused_imports, new_statements, second_idx) + if len(new_statements) > 0: + imports_to_keep[idx] = node.with_changes(body=node.body.with_changes(body=list(new_statements.values()))) + else: + append_new_import_node(node, unused_imports, imports_to_keep, idx) + + return list(imports_to_keep.values()) + + class ModularFileMapper(ModuleMapper): """This is a Mapper for a modular file. It visits the whole file, recording dependency, then visits all model-specific files that should be visited, and manages their mutual dependencies. Calling the method `create_modules()` after visit will create all modules based on this modular file. """ + def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None): super().__init__(python_module) + # fmt: off self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3` self.given_old_name = given_old_name self.given_new_name = given_new_name - self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} + self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} - self.match_patterns = "|".join(list(TYPE_TO_FILE_TYPE.values()).append("modeling")) + self.match_patterns = "|".join(list(TYPE_TO_FILE_TYPE.values()) + ["modeling"]) self.all_all_to_add = {} + # fmt: on def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from model-specific files (i.e. `transformers.models.xxx`) we get the code, parse it, @@ -964,17 +1041,14 @@ def visit_SimpleStatementLine(self, node): ) if m.matches(parent_node, m.Module()): if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): - if node not in self.imports: - self.imports.append(node) + self.imports.append(node) elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): full_statement = self.python_module.code_for_node(node.body[0].module) - if ( + if not ( # OR MATCH ..llama.modeling_llama re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement) and "auto.modeling_auto" not in full_statement ): - return - if node not in self.imports: self.imports.append(node) elif m.matches(node, simple_top_level_assign_structure): assigned_variable = node.body[0].targets[0].target.value @@ -994,7 +1068,9 @@ def visit_SimpleStatementLine(self, node): all_all_to_add[file] += [class_name] for file, new_alls in all_all_to_add.items(): new_node = assign_node.with_changes( - value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) + value=cst.List( + elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls] + ) ) self.all_all_to_add[file] = node.with_changes(body=[new_node]) @@ -1013,18 +1089,25 @@ def leave_Module(self, node): self.visited_modules = {} for file, module in self.model_specific_modules.items(): file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0] - renamer = ReplaceNameTransformer(file_model_name, self.model_name, self.given_old_name, self.given_new_name) + renamer = ReplaceNameTransformer( + file_model_name, self.model_name, self.given_old_name, self.given_new_name + ) renamed_module = module.visit(renamer) - self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(renamed_module, self.functions, self.function_call_dependency_mapping, - self.assignments, self.start_lines) + self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( + renamed_module, + self.functions, + self.function_call_dependency_mapping, + self.assignments, + self.start_lines, + ) - # In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the # definitions found in the visited files self.merge_model_specific_imports(self.visited_modules) # Re-assign all nodes self.global_nodes = {**self.assignments, **self.classes, **self.functions} - + def merge_model_specific_imports(self, visited_modules): """Merge the model-specific imported functions and assignments to the modular nodes and dependency graph, based on the visited files.""" @@ -1043,15 +1126,14 @@ def merge_model_specific_imports(self, visited_modules): for dep in dependencies: self.added_objects_file_mapping[dep] = file self.functions[dep] = visited_module.global_nodes[dep] - + # Add assignments elif object_name in visited_module.assignments and object_name not in self.assignments: self.added_objects_file_mapping[object_name] = file self.assignments[object_name] = visited_module.assignments[object_name] def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - """Compute the relative order that the `missing_dependencies` should have between themselves in the output file. - """ + """Compute the relative order that the `missing_dependencies` should have between themselves in the output file.""" relative_order = {} idx = 0 @@ -1068,7 +1150,7 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: for file, dependencies in other_files_dependencies.items(): sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) all_dependencies += sorted_dependencies - + # Add all original node first, then merged ones (one file at a time) for dep in all_dependencies: relative_order[dep] = idx @@ -1094,14 +1176,6 @@ class node based on the inherited classes if needed. super_class = bases[0] super_file_name = self.model_specific_imported_objects[super_class] - model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", super_file_name) - if model_name: - model_name = model_name.groups()[0] - else: - raise ValueError( - f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name" - ) - # Get the mapper corresponding to the inherited class mapper = self.visited_modules[super_file_name] @@ -1118,16 +1192,15 @@ class node based on the inherited classes if needed. relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) nodes_to_add = { - dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) - for dep in all_dependencies_to_add + dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add } # No super class, just check functions and assignments dependency in the imports from other model-specific files else: updated_node = node - # The node was NOT modified -> no need to look for recursive dependencies + # The node was NOT modified -> no need to look for dependencies recursively all_dependencies_to_add = ClassDependencyMapper.dependencies_for_node(updated_node, self.global_nodes) - + relative_dependency_order = self.compute_relative_order(all_dependencies_to_add) nodes_to_add = { dep: (relative_dependency_order[dep], self.global_nodes[dep]) @@ -1166,20 +1239,21 @@ def create_modules(self) -> dict[str, cst.Module]: idx = current_file_indices[file_type] files[file_type]["__all__"] = {"insert_idx": idx, "node": node} - # Merge imports - # TODO: use scope solution instead - imports = {self.python_module.code_for_node(k): k for k in self.imports} - dependency_imports = {file_type: imports.copy() for file_type in files} - for super_file_name, visiter in self.visited_modules.items(): - file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - dependency_imports[file_type].update( - {self.python_module.code_for_node(k): k for k in visiter.imports.values()} - ) + # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves) + all_imports = self.imports.copy() + all_imports_code = {self.python_module.code_for_node(node) for node in all_imports} + for file, mapper in self.visited_modules.items(): + new_imports = [node for node in mapper.imports if mapper.python_module.code_for_node(node) not in all_imports_code] + new_imports_code = {mapper.python_module.code_for_node(node) for node in new_imports} + all_imports.extend(new_imports) + all_imports_code.update(new_imports_code) + # Find the correct imports, and write the new modules for file, body in files.items(): new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] - new_body = list(dependency_imports[file].values()) + new_body - new_module = cst.Module(body=[*new_body], header=self.python_module.header) + needed_imports = get_needed_imports(body, all_imports) + full_module = needed_imports + new_body + new_module = cst.Module(body=full_module, header=self.python_module.header) files[file] = new_module return files From 1084ca70633cf0e3dd0c5afa73abf339c43f7e71 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 10:49:32 +0100 Subject: [PATCH 09/40] imports --- .../modular_llava_next_video.py | 1 - utils/modular_model_converter.py | 21 +++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 2025140bb6e..450685e3606 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -309,7 +309,6 @@ def get_video_features( video_features = torch.split(video_features, frames, dim=0) return video_features - @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class="LlavaNextVideoConfig") def forward( self, input_ids: torch.LongTensor = None, diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 53d04e88e2d..3655b1f0716 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -959,7 +959,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> ref_count = len(assignment.references) name = assignment.name # Similar imports may be redefined, and only used between their 1st and 2nd definition - # so if we already have a ref count > 0, the imports is not unused + # so if we already have a ref count > 0, the imports is actually used if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys(): unused_imports.add(name) import_ref_count[name] = ref_count @@ -976,7 +976,15 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> else: append_new_import_node(node, unused_imports, imports_to_keep, idx) - return list(imports_to_keep.values()) + protected_import_nodes = [node for node in imports_to_keep.values() if m.matches(node, m.If())] + usual_import_nodes = [node for node in imports_to_keep.values() if not m.matches(node, m.If())] + # If the same import is both protected and unprotected, only keep the protected one + for protected_node in protected_import_nodes: + for stmt_node in protected_node.body.body: + usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]] + + # Protected imports always appear at the end of all imports + return usual_import_nodes + protected_import_nodes class ModularFileMapper(ModuleMapper): @@ -1239,12 +1247,13 @@ def create_modules(self) -> dict[str, cst.Module]: idx = current_file_indices[file_type] files[file_type]["__all__"] = {"insert_idx": idx, "node": node} - # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves) + # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because + # they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc) all_imports = self.imports.copy() - all_imports_code = {self.python_module.code_for_node(node) for node in all_imports} + all_imports_code = {self.python_module.code_for_node(node).strip() for node in all_imports} for file, mapper in self.visited_modules.items(): - new_imports = [node for node in mapper.imports if mapper.python_module.code_for_node(node) not in all_imports_code] - new_imports_code = {mapper.python_module.code_for_node(node) for node in new_imports} + new_imports = [node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code] + new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} all_imports.extend(new_imports) all_imports_code.update(new_imports_code) From 39a0a89788d19c161ca5a9e875034225bd822075 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 10:54:02 +0100 Subject: [PATCH 10/40] Update modular_model_converter.py --- utils/modular_model_converter.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 3655b1f0716..a801cd6d244 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -912,6 +912,7 @@ def find_file_type(class_name: str) -> str: "_CONFIG_FOR_DOC", ] + def get_module_name(node: cst.ImportFrom) -> str: """Recursively get the fully dotted name of a module in a cst.ImportFrom.""" if m.matches(node, m.Name()): @@ -921,9 +922,12 @@ def get_module_name(node: cst.ImportFrom) -> str: return f"{get_module_name(node.value)}.{node.attr.value}" return "" -def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: dict[str, cst.CSTNode], current_idx: int): + +def append_new_import_node( + node: cst.CSTNode, unused_imports: set[str], imports_to_keep: dict[str, cst.CSTNode], current_idx: int +): """Insert the new `node` to the dict of `imports_to_keep` in-place, if it is not part of the `unused_imports`. - This function takes cares of aggregating similar ImportFrom, i.e. if we ever saw a statement such as + This function takes cares of aggregating similar ImportFrom, i.e. if we ever saw a statement such as `from typing import Any`, and later another one `from typing import List`, we will aggregate as `from typing import Any, List` in a single statement. """ @@ -934,7 +938,8 @@ def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_ module_name = current_idx # If we have a new import from with the same module name, write new names to the same import statement - names_to_keep = [name for name in imports_to_keep[module_name].body[0].names] if module_name in imports_to_keep else [] + names_to_keep = list(imports_to_keep[module_name].body[0].names) if module_name in imports_to_keep else [] + for name in import_node.names: name_value = name.evaluated_name if name_value not in unused_imports: @@ -972,7 +977,9 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> for second_idx, stmt_node in enumerate(node.body.body): append_new_import_node(stmt_node, unused_imports, new_statements, second_idx) if len(new_statements) > 0: - imports_to_keep[idx] = node.with_changes(body=node.body.with_changes(body=list(new_statements.values()))) + imports_to_keep[idx] = node.with_changes( + body=node.body.with_changes(body=list(new_statements.values())) + ) else: append_new_import_node(node, unused_imports, imports_to_keep, idx) @@ -1212,7 +1219,8 @@ class node based on the inherited classes if needed. relative_dependency_order = self.compute_relative_order(all_dependencies_to_add) nodes_to_add = { dep: (relative_dependency_order[dep], self.global_nodes[dep]) - for dep in all_dependencies_to_add if dep not in file_to_update.keys() + for dep in all_dependencies_to_add + if dep not in file_to_update.keys() } # Add the class node itself to the nodes to add @@ -1252,7 +1260,11 @@ def create_modules(self) -> dict[str, cst.Module]: all_imports = self.imports.copy() all_imports_code = {self.python_module.code_for_node(node).strip() for node in all_imports} for file, mapper in self.visited_modules.items(): - new_imports = [node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code] + new_imports = [ + node + for node in mapper.imports + if mapper.python_module.code_for_node(node).strip() not in all_imports_code + ] new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} all_imports.extend(new_imports) all_imports_code.update(new_imports_code) From 3ba751a74d7b0f2be34184dd2c8be939a430ef54 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 11:33:03 +0100 Subject: [PATCH 11/40] Better renaming to avoid visiting same file multiple times --- utils/modular_model_converter.py | 88 ++++++++++++-------------------- 1 file changed, 32 insertions(+), 56 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index a801cd6d244..e577313fb77 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -54,6 +54,24 @@ def get_module_source_from_name(module_name: str) -> str: source_code = file.read() return source_code +def preserve_case_replace(text, patterns: dict, default_name: str): + # Create a regex pattern to match all variations + regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) + compiled_regex = re.compile(regex_pattern, re.IGNORECASE) + + def replace(match): + word = match.group(0) + result = patterns.get(word, default_name) + return result + + return compiled_regex.sub(replace, text) + +def convert_to_camelcase(text, old_name: str, default_old_name: str): + # Regex pattern to match consecutive uppercase letters and lowercase the first set + result = re.sub( + rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1 + ) + return result class ReplaceNameTransformer(m.MatcherDecoratableTransformer): """A transformer that replaces `old_name` with `new_name` in comments, string and any references. @@ -92,34 +110,15 @@ def __init__( if self.default_old_name.isupper(): self.default_old_name = self.default_old_name.capitalize() - def preserve_case_replace(self, text): - # Create a regex pattern to match all variations - regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys()) - compiled_regex = re.compile(regex_pattern, re.IGNORECASE) - - def replace(match): - word = match.group(0) - result = self.patterns.get(word, self.default_name) - return result - - return compiled_regex.sub(replace, text) - - def convert_to_camelcase(self, text): - # Regex pattern to match consecutive uppercase letters and lowercase the first set - result = re.sub( - rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1 - ) - return result - @m.leave(m.Name() | m.SimpleString() | m.Comment()) def replace_name(self, original_node, updated_node): if re.findall(r"# Copied from", updated_node.value): return cst.RemoveFromParent() - update = self.preserve_case_replace(updated_node.value) + update = preserve_case_replace(updated_node.value, self.patterns, self.default_name) return updated_node.with_changes(value=update) def leave_ClassDef(self, original_node, updated_node): - return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) + return updated_node.with_changes(name=cst.Name(convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name))) DOCSTRING_NODE = m.SimpleStatementLine( @@ -236,13 +235,12 @@ def merge_docstrings(original_docstring, updated_docstring): class SuperTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None): + def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None): self.python_module = python_module self.original_methods = original_methods self.updated_methods = updated_methods self.all_assign_target = {} self.deleted_targets = {} # child node can delete some arguments - self.class_name = class_name self.all_bases = all_bases or [] self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) @@ -743,7 +741,7 @@ def visit_and_merge_dependencies( return mapper -def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef): +def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str): """ Replace a class node which inherits from an imported model-class. This function works in the following way: - start from the class node of the inherited class @@ -769,9 +767,8 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef): | ``` """ all_bases = [k.value.value for k in class_node.bases] - class_name = class_node.name.value - original_node = mapper.classes[class_name] + original_node = mapper.classes[renamed_super_class] original_methods = { f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in original_node.body.body @@ -846,9 +843,7 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef): result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) - new_replacement_class = new_module.visit( - SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases) - ) + new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods, all_bases)) new_replacement_body = new_replacement_class.body[0].body # get the indented block # Use decorators redefined in `modular_xxx.py` if any @@ -867,32 +862,6 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef): "FeatureExtractor": "feature_extractor", } - -def get_new_part(class_name, base_class): - """ - When `MyClassNameAttention` inherits from `MistralAttention`, we need - to process the name to properly find dependencies. - - Here we take what is the same (Attention) and what is different - when finding the dependencies. - """ - common_suffix_len = 0 - for i in range(1, min(len(class_name), len(base_class)) + 1): - if class_name[-i] == base_class[-i]: - common_suffix_len += 1 - else: - break - - if common_suffix_len > 0: - new_part = class_name[:-common_suffix_len] - else: - new_part = class_name - - # Convert the remaining new part to snake_case - snake_case = re.sub(r"(? str: """Based on a class name, find the file type corresponding to the class.""" match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) @@ -1102,6 +1071,7 @@ def leave_Module(self, node): # Now, visit every model-specific files found in the imports, and merge their dependencies self.visited_modules = {} + self.renamers = {} for file, module in self.model_specific_modules.items(): file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0] renamer = ReplaceNameTransformer( @@ -1115,6 +1085,8 @@ def leave_Module(self, node): self.assignments, self.start_lines, ) + # We record it so that we can rename classes later the exact same way + self.renamers[file] = renamer # In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the # definitions found in the visited files @@ -1193,9 +1165,13 @@ class node based on the inherited classes if needed. # Get the mapper corresponding to the inherited class mapper = self.visited_modules[super_file_name] + # Rename the super class according to the exact same rule we used when renaming the whole module + renamer = self.renamers[super_file_name] + renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name) + renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name) # Create the new class node - updated_node = replace_class_node(mapper, node) + updated_node = replace_class_node(mapper, node, renamed_super_class) # The node was modified -> look for all dependencies (recursively) of the new node new_node_dependencies = ClassDependencyMapper.dependencies_for_new_node(updated_node, mapper) From 7416080e389770b75af3c874b35dc92fbd41ae0d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 11:39:21 +0100 Subject: [PATCH 12/40] start converting files --- .../models/gemma/configuration_gemma.py | 1 - .../models/gemma/modeling_gemma.py | 24 +- .../models/gemma2/configuration_gemma2.py | 2 - .../models/gemma2/modeling_gemma2.py | 35 +-- src/transformers/models/glm/modeling_glm.py | 30 +-- .../modeling_instructblipvideo.py | 198 ++++++++-------- .../modeling_llava_next_video.py | 222 +++++++++--------- .../modular_llava_next_video.py | 1 - 8 files changed, 249 insertions(+), 264 deletions(-) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index e170803ccca..346f386ba69 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -20,7 +20,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9a4de1022c5..dec4e26fee2 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -49,7 +48,10 @@ from .configuration_gemma import GemmaConfig +logger = logging.get_logger(__name__) + _CHECKPOINT_FOR_DOC = "google/gemma-7b" +_CONFIG_FOR_DOC = "GemmaConfig" class GemmaRMSNorm(nn.Module): @@ -72,9 +74,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -logger = logging.get_logger(__name__) - - class GemmaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -156,13 +155,6 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -390,6 +382,13 @@ def forward( return attn_output, None, past_key_value +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + class GemmaFlashAttention2(GemmaAttention): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays @@ -624,9 +623,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -_CONFIG_FOR_DOC = "GemmaConfig" - - GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 74976bdd340..45006b8ca2f 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -19,8 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6d61c47619f..aa3a926f2c9 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -23,7 +23,6 @@ import torch import torch.nn as nn -import torch.utils.checkpoint from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache @@ -40,6 +39,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, logging, @@ -48,7 +48,15 @@ from .configuration_gemma2 import Gemma2Config +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + _CHECKPOINT_FOR_DOC = "google/gemma2-7b" +_CONFIG_FOR_DOC = "Gemma2Config" class Gemma2RMSNorm(nn.Module): @@ -86,9 +94,6 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -logger = logging.get_logger(__name__) - - class Gemma2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -117,13 +122,6 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -163,6 +161,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -198,12 +203,12 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.rotary_emb = Gemma2RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -495,12 +500,12 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.config = config self.is_sliding = not bool(layer_idx % 2) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window - self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -638,9 +643,6 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): return config -_CONFIG_FOR_DOC = "Gemma2Config" - - GEMMA2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -865,6 +867,7 @@ def forward( attentions=all_self_attns, ) + @torch.no_grad() def _update_causal_mask( self, attention_mask: torch.Tensor, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 5f8eaf89ed9..248ec402179 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -24,7 +24,6 @@ import torch import torch.nn as nn -import torch.utils.checkpoint from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache @@ -50,7 +49,10 @@ from .configuration_glm import GlmConfig +logger = logging.get_logger(__name__) + _CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" +_CONFIG_FOR_DOC = "GlmConfig" class GlmRMSNorm(nn.Module): @@ -121,7 +123,16 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: return self.down_proj(up_states) -logger = logging.get_logger(__name__) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def rotate_half(x): @@ -172,18 +183,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - class GlmAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -608,9 +607,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -_CONFIG_FOR_DOC = "GlmConfig" - - GLM_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index c9f12391666..0a8f383380d 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -24,7 +24,6 @@ from typing import Any, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss @@ -347,104 +346,6 @@ def _init_weights(self, module): module.bias.data.zero_() -INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See - [`InstructBlipVideoProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. -""" - -INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See - [`InstructBlipVideoProcessor.__call__`] for details. - - qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided - to serve as text prompt, which the Q-Former model will encode. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be - provided to serve as text prompt, which the language model can continue. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an - encoder-decoder language model (like T5) is used. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) - - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - Only relevant in case an encoder-decoder language model (like T5) is used. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. -""" - - class InstructBlipVideoEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -531,6 +432,24 @@ def forward( ) +INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See + [`InstructBlipVideoProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. +""" + + class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel): main_input_name = "pixel_values" config_class = InstructBlipVideoVisionConfig @@ -1268,6 +1187,87 @@ def forward( ) +INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See + [`InstructBlipVideoProcessor.__call__`] for details. + + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be + provided to serve as text prompt, which the language model can continue. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an + encoder-decoder language model (like T5) is used. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) + + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. +""" + + @add_start_docstrings( """ InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index a2328c1d2d9..1eb94508ee4 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -25,7 +25,6 @@ import numpy as np import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -33,12 +32,7 @@ from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next_video import LlavaNextVideoConfig @@ -48,113 +42,6 @@ _CONFIG_FOR_DOC = "LlavaNextVideoConfig" -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - - Args: - image_size (`tuple`): - The size of the input image in the format (width, height). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" - ) - image_size = image_size.tolist() - - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size - - -def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): - """ - Calculate the number of patches after the preprocessing for images of any resolution. - - Args: - image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): - The size of the input image in the format (height, width). ? - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - int: the number of patches - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") - image_size = image_size.tolist() - - best_resolution = select_best_resolution(image_size, grid_pinpoints) - height, width = best_resolution - num_patches = 0 - # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - num_patches += 1 - # add the base patch - num_patches += 1 - return num_patches - - -def unpad_image(tensor, original_size): - """ - Unpads a PyTorch tensor of a padded and resized image. - - Args: - tensor (`torch.Tensor`): - The image tensor, assumed to be of shape (num_channels, height, width). - original_size (`tuple`): - The original size of the image (height, width). - - Returns: - `torch.Tensor`: The unpadded image tensor. - """ - if not isinstance(original_size, (list, tuple)): - if not isinstance(original_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" - ) - original_size = original_size.tolist() - original_height, original_width = original_size - current_height, current_width = tensor.shape[1:] - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = int(round(original_height * scale_factor, 7)) - padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding : current_height - padding, :] - else: - scale_factor = current_height / original_height - new_width = int(round(original_width * scale_factor, 7)) - padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding : current_width - padding] - - return unpadded_tensor - - @dataclass class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): """ @@ -304,6 +191,113 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """ + Calculate the number of patches after the preprocessing for images of any resolution. + + Args: + image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): + The size of the input image in the format (height, width). ? + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + int: the number of patches + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + if not isinstance(original_size, (list, tuple)): + if not isinstance(original_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 450685e3606..8018afa7244 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -30,7 +30,6 @@ from ...configuration_utils import PretrainedConfig from ...utils import ( logging, - replace_return_docstrings, ) from ..auto import CONFIG_MAPPING From 4545b63bd8aeee15e64479a675ba0ac4455da1e2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 11:46:59 +0100 Subject: [PATCH 13/40] style --- utils/modular_model_converter.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index e577313fb77..ef7db834c44 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -54,6 +54,7 @@ def get_module_source_from_name(module_name: str) -> str: source_code = file.read() return source_code + def preserve_case_replace(text, patterns: dict, default_name: str): # Create a regex pattern to match all variations regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) @@ -66,13 +67,13 @@ def replace(match): return compiled_regex.sub(replace, text) + def convert_to_camelcase(text, old_name: str, default_old_name: str): # Regex pattern to match consecutive uppercase letters and lowercase the first set - result = re.sub( - rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1 - ) + result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1) return result + class ReplaceNameTransformer(m.MatcherDecoratableTransformer): """A transformer that replaces `old_name` with `new_name` in comments, string and any references. It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. @@ -118,7 +119,9 @@ def replace_name(self, original_node, updated_node): return updated_node.with_changes(value=update) def leave_ClassDef(self, original_node, updated_node): - return updated_node.with_changes(name=cst.Name(convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name))) + return updated_node.with_changes( + name=cst.Name(convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)) + ) DOCSTRING_NODE = m.SimpleStatementLine( @@ -843,7 +846,9 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) - new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods, all_bases)) + new_replacement_class = new_module.visit( + SuperTransformer(temp_module, original_methods, updated_methods, all_bases) + ) new_replacement_body = new_replacement_class.body[0].body # get the indented block # Use decorators redefined in `modular_xxx.py` if any @@ -862,6 +867,7 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename "FeatureExtractor": "feature_extractor", } + def find_file_type(class_name: str) -> str: """Based on a class name, find the file type corresponding to the class.""" match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) From 5958f6461bdb00771b0f9a0cadaf9870c57c3436 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 17:35:35 +0100 Subject: [PATCH 14/40] address most comments --- utils/modular_model_converter.py | 235 ++++++++++++++++++------------- 1 file changed, 140 insertions(+), 95 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index ef7db834c44..cdbc382cd2a 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -119,9 +119,8 @@ def replace_name(self, original_node, updated_node): return updated_node.with_changes(value=update) def leave_ClassDef(self, original_node, updated_node): - return updated_node.with_changes( - name=cst.Name(convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)) - ) + new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name) + return updated_node.with_changes(name=cst.Name(new_name)) DOCSTRING_NODE = m.SimpleStatementLine( @@ -288,21 +287,27 @@ def update_body(self, existing_body, new_statements): deduplicated_new_body.append(node) existing_nodes.add(comment_less_code) + deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) + + return deduplicated_new_body + + def _fix_post_init_location(self, new_body: list[cst.CSTNode]): + """Fix the location of the `post_init()` in the new body, if we added statements after the call to + `super()` (it needs to be the very last statement called)""" # Fix the post_init() that has to be last - for i, node in enumerate(deduplicated_new_body): + for i, node in enumerate(new_body): code = self.python_module.code_for_node(node) comment_less_code = re.sub(r"#.*", "", code).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if "self.post_init(" in comment_less_code and i < len(deduplicated_new_body) - 1: + if "self.post_init(" in comment_less_code and i < len(new_body) - 1: # Remove it and add it again at the end - deduplicated_new_body.pop(i) - deduplicated_new_body.append(node) + new_body.pop(i) + new_body.append(node) break - - return deduplicated_new_body + return new_body def _fix_init_location(self, new_body): - """Fix the location of the super()__init__ in the new body, if we had new statements before it.""" + """Fix the location of the `super().__init__()` in the new body, if we had new statements before it.""" start_index = 0 for i, node in enumerate(new_body): if m.matches(node, DOCSTRING_NODE) and i == start_index: @@ -394,15 +399,15 @@ def find_all_dependencies( start_entity (str | None, *optional*): A key of `dependency_mapping`, indicating from which entity to start the search. initial_dependencies (set | None, *optional*): - If `start_entity` is not provided, this can be used as an alternative. In this case, `initial_dependencies` - the search will continue from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. + If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue + from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. initial_checked_dependencies (set | None, *optional*): If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. return_parent (bool, *optional*): If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. Returns: - A set of all the dependencies, or a list containing parents as well if `return_parent=True`. + A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`. Example: Given the following structure in the `modular_xxx.py` file: @@ -470,7 +475,8 @@ def forward(...): class ClassDependencyMapper(CSTVisitor): - """A visitor which is designed to analyze a single class node to get all its dependencies that are mutual with `global_names`. + """A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of + `global_names`. This class is used through the 2 convenient class methods. """ @@ -512,7 +518,10 @@ def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMa class ModuleMapper(CSTVisitor, ABC): """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes and functions. - It defines common visiting patterns between the modular file and the model-specific module files that will be visited. + Class dependencies are computed with `compute_class_dependencies()`, while function dependencies are stored in + `self.function_recursive_dependency_mapping` (can be computed by `_compute_recursive_function_dependencies()`). + It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the + modeling files that will be visited. """ METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) @@ -523,14 +532,14 @@ def __init__(self, python_module: cst.Module): self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes self.imports = [] # stores all import statements self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes - self.function_call_dependency_mapping = defaultdict(set) # 1st-level function dependency mapping + self.function_dependency_mapping = defaultdict(set) # immediate function dependency mapping (i.e. dependencies immediately in the function definition) self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes - self.current_function = None + self.current_function = None # this keeps track of the current module-scope function # fmt: on def visit_SimpleStatementLine(self, node): """ - Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements + Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. """ parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) @@ -565,21 +574,16 @@ def visit_ClassDef(self, node: ClassDef) -> None: """Record class nodes to create their dependencies at the end.""" self.classes[node.name.value] = node - def visit_Call(self, node: cst.Call): - """This is used to create a mapping from top-level functions to functions called inside them. - Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, - add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible. - """ + def visit_Name(self, node: cst.Call): + """This is used to create a mapping from module-scope functions to objects used inside them.""" if self.current_function is not None: - # Simple function calls such as foo() - if m.matches(node.func, m.Name()): - self.function_call_dependency_mapping[self.current_function].add(node.func.value) + self.function_dependency_mapping[self.current_function].add(node.value) def leave_Module(self, node): - """When leaving the module, we finally create the `function_call_recursive_dependency_mapping`, then we - compute the dependencies for all recorded classes based on all the nodes we visited. - We also store the position of each global scoped node to allow sorting the dependencies based on their - position in the code later. We use the PositionProvider metadata wrapper for this. + """When leaving the module, we store the position of each global scoped node to allow sorting the dependencies + based on their position in the code later. We use the PositionProvider metadata wrapper for this. + We also make sure to update `self.function_dependency_mapping` so that it contains only names recorded in + `self.global_nodes`. """ # assign all nodes self.global_nodes = {**self.assignments, **self.classes, **self.functions} @@ -588,39 +592,55 @@ def leave_Module(self, node): for id, node in self.global_nodes.items(): self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + # Since we added every Name as part of `self.function_dependency_mapping`, we now remove those that + # are not part of the recorded objects (i.e. built-in variables, imports, etc) + global_objects = set(self.global_nodes.keys()) + for function_name, dependencies in self.function_dependency_mapping.items(): + self.function_dependency_mapping[function_name] = {dep for dep in dependencies if dep in global_objects} + def _compute_recursive_function_dependencies(self) -> dict[str, set]: - """Based on the 1st level function dependency mapping, create the recursive dependency mapping.""" + """Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the + following file: + ``` + def foo(): + pass + + def bar(): + foo() + + def test(): + bar() + ``` + this visitor can only record immediate dependencies, i.e. it will record the following + `self.function_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create + the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. + """ recursive_dependencies = {} - for function_name in self.function_call_dependency_mapping.keys(): - # We need to check if they are present in self.functions to avoid built-in functions - all_dependencies = { - dep - for dep in find_all_dependencies(self.function_call_dependency_mapping, start_entity=function_name) - if dep in self.functions.keys() - } + for function_name in self.function_dependency_mapping.keys(): + all_dependencies = find_all_dependencies(self.function_dependency_mapping, start_entity=function_name) recursive_dependencies[function_name] = all_dependencies return recursive_dependencies - def augment_dependencies_with_functions(self, dependencies: set) -> set: - """For a set of `dependencies`, augment them by adding all potential functions which are dependencies of - the functions present in the `dependencies`. + def augment_dependencies_with_functions(self, dependencies: set[str]) -> set[str]: + """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** present + in the `dependencies`. """ new_dependencies = dependencies.copy() # Go through the set of dependencies for dep in tuple(dependencies): - if dep in self.function_call_recursive_dependency_mapping.keys(): - new_dependencies.update(self.function_call_recursive_dependency_mapping[dep]) + if dep in self.function_recursive_dependency_mapping.keys(): + new_dependencies.update(self.function_recursive_dependency_mapping[dep]) return new_dependencies def compute_class_dependencies(self): - """For each visited class, find its dependencies based on visited the current file + potential merged dependencies. + """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies. Note: This function takes care of updating `global_nodes` and `function_call_recursive_dependency_mapping` as well after the merge with other files dependencies. """ # Correctly re-set the global nodes at this point self.global_nodes = {**self.assignments, **self.classes, **self.functions} # Create the global mapping of recursive dependencies for functions - self.function_call_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + self.function_recursive_dependency_mapping = self._compute_recursive_function_dependencies() self.class_dependency_mapping = {} for class_name, class_node in self.classes.items(): @@ -634,16 +654,22 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: class ModelFileMapper(ModuleMapper): - """A mapper designed for model-specific files (i.e. a `transformers.models.xxx` file). When encountering such a file + """A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes - care of correctly merging dependencies, then finalizes all dependency graph computations.""" + care of correctly merging dependencies, then finalizes all dependency graph computations. + Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified. + For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies + of the modeling files as well. + """ def __init__(self, python_module: cst.Module): super().__init__(python_module) - def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - """Compute the relative order that the `missing_dependencies` should have between themselves in the output file.""" + def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]: + """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that + will be created based on the modular. + """ relative_order = {} idx = 0 classes = sorted( @@ -710,7 +736,7 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapp """ # Add/overwrite all needed function nodes and dependencies self.functions.update(functions) - self.function_call_dependency_mapping.update(function_call_mapping) + self.function_dependency_mapping.update(function_call_mapping) def _merge_assignments(self, assignments: dict[str, cst.CSTNode]): """Update the global nodes with the assignment from the modular file. @@ -746,10 +772,10 @@ def visit_and_merge_dependencies( def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str): """ - Replace a class node which inherits from an imported model-class. This function works in the following way: - - start from the class node of the inherited class - - replace all methods with the same name with the ones defined in the modular - - append all new methods defined in the modular + Replace a class node which inherits from another modeling class. This function works in the following way: + - start from the base class node of the inherited class (a cst.Node) + - replace all methods of the base node with the methods defined in the child class + - append all new methods defined in the child class - replace all calls to super() with the unravelled code | ```python | | ```python @@ -869,7 +895,10 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename def find_file_type(class_name: str) -> str: - """Based on a class name, find the file type corresponding to the class.""" + """Based on a class name, find the file type corresponding to the class. + If the class name is `LlamaConfig` it will return `configuration`. + The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` + """ match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) match = re.search(rf"({match_pattern})$", class_name) if match: @@ -879,14 +908,18 @@ def find_file_type(class_name: str) -> str: return file_type -# These top-level variables will always appear the very beginning of the file, in the order they are defined in +# These top-level variables will always appear at the very beginning of the file, in the order they are defined in # this list (this is to avoid having variables at weird places, even if they are not used before) -VARIABLES_AT_THE_BEGINNING = [ +VARIABLES_AT_THE_BEGINNING = ( "logger", "_CHECKPOINT_FOR_DOC", "_CONFIG_FOR_DOC", -] +) +# These specific modeling imports should not be visited as other modeling files +IMPORTS_TO_SKIP_IN_MODULAR = ( + "auto.modeling_auto", +) def get_module_name(node: cst.ImportFrom) -> str: """Recursively get the fully dotted name of a module in a cst.ImportFrom.""" @@ -925,7 +958,8 @@ def append_new_import_node( def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: """Get all the imports needed in the `body`, from the list of `all_imports`. - Note: we need to use `isinstance` on assignements, m.matches apparently does not work here yet! + `body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`. + Note: we need to use `isinstance` on scope assignements, m.matches apparently does not work here yet! """ new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) @@ -947,7 +981,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> # Note that dicts implicitly keep the order of insertion imports_to_keep = {} for idx, node in enumerate(all_imports): - if m.matches(node, m.If()): + if m.matches(node, m.If()): # handle safe imports new_statements = {} for second_idx, stmt_node in enumerate(node.body.body): append_new_import_node(stmt_node, unused_imports, new_statements, second_idx) @@ -969,9 +1003,32 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> return usual_import_nodes + protected_import_nodes +def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]: + """Split the `__all__` assignment found in the modular between each corresponding files.""" + all_all_per_file = {} + assign_node = node.body[0] + if isinstance(assign_node.value, cst.List): + # Extract the elements from the list + all_all_to_add = defaultdict(list) + for element in assign_node.value.elements: + if isinstance(element.value, cst.SimpleString): + # Remove quotes and add the string to the elements list + class_name = element.value.value + file = find_file_type(element.value.evaluated_value) + all_all_to_add[file] += [class_name] + for file, new_alls in all_all_to_add.items(): + new_node = assign_node.with_changes( + value=cst.List( + elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls] + ) + ) + all_all_per_file[file] = node.with_changes(body=[new_node]) + return all_all_per_file + + class ModularFileMapper(ModuleMapper): - """This is a Mapper for a modular file. It visits the whole file, recording dependency, then visits all model-specific - files that should be visited, and manages their mutual dependencies. + """This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency, + then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies. Calling the method `create_modules()` after visit will create all modules based on this modular file. """ @@ -991,10 +1048,10 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from model-specific files (i.e. `transformers.models.xxx`) we get the code, parse it, - and record it in `self.model_specific_modules`. The imported objects are recorded in `self.model_specific_imported_objects`. + and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. """ import_statement = self.python_module.code_for_node(node.module) - if "auto.modeling_auto" in import_statement: + if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): return if m.matches(node.module, m.Attribute()): for imported_ in node.names: @@ -1022,7 +1079,7 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: ) def visit_SimpleStatementLine(self, node): - """If we visit an import statement not previously visited, record it. If we visit a top-level assignment, + """If we visit an import statement not previously visited, record it. If we visit a module-scope assignment, simply record it or, if it is `__all__`, split it between files where we should dispatch it. """ parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) @@ -1046,36 +1103,22 @@ def visit_SimpleStatementLine(self, node): if assigned_variable != "__all__": self.assignments[assigned_variable] = node else: - assign_node = node.body[0] - if isinstance(assign_node.value, cst.List): - # Extract the elements from the list - all_all_to_add = defaultdict(list) - for element in assign_node.value.elements: - if isinstance(element.value, cst.SimpleString): - # Remove quotes and add the string to the elements list - class_name = element.value.value - file = find_file_type(element.value.evaluated_value) - all_all_to_add[file] += [class_name] - for file, new_alls in all_all_to_add.items(): - new_node = assign_node.with_changes( - value=cst.List( - elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls] - ) - ) - self.all_all_to_add[file] = node.with_changes(body=[new_node]) + self.all_all_to_add = split_all_assignment(node) def leave_Module(self, node): """When we leave the modular file, we do the following in order: - - compute recursive function dependencies - - for each model-specific file found in the imports, rename it with the new model name, visit it, and update + 1. compute the nested (recursive) function dependencies + 2. for each model-specific file found in the imports, rename it with the new model name, visit it, and update its dependency graph with the new function and assignment definitions found in the modular - - update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) """ + # Takes care of finalizing our visit super().leave_Module(node) - self.function_call_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + # 1. compute the nested (recursive) function dependencies + self.function_recursive_dependency_mapping = self._compute_recursive_function_dependencies() - # Now, visit every model-specific files found in the imports, and merge their dependencies + # 2. for each model-specific file found in the imports, rename it with the new model name, visit it, and update dependencies self.visited_modules = {} self.renamers = {} for file, module in self.model_specific_modules.items(): @@ -1087,20 +1130,17 @@ def leave_Module(self, node): self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( renamed_module, self.functions, - self.function_call_dependency_mapping, + self.function_dependency_mapping, self.assignments, self.start_lines, ) # We record it so that we can rename classes later the exact same way self.renamers[file] = renamer - # In turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the # definitions found in the visited files self.merge_model_specific_imports(self.visited_modules) - # Re-assign all nodes - self.global_nodes = {**self.assignments, **self.classes, **self.functions} - def merge_model_specific_imports(self, visited_modules): """Merge the model-specific imported functions and assignments to the modular nodes and dependency graph, based on the visited files.""" @@ -1113,9 +1153,9 @@ def merge_model_specific_imports(self, visited_modules): if object_name in visited_module.functions and object_name not in self.functions: self.functions[object_name] = visited_module.functions[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.function_call_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.function_recursive_dependency_mapping.get(object_name, None) if dependencies is not None: - self.function_call_recursive_dependency_mapping[object_name] = dependencies + self.function_recursive_dependency_mapping[object_name] = dependencies for dep in dependencies: self.added_objects_file_mapping[dep] = file self.functions[dep] = visited_module.global_nodes[dep] @@ -1125,8 +1165,13 @@ def merge_model_specific_imports(self, visited_modules): self.added_objects_file_mapping[object_name] = file self.assignments[object_name] = visited_module.assignments[object_name] + # Do not forget to re-assign all nodes after the merge + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - """Compute the relative order that the `missing_dependencies` should have between themselves in the output file.""" + """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that + will be created based on the modular. + """ relative_order = {} idx = 0 From cfdafe3391beb8ecaa9a07a5f8bdc11bc6699265 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 17:36:32 +0100 Subject: [PATCH 15/40] style --- utils/modular_model_converter.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index cdbc382cd2a..6c90ff20fcb 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -290,7 +290,7 @@ def update_body(self, existing_body, new_statements): deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) return deduplicated_new_body - + def _fix_post_init_location(self, new_body: list[cst.CSTNode]): """Fix the location of the `post_init()` in the new body, if we added statements after the call to `super()` (it needs to be the very last statement called)""" @@ -604,10 +604,10 @@ def _compute_recursive_function_dependencies(self) -> dict[str, set]: ``` def foo(): pass - + def bar(): foo() - + def test(): bar() ``` @@ -896,7 +896,7 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename def find_file_type(class_name: str) -> str: """Based on a class name, find the file type corresponding to the class. - If the class name is `LlamaConfig` it will return `configuration`. + If the class name is `LlamaConfig` it will return `configuration`. The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` """ match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) @@ -917,9 +917,8 @@ def find_file_type(class_name: str) -> str: ) # These specific modeling imports should not be visited as other modeling files -IMPORTS_TO_SKIP_IN_MODULAR = ( - "auto.modeling_auto", -) +IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",) + def get_module_name(node: cst.ImportFrom) -> str: """Recursively get the fully dotted name of a module in a cst.ImportFrom.""" @@ -981,7 +980,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> # Note that dicts implicitly keep the order of insertion imports_to_keep = {} for idx, node in enumerate(all_imports): - if m.matches(node, m.If()): # handle safe imports + if m.matches(node, m.If()): # handle safe imports new_statements = {} for second_idx, stmt_node in enumerate(node.body.body): append_new_import_node(stmt_node, unused_imports, new_statements, second_idx) @@ -1018,9 +1017,7 @@ def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]: all_all_to_add[file] += [class_name] for file, new_alls in all_all_to_add.items(): new_node = assign_node.with_changes( - value=cst.List( - elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls] - ) + value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) ) all_all_per_file[file] = node.with_changes(body=[new_node]) return all_all_per_file From bc7e20b6643fb3c7cad2fdc5a3fc7bad01d4c8e9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 18:04:22 +0100 Subject: [PATCH 16/40] remove unused stuff in get_needed_imports --- utils/modular_model_converter.py | 53 +++++++++----------------------- 1 file changed, 15 insertions(+), 38 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 6c90ff20fcb..6b80b53583b 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -920,39 +920,18 @@ def find_file_type(class_name: str) -> str: IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",) -def get_module_name(node: cst.ImportFrom) -> str: - """Recursively get the fully dotted name of a module in a cst.ImportFrom.""" - if m.matches(node, m.Name()): - return node.value - elif m.matches(node, m.Attribute()): - # Recursively get the full name for attributes - return f"{get_module_name(node.value)}.{node.attr.value}" - return "" - - -def append_new_import_node( - node: cst.CSTNode, unused_imports: set[str], imports_to_keep: dict[str, cst.CSTNode], current_idx: int -): - """Insert the new `node` to the dict of `imports_to_keep` in-place, if it is not part of the `unused_imports`. - This function takes cares of aggregating similar ImportFrom, i.e. if we ever saw a statement such as - `from typing import Any`, and later another one `from typing import List`, we will aggregate as - `from typing import Any, List` in a single statement. +def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]): + """Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`. """ import_node = node.body[0] - if m.matches(import_node, m.ImportFrom()): - module_name = get_module_name(import_node.module) - else: - module_name = current_idx - - # If we have a new import from with the same module name, write new names to the same import statement - names_to_keep = list(imports_to_keep[module_name].body[0].names) if module_name in imports_to_keep else [] - + names_to_keep = [] for name in import_node.names: name_value = name.evaluated_name if name_value not in unused_imports: names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) if len(names_to_keep) > 0: - imports_to_keep[module_name] = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) + new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) + imports_to_keep.append(new_node) def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: @@ -977,22 +956,20 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> unused_imports.add(name) import_ref_count[name] = ref_count - # Note that dicts implicitly keep the order of insertion - imports_to_keep = {} - for idx, node in enumerate(all_imports): + imports_to_keep = [] + for node in all_imports: if m.matches(node, m.If()): # handle safe imports - new_statements = {} - for second_idx, stmt_node in enumerate(node.body.body): - append_new_import_node(stmt_node, unused_imports, new_statements, second_idx) + new_statements = [] + for stmt_node in node.body.body: + append_new_import_node(stmt_node, unused_imports, new_statements) if len(new_statements) > 0: - imports_to_keep[idx] = node.with_changes( - body=node.body.with_changes(body=list(new_statements.values())) - ) + new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) + imports_to_keep.append(new_node) else: - append_new_import_node(node, unused_imports, imports_to_keep, idx) + append_new_import_node(node, unused_imports, imports_to_keep) - protected_import_nodes = [node for node in imports_to_keep.values() if m.matches(node, m.If())] - usual_import_nodes = [node for node in imports_to_keep.values() if not m.matches(node, m.If())] + protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] + usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] # If the same import is both protected and unprotected, only keep the protected one for protected_node in protected_import_nodes: for stmt_node in protected_node.body.body: From 2ab7f56c9a908e0aa01a5f621461c11a0f9ce45b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 18:04:43 +0100 Subject: [PATCH 17/40] style --- utils/modular_model_converter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 6b80b53583b..c67a8ba5817 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -921,8 +921,7 @@ def find_file_type(class_name: str) -> str: def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]): - """Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`. - """ + """Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`.""" import_node = node.body[0] names_to_keep = [] for name in import_node.names: From 197d93707dcf05a402441ed8bb12e56dc57e1eb0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 18:58:38 +0100 Subject: [PATCH 18/40] move class dependency functions outside class --- utils/modular_model_converter.py | 43 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index c67a8ba5817..a1847c0fec8 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -477,7 +477,6 @@ def forward(...): class ClassDependencyMapper(CSTVisitor): """A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of `global_names`. - This class is used through the 2 convenient class methods. """ METADATA_DEPENDENCIES = (ParentNodeProvider,) @@ -495,25 +494,25 @@ def visit_Name(self, node): if not m.matches(parent_node, m.Annotation()): self.dependencies.add(node.value) - @classmethod - def dependencies_for_node(cls, node: cst.ClassDef, global_names: set) -> set: - """Create dependencies for a node in the `ModuleMapper`.""" - temp_module = cst.Module(body=[node]) - wrapper = MetadataWrapper(temp_module) - visitor = cls(node.name.value, global_names) - wrapper.visit(visitor) - return visitor.dependencies - @classmethod - def dependencies_for_new_node(cls, updated_node: cst.ClassDef, mapper: "ModuleMapper") -> set: - """Create dependencies for a node in the `ModularFileMapper` (which may have been changed by - `replace_call_to_super`). - """ - temp_module = cst.Module(body=[updated_node]) - wrapper = MetadataWrapper(temp_module) - visitor = cls(updated_node.name.value, set(mapper.global_nodes.keys())) - wrapper.visit(visitor) - return mapper.augment_dependencies_with_functions(visitor.dependencies) +def dependencies_for_class_node(node: cst.ClassDef, global_names: set) -> set: + """Create immediate dependencies for a class node based on the `global_names`.""" + temp_module = cst.Module(body=[node]) + wrapper = MetadataWrapper(temp_module) + visitor = ClassDependencyMapper(node.name.value, global_names) + wrapper.visit(visitor) + return visitor.dependencies + + +def augmented_dependencies_for_class_node(node: cst.ClassDef, mapper: "ModuleMapper") -> set: + """Create augmented dependencies for a class node based on a `mapper`. + Augmented dependencies means immediate dependencies + recursive function dependencies. + """ + temp_module = cst.Module(body=[node]) + wrapper = MetadataWrapper(temp_module) + visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys())) + wrapper.visit(visitor) + return mapper.augment_dependencies_with_functions(visitor.dependencies) class ModuleMapper(CSTVisitor, ABC): @@ -644,7 +643,7 @@ def compute_class_dependencies(self): self.class_dependency_mapping = {} for class_name, class_node in self.classes.items(): - dependencies = ClassDependencyMapper.dependencies_for_node(class_node, set(self.global_nodes.keys())) + dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) # Corretcly augment class dependencies with all needed functions self.class_dependency_mapping[class_name] = self.augment_dependencies_with_functions(dependencies) @@ -1198,7 +1197,7 @@ class node based on the inherited classes if needed. updated_node = replace_class_node(mapper, node, renamed_super_class) # The node was modified -> look for all dependencies (recursively) of the new node - new_node_dependencies = ClassDependencyMapper.dependencies_for_new_node(updated_node, mapper) + new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper) all_dependencies_to_add = find_all_dependencies( dependency_mapping=mapper.class_dependency_mapping, initial_dependencies=new_node_dependencies, @@ -1214,7 +1213,7 @@ class node based on the inherited classes if needed. else: updated_node = node # The node was NOT modified -> no need to look for dependencies recursively - all_dependencies_to_add = ClassDependencyMapper.dependencies_for_node(updated_node, self.global_nodes) + all_dependencies_to_add = dependencies_for_class_node(updated_node, self.global_nodes) relative_dependency_order = self.compute_relative_order(all_dependencies_to_add) nodes_to_add = { From 459be8f2b0cc9060b7569afd760f07556bb1f8e9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 19:03:41 +0100 Subject: [PATCH 19/40] Move main functions outside class --- utils/modular_model_converter.py | 212 ++++++++++++++++--------------- 1 file changed, 107 insertions(+), 105 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index a1847c0fec8..b02017f8965 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1168,115 +1168,117 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: return relative_order - def add_class_node(self, class_name: str, node: cst.CSTNode, files: dict[str, dict]) -> tuple[dict, str]: - """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new - class node based on the inherited classes if needed. - """ - bases = [k.value.value for k in node.bases if k.value.value in self.model_specific_imported_objects] - if len(bases) > 1: - raise ValueError( - f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." - ) - file_type = find_file_type(class_name) - file_to_update = files[file_type] - - # We need to replace the class node with the super class node - if len(bases) == 1: - super_class = bases[0] - super_file_name = self.model_specific_imported_objects[super_class] - - # Get the mapper corresponding to the inherited class - mapper = self.visited_modules[super_file_name] - # Rename the super class according to the exact same rule we used when renaming the whole module - renamer = self.renamers[super_file_name] - renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name) - renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name) - - # Create the new class node - updated_node = replace_class_node(mapper, node, renamed_super_class) - - # The node was modified -> look for all dependencies (recursively) of the new node - new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper) - all_dependencies_to_add = find_all_dependencies( - dependency_mapping=mapper.class_dependency_mapping, - initial_dependencies=new_node_dependencies, - initial_checked_dependencies=set(file_to_update.keys()), - ) +def add_class_node(modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict]) -> tuple[dict, str]: + """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new + class node based on the inherited classes if needed. + """ + bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects] + if len(bases) > 1: + raise ValueError( + f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." + ) - relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) - nodes_to_add = { - dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add - } + file_type = find_file_type(class_name) + file_to_update = files[file_type] + + # We need to replace the class node with the super class node + if len(bases) == 1: + super_class = bases[0] + super_file_name = modular_mapper.model_specific_imported_objects[super_class] + + # Get the mapper corresponding to the inherited class + mapper = modular_mapper.visited_modules[super_file_name] + # Rename the super class according to the exact same rule we used when renaming the whole module + renamer = modular_mapper.renamers[super_file_name] + renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name) + renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name) + + # Create the new class node + updated_node = replace_class_node(mapper, node, renamed_super_class) + + # The node was modified -> look for all dependencies (recursively) of the new node + new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper) + all_dependencies_to_add = find_all_dependencies( + dependency_mapping=mapper.class_dependency_mapping, + initial_dependencies=new_node_dependencies, + initial_checked_dependencies=set(file_to_update.keys()), + ) - # No super class, just check functions and assignments dependency in the imports from other model-specific files - else: - updated_node = node - # The node was NOT modified -> no need to look for dependencies recursively - all_dependencies_to_add = dependencies_for_class_node(updated_node, self.global_nodes) - - relative_dependency_order = self.compute_relative_order(all_dependencies_to_add) - nodes_to_add = { - dep: (relative_dependency_order[dep], self.global_nodes[dep]) - for dep in all_dependencies_to_add - if dep not in file_to_update.keys() - } - - # Add the class node itself to the nodes to add - class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 - nodes_to_add[class_name] = (class_idx, updated_node) - - return nodes_to_add, file_type - - def create_modules(self) -> dict[str, cst.Module]: - """Create all the new modules based on visiting the modular file. It replaces all classes as necesary.""" - files = defaultdict(dict) - current_file_indices = defaultdict(lambda: 0) - - # For each class defined in modular, potentially replace the node and add it with its dependencies - for class_name, node in self.classes.items(): - nodes_to_add, file_type = self.add_class_node(class_name, node, files) - # Sort the nodes according to their relative order - nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) - # Write all nodes to file - for dependency, (_, node) in nodes_to_add: - # This is used to keep certain variables at the beginning of the file - try: - # The -1000 is arbitrary -> just keep it bigger than the list - idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) - except ValueError: - idx = current_file_indices[file_type] - current_file_indices[file_type] += 1 - files[file_type][dependency] = {"insert_idx": idx, "node": node} - - # Add the __all__ statement to files at the end - for file_type, node in self.all_all_to_add.items(): - idx = current_file_indices[file_type] - files[file_type]["__all__"] = {"insert_idx": idx, "node": node} - - # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because - # they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc) - all_imports = self.imports.copy() - all_imports_code = {self.python_module.code_for_node(node).strip() for node in all_imports} - for file, mapper in self.visited_modules.items(): - new_imports = [ - node - for node in mapper.imports - if mapper.python_module.code_for_node(node).strip() not in all_imports_code - ] - new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} - all_imports.extend(new_imports) - all_imports_code.update(new_imports_code) + relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add + } - # Find the correct imports, and write the new modules - for file, body in files.items(): - new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] - needed_imports = get_needed_imports(body, all_imports) - full_module = needed_imports + new_body - new_module = cst.Module(body=full_module, header=self.python_module.header) - files[file] = new_module + # No super class, just check functions and assignments dependency in the imports from other model-specific files + else: + updated_node = node + # The node was NOT modified -> no need to look for dependencies recursively + all_dependencies_to_add = dependencies_for_class_node(updated_node, modular_mapper.global_nodes) + + relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) + for dep in all_dependencies_to_add + if dep not in file_to_update.keys() + } - return files + # Add the class node itself to the nodes to add + class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 + nodes_to_add[class_name] = (class_idx, updated_node) + + return nodes_to_add, file_type + + +def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: + """Create all the new modules based on visiting the modular file. It replaces all classes as necesary.""" + files = defaultdict(dict) + current_file_indices = defaultdict(lambda: 0) + + # For each class defined in modular, potentially replace the node and add it with its dependencies + for class_name, node in modular_mapper.classes.items(): + nodes_to_add, file_type = add_class_node(modular_mapper, class_name, node, files) + # Sort the nodes according to their relative order + nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) + # Write all nodes to file + for dependency, (_, node) in nodes_to_add: + # This is used to keep certain variables at the beginning of the file + try: + # The -1000 is arbitrary -> just keep it bigger than the list + idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) + except ValueError: + idx = current_file_indices[file_type] + current_file_indices[file_type] += 1 + files[file_type][dependency] = {"insert_idx": idx, "node": node} + + # Add the __all__ statement to files at the end + for file_type, node in modular_mapper.all_all_to_add.items(): + idx = current_file_indices[file_type] + files[file_type]["__all__"] = {"insert_idx": idx, "node": node} + + # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because + # they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc) + all_imports = modular_mapper.imports.copy() + all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} + for file, mapper in modular_mapper.visited_modules.items(): + new_imports = [ + node + for node in mapper.imports + if mapper.python_module.code_for_node(node).strip() not in all_imports_code + ] + new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} + all_imports.extend(new_imports) + all_imports_code.update(new_imports_code) + + # Find the correct imports, and write the new modules + for file, body in files.items(): + new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] + needed_imports = get_needed_imports(body, all_imports) + full_module = needed_imports + new_body + new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header) + files[file] = new_module + + return files def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None): @@ -1292,7 +1294,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, if cst_transformers is None: cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name) wrapper.visit(cst_transformers) - for file, module in cst_transformers.create_modules().items(): + for file, module in create_modules(cst_transformers).items(): if module != {}: # Get relative path starting from src/transformers/ relative_path = re.search( From 128986d5bcabfe7a9a936ba7f0b14a83628fab10 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 19:04:11 +0100 Subject: [PATCH 20/40] style --- utils/modular_model_converter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index b02017f8965..c22ad7aa624 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1169,7 +1169,9 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: return relative_order -def add_class_node(modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict]) -> tuple[dict, str]: +def add_class_node( + modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] +) -> tuple[dict, str]: """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new class node based on the inherited classes if needed. """ @@ -1262,9 +1264,7 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} for file, mapper in modular_mapper.visited_modules.items(): new_imports = [ - node - for node in mapper.imports - if mapper.python_module.code_for_node(node).strip() not in all_imports_code + node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code ] new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} all_imports.extend(new_imports) From 79113cf2cfe2f78e9f3537dd89ce4c19bb7e2404 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 19:08:13 +0100 Subject: [PATCH 21/40] Update modular_model_converter.py --- utils/modular_model_converter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index c22ad7aa624..9a15ecd83c3 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1019,7 +1019,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= # fmt: on def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - """When visiting imports from model-specific files (i.e. `transformers.models.xxx`) we get the code, parse it, + """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. """ import_statement = self.python_module.code_for_node(node.module) @@ -1080,7 +1080,7 @@ def visit_SimpleStatementLine(self, node): def leave_Module(self, node): """When we leave the modular file, we do the following in order: 1. compute the nested (recursive) function dependencies - 2. for each model-specific file found in the imports, rename it with the new model name, visit it, and update + 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update its dependency graph with the new function and assignment definitions found in the modular 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) """ @@ -1090,7 +1090,7 @@ def leave_Module(self, node): # 1. compute the nested (recursive) function dependencies self.function_recursive_dependency_mapping = self._compute_recursive_function_dependencies() - # 2. for each model-specific file found in the imports, rename it with the new model name, visit it, and update dependencies + # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies self.visited_modules = {} self.renamers = {} for file, module in self.model_specific_modules.items(): @@ -1114,7 +1114,7 @@ def leave_Module(self, node): self.merge_model_specific_imports(self.visited_modules) def merge_model_specific_imports(self, visited_modules): - """Merge the model-specific imported functions and assignments to the modular nodes and dependency graph, + """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, based on the visited files.""" self.start_lines_file_mapping = {} self.added_objects_file_mapping = {} @@ -1212,7 +1212,7 @@ class node based on the inherited classes if needed. dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add } - # No super class, just check functions and assignments dependency in the imports from other model-specific files + # No super class, just check functions and assignments dependency in the imports from other modeling files else: updated_node = node # The node was NOT modified -> no need to look for dependencies recursively From 8d26fa912c75fdeb8ae77d84d27eaa166f003897 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 19:10:48 +0100 Subject: [PATCH 22/40] rename func --- utils/modular_model_converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 9a15ecd83c3..ecfc8d0dfc2 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1169,7 +1169,7 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: return relative_order -def add_class_node( +def get_class_node_and_dependencies( modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] ) -> tuple[dict, str]: """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new @@ -1239,7 +1239,7 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: # For each class defined in modular, potentially replace the node and add it with its dependencies for class_name, node in modular_mapper.classes.items(): - nodes_to_add, file_type = add_class_node(modular_mapper, class_name, node, files) + nodes_to_add, file_type = get_class_node_and_dependencies(modular_mapper, class_name, node, files) # Sort the nodes according to their relative order nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) # Write all nodes to file From b2503673038d3cbe7433538f28000ff475089de1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 19:19:00 +0100 Subject: [PATCH 23/40] add augmented dependencies --- src/transformers/models/gemma/modeling_gemma.py | 14 +++++++------- src/transformers/models/gemma2/modeling_gemma2.py | 14 +++++++------- utils/modular_model_converter.py | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index dec4e26fee2..fa3fadc4349 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -155,6 +155,13 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -382,13 +389,6 @@ def forward( return attn_output, None, past_key_value -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - class GemmaFlashAttention2(GemmaAttention): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index aa3a926f2c9..626e5537fc0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -122,6 +122,13 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -161,13 +168,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index ecfc8d0dfc2..18b9fed8508 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1215,8 +1215,8 @@ class node based on the inherited classes if needed. # No super class, just check functions and assignments dependency in the imports from other modeling files else: updated_node = node - # The node was NOT modified -> no need to look for dependencies recursively - all_dependencies_to_add = dependencies_for_class_node(updated_node, modular_mapper.global_nodes) + # The node was NOT modified -> no need to look recursively for other class dependencies (they should all be defined) + all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper) relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) nodes_to_add = { From 33dbde7f35cbf96a9e6b4c82ac09e60c2e967931 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 19:28:40 +0100 Subject: [PATCH 24/40] Update modular_model_converter.py --- utils/modular_model_converter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 18b9fed8508..a1fc7f3dda4 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1215,7 +1215,8 @@ class node based on the inherited classes if needed. # No super class, just check functions and assignments dependency in the imports from other modeling files else: updated_node = node - # The node was NOT modified -> no need to look recursively for other class dependencies (they should all be defined) + # The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not + # already defined (which would mean a weird order of the code in the modular...), they will be in the future all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper) relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) From 9fcddb87716740b8e3aec84210d6aab6e0c2b327 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 29 Oct 2024 20:25:48 +0100 Subject: [PATCH 25/40] Add types_to_file_type + tweak annotation handling --- utils/modular_model_converter.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index a1fc7f3dda4..c3f1249ed39 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -490,9 +490,13 @@ def __init__(self, class_name: str, global_names: set | None): def visit_Name(self, node): if node.value != self.class_name and node.value in self.global_names: parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - # If it is only an annotation, do not add dependency - if not m.matches(parent_node, m.Annotation()): - self.dependencies.add(node.value) + # If it is only an annotation inside a method definition, do not add dependency (however, do it for + # annotations that are variable definitions, i.e. for Kwargs classes) + if m.matches(parent_node, m.Annotation()): + grand_parent = self.get_metadata(cst.metadata.ParentNodeProvider, parent_node) + if m.matches(grand_parent, m.Param() | m.FunctionDef()): + return + self.dependencies.add(node.value) def dependencies_for_class_node(node: cst.ClassDef, global_names: set) -> set: @@ -890,6 +894,9 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename "Processor": "processing", "ImageProcessor": "image_processing", "FeatureExtractor": "feature_extractor", + "ProcessorKwargs": "processing", + "ImagesKwargs": "processing", + "TextKwargs": "processing", } From 70f006b4cbaed558f97308dc55ca368a6800b10f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 30 Oct 2024 23:03:05 +0100 Subject: [PATCH 26/40] Allow assignment dependency mapping + fix regex --- utils/modular_model_converter.py | 168 +++++++++++++++++-------------- 1 file changed, 94 insertions(+), 74 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index c3f1249ed39..f98af57a554 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -393,7 +393,7 @@ def find_all_dependencies( Args: dependency_mapping (`Dict[str, set]`): - A mapping from entities (usually function names), to immediate dependencies. That is, for function names, + A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names, a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called in `foo`'s definition. start_entity (str | None, *optional*): @@ -510,19 +510,19 @@ def dependencies_for_class_node(node: cst.ClassDef, global_names: set) -> set: def augmented_dependencies_for_class_node(node: cst.ClassDef, mapper: "ModuleMapper") -> set: """Create augmented dependencies for a class node based on a `mapper`. - Augmented dependencies means immediate dependencies + recursive function dependencies. + Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. """ temp_module = cst.Module(body=[node]) wrapper = MetadataWrapper(temp_module) visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys())) wrapper.visit(visitor) - return mapper.augment_dependencies_with_functions(visitor.dependencies) + return mapper.augment_dependencies(visitor.dependencies) class ModuleMapper(CSTVisitor, ABC): - """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes and functions. - Class dependencies are computed with `compute_class_dependencies()`, while function dependencies are stored in - `self.function_recursive_dependency_mapping` (can be computed by `_compute_recursive_function_dependencies()`). + """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. + Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in + `self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`). It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the modeling files that will be visited. """ @@ -535,9 +535,10 @@ def __init__(self, python_module: cst.Module): self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes self.imports = [] # stores all import statements self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes - self.function_dependency_mapping = defaultdict(set) # immediate function dependency mapping (i.e. dependencies immediately in the function definition) + self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition) self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes self.current_function = None # this keeps track of the current module-scope function + self.current_assignment = None # this keeps track of the current module-scope assignment # fmt: on def visit_SimpleStatementLine(self, node): @@ -546,17 +547,22 @@ def visit_SimpleStatementLine(self, node): are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. """ parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + simple_top_level_assign_structure = m.SimpleStatementLine( + body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] + ) if m.matches(parent_node, m.Module()): - if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): - left_hand_side = node.body[0].targets[0].target - if hasattr(left_hand_side, "value"): - self.assignments[left_hand_side.value] = node - else: - for idx, target in enumerate(list(left_hand_side.elements)): - self.assignments[target.value.value] = node.body[0].value.elements[idx].value + if m.matches(node, simple_top_level_assign_structure): + left_hand_side = node.body[0].targets[0].target.value + self.current_assignment = left_hand_side + self.assignments[left_hand_side] = node elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): self.imports.append(node) + def leave_SimpleStatementLine(self, node): + # No need to check for the parent here -> everytime we exit one, it should be None anyway independently of where the + # SimpleStatement is located + self.current_assignment = None + def visit_FunctionDef(self, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) if m.matches(parent_node, m.Module()): @@ -578,14 +584,16 @@ def visit_ClassDef(self, node: ClassDef) -> None: self.classes[node.name.value] = node def visit_Name(self, node: cst.Call): - """This is used to create a mapping from module-scope functions to objects used inside them.""" + """This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" if self.current_function is not None: - self.function_dependency_mapping[self.current_function].add(node.value) + self.object_dependency_mapping[self.current_function].add(node.value) + if self.current_assignment is not None: + self.object_dependency_mapping[self.current_assignment].add(node.value) def leave_Module(self, node): """When leaving the module, we store the position of each global scoped node to allow sorting the dependencies based on their position in the code later. We use the PositionProvider metadata wrapper for this. - We also make sure to update `self.function_dependency_mapping` so that it contains only names recorded in + We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in `self.global_nodes`. """ # assign all nodes @@ -595,13 +603,13 @@ def leave_Module(self, node): for id, node in self.global_nodes.items(): self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line - # Since we added every Name as part of `self.function_dependency_mapping`, we now remove those that + # Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that # are not part of the recorded objects (i.e. built-in variables, imports, etc) global_objects = set(self.global_nodes.keys()) - for function_name, dependencies in self.function_dependency_mapping.items(): - self.function_dependency_mapping[function_name] = {dep for dep in dependencies if dep in global_objects} + for object_name, dependencies in self.object_dependency_mapping.items(): + self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} - def _compute_recursive_function_dependencies(self) -> dict[str, set]: + def _compute_recursive_object_dependencies(self) -> dict[str, set]: """Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the following file: ``` @@ -615,41 +623,41 @@ def test(): bar() ``` this visitor can only record immediate dependencies, i.e. it will record the following - `self.function_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create + `self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. """ recursive_dependencies = {} - for function_name in self.function_dependency_mapping.keys(): - all_dependencies = find_all_dependencies(self.function_dependency_mapping, start_entity=function_name) - recursive_dependencies[function_name] = all_dependencies + for object_name in self.object_dependency_mapping.keys(): + all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name) + recursive_dependencies[object_name] = all_dependencies return recursive_dependencies - def augment_dependencies_with_functions(self, dependencies: set[str]) -> set[str]: - """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** present - in the `dependencies`. + def augment_dependencies(self, dependencies: set[str]) -> set[str]: + """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and + **assignments** present in the `dependencies`. """ new_dependencies = dependencies.copy() # Go through the set of dependencies for dep in tuple(dependencies): - if dep in self.function_recursive_dependency_mapping.keys(): - new_dependencies.update(self.function_recursive_dependency_mapping[dep]) + if dep in self.object_recursive_dependency_mapping.keys(): + new_dependencies.update(self.object_recursive_dependency_mapping[dep]) return new_dependencies def compute_class_dependencies(self): """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies. - Note: This function takes care of updating `global_nodes` and `function_call_recursive_dependency_mapping` as well after the + Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the merge with other files dependencies. """ # Correctly re-set the global nodes at this point self.global_nodes = {**self.assignments, **self.classes, **self.functions} - # Create the global mapping of recursive dependencies for functions - self.function_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + # Create the global mapping of recursive dependencies for functions and assignments + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() self.class_dependency_mapping = {} for class_name, class_node in self.classes.items(): dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) - # Corretcly augment class dependencies with all needed functions - self.class_dependency_mapping[class_name] = self.augment_dependencies_with_functions(dependencies) + # Correctly augment class dependencies with all needed objects + self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies) @abstractmethod def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: @@ -731,7 +739,7 @@ def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, in return relative_order - def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapping: dict[str, set]): + def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]): """Update the global nodes and function dependency mapping with those from the modular file. Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies @@ -739,9 +747,9 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], function_call_mapp """ # Add/overwrite all needed function nodes and dependencies self.functions.update(functions) - self.function_dependency_mapping.update(function_call_mapping) + self.object_dependency_mapping.update({obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}) - def _merge_assignments(self, assignments: dict[str, cst.CSTNode]): + def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): """Update the global nodes with the assignment from the modular file. Merging rule: if any assignment with the same name was redefined in the modular, we use it ONLY if it is @@ -751,23 +759,25 @@ def _merge_assignments(self, assignments: dict[str, cst.CSTNode]): for assignment, node in assignments.items(): if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments: self.assignments[assignment] = node + if assignment in object_mapping: + self.object_dependency_mapping[assignment] = object_mapping[assignment] - def merge_modular_dependencies(self, functions, function_mapping, assignments, start_lines): + def merge_modular_dependencies(self, functions, assignments, object_mapping, start_lines): """Merge both functions and assignments from the modular definitions into the current module file, - then compute the relative order of all nodes.""" - self._merge_functions(functions, function_mapping) - self._merge_assignments(assignments) + then record the relative order of all nodes.""" + self._merge_functions(functions, object_mapping) + self._merge_assignments(assignments, object_mapping) self.modular_file_start_lines = start_lines @classmethod def visit_and_merge_dependencies( - cls, module: cst.Module, functions, function_mapping, assignments, start_lines + cls, module: cst.Module, functions, assignments, object_mapping, start_lines ) -> "ModelFileMapper": wrapper = MetadataWrapper(module) mapper = cls(module) wrapper.visit(mapper) # Merge dependencies - mapper.merge_modular_dependencies(functions, function_mapping, assignments, start_lines) + mapper.merge_modular_dependencies(functions, assignments, object_mapping, start_lines) # Create the class dependencies graph mapper.compute_class_dependencies() return mapper @@ -1029,32 +1039,33 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. """ - import_statement = self.python_module.code_for_node(node.module) + import_module = self.python_module.code_for_node(node.module) + import_statement = "."*len(node.relative) + import_module if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): return if m.matches(node.module, m.Attribute()): for imported_ in node.names: - _import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement) + _import = re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) if _import: - source = _import.groups()[0] + source = _import.group(1) if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): raise ValueError( f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" ) - if import_statement not in self.model_specific_modules: - if "models" not in import_statement: - import_statement = "models." + import_statement - if "transformers" not in import_statement: - import_statement = "transformers." + import_statement - source_code = get_module_source_from_name(import_statement) + if import_module not in self.model_specific_modules: + if "models" not in import_module: + import_module = "models." + import_module + if "transformers" not in import_module: + import_module = "transformers." + import_module + source_code = get_module_source_from_name(import_module) tree = cst.parse_module(source_code) - self.model_specific_modules[import_statement] = tree + self.model_specific_modules[import_module] = tree imported_object = self.python_module.code_for_node(imported_.name) - self.model_specific_imported_objects[imported_object] = import_statement + self.model_specific_imported_objects[imported_object] = import_module if m.matches(node.module, m.Name()): - if "transformers" == import_statement: + if "transformers" == import_module: raise ValueError( - f"You are importing from {import_statement} directly using global imports. Import from the correct local path" + f"You are importing from {import_module} directly using global imports. Import from the correct local path" ) def visit_SimpleStatementLine(self, node): @@ -1069,24 +1080,24 @@ def visit_SimpleStatementLine(self, node): if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): self.imports.append(node) elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): - full_statement = self.python_module.code_for_node(node.body[0].module) + import_module = self.python_module.code_for_node(node.body[0].module) + import_statement = "."*len(node.body[0].relative) + import_module if not ( - # OR MATCH ..llama.modeling_llama - re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement) - and "auto.modeling_auto" not in full_statement + re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) + and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) ): self.imports.append(node) elif m.matches(node, simple_top_level_assign_structure): assigned_variable = node.body[0].targets[0].target.value # __all__ is treated differently and not added to general assignments - if assigned_variable != "__all__": - self.assignments[assigned_variable] = node - else: + if assigned_variable == "__all__": self.all_all_to_add = split_all_assignment(node) + else: + self.assignments[assigned_variable] = node def leave_Module(self, node): """When we leave the modular file, we do the following in order: - 1. compute the nested (recursive) function dependencies + 1. compute the nested (recursive) function and assignment dependencies 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update its dependency graph with the new function and assignment definitions found in the modular 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) @@ -1094,8 +1105,8 @@ def leave_Module(self, node): # Takes care of finalizing our visit super().leave_Module(node) - # 1. compute the nested (recursive) function dependencies - self.function_recursive_dependency_mapping = self._compute_recursive_function_dependencies() + # 1. compute the nested (recursive) function and assignment dependencies + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies self.visited_modules = {} @@ -1109,8 +1120,8 @@ def leave_Module(self, node): self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( renamed_module, self.functions, - self.function_dependency_mapping, self.assignments, + self.object_dependency_mapping, self.start_lines, ) # We record it so that we can rename classes later the exact same way @@ -1132,17 +1143,25 @@ def merge_model_specific_imports(self, visited_modules): if object_name in visited_module.functions and object_name not in self.functions: self.functions[object_name] = visited_module.functions[object_name] self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.function_recursive_dependency_mapping.get(object_name, None) + dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) if dependencies is not None: - self.function_recursive_dependency_mapping[object_name] = dependencies + self.object_recursive_dependency_mapping[object_name] = dependencies for dep in dependencies: - self.added_objects_file_mapping[dep] = file - self.functions[dep] = visited_module.global_nodes[dep] + if dep not in self.global_nodes: + self.added_objects_file_mapping[dep] = file + self.functions[dep] = visited_module.global_nodes[dep] - # Add assignments + # Add assignments and their dependencies elif object_name in visited_module.assignments and object_name not in self.assignments: - self.added_objects_file_mapping[object_name] = file self.assignments[object_name] = visited_module.assignments[object_name] + self.added_objects_file_mapping[object_name] = file + dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + if dependencies is not None: + self.object_recursive_dependency_mapping[object_name] = dependencies + for dep in dependencies: + if dep not in self.global_nodes: + self.added_objects_file_mapping[dep] = file + self.assignments[dep] = visited_module.global_nodes[dep] # Do not forget to re-assign all nodes after the merge self.global_nodes = {**self.assignments, **self.classes, **self.functions} @@ -1360,6 +1379,7 @@ def save_modeling_file(modular_file, converted_file): args = parser.parse_args() if args.files_to_parse == ["all"]: args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) + args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True) for file_name in find_priority_list(args.files_to_parse): print(f"Converting {file_name} to a single model single file format") From b5879b14026897f90ac2dce66d878c406cf27f57 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 30 Oct 2024 23:07:20 +0100 Subject: [PATCH 27/40] style + update modular examples --- .../configuration_my_new_model.py | 27 +- .../configuration_my_new_model2.py | 113 +- .../configuration_new_model.py | 27 +- .../modular-transformers/modeling_dummy.py | 208 ++-- .../modeling_dummy_bert.py | 190 ++- .../modeling_my_new_model2.py | 164 ++- .../modeling_new_task_model.py | 148 +-- .../modular-transformers/modeling_roberta.py | 1017 +++++++++++++++++ .../modular-transformers/modeling_super.py | 227 ++-- utils/modular_model_converter.py | 12 +- 10 files changed, 1459 insertions(+), 674 deletions(-) create mode 100644 examples/modular-transformers/modeling_roberta.py diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py index 3c7848e6956..aa0aac55ba9 100644 --- a/examples/modular-transformers/configuration_my_new_model.py +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_my_new_model.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_my_new_model.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation @@ -158,6 +158,13 @@ def __init__( new_param=0, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -187,11 +194,3 @@ def __init__( self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) self.new_param = new_param - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/examples/modular-transformers/configuration_my_new_model2.py b/examples/modular-transformers/configuration_my_new_model2.py index 5fef1cecc70..f05ace94b62 100644 --- a/examples/modular-transformers/configuration_my_new_model2.py +++ b/examples/modular-transformers/configuration_my_new_model2.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_my_new_model2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation @@ -11,106 +11,6 @@ class MyNewModel2Config(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`MyNewModel2Model`]. It is used to instantiate an MyNewModel2 - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the MyNewModel2-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the MyNewModel2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`MyNewModel2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. MyNewModel2 1 supports up to 2048 tokens, - MyNewModel2 2 up to 4096, CodeMyNewModel2 up to 16384. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to - understand more about it. This value is necessary to ensure exact reproducibility of the pretraining - results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'my_new_model23'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'my_new_model23'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'my_new_model23'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'my_new_model23'. Scaling factor applied to high frequency components of the RoPE - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. - head_dim (`int`, *optional*): - The attention head dimension. If None, it will default to hidden_size // num_heads This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma-7B. @@ -121,7 +21,6 @@ class MyNewModel2Config(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GemmaModel`] - ```python >>> from transformers import GemmaModel, GemmaConfig >>> # Initializing a Gemma gemma-7b style configuration diff --git a/examples/modular-transformers/configuration_new_model.py b/examples/modular-transformers/configuration_new_model.py index 8bc8ef52cee..4d164fe3e75 100644 --- a/examples/modular-transformers/configuration_new_model.py +++ b/examples/modular-transformers/configuration_new_model.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_new_model.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_new_model.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Example where we only want to overwrite the defaults of an init from ...configuration_utils import PretrainedConfig @@ -104,6 +104,13 @@ def __init__( attention_dropout=0.0, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -121,14 +128,6 @@ def __init__( self.attention_bias = attention_bias self.attention_dropout = attention_dropout - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - @property def num_heads(self): return self.num_attention_heads diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index b5b1fc6aec8..ed87fb66f0a 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -1,26 +1,24 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_dummy.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dummy.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, -) +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -33,59 +31,6 @@ logger = logging.get_logger(__name__) -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class DummyRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -193,40 +138,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 4] - x2 = x[..., x.shape[-1] // 4 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class DummyMLP(nn.Module): def __init__(self, config): super().__init__() @@ -261,6 +172,33 @@ def forward(self, x): return down_proj +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -273,6 +211,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 4] + x2 = x[..., x.shape[-1] // 4 :] + return torch.cat((-x2, x1), dim=-1) + + class DummyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -423,6 +368,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -507,6 +453,7 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -871,6 +818,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -952,6 +900,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -1011,10 +960,9 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1023,13 +971,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1043,6 +990,63 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index 611d7be961f..e18e6a19e8a 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -1,27 +1,20 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_dummy_bert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dummy_bert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import os from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from packaging import version from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import ( - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from ...modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, -) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -40,79 +33,6 @@ _CONFIG_FOR_DOC = "DummyBertConfig" -def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class DummyBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -706,6 +626,79 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output +def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + class DummyBertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -871,26 +864,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -1027,7 +1000,6 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return super().forward(input_ids) return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 49cdd274162..16f9e525a05 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -1,25 +1,20 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_my_new_model2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - SequenceClassifierOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -30,6 +25,9 @@ from .configuration_my_new_model2 import MyNewModel2Config +logger = logging.get_logger(__name__) + + class MyNewModel2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -50,9 +48,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -logger = logging.get_logger(__name__) - - class MyNewModel2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -448,59 +443,6 @@ def forward( return attn_output, attn_weights, past_key_value -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - MY_NEW_MODEL2_ATTENTION_CLASSES = { "eager": MyNewModel2Attention, "flash_attention_2": MyNewModel2FlashAttention2, @@ -893,10 +835,9 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -905,13 +846,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -925,10 +865,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( """ @@ -1019,27 +1016,8 @@ def forward( loss = None if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 640331ace1d..4556308f1ea 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -8,7 +8,6 @@ from typing import ClassVar, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...cache_utils import Cache, StaticCache @@ -18,92 +17,15 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, replace_return_docstrings, ) -from .configuration_new_task_model import NewTaskModelConfig - - -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_new_task_model import NewTaskModelConfig -logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "NewTaskModelConfig" -# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position -# But NewTaskModel has no causal mask on prefix -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, - is_training: bool = False, - token_type_ids: torch.Tensor = None, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - is_training (`bool`): - Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels` - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - if is_training: - causal_mask = torch.triu(causal_mask, diagonal=1) - else: - causal_mask[:, :sequence_length] = 0.0 - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - # we are training thus we need to create a full mask on the image + prefix but causal on suffix - if is_training: - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) - return causal_mask - - @dataclass class NewTaskModelCausalLMOutputWithPast(ModelOutput): """ @@ -182,12 +104,12 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["NewTaskModelMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = False _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True - _supports_sdpa = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of NewTaskModelisn't meant for training from scratch - only @@ -210,14 +132,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - NEW_TASK_MODEL_INPUTS_DOCSTRING = r""" Args: @@ -301,11 +215,8 @@ def __init__(self, config): self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = NewTaskModelMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self._attn_implementation = config._attn_implementation - language_model = AutoModelForCausalLM.from_config( - config=config.text_config, attn_implementation=self._attn_implementation - ) + language_model = AutoModelForCausalLM.from_config(config=config.text_config) if language_model._tied_weights_keys is not None: self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] @@ -344,6 +255,11 @@ def tie_weights(self): def _update_causal_mask( self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + using_static_cache = isinstance(past_key_values, StaticCache) dtype = inputs_embeds.dtype min_dtype = torch.finfo(dtype).min @@ -388,6 +304,22 @@ def _update_causal_mask( ) return causal_mask + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features / (self.config.hidden_size**0.5) + return image_features + @add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -426,9 +358,9 @@ def forward( ```python >>> from PIL import Image >>> import requests - >>> from transformers import AutoProcessor, NewTaskModelForNewTask + >>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration - >>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf") + >>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf") >>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf") >>> prompt = "answer en Where is the cow standing?" @@ -484,6 +416,7 @@ def prepare_inputs_for_generation( num_logits_to_keep=None, **kwargs, ): + # Overwritten -- custom `position_ids` and `pixel_values` handling model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -493,33 +426,10 @@ def prepare_inputs_for_generation( cache_position=cache_position, use_cache=use_cache, num_logits_to_keep=num_logits_to_keep, + token_type_ids=token_type_ids, **kwargs, ) - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - dtype = self.get_output_embeddings().weight.dtype - min_dtype = torch.finfo(dtype).min - - model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_length(), - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=batch_size, - ) - - model_inputs["token_type_ids"] = token_type_ids - # position_ids in NewTaskModel are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py new file mode 100644 index 00000000000..47fea7d32d3 --- /dev/null +++ b/examples/modular-transformers/modeling_roberta.py @@ -0,0 +1,1017 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_roberta.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_roberta.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from packaging import version + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + get_torch_version, + logging, +) +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-roberta/roberta-base-uncased" +_CONFIG_FOR_DOC = "RobertaConfig" + + +class RobertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, config.pad_token_id + ) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + self.pad_token_id = config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class RobertaSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class RobertaSdpaSelfAttention(RobertaSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from RobertaSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class RobertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": RobertaSelfAttention, + "sdpa": RobertaSdpaSelfAttention, +} + + +class RobertaAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = RobertaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class RobertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class RobertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RobertaLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RobertaAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute") + self.intermediate = RobertaIntermediate(config) + self.output = RobertaOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class RobertaEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class RobertaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +def load_tf_weights_in_roberta(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class RobertaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + load_tf_weights = load_tf_weights_in_roberta + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Roberta Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class RobertaModel(RobertaPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = RobertaEmbeddings(config) + self.encoder = RobertaEncoder(config) + + self.pooler = RobertaPooler(config) if add_pooling_layer else None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + # Error out here. Why? Because `RobertaEmbeddings` is defined but not used. + # no, because it's defined, and RobertaModel should use RobertaEmbedding + # here if initialized that way it won't use the new embedding. + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index d91bdb1820c..7df04bcc2a9 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -1,26 +1,24 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_super.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_super.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, -) +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -33,59 +31,6 @@ logger = logging.get_logger(__name__) -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class SuperRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -123,7 +68,7 @@ def __init__( if config is None: logger.warning_once( "`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" + "`config` argument. All other arguments will be removed in v4.46" ) self.rope_kwargs = { "rope_type": rope_type, @@ -193,40 +138,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class SuperMLP(nn.Module): def __init__(self, config): super().__init__() @@ -261,6 +172,40 @@ def forward(self, x): return down_proj +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -302,7 +247,7 @@ def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = SuperRotaryEmbedding(config=self.config) def forward( @@ -314,7 +259,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -349,7 +294,7 @@ def forward( logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -422,7 +367,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -449,7 +395,7 @@ def forward( logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -507,6 +453,7 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -535,7 +482,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -569,7 +516,7 @@ def forward( logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -644,7 +591,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -790,7 +737,8 @@ def _init_weights(self, module): returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - - a [`~cache_utils.Cache`] instance; + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. @@ -916,10 +864,9 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -928,13 +875,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -948,6 +894,63 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index f98af57a554..26665cb9ee0 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -747,7 +747,9 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: di """ # Add/overwrite all needed function nodes and dependencies self.functions.update(functions) - self.object_dependency_mapping.update({obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}) + self.object_dependency_mapping.update( + {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} + ) def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): """Update the global nodes with the assignment from the modular file. @@ -1040,12 +1042,14 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. """ import_module = self.python_module.code_for_node(node.module) - import_statement = "."*len(node.relative) + import_module + import_statement = "." * len(node.relative) + import_module if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): return if m.matches(node.module, m.Attribute()): for imported_ in node.names: - _import = re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) + _import = re.search( + rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement + ) if _import: source = _import.group(1) if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): @@ -1081,7 +1085,7 @@ def visit_SimpleStatementLine(self, node): self.imports.append(node) elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): import_module = self.python_module.code_for_node(node.body[0].module) - import_statement = "."*len(node.body[0].relative) + import_module + import_statement = "." * len(node.body[0].relative) + import_module if not ( re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) From efdbe788c95605238d44fe78119f4fa282dc8108 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 30 Oct 2024 23:16:36 +0100 Subject: [PATCH 28/40] fix modular_roberta example (wrong redefinition of __init__) --- examples/modular-transformers/modeling_roberta.py | 5 +---- examples/modular-transformers/modular_roberta.py | 7 ++----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index 47fea7d32d3..e50cf60c3a4 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -816,7 +816,7 @@ class RobertaModel(RobertaPreTrainedModel): _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] - def __init__(self, config): + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -830,9 +830,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - # Error out here. Why? Because `RobertaEmbeddings` is defined but not used. - # no, because it's defined, and RobertaModel should use RobertaEmbedding - # here if initialized that way it won't use the new embedding. def get_input_embeddings(self): return self.embeddings.word_embeddings diff --git a/examples/modular-transformers/modular_roberta.py b/examples/modular-transformers/modular_roberta.py index a3e0218f932..8ca74e52674 100644 --- a/examples/modular-transformers/modular_roberta.py +++ b/examples/modular-transformers/modular_roberta.py @@ -13,8 +13,5 @@ def __init__(self, config): class RobertaModel(BertModel): - def __init__(self, config): - super().__init__(self, config) - # Error out here. Why? Because `RobertaEmbeddings` is defined but not used. - # no, because it's defined, and RobertaModel should use RobertaEmbedding - # here if initialized that way it won't use the new embedding. + def __init__(self, config, add_pooling_layer=True): + super().__init__(self, config) \ No newline at end of file From e8fe360177aa9f2fcd1638f9fb98827d778ec039 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 09:38:05 +0100 Subject: [PATCH 29/40] slightly correct order in which dependencies will appear --- examples/modular-transformers/modeling_dummy.py | 14 +++++++------- utils/modular_model_converter.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index ed87fb66f0a..ed7e3c64d7a 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -172,6 +172,13 @@ def forward(self, x): return down_proj +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 4] + x2 = x[..., x.shape[-1] // 4 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -211,13 +218,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 4] - x2 = x[..., x.shape[-1] // 4 :] - return torch.cat((-x2, x1), dim=-1) - - class DummyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 26665cb9ee0..dbebe17a1d9 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -701,10 +701,10 @@ def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, in # We need to differentiate between nodes that were already present (we can get relative order globally) and # nodes that were merged (we can get relative order only relative to the class the dependencies relate to) for class_dep in class_dependencies: - if class_dep in self.modular_file_start_lines: - merged_dependencies.append(class_dep) - else: + if class_dep in self.start_lines: original_dependencies.append(class_dep) + else: + merged_dependencies.append(class_dep) # Sort both list according to the order in their respective file original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) From dea43c8810c9d05d2a790384698a347366864af0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 09:38:51 +0100 Subject: [PATCH 30/40] style --- examples/modular-transformers/modular_roberta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modular-transformers/modular_roberta.py b/examples/modular-transformers/modular_roberta.py index 8ca74e52674..13dca4845c1 100644 --- a/examples/modular-transformers/modular_roberta.py +++ b/examples/modular-transformers/modular_roberta.py @@ -14,4 +14,4 @@ def __init__(self, config): class RobertaModel(BertModel): def __init__(self, config, add_pooling_layer=True): - super().__init__(self, config) \ No newline at end of file + super().__init__(self, config) From 9a8a7e07bc47174a39c215e59627a9a963a629b1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 09:46:22 +0100 Subject: [PATCH 31/40] review comments --- utils/modular_model_converter.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index dbebe17a1d9..ecdf85d6ebe 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -499,7 +499,7 @@ def visit_Name(self, node): self.dependencies.add(node.value) -def dependencies_for_class_node(node: cst.ClassDef, global_names: set) -> set: +def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: """Create immediate dependencies for a class node based on the `global_names`.""" temp_module = cst.Module(body=[node]) wrapper = MetadataWrapper(temp_module) @@ -661,7 +661,7 @@ def compute_class_dependencies(self): @abstractmethod def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - pass + raise NotImplementedError class ModelFileMapper(ModuleMapper): @@ -1214,7 +1214,7 @@ class node based on the inherited classes if needed. file_type = find_file_type(class_name) file_to_update = files[file_type] - # We need to replace the class node with the super class node + # We need to replace the class node with the transformers (modeling file) super class node if len(bases) == 1: super_class = bases[0] super_file_name = modular_mapper.model_specific_imported_objects[super_class] @@ -1242,7 +1242,8 @@ class node based on the inherited classes if needed. dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add } - # No super class, just check functions and assignments dependency in the imports from other modeling files + # No transformers (modeling file) super class, just check functions and assignments dependency in the imports from + # other modeling files else: updated_node = node # The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not From 38a574ace801df332b2a0aa8fc63362d3c490212 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 12:37:49 +0100 Subject: [PATCH 32/40] Performance + better handling of dependencies when they are imported --- utils/modular_model_converter.py | 72 +++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index ecdf85d6ebe..679007cab15 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -479,46 +479,46 @@ class ClassDependencyMapper(CSTVisitor): `global_names`. """ - METADATA_DEPENDENCIES = (ParentNodeProvider,) - - def __init__(self, class_name: str, global_names: set | None): + def __init__(self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None): super().__init__() self.class_name = class_name self.dependencies = set() self.global_names = global_names + self.objects_imported_from_modeling = set() if objects_imported_from_modeling is None else objects_imported_from_modeling def visit_Name(self, node): - if node.value != self.class_name and node.value in self.global_names: - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - # If it is only an annotation inside a method definition, do not add dependency (however, do it for - # annotations that are variable definitions, i.e. for Kwargs classes) - if m.matches(parent_node, m.Annotation()): - grand_parent = self.get_metadata(cst.metadata.ParentNodeProvider, parent_node) - if m.matches(grand_parent, m.Param() | m.FunctionDef()): - return + if node.value != self.class_name and node.value in self.global_names and node.value not in self.objects_imported_from_modeling: self.dependencies.add(node.value) def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: """Create immediate dependencies for a class node based on the `global_names`.""" temp_module = cst.Module(body=[node]) - wrapper = MetadataWrapper(temp_module) visitor = ClassDependencyMapper(node.name.value, global_names) - wrapper.visit(visitor) + temp_module.visit(visitor) return visitor.dependencies -def augmented_dependencies_for_class_node(node: cst.ClassDef, mapper: "ModuleMapper") -> set: +def augmented_dependencies_for_class_node(node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None) -> set: """Create augmented dependencies for a class node based on a `mapper`. Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. """ temp_module = cst.Module(body=[node]) - wrapper = MetadataWrapper(temp_module) - visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys())) - wrapper.visit(visitor) + visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling) + temp_module.visit(visitor) return mapper.augment_dependencies(visitor.dependencies) +# All the potential file types to create +ALL_FILE_TYPES = ( + "modeling", + "configuration", + "tokenization", + "processing", + "image_processing", + "feature_extractor", +) + class ModuleMapper(CSTVisitor, ABC): """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in @@ -539,8 +539,26 @@ def __init__(self, python_module: cst.Module): self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes self.current_function = None # this keeps track of the current module-scope function self.current_assignment = None # this keeps track of the current module-scope assignment + # this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency + self.objects_imported_from_modeling = set() + # regex pattern joining every possible file type + self.match_patterns = "|".join(ALL_FILE_TYPES) # fmt: on + def visit_ImportFrom(self, node): + """This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have + `from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs + to be added (because it will be part of the imports)""" + import_module = self.python_module.code_for_node(node.module) + import_statement = "." * len(node.relative) + import_module + if re.search(rf"^\.({self.match_patterns})_.*", import_statement): + for imported_object in node.names: + # If an alias is present, we record it and not the original name + if imported_object.evaluated_alias is not None: + self.objects_imported_from_modeling.add(imported_object.evaluated_alias) + else: + self.objects_imported_from_modeling.add(imported_object.evaluated_name) + def visit_SimpleStatementLine(self, node): """ Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements @@ -1033,7 +1051,6 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} - self.match_patterns = "|".join(list(TYPE_TO_FILE_TYPE.values()) + ["modeling"]) self.all_all_to_add = {} # fmt: on @@ -1135,6 +1152,14 @@ def leave_Module(self, node): # definitions found in the visited files self.merge_model_specific_imports(self.visited_modules) + # We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later + # Note that we may visit several of the same file types, thus we save them per file type, not file + self.imported_objects_per_file = defaultdict(set) + for file, mapper in self.visited_modules.items(): + file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1) + self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) + print(self.imported_objects_per_file) + def merge_model_specific_imports(self, visited_modules): """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, based on the visited files.""" @@ -1214,6 +1239,13 @@ class node based on the inherited classes if needed. file_type = find_file_type(class_name) file_to_update = files[file_type] + # This is used to avoid adding objects to the dependencies graph if they will be imported + # e.g. Config is imported in modeling, it should not be redefined into it + if file_type in modular_mapper.imported_objects_per_file: + imported_objects = modular_mapper.imported_objects_per_file[file_type] + else: + imported_objects = None + # We need to replace the class node with the transformers (modeling file) super class node if len(bases) == 1: super_class = bases[0] @@ -1230,7 +1262,7 @@ class node based on the inherited classes if needed. updated_node = replace_class_node(mapper, node, renamed_super_class) # The node was modified -> look for all dependencies (recursively) of the new node - new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper) + new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) all_dependencies_to_add = find_all_dependencies( dependency_mapping=mapper.class_dependency_mapping, initial_dependencies=new_node_dependencies, @@ -1248,7 +1280,7 @@ class node based on the inherited classes if needed. updated_node = node # The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not # already defined (which would mean a weird order of the code in the modular...), they will be in the future - all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper) + all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) nodes_to_add = { From 0b7c10396a8fef25c941d319bd83a02e61ca2ada Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 12:38:54 +0100 Subject: [PATCH 33/40] style --- utils/modular_model_converter.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 679007cab15..5c5804d4f5d 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -479,15 +479,23 @@ class ClassDependencyMapper(CSTVisitor): `global_names`. """ - def __init__(self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None): + def __init__( + self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None + ): super().__init__() self.class_name = class_name self.dependencies = set() self.global_names = global_names - self.objects_imported_from_modeling = set() if objects_imported_from_modeling is None else objects_imported_from_modeling + self.objects_imported_from_modeling = ( + set() if objects_imported_from_modeling is None else objects_imported_from_modeling + ) def visit_Name(self, node): - if node.value != self.class_name and node.value in self.global_names and node.value not in self.objects_imported_from_modeling: + if ( + node.value != self.class_name + and node.value in self.global_names + and node.value not in self.objects_imported_from_modeling + ): self.dependencies.add(node.value) @@ -499,7 +507,9 @@ def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> s return visitor.dependencies -def augmented_dependencies_for_class_node(node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None) -> set: +def augmented_dependencies_for_class_node( + node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None +) -> set: """Create augmented dependencies for a class node based on a `mapper`. Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. """ @@ -519,6 +529,7 @@ def augmented_dependencies_for_class_node(node: cst.ClassDef, mapper: "ModuleMap "feature_extractor", ) + class ModuleMapper(CSTVisitor, ABC): """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in @@ -557,7 +568,7 @@ def visit_ImportFrom(self, node): if imported_object.evaluated_alias is not None: self.objects_imported_from_modeling.add(imported_object.evaluated_alias) else: - self.objects_imported_from_modeling.add(imported_object.evaluated_name) + self.objects_imported_from_modeling.add(imported_object.evaluated_name) def visit_SimpleStatementLine(self, node): """ @@ -1158,7 +1169,6 @@ def leave_Module(self, node): for file, mapper in self.visited_modules.items(): file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1) self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) - print(self.imported_objects_per_file) def merge_model_specific_imports(self, visited_modules): """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, @@ -1239,7 +1249,7 @@ class node based on the inherited classes if needed. file_type = find_file_type(class_name) file_to_update = files[file_type] - # This is used to avoid adding objects to the dependencies graph if they will be imported + # This is used to avoid adding objects to the dependencies graph if they will be imported # e.g. Config is imported in modeling, it should not be redefined into it if file_type in modular_mapper.imported_objects_per_file: imported_objects = modular_mapper.imported_objects_per_file[file_type] From dde85dc30e11ecd3e7779f1854ff92c096245104 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 16:45:41 +0100 Subject: [PATCH 34/40] Add advanced new classes capabilities --- utils/modular_model_converter.py | 122 ++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 28 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 5c5804d4f5d..ce76fb9b726 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -543,7 +543,7 @@ class ModuleMapper(CSTVisitor, ABC): def __init__(self, python_module: cst.Module): # fmt: off self.python_module: cst.Module = python_module # original cst.Module being visited - self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes + self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!) self.imports = [] # stores all import statements self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition) @@ -674,14 +674,7 @@ def augment_dependencies(self, dependencies: set[str]) -> set[str]: def compute_class_dependencies(self): """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies. - Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the - merge with other files dependencies. """ - # Correctly re-set the global nodes at this point - self.global_nodes = {**self.assignments, **self.classes, **self.functions} - # Create the global mapping of recursive dependencies for functions and assignments - self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() - self.class_dependency_mapping = {} for class_name, class_node in self.classes.items(): dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) @@ -783,8 +776,8 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: di def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): """Update the global nodes with the assignment from the modular file. - Merging rule: if any assignment with the same name was redefined in the modular, we use it ONLY if it is - in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value. This rule was chosen to avoid having to rewrite the + Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is + in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the big docstrings. """ for assignment, node in assignments.items(): @@ -793,22 +786,40 @@ def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping if assignment in object_mapping: self.object_dependency_mapping[assignment] = object_mapping[assignment] - def merge_modular_dependencies(self, functions, assignments, object_mapping, start_lines): - """Merge both functions and assignments from the modular definitions into the current module file, - then record the relative order of all nodes.""" + def _merge_classes(self, classes: dict[str, cst.CSTNode]): + """Update the global nodes with the new classes from the modular. We do NOT update any dependency mapping here. + This is because we only need the names of newly defined classes in the modular to be discoverable when computing dependencies + for new nodes later on. For this reason, we do not add the new classes to `self.classes`, but only to `global_nodes`. + """ + # Add/overwrite all needed function nodes and dependencies + self.global_nodes.update({name: node for name, node in classes.items() if name not in self.classes}) + + def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): + """Merge classes, functions and assignments from the modular definitions into the current module file, + then record the relative order of all nodes. + Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the + merge with other files dependencies. + """ self._merge_functions(functions, object_mapping) self._merge_assignments(assignments, object_mapping) + self._merge_classes(classes) self.modular_file_start_lines = start_lines + # Correctly re-set the global nodes at this point + self.global_nodes.update(self.functions) + self.global_nodes.update(self.assignments) + # Create the global mapping of recursive dependencies for functions and assignments + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + @classmethod def visit_and_merge_dependencies( - cls, module: cst.Module, functions, assignments, object_mapping, start_lines + cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines ) -> "ModelFileMapper": wrapper = MetadataWrapper(module) mapper = cls(module) wrapper.visit(mapper) # Merge dependencies - mapper.merge_modular_dependencies(functions, assignments, object_mapping, start_lines) + mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines) # Create the class dependencies graph mapper.compute_class_dependencies() return mapper @@ -1151,6 +1162,7 @@ def leave_Module(self, node): renamed_module = module.visit(renamer) self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( renamed_module, + self.classes, self.functions, self.assignments, self.object_dependency_mapping, @@ -1234,11 +1246,52 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: return relative_order +def check_dependencies_and_create_import_node(file_type: str, new_dependencies: set[str], mapper: ModuleMapper, + new_name: str) -> tuple[set[str], dict[str, cst.CSTNode]]: + """Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, + we need to remove it from the dependencies, and create a new import to it instead. + This scenario may appear in the following case: + If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py` + (e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as + part of the standard dependency graph (because we never encountered an import towards this new class in any file). + For example imagine the following `modular.py`: + ``` + from ..llama.modeling_llama import LlamaModel + + class NewNameTextConfig(PretrainedConfig): + ... + + class NewNameConfig(PretrainedConfig): + ... + + class NewNameModel(LlamaModel): + config = NewNameConfig() + text_config = NewNameTextConfig() + ... + ``` + then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as + `configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no + knowledge of `NewNameTextConfig`. + """ + class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())} + corrected_dependencies = new_dependencies.copy() + new_imports = {} + for class_name in class_dependencies: + class_file_type = find_file_type(class_name) + # In this case, we need to remove it from the dependencies and create a new import instead + if class_file_type != file_type: + corrected_dependencies.remove(class_name) + import_statement = f"from .{class_file_type}_{new_name} import {class_name}" + new_imports[class_name] = cst.parse_statement(import_statement) + + return corrected_dependencies, new_imports + def get_class_node_and_dependencies( modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] -) -> tuple[dict, str]: +) -> tuple[dict, str, dict]: """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new - class node based on the inherited classes if needed. + class node based on the inherited classes if needed. Also returns any new imports of a new class defined in + the modular that we nay need. """ bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects] if len(bases) > 1: @@ -1248,13 +1301,10 @@ class node based on the inherited classes if needed. file_type = find_file_type(class_name) file_to_update = files[file_type] + model_name = modular_mapper.model_name - # This is used to avoid adding objects to the dependencies graph if they will be imported - # e.g. Config is imported in modeling, it should not be redefined into it - if file_type in modular_mapper.imported_objects_per_file: - imported_objects = modular_mapper.imported_objects_per_file[file_type] - else: - imported_objects = None + # This is used to avoid adding objects to the dependencies graph if they will be imported already + imported_objects = modular_mapper.imported_objects_per_file[file_type] # We need to replace the class node with the transformers (modeling file) super class node if len(bases) == 1: @@ -1271,8 +1321,15 @@ class node based on the inherited classes if needed. # Create the new class node updated_node = replace_class_node(mapper, node, renamed_super_class) - # The node was modified -> look for all dependencies (recursively) of the new node + # Grab all immediate dependencies of the new node new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) + + # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove + # it from the dependencies, and add a new import of it instead + new_node_dependencies, new_imports = check_dependencies_and_create_import_node(file_type, new_node_dependencies, + mapper, model_name) + + # The node was modified -> look for all recursive dependencies of the new node all_dependencies_to_add = find_all_dependencies( dependency_mapping=mapper.class_dependency_mapping, initial_dependencies=new_node_dependencies, @@ -1284,14 +1341,18 @@ class node based on the inherited classes if needed. dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add } - # No transformers (modeling file) super class, just check functions and assignments dependency in the imports from - # other modeling files + # No transformers (modeling file) super class, just check functions and assignments dependencies else: updated_node = node # The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not # already defined (which would mean a weird order of the code in the modular...), they will be in the future all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) + # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove + # it from the dependencies, and add a new import of it instead + all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node(file_type, all_dependencies_to_add, + modular_mapper, model_name) + relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) nodes_to_add = { dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) @@ -1303,7 +1364,7 @@ class node based on the inherited classes if needed. class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 nodes_to_add[class_name] = (class_idx, updated_node) - return nodes_to_add, file_type + return nodes_to_add, file_type, new_imports def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: @@ -1313,7 +1374,12 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: # For each class defined in modular, potentially replace the node and add it with its dependencies for class_name, node in modular_mapper.classes.items(): - nodes_to_add, file_type = get_class_node_and_dependencies(modular_mapper, class_name, node, files) + nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files) + + # Add the new potential new imports that we may need to the `modular_mapper` variable + modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys()) + modular_mapper.imports.extend(list(new_imports.values())) + # Sort the nodes according to their relative order nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) # Write all nodes to file From 7cd1eff6bb42c886764f7fc978fcdd5f2619fdb4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 16:46:38 +0100 Subject: [PATCH 35/40] style --- utils/modular_model_converter.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index ce76fb9b726..9db60eb6110 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -673,8 +673,7 @@ def augment_dependencies(self, dependencies: set[str]) -> set[str]: return new_dependencies def compute_class_dependencies(self): - """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies. - """ + """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies.""" self.class_dependency_mapping = {} for class_name, class_node in self.classes.items(): dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) @@ -1246,8 +1245,9 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: return relative_order -def check_dependencies_and_create_import_node(file_type: str, new_dependencies: set[str], mapper: ModuleMapper, - new_name: str) -> tuple[set[str], dict[str, cst.CSTNode]]: +def check_dependencies_and_create_import_node( + file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str +) -> tuple[set[str], dict[str, cst.CSTNode]]: """Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, we need to remove it from the dependencies, and create a new import to it instead. This scenario may appear in the following case: @@ -1260,7 +1260,7 @@ def check_dependencies_and_create_import_node(file_type: str, new_dependencies: class NewNameTextConfig(PretrainedConfig): ... - + class NewNameConfig(PretrainedConfig): ... @@ -1286,6 +1286,7 @@ class NewNameModel(LlamaModel): return corrected_dependencies, new_imports + def get_class_node_and_dependencies( modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] ) -> tuple[dict, str, dict]: @@ -1326,8 +1327,9 @@ class node based on the inherited classes if needed. Also returns any new import # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove # it from the dependencies, and add a new import of it instead - new_node_dependencies, new_imports = check_dependencies_and_create_import_node(file_type, new_node_dependencies, - mapper, model_name) + new_node_dependencies, new_imports = check_dependencies_and_create_import_node( + file_type, new_node_dependencies, mapper, model_name + ) # The node was modified -> look for all recursive dependencies of the new node all_dependencies_to_add = find_all_dependencies( @@ -1350,8 +1352,9 @@ class node based on the inherited classes if needed. Also returns any new import # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove # it from the dependencies, and add a new import of it instead - all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node(file_type, all_dependencies_to_add, - modular_mapper, model_name) + all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node( + file_type, all_dependencies_to_add, modular_mapper, model_name + ) relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) nodes_to_add = { From cc58d435df1e5a874686402c3e681614795cd9b0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 16:53:06 +0100 Subject: [PATCH 36/40] add forgotten check --- utils/modular_model_converter.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 9db60eb6110..b1dfa18a7a9 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -786,12 +786,19 @@ def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping self.object_dependency_mapping[assignment] = object_mapping[assignment] def _merge_classes(self, classes: dict[str, cst.CSTNode]): - """Update the global nodes with the new classes from the modular. We do NOT update any dependency mapping here. - This is because we only need the names of newly defined classes in the modular to be discoverable when computing dependencies - for new nodes later on. For this reason, we do not add the new classes to `self.classes`, but only to `global_nodes`. + """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and + are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined + classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we + do not add the new classes to `self.classes`, but only to `global_nodes`. """ # Add/overwrite all needed function nodes and dependencies - self.global_nodes.update({name: node for name, node in classes.items() if name not in self.classes}) + self.global_nodes.update( + { + name: node + for name, node in classes.items() + if name not in self.classes and name not in self.objects_imported_from_modeling + } + ) def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): """Merge classes, functions and assignments from the modular definitions into the current module file, From f05849ab5d6cfa3be74acaf3ac84bead89cafb32 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 17:35:35 +0100 Subject: [PATCH 37/40] Update modeling_llava_next_video.py --- .../models/llava_next_video/modeling_llava_next_video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 1eb94508ee4..73118f4bfcd 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -286,12 +286,12 @@ def unpad_image(tensor, original_size): if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width - new_height = int(original_height * scale_factor) + new_height = int(round(original_height * scale_factor, 7)) padding = (current_height - new_height) // 2 unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height - new_width = int(original_width * scale_factor) + new_width = int(round(original_width * scale_factor, 7)) padding = (current_width - new_width) // 2 unpadded_tensor = tensor[:, :, padding : current_width - padding] From be70f7de0a6dd5760001720e1ff45e937b78c512 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 18:15:13 +0100 Subject: [PATCH 38/40] Add prority list ordering in check_conversion as well --- utils/check_modular_conversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/check_modular_conversion.py b/utils/check_modular_conversion.py index 09b237c1e6c..97692d5bccc 100644 --- a/utils/check_modular_conversion.py +++ b/utils/check_modular_conversion.py @@ -6,6 +6,7 @@ # Console for rich printing from modular_model_converter import convert_modular_file +from create_dependency_mapping import find_priority_list from rich.console import Console from rich.syntax import Syntax @@ -69,7 +70,7 @@ def compare_files(modular_file_path, fix_and_overwrite=False): if args.files == ["all"]: args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) non_matching_files = 0 - for modular_file_path in args.files: + for modular_file_path in find_priority_list(args.files): non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite) if non_matching_files and not args.fix_and_overwrite: From c8a4d4d5fbab9f1cdf4a189734acd1b720d7924d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 31 Oct 2024 18:18:32 +0100 Subject: [PATCH 39/40] Update check_modular_conversion.py --- utils/check_modular_conversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/check_modular_conversion.py b/utils/check_modular_conversion.py index 97692d5bccc..86af396e03a 100644 --- a/utils/check_modular_conversion.py +++ b/utils/check_modular_conversion.py @@ -4,9 +4,10 @@ import logging from io import StringIO +from create_dependency_mapping import find_priority_list + # Console for rich printing from modular_model_converter import convert_modular_file -from create_dependency_mapping import find_priority_list from rich.console import Console from rich.syntax import Syntax From cfec75d14c97a58e27faa97ec59f5e0acde0c743 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 1 Nov 2024 09:44:54 +0100 Subject: [PATCH 40/40] Update configuration_gemma.py --- src/transformers/models/gemma/configuration_gemma.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 346f386ba69..e170803ccca 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from ...configuration_utils import PretrainedConfig