Skip to content

Add Deflation methods#2044

Open
dpanici wants to merge 54 commits intomasterfrom
dp/deflated-continuation
Open

Add Deflation methods#2044
dpanici wants to merge 54 commits intomasterfrom
dp/deflated-continuation

Conversation

@dpanici
Copy link
Collaborator

@dpanici dpanici commented Dec 22, 2025

Deflation method motivation: find multiple solutions to non-convex optimization problems (which can include certain equilibirum solves)

This PR adds ways to apply deflation methods in stellarator optimization and equilibrium solving through the new DeflationOperator objective

  • adds a new generic objective DeflationOperator whose cost is simply M(x;y) = 1/(x-y)^p + sigma (to add as constraints to an optimization like in Tarek 2022 work). This can be used as a standalone metric, or another _Objective can be passed to it to wrap it and return as the cost M(x;y)f(x) where f(x) is that _Objective's compute value, like is done in usual deflation
  • Adds tutorial covering these

References:

  • Riley 2024 - for the "exp" deflation type
  • Farrell 2015 - for the addition of the shift parameter and the general form of deflation used

TODO

  • add tests
  • update changelog
  • update ForceBalanceDeflated to use pytree inputs for params_to_deflate_with
  • add option for using single shift, like discussed in Riley 2024
  • Add a wrapper objective so that one can multiply the deflation operator with any arbitrary objective, instead of only able to add it as an extra cost when doing stage one/two optimization.
  • figure out how to avoid recompilation

Future work for another PR:

  • Implement algorithms from Riley 2024
    • Implement deflated line-search Gauss-Newton algorithm like they use
    • Adapt their algorithm for our usual trust-region approach
  • add _equilibrium as attribute of DeflationOperator and test using it in proximal-lsq-exact

@dpanici dpanici requested review from a team, YigitElma, ddudt, f0uriest, rahulgaur104 and unalmis and removed request for a team December 22, 2025 19:00
@github-actions
Copy link
Contributor

github-actions bot commented Dec 22, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    4.16 %    |     3.871e+03      |     4.032e+03      |    160.85    |       40.27        |       37.56        |
  test_proximal_jac_w7x_with_eq_update   |   -2.52 %    |     6.655e+03      |     6.487e+03      |   -168.01    |       167.11       |       167.35       |
  test_proximal_freeb_jac                |    0.13 %    |     1.318e+04      |     1.320e+04      |    16.87     |       85.56        |       85.96        |
  test_proximal_freeb_jac_blocked        |   -0.10 %    |     7.477e+03      |     7.469e+03      |    -7.44     |       75.83        |       76.62        |
  test_proximal_freeb_jac_batched        |   -0.45 %    |     7.527e+03      |     7.492e+03      |    -34.23    |       75.24        |       75.83        |
  test_proximal_jac_ripple               |    2.79 %    |     3.429e+03      |     3.525e+03      |    95.75     |       67.01        |       67.21        |
  test_proximal_jac_ripple_bounce1d      |    1.40 %    |     3.499e+03      |     3.548e+03      |    48.82     |       79.39        |       78.57        |
  test_eq_solve                          |   -0.55 %    |     2.006e+03      |     1.995e+03      |    -10.99    |       95.10        |       95.19        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as I see it there are two main ways of applying deflation

  1. As a multiplicative factor on the objective eg f(x) -> M(x,x*)f(x)
  2. As an additional inequality constraint M(x,x*)<r

The new DeflationOperator objective seems to cover the 2nd case, but for the first case ForceBalanceDeflated only works for equilibrium problems. I think it would be better as a sort of "wrapper objective" that can be applied to any objective to multiply the deflation operator.

We could possibly combine the two and make it a single objective like

class DeflationOperator:
    """Multiplicative or constraint type deflation"""

def __init__(self, objective=None, ...):
    self.objective = objective

def compute(self, x):
    if self.objective is not None:
        f = self.objective.compute(x)
    else:
        f = 1
    return M(x,x*)*f

This would cover both cases, either treating the deflation as an extra constraint (with objective=None) or applying multiplicative deflation to an arbitrary objective (eg by passing objective=ForceBalance())

@dpanici dpanici requested review from YigitElma and f0uriest January 26, 2026 17:00
self._dim_f = 1

self._is_none_mask = []
self._is_not_none_mask = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think you ever have conditionals on the mask in the compute function? it looks like its just used as the where arg to prod/sum?

# if wrapping an objective, but all things are None, make deflation do
# nothing when multiplying f, so here we add 1 to it as it is 0 right now
# if all things are None
deflation_parameter += 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might want this to be += jnp.invert(self._not_all_things_to_deflate_are_None) that way it only adds 1 if all the deflated things are None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you are right. the logic was getting confusing

@dpanici dpanici requested review from YigitElma, ddudt and f0uriest and removed request for YigitElma, ddudt and f0uriest January 27, 2026 18:15
Comment on lines +825 to +827
assert isinstance(
self._objective, _Objective
), "objective passed in must be an _Objective!"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer using our errorif util function for this

Comment on lines +831 to +833
self._units = self._objective._units
self._scalar = self._objective._scalar
self._coordinates = self._objective._coordinates
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not cause problems when self._objective = None?

self._dim_f = 1

self._is_none_mask = []
self._is_not_none_mask = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dpanici I don't see self._is_none_mask get used anywhere in the compute function. Can we remove that array? Then this would be resolved.

Comment on lines +910 to +914
assert np.all(self._bounds[0] <= lower_bound_min), (
"Provided lower bound for deflation operator is too high compared "
f"to the minimum value of {lower_bound_min} it can take based off "
"of sigma, use a smaller lower bound"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer using errorif for consistent formatting throughout the code.

Comment on lines +965 to +972
if self._multiple_deflation_type == "prod":
deflation_parameter = jnp.prod(
M_i, initial=1.0, where=self._is_not_none_mask
) + self._sigma * (self._single_shift)
else:
deflation_parameter = jnp.sum(
M_i, where=self._is_not_none_mask, initial=0.0
) + self._sigma * (self._single_shift)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change to be fancy and make the code more readable:

fun = jnp.prod if self._multiple_deflation_type == "prod" else jnp.sum
deflation_parameter = fun(
    M_i, initial=1.0, where=self._is_not_none_mask
) + self._sigma * (self._single_shift)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initial needs to be different for sum vs prod

Comment on lines +975 to +977
deflation_parameter = (
deflation_parameter * self._not_all_things_to_deflate_are_None
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change:

deflation_parameter *= self._not_all_things_to_deflate_are_None

Comment on lines +725 to +738
things_to_deflate: list containing elements of type {Optimizable, None}
list of objects to use in deflation operator. Should be same type
as thing. Can also contain None elements, in which case those will be ignored.
The utility of allowing the None element and ignoring them is if one is using
this objective in a loop with a pre-determined number of iterations and adding
each result of the loop iterate to the things_to_deflate, it may trigger
recompilation of the objective's compute and jac/grad functions each time,
which is wasteful. You can instead pass in a list containing None elements
padding the list out to the max length it will attain. In this way, no
recompilations will be triggered, and the entire loop will be completed
much more quickly.
If all things_to_deflate are None, this objective has zero cost (if not
wrapping another objective) or simply returns the wrapped objective's
cost (if wrapping another objective)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How common of a use case is the list of Nones? Because this is a lot of clunky logic to prevent recompilation for something that is happening outside the scope of this objective (in a higher level loop). If it is not a common use case and/or the recompilation time is not significant, then I vote to not worry about it. The code could be simplified a lot.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth it. The whole point of deflation is that you call it in a loop multiple times, and in common cases the compilation time is a significant fraction of the overall solve.

Comment on lines +3581 to +3587
# put res of self.eq back to original
res = self.res_array[0]
self.eq.change_resolution(
L_grid=int(self.eq.L * res),
M_grid=int(self.eq.M * res),
N_grid=int(self.eq.N * res),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of modifying the equilibrium and then reverting the changes back, it would be safer and cleaner to make a copy of eq at the beginning of this function and only modify that copy.

assert not np.any(np.isnan(g)), "plasma vessel distance"

@pytest.mark.unit
def test_objective_no_nangrad_ForceBalanceDeflated(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ForceBalanceDeflated isn't actually its own objective, right? I assume this was from an old version of the code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments