We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chaglm-6b lora微调执行到指定的eval_step后提示“iteration over a 0-d tensor”,故障如下所示:
代码如下: `def train_v2(model, train_data, val_data): writer = SummaryWriter()
world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 train_args = TrainingArguments( output_dir=args.output_path, do_train=True, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=1, learning_rate=3e-4, lr_scheduler_type="linear", warmup_ratio=0.05, max_steps=max_train_steps, fp16=True, logging_steps=100, eval_steps=100, save_steps=100, evaluation_strategy="steps" if args.test_size > 0 else "no", save_strategy="steps", load_best_model_at_end=True, remove_unused_columns=False, ddp_find_unused_parameters=False if ddp else None, ignore_data_skip=False, seed=10, data_seed=10, group_by_length=False # deepspeed="./config/ds_config.json" ) trainer = ModifiedTrainer( model=model, # optimizers=(optimizer, lr_scheduler), train_dataset=train_data, eval_dataset=val_data, args=train_args, callbacks=[TensorBoardCallback(writer)], # data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), data_collator=data_collator ) old_state_dict = model.state_dict model.state_dict = ( lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) ).__get__(model, type(model)) if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) print("\n If there's a warning about missing keys above, please disregard :)") # trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) trainer.train() writer.close() model.save_pretrained(args.output_path)
model = AutoModel.from_pretrained( args.model_name_or_path, config=config, # load_in_8bit=True, torch_dtype=torch.float16, device_map=device_map, trust_remote_code=True, revision="", )
model = model.half()
model.supports_gradient_checkpointing = True model.gradient_checkpointing_enable() model.enable_input_require_grads() model.config.use_cache = False
model.lm_head = CastOutputToFloat(model.lm_head) peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, ) model = get_peft_model(model, peft_config) model.print_trainable_parameters()
train_data = load_from_disk(args.data_path) train_v2(model, train_data, None)`
The text was updated successfully, but these errors were encountered:
修改ModifiedTrainer部分:
ModifiedTrainer
class ModifiedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): outputs = model( input_ids=inputs["input_ids"], labels=inputs["labels"], ) loss = outputs.loss return (loss, outputs) if return_outputs else loss def save_model(self, output_dir=None, _internal_call=False): self.model.save_pretrained(output_dir)
Sorry, something went wrong.
No branches or pull requests
chaglm-6b lora微调执行到指定的eval_step后提示“iteration over a 0-d tensor”,故障如下所示:
代码如下:
`def train_v2(model, train_data, val_data):
writer = SummaryWriter()
model = AutoModel.from_pretrained(
args.model_name_or_path,
config=config,
# load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
trust_remote_code=True,
revision="",
)
model = model.half()
model.supports_gradient_checkpointing = True
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.config.use_cache = False
model.lm_head = CastOutputToFloat(model.lm_head)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
train_data = load_from_disk(args.data_path)
train_v2(model, train_data, None)`
The text was updated successfully, but these errors were encountered: