How to use the chemprop.nn_utils.param_count function in chemprop

To help you get started, we’ve selected a few chemprop examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wengong-jin / chemprop / chemprop / train / run_training.py View on Github external
for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        os.makedirs(save_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}')
            model = load_checkpoint(args.checkpoint_paths[model_idx], current_args=args, logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = build_model(args)

        debug(model)
        debug(f'Number of parameters = {param_count(model):,}')
        if args.cuda:
            debug('Moving model to cuda')
            model = model.cuda()

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)

        if args.adjust_weight_decay:
            args.pnorm_target = compute_pnorm(model)

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)
github wengong-jin / chemprop / hyperparameter_optimization.py View on Github external
hyper_args = deepcopy(args)
        if args.save_dir is not None:
            folder_name = '_'.join([f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}' for key, value in hyperparams.items()])
            hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name)
        for key, value in hyperparams.items():
            setattr(hyper_args, key, value)

        # Record hyperparameters
        logger.info(hyperparams)

        # Cross validate
        mean_score, std_score = cross_validate(hyper_args, train_logger)

        # Record results
        temp_model = build_model(hyper_args)
        num_params = param_count(temp_model)
        logger.info(f'num params: {num_params:,}')
        logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')

        results.append({
            'mean_score': mean_score,
            'std_score': std_score,
            'hyperparams': hyperparams,
            'num_params': num_params
        })

        # Deal with nan
        if np.isnan(mean_score):
            if hyper_args.dataset_type == 'classification':
                mean_score = 0
            else:
                raise ValueError('Can\'t handle nan score for non-classification dataset.')
github wengong-jin / chemprop / model_comparison.py View on Github external
modify_train_args(args)

        # Set up logging for training
        os.makedirs(args.save_dir, exist_ok=True)
        fh = logging.FileHandler(os.path.join(args.save_dir, args.log_name))
        fh.setLevel(logging.DEBUG)

        # Cross validate
        TRAIN_LOGGER.addHandler(fh)
        mean_score, std_score = cross_validate(args, TRAIN_LOGGER)
        TRAIN_LOGGER.removeHandler(fh)

        # Record results
        logger.info(f'{mean_score} +/- {std_score} {metric}')
        temp_model = build_model(args)
        logger.info(f'num params: {param_count(temp_model):,}')