Skip to content

Commit

Permalink
fix review
Browse files Browse the repository at this point in the history
  • Loading branch information
hijkzzz committed Nov 15, 2023
1 parent 107f2f9 commit 65bb79a
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 25 deletions.
9 changes: 2 additions & 7 deletions examples/nlp/gpt/offline/launch_random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from datetime import datetime, timedelta

import jsonlines
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from megatron.core import mpu
from pytorch_lightning.trainer.trainer import Trainer
from tqdm import tqdm
from utils import get_max_time_per_run, load_nemo_or_checkpoint, set_seed
Expand All @@ -29,13 +31,6 @@
from nemo_aligner.data.nlp.offline.builders import build_data_loader, build_dataset
from nemo_aligner.utils.utils import set_autocast_gpu_dtype

try:
from megatron.core import mpu

HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_MEGATRON_CORE = False

if not torch.cuda.is_available():
raise OSError("GPU is needed for the inference")

Expand Down
9 changes: 2 additions & 7 deletions examples/nlp/gpt/offline/launch_reward_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from datetime import datetime, timedelta

import jsonlines
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from megatron.core import mpu
from pytorch_lightning.trainer.trainer import Trainer
from tqdm import tqdm
from utils import get_max_time_per_run, load_nemo_or_checkpoint, set_seed
Expand All @@ -30,13 +32,6 @@
from nemo_aligner.models.nlp.gpt.megatron_gpt_reward_model import MegatronGPTRewardModel
from nemo_aligner.utils.utils import set_autocast_gpu_dtype

try:
from megatron.core import mpu

HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_MEGATRON_CORE = False

if not torch.cuda.is_available():
raise OSError("GPU is needed for the inference")

Expand Down
8 changes: 1 addition & 7 deletions examples/nlp/gpt/offline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
import torch
from megatron.core import mpu
from omegaconf import OmegaConf, open_dict

from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
Expand All @@ -26,13 +27,6 @@
from nemo.utils.model_utils import inject_model_parallel_rank
from nemo_aligner.models.nlp.gpt.megatron_gpt_reward_model import MegatronGPTRewardModel

try:
from megatron.core import mpu

HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_MEGATRON_CORE = False


def load_nemo_or_checkpoint(model_class, trainer, cfg):
if cfg.gpt_model_file:
Expand Down
5 changes: 1 addition & 4 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from megatron.core import parallel_state
from omegaconf.omegaconf import OmegaConf, open_dict

from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import get_prompt_template_example
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
)
Expand Down Expand Up @@ -90,10 +91,6 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):

if cfg.model.data.get("chat", False):
# chat model, overwrite the prompt template
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import (
get_prompt_template_example,
)

prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
gpt_cfg.data.train_ds.prompt_template = prompt_template
gpt_cfg.data.validation_ds.prompt_template = prompt_template
Expand Down

0 comments on commit 65bb79a

Please sign in to comment.