Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated docstrings, added missing docstrings, minor formatting/spacing fixes #849

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 76 additions & 14 deletions demucs/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

progress_bar_num = 0


class BagOfModels(nn.Module):
def __init__(self, models: tp.List[Model],
weights: tp.Optional[tp.List[tp.List[float]]] = None,
Expand All @@ -43,7 +44,7 @@ def __init__(self, models: tp.List[Model],
segment (None or float): overrides the `segment` attribute of each model
(this is performed inplace, be careful if you reuse the models passed).
"""

super().__init__()
assert len(models) > 0
first = models[0]
Expand All @@ -68,10 +69,29 @@ def __init__(self, models: tp.List[Model],
self.weights = weights

def forward(self, x):
"""
Forward method to apply the model on input data.

This method should be implemented by subclasses to apply the model on the input data. The implementation should raise a NotImplementedError with the message 'Call `apply_model` on this.'

Parameters:
x: The input data to apply the model on.
"""
raise NotImplementedError("Call `apply_model` on this.")

class TensorChunk:
"""
This class represents a chunk of a tensor.
"""
def __init__(self, tensor, offset=0, length=None):
"""
Initialize a new TensorChunk with a tensor, offset, and length.

Parameters:
tensor: The tensor object.
offset (int): The starting offset of the chunk.
length (int): The length of the chunk.
"""
total_length = tensor.shape[-1]
assert offset >= 0
assert offset < total_length
Expand All @@ -92,11 +112,23 @@ def __init__(self, tensor, offset=0, length=None):

@property
def shape(self):
"""
Return the shape of the tensor chunk.
"""
shape = list(self.tensor.shape)
shape[-1] = self.length
return shape

def padded(self, target_length):
"""
Pad the tensor chunk to a target length.

Parameters:
target_length (int): The desired length of the padded tensor chunk.

Returns:
The padded tensor chunk.
"""
delta = target_length - self.length
total_length = self.tensor.shape[-1]
assert delta >= 0
Expand All @@ -115,12 +147,25 @@ def padded(self, target_length):
return out

def tensor_chunk(tensor_or_chunk):
"""
Given a tensor or a chunk, return a TensorChunk object.

If the input is already a TensorChunk object, it is returned as is. Otherwise, if the input is an instance of the th.Tensor class,
a new TensorChunk object is created using the input tensor and returned.

Parameters:
tensor_or_chunk: A tensor or a chunk.

Returns:
TensorChunk: A TensorChunk object.
"""
if isinstance(tensor_or_chunk, TensorChunk):
return tensor_or_chunk
else:
assert isinstance(tensor_or_chunk, th.Tensor)
return TensorChunk(tensor_or_chunk)


def apply_model(model,
mix,
shifts=1,
Expand All @@ -137,24 +182,32 @@ def apply_model(model,
Apply model to a given mixture.

Args:
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
and apply the oppositve shift to the output. This is repeated `shifts` time and
model: The model to apply.
mix: The mixture to apply the model on.
shifts (int): If > 0, will shift in time 'mix' by a random amount between 0 and 0.5 sec
and apply the opposite shift to the output. This is repeated 'shifts' times and
all predictions are averaged. This effectively makes the model time equivariant
and improves SDR by up to 0.2 points.
split (bool): if True, the input will be broken down in 8 seconds extracts
split (bool): If True, the input will be broken down into 8 seconds extracts
and predictions will be performed individually on each and concatenated.
Useful for model with large memory footprint like Tasnet.
progress (bool): if True, show a progress bar (requires split=True)
Useful for models with large memory footprint like Tasnet.
overlap (float): The overlap between the extracted segments.
transition_power (float): The power to apply to the transition weights.
static_shifts (int): The number of static shifts to apply.
set_progress_bar: A function to set the progress bar.
device (torch.device, str, or None): if provided, device on which to
execute the computation, otherwise `mix.device` is assumed.
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
progress (bool): If True, show a progress bar (requires split=True).
num_workers (int): The number of workers for parallel computation.
pool: The thread pool executor.
"""

global fut_length
global bag_num
global prog_bar

if device is None:
device = mix.device
else:
Expand All @@ -164,7 +217,7 @@ def apply_model(model,
pool = ThreadPoolExecutor(num_workers)
else:
pool = DummyPoolExecutor()

kwargs = {
'shifts': shifts,
'split': split,
Expand All @@ -176,7 +229,7 @@ def apply_model(model,
'set_progress_bar': set_progress_bar,
'static_shifts': static_shifts,
}

if isinstance(model, BagOfModels):
# Special treatment for bag of model.
# We explicitely apply multiple times `apply_model` so that the random shifts
Expand Down Expand Up @@ -209,7 +262,7 @@ def apply_model(model,
model.eval()
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
batch, channels, length = mix.shape

if shifts:
kwargs['shifts'] = 0
max_shift = int(0.5 * model.samplerate)
Expand Down Expand Up @@ -270,9 +323,18 @@ def apply_model(model,
with th.no_grad():
out = model(padded_mix)
return center_trim(out, length)

def demucs_segments(demucs_segment, demucs_model):

"""
Function to assign a segment value to models in demucs_model based on the value of demucs_segment.

Parameters:
demucs_segment (str or int): The segment value to assign to the models.
demucs_model (BagOfModels or Model): The models to assign the segment value to.

Returns:
demucs_model: The models with the segment value assigned.
"""

if demucs_segment == 'Default':
segment = None
if isinstance(demucs_model, BagOfModels):
Expand Down Expand Up @@ -301,5 +363,5 @@ def demucs_segments(demucs_segment, demucs_model):
else:
if segment is not None:
sub.segment = segment
return demucs_model

return demucs_model
Loading