Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_mode_data():
run_id = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
trial_dir = "/tmp/ts_outputs/" + run_id
c = CollectionManager()
c.add("default")
c.get("default").tensor_names = ["arr"]
c.export(trial_dir, DEFAULT_COLLECTIONS_FILE_NAME)
tr = create_trial(trial_dir)
worker = socket.gethostname()
for s in range(0, 10):
fw = FileWriter(trial_dir=trial_dir, step=s, worker=worker)
if s % 2 == 0:
fw.write_tensor(
tdata=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
tname="arr",
mode=modes.TRAIN,
mode_step=s // 2,
)
else:
fw.write_tensor(
def test_manager_export_load():
cm = CollectionManager()
cm.create_collection("default")
cm.get("default").include("loss")
cm.add(Collection("trial1"))
cm.add("trial2")
cm.get("trial2").include("total_loss")
cm.export("/tmp/dummy_trial", DEFAULT_COLLECTIONS_FILE_NAME)
cm2 = CollectionManager.load(
os.path.join(get_path_to_collections("/tmp/dummy_trial"), DEFAULT_COLLECTIONS_FILE_NAME)
)
assert cm == cm2
def dummy_trial_creator(trial_dir, num_workers, job_ended):
Path(trial_dir).mkdir(parents=True, exist_ok=True)
cm = CollectionManager()
for i in range(num_workers):
collection_file_name = f"worker_{i}_collections.json"
cm.export(trial_dir, collection_file_name)
if job_ended:
Path(os.path.join(trial_dir, "training_job_end.ts")).touch()
worker,
shape,
dtype=np.float32,
rank=None,
mode=None,
mode_step=None,
export_colls=True,
data=None,
):
with FileWriter(trial_dir=os.path.join(path, trial), step=step, worker=worker) as fw:
for i in range(num_tensors):
if data is None:
data = np.ones(shape=shape, dtype=dtype) * step
fw.write_tensor(tdata=data, tname=f"{tname_prefix}_{i}", mode=mode, mode_step=mode_step)
if export_colls:
c = CollectionManager()
c.add("default")
c.get("default").tensor_names = [f"{tname_prefix}_{i}" for i in range(num_tensors)]
c.add("gradients")
c.get("gradients").tensor_names = [f"{tname_prefix}_{i}" for i in range(num_tensors)]
c.export(os.path.join(path, trial), DEFAULT_COLLECTIONS_FILE_NAME)
def help_test_multiple_trials(num_steps=20, num_tensors=10):
trial_name = str(uuid.uuid4())
bucket = "smdebug-testing"
path = "s3://" + os.path.join(bucket, "outputs/")
c = CollectionManager()
c.add("default")
c.get("default").tensor_names = ["foo_" + str(i) for i in range(num_tensors)]
c.export(path + trial_name, DEFAULT_COLLECTIONS_FILE_NAME)
c.export(path + trial_name, DEFAULT_COLLECTIONS_FILE_NAME)
for i in range(num_steps):
generate_data(
path=path,
trial=trial_name,
num_tensors=num_tensors,
step=i,
tname_prefix="foo",
worker="algo-1",
shape=(3, 3, 3),
rank=0,
)
_, bucket, prefix = is_s3(os.path.join(path, trial_name))
def gen_y_and_y_hat(path, trial, step, y, y_name, y_hat, y_hat_name, colls={}):
trial_dir = os.path.join(path, trial)
with FileWriter(trial_dir=trial_dir, step=step, worker="algo-1") as fw:
fw.write_tensor(tdata=y, tname=y_name)
fw.write_tensor(tdata=y_hat, tname=y_hat_name)
c = CollectionManager()
for coll in colls:
c.add(coll)
c.get(coll).tensor_names = colls[coll]
c.export(trial_dir, DEFAULT_COLLECTIONS_FILE_NAME)
def __eq__(self, other):
if not isinstance(other, CollectionManager):
return NotImplemented
return self.collections == other.collections
self.add_tensor_name(tensor.export_name)
def has_tensor(self, name):
# tf object name
return name in self._tensors
def add_keras_layer(self, layer, inputs=False, outputs=True):
if inputs:
input_tensor_regex = layer.name + "/inputs/"
self.include(input_tensor_regex)
if outputs:
output_tensor_regex = layer.name + "/outputs/"
self.include(output_tensor_regex)
class CollectionManager(BaseCollectionManager):
def __init__(self, collections=None, create_default=True):
super().__init__(collections=collections)
if create_default:
for n in [
CollectionKeys.DEFAULT,
CollectionKeys.WEIGHTS,
CollectionKeys.BIASES,
CollectionKeys.GRADIENTS,
CollectionKeys.LOSSES,
CollectionKeys.METRICS,
CollectionKeys.INPUTS,
CollectionKeys.OUTPUTS,
CollectionKeys.ALL,
CollectionKeys.SM_METRICS,
]:
self.create_collection(n)
def _read_collections(self, collection_files):
first_collection_file = collection_files[0] # First Collection File
key = os.path.join(first_collection_file)
collections_req = ReadObjectRequest(self._get_s3_location(key))
obj_data = S3Handler.get_objects([collections_req])[0]
obj_data = obj_data.decode("utf-8")
self.collection_manager = CollectionManager.load_from_string(obj_data)
self.num_workers = self.collection_manager.get_num_workers()