Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_pickle(self):
classifier, sess = get_classifier_tf()
full_path = os.path.join(DATA_PATH, 'my_classifier')
folder = os.path.split(full_path)[0]
if not os.path.exists(folder):
os.makedirs(folder)
pickle.dump(classifier, open(full_path, 'wb'))
# Unpickle:
with open(full_path, 'rb') as f:
loaded = pickle.load(f)
self.assertEqual(classifier._clip_values, loaded._clip_values)
self.assertEqual(classifier._channel_index, loaded._channel_index)
self.assertEqual(set(classifier.__dict__.keys()), set(loaded.__dict__.keys()))
# Test predict
predictions_1 = classifier.predict(self.x_test)
def test_save_image(self):
(x, _), (_, _), _, _ = load_mnist(raw=True)
f_name = 'image1.png'
save_image(x[0], f_name)
path = os.path.join(DATA_PATH, f_name)
self.assertTrue(os.path.isfile(path))
os.remove(path)
f_name = 'image2.jpg'
save_image(x[1], f_name)
path = os.path.join(DATA_PATH, f_name)
self.assertTrue(os.path.isfile(path))
os.remove(path)
folder = 'images123456'
f_name_with_dir = os.path.join(folder, 'image3.png')
save_image(x[3], f_name_with_dir)
path = os.path.join(DATA_PATH, f_name_with_dir)
self.assertTrue(os.path.isfile(path))
os.remove(path)
os.rmdir(os.path.split(path)[0]) # Remove also test folder
folder = os.path.join('images123456', 'inner')
f_name_with_dir = os.path.join(folder, 'image4.png')
save_image(x[3], f_name_with_dir)
path_nested = os.path.join(DATA_PATH, f_name_with_dir)
self.assertTrue(os.path.isfile(path_nested))
def save_image(image_array, f_name):
"""
Saves image into a file inside `DATA_PATH` with the name `f_name`.
:param image_array: Image to be saved
:type image_array: `np.ndarray`
:param f_name: File name containing extension e.g., my_img.jpg, my_img.png, my_images/my_img.png
:type f_name: `str`
:return: `None`
"""
file_name = os.path.join(DATA_PATH, f_name)
folder = os.path.split(file_name)[0]
if not os.path.exists(folder):
os.makedirs(folder)
from PIL import Image
image = Image.fromarray(image_array)
image.save(file_name)
logger.info('Image saved to %s.', file_name)
colors.append('C' + str(i))
else:
if len(colors) != len(np.unique(labels)):
raise ValueError('The amount of provided colors should match the number of labels in the 3pd plot.')
fig = plt.figure()
axis = plt.axes(projection='3d')
for i, coord in enumerate(points):
try:
color_point = labels[i]
axis.scatter3D(coord[0], coord[1], coord[2], color=colors[color_point])
except IndexError:
raise ValueError('Labels outside the range. Should start from zero and be sequential there after')
if save:
file_name = os.path.realpath(os.path.join(DATA_PATH, f_name))
folder = os.path.split(file_name)[0]
if not os.path.exists(folder):
os.makedirs(folder)
fig.savefig(file_name, bbox_inches='tight')
logger.info('3d-plot saved to %s.', file_name)
return fig
except ImportError:
logger.warning("matplotlib not installed. For this reason, cluster visualization was not displayed.")
def save(self, filename, path=None):
"""
Save a model to file in the format specific to the backend framework. For Keras, .h5 format is used.
:param filename: Name of the file where to store the model.
:type filename: `str`
:param path: Path of the folder where to store the model. If no path is specified, the model will be stored in
the default data location of the library `DATA_PATH`.
:type path: `str`
:return: None
"""
import os
if path is None:
from art import DATA_PATH
full_path = os.path.join(DATA_PATH, filename)
else:
full_path = os.path.join(path, filename)
folder = os.path.split(full_path)[0]
if not os.path.exists(folder):
os.makedirs(folder)
self._model.save(str(full_path))
logger.info('Model saved in path: %s.', full_path)
can also be extracted. This is a simplified version of the function with the same name in Keras.
:param filename: Name of the file.
:type filename: `str`
:param url: Download URL.
:type url: `str`
:param path: Folder to store the download. If not specified, `~/.art/data` is used instead.
:type: `str`
:param extract: If true, tries to extract the archive.
:type extract: `bool`
:return: Path to the downloaded file.
:rtype: `str`
"""
if path is None:
from art import DATA_PATH
path_ = os.path.expanduser(DATA_PATH)
else:
path_ = os.path.expanduser(path)
if not os.access(path_, os.W_OK):
path_ = os.path.join('/tmp', '.art')
if not os.path.exists(path_):
os.makedirs(path_)
if extract:
extract_path = os.path.join(path_, filename)
full_path = extract_path + '.tar.gz'
else:
full_path = os.path.join(path_, filename)
# Determine if dataset needs downloading
download = not os.path.exists(full_path)
def _unpickle_classifier(file_name):
"""
Unpickles classifier using the filename provided. Function assumes that the pickle is in `art.DATA_PATH`.
:param file_name:
:return:
"""
from art import DATA_PATH
import pickle
full_path = os.path.join(DATA_PATH, file_name)
logger.info('Loading classifier from %s', full_path)
with open(full_path, 'rb') as f_classifier:
loaded_classifier = pickle.load(f_classifier)
return loaded_classifier
def __setstate__(self, state):
"""
Use to ensure `KerasClassifier` can be unpickled.
:param state: State dictionary with instance parameters to restore.
:type state: `dict`
"""
self.__dict__.update(state)
# Load and update all functionality related to Keras
import os
from art import DATA_PATH
from keras.models import load_model
full_path = os.path.join(DATA_PATH, state['model_name'])
model = load_model(str(full_path))
self._model = model
self._initialize_params(model, state['_use_logits'], state['_input_layer'], state['_output_layer'],
state['_custom_activation'])
content = cPickle.load(file_)
else:
content = cPickle.load(file_, encoding='bytes')
content_decoded = {}
for key, value in content.items():
content_decoded[key.decode('utf8')] = value
content = content_decoded
data = content['data']
labels = content['labels']
data = data.reshape(data.shape[0], 3, 32, 32)
return data, labels
from art import DATA_PATH
path = get_file('cifar-10-batches-py', extract=True, path=DATA_PATH,
url='http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
num_train_samples = 50000
x_train = np.zeros((num_train_samples, 3, 32, 32), dtype=np.uint8)
y_train = np.zeros((num_train_samples,), dtype=np.uint8)
for i in range(1, 6):
fpath = os.path.join(path, 'data_batch_' + str(i))
data, labels = load_batch(fpath)
x_train[(i - 1) * 10000: i * 10000, :, :, :] = data
y_train[(i - 1) * 10000: i * 10000] = labels
fpath = os.path.join(path, 'test_batch')
x_test, y_test = load_batch(fpath)
y_train = np.reshape(y_train, (len(y_train), 1))