Computing SOAP featuresΒΆ

This examples shows how to compute the SOAP power spectrum descriptor for each atom in each system of a provided systems file. The path to the systems file is taken from the first command line argument.

In the end, the descriptor is transformed in a way compatible with how most classic machine learning (such as PCA or linear regression) work.

The workflow is the same for every provided descriptor. Take a look at the Reference guides for a list with all descriptors and their specific parameters.

You can obtain a testing dataset from our website.

import chemfiles

from rascaline import SoapPowerSpectrum

Read systems using chemfiles. You can obtain the dataset used in this example from our website.

with chemfiles.Trajectory("dataset.xyz") as trajectory:
    systems = [s for s in trajectory]

Rascaline can also handles systems read by ASE using

systems = ase.io.read("dataset.xyz", ":").

We can now define hyper parameters for the calculation

HYPER_PARAMETERS = {
    "cutoff": 5.0,
    "max_radial": 6,
    "max_angular": 4,
    "atomic_gaussian_width": 0.3,
    "center_atom_weight": 1.0,
    "radial_basis": {
        "Gto": {},
    },
    "cutoff_function": {
        "ShiftedCosine": {"width": 0.5},
    },
}

calculator = SoapPowerSpectrum(**HYPER_PARAMETERS)

And then run the actual calculation, including gradients with respect to positions

descriptor = calculator.compute(systems, gradients=["positions"])

The descriptor is a metatensor TensorMap, containing multiple blocks. We can transform it to a single block containing a dense representation, with one sample for each atom-centered environment by using keys_to_samples and keys_to_properties

print("before: ", len(descriptor.keys))

descriptor = descriptor.keys_to_samples("center_type")
descriptor = descriptor.keys_to_properties(["neighbor_1_type", "neighbor_2_type"])
print("after: ", len(descriptor.keys))
before:  40
after:  1

you can now use descriptor.block().values as the input of a machine learning algorithm

print(descriptor.block().values.shape)
(1380, 1800)