How to use the dtlpy.entities.Dataset function in dtlpy

To help you get started, we’ve selected a few dtlpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github dataloop-ai / ZazuML / dataloop_services / trial_module.py View on Github external
def run(self, dataset, train_query, val_query, model_specs, hp_values, configs=None, progress=None):
        maybe_download_data(dataset, train_query, val_query)

        # get project
        # project = dataset.project
        assert isinstance(dataset, dl.entities.Dataset)
        project = dl.projects.get(project_id=dataset.projects[0])

        # start tune
        cls = getattr(import_module('.adapter', 'ObjectDetNet.' + model_specs['name']), 'AdapterModel')
        #TODO: without roberto work with path / or github
        inputs_dict = {'devices': {'gpu_index': 0}, 'model_specs': model_specs, 'hp_values': hp_values}
        #json save
        #TODO: make sure you dont run two runs in concurrency and have two saving the same thing twice
        torch.save(inputs_dict, 'checkpoint.pt')

        adapter = cls()
        adapter.load()
        if hasattr(adapter, 'reformat'):
            adapter.reformat()
        if hasattr(adapter, 'data_loader'):
            adapter.data_loader()
github dataloop-ai / ZazuML / dataloop_services / predict_module.py View on Github external
def run(self, dataset, val_query, checkpoint_path, model_specs, configs=None, progress=None):
        self.logger.info('checkpoint path: ' + str(checkpoint_path))
        self.logger.info('Beginning to download checkpoint')
        dataset.items.get(filepath='/checkpoints').download(local_path=os.getcwd())
        self.logger.info('checkpoint downloaded, dir is here' + str(os.listdir('.')))
        self.logger.info('downloading data')
        maybe_download_pred_data(dataset, val_query)
        self.logger.info('data downloaded')
        assert isinstance(dataset, dl.entities.Dataset)
        project = dl.projects.get(project_id=dataset.projects[0])
        cls = getattr(import_module('.adapter', 'ObjectDetNet.' + model_specs['name']), 'AdapterModel')

        home_path = model_specs['data']['home_path']

        inputs_dict = {'checkpoint_path': checkpoint_path['checkpoint_path'], 'home_path': home_path}
        torch.save(inputs_dict, 'predict_checkpoint.pt')

        adapter = cls()
        output_path = adapter.predict(home_path=home_path, checkpoint_path=checkpoint_path['checkpoint_path'])
        save_info = {
            'package_name': self.package_name,
            'execution_id': progress.execution.id
        }
        project.artifacts.upload(filepath=output_path,
                                 package_name=save_info['package_name'],