Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions protocols/frost/sign/round1.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sign

import (
"crypto/rand"
"sync"

"github.com/luxfi/threshold/internal/round"
"github.com/luxfi/threshold/pkg/math/curve"
Expand Down Expand Up @@ -122,6 +123,7 @@ func (r *round1) Finalize(out chan<- *round.Message) (round.Session, error) {
e_i: eI,
D: D,
E: E,
deMu: &sync.Mutex{},
}, nil
}

Expand Down
44 changes: 33 additions & 11 deletions protocols/frost/sign/round2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sign
import (
"fmt"
"sort"
"sync"

"github.com/cronokirby/saferith"
"github.com/gtank/merlin"
Expand Down Expand Up @@ -35,7 +36,8 @@ type round2 struct {
// D[i] = Dᵢ will contain all of the commitments created by each party, ourself included.
D map[party.ID]curve.Point
// E[i] = Eᵢ will contain all of the commitments created by each party, ourself included.
E map[party.ID]curve.Point
E map[party.ID]curve.Point
deMu *sync.Mutex
}

type broadcast2 struct {
Expand Down Expand Up @@ -70,12 +72,6 @@ func (r *round2) StoreBroadcastMessage(msg round.Message) error {
return fmt.Errorf("nonce commitment is the identity point")
}

// Only skip if we already have BOTH; otherwise we could drop one
if r.D[msg.From] != nil && r.E[msg.From] != nil {
// Already have both values for this party, skip
return nil
}

// Deep copy points to avoid aliasing issues - use marshal/unmarshal for clean copy
dBytes, err := body.D_i.MarshalBinary()
if err != nil {
Expand All @@ -95,8 +91,15 @@ func (r *round2) StoreBroadcastMessage(msg round.Message) error {
return fmt.Errorf("failed to unmarshal E_i: %w", err)
}

r.deMu.Lock()
// Only skip if we already have BOTH; otherwise we could drop one
if r.D[msg.From] != nil && r.E[msg.From] != nil {
r.deMu.Unlock()
return nil
}
r.D[msg.From] = dCopy
r.E[msg.From] = eCopy
r.deMu.Unlock()
return nil
}

Expand All @@ -111,25 +114,43 @@ func (r *round2) Finalize(out chan<- *round.Message) (round.Session, error) {
// Check if we have all D and E values from ALL signers
// This is critical - we MUST have D,E from every signer before proceeding
signers := r.PartyIDs()

r.deMu.Lock()
missingCount := 0
for _, l := range signers {
if r.D[l] == nil || r.E[l] == nil {
missingCount++
}
// Also verify they're not identity points (shouldn't happen but double-check)
if r.D[l] != nil && r.D[l].IsIdentity() {
r.deMu.Unlock()
return r, fmt.Errorf("party %s has identity point for D", l)
}
if r.E[l] != nil && r.E[l].IsIdentity() {
r.deMu.Unlock()
return r, fmt.Errorf("party %s has identity point for E", l)
}
}

if missingCount > 0 {
r.deMu.Unlock()
// Not ready yet, return self to continue waiting for broadcasts
return r, nil
}

// Snapshot D and E under the lock, then release.
// After this point no new StoreBroadcastMessage calls will arrive
// for this round (protocol guarantees), so the copies are final.
D := make(map[party.ID]curve.Point, len(r.D))
E := make(map[party.ID]curve.Point, len(r.E))
for k, v := range r.D {
D[k] = v
}
for k, v := range r.E {
E[k] = v
}
r.deMu.Unlock()

// This essentially follows parts of Figure 3.

// 4. "Each Pᵢ then computes the set of binding values ρₗ = H₁(l, m, B).
Expand Down Expand Up @@ -165,13 +186,13 @@ func (r *round2) Finalize(out chan<- *round.Message) (round.Session, error) {
Bytes: []byte(l),
})
// Write canonical encoding of D[l]
dBytes, _ := r.D[l].MarshalBinary()
dBytes, _ := D[l].MarshalBinary()
_ = rhoPreHash.WriteAny(&hash.BytesWithDomain{
TheDomain: "D",
Bytes: dBytes,
})
// Write canonical encoding of E[l]
eBytes, _ := r.E[l].MarshalBinary()
eBytes, _ := E[l].MarshalBinary()
_ = rhoPreHash.WriteAny(&hash.BytesWithDomain{
TheDomain: "E",
Bytes: eBytes,
Expand All @@ -190,8 +211,8 @@ func (r *round2) Finalize(out chan<- *round.Message) (round.Session, error) {
RShares := make(map[party.ID]curve.Point)
// Use sorted order to ensure consistent R computation
for _, l := range sortedSigners {
RShares[l] = rho[l].Act(r.E[l])
RShares[l] = RShares[l].Add(r.D[l])
RShares[l] = rho[l].Act(E[l])
RShares[l] = RShares[l].Add(D[l])
R = R.Add(RShares[l])
}
var c curve.Scalar
Expand Down Expand Up @@ -302,6 +323,7 @@ func (r *round2) Finalize(out chan<- *round.Message) (round.Session, error) {
RShares: RShares,
c: c,
z: map[party.ID]curve.Scalar{r.SelfID(): zI},
zMu: &sync.Mutex{},
Lambda: Lambdas,
}, nil
}
Expand Down
15 changes: 13 additions & 2 deletions protocols/frost/sign/round3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sign

import (
"fmt"
"sync"

"github.com/luxfi/threshold/internal/round"
"github.com/luxfi/threshold/pkg/math/curve"
Expand All @@ -28,7 +29,8 @@ type round3 struct {
// z contains the response from each participant
//
// z[i] corresponds to zᵢ in the Frost paper
z map[party.ID]curve.Scalar
z map[party.ID]curve.Scalar
zMu *sync.Mutex

// Lambda contains all Lagrange coefficients of the parties participating in this session.
// Lambda[l] = λₗ
Expand Down Expand Up @@ -75,7 +77,9 @@ func (r *round3) StoreBroadcastMessage(msg round.Message) error {
return fmt.Errorf("failed to verify response from %v", from)
}

r.zMu.Lock()
r.z[from] = body.ZI
r.zMu.Unlock()

return nil
}
Expand All @@ -91,8 +95,15 @@ func (r *round3) Finalize(chan<- *round.Message) (round.Session, error) {
// These steps come from Figure 3 of the Frost paper.

// 7.c "Compute the group's response z = ∑ᵢ zᵢ"
r.zMu.Lock()
zMap := make(map[party.ID]curve.Scalar, len(r.z))
for k, v := range r.z {
zMap[k] = v
}
r.zMu.Unlock()

z := r.Group().NewScalar()
for _, z_l := range r.z {
for _, z_l := range zMap {
z.Add(z_l)
}

Expand Down
Loading