Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
try:
os.mkdir(output)
except Exception:
pass
text_transforms = []
if normalization:
text_transforms.append(lambda x: unicodedata.normalize(normalization, x))
if normalize_whitespace:
text_transforms.append(lambda x: regex.sub('\s', ' ', x))
if reorder:
text_transforms.append(get_display)
idx = 0
manifest = []
with log.progressbar(transcriptions, label='Reading transcriptions') as bar:
for fp in bar:
logger.info('Reading {}'.format(fp.name))
doc = html.parse(fp)
etree.strip_tags(doc, etree.Comment)
td = doc.find(".//meta[@itemprop='text_direction']")
if td is None:
td = 'horizontal-lr'
else:
td = td.attrib['content']
im = None
dest_dict = {'output': output, 'idx': 0, 'src': fp.name, 'uuid': str(uuid.uuid4())}
for section in doc.xpath('//section'):
img = section.xpath('.//img')[0].get('src')
fd = BytesIO(base64.b64decode(img.split(',')[1]))
im = Image.open(fd)
def _get_text(im):
with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp:
return get_display(fp.read())
acc_list = []
for p, net in nn.items():
algn_gt: List[str] = []
algn_pred: List[str] = []
chars = 0
error = 0
message('Evaluating {}'.format(p))
logger.info('Evaluating {}'.format(p))
batch, channels, height, width = net.nn.input
ts = generate_input_transforms(batch, height, width, channels, pad)
with log.progressbar(test_set, label='Evaluating') as bar:
for im_path in bar:
i = ts(Image.open(im_path))
text = _get_text(im_path)
pred = net.predict_string(i)
chars += len(text)
c, algn1, algn2 = global_align(text, pred)
algn_gt.extend(algn1)
algn_pred.extend(algn2)
error += c
acc_list.append((chars-error)/chars)
confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred)
rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs)
logger.info(rep)
message(rep)
logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
for line in lines:
alphabet.update(line)
chars = []
combining = []
for char in sorted(alphabet):
k = make_printable(char)
if k != char:
combining.append(k)
else:
chars.append(k)
message('Σ (len: {})'.format(len(alphabet)))
message('Symbols: {}'.format(''.join(chars)))
if combining:
message('Combining Characters: {}'.format(', '.join(combining)))
lg = linegen.LineGenerator(font, font_size, font_weight, language)
with log.progressbar(lines, label='Writing images') as bar:
for idx, line in enumerate(bar):
logger.info(line)
try:
if renormalize:
im = lg.render_line(unicodedata.normalize(renormalize, line))
else:
im = lg.render_line(line)
except KrakenCairoSurfaceException as e:
logger.info('{}: {} {}'.format(e.message, e.width, e.height))
continue
if not disable_degradation and not legacy:
im = linegen.degrade_line(im, alpha=alpha, beta=beta)
im = linegen.distort_line(im, abs(np.random.normal(distort)), abs(np.random.normal(distortion_sigma)))
elif legacy:
im = linegen.ocropy_degrade(im)
im.save('{}/{:06d}.png'.format(output, idx))
for t in l:
scripts.add(t[0])
it = rpred.mm_rpred(model, im, bounds, pad,
bidi_reordering=bidi_reordering,
script_ignore=script_ignore)
else:
it = rpred.rpred(model['default'], im, bounds, pad,
bidi_reordering=bidi_reordering)
if not lines and no_segmentation:
logger.debug('Removing temporary segmentation file.')
os.unlink(lines.name)
preds = []
with log.progressbar(it, label='Processing', length=len(bounds['boxes'])) as bar:
for pred in bar:
preds.append(pred)
ctx = click.get_current_context()
with open_file(output, 'w', encoding='utf-8') as fp:
fp = cast(IO[Any], fp)
message('Writing recognition results for {}\t'.format(base_image), nl=False)
logger.info('Serializing as {} into {}'.format(ctx.meta['mode'], output))
if ctx.meta['mode'] != 'text':
from kraken import serialization
fp.write(serialization.serialize(preds, base_image,
Image.open(base_image).size,
ctx.meta['text_direction'],
scripts,
ctx.meta['mode']))
else:
#for param in hyper_fields:
# logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param]))
# nn.user_metadata[param] = locals()[param]
trainer = train.KrakenTrainer(model=nn,
optimizer=optim,
device=device,
filename_prefix=output,
event_frequency=freq,
train_set=train_loader,
val_set=val_set,
stopper=st_it)
trainer.add_lr_scheduler(tr_it)
with log.progressbar(label='stage {}/{}'.format(1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'),
length=trainer.event_it, show_pos=True) as bar:
def _draw_progressbar():
bar.update(1)
def _print_eval(epoch, accuracy, chars, error):
message('Accuracy report ({}) {:0.4f} {} {}'.format(epoch, accuracy, chars, error))
# reset progress bar
bar.label = 'stage {}/{}'.format(epoch+1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞')
bar.pos = 0
bar.finished = False
trainer.run(_print_eval, _draw_progressbar)
if quit == 'early':
message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss))
te_im = evaluation_files
else:
te_im = ground_truth[int(len(ground_truth) * partition):]
logger.debug('Taking {} lines from training for evaluation'.format(len(te_im)))
# set multiprocessing tensor sharing strategy
if 'file_system' in torch.multiprocessing.get_all_sharing_strategies():
logger.debug('Setting multiprocessing tensor sharing strategy to file_system')
torch.multiprocessing.set_sharing_strategy('file_system')
gt_set = GroundTruthDataset(normalization=normalization,
whitespace_normalization=normalize_whitespace,
reorder=reorder,
im_transforms=transforms,
preload=preload)
with log.progressbar(tr_im, label='Building training set') as bar:
for im in bar:
logger.debug('Adding line {} to training set'.format(im))
try:
gt_set.add(im)
except FileNotFoundError as e:
logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename))
except KrakenInputException as e:
logger.warning(str(e))
val_set = GroundTruthDataset(normalization=normalization,
whitespace_normalization=normalize_whitespace,
reorder=reorder,
im_transforms=transforms,
preload=preload)
with log.progressbar(te_im, label='Building validation set') as bar:
for im in bar:
with log.progressbar(tr_im, label='Building training set') as bar:
for im in bar:
logger.debug('Adding line {} to training set'.format(im))
try:
gt_set.add(im)
except FileNotFoundError as e:
logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename))
except KrakenInputException as e:
logger.warning(str(e))
val_set = GroundTruthDataset(normalization=normalization,
whitespace_normalization=normalize_whitespace,
reorder=reorder,
im_transforms=transforms,
preload=preload)
with log.progressbar(te_im, label='Building validation set') as bar:
for im in bar:
logger.debug('Adding line {} to validation set'.format(im))
try:
val_set.add(im)
except FileNotFoundError as e:
logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename))
except KrakenInputException as e:
logger.warning(str(e))
logger.info('Training set {} lines, validation set {} lines, alphabet {} symbols'.format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet)))
alpha_diff_only_train = set(gt_set.alphabet).difference(set(val_set.alphabet))
alpha_diff_only_val = set(val_set.alphabet).difference(set(gt_set.alphabet))
if alpha_diff_only_train:
logger.warning('alphabet mismatch: chars in training set only: {} (not included in accuracy test during training)'.format(alpha_diff_only_train))
if alpha_diff_only_val:
logger.warning('alphabet mismatch: chars in validation set only: {} (not trained)'.format(alpha_diff_only_val))
reorder, font_size, font_weight, language, max_length, strip,
disable_degradation, alpha, beta, distort, distortion_sigma,
legacy, output, text):
"""
Generates artificial text line training data.
"""
import errno
import numpy as np
from kraken import linegen
from kraken.lib.util import make_printable
lines: Set[str] = set()
if not text:
return
with log.progressbar(text, label='Reading texts') as bar:
for t in text:
with click.open_file(t, encoding=encoding) as fp:
logger.info('Reading {}'.format(t))
for l in fp:
lines.add(l.rstrip('\r\n'))
if normalization:
lines = set([unicodedata.normalize(normalization, line) for line in lines])
if strip:
lines = set([line.strip() for line in lines])
if max_length:
lines = set([line for line in lines if len(line) < max_length])
logger.info('Read {} lines'.format(len(lines)))
message('Read {} unique lines'.format(len(lines)))
if maxlines and maxlines < len(lines):
message('Sampling {} lines\t'.format(maxlines), nl=False)
llist = list(lines)