Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from marshmallow import fields, post_load
from faculty.clients.base import BaseSchema, BaseClient
DatasetsSecrets = namedtuple(
"DatasetsSecrets",
["bucket", "access_key", "secret_key", "region", "verified"],
)
class DatasetsSecretsSchema(BaseSchema):
bucket = fields.String(required=True)
access_key = fields.String(required=True)
secret_key = fields.String(required=True)
region = fields.String(required=True)
verified = fields.Boolean(required=True)
@post_load
def make_project_datasets_secrets(self, data):
return DatasetsSecrets(**data)
class SecretClient(BaseClient):
_SERVICE_NAME = "secret-service"
def datasets_secrets(self, project_id):
"TagSort": _TagSortSchema,
"ParamSort": _ParamSortSchema,
"MetricSort": _MetricSortSchema,
}
class RunQuerySchema(BaseSchema):
filter = _OptionalField(fields.Nested(FilterSchema))
sort = fields.List(fields.Nested(SortSchema))
page = fields.Nested(PageSchema, missing=None)
# Schemas for responses returned from API:
class DeleteExperimentRunsResponseSchema(BaseSchema):
deleted_run_ids = fields.List(
fields.UUID(), data_key="deletedRunIds", required=True
)
conflicted_run_ids = fields.List(
fields.UUID(), data_key="conflictedRunIds", required=True
)
@post_load
def make_delete_runs_response(self, data):
return DeleteExperimentRunsResponse(**data)
class RestoreExperimentRunsResponseSchema(BaseSchema):
restored_run_ids = fields.List(
fields.UUID(), data_key="restoredRunIds", required=True
)
return ReportVersion(**data)
class ReportSchema(BaseSchema):
created_at = fields.DateTime(required=True)
name = fields.String(required=True, data_key="report_name")
id = fields.UUID(required=True, data_key="report_id")
description = fields.String(required=True)
active_version = fields.Nested(ReportVersionSchema, required=True)
@post_load
def make_report(self, data):
return Report(**data)
class ReportWithVersionsSchema(BaseSchema):
created_at = fields.DateTime(required=True)
name = fields.String(required=True, data_key="report_name")
id = fields.UUID(required=True, data_key="report_id")
description = fields.String(required=True)
active_version_id = fields.UUID(required=True)
versions = fields.Nested(ReportVersionSchema, required=True, many=True)
@post_load
def make_report_with_versions(self, data):
return ReportWithVersions(**data)
class ReportClient(BaseClient):
_SERVICE_NAME = "tavern"
"""Field that serialises/deserialises an apt package version."""
def _deserialize(self, value, attr, obj, **kwargs):
if value == "latest":
return "latest"
else:
return AptVersionSchema().load(value)
def _serialize(self, value, attr, obj, **kwargs):
if value == "latest":
return "latest"
else:
return AptVersionSchema().dump(value)
class PythonPackageSchema(BaseSchema):
name = fields.String(required=True)
version = PythonVersionField(required=True)
@post_load
def make_python_package(self, data):
return PythonPackage(**data)
class PipSchema(BaseSchema):
extra_index_urls = fields.List(
fields.String(), data_key="extraIndexUrls", required=True
)
packages = fields.List(fields.Nested(PythonPackageSchema()), required=True)
@post_load
def make_pip(self, data):
value = _FilterValueField(EnumField(ExperimentRunStatus, by_value=True))
by = fields.Constant("status", dump_only=True)
@pre_dump
def check_operator(self, obj):
_validate_discrete(obj.operator)
return obj
class _DeletedAtFilterSchema(BaseSchema):
operator = EnumField(ComparisonOperator, by_value=True)
value = _FilterValueField(fields.DateTime())
by = fields.Constant("deletedAt", dump_only=True)
class _TagFilterSchema(BaseSchema):
key = fields.String()
operator = EnumField(ComparisonOperator, by_value=True)
value = _FilterValueField(fields.String())
by = fields.Constant("tag", dump_only=True)
@pre_dump
def check_operator(self, obj):
_validate_discrete(obj.operator)
return obj
class _ParamFilterSchema(BaseSchema):
key = fields.String()
operator = EnumField(ComparisonOperator, by_value=True)
value = _FilterValueField(_ParamFilterValueField())
by = fields.Constant("param", dump_only=True)
"RunIdFilter": _RunIdFilterSchema,
"RunStatusFilter": _RunStatusFilterSchema,
"DeletedAtFilter": _DeletedAtFilterSchema,
"TagFilter": _TagFilterSchema,
"ParamFilter": _ParamFilterSchema,
"MetricFilter": _MetricFilterSchema,
"CompoundFilter": _CompoundFilterSchema,
}
class _StartedAtSortSchema(BaseSchema):
order = EnumField(SortOrder, by_value=True)
by = fields.Constant("startedAt", dump_only=True)
class _RunNumberSortSchema(BaseSchema):
order = EnumField(SortOrder, by_value=True)
by = fields.Constant("runNumber", dump_only=True)
class _DurationSortSchema(BaseSchema):
order = EnumField(SortOrder, by_value=True)
by = fields.Constant("duration", dump_only=True)
class _TagSortSchema(BaseSchema):
key = fields.String()
order = EnumField(SortOrder, by_value=True)
by = fields.Constant("tag", dump_only=True)
class _ParamSortSchema(BaseSchema):
# Schemas for payloads sent to API:
class ExperimentRunDataSchema(BaseSchema):
metrics = fields.List(fields.Nested(MetricSchema))
params = fields.List(fields.Nested(ParamSchema))
tags = fields.List(fields.Nested(TagSchema))
class ExperimentRunInfoSchema(BaseSchema):
status = EnumField(ExperimentRunStatus, by_value=True, required=True)
ended_at = fields.DateTime(data_key="endedAt", missing=None)
class ListExperimentRunsResponseSchema(BaseSchema):
pagination = fields.Nested(PaginationSchema, required=True)
runs = fields.Nested(ExperimentRunSchema, many=True, required=True)
@post_load
def make_list_runs_response_schema(self, data):
return ListExperimentRunsResponse(**data)
class CreateRunSchema(BaseSchema):
name = fields.String()
parent_run_id = fields.UUID(data_key="parentRunId")
started_at = fields.DateTime(data_key="startedAt")
artifact_location = fields.String(data_key="artifactLocation")
tags = fields.Nested(TagSchema, many=True, required=True)
value = _FilterValueField(EnumField(ExperimentRunStatus, by_value=True))
by = fields.Constant("status", dump_only=True)
@pre_dump
def check_operator(self, obj):
_validate_discrete(obj.operator)
return obj
class _DeletedAtFilterSchema(BaseSchema):
operator = EnumField(ComparisonOperator, by_value=True)
value = _FilterValueField(fields.DateTime())
by = fields.Constant("deletedAt", dump_only=True)
class _TagFilterSchema(BaseSchema):
key = fields.String()
operator = EnumField(ComparisonOperator, by_value=True)
value = _FilterValueField(fields.String())
by = fields.Constant("tag", dump_only=True)
@pre_dump
def check_operator(self, obj):
_validate_discrete(obj.operator)
return obj
class _ParamFilterSchema(BaseSchema):
key = fields.String()
operator = EnumField(ComparisonOperator, by_value=True)
value = _FilterValueField(_ParamFilterValueField())
by = fields.Constant("param", dump_only=True)
@post_load
def make_experiment_run(self, data):
return ExperimentRun(**data)
# Schemas for payloads sent to API:
class _ExperimentRunDataSchema(BaseSchema):
metrics = fields.List(fields.Nested(_MetricSchema))
params = fields.List(fields.Nested(_ParamSchema))
tags = fields.List(fields.Nested(_TagSchema))
class _ExperimentRunInfoSchema(BaseSchema):
status = EnumField(ExperimentRunStatus, by_value=True, required=True)
ended_at = fields.DateTime(data_key="endedAt", missing=None)
class _ListExperimentRunsResponseSchema(BaseSchema):
pagination = fields.Nested(_PaginationSchema, required=True)
runs = fields.Nested(_ExperimentRunSchema, many=True, required=True)
@post_load
def make_list_runs_response_schema(self, data):
return ListExperimentRunsResponse(**data)
class _CreateRunSchema(BaseSchema):
name = fields.String()
parent_run_id = fields.UUID(data_key="parentRunId")
@post_load
def make_job_metadata(self, data):
return JobMetadata(**data)
class JobSummarySchema(BaseSchema):
id = fields.UUID(data_key="jobId", required=True)
metadata = fields.Nested(JobMetadataSchema, data_key="meta", required=True)
@post_load
def make_job_summary(self, data):
return JobSummary(**data)
class InstanceSizeSchema(BaseSchema):
milli_cpus = fields.Integer(data_key="milliCpus", required=True)
memory_mb = fields.Integer(data_key="memoryMb", required=True)
@post_load
def make_instance_size(self, data):
return InstanceSize(**data)
class JobParameterSchema(BaseSchema):
name = fields.String(required=True)
type = EnumField(
ParameterType, data_key="type", by_value=True, required=True
)
default = fields.String(required=True)
required = fields.Boolean(required=True)