Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial kernel changes to support GaLore #1137

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,9 +1555,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =

def optimizer_update_32bit(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
beta1: float,
eps: float,
step: int,
Expand All @@ -1571,6 +1571,7 @@ def optimizer_update_32bit(
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
skip_zeros=False,
return_updates: Optional[torch.Tensor] = None,
) -> None:
"""
Performs an inplace optimizer update with one or two optimizer states.
Expand Down Expand Up @@ -1613,6 +1614,8 @@ def optimizer_update_32bit(
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
return_updates: Optional[torch.Tensor]
When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
"""

param_norm = 0.0
Expand All @@ -1636,6 +1639,7 @@ def optimizer_update_32bit(
optim_func(
get_ptr(g),
get_ptr(p),
get_ptr(return_updates),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
Expand All @@ -1658,25 +1662,26 @@ def optimizer_update_32bit(

def optimizer_update_8bit(
optimizer_name: str,
g: Tensor,
p: Tensor,
state1: Tensor,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
eps: float,
step: int,
lr: float,
qmap1: Tensor,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
max1: Tensor,
max1: torch.Tensor,
max2: Optional[torch.Tensor],
new_max1: Tensor,
new_max1: torch.Tensor,
new_max2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
return_updates: Optional[torch.Tensor] = None,
) -> None:
"""
Performs an inplace Adam update.
Expand Down Expand Up @@ -1726,6 +1731,8 @@ def optimizer_update_8bit(
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
return_updates: Optional[torch.Tensor]
When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
"""

param_norm = 0.0
Expand All @@ -1738,6 +1745,7 @@ def optimizer_update_8bit(
str2optimizer8bit[optimizer_name][0](
get_ptr(p),
get_ptr(g),
get_ptr(return_updates),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
Expand All @@ -1762,6 +1770,7 @@ def optimizer_update_8bit(
str2optimizer8bit[optimizer_name][1](
get_ptr(p),
get_ptr(g),
get_ptr(return_updates),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
Expand Down Expand Up @@ -1809,6 +1818,7 @@ def optimizer_update_8bit_blockwise(
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
return_updates: Optional[torch.Tensor] = None,
) -> None:
optim_func = None
prev_device = pre_call(g.device)
Expand All @@ -1835,6 +1845,7 @@ def optimizer_update_8bit_blockwise(
optim_func(
get_ptr(p),
get_ptr(g),
get_ptr(return_updates),
get_ptr(state1),
get_ptr(state2),
ct.c_float(beta1),
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AdamW,
AdamW8bit,
AdamW32bit,
GaLoreAdamW8bit,
PagedAdamW,
PagedAdamW8bit,
PagedAdamW32bit,
Expand Down
123 changes: 122 additions & 1 deletion bitsandbytes/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
import torch

from bitsandbytes.optim.optimizer import GaLoreWrappedParameter, Optimizer2State

_galore_available = False
try:
from galore_torch.galore_projector import GaLoreProjector

_galore_available = True
except ImportError:
pass


class AdamW(Optimizer2State):
Expand Down Expand Up @@ -127,6 +137,117 @@ def __init__(
)


class GaLoreAdamW8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=8,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
if not _galore_available:
raise RuntimeError("The galore_torch package must be installed to use GaLoreAdamW8bit.")
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

overflows = []

if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True

# if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]

if "step" not in state:
state["step"] = 0

# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)

grad = state["projector"].project(p.grad, state["step"])

# suboptimal implementation
# p.saved_data = p.data.clone()
# p.data = grad.clone().to(p.data.dtype).to(p.data.device)
# p.data.zero_()
# p.grad = grad
lor_update = torch.zeros_like(
grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad
)

if "state1" not in state:
self.init_state(group, p, gindex, pindex)

self.prefetch_state(p)

if "rank" in group:
galore_p = GaLoreWrappedParameter(p=p, grad=grad)
self.update_step(group, galore_p, gindex, pindex, return_updates=lor_update)

# GaLore Projection Back
p.data.add_(state["projector"].project_back(lor_update))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Algorithm 1 in the paper:

update = project_back(lor_update)
weight.data += update


if "weight_decay" in group and group["weight_decay"] > 0:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
else:
self.update_step(group, p, gindex, pindex)

torch.cuda.synchronize()

if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()

return loss


class AdamW32bit(Optimizer2State):
def __init__(
self,
Expand Down
44 changes: 33 additions & 11 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.
from collections import abc as container_abcs, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from itertools import chain
from typing import Optional
from typing import Any, Dict, Optional, Union

import torch

Expand All @@ -18,6 +19,12 @@ def __init__(self, initial_data):
setattr(self, key, initial_data[key])


@dataclass
class GaLoreWrappedParameter:
p: torch.Tensor
grad: torch.Tensor


class GlobalOptimManager:
"""
A global optimizer manager for enabling custom optimizer configs.
Expand Down Expand Up @@ -320,7 +327,7 @@ def get_config(self, gindex, pindex, group):
def init_state(self, group, p, gindex, pindex):
raise NotImplementedError("init_state method needs to be overridden")

def update_step(self, group, p, gindex, pindex):
def update_step(self, group, p, gindex, pindex, return_updates):
raise NotImplementedError("The update_step method needs to be overridden")

def get_state_buffer(self, p, dtype=torch.float32):
Expand Down Expand Up @@ -494,13 +501,25 @@ def init_state(self, group, p, gindex, pindex):
state["unorm_vec"] = torch.zeros((1,), device=p.device)

@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
# avoid update error from non-contiguous memory layout
p.data = p.data.contiguous()
p.grad = p.grad.contiguous()
def update_step(
self,
group: Dict[str, Any],
p: Union[torch.Tensor, GaLoreWrappedParameter],
gindex: int,
pindex: int,
return_updates: Optional[torch.Tensor] = None,
):
if isinstance(p, GaLoreWrappedParameter):
# Unwrap for GaLore
param_to_optimize = p.p
else:
param_to_optimize = p

state = self.state[p]
grad = p.grad
state = self.state[param_to_optimize]

# avoid update error from non-contiguous memory layout
param_to_optimize.data = param_to_optimize.data.contiguous()
grad = p.grad.contiguous()

config = self.get_config(gindex, pindex, group)

Expand All @@ -521,7 +540,7 @@ def update_step(self, group, p, gindex, pindex):
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
param_to_optimize,
state["state1"],
config["betas"][0],
config["eps"],
Expand All @@ -536,13 +555,14 @@ def update_step(self, group, p, gindex, pindex):
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
return_updates=return_updates,
)

elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
param_to_optimize,
state["state1"],
state["state2"],
config["betas"][0],
Expand All @@ -560,6 +580,7 @@ def update_step(self, group, p, gindex, pindex):
gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
return_updates=return_updates,
)

# swap maxes
Expand All @@ -569,7 +590,7 @@ def update_step(self, group, p, gindex, pindex):
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
param_to_optimize,
state["state1"],
state["state2"],
config["betas"][0],
Expand All @@ -586,6 +607,7 @@ def update_step(self, group, p, gindex, pindex):
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
return_updates=return_updates,
)


Expand Down
Loading
Loading