Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from glob import glob
from os.path import join, exists
# pylint: disable=import-error
import numpy as np
import pandas as pd
# pylint: enable=import-error
from .separate import entrypoint as separate_entrypoint
from ..utils.logging import get_logger
try:
import musdb
import museval
except ImportError:
logger = get_logger()
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
sys.exit(1)
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_SPLIT = 'test'
_MIXTURE = 'mixture.wav'
_AUDIO_DIRECTORY = 'audio'
_METRICS_DIRECTORY = 'metrics'
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
output_kwargs['audio_bitrate'] = bitrate
if codec is not None and codec != 'wav':
output_kwargs['codec'] = codec
process = (
ffmpeg
.input('pipe:', format='f32le', **input_kwargs)
.output(path, **output_kwargs)
.overwrite_output()
.run_async(pipe_stdin=True, quiet=True))
try:
process.stdin.write(data.astype('
def cache(self, dataset, cache, wait):
""" Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already computing
cache) if provided wait flag is True.
:param dataset: Dataset to be cached if cache is required.
:param cache: Path of cache directory to be used, None if no cache.
:param wait: If caching is enabled, True is cache should be waited.
:returns: Cached dataset if needed, original dataset otherwise.
"""
if cache is not None:
if wait:
while not exists(f'{cache}.index'):
get_logger().info(
'Cache not available, wait %s',
self.WAIT_PERIOD)
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
return dataset.cache(cache)
return dataset
def safe_load(path, offset, duration, sample_rate, dtype):
get_logger().info(
f'Loading audio {path} from {offset} to {offset + duration}')
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy())
get_logger().info('Audio data loaded successfully')
return (data, False)
except Exception as e:
get_logger().warning(e)
return (np.float32(-1.0), True)
def safe_load(path, offset, duration, sample_rate, dtype):
get_logger().info(
f'Loading audio {path} from {offset} to {offset + duration}')
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy())
get_logger().info('Audio data loaded successfully')
return (data, False)
except Exception as e:
get_logger().warning(e)
return (np.float32(-1.0), True)
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
audio_adapter = get_audio_adapter(arguments.audio_adapter)
audio_path = arguments.audio_path
estimator = _create_estimator(params)
train_spec = _create_train_spec(params, audio_adapter, audio_path)
evaluation_spec = _create_evaluation_spec(
params,
audio_adapter,
audio_path)
get_logger().info('Start model training')
tf.estimator.train_and_evaluate(
estimator,
train_spec,
evaluation_spec)
get_logger().info('Model training done')
def _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory):
""" Generates musdb metrics fro previsouly computed audio estimation.
:param arguments: Entrypoint arguments.
:param audio_output_directory: Directory to get audio estimation from.
:returns: Path of generated metrics directory.
"""
metrics_output_directory = join(
arguments.output_path,
_METRICS_DIRECTORY)
get_logger().info('Starting musdb evaluation (this could be long) ...')
dataset = musdb.DB(
root=musdb_root_directory,
is_wav=True,
subsets=[_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
get_logger().info('musdb evaluation done')
return metrics_output_directory
try:
parser = create_argument_parser()
arguments = parser.parse_args(argv[1:])
enable_logging()
if arguments.verbose:
enable_tensorflow_logging()
if arguments.command == 'separate':
from .commands.separate import entrypoint
elif arguments.command == 'train':
from .commands.train import entrypoint
elif arguments.command == 'evaluate':
from .commands.evaluate import entrypoint
params = load_configuration(arguments.configuration)
entrypoint(arguments, params)
except SpleeterError as e:
get_logger().error(e)
self._repository,
self.RELEASE_PATH,
self._release,
name)
get_logger().info('Downloading model archive %s', url)
response = requests.get(url, stream=True)
if response.status_code != 200:
raise IOError(f'Resource {url} not found')
with TemporaryFile() as stream:
copyfileobj(response.raw, stream)
get_logger().info('Extracting downloaded %s archive', name)
stream.seek(0)
tar = tarfile.open(fileobj=stream)
tar.extractall(path=path)
tar.close()
get_logger().info('%s model file(s) extracted', name)