Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from deeprank import rank_module
class DeepRankNet(rank_module.RankNet):
def __init__(self, config):
super().__init__(config)
self.input_type = 'LL'
self.qw_embedding = nn.Embedding(
config['vocab_size'],
config['dim_weight'],
padding_idx=config['pad_value']
)
self.embedding = nn.Embedding(
config['vocab_size'],
config['embed_dim'],
padding_idx=config['pad_value']
)
self.embedding.weight.requires_grad = config['finetune_embed']
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from deeprank import rank_module
class MatchPyramidNet(rank_module.RankNet):
def __init__(self, config):
super().__init__(config)
self.input_type = 'S'
self.embedding = nn.Embedding(
config['vocab_size'],
config['embed_dim'],
padding_idx=config['pad_value']
)
self.embedding.weight.requires_grad = config['finetune_embed']
cin = config['simmat_channel']
self.conv_layers = []
for cout, h, w in config['conv_params']: