Skip to content

Commit

Permalink
torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
wwbrannon committed Sep 24, 2024
1 parent ae9d201 commit 4dd3697
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions images/sentiment-topic/topic-embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
class SentenceEmbedder:
def __init__(self, data, output_dir, cache_dir='sentence-embeds-cache',
padding=True, truncation=True, max_length=512, batch_size=128,
sort_length=True, autocast=True, device_ids=None):
sort_length=True, autocast=True, torch_compile=None,
device_ids=None):
super().__init__()

if torch_compile is None:
torch_compile = bool(int(os.getenv('TORCH_COMPILE', '0')))

self.data = data
self.output_dir = output_dir
self._cache_dir = cache_dir
Expand Down Expand Up @@ -78,11 +82,12 @@ def __init__(self, data, output_dir, cache_dir='sentence-embeds-cache',
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name)

if torch_compile:
self.model = torch.compile(self.model)

if len(self.devices) > 1:
self.model = nn.DataParallel(self.model, device_ids=self.devices)

# self.model = torch.compile(self.model)

self.model = self.model.to(self.device)

@property
Expand Down

0 comments on commit 4dd3697

Please sign in to comment.