Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
elif resize == 'both':
message('Fitting network exactly to training set ', nl=False)
logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet)))
gt_set.encode(None)
ncodec, del_labels = codec.merge(gt_set.codec)
logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels)))
gt_set.encode(ncodec)
nn.resize_output(ncodec.max_label()+1, del_labels)
message('\u2713', fg='green')
else:
raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize))
else:
gt_set.encode(codec)
logger.info('Creating new model {} with {} outputs'.format(spec, gt_set.codec.max_label()+1))
spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1)
nn = vgsl.TorchVGSLModel(spec)
# initialize weights
message('Initializing model ', nl=False)
nn.init_weights()
nn.add_codec(gt_set.codec)
# initialize codec
message('\u2713', fg='green')
# half the number of data loading processes if device isn't cuda and we haven't enabled preloading
if device == 'cpu' and not preload:
loader_threads = threads // 2
else:
loader_threads = threads
train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True)
threads = max(threads-loader_threads, 1)
# don't encode validation set as the alphabets may not match causing encoding failures
@raises(KrakenInputException)
def test_not_binarize_empty(self):
"""
Test that mode '1' images aren't binarized again.
"""
with Image.new('1', (1000,1000)) as im:
nlbin(im)
@raises(KrakenInputException)
def test_rpred_outbounds(self):
"""
Tests correct handling of invalid line coordinates.
"""
nn = load_any(os.path.join(resources, 'toy.clstm'))
pred = rpred(nn, self.im, {'boxes': [[-1, -1, 10000, 10000]], 'text_direction': 'horizontal'}, True)
next(pred)
def test_load_any_proto(self):
"""
Test load_any loads protobuf models.
"""
rnn = models.load_any(os.path.join(resources, 'model.pronn'))
self.assertIsInstance(rnn, kraken.lib.models.TorchSeqRecognizer)
def test_load_invalid(self):
"""
Tests correct handling of invalid files.
"""
models.load_any(self.temp.name)
def test_load_any_pyrnn_py3(self):
"""
Test load_any doesn't load pickled models on python 3
"""
rnn = models.load_any(os.path.join(resources, 'model.pyrnn.gz'))
def test_load_clstm(self):
"""
Tests loading of valid clstm files.
"""
rnn = models.load_any(os.path.join(resources, 'toy.clstm').encode('utf-8'))
self.assertIsInstance(rnn, models.TorchSeqRecognizer)
def test_helper_train(self):
"""
Tests train/eval mode helper methods
"""
rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
rnn.train()
self.assertTrue(torch.is_grad_enabled())
self.assertTrue(rnn.nn.training)
rnn.eval()
self.assertFalse(torch.is_grad_enabled())
self.assertFalse(rnn.nn.training)
def test_save_model(self):
"""
Test model serialization.
"""
rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
with tempfile.TemporaryDirectory() as dir:
rnn.save_model(dir + '/foo.mlmodel')
self.assertTrue(os.path.exists(dir + '/foo.mlmodel'))
def test_del_resize(self):
"""
Tests resizing of output layers with entry deletion.
"""
rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
rnn.resize_output(80, [2, 4, 5, 6, 7, 12, 25])
self.assertEqual(rnn.nn[-1].lin.out_features, 80)