Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
results_file = open(temp_file, 'w')
gold_list = []
pred_list = []
for data_batch_idx, data_batch in enumerate(dataset_iter):
scores = model(data_batch)
if args.dataset == 'EntityDetection':
n_correct += torch.sum(torch.sum(torch.max(scores, 1)[1].view(data_batch.ed.size()).data == data_batch.ed.data, dim=1) \
== data_batch.ed.size()[0]).item()
index_tag = np.transpose(torch.max(scores, 1)[1].view(data_batch.ed.size()).cpu().data.numpy())
tag_array = index2tag[index_tag]
index_question = np.transpose(data_batch.text.cpu().data.numpy())
question_array = index2word[index_question]
gold_list.append(np.transpose(data_batch.ed.cpu().data.numpy()))
gold_array = index2tag[np.transpose(data_batch.ed.cpu().data.numpy())]
pred_list.append(index_tag)
for question, label, gold in zip(question_array, tag_array, gold_array):
results_file.write("{}\t{}\t{}\n".format(" ".join(question), " ".join(label), " ".join(gold)))
else:
print("Wrong Dataset")
exit()
if args.dataset == 'EntityDetection':
P, R, F = evaluation(gold_list, pred_list, index2tag, type=False)
print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R,
100. * F))
else:
print("Wrong dataset")
exit()
results_file.flush()
results_file.close()
n_correct = 0
fname = "{}.txt".format(data_name)
results_file = open(os.path.join(results_path, fname), 'w')
n_retrieved = 0
fid = open(os.path.join(args.data_dir,"lineids_{}.txt".format(data_name)))
sent_id = [x.strip() for x in fid.readlines()]
for data_batch_idx, data_batch in enumerate(dataset_iter):
scores = model(data_batch)
if args.dataset == 'RelationPrediction':
n_correct += torch.sum(torch.max(scores, 1)[1].view(data_batch.relation.size()).data == data_batch.relation.data).item()
# Get top k
top_k_scores, top_k_indices = torch.topk(scores, k=args.hits, dim=1, sorted=True) # shape: (batch_size, k)
top_k_scores_array = top_k_scores.cpu().data.numpy()
top_k_indices_array = top_k_indices.cpu().data.numpy()
top_k_relatons_array = index2tag[top_k_indices_array]
for i, (relations_row, scores_row) in enumerate(zip(top_k_relatons_array, top_k_scores_array)):
index = (data_batch_idx * args.batch_size) + i
example = data_batch.dataset.examples[index]
for j, (rel, score) in enumerate(zip(relations_row, scores_row)):
if (rel == example.relation):
label = 1
n_retrieved += 1
else:
label = 0
results_file.write(
"{} %%%% {} %%%% {} %%%% {}\n".format( sent_id[index], rel, label, score))
else:
print("Wrong Dataset")
exit()
if torch.cuda.is_available():
model = model.cuda()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
max_acc = 0
model.train()
# get batch data
train_pred_label = []
n_iter = 0
for batch in iter(train_iter):
# zero_grad
optimizer.zero_grad()
#forward
probs = model(batch.text)
train_pred_label.extend(torch.max(probs, 2)[1].squeeze(1).cpu().data.numpy())
# compute loss
loss = criterion(probs.view(-1, n_classes), batch.labels.view(-1))
print_loss = loss.item()
# backward
loss.backward()
optimizer.step()
n_iter += 1
print('Batch idx: (%d / %d) loss: %.6f' % (n_iter, args.max_iter, print_loss/len(batch.text)))
train_pred = [list(map(lambda x: label_vocab.itos[x], y)) for y in train_pred_label]
if n_iter % 100 == 0:
val_loss, accuracy = validate(model, val_iter, n_classes, criterion)
if accuracy > max_acc:
max_acc = accuracy
save_model(args, model)
dataset_iter.init_epoch()
n_correct = 0
linenum = 1
fname = "main-{}-results.txt".format(data_name)
results_file = open(os.path.join(args.results_path, fname), 'w')
gold_list = []
pred_list = []
for data_batch_idx, data_batch in enumerate(dataset_iter):
scores = model(data_batch)
n_correct += ((torch.max(scores, 1)[1].view(data_batch.label.size()).data ==
data_batch.label.data).sum(dim=0) == data_batch.label.size()[0]).sum()
index_tag = np.transpose(torch.max(scores, 1)[1].view(data_batch.label.size()).cpu().data.numpy())
tag_array = index2tag[index_tag]
index_question = np.transpose(data_batch.question.cpu().data.numpy())
question_array = index2word[index_question]
# print and write the result
for i in range(data_batch.batch_size):
line_to_print = "{}-{} %%%% {} %%%% {}".format(data_name, linenum, " ".join(question_array[i]), " ".join(tag_array[i]))
# print(line_to_print)
results_file.write(line_to_print + "\n")
linenum += 1
gold_list.append(np.transpose(data_batch.label.cpu().data.numpy()))
pred_list.append(index_tag)
#print("no. correct: {} out of {}".format(n_correct, len(dataset)))
#accuracy = 100. * n_correct / len(dataset)
#print("{} accuracy: {:8.6f}%".format(data_name, accuracy))
loss_num, acc / tot, dev_map, dev_mrr))
qids = []
predictions = []
labels = []
for test_batch_idx, test_batch in enumerate(test_iter):
'''
# dev singlely or in a batch? -> in a batch
but dev singlely is equal to dev_size = 1
'''
scores = pw_model.convModel(test_batch)
# scores = pw_model.linearLayer(scores)
scores = pw_model.predict(scores)
qid_array = np.transpose(test_batch.id.cpu().data.numpy())
score_array = scores.cpu().data.numpy().reshape(-1)
true_label_array = np.transpose(test_batch.label.cpu().data.numpy())
qids.extend(qid_array.tolist())
predictions.extend(score_array.tolist())
labels.extend(true_label_array.tolist())
if args.tensorboard:
writer.add_scalar('{}/dev/map'.format(args.dataset), dev_map, dev_index)
writer.add_scalar('{}/dev/mrr'.format(args.dataset), dev_mrr, dev_index)
writer.add_scalar('{}/lr'.format(args.dataset),
optimizer.param_groups[0]['lr'], dev_index)
writer.add_scalar('{}/train/loss'.format(args.dataset), loss_num, dev_index)
if best_dev_mrr < dev_mrr:
snapshot_path = os.path.join(args.save_path, args.dataset, args.mode + '_best_model.pt')
torch.save(pw_model, snapshot_path)
iters_not_improved = 0
def write_top_k_results(dataset_iter=train_iter, dataset=train, data_name="train"):
print("Dataset: {}".format(data_name))
model.eval();
dataset_iter.init_epoch()
n_correct = 0
n_retrieved = 0
fname = "topk-retrieval-{}-hits-{}.txt".format(data_name, args.hits)
results_file = open(os.path.join(args.results_path, fname), 'w')
for data_batch_idx, data_batch in enumerate(dataset_iter):
scores = model(data_batch)
n_correct += (torch.max(scores, 1)[1].view(data_batch.relation.size()).data == data_batch.relation.data).sum()
# get the predicted top K relations
top_k_scores, top_k_indices = torch.topk(scores, k=args.hits, dim=1, sorted=True) # shape: (batch_size, k)
top_k_indices_array = top_k_indices.cpu().data.numpy()
top_k_scores_array = top_k_scores.cpu().data.numpy()
top_k_relatons_array = index2rel[top_k_indices_array] # shape: (batch_size, k)
# write to file
for i, (relations_row, scores_row) in enumerate(zip(top_k_relatons_array, top_k_scores_array)):
index = (data_batch_idx * args.batch_size) + i
example = data_batch.dataset.examples[index]
# correct_relation = index2rel[test_batch.relation.data[i]]
# results_file.write("{}-{} %%%% {} %%%% {}\n".format(data_name, index+1, " ".join(example.question), example.relation))
found = (False, -1)
for i, (rel, score) in enumerate(zip(relations_row, scores_row)):
# results_file.write("{} %%%% {}\n".format(rel, score))
if (rel == example.relation):
label = 1
n_retrieved += 1
found = (True, i)
def predict(dataset_iter=test_iter, dataset=test, data_name="test"):
print("Dataset: {}".format(data_name))
model.eval();
dataset_iter.init_epoch()
n_correct = 0
linenum = 1
fname = "main-{}-results.txt".format(data_name)
results_file = open(os.path.join(args.results_path, fname), 'w')
for data_batch_idx, data_batch in enumerate(dataset_iter):
scores = model(data_batch)
n_correct += (torch.max(scores, 1)[1].view(data_batch.relation.size()).data == data_batch.relation.data).sum()
# get the predicted relations
top_scores, top_indices = torch.max(scores, dim=1) # shape: (batch_size, 1)
top_indices_array = top_indices.cpu().data.numpy().reshape(-1)
top_scores_array = top_scores.cpu().data.numpy().reshape(-1)
top_relatons_array = index2rel[top_indices_array] # shape: vector of dim: batch_size
# write to file
for i in range(data_batch.batch_size):
line_to_print = "{}-{} %%%% {} %%%% {}".format(data_name, linenum, top_relatons_array[i], top_scores_array[i])
results_file.write(line_to_print + "\n")
linenum += 1
print("no. correct: {} out of {}".format(n_correct, len(dataset)))
accuracy = 100. * n_correct / len(dataset)
print("{} accuracy: {:8.6f}%".format(data_name, accuracy))
print("-" * 80)
results_file.close()