How to use the babi.read_data.DataSet function in babi

To help you get started, we’ve selected a few babi examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github uwnlp / qrn / babi / base_model.py View on Github external
def eval(self, data_set, eval_tensor_names=(), eval_ph_names=(), num_batches=None):
        # TODO : eval_ph_names
        assert isinstance(data_set, DataSet)
        assert self.initialized, "Initialize tower before training."

        params = self.params
        sess = self.sess
        epoch_op = self.tensors['epoch']
        epoch = sess.run(epoch_op)
        progress = params.progress
        num_batches = num_batches or data_set.get_num_batches(partial=True)
        num_iters = int(np.ceil(num_batches / self.num_towers))
        num_corrects, total, total_loss = 0, 0, 0.0
        eval_values = []
        idxs = []
        N = data_set.batch_size * num_batches
        if N > data_set.num_examples:
            N = data_set.num_examples
        eval_args = self._get_eval_args(epoch)
github uwnlp / qrn / babi / base_model.py View on Github external
def train(self, train_data_set, num_epochs, val_data_set=None, eval_ph_names=(),
              eval_tensor_names=(), num_batches=None, val_num_batches=None):
        assert isinstance(train_data_set, DataSet)
        assert self.initialized, "Initialize tower before training."

        sess = self.sess
        writer = self.writer
        params = self.params
        progress = params.progress
        val_acc = None
        # if num batches is specified, then train only that many
        num_batches = num_batches or train_data_set.get_num_batches(partial=False)
        num_iters_per_epoch = int(num_batches / self.num_towers)
        num_digits = int(np.log10(num_batches))

        epoch_op = self.tensors['epoch']
        epoch = sess.run(epoch_op)
        print("training %d epochs ... " % num_epochs)
        logging.info("num iters per epoch: %d" % num_iters_per_epoch)
github uwnlp / qrn / babi / read_data.py View on Github external
def read_data(params, mode, task):
    logging.info("loading {} data for task {}... ".format(mode, task))
    mid = params.lang + ("-10k" if params.large else "")
    task_dir = os.path.join(params.data_dir, mid, task.zfill(2))
    batch_size = params.batch_size

    mode2idxs_path = os.path.join(task_dir, "mode2idxs.json")
    data_path = os.path.join(task_dir, "data.json")
    mode2idxs_dict = json.load(open(mode2idxs_path, 'r'))
    data = json.load(open(data_path, 'r'))
    idxs = mode2idxs_dict[mode]
    data_set = DataSet(mode, batch_size, data, idxs)
    return data_set