Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def profile_speed():
bcell = SRUCell(400, 200, bidirectional=True)
bcell.eval()
mask = torch.zeros(200, 1)
x = torch.randn(200, 1, 400)
pr = cProfile.Profile()
pr.enable()
with torch.no_grad():
for i in range(10):
r = bcell(x, mask_pad=mask)
pr.disable()
s = io.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())
pr = cProfile.Profile()
def test_bi_fwd():
cell = SRUCell(5, 5, bidirectional=True)
x = torch.randn(7, 1, 5)
mask = torch.zeros(7, 1)
mask[0,0]=1
mask[6,0]=1
with torch.no_grad():
out_1 = cell(x)
out_2 = cell(x)
print (out_1)
print ()
print (out_2)
def test_fwd():
cell = SRUCell(3, 5, use_tanh=True)
mask = torch.zeros(7, 1)
mask[0,0]=1
mask[6,0]=1
x = torch.randn(7, 1, 3)
with torch.no_grad():
out_1 = cell(x, mask_pad=mask)
out_2 = cell(x, mask_pad=mask)
print (out_1)
print ()
print (out_2)
dim_feedforward : int, optional
The dimension of the feedforward network (default=2048).
dropout : float, optional
The dropout value (default=0.1).
sru_dropout: float, optional
Dropout for the SRU cell. If not given, uses the same
dropout value as the rest of the transformer.
Extra keyword arguments are passed to the SRUCell.
"""
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.sru = SRUCell(d_model,
dim_feedforward,
dropout,
sru_dropout or dropout,
bidirectional=False,
has_skip_term=False, **kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)