How to use the sru.SRUCell function in sru

To help you get started, we’ve selected a few sru examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github asappresearch / sru / misc / test_cpu_impl.py View on Github external
def profile_speed():
    bcell = SRUCell(400, 200, bidirectional=True)
    bcell.eval()
    mask = torch.zeros(200, 1)
    x = torch.randn(200, 1, 400)
    pr = cProfile.Profile()
    pr.enable()
    with torch.no_grad():
        for i in range(10):
             r = bcell(x, mask_pad=mask)
    pr.disable()
    s = io.StringIO()
    sortby = 'cumulative'
    ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
    ps.print_stats()
    print(s.getvalue())

    pr = cProfile.Profile()
github asappresearch / sru / misc / test_cpu_impl.py View on Github external
def test_bi_fwd():
    cell = SRUCell(5, 5, bidirectional=True)
    x = torch.randn(7, 1, 5)
    mask = torch.zeros(7, 1)
    mask[0,0]=1
    mask[6,0]=1
    with torch.no_grad():
        out_1 = cell(x)
    out_2 = cell(x)
    print (out_1)
    print ()
    print (out_2)
github asappresearch / sru / misc / test_cpu_impl.py View on Github external
def test_fwd():
    cell = SRUCell(3, 5, use_tanh=True)
    mask = torch.zeros(7, 1)
    mask[0,0]=1
    mask[6,0]=1
    x = torch.randn(7, 1, 3)
    with torch.no_grad():
        out_1 = cell(x, mask_pad=mask)
    out_2 = cell(x, mask_pad=mask)
    print (out_1)
    print ()
    print (out_2)
github asappresearch / flambe / flambe / nn / transformer_sru.py View on Github external
dim_feedforward : int, optional
            The dimension of the feedforward network (default=2048).
        dropout : float, optional
            The dropout value (default=0.1).
        sru_dropout: float, optional
            Dropout for the SRU cell. If not given, uses the same
            dropout value as the rest of the transformer.

        Extra keyword arguments are passed to the SRUCell.

        """
        super().__init__()

        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.sru = SRUCell(d_model,
                           dim_feedforward,
                           dropout,
                           sru_dropout or dropout,
                           bidirectional=False,
                           has_skip_term=False, **kwargs)

        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)