From 2eb3c66ad34de0d09f9532467b7246889ee58f6b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 2 Jan 2025 18:24:34 +0100 Subject: [PATCH] Fixing the stub? --- .../python/py_src/safetensors/__init__.pyi | 104 +++++++++++++++--- bindings/python/src/lib.rs | 2 + bindings/python/stub.py | 37 +++++-- 3 files changed, 122 insertions(+), 21 deletions(-) diff --git a/bindings/python/py_src/safetensors/__init__.pyi b/bindings/python/py_src/safetensors/__init__.pyi index 2de57a8e..81442bd3 100644 --- a/bindings/python/py_src/safetensors/__init__.pyi +++ b/bindings/python/py_src/safetensors/__init__.pyi @@ -5,11 +5,11 @@ def deserialize(bytes): Opens a safetensors lazily and returns tensors as asked Args: - data (:obj:`bytes`): + data (`bytes`): The byte content of a file Returns: - (:obj:`List[str, Dict[str, Dict[str, any]]]`): + (`List[str, Dict[str, Dict[str, any]]]`): The deserialized content is like: [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)] """ @@ -21,14 +21,14 @@ def serialize(tensor_dict, metadata=None): Serializes raw data. Args: - tensor_dict (:obj:`Dict[str, Dict[Any]]`): + tensor_dict (`Dict[str, Dict[Any]]`): The tensor dict is like: {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}} - metadata (:obj:`Dict[str, str]`, *optional*): + metadata (`Dict[str, str]`, *optional*): The optional purely text annotations Returns: - (:obj:`bytes`): + (`bytes`): The serialized content. """ pass @@ -39,16 +39,16 @@ def serialize_file(tensor_dict, filename, metadata=None): Serializes raw data. Args: - tensor_dict (:obj:`Dict[str, Dict[Any]]`): + tensor_dict (`Dict[str, Dict[Any]]`): The tensor dict is like: {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}} - filename (:obj:`str`): + filename (`str`, or `os.PathLike`): The name of the file to write into. - metadata (:obj:`Dict[str, str]`, *optional*): + metadata (`Dict[str, str]`, *optional*): The optional purely text annotations Returns: - (:obj:`bytes`): + (`bytes`): The serialized content. """ pass @@ -58,16 +58,92 @@ class safe_open: Opens a safetensors lazily and returns tensors as asked Args: - filename (:obj:`str`): + filename (`str`, or `os.PathLike`): The filename to open - framework (:obj:`str`): - The framework you want your tensors in. Supported values: + framework (`str`): + The framework you want you tensors in. Supported values: `pt`, `tf`, `flax`, `numpy`. - device (:obj:`str`, defaults to :obj:`"cpu"`): + device (`str`, defaults to `"cpu"`): The device on which you want the tensors. """ - def __init__(self, filename, framework, device="cpu"): + def __init__(filename, framework, device=...): pass + def __enter__(self): + """ + Start the context manager + """ + pass + def __exit__(self, _exc_type, _exc_value, _traceback): + """ + Exits the context manager + """ + pass + def get_slice(self, name): + """ + Returns a full slice view object + + Args: + name (`str`): + The name of the tensor you want + + Returns: + (`PySafeSlice`): + A dummy object you can slice into to get a real tensor + Example: + ```python + from safetensors import safe_open + + with safe_open("model.safetensors", framework="pt", device=0) as f: + tensor_part = f.get_slice("embedding")[:, ::8] + + ``` + """ + pass + def get_tensor(self, name): + """ + Returns a full tensor + + Args: + name (`str`): + The name of the tensor you want + + Returns: + (`Tensor`): + The tensor in the framework you opened the file for. + + Example: + ```python + from safetensors import safe_open + + with safe_open("model.safetensors", framework="pt", device=0) as f: + tensor = f.get_tensor("embedding") + + ``` + """ + pass + def keys(self): + """ + Returns the names of the tensors in the file. + + Returns: + (`List[str]`): + The name of the tensors contained in that file + """ + pass + def metadata(self): + """ + Return the special non tensor information in the header + + Returns: + (`Dict[str, str]`): + The freeform metadata. + """ + pass + +class SafetensorError(Exception): + """ + Custom Python Exception for Safetensor errors. + """ diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 21062bd4..2ecd9a7f 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -755,10 +755,12 @@ impl safe_open { self.inner()?.get_slice(name) } + /// Start the context manager pub fn __enter__(slf: Py) -> Py { slf } + /// Exits the context manager pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) { self.inner = None; } diff --git a/bindings/python/stub.py b/bindings/python/stub.py index 340fda55..78ddc47f 100644 --- a/bindings/python/stub.py +++ b/bindings/python/stub.py @@ -39,7 +39,14 @@ def member_sort(member): def fn_predicate(obj): value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj) if value: - return obj.__doc__ and obj.__text_signature__ and not obj.__name__.startswith("_") + return ( + obj.__doc__ + and obj.__text_signature__ + and ( + not obj.__name__.startswith("_") + or obj.__name__ in {"__enter__", "__exit__"} + ) + ) if inspect.isgetsetdescriptor(obj): return obj.__doc__ and not obj.__name__.startswith("_") return False @@ -74,7 +81,9 @@ def pyi_file(obj, indent=""): body = "" if obj.__doc__: - body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n' + body += ( + f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n' + ) fns = inspect.getmembers(obj, fn_predicate) @@ -84,7 +93,7 @@ def pyi_file(obj, indent=""): body += f"{indent+INDENT}pass\n" body += "\n" - for (name, fn) in fns: + for name, fn in fns: body += pyi_file(fn, indent=indent) if not body: @@ -130,13 +139,18 @@ def do_black(content, is_pyi): experimental_string_processing=False, ) try: + content = content.replace("$self", "self") return black.format_file_contents(content, fast=True, mode=mode) except black.NothingChanged: return content def write(module, directory, origin, check=False): - submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)] + submodules = [ + (name, member) + for name, member in inspect.getmembers(module) + if inspect.ismodule(member) + ] filename = os.path.join(directory, "__init__.pyi") pyi_content = pyi_file(module) @@ -145,7 +159,9 @@ def write(module, directory, origin, check=False): if check: with open(filename, "r") as f: data = f.read() - assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`" + assert ( + data == pyi_content + ), f"The content of {filename} seems outdated, please run `python stub.py`" else: with open(filename, "w") as f: f.write(pyi_content) @@ -168,7 +184,9 @@ def write(module, directory, origin, check=False): if check: with open(filename, "r") as f: data = f.read() - assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`" + assert ( + data == py_content + ), f"The content of {filename} seems outdated, please run `python stub.py`" else: with open(filename, "w") as f: f.write(py_content) @@ -184,4 +202,9 @@ def write(module, directory, origin, check=False): args = parser.parse_args() import safetensors - write(safetensors.safetensors_rust, "py_src/safetensors/", "safetensors", check=args.check) + write( + safetensors._safetensors_rust, + "py_src/safetensors/", + "safetensors", + check=args.check, + )