Skip to content

Commit b01cdc5

Browse files
committed
feat: add SABRInterpolation binding with calibration support
1 parent 79dc7cd commit b01cdc5

7 files changed

Lines changed: 275 additions & 2 deletions

File tree

docs/api/math.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,28 @@ Log-space mixed interpolation: log-linear for the first `n` points, log-cubic fo
281281
.. autoclass:: pyquantlib.LogMixedLinearCubicNaturalSpline
282282
```
283283

284+
### SABRInterpolation
285+
286+
```{eval-rst}
287+
.. autoclass:: pyquantlib.SABRInterpolation
288+
```
289+
290+
SABR smile interpolation that calibrates alpha, beta, nu, rho to market volatilities.
291+
292+
```python
293+
strikes = [0.03, 0.04, 0.05, 0.06, 0.07]
294+
vols = [0.25, 0.22, 0.20, 0.21, 0.23]
295+
296+
interp = ql.SABRInterpolation(
297+
strikes, vols, expiry=1.0, forward=0.05,
298+
alpha=0.2, beta=0.5, nu=0.4, rho=-0.3,
299+
alphaIsFixed=False, betaIsFixed=True,
300+
nuIsFixed=False, rhoIsFixed=False,
301+
)
302+
interp.update()
303+
print(f"alpha={interp.alpha():.4f}, rho={interp.rho():.4f}, rms={interp.rmsError():.6f}")
304+
```
305+
284306
### ForwardFlatInterpolation
285307

286308
```{eval-rst}

include/pyquantlib/pyquantlib.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ namespace ql_math {
9999
void cubicinterpolation(py::module_&);
100100
void mixedinterpolation(py::module_&);
101101
void loginterpolation(py::module_&);
102+
void sabrinterpolation(py::module_&);
102103
void normaldistribution(py::module_&);
103104
void bivariatenormaldistribution(py::module_&);
104105
void solvers1d(py::module_&);

pyquantlib/__init__.pyi

Lines changed: 2 additions & 1 deletion
Large diffs are not rendered by default.

pyquantlib/_pyquantlib/__init__.pyi

Lines changed: 49 additions & 1 deletion
Large diffs are not rendered by default.

src/math/all.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ DECLARE_MODULE_BINDINGS(math_bindings) {
3434
ADD_MAIN_BINDING(ql_math::cubicinterpolation, "Cubic interpolation");
3535
ADD_MAIN_BINDING(ql_math::mixedinterpolation, "Mixed linear/cubic interpolation");
3636
ADD_MAIN_BINDING(ql_math::loginterpolation, "Log-cubic and log-mixed interpolation");
37+
ADD_MAIN_BINDING(ql_math::sabrinterpolation, "SABR interpolation");
3738
ADD_MAIN_BINDING(ql_math::normaldistribution, "Normal distribution functions");
3839
ADD_MAIN_BINDING(ql_math::bivariatenormaldistribution, "Bivariate cumulative normal distribution");
3940
ADD_MAIN_BINDING(ql_math::solvers1d, "1-D root-finding solvers");
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* PyQuantLib: Python bindings for QuantLib
3+
* https://github.com/quantales/pyquantlib
4+
*
5+
* Copyright (c) 2025 Yassine Idyiahia
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
* See LICENSE for details.
8+
*
9+
* ---
10+
* QuantLib is Copyright (c) 2000-2025 The QuantLib Authors
11+
* https://www.quantlib.org/
12+
*/
13+
14+
#include "pyquantlib/pyquantlib.h"
15+
#include <ql/math/interpolations/sabrinterpolation.hpp>
16+
#include <ql/math/optimization/levenbergmarquardt.hpp>
17+
#include <pybind11/pybind11.h>
18+
#include <pybind11/stl.h>
19+
20+
namespace py = pybind11;
21+
using namespace QuantLib;
22+
23+
namespace {
24+
25+
// Extended data holder that also stores forward on the heap.
26+
// SABRInterpolation stores const Real& forward, so forward needs
27+
// a stable address that survives the move into shared_ptr's deleter.
28+
struct SABRDataHolder {
29+
std::vector<Real> x, y;
30+
ext::shared_ptr<Real> forward;
31+
32+
template <typename T>
33+
void operator()(T* p) const { delete p; }
34+
};
35+
36+
} // namespace
37+
38+
void ql_math::sabrinterpolation(py::module_& m) {
39+
py::class_<SABRInterpolation, Interpolation,
40+
ext::shared_ptr<SABRInterpolation>>(
41+
m, "SABRInterpolation",
42+
"SABR smile interpolation between discrete volatility points.")
43+
.def(py::init([](std::vector<Real> strikes,
44+
std::vector<Real> vols,
45+
Real t,
46+
Real forward,
47+
Real alpha, Real beta, Real nu, Real rho,
48+
bool alphaIsFixed, bool betaIsFixed,
49+
bool nuIsFixed, bool rhoIsFixed,
50+
bool vegaWeighted,
51+
const ext::shared_ptr<EndCriteria>& endCriteria,
52+
const ext::shared_ptr<OptimizationMethod>& optMethod,
53+
Real errorAccept,
54+
bool useMaxError,
55+
Size maxGuesses,
56+
Real shift) {
57+
58+
QL_REQUIRE(strikes.size() == vols.size(),
59+
"strikes and vols must have the same size");
60+
QL_REQUIRE(strikes.size() >= 2,
61+
"at least 2 points required");
62+
63+
SABRDataHolder holder{
64+
std::move(strikes), std::move(vols),
65+
ext::make_shared<Real>(forward)};
66+
67+
auto* ptr = new SABRInterpolation(
68+
holder.x.begin(), holder.x.end(), holder.y.begin(),
69+
t, *holder.forward,
70+
alpha, beta, nu, rho,
71+
alphaIsFixed, betaIsFixed, nuIsFixed, rhoIsFixed,
72+
vegaWeighted, endCriteria, optMethod,
73+
errorAccept, useMaxError, maxGuesses,
74+
shift);
75+
76+
return ext::shared_ptr<SABRInterpolation>(ptr, std::move(holder));
77+
}),
78+
py::arg("strikes"), py::arg("vols"),
79+
py::arg("expiry"), py::arg("forward"),
80+
py::arg("alpha"), py::arg("beta"), py::arg("nu"), py::arg("rho"),
81+
py::arg("alphaIsFixed"), py::arg("betaIsFixed"),
82+
py::arg("nuIsFixed"), py::arg("rhoIsFixed"),
83+
py::arg("vegaWeighted") = true,
84+
py::arg("endCriteria") = ext::shared_ptr<EndCriteria>(),
85+
py::arg("optMethod") = ext::shared_ptr<OptimizationMethod>(),
86+
py::arg("errorAccept") = 0.0020,
87+
py::arg("useMaxError") = false,
88+
py::arg("maxGuesses") = Size(50),
89+
py::arg("shift") = 0.0,
90+
"Constructs SABR interpolation from strikes and volatilities.")
91+
.def("expiry", &SABRInterpolation::expiry, "Returns the expiry.")
92+
.def("forward", &SABRInterpolation::forward, "Returns the forward.")
93+
.def("alpha", &SABRInterpolation::alpha, "Returns calibrated alpha.")
94+
.def("beta", &SABRInterpolation::beta, "Returns calibrated beta.")
95+
.def("nu", &SABRInterpolation::nu, "Returns calibrated nu.")
96+
.def("rho", &SABRInterpolation::rho, "Returns calibrated rho.")
97+
.def("rmsError", &SABRInterpolation::rmsError, "Returns RMS calibration error.")
98+
.def("maxError", &SABRInterpolation::maxError, "Returns max calibration error.")
99+
.def("endCriteria", &SABRInterpolation::endCriteria, "Returns the end criteria type.")
100+
.def("interpolationWeights", &SABRInterpolation::interpolationWeights,
101+
py::return_value_policy::reference_internal,
102+
"Returns the interpolation weights.");
103+
}

tests/test_math_interpolations.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,100 @@ def make():
492492

493493
interp = make()
494494
assert interp(3.0) == pytest.approx(0.95)
495+
496+
497+
# ---------------------------------------------------------------------------
498+
# SABRInterpolation
499+
# ---------------------------------------------------------------------------
500+
501+
def _sabr_test_data():
502+
"""Generate SABR test data from known parameters (mirrors QL test suite)."""
503+
forward = 0.039
504+
expiry = 1.0
505+
alpha, beta, nu, rho_val = 0.3, 0.6, 0.02, 0.01
506+
507+
strikes = [0.03 + 0.002 * i for i in range(31)]
508+
vols = [
509+
ql.sabrVolatility(K, forward, expiry, alpha, beta, nu, rho_val)
510+
for K in strikes
511+
]
512+
return strikes, vols, forward, expiry, alpha, beta, nu, rho_val
513+
514+
515+
def test_sabr_interpolation_fixed():
516+
"""SABRInterpolation with all params fixed evaluates correctly."""
517+
strikes, vols, forward, expiry, alpha, beta, nu, rho_val = _sabr_test_data()
518+
519+
interp = ql.SABRInterpolation(
520+
strikes, vols, expiry, forward,
521+
alpha=alpha, beta=beta, nu=nu, rho=rho_val,
522+
alphaIsFixed=True, betaIsFixed=True,
523+
nuIsFixed=True, rhoIsFixed=True,
524+
)
525+
interp.update()
526+
527+
assert isinstance(interp, ql.base.Interpolation)
528+
assert interp.expiry() == pytest.approx(expiry)
529+
assert interp.forward() == pytest.approx(forward)
530+
assert interp.alpha() == pytest.approx(alpha)
531+
assert interp.beta() == pytest.approx(beta)
532+
533+
# Should reproduce input vols exactly
534+
for K, v in zip(strikes, vols):
535+
assert interp(K) == pytest.approx(v, rel=1e-10)
536+
537+
538+
def test_sabr_interpolation_calibration():
539+
"""Calibrated SABR recovers original parameters."""
540+
strikes, vols, forward, expiry, alpha, beta, nu, rho_val = _sabr_test_data()
541+
542+
interp = ql.SABRInterpolation(
543+
strikes, vols, expiry, forward,
544+
alpha=0.2, beta=beta, nu=0.1, rho=0.0,
545+
alphaIsFixed=False, betaIsFixed=True,
546+
nuIsFixed=False, rhoIsFixed=False,
547+
vegaWeighted=True,
548+
)
549+
interp.update()
550+
551+
assert interp.alpha() == pytest.approx(alpha, rel=0.01)
552+
assert interp.nu() == pytest.approx(nu, rel=0.1)
553+
assert interp.rmsError() < 1e-4
554+
555+
556+
def test_sabr_interpolation_accessors():
557+
"""SABR accessors return calibrated values."""
558+
strikes, vols, forward, expiry, alpha, beta, nu, rho_val = _sabr_test_data()
559+
560+
interp = ql.SABRInterpolation(
561+
strikes, vols, expiry, forward,
562+
alpha=alpha, beta=beta, nu=nu, rho=rho_val,
563+
alphaIsFixed=True, betaIsFixed=True,
564+
nuIsFixed=True, rhoIsFixed=True,
565+
)
566+
interp.update()
567+
568+
assert interp.rmsError() < 1e-10
569+
assert interp.maxError() < 1e-10
570+
assert interp.endCriteria() is not None
571+
weights = interp.interpolationWeights()
572+
assert len(weights) == len(strikes)
573+
574+
575+
def test_sabr_interpolation_data_lifetime():
576+
"""SABR interpolation survives after input lists go out of scope."""
577+
def make():
578+
strikes, vols, forward, expiry, alpha, beta, nu, rho_val = _sabr_test_data()
579+
interp = ql.SABRInterpolation(
580+
strikes, vols, expiry, forward,
581+
alpha=alpha, beta=beta, nu=nu, rho=rho_val,
582+
alphaIsFixed=True, betaIsFixed=True,
583+
nuIsFixed=True, rhoIsFixed=True,
584+
)
585+
interp.update()
586+
return interp
587+
588+
interp = make()
589+
val = interp(0.05)
590+
assert math.isfinite(val)
591+
assert val > 0

0 commit comments

Comments
 (0)