Skip to content

Commit

Permalink
Resnet fp16 training with fp32 master weight copy (tinygrad#4144)
Browse files Browse the repository at this point in the history
* add casts to layers

* FLOAT flag

* detach

* no_grad for eval

* whitespace

* explicit fp32 initialization

* oops

* whitespace

* put back config['DEFAULT_FLOAT']

* bad

* live dangerously (don't hide bugs)

* don't bundle changes

---------

Co-authored-by: chenyu <[email protected]>
  • Loading branch information
chaosagent and chenyuxyz authored Apr 14, 2024
1 parent e20d6f9 commit 593c90d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
7 changes: 4 additions & 3 deletions examples/hlb_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, mome
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
self.num_devices = num_devices

if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
if affine: self.weight, self.bias = Tensor.ones(sz, dtype=dtypes.float32), Tensor.zeros(sz, dtype=dtypes.float32)
else: self.weight, self.bias = None, None

self.running_mean, self.running_var = Tensor.zeros(num_devices, sz, requires_grad=False), Tensor.ones(num_devices, sz, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
self.running_mean = Tensor.zeros(num_devices, sz, dtype=dtypes.float32, requires_grad=False)
self.running_var = Tensor.ones(num_devices, sz, dtype=dtypes.float32, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False)

def __call__(self, x:Tensor):
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
Expand Down
17 changes: 12 additions & 5 deletions examples/mlperf/initializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math

from tinygrad import Tensor, nn
from tinygrad import Tensor, nn, dtypes
from tinygrad.helpers import prod, argfix

# rejection sampling truncated randn
Expand All @@ -17,11 +17,18 @@ def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor:
return std * rand_truncn(*shape, **kwargs)

class Conv2dHeNormal(nn.Conv2d):
def initialize_weight(self, out_channels, in_channels, groups):
return he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0)
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.weight = he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0, dtype=dtypes.float32)
if bias: self.bias = self.bias.cast(dtypes.float32)
def __call__(self, x: Tensor):
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)

class Linear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias=bias)
self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01)
if bias: self.bias = Tensor.zeros(out_features)
self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01, dtype=dtypes.float32)
if bias: self.bias = Tensor.zeros(out_features, dtype=dtypes.float32)
def __call__(self, x:Tensor):
return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
6 changes: 3 additions & 3 deletions tinygrad/nn/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class LARS(Optimizer):
def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
super().__init__(params, lr)
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []

def _step(self) -> List[Tensor]:
for i, t in enumerate(self.params):
Expand Down Expand Up @@ -73,8 +73,8 @@ def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, w
super().__init__(params, lr)
self.eps, self.wd, self.adam = eps, wd, adam
self.b1, self.b2, self.t = (Tensor([x], device=self.device, requires_grad=False).realize() for x in [b1, b2, 0])
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.m = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.v = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]

def _step(self) -> List[Tensor]:
self.t.assign(self.t + 1)
Expand Down

0 comments on commit 593c90d

Please sign in to comment.