diff --git a/.gitignore b/.gitignore index ead0f8b..fafbf4d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ bin node_modules config.json mrrowisp -mrrowisp.exe \ No newline at end of file +mrrowisp.exe diff --git a/main.go b/main.go index 64c292a..c0512ad 100644 --- a/main.go +++ b/main.go @@ -2,545 +2,45 @@ package main import ( "context" - "crypto/ed25519" - "encoding/hex" - "encoding/json" "flag" "fmt" "net" "net/http" "os" "os/signal" - "strings" "syscall" "time" "mrrowisp/wisp" ) -type PortEntry struct { - Min int - Max int -} - -func (p *PortEntry) UnmarshalJSON(data []byte) error { - var single int - if err := json.Unmarshal(data, &single); err == nil { - if single <= 0 || single > 65535 { - return fmt.Errorf("invalid port %d", single) - } - p.Min = single - p.Max = single - return nil - } - var pair [2]int - if err := json.Unmarshal(data, &pair); err != nil { - return fmt.Errorf("port entry must be an integer or [min, max] pair: %w", err) - } - if pair[0] <= 0 || pair[1] > 65535 || pair[0] > pair[1] { - return fmt.Errorf("invalid port range [%d, %d]", pair[0], pair[1]) - } - p.Min = pair[0] - p.Max = pair[1] - return nil -} - -type Config struct { - Port int `json:"port"` - AllowTCP bool `json:"allowTCP"` - AllowUDP bool `json:"allowUDP"` - AllowDirectIP bool `json:"allowDirectIP"` - AllowPrivateIPs bool `json:"allowPrivateIPs"` - AllowLoopbackIPs bool `json:"allowLoopbackIPs"` - TcpBufferSize int `json:"tcpBufferSize"` - BufferRemainingLength uint32 `json:"bufferRemainingLength"` - TcpNoDelay bool `json:"tcpNoDelay"` - WebsocketTcpNoDelay bool `json:"websocketTcpNoDelay"` - StreamLimitPerHost int `json:"streamLimitPerHost"` - StreamLimitTotal int `json:"streamLimitTotal"` - - Blacklist struct { - Hostnames []string `json:"hostnames"` - Ports []PortEntry `json:"ports"` - } `json:"blacklist"` - Whitelist struct { - Hostnames []string `json:"hostnames"` - Ports []PortEntry `json:"ports"` - } `json:"whitelist"` - - Proxy string `json:"proxy"` - WebsocketPermessageDeflate bool `json:"websocketPermessageDeflate"` - DnsServers []string `json:"dnsServers"` - DnsTTLSeconds int `json:"dnsTTLSeconds"` - DnsMethod string `json:"dnsMethod"` - DnsResultOrder string `json:"dnsResultOrder"` - - EnableTwisp bool `json:"enableTwisp"` - - EnableV2 bool `json:"enableV2"` - Motd string `json:"motd"` - PasswordAuth bool `json:"passwordAuth"` - PasswordAuthRequired bool `json:"passwordAuthRequired"` - PasswordUsers map[string]string `json:"passwordUsers"` - CertAuth bool `json:"certAuth"` - CertAuthRequired bool `json:"certAuthRequired"` - CertAuthPublicKeys []string `json:"certAuthPublicKeys"` - EnableStreamConfirm bool `json:"enableStreamConfirm"` - MaxConnectsPerSecond int `json:"maxConnectsPerSecond"` - - BandwidthLimitKbps int `json:"bandwidthLimitKbps"` - ConnectionsLimitPerIP int `json:"connectionsLimitPerIP"` - ConnectionWindowSeconds int `json:"connectionWindowSeconds"` - ParseRealIP bool `json:"parseRealIP"` - ParseRealIPFrom []string `json:"parseRealIPFrom"` - MaxMessageSize int `json:"maxMessageSize"` - StaticDir string `json:"staticDir"` - NonWSResponse string `json:"nonWSResponse"` - AllowedOrigins []string `json:"allowedOrigins"` - WriteTimeoutSeconds int `json:"writeTimeoutSeconds"` - FrameReadTimeoutSeconds int `json:"frameReadTimeoutSeconds"` - LogLevel string `json:"logLevel"` - - BanEnabled bool `json:"banEnabled"` - BanDurationSeconds int `json:"banDurationSeconds"` - BanMaxStrikes int `json:"banMaxStrikes"` - BanEscalationMultiplier int `json:"banEscalationMultiplier"` - MaxHandshakeFailures int `json:"maxHandshakeFailures"` - - MaxPacketRate int `json:"maxPacketRate"` - MaxConnectionLifetimeSec int `json:"maxConnectionLifetimeSeconds"` - MaxStreamsPerConnection int `json:"maxStreamsPerConnection"` - MaxConnectionsPerIP int `json:"maxConnectionsPerIP"` - GlobalMaxConnections int `json:"globalMaxConnections"` - WriteQueueSize int `json:"writeQueueSize"` - MaxInboundBytesPerSecond int `json:"maxInboundBytesPerSecond"` -} - -const ( - defaultStreamLimitPerHost = 512 - defaultStreamLimitTotal = 16384 - defaultMaxConnectsPerSecond = 20 - defaultConnectionsLimitPerIP = 120 - defaultHandshakeFailures = 10 -) - -func defaultConfig() Config { - return Config{ - Port: 6001, - AllowTCP: true, - AllowUDP: true, - AllowDirectIP: false, - AllowPrivateIPs: false, - AllowLoopbackIPs: false, - TcpBufferSize: 32768, - BufferRemainingLength: 65536, - TcpNoDelay: true, - WebsocketTcpNoDelay: true, - StreamLimitPerHost: defaultStreamLimitPerHost, - StreamLimitTotal: defaultStreamLimitTotal, - WebsocketPermessageDeflate: false, - EnableTwisp: false, - EnableV2: false, - PasswordAuth: false, - PasswordAuthRequired: false, - PasswordUsers: make(map[string]string), - CertAuth: false, - CertAuthRequired: false, - EnableStreamConfirm: false, - MaxConnectsPerSecond: defaultMaxConnectsPerSecond, - DnsTTLSeconds: 120, - DnsMethod: "lookup", - DnsResultOrder: "verbatim", - ConnectionWindowSeconds: 1, - ConnectionsLimitPerIP: defaultConnectionsLimitPerIP, - ParseRealIP: true, - ParseRealIPFrom: []string{"127.0.0.1"}, - WriteTimeoutSeconds: 15, - FrameReadTimeoutSeconds: 30, - LogLevel: "debug", - NonWSResponse: "not found", - BanEnabled: true, - BanDurationSeconds: 3600, - BanMaxStrikes: 10, - BanEscalationMultiplier: 0, - MaxHandshakeFailures: defaultHandshakeFailures, - MaxPacketRate: 500, - MaxConnectionLifetimeSec: 0, - MaxStreamsPerConnection: 0, - MaxConnectionsPerIP: 0, - GlobalMaxConnections: 0, - WriteQueueSize: 4096, - MaxInboundBytesPerSecond: 0, - } -} - -func loadConfig(config string) (Config, error) { - cfg := defaultConfig() - - trimConfig := strings.TrimSpace(config) - if strings.HasPrefix(trimConfig, "{") { - if err := json.Unmarshal([]byte(trimConfig), &cfg); err != nil { - return cfg, err - } - return cfg, nil - } - - file, err := os.Open(config) - if err != nil { - return cfg, err - } - defer file.Close() - - decoder := json.NewDecoder(file) - if err := decoder.Decode(&cfg); err != nil { - return cfg, err - } - - return cfg, nil -} - -func portEntriesToRanges(entries []PortEntry) []wisp.PortRange { - out := make([]wisp.PortRange, 0, len(entries)) - for _, e := range entries { - out = append(out, wisp.PortRange{Min: e.Min, Max: e.Max}) - } - return out -} - -func createWispConfig(cfg Config) *wisp.Config { - normalizeHostname := func(host string) string { - host = strings.TrimSpace(strings.ToLower(host)) - host = strings.TrimSuffix(host, ".") - return host - } - - if cfg.TcpBufferSize <= 0 { - cfg.TcpBufferSize = 32768 - } - if cfg.BufferRemainingLength == 0 { - cfg.BufferRemainingLength = 65536 - } - if cfg.ConnectionWindowSeconds <= 0 { - cfg.ConnectionWindowSeconds = 1 - } - if cfg.BandwidthLimitKbps < 0 { - cfg.BandwidthLimitKbps = 0 - } - if cfg.ConnectionsLimitPerIP < 0 { - cfg.ConnectionsLimitPerIP = 0 - } - if cfg.MaxConnectsPerSecond < 0 { - cfg.MaxConnectsPerSecond = 0 - } - if cfg.MaxMessageSize < 0 { - cfg.MaxMessageSize = 0 - } - if cfg.WriteTimeoutSeconds < 0 { - cfg.WriteTimeoutSeconds = 0 - } - if cfg.FrameReadTimeoutSeconds < 0 { - cfg.FrameReadTimeoutSeconds = 0 - } - if cfg.MaxHandshakeFailures <= 0 { - cfg.MaxHandshakeFailures = defaultHandshakeFailures - } - if cfg.MaxPacketRate <= 0 { - cfg.MaxPacketRate = 500 - } - if cfg.WriteQueueSize <= 0 { - cfg.WriteQueueSize = 4096 - } - if cfg.MaxConnectionLifetimeSec < 0 { - cfg.MaxConnectionLifetimeSec = 0 - } - if cfg.MaxStreamsPerConnection < 0 { - cfg.MaxStreamsPerConnection = 0 - } - if cfg.MaxConnectionsPerIP < 0 { - cfg.MaxConnectionsPerIP = 0 - } - if cfg.GlobalMaxConnections < 0 { - cfg.GlobalMaxConnections = 0 - } - if cfg.MaxInboundBytesPerSecond < 0 { - cfg.MaxInboundBytesPerSecond = 0 - } - if len(cfg.AllowedOrigins) > 0 { - filtered := make([]string, 0, len(cfg.AllowedOrigins)) - for _, origin := range cfg.AllowedOrigins { - origin = strings.TrimSpace(origin) - if origin == "" { - continue - } - filtered = append(filtered, origin) - } - cfg.AllowedOrigins = filtered - } - - blacklistedHostnames := make(map[string]struct{}) - for _, host := range cfg.Blacklist.Hostnames { - normalized := normalizeHostname(host) - if normalized == "" { - continue - } - blacklistedHostnames[normalized] = struct{}{} - } - - whitelistedHostnames := make(map[string]struct{}) - for _, host := range cfg.Whitelist.Hostnames { - normalized := normalizeHostname(host) - if normalized == "" { - continue - } - whitelistedHostnames[normalized] = struct{}{} - } - - var pubKeys []ed25519.PublicKey - for _, hexKey := range cfg.CertAuthPublicKeys { - hexKeyBytes, err := hex.DecodeString(hexKey) - if err != nil { - fmt.Printf("warning: invalid public key hex %q: %v\n", hexKey, err) - continue - } - if len(hexKeyBytes) != ed25519.PublicKeySize { - fmt.Printf("warning: public key %q has invalid length %d (expected %d)\n", hexKey, len(hexKeyBytes), ed25519.PublicKeySize) - continue - } - pubKeys = append(pubKeys, ed25519.PublicKey(hexKeyBytes)) - } - - parseReal := make(map[string]struct{}) - for _, ip := range cfg.ParseRealIPFrom { - normalized := strings.TrimSpace(ip) - if normalized == "" { - continue - } - if net.ParseIP(normalized) == nil { - fmt.Printf("warning: invalid parse-real-ip-from value %q\n", ip) - continue - } - parseReal[normalized] = struct{}{} - } - - wispCfg := &wisp.Config{ - AllowTCP: cfg.AllowTCP, - AllowUDP: cfg.AllowUDP, - AllowDirectIP: cfg.AllowDirectIP, - AllowPrivateIPs: cfg.AllowPrivateIPs, - AllowLoopbackIPs: cfg.AllowLoopbackIPs, - TcpBufferSize: cfg.TcpBufferSize, - BufferRemainingLength: cfg.BufferRemainingLength, - TcpNoDelay: cfg.TcpNoDelay, - WebsocketTcpNoDelay: cfg.WebsocketTcpNoDelay, - StreamLimitPerHost: cfg.StreamLimitPerHost, - StreamLimitTotal: cfg.StreamLimitTotal, - Blacklist: struct { - Hostnames map[string]struct{} - Ports []wisp.PortRange - }{ - Hostnames: blacklistedHostnames, - Ports: portEntriesToRanges(cfg.Blacklist.Ports), - }, - Whitelist: struct { - Hostnames map[string]struct{} - Ports []wisp.PortRange - }{ - Hostnames: whitelistedHostnames, - Ports: portEntriesToRanges(cfg.Whitelist.Ports), - }, - Proxy: cfg.Proxy, - WebsocketPermessageDeflate: cfg.WebsocketPermessageDeflate, - DnsServers: cfg.DnsServers, - DnsTTLSeconds: cfg.DnsTTLSeconds, - DnsMethod: cfg.DnsMethod, - DnsResultOrder: cfg.DnsResultOrder, - EnableTwisp: cfg.EnableTwisp, - EnableV2: cfg.EnableV2, - Motd: cfg.Motd, - PasswordAuth: cfg.PasswordAuth, - PasswordAuthRequired: cfg.PasswordAuthRequired, - PasswordUsers: cfg.PasswordUsers, - CertAuth: cfg.CertAuth, - CertAuthRequired: cfg.CertAuthRequired, - CertAuthPublicKeys: pubKeys, - EnableStreamConfirm: cfg.EnableStreamConfirm, - MaxConnectsPerSecond: cfg.MaxConnectsPerSecond, - BandwidthLimitKbps: cfg.BandwidthLimitKbps, - ConnectionsLimitPerIP: cfg.ConnectionsLimitPerIP, - ConnectionWindowSeconds: cfg.ConnectionWindowSeconds, - ParseRealIP: cfg.ParseRealIP, - ParseRealIPFrom: parseReal, - MaxMessageSize: cfg.MaxMessageSize, - NonWSResponse: cfg.NonWSResponse, - AllowedOrigins: cfg.AllowedOrigins, - WriteTimeout: time.Duration(cfg.WriteTimeoutSeconds) * time.Second, - FrameReadTimeout: time.Duration(cfg.FrameReadTimeoutSeconds) * time.Second, - LogLevel: cfg.LogLevel, - BanEnabled: cfg.BanEnabled, - BanDuration: time.Duration(cfg.BanDurationSeconds) * time.Second, - BanMaxStrikes: cfg.BanMaxStrikes, - BanEscalationMultiplier: cfg.BanEscalationMultiplier, - MaxHandshakeFailures: cfg.MaxHandshakeFailures, - MaxPacketRate: cfg.MaxPacketRate, - MaxConnectionLifetime: time.Duration(cfg.MaxConnectionLifetimeSec) * time.Second, - MaxStreamsPerConnection: cfg.MaxStreamsPerConnection, - MaxConnectionsPerIP: cfg.MaxConnectionsPerIP, - GlobalMaxConnections: cfg.GlobalMaxConnections, - WriteQueueSize: cfg.WriteQueueSize, - MaxInboundBytesPerSecond: cfg.MaxInboundBytesPerSecond, - } - - if wispCfg.PasswordUsers == nil { - wispCfg.PasswordUsers = make(map[string]string) - } - if wispCfg.StreamLimitPerHost < 0 { - wispCfg.StreamLimitPerHost = 0 - } - if wispCfg.StreamLimitTotal < 0 { - wispCfg.StreamLimitTotal = 0 - } - if wispCfg.AllowedOrigins == nil { - wispCfg.AllowedOrigins = []string{} - } - - return wispCfg -} - func main() { fConfig := flag.String("config", "", "config to load (file or json string)") fPort := flag.Int("port", 0, "port to run on") - fLogLevel := flag.String("log-level", "", "log level (debug, info, warn, error)") - fAllowTCP := flag.Bool("allow-tcp", true, "allow TCP streams") - fAllowUDP := flag.Bool("allow-udp", true, "allow UDP streams") - fAllowDirectIP := flag.Bool("allow-direct-ip", false, "allow direct IP targets") - fAllowPrivateIPs := flag.Bool("allow-private", false, "allow private IP targets") fAllowLoopbackIPs := flag.Bool("allow-loopback", false, "allow loopback IP targets") - fStreamLimitPerHost := flag.Int("stream-limit-per-host", 0, "max streams per host (0 = unlimited)") - fStreamLimitTotal := flag.Int("stream-limit-total", 0, "max total streams (0 = unlimited)") - fBandwidthLimit := flag.Int("bandwidth", 0, "bandwidth limit per IP in KB/s") - fConnectionsLimit := flag.Int("connections", 0, "connections per IP per window") - fWindow := flag.Int("window", 1, "rate limit window in seconds") - fDnsServers := flag.String("dns", "", "comma-separated DNS servers") - fDnsMethod := flag.String("dns-method", "", "DNS method (lookup|resolve)") - fDnsOrder := flag.String("dns-order", "", "DNS result order (ipv4first|ipv6first|verbatim)") - fDnsTTL := flag.Int("dns-ttl", 0, "DNS cache TTL seconds") - fStatic := flag.String("static", "", "static directory to serve") - fNonWS := flag.String("non-ws-response", "", "response body for non-websocket requests") - fParseRealIP := flag.Bool("parse-real-ip", true, "parse client IP from forwarded headers") - fParseRealIPFrom := flag.String("parse-real-ip-from", "", "comma-separated list of IPs allowed to set real IP") - fMaxMessageSize := flag.Int("max-message-size", 0, "max websocket message size in bytes") - fWriteTimeout := flag.Int("write-timeout", 0, "write timeout in seconds (0 = disabled)") - fAllowedOrigins := flag.String("allowed-origins", "", "comma-separated list of allowed origins") - fMaxPacketRate := flag.Int("max-packet-rate", 0, "max wisp packets/sec per connection (0=default)") - fMaxConnLifetime := flag.Int("max-conn-lifetime", 0, "max connection lifetime in seconds (0=unlimited)") - fMaxStreamsPerConn := flag.Int("max-streams-per-conn", 0, "max concurrent streams per connection (0=unlimited)") - fMaxConnPerIP := flag.Int("max-conn-per-ip", 0, "hard connection cap per IP (0=unlimited)") - fGlobalMaxConn := flag.Int("global-max-conn", 0, "global connection cap (0=unlimited)") - fInboundBPS := flag.Int("inbound-bps", 0, "max inbound bytes/sec per connection (0=unlimited)") flag.Parse() - var cfg Config + var cfg wisp.Config var err error if *fConfig != "" { - cfg, err = loadConfig(*fConfig) + cfg, err = wisp.LoadConfig(*fConfig) if err != nil { fmt.Printf("Failed to load config: %v\n", err) return } } else { - cfg = defaultConfig() + cfg = wisp.DefaultConfig() } if *fPort != 0 { cfg.Port = *fPort } - if *fLogLevel != "" { - cfg.LogLevel = *fLogLevel - } - if *fAllowTCP != true { - cfg.AllowTCP = *fAllowTCP - } - if *fAllowUDP != true { - cfg.AllowUDP = *fAllowUDP - } - if *fAllowDirectIP != false { - cfg.AllowDirectIP = *fAllowDirectIP - } - if *fAllowPrivateIPs != false { - cfg.AllowPrivateIPs = *fAllowPrivateIPs - } if *fAllowLoopbackIPs != false { cfg.AllowLoopbackIPs = *fAllowLoopbackIPs } - if *fStreamLimitPerHost != 0 { - cfg.StreamLimitPerHost = *fStreamLimitPerHost - } - if *fStreamLimitTotal != 0 { - cfg.StreamLimitTotal = *fStreamLimitTotal - } - if *fBandwidthLimit != 0 { - cfg.BandwidthLimitKbps = *fBandwidthLimit - } - if *fConnectionsLimit != 0 { - cfg.ConnectionsLimitPerIP = *fConnectionsLimit - } - if *fWindow != 0 { - cfg.ConnectionWindowSeconds = *fWindow - } - if *fDnsServers != "" { - cfg.DnsServers = strings.Split(*fDnsServers, ",") - } - if *fDnsMethod != "" { - cfg.DnsMethod = *fDnsMethod - } - if *fDnsOrder != "" { - cfg.DnsResultOrder = *fDnsOrder - } - if *fDnsTTL != 0 { - cfg.DnsTTLSeconds = *fDnsTTL - } - if *fStatic != "" { - cfg.StaticDir = *fStatic - } - if *fNonWS != "" { - cfg.NonWSResponse = *fNonWS - } - if *fParseRealIP != true { - cfg.ParseRealIP = *fParseRealIP - } - if *fParseRealIPFrom != "" { - cfg.ParseRealIPFrom = strings.Split(*fParseRealIPFrom, ",") - } - if *fMaxMessageSize != 0 { - cfg.MaxMessageSize = *fMaxMessageSize - } - if *fWriteTimeout != 0 { - cfg.WriteTimeoutSeconds = *fWriteTimeout - } - if *fMaxPacketRate != 0 { - cfg.MaxPacketRate = *fMaxPacketRate - } - if *fMaxConnLifetime != 0 { - cfg.MaxConnectionLifetimeSec = *fMaxConnLifetime - } - if *fMaxStreamsPerConn != 0 { - cfg.MaxStreamsPerConnection = *fMaxStreamsPerConn - } - if *fMaxConnPerIP != 0 { - cfg.MaxConnectionsPerIP = *fMaxConnPerIP - } - if *fGlobalMaxConn != 0 { - cfg.GlobalMaxConnections = *fGlobalMaxConn - } - if *fInboundBPS != 0 { - cfg.MaxInboundBytesPerSecond = *fInboundBPS - } - if *fAllowedOrigins != "" { - cfg.AllowedOrigins = strings.Split(*fAllowedOrigins, ",") - } - wispConfig := createWispConfig(cfg) + wispConfig := wisp.CreateWispConfig(&cfg) wispHandler := wisp.CreateWispHandler(wispConfig) @@ -550,7 +50,7 @@ func main() { } else { http.HandleFunc("/", wispHandler) } - fmt.Printf("Starting Mrrowisp on port %d. . .\n", cfg.Port) + fmt.Printf("[INFO] Starting Mrrowisp on port %d. . .\n", cfg.Port) server := &http.Server{ Addr: fmt.Sprintf(":%d", cfg.Port), ReadHeaderTimeout: 5 * time.Second, @@ -562,16 +62,16 @@ func main() { go func() { sig := <-sigch - fmt.Printf("Shutting down (signal: %s)\n", sig.String()) + fmt.Printf("[INFO] Shutting down (signal: %s)\n", sig.String()) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if shutdownErr := server.Shutdown(ctx); shutdownErr != nil { - fmt.Printf("Shutdown error: %v\n", shutdownErr) + fmt.Printf("[INFO] Shutdown error: %v\n", shutdownErr) } }() err = server.ListenAndServe() if err != nil && err != http.ErrServerClosed { - fmt.Printf("Failed to start Mrrowisp: %v", err) + fmt.Printf("[INFO] Failed to start Mrrowisp: %v", err) } } diff --git a/src/index.ts b/src/index.ts index ce19203..538fc65 100644 --- a/src/index.ts +++ b/src/index.ts @@ -9,68 +9,152 @@ import type { Socket } from "node:net"; type PortEntry = number | [number, number]; type MrrowispConfig = { + /** + * TCP port the server listens on. + */ port: number; + + /** + * Allow clients to open TCP streams. + */ allowTCP: boolean; + + /** + * Allow clients to open UDP streams. + */ allowUDP: boolean; + + /** + * Allow direct connections to IP addresses. + */ allowDirectIP: boolean; + + /** + * Allow connections to private/local IP ranges. + */ allowPrivateIPs: boolean; + + /** + * Allow connections to loopback IP addresses. + */ allowLoopbackIPs: boolean; + + /** + * Size of the TCP stream buffer in bytes. + */ tcpBufferSize: number; - bufferRemainingLength: number; + + /** + * Enable TCP_NODELAY on TCP sockets. + */ tcpNoDelay: boolean; - websocketTcpNoDelay: boolean; - streamLimitPerHost: number; - streamLimitTotal: number; + + /** + * Hostname and port blacklist rules. + */ blacklist: { hostnames: string[]; ports: PortEntry[]; }; + + /** + * Hostname and port whitelist rules. + */ whitelist: { hostnames: string[]; ports: PortEntry[]; }; - proxy: string; - websocketPermessageDeflate: boolean; + + /** + * DNS servers used for hostname resolution. + */ dnsServers: string[]; - dnsTTLSeconds: number; + + /** + * DNS resolution method. + */ dnsMethod: "lookup" | "resolve"; + + /** + * Preferred ordering for resolved IP addresses. + */ dnsResultOrder: "ipv4first" | "ipv6first" | "verbatim"; + + /** + * Enable TWisp experimental protocol. + */ enableTwisp: boolean; + + /** + * Enable Wisp v2 protocol support. + */ enableV2: boolean; + + /** + * Message of the day sent during handshake. + */ motd: string; + + /** + * Enable password authentication. + */ passwordAuth: boolean; + + /** + * Require password authentication for all clients. + */ passwordAuthRequired: boolean; + + /** + * Username/password credential map. + */ passwordUsers: Map; - certAuth: boolean; - certAuthRequired: boolean; - certAuthPublicKeys: string[]; - enableStreamConfirm: boolean; - maxConnectsPerSecond: number; - bandwidthLimitKbps: number; - connectionsLimitPerIP: number; - connectionWindowSeconds: number; + + /** + * Parse reverse-proxy real IP headers. + */ parseRealIP: boolean; - parseRealIPFrom: string[]; + + /** + * HTTP response returned for non-WebSocket requests. + */ + nonWSResponse: string; + + /** + * Logging verbosity level. + */ + logLevel: "debug" | "warn" | "error" | "info"; + + /** + * Optional upstream proxy URL (SOCKS/HTTP). + */ + proxy: string; + + /** + * Maximum WebSocket message size in bytes. + */ maxMessageSize: number; + + /** + * Directory for static file serving. + */ staticDir: string; - nonWSResponse: string; - allowedOrigins: string[]; - writeTimeoutSeconds: number; - frameReadTimeoutSeconds: number; - logLevel: "debug" | "info" | "warn" | "error"; - banEnabled: boolean; - banDurationSeconds: number; - banMaxStrikes: number; - banEscalationMultiplier: number; - maxHandshakeFailures: number; - maxPacketRate: number; - maxConnectionLifetimeSeconds: number; - maxStreamsPerConnection: number; - maxConnectionsPerIP: number; - globalMaxConnections: number; - writeQueueSize: number; - maxInboundBytesPerSecond: number; -} + + /** + * Bandwidth limit per IP in Kbps. + */ + bandwidthLimitKbps: number; + + /** + * Connection rate limit per IP. + */ + connectionsLimitPerIP: number; + + /** + * Connection rate limit window in seconds. + */ + connectionWindowSeconds: number; +}; const defaultConfig: MrrowispConfig = JSON.parse(fs.readFileSync(configPath, "utf-8")); diff --git a/src/logger.ts b/src/logger.ts index b7ff9ab..2251d36 100644 --- a/src/logger.ts +++ b/src/logger.ts @@ -1,12 +1,13 @@ import chalk from "chalk"; -export type LogLevel = "debug" | "info" | "warn" | "error"; +export type LogLevel = "debug" | "warn" | "error" | "info" | "none"; const levelPriority: Record = { debug: 0, warn: 1, error: 2, info: 3, + none: 4, }; class Logger { diff --git a/wisp/config.go b/wisp/config.go new file mode 100644 index 0000000..8e4f85d --- /dev/null +++ b/wisp/config.go @@ -0,0 +1,179 @@ +package wisp + +import ( + "encoding/json" + "net" + "os" + "strings" + "sync" +) + +type FilterList struct { + Hostnames []string `json:"hostnames"` + Ports []interface{} `json:"ports"` +} + +type Config struct { + Port int `json:"port"` + + AllowTCP bool `json:"allowTCP"` + AllowUDP bool `json:"allowUDP"` + + AllowDirectIP bool `json:"allowDirectIP"` + AllowPrivateIPs bool `json:"allowPrivateIPs"` + AllowLoopbackIPs bool `json:"allowLoopbackIPs"` + + TcpBufferSize int `json:"tcpBufferSize"` + TcpNoDelay bool `json:"tcpNoDelay"` + + Blacklist struct { + Hostnames map[string]struct{} + Ports map[uint16]struct{} + } + Whitelist struct { + Hostnames map[string]struct{} + Ports map[uint16]struct{} + } + + WebsocketPermessageDeflate bool + + DnsServers []string `json:"dnsServers"` + DnsMethod string `json:"dnsMethod"` + DnsResultOrder string `json:"dnsResultOrder"` + + EnableTwisp bool `json:"enableTwisp"` + + EnableV2 bool `json:"enableV2"` + Motd string `json:"motd"` + PasswordAuth bool `json:"passwordAuth"` + PasswordAuthRequired bool `json:"passwordAuthRequired"` + PasswordUsers map[string]string `json:"passwordUsers"` + + ParseRealIP bool `json:"parseRealIP"` + NonWSResponse string `json:"nonWSResponse"` + + LogLevel string `json:"logLevel"` + + Proxy string `json:"proxy"` + MaxMessageSize int `json:"maxMessageSize"` + StaticDir string `json:"staticDir"` + BandwidthLimitKbps int `json:"bandwidthLimitKbps"` + ConnectionsLimitPerIP int `json:"connectionsLimitPerIP"` + ConnectionWindowSeconds int `json:"connectionWindowSeconds"` + + BufferRemainingLength uint32 `json:"bufferRemainingLength"` + + Logger Logger + DNSCache *DNSCache + ReadBufPool *sync.Pool + Dialer net.Dialer +} + +func DefaultConfig() Config { + return Config{ + Port: 6001, + + AllowTCP: true, + AllowUDP: true, + + AllowDirectIP: false, + AllowPrivateIPs: false, + AllowLoopbackIPs: false, + + TcpBufferSize: 32768, + TcpNoDelay: true, + + DnsServers: []string{}, + DnsMethod: "resolve", + DnsResultOrder: "ipv4first", + + EnableTwisp: false, + + EnableV2: true, + Motd: "", + PasswordAuth: false, + PasswordAuthRequired: false, + PasswordUsers: map[string]string{}, + + ParseRealIP: true, + NonWSResponse: "", + + LogLevel: "info", + + Proxy: "", + MaxMessageSize: 0, + StaticDir: "", + BandwidthLimitKbps: 0, + ConnectionsLimitPerIP: 0, + ConnectionWindowSeconds: 0, + BufferRemainingLength: 32768, + } +} + +func CreateWispConfig(cfg *Config) *Config { + wispCfg := &Config{ + AllowTCP: cfg.AllowTCP, + AllowUDP: cfg.AllowUDP, + + AllowDirectIP: cfg.AllowDirectIP, + AllowPrivateIPs: cfg.AllowPrivateIPs, + AllowLoopbackIPs: cfg.AllowLoopbackIPs, + + TcpBufferSize: cfg.TcpBufferSize, + TcpNoDelay: cfg.TcpNoDelay, + + Blacklist: cfg.Blacklist, + Whitelist: cfg.Whitelist, + + DnsServers: cfg.DnsServers, + DnsMethod: cfg.DnsMethod, + DnsResultOrder: cfg.DnsResultOrder, + + EnableTwisp: cfg.EnableTwisp, + + EnableV2: cfg.EnableV2, + Motd: cfg.Motd, + PasswordAuth: cfg.PasswordAuth, + PasswordAuthRequired: cfg.PasswordAuthRequired, + PasswordUsers: cfg.PasswordUsers, + + ParseRealIP: cfg.ParseRealIP, + NonWSResponse: cfg.NonWSResponse, + + LogLevel: cfg.LogLevel, + + Proxy: cfg.Proxy, + MaxMessageSize: cfg.MaxMessageSize, + BandwidthLimitKbps: cfg.BandwidthLimitKbps, + ConnectionsLimitPerIP: cfg.ConnectionsLimitPerIP, + ConnectionWindowSeconds: cfg.ConnectionWindowSeconds, + + BufferRemainingLength: cfg.BufferRemainingLength, + } + + return wispCfg +} + +func LoadConfig(config string) (Config, error) { + cfg := DefaultConfig() + + trimConfig := strings.TrimSpace(config) + if strings.HasPrefix(trimConfig, "{") { + if err := json.Unmarshal([]byte(trimConfig), &cfg); err != nil { + return cfg, err + } + return cfg, nil + } + + file, err := os.Open(config) + if err != nil { + return cfg, err + } + defer file.Close() + + decoder := json.NewDecoder(file) + if err := decoder.Decode(&cfg); err != nil { + return cfg, err + } + return cfg, nil +} diff --git a/wisp/v2.go b/wisp/v2.go index adaaa6f..7ae9027 100644 --- a/wisp/v2.go +++ b/wisp/v2.go @@ -1,9 +1,6 @@ package wisp import ( - "crypto/ed25519" - "crypto/rand" - "crypto/sha256" "crypto/subtle" "encoding/binary" "errors" @@ -42,31 +39,10 @@ func (c *wispConnection) buildServerInfoPacket() []byte { extensions = addExtension(extensions, extensionPasswordAuth, meta[:]) } - if c.config.CertAuth && len(c.config.CertAuthPublicKeys) > 0 { - challenge := make([]byte, 64) - if _, err := rand.Read(challenge); err != nil { - c.config.Logger.Warn("certificate auth challenge: rand.Read failed", "error", err) - } else { - c.v2Challenge = challenge - - meta := make([]byte, 2+len(challenge)) - if c.config.CertAuthRequired { - meta[0] = 1 - } - meta[1] = sigEd25519 - copy(meta[2:], challenge) - extensions = addExtension(extensions, extensionCertificateAuth, meta) - } - } - if c.config.Motd != "" { extensions = addExtension(extensions, extensionMotd, []byte(c.config.Motd)) } - if c.config.EnableStreamConfirm { - extensions = addExtension(extensions, extensionStreamConfirm, nil) - } - payload := make([]byte, 5+2+len(extensions)) payload[0] = packetTypeInfo payload[5] = wispMajorVersion @@ -176,7 +152,7 @@ func (c *wispConnection) handleInfo(streamId uint32, payload []byte) { return } - authRequired := c.config.PasswordAuthRequired || c.config.CertAuthRequired + authRequired := c.config.PasswordAuthRequired authPassed := false if c.config.PasswordAuth && clientExts.passwordUsername != "" { @@ -193,16 +169,6 @@ func (c *wispConnection) handleInfo(streamId uint32, payload []byte) { } } - if c.config.CertAuth && len(clientExts.certificateSig) > 0 && c.v2Challenge != nil { - if c.verifyCertificate(clientExts) { - authPassed = true - } else { - c.sendClosePacket(0, closeReasonAuthBadSignature) - c.close() - return - } - } - if authRequired && !authPassed { c.sendClosePacket(0, closeReasonAuthRequired) c.close() @@ -210,7 +176,7 @@ func (c *wispConnection) handleInfo(streamId uint32, payload []byte) { } c.authenticated.Store(authPassed) - c.streamConfirm = c.config.EnableStreamConfirm && clientExts.streamConfirm + c.streamConfirm = clientExts.streamConfirm c.sendPacket(0, c.config.BufferRemainingLength) @@ -218,24 +184,6 @@ func (c *wispConnection) handleInfo(streamId uint32, payload []byte) { close(c.handshakeDone) c.handshakeDone = nil } - -func (c *wispConnection) verifyCertificate(exts *extensions) bool { - if exts.certificateSelected&sigEd25519 == 0 { - return false - } - - for _, pubKey := range c.config.CertAuthPublicKeys { - hash := sha256.Sum256([]byte(pubKey)) - if hash == exts.certificatePubkeyHash { - if ed25519.Verify(pubKey, c.v2Challenge, exts.certificateSig) { - return true - } - } - } - - return false -} - func (c *wispConnection) sendRawFrame(packet []byte) { totalLen := len(packet) var frame []byte diff --git a/wisp/wisp-connection.go b/wisp/wisp-connection.go index d0fe085..4a0a778 100644 --- a/wisp/wisp-connection.go +++ b/wisp/wisp-connection.go @@ -53,18 +53,16 @@ const maxConcurrentDials = 50 const maxPendingStreamBytes = 16 * 1024 * 1024 type wispConnection struct { - netConn net.Conn - writeCh chan writeReq - streams sync.Map - cachedStreamId uint32 - cachedStream unsafe.Pointer - isClosed atomic.Bool - shutdownOnce sync.Once - config *Config - twispStreams *twispRegistry - connectLimiter *connectRateLimiter - remoteIP string - handshakeFailures int + netConn net.Conn + writeCh chan writeReq + streams sync.Map + cachedStreamId uint32 + cachedStream unsafe.Pointer + isClosed atomic.Bool + shutdownOnce sync.Once + config *Config + twispStreams *twispRegistry + remoteIP string isV2 bool handshakeDone chan struct{} @@ -72,20 +70,10 @@ type wispConnection struct { v2Challenge []byte authenticated atomic.Bool - dialSem chan struct{} - closeCh chan struct{} - createdAt time.Time - packetLimiter *prot.PacketRateLimiter - inboundLimiter *prot.InboundRateLimiter - streamCount atomic.Int32 -} - -func (c *wispConnection) terminateNetwork() { - c.shutdownOnce.Do(func() { - c.isClosed.Store(true) - close(c.closeCh) - c.netConn.Close() - }) + dialSem chan struct{} + closeCh chan struct{} + createdAt time.Time + streamCount atomic.Int32 } func (c *wispConnection) close() { @@ -154,42 +142,17 @@ func (c *wispConnection) queueWritePooled(data []byte) { } func (c *wispConnection) releaseFrame(data []byte) { - if c.config == nil || c.config.FramePool == nil || len(data) == 0 { + if c.config == nil || len(data) == 0 { return } - if cap(data) < minFramePoolCap { + if cap(data) < 64*1024 { return } buf := data if len(buf) != cap(buf) { buf = data[:cap(data)] } - c.config.FramePool.Put(buf) -} - -func (c *wispConnection) noteHandshakeFailure() { - if c.config == nil || c.config.MaxHandshakeFailures <= 0 { - return - } - c.handshakeFailures++ - if c.handshakeFailures >= c.config.MaxHandshakeFailures { - c.config.Logger.Warn("handshake failures exceeded", "ip", c.remoteIP) - c.terminateNetwork() - } -} - -func (c *wispConnection) lifetimeWatchdog() { - if c.config.MaxConnectionLifetime <= 0 { - return - } - timer := time.NewTimer(c.config.MaxConnectionLifetime) - defer timer.Stop() - select { - case <-timer.C: - c.config.Logger.Warn("connection lifetime exceeded", "ip", c.remoteIP) - c.terminateNetwork() - case <-c.closeCh: - } + // cfg.config.FramePool.Put(buf) } func (c *wispConnection) handlePacket(packetType uint8, streamId uint32, payload []byte) { @@ -224,12 +187,11 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { hostname := string(payload[3:]) c.config.Logger.Debug("creating stream", "ip", c.remoteIP, "streamId", streamId, "hostname", hostname, "port", port, "type", streamType) - action, normalizedHostname, reason := guard.allowConnect(c, streamType, hostname, port) - if action == connectBlocked { - c.sendClosePacket(streamId, reason) - return - } - if action == connectTwisp { + if streamType == streamTypeTerm { + if !c.config.EnableTwisp { + c.sendClosePacket(streamId, closeReasonBlocked) + return + } go handleTwisp(c, streamId, hostname) return } @@ -264,7 +226,7 @@ func (c *wispConnection) handleConnectPacket(streamId uint32, payload []byte) { } c.streamCount.Add(1) - go stream.handleConnect(streamType, port, normalizedHostname) + go stream.handleConnect(streamType, port, hostname) } func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { @@ -315,7 +277,7 @@ func (c *wispConnection) handleDataPacket(streamId uint32, payload []byte) { stream.pendingMutex.Lock() if !stream.connReadyDone.Load() { - if stream.pendingBytes+len(payload) > maxPendingStreamBytes { + if stream.pendingBytes+len(payload) > 16*1024*1024 { stream.pendingMutex.Unlock() stream.close(closeReasonThrottled) return @@ -421,10 +383,7 @@ func (c *wispConnection) deleteWispStream(streamId uint32) { } func (c *wispConnection) deleteAllWispStreams() { - c.terminateNetwork() - if c.config.ConnectionCounter != nil { - c.config.ConnectionCounter.Remove(c.remoteIP) - } + c.close() c.config.Logger.Info("connection closed", "ip", c.remoteIP) c.streams.Range(func(key, value any) bool { stream := value.(*wispStream) diff --git a/wisp/wisp-stream.go b/wisp/wisp-stream.go index a2b991d..6269764 100644 --- a/wisp/wisp-stream.go +++ b/wisp/wisp-stream.go @@ -9,8 +9,6 @@ import ( "sync/atomic" "time" - prot "mrrowisp/wisp/protection" - "golang.org/x/net/proxy" ) @@ -35,15 +33,19 @@ type wispStream struct { const dnsLookupTimeout = 10 * time.Second +func NormalizeTargetHostname(host string) string { + host = strings.TrimSpace(strings.ToLower(host)) + host = strings.TrimSuffix(host, ".") + return host +} + +const dnsLookupTimeout = 10 * time.Second + func (s *wispStream) handleConnect(streamType uint8, port string, hostname string) { defer s.signalConnReady() cfg := s.wispConn.config - s.hostname = prot.NormalizeTargetHostname(hostname) - if s.hostname == "" { - s.close(closeReasonInvalidInfo) - return - } + s.hostname = NormalizeTargetHostname(hostname) guard := newProtection(cfg) @@ -104,7 +106,6 @@ func (s *wispStream) handleConnect(streamType uint8, port string, hostname strin proxyURL = strings.Replace(proxyURL, "socks4a://", "socks4://", 1) dialer, proxyErr := proxy.SOCKS5("tcp", stripScheme(proxyURL), nil, proxy.Direct) if proxyErr != nil { - <-s.wispConn.dialSem cfg.Logger.Warn("proxy dialer creation failed", "ip", s.wispConn.remoteIP, "error", proxyErr) s.close(closeReasonNetworkError) return diff --git a/wisp/wisp.go b/wisp/wisp.go index 20d9974..7802b7c 100644 --- a/wisp/wisp.go +++ b/wisp/wisp.go @@ -1,7 +1,6 @@ package wisp import ( - "crypto/ed25519" "net" "net/http" "strings" @@ -13,191 +12,22 @@ import ( "github.com/lxzan/gws" ) -type PortRange struct { - Min int - Max int -} - -func (r PortRange) Contains(p int) bool { - return p >= r.Min && p <= r.Max -} - -type Config struct { - AllowTCP bool - AllowUDP bool - AllowDirectIP bool - AllowPrivateIPs bool - AllowLoopbackIPs bool - - TcpBufferSize int - BufferRemainingLength uint32 - TcpNoDelay bool - WebsocketTcpNoDelay bool - - StreamLimitPerHost int - StreamLimitTotal int - - Blacklist struct { - Hostnames map[string]struct{} - Ports []PortRange - } - Whitelist struct { - Hostnames map[string]struct{} - Ports []PortRange - } - - Proxy string - WebsocketPermessageDeflate bool - - DnsServers []string - DnsTTLSeconds int - DnsMethod string - DnsResultOrder string - - EnableTwisp bool - - EnableV2 bool - Motd string - PasswordAuth bool - PasswordAuthRequired bool - PasswordUsers map[string]string - CertAuth bool - CertAuthRequired bool - CertAuthPublicKeys []ed25519.PublicKey - EnableStreamConfirm bool - MaxConnectsPerSecond int - - BandwidthLimitKbps int - ConnectionsLimitPerIP int - ConnectionWindowSeconds int - ParseRealIP bool - ParseRealIPFrom map[string]struct{} - - MaxMessageSize int - AllowedOrigins []string - WriteTimeout time.Duration - FrameReadTimeout time.Duration - - NonWSResponse string - LogLevel string - - Logger Logger - BandwidthLimiter *protection.BandwidthLimiter - ConnectionLimiter *protection.ConnectionLimiter - ConnectionCounter *protection.ConnectionCounter - StreamLimiter *protection.StreamLimiter - FramePool *sync.Pool - - DNSCache *DNSCache - ReadBufPool sync.Pool - Dialer net.Dialer - - BanEnabled bool - BanDuration time.Duration - BanMaxStrikes int - BanEscalationMultiplier int - BanList *protection.BanList - MaxHandshakeFailures int - - MaxPacketRate int - MaxConnectionLifetime time.Duration - MaxStreamsPerConnection int - MaxConnectionsPerIP int - GlobalMaxConnections int - WriteQueueSize int - MaxInboundBytesPerSecond int -} - const ( defaultStreamLimitPerHost = 512 defaultStreamLimitTotal = 16384 + defaultMaxConnectsPerSecond = 20 defaultConnectionsLimitPerIP = 120 + defaultHandshakeFailures = 10 ) -func DefaultConfig() *Config { - return &Config{ - AllowTCP: true, - AllowUDP: true, - AllowDirectIP: false, - AllowPrivateIPs: false, - AllowLoopbackIPs: false, - TcpBufferSize: 32768, - BufferRemainingLength: 65536, - TcpNoDelay: true, - WebsocketTcpNoDelay: true, - StreamLimitPerHost: defaultStreamLimitPerHost, - StreamLimitTotal: defaultStreamLimitTotal, - MaxConnectsPerSecond: maxConnectsPerSecond, - PasswordUsers: make(map[string]string), - DnsTTLSeconds: 120, - DnsMethod: "lookup", - DnsResultOrder: "verbatim", - ConnectionWindowSeconds: 1, - ConnectionsLimitPerIP: defaultConnectionsLimitPerIP, - MaxHandshakeFailures: 10, - BanEnabled: true, - BanDuration: time.Hour, - BanMaxStrikes: 10, - BanEscalationMultiplier: 0, - WriteTimeout: 15 * time.Second, - FrameReadTimeout: 30 * time.Second, - MaxPacketRate: 500, - MaxConnectionLifetime: 0, - MaxStreamsPerConnection: 0, - MaxConnectionsPerIP: 0, - GlobalMaxConnections: 0, - WriteQueueSize: 4096, - MaxInboundBytesPerSecond: 0, - } -} - -func (c *Config) InitResolver() { - c.DNSCache = NewDNSCache( +func (cfg *Config) InitResolver() { + cfg.DNSCache = NewDNSCache( DNSCacheConfig{ - Servers: c.DnsServers, - TTLSeconds: c.DnsTTLSeconds, - Method: c.DnsMethod, - ResultOrder: c.DnsResultOrder, + Servers: cfg.DnsServers, + Method: cfg.DnsMethod, + ResultOrder: cfg.DnsResultOrder, }) - if c.LogLevel == "" { - c.LogLevel = "info" - } - if c.Logger == nil { - c.Logger = newLogger(c.LogLevel) - } - if c.BandwidthLimitKbps > 0 { - c.BandwidthLimiter = protection.NewBandwidthLimiter(c.BandwidthLimitKbps, time.Duration(c.ConnectionWindowSeconds)*time.Second) - } - if c.ConnectionsLimitPerIP > 0 { - c.ConnectionLimiter = protection.NewConnectionLimiter(c.ConnectionsLimitPerIP, time.Duration(c.ConnectionWindowSeconds)*time.Second) - } - if c.StreamLimiter == nil { - c.StreamLimiter = protection.NewStreamLimiter() - } - if c.ParseRealIPFrom == nil { - c.ParseRealIPFrom = make(map[string]struct{}) - } - if c.MaxConnectionsPerIP > 0 || c.GlobalMaxConnections > 0 { - c.ConnectionCounter = protection.NewConnectionCounter() - } - if c.BanEnabled { - c.BanList = protection.NewBanListEscalated(c.BanDuration, c.BanMaxStrikes, c.BanEscalationMultiplier) - } - if c.FramePool == nil { - readBufSize := 15 + c.TcpBufferSize - c.FramePool = &sync.Pool{ - New: func() any { - buf := make([]byte, readBufSize) - return buf - }, - } - } - if c.WriteTimeout < 0 { - c.WriteTimeout = 0 - } - if c.FrameReadTimeout < 0 { - c.FrameReadTimeout = 0 - } + cfg.Logger = newLogger(cfg.LogLevel) } type upgradeHandler struct { @@ -208,7 +38,7 @@ func CreateWispHandler(config *Config) http.HandlerFunc { config.InitResolver() readBufSize := 15 + config.TcpBufferSize - config.ReadBufPool = sync.Pool{ + config.ReadBufPool = &sync.Pool{ New: func() any { buf := make([]byte, readBufSize) return &buf @@ -292,9 +122,6 @@ func CreateWispHandler(config *Config) http.HandlerFunc { netConn := wsConn.NetConn() if tc, ok := netConn.(*net.TCPConn); ok { - if config.WebsocketTcpNoDelay { - tc.SetNoDelay(true) - } tc.SetReadBuffer(1 << 20) tc.SetWriteBuffer(1 << 20) } @@ -340,25 +167,9 @@ func CreateWispHandler(config *Config) http.HandlerFunc { } } -func (c *Config) requiresV2() bool { - if c == nil { - return false - } - return c.PasswordAuthRequired || c.CertAuthRequired || c.EnableTwisp -} - -func originAllowed(r *http.Request, allowedOrigins []string) bool { - if len(allowedOrigins) == 0 { - return true - } - origin := strings.TrimSpace(r.Header.Get("Origin")) - if origin == "" { +func (cfg *Config) requiresV2() bool { + if cfg == nil { return false } - for _, allowed := range allowedOrigins { - if origin == strings.TrimSpace(allowed) { - return true - } - } - return false + return cfg.PasswordAuthRequired || cfg.EnableTwisp } diff --git a/wisp/wsreader.go b/wisp/wsreader.go index 4d21716..faef8c1 100644 --- a/wisp/wsreader.go +++ b/wisp/wsreader.go @@ -32,7 +32,6 @@ func (c *wispConnection) readLoop() { if rsv != 0 || !masked || !fin { c.sendWSClose(1002) - c.noteHandshakeFailure() return } @@ -55,7 +54,6 @@ func (c *wispConnection) readLoop() { isControlFrame := opcode >= 0x8 if isControlFrame && payloadLen > 125 { c.sendWSClose(1002) - c.noteHandshakeFailure() return } @@ -68,7 +66,6 @@ func (c *wispConnection) readLoop() { if payloadLen > c.maxPayloadSize() { c.sendWSClose(1009) - c.noteHandshakeFailure() return } @@ -114,7 +111,6 @@ func (c *wispConnection) readLoop() { default: if opcode != 0x0 { - c.noteHandshakeFailure() } continue }