From 4dd3697612015cf35a9f9d20b33e5d753d8a38a9 Mon Sep 17 00:00:00 2001 From: William Brannon Date: Tue, 24 Sep 2024 20:17:41 +0000 Subject: [PATCH] torch.compile --- images/sentiment-topic/topic-embeds.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/images/sentiment-topic/topic-embeds.py b/images/sentiment-topic/topic-embeds.py index 07d7c86..a410415 100755 --- a/images/sentiment-topic/topic-embeds.py +++ b/images/sentiment-topic/topic-embeds.py @@ -26,9 +26,13 @@ class SentenceEmbedder: def __init__(self, data, output_dir, cache_dir='sentence-embeds-cache', padding=True, truncation=True, max_length=512, batch_size=128, - sort_length=True, autocast=True, device_ids=None): + sort_length=True, autocast=True, torch_compile=None, + device_ids=None): super().__init__() + if torch_compile is None: + torch_compile = bool(int(os.getenv('TORCH_COMPILE', '0'))) + self.data = data self.output_dir = output_dir self._cache_dir = cache_dir @@ -78,11 +82,12 @@ def __init__(self, data, output_dir, cache_dir='sentence-embeds-cache', self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name) + if torch_compile: + self.model = torch.compile(self.model) + if len(self.devices) > 1: self.model = nn.DataParallel(self.model, device_ids=self.devices) - # self.model = torch.compile(self.model) - self.model = self.model.to(self.device) @property