-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
323 lines (249 loc) · 9.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""Util functions."""
import jax
from jax import numpy as jnp
from jax import random as jr
from jax import tree_util as jtu
from jax import Array
from optax import GradientTransformation
from typing import Tuple
import ml_dtypes
# Other util functions.
def merge_non_zero_dict(target, source):
"""Merges non-zero items in source dictionary into target dictionary.
This is a mutable operation.
"""
for key, value in source.items():
if not value == 0:
target[key] = value
# Util functions for tree manipulation.
def zero_tree(tree):
"""Returns an all-zero tree with the same structure as the input."""
return jtu.tree_map(jnp.zeros_like, tree)
def tree_add(tree1, tree2):
return jtu.tree_map(lambda x,y: x+y, tree1, tree2)
def tree_subtract(tree1, tree2):
return jtu.tree_map(lambda x,y: x-y, tree1, tree2)
def tree_multiply(tree1, tree2):
return jtu.tree_map(lambda x,y: x*y, tree1, tree2)
def tree_dot(tree1, tree2):
return jtu.tree_reduce(
lambda x,y: x+y,
jtu.tree_map(lambda x,y: jnp.dot(x,y), tree1, tree2)
)
def negative_tree(tree):
"""A `jtu.tree_map`-broadcasted version of tree -> -tree."""
return jtu.tree_map(lambda t: -t, tree)
def tree_scalar_multiply(tree, scalar):
return jtu.tree_map(lambda x: scalar*x, tree)
def tree_l1_norm(tree):
"""Returns the l1 norm of the vectorized tree."""
return jtu.tree_reduce(
lambda x, y: x + y,
jtu.tree_map(lambda x: jnp.sum(jnp.abs(x)), tree)
)
def tree_l2_norm(tree):
"""Returns the l2 norm of the vectorized tree."""
return jnp.sqrt(
jtu.tree_reduce(
lambda x, y: x + y, jtu.tree_map(lambda x: jnp.sum(x * x), tree)
)
)
def tree_squared_l2_norm(tree):
"""Returns the sum of squares across the entire tree, without taking the square root."""
return jtu.tree_reduce(
lambda x, y: x + y,
jtu.tree_map(lambda x: jnp.sum(x * x), tree)
)
# TODO: deprecated, to be removed
def tree_norm(tree):
"""Returns the l2 norm of the vectorized tree."""
return tree_l2_norm(tree)
def is_zero_tree(tree):
"""Checks if a tree only has zero entries."""
return jtu.tree_reduce(
lambda x, y: x & y, jtu.tree_map(lambda x: jnp.all(x == 0), tree)
)
def is_finite_tree(tree):
"""Returns whether a tree is finite."""
leaves = jtu.tree_flatten(tree)[0]
return jnp.all(
jnp.array([jnp.all(jnp.isfinite(node)) for node in leaves]))
def tree_normalize(tree):
# Use jax.lax.cond to avoid trace issue.
return jax.lax.cond(
is_zero_tree(tree),
true_fun=lambda _: tree,
false_fun=lambda _: tree_scalar_multiply(tree, 1/tree_norm(tree)),
operand=None,
)
def tree_inner_product(tree1, tree2):
leaves1, _ = jtu.tree_flatten(tree1)
leaves2, _ = jtu.tree_flatten(tree2)
return sum(jnp.sum(a * b) for a, b in zip(leaves1, leaves2))
def tree_cosine_similarity(tree1, tree2):
"""Returns the cosine similarity of two trees."""
return tree_inner_product(tree_normalize(tree1), tree_normalize(tree2))
def tree_norm_direction_decomposition(tree):
"""Decomposes the norm and the direction of a tree.
Returns:
The norm of a tree (1d array) and the normalized tree.
If the tree is all zeros, then return 0 as the norm and an all-zero tree.
"""
def true_fun(_):
return jnp.zeros([], jnp.float32), tree
def false_fun(_):
norm = tree_norm(tree)
return norm, tree_scalar_multiply(tree, 1/norm)
return jax.lax.cond(
is_zero_tree(tree), true_fun, false_fun, operand=None)
# norm = tree_norm(tree)
# # NOTE: we need to return jax.Array to make sure both branches returns the
# # same data structure and thus avoid lax.cond issue
# if norm == 0:
# return jnp.zeros([], jnp.float32), tree
# return norm, tree_scalar_multiply(tree, 1/norm)
def random_unit_vector(key, tree):
"""Constructs a pytree of same structure as input whose leaves is a random unit vector.
Returns:
New PRNGKey and a uniform random vector on the unit sphere.
"""
# Construct a pytree of random keys.
key, new_key = jr.split(key)
keys = jr.split(key, num=len(jtu.tree_leaves(tree)))
keys_tree = jtu.tree_unflatten(jtu.tree_structure(tree), keys)
# Sample Gaussian vector.
normal_vector = jtu.tree_map(
lambda t, k: jr.normal(k, shape=t.shape, dtype=t.dtype),
tree, keys_tree
)
return new_key, tree_normalize(normal_vector)
def check_tree_structures_match(tree1, tree2):
"""Check whether tree1 and tree2 have the same tree structure.
Raises error when structures do not match.
"""
if jtu.tree_structure(tree1) != jtu.tree_structure(tree2):
raise ValueError("Input Pytrees do not have the same structure")
# ===============================================
# Other util functions
# ===============================================
def merge_dicts(*to_merge):
result = {}
for d in to_merge:
result.update(d)
return result
def get_accuracy(logits: Array, batch: Tuple[Array], ignore_index: int = -100):
input, target = batch # [N, L], [N, L]
predictions = jnp.argmax(logits, axis=2) # [N, L, C] -> [N, L]
return jnp.sum(predictions == target) / jnp.sum(target != ignore_index)
def get_dtype(dtype: str):
registry = {
"bfloat16": ml_dtypes.bfloat16,
"float16": jnp.float16,
}
return registry[dtype.lower()]
# TODO: This is hella slow. Needs better solution
def log_optax(base_optimizer, log_fn):
def init_fn(params):
return base_optimizer.init(params)
def update_fn(updates, state, params, hint=None):
log_fn(updates, state, params)
return base_optimizer.update(updates, state, params, hint)
return GradientTransformation(init_fn, update_fn)
# basically the same as the pytorch function cross_entropy
def softmax_cross_entropy(
input,
target,
weight=None,
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
axis=None,
):
"""Computes softmax cross entropy between sets of logits and integer labels.
Measures the probability error in discrete classification tasks in which
the classes are mutually exclusive (each entry is in exactly one class).
For example, each CIFAR-10 image is labeled with one and only one label:
an image can be a dog or a truck, but not both.
References:
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
Args:
logits: Unnormalized log probabilities, with shape `[..., num_classes]`.
labels: Integers specifying the correct class for each input, with shape
`[...]`.
Returns:
Cross entropy between each prediction and the corresponding target
distributions, with shape `[...]`.
"""
# This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that
# we avoid subtracting the normalizer from all values, just from the values
# for the correct labels.
if axis is None:
axis = input.ndim - 1
if axis < 0:
axis = input.ndim + axis
C = input.shape[axis]
if weight is not None:
weight_shape = (
(1,) * axis + (input.shape[axis],) + (1,) * (input.ndim - axis - 1)
)
weight = weight.reshape(weight_shape)
if isinstance(target, int) or target.ndim != input.ndim:
no_ignore = jax.lax.stop_gradient(target != ignore_index)
logits_max = jnp.max(
input, axis=axis, keepdims=True
) # , where=no_ignore, initial=-jnp.inf)
logits = input - jax.lax.stop_gradient(logits_max)
broadcast_shape = logits.shape[:axis] + (1,) + logits.shape[axis + 1 :]
log_normalizers = jnp.log(
jnp.sum(
jnp.exp(logits), axis=axis, where=no_ignore.reshape(broadcast_shape)
)
)
labels_no_ignore = jnp.where(no_ignore, target, 0)
label_logits = jnp.take_along_axis(
logits, labels_no_ignore[..., None], axis=axis
)[..., 0]
if label_smoothing != 0 or weight is not None:
one_hot_labels = jax.nn.one_hot(labels_no_ignore, num_classes=C, axis=axis)
target_probs = (
one_hot_labels * (1.0 - label_smoothing)
+ jnp.ones_like(one_hot_labels) / C * label_smoothing
)
if weight is not None:
target_probs = target_probs * weight
log_normalizers = log_normalizers * jnp.sum(target_probs, axis=axis)
losses = -(
jnp.sum(
target_probs * logits,
where=no_ignore.reshape(broadcast_shape),
axis=axis,
)
- log_normalizers
)
else:
label_logits = jnp.take_along_axis(
logits, labels_no_ignore[..., None], axis=axis
)[..., 0]
losses = log_normalizers - label_logits
losses = jnp.where(no_ignore, losses, 0.0)
else:
target_probs = (
target * (1.0 - label_smoothing)
+ jnp.ones_like(target) / C * label_smoothing
)
logits_max = jnp.max(input, axis=axis, keepdims=True)
logits = input - jax.lax.stop_gradient(logits_max)
log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=axis))
if weight is not None:
target_probs = target_probs * weight
log_normalizers = log_normalizers * jnp.sum(
target_probs * weight, axis=axis
)
losses = -(jnp.sum(target_probs * logits, axis=axis) - log_normalizers)
no_ignore = None
if reduction == "none":
return losses
if reduction == "mean":
return jnp.mean(losses, where=no_ignore)
if reduction == "sum":
return jnp.sum(losses, where=no_ignore)