Skip to content

Commit

Permalink
Add stablediffusion inference reference model (#8027)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Sep 26, 2024
1 parent f4c8b73 commit f088810
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 13 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ on:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
- 'experimental/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
- 'experimental/**'
workflow_dispatch:

concurrency:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_upstream_image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
- 'experimental/**'
workflow_dispatch:
jobs:
build:
Expand Down
20 changes: 20 additions & 0 deletions experimental/reference_models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
This directory will contain a list of reference models that
we have optimized and runs well on TPU.

Contents of this directory is organized in the following way:

* Every subdirectory is a self-contained model, as a seperate pip package.

* Each subdirectory must has a README indicating:
** is this training or inference
** on what devices it has been tested / developed
** instructions on running.

* Every subdirectory contains it's own set of shell scripts do with all the flags
set for the best performance that we turned, be it training or inference.

* Each subdirectory can specify their own dependencies, and can depend on models / layers
defined in well-known OSS libraries, such as HuggingFace transformers. But should ideally not depend on each other.

* (Optional) Each model can also have a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU.

1 change: 1 addition & 0 deletions experimental/torch_xla2/test/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ExportTest(unittest.TestCase):

def setUp(self):
torch.manual_seed(0)
torch_xla2.enable_accuracy_mode()

def test_interpolate(self):

Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase):

def setUp(self):
self.env = torch_xla2.tensor.Environment()
torch_xla2.enable_accuracy_mode()

@parameterized.named_parameters(
('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
Expand Down
4 changes: 3 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
instantiate_device_type_tests, ops)
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
import torch_xla2


skiplist = {
Expand Down Expand Up @@ -259,7 +260,8 @@ def setUpClass(cls):
print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test))

def setUp(self):
self.env = tensor.Environment()
self.env = torch_xla2.default_env()
torch_xla2.enable_accuracy_mode()
#self.env.config.debug_accuracy_for_each_op = True
torch.manual_seed(0)

Expand Down
15 changes: 13 additions & 2 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from jax._src import xla_bridge
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
jax.config.update('jax_enable_x64', True)

# torch_xla2:oss-begin
old_pjrt_options = jax.config.jax_pjrt_client_create_options
Expand Down Expand Up @@ -80,4 +79,16 @@ def disable_globally():
unsupported_dtype=unsupported_dtype)

import jax
torch._register_device_module('jax', jax)
torch._register_device_module('jax', jax)


def enable_accuracy_mode():
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'highest')
default_env().config.internal_respect_torch_return_dtypes = True


def enable_performance_mode():
jax.config.update('jax_enable_x64', False)
jax.config.update('jax_default_matmul_precision', 'default')
default_env().config.internal_respect_torch_return_dtypes = False
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ class Configuration:
# device
treat_cuda_as_jax_device: bool = True
use_torch_native_for_cpu_tensor: bool = False
internal_respect_torch_return_dtypes: bool = False
75 changes: 70 additions & 5 deletions experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,25 @@ def set_one(module, prefix):
set_one(m, '')


class JittableModule:
class JittableModule(torch.nn.Module):

def __init__(self, m: torch.nn.Module):
# TODO: add statedict loading hook

def __init__(self, m: torch.nn.Module, extra_jit_args={}):
super().__init__()
self.params, self.buffers = extract_all_buffers(m)
self._model = m
self._jitted = {}

self._extra_jit_args = extra_jit_args


def __call__(self, *args, **kwargs):
res = self._model(*args, **kwargs)
return res
return self.forward(*args, **kwargs)


def functional_call(
self, method_name, params, buffers, args, kwargs=None):
self, method_name, params, buffers, *args, **kwargs):
kwargs = kwargs or {}
params_copy = copy.copy(params)
params_copy.update(buffers)
Expand All @@ -68,6 +74,65 @@ def functional_call(
return res


def forward(self, *args, **kwargs):
if 'forward' not in self._jitted:
jitted = jax_jit(
functools.partial(self.functional_call, 'forward'),
kwargs_for_jax_jit=self._extra_jit_args,
)
def jitted_forward(*args, **kwargs):
return jitted(self.params, self.buffers, *args, **kwargs)
self._jitted['forward'] = jitted_forward
return self._jitted['forward'](*args, **kwargs)

def __getattr__(self, key):
if key == '_model':
return super().__getattr__(key)
if key in self._jitted:
return self._jitted[key]
return getattr(self._model, key)

def make_jitted(self, key):
jitted = jax_jit(
functools.partial(self.functional_call, key),
kwargs_for_jax_jit=self._extra_jit_args)
def call(*args, **kwargs):
return jitted(self.params, self.buffers, *args, **kwargs)
self._jitted[key] = call





class CompileMixin:

def functional_call(
self, method, params, buffers, *args, **kwargs):
kwargs = kwargs or {}
params_copy = copy.copy(params)
params_copy.update(buffers)
with torch_stateless._reparametrize_module(self, params_copy):
res = method(*args, **kwargs)
return res

def jit(self, method):
jitted = jax_jit(functools.partial(self.functional_call, method_name))
def call(*args, **kwargs):
return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
return call


def compile_nn_module(m: torch.nn.Module, methods=None):
if methods is None:
methods = ['forward']

new_parent = type(
m.__class__.__name__ + '_with_CompileMixin',
(CompileMixin, m.__class__),
)
m.__class__ = NewParent


def _torch_view(t: JaxValue) -> TorchValue:
# t is an object from jax land
# view it as-if it's a torch land object
Expand Down
4 changes: 2 additions & 2 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
the_tensor = the_tensor.to(new_dtype)
jax_device = self.get_as_jax_device(new_device)
if jax_device:
with jax.default_device(jax_device):
arr = t2j(the_tensor)
arr = t2j(the_tensor)
arr = jax.device_put(arr, jax_device)
else:
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
return torch_tensor.to(new_device)
Expand Down

0 comments on commit f088810

Please sign in to comment.