Adding a new calculator

Introduction

Before adding a new calculator it be might worth taking a look if your desired already exists in our list supported ones.

In this tutorial, we will go over all the steps required to create a new calculator. For simplicity sake, the calculator we will implement will be very basic, keeping the focus on how different bits of the code interact with one another instead of complex math or performance tricks.

The calculator that we will create computes an atom-centered representation, where each atomic environment is represented with the moments of the positions of the neighbors up to a maximal order. Each atomic type in the neighborhood will be considered separately. The resulting descriptor will represent an atom-centered environment \(\ket{\mathcal{A}_i}\) on a basis of atomic types \(\alpha\) and moment order \(k\):

\[\braket{\alpha k | \mathcal{A}_i} = \frac{1}{N_\text{neighbors}} \sum_{j \in \mathcal{A}_i} r_{ij}^k \ \delta_{\alpha, \alpha_j}\]
../../_images/moments-descriptor.svg

Throughout this tutorial, very basic knowledge of the Rust and Python programming languages is assumed. If you are just starting up, you may find the official Rust book useful; as well as the documentation for the standard library; and the API documentation for rascaline itself.

We will also assume that you have a local copy of the rascaline git repository, and can build the code and run the tests. If not, please look at the Getting started sections.

The traits we’ll use

Two of the three core concepts in rascaline are represented in the code as Rust traits: systems implements the System trait, and calculators implement the CalculatorBase trait. Traits (also called interfaces in other languages) define contracts that the implementing code must follow, in the form of a set of function and documented behavior for these functions. Fulfilling this contract allow to add new systems which work with all calculators, already implement or not; and new calculators which can use any system, already implemented or not.

In this tutorial, our goal is to write a new struct implementing CalculatorBase. This implementation will take as input a slice of boxed System trait objects, and using data from those fill up a TensorMap (defined in the metatensor crate).

Implementation

Let’s start by creating a new file in rascaline/src/calculators/moments.rs, and importing everything we’ll need. Everything in here will be explained when we get to using it.

use metatensor::{Labels, TensorMap, LabelsBuilder};

use crate::{System, Error};
use crate::labels::{CenterSingleNeighborsTypesKeys, KeysBuilder};
use crate::labels::{AtomCenteredSamples, SamplesBuilder, AtomicTypeFilter};
use crate::calculators::CalculatorBase;

Then, we can define a struct for our new calculator GeometricMoments. It will contain two fields: cutoff to store the cutoff radius, and max_moment to store the maximal moment to compute.

#[derive(Clone, Debug)]
struct GeometricMoments {
    cutoff: f64,
    max_moment: usize,
}

We can then write a skeleton implementation for the CalculatorBase trait, leaving all function unimplemented with the todo!() macro. CalculatorBase is the trait defining all the functions required for a calculator. Users might be more familiar with the concrete struct Calculator, which uses a Box<dyn CalculatorBase> (i.e. a pointer to a CalculatorBase) to provide its functionalities.

impl CalculatorBase for GeometricMoments {
    fn name(&self) -> String {
        todo!()
    }

    fn parameters(&self) -> String {
        todo!()
    }

    fn cutoffs(&self) -> &[f64] {
        todo!()
    }

    fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
        todo!()
    }

    fn sample_names(&self) -> Vec<&str> {
        todo!()
    }

    fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
        todo!()
    }

    fn supports_gradient(&self, parameter: &str) -> bool {
        todo!()
    }

    fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
        todo!()
    }

    fn components(&self, keys: &Labels) -> Vec<Vec<Labels>> {
        todo!()
    }

    fn property_names(&self) -> Vec<&str> {
        todo!()
    }

    fn properties(&self, keys: &Labels) -> Vec<Labels> {
        todo!()
    }

    fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
        todo!()
    }
}

We’ll go over these functions one by one, explaining what they do as we go. Most of the functions here are used to communicate metadata about the calculator and the representation, and the compute function does the main part of the work.

Calculator metadata

The first function returning metadata about the calculator itself is name, which should return a user-facing name for the current instance of the descriptor. As a quick refresher on Rust, all functions return the last (and in this case only) expression. Here the expression creates a reference to a str (&str) and then convert it to an heap-allocated String using the Into trait.

fn name(&self) -> String {
    "geometric moments".to_string()
}

Then, the parameters function should return the parameters used to create the current instance of the calculator in JSON format. To this end, we use serde and serde_json everywhere in rascaline, so it is a good idea to do the same here. Let’s start by adding the corresponding #[derive] to the definition of GeometricMoments, and use it to implement the function.

#[derive(Clone, Debug)]
#[derive(serde::Serialize, serde::Deserialize)]
struct GeometricMoments {
    cutoff: f64,
    max_moment: usize,
}
fn parameters(&self) -> String {
    serde_json::to_string(self).expect("failed to serialize to JSON")
}

One interesting thing here is that serde_json::to_string returns a Result<String, serde::Error>, and we use expect to extract the string value. This Result would only contain an error if GeometricMoments contained maps with non-string keys, which is not the case here. expect allow us to indicate we don’t ever expect this function to fail, but if it were to return an error, then the code would immediately stop and show the given message (using a panic).

Finally, the cutoffs function should return all the radial cutoffs used in neighbors lists. Here, we only have one — self.cutoffs — and we use std::slice::from_ref to construct a list with a single element from a scalar.

fn cutoffs(&self) -> &[f64] {
    std::slice::from_ref(&self.cutoff)
}

Representation metadata

The next set of functions in the CalculatorBase trait is used to communicate metadata about the representation, and called by the concrete Calculator struct when initializing and allocating the corresponding memory.

Keys

First, we have one function defining the set of keys that will be in the final TensorMap. In our case, we will want to have the central atom type and the neighbor atom type as keys. This allow to only store data if a given neighbor is actually present around a given atom.

We could manually create a set of Labels with a LabelsBuilder and return them. But since multiple calculators will create the same kind of keys, there are already implementation of typical atomic types keys. Here we use CenterSingleNeighborsTypesKeys to create a set of keys containing the central atom type and one neighbor type. This key builder requires a cutoff (to determine which neighbors it should use) and self_pairs indicated whether atoms should be considered to be their own neighbor or not.

fn keys(&self, systems: &mut [Box<dyn System>]) -> Result<Labels, Error> {
    let builder = CenterSingleNeighborsTypesKeys {
        cutoff: self.cutoff,
        // self pairs would have a distance of 0 and would not contribute
        // anything meaningful to a GeometricMoments representation
        self_pairs: false,
    };
    return builder.keys(systems);
}

Samples

Having defined the keys, we need to define the metadata associated with each block. For each block, the first set of metadata — called the samples – describes the rows of the data. Three functions are used to define the samples: first, features_names defines the name associated with the different columns in the sample labels. Then, samples determines the set of samples associated with each key/block. The return type of the samples function takes some unpacking: we are returning a Result since any call to a System function can fail. The non-error case of the result is a Vec<Labels>: we need one set of Labels for each key/block.

fn sample_names(&self) -> Vec<&str> {
    AtomCenteredSamples::sample_names()
}

fn samples(&self, keys: &Labels, systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
    assert_eq!(keys.names(), ["center_type", "neighbor_type"]);

    let mut samples = Vec::new();
    for [center_type, neighbor_type] in keys.iter_fixed_size() {
        let builder = AtomCenteredSamples {
            cutoff: self.cutoff,
            // only include central atoms of this type
            center_type: AtomicTypeFilter::Single(center_type.i32()),
            // with a neighbor of this type somewhere in the neighborhood
            // defined by the spherical `cutoff`.
            neighbor_type: AtomicTypeFilter::Single(neighbor_type.i32()),
            self_pairs: false,
        };

        samples.push(builder.samples(systems)?);
    }

    return Ok(samples);
}

Like for CalculatorBase::keys, we could manually write code to detect the right set of samples for each key. But since a lot of representation are built on atom-centered neighborhoods, there is already a tool to create the right set of samples in the form of AtomCenteredSamples.

Components

The next set of metadata associated with a block are the components. Each block can have 0 or more components, that should be used to store metadata and information about symmetry operations or any kind of tensorial components.

Here, we dont’ have any components (the GeometricMoments representation is invariant), so we just return a list (one for each key) of empty vectors.

fn components(&self, keys: &Labels) -> Vec<Vec<Labels>> {
    return vec![vec![]; keys.count()];
}

Properties

The properties define metadata associated with the columns of the data arrays. Like for the samples, we have one function to define the set of names associated with each variable in the properties Labels, and one function to compute the set of properties defined for each key.

In our case, there is only one variable in the properties labels, the power \(k\) used to compute the moment. When building the full list of Labels for each key in CalculatorBase::properties, we use the fact that the properties are the same for each key/block and make copies of the Labels (since Labels are reference-counted, the copies are actually quite cheap).

fn property_names(&self) -> Vec<&str> {
    vec!["k"]
}

fn properties(&self, keys: &Labels) -> Vec<Labels> {
    let mut builder = LabelsBuilder::new(self.property_names());
    for k in 0..=self.max_moment {
        builder.add(&[k]);
    }
    let properties = builder.finish();

    return vec![properties; keys.count()];
}

Gradients

Finally, we have metadata related to the gradients. First, the supports_gradient function should return which if any of the gradients can be computed by the current calculator. Typically parameter is either "positions", "cell"`, or "strain". Here we only support computing the gradients with respect to positions.

fn supports_gradient(&self, parameter: &str) -> bool {
    match parameter {
        "positions" => true,
        _ => false,
    }
}

If the user request the calculation of some gradients, and the calculator supports it, the next step is to define the same set of metadata as for the values above: samples, components and properties. Properties are easy, because they are the same between the values and the gradients. The components are also similar, with some additional components added at the beginning depending on the kind of gradient. For example, if a calculator uses [first, second] as it’s set of components, the "positions" gradient would use [xyz, first, second], where xyz contains 3 entries. Similarly, the "strain" gradients would use [xyz_1, xyz_2, first, second] and the "cell" gradients would use [abc, xyz, first, second].

Finally, the samples needs to be defined. For the "cell" or "strain" gradients, there is always exactly one gradient sample per value sample. For the "positions" gradient samples, we could have one gradient sample for each atom in the same system for each value sample. However, this would create a very large number of gradient samples (number of atoms squared), and a lot of entries would be filled with zeros. Instead, each calculator which supports positions gradients must implement the positions_gradient_samples function, and use it to return only the sample associated with non-zero gradients. This function get as input the set of keys, the list of samples associated with each key, and the list of systems on which we want to run the calculation.

We are again using the AtomCenteredSamples here to share code between multiple calculators all using atom-centered samples.

fn positions_gradient_samples(&self, keys: &Labels, samples: &[Labels], systems: &mut [Box<dyn System>]) -> Result<Vec<Labels>, Error> {
    assert_eq!(keys.names(), ["center_type", "neighbor_type"]);
    debug_assert_eq!(keys.count(), samples.len());

    let mut gradient_samples = Vec::new();
    for ([center_type, neighbor_type], samples_for_key) in keys.iter_fixed_size().zip(samples) {
        let builder = AtomCenteredSamples {
            cutoff: self.cutoff,
            center_type: AtomicTypeFilter::Single(center_type.i32()),
            // only include gradients with respect to neighbor atoms with
            // this atomic type (the other atoms do not contribute to the
            // gradients in the current block).
            neighbor_type: AtomicTypeFilter::Single(neighbor_type.i32()),
            self_pairs: false,
        };

        gradient_samples.push(builder.gradients_for(systems, samples_for_key)?);
    }

    return Ok(gradient_samples);
}

We are now done defining the metadata associated with our GeometricMoments calculator! In the next section, we’ll go over the actual calculation of the representation, and how to use the functions provided by System.

The compute function

We are finally approaching the most important function in CalculatorBase, compute. This function takes as input a list of systems and a TensorMap in which to write the results of the calculation. The function also returns a Result, to be able to indicate that an error was reached during the calculation.

The TensorMap is initialized by the concrete Calculator struct, according to parameters provided by the user. In particular, the tensor map will only contain samples and properties requested by the user, meaning that the code in compute should check for each block whether a particular sample (respectively property) is present in block.samples (resp. block.property) before computing it.

This being said, let’s start writing our compute function. We’ll defensively check that the tensor map keys match what we expect from them, and return a unit value () wrapped in Ok at the end of the function.

fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
    assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]);

    // we'll add more code here

    return Ok(());
}

From here, the easiest way to implement our geometric moments descriptor is to iterate over the systems, and then iterate over the pairs in the system. Before we can get the pairs with system.pairs(), we need to compute the neighbors list for our current cutoff, using system.compute_neighbors(), which requires a mutable reference to the system to be able to store the list of computed pairs (hence the iteration using systems.iter_mut()).

All the functions on the System trait return Result, but in contrary to the CalculatorBase::parameters function above, we want to send the possible errors back to the user so that they can deal with them as they want. The question mark ? operator does exactly that: if the value returned by the called function is Err(e), ? immediately returns Err(e); and if the result is Ok(v), ? extract the v and the execution continues.

fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
    assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]);

    for (system_i, system) in systems.iter_mut().enumerate() {
        system.compute_neighbors(self.cutoff)?;

        for pair in system.pairs()? {
            // more code to come here
        }
    }

    return Ok(());
}

For each pair, we now have to find the corresponding block (using the center and neighbor atomic types), and check wether the corresponding sample was requested by the user.

To find blocks and check for samples, we can use the Labels::position function on the keys and the samples Labels. This function returns an Option<usize>, which will be None is the label (key or sample) was not found, and Some(position) where position is an unsigned integer if the label was found. For the keys, we know the blocks must exists, so we again use expect to immediately extract the value of the block index and access the block. For the samples, we keep them as Option<usize> and will deal with missing samples later.

One thing to keep in mind is that a given pair can participate to two different samples. If two atoms i and j are closer than the cutoff, the list of pairs will only contain the i-j pair, and not the j-i pair (it is a so-called half neighbors list). That being said, we can get the list of atomic types with system.types() before the loop over pairs, and then construct the two candidate samples and check for their presence. If neither of the samples was requested, then we can skip the calculation for this pair. We also use system.pairs_containing() to get the number of neighbors a given center has.

fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
    assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]);

    for (system_i, system) in systems.iter_mut().enumerate() {
        system.compute_neighbors(self.cutoff)?;

        // add this line
        let types = system.types()?;

        for pair in system.pairs()? {
            // get the block where the first atom is the center
            let first_block_id = descriptor.keys().position(&[
                types[pair.first].into(), types[pair.second].into(),
            ]);

            // get the sample corresponding to the first atom as a center
            //
            // This will be `None` if the block or samples are not present
            // in the descriptor, i.e. if the user did not request them.
            let first_sample_position = if let Some(block_id) = first_block_id {
                descriptor.block_by_id(block_id).samples().position(&[
                    system_i.into(), pair.first.into()
                ])
            } else {
                None
            };

            // get the id of the block where the second atom is the center
            let second_block_id = descriptor.keys().position(&[
                types[pair.second].into(), types[pair.first].into(),
            ]);
            // get the sample corresponding to the first atom as a center
            let second_sample_position = if let Some(block_id) = second_block_id {
                descriptor.block_by_id(block_id).samples().position(&[
                    system_i.into(), pair.second.into()
                ])
            } else {
                None
            };

            // skip calculation if neither of the samples is present
            if first_sample_position.is_none() && second_sample_position.is_none() {
                continue;
            }

            let n_neighbors_first = system.pairs_containing(pair.first)?.len() as f64;
            let n_neighbors_second = system.pairs_containing(pair.second)?.len() as f64;

            // more code coming up here!
        }
    }

    return Ok(());
}

Now, we can check if the samples are present, and if they are, iterate over the requested features, compute the moments for the current pair distance, and accumulate it in the descriptor values array:

fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
    // ...
    for (system_i, system) in systems.iter_mut().enumerate() {
        // ...
        for pair in system.pairs()? {
            // ...

            let n_neighbors_first = system.pairs_containing(pair.first)?.len() as f64;
            let n_neighbors_second = system.pairs_containing(pair.second)?.len() as f64;

            if let Some(sample_i) = first_sample_position {
                let block_id = first_block_id.expect("we have a sample in this block");
                let mut block = descriptor.block_mut_by_id(block_id);
                let block = block.data_mut();
                let array = block.values.to_array_mut();

                for (property_i, [k]) in block.properties.iter_fixed_size().enumerate() {
                    let value = f64::powi(pair.distance, k.i32()) / n_neighbors_first;
                    array[[sample_i, property_i]] += value;
                }
            }

            if let Some(sample_i) = second_sample_position {
                let block_id = second_block_id.expect("we have a sample in this block");
                let mut block = descriptor.block_mut_by_id(block_id);
                let block = block.data_mut();
                let array = block.values.to_array_mut();

                for (property_i, [k]) in block.properties.iter_fixed_size().enumerate() {
                    let value = f64::powi(pair.distance, k.i32()) / n_neighbors_second;
                    array[[sample_i, property_i]] += value;
                }
            }

            // more code coming up
        }
    }
    return Ok(());
}

Finally, we can deal with the gradients. We first check if gradient data is defined in the descriptor we need to fill, by checking if it is defined on the first block (we know it is either defined on all blocks or none).

If we need to compute the gradients with respect to atomic positions, we will us the following expression:

\[\frac{\partial}{\partial \vec{r_{j}}} \braket{\alpha k | \chi_i} = \frac{\vec{r_{ij}}}{r_{ij}} \cdot \frac{k \ r_{ij}^{k - 1} \ \delta_{\alpha, \alpha_j}}{N_\text{neighbors}} = \vec{r_{ij}} \frac{k \ r_{ij}^{k - 2} \ \delta_{\alpha, \alpha_j}}{N_\text{neighbors}}\]

The code to compute gradients is very similar to the code computing the representation, checking the existence of a given gradient sample before writing to it. There are now four possible contributions for a given pair: \(\partial \ket{\chi_i} / \partial r_j\), \(\partial \ket{\chi_j} / \partial r_i\), \(\partial \ket{\chi_i} / \partial r_i\) and \(\partial \ket{\chi_j} / \partial r_j\), where \(\ket{\chi_i}\) is the representation around atom \(i\). Another way to say it is that in addition to the gradients of the descriptor centered on \(i\) with respect to atom \(j\), we also need to account for the gradient of the descriptor centered on atom \(i\) with respect to its own position.

fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
    // ...

    // add these lines
    assert!(descriptor.keys().count() > 0);
    let do_positions_gradients = descriptor.block_by_id(0).gradient("positions").is_some();

    for (system_i, system) in systems.iter_mut().enumerate() {
        // ...
        for pair in system.pairs()? {
            // ...

            if do_positions_gradients {
                let mut moment_gradients = Vec::new();
                for k in 0..=self.max_moment {
                    moment_gradients.push([
                        pair.vector[0] * k as f64 * f64::powi(pair.distance, (k as i32) - 2),
                        pair.vector[1] * k as f64 * f64::powi(pair.distance, (k as i32) - 2),
                        pair.vector[2] * k as f64 * f64::powi(pair.distance, (k as i32) - 2),
                    ]);
                }

                if let Some(sample_position) = first_sample_position {
                    let block_id = first_block_id.expect("we have a sample in this block");
                    let mut block = descriptor.block_mut_by_id(block_id);

                    let mut gradient = block.gradient_mut("positions").expect("missing gradient storage");
                    let gradient = gradient.data_mut();
                    let array = gradient.values.to_array_mut();

                    let gradient_wrt_second = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.second.into()
                    ]);
                    let gradient_wrt_self = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.first.into()
                    ]);

                    for (property_i, [k]) in gradient.properties.iter_fixed_size().enumerate() {
                        if let Some(sample_i) = gradient_wrt_second {
                            let grad = moment_gradients[k.usize()];
                            // There is one extra dimension in the gradients
                            // array compared to the values, accounting for
                            // each of the Cartesian directions.
                            array[[sample_i, 0, property_i]] += grad[0] / n_neighbors_first;
                            array[[sample_i, 1, property_i]] += grad[1] / n_neighbors_first;
                            array[[sample_i, 2, property_i]] += grad[2] / n_neighbors_first;
                        }

                        if let Some(sample_i) = gradient_wrt_self {
                            let grad = moment_gradients[k.usize()];
                            array[[sample_i, 0, property_i]] -= grad[0] / n_neighbors_first;
                            array[[sample_i, 1, property_i]] -= grad[1] / n_neighbors_first;
                            array[[sample_i, 2, property_i]] -= grad[2] / n_neighbors_first;
                        }
                    }
                }

                if let Some(sample_position) = second_sample_position {
                    let block_id = second_block_id.expect("we have a sample in this block");
                    let mut block = descriptor.block_mut_by_id(block_id);

                    let mut gradient = block.gradient_mut("positions").expect("missing gradient storage");
                    let gradient = gradient.data_mut();
                    let array = gradient.values.to_array_mut();

                    let gradient_wrt_first = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.first.into()
                    ]);
                    let gradient_wrt_self = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.second.into()
                    ]);

                    for (property_i, [k]) in gradient.properties.iter_fixed_size().enumerate() {
                        if let Some(sample_i) = gradient_wrt_first {
                            let grad = moment_gradients[k.usize()];
                            array[[sample_i, 0, property_i]] -= grad[0] / n_neighbors_second;
                            array[[sample_i, 1, property_i]] -= grad[1] / n_neighbors_second;
                            array[[sample_i, 2, property_i]] -= grad[2] / n_neighbors_second;
                        }

                        if let Some(sample_i) = gradient_wrt_self {
                            let grad = moment_gradients[k.usize()];
                            array[[sample_i, 0, property_i]] += grad[0] / n_neighbors_second;
                            array[[sample_i, 1, property_i]] += grad[1] / n_neighbors_second;
                            array[[sample_i, 2, property_i]] += grad[2] / n_neighbors_second;
                        }
                    }
                }
            }
        }
    }

    return Ok(());
}

Here is the final implementation for the compute function
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
    assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]);
    assert!(descriptor.keys().count() > 0);

    let do_positions_gradients = descriptor.block_by_id(0).gradient("positions").is_some();

    for (system_i, system) in systems.iter_mut().enumerate() {
        system.compute_neighbors(self.cutoff)?;
        let types = system.types()?;

        for pair in system.pairs()? {
            let first_block_id = descriptor.keys().position(&[
                types[pair.first].into(), types[pair.second].into(),
            ]);

            let first_sample_position = if let Some(block_id) = first_block_id {
                descriptor.block_by_id(block_id).samples().position(&[
                    system_i.into(), pair.first.into()
                ])
            } else {
                None
            };

            let second_block_id = descriptor.keys().position(&[
                types[pair.second].into(), types[pair.first].into(),
            ]);
            let second_sample_position = if let Some(block_id) = second_block_id {
                descriptor.block_by_id(block_id).samples().position(&[
                    system_i.into(), pair.second.into()
                ])
            } else {
                None
            };

            if first_sample_position.is_none() && second_sample_position.is_none() {
                continue;
            }

            let n_neighbors_first = system.pairs_containing(pair.first)?.len() as f64;
            let n_neighbors_second = system.pairs_containing(pair.second)?.len() as f64;

            if let Some(sample_i) = first_sample_position {
                let block_id = first_block_id.expect("we have a sample in this block");
                let mut block = descriptor.block_mut_by_id(block_id);
                let block = block.data_mut();
                let array = block.values.to_array_mut();

                for (property_i, [k]) in block.properties.iter_fixed_size().enumerate() {
                    let value = f64::powi(pair.distance, k.i32()) / n_neighbors_first;
                    array[[sample_i, property_i]] += value;
                }
            }

            if let Some(sample_i) = second_sample_position {
                let block_id = second_block_id.expect("we have a sample in this block");
                let mut block = descriptor.block_mut_by_id(block_id);
                let block = block.data_mut();
                let array = block.values.to_array_mut();

                for (property_i, [k]) in block.properties.iter_fixed_size().enumerate() {
                    let value = f64::powi(pair.distance, k.i32()) / n_neighbors_second;
                    array[[sample_i, property_i]] += value;
                }
            }

            if do_positions_gradients {
                let mut moment_gradients = Vec::new();
                for k in 0..=self.max_moment {
                    moment_gradients.push([
                        pair.vector[0] * k as f64 * f64::powi(pair.distance, (k as i32) - 2),
                        pair.vector[1] * k as f64 * f64::powi(pair.distance, (k as i32) - 2),
                        pair.vector[2] * k as f64 * f64::powi(pair.distance, (k as i32) - 2),
                    ]);
                }

                if let Some(sample_position) = first_sample_position {
                    let block_id = first_block_id.expect("we have a sample in this block");
                    let mut block = descriptor.block_mut_by_id(block_id);

                    let mut gradient = block.gradient_mut("positions").expect("missing gradient storage");
                    let gradient = gradient.data_mut();
                    let array = gradient.values.to_array_mut();

                    let gradient_wrt_second = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.second.into()
                    ]);
                    let gradient_wrt_self = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.first.into()
                    ]);

                    for (property_i, [k]) in gradient.properties.iter_fixed_size().enumerate() {
                        if let Some(sample_i) = gradient_wrt_second {
                            let grad = moment_gradients[k.usize()];
                            array[[sample_i, 0, property_i]] += grad[0] / n_neighbors_first;
                            array[[sample_i, 1, property_i]] += grad[1] / n_neighbors_first;
                            array[[sample_i, 2, property_i]] += grad[2] / n_neighbors_first;
                        }

                        if let Some(sample_i) = gradient_wrt_self {
                            let grad = moment_gradients[k.usize()];
                            array[[sample_i, 0, property_i]] -= grad[0] / n_neighbors_first;
                            array[[sample_i, 1, property_i]] -= grad[1] / n_neighbors_first;
                            array[[sample_i, 2, property_i]] -= grad[2] / n_neighbors_first;
                        }
                    }
                }

                if let Some(sample_position) = second_sample_position {
                    let block_id = second_block_id.expect("we have a sample in this block");
                    let mut block = descriptor.block_mut_by_id(block_id);

                    let mut gradient = block.gradient_mut("positions").expect("missing gradient storage");
                    let gradient = gradient.data_mut();
                    let array = gradient.values.to_array_mut();

                    let gradient_wrt_first = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.first.into()
                    ]);
                    let gradient_wrt_self = gradient.samples.position(&[
                        sample_position.into(), system_i.into(), pair.second.into()
                    ]);

                    for (property_i, [k]) in gradient.properties.iter_fixed_size().enumerate() {
                        if let Some(sample_i) = gradient_wrt_first {
                            let grad = moment_gradients[k.usize()];
                            array[[sample_i, 0, property_i]] -= grad[0] / n_neighbors_second;
                            array[[sample_i, 1, property_i]] -= grad[1] / n_neighbors_second;
                            array[[sample_i, 2, property_i]] -= grad[2] / n_neighbors_second;
                        }

                        if let Some(sample_i) = gradient_wrt_self {
                            let grad = moment_gradients[k.usize()];
                            array[[sample_i, 0, property_i]] += grad[0] / n_neighbors_second;
                            array[[sample_i, 1, property_i]] += grad[1] / n_neighbors_second;
                            array[[sample_i, 2, property_i]] += grad[2] / n_neighbors_second;
                        }
                    }
                }
            }
        }
    }
    return Ok(());
}

Registering the new calculator

Now that we are done with the code for this calculator, we need to make it available to users. The entry point for users is the Calculator struct, which needs to be constructed from a calculator name and hyper-parameters in JSON format.

When the user calls Calculator::new("calculator_name", "{\"hyper_parameters\": 1}"), rascaline looks for "calculator_name" in the global calculator registry, and tries to create an instance using the hyper-parameters. In order to make our calculator available to all users, we need to add it to this registry, in rascaline/src/calculator.rs. The registry looks like this:

static REGISTERED_CALCULATORS: Lazy<BTreeMap<&'static str, CalculatorCreator>> = Lazy::new(|| {
    let mut map = BTreeMap::new();
    add_calculator!(map, "atomic_composition", AtomicComposition);
    add_calculator!(map, "dummy_calculator", DummyCalculator);
    add_calculator!(map, "neighbor_list", NeighborList);
    add_calculator!(map, "sorted_distances", SortedDistances);

    add_calculator!(map, "spherical_expansion_by_pair", SphericalExpansionByPair, SphericalExpansionParameters);
    add_calculator!(map, "spherical_expansion", SphericalExpansion, SphericalExpansionParameters);
    add_calculator!(map, "soap_radial_spectrum", SoapRadialSpectrum, RadialSpectrumParameters);
    add_calculator!(map, "soap_power_spectrum", SoapPowerSpectrum, PowerSpectrumParameters);

    add_calculator!(map, "lode_spherical_expansion", LodeSphericalExpansion, LodeSphericalExpansionParameters);
    return map;
});

add_calculator! is a local macro that takes three or four arguments: the registry itself (a BTreeMap), the calculator name, the struct implementing CalculatorBase and optionally a struct to use as parameters to create the previous one. In our case, we want to use the three arguments version in something like add_calculator!(map, "geometric_moments", GeometricMoments);. You’ll need to make sure to bring your new calculator in scope with a use item.

Additionally, you may want to add a convenience class in Python for our new calculator. For this, you can add a class like this to python/rascaline/calculators.py:

class GeometricMoments(CalculatorBase):
 """ TODO: documentation """

   def __init__(self, cutoff, max_moment, gradients):
      parameters = {
            "cutoff": cutoff,
            "max_moment": max_moment,
            "gradients": gradients,
      }
      super().__init__("geometric_moments", parameters)


#############################################################################

# this allows using the calculator like this
from rascaline import GeometricMoments
calculator = GeometricMoments(cutoff=3.5, max_moment=6, gradients=False)

# instead of
from rascaline.calculators import CalculatorBase
calculator = CalculatorBase(
   "geometric_moments",
   {"cutoff": 3.5, "max_moment": 6, "gradients": False},
)

We have now finished our implementation of the geometric moments calculator! In the next steps, we’ll see how to write tests to ensure the calculator works and how to write some documentation for it.

Testing the new calculator

Before we can release our new calculator in the world, we need to make sure it currently behaves as intended, and that we have a way to ensure it continues to behave as intended as the code changes. To achieve both goals, rascaline uses unit tests and regression tests. Unit tests are written in the same file as the main part of the code, in a tests module, and are expected to test high level properties of the code. For example, unit tests allow to check that the computed gradient match the derivatives of the computed values; or that the right values are computed when the users requests a subset of samples & features. On the other hand, regression tests check the exact values produced by a given calculator on a specific system; and that these values stay the same as we modify the code, for example when trying to optimize it. These regression tests live in the rascaline/tests folder, with one file per test.

This tutorial will focus on unit tests and introduce some utilities for tests that should apply to all calculators. To write regression tests, you should take inspiration from existing tests such as spherical-expansion test. Each Rust file in rascaline/tests is associated with a Python file in rascaline/tests/data used to generate the values the regression test is checking, so you’ll need one of these as well.

Testing properties

If this is the first time you are writing tests in Rust, you should read the corresponding chapter in the official Rust book for a great introduction to this subject.

Depending on the representation you are working with, you should write tests that check the fundamental properties of this representation. For example, for our geometric moments representation, the first moment (with order 0) should always be the number of neighbor of the current atomic type over the total number of neighbors. A test checking this property would look like this:

#[cfg(test)]
mod tests {
    use super::*;
    use crate::Calculator;
    use crate::systems::test_utils::test_systems;

    use approx::assert_relative_eq;

    use ndarray::array;

    #[test]
    fn zeroth_moment() {
        // Create a Calculator wrapping a GeometricMoments instance
        let mut calculator = Calculator::from(Box::new(GeometricMoments{
            cutoff: 2.5,
            max_moment: 0,
        }) as Box<dyn CalculatorBase>);

        // create a bunch of systems in a format compatible with `calculator.compute`.
        // Available systems include "water" and "methane" for the corresponding
        // molecules, and "CH" for a basic 2 atoms system.
        let mut systems = test_systems(&["water", "CH"]);

        // run the calculation using default parameters
        let descriptor = calculator.compute(&mut systems, Default::default()).unwrap();

        // check the results
        assert_eq!(*descriptor.keys(), Labels::new(
            ["center_type", "neighbor_type"],
            &[[-42, 1], [1, -42], [1, 1], [1, 6], [6, 1]]
        ));

        let expected_properties = Labels::new(["k"], &[[0]]);

        /**********************************************************************/
        // O center, H neighbor
        let block = &descriptor.block_by_id(0);
        assert_eq!(block.samples(), Labels::new(
            ["system", "atom"],
            &[[0, 0]]
        ));

        assert_eq!(block.properties(), expected_properties);

        assert_relative_eq!(block.values().as_array(), &array![[2.0 / 2.0]].into_dyn());

        /**********************************************************************/
        // H center, O neighbor
        let block = &descriptor.block_by_id(1);
        assert_eq!(block.samples(), Labels::new(
            ["system", "atom"],
            &[[0, 1], [0, 2]]
        ));

        assert_eq!(block.properties(), expected_properties);

        assert_relative_eq!(block.values().as_array(), &array![[1.0 / 2.0], [1.0 / 2.0]].into_dyn());

        /**********************************************************************/
        // H center, H neighbor
        let block = &descriptor.block_by_id(2);
        assert_eq!(block.samples(), Labels::new(
            ["system", "atom"],
            &[[0, 1], [0, 2]]
        ));

        assert_eq!(block.properties(), expected_properties);

        assert_relative_eq!(block.values().as_array(), &array![[1.0 / 2.0], [1.0 / 2.0]].into_dyn());

        /**********************************************************************/
        // H center, C neighbor
        let block = &descriptor.block_by_id(3);
        assert_eq!(block.samples(), Labels::new(
            ["system", "atom"],
            &[[1, 1]]
        ));

        assert_eq!(block.properties(), expected_properties);

        assert_relative_eq!(block.values().as_array(), &array![[1.0 / 1.0]].into_dyn());

        /**********************************************************************/
        // C center, H neighbor
        let block = &descriptor.block_by_id(4);
        assert_eq!(block.samples(), Labels::new(
            ["system", "atom"],
            &[[1, 0]]
        ));

        assert_eq!(block.properties(), expected_properties);

        assert_relative_eq!(block.values().as_array(), &array![[1.0 / 1.0]].into_dyn());
    }
}

The rascaline::systems::test_utils::test_systems function provides a couple of very simple systems to be used for testing.

Testing partial calculations

One properties that all calculators must respect is that computing only a subset of samples or feature should give the same values as computing everything. Rascaline provides a function (calculators::tests_utils::compute_partial) to check this for you, simplifying the tests a bit. Here is how one can use it with the GeometricMoments calculator:

    #[test]
    fn compute_partial() {
        let mut calculator = Calculator::from(Box::new(GeometricMoments{
            cutoff: 2.5,
            max_moment: 6,
        }) as Box<dyn CalculatorBase>);

        let mut systems = test_systems(&["water", "methane"]);

        // build a list of samples to compute
        let samples = Labels::new(
            ["system", "atom"],
            &[[0, 1], [0, 2], [1, 0], [1, 2]]
        );

        // create some properties. There is no need to order them in the same way
        // as the default calculator
        let properties = Labels::new(["k"], &[[2], [1], [5]]);

        // Some keys (more than the calculator would produce by default)
        let keys = Labels::new(
            ["center_type", "neighbor_type"],
            &[[-42, 1], [1, 8], [1, -42], [8, 8], [1, 1], [1, 6], [6, 1]]
        );

        // this function will check that selecting keys/samples/properties will
        // not change the result of the calculation
        crate::calculators::tests_utils::compute_partial(
            calculator, &mut systems, &keys, &samples, &properties
        );
    }

Testing gradients

If a calculator can compute gradients, it is a good idea to check if the gradient does match the finite differences definition of derivatives. Rascaline provides calculators::tests_utils::finite_difference to help check this.

    #[test]
    fn finite_differences() {
        let mut calculator = Calculator::from(Box::new(GeometricMoments{
            cutoff: 2.5,
            max_moment: 7,
        }) as Box<dyn CalculatorBase>);

        let system = test_system("water");

        let options = crate::calculators::tests_utils::FinalDifferenceOptions {
            displacement: 1e-6,
            max_relative: 1e-6,
            epsilon: 1e-9,
        };

        crate::calculators::tests_utils::finite_differences_positions(calculator, &system, options);
    }

Documenting the new calculator

Warning

Work in progress

This section of the documentation is not yet written