Skip to content

Commit

Permalink
fix possible bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ostix360 committed Jun 27, 2024
1 parent 8aa7deb commit 46c9d5f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
3 changes: 1 addition & 2 deletions audioenhancer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

SAMPLING_RATE = 16000
UPSAMPLE_RATE = 32000
MAX_AUDIO_LENGTH = 10
BATCH_SIZE = 3
EPOCH = 1
Expand All @@ -12,4 +11,4 @@
SAVE_STEPS = 1000

INPUT_FREQ = 16000
OUTPUT_FREQ = 16000
OUTPUT_FREQ = 32000
6 changes: 3 additions & 3 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torchaudio

from audioenhancer.constants import SAMPLING_RATE, MAX_AUDIO_LENGTH
from audioenhancer.constants import SAMPLING_RATE, MAX_AUDIO_LENGTH, INPUT_FREQ
from audioenhancer.model.audio_ae.auto_encoder import AutoEncoder1d

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -81,7 +81,7 @@ def load(waveform_path):
"""
waveform, sample_rate = torchaudio.load(waveform_path)
resampler = torchaudio.transforms.Resample(
sample_rate, args.sampling_rate, dtype=waveform.dtype
sample_rate, INPUT_FREQ, dtype=waveform.dtype
)
waveform = resampler(waveform)
if waveform.shape[0] == 1:
Expand Down Expand Up @@ -118,5 +118,5 @@ def load(waveform_path):
output = output.detach().cpu()
audio = audio.squeeze(0).detach().cpu()

torchaudio.save("./data/input.mp3", audio.T, args.sampling_rate, channels_first=False)
torchaudio.save("./data/input.mp3", audio.T, SAMPLING_RATE, channels_first=False)
torchaudio.save("./data/output.mp3", output.T, args.sampling_rate, channels_first=False)
9 changes: 3 additions & 6 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,8 @@

x = batch[0].to(device, dtype=dtype)
y = batch[1].to(device, dtype=dtype)
up_y = batch[2].to(device, dtype=dtype)

y_hat = model(y)
y_hat = model(x)
loss = sum(loss(y_hat, up_y) for loss in loss_fn)
loss.backward()

Expand Down Expand Up @@ -255,10 +254,8 @@
for batch in test_loader:
x = batch[0].to(device, dtype=dtype)
y = batch[1].to(device, dtype=dtype)
up_y = batch[2].to(device, dtype=dtype)

y_hat = model(y)
loss = sum(loss(y_hat, up_y) for loss in loss_fn)
y_hat = model(x)
loss = sum(loss(y_hat, y) for loss in loss_fn)
# batch_disc = torch.cat([y, y_hat], dim=0)
# disc_pred = discriminator(batch_disc)
# disc_pred = torch.sigmoid(disc_pred).squeeze()
Expand Down

0 comments on commit 46c9d5f

Please sign in to comment.