Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main():
# Args
args = get_args()
# Context
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
logger.info(ctx)
nn.set_default_context(ctx)
nn.set_auto_forward(True)
# Monitor
monitor = Monitor(args.monitor_path)
# Validation
logger.info("Start validation")
num_images = args.valid_samples
num_batches = num_images // args.batch_size
# DataIterator
di = data_iterator(args.img_path, args.batch_size,
imsize=(args.imsize, args.imsize),
num_samples=args.valid_samples,
dataset_name=args.dataset_name)
distributed = args.distributed
compute_acc = args.compute_acc
if distributed:
# Communicator and Context
from nnabla.ext_utils import get_extension_context
extension_module = "cudnn"
ctx = get_extension_context(
extension_module, type_config=args.type_config)
comm = C.MultiProcessDataParalellCommunicator(ctx)
comm.init()
n_devices = comm.size
mpi_rank = comm.rank
device_id = mpi_rank
ctx.device_id = str(device_id)
nn.set_default_context(ctx)
else:
# Get context.
from nnabla.ext_utils import get_extension_context
extension_module = args.context
if args.context is None:
extension_module = 'cpu'
logger.info("Running in %s" % extension_module)
ctx = get_extension_context(
extension_module, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)
n_devices = 1
device_id = 0
# training data
data = data_iterator_segmentation(
args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height)
max_iteration = args.max_iteration
lr_decay_interval = args.lr_decay_interval
lr_decay = args.lr_decay
iter_per_epoch = args.iter_per_epoch
iter_per_valid = args.iter_per_valid
n_episode_for_valid = args.n_episode_for_valid
n_episode_for_test = args.n_episode_for_test
work_dir = args.work_dir
# Set context
from nnabla.ext_utils import get_extension_context
logger.info("Running in %s" % args.context)
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)
# Monitor outputs
from nnabla.monitor import Monitor, MonitorSeries
monitor = Monitor(args.work_dir)
monitor_loss = MonitorSeries(
"Training loss", monitor, interval=iter_per_epoch)
monitor_valid_err = MonitorSeries(
"Validation error", monitor, interval=iter_per_valid)
monitor_test_err = MonitorSeries("Test error", monitor)
monitor_test_conf = MonitorSeries("Test error confidence", monitor)
# Output files
param_file = work_dir + "params.h5"
tsne_file = work_dir + "tsne.png"
# Load data
* Execute forwardprop on the training graph.
* Compute training error
* Set parameter gradients zero
* Execute backprop.
* Solver updates parameters by using gradients computed by backprop.
"""
args = get_args()
# Get context.
from nnabla.contrib.context import extension_context
extension_module = args.context
if args.context is None:
extension_module = 'cpu'
logger.info("Running in %s" % extension_module)
ctx = extension_context(extension_module, device_id=args.device_id)
nn.set_default_context(ctx)
# Create CNN network for both training and testing.
mnist_cnn_prediction = mnist_lenet_prediction
if args.net == 'resnet':
mnist_cnn_prediction = mnist_resnet_prediction
# TRAIN
# Create input variables.
image = nn.Variable([args.batch_size, 1, 28, 28])
label = nn.Variable([args.batch_size, 1])
# Create prediction graph.
pred = mnist_cnn_prediction(image, test=False)
pred.persistent = True
# Create loss function.
loss = F.mean(F.softmax_cross_entropy(pred, label))
def generate(args):
# Communicator and Context
extension_module = "cudnn"
ctx = get_extension_context(extension_module, type_config=args.type_config)
nn.set_default_context(ctx)
# Args
latent = args.latent
maps = args.maps
batch_size = args.batch_size
image_size = args.image_size
n_classes = args.n_classes
not_sn = args.not_sn
threshold = args.truncation_threshold
# Model
nn.load_parameters(args.model_load_path)
z = nn.Variable([batch_size, latent])
y_fake = nn.Variable([batch_size])
x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
.apply(persistent=True)
def train():
args = get_args()
# Get context.
from nnabla.ext_utils import get_extension_context
logger.info("Running in %s" % args.context)
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)
# Create CNN network for both training and testing.
if args.net == "cifar10_resnet23_prediction":
model_prediction = cifar10_resnet23_prediction
# TRAIN
maps = 64
data_iterator = data_iterator_cifar10
c = 3
h = w = 32
n_train = 50000
n_valid = 10000
# Create input variables.
image = nn.Variable([args.batch_size, c, h, w])
label = nn.Variable([args.batch_size, 1])
def main():
args = get_args()
rng = np.random.RandomState(1223)
# Get context
from nnabla.ext_utils import get_extension_context
logger.info("Running in %s" % args.context)
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)
miou = validate(args)
* Computate error rate for validation data (periodically)
* Get a next minibatch.
* Set parameter gradients zero
* Execute forwardprop on the training graph.
* Execute backprop.
* Solver updates parameters by using gradients computed by backprop.
* Compute training error
"""
args = get_args(monitor_path='tmp.monitor.bnn')
# Get context.
from nnabla.ext_utils import get_extension_context
logger.info("Running in %s" % args.context)
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)
# Initialize DataIterator for MNIST.
data = data_iterator_mnist(args.batch_size, True)
vdata = data_iterator_mnist(args.batch_size, False)
# Create CNN network for both training and testing.
mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
if args.net == 'bincon':
mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
elif args.net == 'binnet':
mnist_cnn_prediction = mnist_binary_net_lenet_prediction
elif args.net == 'bwn':
mnist_cnn_prediction = mnist_binary_weight_lenet_prediction
elif args.net == 'bincon_resnet':
mnist_cnn_prediction = mnist_binary_connect_resnet_prediction
elif args.net == 'binnet_resnet':
def morph(args):
# Communicator and Context
extension_module = "cudnn"
ctx = get_extension_context(extension_module, type_config=args.type_config)
nn.set_default_context(ctx)
# Args
latent = args.latent
maps = args.maps
batch_size = args.batch_size
image_size = args.image_size
n_classes = args.n_classes
not_sn = args.not_sn
threshold = args.truncation_threshold
# Model
nn.load_parameters(args.model_load_path)
z = nn.Variable([batch_size, latent])
alpha = nn.Variable.from_numpy_array(np.zeros([1, 1]))
beta = (nn.Variable.from_numpy_array(np.ones([1, 1])) - alpha)
y_fake_a = nn.Variable([batch_size])