Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def eval(self, data_set, eval_tensor_names=(), eval_ph_names=(), num_batches=None):
# TODO : eval_ph_names
assert isinstance(data_set, DataSet)
assert self.initialized, "Initialize tower before training."
params = self.params
sess = self.sess
epoch_op = self.tensors['epoch']
epoch = sess.run(epoch_op)
progress = params.progress
num_batches = num_batches or data_set.get_num_batches(partial=True)
num_iters = int(np.ceil(num_batches / self.num_towers))
num_corrects, total, total_loss = 0, 0, 0.0
eval_values = []
idxs = []
N = data_set.batch_size * num_batches
if N > data_set.num_examples:
N = data_set.num_examples
eval_args = self._get_eval_args(epoch)
def train(self, train_data_set, num_epochs, val_data_set=None, eval_ph_names=(),
eval_tensor_names=(), num_batches=None, val_num_batches=None):
assert isinstance(train_data_set, DataSet)
assert self.initialized, "Initialize tower before training."
sess = self.sess
writer = self.writer
params = self.params
progress = params.progress
val_acc = None
# if num batches is specified, then train only that many
num_batches = num_batches or train_data_set.get_num_batches(partial=False)
num_iters_per_epoch = int(num_batches / self.num_towers)
num_digits = int(np.log10(num_batches))
epoch_op = self.tensors['epoch']
epoch = sess.run(epoch_op)
print("training %d epochs ... " % num_epochs)
logging.info("num iters per epoch: %d" % num_iters_per_epoch)
def read_data(params, mode, task):
logging.info("loading {} data for task {}... ".format(mode, task))
mid = params.lang + ("-10k" if params.large else "")
task_dir = os.path.join(params.data_dir, mid, task.zfill(2))
batch_size = params.batch_size
mode2idxs_path = os.path.join(task_dir, "mode2idxs.json")
data_path = os.path.join(task_dir, "data.json")
mode2idxs_dict = json.load(open(mode2idxs_path, 'r'))
data = json.load(open(data_path, 'r'))
idxs = mode2idxs_dict[mode]
data_set = DataSet(mode, batch_size, data, idxs)
return data_set