-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
clebsch gordan submodule - implementation of TorchScript interface #269
Conversation
Hmm, the error is a bit weird. I'll try to debug it. |
Should be fixed by metatensor/metatensor#438 |
Here is a pre-built version of the code in this pull request: wheels.zip, you can install it locally by unzipping |
python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
268bdc6
to
8d5b04f
Compare
Currently get TorchScript error that is hard to interpret (see below). By going through the torch compile tree I could identify that it is in the
|
Debugging is actually quite easy there is a |
That's very interesting, I did not know about the ctx! Could you show a quick example of how you use it? |
I have problems with TorchScript supporting the type |
Did you try with an explicit |
Its not that much that it does not recognize the type. It is more that it does not allow to create an object of this type using any list functionalities to concatenate lists. This works perfectly fine def foo(a: Labels):
l: List[Union[None, Labels]] = [None, a]
return l
torch.jit.script(foo) But you cannot do this def foo(a: Labels):
l: List[Union[None, Labels]] = ([None] * 5) + [a]
return l
torch.jit.script(foo)
def foo(a: Labels):
l: List[Union[None, Labels]] = [None, a]
l = [None] * 5
l.append(a)
return l
torch.jit.script(foo)
def foo(a: Labels):
l: List[Union[None, Labels]] = [None, a]
l = [None] * 6
l[5] = a
return l
torch.jit.script(foo) |
This seems to work: import torch
from typing import List, Union
class Labels:
pass
def foo(a: Labels):
l: List[Union[None, Labels]] = [torch.jit.annotate(Union[None, Labels], None)] * 5
l.append(a)
return l
torch.jit.script(foo) But you need to EDIT: this is fine as well: def foo(a: Labels):
l = [torch.jit.annotate(Union[None, Labels], None)] * 5
l.append(a)
return l |
So this worked for fixing the bug that was fixed with the commit d8f9906 Tor reproduce this I pushed a branch git checkout debug-unary-op
tox -e torch-tests -- --pdb The debugging went like this nd.py:406: in __call__
return method(ctx, node)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = <torch._sources.SourceContext object at 0x7f557836f470>, expr = <ast.UnaryOp object at 0x7f557834b310>
@staticmethod
def build_UnaryOp(ctx, expr):
sub_expr = build_expr(ctx, expr.operand)
op = type(expr.op)
op_token = ExprBuilder.unop_map.get(op)
if op_token is None:
raise NotSupportedError(
> expr.range(), "unsupported unary operator: " + op.__name__
)
E AttributeError: 'UnaryOp' object has no attribute 'range'
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:976: AttributeError
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> entering PDB >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB post_mortem (IO-capturing turned off) >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /home/alexgo/code/rascaline/.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py(976)build_UnaryOp()
-> expr.range(), "unsupported unary operator: " + op.__name__
(Pdb) up
> /home/alexgo/code/rascaline/.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py(406)__call__()
-> return method(ctx, node)
(Pdb) print(ctx.source)
def _parse_selected_keys(
n_iterations: int,
angular_cutoff: Optional[int] = None,
selected_keys: Optional[Union[Labels, List[Labels]]] = None,
like=None,
) -> List[Union[None, Labels]]:
"""
Parses the `selected_keys` argument passed to public functions. Checks the
values and returns a :py:class:`list` of :py:class:`Labels` objects, one for
each iteration of CG combination.
`like` is required if a new :py:class:`Labels` object is to be created by
:py:mod:`_dispatch`.
"""
# Check angular_cutoff arg
if angular_cutoff is not None:
if not isinstance(angular_cutoff, int):
raise TypeError("`angular_cutoff` must be passed as an int")
if angular_cutoff < 1:
raise ValueError("`angular_cutoff` must be >= 1")
if selected_keys is None:
if angular_cutoff is None: # no selections at all
selected_keys = [None] * n_iterations
else:
# Create a key selection with all angular channels <= the specified
# angular cutoff
selected_keys = [
Labels(
names=["spherical_harmonics_l"],
values=_dispatch.int_range_like(
0, angular_cutoff, like=like
).reshape(-1, 1),
)
] * n_iterations
if isinstance(selected_keys, Labels):
# Create a list, but only apply a key selection at the final iteration
selected_keys = [None] * (n_iterations - 1) + [selected_keys]
# Check the selected_keys
if not isinstance(selected_keys, List):
raise TypeError(
"`selected_keys` must be a `Labels` or List[Union[None, `Labels`]]"
)
if not len(selected_keys) == n_iterations:
raise ValueError(
"`selected_keys` must be a List[Union[None, Labels]] of length"
" `correlation_order` - 1"
)
if not _dispatch.all(
[isinstance(val, (Labels, type(None))) for val in selected_keys]
):
raise TypeError("`selected_keys` must be a Labels or List[Union[None, Labels]]")
# Now iterate over each of the Labels (or None) in the list and check
for slct in selected_keys:
if slct is None:
continue
assert isinstance(slct, Labels)
if not _dispatch.all(
[
name in ["spherical_harmonics_l", "inversion_sigma"]
for name in slct.names
]
):
raise ValueError(
"specified key names in `selected_keys` must be either"
" 'spherical_harmonics_l' or 'inversion_sigma'"
)
if "spherical_harmonics_l" in slct.names:
if angular_cutoff is not None:
if not _dispatch.all(
slct.column("spherical_harmonics_l") <= angular_cutoff
):
raise ValueError(
"specified angular channels in `selected_keys` must be <= the"
" specified `angular_cutoff`"
)
if not _dispatch.all(
[angular_l >= 0 for angular_l in slct.column("spherical_harmonics_l")]
):
raise ValueError(
"specified angular channels in `selected_keys` must be >= 0"
)
if "inversion_sigma" in slct.names:
if not _dispatch.all(
[parity_s in [-1, +1] for parity_s in slct.column("inversion_sigma")]
):
raise ValueError(
"specified parities in `selected_keys` must be -1 or +1"
)
return selected_keys
(Pdb) node.lineno
88 From that I could infer that it is the line 88 in the printed source code and since the operation that caused an error was an add, I inferred it was the |
Something very slightly annoying not sure if worth an issue, since it could potentially be easy fix RuntimeError:
Variable 'le' is annotated with type __torch__.torch.classes.metatensor.LabelsEntry (of Python compilation unit at: 0) but is being assigned to a value of type Any:
File "<ipython-input-379-f5b994008d5f>", line 2
def foo(l:Labels):
le: LabelsEntry = l[0]
~~ <--- HERE
return le.values workaround def foo(l:Labels):
le = l.values[0]
return le
torch.jit.script(foo) EDIT: def foo(l:Labels):
le: LabelsEntry = l.entry(0)
return le.values
torch.jit.script(foo) |
this breaks TorchScript (segmentation fault) def foo():
f: List[List[Any]] = [torch.jit.annotate(Any,[5,1])]
return f
torch.jit.script(foo) |
So TorchScript seems to have problems associating def foo():
l: List[Any] = [5,3]
torch.jit.script(foo) # works
def foo():
l: List[List[Any]] = [[5,3]]
torch.jit.script(foo) # does not work
def foo():
l: List[Any] = [[5,3]]
torch.jit.script(foo) # works
def foo():
l: List[List[int]] = [[5,3]]
return foo2(l)
def foo2(l: List[Any]):
return l
torch.jit.script(foo) # does not work |
a9907d0
to
be0c6fe
Compare
Building seems to fail because of required rustc update.
The documentation fails because I use autodoc: failed to import class 'torch.CalculatorModule' from module 'rascaline'; the following exception was raised:
Traceback (most recent call last):
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/sphinx/ext/autodoc/importer.py", line 69, in import_module
return importlib.import_module(modname)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/hostedtoolcache/Python/3.11.7/x64/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/rascaline/torch/__init__.py", line 12, in <module>
from . import utils # noqa
^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/rascaline/torch/utils/__init__.py", line 3, in <module>
from . import clebsch_gordan
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/rascaline/torch/utils/clebsch_gordan.py", line 36, in <module>
is_labels = torch.jit.script(is_labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/_script.py", line 1395, in script
fn = torch._C._jit_script_compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/annotations.py", line 501, in try_ann_to_type
return torch.jit._script._recursive_compile_class(ann, loc)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/_script.py", line 1568, in _recursive_compile_class
return _compile_and_register_class(obj, rcb, _qual_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/_recursive.py", line 61, in _compile_and_register_class
script_class = torch._C._jit_script_class_compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/annotations.py", line 453, in try_ann_to_type
maybe_type = try_ann_to_type(a, loc)
reading sources... [ 89%] references/calculators/index
reading sources... [ 90%] references/calculators/lode-spherical-expansion
reading sources... [ 91%] references/calculators/neighbor-list
reading sources... [ 93%] references/calculators/soap-power-spectrum
reading sources... [ 94%] references/calculators/soap-radial-spectrum
reading sources... [ 96%] references/calculators/sorted-distances
reading sources... [ 97%] references/calculators/spherical-expansion
reading sources... [ 99%] references/calculators/spherical-expansion-by-pair
reading sources... [100%] references/index
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/annotations.py", line 415, in try_ann_to_type
return TupleType([try_ann_to_type(a, loc) for a in ann_args])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Can not create tuple with None type Don't know yet how to fix this documentation error. |
8ddf513
to
feee5b7
Compare
test torch array backend and add required dispatches add _classes.py forgot to push add torchscript test for correlate_density start torchscript test fixing TorchScript UnaryOp bug checkpoint: all-deps test pass and made progress on TorchScriptabilty checkpoint2: all-deps test pass and made progress on TorchScriptabilty checkpoint3: all-deps test pass and made progress on TorchScriptabilty checkpoint4: all-deps fails, invalid key pair (l1=0,l2=1,lam=2) is accessed in sparse_combine checkpoint5: all-deps test pass and made progress on TorchScriptabilty checkpoint6: all-deps test pass and made progress on TorchScriptabilty - abstracted out Dict[Tuple[int,int,int], Array] and Dict[Tuple[int,int,int], Dict[Tuple[int,int,int], Array]] into custom classes with utilities that allow similar access checkpoint7: all-deps test pass and made progress on TorchScriptabilty - sparse property in ClebschGordanReal is determined by type of coeffs because we need to do anyway isinstance checks for TorchScript to distinguish the two types in the different functions Union[SparseCgDict, DenseCgDict] -> SparseCgDict - replacing input parameter `return_empty_array` in `sparse combine` function in _cg_cache.py by `empty_combine` function that to distinguish None type for TorchScript checkpoint8: all-deps test pass and made progress on TorchScriptabilty - Labels.insert cannot broadcast in metatensor.torch, correct shape has to be given - torch.tensors of type int32 cannot be converted to lists using tolist() pytorch/pytorch#76295 therefore added tolist dispatch function that first converts the array to int64 checkpoint9: all-deps test pass and made progress on TorchScriptabilty, torch-test test_torch_script_correlate_density_angular_selection passes - changed to comlex type until real conversion otherwise (complex, real) operation checkpoint10: all-deps test pass and made progress on TorchScriptabilty, torch-test test_torch_script_correlate_density_angular_selection passes - made `like` parameter in _parse_selected_keys all-deps test and torch-tests pass - all isinstance check are moved to the false branch of jit.is_scripting linting fixes and partial format fix doctest format remove _dispatch.array_like during documentation building the jit scripting of functions needs to be disabled integrating labels_array_like into int_array_like cleaned code, removed unnecessary dispatch operations remove TODOs in cg_cache
feee5b7
to
3f8369f
Compare
I document here a behavior of TorchScript which seems to me like a bug. This code here to deal with the TorchScript inferring a @property
def selected_keys(self) -> List[Union[Labels, None]]:
if torch_jit_is_scripting():
if torch.jit.isinstance(self._selected_keys, List[Union[Labels, None]]):
return self._selected_keys
else:
selected_keys_: List[Union[None, Labels]] = [
torch_jit_annotate(Union[None, Labels], None)
] * len(self._selected_keys)
return selected_keys_
return self._selected_keys returned me a weird object instead of a List of Labels.
This is the variable in the non scripted one
In the first place I don't know why TorchScript complains when returning the variable inside a class method. The function that computes the value was totally fine working before. Changing the type to UPDATE: Changing the type to if torch_jit_is_scripting():
if torch_jit_isinstance(selected_keys, List[None]):
pass
elif torch_jit_isinstance(selected_keys, List[Labels]):
keys_1_entries, keys_2_entries, keys_out = _apply_key_selection(
keys_1_entries,
keys_2_entries,
keys_out,
selected_keys=selected_keys[iteration],
)
elif torch_jit_isinstance(selected_keys, List[Union[Labels, None]]):
selected_keys_i = selected_keys[iteration]
if selected_keys_i is not None:
keys_1_entries, keys_2_entries, keys_out = _apply_key_selection(
keys_1_entries,
keys_2_entries,
keys_out,
selected_keys=selected_keys_i,
)
else:
raise TypeError("Assumed object of type List[None], List[Labels] or "
"List[Union[Labels, None]] but got object {selected_keys}")
else:
selected_keys_i = selected_keys[iteration]
if selected_keys_i is not None:
keys_1_entries, keys_2_entries, keys_out = _apply_key_selection(
keys_1_entries,
keys_2_entries,
keys_out,
selected_keys=selected_keys_i,
) UPDATE
EDIT |
I will go on with the way I wrote in the end of last message. If someone finds a way to make it work please let me know. Here is the minimal breaking example import torch
from typing import Union, List
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self._foo: List[Union[None, str]] = [torch.jit.annotate(Union[None, str], None)]
@property
def foo(self) -> List[Union[None, str]]:
return self._foo
aobj = A()
print(torch.jit.script(aobj).foo) |
You need to annotate at the class level: class A(torch.nn.Module):
_foo: List[Optional[str]]
def __init__(self):
super().__init__()
self._foo = [None]
@property
def foo(self) -> List[Optional[str]]:
return self._foo |
From Dict of Tuple[int,int,int] to TensorMap all-deps, all-deps-torch are passing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really great thank you @agoscinski ! Only requested some minor changes for now
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jwa7 and me had a small discussion. In a future PR we would make max_angular
for the init for DensityCorrelations
optional if angular_cutoff
is specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The work Alex did looks great to me, and I made some changes on top. As you weren't involved @Luthaf please may you provide an impartial review? :D
python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
Please also add the public CG classes to the API reference documentation! |
python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py
Outdated
Show resolved
Hide resolved
python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll merge once CI is happy!
📚 Documentation preview 📚: https://rascaline--269.org.readthedocs.build/en/269/