diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5d6f7921..0b8e381c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -197,17 +197,21 @@ def patch_mistral_nemo_config(config): source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) start = source.find("def") - spaces = start*" " source = source.split("\n") source = "\n".join(x[start:] for x in source) + where = source.find("raise KeyError") - source = source[:where] + \ - f"if len(self) == 0:\n{spaces}{spaces}"\ - " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ - f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] + source = ( + source[:where] + + "if len(self) == 0:\n" + + " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + + " else:\n" + + source[where:] + ) source = source.replace("__getitem__", "__cache_utils_getitem__", 1) exec(source) transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ + pass # =============================================