Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _compute_experiment_from_runs(self):
"""Computes a minimal Experiment protocol buffer by scanning the
runs."""
hparam_infos = self._compute_hparam_infos()
if not hparam_infos:
return None
metric_infos = self._compute_metric_infos()
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
)
def _write_session_end():
global _writer
with _writer.as_default():
protob = summary.session_end_pb(api_pb2.STATUS_SUCCESS)
raw_pb = protob.SerializeToString()
tf.summary.experimental.write_raw_pb(raw_pb, step=0)
_writer = None
Args:
session: api_pb2.Session. The session to add.
start_info: The SessionStartInfo protobuffer associated with the session.
groups_by_name: A str to SessionGroup protobuffer dict. Representing the
session groups and sessions found so far.
"""
# If the group_name is empty, this session's group contains only
# this session. Use the session name for the group name since session
# names are unique.
group_name = start_info.group_name or session.name
if group_name in groups_by_name:
groups_by_name[group_name].sessions.extend([session])
else:
# Create the group and add the session as the first one.
group = api_pb2.SessionGroup(
name=group_name,
sessions=[session],
monitor_url=start_info.monitor_url,
)
# Copy hparams from the first session (all sessions should have the same
# hyperparameter values) into result.
# There doesn't seem to be a way to initialize a protobuffer map in the
# constructor.
for (key, value) in six.iteritems(start_info.hparams):
group.hparams[key].CopyFrom(value)
groups_by_name[group_name] = group
def as_proto(self):
return api_pb2.MetricInfo(
name=api_pb2.MetricName(group=self._group, tag=self._tag,),
display_name=self._display_name,
description=self._description,
dataset_type=self._dataset_type,
)
def _protobuf_value_type(value):
"""Returns the type of the google.protobuf.Value message as an
api.DataType.
Returns None if the type of 'value' is not one of the types supported in
api_pb2.DataType.
Args:
value: google.protobuf.Value message.
"""
if value.HasField("number_value"):
return api_pb2.DATA_TYPE_FLOAT64
if value.HasField("string_value"):
return api_pb2.DATA_TYPE_STRING
if value.HasField("bool_value"):
return api_pb2.DATA_TYPE_BOOL
return None
def update_hparam_info(self, hparam_info):
hparam_info.type = {
int: api_pb2.DATA_TYPE_FLOAT64, # TODO(#1998): Add int dtype.
float: api_pb2.DATA_TYPE_FLOAT64,
bool: api_pb2.DATA_TYPE_BOOL,
str: api_pb2.DATA_TYPE_STRING,
}[self._dtype]
hparam_info.ClearField("domain_discrete")
hparam_info.domain_discrete.extend(self._values)
"""
# Figure out the type from the values.
# Ignore values whose type is not listed in api_pb2.DataType
# If all values have the same type, then that is the type used.
# Otherwise, the returned type is DATA_TYPE_STRING.
result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET)
distinct_values = set(
_protobuf_value_to_string(v)
for v in values
if _protobuf_value_type(v)
)
for v in values:
v_type = _protobuf_value_type(v)
if not v_type:
continue
if result.type == api_pb2.DATA_TYPE_UNSET:
result.type = v_type
elif result.type != v_type:
result.type = api_pb2.DATA_TYPE_STRING
if result.type == api_pb2.DATA_TYPE_STRING:
# A string result.type does not change, so we can exit the loop.
break
# If we couldn't figure out a type, then we can't compute the hparam_info.
if result.type == api_pb2.DATA_TYPE_UNSET:
return None
# If the result is a string, set the domain to be the distinct values if
# there aren't too many of them.
if (
result.type == api_pb2.DATA_TYPE_STRING
and len(distinct_values) <= self._max_domain_discrete_len
def _sort(self, session_groups):
"""Sorts 'session_groups' in place according to _request.col_params."""
# Sort by session_group name so we have a deterministic order.
session_groups.sort(key=operator.attrgetter("name"))
# Sort by lexicographical order of the _request.col_params whose order
# is not ORDER_UNSPECIFIED. The first such column is the primary sorting
# key, the second is the secondary sorting key, etc. To achieve that we
# need to iterate on these columns in reverse order (thus the primary key
# is the key used in the last sort).
for col_param, extractor in reversed(
list(zip(self._request.col_params, self._extractors))
):
if col_param.order == api_pb2.ORDER_UNSPECIFIED:
continue
if col_param.order == api_pb2.ORDER_ASC:
session_groups.sort(
key=_create_key_func(
extractor,
none_is_largest=not col_param.missing_values_first,
)
)
elif col_param.order == api_pb2.ORDER_DESC:
session_groups.sort(
key=_create_key_func(
extractor,
none_is_largest=col_param.missing_values_first,
),
reverse=True,
)
def run_tag_from_session_and_metric(session_name, metric_name):
"""Returns a (run,tag) tuple storing the evaluations of the specified
metric.
Args:
session_name: str.
metric_name: MetricName protobuffer.
Returns: (run, tag) tuple.
"""
assert isinstance(session_name, six.string_types)
assert isinstance(metric_name, api_pb2.MetricName)
# os.path.join() will append a final slash if the group is empty; it seems
# like multiplexer.Tensors won't recognize paths that end with a '/' so
# we normalize the result of os.path.join() to remove the final '/' in that
# case.
run = os.path.normpath(os.path.join(session_name, metric_name.group))
tag = metric_name.tag
return run, tag
def create_experiment_summary():
"""Returns a summary proto buffer holding this experiment."""
# Convert TEMPERATURE_LIST to google.protobuf.ListValue
temperature_list = struct_pb2.ListValue()
temperature_list.extend(TEMPERATURE_LIST)
materials = struct_pb2.ListValue()
materials.extend(HEAT_COEFFICIENTS.keys())
return summary.experiment_pb(
hparam_infos=[
api_pb2.HParamInfo(
name="initial_temperature",
display_name="Initial temperature",
type=api_pb2.DATA_TYPE_FLOAT64,
domain_discrete=temperature_list,
),
api_pb2.HParamInfo(
name="ambient_temperature",
display_name="Ambient temperature",
type=api_pb2.DATA_TYPE_FLOAT64,
domain_discrete=temperature_list,
),
api_pb2.HParamInfo(
name="material",
display_name="Material",
type=api_pb2.DATA_TYPE_STRING,
domain_discrete=materials,