Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def run_display_test(self, kwargs):
f = io.StringIO()
with redirect_stdout(f):
parser = setup_args()
parser.set_defaults(**kwargs)
opt = parser.parse_args()
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
display(opt)
str_output = f.getvalue()
self.assertTrue(
'[ loaded {} episodes with a total of {} examples ]'.format(
world.num_episodes(), world.num_examples()
)
in str_output,
'Wizard of Wikipedia failed with following args: {}'.format(opt),
)
def get_acts_epochs_1_and_2(defaults):
parser.set_defaults(**defaults)
opt = parser.parse_args()
build_dict(opt)
agent = create_agent(opt)
world_data = create_task(opt, agent)
acts_epoch_1 = []
acts_epoch_2 = []
while not world_data.epoch_done():
world_data.parley()
acts_epoch_1.append(world_data.acts[0])
world_data.reset()
while not world_data.epoch_done():
world_data.parley()
acts_epoch_2.append(world_data.acts[0])
acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
acts_epoch_1 = sorted(
[b for b in acts_epoch_1 if 'text' in b], key=lambda x: x.get('text')
)
acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
acts_epoch_2 = sorted(
[b for b in acts_epoch_2 if 'text' in b], key=lambda x: x.get('text')
def get_word_counts(opt, count_inputs):
"""Goes through the dataset specified in opt, returns word counts and all utterances
Inputs:
count_inputs: If True, include both input and reply when counting words and
utterances. Otherwise, only include reply text.
Returns:
word_counter: a Counter mapping each word to the total number of times it appears
total_count: int. total word count, i.e. the sum of the counts for each word
all_utts: list of strings. all the utterances that were used for counting words
"""
# Create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
# Count word frequency for all words in dataset
word_counter = Counter()
total_count = 0
all_utts = []
log_timer = TimeLogger()
while True:
world.parley()
# Count words in reply
reply = world.acts[0].get('labels', world.acts[0].get('eval_labels'))[0]
words = reply.split()
word_counter.update(words)
total_count += len(words)
all_utts.append(reply)
def load_personas(opt):
print('[ loading personas.. ]')
# Create ConvAI2 data so we can assign personas.
convai2_opt = opt.copy()
convai2_opt['task'] = 'convai2:both'
if convai2_opt['datatype'].startswith('train'):
convai2_opt['datatype'] = 'train:evalmode'
convai2_opt['interactive_task'] = False
convai2_agent = RepeatLabelAgent(convai2_opt)
convai2_world = create_task(convai2_opt, convai2_agent)
personas = set()
while not convai2_world.epoch_done():
convai2_world.parley()
msg = convai2_world.get_acts()[0]
# Find a new episode
if msg.get('episode_done', False) and not convai2_world.epoch_done():
convai2_world.parley()
msg = convai2_world.get_acts()[0]
txt = msg.get('text', '').split('\n')
a1_persona = ""
a2_persona = ""
for t in txt:
if t.startswith("partner's persona:"):
a1_persona += (
t.replace("partner's persona:", 'your persona:') + '\n'
)
def verify(opt, printargs=None, print_parser=None):
if opt['datatype'] == 'train':
print("[ note: changing datatype from train to train:ordered ]")
opt['datatype'] = 'train:ordered'
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
log_every_n_secs = opt.get('log_every_n_secs', -1)
if log_every_n_secs <= 0:
log_every_n_secs = float('inf')
log_time = TimeLogger()
dictionary = DictionaryAgent(opt)
ignore_tokens = opt.get('ignore_tokens').split(',')
counts = {}
for t in {'input', 'labels', 'both'}:
counts['tokens_in_' + t] = 0
counts['utterances_in_' + t] = 0
counts['avg_utterance_length_in_' + t] = 0
counts['unique_tokens_in_' + t] = 0
counts['unique_utterances_in_' + t] = 0
def main(opt, return_full=False, out_file=None):
data_agent = prepro(opt)
model_agent = (
ObjectChecklistModelAgent(opt, data_agent=data_agent)
if not opt['seq2seq']
else Seq2SeqModelAgent(opt, data_agent=data_agent)
)
train_world = create_task(opt, model_agent)
max_dict = model_agent.model.state_dict()
max_acc, max_f1, max_data, last_max, max_acc_len = -1, 0, None, 0, None
for iter in range(opt['max_iter']):
if iter - last_max > opt['exit_iter']:
break
if 'inc_ratio' in opt and opt['inc_ratio'] > 0 and iter == opt['inc_pre_iters']:
print('resetting best model for finetuning')
model_agent.model.load_state_dict(max_dict)
max_acc = 0
train_world.parley()
train_report = train_world.report()
def main():
# Get command line arguments
parser = ParlaiParser(True, True)
parser.add_argument('-n', '--num-examples', default=10)
opt = parser.parse_args()
agent = Agent(opt)
opt['datatype'] = 'train'
world_train = create_task(opt, agent)
opt['datatype'] = 'valid'
world_valid = create_task(opt, agent)
start = time.time()
# train / valid loop
for _ in range(1):
print('[ training ]')
for _ in range(10): # train for a bit
world_train.parley()
print('[ training summary. ]')
print(world_train.report())
print('[ validating ]')
for _ in range(1): # check valid accuracy
build_dict.setup_args(parser)
opt = parser.parse_args()
opt['dict_file'] = os.path.join(opt['model_folder'], 'vocab')
dictionary = build_dict.build_dict(opt)
agents = []
for _ in range(opt['num_agents']):
agents.append(LearningToRankAgent(opt))
opt['datatype'] = 'train'
world_train = create_task(opt, agents)
opt['datatype'] = 'valid'
world_valid = create_task(opt, agents)
start = time.time()
# train / valid loop
for _ in range(1):
print('[ training ]')
for _ in range(opt['num_iters']): # train for a bit
world_train.parley()
print('[ training summary. ]')
print(world_train.report())
print('[ validating ]')
for _ in range(1): # check valid accuracy
world_valid.parley()
print('[ validation summary. ]')
def validate(opt, agent):
old_datatype = agent.opt['datatype']
agent.opt['datatype'] = 'valid'
opt = deepcopy(opt)
opt['datatype'] = 'valid'
opt['terminate'] = True
opt['batchsize'] = 1
old_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
valid_world = create_task(opt, agent)
sys.stdout = old_stdout
for _ in valid_world:
valid_world.parley()
stats = valid_world.report()
agent.opt['datatype'] = old_datatype
return stats
def dump_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
ignorefields = opt.get('ignore_fields', '')
if opt['outfile'] is None:
outfile = tempfile.mkstemp(
prefix='{}_{}_'.format(opt['task'], opt['datatype']), suffix='.txt'
)[1]
else:
outfile = opt['outfile']
if opt['num_examples'] == -1:
num_examples = world.num_examples()
else:
num_examples = opt['num_examples']
log_timer = TimeLogger()
print('[ starting to convert.. ]')
print('[ saving output to {} ]'.format(outfile))