Add promote atom and affine broadcasting#12
Conversation
SteveDiamond
left a comment
There was a problem hiding this comment.
Reviewed the full diff, built the branch, and ran the suite: all 255 tests pass and clippy is clean. This is a solid, well-tested PR — promote is wired correctly through every Expr match site (curvature, sign, eval, shape, variable collection, canonicalization, problem shape collection), and folding the old ad-hoc Constraint::broadcast_scalar into a single shared broadcast_exprs is a nice consolidation. The fix to #11 is real and demonstrated by test_scalar_affine_broadcasts_in_vector_constraint.
A few findings, mostly minor (inline). Summary:
🟠 Medium — dimension mismatch now panics instead of erroring. The new asserts in elementwise_mul_const_lin turn a shape mismatch into a hard panic during solve(). Repro:
let x = variable((3, 3));
let c = constant_matrix(vec![1.,2.,3.,4.], 2, 2); // incompatible
Problem::minimize(sum(&(c * x))).subject_to([x.ge(0.0)]).build().solve();
// → panic at canonicalizer.rs (elementwise assert)This is arguably an improvement over the old silently-wrong size.min(coeff.nrows()), and since canonicalize_expr is infallible an assert! is consistent with the architecture. But a library panicking on user input is rough — ideally incompatible shapes are caught at expression/constraint construction (where a Result/CvxError::InvalidProblem is available). At minimum worth flagging that this is now a panic path.
🟡 Low — unsupported/incompatible broadcasts silently report scalar shape. broadcast_binary_shape falls back to Shape::scalar(), so (1,3) + (2,1) reports shape() == (). Mutual broadcasting (1,n)+(m,1) → (m,n) (which CVXPY supports) isn't handled. Pre-existing pattern (unwrap_or_else(Shape::scalar)), but the new code makes it more prominent.
🟡 Low — efficiency regression on scalar * vector/matrix. broadcast_exprs now wraps the scalar in Promote for every c * X, including the hot 3.0 * x. The l.scale(scalar) fast path in canonicalize_mul matches only constant_value() (bare Constant, not Promote), so promoted scalars now fall through to elementwise_mul_const_lin, which materializes a full constant array with csc↔dense roundtrips. Correct but heavier. Letting canonicalize_mul (or constant_value()) see through a Promote of a scalar constant would restore the fast path.
🔵 Question — canonicalize_matmul result-shape change is broader than broadcasting. The ncols==1 && nrows>1 → Shape::vector logic changes matmul result shapes globally (any matmul producing (m,1) is now vector(m) instead of matrix(m,1)). Worth a sentence on why it's needed and confirming nothing relies on the old shape.
🔵 Nit — duplicated broadcasting rules. broadcast_2d_to/broadcast_exprs (affine.rs) and broadcast_binary_shape (expression.rs) encode the same rules in two places and can drift. Consider a single source of truth.
Comfortable approving after the panic finding is acknowledged; the rest are optional polish. Nice work!
| assert_eq!( | ||
| c_flat.len(), | ||
| size, | ||
| "elementwise multiplication requires matching sizes after broadcasting" |
There was a problem hiding this comment.
Medium: these asserts turn a shape mismatch into a hard panic during solve() (e.g. constant_matrix(_,2,2) * variable((3,3))). Better than the old silently-wrong size.min(coeff.nrows()), and consistent with the infallible canonicalize_expr design — but a library panicking on user input is rough. Ideally validate shapes at expression/constraint construction where a CvxError::InvalidProblem can be returned. At minimum worth calling out that this is now a panic path.
| } | ||
| } | ||
|
|
||
| Shape::scalar() |
There was a problem hiding this comment.
The Shape::scalar() fallback means genuinely-incompatible or unsupported broadcasts silently report a scalar shape — e.g. (1,3) + (2,1) yields shape() == (). Mutual broadcasting (1,n)+(m,1) → (m,n) (supported by CVXPY) is not handled here. Pre-existing pattern, but more prominent now; consider erroring rather than collapsing to ().
| let new_const = &a_mat * &b.constant; | ||
| let shape = Shape::matrix(new_const.nrows(), new_const.ncols()); | ||
| let shape = if new_const.ncols() == 1 && new_const.nrows() > 1 { | ||
| Shape::vector(new_const.nrows()) |
There was a problem hiding this comment.
This reshapes any matmul result with ncols==1 && nrows>1 from matrix(m,1) to vector(m) — a global matmul semantics change beyond broadcasting. Worth a comment on why it's required and confirming no existing matmul-produces-column path depends on the old matrix(m,1) shape.
|
@haozhu10015 I had claude take a look and I think there is room to improve the PR. I would be fine with merging it as is though if you prefer. |
|
@SteveDiamond Thanks for the review! I think they should be addressed now. The following are the summary of changes:
One comment I should leave is that, currently I always use the panic path for capturing the errors. Ideally in Rust, we should handling the error with something like |
This PR fixes scalar affine broadcasting by adding a
promoteaffine atom.Close #11.
Changes
promote(expr, shape)API andExpr::Promote.(),(1,), and(1, 1)shaped variables.(1, n)to(m, n)viaones((m, 1)) @ expr(m, 1)to(m, n)viaexpr @ ones((1, n))elementwise_mul_const_linso mismatched canonical dimensions fail directly.Tests