Skip to content

Commit

Permalink
Raise pytorch dependency to 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
martenlienen committed Nov 10, 2023
1 parent 1218aea commit f10e2b3
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 20 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 0.2.0 - 2023-11-10

### Changed

- Removed compatibility with pytorch 1.x

## 0.1.9 - 2023-08-29

### Added
Expand Down
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
solver = to.AutoDiffAdjoint(step_method, step_size_controller)
jit_solver = torch.compile(solver)

# For pytorch versions < 2.0, use the older TorchScript compiler
#jit_solver = torch.jit.script(solver)

sol = jit_solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval))
print(sol.stats)
# => {'n_f_evals': tensor([26, 26]), 'n_steps': tensor([4, 2]),
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ classifiers = [
]
dependencies = [
"sympy ~= 1.10",
"torch >= 1.11",
"torch >= 2.0",
"torchtyping ~= 0.1.4",
# functorch has been integrated into pytorch and this is just a dummy dependency for
# compatibility with pytorch <= 1.13
"functorch",
]

[project.optional-dependencies]
Expand Down
17 changes: 5 additions & 12 deletions tests/torch_script_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,12 @@
torch_version = version.parse(torch.__version__)


def compile(module):
if torch_version.major < 2:
return torch.jit.script(module)
else:
return torch.compile(module)


@pytest.mark.parametrize("step_method", [Dopri5, Heun, Tsit5, Euler])
def test_can_be_jitted_with_torch_script(step_method):
_, term, problem = get_problem("sine", [[0.1, 0.15, 1.0], [1.0, 1.9, 2.0]])
step_size_controller = IntegralController(1e-3, 1e-3, term=term)
adjoint = AutoDiffAdjoint(step_method(term), step_size_controller)
jitted = compile(adjoint)
jitted = torch.compile(adjoint)

dt0 = torch.tensor([0.01, 0.01]) if step_method is Euler else None
solution = adjoint.solve(problem, dt0=dt0)
Expand All @@ -33,9 +26,9 @@ def test_can_be_jitted_with_torch_script(step_method):

methods = [Dopri5, Heun, Tsit5]
v = torch_version
if (v.major, v.minor) not in [(1, 13), (2, 0)]:
# In pytorch 1.13.0 and 2.0, Euler triggers an internal error in the JIT compiler
# specifically in this next test, so we just exclude it
if (v.major, v.minor) not in [(2, 0)]:
# In pytorch 2.0, Euler triggers an internal error in the JIT compiler specifically
# in this next test, so we just exclude it
methods.append(Euler)


Expand All @@ -51,7 +44,7 @@ def test_passing_term_dynamically_equals_fixed_term(step_method):

controller_jit = IntegralController(1e-3, 1e-3, term=term)
adjoint_jit = AutoDiffAdjoint(step_method(term), controller_jit)
solution_jit = compile(adjoint_jit).solve(problem, dt0=dt0)
solution_jit = torch.compile(adjoint_jit).solve(problem, dt0=dt0)

assert solution.ts == approx(solution_jit.ts)
assert solution.ys == approx(solution_jit.ys, abs=1e-3, rel=1e-3)
2 changes: 1 addition & 1 deletion torchode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A parallel ODE solver for PyTorch"""

__version__ = "0.1.9"
__version__ = "0.2.0"

from .adjoints import AutoDiffAdjoint, BacksolveAdjoint, JointBacksolveAdjoint
from .interface import register_method, solve_ivp
Expand Down

0 comments on commit f10e2b3

Please sign in to comment.