From a8165fc22fe9feeea70906d36e40d3d3d0d1dd14 Mon Sep 17 00:00:00 2001 From: Grigory Buteyko Date: Mon, 8 Jun 2026 19:06:40 +0300 Subject: [PATCH] return RPC tests --- pkg/rpc/client_test.go | 43 ++ pkg/rpc/clientmulti_test.go | 290 ++++++++++++ pkg/rpc/clientserver_test.go | 223 ++++++++++ pkg/rpc/crypto_test.go | 154 +++++++ pkg/rpc/forward_test.go | 139 ++++++ pkg/rpc/gracefull_shutdown_test.go | 145 ++++++ pkg/rpc/handshake_test.go | 337 ++++++++++++++ pkg/rpc/packetconn_test.go | 168 +++++++ pkg/rpc/rpc_format_test.go | 133 ++++++ pkg/rpc/rpc_test.go | 66 +++ pkg/rpc/server_longpoll_test.go | 680 +++++++++++++++++++++++++++++ pkg/rpc/server_shutdown_test.go | 194 ++++++++ pkg/rpc/udp_server_test.go | 47 ++ 13 files changed, 2619 insertions(+) create mode 100644 pkg/rpc/client_test.go create mode 100644 pkg/rpc/clientmulti_test.go create mode 100644 pkg/rpc/clientserver_test.go create mode 100644 pkg/rpc/crypto_test.go create mode 100644 pkg/rpc/forward_test.go create mode 100644 pkg/rpc/gracefull_shutdown_test.go create mode 100644 pkg/rpc/handshake_test.go create mode 100644 pkg/rpc/packetconn_test.go create mode 100644 pkg/rpc/rpc_format_test.go create mode 100644 pkg/rpc/rpc_test.go create mode 100644 pkg/rpc/server_longpoll_test.go create mode 100644 pkg/rpc/server_shutdown_test.go create mode 100644 pkg/rpc/udp_server_test.go diff --git a/pkg/rpc/client_test.go b/pkg/rpc/client_test.go new file mode 100644 index 000000000..04e160649 --- /dev/null +++ b/pkg/rpc/client_test.go @@ -0,0 +1,43 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewClient(t *testing.T) { + var logStr string + c := NewClient( + ClientWithLogf(func(format string, args ...any) { + logStr = fmt.Sprintf("my prefix "+format, args...) + }), + ClientWithForceEncryption(true), + ClientWithCryptoKey("crypto-key-crypto-key-crypto-key"), + ClientWithConnReadBufSize(123), + ClientWithConnWriteBufSize(456), + ClientWithPacketTimeout(2*time.Second), + ClientWithTrustedSubnetGroups([][]string{{"10.32.0.0/11"}}), + ).(*ClientImpl) + + require.Equal(t, 123, c.opts.ConnReadBufSize) + require.Equal(t, 456, c.opts.ConnWriteBufSize) + require.Equal(t, 2*time.Second, c.opts.PacketTimeout) + require.Equal(t, "crypto-key-crypto-key-crypto-key", c.opts.CryptoKey) + require.Equal(t, true, c.opts.ForceEncryption) + + c.Logf("123") + require.Equal(t, "my prefix 123", logStr) + + expectedTrustedSubnetGroups, errs := ParseTrustedSubnets([][]string{{"10.32.0.0/11"}}) + require.Equal(t, expectedTrustedSubnetGroups, c.opts.TrustedSubnetGroups) + require.Nil(t, errs) +} diff --git a/pkg/rpc/clientmulti_test.go b/pkg/rpc/clientmulti_test.go new file mode 100644 index 000000000..49b1cb44e --- /dev/null +++ b/pkg/rpc/clientmulti_test.go @@ -0,0 +1,290 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "context" + "errors" + "net" + "sync" + "testing" + + "pgregory.net/rapid" +) + +func TestDoMultiErrorError(t *testing.T) { + t.Parallel() + + addr := NetAddr{ + Network: "tcp4", + Address: "127.0.0.1:10000", + } + wrappedErr := errors.New("boom") + + tests := []struct { + name string + input DoMultiError + expected string + }{ + { + name: "with wrapped error", + input: DoMultiError{ + Addr: addr, + Err: wrappedErr, + msg: "failed to prepare request for", + }, + expected: "failed to prepare request for tcp4://127.0.0.1:10000: boom", + }, + { + name: "without wrapped error", + input: DoMultiError{ + Addr: addr, + msg: "failed to prepare request for", + }, + expected: "failed to prepare request for tcp4://127.0.0.1:10000", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.input.Error(); got != tt.expected { + t.Fatalf("unexpected Error() string: got %q, want %q", got, tt.expected) + } + }) + } +} + +func TestDoMultiErrorUnwrap(t *testing.T) { + t.Parallel() + + innerErr := errors.New("inner") + err := DoMultiError{ + Addr: NetAddr{ + Network: "tcp4", + Address: "127.0.0.1:10000", + }, + Err: innerErr, + msg: "failed to handle response from", + } + + if !errors.Is(err, innerErr) { + t.Fatalf("expected wrapped error %q, got %q", innerErr, err) + } +} + +func TestRPCMultiRoundtrip(t *testing.T) { + t.Parallel() + + // this is not really a property-based test, since it is not deterministic + // however, biased integer generators from rapid are very convenient + rapid.Check(t, testRPCMultiRoundtrip) +} + +func TestDoMultiReturnsDoMultiErrorWithActorID(t *testing.T) { + t.Parallel() + + var c ClientImpl + prepareErr := errors.New("prepare failed") + addr := NetAddr{ + Network: "tcp4", + Address: "127.0.0.1:10000", + } + const actorID int64 = 42 + + err := c.DoMulti( + context.Background(), + []NetAddr{addr}, + func(_ NetAddr, req *Request) error { + req.ActorID = actorID + return prepareErr + }, + func(_ NetAddr, _ *Response, _ error) error { return nil }, + ) + if err == nil { + t.Fatal("expected error") + } + + var doMultiErr DoMultiError + if !errors.As(err, &doMultiErr) { + t.Fatalf("expected DoMultiError, got %T", err) + } + if doMultiErr.Addr != addr { + t.Fatalf("unexpected addr in error: got %+v, want %+v", doMultiErr.Addr, addr) + } + if doMultiErr.ActorID != actorID { + t.Fatalf("unexpected actorID in error: got %d, want %d", doMultiErr.ActorID, actorID) + } + if !errors.Is(err, prepareErr) { + t.Fatalf("expected wrapped prepare error %q, got %q", prepareErr, err) + } +} + +func TestDoMultiReturnsDoMultiErrorWithActorIDFromProcessResponse(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + s := NewServer( + ServerWithHandler(handler), + ServerWithCryptoKeys(testCryptoKeys), + ) + serverErr := make(chan error, 1) + go func() { + serverErr <- s.Serve(ln) + }() + defer func() { + if closeErr := s.Close(); closeErr != nil { + t.Errorf("failed to close server: %v", closeErr) + } + if serveErr := <-serverErr; serveErr != nil { + t.Errorf("server serve error: %v", serveErr) + } + }() + + c := NewClient( + ClientWithProtocolVersion(LatestProtocolVersion), + ClientWithCryptoKey(testCryptoKeys[0]), + ) + defer func() { + if closeErr := c.Close(); closeErr != nil { + t.Errorf("failed to close client: %v", closeErr) + } + }() + + addr := NetAddr{ + Network: "tcp4", + Address: ln.Addr().String(), + } + const actorID int64 = 777 + processErr := errors.New("process failed") + + err = c.DoMulti( + context.Background(), + []NetAddr{addr}, + func(_ NetAddr, req *Request) error { + _ = prepareTestRequest(req) + req.ActorID = actorID + return nil + }, + func(_ NetAddr, _ *Response, _ error) error { + return processErr + }, + ) + if err == nil { + t.Fatal("expected error") + } + + var doMultiErr DoMultiError + if !errors.As(err, &doMultiErr) { + t.Fatalf("expected DoMultiError, got %T", err) + } + if doMultiErr.Addr != addr { + t.Fatalf("unexpected addr in error: got %+v, want %+v", doMultiErr.Addr, addr) + } + if doMultiErr.ActorID != actorID { + t.Fatalf("unexpected actorID in error: got %d, want %d", doMultiErr.ActorID, actorID) + } + if !errors.Is(err, processErr) { + t.Fatalf("expected wrapped process error %q, got %q", processErr, err) + } +} + +func testRPCMultiRoundtrip(t *rapid.T) { + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + clients := rapid.SliceOf(rapid.Custom(genClient)).Draw(t, "clients") + numRequests := rapid.IntRange(1, 10).Draw(t, "numRequests") + + s := NewServer( + ServerWithHandler(handler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithMaxConns(rapid.IntRange(0, 3).Draw(t, "maxConns")), + ServerWithMaxWorkers(rapid.IntRange(-1, 3).Draw(t, "maxWorkers")), + ServerWithConnReadBufSize(rapid.IntRange(0, 64).Draw(t, "connReadBufSize")), + ServerWithConnWriteBufSize(rapid.IntRange(0, 64).Draw(t, "connWriteBufSize")), + ServerWithRequestBufSize(rapid.IntRange(512, 1024).Draw(t, "requestBufSize")), + ServerWithResponseBufSize(rapid.IntRange(512, 1024).Draw(t, "responseBufSize")), + ) + serverErr := make(chan error) + go func() { + serverErr <- s.Serve(ln) + }() + + var wg sync.WaitGroup + for _, c := range clients { + wg.Add(1) + go func(c Client) { + defer wg.Done() + + m := c.Multi(numRequests) + defer m.Close() + + queryIDs := map[int64]struct{}{} + queryIDToBodyCopy := map[int64]string{} + + for j := 0; j < numRequests; j++ { + req := c.GetRequest() + queryID := req.QueryID() + bodyCopy := prepareTestRequest(req) + + err := m.Start(context.Background(), "tcp4", ln.Addr().String(), req) + if err != nil { + t.Errorf("failed to start request %v: %v", j, err) + } + + queryIDs[queryID] = struct{}{} + queryIDToBodyCopy[queryID] = bodyCopy + } + + for k := 0; k < numRequests; k++ { + var queryID int64 + var resp *Response + var err error + if k%2 == 0 { + for qID := range queryIDs { + queryID = qID // get the first request ID from the map + break + } + resp, err = m.Wait(context.Background(), queryID) + } else { + queryID, resp, err = m.WaitAny(context.Background()) + } + + bodyCopy := queryIDToBodyCopy[queryID] + delete(queryIDToBodyCopy, queryID) + delete(queryIDs, queryID) + checkTestResponse(t, resp, err, bodyCopy) + c.PutResponse(resp) + } + + err := c.Close() + if err != nil { + t.Errorf("failed to close client: %v", err) + } + }(c) + } + + wg.Wait() + + err = s.Close() + if err != nil { + t.Fatal(err) + } + err = <-serverErr + if err != nil { + t.Fatal(err) + } +} diff --git a/pkg/rpc/clientserver_test.go b/pkg/rpc/clientserver_test.go new file mode 100644 index 000000000..81aac5541 --- /dev/null +++ b/pkg/rpc/clientserver_test.go @@ -0,0 +1,223 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "context" + "encoding/binary" + "net" + "reflect" + "strconv" + "sync" + "testing" + + "pgregory.net/rand" + "pgregory.net/rapid" + + "github.com/VKCOM/tl/internal/vkgo/pkg/basictl" + "github.com/VKCOM/tl/pkg/rpc/internal/gen/tl" +) + +const ( + testRequestType = uint32(0x12345678) + testRequestMethod = "test.request" +) + +var testCryptoKeys = []string{"test-crypto-key-crypto-key-crypto-key", "2-crypto-key-crypto-key-crypto-key", "1test-crypto-key-crypto-key-crypto-key"} + +func handler(_ context.Context, hctx *HandlerContext) (err error) { + bodyCopy := string(hctx.Request) + hctx.Request, err = basictl.NatReadExactTag(hctx.Request, testRequestType) + if err != nil { + return err + } + + var n uint32 + if hctx.bodyFormatTL2 { // pretend we have variation in format + var sz int + hctx.Request, sz, err = basictl.TL2ParseSize(hctx.Request) + n = uint32(sz) + } else { + hctx.Request, err = basictl.NatRead(hctx.Request, &n) + } + if err != nil { + return err + } + + if n%7 == 0 { + hctx.ResponseExtra.SetBinlogPos(int64(n) * 100) + } + + if n%2 != 0 { + rpcErr := &Error{ + Code: int32(n), + Description: strconv.Itoa(int(n)), + } + if n/7%2 == 0 || hctx.BodyFormatTL2() { + return rpcErr // will serialize by default. + } + switch (n / 2) % 3 { + case 0: + // serialize manually to check parsing correctness + hctx.Response = basictl.NatWrite(hctx.Response, tl.ReqError{}.TLTag()) + hctx.Response = basictl.IntWrite(hctx.Response, rpcErr.Code) + hctx.Response = basictl.StringWrite(hctx.Response, rpcErr.Description) + return nil + case 1: + // serialize manually to check parsing correctness + hctx.Response = basictl.NatWrite(hctx.Response, tl.RpcReqResultErrorWrapped{}.TLTag()) + hctx.Response = basictl.IntWrite(hctx.Response, rpcErr.Code) + hctx.Response = basictl.StringWrite(hctx.Response, rpcErr.Description) + return nil + default: + // serialize manually to check parsing correctness + hctx.Response = basictl.NatWrite(hctx.Response, tl.RpcReqResultError{}.TLTag()) + hctx.Response = basictl.LongWrite(hctx.Response, 0) // unused + hctx.Response = basictl.IntWrite(hctx.Response, rpcErr.Code) + hctx.Response = basictl.StringWrite(hctx.Response, rpcErr.Description) + return nil + } + } + hctx.Response = append(hctx.Response, bodyCopy...) + return nil +} + +func prepareTestRequest(req *Request) string { + req.FunctionName = testRequestMethod + j := rand.Int63() + if j%2 == 0 { + req.ActorID = j + } + if (j/2)%2 == 0 { + req.Extra.SetIntForward(j) + } + n := uint32(req.QueryID()) + + if (j/4)%2 == 0 { + req.BodyFormatTL2 = true + } + req.Body = binary.LittleEndian.AppendUint32(req.Body, testRequestType) + if req.BodyFormatTL2 { // pretend we have variation in format + req.Body = basictl.TL2WriteSize(req.Body, int(n)) + } else { + req.Body = binary.LittleEndian.AppendUint32(req.Body, n) + } + req.Body = append(req.Body, "body"...) + return string(req.Body) +} + +func checkTestResponse(t *rapid.T, resp *Response, err error, bodyCopy string) { + n := uint32(resp.QueryID()) + if n%2 != 0 { + refErr := &Error{ + Code: int32(n), + Description: strconv.Itoa(int(n)), + } + + if !reflect.DeepEqual(err, refErr) { + t.Errorf("got error %q instead of %q", err, refErr) + } + } else if resp == nil || bodyCopy != string(resp.Body) { + t.Errorf("sent %q, got back %v (%v)", bodyCopy, resp, err) + } +} + +func dorequest(t *rapid.T, c Client, addr string) { + req := c.GetRequest() + bodyCopy := prepareTestRequest(req) + resp, err := c.Do(context.Background(), "tcp4", addr, req) + defer c.PutResponse(resp) + checkTestResponse(t, resp, err, bodyCopy) +} + +func TestRPCRoundtrip(t *testing.T) { + t.Parallel() + + // this is not really a property-based test, since it is not deterministic + // however, biased integer generators from rapid are very convenient + rapid.Check(t, testRPCRoundtrip) +} + +func genClient(t *rapid.T) Client { + return NewClient( + ClientWithProtocolVersion(LatestProtocolVersion), + ClientWithForceEncryption(rapid.Bool().Draw(t, "forceEncryption")), + ClientWithCryptoKey(testCryptoKeys[rapid.IntRange(0, 2).Draw(t, "cryptoKeyIndex")]), + ClientWithConnReadBufSize(rapid.IntRange(0, 64).Draw(t, "connReadBufSize")), + ClientWithConnWriteBufSize(rapid.IntRange(0, 64).Draw(t, "connWriteBufSize")), + ) +} + +func genServer(t *rapid.T) *Server { + return NewServer( + ServerWithHandler(handler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithMaxConns(rapid.IntRange(0, 3).Draw(t, "maxConns")), + ServerWithMaxWorkers(rapid.IntRange(-1, 3).Draw(t, "maxWorkers")), + ServerWithConnReadBufSize(rapid.IntRange(0, 64).Draw(t, "connReadBufSize")), + ServerWithConnWriteBufSize(rapid.IntRange(0, 64).Draw(t, "connWriteBufSize")), + ServerWithRequestBufSize(rapid.IntRange(512, 1024).Draw(t, "requestBufSize")), + ServerWithResponseBufSize(rapid.IntRange(512, 1024).Draw(t, "responseBufSize")), + ) +} + +func testRPCRoundtrip(t *rapid.T) { + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + clients := rapid.SliceOf(rapid.Custom(genClient)).Draw(t, "clients") + numRequests := rapid.IntRange(1, 10).Draw(t, "numRequests") + + s := genServer(t) + + serverErr := make(chan error, 1) + var wg sync.WaitGroup + go func() { + serverErr <- s.Serve(ln) + }() + + for _, c := range clients { + wg.Add(1) + go func(c Client) { + defer wg.Done() + + var cwg sync.WaitGroup + for j := 0; j < numRequests; j++ { + cwg.Add(1) + go func(j int) { + defer cwg.Done() + + dorequest(t, c, ln.Addr().String()) + }(j) + } + + cwg.Wait() + + err := c.Close() + if err != nil { + t.Errorf("failed to close client: %v", err) + } + }(c) + } + // s.Shutdown() + // ctx, cancel := context.WithTimeout(context.Background(), time.Second) + // defer cancel() + // _ = s.CloseWait(ctx) + + wg.Wait() + + err = s.Close() + if err != nil { + t.Fatal(err) + } + err = <-serverErr + if err != nil { + t.Fatal(err) + } +} diff --git a/pkg/rpc/crypto_test.go b/pkg/rpc/crypto_test.go new file mode 100644 index 000000000..99d13ad6e --- /dev/null +++ b/pkg/rpc/crypto_test.go @@ -0,0 +1,154 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "io" + "testing" + + "pgregory.net/rapid" +) + +type cryptoRWMachine struct { + buf *bytes.Buffer + r *cryptoReader + w *cryptoWriter + enc cipher.BlockMode + read *bytes.Buffer + written *bytes.Buffer + flushed int + encStart int +} + +func (c *cryptoRWMachine) init(t *rapid.T) { + rb := rapid.IntRange(0, 4*des.BlockSize).Draw(t, "rb") + wb := rapid.IntRange(0, 4*des.BlockSize).Draw(t, "wb") + + c.buf = &bytes.Buffer{} + c.r = newCryptoReader(c.buf, rb) + c.w = newCryptoWriter(c.buf, wb) + c.read = &bytes.Buffer{} + c.written = &bytes.Buffer{} +} + +func (c *cryptoRWMachine) Check(t *rapid.T) { + nr, nw := c.read.Len(), c.written.Len() + + if nr > nw { + t.Fatalf("read %v bytes, written %v bytes", nr, nw) + } + if !bytes.Equal(c.read.Bytes(), c.written.Bytes()[:nr]) { + t.Fatalf("read %q, but written %q", c.read.Bytes(), c.written.Bytes()[:nr]) + } +} + +func (c *cryptoRWMachine) Encrypt(t *rapid.T) { + if c.enc != nil { + t.Skip("already encrypted") + } + + key := rapid.SliceOfN(rapid.Byte(), 8, 8).Draw(t, "key") + e, err := des.NewCipher(key) + if err != nil { + t.Fatal(err) + } + + c.encStart = c.written.Len() + iv := rapid.SliceOfN(rapid.Byte(), des.BlockSize, des.BlockSize).Draw(t, "iv") + c.w.encrypt(cipher.NewCBCEncrypter(e, iv)) + c.enc = cipher.NewCBCDecrypter(e, iv) +} + +func (c *cryptoRWMachine) Read(t *rapid.T) { + n := rapid.IntRange(0, 32768).Draw(t, "n") + if c.encStart >= c.read.Len() && c.encStart < c.read.Len()+n && c.r.enc == nil { + n = c.encStart - c.read.Len() + } + p := make([]byte, n) + + shouldReadFlushed := c.read.Len()+n >= c.flushed + + m, err := c.r.Read(p) + if err != nil && err != io.EOF { + t.Fatalf("read failed: %v", err) + } + if m > len(p) { + t.Fatalf("long? read: %v instead of max %v", m, len(p)) + } + if shouldReadFlushed && c.read.Len()+m < c.flushed { + t.Fatalf("read only %v total, with %v flushed", c.read.Len()+m, c.flushed) + } + + c.read.Write(p[:m]) + + if c.w.enc != nil && c.r.enc == nil && c.encStart == c.read.Len() { + c.r.encrypt(c.enc) + } +} + +func (c *cryptoRWMachine) Write(t *rapid.T) { + p := rapid.SliceOf(rapid.Byte()).Draw(t, "p") + q := append([]byte(nil), p...) + + n, err := c.w.Write(p) + if err != nil { + t.Fatalf("write failed: %v", err) + } + if n != len(p) { + t.Fatalf("short write: %v instead of %v", n, len(p)) + } + if !bytes.Equal(p, q) { + t.Fatalf("write buffer modified: %q instead of %q", p, q) + } + + c.written.Write(p) +} + +func (c *cryptoRWMachine) Flush(t *rapid.T) { + err := c.w.Flush() + if err != nil { + t.Fatalf("flush failed: %v", err) + } + + if c.enc != nil { + c.flushed = c.encStart + roundDownPow2(c.written.Len()-c.encStart, c.w.blockSize) + } else { + c.flushed = c.written.Len() + } +} + +func TestCryptoRWRoundtrip(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + var m cryptoRWMachine + m.init(t) + t.Repeat(rapid.StateMachineActions(&m)) + }) +} + +func BenchmarkCryptoWriter_Write(b *testing.B) { + w := newCryptoWriter(io.Discard, 0) + e, err := aes.NewCipher(make([]byte, 16)) + if err != nil { + b.Fatal(err) + } + w.encrypt(cipher.NewCBCEncrypter(e, make([]byte, e.BlockSize()))) + b.ResetTimer() + + var msg [64]byte + for i := 0; i < b.N; i++ { + _, err = w.Write(msg[:]) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/pkg/rpc/forward_test.go b/pkg/rpc/forward_test.go new file mode 100644 index 000000000..80170d633 --- /dev/null +++ b/pkg/rpc/forward_test.go @@ -0,0 +1,139 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "net" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "pgregory.net/rapid" +) + +type forwardPacketMachine struct { + encryptClient bool + encryptServer bool + protocolClient uint32 + protocolServer uint32 + + client *PacketConn + server *PacketConn + proxyClient *PacketConn + proxyServer *PacketConn +} + +func newForwardPacketMachine(t *rapid.T) (_ forwardPacketMachine, _ func(), err error) { + startTime := uint32(time.Now().Unix()) + cryptoKey := "crypto_key" + strings.Repeat("_", 32) // crypto_key must be longer then 32 bytes + handshakeServer := func(conn net.Conn, forceEncryption bool) (*PacketConn, error) { + res := NewPacketConn(conn, DefaultServerRequestBufSize, DefaultServerResponseBufSize) + _, _, err := res.HandshakeServer([]string{cryptoKey}, nil, forceEncryption, startTime, 0) + return res, err + } + handshakeClient := func(conn net.Conn, version uint32, forceEncryption bool) (*PacketConn, error) { + res := NewPacketConn(conn, DefaultServerRequestBufSize, DefaultServerResponseBufSize) + err := res.HandshakeClient(cryptoKey, nil, forceEncryption, startTime, 0, 0, version) + return res, err + } + protocolVersion := func(label string) uint32 { + if rapid.Bool().Draw(t, label) { + return DefaultProtocolVersion + } else { + return LatestProtocolVersion + } + } + client, proxyServer := net.Pipe() + proxyClient, server := net.Pipe() + cancel := func() { + client.Close() + proxyServer.Close() + proxyClient.Close() + server.Close() + } + defer func() { + if err != nil { + cancel() + } + }() + res := forwardPacketMachine{ + encryptClient: rapid.Bool().Draw(t, "encrypt_client"), + encryptServer: rapid.Bool().Draw(t, "encrypt_server"), + protocolClient: protocolVersion("protocol_client"), + protocolServer: protocolVersion("protocol_proxy"), + } + var group errgroup.Group + group.Go(func() (err error) { + res.server, err = handshakeServer(server, res.encryptServer) + return err + }) + group.Go(func() (err error) { + if res.proxyServer, err = handshakeServer(proxyServer, res.encryptClient); err == nil { + res.proxyClient, err = handshakeClient(proxyClient, res.protocolServer, res.encryptServer) + } + return err + }) + if res.client, err = handshakeClient(client, res.protocolClient, res.encryptClient); err == nil { + err = group.Wait() + } + return res, cancel, err +} + +func (m *forwardPacketMachine) run(t *rapid.T) { + type message struct { + tip uint32 + body []byte + } + minBodyLen := 1 + legacyProtocol := m.protocolClient == 0 + if legacyProtocol { + minBodyLen = 4 + } + bodyBuf := make([]byte, 256) + for i := 0; i < 512; i++ { + var forward errgroup.Group + forward.Go(func() error { + res := ForwardPacket(m.proxyClient, m.proxyServer, bodyBuf, forwardPacketOptions{testEnv: false}) + return res.Error() + }) + sent := message{ + tip: 0x1234567, + body: rapid.SliceOfN(rapid.Byte(), minBodyLen, 1024).Draw(t, "body"), + } + if legacyProtocol { + sent.body = sent.body[:len(sent.body)-len(sent.body)%4] + } + err := m.client.WritePacket(sent.tip, sent.body, DefaultPacketTimeout) + require.NoError(t, err) + err = m.client.Flush() + require.NoError(t, err) + var receive errgroup.Group + var received message + receive.Go(func() (err error) { + received.tip, received.body, err = m.server.ReadPacket(nil, DefaultPacketTimeout) + return err + }) + require.NoError(t, forward.Wait()) + require.NoError(t, receive.Wait()) + if m.protocolServer == 0 { + writeAlignTo4 := int(-uint(len(sent.body)) & 3) + sent.body = append(sent.body, forwardPacketTrailer[writeAlignTo4]...) + } + require.Equal(t, sent, received) + } +} + +func TestForwardPacket(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + machine, shutdown, err := newForwardPacketMachine(t) + require.NoError(t, err) + machine.run(t) + shutdown() + }) +} diff --git a/pkg/rpc/gracefull_shutdown_test.go b/pkg/rpc/gracefull_shutdown_test.go new file mode 100644 index 000000000..8b4adc7ba --- /dev/null +++ b/pkg/rpc/gracefull_shutdown_test.go @@ -0,0 +1,145 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "pgregory.net/rapid" +) + +func TestRPCGraceful(t *testing.T) { + t.Parallel() + + // this is not really a property-based test, since it is not deterministic + // however, biased integer generators from rapid are very convenient + rapid.Check(t, testRPCGraceful) +} + +func testRPCGraceful(t *rapid.T) { + if debugPrint { + fmt.Printf("---- testRPCGraceful\n") + } + var lc net.ListenConfig + lc.Control = ControlSetTCPReuse(true, true) + + ln, err := lc.Listen(context.Background(), "tcp4", "127.0.0.1:41577") + if err != nil { + t.Fatal(err) + } + addr := ln.Addr().String() + _ = ln.Close() // otherwise, we must listen to it or lose SYNs + var servers []*Server + serverNum := 0 + for i := 0; i != 10; i++ { + servers = append(servers, genServer(t)) + } + clients := []Client{genClient(t)} // rapid.SliceOf(rapid.Custom(genClient)).Draw(t, "clients") + numRequests := 10 // rapid.IntRange(1, 10).Draw(t, "numRequests") + + var smu sync.Mutex + var server *Server + + var wg sync.WaitGroup + var sg sync.WaitGroup + replaceServer := func() { + sg.Add(1) + smu.Lock() + if serverNum >= len(servers) { + smu.Unlock() + return + } + s := servers[serverNum] + serverNum++ + if server != nil { + if debugPrint { + fmt.Printf("%v shutdown of %p\n", time.Now(), server) + } + server.Shutdown() + } + server = s + smu.Unlock() + + ln2, err := lc.Listen(context.Background(), "tcp4", addr) + if err != nil { + t.Fatal(err) + } + go func() { + defer sg.Done() + if debugPrint { + fmt.Printf("%v serve of %p\n", time.Now(), s) + } + if err := s.Serve(ln2); err != nil { + t.Fatal(err) + } + if debugPrint { + fmt.Printf("%v quit of %p\n", time.Now(), s) + } + }() + } + + replaceServer() + + var reqCounter atomic.Int64 + + for _, c := range clients { + wg.Add(1) + go func(c Client) { + defer wg.Done() + + for j := 0; j < numRequests*2; j++ { + reqID := reqCounter.Add(1) + if reqID%int64(numRequests) == 0 { + replaceServer() + } + dorequest(t, c, addr) + } + if debugPrint { + fmt.Printf("%v before client close\n", time.Now()) + } + if err := c.Close(); err != nil { + t.Fatal(err) + } + if debugPrint { + fmt.Printf("%v after client close\n", time.Now()) + } + }(c) + } + wg.Wait() // All clients finished + if debugPrint { + fmt.Printf("%v all clients finished\n", time.Now()) + } + + smu.Lock() + if server != nil { + if debugPrint { + fmt.Printf("%v shutdown (last) of %p\n", time.Now(), server) + } + server.Shutdown() + } + smu.Unlock() + // ctx, cancel := context.WithTimeout(context.Background(), time.Second) + // defer cancel() + // _ = s.CloseWait(ctx) + + sg.Wait() + if debugPrint { + fmt.Printf("%v all servers finished\n", time.Now()) + } + + for _, s := range servers { + if err := s.Close(); err != nil { + t.Fatal(err) + } + } +} diff --git a/pkg/rpc/handshake_test.go b/pkg/rpc/handshake_test.go new file mode 100644 index 000000000..51368bf5b --- /dev/null +++ b/pkg/rpc/handshake_test.go @@ -0,0 +1,337 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "bytes" + cryptorand "crypto/rand" + "encoding/binary" + "encoding/hex" + "reflect" + "testing" + + "github.com/VKCOM/tl/pkg/rpc/udp" + "pgregory.net/rapid" + + "golang.org/x/crypto/curve25519" +) + +func TestDeriveCryptoKeys(t *testing.T) { + clientNonce := [16]byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'} + serverNonce := [16]byte{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P'} + serverSend := deriveCryptoKeys(false, "hren", 0x01020304, + clientNonce, 0x05060708, 0x090a, + serverNonce, 0x0d0e0f10, 0x1112, nil) + clientSend := deriveCryptoKeys(true, "hren", 0x01020304, + clientNonce, 0x05060708, 0x090a, + serverNonce, 0x0d0e0f10, 0x1112, nil) + if hex.EncodeToString(clientSend.Key[:]) != "28b5a5313b3ea9e2f6f0293e0748b2f743b0e112779faa77a3ee9d71ae70dda6" { + t.Fatalf("readKey") + } + if hex.EncodeToString(clientSend.IV[:]) != "80387128489168b336d998762bce6fef" { + t.Fatalf("readIV") + } + if hex.EncodeToString(serverSend.Key[:]) != "e3cf8557ea4ad963c3b637d466388403841d2e989a1fc684ac691c44b05ac9bb" { + t.Fatalf("writeKey") + } + if hex.EncodeToString(serverSend.IV[:]) != "1efd4c8aa43a87d1ea5488a1bc669269" { + t.Fatalf("writeIV") + } +} + +func TestDeriveCryptoKeysV1(t *testing.T) { + clientNonce := [16]byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'} + serverNonce := [16]byte{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P'} + serverSend := deriveCryptoKeys(false, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, nil) + clientSend := deriveCryptoKeys(true, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, nil) + if hex.EncodeToString(clientSend.Key[:]) != "373374076f52d8f6bb5b063f17b9eb9fb4194e429cf02e207300add4c28a8e57" { + t.Fatalf("readKey") + } + if hex.EncodeToString(clientSend.IV[:]) != "cea8f827019de36741f73e5948aea5be" { + t.Fatalf("readIV") + } + if hex.EncodeToString(serverSend.Key[:]) != "3ce0c95487d99754688e0508a036c8c02727f297d0311db6273d69c07ac7a0d2" { + t.Fatalf("writeKey") + } + if hex.EncodeToString(serverSend.IV[:]) != "34411262ac3e172bc1a2d086b4f1ecb5" { + t.Fatalf("writeIV") + } +} + +func TestDeriveCryptoKeysV2(t *testing.T) { + clientScalar := [32]byte{'0', '1', '2', '3', '4', '4', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'} + serverScalar := [32]byte{'5', '6', '7', '8', '9', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'} + clientPoint, err := curve25519.X25519(clientScalar[:], curve25519.Basepoint) + if err != nil { + t.Fatalf("client D-H point error: %v", err) + } + serverPoint, err := curve25519.X25519(serverScalar[:], curve25519.Basepoint) + if err != nil { + t.Fatalf("server D-H point error:%v", err) + } + clientSharedSecret, err := curve25519.X25519(clientScalar[:], serverPoint) + if err != nil { + t.Fatalf("client D-H shared secret error:%v", err) + } + serverSharedSecret, err := curve25519.X25519(serverScalar[:], clientPoint) + if err != nil { + t.Fatalf("server D-H shared secret error:%v", err) + } + if hex.EncodeToString(serverSharedSecret) != hex.EncodeToString(clientSharedSecret) { + t.Fatalf("different shared secrets") + } + if hex.EncodeToString(clientSharedSecret) != "4541d9fd5263298736d6ecdfa8c5834e12b54e2ad3bb95a50d2085dd4075f458" { + t.Fatalf("clientSharedSecret") + } + + clientNonce := [16]byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'} + serverNonce := [16]byte{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P'} + serverSend := deriveCryptoKeys(false, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, serverSharedSecret) + clientSend := deriveCryptoKeys(true, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, serverSharedSecret) + if hex.EncodeToString(clientSend.Key[:]) != "c513a88366728c719ffe885d943b0faa701ff7f0b061311b9af5fa5a0ec830ef" { + t.Fatalf("readKey") + } + if hex.EncodeToString(clientSend.IV[:]) != "bbaf9484282c1d021c21d9da05e822c0" { + t.Fatalf("readIV") + } + if hex.EncodeToString(serverSend.Key[:]) != "987d9938b0ea97bae1604e78d47131a5b0dc426054d5f9423d14f867480dce1d" { + t.Fatalf("writeKey") + } + if hex.EncodeToString(serverSend.IV[:]) != "cf55ffd9615629f9cc7fc6b14d9a48f8" { + t.Fatalf("writeIV") + } +} + +var result cryptoKeys + +func BenchmarkCryptoKeysV1(b *testing.B) { + var serverNonce [16]byte + _, _ = cryptorand.Read(serverNonce[:]) + for i := 0; i < b.N; i++ { + var clientNonce [16]byte + _, _ = cryptorand.Read(clientNonce[:]) + serverSend := deriveCryptoKeys(false, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, nil) + clientSend := deriveCryptoKeys(true, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, nil) + for j, b := range serverSend.Key { + result.Key[j] ^= b + } + for j, b := range serverSend.IV { + result.IV[j] ^= b + } + for j, b := range clientSend.Key { + result.Key[j] ^= b + } + for j, b := range clientSend.IV { + result.IV[j] ^= b + } + } +} + +func BenchmarkCryptoKeysV2(b *testing.B) { + var serverNonce [16]byte + _, _ = cryptorand.Read(serverNonce[:]) + var serverScalar [32]byte + _, _ = cryptorand.Read(serverScalar[:]) + + for i := 0; i < b.N; i++ { + var clientScalar [32]byte + _, _ = cryptorand.Read(clientScalar[:]) + + serverPoint, _ := curve25519.X25519(serverScalar[:], curve25519.Basepoint) + clientSharedSecret, _ := curve25519.X25519(clientScalar[:], serverPoint) + + var clientNonce [16]byte + _, _ = cryptorand.Read(clientNonce[:]) + serverSend := deriveCryptoKeys(false, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, clientSharedSecret) + clientSend := deriveCryptoKeys(true, "hren", 0x01020304, + clientNonce, 0, 0, + serverNonce, 0x0d0e0f10, 0, clientSharedSecret) + for j, b := range serverSend.Key { + result.Key[j] ^= b + } + for j, b := range serverSend.IV { + result.IV[j] ^= b + } + for j, b := range clientSend.Key { + result.Key[j] ^= b + } + for j, b := range clientSend.IV { + result.IV[j] ^= b + } + } +} + +func TestDeriveCryptoKeysUDP(t *testing.T) { + localPID := NetPID{ + Ip: 0x01020304, + PortPid: 0x05060708, + Utime: 0x090a0b0c, + } + remotePID := NetPID{ + Ip: 0x11121314, + PortPid: 0x15161718, + Utime: 0x191a1b1c, + } + keys := udp.DeriveCryptoKeysUdp("hren", &localPID, &remotePID, 0x20212223) + if hex.EncodeToString(keys.ReadKey[:]) != "be5a0ca85071077fb48f030ab7627f88bf6c4ac3dc4c2c72b96d12692f4c33cf" { + t.Fatalf("ReadKey") + } + if hex.EncodeToString(keys.WriteKey[:]) != "ff861195b0cffc0562e10667acd5ee3b7d884fb2bdcd63a68e03a7294c18be60" { + t.Fatalf("WriteKey") + } +} + +func TestNetPIDRoundtrip(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + pid1 := NetPID{ + Ip: rapid.Uint32().Draw(t, "ip"), + PortPid: rapid.Uint32().Draw(t, "portPid"), + Utime: rapid.Uint32().Draw(t, "time"), + } + + b1 := pid1.WriteTL1(nil) + + var b2 bytes.Buffer + err := binary.Write(&b2, binary.LittleEndian, pid1) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(b1, b2.Bytes()) { + t.Fatalf("got 0x%x instead of 0x%x", b1, b2.Bytes()) + } + + var pid2 NetPID + _, err = pid2.ReadTL1(b1) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(pid1, pid2) { + t.Fatalf("read back %#v, expected %#v", pid2, pid1) + } + }) +} + +func TestInvokeExtraRoundtrip(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + extra1 := InvokeReqExtra{} + if rapid.Bool().Draw(t, "WaitBinlogPos") { + extra1.SetWaitBinlogPos(rapid.Int64().Draw(t, "WaitBinlogPos")) + } + if rapid.Bool().Draw(t, "SetStringForwardKeys") { + extra1.SetStringForwardKeys(rapid.SliceOfN(rapid.String(), 0, 2).Draw(t, "SetStringForwardKeys")) + if len(extra1.StringForwardKeys) == 0 { + extra1.StringForwardKeys = nil + } + } + if rapid.Bool().Draw(t, "SetIntForwardKeys") { + extra1.SetIntForwardKeys(rapid.SliceOfN(rapid.Int64(), 0, 2).Draw(t, "SetIntForwardKeys")) + if len(extra1.IntForwardKeys) == 0 { + extra1.IntForwardKeys = nil + } + } + if rapid.Bool().Draw(t, "StringForward") { + extra1.SetStringForward(rapid.String().Draw(t, "StringForward")) + } + if rapid.Bool().Draw(t, "IntForward") { + extra1.SetIntForward(rapid.Int64().Draw(t, "IntForward")) + } + if rapid.Bool().Draw(t, "CustomTimeoutMs") { + extra1.SetCustomTimeoutMs(rapid.Int32().Draw(t, "CustomTimeoutMs")) + } + if rapid.Bool().Draw(t, "SupportedCompressionVersion") { + extra1.SetSupportedCompressionVersion(rapid.Int32().Draw(t, "SupportedCompressionVersion")) + } + if rapid.Bool().Draw(t, "RandomDelay") { + extra1.SetRandomDelay(rapid.Float64().Draw(t, "RandomDelay")) + } + if rapid.Bool().Draw(t, "ReturnBinlogPos") { + extra1.SetReturnBinlogPos(true) + } + if rapid.Bool().Draw(t, "NoResult") { + extra1.SetNoResult(true) + } + + b1 := extra1.WriteTL1(nil) + + var extra2 InvokeReqExtra + if _, err := extra2.ReadTL1(b1); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(extra1, extra2) { + t.Fatalf("read back %#v, expected %#v", extra2, extra1) + } + }) +} + +func TestResultExtraRoundtrip(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + pid1 := NetPID{ + Ip: rapid.Uint32().Draw(t, "ip"), + PortPid: rapid.Uint32().Draw(t, "portPid"), + Utime: rapid.Uint32().Draw(t, "time"), + } + + extra1 := ResponseExtra{} + if rapid.Bool().Draw(t, "BinlogPos") { + extra1.SetBinlogPos(rapid.Int64().Draw(t, "BinlogPos")) + } + if rapid.Bool().Draw(t, "BinlogTime") { + extra1.SetBinlogTime(rapid.Int64().Draw(t, "BinlogTime")) + } + if rapid.Bool().Draw(t, "EnginePID") { + extra1.SetEnginePid(pid1) + } + if rapid.Bool().Draw(t, "SetRequestSize") { + extra1.SetRequestSize(rapid.Int32().Draw(t, "SetRequestSize")) + } + if rapid.Bool().Draw(t, "SetResponseSize") { + extra1.SetResponseSize(rapid.Int32().Draw(t, "SetResponseSize")) + } + if rapid.Bool().Draw(t, "SetFailedSubqueries") { + extra1.SetFailedSubqueries(rapid.Int32().Draw(t, "SetFailedSubqueries")) + } + if rapid.Bool().Draw(t, "SetCompressionVersion") { + extra1.SetCompressionVersion(rapid.Int32().Draw(t, "SetCompressionVersion")) + } + if rapid.Bool().Draw(t, "SetStats") { + extra1.SetStats(rapid.MapOfN(rapid.String(), rapid.String(), 0, 2).Draw(t, "SetStats")) + if len(extra1.Stats) == 0 { + extra1.Stats = nil // canonical form + } + } + if rapid.Bool().Draw(t, "SetViewNumber") { + extra1.SetViewNumber(rapid.Int64Min(0).Draw(t, "SetViewNumber")) + if rapid.Bool().Draw(t, "SetEpochNumber") { + extra1.EpochNumber = rapid.Int64Min(0).Draw(t, "SetEpochNumber") + } + } + + b1 := extra1.WriteTL1(nil) + + var extra2 ResponseExtra + if _, err := extra2.ReadTL1(b1); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(extra1, extra2) { + t.Fatalf("read back %#v, expected %#v", extra2, extra1) + } + }) +} diff --git a/pkg/rpc/packetconn_test.go b/pkg/rpc/packetconn_test.go new file mode 100644 index 000000000..957334cd8 --- /dev/null +++ b/pkg/rpc/packetconn_test.go @@ -0,0 +1,168 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "bytes" + "crypto/aes" + "fmt" + "net" + "sync" + "testing" + "time" + + "pgregory.net/rapid" +) + +const ( + rwTimeout = 1000 * time.Millisecond // shorter timeout is better for shrinking, but can lead to flaky tests +) + +type packetContent struct { + packetType uint32 + body []byte +} + +type connEx struct { + pc *PacketConn + send []packetContent + recv []packetContent +} + +type packetConnMachine struct { + c1 *connEx + c2 *connEx +} + +func (p *packetConnMachine) init(t *rapid.T) { + nc1, nc2 := net.Pipe() + rb1 := rapid.IntRange(0, 4*aes.BlockSize).Draw(t, "rb1") + wb1 := rapid.IntRange(0, 4*aes.BlockSize).Draw(t, "wb1") + rb2 := rapid.IntRange(0, 4*aes.BlockSize).Draw(t, "rb2") + wb2 := rapid.IntRange(0, 4*aes.BlockSize).Draw(t, "wb2") + + p.c1 = &connEx{ + pc: NewPacketConn(nc1, rb1, wb1), + } + p.c2 = &connEx{ + pc: NewPacketConn(nc2, rb2, wb2), + } + + enc := rapid.Bool().Draw(t, "enc") + if enc { + key1 := rapid.SliceOfN(rapid.Byte(), 16, 16).Draw(t, "key1") + key2 := rapid.SliceOfN(rapid.Byte(), 16, 16).Draw(t, "key2") + iv1 := rapid.SliceOfN(rapid.Byte(), aes.BlockSize, aes.BlockSize).Draw(t, "iv1") + iv2 := rapid.SliceOfN(rapid.Byte(), aes.BlockSize, aes.BlockSize).Draw(t, "iv2") + + err := p.c1.pc.encrypt(key1, iv1, key2, iv2) + if err != nil { + t.Fatal(err) + } + + err = p.c2.pc.encrypt(key2, iv2, key1, iv1) + if err != nil { + t.Fatal(err) + } + } +} + +func (p *packetConnMachine) cleanup() { + _ = p.c1.pc.Close() + _ = p.c2.pc.Close() +} + +func (p *packetConnMachine) Check(t *rapid.T) { + if i := len(p.c1.send) - 1; i >= 0 { + if !equalPackets(p.c1.send[i], p.c2.recv[i]) { + t.Fatalf("c1 send %#v, c2 recv %#v", p.c1.send[i], p.c2.recv[i]) + } + } + if i := len(p.c2.send) - 1; i >= 0 { + if !equalPackets(p.c2.send[i], p.c1.recv[i]) { + t.Fatalf("c2 send %#v, c1 recv %#v", p.c2.send[i], p.c2.recv[i]) + } + } +} + +func equalPackets(p packetContent, q packetContent) bool { + return p.packetType == q.packetType && bytes.Equal(p.body, q.body) +} + +func (p *packetConnMachine) Send12(t *rapid.T) { + send(t, p.c1, p.c2) +} + +func (p *packetConnMachine) Send21(t *rapid.T) { + send(t, p.c2, p.c1) +} + +func send(t *rapid.T, from *connEx, to *connEx) { + t.Helper() + + pc := packetContent{ + body: bytes.Repeat(rapid.SliceOf(rapid.Byte()).Draw(t, "body"), 4), + } + if from.pc.writeSeqNum == startSeqNum { + pc.packetType = packetTypeRPCNonce + } else if from.pc.writeSeqNum == startSeqNum+1 { + pc.packetType = packetTypeRPCHandshake + } else { + pc.packetType = rapid.Uint32().Draw(t, "type") + } + + var wg sync.WaitGroup + wg.Add(2) + errCh := make(chan error, 2) + + go func() { + defer wg.Done() + + err := from.pc.WritePacket(pc.packetType, pc.body, rwTimeout) + if err != nil { + err = fmt.Errorf("failed to write packet: %w", err) + } + errCh <- err + + from.send = append(from.send, pc) + }() + + go func() { + defer wg.Done() + + typ, b, err := to.pc.ReadPacket(nil, rwTimeout) + if err != nil { + err = fmt.Errorf("failed to read packet: %w", err) + } + errCh <- err + + to.recv = append(to.recv, packetContent{ + packetType: typ, + body: b, + }) + }() + + wg.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + t.Fatal(err) + } + } +} + +func TestPacketConn(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + var m packetConnMachine + m.init(t) + defer m.cleanup() + t.Repeat(rapid.StateMachineActions(&m)) + }) +} diff --git a/pkg/rpc/rpc_format_test.go b/pkg/rpc/rpc_format_test.go new file mode 100644 index 000000000..d9ee796d2 --- /dev/null +++ b/pkg/rpc/rpc_format_test.go @@ -0,0 +1,133 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "encoding/hex" + "testing" + + "github.com/VKCOM/tl/pkg/rpc/internal/gen/tl" +) + +func testRPCRequestRoundTrip(t *testing.T, req *Request, mustBeBody string) { + if err := preparePacket(req); err != nil { + t.Error(err) + } + body := append(append([]byte(nil), req.Body[req.extraStart:]...), req.Body[:req.extraStart]...) + req.Body = req.Body[:req.extraStart] + if hex.EncodeToString(body) != mustBeBody { + t.Fatalf("not equal body %x", body) + } + + opts := &ServerOptions{} + hctx := &HandlerContext{ + Request: body, + } + if err := hctx.ParseInvokeReq(opts); err != nil { + t.Error(err) + } + if hctx.QueryID() != req.QueryID() { + t.Fatalf("not equal query id") + } + if hctx.ActorID() != req.ActorID { + t.Fatalf("not equal actor id") + } + if hctx.RequestExtra.String() != req.Extra.String() { + t.Fatalf("not equal extra") + } +} + +func testRPCResponseRoundTrip(t *testing.T, hctx *HandlerContext, sendErr error, mustBeBody string) { + hctx.ResponseExtra.SetBinlogPos(555) // test that extra fields not delivered + hctx.ResponseExtra.SetViewNumber(777) + if err := hctx.prepareResponseBody(sendErr); err != nil { + t.Error(err) + } + body := append(append([]byte(nil), hctx.Response[hctx.extraStart:]...), hctx.Response[:hctx.extraStart]...) + hctx.Response = hctx.Response[:hctx.extraStart] + if hex.EncodeToString(body) != mustBeBody { + t.Fatalf("not equal body %x", body) + } + + var header tl.RpcReqResultHeader + var err error + if body, err = header.ReadTL1(body); err != nil { + t.Error(err) + } + + var extra ResponseExtra + if _, err = parseResponseExtra(hctx.bodyFormatTL2, &extra, body); err != nil && err.Error() != sendErr.Error() { + t.Error(err) + } + if hctx.QueryID() != header.QueryId { + t.Fatalf("not equal query id") + } + if a, b := hctx.ResponseExtra.String(), extra.String(); a != b { + t.Fatalf("not equal extra %s\n%s", a, b) + } +} + +// Simple regression test +func TestRPCFormat(t *testing.T) { + req := &Request{ + Body: []byte{0xaa, 0xbb, 0xcc, 0xdd}, + FunctionName: "memcache.Get", + queryID: 222, + } + testRPCRequestRoundTrip(t, req, "de00000000000000aabbccdd") + testRPCRequestRoundTrip(t, req, "de00000000000000aabbccdd") + req.ActorID = 111 + testRPCRequestRoundTrip(t, req, "de00000000000000bdaa68756f00000000000000aabbccdd") + req.ActorID = 0 + req.Extra.SetCustomTimeoutMs(255) + req.Extra.SetReturnViewNumber(true) + requestExtraFieldsmask := req.Extra.Flags + testRPCRequestRoundTrip(t, req, "de000000000000005e0352e300008008ff000000aabbccdd") + req.ActorID = 111 + testRPCRequestRoundTrip(t, req, "de00000000000000f7aca5f06f0000000000000000008008ff000000aabbccdd") + + hctx := &HandlerContext{ + queryID: 222, + Response: []byte{0xa1, 0xb2, 0xc3, 0xd4}, + handlerContextFields: handlerContextFields{ + requestExtraFieldsmask: requestExtraFieldsmask, + }, + } + testRPCResponseRoundTrip(t, hctx, nil, "de00000000000000e14cc88c0000000800000000000000000903000000000000a1b2c3d4") + testRPCResponseRoundTrip(t, hctx, &Error{Code: 444, Description: "bad"}, "de00000000000000e14cc88c0000000800000000000000000903000000000000f532e47ade00000000000000bc01000003626164") +} + +func TestRPCFormatTL2(t *testing.T) { + req := &Request{ + Body: []byte{0xaa, 0xbb, 0xcc, 0xdd}, + FunctionName: "memcache.Get", + queryID: 222, + BodyFormatTL2: true, + } + testRPCRequestRoundTrip(t, req, "de00000000000000544c3230aabbccdd") + testRPCRequestRoundTrip(t, req, "de00000000000000544c3230aabbccdd") + req.ActorID = 111 + testRPCRequestRoundTrip(t, req, "de00000000000000bdaa68756f00000000000000544c3230aabbccdd") + req.ActorID = 0 + req.Extra.SetCustomTimeoutMs(255) + req.Extra.SetReturnViewNumber(true) + requestExtraFieldsmask := req.Extra.Flags + testRPCRequestRoundTrip(t, req, "de000000000000005e0352e300008008ff000000544c3230aabbccdd") + req.ActorID = 111 + testRPCRequestRoundTrip(t, req, "de00000000000000f7aca5f06f0000000000000000008008ff000000544c3230aabbccdd") + + hctx := &HandlerContext{ + queryID: 222, + Response: []byte{0xa1, 0xb2, 0xc3, 0xd4}, + handlerContextFields: handlerContextFields{ + requestExtraFieldsmask: requestExtraFieldsmask, + bodyFormatTL2: true, + }, + } + testRPCResponseRoundTrip(t, hctx, nil, "de00000000000000e14cc88c0000000800000000000000000903000000000000544c3230a1b2c3d4") + testRPCResponseRoundTrip(t, hctx, &Error{Code: 444, Description: "bad"}, "de00000000000000e14cc88c0000000800000000000000000903000000000000f532e47ade00000000000000bc01000003626164") +} diff --git a/pkg/rpc/rpc_test.go b/pkg/rpc/rpc_test.go new file mode 100644 index 000000000..2cf5a318f --- /dev/null +++ b/pkg/rpc/rpc_test.go @@ -0,0 +1,66 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "context" + "fmt" + "net" + "testing" + + "github.com/VKCOM/tl/pkg/rpc/internal/gen/tltracing" + "github.com/stretchr/testify/assert" +) + +func TestErrorTag(t *testing.T) { + foo := fmt.Errorf("foo") + bar := fmt.Errorf("foo: %w", foo) + baz := &tagError{tag: "baz", err: bar} + bzz := fmt.Errorf("bzz: %w", baz) + tag := ErrorTag(bzz) + assert.Equal(t, tag, "baz:foo") + assert.Empty(t, ErrorTag(nil)) + assert.Empty(t, ErrorTag(&tagError{tag: ""})) + assert.Equal(t, ErrorTag(&tagError{tag: "", err: &net.OpError{Err: context.DeadlineExceeded}}), "timeout") +} + +func TestTraceID(t *testing.T) { + // case by case test is enough here + tEmpty := tltracing.TraceID{} + str := TraceIDToString(tEmpty) + assert.Equal(t, str, "") + t1, err := TraceIDFromString(str) + assert.NoError(t, err) + assert.Equal(t, t1, tEmpty) + + t0 := tltracing.TraceID{ + Lo: 0x0123456789abcdef, + Hi: 0x123456789abcdef0, + } + t0strCanonical := "123456789abcdef00123456789abcdef" + t0strUppercase := "123456789ABCDEF00123456789ABCDEF" + t0strMixed := "123456789ABCDEF00123456789abcdef" + str = TraceIDToString(t0) + assert.Equal(t, str, t0strCanonical) + + t1, err = TraceIDFromString(t0strCanonical) + assert.NoError(t, err) + assert.Equal(t, t1, t0) + t1, err = TraceIDFromString(t0strUppercase) + assert.NoError(t, err) + assert.Equal(t, t1, t0) + t1, err = TraceIDFromString(t0strMixed) + assert.NoError(t, err) + assert.Equal(t, t1, t0) + + t1, err = TraceIDFromString("hren") // wrong length + assert.Error(t, err) + assert.Equal(t, t1, tEmpty) + t1, err = TraceIDFromString("G0123456789ABCDE0123456789ABCDEF") // wrong char + assert.Error(t, err) + assert.Equal(t, t1, tEmpty) +} diff --git a/pkg/rpc/server_longpoll_test.go b/pkg/rpc/server_longpoll_test.go new file mode 100644 index 000000000..5f60ec44d --- /dev/null +++ b/pkg/rpc/server_longpoll_test.go @@ -0,0 +1,680 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "context" + "fmt" + "net" + "strconv" + "sync" + "testing" + "time" + + "pgregory.net/rand" + "pgregory.net/rapid" + + "github.com/VKCOM/tl/internal/vkgo/pkg/basictl" +) + +const emptyBody = "empty response" + +func TestLongpollServer(t *testing.T) { + t.Parallel() + + // this is not really a property-based test, since it is not deterministic + // however, biased integer generators from rapid are very convenient + rapid.Check(t, testLongpollServer) +} + +func testLongpollServer(t *rapid.T) { + if debugPrint { + fmt.Printf("---- TestLongpollServer\n") + } + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + // var clients []*Client + clients := rapid.SliceOf(rapid.Custom(genClient)).Draw(t, "clients") + if len(clients) == 0 { + clients = append(clients, genClient(t)) + } + numRequests := rapid.IntRange(1, 10).Draw(t, "numRequests") + + ts := shutdownTestServer{clients: map[LongpollHandle]int32{}} + s := NewServer( + ServerWithSyncHandler(ts.testShutdownHandler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithDebugRPC(true), + ServerWithMaxConns(len(clients)), // rapid.IntRange(0, 3).Draw(t, "maxConns") + ServerWithMaxWorkers(rapid.IntRange(-1, 3).Draw(t, "maxWorkers")), + ServerWithConnReadBufSize(rapid.IntRange(0, 64).Draw(t, "connReadBufSize")), + ServerWithConnWriteBufSize(rapid.IntRange(0, 64).Draw(t, "connWriteBufSize")), + ServerWithRequestBufSize(rapid.IntRange(512, 1024).Draw(t, "requestBufSize")), + ServerWithResponseBufSize(rapid.IntRange(512, 1024).Draw(t, "responseBufSize")), + ) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + if err := s.Serve(ln); err != nil { + t.Fatal(err) + } + wg.Done() + }() + + var sendWG sync.WaitGroup + var receiveWG sync.WaitGroup + var cancelMu sync.Mutex + var cancelFuncs []func() + + sendWG.Add(len(clients) * numRequests) + receiveWG.Add(len(clients) * numRequests) + for _, c := range clients { + go func(c Client) { + for j := 0; j < numRequests; j++ { + go func() { + n := rand.New().Int31() + req := c.GetRequest() + req.FailIfNoConnection = true + req.Body = basictl.NatWrite(req.Body, testRequestType) + req.Body = basictl.IntWrite(req.Body, n) + + ctx := context.Background() + if n%2 == 0 { + ctx2, cancel := context.WithCancel(context.Background()) + ctx = ctx2 + cancelMu.Lock() + cancelFuncs = append(cancelFuncs, cancel) + cancelMu.Unlock() + } + sendWG.Done() + resp, _ := c.Do(ctx, "tcp4", ln.Addr().String(), req) + defer c.PutResponse(resp) + receiveWG.Done() + }() + } + }(c) + } + time.Sleep(10 * time.Millisecond) // bad + sendWG.Wait() // everything sent + if debugPrint { + fmt.Printf("everything sent\n") + } + for _, c := range cancelFuncs { + c() + } + ts.sendSomeResponses() + time.Sleep(20 * time.Millisecond) // bad + s.Shutdown() + receiveWG.Wait() + if debugPrint { + fmt.Printf("everything received\n") + } + for _, c := range clients { + _ = c.Close() + } + wg.Wait() + + err = s.Close() + if err != nil { + t.Fatal(err) + } + if len(ts.clients) != 0 { + t.Fatal("long poll contexts did not clear in server") + } +} + +type longpollTestServer struct { + mu sync.Mutex + handleToId map[LongpollHandle]int + idToHandle map[int]LongpollHandle + cancellationsCount int + emptyResponsesCount int + handlerCallback func() + cancellationCallback func() + emptyResponseCallback func() + customTimeout time.Duration +} + +type longpollTestServerOption func(lts *longpollTestServer) + +func withCustomTimeout(timeout time.Duration) longpollTestServerOption { + return func(lts *longpollTestServer) { + lts.customTimeout = timeout + } +} + +func newLongpollTestServer(opts ...longpollTestServerOption) *longpollTestServer { + lts := &longpollTestServer{ + handleToId: map[LongpollHandle]int{}, + idToHandle: map[int]LongpollHandle{}, + } + + for _, opt := range opts { + opt(lts) + } + + return lts +} + +func (s *longpollTestServer) CancelLongpoll(lh LongpollHandle) { + s.mu.Lock() + id, exists := s.handleToId[lh] + if !exists { + s.mu.Unlock() + return + } + s.cancellationsCount++ + delete(s.handleToId, lh) + delete(s.idToHandle, id) + s.mu.Unlock() + if s.cancellationCallback != nil { + s.cancellationCallback() + } +} + +func (s *longpollTestServer) WriteEmptyResponse(lh LongpollHandle, resp *HandlerContext) error { + s.mu.Lock() + id, exists := s.handleToId[lh] + if !exists { + s.mu.Unlock() + return ErrLongpollNoEmptyResponse + } + s.emptyResponsesCount++ + delete(s.handleToId, lh) + delete(s.idToHandle, id) + s.mu.Unlock() + + resp.Response = basictl.StringWrite(resp.Response, emptyBody) + if s.emptyResponseCallback != nil { + s.emptyResponseCallback() + } + return nil +} + +func (s *longpollTestServer) WakeUpLongpoll(id int) { + s.mu.Lock() + defer s.mu.Unlock() + + lh, exists := s.idToHandle[id] + if !exists { + return + } + + delete(s.idToHandle, id) + delete(s.handleToId, lh) + + hctx, _ := lh.FinishLongpoll() + if hctx == nil { + return + } + + // Don't store int int tl here, because then I actually need maybe instead of string + hctx.Response = basictl.StringWrite(hctx.Response, strconv.FormatInt(int64(id), 10)) + hctx.SendLongpollResponse(nil) +} + +func (s *longpollTestServer) Handler(_ context.Context, hctx *HandlerContext) (err error) { + if s.handlerCallback != nil { + defer s.handlerCallback() + } + + if hctx.Request, err = basictl.NatReadExactTag(hctx.Request, testRequestType); err != nil { + return err + } + var id int32 + if hctx.Request, err = basictl.IntRead(hctx.Request, &id); err != nil { + return err + } + + var lh LongpollHandle + if s.customTimeout != 0 { + lh, err = hctx.StartLongpollWithTimeoutDeprecated(s, s.customTimeout) + } else { + lh, err = hctx.StartLongpoll(s) + } + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.handleToId[lh] = int(id) + s.idToHandle[int(id)] = lh + + return nil +} + +func (s *longpollTestServer) CancellationsCount() int { + s.mu.Lock() + defer s.mu.Unlock() + + return s.cancellationsCount +} + +func (s *longpollTestServer) EmptyResponsesCount() int { + s.mu.Lock() + defer s.mu.Unlock() + + return s.emptyResponsesCount +} + +func TestLongpollTimeout(t *testing.T) { + t.Run("empty body is sent after timeout", func(t *testing.T) { + ts := newLongpollTestServer() + s := NewServer( + ServerWithSyncHandler(ts.Handler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithDebugRPC(true), + ServerWithMinimumLongpollTimeout(500*time.Millisecond), + ) + + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + errCh := make(chan error) + go func() { + errCh <- s.Serve(ln) + }() + + c := NewClient() + req := c.GetRequest() + req.FailIfNoConnection = true + req.Body = basictl.NatWrite(req.Body, testRequestType) + + n := rand.New().Int31() + req.Body = basictl.IntWrite(req.Body, n) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + resp, err := c.Do(ctx, "tcp4", ln.Addr().String(), req) + if err != nil { + t.Fatalf("unexpected err in client do: %s", err.Error()) + } + + var body string + if _, err := basictl.StringRead(resp.Body, &body); err != nil { + t.Fatalf("unexpected err in string read: %s", err.Error()) + } + + if body != emptyBody { + t.Fatalf("unexpected body in response: %s", body) + } + + s.Shutdown() + if err := s.Close(); err != nil { + t.Fatalf("unexpected err in close: %s", err.Error()) + } + + err = <-errCh + if err != nil { + t.Fatalf("unexpected error in serve: %s", err.Error()) + } + + if ts.cancellationsCount != 0 { + t.Fatalf("unexpected cancel in sever: %d", ts.cancellationsCount) + } + }) + + t.Run("cancel during waiting", func(t *testing.T) { + handleChan := make(chan struct{}, 1) + cancelChan := make(chan struct{}, 1) + ts := newLongpollTestServer() + ts.handlerCallback = func() { + handleChan <- struct{}{} + } + ts.cancellationCallback = func() { + cancelChan <- struct{}{} + } + + s := NewServer( + ServerWithSyncHandler(ts.Handler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithDebugRPC(true), + ) + + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + errCh := make(chan error) + go func() { + errCh <- s.Serve(ln) + }() + + c := NewClient() + req := c.GetRequest() + req.FailIfNoConnection = true + req.Body = basictl.NatWrite(req.Body, testRequestType) + + n := rand.New().Int31() + req.Body = basictl.IntWrite(req.Body, n) + + clientErrCh := make(chan error) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Hour) + go func() { + _, err = c.Do(ctx, "tcp4", ln.Addr().String(), req) + clientErrCh <- err + }() + // Make sure that request was received by the server + <-handleChan + cancel() + if err := <-clientErrCh; err != context.Canceled { + t.Fatalf("unexpected err in client do: %s", err.Error()) + } + // Make sure that cancel was called by the server + <-cancelChan + if err := s.Close(); err != nil { + t.Fatalf("unexpected err in close: %s", err.Error()) + } + + err = <-errCh + if err != nil { + t.Fatalf("unexpected error in serve: %s", err.Error()) + } + + if ts.CancellationsCount() != 1 { + t.Fatalf("expected exactly one cancel: %d", ts.CancellationsCount()) + } + if s.longpollTree.Size() != 0 { + t.Fatalf("longpoll tree expected to be empty") + } + }) + + t.Run("shutdown during waiting", func(t *testing.T) { + handleChan := make(chan struct{}, 1) + emptyResponseChan := make(chan struct{}, 1) + ts := newLongpollTestServer() + ts.handlerCallback = func() { + handleChan <- struct{}{} + } + ts.emptyResponseCallback = func() { + emptyResponseChan <- struct{}{} + } + + s := NewServer( + ServerWithSyncHandler(ts.Handler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithDebugRPC(true), + ) + + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + errCh := make(chan error) + go func() { + errCh <- s.Serve(ln) + }() + + c := NewClient() + req := c.GetRequest() + req.FailIfNoConnection = true + req.Body = basictl.NatWrite(req.Body, testRequestType) + + n := rand.New().Int31() + req.Body = basictl.IntWrite(req.Body, n) + + clientErrCh := make(chan error) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Hour) + defer cancel() + go func() { + _, err = c.Do(ctx, "tcp4", ln.Addr().String(), req) + clientErrCh <- err + }() + // Make sure that request was received by the server + <-handleChan + s.Shutdown() + if err := <-clientErrCh; err != nil { + t.Fatalf("unexpected err in client do: %v", err) + } + // Make sure that cancel was called by the server + <-emptyResponseChan + if err := s.Close(); err != nil { + t.Fatalf("unexpected err in close: %v", err) + } + + err = <-errCh + if err != nil { + t.Fatalf("unexpected error in serve: %s", err.Error()) + } + + if ts.EmptyResponsesCount() != 1 { + t.Fatalf("expected exactly one empty response: %d", ts.EmptyResponsesCount()) + } + if s.longpollTree.Size() != 0 { + t.Fatalf("longpoll tree expected to be empty") + } + }) + + t.Run( + "StartLongpollWithTimeoutDeprecated rewrites rpc.Request extra timeout", + func(t *testing.T) { + ts := newLongpollTestServer(withCustomTimeout(time.Millisecond)) + s := NewServer( + ServerWithSyncHandler(ts.Handler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithDebugRPC(true), + ) + + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + errCh := make(chan error) + go func() { + errCh <- s.Serve(ln) + }() + defer s.Shutdown() + + c := NewClient() + req := c.GetRequest() + req.FailIfNoConnection = true + req.Body = basictl.NatWrite(req.Body, testRequestType) + + n := rand.New().Int31() + req.Body = basictl.IntWrite(req.Body, n) + + // Set really big timeout here + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Hour) + defer cancel() + + // If custom timeout works then this req will take 1 ms + resp, err := c.Do(ctx, "tcp4", ln.Addr().String(), req) + if err != nil { + t.Fatal(err) + } + + var body string + if _, err := basictl.StringRead(resp.Body, &body); err != nil { + t.Fatalf("unexpected err in string read: %s", err.Error()) + } + + if body != emptyBody { + t.Fatalf("unexpected body in response: %s", body) + } + + s.Shutdown() + if err := <-errCh; err != nil { + t.Fatalf("serve returned an error: %s", err.Error()) + } + }, + ) +} + +type client struct { + id int32 + ctx context.Context + cancel context.CancelFunc + readyCh chan struct { + response *Response + err error + } + deadline time.Time +} + +type longpollServerStateMachine struct { + mode string // udp or tcp + clients []client // waiting longpolls + server *longpollTestServer + rpcServer *Server + rpcServerErr chan error + addr string + emptyResponses int + nextId int32 +} + +func NewLongpollServerStateMachine(t *rapid.T) *longpollServerStateMachine { + mode := "tcp4" + // TODO: там отдельный метод для udp + // if rapid.Bool().Draw(t, "udp") { + // mode = "udp" + // } + + longpollServer := newLongpollTestServer() + s := NewServer( + ServerWithSyncHandler(longpollServer.Handler), + ServerWithCryptoKeys(testCryptoKeys), + // ServerWithDebugRPC(true), + ) + ln, err := net.Listen(mode, "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + rpcServerErr := make(chan error) + go func() { + rpcServerErr <- s.Serve(ln) + }() + + return &longpollServerStateMachine{ + mode: mode, + server: longpollServer, + rpcServer: s, + rpcServerErr: rpcServerErr, + addr: ln.Addr().String(), + } +} + +func (lssm *longpollServerStateMachine) SendLongpoll(t *rapid.T) { + var opts []ClientOptionsFunc + if lssm.mode == "udp" { + opts = append(opts, ClientWithExperimentalLocalUDPAddress("127.0.0.1:")) + } + + c := NewClient(opts...) + req := c.GetRequest() + req.FailIfNoConnection = true + + req.Body = basictl.NatWrite(req.Body, testRequestType) + + id := lssm.nextId + lssm.nextId++ + + req.Body = basictl.IntWrite(req.Body, id) + + timeout := rapid.IntRange(100, 200).Draw(t, "requestTimeout") + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + readyCh := make(chan struct { + response *Response + err error + }, 1) + + lssm.clients = append(lssm.clients, client{ + id: id, + ctx: ctx, + cancel: cancel, + readyCh: readyCh, + deadline: time.Now().Add(time.Duration(timeout) * time.Millisecond), + }) + + go func() { + resp, err := c.Do(ctx, lssm.mode, lssm.addr, req) + readyCh <- struct { + response *Response + err error + }{ + response: resp, + err: err, + } + }() +} + +// func (lssm *longpollServerStateMachine) CancelLongpoll(t *rapid.T) { + +// } + +// func (lssm *longpollServerStateMachine) WaitEmptyResponse(t *rapid.T) { + +// } + +func (lssm *longpollServerStateMachine) WakeUpLongpoll(t *rapid.T) { + if len(lssm.clients) == 0 { + return + } + + // wake up first client + client := lssm.clients[0] + lssm.clients = lssm.clients[1:] + lssm.server.WakeUpLongpoll(int(client.id)) + + respInfo := <-client.readyCh + if respInfo.err != nil { + t.Fatalf("unexpected err in wake up: %s", respInfo.err.Error()) + } + + var body string + if _, err := basictl.StringRead(respInfo.response.Body, &body); err != nil { + t.Fatalf("unexpected err in string read: %s", err.Error()) + } + + if body == emptyBody { + lssm.emptyResponses++ + return + } + + parsedId, err := strconv.Atoi(body) + if err != nil { + t.Fatalf("can't parse id from body: %s", err.Error()) + } + + if parsedId != int(client.id) { + t.Fatalf("unexpected id in resp: %d", parsedId) + } +} + +func (lssm *longpollServerStateMachine) Check(t *rapid.T) { + +} + +// func TestLongpollProertyBased(t *testing.T) { +// t.Parallel() + +// Now there is a race condition between the moment when longpoll arrives +// and the moment when the longpoll is woken up +// so this test isn't working, should working once the go version is >= 1.24 +// because https://pkg.go.dev/testing/synctest#Test could be used +// rapid.Check(t, func(t *rapid.T) { +// sm := NewLongpollServerStateMachine(t) +// t.Repeat(rapid.StateMachineActions(sm)) +// sm.rpcServer.Close() +// <-sm.rpcServerErr +// // TODO: в конце попритить пустые запросы и проверить, что нету тех кому ответ не пришел +// }) +// } diff --git a/pkg/rpc/server_shutdown_test.go b/pkg/rpc/server_shutdown_test.go new file mode 100644 index 000000000..857a1e6f7 --- /dev/null +++ b/pkg/rpc/server_shutdown_test.go @@ -0,0 +1,194 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "pgregory.net/rand" + "pgregory.net/rapid" + + "github.com/VKCOM/tl/internal/vkgo/pkg/basictl" +) + +type shutdownTestServer struct { + mu sync.Mutex + clients map[LongpollHandle]int32 +} + +func (s *shutdownTestServer) CancelLongpoll(lh LongpollHandle) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.clients, lh) +} + +func (s *shutdownTestServer) WriteEmptyResponse(lh LongpollHandle, hctx *HandlerContext) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.clients, lh) + + return ErrLongpollNoEmptyResponse +} + +func (s *shutdownTestServer) sendSomeResponses() { + s.mu.Lock() + counter := 0 + for lh, n := range s.clients { + delete(s.clients, lh) + hctx, _ := lh.FinishLongpoll() + if hctx == nil { + continue + } + + hctx.Response = basictl.IntWrite(hctx.Response, n) + hctx.SendLongpollResponse(nil) + counter++ + if counter >= len(s.clients) { + break // approx. half sent + } + } + s.mu.Unlock() +} + +func (s *shutdownTestServer) testShutdownHandler(_ context.Context, hctx *HandlerContext) (err error) { + if hctx.Request, err = basictl.NatReadExactTag(hctx.Request, testRequestType); err != nil { + return err + } + var n int32 + if hctx.Request, err = basictl.IntRead(hctx.Request, &n); err != nil { + return err + } + + hctx.Response = basictl.IntWrite(hctx.Response, n) + if (n/2)%2 == 0 { + s.mu.Lock() + defer s.mu.Unlock() + lctx, err := hctx.StartLongpoll(s) + if err != nil { + return err + } + s.clients[lctx] = n + return nil + } + return nil +} + +func TestShutdownClient(t *testing.T) { + t.Parallel() + + // this is not really a property-based test, since it is not deterministic + // however, biased integer generators from rapid are very convenient + rapid.Check(t, testShutdownClient) +} + +func testShutdownClient(t *rapid.T) { + ln, err := net.Listen("tcp4", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + + clients := rapid.SliceOf(rapid.Custom(genClient)).Draw(t, "clients") + if len(clients) == 0 { + clients = append(clients, genClient(t)) + } + clients = clients[:1] + numRequests := rapid.IntRange(1, 10).Draw(t, "numRequests") + + ts := shutdownTestServer{clients: map[LongpollHandle]int32{}} + s := NewServer( + ServerWithSyncHandler(ts.testShutdownHandler), + ServerWithCryptoKeys(testCryptoKeys), + ServerWithMaxConns(rapid.IntRange(0, 3).Draw(t, "maxConns")), + ServerWithMaxWorkers(rapid.IntRange(-1, 3).Draw(t, "maxWorkers")), + ServerWithConnReadBufSize(rapid.IntRange(0, 64).Draw(t, "connReadBufSize")), + ServerWithConnWriteBufSize(rapid.IntRange(0, 64).Draw(t, "connWriteBufSize")), + ServerWithRequestBufSize(rapid.IntRange(512, 1024).Draw(t, "requestBufSize")), + ServerWithResponseBufSize(rapid.IntRange(512, 1024).Draw(t, "responseBufSize")), + ) + + serverErrChan := make(chan error, 1) + var wg sync.WaitGroup + wg.Add(1) + go func() { + serverErr := s.Serve(ln) + serverErrChan <- serverErr + wg.Done() + }() + + for _, c := range clients { + wg.Add(1) + go func(c Client) { + defer wg.Done() + + var cwg sync.WaitGroup + for j := 0; j < numRequests; j++ { + cwg.Add(1) + go func(j int) { + defer cwg.Done() + + n := rand.New().Int31() + req := c.GetRequest() + req.Body = basictl.NatWrite(req.Body, testRequestType) + req.Body = basictl.IntWrite(req.Body, n) + + resp, _ := c.Do(context.Background(), "tcp4", ln.Addr().String(), req) + defer c.PutResponse(resp) + }(j) + } + + cwg.Wait() + }(c) + } + + time.Sleep(10 * time.Millisecond) // bad + ts.sendSomeResponses() + time.Sleep(10 * time.Millisecond) // bad + for _, c := range clients { + _ = c.Close() + } + time.Sleep(10 * time.Millisecond) // bad + s.Shutdown() + wg.Wait() + + err = s.Close() + if err != nil { + t.Fatal(err) + } + err = <-serverErrChan + if err != nil { + t.Fatal(err) + } + if len(ts.clients) != 0 { + t.Fatal("long poll contexts did not clear in server") + } +} + +func TestPreferTLVersionFunc(t *testing.T) { + ret3 := func(ctx context.Context, req *Request, defaultTLVersion int) int { + return 3 + } + ctx := WithPreferTLVersionFunc(context.Background(), ret3) + v := 0 + if f := GetPreferTLVersionFunc(ctx); f != nil { + v = f(context.Background(), nil, v) + } + if v != 3 { + t.Fatal("preferTLVersionFunc returned unexpected value") + } + ctx = context.Background() + v = 0 + if f := GetPreferTLVersionFunc(ctx); f != nil { + v = f(context.Background(), nil, v) + } + if v != 0 { + t.Fatal("preferTLVersionFunc returned unexpected value") + } +} diff --git a/pkg/rpc/udp_server_test.go b/pkg/rpc/udp_server_test.go new file mode 100644 index 000000000..29f9cfcbc --- /dev/null +++ b/pkg/rpc/udp_server_test.go @@ -0,0 +1,47 @@ +// Copyright 2025 V Kontakte LLC +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rpc + +import "testing" + +func TestUDPRequestBufPoolDropsLargeBuffers(t *testing.T) { + t.Skip("flaky test, see https://teamcity.vkteam.ru/buildConfiguration/Backend_Vkgo_Tests_CheckBase/14028168?buildTab=log&linesState=624&logView=flowAware&focusLine=22473") + + const ( + requestBufSize = 1024 + responseBufSize = 2048 + ) + + s := NewServer( + ServerWithRequestBufSize(requestBufSize), + ServerWithResponseBufSize(responseBufSize), + ) + + pooled := s.allocateRequestBufUDP(responseBufSize) + pooledCap := cap(*pooled) + if pooledCap < responseBufSize { + t.Fatalf("pooled buffer cap = %d, want at least %d", pooledCap, responseBufSize) + } + s.deAllocateRequestBufUDP(pooled) + + reused := s.allocateRequestBufUDP(requestBufSize) + if cap(*reused) != pooledCap { + t.Fatalf("reused buffer cap = %d, want %d", cap(*reused), pooledCap) + } + s.deAllocateRequestBufUDP(reused) + + large := s.allocateRequestBufUDP(responseBufSize + 1) + if cap(*large) <= responseBufSize { + t.Fatalf("large buffer cap = %d, want greater than %d", cap(*large), responseBufSize) + } + s.deAllocateRequestBufUDP(large) + + afterLarge := s.allocateRequestBufUDP(requestBufSize) + if cap(*afterLarge) > responseBufSize { + t.Fatalf("buffer pool retained large buffer cap = %d, limit %d", cap(*afterLarge), responseBufSize) + } +}