Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(
self,
dbpath,
subset=None,
available_properties=None,
load_only=None,
units=None,
environment_provider=SimpleEnvironmentProvider(),
collect_triples=False,
center_positions=True,
):
if not dbpath.endswith(".db"):
raise AtomsDataError(
"Invalid dbpath! Please make sure to add the file extension '.db' to "
"your dbpath."
)
self.dbpath = dbpath
self.subset = subset
self.load_only = load_only
self.available_properties = self.get_available_properties(available_properties)
if load_only is None:
self.load_only = self.available_properties
if units is None:
def __init__(
self,
dbpath,
download=True,
subset=None,
load_only=None,
collect_triples=False,
remove_uncharacterized=False,
environment_provider=spk.environment.SimpleEnvironmentProvider(),
**kwargs
):
self.remove_uncharacterized = remove_uncharacterized
available_properties = [
QM9.A,
QM9.B,
QM9.C,
QM9.mu,
QM9.alpha,
QM9.homo,
QM9.lumo,
QM9.gap,
QM9.r2,
QM9.zpve,
def __init__(
self,
dbpath,
download=True,
subset=None,
load_only=None,
collect_triples=False,
num_heavy_atoms=8,
high_energies=False,
environment_provider=spk.environment.SimpleEnvironmentProvider(),
):
available_properties = [ANI1.energy]
units = [Hartree]
self.num_heavy_atoms = num_heavy_atoms
self.high_energies = high_energies
super().__init__(
dbpath=dbpath,
subset=subset,
download=download,
load_only=load_only,
collect_triples=collect_triples,
available_properties=available_properties,
units=units,
environment_provider=environment_provider,
def __init__(
self,
datapath,
fold,
download=True,
load_only=None,
subset=None,
collect_triples=False,
environment_provider=spk.environment.SimpleEnvironmentProvider(),
):
if fold not in self.existing_folds:
raise ValueError("Fold {:s} does not exist".format(fold))
available_properties = [ISO17.E, ISO17.F]
units = [1.0, 1.0]
self.path = datapath
self.fold = fold
dbpath = os.path.join(datapath, "iso17", fold + ".db")
super().__init__(
dbpath=dbpath,
subset=subset,
load_only=load_only,
def __init__(
self,
model,
device="cpu",
collect_triples=False,
environment_provider=SimpleEnvironmentProvider(),
energy=None,
forces=None,
energy_units="eV",
forces_units="eV/Angstrom",
**kwargs
):
Calculator.__init__(self, **kwargs)
self.model = model
self.atoms_converter = AtomsConverter(
environment_provider=environment_provider,
collect_triples=collect_triples,
device=device,
)
def __init__(
self,
dbpath,
apikey=None,
download=True,
subset=None,
load_only=None,
collect_triples=False,
environment_provider=spk.environment.SimpleEnvironmentProvider(),
):
available_properties = [
MaterialsProject.EformationPerAtom,
MaterialsProject.EPerAtom,
MaterialsProject.BandGap,
MaterialsProject.TotalMagnetization,
]
units = [eV, eV, eV, 1.0]
self.apikey = apikey
super(MaterialsProject, self).__init__(
dbpath=dbpath,
subset=subset,
def __init__(
self,
path,
download=True,
subset=None,
load_only=None,
collect_triples=False,
environment_provider=spk.environment.SimpleEnvironmentProvider(),
):
available_properties = [OrganicMaterialsDatabase.BandGap]
units = [eV]
self.path = path
dbpath = self.path.replace(".tar.gz", ".db")
self.dbpath = dbpath
if not os.path.exists(path) and not os.path.exists(dbpath):
raise FileNotFoundError(
"Download OMDB dataset (e.g. OMDB-GAP1.tar.gz) from "
"https://omdb.diracmaterials.org/dataset/ and set datapath to this file"
)
def _convert_atoms(
atoms,
environment_provider=SimpleEnvironmentProvider(),
collect_triples=False,
center_positions=False,
output=None,
):
"""
Helper function to convert ASE atoms object to SchNetPack input format.
Args:
atoms (ase.Atoms): Atoms object of molecule
environment_provider (callable): Neighbor list provider.
device (str): Device for computation (default='cpu')
output (dict): Destination for converted atoms, if not None
Returns:
dict of torch.Tensor: Properties including neighbor lists and masks
reformated into SchNetPack input format.