Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_fr_big')
def transformer_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('masked_lm', 'xlm_base')
def xlm_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.share_encoder_input_output_embed = getattr(
args, 'share_encoder_input_output_embed', True)
args.no_token_positional_embeddings = getattr(
args, 'no_token_positional_embeddings', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
args.num_segment = getattr(args, 'num_segment', 1)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
@register_model_architecture(
"levenshtein_transformer", "levenshtein_transformer_wmt_en_de"
)
def levenshtein_transformer_wmt_en_de(args):
levenshtein_base_architecture(args)
@register_model_architecture('lstm_lm', 'lstm_lm_wsj')
def lstm_lm_wsj(args):
base_lm_architecture(args)
@register_model_architecture('fconv', 'fconv')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_attention = getattr(args, 'decoder_attention', 'True')
args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
@register_model_architecture(
"nonautoregressive_transformer", "nonautoregressive_transformer"
)
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
@register_model_architecture('speech_transformer', 'speech_transformer')
def base_architecture(args):
args.encoder_conv_channels = getattr(
args, 'encoder_conv_channels', '[64, 64, 128, 128]',
)
args.encoder_conv_kernel_sizes = getattr(
args, 'encoder_conv_kernel_sizes', '[(3, 3), (3, 3), (3, 3), (3, 3)]',
)
args.encoder_conv_strides = getattr(
args, 'encoder_conv_strides', '[(1, 1), (2, 2), (1, 1), (2, 2)]',
)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
@register_model_architecture('default-captioning-model', 'default-captioning-arch')
def default_captioning_arch(args):
args.encoder_layers = getattr(args, 'encoder_layers', 3)
@register_model_architecture('fconv', 'fconv_wmt_en_de')
def fconv_wmt_en_de(args):
base_architecture(args)
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw')
def lightconv_lm_gbw(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
base_lm_architecture(args)