Skip to content

Commit

Permalink
Fixing the stub?
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 2, 2025
1 parent e61e872 commit 2eb3c66
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 21 deletions.
104 changes: 90 additions & 14 deletions bindings/python/py_src/safetensors/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ def deserialize(bytes):
Opens a safetensors lazily and returns tensors as asked
Args:
data (:obj:`bytes`):
data (`bytes`):
The byte content of a file
Returns:
(:obj:`List[str, Dict[str, Dict[str, any]]]`):
(`List[str, Dict[str, Dict[str, any]]]`):
The deserialized content is like:
[("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)]
"""
Expand All @@ -21,14 +21,14 @@ def serialize(tensor_dict, metadata=None):
Serializes raw data.
Args:
tensor_dict (:obj:`Dict[str, Dict[Any]]`):
tensor_dict (`Dict[str, Dict[Any]]`):
The tensor dict is like:
{"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}}
metadata (:obj:`Dict[str, str]`, *optional*):
metadata (`Dict[str, str]`, *optional*):
The optional purely text annotations
Returns:
(:obj:`bytes`):
(`bytes`):
The serialized content.
"""
pass
Expand All @@ -39,16 +39,16 @@ def serialize_file(tensor_dict, filename, metadata=None):
Serializes raw data.
Args:
tensor_dict (:obj:`Dict[str, Dict[Any]]`):
tensor_dict (`Dict[str, Dict[Any]]`):
The tensor dict is like:
{"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}}
filename (:obj:`str`):
filename (`str`, or `os.PathLike`):
The name of the file to write into.
metadata (:obj:`Dict[str, str]`, *optional*):
metadata (`Dict[str, str]`, *optional*):
The optional purely text annotations
Returns:
(:obj:`bytes`):
(`bytes`):
The serialized content.
"""
pass
Expand All @@ -58,16 +58,92 @@ class safe_open:
Opens a safetensors lazily and returns tensors as asked
Args:
filename (:obj:`str`):
filename (`str`, or `os.PathLike`):
The filename to open
framework (:obj:`str`):
The framework you want your tensors in. Supported values:
framework (`str`):
The framework you want you tensors in. Supported values:
`pt`, `tf`, `flax`, `numpy`.
device (:obj:`str`, defaults to :obj:`"cpu"`):
device (`str`, defaults to `"cpu"`):
The device on which you want the tensors.
"""

def __init__(self, filename, framework, device="cpu"):
def __init__(filename, framework, device=...):
pass
def __enter__(self):
"""
Start the context manager
"""
pass
def __exit__(self, _exc_type, _exc_value, _traceback):
"""
Exits the context manager
"""
pass
def get_slice(self, name):
"""
Returns a full slice view object
Args:
name (`str`):
The name of the tensor you want
Returns:
(`PySafeSlice`):
A dummy object you can slice into to get a real tensor
Example:
```python
from safetensors import safe_open
with safe_open("model.safetensors", framework="pt", device=0) as f:
tensor_part = f.get_slice("embedding")[:, ::8]
```
"""
pass
def get_tensor(self, name):
"""
Returns a full tensor
Args:
name (`str`):
The name of the tensor you want
Returns:
(`Tensor`):
The tensor in the framework you opened the file for.
Example:
```python
from safetensors import safe_open
with safe_open("model.safetensors", framework="pt", device=0) as f:
tensor = f.get_tensor("embedding")
```
"""
pass
def keys(self):
"""
Returns the names of the tensors in the file.
Returns:
(`List[str]`):
The name of the tensors contained in that file
"""
pass
def metadata(self):
"""
Return the special non tensor information in the header
Returns:
(`Dict[str, str]`):
The freeform metadata.
"""
pass

class SafetensorError(Exception):
"""
Custom Python Exception for Safetensor errors.
"""
2 changes: 2 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,10 +755,12 @@ impl safe_open {
self.inner()?.get_slice(name)
}

/// Start the context manager
pub fn __enter__(slf: Py<Self>) -> Py<Self> {
slf
}

/// Exits the context manager
pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) {
self.inner = None;
}
Expand Down
37 changes: 30 additions & 7 deletions bindings/python/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ def member_sort(member):
def fn_predicate(obj):
value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
if value:
return obj.__doc__ and obj.__text_signature__ and not obj.__name__.startswith("_")
return (
obj.__doc__
and obj.__text_signature__
and (
not obj.__name__.startswith("_")
or obj.__name__ in {"__enter__", "__exit__"}
)
)
if inspect.isgetsetdescriptor(obj):
return obj.__doc__ and not obj.__name__.startswith("_")
return False
Expand Down Expand Up @@ -74,7 +81,9 @@ def pyi_file(obj, indent=""):

body = ""
if obj.__doc__:
body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
body += (
f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
)

fns = inspect.getmembers(obj, fn_predicate)

Expand All @@ -84,7 +93,7 @@ def pyi_file(obj, indent=""):
body += f"{indent+INDENT}pass\n"
body += "\n"

for (name, fn) in fns:
for name, fn in fns:
body += pyi_file(fn, indent=indent)

if not body:
Expand Down Expand Up @@ -130,13 +139,18 @@ def do_black(content, is_pyi):
experimental_string_processing=False,
)
try:
content = content.replace("$self", "self")
return black.format_file_contents(content, fast=True, mode=mode)
except black.NothingChanged:
return content


def write(module, directory, origin, check=False):
submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
submodules = [
(name, member)
for name, member in inspect.getmembers(module)
if inspect.ismodule(member)
]

filename = os.path.join(directory, "__init__.pyi")
pyi_content = pyi_file(module)
Expand All @@ -145,7 +159,9 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
assert (
data == pyi_content
), f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(pyi_content)
Expand All @@ -168,7 +184,9 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
assert (
data == py_content
), f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(py_content)
Expand All @@ -184,4 +202,9 @@ def write(module, directory, origin, check=False):
args = parser.parse_args()
import safetensors

write(safetensors.safetensors_rust, "py_src/safetensors/", "safetensors", check=args.check)
write(
safetensors._safetensors_rust,
"py_src/safetensors/",
"safetensors",
check=args.check,
)

0 comments on commit 2eb3c66

Please sign in to comment.