Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self,
rnn_config: rnn.RNNConfig,
prefix=C.BIDIRECTIONALRNN_PREFIX,
layout=C.TIME_MAJOR,
encoder_class: Callable = RecurrentEncoder) -> None:
utils.check_condition(rnn_config.num_hidden % 2 == 0,
"num_hidden must be a multiple of 2 for BiDirectionalRNNEncoders.")
super().__init__(rnn_config.dtype)
self.rnn_config = rnn_config
self.internal_rnn_config = rnn_config.copy(num_hidden=rnn_config.num_hidden // 2)
if layout[0] == 'N':
logger.warning("Batch-major layout for encoder input. Consider using time-major layout for faster speed")
# time-major layout as _encode needs to swap layout for SequenceReverse
self.forward_rnn = encoder_class(rnn_config=self.internal_rnn_config,
prefix=prefix + C.FORWARD_PREFIX,
layout=C.TIME_MAJOR)
self.reverse_rnn = encoder_class(rnn_config=self.internal_rnn_config,
prefix=prefix + C.REVERSE_PREFIX,
layout=C.TIME_MAJOR)
self.layout = layout
self.prefix = prefix
def _test_parameter_averaging(model_path: str):
"""
Runs parameter averaging with all available strategies
"""
for strategy in C.AVERAGE_CHOICES:
points = sockeye.average.find_checkpoints(model_path=model_path,
size=4,
strategy=strategy,
metric=C.PERPLEXITY)
assert len(points) > 0
averaged_params = sockeye.average.average(points)
assert averaged_params
def test_check_condition_true():
utils.check_condition(1 == 1, "Nice")
def _get_later_major_version():
release, major, minor = utils.parse_version(__version__)
return "%s.%d.%s" % (release, int(major) + 1, minor)
def test_expand_requested_device_ids_exception(requested_device_ids, num_gpus_available):
with pytest.raises(ValueError):
utils._expand_requested_device_ids(requested_device_ids, num_gpus_available)
use_source_factor, perplexity_thresh, bleu_thresh):
"""Task: sort short sequences of digits"""
with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _TRAIN_LINE_COUNT_EMPTY, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
_TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
with_source_factors=use_source_factor) as data:
data = check_train_translate(train_params=train_params,
translate_params=translate_params,
data=data,
use_prepared_data=use_prepared_data,
max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
compare_output=True,
seed=seed)
# get best validation perplexity
metrics = sockeye.utils.read_metrics_file(os.path.join(data['model'], C.METRICS_NAME))
perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)
# compute metrics
bleu = sockeye.evaluate.raw_corpus_bleu(hypotheses=data['test_outputs'], references=data['test_targets'])
chrf = sockeye.evaluate.raw_corpus_chrf(hypotheses=data['test_outputs'], references=data['test_targets'])
bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=data['test_outputs_restricted'],
references=data['test_targets'])
logger.info("test: %s", name)
logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
assert perplexity <= perplexity_thresh
assert bleu >= bleu_thresh
assert bleu_restrict >= bleu_thresh
train_max_length = 30
dev_line_count = 20
dev_max_length = 30
expected_mean = 1.0
expected_std = 0.0
test_line_count = 20
test_line_count_empty = 0
test_max_length = 30
batch_size = 5
with tmp_digits_dataset("tmp_corpus",
train_line_count, train_line_count_empty, train_max_length - C.SPACE_FOR_XOS,
dev_line_count, dev_max_length - C.SPACE_FOR_XOS,
test_line_count, test_line_count_empty,
test_max_length - C.SPACE_FOR_XOS) as data:
# tmp common vocab
vcb = vocab.build_from_paths([data['train_source'], data['train_target']])
train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
sources=[data['train_source']],
target=data['train_target'],
validation_sources=[data['dev_source']],
validation_target=data['dev_target'],
source_vocabs=[vcb],
target_vocab=vcb,
source_vocab_paths=[None],
target_vocab_path=None,
shared_vocab=True,
batch_size=batch_size,
batch_by_words=False,
batch_num_devices=1,
max_seq_len_source=train_max_length,
max_seq_len_target=train_max_length,
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import tempfile
import os
import pytest
from sockeye import config
class ConfigTest(config.Config):
yaml_tag = "!ConfigTest"
def __init__(self, param, config=None):
super().__init__()
self.param = param
self.config = config
def test_base_freeze():
c = config.Config()
c.param = 1
assert c.param == 1
c.freeze()
with pytest.raises(AttributeError) as e:
c.param = 2
assert str(e.value) == "Cannot set 'param' in frozen config"
def create_parallel_sentence_iter(source_sentences, target_sentences, max_len, batch_size, batch_by_words):
buckets = sockeye.data_io.define_parallel_buckets(max_len, max_len, 10)
batch_num_devices = 1
eos = 0
pad = 1
unk = 2
bucket_iterator = sockeye.data_io.ParallelBucketSentenceIter(source_sentences,
target_sentences,
buckets,
batch_size,
batch_by_words,
batch_num_devices,
eos, pad, unk)
return bucket_iterator
def test_io_args(test_params, expected_params):
_test_args(test_params, expected_params, arguments.add_training_io_args)