Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
(dict): return the state.
"""
return self.state
def load_state_dict(self, state):
"""
Load the state from a given state.
Args:
state (dict): a key value dictionary.
"""
self.state = copy.deepcopy(state)
trained_model, fresh_model = self._create_model(), self._create_model()
with TemporaryDirectory() as f:
checkpointables = CheckpointableObj()
checkpointer = Checkpointer(
trained_model,
save_dir=f,
save_to_disk=True,
checkpointables=checkpointables,
)
checkpointer.save("checkpoint_file")
# in the same folder
fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
self.assertTrue(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(),
os.path.join(f, "checkpoint_file.pth"),
)
checkpoint = fresh_checkpointer.load(
fresh_checkpointer.get_checkpoint_file()
)
(nn.DataParallel(self._create_model()), self._create_model()),
(self._create_model(), nn.DataParallel(self._create_model())),
(
nn.DataParallel(self._create_model()),
nn.DataParallel(self._create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(
trained_model, save_dir=f, save_to_disk=True
)
checkpointer.save("checkpoint_file")
# on different folders.
with TemporaryDirectory() as g:
fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
self.assertFalse(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(), ""
)
fresh_checkpointer.load(
os.path.join(f, "checkpoint_file.pth")
)
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references.
self.assertFalse(id(trained_p) == id(loaded_p))
# same content.
self.assertTrue(trained_p.cpu().equal(loaded_p.cpu()))
def test_from_last_checkpoint_model(self):
"""
test that loading works even if they differ by a prefix.
"""
for trained_model, fresh_model in [
(self._create_model(), self._create_model()),
(nn.DataParallel(self._create_model()), self._create_model()),
(self._create_model(), nn.DataParallel(self._create_model())),
(
nn.DataParallel(self._create_model()),
nn.DataParallel(self._create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(trained_model, save_dir=f)
checkpointer.save("checkpoint_file")
# in the same folder
fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
self.assertTrue(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(),
os.path.join(f, "checkpoint_file.pth"),
)
fresh_checkpointer.load(
fresh_checkpointer.get_checkpoint_file()
)
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
for trained_model, fresh_model in [
(self._create_model(), self._create_model()),
(nn.DataParallel(self._create_model()), self._create_model()),
(self._create_model(), nn.DataParallel(self._create_model())),
(
nn.DataParallel(self._create_model()),
nn.DataParallel(self._create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(trained_model, save_dir=f)
checkpointer.save("checkpoint_file")
# in the same folder
fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
self.assertTrue(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(),
os.path.join(f, "checkpoint_file.pth"),
)
fresh_checkpointer.load(
fresh_checkpointer.get_checkpoint_file()
)
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references
self.assertFalse(id(trained_p) == id(loaded_p))
# same content
self.assertTrue(trained_p.cpu().equal(loaded_p.cpu()))
state (dict): a key value dictionary.
"""
self.state = copy.deepcopy(state)
trained_model, fresh_model = self._create_model(), self._create_model()
with TemporaryDirectory() as f:
checkpointables = CheckpointableObj()
checkpointer = Checkpointer(
trained_model,
save_dir=f,
save_to_disk=True,
checkpointables=checkpointables,
)
checkpointer.save("checkpoint_file")
# in the same folder
fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
self.assertTrue(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(),
os.path.join(f, "checkpoint_file.pth"),
)
checkpoint = fresh_checkpointer.load(
fresh_checkpointer.get_checkpoint_file()
)
state_dict = checkpointables.state_dict()
for key, _ in state_dict.items():
self.assertTrue(
checkpoint["checkpointables"].get(key) is not None
)
self.assertTrue(
checkpoint["checkpointables"][key] == state_dict[key]
)
def test_periodic_checkpointer(self):
"""
test that loading works even if they differ by a prefix.
"""
_period = 10
_max_iter = 100
for trained_model in [
self._create_model(),
nn.DataParallel(self._create_model()),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(
trained_model, save_dir=f, save_to_disk=True
)
periodic_checkpointer = PeriodicCheckpointer(
checkpointer, _period, 99
)
for iteration in range(_max_iter):
periodic_checkpointer.step(iteration)
path = os.path.join(f, "model_{:07d}.pth".format(iteration))
if (iteration + 1) % _period == 0:
self.assertTrue(os.path.exists(path))
else:
self.assertFalse(os.path.exists(path))
def test_from_name_file_model(self):
"""
test that loading works even if they differ by a prefix.
"""
for trained_model, fresh_model in [
(self._create_model(), self._create_model()),
(nn.DataParallel(self._create_model()), self._create_model()),
(self._create_model(), nn.DataParallel(self._create_model())),
(
nn.DataParallel(self._create_model()),
nn.DataParallel(self._create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(
trained_model, save_dir=f, save_to_disk=True
)
checkpointer.save("checkpoint_file")
# on different folders.
with TemporaryDirectory() as g:
fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
self.assertFalse(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(), ""
)
fresh_checkpointer.load(
os.path.join(f, "checkpoint_file.pth")
)
for trained_p, loaded_p in zip(