Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_loaders(transform):
open_fn = lambda x: {"features": x[0], "targets": x[1]}
loaders = collections.OrderedDict()
train_loader = utils.get_loader(
train_data,
open_fn=open_fn,
dict_transform=transform,
batch_size=bs,
num_workers=num_workers,
shuffle=True
)
valid_loader = utils.get_loader(
valid_data,
open_fn=open_fn,
dict_transform=transform,
batch_size=bs,
num_workers=num_workers,
shuffle=False
)
loaders["train"] = train_loader
loaders["valid"] = valid_loader
return loaders
model_params_ = utils.process_model_params(
model[model_key_], layerwise_params, no_bias_weight_decay,
lr_scaling
)
model_params.extend(model_params_)
else:
raise ValueError("unknown type of model_params")
load_from_previous_stage = \
params.pop("load_from_previous_stage", False)
optimizer_key = params.pop("optimizer_key", None)
optimizer = OPTIMIZERS.get_from_params(**params, params=model_params)
if load_from_previous_stage and self.stages.index(stage) != 0:
checkpoint_path = f"{self.logdir}/checkpoints/best_full.pth"
checkpoint = utils.load_checkpoint(checkpoint_path)
dict2load = optimizer
if optimizer_key is not None:
dict2load = {optimizer_key: optimizer}
utils.unpack_checkpoint(checkpoint, optimizer=dict2load)
# move optimizer to device
device = utils.get_device()
for param in model_params:
param = param["params"][0]
state = optimizer.state[param]
for key, value in state.items():
state[key] = utils.any2device(value, device)
# update optimizer params
for key, value in params.items():
def get_optimizer(
self,
stage: str,
model: nn.Module
) -> _Optimizer:
model_params = utils.get_optimizable_params(model.parameters())
optimizer_params = \
self.stages_config[stage].get("optimizer_params", {})
optimizer = self._get_optimizer(
model_params=model_params, **optimizer_params)
return optimizer
parser.add_argument(
"--batch-size",
type=int,
dest="batch_size",
help="Dataloader batch size",
default=128
)
parser.add_argument(
"--verbose",
dest="verbose",
action="store_true",
default=False,
help="Print additional information"
)
parser.add_argument("--seed", type=int, default=42)
utils.boolean_flag(
parser, "deterministic",
default=None,
help="Deterministic mode if running in CuDNN backend"
)
utils.boolean_flag(
parser, "benchmark",
default=None,
help="Use CuDNN benchmark"
)
return parser
def _get_logdir(self, config: Dict) -> str:
timestamp = utils.get_utcnow_time()
config_hash = utils.get_short_hash(config)
logdir = f"{timestamp}.{config_hash}"
distributed_rank = self.distributed_params.get("rank", -1)
if distributed_rank > -1:
logdir = f"{logdir}.rank{distributed_rank:02d}"
return logdir
utils.save_checkpoint(
logdir=Path(f"{logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=f"{suffix}_full",
is_best=is_best,
is_last=True,
special_suffix="_full"
)
exclude = ["criterion", "optimizer", "scheduler"]
checkpoint = {
key: value
for key, value in checkpoint.items()
if all(z not in key for z in exclude)
}
filepath = utils.save_checkpoint(
checkpoint=checkpoint,
logdir=Path(f"{logdir}/checkpoints/"),
suffix=suffix,
is_best=is_best,
is_last=True
)
valid_metrics = checkpoint["valid_metrics"]
checkpoint_metric = valid_metrics[main_metric]
metrics_record = (filepath, checkpoint_metric, valid_metrics)
self.top_best_metrics.append(metrics_record)
self.epochs_metrics.append(metrics_record)
self.truncate_checkpoints(minimize_metric=minimize_metric)
metrics = self.get_metric(valid_metrics)
self.save_metric(logdir, metrics)
def on_exception(self, state: RunnerState):
exception = state.exception
if not is_exception(exception):
return
try:
valid_metrics = state.metrics.valid_values
epoch_metrics = state.metrics.epoch_values
checkpoint = utils.pack_checkpoint(
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
epoch_metrics=epoch_metrics,
valid_metrics=valid_metrics,
stage=state.stage,
epoch=state.epoch_log,
checkpoint_data=state.checkpoint_data
)
suffix = self.get_checkpoint_suffix(checkpoint)
suffix = f"{suffix}.exception_{exception.__class__.__name__}"
utils.save_checkpoint(
logdir=Path(f"{state.logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=suffix,
def main(args, _=None):
global IMG_SIZE
utils.set_global_seed(args.seed)
utils.prepare_cudnn(args.deterministic, args.benchmark)
IMG_SIZE = (args.img_size, args.img_size)
if args.traced_model is not None:
device = utils.get_device()
model = torch.jit.load(str(args.traced_model), map_location=device)
else:
model = ResnetEncoder(arch=args.arch, pooling=args.pooling)
model = model.eval()
model, _, _, _, device = utils.process_components(model=model)
images_df = pd.read_csv(args.in_csv)
images_df = images_df.reset_index().drop("index", axis=1)
images_df = list(images_df.to_dict("index").values())
open_fn = ImageReader(
input_key=args.img_col, output_key="image", datapath=args.datapath
)
dataloader = utils.get_loader(
images_df,
def on_batch_end(self, runner: IRunner):
"""Save batch of images.
Args:
runner (IRunner): current runner
"""
names = runner.input[self.outpath_key]
images = runner.input[self.input_key]
images = utils.tensor_to_ndimage(images.detach().cpu(), dtype=np.uint8)
for image, name in zip(images, names):
fname = self._get_image_path(runner.logdir, name)
imageio.imwrite(fname, image)
def _get_stages_config(self, stages_config):
stages_defaults = {}
stages_config_out = OrderedDict()
for key in self.STAGE_KEYWORDS:
stages_defaults[key] = deepcopy(stages_config.get(key, {}))
for stage in stages_config:
if stage in self.STAGE_KEYWORDS \
or stages_config.get(stage) is None:
continue
stages_config_out[stage] = {}
for key in self.STAGE_KEYWORDS:
stages_config_out[stage][key] = utils.merge_dicts(
deepcopy(stages_defaults.get(key, {})),
deepcopy(stages_config[stage].get(key, {})),
)
return stages_config_out