diff --git a/aggregator.go b/aggregator.go new file mode 100644 index 0000000..f16a375 --- /dev/null +++ b/aggregator.go @@ -0,0 +1,184 @@ +package zeroconf + +import ( + "log" + "math/rand" + "sync" + "time" + + "github.com/miekg/dns" +) + +const ( + // RFC6762 Section 6: In any case where there may be multiple responses, + // each responder SHOULD delay its response by a random amount of time + // selected with uniform random distribution in the range 20-120 ms. + // RFC6762 Section 6.3: For query messages containing more than one + // question, all (non-defensive) answers SHOULD be randomly delayed in + // the range 20-120 ms. + responseMinDelay = 20 * time.Millisecond + responseMaxDelay = 120 * time.Millisecond + + // RFC6762 Section 6.4: Earlier responses may be delayed by up to an + // additional 500ms to permit aggregation with other responses scheduled + // to go out a little later. + responseMaxAggregationDelay = 500 * time.Millisecond +) + +// pendingResp holds a pending multicast response awaiting aggregated delivery. +type pendingResp struct { + msg *dns.Msg + firstSeen time.Time + ifIndex int + timer *time.Timer +} + +// responseAggregator implements RFC6762 Section 6.4 Response Aggregation. +// +// RFC6762 Section 6.4 requires that a responder, for the sake of network +// efficiency, aggregate as many responses as possible into a single Multicast +// DNS response message. Earlier responses SHOULD be delayed by up to an +// additional 500ms if that will permit them to be aggregated with other +// responses. +// +// This reduces network traffic when many nodes are present on the network. +type responseAggregator struct { + mu sync.Mutex + pending map[int]*pendingResp // ifIndex -> pending aggregated response + server *Server +} + +func newResponseAggregator(s *Server) *responseAggregator { + return &responseAggregator{ + pending: make(map[int]*pendingResp), + server: s, + } +} + +// schedule schedules a multicast response for aggregated delivery. +// +// RFC6762 Section 6.4: If a response for this interface is already pending +// within the aggregation window (500ms), the new response is merged into it +// rather than sending a separate packet. Otherwise, a new response is +// scheduled with a random delay of 20-120ms (RFC6762 Section 6 / 6.3). +func (a *responseAggregator) schedule(msg *dns.Msg, ifIndex int) { + a.mu.Lock() + + // If there is already a pending response for this interface, try to merge. + if existing, ok := a.pending[ifIndex]; ok { + mergeMsg(existing.msg, msg) + + // If the first-seen time has already exceeded the max aggregation delay, + // flush immediately (same behavior as before). + elapsed := time.Since(existing.firstSeen) + if elapsed >= responseMaxAggregationDelay { + // Max aggregation delay exceeded: flush the existing response now + existing.timer.Stop() + delete(a.pending, ifIndex) + a.mu.Unlock() + if len(existing.msg.Answer) > 0 { + if err := a.server.multicastResponse(existing.msg, existing.ifIndex); err != nil { + log.Printf("[ERR] zeroconf: failed to send aggregated response: %v", err) + } + } + return + } + + // Otherwise, reschedule delivery from *now* by a random delay of 20-120ms. + // However, do not delay beyond the remaining aggregation window. + delay := responseMinDelay + time.Duration(rand.Int63n(int64(responseMaxDelay-responseMinDelay))) + remaining := responseMaxAggregationDelay - elapsed + if delay > remaining { + delay = remaining + } + + // Stop the previous timer (best-effort) and replace it with a new one. + existing.timer.Stop() + existing.timer = time.AfterFunc(delay, func() { + a.mu.Lock() + cur, ok := a.pending[ifIndex] + if !ok || cur != existing { + // Already flushed or superseded. + a.mu.Unlock() + return + } + delete(a.pending, ifIndex) + a.mu.Unlock() + + if len(existing.msg.Answer) > 0 { + if err := a.server.multicastResponse(existing.msg, existing.ifIndex); err != nil { + log.Printf("[ERR] zeroconf: failed to send aggregated response: %v", err) + } + } + }) + + a.mu.Unlock() + return + } + + // RFC6762 Section 6 / 6.3: delay response by a random amount in [20ms, 120ms]. + delay := responseMinDelay + time.Duration(rand.Int63n(int64(responseMaxDelay-responseMinDelay))) + + newPending := &pendingResp{ + msg: msg.Copy(), + firstSeen: time.Now(), + ifIndex: ifIndex, + } + newPending.timer = time.AfterFunc(delay, func() { + a.mu.Lock() + cur, ok := a.pending[ifIndex] + if !ok || cur != newPending { + // Already flushed by another path. + a.mu.Unlock() + return + } + delete(a.pending, ifIndex) + a.mu.Unlock() + + if len(newPending.msg.Answer) > 0 { + if err := a.server.multicastResponse(newPending.msg, newPending.ifIndex); err != nil { + log.Printf("[ERR] zeroconf: failed to send aggregated response: %v", err) + } + } + }) + a.pending[ifIndex] = newPending + a.mu.Unlock() +} + +// shutdown cancels all pending responses without sending them. +// Must be called before closing the network connections. +func (a *responseAggregator) shutdown() { + a.mu.Lock() + defer a.mu.Unlock() + + for ifIndex, pending := range a.pending { + pending.timer.Stop() + delete(a.pending, ifIndex) + } +} + +// mergeMsg merges records from src into dst, skipping duplicates. +// RFC6762 Section 6.4: aggregate as many responses as possible into a single message. +func mergeMsg(dst, src *dns.Msg) { + for _, rr := range src.Answer { + if !containsRR(dst.Answer, rr) { + dst.Answer = append(dst.Answer, rr) + } + } + for _, rr := range src.Extra { + // Do not add to Extra if already present in Answer or Extra. + if !containsRR(dst.Answer, rr) && !containsRR(dst.Extra, rr) { + dst.Extra = append(dst.Extra, rr) + } + } +} + +// containsRR reports whether list contains a record equivalent to rr. +func containsRR(list []dns.RR, rr dns.RR) bool { + for _, r := range list { + if dns.IsDuplicate(r, rr) { + return true + } + } + return false +} diff --git a/server.go b/server.go index d895acd..78c1006 100644 --- a/server.go +++ b/server.go @@ -169,10 +169,11 @@ const ( // Server structure encapsulates both IPv4/IPv6 UDP connections type Server struct { - service *ServiceEntry - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn - ifaces []net.Interface + service *ServiceEntry + ipv4conn *ipv4.PacketConn + ipv6conn *ipv6.PacketConn + ifaces []net.Interface + aggregator *responseAggregator // RFC6762 6.4 Response Aggregation shouldShutdown chan struct{} shutdownLock sync.Mutex @@ -203,6 +204,7 @@ func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { ttl: opts.ttl, shouldShutdown: make(chan struct{}), } + s.aggregator = newResponseAggregator(s) return s, nil } @@ -241,6 +243,9 @@ func (s *Server) Shutdown() { return } + // Cancel any pending aggregated responses before closing connections. + s.aggregator.shutdown() + if err := s.unregister(); err != nil { log.Printf("failed to unregister: %s", err) } @@ -326,39 +331,60 @@ func (s *Server) handleQuery(query *dns.Msg, ifIndex int, from net.Addr) error { return nil } - // Handle each question + // RFC6762 6.4: Aggregate all multicast responses for this query into a + // single message. This reduces network traffic when many nodes are present. + multicastResp := dns.Msg{} + multicastResp.SetReply(query) + multicastResp.Compress = true + multicastResp.RecursionDesired = false + multicastResp.Authoritative = true + multicastResp.Question = nil // RFC6762 section 6: responses MUST NOT contain any questions + multicastResp.Answer = []dns.RR{} + multicastResp.Extra = []dns.RR{} + var err error for _, q := range query.Question { - resp := dns.Msg{} - resp.SetReply(query) - resp.Compress = true - resp.RecursionDesired = false - resp.Authoritative = true - resp.Question = nil // RFC6762 section 6 "responses MUST NOT contain any questions" - resp.Answer = []dns.RR{} - resp.Extra = []dns.RR{} - if err = s.handleQuestion(q, &resp, query, ifIndex); err != nil { - // log.Printf("[ERR] zeroconf: failed to handle question %v: %v", q, err) + // Use a per-question scratch buffer so that isKnownAnswer's + // "resp.Answer = nil" cannot clobber answers already accumulated + // from previous questions into multicastResp. + perQ := dns.Msg{} + perQ.Answer = []dns.RR{} + perQ.Extra = []dns.RR{} + if e := s.handleQuestion(q, &perQ, query, ifIndex); e != nil { + // log.Printf("[ERR] zeroconf: failed to handle question %v: %v", q, e) + err = e continue } - // Check if there is an answer - if len(resp.Answer) == 0 { + if len(perQ.Answer) == 0 { continue } if isUnicastQuestion(q) { - // Send unicast - if e := s.unicastResponse(&resp, ifIndex, from); e != nil { + // Unicast responses are sent immediately without aggregation. + unicastResp := dns.Msg{} + unicastResp.SetReply(query) + unicastResp.Compress = true + unicastResp.RecursionDesired = false + unicastResp.Authoritative = true + unicastResp.Question = nil // RFC6762 section 6 "responses MUST NOT contain any questions" + unicastResp.Answer = perQ.Answer + unicastResp.Extra = perQ.Extra + if e := s.unicastResponse(&unicastResp, ifIndex, from); e != nil { err = e } } else { - // Send mulicast - if e := s.multicastResponse(&resp, ifIndex); e != nil { - err = e - } + // Merge answers into the aggregated multicast response. + mergeMsg(&multicastResp, &perQ) } } + // Schedule the aggregated multicast response. + // RFC6762 Section 6.4: the aggregator will delay delivery by 20-120ms + // and merge with other pending responses for the same interface. + if len(multicastResp.Answer) > 0 { + s.aggregator.schedule(&multicastResp, ifIndex) + } + return err }