-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from erfanzar/mojo-beta
`Mistral` Models Added
- Loading branch information
Showing
16 changed files
with
898 additions
and
301 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 2 additions & 40 deletions
42
lib/python/EasyDel/modules/falcon/modelling_falcon_flax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from jax.interpreters import pxla | ||
from jax.experimental.pjit import with_sharding_constraint as wsc | ||
import jax | ||
from flax import linen as nn | ||
from functools import partial | ||
|
||
ACT2FN = { | ||
"gelu": partial(nn.gelu, approximate=False), | ||
"relu": nn.relu, | ||
"silu": nn.swish, | ||
"swish": nn.swish, | ||
"gelu_new": partial(nn.gelu, approximate=True), | ||
|
||
} | ||
|
||
|
||
def get_names_from_partition_spec(partition_specs): | ||
names = set() | ||
if isinstance(partition_specs, dict): | ||
partition_specs = partition_specs.values() | ||
for item in partition_specs: | ||
if item is None: | ||
continue | ||
elif isinstance(item, str): | ||
names.add(item) | ||
else: | ||
names.update(get_names_from_partition_spec(item)) | ||
|
||
return list(names) | ||
|
||
|
||
def names_in_mesh(*names): | ||
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) | ||
|
||
|
||
def with_sharding_constraint(x, partition_specs): | ||
axis_names = get_names_from_partition_spec(partition_specs) | ||
if names_in_mesh(*axis_names): | ||
x = wsc(x, partition_specs) | ||
return x | ||
|
||
|
||
def get_gradient_checkpoint_policy(name): | ||
return { | ||
'everything_saveable': jax.checkpoint_policies.everything_saveable, | ||
'nothing_saveable': jax.checkpoint_policies.nothing_saveable, | ||
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, | ||
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, | ||
}[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.