-
Notifications
You must be signed in to change notification settings - Fork 11
/
open_registration_example.py
134 lines (112 loc) · 4.69 KB
/
open_registration_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from utils.custom_device_mode import foo_module, enable_foo_device
# This file contains an example of how to create a custom device extension
# in PyTorch, through the dispatcher.
# It also shows what two possible user API's for custom devices look like. Either:
# (1) Expose your custom device as an object, device=my_device_obj
# (2) Allow users to directly use device strings: device="my_device"
# Running this file prints the following:
# (Correctly) unable to create tensor on device='bar'
# (Correctly) unable to create tensor on device='foo:2'
# Creating x on device 'foo:0'
# Custom aten::empty.memory_format() called!
# Custom allocator's allocate() called!
# Creating y on device 'foo:0'
# Custom aten::empty.memory_format() called!
# Custom allocator's allocate() called!
# Test START
# x.device=foo:0, x.is_cpu=False
# y.device=foo:0, y.is_cpu=False
# Calling z = x + y
# Custom aten::add.Tensor() called!
# Custom aten::empty.memory_format() called!
# Custom allocator's allocate() called!
# z.device=foo:0, z.is_cpu=False
# Calling z = z.to(device="cpu")
# Custom aten::_copy_from() called!
# z_cpu.device=cpu, z_cpu.is_cpu=True
# Calling z2 = z_cpu + z_cpu
# Test END
# Custom allocator's delete() called!
# Creating x on device 'foo:1'
# Custom aten::empty.memory_format() called!
# Custom allocator's allocate() called!
# Creating y on device 'foo:1'
# Custom aten::empty.memory_format() called!
# Custom allocator's allocate() called!
# Test START
# x.device=foo:0, x.is_cpu=False
# y.device=foo:0, y.is_cpu=False
# Calling z = x + y
# Custom aten::add.Tensor() called!
# Custom aten::empty.memory_format() called!
# Custom allocator's allocate() called!
# z.device=foo:0, z.is_cpu=False
# Calling z = z.to(device="cpu")
# Custom aten::_copy_from() called!
# z_cpu.device=cpu, z_cpu.is_cpu=True
# Calling z2 = z_cpu + z_cpu
# Test END
# Custom allocator's delete() called!
# Custom allocator's delete() called!
# Custom allocator's delete() called!
# Custom allocator's delete() called!
# Custom allocator's delete() called!
def test(x, y):
print()
print("Test START")
# Check that our device is correct.
print(f'x.device={x.device}, x.is_cpu={x.is_cpu}')
print(f'y.device={y.device}, y.is_cpu={y.is_cpu}')
# calls out custom add kernel, registered to the dispatcher
print('Calling z = x + y')
z = x + y
print(f'z.device={z.device}, z.is_cpu={z.is_cpu}')
print('Calling z = z.to(device="cpu")')
z_cpu = z.to(device='cpu')
# Check that our cross-device copy correctly copied the data to cpu
print(f'z_cpu.device={z_cpu.device}, z_cpu.is_cpu={z_cpu.is_cpu}')
# Confirm that calling the add kernel no longer invokes our custom kernel,
# since we're using CPU t4ensors.
print('Calling z2 = z_cpu + z_cpu')
z2 = z_cpu + z_cpu
print("Test END")
print()
# Option 1: Use torch.register_privateuse1_backend("foo"), which will allow
# "foo" as a device string to work seamlessly with pytorch's API's.
# You may need a more recent nightly of PyTorch for this.
torch.register_privateuse1_backend('foo')
# Show that in general, passing in a custom device string will fail.
try:
x = torch.ones(4, 4, device='bar')
exit("Error: you should not be able to make a tensor on an arbitrary 'bar' device.")
except RuntimeError as e:
print("(Correctly) unable to create tensor on device='bar'")
# Show that in general, passing in a custom device string will fail.
try:
x = torch.ones(4, 4, device='foo:2')
exit("Error: the foo device only has two valid indices: foo:0 and foo:1")
except RuntimeError as e:
print("(Correctly) unable to create tensor on device='foo:2'")
print("Creating x on device 'foo:0'")
x1 = torch.ones(4, 4, device='foo:0')
print("Creating y on device 'foo:0'")
y1 = torch.ones(4, 4, device='foo:0')
test(x1, y1)
# Option 2: Directly expose a custom device object
# You can pass an optional index arg, specifying which device index to use.
foo_device1 = foo_module.custom_device(1)
print("Creating x on device 'foo:1'")
x2 = torch.ones(4, 4, device=foo_device1)
print("Creating y on device 'foo:1'")
y2 = torch.ones(4, 4, device=foo_device1)
# Option 3: Enable a TorchFunctionMode object in user land,
# that will convert `device="foo"` calls into our custom device objects automatically.
# Option 1 is strictly better here (in particular, printing a.device() will still
# print "privateuseone" instead of your custom device name). Mostly showing this option because:
# (a) Torch Function Modes have been around for longer, and the API in Option 1
# is only available on a more recent nightly.
# (b) This is a cool example of how powerful torch_function and torch_dispatch modes can be!
# holder = enable_foo_device()
# del _holder
test(x2, y2)