From 041d48f7e717fbc54c4ed82021bec8c2ee425a72 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Tue, 15 Aug 2023 23:47:01 -0400 Subject: [PATCH] minor refactor --- align_data/finetuning/finetuning_dataset.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/align_data/finetuning/finetuning_dataset.py b/align_data/finetuning/finetuning_dataset.py index 7ce60c88..c6d28c3d 100644 --- a/align_data/finetuning/finetuning_dataset.py +++ b/align_data/finetuning/finetuning_dataset.py @@ -2,9 +2,7 @@ import random from typing import List, Tuple, Generator -import torch -from sqlalchemy import func -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, get_worker_info from align_data.pinecone.pinecone_db_handler import PineconeDB from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter @@ -39,7 +37,7 @@ def __init__( self.total_articles = session.query(Article).count() def __iter__(self): - worker_info = torch.utils.data.get_worker_info() + worker_info = get_worker_info() if worker_info is None: # Single-process loading return self._generate_pairs() else: # Multi-process loading