Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions src/ML-METATOMIC/fix_metatomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
------------------------------------------------------------------------- */
#include "metatomic_types.h"
#include "metatomic_system.h"
#include "metatomic_units.h"

#include "fix_metatomic.h"

Expand Down Expand Up @@ -62,25 +63,15 @@ FixMetatomic::FixMetatomic(LAMMPS *lmp, int narg, char **arg): Fix(lmp, narg, ar

// Determine unit system for the ML model
// Currently only 'metal' units are fully supported for momenta
std::string energy_unit;
std::string length_unit;
if (strcmp(update->unit_style, "metal") == 0) {
length_unit = "angstrom";
this->momentum_conversion_factor = 10.1805057179 / 1000.0;
} else if (strcmp(update->unit_style, "real") == 0) {
length_unit = "angstrom";
this->momentum_conversion_factor = 10.1805057179;
} else if (strcmp(update->unit_style, "si") == 0) {
length_unit = "m";
this->momentum_conversion_factor = 10.1805057179 / 1.6605390666e-22;
} else {
if (strcmp(update->unit_style, "lj") == 0) {
error->all(FLERR, "unsupported units '{}' for fix metatomic", update->unit_style);
}

// For now, only metal units are fully tested and supported
if (strcmp(update->unit_style, "metal") != 0) {
error->all(FLERR, "fix metatomic currently only supports 'metal' units");
}
std::string energy_unit= unit_map.at("energy").at(update->unit_style);
std::string length_unit = unit_map.at("position").at(update->unit_style);
std::string mass_unit = unit_map.at("mass").at(update->unit_style);
std::string velocity_unit = unit_map.at("velocity").at(update->unit_style);
std::string momentum_unit = mass_unit + "*" + velocity_unit;
this->momentum_conversion_factor = metatomic_torch::unit_conversion_factor(momentum_unit, "(u*eV)^(1/2)");

if (narg < 4) {
error->all(FLERR,
Expand Down Expand Up @@ -445,16 +436,26 @@ void FixMetatomic::initial_integrate(int /*vflag*/) {
error->all(FLERR, "the model requested an unsupported dtype '{}'", mta_data->capabilities->dtype());
}

// deal with the model requested inputs
std::map<std::string, metatomic_torch::ModelOutput> input_holders;
auto requested_inputs = mta_data->model->run_method("requested_inputs").toGenericDict();
for (const auto& entry : requested_inputs) {
input_holders.emplace(
entry.key().toStringRef(),
entry.value().toCustomClass<metatomic_torch::ModelOutputHolder>()
);
}
// transform from LAMMPS to metatomic System
auto system = this->system_adaptor->system_from_lmp(
mta_list,
static_cast<bool>(vflag_global),
dtype,
mta_data->device
mta_data->device,
input_holders
);

// add the required additional inputs
this->system_adaptor->add_masses(system, 1.0);
this->system_adaptor->add_masses(system, metatomic_torch::unit_conversion_factor(unit_map.at("mass").at(update->unit_style), "u"));
this->system_adaptor->add_momenta(system, this->momentum_conversion_factor);

// Configure selected atoms for evaluation
Expand Down
147 changes: 97 additions & 50 deletions src/ML-METATOMIC/metatomic_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
------------------------------------------------------------------------- */
#include "metatomic_system.h"
#include "metatomic_timer.h"
#include "metatomic_units.h"

#include "atom.h"
#include "comm.h"
#include "domain.h"
#include "error.h"
#include "update.h"

#include "neigh_list.h"

#include <map>
#include <string>

#include <metatensor/torch.hpp>
Expand Down Expand Up @@ -206,6 +209,56 @@ static std::array<int32_t, 3> cell_shifts(
return {shift_a, shift_b, shift_c};
}

metatensor_torch::TensorMap LAMMPS_NS::make_per_atom_tensormap(
const torch::Tensor& values,
const torch::ScalarType& dtype,
const torch::Device& device,
const std::string& property_name,
const std::vector<std::string>& component_names
) {
assert (values.dim() == static_cast<int64_t>(component_names.size() + 1));

auto n_atoms = values.size(0);
auto label_tensor_options = torch::TensorOptions().dtype(torch::kInt32).device(device);

auto keys = metatensor_torch::LabelsHolder::single()->to(device);
auto samples_values = torch::column_stack({
torch::zeros(n_atoms, label_tensor_options).unsqueeze(1),
torch::arange(n_atoms, label_tensor_options).unsqueeze(1)
});
auto samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
std::vector<std::string>{"system", "atom"},
samples_values,
metatensor::assume_unique{}
);

auto components = std::vector<metatensor_torch::Labels>{};
for (size_t axis = 0; axis < component_names.size(); axis++) {
auto component_values = torch::arange(values.size(axis + 1), label_tensor_options).unsqueeze(1);
auto component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
std::vector<std::string>{component_names[axis]},
component_values
);
components.push_back(component);
}

auto properties = torch::make_intrusive<metatensor_torch::LabelsHolder>(
std::vector<std::string>{property_name},
torch::tensor({{0}}, label_tensor_options)
);
auto block = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
values.to(dtype).to(device).unsqueeze(-1),
samples,
components,
properties
);

return torch::make_intrusive<metatensor_torch::TensorMapHolder>(
keys,
std::vector<metatensor_torch::TensorBlock>{block}
);
}

void MetatomicSystemAdaptor::guess_periodic_ghosts() {
auto _ = MetatomicTimer("identifying periodic ghosts");
auto total_n_atoms = atom->nlocal + atom->nghost;
Expand Down Expand Up @@ -520,7 +573,8 @@ metatomic_torch::System MetatomicSystemAdaptor::system_from_lmp(
NeighList* list,
bool do_virial,
torch::ScalarType dtype,
torch::Device device
torch::Device device,
const std::map<std::string, torch::intrusive_ptr<metatomic_torch::ModelOutputHolder>>& inputs
) {
auto _ = MetatomicTimer("creating System from LAMMPS data");

Expand Down Expand Up @@ -589,6 +643,21 @@ metatomic_torch::System MetatomicSystemAdaptor::system_from_lmp(

this->setup_neighbors(system, list);

for (const auto& [property, input]: inputs) {
const auto& property_name = property.c_str();
const auto& unit = input->unit().c_str();
if (strcmp(property_name, "masses") == 0) {
add_masses(system, metatomic_torch::unit_conversion_factor(unit_map.at("mass").at(update->unit_style), unit));
} else if (strcmp(property_name, "momenta") == 0) {
const auto& momentum_unit = unit_map.at("mass").at(update->unit_style) + "*" + unit_map.at("velocity").at(update->unit_style);
add_momenta(system, metatomic_torch::unit_conversion_factor(momentum_unit, unit));
} else if (strcmp(property_name, "velocities") == 0) {
add_velocities(system, metatomic_torch::unit_conversion_factor(unit_map.at("velocity").at(update->unit_style), unit));
} else {
error->all(FLERR, "compute metatomic: the model requested an unsupported additional input of '{}'", property_name);
}
}

return system;
}

Expand Down Expand Up @@ -628,29 +697,7 @@ void MetatomicSystemAdaptor::add_masses(metatomic_torch::System& system, double

masses = masses.index_select(0, mta_to_lmp_tensor);
masses = masses * unit_conversion;

auto keys = metatensor_torch::LabelsHolder::single()->to(device);
auto label_tensor_options = torch::TensorOptions().dtype(torch::kInt32).device(device);
auto samples_values = torch::column_stack({
torch::zeros(system->size(), label_tensor_options).unsqueeze(1),
torch::arange(system->size(), label_tensor_options).unsqueeze(1)
});
auto samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
std::vector<std::string>{"system","atom"},
samples_values,
metatensor::assume_unique{}
);
auto properties = metatensor_torch::LabelsHolder::single()->to(device);
auto block = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
masses.to(dtype).to(device).unsqueeze(-1), // add property dimension
samples,
std::vector<metatensor_torch::Labels>{},
properties
);
auto tensor = torch::make_intrusive<metatensor_torch::TensorMapHolder>(
keys,
std::vector<metatensor_torch::TensorBlock>{block}
);
auto tensor = make_per_atom_tensormap(masses, dtype, device, "mass");

system->add_data("masses", tensor, /*override=*/true);
}
Expand Down Expand Up @@ -685,37 +732,37 @@ void MetatomicSystemAdaptor::add_momenta(metatomic_torch::System& system, double

momenta = momenta.index_select(0, mta_to_lmp_tensor);
momenta = momenta * unit_conversion;
auto tensor = make_per_atom_tensormap(momenta, dtype, device, "momentum", {"xyz"});

auto keys = metatensor_torch::LabelsHolder::single()->to(device);
system->add_data("momenta", tensor, /*override=*/true);
}

auto label_tensor_options = torch::TensorOptions().dtype(torch::kInt32).device(device);
auto samples_values = torch::column_stack({
torch::zeros(system->size(), label_tensor_options).unsqueeze(1),
torch::arange(system->size(), label_tensor_options).unsqueeze(1)
});
auto samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
std::vector<std::string>{"system","atom"},
samples_values,
metatensor::assume_unique{}
);
void MetatomicSystemAdaptor::add_velocities(metatomic_torch::System& system, double unit_conversion) {
double** v = atom->v;

auto component_values = torch::arange(3, label_tensor_options).unsqueeze(1);
auto component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
std::vector<std::string>{"xyz"}, component_values
);
auto total_n_atoms = atom->nlocal + atom->nghost;

auto properties = metatensor_torch::LabelsHolder::single()->to(device);
auto device = system->device();
auto dtype = system->scalar_type();

auto block = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
momenta.to(dtype).to(device).unsqueeze(-1),
samples,
std::vector<metatensor_torch::Labels>{component},
properties
);
auto tensor = torch::make_intrusive<metatensor_torch::TensorMapHolder>(
keys,
std::vector<metatensor_torch::TensorBlock>{block}
auto mta_to_lmp_tensor = torch::from_blob(
mta_to_lmp.data(),
{static_cast<int64_t>(mta_to_lmp.size())},
torch::TensorOptions().dtype(torch::kInt).device(torch::kCPU)
);

system->add_data("momenta", tensor, /*override=*/true);
// gather momenta (per-atom) in a tensor and ship to device
torch::Tensor velocities = torch::zeros({total_n_atoms, 3}, torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU));
auto velocities_accessor = velocities.accessor<double, 2>();
for (int i=0; i<total_n_atoms; i++) {
velocities_accessor[i][0] = v[i][0];
velocities_accessor[i][1] = v[i][1];
velocities_accessor[i][2] = v[i][2];
}

velocities = velocities.index_select(0, mta_to_lmp_tensor);
velocities = velocities * unit_conversion;
auto tensor = make_per_atom_tensormap(velocities, dtype, device, "velocity", {"xyz"});

system->add_data("velocities", tensor, /*override=*/true);
}
17 changes: 16 additions & 1 deletion src/ML-METATOMIC/metatomic_system.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <vector>
#include <array>
#include <string>

#include "pointers.h"
#include "pair.h"
Expand Down Expand Up @@ -57,6 +58,16 @@ struct MetatomicNeighborsData {
std::vector<std::array<float, 3>> distances_f32;
};

// Build a per-atom TensorMap from values shaped as [atoms, components...].
// A property dimension is appended automatically.
metatensor_torch::TensorMap make_per_atom_tensormap(
const torch::Tensor& values,
const torch::ScalarType& dtype,
const torch::Device& device,
const std::string& property_name,
const std::vector<std::string>& component_names = {}
);

class MetatomicSystemAdaptor : public Pointers {
public:
MetatomicSystemAdaptor(LAMMPS *lmp, MetatomicSystemOptions options);
Expand All @@ -72,7 +83,8 @@ class MetatomicSystemAdaptor : public Pointers {
NeighList* list,
bool do_virial,
torch::ScalarType dtype,
torch::Device device
torch::Device device,
const std::map<std::string, torch::intrusive_ptr<metatomic_torch::ModelOutputHolder>>& inputs = {}
);

// Add masses as extra data to this system, only for atoms which are not
Expand All @@ -81,6 +93,9 @@ class MetatomicSystemAdaptor : public Pointers {
// Add momenta as extra data to this system, only for atoms which are not
// periodic images of other atoms
virtual void add_momenta(metatomic_torch::System& system, double unit_conversion);
// Add velocities as extra data to this system, only for atoms which are not
// periodic images of other atoms
virtual void add_velocities(metatomic_torch::System& system, double unit_conversion);

// Explicit strain for virial calculations. This uses the same dtype/device
// as LAMMPS data (positions, …)
Expand Down
Loading