How to use the parlai.core.worlds.create_task function in parlai

To help you get started, we’ve selected a few parlai examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github facebookresearch / ParlAI / tests / tasks / test_wizard_of_wikipedia.py View on Github external
def run_display_test(self, kwargs):
        f = io.StringIO()
        with redirect_stdout(f):
            parser = setup_args()
            parser.set_defaults(**kwargs)
            opt = parser.parse_args()
            agent = RepeatLabelAgent(opt)
            world = create_task(opt, agent)
            display(opt)

        str_output = f.getvalue()
        self.assertTrue(
            '[ loaded {} episodes with a total of {} examples ]'.format(
                world.num_episodes(), world.num_examples()
            )
            in str_output,
            'Wizard of Wikipedia failed with following args: {}'.format(opt),
        )
github facebookresearch / ParlAI / tests / test_pytorch_data_teacher.py View on Github external
def get_acts_epochs_1_and_2(defaults):
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            build_dict(opt)
            agent = create_agent(opt)
            world_data = create_task(opt, agent)
            acts_epoch_1 = []
            acts_epoch_2 = []
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_1.append(world_data.acts[0])
            world_data.reset()
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_2.append(world_data.acts[0])
            acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
            acts_epoch_1 = sorted(
                [b for b in acts_epoch_1 if 'text' in b], key=lambda x: x.get('text')
            )
            acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
            acts_epoch_2 = sorted(
                [b for b in acts_epoch_2 if 'text' in b], key=lambda x: x.get('text')
github facebookresearch / ParlAI / projects / controllable_dialogue / controllable_seq2seq / arora.py View on Github external
def get_word_counts(opt, count_inputs):
    """Goes through the dataset specified in opt, returns word counts and all utterances

    Inputs:
      count_inputs: If True, include both input and reply when counting words and
        utterances. Otherwise, only include reply text.

    Returns:
      word_counter: a Counter mapping each word to the total number of times it appears
      total_count: int. total word count, i.e. the sum of the counts for each word
      all_utts: list of strings. all the utterances that were used for counting words
    """
    # Create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Count word frequency for all words in dataset
    word_counter = Counter()
    total_count = 0
    all_utts = []
    log_timer = TimeLogger()
    while True:
        world.parley()

        # Count words in reply
        reply = world.acts[0].get('labels', world.acts[0].get('eval_labels'))[0]
        words = reply.split()
        word_counter.update(words)
        total_count += len(words)
        all_utts.append(reply)
github facebookresearch / ParlAI / parlai / tasks / convai2 / worlds.py View on Github external
def load_personas(opt):
    print('[ loading personas.. ]')
    # Create ConvAI2 data so we can assign personas.
    convai2_opt = opt.copy()
    convai2_opt['task'] = 'convai2:both'
    if convai2_opt['datatype'].startswith('train'):
        convai2_opt['datatype'] = 'train:evalmode'
    convai2_opt['interactive_task'] = False
    convai2_agent = RepeatLabelAgent(convai2_opt)
    convai2_world = create_task(convai2_opt, convai2_agent)
    personas = set()
    while not convai2_world.epoch_done():
        convai2_world.parley()
        msg = convai2_world.get_acts()[0]
        # Find a new episode
        if msg.get('episode_done', False) and not convai2_world.epoch_done():
            convai2_world.parley()
            msg = convai2_world.get_acts()[0]
            txt = msg.get('text', '').split('\n')
            a1_persona = ""
            a2_persona = ""
            for t in txt:
                if t.startswith("partner's persona:"):
                    a1_persona += (
                        t.replace("partner's persona:", 'your persona:') + '\n'
                    )
github facebookresearch / ParlAI / parlai / scripts / data_stats.py View on Github external
def verify(opt, printargs=None, print_parser=None):
    if opt['datatype'] == 'train':
        print("[ note: changing datatype from train to train:ordered ]")
        opt['datatype'] = 'train:ordered'

    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    dictionary = DictionaryAgent(opt)
    ignore_tokens = opt.get('ignore_tokens').split(',')

    counts = {}
    for t in {'input', 'labels', 'both'}:
        counts['tokens_in_' + t] = 0
        counts['utterances_in_' + t] = 0
        counts['avg_utterance_length_in_' + t] = 0
        counts['unique_tokens_in_' + t] = 0
        counts['unique_utterances_in_' + t] = 0
github facebookresearch / ParlAI / projects / mastering_the_dungeon / projects / graph_world2 / train.py View on Github external
def main(opt, return_full=False, out_file=None):
    data_agent = prepro(opt)
    model_agent = (
        ObjectChecklistModelAgent(opt, data_agent=data_agent)
        if not opt['seq2seq']
        else Seq2SeqModelAgent(opt, data_agent=data_agent)
    )

    train_world = create_task(opt, model_agent)

    max_dict = model_agent.model.state_dict()

    max_acc, max_f1, max_data, last_max, max_acc_len = -1, 0, None, 0, None
    for iter in range(opt['max_iter']):
        if iter - last_max > opt['exit_iter']:
            break

        if 'inc_ratio' in opt and opt['inc_ratio'] > 0 and iter == opt['inc_pre_iters']:
            print('resetting best model for finetuning')
            model_agent.model.load_state_dict(max_dict)
            max_acc = 0

        train_world.parley()
        train_report = train_world.report()
github deepmipt / kpi2017 / parlai_tasks / base_train.py View on Github external
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num-examples', default=10)
    opt = parser.parse_args()

    agent = Agent(opt)

    opt['datatype'] = 'train'
    world_train = create_task(opt, agent)

    opt['datatype'] = 'valid'
    world_valid = create_task(opt, agent)

    start = time.time()
    # train / valid loop
    for _ in range(1):
        print('[ training ]')
        for _ in range(10):  # train for a bit
            world_train.parley()

        print('[ training summary. ]')
        print(world_train.report())

        print('[ validating ]')
        for _ in range(1):  # check valid accuracy
github WattSocialBot / alana_learning_to_rank / train_convai2.py View on Github external
build_dict.setup_args(parser)
    opt = parser.parse_args()

    opt['dict_file'] = os.path.join(opt['model_folder'], 'vocab')
    dictionary = build_dict.build_dict(opt)

    agents = []
    for _ in range(opt['num_agents']):
        agents.append(LearningToRankAgent(opt))

    opt['datatype'] = 'train'
    world_train = create_task(opt, agents)

    opt['datatype'] = 'valid'
    world_valid = create_task(opt, agents)

    start = time.time()
    # train / valid loop
    for _ in range(1):
        print('[ training ]')
        for _ in range(opt['num_iters']):  # train for a bit
            world_train.parley()

        print('[ training summary. ]')
        print(world_train.report())

        print('[ validating ]')
        for _ in range(1):  # check valid accuracy
            world_valid.parley()

        print('[ validation summary. ]')
github facebookresearch / ParlAI / projects / mastering_the_dungeon / projects / graph_world2 / train.py View on Github external
def validate(opt, agent):
    old_datatype = agent.opt['datatype']
    agent.opt['datatype'] = 'valid'

    opt = deepcopy(opt)
    opt['datatype'] = 'valid'
    opt['terminate'] = True
    opt['batchsize'] = 1

    old_stdout = sys.stdout
    sys.stdout = open(os.devnull, 'w')
    valid_world = create_task(opt, agent)
    sys.stdout = old_stdout

    for _ in valid_world:
        valid_world.parley()

    stats = valid_world.report()
    agent.opt['datatype'] = old_datatype
    return stats
github facebookresearch / ParlAI / parlai / scripts / convert_data_to_parlai_format.py View on Github external
def dump_data(opt):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    ignorefields = opt.get('ignore_fields', '')
    if opt['outfile'] is None:
        outfile = tempfile.mkstemp(
            prefix='{}_{}_'.format(opt['task'], opt['datatype']), suffix='.txt'
        )[1]
    else:
        outfile = opt['outfile']

    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    print('[ starting to convert.. ]')
    print('[ saving output to {} ]'.format(outfile))