-
Notifications
You must be signed in to change notification settings - Fork 1
/
worker-device.patch
39 lines (37 loc) · 2.33 KB
/
worker-device.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
diff --git a/TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py b/TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py
index 5fe2cd9c..ff5c5bdc 100644
--- a/TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py
+++ b/TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py
@@ -48,7 +48,7 @@ def train_loop(train_dataset,
train_loader = DataLoader(batch_size=batch_size,
dataset=train_dataset,
drop_last=True,
- num_workers=8,
+ num_workers=12 if os.cpu_count() > 12 else max(os.cpu_count() - 2, 1),
pin_memory=False,
shuffle=True,
prefetch_factor=16,
diff --git a/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/Glow.py b/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/Glow.py
index 804548a4..4e3cfac6 100644
--- a/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/Glow.py
+++ b/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/Glow.py
@@ -119,7 +119,7 @@ class InvConvNear(nn.Module):
if self.no_jacobian:
logdet = 0
- weight = weight.view(self.n_split, self.n_split, 1, 1)
+ weight = weight.view(self.n_split, self.n_split, 1, 1).to(x.device)
z = F.conv2d(x, weight)
z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
diff --git a/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/portaspeech_train_loop.py b/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/portaspeech_train_loop.py
index 3b596ddb..ea9c80f9 100644
--- a/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/portaspeech_train_loop.py
+++ b/TrainingInterfaces/Text_to_Spectrogram/PortaSpeech/portaspeech_train_loop.py
@@ -78,7 +78,7 @@ def train_loop(net,
train_loader = DataLoader(batch_size=batch_size,
dataset=train_dataset,
drop_last=True,
- num_workers=12,
+ num_workers=12 if os.cpu_count() > 12 else max(os.cpu_count() - 2, 1),
pin_memory=True,
shuffle=True,
prefetch_factor=8,