Skip to content

Commit

Permalink
Train with yelp data perhaps
Browse files Browse the repository at this point in the history
  • Loading branch information
jzarnett committed Sep 16, 2023
1 parent 57e6199 commit d9e3250
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions lectures/live-coding/L24/yelp_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import TrainingArguments

def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)


def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")


def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)


def print_summary(res):
print(f"Time: {res.metrics['train_runtime']:.2f}")
print(f"Samples/second: {res.metrics['train_samples_per_second']:.2f}")
print_gpu_utilization()


print("Starting up. Initial GPU utilization:")
print_gpu_utilization()
torch.ones((1, 1)).to("cuda")
print("Initialized Torch; current GPU utilization:")
print_gpu_utilization()

dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
training_args = TrainingArguments(per_device_train_batch_size=8, output_dir="test_trainer")
metric = evaluate.load("accuracy")

trainer = Trainer(
model=model,
args=training_args,
train_dataset=small_train_dataset,
eval_dataset=small_eval_dataset,
compute_metrics=compute_metrics,
)

result = trainer.train()
print_summary(result)

0 comments on commit d9e3250

Please sign in to comment.