Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
SHOULD_RUN_INDEX = 2
TEST_INFO_INDEX = 3
VALUES_INDEX = 0
SERIAL_MODE_INDEX = 1
PARALLEL_MODE_INDEX = 2
LOCAL_MODE_INDEX = 3
S3_MODE_INDEX = 4
TRAIN_SCRIPT_INDEX = 0
TRAIN_SCRIPT_ARGS_INDEX = 1
TEST_SCRIPT_INDEX = 2
TEST_SCRIPT_ARGS_INDEX = 3
INTEGRATION_TEST_S3_BUCKET = "tornasolecodebuildtest"
logger = get_logger()
def create_if_not_exists(path):
if not os.path.exists(path):
os.mkdir(path)
# delete the s3 folders using aioboto3
async def del_folder(bucket, keys):
loop = asyncio.get_event_loop()
client = aioboto3.client("s3", loop=loop)
await asyncio.gather(*[client.delete_object(Bucket=bucket, Key=key) for key in keys])
await client.close()
# store path to config file and test mode for testing rule scrip with training script
def get_json_config_as_dict(json_config_path) -> Dict:
"""Checks json_config_path, then environment variables, then attempts to load.
Will throw FileNotFoundError if a config is not available.
"""
if json_config_path is not None:
path = json_config_path
else:
path = os.getenv(CONFIG_FILE_PATH_ENV_STR, DEFAULT_CONFIG_FILE_PATH)
with open(path) as json_config_file:
params_dict = json.load(json_config_file)
get_logger().info(f"Creating hook from json_config at {path}.")
return params_dict
# Third Party
from botocore.exceptions import ClientError
# First Party
from smdebug.core.access_layer.s3handler import DeleteRequest, ListRequest, S3Handler
from smdebug.core.logger import get_logger
from smdebug.core.sagemaker_utils import is_sagemaker_job
from smdebug.core.utils import is_s3
# Local
from .file import TSAccessFile
from .s3 import TSAccessS3
END_OF_JOB_FILENAME = "training_job_end.ts"
logger = get_logger()
def training_has_ended(trial_prefix):
# Emit the end of training file only if the job is not running under SageMaker.
if is_sagemaker_job():
logger.info(
f"The end of training job file will not be written for jobs running under SageMaker."
)
return
try:
check_dir_exists(trial_prefix)
# if path does not exist, then we don't need to write a file
except RuntimeError:
# dir exists
pass
file_path = os.path.join(trial_prefix, END_OF_JOB_FILENAME)
def __init__(self, base_trial, other_trials=None):
self.base_trial = base_trial
self.other_trials = other_trials
self.trials = [base_trial]
if self.other_trials is not None:
self.trials += [x for x in self.other_trials]
self.req_tensors = RequiredTensors(self.base_trial, self.other_trials)
self.logger = get_logger()
self.rule_name = self.__class__.__name__
def __init__(self, path, mode):
super().__init__()
self.path = path
self.mode = mode
self.logger = get_logger()
ensure_dir(path)
if mode in WRITE_MODES:
self.temp_path = get_temp_path(self.path)
ensure_dir(self.temp_path)
self.open(self.temp_path, mode)
else:
self.open(self.path, mode)
# Standard Library
from enum import Enum
# Third Party
import tensorflow as tf
from tensorflow.python.distribute import values
# First Party
from smdebug.core.logger import get_logger
logger = get_logger()
def get_tf_names(arg):
if isinstance(arg, tf.Variable):
tf_names = [arg.name]
elif isinstance(arg, tf.Tensor):
tf_names = [arg.name]
elif isinstance(arg, values.MirroredVariable):
tf_names = [v.name for v in arg._values]
else:
raise NotImplementedError
return tf_names
class TensorType(Enum):
REGULAR = 1
# Third Party
import numpy as np
# First Party
from smdebug.core.logger import get_logger
# Local
from .proto.tensor_pb2 import TensorProto
from .proto.tensor_shape_pb2 import TensorShapeProto
logger = get_logger()
# hash value of ndarray.dtype is not the same as np.float class
# so we need to convert the type classes below to np.dtype object
_NP_DATATYPE_TO_PROTO_DATATYPE = {
np.dtype(np.float16): "DT_INT32",
np.dtype(np.float32): "DT_FLOAT",
np.dtype(np.float64): "DT_DOUBLE",
np.dtype(np.int32): "DT_INT32",
np.dtype(np.int64): "DT_INT64",
np.dtype(np.uint8): "DT_UINT8",
np.dtype(np.uint16): "DT_UINT16",
np.dtype(np.uint32): "DT_UINT32",
np.dtype(np.uint64): "DT_UINT64",
np.dtype(np.int8): "DT_INT8",
np.dtype(np.int16): "DT_INT16",
np.dtype(np.complex64): "DT_COMPLEX64",
import time
# First Party
from smdebug.core.config_constants import (
CHECKPOINT_CONFIG_FILE_PATH_ENV_VAR,
CHECKPOINT_DIR_KEY,
DEFAULT_CHECKPOINT_CONFIG_FILE,
LATEST_GLOBAL_STEP_SAVED,
LATEST_GLOBAL_STEP_SEEN,
LATEST_MODE_STEP,
METADATA_FILENAME,
TRAINING_RUN,
)
from smdebug.core.logger import get_logger
logger = get_logger()
# This is 'predicate' for sorting the list of states based on seen steps.
def _rule_for_sorting(state):
return state[LATEST_GLOBAL_STEP_SEEN]
class StateStore:
def __init__(self):
self._saved_states = []
self._checkpoint_update_timestamp = 0
self._states_file = None
self._checkpoint_dir = None
self._retrieve_path_to_checkpoint()
if self._checkpoint_dir is not None:
self._states_file = os.path.join(self._checkpoint_dir, METADATA_FILENAME)
self._read_states_file()
self._checkpoint_update_timestamp = max(
# First Party
from smdebug.core.logger import get_logger
from smdebug.exceptions import (
RuleEvaluationConditionMet,
StepUnavailable,
TensorUnavailable,
TensorUnavailableForStep,
)
logger = get_logger()
def invoke_rule(rule_obj, start_step=0, end_step=None, raise_eval_cond=False):
step = start_step if start_step is not None else 0
logger.info("Started execution of rule {} at step {}".format(type(rule_obj).__name__, step))
while (end_step is None) or (step < end_step):
try:
rule_obj.invoke(step)
except (TensorUnavailableForStep, StepUnavailable, TensorUnavailable) as e:
logger.debug(str(e))
except RuleEvaluationConditionMet as e:
if raise_eval_cond:
raise e
else:
logger.debug(str(e))
step += 1
"""
Easy-to-use methods for getting the singleton SessionHook.
Sample usage:
import smdebug.(pytorch | tensorflow | mxnet) as smd
hook = smd.hook()
"""
# Standard Library
import atexit
# First Party
from smdebug.core.logger import get_logger
logger = get_logger()
_ts_hook = None
def _create_hook(json_config_path, hook_class):
from smdebug.core.hook import BaseHook # prevent circular imports
if not issubclass(hook_class, BaseHook):
raise TypeError("hook_class needs to be a subclass of BaseHook", hook_class)
# Either returns a hook or None
try:
hook = hook_class.create_from_json_file(json_file_path=json_config_path)
set_hook(custom_hook=hook)
except FileNotFoundError:
pass