Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 4.16 % | 3.871e+03 | 4.032e+03 | 160.85 | 40.27 | 37.56 |
test_proximal_jac_w7x_with_eq_update | -2.52 % | 6.655e+03 | 6.487e+03 | -168.01 | 167.11 | 167.35 |
test_proximal_freeb_jac | 0.13 % | 1.318e+04 | 1.320e+04 | 16.87 | 85.56 | 85.96 |
test_proximal_freeb_jac_blocked | -0.10 % | 7.477e+03 | 7.469e+03 | -7.44 | 75.83 | 76.62 |
test_proximal_freeb_jac_batched | -0.45 % | 7.527e+03 | 7.492e+03 | -34.23 | 75.24 | 75.83 |
test_proximal_jac_ripple | 2.79 % | 3.429e+03 | 3.525e+03 | 95.75 | 67.01 | 67.21 |
test_proximal_jac_ripple_bounce1d | 1.40 % | 3.499e+03 | 3.548e+03 | 48.82 | 79.39 | 78.57 |
test_eq_solve | -0.55 % | 2.006e+03 | 1.995e+03 | -10.99 | 95.10 | 95.19 |For the memory plots, go to the summary of |
There was a problem hiding this comment.
So as I see it there are two main ways of applying deflation
- As a multiplicative factor on the objective eg
f(x) -> M(x,x*)f(x) - As an additional inequality constraint
M(x,x*)<r
The new DeflationOperator objective seems to cover the 2nd case, but for the first case ForceBalanceDeflated only works for equilibrium problems. I think it would be better as a sort of "wrapper objective" that can be applied to any objective to multiply the deflation operator.
We could possibly combine the two and make it a single objective like
class DeflationOperator:
"""Multiplicative or constraint type deflation"""
def __init__(self, objective=None, ...):
self.objective = objective
def compute(self, x):
if self.objective is not None:
f = self.objective.compute(x)
else:
f = 1
return M(x,x*)*fThis would cover both cases, either treating the deflation as an extra constraint (with objective=None) or applying multiplicative deflation to an arbitrary objective (eg by passing objective=ForceBalance())
| self._dim_f = 1 | ||
|
|
||
| self._is_none_mask = [] | ||
| self._is_not_none_mask = [] |
There was a problem hiding this comment.
i don't think you ever have conditionals on the mask in the compute function? it looks like its just used as the where arg to prod/sum?
desc/objectives/_generic.py
Outdated
| # if wrapping an objective, but all things are None, make deflation do | ||
| # nothing when multiplying f, so here we add 1 to it as it is 0 right now | ||
| # if all things are None | ||
| deflation_parameter += 1.0 |
There was a problem hiding this comment.
I think you might want this to be += jnp.invert(self._not_all_things_to_deflate_are_None) that way it only adds 1 if all the deflated things are None
There was a problem hiding this comment.
yes you are right. the logic was getting confusing
| assert isinstance( | ||
| self._objective, _Objective | ||
| ), "objective passed in must be an _Objective!" |
There was a problem hiding this comment.
I prefer using our errorif util function for this
| self._units = self._objective._units | ||
| self._scalar = self._objective._scalar | ||
| self._coordinates = self._objective._coordinates |
There was a problem hiding this comment.
Does this not cause problems when self._objective = None?
| self._dim_f = 1 | ||
|
|
||
| self._is_none_mask = [] | ||
| self._is_not_none_mask = [] |
There was a problem hiding this comment.
@dpanici I don't see self._is_none_mask get used anywhere in the compute function. Can we remove that array? Then this would be resolved.
| assert np.all(self._bounds[0] <= lower_bound_min), ( | ||
| "Provided lower bound for deflation operator is too high compared " | ||
| f"to the minimum value of {lower_bound_min} it can take based off " | ||
| "of sigma, use a smaller lower bound" | ||
| ) |
There was a problem hiding this comment.
I prefer using errorif for consistent formatting throughout the code.
| if self._multiple_deflation_type == "prod": | ||
| deflation_parameter = jnp.prod( | ||
| M_i, initial=1.0, where=self._is_not_none_mask | ||
| ) + self._sigma * (self._single_shift) | ||
| else: | ||
| deflation_parameter = jnp.sum( | ||
| M_i, where=self._is_not_none_mask, initial=0.0 | ||
| ) + self._sigma * (self._single_shift) |
There was a problem hiding this comment.
Suggested change to be fancy and make the code more readable:
fun = jnp.prod if self._multiple_deflation_type == "prod" else jnp.sum
deflation_parameter = fun(
M_i, initial=1.0, where=self._is_not_none_mask
) + self._sigma * (self._single_shift)
There was a problem hiding this comment.
initial needs to be different for sum vs prod
| deflation_parameter = ( | ||
| deflation_parameter * self._not_all_things_to_deflate_are_None | ||
| ) |
There was a problem hiding this comment.
Suggested change:
deflation_parameter *= self._not_all_things_to_deflate_are_None
| things_to_deflate: list containing elements of type {Optimizable, None} | ||
| list of objects to use in deflation operator. Should be same type | ||
| as thing. Can also contain None elements, in which case those will be ignored. | ||
| The utility of allowing the None element and ignoring them is if one is using | ||
| this objective in a loop with a pre-determined number of iterations and adding | ||
| each result of the loop iterate to the things_to_deflate, it may trigger | ||
| recompilation of the objective's compute and jac/grad functions each time, | ||
| which is wasteful. You can instead pass in a list containing None elements | ||
| padding the list out to the max length it will attain. In this way, no | ||
| recompilations will be triggered, and the entire loop will be completed | ||
| much more quickly. | ||
| If all things_to_deflate are None, this objective has zero cost (if not | ||
| wrapping another objective) or simply returns the wrapped objective's | ||
| cost (if wrapping another objective) |
There was a problem hiding this comment.
How common of a use case is the list of Nones? Because this is a lot of clunky logic to prevent recompilation for something that is happening outside the scope of this objective (in a higher level loop). If it is not a common use case and/or the recompilation time is not significant, then I vote to not worry about it. The code could be simplified a lot.
There was a problem hiding this comment.
I think it's worth it. The whole point of deflation is that you call it in a loop multiple times, and in common cases the compilation time is a significant fraction of the overall solve.
| # put res of self.eq back to original | ||
| res = self.res_array[0] | ||
| self.eq.change_resolution( | ||
| L_grid=int(self.eq.L * res), | ||
| M_grid=int(self.eq.M * res), | ||
| N_grid=int(self.eq.N * res), | ||
| ) |
There was a problem hiding this comment.
Instead of modifying the equilibrium and then reverting the changes back, it would be safer and cleaner to make a copy of eq at the beginning of this function and only modify that copy.
| assert not np.any(np.isnan(g)), "plasma vessel distance" | ||
|
|
||
| @pytest.mark.unit | ||
| def test_objective_no_nangrad_ForceBalanceDeflated(self): |
There was a problem hiding this comment.
ForceBalanceDeflated isn't actually its own objective, right? I assume this was from an old version of the code
Deflation method motivation: find multiple solutions to non-convex optimization problems (which can include certain equilibirum solves)
This PR adds ways to apply deflation methods in stellarator optimization and equilibrium solving through the new
DeflationOperatorobjectiveDeflationOperatorwhose cost is simply M(x;y) = 1/(x-y)^p + sigma (to add as constraints to an optimization like in Tarek 2022 work). This can be used as a standalone metric, or another_Objectivecan be passed to it to wrap it and return as the cost M(x;y)f(x) where f(x) is that_Objective's compute value, like is done in usual deflationReferences:
"exp"deflation typeTODO
ForceBalanceDeflatedto use pytree inputs forparams_to_deflate_withFuture work for another PR:
_equilibriumas attribute of DeflationOperator and test using it in proximal-lsq-exact