Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if opt.use_rec:
train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
opt.rec_val, opt.rec_val_idx,
batch_size, num_workers)
else:
train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)
train_metric = mx.metric.Accuracy()
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
save_frequency = opt.save_frequency
if opt.save_dir and save_frequency:
save_dir = opt.save_dir
makedirs(save_dir)
else:
save_dir = ''
save_frequency = 0
def test(ctx, val_data):
if opt.use_rec:
val_data.reset()
acc_top1.reset()
acc_top5.reset()
for i, batch in enumerate(val_data):
data, label = batch_fn(batch, ctx)
outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
acc_top1.update(label, outputs)
acc_top5.update(label, outputs)
_, top1 = acc_top1.get()
from autogluon import config_choice
from gluoncv.utils import makedirs
def parse_args():
parser = argparse.ArgumentParser(description='Train a model for different kaggle competitions.')
parser.add_argument('--data-dir', type=str, default='/home/ubuntu/workspace/data/dataset/',
help='training and validation pictures to use.')
parser.add_argument('--dataset', type=str, default='dogs-vs-cats-redux-kernels-edition',
help='the kaggle competition')
opt = parser.parse_args()
return opt
opt = parse_args()
# data
local_path = os.path.dirname(__file__)
makedirs(opt.dataset)
logging_file = os.path.join(opt.dataset ,'summary.log')
filehandler = logging.FileHandler(logging_file)
streamhandler = logging.StreamHandler()
logger = logging.getLogger('')
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)
logging.info(opt.dataset)
target = config_choice(opt.dataset, opt.data_dir)
load_dataset = task.Dataset(target['dataset'])
classifier = task.fit(dataset = task.Dataset(target['dataset']),
net = target['net'],
optimizer = target['optimizer'],
epochs = target['epochs'],
if opt.use_rec:
train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
opt.rec_val, opt.rec_val_idx,
batch_size, num_workers)
else:
train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
save_frequency = opt.save_frequency
if opt.save_dir and save_frequency:
save_dir = opt.save_dir
makedirs(save_dir)
else:
save_dir = ''
save_frequency = 0
def test(ctx, val_data):
if opt.use_rec:
val_data.reset()
acc_top1.reset()
acc_top5.reset()
for i, batch in enumerate(val_data):
data, label = batch_fn(batch, ctx)
outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
acc_top1.update(label, outputs)
acc_top5.update(label, outputs)
_, top1 = acc_top1.get()
def download_coco(path, overwrite=False):
_DOWNLOAD_URLS = [
('http://images.cocodataset.org/zips/train2017.zip',
'10ad623668ab00c62c096f0ed636d6aff41faca5'),
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'8551ee4bb5860311e79dace7e79cb91e432e78b3'),
('http://images.cocodataset.org/zips/val2017.zip',
'4950dc9d00dbe1c933ee0170f5797584351d2a41'),
# ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
# '46cdcf715b6b4f67e980b529534e79c2edffe084'),
# test2017.zip, for those who want to attend the competition.
# ('http://images.cocodataset.org/zips/test2017.zip',
# '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'),
]
makedirs(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
# extract
with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
opt.rec_val, opt.rec_val_idx,
batch_size, num_workers)
else:
train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)
if opt.mixup:
train_metric = mx.metric.RMSE()
else:
train_metric = mx.metric.Accuracy()
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
save_frequency = opt.save_frequency
if opt.save_dir and save_frequency:
save_dir = opt.save_dir
makedirs(save_dir)
else:
save_dir = ''
save_frequency = 0
def mixup_transform(label, classes, lam=1, eta=0.0):
if isinstance(label, nd.NDArray):
label = [label]
res = []
for l in label:
y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
res.append(lam*y1 + (1-lam)*y2)
return res
def smooth(label, classes, eta=0.1):
if isinstance(label, nd.NDArray):
iters_per_epoch=num_batches,
step_epoch=lr_decay_epoch,
step_factor=lr_decay, power=2)
])
# optimizer = 'sgd'
# optimizer_params = {'wd': opt.wd, 'momentum': 0.9, 'lr_scheduler': lr_scheduler}
optimizer = 'adam'
optimizer_params = {'wd': opt.wd, 'lr_scheduler': lr_scheduler}
if opt.dtype != 'float32':
optimizer_params['multi_precision'] = True
save_frequency = opt.save_frequency
if opt.save_dir and save_frequency:
save_dir = opt.save_dir
makedirs(save_dir)
else:
save_dir = ''
save_frequency = 0
def train(ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
if opt.use_pretrained_base:
if model_name.startswith('simple'):
net.deconv_layers.initialize(ctx=ctx)
net.final_layer.initialize(ctx=ctx)
elif model_name.startswith('mobile'):
net.upsampling.initialize(ctx=ctx)
else:
net.initialize(mx.init.MSRAPrelu(), ctx=ctx)
def build_rec_process(img_dir, train=False, num_thread=1):
from gluoncv.utils import download, makedirs
rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
makedirs(rec_dir)
prefix = 'train' if train else 'val'
print('Building ImageRecord file for ' + prefix + ' ...')
# to_path = rec_dir
# download lst file and im2rec script
script_path = os.path.join(rec_dir, 'im2rec.py')
script_url = 'https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/im2rec.py'
download(script_url, script_path)
lst_path = os.path.join(rec_dir, prefix + '.lst')
lst_url = 'http://data.mxnet.io/models/imagenet/resnet/' + prefix + '.lst'
download(lst_url, lst_path)
# execution
import sys
cmd = [
opt.rec_val, opt.rec_val_idx,
batch_size, num_workers)
else:
train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)
if opt.mixup:
train_metric = mx.metric.RMSE()
else:
train_metric = mx.metric.Accuracy()
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
save_frequency = opt.save_frequency
if opt.save_dir and save_frequency:
save_dir = opt.save_dir
makedirs(save_dir)
else:
save_dir = ''
save_frequency = 0
def mixup_transform(label, classes, lam=1, eta=0.0):
if isinstance(label, nd.NDArray):
label = [label]
res = []
for l in label:
y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
res.append(lam*y1 + (1-lam)*y2)
return res
def smooth(label, classes, eta=0.1):
if isinstance(label, nd.NDArray):
model_name = opt.model
if model_name.startswith('cifar_wideresnet'):
kwargs = {'classes': classes,
'drop_rate': opt.drop_rate}
else:
kwargs = {'classes': classes}
net = get_model(model_name, **kwargs)
model_name += '_mixup'
if opt.resume_from:
net.load_parameters(opt.resume_from, ctx = context)
optimizer = 'nag'
save_period = opt.save_period
if opt.save_dir and save_period:
save_dir = opt.save_dir
makedirs(save_dir)
else:
save_dir = ''
save_period = 0
plot_name = opt.save_plot_dir
logging_handlers = [logging.StreamHandler()]
if opt.logging_dir:
logging_dir = opt.logging_dir
makedirs(logging_dir)
logging_handlers.append(logging.FileHandler('%s/train_cifar10_%s.log'%(logging_dir, model_name)))
logging.basicConfig(level=logging.INFO, handlers = logging_handlers)
logging.info(opt)
transform_train = transforms.Compose([
def build_rec_process(img_dir, train=False, num_thread=1):
rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
makedirs(rec_dir)
prefix = 'train' if train else 'val'
print('Building ImageRecord file for ' + prefix + ' ...')
to_path = rec_dir
# download lst file and im2rec script
script_path = os.path.join(rec_dir, 'im2rec.py')
script_url = 'https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/im2rec.py'
download(script_url, script_path)
lst_path = os.path.join(rec_dir, prefix + '.lst')
lst_url = 'http://data.mxnet.io/models/imagenet/resnet/' + prefix + '.lst'
download(lst_url, lst_path)
# execution
import sys
cmd = [