Skip to content

Commit

Permalink
Revert "[inductor] Fix bug handling output_strides in fx graph cache (p…
Browse files Browse the repository at this point in the history
…ytorch#112041)"

This reverts commit 3d2041b.

Reverted pytorch#112041 on behalf of https://github.com/ZainRizvi due to fbcode failures ([comment](pytorch#112041 (comment)))
  • Loading branch information
pytorchmergebot authored and Skylion007 committed Nov 14, 2023
1 parent 373a2b3 commit 7152a73
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 58 deletions.
46 changes: 13 additions & 33 deletions test/inductor/test_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,6 @@ def test_codecache_fork():
_run_codecache_test("fork")


class MyModelConv2d(torch.nn.Module):
def __init__(self, dim=512):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)

def forward(self, x):
x = self.conv1(x)
torch._dynamo.graph_break()
x = self.conv2(x)
return x


@instantiate_parametrized_tests
class TestFxGraphCache(TestCase):
@classmethod
Expand Down Expand Up @@ -137,40 +124,33 @@ def fn(x, y):
@requires_triton()
@config.patch({"fx_graph_cache": True})
@parametrize("device", ("cuda", "cpu"))
@parametrize("dtype", (torch.float32, torch.float16))
@parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_load_model(self, device, dtype):
"""
Verify that we can populate and load models from the cache.
"""
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("requires CUDA")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")

def fn(mod, x):
mod.zero_grad()
mod(x).sum().backward()
return [p.grad for p in mod.parameters()]
model = MyModel().to(dtype=dtype, device=device)

compiled_fn = torch.compile(fn, dynamic=False)
a = torch.rand(10, 10, dtype=dtype, device=device)

mod = MyModelConv2d().to(device=device, dtype=dtype)
inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype)
compiled_model = torch.compile(model, dynamic=False)

# The first call should see all cache misses.
counters.clear()
grads1 = compiled_fn(mod, inp)
self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
# A first call shold miss in the cache.
self.assertEqual(model(a), compiled_model(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)

# The second should see all hits. (First reset so in-memory guards
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
counters.clear()
torch._dynamo.reset()
grads2 = compiled_fn(mod, inp)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)

# And the results should be the same.
self.assertEqual(grads1, grads2)
self.assertEqual(model(a), compiled_model(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)


class TestFxGraphCacheHashing(TestCase):
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ class CompiledFxGraph:
mutated_inputs: Set[str] = field(default_factory=set)
mutated_input_idxs: Set[int] = field(default_factory=set)
constants: Dict[str, torch.Tensor] = field(default_factory=dict)
output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None

_boxed_call: Optional[bool] = None

Expand Down
32 changes: 8 additions & 24 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,7 @@
import warnings
from itertools import count

from typing import (
Any,
Callable,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
from unittest import mock

from functorch.compile import min_cut_rematerialization_partition
Expand Down Expand Up @@ -394,12 +384,6 @@ def compile_fx_inner(

log.debug("FX codegen and compilation took %.3fs", time.time() - start)

# Return the output strides to the caller via TracingContext
context = torch._guards.TracingContext.get()
if context is not None and context.output_strides is not None:
assert len(context.output_strides) == 0
context.output_strides.extend(compiled_graph.output_strides)

if aot_mode:
return compiled_graph

Expand Down Expand Up @@ -598,19 +582,20 @@ def fx_codegen_and_compile(
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
output_strides: List[Optional[Tuple[int, ...]]] = []
if graph.graph_outputs is not None:
# We'll put the output strides in the compiled graph so we
# can later return them to the caller via TracingContext
context = torch._guards.TracingContext.get()
if context is not None and context.output_strides is not None:
# Return the output strides to the caller via TracingContext
assert len(context.output_strides) == 0
assert graph.graph_outputs is not None
for out in graph.graph_outputs:
if hasattr(out, "layout"):
output_strides.append(
context.output_strides.append(
tuple( # type: ignore[arg-type]
V.graph.sizevars.size_hint(s) for s in out.layout.stride
)
)
else:
output_strides.append(None)
context.output_strides.append(None)

compiled_fn = graph.compile_to_fn()

Expand All @@ -630,7 +615,6 @@ def fx_codegen_and_compile(
mutated_inputs=graph.mutated_inputs,
mutated_input_idxs=set(graph.mutated_input_idxs),
constants=graph.constants,
output_strides=output_strides,
)
return compiled_graph

Expand Down

0 comments on commit 7152a73

Please sign in to comment.