Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/ci_serve.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Build and test serve mode
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
on: [push, pull_request]

env:
SCCACHE_GHA_ENABLED: "true"

jobs:
build_serve:
runs-on: ${{ matrix.os }}
name: test (${{ matrix.os }})
strategy:
fail-fast: false
matrix:
os: [ubuntu-22.04]
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/setup-pixi@v0.8.10
with:
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
activate-environment: true
environments: >-
serve
- name: Configure sccache
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Build with serve mode
shell: pixi run bash -e {0}
run: |
meson setup bbdir \
--prefix=$CONDA_PREFIX \
--libdir=lib \
-Dwith_serve=true \
-Dwith_tests=true
meson compile -C bbdir
- name: Run serve mode tests
shell: pixi run bash -e {0}
run: |
meson test -C bbdir test_serve_spec
90 changes: 89 additions & 1 deletion client/CommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
#include "Potential.h"
#include "version.h"

#ifdef WITH_SERVE_MODE
#include "ServeMode.h"
#endif

#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <memory>
Expand Down Expand Up @@ -62,7 +67,28 @@ void commandLine(int argc, char **argv) {
"t,tolerance", "Distance tolerance",
cxxopts::value<double>()->default_value("0.1"))(
"p,potential", "The potential (e.g. qsc, lj, eam_al)",
cxxopts::value<std::string>())("h,help", "Print usage");
cxxopts::value<std::string>())
#ifdef WITH_SERVE_MODE
("serve",
"Serve potential(s) over rgpot Cap'n Proto RPC. "
"Spec: 'potential:port' or 'pot1:port1,pot2:port2'",
cxxopts::value<std::string>())(
"serve-host", "Host to bind RPC server(s) to",
cxxopts::value<std::string>()->default_value("localhost"))(
"serve-port", "Port for single-potential serve mode (used with -p)",
cxxopts::value<uint16_t>()->default_value("12345"))(
"replicas", "Number of replicated server instances (used with -p)",
cxxopts::value<size_t>()->default_value("1"))(
"gateway",
"Run a single gateway port backed by N pool instances "
"(use with -p and --replicas)",
cxxopts::value<bool>()->default_value("false"))(
"config",
"Config file for potential parameters (INI format, "
"e.g. [Metatomic] model_path=model.pt)",
cxxopts::value<std::string>())
#endif
("h,help", "Print usage");

try {
auto result = options.parse(argc, argv);
Expand Down Expand Up @@ -120,6 +146,68 @@ void commandLine(int argc, char **argv) {
exit(2);
}

#ifdef WITH_SERVE_MODE
// Load config file if provided (for potential-specific parameters
// like model_path, device, length_unit, etc.)
if (result.count("config")) {
auto config_path = result["config"].as<std::string>();
std::ifstream config_file(config_path);
if (!config_file.is_open()) {
std::cerr << "Cannot open config file: " << config_path << std::endl;
exit(2);
}
params.load(config_path);
}

// Handle --serve mode (does not require a con file)
if (result.count("serve")) {
auto spec = result["serve"].as<std::string>();
auto endpoints = parseServeSpec(spec);
if (endpoints.empty()) {
std::cerr << "No valid serve endpoints in spec: " << spec << std::endl;
exit(2);
}
serveMultiple(endpoints, params);
exit(0);
}

// Handle -p with serve flags (single potential serve mode)
if (pflag && !sflag && !mflag && !cflag &&
(result.count("serve-port") || result.count("replicas") ||
result.count("gateway"))) {
for (auto &ch : potential) {
ch = tolower(ch);
}
params.potential_options.potential =
magic_enum::enum_cast<PotType>(potential,
magic_enum::case_insensitive)
.value_or(PotType::UNKNOWN);
auto host = result["serve-host"].as<std::string>();
auto port = result["serve-port"].as<uint16_t>();
auto reps = result["replicas"].as<size_t>();
bool gw = result["gateway"].as<bool>();

if (gw) {
serveGateway(params, host, port, reps);
} else if (reps > 1) {
serveReplicated(params, host, port, reps);
} else {
serveMode(params, host, port);
}
exit(0);
}

// Config-driven serve (no -p or --serve, just --config with [Serve])
if (!pflag && !sflag && !mflag && !cflag && result.count("config") &&
!result.count("serve") &&
(!params.serve_options.endpoints.empty() ||
params.serve_options.gateway_port > 0 ||
params.serve_options.replicas > 1)) {
serveFromConfig(params);
exit(0);
}
#endif

if (!cflag) {
for (auto &ch : potential) {
ch = tolower(ch);
Expand Down
20 changes: 20 additions & 0 deletions client/Parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,13 @@ Parameters::Parameters() {
bgsd_options.grad2energy_convergence = 0.000001;
bgsd_options.grad2force_convergence = 0.0001;

// [Serve] //
serve_options.host = "localhost";
serve_options.port = 12345;
serve_options.replicas = 1;
serve_options.gateway_port = 0;
serve_options.endpoints = "";

// [CatLearn] //
// No reasonable default for catlearn_options.path
catlearn_options.model = "gp";
Expand Down Expand Up @@ -884,6 +891,19 @@ int Parameters::load(FILE *file) {
_variant.energy_uncertainty =
ini.GetValue("Metatomic", "variant_energy_uncertainty", "");
}
// [Serve]
if (ini.FindKey("Serve") != -1) {
serve_options.host = ini.GetValue("Serve", "host", serve_options.host);
serve_options.port = static_cast<uint16_t>(
ini.GetValueL("Serve", "port", serve_options.port));
serve_options.replicas = static_cast<size_t>(
ini.GetValueL("Serve", "replicas", serve_options.replicas));
serve_options.gateway_port = static_cast<uint16_t>(
ini.GetValueL("Serve", "gateway_port", serve_options.gateway_port));
serve_options.endpoints =
ini.GetValue("Serve", "endpoints", serve_options.endpoints);
}

// GP_NEB only
gp_surrogate_options.linear_path_always =
ini.GetValueB("Surrogate", "gp_linear_path_always",
Expand Down
9 changes: 9 additions & 0 deletions client/Parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,15 @@ class Parameters {
double grad2force_convergence;
} bgsd_options;

// [Serve] //
struct serve_options_t {
string host;
uint16_t port;
size_t replicas;
uint16_t gateway_port; // 0 = disabled
string endpoints; // "pot:port,pot:host:port,..." spec string
} serve_options;

// [Debug] //
struct debug_options_t {
bool write_movies;
Expand Down
Loading
Loading