Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 3.82 % | 3.850e+03 | 3.997e+03 | 147.16 | 39.06 | 36.34 |
test_proximal_jac_w7x_with_eq_update | 1.56 % | 6.493e+03 | 6.594e+03 | 101.02 | 162.92 | 161.90 |
test_proximal_freeb_jac | -0.16 % | 1.321e+04 | 1.319e+04 | -20.82 | 85.01 | 83.21 |
test_proximal_freeb_jac_blocked | -0.71 % | 7.546e+03 | 7.492e+03 | -53.29 | 74.37 | 74.28 |
test_proximal_freeb_jac_batched | -0.00 % | 7.487e+03 | 7.486e+03 | -0.21 | 73.78 | 73.48 |
test_proximal_jac_ripple | -1.39 % | 3.550e+03 | 3.501e+03 | -49.41 | 66.38 | 67.49 |
test_proximal_jac_ripple_bounce1d | 5.89 % | 3.455e+03 | 3.659e+03 | 203.66 | 77.63 | 79.54 |
test_eq_solve | -0.33 % | 2.030e+03 | 2.023e+03 | -6.73 | 95.13 | 94.21 |For the memory plots, go to the summary of |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2041 +/- ##
========================================
Coverage 94.53% 94.53%
========================================
Files 102 102
Lines 28712 28823 +111
========================================
+ Hits 27143 27249 +106
- Misses 1569 1574 +5
🚀 New features to boost your workflow:
|
dpanici
left a comment
There was a problem hiding this comment.
just the small docstring fix, should be explitict that x_scale='"auto"` does no scaling here
f0uriest
left a comment
There was a problem hiding this comment.
I'd double check that the x_scale logic is correct
Also, did you look at whether we could just wrap stuff from optax?
From the examples eg https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam
it looks like the user could just pass in an optax.solver and then we can just do
opt_state = solver.init(x0)
...
g = grad(x)*x_scale
updates, opt_state = solver.update(g, opt_state, x)
x = optax.apply_updates(x, x_scale*updates)or something similar. That would give users access to a much wider array of first order optimizers, and save us having to do it all ourselves
desc/optimize/stochastic.py
Outdated
|
|
||
|
|
||
| def sgd( | ||
| def generic_sgd( |
There was a problem hiding this comment.
sgd is technically public (https://desc-docs.readthedocs.io/en/stable/_api/optimize/desc.optimize.sgd.html#desc.optimize.sgd) so if we want to change the name we should keep an alias to the old one with a deprecation warning. That said, I'm not sure we really need to change the name. "SGD" is already used pretty generically in the ML community for a bunch of first order stochastic methods like ADAM, ADAGRAD, RMSPROP, etc
There was a problem hiding this comment.
Yeah, SGD is in fact the general name. I can revert to the old name, and just emphasize that "sgd" option is with nesterov momentum. I was trying to make a distinction I guess
desc/optimize/stochastic.py
Outdated
| for the update rule chosen. | ||
|
|
||
| - ``"alpha"`` : (float > 0) Learning rate. Defaults to | ||
| 1e-1 * ||x_scaled|| / ||g_scaled||. |
There was a problem hiding this comment.
this seems pretty large (steps would be 10% of x), have you checked how robust this is?
There was a problem hiding this comment.
I was trying to solve eq with these and even though none of them converged 10% was better for a variety of equilibrium. I haven't checked other optimization problems. Reverted the change and added a backguard against 0 and NaNs.
desc/optimize/stochastic.py
Outdated
| Where alpha is the step size and beta is the momentum parameter. | ||
| Update rule for ``'sgd'``: | ||
|
|
||
| .. math:: |
There was a problem hiding this comment.
personally I prefer unicode for stuff like this. TeX looks nice in the rendered html docs, but is much harder to read as code.
There was a problem hiding this comment.
I mostly agree and don't have a strong stance either way. My general preference is to use LaTeX for complex equations or public-facing objectives that users will first encounter in the documentation (like guiding center equations, optimization algorithms). For internal development notes or specific compute functions that aren't usually viewed on the web, I’m fine with Unicode since it keeps the source code more readable.
Also |
desc/optimize/stochastic.py
Outdated
| return result | ||
|
|
||
|
|
||
| def _sgd(g, v, alpha, beta): |
There was a problem hiding this comment.
is this the same as some version of optax-sgd? if so I'd vote to remove this and just do something like if method == "sgd": method = "optax-sgd". Then we can simplify a lot of the code here and just always assume we're using optax stuff
There was a problem hiding this comment.
I am not 100% sure but it can be equivalent to optax-sgd with momentum=beta, learning_rate=alpha and nesterov=True. The amount of code dedicated to that is not much and it also handles future implementations (I don't know if anyone wants to add their sgd optimizers but anyway). If people want to I can add depreciation but it doesn't seem urgent to me.
|
A simple helper test to check @pytest.mark.unit
def test_available_optax_optimizers(self):
"""Test that all optax optimizers are included in _all_optax_optimizers."""
optimizers = []
# Optax doesn't have a specific module for optimizers, and there is no specific
# base class for optimizers, so we have to manually exclude some outliers. The
# class optax.GradientTransformationExtraArgs is the closest thing, but there
# are some other classes that inherit from it that are not optimizers. Since
# the optimizers are actually a function that returns an instance of
# optax.GradientTransformationExtraArgs,
names_to_exclude = [
"GradientTransformationExtraArgs",
"freeze",
"scale_by_backtracking_linesearch",
"scale_by_polyak",
"scale_by_zoom_linesearch",
"optimistic_adam", # deprecated
]
for name, obj in inspect.getmembers(optax):
if name.startswith("_"):
continue
if callable(obj):
try:
sig = inspect.signature(obj)
ins = {
p.name: 0.1
for p in sig.parameters.values()
if p.default is inspect._empty
}
if name == "noisy_sgd":
ins["key"] = 0
out = obj(**ins)
if isinstance(out, optax.GradientTransformationExtraArgs):
if name not in names_to_exclude:
optimizers.append(name)
except Exception:
print(f"Could not instantiate: {name}")
pass
msg = (
"Wrapped optax optimizers can be out of date. If the newly added callable "
"is not an optimizer, add it to the names_to_exclude list in this test."
)
print(optimizers)
assert len(set(optimizers)) == len(_all_optax_optimizers), msg
assert sorted(set(optimizers)) == sorted(_all_optax_optimizers), msg
assert len(set(_all_optax_optimizers)) == len(_all_optax_optimizers), msg |
|
We are in favor of removing sgd code and have sgd alias to optax-sgd |
| ) | ||
| deprecated_sgd = False | ||
| if method == "sgd": | ||
| # warn the user but do not fail the pytest |
There was a problem hiding this comment.
why is this here? shouldn't this be in the test?
There was a problem hiding this comment.
There are multiple tests and some are parameterized, this seemed easier
There was a problem hiding this comment.
But doesn't this mean the warning isn't actually emitted?
There was a problem hiding this comment.
It is not treated as a warning but the message is still printed as a warning. So, this is equivalent to something like
print("DepreciationWarning: SGD is deprecated.")
with proper syntax for the warnings. It will be more annoying to add conditionals to parameterized tests; I would rather not do that

x_scaleis now used with SGD methods toooptaxoptimizers, and they can be called byoptax-name