Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_stop_notebook_without_commit(self):
httpretty.register_uri(
httpretty.POST,
BaseApiHandler.build_url(
self.api_config.base_url,
"/",
"username",
"project_name",
"notebook",
"stop",
),
content_type="application/json",
status=200,
)
result = self.api_handler.stop_notebook(
"username", "project_name", commit=False
)
assert result.status_code == 200
# Async
def test_default_model_dir(self):
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
est = Estimator(model_fn=self.get_dummy_model_fn())
self.assertIn(_TMP_DIR, est.config.model_dir)
self.assertIn(_TMP_DIR, est.model_dir)
def test_min_eval_frequency_defaults(self):
def dummy_model_fn(features, labels): # pylint: disable=unused-argument
pass
# The default value when model_dir is on GCS is 1000
estimator = Estimator(dummy_model_fn, 'gs://dummy_bucket')
ex = Experiment(estimator, train_input_fn=None, eval_input_fn=None)
self.assertEquals(ex._eval_every_n_steps, 1)
# The default value when model_dir is not on GCS is 1
estimator = Estimator(dummy_model_fn, '/tmp/dummy')
ex = Experiment(estimator, train_input_fn=None, eval_input_fn=None)
self.assertEquals(ex._eval_every_n_steps, 1)
# Make sure default not used when explicitly set
estimator = Estimator(dummy_model_fn, 'gs://dummy_bucket')
ex = Experiment(
estimator,
eval_every_n_steps=123,
train_input_fn=None,
eval_input_fn=None)
self.assertEquals(ex._eval_every_n_steps, 123)
# Make sure default not used when explicitly set as 0
estimator = Estimator(dummy_model_fn, 'gs://dummy_bucket')
ex = Experiment(
estimator,
def test_features_labels_mode(self):
given_features = {'test-features': [[1], [1]]}
given_labels = {'test-labels': [[1], [1]]}
def _input_fn():
return given_features, given_labels
def _model_fn(features, labels, mode):
self.features, self.labels, self.mode = features, labels, mode
return EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions=constant_op.constant([[0.]]))
est = Estimator(model_fn=_model_fn)
est.train(_input_fn, steps=1)
est.evaluate(_input_fn, steps=1)
self.assertEqual(given_features, self.features)
self.assertEqual(given_labels, self.labels)
self.assertTrue(Modes.is_eval(self.mode))
def test_hooks_should_be_session_run_hook(self):
est = Estimator(model_fn=model_fn_global_step_incrementer)
est.train(dummy_input_fn, steps=1)
with self.assertRaisesRegexp(TypeError, 'must be a SessionRunHook'):
est.evaluate(dummy_input_fn, steps=5, hooks=['NotAHook'])
def test_evaluate_from_checkpoint(self):
params = {
'metric_name': 'metric',
'metric_value': 2.}
est1 = Estimator(
model_fn=_model_fn_with_eval_metric_ops,
params=params)
est1.train(dummy_input_fn, steps=5)
est2 = Estimator(model_fn=_model_fn_with_eval_metric_ops, params=params)
scores = est2.evaluate(dummy_input_fn, steps=1,
checkpoint_path=saver.latest_checkpoint(est1.model_dir))
self.assertEqual(5, scores['global_step'])
def test_batch_size_mismatch(self):
def _model_fn(features, labels, mode):
_, _ = features, labels
return EstimatorSpec(
mode,
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.),
predictions={
'y1': constant_op.constant([[10.]]),
'y2': constant_op.constant([[12.], [13]])
})
est = Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
with self.assertRaisesRegexp(ValueError,
'Batch length of predictions should be same'):
next(est.predict(dummy_input_fn))
def test_same_model_dir_in_constructor_and_run_config(self):
class FakeConfig(RunConfig):
@property
def model_dir(self):
return _TMP_DIR
est = Estimator(model_fn=self.get_dummy_model_fn(), config=FakeConfig(), model_dir=_TMP_DIR)
self.assertEqual(_TMP_DIR, est.config.model_dir)
self.assertEqual(_TMP_DIR, est.model_dir)
def test_variable_sharing(self):
l = plx.layers.Dense(units=1)
x = tf.placeholder(dtype=tf.float32, shape=[1, 1])
y = tf.placeholder(dtype=tf.float32, shape=[2, 1])
lx = l(x)
ly = l(y)
init_all_op = tf.global_variables_initializer()
assign_op = l.variables[0].assign_add([[1]])
with self.test_session() as sess:
sess.run(init_all_op)
lx_results = lx.eval({x: [[1]]})
ly_results = ly.eval({y: [[1], [1]]})
assert len(lx_results) == 1
assert len(ly_results) == 2
def graph_fn1(mode, x):
return plx.layers.Dense(units=1)(x)