Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
elif resize == 'both':
message('Fitting network exactly to training set ', nl=False)
logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet)))
gt_set.encode(None)
ncodec, del_labels = codec.merge(gt_set.codec)
logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels)))
gt_set.encode(ncodec)
nn.resize_output(ncodec.max_label()+1, del_labels)
message('\u2713', fg='green')
else:
raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize))
else:
gt_set.encode(codec)
logger.info('Creating new model {} with {} outputs'.format(spec, gt_set.codec.max_label()+1))
spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1)
nn = vgsl.TorchVGSLModel(spec)
# initialize weights
message('Initializing model ', nl=False)
nn.init_weights()
nn.add_codec(gt_set.codec)
# initialize codec
message('\u2713', fg='green')
# half the number of data loading processes if device isn't cuda and we haven't enabled preloading
if device == 'cpu' and not preload:
loader_threads = threads // 2
else:
loader_threads = threads
train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True)
threads = max(threads-loader_threads, 1)
# don't encode validation set as the alphabets may not match causing encoding failures
def test_helper_train(self):
"""
Tests train/eval mode helper methods
"""
rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
rnn.train()
self.assertTrue(torch.is_grad_enabled())
self.assertTrue(rnn.nn.training)
rnn.eval()
self.assertFalse(torch.is_grad_enabled())
self.assertFalse(rnn.nn.training)
def test_save_model(self):
"""
Test model serialization.
"""
rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
with tempfile.TemporaryDirectory() as dir:
rnn.save_model(dir + '/foo.mlmodel')
self.assertTrue(os.path.exists(dir + '/foo.mlmodel'))
def test_del_resize(self):
"""
Tests resizing of output layers with entry deletion.
"""
rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
rnn.resize_output(80, [2, 4, 5, 6, 7, 12, 25])
self.assertEqual(rnn.nn[-1].lin.out_features, 80)
Returns:
A kraken.lib.models.TorchSeqRecognizer object.
"""
nn = None
kind = ''
fname = abspath(expandvars(expanduser(fname)))
logger.info(u'Loading model from {}'.format(fname))
try:
nn = TorchVGSLModel.load_model(str(fname))
kind = 'vgsl'
except Exception:
try:
nn = TorchVGSLModel.load_clstm_model(fname)
kind = 'clstm'
except Exception:
nn = TorchVGSLModel.load_pronn_model(fname)
kind = 'pronn'
try:
nn = TorchVGSLModel.load_pyrnn_model(fname)
kind = 'pyrnn'
except Exception:
pass
if not nn:
raise KrakenInvalidModelException('File {} not loadable by any parser.'.format(fname))
seq = TorchSeqRecognizer(nn, train=train, device=device)
seq.kind = kind
return seq