Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_load_agent(self):
agent_module = load_agent_module(OPTIONS['agent'])
self.assertEqual(agent_module, RepeatLabelAgent)
bsz = 1
try:
import torch
from torch.utils.data import DataLoader
except ImportError:
raise ImportError('Need to install Pytorch: go to pytorch.org')
dataset = get_dataset_class(opt)(opt)
pre_image_path, _ = os.path.split(dataset.image_path)
image_path = os.path.join(pre_image_path, opt.get('image_mode'))
images_built_file = image_path + '.built'
if not os.path.exists(image_path) or not os.path.isfile(images_built_file):
'''Image features have not been computed yet'''
opt['num_load_threads'] = 20
agent = RepeatLabelAgent(opt)
if opt['task'] == 'pytorch_teacher':
if opt.get('pytorch_teacher_task'):
opt['task'] = opt['pytorch_teacher_task']
else:
opt['task'] = opt['pytorch_teacher_dataset']
world = create_task(opt, agent)
exs_seen = 0
total_exs = world.num_examples()
pbar = tqdm.tqdm(unit='ex', total=total_exs)
print('[ Computing and Saving Image Features ]')
while exs_seen < total_exs:
world.parley()
exs_seen += bsz
pbar.update(bsz)
pbar.close()
print('[ Feature Computation Done ]')
def extract_and_save(opt, save=True):
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
teacher = world.agents[0]
personas_path = opt.get('personas_path')
if not os.path.exists(personas_path):
os.makedirs(personas_path)
new_episode = True
personas = []
while not teacher.epoch_done():
act = teacher.act()
if new_episode:
persona_text = act['text'].split('\n')[:-1]
if opt.get('persona_type') == 'both':
persona_1 = [p for p in persona_text if 'your persona:' in p]
persona_2 = [p for p in persona_text if 'partner\'s persona:' in p]
print('[ Deprecated Warning: extract_feats should be passed opt not Parser ]')
opt = opt.parse_args()
# Get command line arguments
opt = copy.deepcopy(opt)
dt = opt['datatype'].split(':')[0] + ':ordered'
opt['datatype'] = dt
bsz = opt.get('batchsize', 1)
opt['no_cuda'] = False
opt['gpu'] = 0
opt['num_epochs'] = 1
opt['use_hdf5'] = False
opt['num_load_threads'] = 20
print("[ Loading Images ]")
# create repeat label agent and assign it to the specified task
if opt.get('pytorch_teacher_dataset') is None:
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
total_exs = world.num_examples()
# TODO: wrap in a tqdm
while not world.epoch_done():
world.parley()
elif opt.get('use_hdf5_extraction', False):
'''One can specify a Pytorch Dataset for custom image loading'''
nw = opt.get('numworkers', 1)
im = opt.get('image_mode', 'raw')
opt['batchsize'] = 1
opt['extract_image'] = True
bsz = 1
try:
import torch
from torch.utils.data import DataLoader
def get_word_counts(opt, count_inputs):
"""Goes through the dataset specified in opt and gets word counts.
Inputs:
count_inputs: If True, include both input and reply when counting words
and utterances. Otherwise, only include reply text.
Returns:
word_counter_per_sent: a Counter mapping each word to the number of
utterances in which it appears.
num_sents: int. number of utterances counted
"""
# Create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
# Count word frequency for all words in dataset
word_counter_per_sent = Counter()
num_sents = 0
count = 0
log_timer = TimeLogger()
while True:
count += 1
world.parley()
reply = world.acts[0].get('labels', world.acts[0].get('eval_labels'))[0]
words = reply.split()
words_no_dups = list(set(words)) # remove duplicates
word_counter_per_sent.update(words_no_dups)
if isinstance(opt, ParlaiParser):
print('[ Deprecated Warning: interactive should be passed opt not Parser ]')
opt = opt.parse_args()
opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
# Create model and assign it to the specified task
agent = create_agent(opt, requireModelExists=True)
world = create_task(opt, agent)
if print_parser:
# Show arguments after loading model
print_parser.opt = agent.opt
print_parser.print_args()
# Create ConvAI2 data so we can assign personas.
convai2_opt = opt.copy()
convai2_opt['task'] = 'convai2:both'
convai2_agent = RepeatLabelAgent(convai2_opt)
convai2_world = create_task(convai2_opt, convai2_agent)
def get_new_personas():
# Find a new episode
while True:
convai2_world.parley()
msg = convai2_world.get_acts()[0]
if msg['episode_done']:
convai2_world.parley()
msg = convai2_world.get_acts()[0]
break
txt = msg.get('text', '').split('\n')
bot_persona = ""
for t in txt:
if t.startswith("partner's persona:"):
print(t.replace("partner's ", 'your '))
def load_personas(self):
# Create Light data so we can assign personas.
light_opt = self.opt.copy()
light_opt['task'] = 'light_dialog'
light_opt['interactive_task'] = False
light_agent = RepeatLabelAgent(light_opt)
self.light_world = create_task(light_opt, light_agent)
self.cnt = 0
def bucket_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
if opt['num_examples'] == -1:
num_examples = world.num_examples()
else:
num_examples = opt['num_examples']
log_timer = TimeLogger()
assert opt['control'] != ''
ctrl = opt['control']
num_buckets = opt['num_buckets']
ctrl_vals = [] # list of floats
for _ in range(num_examples):
def display_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
# Show some example dialogs.
for _ in range(opt['num_examples']):
world.parley()
# NOTE: If you want to look at the data from here rather than calling
# world.display() you could access world.acts[0] directly
print(world.display() + '\n~~')
if world.epoch_done():
print('EPOCH DONE')
break
try:
# print dataset size if available