How to use the parlai.agents.repeat_label.repeat_label.RepeatLabelAgent function in parlai

To help you get started, we’ve selected a few parlai examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github facebookresearch / ParlAI / tests / test_loader.py View on Github external
def test_load_agent(self):
        agent_module = load_agent_module(OPTIONS['agent'])
        self.assertEqual(agent_module, RepeatLabelAgent)
github facebookresearch / ParlAI / parlai / scripts / extract_image_feature.py View on Github external
bsz = 1
        try:
            import torch
            from torch.utils.data import DataLoader
        except ImportError:
            raise ImportError('Need to install Pytorch: go to pytorch.org')

        dataset = get_dataset_class(opt)(opt)
        pre_image_path, _ = os.path.split(dataset.image_path)
        image_path = os.path.join(pre_image_path, opt.get('image_mode'))
        images_built_file = image_path + '.built'

        if not os.path.exists(image_path) or not os.path.isfile(images_built_file):
            '''Image features have not been computed yet'''
            opt['num_load_threads'] = 20
            agent = RepeatLabelAgent(opt)
            if opt['task'] == 'pytorch_teacher':
                if opt.get('pytorch_teacher_task'):
                    opt['task'] = opt['pytorch_teacher_task']
                else:
                    opt['task'] = opt['pytorch_teacher_dataset']
            world = create_task(opt, agent)
            exs_seen = 0
            total_exs = world.num_examples()
            pbar = tqdm.tqdm(unit='ex', total=total_exs)
            print('[ Computing and Saving Image Features ]')
            while exs_seen < total_exs:
                world.parley()
                exs_seen += bsz
                pbar.update(bsz)
            pbar.close()
            print('[ Feature Computation Done ]')
github facebookresearch / ParlAI / parlai / mturk / tasks / wizard_of_wikipedia / extract_and_save_personas.py View on Github external
def extract_and_save(opt, save=True):
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    teacher = world.agents[0]

    personas_path = opt.get('personas_path')
    if not os.path.exists(personas_path):
        os.makedirs(personas_path)

    new_episode = True
    personas = []
    while not teacher.epoch_done():
        act = teacher.act()
        if new_episode:
            persona_text = act['text'].split('\n')[:-1]
            if opt.get('persona_type') == 'both':
                persona_1 = [p for p in persona_text if 'your persona:' in p]
                persona_2 = [p for p in persona_text if 'partner\'s persona:' in p]
github facebookresearch / ParlAI / parlai / scripts / extract_image_feature.py View on Github external
print('[ Deprecated Warning: extract_feats should be passed opt not Parser ]')
        opt = opt.parse_args()
    # Get command line arguments
    opt = copy.deepcopy(opt)
    dt = opt['datatype'].split(':')[0] + ':ordered'
    opt['datatype'] = dt
    bsz = opt.get('batchsize', 1)
    opt['no_cuda'] = False
    opt['gpu'] = 0
    opt['num_epochs'] = 1
    opt['use_hdf5'] = False
    opt['num_load_threads'] = 20
    print("[ Loading Images ]")
    # create repeat label agent and assign it to the specified task
    if opt.get('pytorch_teacher_dataset') is None:
        agent = RepeatLabelAgent(opt)
        world = create_task(opt, agent)

        total_exs = world.num_examples()
        # TODO: wrap in a tqdm
        while not world.epoch_done():
            world.parley()
    elif opt.get('use_hdf5_extraction', False):
        '''One can specify a Pytorch Dataset for custom image loading'''
        nw = opt.get('numworkers', 1)
        im = opt.get('image_mode', 'raw')
        opt['batchsize'] = 1
        opt['extract_image'] = True
        bsz = 1
        try:
            import torch
            from torch.utils.data import DataLoader
github facebookresearch / ParlAI / projects / controllable_dialogue / controllable_seq2seq / nidf.py View on Github external
def get_word_counts(opt, count_inputs):
    """Goes through the dataset specified in opt and gets word counts.

    Inputs:
      count_inputs: If True, include both input and reply when counting words
        and utterances. Otherwise, only include reply text.

    Returns:
      word_counter_per_sent: a Counter mapping each word to the number of
        utterances in which it appears.
      num_sents: int. number of utterances counted
    """
    # Create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Count word frequency for all words in dataset
    word_counter_per_sent = Counter()
    num_sents = 0
    count = 0
    log_timer = TimeLogger()
    while True:
        count += 1

        world.parley()
        reply = world.acts[0].get('labels', world.acts[0].get('eval_labels'))[0]

        words = reply.split()
        words_no_dups = list(set(words))  # remove duplicates
        word_counter_per_sent.update(words_no_dups)
github facebookresearch / ParlAI / projects / convai2 / interactive.py View on Github external
if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: interactive should be passed opt not Parser ]')
        opt = opt.parse_args()
    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    # Create ConvAI2 data so we can assign personas.
    convai2_opt = opt.copy()
    convai2_opt['task'] = 'convai2:both'
    convai2_agent = RepeatLabelAgent(convai2_opt)
    convai2_world = create_task(convai2_opt, convai2_agent)

    def get_new_personas():
        # Find a new episode
        while True:
            convai2_world.parley()
            msg = convai2_world.get_acts()[0]
            if msg['episode_done']:
                convai2_world.parley()
                msg = convai2_world.get_acts()[0]
                break
        txt = msg.get('text', '').split('\n')
        bot_persona = ""
        for t in txt:
            if t.startswith("partner's persona:"):
                print(t.replace("partner's ", 'your '))
github facebookresearch / ParlAI / parlai / tasks / light_dialog / worlds.py View on Github external
def load_personas(self):
        # Create Light data so we can assign personas.
        light_opt = self.opt.copy()
        light_opt['task'] = 'light_dialog'
        light_opt['interactive_task'] = False
        light_agent = RepeatLabelAgent(light_opt)
        self.light_world = create_task(light_opt, light_agent)
        self.cnt = 0
github facebookresearch / ParlAI / projects / controllable_dialogue / get_bucket_lowerbounds.py View on Github external
def bucket_data(opt):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    assert opt['control'] != ''
    ctrl = opt['control']

    num_buckets = opt['num_buckets']

    ctrl_vals = []  # list of floats

    for _ in range(num_examples):
github facebookresearch / ParlAI / parlai / scripts / display_data.py View on Github external
def display_data(opt):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    for _ in range(opt['num_examples']):
        world.parley()

        # NOTE: If you want to look at the data from here rather than calling
        # world.display() you could access world.acts[0] directly
        print(world.display() + '\n~~')

        if world.epoch_done():
            print('EPOCH DONE')
            break

    try:
        # print dataset size if available