Skip to content

Commit

Permalink
Fixing src. bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Oct 2, 2023
1 parent cc83767 commit 6f57394
Show file tree
Hide file tree
Showing 20 changed files with 34 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/serving/causal-lm/llama-2-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.lax
from EasyDel import JAXServer
from fjutils import get_float_dtype_by_name
from src.EasyDel.transform import llama_from_pretrained
from EasyDel.transform import llama_from_pretrained
from transformers import AutoTokenizer
import gradio as gr

Expand Down
4 changes: 2 additions & 2 deletions examples/serving/causal-lm/llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import absl.app
from transformers import AutoTokenizer
from src.EasyDel.serve import JAXServer
from EasyDel.serve import JAXServer
import EasyDel
from absl import flags

Expand Down Expand Up @@ -108,7 +108,7 @@


def main(argv):
conf = src.EasyDel.configs.configs.llama_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.llama_configs[FLAGS.model_type]
config = EasyDel.LlamaConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down
4 changes: 2 additions & 2 deletions examples/training/causal-lm/falcon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from src.EasyDel import TrainArguments, CausalLMTrainer
from EasyDel import TrainArguments, CausalLMTrainer
from datasets import load_dataset
from huggingface_hub import HfApi
from src import EasyDel
Expand Down Expand Up @@ -159,7 +159,7 @@ def main(argv):
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
else:
conf = src.EasyDel.configs.configs.falcon_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.falcon_configs[FLAGS.model_type]
config = EasyDel.FalconConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down
2 changes: 1 addition & 1 deletion examples/training/causal-lm/gpt-j.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def main(argv):
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
else:
conf = src.EasyDel.configs.configs.gptj_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.gptj_configs[FLAGS.model_type]
config = EasyDel.GPTJConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down
4 changes: 2 additions & 2 deletions examples/training/causal-lm/llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import flax.core

from src.EasyDel import llama_from_pretrained
from EasyDel import llama_from_pretrained

from src.EasyDel import TrainArguments, CausalLMTrainer
from EasyDel import TrainArguments, CausalLMTrainer
from datasets import load_dataset
from huggingface_hub import HfApi
from src import EasyDel
Expand Down
4 changes: 2 additions & 2 deletions examples/training/causal-lm/mpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from src.EasyDel import TrainArguments, CausalLMTrainer
from EasyDel import TrainArguments, CausalLMTrainer
from datasets import load_dataset
from huggingface_hub import HfApi
from src import EasyDel
Expand Down Expand Up @@ -159,7 +159,7 @@ def main(argv):
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
else:
conf = src.EasyDel.configs.configs.mpt_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.mpt_configs[FLAGS.model_type]
config = EasyDel.MptConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down
4 changes: 2 additions & 2 deletions lib/python/EasyDel/configs/get_model_attr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from src.EasyDel import FlaxLlamaForCausalLM, FlaxMptForCausalLM, FlaxGPTJForCausalLM, FlaxGPTNeoXForCausalLM, \
from EasyDel import FlaxLlamaForCausalLM, FlaxMptForCausalLM, FlaxGPTJForCausalLM, FlaxGPTNeoXForCausalLM, \
FlaxT5ForConditionalGeneration, FlaxOPTForCausalLM, FlaxFalconForCausalLM

from src.EasyDel.configs.configs import llama_2_configs, gptj_configs, mpt_configs, opt_configs, falcon_configs, \
from EasyDel.configs.configs import llama_2_configs, gptj_configs, mpt_configs, opt_configs, falcon_configs, \
llama_configs

logger = logging.getLogger()
Expand Down
2 changes: 1 addition & 1 deletion lib/python/EasyDel/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from einops import rearrange
from typing import Union, Optional, List

from src.EasyDel import FlaxMptForCausalLM, FlaxLlamaForCausalLM, MptConfig, LlamaConfig
from EasyDel import FlaxMptForCausalLM, FlaxLlamaForCausalLM, MptConfig, LlamaConfig


# Converted from Pytorch To jax from LudicRrain Guy :)
Expand Down
6 changes: 3 additions & 3 deletions lib/python/EasyDel/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from fjutils import get_float_dtype_by_name
from jax.sharding import Mesh, PartitionSpec as Ps
from transformers import GenerationConfig, TextIteratorStreamer
from src.EasyDel.serve import seafoam
from EasyDel.serve import seafoam
import logging
from src.EasyDel.utils import RNG
from EasyDel.utils import RNG
import multiprocessing as mp
import torch
from src.EasyDel.utils import prefix_str
from EasyDel.utils import prefix_str

pjit = pjit.pjit

Expand Down
6 changes: 3 additions & 3 deletions lib/python/EasyDel/trainer/fsdp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
import wandb
from datasets import Dataset

from src.EasyDel.trainer.config import TrainArguments
from EasyDel.trainer.config import TrainArguments

import jax
import flax
from transformers import FlaxAutoModelForCausalLM, AutoConfig
from tqdm import tqdm
from src.EasyDel.utils import Timers
from EasyDel.utils import Timers
from EasyDel.smi import initialise_tracking, get_mem
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec
from flax.training import train_state
from jax import numpy as jnp
from torch.utils.data import DataLoader
from fjutils import match_partition_rules, make_shard_and_gather_fns, StreamingCheckpointer, count_params
from src.EasyDel.utils import prefix_print
from EasyDel.utils import prefix_print


def calculate_accuracy(predictions: jax.Array, targets: jax.Array):
Expand Down
4 changes: 2 additions & 2 deletions lib/python/EasyDel/trainer/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import os
import typing

from src.EasyDel.trainer.config import TrainArguments
from EasyDel.trainer.config import TrainArguments

import jax
import flax
import optax
from transformers import FlaxAutoModelForCausalLM, AutoConfig
from jax.sharding import PartitionSpec
from src.EasyDel.utils import Timers
from EasyDel.utils import Timers

from jax.experimental.pjit import pjit, with_sharding_constraint
from flax.training import train_state
Expand Down
2 changes: 1 addition & 1 deletion lib/python/EasyDel/transform/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from jax import numpy as jnp
from tqdm import tqdm
from transformers import FalconForCausalLM
from src.EasyDel.modules.falcon import FalconConfig
from EasyDel.modules.falcon import FalconConfig


def falcon_from_pretrained(model_id, device=jax.devices('cpu')[0]):
Expand Down
2 changes: 1 addition & 1 deletion lib/python/EasyDel/transform/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import torch
from transformers import LlamaForCausalLM
from src.EasyDel.modules.llama import LlamaConfig
from EasyDel.modules.llama import LlamaConfig


def llama_convert_hf_to_flax_load(checkpoints_dir, num_hidden_layers=32, num_attention_heads=32, hidden_size=4096,
Expand Down
2 changes: 1 addition & 1 deletion lib/python/EasyDel/transform/mpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from src.EasyDel import MptConfig, FlaxMptForCausalLM
from EasyDel import MptConfig, FlaxMptForCausalLM
from jax import numpy as jnp
import jax
import torch
Expand Down
2 changes: 1 addition & 1 deletion lib/python/examples/serving/causal-lm/llama-2-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.lax
from EasyDel import JAXServer
from fjutils import get_float_dtype_by_name
from src.EasyDel.transform import llama_from_pretrained
from EasyDel.transform import llama_from_pretrained
from transformers import AutoTokenizer
import gradio as gr

Expand Down
4 changes: 2 additions & 2 deletions lib/python/examples/serving/causal-lm/llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import absl.app
from transformers import AutoTokenizer
from src.EasyDel.serve import JAXServer
from EasyDel.serve import JAXServer
import EasyDel
from absl import flags

Expand Down Expand Up @@ -108,7 +108,7 @@


def main(argv):
conf = src.EasyDel.configs.configs.llama_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.llama_configs[FLAGS.model_type]
config = EasyDel.LlamaConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down
4 changes: 2 additions & 2 deletions lib/python/examples/training/causal-lm/falcon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from src.EasyDel import TrainArguments, CausalLMTrainer
from EasyDel import TrainArguments, CausalLMTrainer
from datasets import load_dataset
from huggingface_hub import HfApi
from src import EasyDel
Expand Down Expand Up @@ -159,7 +159,7 @@ def main(argv):
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
else:
conf = src.EasyDel.configs.configs.falcon_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.falcon_configs[FLAGS.model_type]
config = EasyDel.FalconConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down
2 changes: 1 addition & 1 deletion lib/python/examples/training/causal-lm/gpt-j.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def main(argv):
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
else:
conf = src.EasyDel.configs.configs.gptj_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.gptj_configs[FLAGS.model_type]
config = EasyDel.GPTJConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down
4 changes: 2 additions & 2 deletions lib/python/examples/training/causal-lm/llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import flax.core

from src.EasyDel import llama_from_pretrained
from EasyDel import llama_from_pretrained

from src.EasyDel import TrainArguments, CausalLMTrainer
from EasyDel import TrainArguments, CausalLMTrainer
from datasets import load_dataset
from huggingface_hub import HfApi
from src import EasyDel
Expand Down
4 changes: 2 additions & 2 deletions lib/python/examples/training/causal-lm/mpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from src.EasyDel import TrainArguments, CausalLMTrainer
from EasyDel import TrainArguments, CausalLMTrainer
from datasets import load_dataset
from huggingface_hub import HfApi
from src import EasyDel
Expand Down Expand Up @@ -159,7 +159,7 @@ def main(argv):
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
else:
conf = src.EasyDel.configs.configs.mpt_configs[FLAGS.model_type]
conf = EasyDel.configs.configs.mpt_configs[FLAGS.model_type]
config = EasyDel.MptConfig(**conf, rotary_type=FLAGS.rotary_type)
config.use_flash_attention = FLAGS.use_flash_attention
config.use_sacn_mlp = FLAGS.use_sacn_mlp
Expand Down

0 comments on commit 6f57394

Please sign in to comment.