Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
dtype: str = C.DTYPE_FP32) -> None:
super().__init__()
self.num_embed = num_embed
self.output_dim = output_dim
self.max_filter_width = max_filter_width
self.num_filters = num_filters
self.pool_stride = pool_stride
self.num_highway_layers = num_highway_layers
self.dropout = dropout
self.add_positional_encoding = add_positional_encoding
if self.output_dim is None:
self.output_dim = sum(self.num_filters)
self.dtype = dtype
class ConvolutionalEmbeddingEncoder(Encoder):
"""
An encoder developed to map a sequence of character embeddings to a shorter sequence of segment
embeddings using convolutional, pooling, and highway layers. More generally, it maps a sequence
of input embeddings to a sequence of span embeddings.
* "Fully Character-Level Neural Machine Translation without Explicit Segmentation"
Jason Lee; Kyunghyun Cho; Thomas Hofmann (https://arxiv.org/pdf/1610.03017.pdf)
:param config: Convolutional embedding config.
:param prefix: Name prefix for symbols of this encoder.
"""
def __init__(self,
config: ConvolutionalEmbeddingConfig,
prefix: str = C.CHAR_SEQ_ENCODER_PREFIX) -> None:
utils.check_condition(len(config.num_filters) == config.max_filter_width,
def get_positional_embedding(positional_embedding_type: str,
num_embed: int,
max_seq_len: int,
fixed_pos_embed_scale_up_input: bool = False,
fixed_pos_embed_scale_down_positions: bool = False,
prefix: str = '') -> PositionalEncoder:
cls, encoder_params = _get_positional_embedding_params(positional_embedding_type,
num_embed,
max_seq_len,
fixed_pos_embed_scale_up_input,
fixed_pos_embed_scale_down_positions,
prefix)
return cls(**encoder_params)
class EncoderSequence(Encoder):
"""
A sequence of encoders is itself an encoder.
:param encoders: List of encoders.
:param dtype: Data type.
"""
def __init__(self, encoders: List[Encoder], dtype: str = C.DTYPE_FP32) -> None:
super().__init__(dtype)
self.encoders = encoders
def encode(self,
data: mx.sym.Symbol,
data_length: mx.sym.Symbol,
seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]:
"""
:return: Instance of Encoder.
"""
params = dict(kwargs)
if infer_hidden:
params['num_hidden'] = self.get_num_hidden()
sig_params = inspect.signature(cls.__init__).parameters
if 'dtype' in sig_params and 'dtype' not in kwargs:
params['dtype'] = self.dtype
encoder = cls(**params)
self.encoders.append(encoder)
return encoder
class EmptyEncoder(Encoder):
"""
This encoder ignores the input data and simply returns zero-filled states in the expected shape.
:param config: configuration.
"""
def __init__(self,
config: EmptyEncoderConfig) -> None:
super().__init__(config.dtype)
self.num_embed = config.num_embed
self.num_hidden = config.num_hidden
def encode(self,
data: mx.sym.Symbol,
data_length: Optional[mx.sym.Symbol],
seq_len: int,
metadata=None) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]:
data = mx.sym.FullyConnected(data=data,
num_hidden=self.config.cnn_config.num_hidden,
no_bias=True,
flatten=False,
weight=self.i2h_weight)
# Multiple layers with residual connections:
for layer in self.layers:
data = data + layer(data, data_length, seq_len)
return data, data_length, seq_len
def get_num_hidden(self) -> int:
return self.config.cnn_config.num_hidden
class GraphConvolutionEncoder(Encoder):
def __init__(self,
config: gcn.GCNConfig,
prefix: str = C.GCN_PREFIX):
super().__init__(config.dtype)
self._gcn = gcn.get_gcn(config, prefix)
self._num_hidden = config.output_dim
def encode(self,
data: mx.sym.Symbol,
data_length: Optional[mx.sym.Symbol],
seq_len: int,
metadata=None):
adj = metadata[0]
outputs = self._gcn.convolve(adj, data, seq_len)
raise NotImplementedError()
def get_encoded_seq_len(self, seq_len: int) -> int:
"""
:return: The size of the encoded sequence.
"""
return seq_len
def get_max_seq_len(self) -> Optional[int]:
"""
:return: The maximum length supported by the encoder if such a restriction exists.
"""
return None
class ConvertLayout(Encoder):
"""
Converts batch major data to time major by swapping the first dimension and setting the __layout__ attribute.
:param target_layout: The target layout to convert to (C.BATCH_MAJOR or C.TIMEMAJOR).
:param num_hidden: The number of hidden units of the previous encoder.
:param dtype: Data type.
"""
def __init__(self, target_layout: str, num_hidden: int, dtype: str = C.DTYPE_FP32) -> None:
assert target_layout == C.BATCH_MAJOR or target_layout == C.TIME_MAJOR
super().__init__(dtype)
self.num_hidden = num_hidden
self.target_layout = target_layout
def encode(self,
data: mx.sym.Symbol,
raise NotImplementedError()
def get_encoded_seq_len(self, seq_len: int) -> int:
"""
:return: The size of the encoded sequence.
"""
return seq_len
def get_max_seq_len(self) -> Optional[int]:
"""
:return: The maximum length supported by the encoder if such a restriction exists.
"""
return None
class ConvertLayout(Encoder):
"""
Converts batch major data to time major by swapping the first dimension and setting the __layout__ attribute.
:param target_layout: The target layout to convert to (C.BATCH_MAJOR or C.TIMEMAJOR).
:param num_hidden: The number of hidden units of the previous encoder.
:param dtype: Data type.
"""
def __init__(self, target_layout: str, num_hidden: int, dtype: str = C.DTYPE_FP32) -> None:
assert target_layout == C.BATCH_MAJOR or target_layout == C.TIME_MAJOR
super().__init__(dtype)
self.num_hidden = num_hidden
self.target_layout = target_layout
def encode(self,
data: mx.sym.Symbol,
:return: Instance of Encoder.
"""
params = dict(kwargs)
if infer_hidden:
params['num_hidden'] = self.get_num_hidden()
sig_params = inspect.signature(cls.__init__).parameters
if 'dtype' in sig_params and 'dtype' not in kwargs:
params['dtype'] = self.dtype
encoder = cls(**params)
self.encoders.append(encoder)
return encoder
class EmptyEncoder(Encoder):
"""
This encoder ignores the input data and simply returns zero-filled states in the expected shape.
:param config: configuration.
"""
def __init__(self,
config: EmptyEncoderConfig) -> None:
super().__init__(config.dtype)
self.num_embed = config.num_embed
self.num_hidden = config.num_hidden
def encode(self,
data: mx.sym.Symbol,
data_length: Optional[mx.sym.Symbol],
seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]:
"""
def encode(self,
data: mx.sym.Symbol,
data_length: Optional[mx.sym.Symbol],
seq_len: int,
metadata=None):
adj = metadata[0]
outputs = self._gcn.convolve(adj, data, seq_len)
return outputs, data_length, seq_len
def get_num_hidden(self) -> int:
return self._num_hidden
class TransformerEncoder(Encoder):
"""
Non-recurrent encoder based on the transformer architecture in:
Attention Is All You Need, Figure 1 (left)
Vaswani et al. (https://arxiv.org/pdf/1706.03762.pdf).
:param config: Configuration for transformer encoder.
:param prefix: Name prefix for operations in this encoder.
"""
def __init__(self,
config: transformer.TransformerConfig,
prefix: str = C.TRANSFORMER_ENCODER_PREFIX) -> None:
super().__init__(config.dtype)
self.config = config
self.prefix = prefix
return hidden_concat
def get_num_hidden(self) -> int:
"""
Return the representation size of this encoder.
"""
return self.rnn_config.num_hidden
def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]:
"""
Returns a list of RNNCells used by this encoder.
"""
return self.forward_rnn.get_rnn_cells() + self.reverse_rnn.get_rnn_cells()
class ConvolutionalEncoder(Encoder):
"""
Encoder that uses convolution instead of recurrent connections, similar to Gehring et al. 2017.
:param config: Configuration for convolutional encoder.
:param prefix: Name prefix for operations in this encoder.
"""
def __init__(self,
config: ConvolutionalEncoderConfig,
prefix: str = C.CNN_ENCODER_PREFIX) -> None:
super().__init__(config.dtype)
self.config = config
# initialize the weights of the linear transformation required for the residual connections
self.i2h_weight = mx.sym.Variable('%si2h_weight' % prefix)
:param data_length: Vector with sequence lengths.
:param seq_len: Maximum sequence length.
:return: Expected number of empty states (zero-filled).
"""
# outputs: (batch_size, seq_len, num_hidden)
outputs = mx.sym.dot(data, mx.sym.zeros((self.num_embed, self.num_hidden)))
return outputs, data_length, seq_len
def get_num_hidden(self):
"""
Return the representation size of this encoder.
"""
return self.num_hidden
class RecurrentEncoder(Encoder):
"""
Uni-directional (multi-layered) recurrent encoder.
:param rnn_config: RNN configuration.
:param prefix: Prefix for variable names.
:param layout: Data layout.
"""
def __init__(self,
rnn_config: rnn.RNNConfig,
prefix: str = C.STACKEDRNN_PREFIX,
layout: str = C.TIME_MAJOR) -> None:
super().__init__(rnn_config.dtype)
self.rnn_config = rnn_config
self.layout = layout
self.rnn = rnn.get_stacked_rnn(rnn_config, prefix)