Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clebsch gordan submodule - implementation of TorchScript interface #269

Merged
merged 24 commits into from
Feb 21, 2024

Conversation

agoscinski
Copy link
Collaborator

@agoscinski agoscinski commented Dec 29, 2023


📚 Documentation preview 📚: https://rascaline--269.org.readthedocs.build/en/269/

@agoscinski
Copy link
Collaborator Author

@Luthaf The torch-tests pass locally. Probably related to new Rust version? I checked the Rust version in PR #268 rustc 1.74.1 (a28077b28 2023-12-04) where the test pass and here it is version rustc 1.75.0 (82e1608df 2023-12-21)

Base automatically changed from cg_restructure to master January 4, 2024 12:37
@Luthaf
Copy link
Member

Luthaf commented Jan 4, 2024

Hmm, the error is a bit weird. I'll try to debug it.

@Luthaf
Copy link
Member

Luthaf commented Jan 4, 2024

Should be fixed by metatensor/metatensor#438

Copy link

github-actions bot commented Jan 8, 2024

Here is a pre-built version of the code in this pull request: wheels.zip, you can install it locally by unzipping wheels.zip and using pip to install the file matching your system

@agoscinski
Copy link
Collaborator Author

Currently get TorchScript error that is hard to interpret (see below). By going through the torch compile tree I could identify that it is in the _parse_selected_keys function. Will look more into this

    @pytest.mark.parametrize("skip_redundant", [True, False])                                                                                                                                                                                                                                                                                                             
    def test_scriptability_correlate_density_angular_selection(                                                                                                                                                                                                                                                                                                           
        selected_keys: Labels,                                                                                                                                                                                                                                                                                                                                            
        skip_redundant: bool,                                                                                                                                                                                                                                                                                                                                             
    ):                                                                                                                                                                                                                                                                                                                                                                    
        """                                                                                                                                                                                                                                                                                                                                                               
        Tests that the correct angular channels are output based on the specified                                                                                                                                                                                                                                                                                         
        ``selected_keys``.                                                                                                                                                                                                                                                                                                                                                
        """                                                                                                                                                                                                                                                                                                                                                               
        frames = h2o_isolated()                                                                                                                                                                                                                                                                                                                                           
        nu_1 = spherical_expansion(frames)                                                                                                                                                                                                                                                                                                                                
>       scripted_correlate_density = torch.jit.script(correlate_density)                                                                                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                                                                                                          
tests/utils/correlate_density.py:54:                                                                                                                                                                                                                                                                                                                                      
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/_script.py:1395: in script                                                                                                                                                                                                                                                                                  
    fn = torch._C._jit_script_compile(                                                                                                                                                                                                                                                                                                                                    
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/_recursive.py:1003: in try_compile_fn                                                                                                                                                                                                                                                                       
    return torch.jit.script(fn, _rcb=rcb)                                                                                                                                                                                                                                                                                                                                 
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/_script.py:1395: in script                                                                                                                                                                                                                                                                                  
    fn = torch._C._jit_script_compile(                                                                                                                                                                                                                                                                                                                                    
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/_recursive.py:1003: in try_compile_fn                                                                                                                                                                                                                                                                       
    return torch.jit.script(fn, _rcb=rcb)                                                                                                                                                                                                                                                                                                                                 
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/_script.py:1392: in script                                                                                                                                                                                                                                                                                  
    ast = get_jit_def(obj, obj.__name__)                                                                                                                                                                                                                                                                                                                                  
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:372: in get_jit_def                                                                                                                                                                                                                                                                             
    return build_def(                                                                                                                                                                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:433: in build_def                                                                                                                                                                                                                                                                               
    return Def(Ident(r, def_name), decl, build_stmts(ctx, body))                                                                                                                                                                                                                                                                                                          
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:195: in build_stmts                                                                                                                                                                                                                                                                             
    stmts = [build_stmt(ctx, s) for s in stmts]                                                                                                                                                                                                                                                                                                                           
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:195: in <listcomp>                                                                                                                                                                                                                                                                              
    stmts = [build_stmt(ctx, s) for s in stmts]                                                                                                                                                                                                                                                                                                                           
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                                                                                                                                                
    return method(ctx, node)                                                                                                                                                                                                                                                                                                                                              
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:773: in build_For                                                                                                                                                                                                                                                                               
    build_stmts(ctx, stmt.body),                                                                                                                                                                                                                                                                                                                                          
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:195: in build_stmts                                                                                                                                                                                                                                                                             
    stmts = [build_stmt(ctx, s) for s in stmts]                                                                                                                                                                                                                                                                                                                           
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:195: in <listcomp>                                                                                                                                                                                                                                                                              
    stmts = [build_stmt(ctx, s) for s in stmts]                                                                                                                                                                                                                                                                                                                           
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                                                                                                                                                
    return method(ctx, node)                                                                                                                                                                                                                                                                                                                                              
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:782: in build_If                                                                                                                                                                                                                                                                                
    build_stmts(ctx, stmt.body),                                                                                                                                                                                                                                                                                                                                          
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:195: in build_stmts                                                                                                                                                                                                                                                                             
    stmts = [build_stmt(ctx, s) for s in stmts]                                                                                                                                                                                                                                                                                                                           
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:195: in <listcomp>                                                                                                                                                                                                                                                                              
    stmts = [build_stmt(ctx, s) for s in stmts]                                                                                                                                                                                                                                                                                                                           
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                                                                                                                                                
    return method(ctx, node)                                                                                                                                                                                                                                                                                                                                              
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:781: in build_If                                                                                                                                                                                                                                                                                
    build_expr(ctx, stmt.test),                                                                                                                                                                                                                                                                                                                                           
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                                                                                                                                                
    return method(ctx, node)                                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:971: in build_UnaryOp                                                                                                                                                  
    sub_expr = build_expr(ctx, expr.operand)                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                       
    return method(ctx, node)                                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:890: in build_Call                                                                                                                                                     
    args = [build_expr(ctx, py_arg) for py_arg in expr.args]                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:890: in <listcomp>                                                                                                                                                     
    args = [build_expr(ctx, py_arg) for py_arg in expr.args]                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                       
    return method(ctx, node)                                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:1226: in build_ListComp                                                                                                                                                
    elt_expr = build_expr(ctx, stmt.elt)                                                                                                                                                                                                         
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                       
    return method(ctx, node)                                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:1014: in build_Compare                                                                                                                                                 
    operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]                                                                                                                                                                
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:1014: in <listcomp>                                                                                                                                                    
    operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]                                                                                                                                                                
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                       
    return method(ctx, node)                                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:1140: in build_List                                                                                                                                                    
    [build_expr(ctx, e) for e in expr.elts],                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:1140: in <listcomp>                                                                                                                                                    
    [build_expr(ctx, e) for e in expr.elts],                                                                                                                                                                                                     
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:406: in __call__                                                                                                                                                       
    return method(ctx, node)                                                                                                                                                                                                                     
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _                                                            
                                                                                                                                                                                                                                                 
ctx = <torch._sources.SourceContext object at 0x7fbafe6433b0>, expr = <ast.UnaryOp object at 0x7fbafe6222f0>                                                                                                                                     
                                                                                                                                                                                                                                                 
    @staticmethod                                                                                                                                                                                                                                
    def build_UnaryOp(ctx, expr):                                                                                                                                                                                                                
        sub_expr = build_expr(ctx, expr.operand)                                                                                                                                                                                                 
        op = type(expr.op)                                                                                                                                                                                                                       
        op_token = ExprBuilder.unop_map.get(op)                                                                                                                                                                                                  
        if op_token is None:                                                                                                                                                                                                                     
            raise NotSupportedError(                                                                                                                                                                                                             
>               expr.range(), "unsupported unary operator: " + op.__name__                                                                                                                                                                       
            )                                                                                                                                                                                                                                    
E           AttributeError: 'UnaryOp' object has no attribute 'range'                                                                                                                                                                            
                                                                                                                                                                                                                                                 
../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:976: AttributeError                                                                                                                                                    

@agoscinski
Copy link
Collaborator Author

agoscinski commented Feb 7, 2024

Debugging is actually quite easy there is a ctx variable when going up the debug stack this one can use get the currently compiled source code and there is also somewhere in the stack a node variable which has a lineo variable that gives you the line number. In this case the bug was just a + in a list definition [-1, +1]. Solution [-1, 1].

@Luthaf
Copy link
Member

Luthaf commented Feb 7, 2024

That's very interesting, I did not know about the ctx! Could you show a quick example of how you use it?

@agoscinski
Copy link
Collaborator Author

I have problems with TorchScript supporting the type List[Union[None, Labels]], I came to the conclusion that it is not possible to support Union types for dynamically allocated lists containing both examples of types in the list. Its not really a big issue, I will instead try to use empty Labels to replace None, and change the check here https://github.com/Luthaf/rascaline/blob/daab90a700b05e81314b88f07260a92c2eecb54a/python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py#L212-L218

@Luthaf
Copy link
Member

Luthaf commented Feb 7, 2024

I have problems with TorchScript supporting the type List[Union[None, Labels]]

Did you try with an explicit List[Optional[Labels]]? I know TorchScript has some special treatment for Optional internally

@agoscinski
Copy link
Collaborator Author

Did you try with an explicit List[Optional[Labels]]? I know TorchScript has some special treatment for Optional internally

Its not that much that it does not recognize the type. It is more that it does not allow to create an object of this type using any list functionalities to concatenate lists. This works perfectly fine

def foo(a: Labels):
    l: List[Union[None, Labels]] = [None, a]
    return l
torch.jit.script(foo)

But you cannot do this

def foo(a: Labels):
    l: List[Union[None, Labels]] = ([None] * 5) + [a]
    return l
torch.jit.script(foo)

def foo(a: Labels):
    l: List[Union[None, Labels]] = [None, a]
    l = [None] * 5
    l.append(a)
    return l
torch.jit.script(foo)

def foo(a: Labels):
    l: List[Union[None, Labels]] = [None, a]
    l = [None] * 6
    l[5] = a
    return l
torch.jit.script(foo)

@Luthaf
Copy link
Member

Luthaf commented Feb 7, 2024

This seems to work:

import torch
from typing import List, Union


class Labels:
    pass


def foo(a: Labels):
    l: List[Union[None, Labels]] = [torch.jit.annotate(Union[None, Labels], None)] * 5
    l.append(a)
    return l


torch.jit.script(foo)

But you need to torch.jit.annotate the list element, and it does not work on the whole list.

EDIT:

this is fine as well:

def foo(a: Labels):
    l = [torch.jit.annotate(Union[None, Labels], None)] * 5
    l.append(a)
    return l

@agoscinski
Copy link
Collaborator Author

That's very interesting, I did not know about the ctx! Could you show a quick example of how you use it?

So this worked for fixing the bug that was fixed with the commit d8f9906

Tor reproduce this I pushed a branch debug-unary-op

git checkout debug-unary-op
tox -e torch-tests -- --pdb

The debugging went like this

nd.py:406: in __call__
    return method(ctx, node)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

ctx = <torch._sources.SourceContext object at 0x7f557836f470>, expr = <ast.UnaryOp object at 0x7f557834b310>

    @staticmethod
    def build_UnaryOp(ctx, expr):
        sub_expr = build_expr(ctx, expr.operand)
        op = type(expr.op)
        op_token = ExprBuilder.unop_map.get(op)
        if op_token is None:
            raise NotSupportedError(
>               expr.range(), "unsupported unary operator: " + op.__name__
            )
E           AttributeError: 'UnaryOp' object has no attribute 'range'

../../.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py:976: AttributeError
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> entering PDB >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB post_mortem (IO-capturing turned off) >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /home/alexgo/code/rascaline/.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py(976)build_UnaryOp()
-> expr.range(), "unsupported unary operator: " + op.__name__
(Pdb) up
> /home/alexgo/code/rascaline/.tox/torch-tests/lib/python3.11/site-packages/torch/jit/frontend.py(406)__call__()
-> return method(ctx, node)
(Pdb) print(ctx.source)
def _parse_selected_keys(
    n_iterations: int,
    angular_cutoff: Optional[int] = None,
    selected_keys: Optional[Union[Labels, List[Labels]]] = None,
    like=None,
) -> List[Union[None, Labels]]:
    """
    Parses the `selected_keys` argument passed to public functions. Checks the
    values and returns a :py:class:`list` of :py:class:`Labels` objects, one for
    each iteration of CG combination.

    `like` is required if a new :py:class:`Labels` object is to be created by
    :py:mod:`_dispatch`.
    """
    # Check angular_cutoff arg
    if angular_cutoff is not None:
        if not isinstance(angular_cutoff, int):
            raise TypeError("`angular_cutoff` must be passed as an int")
        if angular_cutoff < 1:
            raise ValueError("`angular_cutoff` must be >= 1")

    if selected_keys is None:
        if angular_cutoff is None:  # no selections at all
            selected_keys = [None] * n_iterations
        else:
            # Create a key selection with all angular channels <= the specified
            # angular cutoff
            selected_keys = [
                Labels(
                    names=["spherical_harmonics_l"],
                    values=_dispatch.int_range_like(
                        0, angular_cutoff, like=like
                    ).reshape(-1, 1),
                )
            ] * n_iterations

    if isinstance(selected_keys, Labels):
        # Create a list, but only apply a key selection at the final iteration
        selected_keys = [None] * (n_iterations - 1) + [selected_keys]

    # Check the selected_keys
    if not isinstance(selected_keys, List):
        raise TypeError(
            "`selected_keys` must be a `Labels` or List[Union[None, `Labels`]]"
        )
    if not len(selected_keys) == n_iterations:
        raise ValueError(
            "`selected_keys` must be a List[Union[None, Labels]] of length"
            " `correlation_order` - 1"
        )
    if not _dispatch.all(
        [isinstance(val, (Labels, type(None))) for val in selected_keys]
    ):
        raise TypeError("`selected_keys` must be a Labels or List[Union[None, Labels]]")

    # Now iterate over each of the Labels (or None) in the list and check
    for slct in selected_keys:
        if slct is None:
            continue
        assert isinstance(slct, Labels)
        if not _dispatch.all(
            [
                name in ["spherical_harmonics_l", "inversion_sigma"]
                for name in slct.names
            ]
        ):
            raise ValueError(
                "specified key names in `selected_keys` must be either"
                " 'spherical_harmonics_l' or 'inversion_sigma'"
            )
        if "spherical_harmonics_l" in slct.names:
            if angular_cutoff is not None:
                if not _dispatch.all(
                    slct.column("spherical_harmonics_l") <= angular_cutoff
                ):
                    raise ValueError(
                        "specified angular channels in `selected_keys` must be <= the"
                        " specified `angular_cutoff`"
                    )
            if not _dispatch.all(
                [angular_l >= 0 for angular_l in slct.column("spherical_harmonics_l")]
            ):
                raise ValueError(
                    "specified angular channels in `selected_keys` must be >= 0"
                )
        if "inversion_sigma" in slct.names:
            if not _dispatch.all(
                [parity_s in [-1, +1] for parity_s in slct.column("inversion_sigma")]
            ):
                raise ValueError(
                    "specified parities in `selected_keys` must be -1 or +1"
                )

    return selected_keys
(Pdb) node.lineno
88

From that I could infer that it is the line 88 in the printed source code and since the operation that caused an error was an add, I inferred it was the +1

@agoscinski
Copy link
Collaborator Author

agoscinski commented Feb 9, 2024

Something very slightly annoying not sure if worth an issue, since it could potentially be easy fix

RuntimeError: 
Variable 'le' is annotated with type __torch__.torch.classes.metatensor.LabelsEntry (of Python compilation unit at: 0) but is being assigned to a value of type Any:
  File "<ipython-input-379-f5b994008d5f>", line 2
def foo(l:Labels):
    le: LabelsEntry = l[0]
    ~~ <--- HERE
    return le.values

workaround

def foo(l:Labels):
    le = l.values[0]
    return le
torch.jit.script(foo)

EDIT:
@Luthaf mentioned that one can do

def foo(l:Labels):
    le: LabelsEntry = l.entry(0)
    return le.values
torch.jit.script(foo)

@agoscinski
Copy link
Collaborator Author

this breaks TorchScript (segmentation fault)

def foo():
    f: List[List[Any]] = [torch.jit.annotate(Any,[5,1])]
    return f

torch.jit.script(foo)

@agoscinski
Copy link
Collaborator Author

agoscinski commented Feb 9, 2024

So TorchScript seems to have problems associating List[*] types to Any

def foo():
    l: List[Any] = [5,3]
torch.jit.script(foo) # works

def foo():
    l: List[List[Any]] = [[5,3]]
torch.jit.script(foo) # does not work

def foo():
    l: List[Any] = [[5,3]]
torch.jit.script(foo) # works

def foo():
    l: List[List[int]] = [[5,3]]
    return foo2(l)
def foo2(l: List[Any]):
    return l
torch.jit.script(foo) # does not work

@agoscinski
Copy link
Collaborator Author

agoscinski commented Feb 11, 2024

Building seems to fail because of required rustc update.

    error: package `ahash v0.8.8` cannot be built because it requires rustc 1.72.0 or newer, while the currently active rustc version is 1.65.0
    Either upgrade to rustc 1.72.0 or newer, or use
    cargo update -p [email protected] --precise ver

The documentation fails because I use torch.jit.script in the torch clebsch utils for scripting is_labels the instance check function for Labels.

autodoc: failed to import class 'torch.CalculatorModule' from module 'rascaline'; the following exception was raised:
Traceback (most recent call last):
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/sphinx/ext/autodoc/importer.py", line 69, in import_module
    return importlib.import_module(modname)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.7/x64/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/rascaline/torch/__init__.py", line 12, in <module>
    from . import utils  # noqa
    ^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/rascaline/torch/utils/__init__.py", line 3, in <module>
    from . import clebsch_gordan
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/rascaline/torch/utils/clebsch_gordan.py", line 36, in <module>
    is_labels = torch.jit.script(is_labels)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/_script.py", line 1395, in script
    fn = torch._C._jit_script_compile(
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/annotations.py", line 501, in try_ann_to_type
    return torch.jit._script._recursive_compile_class(ann, loc)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/_script.py", line 1568, in _recursive_compile_class
    return _compile_and_register_class(obj, rcb, _qual_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/_recursive.py", line 61, in _compile_and_register_class
    script_class = torch._C._jit_script_class_compile(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/annotations.py", line 453, in try_ann_to_type
    maybe_type = try_ann_to_type(a, loc)
reading sources... [ 89%] references/calculators/index
reading sources... [ 90%] references/calculators/lode-spherical-expansion
reading sources... [ 91%] references/calculators/neighbor-list
reading sources... [ 93%] references/calculators/soap-power-spectrum
reading sources... [ 94%] references/calculators/soap-radial-spectrum
reading sources... [ 96%] references/calculators/sorted-distances
reading sources... [ 97%] references/calculators/spherical-expansion
reading sources... [ 99%] references/calculators/spherical-expansion-by-pair
reading sources... [100%] references/index
                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/rascaline/rascaline/.tox/docs/lib/python3.11/site-packages/torch/jit/annotations.py", line 415, in try_ann_to_type
    return TupleType([try_ann_to_type(a, loc) for a in ann_args])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Can not create tuple with None type

Don't know yet how to fix this documentation error.
EDIT: I think I found how it is solved in metatensor-operations
https://github.com/lab-cosmo/metatensor/blob/5cfdbcdcefe850481d3b4abadc8385be64b70e5f/python/metatensor-torch/metatensor/torch/__init__.py#L9-L12

test torch array backend and add required dispatches

add _classes.py forgot to push

add torchscript test for correlate_density

start torchscript test

fixing TorchScript UnaryOp bug

checkpoint: all-deps test pass and made progress on TorchScriptabilty

checkpoint2: all-deps test pass and made progress on TorchScriptabilty

checkpoint3: all-deps test pass and made progress on TorchScriptabilty

checkpoint4: all-deps fails, invalid key pair (l1=0,l2=1,lam=2) is accessed in sparse_combine

checkpoint5: all-deps test pass and made progress on TorchScriptabilty

checkpoint6: all-deps test pass and made progress on TorchScriptabilty

- abstracted out Dict[Tuple[int,int,int], Array] and
  Dict[Tuple[int,int,int], Dict[Tuple[int,int,int], Array]] into custom
  classes with utilities that allow similar access

checkpoint7: all-deps test pass and made progress on TorchScriptabilty

- sparse property in ClebschGordanReal is determined by type of coeffs
  because we need to do anyway isinstance checks for TorchScript to
  distinguish the two types in the different functions
  Union[SparseCgDict, DenseCgDict] -> SparseCgDict
- replacing input parameter `return_empty_array` in `sparse combine`
  function in _cg_cache.py by `empty_combine` function that to distinguish
  None type for TorchScript

checkpoint8: all-deps test pass and made progress on TorchScriptabilty

- Labels.insert cannot broadcast in metatensor.torch, correct shape has to be given
- torch.tensors of type int32 cannot be converted to lists using tolist() pytorch/pytorch#76295 therefore added tolist dispatch function that first converts the array to int64

checkpoint9: all-deps test pass and made progress on TorchScriptabilty, torch-test test_torch_script_correlate_density_angular_selection passes

- changed to comlex type until real conversion otherwise (complex, real) operation

checkpoint10: all-deps test pass and made progress on TorchScriptabilty, torch-test test_torch_script_correlate_density_angular_selection passes

- made `like` parameter in _parse_selected_keys

all-deps test and torch-tests pass

- all isinstance check are moved to the false branch of jit.is_scripting

linting fixes and partial format

fix doctest

format

remove _dispatch.array_like

during documentation building the jit scripting of functions needs to be disabled

integrating labels_array_like into int_array_like

cleaned code, removed unnecessary dispatch operations

remove TODOs in cg_cache
@agoscinski
Copy link
Collaborator Author

agoscinski commented Feb 15, 2024

I document here a behavior of TorchScript which seems to me like a bug. This code here to deal with the TorchScript inferring a List[None] as List[Labels, None] (see discussion above)

    @property
    def selected_keys(self) -> List[Union[Labels, None]]:
        if torch_jit_is_scripting():
            if torch.jit.isinstance(self._selected_keys, List[Union[Labels, None]]):
                return self._selected_keys
            else:
                selected_keys_: List[Union[None, Labels]] = [
                    torch_jit_annotate(Union[None, Labels], None)
                ] * len(self._selected_keys)
                return selected_keys_
        return self._selected_keys

returned me a weird object instead of a List of Labels.

(Pdb) scripted_corr_calculator.selected_keys                                                                                                                                                                             
<property object at 0x7f8d8d83d710>                                                                                                                                                                                      
(Pdb) scripted_corr_calculator.selected_keys.__dir__()                                                                                                                                                                   
['__new__', '__getattribute__', '__get__', '__set__', '__delete__', '__init__', 'getter', 'setter', 'deleter', '__set_name__', 'fget', 'fset', 'fdel', '__doc__', '__isabstractmethod__', '__repr__', '__hash__', '__str_
_', '__setattr__', '__delattr__', '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__', '__reduce_ex__', '__reduce__', '__getstate__', '__subclasshook__', '__init_subclass__', '__format__', '__sizeof__', '__dir_
_', '__class__']                                                                                                                                                                                                         

This is the variable in the non scripted one

(Pdb) corr_calculator.selected_keys                                                                                                                                                                                      
[Labels(
    spherical_harmonics_l
              1
              3
)]

In the first place I don't know why TorchScript complains when returning the variable inside a class method. The function that computes the value was totally fine working before.

Changing the type to Union[List[Union[Labels, None]], List[None], List[Labels]] seems to work.

UPDATE:

Changing the type to Union[List[Union[Labels, None]], List[None], List[Labels]] causes problems later in the isinstance branching during runtime. I think it makes sense, because a compiler needs to change the list type which might be problematic. This is the code I tried

        if torch_jit_is_scripting():
            if torch_jit_isinstance(selected_keys, List[None]):
                pass
            elif torch_jit_isinstance(selected_keys, List[Labels]):
                keys_1_entries, keys_2_entries, keys_out = _apply_key_selection(
                    keys_1_entries,
                    keys_2_entries,
                    keys_out,
                    selected_keys=selected_keys[iteration],
                )
            elif torch_jit_isinstance(selected_keys, List[Union[Labels, None]]):
                selected_keys_i = selected_keys[iteration]
                if selected_keys_i is not None:
                    keys_1_entries, keys_2_entries, keys_out = _apply_key_selection(
                        keys_1_entries,
                        keys_2_entries,
                        keys_out,
                        selected_keys=selected_keys_i,
                    )
            else:
                raise TypeError("Assumed object of type List[None], List[Labels] or "
                        "List[Union[Labels, None]] but got object {selected_keys}")
        else:
            selected_keys_i = selected_keys[iteration]
            if selected_keys_i is not None:
                keys_1_entries, keys_2_entries, keys_out = _apply_key_selection(
                    keys_1_entries,
                    keys_2_entries,
                    keys_out,
                    selected_keys=selected_keys_i,
                )

UPDATE

Seems like the torch_jit_annotate only works properly while scripted. Since we do not script it anymore
In the _parse_selected_keys function I did the type annotation correctly for the `torch_jit_is_scripting path"

WRONG

EDIT
So any kind of jit annotate I applied in the __init__ did not have any effect and was ignored by TorchScript, so use internally Labels.empty("_") to represent Nones, and on access of the public API (correlation_calculator.selected_keys) I convert the empty Labels to Nones back.

@agoscinski
Copy link
Collaborator Author

I will go on with the way I wrote in the end of last message. If someone finds a way to make it work please let me know. Here is the minimal breaking example

import torch
from typing import Union, List

class A(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._foo: List[Union[None, str]] = [torch.jit.annotate(Union[None, str], None)]
    @property
    def foo(self) -> List[Union[None, str]]:
        return self._foo


aobj = A()
print(torch.jit.script(aobj).foo)

@Luthaf
Copy link
Member

Luthaf commented Feb 15, 2024

You need to annotate at the class level:

class A(torch.nn.Module):
    _foo: List[Optional[str]]

    def __init__(self):
        super().__init__()
        self._foo = [None]

    @property
    def foo(self) -> List[Optional[str]]:
        return self._foo

From Dict of Tuple[int,int,int] to TensorMap
all-deps, all-deps-torch are passing
Copy link
Member

@jwa7 jwa7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really great thank you @agoscinski ! Only requested some minor changes for now

Copy link
Collaborator Author

@agoscinski agoscinski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jwa7 and me had a small discussion. In a future PR we would make max_angular for the init for DensityCorrelations optional if angular_cutoff is specified.

Copy link
Member

@jwa7 jwa7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The work Alex did looks great to me, and I made some changes on top. As you weren't involved @Luthaf please may you provide an impartial review? :D

@Luthaf
Copy link
Member

Luthaf commented Feb 16, 2024

Please also add the public CG classes to the API reference documentation!

@jwa7 jwa7 requested a review from Luthaf February 19, 2024 16:58
python/rascaline-torch/tests/utils/correlate_density.py Outdated Show resolved Hide resolved
python/rascaline-torch/tests/utils/correlate_density.py Outdated Show resolved Hide resolved
python/rascaline-torch/tests/utils/data/h2o_isolated.xyz Outdated Show resolved Hide resolved
python/rascaline/rascaline/utils/_dispatch.py Outdated Show resolved Hide resolved
@jwa7 jwa7 requested a review from Luthaf February 21, 2024 13:39
Copy link
Member

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll merge once CI is happy!

@jwa7 jwa7 merged commit ddd8802 into master Feb 21, 2024
25 checks passed
@jwa7 jwa7 deleted the cg-torchscript branch February 21, 2024 17:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants