From a28f58bb1e9b710a8b4010617bc2ee5ec762e3f1 Mon Sep 17 00:00:00 2001 From: Sophia <193290223+soap-phia@users.noreply.github.com> Date: Mon, 18 May 2026 14:46:58 +0000 Subject: [PATCH 1/2] lowk --- .gitignore | 2 +- Dockerfile | 24 +++ build.sh | 48 +++--- example.config.json | 48 +++++- go.mod | 5 +- go.sum | 6 +- main.go | 193 +++++------------------ package.json | 24 +-- pnpm-lock.yaml | 94 +++++++++++ src/index.ts | 268 +++++++++++++++++++++++++++++++- src/logger.ts | 40 +++++ src/path.ts | 16 +- src/server/index.ts | 243 ----------------------------- src/types.d.ts | 71 --------- wisp/config.go | 178 +++++++++++++++++++++ wisp/dnscache.go | 151 +++++++++++++++--- wisp/logger.go | 58 +++++++ wisp/protection.go | 173 +++++++++++++++++++++ wisp/protection/banlist.go | 95 +++++++++++ wisp/protection/iputil.go | 150 ++++++++++++++++++ wisp/protection/limits.go | 177 +++++++++++++++++++++ wisp/protection/streamlimits.go | 38 +++++ wisp/twisp.go | 12 +- wisp/v2.go | 76 +++------ wisp/windows.go | 62 ++++++++ wisp/wisp-connection.go | 178 ++++++++++++++++++--- wisp/wisp-stream.go | 112 ++++++++----- wisp/wisp.go | 152 +++++++++--------- wisp/wsreader.go | 37 ++++- 29 files changed, 2005 insertions(+), 726 deletions(-) create mode 100644 Dockerfile mode change 100644 => 100755 build.sh create mode 100644 pnpm-lock.yaml create mode 100644 src/logger.ts delete mode 100644 src/server/index.ts delete mode 100644 src/types.d.ts create mode 100644 wisp/config.go create mode 100644 wisp/logger.go create mode 100644 wisp/protection.go create mode 100644 wisp/protection/banlist.go create mode 100644 wisp/protection/iputil.go create mode 100644 wisp/protection/limits.go create mode 100644 wisp/protection/streamlimits.go create mode 100644 wisp/windows.go diff --git a/.gitignore b/.gitignore index 0c01eb2..8ff7b72 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ dist node_modules config.json -mrrowisp +mrrowisp* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..84be1de --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +FROM golang:1.21.7-alpine AS builder + +WORKDIR /app + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . + +RUN CGO_ENABLED=0 GOOS=linux go build -o mrrowisp main.go + +FROM alpine:3.19 + +RUN apk --no-cache add ca-certificates && addgroup -S mrrowisp && adduser -S mrrowisp -G mrrowisp + +WORKDIR /app + +COPY --from=builder /app/mrrowisp . + +USER mrrowisp + +EXPOSE 6001 + +CMD ["./mrrowisp", "-config", "config.json"] diff --git a/build.sh b/build.sh old mode 100644 new mode 100755 index 1c56f50..fc39ced --- a/build.sh +++ b/build.sh @@ -1,26 +1,38 @@ #!/bin/bash -mkdir -p dist +rm -rf bin +mkdir -p bin if [ ! -f "main.go" ]; then - echo "main.go not found." - exit 1 + echo "main.go not found." + exit 1 fi -for os in linux darwin; do - if [ "$os" = "win32" ]; then - goos="windows" - else - goos=$os - fi - for arch in x64 arm64; do - if [ "$arch" = "x64" ]; then - goarch="amd64" - else - goarch=$arch - fi - GOOS=$goos GOARCH=$goarch go build -o ./dist/${os}-${arch}/mrrowisp main.go - done +for os in linux darwin win32; do + if [ "$os" = "win32" ]; then + goos="windows" + else + goos=$os + fi + for arch in x64 arm64; do + if [ "$arch" = "x64" ]; then + goarch="amd64" + else + goarch=$arch + fi + if [ "$os" = "win32" ]; then + ext=".exe" + else + ext="" + fi + mkdir -p bin/${os}-${arch} + GOOS=$goos GOARCH=$goarch go build -o ./bin/${os}-${arch}/mrrowisp${ext} main.go + done done -echo "Finished building. Binaries in ./dist/PLATFORM-ARCH/mrrowisp" \ No newline at end of file +echo "Finished building to ./bin/" + +cp package.json README.md LICENSE dist/ +cp example.config.json dist/config.json + +echo "Finished copying package files to ./dist/" diff --git a/example.config.json b/example.config.json index eb3d3cc..97128f1 100644 --- a/example.config.json +++ b/example.config.json @@ -1,19 +1,30 @@ { "port": 6001, - "disableUDP": false, + "allowTCP": true, + "allowUDP": true, + "allowDirectIP": false, + "allowPrivateIPs": false, + "allowLoopbackIPs": false, "tcpBufferSize": 65535, "bufferRemainingLength": 1024, "tcpNoDelay": true, "websocketTcpNoDelay": true, + "streamLimitPerHost": 64, + "streamLimitTotal": 2048, "blacklist": { - "hostnames": [] + "hostnames": [], + "ports": [] }, "whitelist": { - "hostnames": [] + "hostnames": [], + "ports": [] }, "proxy": "", "websocketPermessageDeflate": false, - "dnsServer": [], + "dnsServers": [], + "dnsTTLSeconds": 120, + "dnsMethod": "lookup", + "dnsResultOrder": "verbatim", "enableTwisp": false, "enableV2": true, "motd": "", @@ -23,5 +34,30 @@ "certAuth": false, "certAuthRequired": false, "certAuthPublicKeys": [], - "enableStreamConfirm": false -} \ No newline at end of file + "enableStreamConfirm": false, + "maxConnectsPerSecond": 5, + "bandwidthLimitKbps": 5000, + "connectionsLimitPerIP": 20, + "connectionWindowSeconds": 10, + "parseRealIP": true, + "parseRealIPFrom": ["127.0.0.1"], + "maxMessageSize": 0, + "staticDir": "", + "nonWSResponse": "mrrow merp >w<", + "allowedOrigins": [], + "writeTimeoutSeconds": 15, + "frameReadTimeoutSeconds": 30, + "logLevel": "info", + "banEnabled": true, + "banDurationSeconds": 3600, + "banMaxStrikes": 5, + "banEscalationMultiplier": 0, + "maxHandshakeFailures": 10, + "maxPacketRate": 500, + "maxConnectionLifetimeSeconds": 0, + "maxStreamsPerConnection": 0, + "maxConnectionsPerIP": 0, + "globalMaxConnections": 0, + "writeQueueSize": 4096, + "maxInboundBytesPerSecond": 0 +} diff --git a/go.mod b/go.mod index c52691a..9de587f 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,12 @@ module mrrowisp -go 1.21 +go 1.25.0 require ( github.com/creack/pty v1.1.21 github.com/lxzan/gws v1.8.3 - golang.org/x/net v0.24.0 + golang.org/x/net v0.53.0 + golang.org/x/sync v0.9.0 ) require ( diff --git a/go.sum b/go.sum index 071bc24..7342238 100644 --- a/go.sum +++ b/go.sum @@ -12,7 +12,9 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index eff978f..31d5e41 100644 --- a/main.go +++ b/main.go @@ -1,187 +1,76 @@ package main import ( - "crypto/ed25519" - "encoding/hex" - "encoding/json" + "context" "flag" "fmt" "net/http" "os" - "strings" + "os/signal" + "syscall" + "time" "mrrowisp/wisp" ) -type Config struct { - Port int `json:"port"` - DisableUDP bool `json:"disableUDP"` - TcpBufferSize int `json:"tcpBufferSize"` - BufferRemainingLength uint32 `json:"bufferRemainingLength"` - TcpNoDelay bool `json:"tcpNoDelay"` - WebsocketTcpNoDelay bool `json:"websocketTcpNoDelay"` - - Blacklist struct { - Hostnames []string `json:"hostnames"` - } `json:"blacklist"` - Whitelist struct { - Hostnames []string `json:"hostnames"` - } `json:"whitelist"` - - Proxy string `json:"proxy"` - WebsocketPermessageDeflate bool `json:"websocketPermessageDeflate"` - DnsServers []string `json:"dnsServers"` - - EnableTwisp bool `json:"enableTwisp"` - - EnableV2 bool `json:"enableV2"` - Motd string `json:"motd"` - PasswordAuth bool `json:"passwordAuth"` - PasswordAuthRequired bool `json:"passwordAuthRequired"` - PasswordUsers map[string]string `json:"passwordUsers"` - CertAuth bool `json:"certAuth"` - CertAuthRequired bool `json:"certAuthRequired"` - CertAuthPublicKeys []string `json:"certAuthPublicKeys"` - EnableStreamConfirm bool `json:"enableStreamConfirm"` -} - -func defaultConfig() Config { - return Config{ - Port: 6001, - DisableUDP: false, - TcpBufferSize: 32768, - BufferRemainingLength: 65536, - TcpNoDelay: true, - WebsocketTcpNoDelay: true, - WebsocketPermessageDeflate: false, - EnableTwisp: false, - EnableV2: false, - PasswordAuth: false, - PasswordAuthRequired: false, - PasswordUsers: make(map[string]string), - CertAuth: false, - CertAuthRequired: false, - EnableStreamConfirm: false, - } -} - -func loadConfig(config string) (Config, error) { - cfg := defaultConfig() - - trimConfig := strings.TrimSpace(config) - if strings.HasPrefix(trimConfig, "{") { - if err := json.Unmarshal([]byte(trimConfig), &cfg); err != nil { - return cfg, err - } - return cfg, nil - } - - file, err := os.Open(config) - if err != nil { - return cfg, err - } - defer file.Close() - - decoder := json.NewDecoder(file) - if err := decoder.Decode(&cfg); err != nil { - return cfg, err - } - return cfg, nil -} - -func createWispConfig(cfg Config) *wisp.Config { - blacklistedHostnames := make(map[string]struct{}) - for _, host := range cfg.Blacklist.Hostnames { - blacklistedHostnames[host] = struct{}{} - } - - whitelistedHostnames := make(map[string]struct{}) - for _, host := range cfg.Whitelist.Hostnames { - whitelistedHostnames[host] = struct{}{} - } - - var pubKeys []ed25519.PublicKey - for _, hexKey := range cfg.CertAuthPublicKeys { - hexKeyBytes, err := hex.DecodeString(hexKey) - if err != nil { - fmt.Printf("warning: invalid public key hex %q: %v\n", hexKey, err) - continue - } - if len(hexKeyBytes) != ed25519.PublicKeySize { - fmt.Printf("warning: public key %q has invalid length %d (expected %d)\n", hexKey, len(hexKeyBytes), ed25519.PublicKeySize) - continue - } - pubKeys = append(pubKeys, ed25519.PublicKey(hexKeyBytes)) - } - - wispCfg := &wisp.Config{ - DisableUDP: cfg.DisableUDP, - TcpBufferSize: cfg.TcpBufferSize, - BufferRemainingLength: cfg.BufferRemainingLength, - TcpNoDelay: cfg.TcpNoDelay, - WebsocketTcpNoDelay: cfg.WebsocketTcpNoDelay, - Blacklist: struct { - Hostnames map[string]struct{} - }{ - Hostnames: blacklistedHostnames, - }, - Whitelist: struct { - Hostnames map[string]struct{} - }{ - Hostnames: whitelistedHostnames, - }, - Proxy: cfg.Proxy, - WebsocketPermessageDeflate: cfg.WebsocketPermessageDeflate, - DnsServers: cfg.DnsServers, - EnableTwisp: cfg.EnableTwisp, - EnableV2: cfg.EnableV2, - Motd: cfg.Motd, - PasswordAuth: cfg.PasswordAuth, - PasswordAuthRequired: cfg.PasswordAuthRequired, - PasswordUsers: cfg.PasswordUsers, - CertAuth: cfg.CertAuth, - CertAuthRequired: cfg.CertAuthRequired, - CertAuthPublicKeys: pubKeys, - EnableStreamConfirm: cfg.EnableStreamConfirm, - } - - if wispCfg.PasswordUsers == nil { - wispCfg.PasswordUsers = make(map[string]string) - } - - return wispCfg -} - func main() { fConfig := flag.String("config", "", "config to load (file or json string)") fPort := flag.Int("port", 0, "port to run on") + fAllowLoopbackIPs := flag.Bool("allow-loopback", false, "allow loopback IP targets") flag.Parse() - var cfg Config + var cfg wisp.Config var err error if *fConfig != "" { - cfg, err = loadConfig(*fConfig) + cfg, err = wisp.LoadConfig(*fConfig) if err != nil { fmt.Printf("Failed to load config: %v\n", err) return } } else { - cfg = defaultConfig() + cfg = wisp.DefaultConfig() } if *fPort != 0 { cfg.Port = *fPort } + if *fAllowLoopbackIPs != false { + cfg.AllowLoopbackIPs = *fAllowLoopbackIPs + } - wispConfig := createWispConfig(cfg) + wispConfig := wisp.CreateWispConfig(cfg) wispHandler := wisp.CreateWispHandler(wispConfig) - http.HandleFunc("/", wispHandler) - fmt.Printf("Starting Mrrowisp on port %d. . .", cfg.Port) - err = http.ListenAndServe(fmt.Sprintf(":%d", cfg.Port), nil) - if err != nil { - fmt.Printf("Failed to start Mrrowisp: %v", err) + if cfg.StaticDir != "" { + http.Handle("/", http.FileServer(http.Dir(cfg.StaticDir))) + http.HandleFunc("/wisp", wispHandler) + } else { + http.HandleFunc("/", wispHandler) + } + fmt.Printf("[INFO] Starting Mrrowisp on port %d. . .\n", cfg.Port) + server := &http.Server{ + Addr: fmt.Sprintf(":%d", cfg.Port), + ReadHeaderTimeout: 5 * time.Second, + IdleTimeout: 120 * time.Second, + } + + sigch := make(chan os.Signal, 1) + signal.Notify(sigch, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigch + fmt.Printf("[INFO] Shutting down (signal: %s)\n", sig.String()) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if shutdownErr := server.Shutdown(ctx); shutdownErr != nil { + fmt.Printf("[INFO] Shutdown error: %v\n", shutdownErr) + } + }() + + err = server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + fmt.Printf("[INFO] Failed to start Mrrowisp: %v", err) } } diff --git a/package.json b/package.json index 2dd0f35..379a210 100644 --- a/package.json +++ b/package.json @@ -1,25 +1,31 @@ { "name": "mrrowisp", - "version": "1.2.2", + "repository": { + "type": "git", + "url": "https://github.com/soap-phia/mrrowisp" + }, + "version": "2.10.0", "module": "index.ts", "type": "module", "author": "soap-phia", "license": "BSD-3-Clause", "private": false, "main": "dist/index.js", + "types": "dist/index.d.ts", "devDependencies": { - "@types/bun": "^1.3.11" + "@types/bun": "^1.3.11", + "@types/node": "^25.5.0", + "@types/ws": "^8.18.1" }, "peerDependencies": { - "typescript": "^5" + "typescript": "^6.0.3", + "chalk": "^5.6.2" }, "dependencies": { - "@types/node": "^25.5.0", - "chalk": "^5.6.2" + "detect-port": "^2.1.0" }, "scripts": { - "build:ts": "tsc && cp ./example.config.json ./dist/config.json", - "build:bin": "bash build.sh", - "build": "bun run build:ts && bun run build:bin" + "compile": "tsc", + "build": "pnpm run compile && chmod +x ./build.sh && ./build.sh" } -} +} \ No newline at end of file diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml new file mode 100644 index 0000000..4c05795 --- /dev/null +++ b/pnpm-lock.yaml @@ -0,0 +1,94 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +importers: + + .: + dependencies: + chalk: + specifier: ^5.6.2 + version: 5.6.2 + detect-port: + specifier: ^2.1.0 + version: 2.1.0 + typescript: + specifier: ^6.0.3 + version: 6.0.3 + devDependencies: + '@types/bun': + specifier: ^1.3.11 + version: 1.3.14 + '@types/node': + specifier: ^25.5.0 + version: 25.8.0 + '@types/ws': + specifier: ^8.18.1 + version: 8.18.1 + +packages: + + '@types/bun@1.3.14': + resolution: {integrity: sha512-h1hFqFVcvAvD9j9K7ZW7vd82aSA+rTdznZa+5bwvCwqSB1jmmfLcbIWhOLx1/+boy/xmjgCs/OMUL8hRJSmnPw==} + + '@types/node@25.8.0': + resolution: {integrity: sha512-TCFSk8IZh+iLX1xtksoBVtdmgL+1IX0fC9BeU4QqFSuNdN/K+HUlhqOzEmSYYpZUVsLYcPqc9KX+60iDuninSQ==} + + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + + address@2.0.3: + resolution: {integrity: sha512-XNAb/a6TCqou+TufU8/u11HCu9x1gYvOoxLwtlXgIqmkrYQADVv6ljyW2zwiPhHz9R1gItAWpuDrdJMmrOBFEA==} + engines: {node: '>= 16.0.0'} + + bun-types@1.3.14: + resolution: {integrity: sha512-4N0ig0fEomHt5R0KCFWjovxow98rIoRwKolrYdCcknNwMekCXRnWEUvgu5soYV8QXtVsrUD8B95MBOZGPvr6KQ==} + + chalk@5.6.2: + resolution: {integrity: sha512-7NzBL0rN6fMUW+f7A6Io4h40qQlG+xGmtMxfbnH/K7TAtt8JQWVQK+6g0UXKMeVJoyV5EkkNsErQ8pVD3bLHbA==} + engines: {node: ^12.17.0 || ^14.13 || >=16.0.0} + + detect-port@2.1.0: + resolution: {integrity: sha512-epZuWb/6Q62L+nDHJc/hQAqf8pylsqgk3BpZXVBx1CDnr3nkrVNn73Uu1rXcFzkNcc+hkP3whuOg7JZYaQB65Q==} + engines: {node: '>= 16.0.0'} + hasBin: true + + typescript@6.0.3: + resolution: {integrity: sha512-y2TvuxSZPDyQakkFRPZHKFm+KKVqIisdg9/CZwm9ftvKXLP8NRWj38/ODjNbr43SsoXqNuAisEf1GdCxqWcdBw==} + engines: {node: '>=14.17'} + hasBin: true + + undici-types@7.24.6: + resolution: {integrity: sha512-WRNW+sJgj5OBN4/0JpHFqtqzhpbnV0GuB+OozA9gCL7a993SmU+1JBZCzLNxYsbMfIeDL+lTsphD5jN5N+n0zg==} + +snapshots: + + '@types/bun@1.3.14': + dependencies: + bun-types: 1.3.14 + + '@types/node@25.8.0': + dependencies: + undici-types: 7.24.6 + + '@types/ws@8.18.1': + dependencies: + '@types/node': 25.8.0 + + address@2.0.3: {} + + bun-types@1.3.14: + dependencies: + '@types/node': 25.8.0 + + chalk@5.6.2: {} + + detect-port@2.1.0: + dependencies: + address: 2.0.3 + + typescript@6.0.3: {} + + undici-types@7.24.6: {} diff --git a/src/index.ts b/src/index.ts index c530be0..538fc65 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,3 +1,267 @@ -export { wispConfigPath, wispPath } from "./path.js"; -export * from "./server/index.js"; +import { spawn, type ChildProcess } from "child_process"; +import { binPath, configPath } from "./path.js"; +import * as fs from "node:fs"; +import { detect } from 'detect-port'; +import logger from "./logger.js"; +import { request, type IncomingMessage } from "node:http"; +import type { Socket } from "node:net"; +type PortEntry = number | [number, number]; + +type MrrowispConfig = { + /** + * TCP port the server listens on. + */ + port: number; + + /** + * Allow clients to open TCP streams. + */ + allowTCP: boolean; + + /** + * Allow clients to open UDP streams. + */ + allowUDP: boolean; + + /** + * Allow direct connections to IP addresses. + */ + allowDirectIP: boolean; + + /** + * Allow connections to private/local IP ranges. + */ + allowPrivateIPs: boolean; + + /** + * Allow connections to loopback IP addresses. + */ + allowLoopbackIPs: boolean; + + /** + * Size of the TCP stream buffer in bytes. + */ + tcpBufferSize: number; + + /** + * Enable TCP_NODELAY on TCP sockets. + */ + tcpNoDelay: boolean; + + /** + * Hostname and port blacklist rules. + */ + blacklist: { + hostnames: string[]; + ports: PortEntry[]; + }; + + /** + * Hostname and port whitelist rules. + */ + whitelist: { + hostnames: string[]; + ports: PortEntry[]; + }; + + /** + * DNS servers used for hostname resolution. + */ + dnsServers: string[]; + + /** + * DNS resolution method. + */ + dnsMethod: "lookup" | "resolve"; + + /** + * Preferred ordering for resolved IP addresses. + */ + dnsResultOrder: "ipv4first" | "ipv6first" | "verbatim"; + + /** + * Enable TWisp experimental protocol. + */ + enableTwisp: boolean; + + /** + * Enable Wisp v2 protocol support. + */ + enableV2: boolean; + + /** + * Message of the day sent during handshake. + */ + motd: string; + + /** + * Enable password authentication. + */ + passwordAuth: boolean; + + /** + * Require password authentication for all clients. + */ + passwordAuthRequired: boolean; + + /** + * Username/password credential map. + */ + passwordUsers: Map; + + /** + * Parse reverse-proxy real IP headers. + */ + parseRealIP: boolean; + + /** + * HTTP response returned for non-WebSocket requests. + */ + nonWSResponse: string; + + /** + * Logging verbosity level. + */ + logLevel: "debug" | "warn" | "error" | "info"; + + /** + * Optional upstream proxy URL (SOCKS/HTTP). + */ + proxy: string; + + /** + * Maximum WebSocket message size in bytes. + */ + maxMessageSize: number; + + /** + * Directory for static file serving. + */ + staticDir: string; + + /** + * Bandwidth limit per IP in Kbps. + */ + bandwidthLimitKbps: number; + + /** + * Connection rate limit per IP. + */ + connectionsLimitPerIP: number; + + /** + * Connection rate limit window in seconds. + */ + connectionWindowSeconds: number; +}; + +const defaultConfig: MrrowispConfig = JSON.parse(fs.readFileSync(configPath, "utf-8")); + +export class Mrrowisp { + config: MrrowispConfig; + process: ChildProcess | undefined; + + constructor(config?: Partial) { + this.config = defaultConfig; + this.process = undefined; + if (config) { + this.config = { ...this.config, ...config }; + } + logger.level = this.config.logLevel; + } + + async start() { + if (await detect(this.config.port) !== this.config.port) { + logger.error(`port ${this.config.port} is not available!! >w<`); + return; + } + + this.process = spawn(binPath, ["--config", JSON.stringify(this.config)], { + stdio: "pipe" + }); + + const handleData = (data: Buffer) => { + const msg = data.toString().trim(); + const levelMatch = msg.match(/^\[(DEBUG|INFO|WARN|ERROR)\]/); + if (levelMatch) { + switch (levelMatch[1]) { + case "DEBUG": logger.debug(msg); break; + case "INFO": logger.info(msg); break; + case "WARN": logger.warn(msg); break; + case "ERROR": logger.error(msg); break; + } + } else { + logger.error(msg); + } + }; + + this.process.stdout?.on("data", handleData); + this.process.stderr?.on("data", handleData); + + this.process.on("close", (code) => { + logger.info(`child process exited with code ${code} D:`); + this.process = undefined; + }); + } + + async route(req: IncomingMessage, socket: Socket, head: Buffer) { + if (!this.process) { + logger.error("mrrowisp is not running!! >w<"); + socket.destroy(); + return; + } + + const proxyReq = request({ + hostname: "127.0.0.1", + port: this.config.port, + path: req.url, + method: req.method, + headers: req.headers, + }); + + proxyReq.on("upgrade", (proxyRes, proxySocket, proxyHead) => { + socket.write( + `HTTP/1.1 101 Switching Protocols\r\n` + + Object.entries(proxyRes.headers) + .map(([k, v]) => `${k}: ${v}`) + .join("\r\n") + + "\r\n\r\n" + ); + + if (proxyHead?.length) proxySocket.unshift(proxyHead); + if (head?.length) socket.unshift(head); + + proxySocket.pipe(socket); + socket.pipe(proxySocket); + + proxySocket.on("error", () => socket.destroy()); + socket.on("error", () => proxySocket.destroy()); + }); + + proxyReq.on("error", (err) => { + logger.error(`proxy request error: ${err.message}`); + socket.destroy(); + }); + + proxyReq.end(); + } + + async stop() { + if (this.process) { + this.process.kill("SIGTERM"); + this.process = undefined; + } else { + logger.warn("mrrowisp is not running..."); + } + } + + async kill() { + if (this.process) { + this.process.kill("SIGKILL"); + this.process = undefined; + } else { + logger.warn("mrrowisp is not running..."); + } + } +} diff --git a/src/logger.ts b/src/logger.ts new file mode 100644 index 0000000..2251d36 --- /dev/null +++ b/src/logger.ts @@ -0,0 +1,40 @@ +import chalk from "chalk"; + +export type LogLevel = "debug" | "warn" | "error" | "info" | "none"; + +const levelPriority: Record = { + debug: 0, + warn: 1, + error: 2, + info: 3, + none: 4, +}; + +class Logger { + level: LogLevel = "info"; + + private shouldLog(method: LogLevel): boolean { + return levelPriority[method] >= levelPriority[this.level]; + } + + info(message: string) { + if (!this.shouldLog("info")) return; + console.log(chalk.bold(chalk.hex("#ebaaee")(`[mrrowisp]: ${message}`))); + } + error(message: string) { + if (!this.shouldLog("error")) return; + console.log(chalk.bold(chalk.hex("#f38fad")(`[mrrowisp]: ${message}`))); + } + warn(message: string) { + if (!this.shouldLog("warn")) return; + console.log(chalk.bold(chalk.hex("#f9dca1")(`[mrrowisp]: ${message}`))); + } + debug(message: string) { + if (!this.shouldLog("debug")) return; + console.log(chalk.bold(chalk.hex("#89b4fa")(`[mrrowisp]: ${message}`))); + } +} + +const logger = new Logger(); + +export default logger; diff --git a/src/path.ts b/src/path.ts index bc9b6e1..86d2c7e 100644 --- a/src/path.ts +++ b/src/path.ts @@ -1,11 +1,11 @@ -import * as os from "os"; +import * as path from "node:path"; +import * as os from "node:os"; +import { fileURLToPath } from "node:url"; -const arch = os.arch() -const platform = os.platform() +const bin = os.platform() === "win32" ? "mrrowisp.exe" : "mrrowisp"; -const pkg = `${platform}-${arch}` -const wispConfigPath = new URL("../dist/config.json", import.meta.url).pathname; -const wispPath = new URL(`../dist/${pkg}/mrrowisp`, import.meta.url).pathname; - -export { wispConfigPath, wispPath }; +const __dirname = path.dirname(fileURLToPath(import.meta.url)); +const root = path.resolve(__dirname, ".."); +export const configPath = path.join(root, "dist", "config.json"); +export const binPath = path.join(root, "bin", `${os.platform()}-${os.arch()}`, bin); diff --git a/src/server/index.ts b/src/server/index.ts deleted file mode 100644 index a27fd0c..0000000 --- a/src/server/index.ts +++ /dev/null @@ -1,243 +0,0 @@ -import { spawn, type ChildProcess } from "child_process"; -import * as fs from "fs"; -import { wispConfigPath, wispPath } from "../path.js"; -import type { Config, WispBuilder, WispEvents, WispServer } from "../types.js"; - -type EventListeners = { - [E in keyof WispEvents]: Array; -}; - -class WispServerImpl implements WispServer { - readonly process: ChildProcess; - readonly config: Config; - private _running: boolean = true; - private listeners: EventListeners; - - constructor(process: ChildProcess, config: Config, listeners: EventListeners) { - this.process = process; - this.config = config; - this.listeners = listeners; - - this.process.on("exit", (code, signal) => { - this._running = false; - this.listeners.exit.forEach((cb) => cb(code, signal)); - }); - - this.process.on("error", (err) => { - this._running = false; - this.listeners.error.forEach((cb) => cb(err)); - }); - } - - get running(): boolean { - return this._running; - } - - stop(): Promise { - return new Promise((resolve, reject) => { - if (!this._running) { - resolve(); - return; - } - - const timeout = setTimeout(() => { - this.process.kill("SIGKILL"); - }, 5000); - - this.process.once("exit", () => { - clearTimeout(timeout); - resolve(); - }); - - this.process.once("error", (err) => { - clearTimeout(timeout); - reject(err); - }); - - this.process.kill("SIGTERM"); - }); - } - - kill(signal: NodeJS.Signals = "SIGKILL"): void { - if (this._running) { - this.process.kill(signal); - } - } - - on(event: K, listener: WispEvents[K]): WispServer { - (this.listeners[event] as Array).push(listener); - return this; - } - - off(event: K, listener: WispEvents[K]): WispServer { - const arr = this.listeners[event] as Array; - const idx = arr.indexOf(listener); - if (idx !== -1) { - arr.splice(idx, 1); - } - return this; - } -} - -class WispBuilderImpl implements WispBuilder { - private config: Config; - private listeners: EventListeners = { - ready: [], - error: [], - exit: [], - stdout: [], - stderr: [], - }; - - constructor() { - this.config = JSON.parse(fs.readFileSync(wispConfigPath, "utf-8")); - } - - fromFile(path: string): WispBuilder { - const fileConfig = JSON.parse(fs.readFileSync(path, "utf-8")); - this.config = { ...this.config, ...fileConfig }; - return this; - } - - fromJSON(json: string): WispBuilder { - const parsed = JSON.parse(json); - this.config = { ...this.config, ...parsed }; - return this; - } - - withConfig(config: Partial): WispBuilder { - this.config = { ...this.config, ...config }; - return this; - } - - port(port: number): WispBuilder { - this.config.port = port; - return this; - } - - udp(enabled: boolean): WispBuilder { - this.config.disableUDP = !enabled; - return this; - } - - v2(enabled: boolean): WispBuilder { - this.config.enableV2 = enabled; - return this; - } - - twisp(enabled: boolean): WispBuilder { - this.config.enableTwisp = enabled; - return this; - } - - motd(message: string): WispBuilder { - this.config.motd = message; - return this; - } - - blacklist(hostnames: string[]): WispBuilder { - this.config.blacklist = { hostnames }; - return this; - } - - whitelist(hostnames: string[]): WispBuilder { - this.config.whitelist = { hostnames }; - return this; - } - - proxy(url: string): WispBuilder { - this.config.proxy = url; - return this; - } - - dns(servers: string | string[]): WispBuilder { - this.config.dnsServer = Array.isArray(servers) ? servers : [servers]; - return this; - } - - onReady(callback: () => void): WispBuilder { - this.listeners.ready.push(callback); - return this; - } - - onError(callback: (error: Error) => void): WispBuilder { - this.listeners.error.push(callback); - return this; - } - - onExit(callback: (code: number | null, signal: NodeJS.Signals | null) => void): WispBuilder { - this.listeners.exit.push(callback); - return this; - } - - onStdout(callback: (data: string) => void): WispBuilder { - this.listeners.stdout.push(callback); - return this; - } - - onStderr(callback: (data: string) => void): WispBuilder { - this.listeners.stderr.push(callback); - return this; - } - - getConfig(): Config { - return { ...this.config }; - } - - start(): Promise { - return new Promise((resolve, reject) => { - let resolved = false; - - const process = spawn(wispPath, ["--config", JSON.stringify(this.config)]); - - const server = new WispServerImpl(process, this.config, this.listeners); - - process.stdout.on("data", (data: Buffer) => { - const str = data.toString(); - this.listeners.stdout.forEach((cb) => cb(str)); - - if (!resolved && str.includes("Starting Mrrowisp")) { - resolved = true; - this.listeners.ready.forEach((cb) => cb()); - resolve(server); - } - }); - - process.stderr.on("data", (data: Buffer) => { - const str = data.toString(); - this.listeners.stderr.forEach((cb) => cb(str)); - }); - - process.on("error", (err) => { - if (!resolved) { - resolved = true; - this.listeners.error.forEach((cb) => cb(err)); - reject(err); - } - }); - - process.on("exit", (code, signal) => { - if (!resolved) { - resolved = true; - const err = new Error(`Server exited before ready (code: ${code}, signal: ${signal})`); - this.listeners.error.forEach((cb) => cb(err)); - reject(err); - } - }); - - setTimeout(() => { - if (!resolved) { - resolved = true; - const err = new Error("Server startup timed out after 10 seconds"); - this.listeners.error.forEach((cb) => cb(err)); - process.kill("SIGKILL"); - reject(err); - } - }, 10000); - }); - } -} - -export function createMrrowisp(): WispBuilder { - return new WispBuilderImpl(); -} \ No newline at end of file diff --git a/src/types.d.ts b/src/types.d.ts deleted file mode 100644 index ad5555b..0000000 --- a/src/types.d.ts +++ /dev/null @@ -1,71 +0,0 @@ -import type { ChildProcess } from "child_process"; - -export type Config = { - port?: number; - disableUDP?: boolean; - tcpBufferSize?: number; - bufferRemainingLength?: number; - tcpNoDelay?: boolean; - websocketTcpNoDelay?: boolean; - blacklist?: { - hostnames: string[]; - }; - whitelist?: { - hostnames: string[]; - }; - proxy?: string; - websocketPermessageDeflate?: boolean; - dnsServer?: string[]; - enableTwisp?: boolean; - enableV2: boolean; - motd?: string; - passwordAuth?: boolean; - passwordAuthRequired?: boolean; - passwordUsers?: { - [username: string]: string; - }; - certAuth?: boolean; - certAuthRequired?: boolean; - certAuthPublicKeys?: string[]; - enableStreamConfirm?: boolean; -}; - -export type WispEvents = { - ready: () => void; - error: (error: Error) => void; - exit: (code: number | null, signal: NodeJS.Signals | null) => void; - stdout: (data: string) => void; - stderr: (data: string) => void; -}; - -export type WispServer = { - readonly process: ChildProcess; - readonly config: Config; - readonly running: boolean; - stop(): Promise; - kill(signal?: NodeJS.Signals): void; - on(event: K, listener: WispEvents[K]): WispServer; - off(event: K, listener: WispEvents[K]): WispServer; -}; - -export type WispBuilder = { - fromFile(path: string): WispBuilder; - fromJSON(json: string): WispBuilder; - withConfig(config: Partial): WispBuilder; - port(port: number): WispBuilder; - udp(enabled: boolean): WispBuilder; - v2(enabled: boolean): WispBuilder; - twisp(enabled: boolean): WispBuilder; - motd(message: string): WispBuilder; - blacklist(hostnames: string[]): WispBuilder; - whitelist(hostnames: string[]): WispBuilder; - proxy(url: string): WispBuilder; - dns(servers: string | string[]): WispBuilder; - onReady(callback: () => void): WispBuilder; - onError(callback: (error: Error) => void): WispBuilder; - onExit(callback: (code: number | null, signal: NodeJS.Signals | null) => void): WispBuilder; - onStdout(callback: (data: string) => void): WispBuilder; - onStderr(callback: (data: string) => void): WispBuilder; - getConfig(): Config; - start(): Promise; -}; diff --git a/wisp/config.go b/wisp/config.go new file mode 100644 index 0000000..c47dadc --- /dev/null +++ b/wisp/config.go @@ -0,0 +1,178 @@ +package wisp + +import ( + "encoding/json" + "net" + "os" + "strings" +) + +type FilterList struct { + Hostnames []string `json:"hostnames"` + Ports []interface{} `json:"ports"` +} + +type Config struct { + Port int `json:"port"` + + AllowTCP bool `json:"allowTCP"` + AllowUDP bool `json:"allowUDP"` + + AllowDirectIP bool `json:"allowDirectIP"` + AllowPrivateIPs bool `json:"allowPrivateIPs"` + AllowLoopbackIPs bool `json:"allowLoopbackIPs"` + + TcpBufferSize int `json:"tcpBufferSize"` + TcpNoDelay bool `json:"tcpNoDelay"` + + Blacklist FilterList `json:"blacklist"` + Whitelist FilterList `json:"whitelist"` + + DnsServers []string `json:"dnsServers"` + DnsMethod string `json:"dnsMethod"` + DnsResultOrder string `json:"dnsResultOrder"` + + EnableTwisp bool `json:"enableTwisp"` + + EnableV2 bool `json:"enableV2"` + Motd string `json:"motd"` + PasswordAuth bool `json:"passwordAuth"` + PasswordAuthRequired bool `json:"passwordAuthRequired"` + PasswordUsers map[string]string `json:"passwordUsers"` + + ParseRealIP bool `json:"parseRealIP"` + NonWSResponse string `json:"nonWSResponse"` + + LogLevel string `json:"logLevel"` + + Proxy string `json:"proxy"` + MaxMessageSize int `json:"maxMessageSize"` + StaticDir string `json:"staticDir"` + BandwidthLimitKbps int `json:"bandwidthLimitKbps"` + ConnectionsLimitPerIP int `json:"connectionsLimitPerIP"` + ConnectionWindowSeconds int `json:"connectionWindowSeconds"` + + BufferRemainingLength uint32 `json:"bufferRemainingLength"` + + Logger Logger + DNSCache *DNSCache + Dialer net.Dialer +} + +func DefaultConfig() Config { + return Config{ + Port: 6001, + + AllowTCP: true, + AllowUDP: true, + + AllowDirectIP: false, + AllowPrivateIPs: false, + AllowLoopbackIPs: false, + + TcpBufferSize: 32768, + TcpNoDelay: true, + + Blacklist: FilterList{ + Hostnames: []string{}, + Ports: []interface{}{}, + }, + Whitelist: FilterList{ + Hostnames: []string{}, + Ports: []interface{}{}, + }, + + DnsServers: []string{}, + DnsMethod: "resolve", + DnsResultOrder: "ipv4first", + + EnableTwisp: false, + + EnableV2: true, + Motd: "", + PasswordAuth: false, + PasswordAuthRequired: false, + PasswordUsers: map[string]string{}, + + ParseRealIP: true, + NonWSResponse: "", + + LogLevel: "info", + + Proxy: "", + MaxMessageSize: 0, + StaticDir: "", + BandwidthLimitKbps: 0, + ConnectionsLimitPerIP: 0, + ConnectionWindowSeconds: 0, + BufferRemainingLength: 32768, + } +} + +func CreateWispConfig(cfg Config) *Config { + wispCfg := &Config{ + AllowTCP: cfg.AllowTCP, + AllowUDP: cfg.AllowUDP, + + AllowDirectIP: cfg.AllowDirectIP, + AllowPrivateIPs: cfg.AllowPrivateIPs, + AllowLoopbackIPs: cfg.AllowLoopbackIPs, + + TcpBufferSize: cfg.TcpBufferSize, + TcpNoDelay: cfg.TcpNoDelay, + + Blacklist: cfg.Blacklist, + Whitelist: cfg.Whitelist, + + DnsServers: cfg.DnsServers, + DnsMethod: cfg.DnsMethod, + DnsResultOrder: cfg.DnsResultOrder, + + EnableTwisp: cfg.EnableTwisp, + + EnableV2: cfg.EnableV2, + Motd: cfg.Motd, + PasswordAuth: cfg.PasswordAuth, + PasswordAuthRequired: cfg.PasswordAuthRequired, + PasswordUsers: cfg.PasswordUsers, + + ParseRealIP: cfg.ParseRealIP, + NonWSResponse: cfg.NonWSResponse, + + LogLevel: cfg.LogLevel, + + Proxy: cfg.Proxy, + MaxMessageSize: cfg.MaxMessageSize, + BandwidthLimitKbps: cfg.BandwidthLimitKbps, + ConnectionsLimitPerIP: cfg.ConnectionsLimitPerIP, + ConnectionWindowSeconds: cfg.ConnectionWindowSeconds, + + BufferRemainingLength: cfg.BufferRemainingLength, + } + + return wispCfg +} + +func LoadConfig(config string) (Config, error) { + cfg := DefaultConfig() + + trimConfig := strings.TrimSpace(config) + if strings.HasPrefix(trimConfig, "{") { + if err := json.Unmarshal([]byte(trimConfig), &cfg); err != nil { + return cfg, err + } + return cfg, nil + } + + file, err := os.Open(config) + if err != nil { + return cfg, err + } + defer file.Close() + + decoder := json.NewDecoder(file) + if err := decoder.Decode(&cfg); err != nil { + return cfg, err + } + return cfg, nil +} diff --git a/wisp/dnscache.go b/wisp/dnscache.go index 0857c84..3197981 100644 --- a/wisp/dnscache.go +++ b/wisp/dnscache.go @@ -3,8 +3,11 @@ package wisp import ( "context" "net" + "strings" "sync" "time" + + "golang.org/x/sync/singleflight" ) type dnsEntry struct { @@ -13,37 +16,103 @@ type dnsEntry struct { err error } +type DNSCacheConfig struct { + Servers []string + TTLSeconds int + Method string + ResultOrder string +} + type DNSCache struct { - servers []string - resolver *net.Resolver + servers []string + resolver *net.Resolver + ttl time.Duration + resultOrder string mu sync.RWMutex cache map[string]dnsEntry + group singleflight.Group } -func NewDNSCache(servers []string) *DNSCache { +func NewDNSCache(cfg DNSCacheConfig) *DNSCache { + ttl := time.Duration(cfg.TTLSeconds) * time.Second + if ttl <= 0 { + ttl = 120 * time.Second + } cache := &DNSCache{ - servers: servers, - cache: make(map[string]dnsEntry), + servers: cfg.Servers, + ttl: ttl, + resultOrder: cfg.ResultOrder, + cache: make(map[string]dnsEntry), } - cache.initResolver() + cache.initResolver(cfg.Method) + cache.cleanup() return cache } -func (d *DNSCache) initResolver() { - if len(d.servers) > 0 { +func (d *DNSCache) cleanup() { + interval := d.ttl / 2 + if interval < time.Minute { + interval = time.Minute + } + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for range ticker.C { + d.expire() + } + }() +} + +func (d *DNSCache) expire() { + now := time.Now() + d.mu.Lock() + defer d.mu.Unlock() + for host, entry := range d.cache { + if now.After(entry.expiresAt) { + delete(d.cache, host) + } + } +} + +func (d *DNSCache) initResolver(method string) { + method = strings.ToLower(strings.TrimSpace(method)) + if method == "resolve" && len(d.servers) > 0 { + server := firstDNSServer(d.servers) + if server == "" { + d.resolver = net.DefaultResolver + return + } d.resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { dialer := net.Dialer{ Timeout: 5 * time.Second, } - return dialer.DialContext(ctx, "udp", d.servers[0]) + return dialer.DialContext(ctx, "udp", server) }, } - } else { - d.resolver = net.DefaultResolver + return } + d.resolver = net.DefaultResolver +} + +func firstDNSServer(servers []string) string { + for _, server := range servers { + server = strings.TrimSpace(server) + if server == "" { + continue + } + return normalizeDNSServer(server) + } + return "" +} + +func normalizeDNSServer(server string) string { + if _, _, err := net.SplitHostPort(server); err == nil { + return server + } + return net.JoinHostPort(server, "53") } func (d *DNSCache) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { @@ -64,19 +133,59 @@ func (d *DNSCache) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, return entry.ips, nil } - ips, err := d.resolver.LookupIPAddr(ctx, host) + v, err, _ := d.group.Do(host, func() (any, error) { + ips, resolveErr := d.resolver.LookupIPAddr(ctx, host) + if resolveErr == nil { + ips = reorderIPs(ips, d.resultOrder) + } + entry := dnsEntry{ + ips: ips, + expiresAt: time.Now().Add(d.ttl), + err: resolveErr, + } + d.mu.Lock() + d.cache[host] = entry + d.mu.Unlock() + return entry, resolveErr + }) + if err != nil { + return nil, err + } + entry, ok = v.(dnsEntry) + if !ok { + return nil, err + } + if entry.err != nil { + return nil, entry.err + } + return entry.ips, nil +} - d.mu.Lock() - d.cache[host] = dnsEntry{ - ips: ips, - expiresAt: now.Add(120 * time.Second), - err: err, +func reorderIPs(ips []net.IPAddr, order string) []net.IPAddr { + if len(ips) <= 1 { + return ips + } + order = strings.ToLower(strings.TrimSpace(order)) + if order == "verbatim" || order == "" { + return ips } - d.mu.Unlock() - if err != nil { - return nil, err + var v4 []net.IPAddr + var v6 []net.IPAddr + for _, ip := range ips { + if ip.IP.To4() != nil { + v4 = append(v4, ip) + } else { + v6 = append(v6, ip) + } + } + + if order == "ipv4first" { + return append(v4, v6...) + } + if order == "ipv6first" { + return append(v6, v4...) } - return ips, nil + return ips } diff --git a/wisp/logger.go b/wisp/logger.go new file mode 100644 index 0000000..7014479 --- /dev/null +++ b/wisp/logger.go @@ -0,0 +1,58 @@ +package wisp + +import ( + "log" + "strings" +) + +type Logger interface { + Debug(msg string, kv ...any) + Info(msg string, kv ...any) + Warn(msg string, kv ...any) + Error(msg string, kv ...any) +} + +type logLevel int + +const ( + levelDebug logLevel = iota + levelInfo + levelWarn + levelError +) + +type Log struct { + level logLevel + inner *log.Logger +} + +func newLogger(level string) Logger { + lvl := levelInfo + switch strings.ToLower(strings.TrimSpace(level)) { + case "debug": + lvl = levelDebug + case "info": + lvl = levelInfo + case "warn", "warning": + lvl = levelWarn + case "error": + lvl = levelError + } + return &Log{level: lvl, inner: log.Default()} +} + +func (l *Log) Debug(msg string, kv ...any) { l.log(levelDebug, "DEBUG", msg, kv...) } +func (l *Log) Info(msg string, kv ...any) { l.log(levelInfo, "INFO", msg, kv...) } +func (l *Log) Warn(msg string, kv ...any) { l.log(levelWarn, "WARN", msg, kv...) } +func (l *Log) Error(msg string, kv ...any) { l.log(levelError, "ERROR", msg, kv...) } + +func (l *Log) log(lvl logLevel, prefix string, msg string, kv ...any) { + if l == nil || l.inner == nil || lvl < l.level { + return + } + if len(kv) == 0 { + l.inner.Printf("[%s] %s", prefix, msg) + return + } + l.inner.Printf("[%s] %s %v", prefix, msg, kv) +} diff --git a/wisp/protection.go b/wisp/protection.go new file mode 100644 index 0000000..b033fe8 --- /dev/null +++ b/wisp/protection.go @@ -0,0 +1,173 @@ +package wisp + +import ( + "net" + "net/http" + "strconv" + "strings" + + prot "mrrowisp/wisp/protection" +) + +type guard struct { + config *Config +} + +type connectAction uint8 + +const ( + connectBlocked connectAction = iota + connectStream + connectTwisp +) + +func newProtection(config *Config) *guard { + return &guard{config: config} +} + +func (p *guard) allowHTTP(r *http.Request, remoteIP string, useV2 bool) (int, string, bool) { + cfg := p.config + + if !isWebsocketUpgrade(r) { + if cfg.NonWSResponse != "" { + return http.StatusOK, cfg.NonWSResponse, false + } + return http.StatusBadRequest, "", false + } + + if !useV2 && cfg.requiresV2() { + cfg.Logger.Warn("websocket v1 downgrade blocked", "ip", remoteIP) + return http.StatusUnauthorized, cfg.NonWSResponse, false + } + + return 0, "", true +} + +func isWebsocketUpgrade(r *http.Request) bool { + return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") && + strings.ToLower(r.Header.Get("Upgrade")) == "websocket" +} + +func (p *guard) allowConnect(c *wispConnection, streamType uint8, hostname string, port string) (connectAction, string, uint8) { + cfg := p.config + + if len(hostname) > 2048 || strings.IndexByte(hostname, 0) >= 0 { + return connectBlocked, "", closeReasonInvalidInfo + } + + if !c.connectLimiter.allow() { + cfg.Logger.Warn("connect rate limit exceeded", "ip", c.remoteIP) + return connectBlocked, "", closeReasonThrottled + } + + if streamType == streamTypeTerm { + if !cfg.EnableTwisp || !c.twispAuthorized() { + cfg.Logger.Warn("terminal stream blocked", "ip", c.remoteIP) + return connectBlocked, "", closeReasonBlocked + } + return connectTwisp, "", 0 + } + + if streamType == streamTypeTCP && !cfg.AllowTCP { + cfg.Logger.Warn("TCP streams blocked", "ip", c.remoteIP, "hostname", hostname) + return connectBlocked, "", closeReasonBlocked + } + if streamType == streamTypeUDP && !cfg.AllowUDP { + cfg.Logger.Warn("UDP streams blocked", "ip", c.remoteIP, "hostname", hostname) + return connectBlocked, "", closeReasonBlocked + } + + normalizedHostname := prot.NormalizeTargetHostname(hostname) + if normalizedHostname == "" { + return connectBlocked, "", closeReasonInvalidInfo + } + + if !cfg.AllowLoopbackIPs && prot.IsOwnIP(normalizedHostname) { + cfg.Logger.Warn("self-targeting stream blocked", "ip", c.remoteIP, "hostname", hostname) + return connectBlocked, "", closeReasonBlocked + } + + return connectStream, normalizedHostname, 0 +} + +func (p *guard) allowHostPort(hostname string, port string) (uint8, bool) { + cfg := p.config + + portNum, err := strconv.Atoi(port) + if err != nil { + return closeReasonInvalidInfo, false + } + + if len(cfg.Whitelist.Ports) > 0 { + allowed := false + type portContains interface{ Contains(int) bool } + for _, r := range cfg.Whitelist.Ports { + if c, ok := r.(portContains); ok { + if c.Contains(portNum) { + allowed = true + break + } + } + } + if !allowed { + return closeReasonBlocked, false + } + } else if len(cfg.Blacklist.Ports) > 0 { + type portContains interface{ Contains(int) bool } + for _, r := range cfg.Blacklist.Ports { + if c, ok := r.(portContains); ok { + if c.Contains(portNum) { + return closeReasonBlocked, false + } + } + } + } + + return 0, true +} + +func (p *guard) allowDirectIP(ip net.IP, remoteIP string, hostname string) (uint8, bool) { + cfg := p.config + + if !cfg.AllowDirectIP { + return closeReasonBlocked, false + } + if !prot.IsAllowedTargetIP(ip, prot.IPConfig{ + AllowDirectIP: cfg.AllowDirectIP, + AllowPrivateIPs: cfg.AllowPrivateIPs, + AllowLoopbackIPs: cfg.AllowLoopbackIPs, + }) { + return closeReasonBlocked, false + } + if !cfg.AllowLoopbackIPs && prot.IsOwnIP(ip.String()) { + cfg.Logger.Warn("self-targeting stream blocked", "ip", remoteIP, "hostname", hostname) + return closeReasonBlocked, false + } + + return 0, true +} + +func (p *guard) selectAllowedIP(ips []net.IPAddr, remoteIP string, hostname string) (string, uint8, bool) { + cfg := p.config + + selected, ok := prot.FirstAllowedIP(ips, prot.IPConfig{ + AllowDirectIP: cfg.AllowDirectIP, + AllowPrivateIPs: cfg.AllowPrivateIPs, + AllowLoopbackIPs: cfg.AllowLoopbackIPs, + }) + if !ok { + cfg.Logger.Warn("DNS returned only blocked IPs", "ip", remoteIP, "hostname", hostname) + return "", closeReasonBlocked, false + } + if !cfg.AllowLoopbackIPs && prot.IsOwnIP(selected) { + cfg.Logger.Warn("self-targeting stream blocked", "ip", remoteIP, "hostname", hostname) + return "", closeReasonBlocked, false + } + + return selected, 0, true +} + +func (p *guard) allowMessageSize(size int) bool { + max := p.config.MaxMessageSize + return max <= 0 || size <= max +} diff --git a/wisp/protection/banlist.go b/wisp/protection/banlist.go new file mode 100644 index 0000000..0d7f50c --- /dev/null +++ b/wisp/protection/banlist.go @@ -0,0 +1,95 @@ +package protection + +import ( + "net" + "net/http" + "sync" + "time" +) + +type BanList struct { + mutex sync.RWMutex + bans map[string]time.Time + banDur time.Duration + strikes map[string]int + maxStrikes int + escalationMultiplier int +} + +func NewBanList(banDuration time.Duration, maxStrikes int) *BanList { + return NewBanListEscalated(banDuration, maxStrikes, 0) +} + +func NewBanListEscalated(banDuration time.Duration, maxStrikes int, escalation int) *BanList { + if banDuration <= 0 { + banDuration = time.Hour + } + if maxStrikes <= 0 { + maxStrikes = 10 + } + b := &BanList{ + bans: make(map[string]time.Time), + strikes: make(map[string]int), + mutex: sync.RWMutex{}, + banDur: banDuration, + maxStrikes: maxStrikes, + escalationMultiplier: escalation, + } + go b.cleanup() + return b +} + +func (b *BanList) Strike(ip string) (banned bool) { + b.mutex.Lock() + defer b.mutex.Unlock() + b.strikes[ip]++ + if b.strikes[ip] >= b.maxStrikes { + dur := b.banDur + if b.escalationMultiplier > 0 { + for i := 1; i < b.strikes[ip]/b.maxStrikes; i++ { + dur *= time.Duration(b.escalationMultiplier) + } + } + b.bans[ip] = time.Now().Add(dur) + delete(b.strikes, ip) + return true + } + return false +} + +func (b *BanList) IsBanned(ip string) bool { + b.mutex.RLock() + defer b.mutex.RUnlock() + unbanAt, exists := b.bans[ip] + if !exists { + return false + } + return time.Now().Before(unbanAt) +} + +func (b *BanList) cleanup() { + for range time.Tick(5 * time.Minute) { + b.mutex.Lock() + now := time.Now() + for ip, unbanAt := range b.bans { + if now.After(unbanAt) { + delete(b.bans, ip) + } + } + b.mutex.Unlock() + } +} + +func (b *BanList) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + if b.IsBanned(ip) { + http.Error(w, "banned", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/wisp/protection/iputil.go b/wisp/protection/iputil.go new file mode 100644 index 0000000..5ba0316 --- /dev/null +++ b/wisp/protection/iputil.go @@ -0,0 +1,150 @@ +package protection + +import ( + "net" + "net/http" + "strings" +) + +type IPConfig struct { + AllowDirectIP bool + AllowPrivateIPs bool + AllowLoopbackIPs bool + ParseRealIP bool + ParseRealIPFrom map[string]struct{} +} + +func RemoteIPFromRequest(r *http.Request, cfg IPConfig) string { + if r == nil { + return "" + } + if cfg.ParseRealIP { + if ip := parseForwardedIP(r, cfg.ParseRealIPFrom); ip != "" { + return ip + } + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil { + return host + } + return r.RemoteAddr +} + +func parseForwardedIP(r *http.Request, allowed map[string]struct{}) string { + if r == nil { + return "" + } + if len(allowed) == 0 { + return "" + } + + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return "" + } + if !isIPAllowed(host, allowed) { + return "" + } + + xff := r.Header.Get("X-Forwarded-For") + if xff != "" { + parts := strings.Split(xff, ",") + if len(parts) > 0 { + ip := strings.TrimSpace(parts[0]) + if net.ParseIP(ip) != nil { + return ip + } + } + } + + xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")) + if xrip != "" && net.ParseIP(xrip) != nil { + return xrip + } + + return "" +} + +func isIPAllowed(ip string, allowed map[string]struct{}) bool { + if _, ok := allowed[ip]; ok { + return true + } + parsed := net.ParseIP(ip) + if parsed == nil { + return false + } + for entry := range allowed { + _, cidr, err := net.ParseCIDR(entry) + if err == nil && cidr.Contains(parsed) { + return true + } + } + return false +} + +func IsAllowedTargetIP(ip net.IP, cfg IPConfig) bool { + if ip == nil { + return false + } + if ip.IsUnspecified() || ip.IsMulticast() { + return false + } + if isCarrierGradeNAT(ip) || isBenchmarkingIP(ip) { + return false + } + if !cfg.AllowLoopbackIPs && ip.IsLoopback() { + return false + } + if !cfg.AllowPrivateIPs && (ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) { + return false + } + return true +} + +func FirstAllowedIP(ips []net.IPAddr, cfg IPConfig) (string, bool) { + for _, addr := range ips { + if IsAllowedTargetIP(addr.IP, cfg) { + return addr.IP.String(), true + } + } + return "", false +} + +func isCarrierGradeNAT(ip net.IP) bool { + ipv4 := ip.To4() + if ipv4 == nil { + return false + } + return ipv4[0] == 100 && ipv4[1]&0xC0 == 0x40 +} + +func isBenchmarkingIP(ip net.IP) bool { + ipv4 := ip.To4() + if ipv4 == nil { + return false + } + return ipv4[0] == 198 && (ipv4[1] == 18 || ipv4[1] == 19) +} + +func NormalizeTargetHostname(host string) string { + host = strings.TrimSpace(strings.ToLower(host)) + host = strings.TrimSuffix(host, ".") + return host +} + +func IsOwnIP(resolvedIP string) bool { + ifaces, err := net.Interfaces() + if err != nil { + return false + } + for _, iface := range ifaces { + ifaceAddrs, _ := iface.Addrs() + for _, ifaceAddr := range ifaceAddrs { + ip, _, _ := net.ParseCIDR(ifaceAddr.String()) + if ip != nil && ip.String() == resolvedIP { + return true + } + } + } + return false +} diff --git a/wisp/protection/limits.go b/wisp/protection/limits.go new file mode 100644 index 0000000..d7279d4 --- /dev/null +++ b/wisp/protection/limits.go @@ -0,0 +1,177 @@ +package protection + +import ( + "sync" + "time" +) + +type BandwidthLimiter struct { + mu sync.Mutex + window time.Duration + bytes map[string]uint64 + start time.Time + limit uint64 +} + +func NewBandwidthLimiter(kbps int, window time.Duration) *BandwidthLimiter { + if window <= 0 { + window = time.Second + } + limit := uint64(kbps) * 1024 + return &BandwidthLimiter{window: window, start: time.Now(), limit: limit, bytes: make(map[string]uint64)} +} + +func (b *BandwidthLimiter) Allow(ip string, n uint64) bool { + if b == nil || b.limit == 0 { + return true + } + b.mu.Lock() + defer b.mu.Unlock() + now := time.Now() + if now.Sub(b.start) >= b.window { + b.start = now + b.bytes = make(map[string]uint64) + } + used := b.bytes[ip] + if used+n > b.limit { + return false + } + b.bytes[ip] = used + n + return true +} + +type ConnectionLimiter struct { + mu sync.Mutex + window time.Duration + start time.Time + counts map[string]int + limit int +} + +func NewConnectionLimiter(limit int, window time.Duration) *ConnectionLimiter { + if window <= 0 { + window = time.Second + } + return &ConnectionLimiter{window: window, start: time.Now(), limit: limit, counts: make(map[string]int)} +} + +func (c *ConnectionLimiter) Allow(ip string) bool { + if c == nil || c.limit <= 0 { + return true + } + c.mu.Lock() + defer c.mu.Unlock() + now := time.Now() + if now.Sub(c.start) >= c.window { + c.start = now + c.counts = make(map[string]int) + } + c.counts[ip]++ + return c.counts[ip] <= c.limit +} + +type PacketRateLimiter struct { + mu sync.Mutex + interval time.Duration + limit int + count int + resetAt time.Time +} + +func NewPacketRateLimiter(packetsPerSec int) *PacketRateLimiter { + if packetsPerSec <= 0 { + packetsPerSec = 500 + } + return &PacketRateLimiter{ + interval: time.Second, + limit: packetsPerSec, + resetAt: time.Now().Add(time.Second), + } +} + +func (p *PacketRateLimiter) Allow() bool { + p.mu.Lock() + defer p.mu.Unlock() + now := time.Now() + if now.After(p.resetAt) { + p.count = 0 + p.resetAt = now.Add(p.interval) + } + p.count++ + return p.count <= p.limit +} + +type ConnectionCounter struct { + mu sync.Mutex + perIP map[string]int + global int +} + +func NewConnectionCounter() *ConnectionCounter { + return &ConnectionCounter{perIP: make(map[string]int)} +} + +func (c *ConnectionCounter) TryAdd(ip string, maxPerIP int, maxGlobal int) bool { + c.mu.Lock() + defer c.mu.Unlock() + if maxGlobal > 0 && c.global >= maxGlobal { + return false + } + if maxPerIP > 0 && c.perIP[ip] >= maxPerIP { + return false + } + c.perIP[ip]++ + c.global++ + return true +} + +func (c *ConnectionCounter) Remove(ip string) { + c.mu.Lock() + defer c.mu.Unlock() + if c.perIP[ip] > 0 { + c.perIP[ip]-- + if c.perIP[ip] <= 0 { + delete(c.perIP, ip) + } + } + if c.global > 0 { + c.global-- + } +} + +type InboundRateLimiter struct { + mu sync.Mutex + interval time.Duration + limit int + count int + resetAt time.Time +} + +func NewInboundRateLimiter(bytesPerSec int) *InboundRateLimiter { + if bytesPerSec <= 0 { + bytesPerSec = 0 + } + return &InboundRateLimiter{ + interval: time.Second, + limit: bytesPerSec, + resetAt: time.Now().Add(time.Second), + } +} + +func (r *InboundRateLimiter) Allow(n int) bool { + if r == nil || r.limit <= 0 { + return true + } + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + if now.After(r.resetAt) { + r.count = 0 + r.resetAt = now.Add(r.interval) + } + if r.count+n > r.limit { + return false + } + r.count += n + return true +} diff --git a/wisp/protection/streamlimits.go b/wisp/protection/streamlimits.go new file mode 100644 index 0000000..b55d633 --- /dev/null +++ b/wisp/protection/streamlimits.go @@ -0,0 +1,38 @@ +package protection + +import "sync" + +type StreamLimiter struct { + mutex sync.Mutex + pH map[string]int + total int +} + +func NewStreamLimiter() *StreamLimiter { + return &StreamLimiter{pH: make(map[string]int)} +} + +func (s *StreamLimiter) Allow(host string, perHostLimit int, totalLimit int) bool { + s.mutex.Lock() + defer s.mutex.Unlock() + if totalLimit > 0 && s.total >= totalLimit { + return false + } + if perHostLimit > 0 && s.pH[host] >= perHostLimit { + return false + } + s.total++ + s.pH[host]++ + return true +} + +func (s *StreamLimiter) Release(host string) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.total > 0 { + s.total-- + } + if s.pH[host] > 0 { + s.pH[host]-- + } +} diff --git a/wisp/twisp.go b/wisp/twisp.go index d04e192..2ac95d0 100644 --- a/wisp/twisp.go +++ b/wisp/twisp.go @@ -91,9 +91,7 @@ func handleTwisp(wc *wispConnection, streamId uint32, command string) { func (ts *twispStream) readPty() { const maxHeaderLen = 15 - bufPool := ts.wispConn.config.ReadBufPool.Get().(*[]byte) - buf := *bufPool - defer ts.wispConn.config.ReadBufPool.Put(bufPool) + buf := make([]byte, maxHeaderLen+65535) streamId := ts.streamId @@ -117,10 +115,10 @@ func (ts *twispStream) readPty() { frameStart = 0 buf[0] = 0x82 buf[1] = 127 - buf[2] = 0 - buf[3] = 0 - buf[4] = 0 - buf[5] = 0 + buf[2] = byte(totalPayload >> 56) + buf[3] = byte(totalPayload >> 48) + buf[4] = byte(totalPayload >> 40) + buf[5] = byte(totalPayload >> 32) buf[6] = byte(totalPayload >> 24) buf[7] = byte(totalPayload >> 16) buf[8] = byte(totalPayload >> 8) diff --git a/wisp/v2.go b/wisp/v2.go index 8fdc47f..7ae9027 100644 --- a/wisp/v2.go +++ b/wisp/v2.go @@ -1,15 +1,16 @@ package wisp import ( - "crypto/ed25519" - "crypto/rand" - "crypto/sha256" + "crypto/subtle" "encoding/binary" "errors" + "time" ) var errorInvalid = errors.New("invalid wisp v2 payload") +const v2HandshakeTimeout = 15 * time.Second + type extensions struct { udp bool streamConfirm bool @@ -26,7 +27,7 @@ type extensions struct { func (c *wispConnection) buildServerInfoPacket() []byte { var extensions []byte - if !c.config.DisableUDP { + if c.config.AllowUDP { extensions = addExtension(extensions, extensionUDP, nil) } @@ -38,28 +39,10 @@ func (c *wispConnection) buildServerInfoPacket() []byte { extensions = addExtension(extensions, extensionPasswordAuth, meta[:]) } - if c.config.CertAuth && len(c.config.CertAuthPublicKeys) > 0 { - challenge := make([]byte, 64) - rand.Read(challenge) - c.v2Challenge = challenge - - meta := make([]byte, 2+len(challenge)) - if c.config.CertAuthRequired { - meta[0] = 1 - } - meta[1] = sigEd25519 - copy(meta[2:], challenge) - extensions = addExtension(extensions, extensionCertificateAuth, meta) - } - if c.config.Motd != "" { extensions = addExtension(extensions, extensionMotd, []byte(c.config.Motd)) } - if c.config.EnableStreamConfirm { - extensions = addExtension(extensions, extensionStreamConfirm, nil) - } - payload := make([]byte, 5+2+len(extensions)) payload[0] = packetTypeInfo payload[5] = wispMajorVersion @@ -87,7 +70,10 @@ func parseClientInfo(payload []byte) (*extensions, error) { exts := &extensions{} data := payload[2:] - for len(data) >= 5 { + for len(data) > 0 { + if len(data) < 5 { + return nil, errorInvalid + } extID := data[0] extLen := binary.LittleEndian.Uint32(data[1:5]) data = data[5:] @@ -143,6 +129,7 @@ func parseClientInfo(payload []byte) (*extensions, error) { func (c *wispConnection) v2Handshake() { c.handshakeDone = make(chan struct{}) + _ = c.netConn.SetReadDeadline(time.Now().Add(v2HandshakeTimeout)) infoPayload := c.buildServerInfoPacket() c.sendRawFrame(infoPayload) @@ -154,6 +141,9 @@ func (c *wispConnection) handleInfo(streamId uint32, payload []byte) { if streamId != 0 { return } + if c.handshakeDone == nil { + return + } clientExts, err := parseClientInfo(payload) if err != nil { @@ -162,12 +152,15 @@ func (c *wispConnection) handleInfo(streamId uint32, payload []byte) { return } - authRequired := c.config.PasswordAuthRequired || c.config.CertAuthRequired + authRequired := c.config.PasswordAuthRequired authPassed := false if c.config.PasswordAuth && clientExts.passwordUsername != "" { expectedPassword, userExists := c.config.PasswordUsers[clientExts.passwordUsername] - if userExists && expectedPassword == clientExts.passwordPassword { + expBytes := []byte(expectedPassword) + gotBytes := []byte(clientExts.passwordPassword) + ok := userExists && len(expBytes) == len(gotBytes) && subtle.ConstantTimeCompare(expBytes, gotBytes) == 1 + if ok { authPassed = true } else { c.sendClosePacket(0, closeReasonAuthBadPassword) @@ -176,46 +169,21 @@ func (c *wispConnection) handleInfo(streamId uint32, payload []byte) { } } - if c.config.CertAuth && len(clientExts.certificateSig) > 0 && c.v2Challenge != nil { - if c.verifyCertificate(clientExts) { - authPassed = true - } else { - c.sendClosePacket(0, closeReasonAuthBadSignature) - c.close() - return - } - } - if authRequired && !authPassed { c.sendClosePacket(0, closeReasonAuthRequired) c.close() return } - c.streamConfirm = c.config.EnableStreamConfirm && clientExts.streamConfirm + c.authenticated.Store(authPassed) + c.streamConfirm = clientExts.streamConfirm c.sendPacket(0, c.config.BufferRemainingLength) + _ = c.netConn.SetReadDeadline(time.Time{}) close(c.handshakeDone) + c.handshakeDone = nil } - -func (c *wispConnection) verifyCertificate(exts *extensions) bool { - if exts.certificateSelected&sigEd25519 == 0 { - return false - } - - for _, pubKey := range c.config.CertAuthPublicKeys { - hash := sha256.Sum256([]byte(pubKey)) - if hash == exts.certificatePubkeyHash { - if ed25519.Verify(pubKey, c.v2Challenge, exts.certificateSig) { - return true - } - } - } - - return false -} - func (c *wispConnection) sendRawFrame(packet []byte) { totalLen := len(packet) var frame []byte diff --git a/wisp/windows.go b/wisp/windows.go new file mode 100644 index 0000000..c826c52 --- /dev/null +++ b/wisp/windows.go @@ -0,0 +1,62 @@ +//go:build windows + +package wisp + +import ( + "sync" + "sync/atomic" +) + +type twispStream struct { + wispConn *wispConnection + streamId uint32 + isOpen atomic.Bool +} + +type twispRegistry struct { + mu sync.RWMutex + streams map[uint32]*twispStream +} + +func newTwisp() *twispRegistry { + return &twispRegistry{ + streams: make(map[uint32]*twispStream), + } +} + +func (r *twispRegistry) add(id uint32, s *twispStream) { + r.mu.Lock() + r.streams[id] = s + r.mu.Unlock() +} + +func (r *twispRegistry) remove(id uint32) { + r.mu.Lock() + delete(r.streams, id) + r.mu.Unlock() +} + +func (r *twispRegistry) get(id uint32) *twispStream { + r.mu.RLock() + s := r.streams[id] + r.mu.RUnlock() + return s +} + +func handleTwisp(wc *wispConnection, streamId uint32, command string) { + wc.sendClosePacket(streamId, closeReasonBlocked) +} + +func (ts *twispStream) writePty(data []byte) error { + return nil +} + +func (ts *twispStream) resize(rows, cols uint16) {} + +func (ts *twispStream) close(reason uint8) { + if !ts.isOpen.CompareAndSwap(true, false) { + return + } + ts.wispConn.twispStreams.remove(ts.streamId) + ts.wispConn.sendClosePacket(ts.streamId, reason) +} diff --git a/wisp/wisp-connection.go b/wisp/wisp-connection.go index c2dd0d4..2603a74 100644 --- a/wisp/wisp-connection.go +++ b/wisp/wisp-connection.go @@ -4,17 +4,54 @@ import ( "encoding/binary" "net" "strconv" - "strings" "sync" "sync/atomic" + "time" "unsafe" + + prot "mrrowisp/wisp/protection" +) + +const ( + maxConnectsPerSecond = 20 + connectRateWindow = time.Second + minFramePoolCap = 64 * 1024 ) +type connectRateLimiter struct { + mutex sync.Mutex + windowStart time.Time + count int + limit int +} + +func newConnectRateLimiter(limit int) *connectRateLimiter { + if limit <= 0 { + limit = maxConnectsPerSecond + } + return &connectRateLimiter{windowStart: time.Now(), limit: limit} +} + +func (r *connectRateLimiter) allow() bool { + r.mutex.Lock() + defer r.mutex.Unlock() + now := time.Now() + if now.Sub(r.windowStart) >= connectRateWindow { + r.windowStart = now + r.count = 0 + } + r.count++ + return r.count <= r.limit +} + type writeReq struct { data []byte pool bool } +const maxConcurrentDials = 50 +const maxPendingStreamBytes = 16 * 1024 * 1024 + type wispConnection struct { netConn net.Conn writeCh chan writeReq @@ -22,35 +59,60 @@ type wispConnection struct { cachedStreamId uint32 cachedStream unsafe.Pointer isClosed atomic.Bool + shutdownOnce sync.Once config *Config twispStreams *twispRegistry + connectLimiter *connectRateLimiter + remoteIP string isV2 bool handshakeDone chan struct{} streamConfirm bool v2Challenge []byte + authenticated atomic.Bool + + dialSem chan struct{} + closeCh chan struct{} + createdAt time.Time + packetLimiter *prot.PacketRateLimiter + inboundLimiter *prot.InboundRateLimiter + streamCount atomic.Int32 } func (c *wispConnection) close() { - if !c.isClosed.CompareAndSwap(false, true) { - return - } - c.netConn.Close() + c.shutdownOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeCh) + c.netConn.Close() + }) } func (c *wispConnection) writeLoop() { for req := range c.writeCh { - bufs := net.Buffers{req.data} + reqs := []writeReq{req} n := len(c.writeCh) for i := 0; i < n; i++ { - r := <-c.writeCh + reqs = append(reqs, <-c.writeCh) + } + bufs := make(net.Buffers, 0, len(reqs)) + for _, r := range reqs { bufs = append(bufs, r.data) } + // if cfg.config != nil { + // _ = cfg.netConn.SetWriteDeadline(time.Now().Add(cfg.config.WriteTimeout)) + // } if _, err := bufs.WriteTo(c.netConn); err != nil { - c.isClosed.Store(true) - c.netConn.Close() + c.close() return } + // if cfg.config != nil && cfg.config.WriteTimeout > 0 { + // _ = cfg.netConn.SetWriteDeadline(time.Time{}) + // } + for _, r := range reqs { + if r.pool { + c.releaseFrame(r.data) + } + } } } @@ -61,7 +123,43 @@ func (c *wispConnection) queueWrite(data []byte) { defer func() { recover() }() - c.writeCh <- writeReq{data: data} + select { + case c.writeCh <- writeReq{data: data}: + case <-c.closeCh: + return + } +} + +func (c *wispConnection) queueWritePooled(data []byte) { + if c.isClosed.Load() { + c.releaseFrame(data) + return + } + defer func() { + if recover() != nil { + c.releaseFrame(data) + } + }() + select { + case c.writeCh <- writeReq{data: data, pool: true}: + case <-c.closeCh: + c.releaseFrame(data) + return + } +} + +func (c *wispConnection) releaseFrame(data []byte) { + if c.config == nil || len(data) == 0 { + return + } + if cap(data) < minFramePoolCap { + return + } + buf := data + if len(buf) != cap(buf) { + buf = data[:cap(data)] + } + // cfg.config.FramePool.Put(buf) } func (c *wispConnection) handlePacket(packetType uint8, streamId uint32, payload []byte) { @@ -90,15 +188,18 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { if len(payload) < 3 { return } + guard := newProtection(c.config) streamType := payload[0] port := strconv.FormatUint(uint64(binary.LittleEndian.Uint16(payload[1:3])), 10) hostname := string(payload[3:]) - if streamType == streamTypeTerm { - if !c.config.EnableTwisp { - c.sendClosePacket(streamId, closeReasonBlocked) - return - } + c.config.Logger.Debug("creating stream", "ip", c.remoteIP, "streamId", streamId, "hostname", hostname, "port", port, "type", streamType) + action, normalizedHostname, reason := guard.allowConnect(c, streamType, hostname, port) + if action == connectBlocked { + c.sendClosePacket(streamId, reason) + return + } + if action == connectTwisp { go handleTwisp(c, streamId, hostname) return } @@ -107,7 +208,7 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { wispConn: c, streamId: streamId, connReady: make(chan struct{}), - hostname: strings.ToLower(strings.TrimSpace(hostname)), + hostname: normalizedHostname, } stream.isOpen.Store(true) @@ -116,10 +217,24 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { return } - go stream.handleConnect(streamType, port, hostname) + c.streamCount.Add(1) + go stream.handleConnect(streamType, port, normalizedHostname) } func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { + guard := newProtection(c.config) + if c.packetLimiter != nil && !c.packetLimiter.Allow() { + c.sendClosePacket(streamId, closeReasonThrottled) + return + } + if c.inboundLimiter != nil && !c.inboundLimiter.Allow(len(payload)) { + c.sendClosePacket(streamId, closeReasonThrottled) + return + } + if !guard.allowMessageSize(len(payload)) { + c.sendClosePacket(streamId, closeReasonInvalidInfo) + return + } var stream *wispStream if c.cachedStreamId == streamId { stream = (*wispStream)(atomic.LoadPointer(&c.cachedStream)) @@ -136,7 +251,7 @@ func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { return } } - go c.sendClosePacket(streamId, closeReasonInvalidInfo) + c.sendClosePacket(streamId, closeReasonInvalidInfo) return } stream = v.(*wispStream) @@ -148,14 +263,21 @@ func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { return } + stream.pendingMutex.Lock() if !stream.connReadyDone.Load() { + if stream.pendingBytes+len(payload) > maxPendingStreamBytes { + stream.pendingMutex.Unlock() + stream.close(closeReasonThrottled) + return + } dataCopy := make([]byte, len(payload)) copy(dataCopy, payload) - stream.pendingMutex.Lock() stream.pendingData = append(stream.pendingData, dataCopy) + stream.pendingBytes += len(dataCopy) stream.pendingMutex.Unlock() return } + stream.pendingMutex.Unlock() _, err := stream.conn.Write(payload) if err != nil { @@ -166,12 +288,16 @@ func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { if stream.streamType == streamTypeTCP { stream.bufferRemaining-- if stream.bufferRemaining == 0 { - stream.bufferRemaining = c.config.BufferRemainingLength + // stream.bufferRemaining = c.config.BufferRemainingLength c.sendPacket(streamId, stream.bufferRemaining) } } } +func (c *wispConnection) twispAuthorized() bool { + return c.isV2 && c.authenticated.Load() +} + func (c *wispConnection) handleClosePacket(streamId uint32, payload []byte) { if len(payload) < 1 { return @@ -241,21 +367,27 @@ func (c *wispConnection) deleteWispStream(streamId uint32) { if c.cachedStreamId == streamId { atomic.StorePointer(&c.cachedStream, nil) } + c.streamCount.Add(-1) } func (c *wispConnection) deleteAllWispStreams() { - c.isClosed.Store(true) + c.close() + c.config.Logger.Info("connection closed", "ip", c.remoteIP) c.streams.Range(func(key, value any) bool { stream := value.(*wispStream) stream.close(closeReasonUnspecified) return true }) if c.twispStreams != nil { - c.twispStreams.mu.RLock() + c.twispStreams.mu.Lock() + streams := make([]*twispStream, 0, len(c.twispStreams.streams)) for _, ts := range c.twispStreams.streams { + streams = append(streams, ts) + } + c.twispStreams.mu.Unlock() + for _, ts := range streams { ts.close(closeReasonUnspecified) } - c.twispStreams.mu.RUnlock() } defer func() { recover() }() close(c.writeCh) diff --git a/wisp/wisp-stream.go b/wisp/wisp-stream.go index ed9a92e..9d73d71 100644 --- a/wisp/wisp-stream.go +++ b/wisp/wisp-stream.go @@ -7,6 +7,9 @@ import ( "strings" "sync" "sync/atomic" + "time" + + prot "mrrowisp/wisp/protection" "golang.org/x/net/proxy" ) @@ -27,40 +30,58 @@ type wispStream struct { pendingMutex sync.Mutex pendingData [][]byte + pendingBytes int } +const dnsLookupTimeout = 10 * time.Second + func (s *wispStream) handleConnect(streamType uint8, port string, hostname string) { defer s.signalConnReady() cfg := s.wispConn.config - s.hostname = strings.ToLower(strings.TrimSpace(hostname)) + s.hostname = prot.NormalizeTargetHostname(hostname) + if s.hostname == "" { + s.close(closeReasonInvalidInfo) + return + } - if len(cfg.Whitelist.Hostnames) > 0 { - if _, ok := cfg.Whitelist.Hostnames[s.hostname]; !ok { - s.close(closeReasonBlocked) + guard := newProtection(cfg) + + if reason, ok := guard.allowHostPort(s.hostname, port); !ok { + s.close(reason) + return + } + + resolvedHostname := s.hostname + + if ip := net.ParseIP(resolvedHostname); ip != nil { + if reason, ok := guard.allowDirectIP(ip, s.wispConn.remoteIP, s.hostname); !ok { + s.close(reason) return } - } else if len(cfg.Blacklist.Hostnames) > 0 { - if _, ok := cfg.Blacklist.Hostnames[s.hostname]; ok { - s.close(closeReasonBlocked) + resolvedHostname = ip.String() + } else if cfg.Proxy != "" { + resolvedHostname = s.hostname + } else if cfg.DNSCache != nil { + ctx, cancel := context.WithTimeout(context.Background(), dnsLookupTimeout) + ips, err := cfg.DNSCache.LookupIPAddr(ctx, resolvedHostname) + cancel() + if err != nil { + cfg.Logger.Warn("DNS lookup failed", "ip", s.wispConn.remoteIP, "hostname", resolvedHostname, "error", err) + s.close(closeReasonUnreachable) return } - } - - resolvedHostname := hostname - if cfg.DNSCache != nil { - if _, whitelisted := cfg.Whitelist.Hostnames[hostname]; !whitelisted { - ips, err := cfg.DNSCache.LookupIPAddr(context.Background(), hostname) - if err != nil { - s.close(closeReasonUnreachable) - return - } - if len(ips) == 0 { - s.close(closeReasonUnreachable) - return - } - resolvedHostname = ips[0].IP.String() + if len(ips) == 0 { + cfg.Logger.Warn("DNS returned no results", "ip", s.wispConn.remoteIP, "hostname", resolvedHostname) + s.close(closeReasonUnreachable) + return } + selected, reason, ok := guard.selectAllowedIP(ips, s.wispConn.remoteIP, resolvedHostname) + if !ok { + s.close(reason) + return + } + resolvedHostname = selected } s.streamType = streamType @@ -71,18 +92,29 @@ func (s *wispStream) handleConnect(streamType uint8, port string, hostname strin var err error switch streamType { case streamTypeTCP: + select { + case s.wispConn.dialSem <- struct{}{}: + case <-s.wispConn.closeCh: + return + } if cfg.Proxy != "" { - dialer, proxyErr := proxy.SOCKS5("tcp", cfg.Proxy, nil, proxy.Direct) + proxyURL := cfg.Proxy + proxyURL = strings.Replace(proxyURL, "socks5h://", "socks5://", 1) + proxyURL = strings.Replace(proxyURL, "socks4a://", "socks4://", 1) + dialer, proxyErr := proxy.SOCKS5("tcp", stripScheme(proxyURL), nil, proxy.Direct) if proxyErr != nil { + <-s.wispConn.dialSem + cfg.Logger.Warn("proxy dialer creation failed", "ip", s.wispConn.remoteIP, "error", proxyErr) s.close(closeReasonNetworkError) return } - s.conn, err = dialer.Dial("tcp", destination) + s.conn, err = dialer.Dial("tcp", net.JoinHostPort(s.hostname, port)) } else { s.conn, err = cfg.Dialer.Dial("tcp", destination) } + <-s.wispConn.dialSem case streamTypeUDP: - if cfg.DisableUDP || cfg.Proxy != "" { + if cfg.Proxy != "" || !cfg.AllowUDP { s.close(closeReasonBlocked) return } @@ -93,6 +125,7 @@ func (s *wispStream) handleConnect(streamType uint8, port string, hostname strin } if err != nil { + cfg.Logger.Warn("stream connection failed", "ip", s.wispConn.remoteIP, "hostname", hostname, "port", port, "error", err) s.close(mapDialError(err)) return } @@ -109,12 +142,12 @@ func (s *wispStream) handleConnect(streamType uint8, port string, hostname strin s.wispConn.sendPacket(s.streamId, s.bufferRemaining) } - s.signalConnReady() - s.pendingMutex.Lock() pending := s.pendingData s.pendingData = nil + s.pendingBytes = 0 s.pendingMutex.Unlock() + for _, data := range pending { if !s.isOpen.Load() { return @@ -125,9 +158,19 @@ func (s *wispStream) handleConnect(streamType uint8, port string, hostname strin } } + // Signal ready only after all pending data has been written in order. + s.signalConnReady() + s.readFromConnection() } +func stripScheme(url string) string { + if idx := strings.Index(url, "://"); idx >= 0 { + return url[idx+3:] + } + return url +} + func (s *wispStream) signalConnReady() { if s.connReadyDone.CompareAndSwap(false, true) { close(s.connReady) @@ -136,9 +179,7 @@ func (s *wispStream) signalConnReady() { func (s *wispStream) readFromConnection() { const maxHeaderLen = 15 - bufp := s.wispConn.config.ReadBufPool.Get().(*[]byte) - buf := *bufp - defer s.wispConn.config.ReadBufPool.Put(bufp) + buf := make([]byte, maxHeaderLen+65535) streamId := s.streamId @@ -162,10 +203,10 @@ func (s *wispStream) readFromConnection() { frameStart = 0 buf[0] = 0x82 buf[1] = 127 - buf[2] = 0 - buf[3] = 0 - buf[4] = 0 - buf[5] = 0 + buf[2] = byte(totalPayload >> 56) + buf[3] = byte(totalPayload >> 48) + buf[4] = byte(totalPayload >> 40) + buf[5] = byte(totalPayload >> 32) buf[6] = byte(totalPayload >> 24) buf[7] = byte(totalPayload >> 16) buf[8] = byte(totalPayload >> 8) @@ -181,12 +222,13 @@ func (s *wispStream) readFromConnection() { frame := make([]byte, maxHeaderLen+n-frameStart) copy(frame, buf[frameStart:maxHeaderLen+n]) - s.wispConn.queueWrite(frame) + s.wispConn.queueWritePooled(frame) } if err != nil { if err == io.EOF { s.close(closeReasonVoluntary) } else { + s.wispConn.config.Logger.Warn("stream read error", "ip", s.wispConn.remoteIP, "hostname", s.hostname, "error", err) s.close(closeReasonNetworkError) } return diff --git a/wisp/wisp.go b/wisp/wisp.go index fd1ee3d..dc3de61 100644 --- a/wisp/wisp.go +++ b/wisp/wisp.go @@ -1,65 +1,38 @@ package wisp import ( - "crypto/ed25519" "net" "net/http" - "sync" + "strings" "time" + "mrrowisp/wisp/protection" + "github.com/lxzan/gws" ) -type Config struct { - DisableUDP bool - - TcpBufferSize int - BufferRemainingLength uint32 - TcpNoDelay bool - WebsocketTcpNoDelay bool - - Blacklist struct { - Hostnames map[string]struct{} - } - Whitelist struct { - Hostnames map[string]struct{} - } - - Proxy string - WebsocketPermessageDeflate bool - - DnsServers []string - - EnableTwisp bool - - EnableV2 bool - Motd string - PasswordAuth bool - PasswordAuthRequired bool - PasswordUsers map[string]string - CertAuth bool - CertAuthRequired bool - CertAuthPublicKeys []ed25519.PublicKey - EnableStreamConfirm bool - - DNSCache *DNSCache - ReadBufPool sync.Pool - Dialer net.Dialer -} - -func DefaultConfig() *Config { - return &Config{ - DisableUDP: false, - TcpBufferSize: 32768, - BufferRemainingLength: 65536, - TcpNoDelay: true, - WebsocketTcpNoDelay: true, - PasswordUsers: make(map[string]string), - } -} +const ( + defaultStreamLimitPerHost = 512 + defaultStreamLimitTotal = 16384 + defaultMaxConnectsPerSecond = 20 + defaultConnectionsLimitPerIP = 120 + defaultHandshakeFailures = 10 +) -func (c *Config) InitResolver() { - c.DNSCache = NewDNSCache(c.DnsServers) +func (cfg *Config) InitResolver() { + cfg.DNSCache = NewDNSCache( + DNSCacheConfig{ + Servers: cfg.DnsServers, + Method: cfg.DnsMethod, + ResultOrder: cfg.DnsResultOrder, + }) + // if cfg.BandwidthLimitKbps > 0 { + // cfg.BandwidthLimiter = protection.NewBandwidthLimiter(cfg.BandwidthLimitKbps, time.Duration(cfg.ConnectionWindowSeconds)*time.Second) + // } + // if cfg.ConnectionsLimitPerIP > 0 { + // cfg.ConnectionLimiter = protection.NewConnectionLimiter(cfg.ConnectionsLimitPerIP, time.Duration(cfg.ConnectionWindowSeconds)*time.Second) + // } + cfg.Logger = newLogger(cfg.LogLevel) } type upgradeHandler struct { @@ -69,51 +42,67 @@ type upgradeHandler struct { func CreateWispHandler(config *Config) http.HandlerFunc { config.InitResolver() - readBufSize := 15 + config.TcpBufferSize - config.ReadBufPool = sync.Pool{ - New: func() any { - buf := make([]byte, readBufSize) - return &buf - }, - } - - config.Dialer = net.Dialer{ - Timeout: 15 * time.Second, - KeepAlive: 30 * time.Second, - } - upgrader := gws.NewUpgrader(&upgradeHandler{}, &gws.ServerOption{ PermessageDeflate: gws.PermessageDeflate{ - Enabled: config.WebsocketPermessageDeflate, + Enabled: false, }, }) + guard := newProtection(config) + return func(w http.ResponseWriter, r *http.Request) { useV2 := config.EnableV2 && r.Header.Get("Sec-WebSocket-Protocol") != "" + remoteIP := protection.RemoteIPFromRequest(r, protection.IPConfig{ + AllowDirectIP: config.AllowDirectIP, + AllowPrivateIPs: config.AllowPrivateIPs, + AllowLoopbackIPs: config.AllowLoopbackIPs, + ParseRealIP: config.ParseRealIP, + }) + config.Logger.Info("incoming connection", "ip", remoteIP, "path", r.URL.Path, "origin", r.Header.Get("Origin")) + if config.requiresV2() && !useV2 { + config.Logger.Warn("v2 required but not negotiated", "ip", remoteIP) + w.WriteHeader(http.StatusUnauthorized) + return + } + + if status, response, ok := guard.allowHTTP(r, remoteIP, useV2); !ok { + w.WriteHeader(status) + if response != "" { + _, _ = w.Write([]byte(response)) + } + return + } wsConn, err := upgrader.Upgrade(w, r) if err != nil { + if config.NonWSResponse != "" { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(config.NonWSResponse)) + } + config.Logger.Debug("websocket upgrade failed", "error", err) return } netConn := wsConn.NetConn() if tc, ok := netConn.(*net.TCPConn); ok { - if config.WebsocketTcpNoDelay { - tc.SetNoDelay(true) - } tc.SetReadBuffer(1 << 20) tc.SetWriteBuffer(1 << 20) } wc := &wispConnection{ - netConn: netConn, - writeCh: make(chan writeReq, 4096), // funny number + netConn: netConn, + // writeCh: make(chan writeReq, writeQSize), config: config, twispStreams: newTwisp(), isV2: useV2, + remoteIP: remoteIP, + dialSem: make(chan struct{}, maxConcurrentDials), + closeCh: make(chan struct{}), + createdAt: time.Now(), } + config.Logger.Info("connection established", "ip", remoteIP, "v2", useV2) go wc.writeLoop() if useV2 { @@ -124,3 +113,26 @@ func CreateWispHandler(config *Config) http.HandlerFunc { } } } + +func (cfg *Config) requiresV2() bool { + if cfg == nil { + return false + } + return cfg.PasswordAuthRequired || cfg.EnableTwisp +} + +func originAllowed(r *http.Request, allowedOrigins []string) bool { + if len(allowedOrigins) == 0 { + return true + } + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin == "" { + return false + } + for _, allowed := range allowedOrigins { + if origin == strings.TrimSpace(allowed) { + return true + } + } + return false +} diff --git a/wisp/wsreader.go b/wisp/wsreader.go index 4ab00e6..df9b66f 100644 --- a/wisp/wsreader.go +++ b/wisp/wsreader.go @@ -16,14 +16,24 @@ func (c *wispConnection) readLoop() { var headerBuffer [14]byte for { + // if c.config != nil && c.config.FrameReadTimeout > 0 { + // _ = c.netConn.SetReadDeadline(time.Now().Add(c.config.FrameReadTimeout)) + // } if _, err := io.ReadFull(reader, headerBuffer[:2]); err != nil { return } - data := headerBuffer[0] & 0x0F + fin := headerBuffer[0]&0x80 != 0 + rsv := headerBuffer[0] & 0x70 + opcode := headerBuffer[0] & 0x0F masked := headerBuffer[1]&0x80 != 0 lengthCode := headerBuffer[1] & 0x7F + if rsv != 0 || !masked || !fin { + c.sendWSClose(1002) + return + } + var payloadLen uint64 switch { case lengthCode <= 125: @@ -40,6 +50,12 @@ func (c *wispConnection) readLoop() { payloadLen = binary.BigEndian.Uint64(headerBuffer[2:10]) } + isControlFrame := opcode >= 0x8 + if isControlFrame && payloadLen > 125 { + c.sendWSClose(1002) + return + } + var maskKey [4]byte if masked { if _, err := io.ReadFull(reader, maskKey[:]); err != nil { @@ -47,6 +63,11 @@ func (c *wispConnection) readLoop() { } } + if payloadLen > c.maxPayloadSize() { + c.sendWSClose(1009) + return + } + var payload []byte if payloadLen <= PayloadBufferSize { payload = PayloadBuffer[:payloadLen] @@ -64,7 +85,7 @@ func (c *wispConnection) readLoop() { maskXOR(payload, maskKey) } - switch data { + switch opcode { case 0x2: c.handleWispFrame(payload) @@ -87,9 +108,21 @@ func (c *wispConnection) readLoop() { c.handleWispFrame(payload) default: + if opcode != 0x0 { + } continue } + + } +} + +const DefaultMaxPayloadSize = 256 * 1024 + +func (c *wispConnection) maxPayloadSize() uint64 { + if c != nil && c.config != nil && c.config.MaxMessageSize > 0 { + return uint64(c.config.MaxMessageSize) } + return DefaultMaxPayloadSize } func (c *wispConnection) handleWispFrame(packet []byte) { From 568c34b3983b9d82ba4e3390e4e9a6cb28e1c535 Mon Sep 17 00:00:00 2001 From: Sophia <193290223+soap-phia@users.noreply.github.com> Date: Tue, 19 May 2026 02:37:02 +0000 Subject: [PATCH 2/2] roll back prot features --- main.go | 2 +- wisp/config.go | 31 +++--- wisp/protection.go | 173 ------------------------------- wisp/protection/banlist.go | 95 ----------------- wisp/protection/iputil.go | 150 --------------------------- wisp/protection/limits.go | 177 -------------------------------- wisp/protection/streamlimits.go | 38 ------- wisp/twisp.go | 4 +- wisp/v2.go | 8 +- wisp/wisp-connection.go | 114 +++++--------------- wisp/wisp-stream.go | 78 ++++++-------- wisp/wisp.go | 75 +++----------- wisp/wsreader.go | 3 - 13 files changed, 91 insertions(+), 857 deletions(-) delete mode 100644 wisp/protection.go delete mode 100644 wisp/protection/banlist.go delete mode 100644 wisp/protection/iputil.go delete mode 100644 wisp/protection/limits.go delete mode 100644 wisp/protection/streamlimits.go diff --git a/main.go b/main.go index 31d5e41..072ab82 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ func main() { cfg.AllowLoopbackIPs = *fAllowLoopbackIPs } - wispConfig := wisp.CreateWispConfig(cfg) + wispConfig := wisp.CreateWispConfig(&cfg) wispHandler := wisp.CreateWispHandler(wispConfig) diff --git a/wisp/config.go b/wisp/config.go index c47dadc..8e4f85d 100644 --- a/wisp/config.go +++ b/wisp/config.go @@ -5,6 +5,7 @@ import ( "net" "os" "strings" + "sync" ) type FilterList struct { @@ -25,8 +26,16 @@ type Config struct { TcpBufferSize int `json:"tcpBufferSize"` TcpNoDelay bool `json:"tcpNoDelay"` - Blacklist FilterList `json:"blacklist"` - Whitelist FilterList `json:"whitelist"` + Blacklist struct { + Hostnames map[string]struct{} + Ports map[uint16]struct{} + } + Whitelist struct { + Hostnames map[string]struct{} + Ports map[uint16]struct{} + } + + WebsocketPermessageDeflate bool DnsServers []string `json:"dnsServers"` DnsMethod string `json:"dnsMethod"` @@ -54,9 +63,10 @@ type Config struct { BufferRemainingLength uint32 `json:"bufferRemainingLength"` - Logger Logger - DNSCache *DNSCache - Dialer net.Dialer + Logger Logger + DNSCache *DNSCache + ReadBufPool *sync.Pool + Dialer net.Dialer } func DefaultConfig() Config { @@ -73,15 +83,6 @@ func DefaultConfig() Config { TcpBufferSize: 32768, TcpNoDelay: true, - Blacklist: FilterList{ - Hostnames: []string{}, - Ports: []interface{}{}, - }, - Whitelist: FilterList{ - Hostnames: []string{}, - Ports: []interface{}{}, - }, - DnsServers: []string{}, DnsMethod: "resolve", DnsResultOrder: "ipv4first", @@ -109,7 +110,7 @@ func DefaultConfig() Config { } } -func CreateWispConfig(cfg Config) *Config { +func CreateWispConfig(cfg *Config) *Config { wispCfg := &Config{ AllowTCP: cfg.AllowTCP, AllowUDP: cfg.AllowUDP, diff --git a/wisp/protection.go b/wisp/protection.go deleted file mode 100644 index b033fe8..0000000 --- a/wisp/protection.go +++ /dev/null @@ -1,173 +0,0 @@ -package wisp - -import ( - "net" - "net/http" - "strconv" - "strings" - - prot "mrrowisp/wisp/protection" -) - -type guard struct { - config *Config -} - -type connectAction uint8 - -const ( - connectBlocked connectAction = iota - connectStream - connectTwisp -) - -func newProtection(config *Config) *guard { - return &guard{config: config} -} - -func (p *guard) allowHTTP(r *http.Request, remoteIP string, useV2 bool) (int, string, bool) { - cfg := p.config - - if !isWebsocketUpgrade(r) { - if cfg.NonWSResponse != "" { - return http.StatusOK, cfg.NonWSResponse, false - } - return http.StatusBadRequest, "", false - } - - if !useV2 && cfg.requiresV2() { - cfg.Logger.Warn("websocket v1 downgrade blocked", "ip", remoteIP) - return http.StatusUnauthorized, cfg.NonWSResponse, false - } - - return 0, "", true -} - -func isWebsocketUpgrade(r *http.Request) bool { - return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") && - strings.ToLower(r.Header.Get("Upgrade")) == "websocket" -} - -func (p *guard) allowConnect(c *wispConnection, streamType uint8, hostname string, port string) (connectAction, string, uint8) { - cfg := p.config - - if len(hostname) > 2048 || strings.IndexByte(hostname, 0) >= 0 { - return connectBlocked, "", closeReasonInvalidInfo - } - - if !c.connectLimiter.allow() { - cfg.Logger.Warn("connect rate limit exceeded", "ip", c.remoteIP) - return connectBlocked, "", closeReasonThrottled - } - - if streamType == streamTypeTerm { - if !cfg.EnableTwisp || !c.twispAuthorized() { - cfg.Logger.Warn("terminal stream blocked", "ip", c.remoteIP) - return connectBlocked, "", closeReasonBlocked - } - return connectTwisp, "", 0 - } - - if streamType == streamTypeTCP && !cfg.AllowTCP { - cfg.Logger.Warn("TCP streams blocked", "ip", c.remoteIP, "hostname", hostname) - return connectBlocked, "", closeReasonBlocked - } - if streamType == streamTypeUDP && !cfg.AllowUDP { - cfg.Logger.Warn("UDP streams blocked", "ip", c.remoteIP, "hostname", hostname) - return connectBlocked, "", closeReasonBlocked - } - - normalizedHostname := prot.NormalizeTargetHostname(hostname) - if normalizedHostname == "" { - return connectBlocked, "", closeReasonInvalidInfo - } - - if !cfg.AllowLoopbackIPs && prot.IsOwnIP(normalizedHostname) { - cfg.Logger.Warn("self-targeting stream blocked", "ip", c.remoteIP, "hostname", hostname) - return connectBlocked, "", closeReasonBlocked - } - - return connectStream, normalizedHostname, 0 -} - -func (p *guard) allowHostPort(hostname string, port string) (uint8, bool) { - cfg := p.config - - portNum, err := strconv.Atoi(port) - if err != nil { - return closeReasonInvalidInfo, false - } - - if len(cfg.Whitelist.Ports) > 0 { - allowed := false - type portContains interface{ Contains(int) bool } - for _, r := range cfg.Whitelist.Ports { - if c, ok := r.(portContains); ok { - if c.Contains(portNum) { - allowed = true - break - } - } - } - if !allowed { - return closeReasonBlocked, false - } - } else if len(cfg.Blacklist.Ports) > 0 { - type portContains interface{ Contains(int) bool } - for _, r := range cfg.Blacklist.Ports { - if c, ok := r.(portContains); ok { - if c.Contains(portNum) { - return closeReasonBlocked, false - } - } - } - } - - return 0, true -} - -func (p *guard) allowDirectIP(ip net.IP, remoteIP string, hostname string) (uint8, bool) { - cfg := p.config - - if !cfg.AllowDirectIP { - return closeReasonBlocked, false - } - if !prot.IsAllowedTargetIP(ip, prot.IPConfig{ - AllowDirectIP: cfg.AllowDirectIP, - AllowPrivateIPs: cfg.AllowPrivateIPs, - AllowLoopbackIPs: cfg.AllowLoopbackIPs, - }) { - return closeReasonBlocked, false - } - if !cfg.AllowLoopbackIPs && prot.IsOwnIP(ip.String()) { - cfg.Logger.Warn("self-targeting stream blocked", "ip", remoteIP, "hostname", hostname) - return closeReasonBlocked, false - } - - return 0, true -} - -func (p *guard) selectAllowedIP(ips []net.IPAddr, remoteIP string, hostname string) (string, uint8, bool) { - cfg := p.config - - selected, ok := prot.FirstAllowedIP(ips, prot.IPConfig{ - AllowDirectIP: cfg.AllowDirectIP, - AllowPrivateIPs: cfg.AllowPrivateIPs, - AllowLoopbackIPs: cfg.AllowLoopbackIPs, - }) - if !ok { - cfg.Logger.Warn("DNS returned only blocked IPs", "ip", remoteIP, "hostname", hostname) - return "", closeReasonBlocked, false - } - if !cfg.AllowLoopbackIPs && prot.IsOwnIP(selected) { - cfg.Logger.Warn("self-targeting stream blocked", "ip", remoteIP, "hostname", hostname) - return "", closeReasonBlocked, false - } - - return selected, 0, true -} - -func (p *guard) allowMessageSize(size int) bool { - max := p.config.MaxMessageSize - return max <= 0 || size <= max -} diff --git a/wisp/protection/banlist.go b/wisp/protection/banlist.go deleted file mode 100644 index 0d7f50c..0000000 --- a/wisp/protection/banlist.go +++ /dev/null @@ -1,95 +0,0 @@ -package protection - -import ( - "net" - "net/http" - "sync" - "time" -) - -type BanList struct { - mutex sync.RWMutex - bans map[string]time.Time - banDur time.Duration - strikes map[string]int - maxStrikes int - escalationMultiplier int -} - -func NewBanList(banDuration time.Duration, maxStrikes int) *BanList { - return NewBanListEscalated(banDuration, maxStrikes, 0) -} - -func NewBanListEscalated(banDuration time.Duration, maxStrikes int, escalation int) *BanList { - if banDuration <= 0 { - banDuration = time.Hour - } - if maxStrikes <= 0 { - maxStrikes = 10 - } - b := &BanList{ - bans: make(map[string]time.Time), - strikes: make(map[string]int), - mutex: sync.RWMutex{}, - banDur: banDuration, - maxStrikes: maxStrikes, - escalationMultiplier: escalation, - } - go b.cleanup() - return b -} - -func (b *BanList) Strike(ip string) (banned bool) { - b.mutex.Lock() - defer b.mutex.Unlock() - b.strikes[ip]++ - if b.strikes[ip] >= b.maxStrikes { - dur := b.banDur - if b.escalationMultiplier > 0 { - for i := 1; i < b.strikes[ip]/b.maxStrikes; i++ { - dur *= time.Duration(b.escalationMultiplier) - } - } - b.bans[ip] = time.Now().Add(dur) - delete(b.strikes, ip) - return true - } - return false -} - -func (b *BanList) IsBanned(ip string) bool { - b.mutex.RLock() - defer b.mutex.RUnlock() - unbanAt, exists := b.bans[ip] - if !exists { - return false - } - return time.Now().Before(unbanAt) -} - -func (b *BanList) cleanup() { - for range time.Tick(5 * time.Minute) { - b.mutex.Lock() - now := time.Now() - for ip, unbanAt := range b.bans { - if now.After(unbanAt) { - delete(b.bans, ip) - } - } - b.mutex.Unlock() - } -} - -func (b *BanList) Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - ip = r.RemoteAddr - } - if b.IsBanned(ip) { - http.Error(w, "banned", http.StatusForbidden) - return - } - next.ServeHTTP(w, r) - }) -} diff --git a/wisp/protection/iputil.go b/wisp/protection/iputil.go deleted file mode 100644 index 5ba0316..0000000 --- a/wisp/protection/iputil.go +++ /dev/null @@ -1,150 +0,0 @@ -package protection - -import ( - "net" - "net/http" - "strings" -) - -type IPConfig struct { - AllowDirectIP bool - AllowPrivateIPs bool - AllowLoopbackIPs bool - ParseRealIP bool - ParseRealIPFrom map[string]struct{} -} - -func RemoteIPFromRequest(r *http.Request, cfg IPConfig) string { - if r == nil { - return "" - } - if cfg.ParseRealIP { - if ip := parseForwardedIP(r, cfg.ParseRealIPFrom); ip != "" { - return ip - } - } - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err == nil { - return host - } - return r.RemoteAddr -} - -func parseForwardedIP(r *http.Request, allowed map[string]struct{}) string { - if r == nil { - return "" - } - if len(allowed) == 0 { - return "" - } - - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return "" - } - if !isIPAllowed(host, allowed) { - return "" - } - - xff := r.Header.Get("X-Forwarded-For") - if xff != "" { - parts := strings.Split(xff, ",") - if len(parts) > 0 { - ip := strings.TrimSpace(parts[0]) - if net.ParseIP(ip) != nil { - return ip - } - } - } - - xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")) - if xrip != "" && net.ParseIP(xrip) != nil { - return xrip - } - - return "" -} - -func isIPAllowed(ip string, allowed map[string]struct{}) bool { - if _, ok := allowed[ip]; ok { - return true - } - parsed := net.ParseIP(ip) - if parsed == nil { - return false - } - for entry := range allowed { - _, cidr, err := net.ParseCIDR(entry) - if err == nil && cidr.Contains(parsed) { - return true - } - } - return false -} - -func IsAllowedTargetIP(ip net.IP, cfg IPConfig) bool { - if ip == nil { - return false - } - if ip.IsUnspecified() || ip.IsMulticast() { - return false - } - if isCarrierGradeNAT(ip) || isBenchmarkingIP(ip) { - return false - } - if !cfg.AllowLoopbackIPs && ip.IsLoopback() { - return false - } - if !cfg.AllowPrivateIPs && (ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()) { - return false - } - return true -} - -func FirstAllowedIP(ips []net.IPAddr, cfg IPConfig) (string, bool) { - for _, addr := range ips { - if IsAllowedTargetIP(addr.IP, cfg) { - return addr.IP.String(), true - } - } - return "", false -} - -func isCarrierGradeNAT(ip net.IP) bool { - ipv4 := ip.To4() - if ipv4 == nil { - return false - } - return ipv4[0] == 100 && ipv4[1]&0xC0 == 0x40 -} - -func isBenchmarkingIP(ip net.IP) bool { - ipv4 := ip.To4() - if ipv4 == nil { - return false - } - return ipv4[0] == 198 && (ipv4[1] == 18 || ipv4[1] == 19) -} - -func NormalizeTargetHostname(host string) string { - host = strings.TrimSpace(strings.ToLower(host)) - host = strings.TrimSuffix(host, ".") - return host -} - -func IsOwnIP(resolvedIP string) bool { - ifaces, err := net.Interfaces() - if err != nil { - return false - } - for _, iface := range ifaces { - ifaceAddrs, _ := iface.Addrs() - for _, ifaceAddr := range ifaceAddrs { - ip, _, _ := net.ParseCIDR(ifaceAddr.String()) - if ip != nil && ip.String() == resolvedIP { - return true - } - } - } - return false -} diff --git a/wisp/protection/limits.go b/wisp/protection/limits.go deleted file mode 100644 index d7279d4..0000000 --- a/wisp/protection/limits.go +++ /dev/null @@ -1,177 +0,0 @@ -package protection - -import ( - "sync" - "time" -) - -type BandwidthLimiter struct { - mu sync.Mutex - window time.Duration - bytes map[string]uint64 - start time.Time - limit uint64 -} - -func NewBandwidthLimiter(kbps int, window time.Duration) *BandwidthLimiter { - if window <= 0 { - window = time.Second - } - limit := uint64(kbps) * 1024 - return &BandwidthLimiter{window: window, start: time.Now(), limit: limit, bytes: make(map[string]uint64)} -} - -func (b *BandwidthLimiter) Allow(ip string, n uint64) bool { - if b == nil || b.limit == 0 { - return true - } - b.mu.Lock() - defer b.mu.Unlock() - now := time.Now() - if now.Sub(b.start) >= b.window { - b.start = now - b.bytes = make(map[string]uint64) - } - used := b.bytes[ip] - if used+n > b.limit { - return false - } - b.bytes[ip] = used + n - return true -} - -type ConnectionLimiter struct { - mu sync.Mutex - window time.Duration - start time.Time - counts map[string]int - limit int -} - -func NewConnectionLimiter(limit int, window time.Duration) *ConnectionLimiter { - if window <= 0 { - window = time.Second - } - return &ConnectionLimiter{window: window, start: time.Now(), limit: limit, counts: make(map[string]int)} -} - -func (c *ConnectionLimiter) Allow(ip string) bool { - if c == nil || c.limit <= 0 { - return true - } - c.mu.Lock() - defer c.mu.Unlock() - now := time.Now() - if now.Sub(c.start) >= c.window { - c.start = now - c.counts = make(map[string]int) - } - c.counts[ip]++ - return c.counts[ip] <= c.limit -} - -type PacketRateLimiter struct { - mu sync.Mutex - interval time.Duration - limit int - count int - resetAt time.Time -} - -func NewPacketRateLimiter(packetsPerSec int) *PacketRateLimiter { - if packetsPerSec <= 0 { - packetsPerSec = 500 - } - return &PacketRateLimiter{ - interval: time.Second, - limit: packetsPerSec, - resetAt: time.Now().Add(time.Second), - } -} - -func (p *PacketRateLimiter) Allow() bool { - p.mu.Lock() - defer p.mu.Unlock() - now := time.Now() - if now.After(p.resetAt) { - p.count = 0 - p.resetAt = now.Add(p.interval) - } - p.count++ - return p.count <= p.limit -} - -type ConnectionCounter struct { - mu sync.Mutex - perIP map[string]int - global int -} - -func NewConnectionCounter() *ConnectionCounter { - return &ConnectionCounter{perIP: make(map[string]int)} -} - -func (c *ConnectionCounter) TryAdd(ip string, maxPerIP int, maxGlobal int) bool { - c.mu.Lock() - defer c.mu.Unlock() - if maxGlobal > 0 && c.global >= maxGlobal { - return false - } - if maxPerIP > 0 && c.perIP[ip] >= maxPerIP { - return false - } - c.perIP[ip]++ - c.global++ - return true -} - -func (c *ConnectionCounter) Remove(ip string) { - c.mu.Lock() - defer c.mu.Unlock() - if c.perIP[ip] > 0 { - c.perIP[ip]-- - if c.perIP[ip] <= 0 { - delete(c.perIP, ip) - } - } - if c.global > 0 { - c.global-- - } -} - -type InboundRateLimiter struct { - mu sync.Mutex - interval time.Duration - limit int - count int - resetAt time.Time -} - -func NewInboundRateLimiter(bytesPerSec int) *InboundRateLimiter { - if bytesPerSec <= 0 { - bytesPerSec = 0 - } - return &InboundRateLimiter{ - interval: time.Second, - limit: bytesPerSec, - resetAt: time.Now().Add(time.Second), - } -} - -func (r *InboundRateLimiter) Allow(n int) bool { - if r == nil || r.limit <= 0 { - return true - } - r.mu.Lock() - defer r.mu.Unlock() - now := time.Now() - if now.After(r.resetAt) { - r.count = 0 - r.resetAt = now.Add(r.interval) - } - if r.count+n > r.limit { - return false - } - r.count += n - return true -} diff --git a/wisp/protection/streamlimits.go b/wisp/protection/streamlimits.go deleted file mode 100644 index b55d633..0000000 --- a/wisp/protection/streamlimits.go +++ /dev/null @@ -1,38 +0,0 @@ -package protection - -import "sync" - -type StreamLimiter struct { - mutex sync.Mutex - pH map[string]int - total int -} - -func NewStreamLimiter() *StreamLimiter { - return &StreamLimiter{pH: make(map[string]int)} -} - -func (s *StreamLimiter) Allow(host string, perHostLimit int, totalLimit int) bool { - s.mutex.Lock() - defer s.mutex.Unlock() - if totalLimit > 0 && s.total >= totalLimit { - return false - } - if perHostLimit > 0 && s.pH[host] >= perHostLimit { - return false - } - s.total++ - s.pH[host]++ - return true -} - -func (s *StreamLimiter) Release(host string) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.total > 0 { - s.total-- - } - if s.pH[host] > 0 { - s.pH[host]-- - } -} diff --git a/wisp/twisp.go b/wisp/twisp.go index 2ac95d0..3da5283 100644 --- a/wisp/twisp.go +++ b/wisp/twisp.go @@ -91,7 +91,9 @@ func handleTwisp(wc *wispConnection, streamId uint32, command string) { func (ts *twispStream) readPty() { const maxHeaderLen = 15 - buf := make([]byte, maxHeaderLen+65535) + bufPool := ts.wispConn.config.ReadBufPool.Get().(*[]byte) + buf := *bufPool + defer ts.wispConn.config.ReadBufPool.Put(bufPool) streamId := ts.streamId diff --git a/wisp/v2.go b/wisp/v2.go index 7ae9027..017a036 100644 --- a/wisp/v2.go +++ b/wisp/v2.go @@ -9,8 +9,6 @@ import ( var errorInvalid = errors.New("invalid wisp v2 payload") -const v2HandshakeTimeout = 15 * time.Second - type extensions struct { udp bool streamConfirm bool @@ -70,10 +68,7 @@ func parseClientInfo(payload []byte) (*extensions, error) { exts := &extensions{} data := payload[2:] - for len(data) > 0 { - if len(data) < 5 { - return nil, errorInvalid - } + for len(data) >= 5 { extID := data[0] extLen := binary.LittleEndian.Uint32(data[1:5]) data = data[5:] @@ -129,7 +124,6 @@ func parseClientInfo(payload []byte) (*extensions, error) { func (c *wispConnection) v2Handshake() { c.handshakeDone = make(chan struct{}) - _ = c.netConn.SetReadDeadline(time.Now().Add(v2HandshakeTimeout)) infoPayload := c.buildServerInfoPacket() c.sendRawFrame(infoPayload) diff --git a/wisp/wisp-connection.go b/wisp/wisp-connection.go index 2603a74..e102469 100644 --- a/wisp/wisp-connection.go +++ b/wisp/wisp-connection.go @@ -4,54 +4,18 @@ import ( "encoding/binary" "net" "strconv" + "strings" "sync" "sync/atomic" "time" "unsafe" - - prot "mrrowisp/wisp/protection" -) - -const ( - maxConnectsPerSecond = 20 - connectRateWindow = time.Second - minFramePoolCap = 64 * 1024 ) -type connectRateLimiter struct { - mutex sync.Mutex - windowStart time.Time - count int - limit int -} - -func newConnectRateLimiter(limit int) *connectRateLimiter { - if limit <= 0 { - limit = maxConnectsPerSecond - } - return &connectRateLimiter{windowStart: time.Now(), limit: limit} -} - -func (r *connectRateLimiter) allow() bool { - r.mutex.Lock() - defer r.mutex.Unlock() - now := time.Now() - if now.Sub(r.windowStart) >= connectRateWindow { - r.windowStart = now - r.count = 0 - } - r.count++ - return r.count <= r.limit -} - type writeReq struct { data []byte pool bool } -const maxConcurrentDials = 50 -const maxPendingStreamBytes = 16 * 1024 * 1024 - type wispConnection struct { netConn net.Conn writeCh chan writeReq @@ -62,7 +26,6 @@ type wispConnection struct { shutdownOnce sync.Once config *Config twispStreams *twispRegistry - connectLimiter *connectRateLimiter remoteIP string isV2 bool @@ -71,48 +34,32 @@ type wispConnection struct { v2Challenge []byte authenticated atomic.Bool - dialSem chan struct{} - closeCh chan struct{} - createdAt time.Time - packetLimiter *prot.PacketRateLimiter - inboundLimiter *prot.InboundRateLimiter - streamCount atomic.Int32 + dialSem chan struct{} + closeCh chan struct{} + createdAt time.Time + streamCount atomic.Int32 } func (c *wispConnection) close() { - c.shutdownOnce.Do(func() { - c.isClosed.Store(true) - close(c.closeCh) - c.netConn.Close() - }) + if !c.isClosed.CompareAndSwap(false, true) { + return + } + c.netConn.Close() } func (c *wispConnection) writeLoop() { for req := range c.writeCh { - reqs := []writeReq{req} + bufs := net.Buffers{req.data} n := len(c.writeCh) for i := 0; i < n; i++ { - reqs = append(reqs, <-c.writeCh) - } - bufs := make(net.Buffers, 0, len(reqs)) - for _, r := range reqs { + r := <-c.writeCh bufs = append(bufs, r.data) } - // if cfg.config != nil { - // _ = cfg.netConn.SetWriteDeadline(time.Now().Add(cfg.config.WriteTimeout)) - // } if _, err := bufs.WriteTo(c.netConn); err != nil { - c.close() + c.isClosed.Store(true) + c.netConn.Close() return } - // if cfg.config != nil && cfg.config.WriteTimeout > 0 { - // _ = cfg.netConn.SetWriteDeadline(time.Time{}) - // } - for _, r := range reqs { - if r.pool { - c.releaseFrame(r.data) - } - } } } @@ -152,7 +99,7 @@ func (c *wispConnection) releaseFrame(data []byte) { if c.config == nil || len(data) == 0 { return } - if cap(data) < minFramePoolCap { + if cap(data) < 64*1024 { return } buf := data @@ -188,18 +135,16 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { if len(payload) < 3 { return } - guard := newProtection(c.config) streamType := payload[0] port := strconv.FormatUint(uint64(binary.LittleEndian.Uint16(payload[1:3])), 10) hostname := string(payload[3:]) c.config.Logger.Debug("creating stream", "ip", c.remoteIP, "streamId", streamId, "hostname", hostname, "port", port, "type", streamType) - action, normalizedHostname, reason := guard.allowConnect(c, streamType, hostname, port) - if action == connectBlocked { - c.sendClosePacket(streamId, reason) - return - } - if action == connectTwisp { + if streamType == streamTypeTerm { + if !c.config.EnableTwisp { + c.sendClosePacket(streamId, closeReasonBlocked) + return + } go handleTwisp(c, streamId, hostname) return } @@ -208,7 +153,7 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { wispConn: c, streamId: streamId, connReady: make(chan struct{}), - hostname: normalizedHostname, + hostname: strings.ToLower(strings.TrimSpace(hostname)), } stream.isOpen.Store(true) @@ -218,23 +163,10 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { } c.streamCount.Add(1) - go stream.handleConnect(streamType, port, normalizedHostname) + go stream.handleConnect(streamType, port, hostname) } func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { - guard := newProtection(c.config) - if c.packetLimiter != nil && !c.packetLimiter.Allow() { - c.sendClosePacket(streamId, closeReasonThrottled) - return - } - if c.inboundLimiter != nil && !c.inboundLimiter.Allow(len(payload)) { - c.sendClosePacket(streamId, closeReasonThrottled) - return - } - if !guard.allowMessageSize(len(payload)) { - c.sendClosePacket(streamId, closeReasonInvalidInfo) - return - } var stream *wispStream if c.cachedStreamId == streamId { stream = (*wispStream)(atomic.LoadPointer(&c.cachedStream)) @@ -265,7 +197,7 @@ func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { stream.pendingMutex.Lock() if !stream.connReadyDone.Load() { - if stream.pendingBytes+len(payload) > maxPendingStreamBytes { + if stream.pendingBytes+len(payload) > 16*1024*1024 { stream.pendingMutex.Unlock() stream.close(closeReasonThrottled) return @@ -288,7 +220,7 @@ func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { if stream.streamType == streamTypeTCP { stream.bufferRemaining-- if stream.bufferRemaining == 0 { - // stream.bufferRemaining = c.config.BufferRemainingLength + stream.bufferRemaining = c.config.BufferRemainingLength c.sendPacket(streamId, stream.bufferRemaining) } } diff --git a/wisp/wisp-stream.go b/wisp/wisp-stream.go index 9d73d71..525c5af 100644 --- a/wisp/wisp-stream.go +++ b/wisp/wisp-stream.go @@ -9,8 +9,6 @@ import ( "sync/atomic" "time" - prot "mrrowisp/wisp/protection" - "golang.org/x/net/proxy" ) @@ -35,53 +33,44 @@ type wispStream struct { const dnsLookupTimeout = 10 * time.Second +func NormalizeTargetHostname(host string) string { + host = strings.TrimSpace(strings.ToLower(host)) + host = strings.TrimSuffix(host, ".") + return host +} + func (s *wispStream) handleConnect(streamType uint8, port string, hostname string) { defer s.signalConnReady() cfg := s.wispConn.config - s.hostname = prot.NormalizeTargetHostname(hostname) - if s.hostname == "" { - s.close(closeReasonInvalidInfo) - return - } - - guard := newProtection(cfg) + s.hostname = NormalizeTargetHostname(hostname) - if reason, ok := guard.allowHostPort(s.hostname, port); !ok { - s.close(reason) - return - } - - resolvedHostname := s.hostname - - if ip := net.ParseIP(resolvedHostname); ip != nil { - if reason, ok := guard.allowDirectIP(ip, s.wispConn.remoteIP, s.hostname); !ok { - s.close(reason) - return - } - resolvedHostname = ip.String() - } else if cfg.Proxy != "" { - resolvedHostname = s.hostname - } else if cfg.DNSCache != nil { - ctx, cancel := context.WithTimeout(context.Background(), dnsLookupTimeout) - ips, err := cfg.DNSCache.LookupIPAddr(ctx, resolvedHostname) - cancel() - if err != nil { - cfg.Logger.Warn("DNS lookup failed", "ip", s.wispConn.remoteIP, "hostname", resolvedHostname, "error", err) - s.close(closeReasonUnreachable) + if len(cfg.Whitelist.Hostnames) > 0 { + if _, ok := cfg.Whitelist.Hostnames[s.hostname]; !ok { + s.close(closeReasonBlocked) return } - if len(ips) == 0 { - cfg.Logger.Warn("DNS returned no results", "ip", s.wispConn.remoteIP, "hostname", resolvedHostname) - s.close(closeReasonUnreachable) + } else if len(cfg.Blacklist.Hostnames) > 0 { + if _, ok := cfg.Blacklist.Hostnames[s.hostname]; ok { + s.close(closeReasonBlocked) return } - selected, reason, ok := guard.selectAllowedIP(ips, s.wispConn.remoteIP, resolvedHostname) - if !ok { - s.close(reason) - return + } + + resolvedHostname := hostname + if cfg.DNSCache != nil { + if _, whitelisted := cfg.Whitelist.Hostnames[hostname]; !whitelisted { + ips, err := cfg.DNSCache.LookupIPAddr(context.Background(), hostname) + if err != nil { + s.close(closeReasonUnreachable) + return + } + if len(ips) == 0 { + s.close(closeReasonUnreachable) + return + } + resolvedHostname = ips[0].IP.String() } - resolvedHostname = selected } s.streamType = streamType @@ -92,18 +81,12 @@ func (s *wispStream) handleConnect(streamType uint8, port string, hostname strin var err error switch streamType { case streamTypeTCP: - select { - case s.wispConn.dialSem <- struct{}{}: - case <-s.wispConn.closeCh: - return - } if cfg.Proxy != "" { proxyURL := cfg.Proxy proxyURL = strings.Replace(proxyURL, "socks5h://", "socks5://", 1) proxyURL = strings.Replace(proxyURL, "socks4a://", "socks4://", 1) dialer, proxyErr := proxy.SOCKS5("tcp", stripScheme(proxyURL), nil, proxy.Direct) if proxyErr != nil { - <-s.wispConn.dialSem cfg.Logger.Warn("proxy dialer creation failed", "ip", s.wispConn.remoteIP, "error", proxyErr) s.close(closeReasonNetworkError) return @@ -112,7 +95,6 @@ func (s *wispStream) handleConnect(streamType uint8, port string, hostname strin } else { s.conn, err = cfg.Dialer.Dial("tcp", destination) } - <-s.wispConn.dialSem case streamTypeUDP: if cfg.Proxy != "" || !cfg.AllowUDP { s.close(closeReasonBlocked) @@ -179,7 +161,9 @@ func (s *wispStream) signalConnReady() { func (s *wispStream) readFromConnection() { const maxHeaderLen = 15 - buf := make([]byte, maxHeaderLen+65535) + bufp := s.wispConn.config.ReadBufPool.Get().(*[]byte) + buf := *bufp + defer s.wispConn.config.ReadBufPool.Put(bufp) streamId := s.streamId diff --git a/wisp/wisp.go b/wisp/wisp.go index dc3de61..0a96ed5 100644 --- a/wisp/wisp.go +++ b/wisp/wisp.go @@ -3,11 +3,9 @@ package wisp import ( "net" "net/http" - "strings" + "sync" "time" - "mrrowisp/wisp/protection" - "github.com/lxzan/gws" ) @@ -26,12 +24,6 @@ func (cfg *Config) InitResolver() { Method: cfg.DnsMethod, ResultOrder: cfg.DnsResultOrder, }) - // if cfg.BandwidthLimitKbps > 0 { - // cfg.BandwidthLimiter = protection.NewBandwidthLimiter(cfg.BandwidthLimitKbps, time.Duration(cfg.ConnectionWindowSeconds)*time.Second) - // } - // if cfg.ConnectionsLimitPerIP > 0 { - // cfg.ConnectionLimiter = protection.NewConnectionLimiter(cfg.ConnectionsLimitPerIP, time.Duration(cfg.ConnectionWindowSeconds)*time.Second) - // } cfg.Logger = newLogger(cfg.LogLevel) } @@ -42,44 +34,30 @@ type upgradeHandler struct { func CreateWispHandler(config *Config) http.HandlerFunc { config.InitResolver() + readBufSize := 15 + config.TcpBufferSize + config.ReadBufPool = &sync.Pool{ + New: func() any { + buf := make([]byte, readBufSize) + return &buf + }, + } + + config.Dialer = net.Dialer{ + Timeout: 15 * time.Second, + KeepAlive: 30 * time.Second, + } + upgrader := gws.NewUpgrader(&upgradeHandler{}, &gws.ServerOption{ PermessageDeflate: gws.PermessageDeflate{ Enabled: false, }, }) - guard := newProtection(config) - return func(w http.ResponseWriter, r *http.Request) { useV2 := config.EnableV2 && r.Header.Get("Sec-WebSocket-Protocol") != "" - remoteIP := protection.RemoteIPFromRequest(r, protection.IPConfig{ - AllowDirectIP: config.AllowDirectIP, - AllowPrivateIPs: config.AllowPrivateIPs, - AllowLoopbackIPs: config.AllowLoopbackIPs, - ParseRealIP: config.ParseRealIP, - }) - config.Logger.Info("incoming connection", "ip", remoteIP, "path", r.URL.Path, "origin", r.Header.Get("Origin")) - if config.requiresV2() && !useV2 { - config.Logger.Warn("v2 required but not negotiated", "ip", remoteIP) - w.WriteHeader(http.StatusUnauthorized) - return - } - - if status, response, ok := guard.allowHTTP(r, remoteIP, useV2); !ok { - w.WriteHeader(status) - if response != "" { - _, _ = w.Write([]byte(response)) - } - return - } wsConn, err := upgrader.Upgrade(w, r) if err != nil { - if config.NonWSResponse != "" { - w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte(config.NonWSResponse)) - } - config.Logger.Debug("websocket upgrade failed", "error", err) return } @@ -91,18 +69,13 @@ func CreateWispHandler(config *Config) http.HandlerFunc { } wc := &wispConnection{ - netConn: netConn, - // writeCh: make(chan writeReq, writeQSize), + netConn: netConn, + writeCh: make(chan writeReq, 4096), // funny number config: config, twispStreams: newTwisp(), isV2: useV2, - remoteIP: remoteIP, - dialSem: make(chan struct{}, maxConcurrentDials), - closeCh: make(chan struct{}), - createdAt: time.Now(), } - config.Logger.Info("connection established", "ip", remoteIP, "v2", useV2) go wc.writeLoop() if useV2 { @@ -120,19 +93,3 @@ func (cfg *Config) requiresV2() bool { } return cfg.PasswordAuthRequired || cfg.EnableTwisp } - -func originAllowed(r *http.Request, allowedOrigins []string) bool { - if len(allowedOrigins) == 0 { - return true - } - origin := strings.TrimSpace(r.Header.Get("Origin")) - if origin == "" { - return false - } - for _, allowed := range allowedOrigins { - if origin == strings.TrimSpace(allowed) { - return true - } - } - return false -} diff --git a/wisp/wsreader.go b/wisp/wsreader.go index df9b66f..5a73942 100644 --- a/wisp/wsreader.go +++ b/wisp/wsreader.go @@ -16,9 +16,6 @@ func (c *wispConnection) readLoop() { var headerBuffer [14]byte for { - // if c.config != nil && c.config.FrameReadTimeout > 0 { - // _ = c.netConn.SetReadDeadline(time.Now().Add(c.config.FrameReadTimeout)) - // } if _, err := io.ReadFull(reader, headerBuffer[:2]); err != nil { return }