Skip to content

Commit

Permalink
adding typing to apply_model
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed May 23, 2023
1 parent 5a79870 commit a8154eb
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 29 deletions.
2 changes: 1 addition & 1 deletion demucs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

__version__ = "4.0.0"
__version__ = "4.0.1a1"
69 changes: 46 additions & 23 deletions demucs/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def __init__(self, models: tp.List[Model],
assert len(weight) == len(first.sources)
self.weights = weights

@property
def max_allowed_segment(self) -> float:
max_allowed_segment = float('inf')
for model in self.models:
max_allowed_segment = min(max_allowed_segment, float(model.segment))
return max_allowed_segment

def forward(self, x):
raise NotImplementedError("Call `apply_model` on this.")

Expand Down Expand Up @@ -121,9 +128,13 @@ def tensor_chunk(tensor_or_chunk):
return TensorChunk(tensor_or_chunk)


def apply_model(model, mix, shifts=1, split=True,
overlap=0.25, transition_power=1., progress=False, device=None,
num_workers=0, segment=None, pool=None):
def apply_model(model: tp.Union[BagOfModels, Model],
mix: tp.Union[th.Tensor, TensorChunk],
shifts: int = 1, split: bool = True,
overlap: float = 0.25, transition_power: float = 1.,
progress: bool = False, device=None,
num_workers: int = 0, segment: tp.Optional[float] = None,
pool=None) -> th.Tensor:
"""
Apply model to a given mixture.
Expand All @@ -140,6 +151,9 @@ def apply_model(model, mix, shifts=1, split=True,
execute the computation, otherwise `mix.device` is assumed.
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
num_workers (int): if non zero, device is 'cpu', how many threads to
use in parallel.
segment (float or None): override the model segment parameter.
"""
if device is None:
device = mix.device
Expand All @@ -150,7 +164,7 @@ def apply_model(model, mix, shifts=1, split=True,
pool = ThreadPoolExecutor(num_workers)
else:
pool = DummyPoolExecutor()
kwargs = {
kwargs: tp.Dict[str, tp.Any] = {
'shifts': shifts,
'split': split,
'overlap': overlap,
Expand All @@ -160,24 +174,26 @@ def apply_model(model, mix, shifts=1, split=True,
'pool': pool,
'segment': segment,
}
out: tp.Union[float, th.Tensor]
if isinstance(model, BagOfModels):
# Special treatment for bag of model.
# We explicitely apply multiple times `apply_model` so that the random shifts
# are different for each model.
estimates = 0
totals = [0] * len(model.sources)
for sub_model, weight in zip(model.models, model.weights):
estimates: tp.Union[float, th.Tensor] = 0.
totals = [0.] * len(model.sources)
for sub_model, model_weights in zip(model.models, model.weights):
original_model_device = next(iter(sub_model.parameters())).device
sub_model.to(device)

out = apply_model(sub_model, mix, **kwargs)
sub_model.to(original_model_device)
for k, inst_weight in enumerate(weight):
for k, inst_weight in enumerate(model_weights):
out[:, k, :, :] *= inst_weight
totals[k] += inst_weight
estimates += out
del out

assert isinstance(estimates, th.Tensor)
for k in range(estimates.shape[1]):
estimates[:, k, :, :] /= totals[k]
return estimates
Expand All @@ -190,60 +206,67 @@ def apply_model(model, mix, shifts=1, split=True,
kwargs['shifts'] = 0
max_shift = int(0.5 * model.samplerate)
mix = tensor_chunk(mix)
assert isinstance(mix, TensorChunk)
padded_mix = mix.padded(length + 2 * max_shift)
out = 0
out = 0.
for _ in range(shifts):
offset = random.randint(0, max_shift)
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
shifted_out = apply_model(model, shifted, **kwargs)
out += shifted_out[..., max_shift - offset:]
out /= shifts
assert isinstance(out, th.Tensor)
return out
elif split:
kwargs['split'] = False
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
sum_weight = th.zeros(length, device=mix.device)
if segment is None:
segment = model.segment
segment_old = model.segment
model.segment = segment
segment = int(model.samplerate * segment)
stride = int((1 - overlap) * segment)
assert segment is not None and segment > 0.
segment_length: int = int(model.samplerate * segment)
stride = int((1 - overlap) * segment_length)
offsets = range(0, length, stride)
scale = float(format(stride / model.samplerate, ".2f"))
# We start from a triangle shaped weight, with maximal weight in the middle
# of the segment. Then we normalize and take to the power `transition_power`.
# Large values of transition power will lead to sharper transitions.
weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
th.arange(segment - segment // 2, 0, -1, device=device)])
assert len(weight) == segment
weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
assert len(weight) == segment_length
# If the overlap < 50%, this will translate to linear transition when
# transition_power is 1.
weight = (weight / weight.max())**transition_power
futures = []
for offset in offsets:
chunk = TensorChunk(mix, offset, segment)
chunk = TensorChunk(mix, offset, segment_length)
future = pool.submit(apply_model, model, chunk, **kwargs)
futures.append((future, offset))
offset += segment
offset += segment_length
if progress:
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
for future, offset in futures:
chunk_out = future.result()
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
model.segment = segment_old
out[..., offset:offset + segment_length] += (
weight[:chunk_length] * chunk_out).to(mix.device)
sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
assert sum_weight.min() > 0
out /= sum_weight
assert isinstance(out, th.Tensor)
return out
else:
if hasattr(model, 'valid_length'):
valid_length = model.valid_length(length)
valid_length: int
if isinstance(model, HTDemucs) and segment is not None:
valid_length = int(segment * model.samplerate)
elif hasattr(model, 'valid_length'):
valid_length = model.valid_length(length) # type: ignore
else:
valid_length = length
mix = tensor_chunk(mix)
assert isinstance(mix, TensorChunk)
padded_mix = mix.padded(valid_length).to(device)
with th.no_grad():
out = model(padded_mix)
assert isinstance(out, th.Tensor)
return center_trim(out, length)
14 changes: 9 additions & 5 deletions demucs/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .apply import apply_model, BagOfModels
from .audio import AudioFile, convert_audio, save_audio
from .htdemucs import HTDemucs
from .pretrained import get_model_from_args, add_model_flags, ModelLoadingError


Expand Down Expand Up @@ -127,11 +128,14 @@ def main(opts=None):
except ModelLoadingError as error:
fatal(error.args[0])

if args.segment is not None and args.segment < 8:
fatal("Segment must greater than 8. ")

if '..' in args.filename.replace("\\", "/").split("/"):
fatal('".." must not appear in filename. ')
max_allowed_segment: float = float('inf')
if isinstance(model, HTDemucs):
max_allowed_segment = float(model.segment)
elif isinstance(model, BagOfModels):
max_allowed_segment = model.max_allowed_segment
if args.segment is not None and args.segment > max_allowed_segment:
fatal("Cannot use a Transformer model with a longer segment "
f"than it was trained for. Maximum segment is: {max_allowed_segment}")

if isinstance(model, BagOfModels):
print(f"Selected model is a bag of {len(model.models)} models. "
Expand Down
4 changes: 4 additions & 0 deletions docs/release.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Release notes for Demucs

## V4.0.1a1, TBD

Various improvements by @CarlGao4. Support for `segment` param inside of HTDemucs
model.

## V4.0.0, 7th of December 2022

Expand Down

0 comments on commit a8154eb

Please sign in to comment.