Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, cl_path, dataset_path, image_size, batch_size, shuffle=True):
self.dataset_path, self.image_size, self.batch_size, self.shuffle = [value for value in (dataset_path, image_size, batch_size, shuffle)] # use shuffle only with train, not with test
self.class_labels = Dataset.get_class_labels(cl_path)
# the number of classes present should
# be evaluated well in advance :(
# if you run into some errors,
# consider commenting the with block
# and setting num_classes manually externally
with tf.Session() as sess:
self.num_classes = sess.run(tf.shape(self.class_labels)[0])
self.data = self.get_dataset()
def get_train_test_dataset(cl_path, dataset_path, image_size, batch_size):
train_path = '/train'
test_path = '/test'
train, test = [
Dataset(cl_path, dataset_path + curr_path, image_size, batch_size, training)
for curr_path, training in zip((train_path, test_path), (True, False))
]
return train, test
def get_image_and_class(self, image, classl):
classl = tf.math.equal(self.class_labels, classl)
classl = tf.cast(classl, tf.int32)
classl = tf.argmax(classl, axis=-1)
classl = tf.one_hot(classl, self.num_classes)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_pad(image, self.image_size[0], self.image_size[1])
image = tf.cast(image, tf.float32)
image = Dataset.preprocess_image(image)
return image, classl