Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch backend slow with pymc model #1110

Open
Ch0ronomato opened this issue Nov 30, 2024 · 6 comments
Open

Pytorch backend slow with pymc model #1110

Ch0ronomato opened this issue Nov 30, 2024 · 6 comments

Comments

@Ch0ronomato
Copy link
Contributor

Ch0ronomato commented Nov 30, 2024

Description

@ricardoV94 did a nice perf improvement in pymc-devs/pymc#7578 to try to speedup jitted backends. I tried out torch as well. The model performed quite slow.

mode t_sampling (seconds) manual measure (seconds)
NUMBA 2.483 11.346
PYTORCH (COMPILED) 206.503 270.188
PYTORCH (EAGER) 60.607 64.140

We need to investigate why

  1. Torch is so slow
  2. Torch compile is slower than eager mode

When doing perf evaluations, keep in mind that torch does a lot of caching. If you want a truly cache-less eval, you can either add torch.compiler.reset() or set the env variable to disable the dynamo cache (google it).

@Ch0ronomato
Copy link
Contributor Author

The script I used for reference

import arviz as az
import numpy as np
import multiprocessing
import pandas as pd
import pymc as pm
import pytensor as pt
import pytensor.tensor.random as ptr
import time

def main():
    # Load the radon dataset
    data = pd.read_csv(pm.get_data("radon.csv"))
    data["log_radon"] = data["log_radon"].astype(np.float64)
    county_idx, counties = pd.factorize(data.county)
    coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

    # Create a simple hierarchical model for the radon dataset
    with pm.Model(coords=coords, check_bounds=False) as model:
        intercept = pm.Normal("intercept", sigma=10)

        # County effects
        raw = pm.ZeroSumNormal("county_raw", dims="county")
        sd = pm.HalfNormal("county_sd")
        county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

        # Global floor effect
        floor_effect = pm.Normal("floor_effect", sigma=2)

        # County:floor interaction
        raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
        sd = pm.HalfNormal("county_floor_sd")
        county_floor_effect = pm.Deterministic(
            "county_floor_effect", raw * sd, dims="county"
        )

        mu = (
            intercept
            + county_effect[county_idx]
            + floor_effect * data.floor.values
            + county_floor_effect[county_idx] * data.floor.values
        )

        sigma = pm.HalfNormal("sigma", sigma=1.5)
        pm.Normal(
            "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
        )

    from pymc.model.transform.optimization import freeze_dims_and_data
    model = freeze_dims_and_data(model)
    for mode in ("NUMBA", "PYTORCH"):
        start = time.perf_counter()
        trace = pm.sample(
            model=model, 
            cores=1,
            chains=1,
            tune=500, 
            draws=500, 
            progressbar=False, 
            compute_convergence_checks=False, 
            return_inferencedata=False,
            compile_kwargs=dict(mode=mode)
        )
        end = time.perf_counter()
        idata = pm.to_inference_data(trace, model=model)
        print(az.summary(idata, kind="diagnostics"))
        print(mode, trace._report.t_sampling, end - start)


if __name__ == "__main__":
    multiprocessing.freeze_support()
    main()

@Ch0ronomato
Copy link
Contributor Author

I called compile logp and dlogp as well to narrow down the time

mode method time
PYTORCH logp 5.891
PYTORCH dlogp 4.054
NUMBA logp 2.376
NUMBA dlogp 2.244

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 30, 2024

You have to call it once (perhaps assert they output the same) to jit compile and then only timeit

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 can you assign this issue to me by chance?

I profiled a bit more. The logp and dlogp pytensor functions don't take long to generate, but executing it is slower. Both numba and torch do show that it gets faster if you execute it multiple times, but numba is much faster

mode method time
PYTORCH logp.call 4.8469
PYTORCH logp.call 0.0013
PYTORCH logp.call 0.0008
PYTORCH dlogp.call 4.1302
PYTORCH dlogp.call 0.0016
PYTORCH dlogp.call 0.0014
NUMBA logp.call 2.8590
NUMBA logp.call 0.0001
NUMBA logp.call 0.0001
NUMBA dlogp.call 3.7824
NUMBA dlogp.call 0.0059
NUMBA dlogp.call 0.0001

I'm seeing some interesting data in the .explain call, it looks like join and alloc create a few different graph breaks, and then a lot of data dependent stuff happening as well. Seeing that just running logp once multiple times has the runtime reduce makes me thing that there is a bit more recomplication happening. Attached is the full dump of the torch compile logs, but here is the header that just outlines what happened to the graph

Graph Count: 8
Graph Break Count: 7
Op Count: 198
Break Reasons:
  Break Reason 1:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 35 in pytorch_funcified_fgraph>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 2:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 3:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 75 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_35>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 4:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 133 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_75>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 77 in inc_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in indices_from_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in <genexpr>>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 131 in convert_indices>
  Break Reason 5:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 81 in torch_dynamo_resume_in_inc_subtensor_at_78>
  Break Reason 6:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 81 in torch_dynamo_resume_in_inc_subtensor_at_81>
  Break Reason 7:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 193 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_133>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 77 in inc_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in indices_from_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in <genexpr>>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 131 in convert_indices>

explaination.txt

The first thing I'm gonna do is clean up some of the warnings. These timings are only possible because of sending the warning logs out to dev null, if you don't the timings balloon a bit.

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 1, 2024

The first time is the compilation, it's not as relevant since it's a one time thing. If we're recompiling multiple times that's a different thing.

Also are you using %timeit after the first call? A single call is going to be noisy.

Did you confirm the outputs match?

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 1, 2024

Those breaks are interesting but most are not data dependent? Like the slice is constant in this model. Can you enable that capture scalar outputs option?

Also are you freezing the model data and dims like in the original example? When we have static shapes we could forward those to the dispatch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants