Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _data_processing(config, params):
config.batch_shape = params.get('batch_shape', (50, 50))
config.num_chunks = params.get('num_chunks', 1)
image_bits = params.get('image_bits', 5)
config.preprocess_fn = tools.bind(
tools.preprocess.preprocess, bits=image_bits)
config.postprocess_fn = tools.bind(
tools.preprocess.postprocess, bits=image_bits)
config.open_loop_context = 5
config.data_reader = tools.numpy_episodes.episode_reader
config.data_loader = {
'cache': tools.bind(
tools.numpy_episodes.cache_loader,
every=params.get('loader_every', 1000)),
'recent': tools.bind(
tools.numpy_episodes.recent_loader,
every=params.get('loader_every', 1000)),
'reload': tools.numpy_episodes.reload_loader,
'dummy': tools.numpy_episodes.dummy_loader,
}[params.get('loader', 'recent')]
config.bound_action = tools.bind(
tools.bound_action,
strategy=params.get('bound_action', 'clip'))
return config
def _data_processing(config, params):
config.batch_shape = params.get('batch_shape', (50, 50))
config.num_chunks = params.get('num_chunks', 1)
image_bits = params.get('image_bits', 5)
config.preprocess_fn = tools.bind(
tools.preprocess.preprocess, bits=image_bits)
config.postprocess_fn = tools.bind(
tools.preprocess.postprocess, bits=image_bits)
config.open_loop_context = 5
config.data_reader = tools.numpy_episodes.episode_reader
config.data_loader = {
'cache': tools.bind(
tools.numpy_episodes.cache_loader,
every=params.get('loader_every', 1000)),
'recent': tools.bind(
tools.numpy_episodes.recent_loader,
every=params.get('loader_every', 1000)),
'reload': tools.numpy_episodes.reload_loader,
'dummy': tools.numpy_episodes.dummy_loader,
}[params.get('loader', 'recent')]
config.bound_action = tools.bind(
def _data_processing(config, params):
config.batch_shape = params.get('batch_shape', (50, 50))
config.num_chunks = params.get('num_chunks', 1)
image_bits = params.get('image_bits', 5)
config.preprocess_fn = tools.bind(
tools.preprocess.preprocess, bits=image_bits)
config.postprocess_fn = tools.bind(
tools.preprocess.postprocess, bits=image_bits)
config.open_loop_context = 5
config.data_reader = tools.numpy_episodes.episode_reader
config.data_loader = {
'cache': tools.bind(
tools.numpy_episodes.cache_loader,
every=params.get('loader_every', 1000)),
'recent': tools.bind(
tools.numpy_episodes.recent_loader,
every=params.get('loader_every', 1000)),
'reload': tools.numpy_episodes.reload_loader,
'dummy': tools.numpy_episodes.dummy_loader,
}[params.get('loader', 'recent')]
config.bound_action = tools.bind(
tools.bound_action,
strategy=params.get('bound_action', 'clip'))
return config
def _define_simulation(
task, config, params, horizon, batch_size, objective='reward',
rewards=False):
planner = params.get('planner', 'cem')
if planner == 'cem':
planner_fn = tools.bind(
control.planning.cross_entropy_method,
amount=params.get('planner_amount', 1000),
iterations=params.get('planner_iterations', 10),
topk=params.get('planner_topk', 100),
horizon=horizon)
else:
raise NotImplementedError(planner)
return tools.AttrDict(
task=task,
num_agents=batch_size,
planner=planner_fn,
objective=tools.bind(getattr(objectives_lib, objective), params=params))
rewards=False):
planner = params.get('planner', 'cem')
if planner == 'cem':
planner_fn = tools.bind(
control.planning.cross_entropy_method,
amount=params.get('planner_amount', 1000),
iterations=params.get('planner_iterations', 10),
topk=params.get('planner_topk', 100),
horizon=horizon)
else:
raise NotImplementedError(planner)
return tools.AttrDict(
task=task,
num_agents=batch_size,
planner=planner_fn,
objective=tools.bind(getattr(objectives_lib, objective), params=params))
activation=config.activation)
config.encoder = network.encoder
config.decoder = network.decoder
config.heads = tools.AttrDict(_unlocked=True)
config.heads.image = config.decoder
size = params.get('model_size', 200)
state_size = params.get('state_size', 30)
model = params.get('model', 'rssm')
if model == 'ssm':
config.cell = tools.bind(
models.SSM, state_size, size,
params.get('mean_only', False),
config.activation,
params.get('min_stddev', 1e-1))
elif model == 'rssm':
config.cell = tools.bind(
models.RSSM, state_size, size, size,
params.get('future_rnn', True),
params.get('mean_only', False),
params.get('min_stddev', 1e-1),
config.activation,
params.get('model_layers', 1))
elif params.model == 'drnn':
config.cell = tools.bind(
models.DRNN, state_size, size, size,
params.get('mean_only', False),
params.get('min_stddev', 1e-1), config.activation,
params.get('drnn_encoder_to_decoder', False),
params.get('drnn_sample_to_sample', True),
params.get('drnn_sample_to_encoder', True),
params.get('drnn_decoder_to_encoder', False),
params.get('drnn_decoder_to_sample', True),