Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, args: Namespace):
super(JTNN, self).__init__()
with open(args.vocab_path) as f:
self.vocab = Vocab([line.strip("\r\n ") for line in f])
self.hidden_size = args.hidden_size
self.depth = args.depth
self.args = args
self.jtnn = MPN(args, atom_fdim=self.hidden_size, bond_fdim=self.hidden_size, graph_input=True)
self.embedding = nn.Embedding(self.vocab.size(), self.hidden_size)
self.mpn = MPN(args)
def __init__(self,
args: Namespace,
atom_fdim: int = None,
bond_fdim: int = None,
graph_input: bool = False,
params: Dict[str, nn.Parameter] = None):
super(MPN, self).__init__()
self.args = args
self.atom_fdim = atom_fdim or get_atom_fdim(args)
self.bond_fdim = bond_fdim or get_bond_fdim(args) + (not args.atom_messages) * self.atom_fdim
self.graph_input = graph_input
self.encoder = MPNEncoder(self.args, self.atom_fdim, self.bond_fdim, params=params)
def create_encoder(self, args: Namespace, params: Dict[str, nn.Parameter] = None):
if args.jtnn:
if params is not None:
raise ValueError('Setting parameters not yeet supported for JTNN')
self.encoder = JTNN(args)
elif args.dataset_type == 'bert_pretraining':
self.encoder = MPN(args, graph_input=True, params=params)
else:
self.encoder = MPN(args, params=params)
if args.freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
if args.gradual_unfreezing:
self.create_unfreeze_queue(args)
def __init__(self, args):
super(MOE, self).__init__()
self.args = args
self.num_sources = args.num_sources
self.classifiers = nn.ModuleList([Classifier(args) for _ in range(args.num_sources)])
self.encoder = MPN(args)
self.mmd = MMD(args)
self.Us = nn.ParameterList(
[nn.Parameter(torch.zeros((args.hidden_size, args.m_rank)), requires_grad=True) for _ in
range(args.num_sources)])
# note zeros are replaced during initialization later
if args.dataset_type == 'regression':
self.mtl_criterion = nn.MSELoss(reduction='none')
self.moe_criterion = nn.MSELoss(reduction='none')
elif args.dataset_type == 'classification': # this half untested
self.mtl_criterion = nn.BCELoss(reduction='none')
self.moe_criterion = nn.BCELoss(reduction='none')
self.entropy_criterion = HLoss()
self.lambda_moe = args.lambda_moe
self.lambda_critic = args.lambda_critic
self.lambda_entropy = args.lambda_entropy
def create_encoder(self, args: Namespace, params: Dict[str, nn.Parameter] = None):
if args.jtnn:
if params is not None:
raise ValueError('Setting parameters not yeet supported for JTNN')
self.encoder = JTNN(args)
elif args.dataset_type == 'bert_pretraining':
self.encoder = MPN(args, graph_input=True, params=params)
else:
self.encoder = MPN(args, params=params)
if args.freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
if args.gradual_unfreezing:
self.create_unfreeze_queue(args)
def __init__(self, args: Namespace):
super(JTNN, self).__init__()
with open(args.vocab_path) as f:
self.vocab = Vocab([line.strip("\r\n ") for line in f])
self.hidden_size = args.hidden_size
self.depth = args.depth
self.args = args
self.jtnn = MPN(args, atom_fdim=self.hidden_size, bond_fdim=self.hidden_size, graph_input=True)
self.embedding = nn.Embedding(self.vocab.size(), self.hidden_size)
self.mpn = MPN(args)