diff --git a/env.example b/env.example index f1d67a9..adb19c2 100644 --- a/env.example +++ b/env.example @@ -29,6 +29,7 @@ JWT_CLOCK_SKEW=2m # Rate Limiting Configuration RATE_LIMIT_ENABLED=true RATE_LIMIT_AUTH_PER_MINUTE=5 +RATE_LIMIT_REFRESH_PER_MINUTE=20 RATE_LIMIT_GENERAL_PER_MINUTE=60 RATE_LIMIT_BURST_SIZE=10 diff --git a/internal/config/config.go b/internal/config/config.go index d8ebbd4..034775e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -56,6 +56,7 @@ type JWTConfig struct { type RateLimitConfig struct { AuthRequestsPerMinute int + RefreshRequestsPerMinute int GeneralRequestsPerMinute int BurstSize int Enabled bool @@ -115,6 +116,7 @@ func Load() (*Config, error) { }, RateLimit: RateLimitConfig{ AuthRequestsPerMinute: getEnvInt("RATE_LIMIT_AUTH_PER_MINUTE", 5), + RefreshRequestsPerMinute: getEnvInt("RATE_LIMIT_REFRESH_PER_MINUTE", 20), GeneralRequestsPerMinute: getEnvInt("RATE_LIMIT_GENERAL_PER_MINUTE", 60), BurstSize: getEnvInt("RATE_LIMIT_BURST_SIZE", 10), Enabled: getEnv("RATE_LIMIT_ENABLED", trueStr) == trueStr, diff --git a/internal/http/rate_limit_middleware.go b/internal/http/rate_limit_middleware.go index a093464..50e563d 100644 --- a/internal/http/rate_limit_middleware.go +++ b/internal/http/rate_limit_middleware.go @@ -130,7 +130,9 @@ func (rl *RateLimiter) RateLimitMiddleware() func(http.Handler) http.Handler { clientID := getClientIP(r) limit := rl.config.GeneralRequestsPerMinute - if IsAuthEndpoint(r.URL.Path) { + if IsRefreshEndpoint(r.URL.Path) { + limit = rl.config.RefreshRequestsPerMinute + } else if IsAuthEndpoint(r.URL.Path) { limit = rl.config.AuthRequestsPerMinute } @@ -144,11 +146,14 @@ func (rl *RateLimiter) RateLimitMiddleware() func(http.Handler) http.Handler { } } +func IsRefreshEndpoint(path string) bool { + return path == "/api/v1/auth/refresh" +} + func IsAuthEndpoint(path string) bool { authPaths := []string{ "/api/v1/auth/login", "/api/v1/auth/register", - "/api/v1/auth/refresh", } for _, authPath := range authPaths { diff --git a/internal/http/rate_limit_middleware_test.go b/internal/http/rate_limit_middleware_test.go index 462ad29..a098c88 100644 --- a/internal/http/rate_limit_middleware_test.go +++ b/internal/http/rate_limit_middleware_test.go @@ -14,6 +14,7 @@ const testClientIP = "192.168.1.1:12345" func TestRateLimiter_GeneralRequests(t *testing.T) { cfg := &config.RateLimitConfig{ AuthRequestsPerMinute: 5, + RefreshRequestsPerMinute: 20, GeneralRequestsPerMinute: 3, BurstSize: 5, Enabled: true, @@ -51,6 +52,7 @@ func TestRateLimiter_GeneralRequests(t *testing.T) { func TestRateLimiter_AuthRequests(t *testing.T) { cfg := &config.RateLimitConfig{ AuthRequestsPerMinute: 2, + RefreshRequestsPerMinute: 20, GeneralRequestsPerMinute: 10, BurstSize: 5, Enabled: true, @@ -85,9 +87,48 @@ func TestRateLimiter_AuthRequests(t *testing.T) { } } +func TestRateLimiter_RefreshRequests(t *testing.T) { + cfg := &config.RateLimitConfig{ + AuthRequestsPerMinute: 2, + RefreshRequestsPerMinute: 5, + GeneralRequestsPerMinute: 10, + BurstSize: 5, + Enabled: true, + } + + log := logger.New("INFO", "text") + rateLimiter := NewRateLimiter(cfg, log) + + handler := rateLimiter.RateLimitMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + })) + + // Test refresh endpoint + req := httptest.NewRequest("POST", "/api/v1/auth/refresh", http.NoBody) + req.RemoteAddr = testClientIP + + // First 5 requests should succeed + for i := 0; i < 5; i++ { + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("Request %d: expected status 200, got %d", i+1, w.Code) + } + } + + // 6th request should be rate limited + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Errorf("Expected status 429, got %d", w.Code) + } +} + func TestRateLimiter_Disabled(t *testing.T) { cfg := &config.RateLimitConfig{ AuthRequestsPerMinute: 1, + RefreshRequestsPerMinute: 1, GeneralRequestsPerMinute: 1, BurstSize: 1, Enabled: false, @@ -117,6 +158,7 @@ func TestRateLimiter_Disabled(t *testing.T) { func TestRateLimiter_DifferentClients(t *testing.T) { cfg := &config.RateLimitConfig{ AuthRequestsPerMinute: 2, + RefreshRequestsPerMinute: 20, GeneralRequestsPerMinute: 2, BurstSize: 5, Enabled: true, @@ -162,7 +204,7 @@ func TestIsAuthEndpoint(t *testing.T) { }{ {"/api/v1/auth/login", true}, {"/api/v1/auth/register", true}, - {"/api/v1/auth/refresh", true}, + {"/api/v1/auth/refresh", false}, {"/health", false}, {"/api/v1/user/profile", false}, {"/swagger/", false}, @@ -177,3 +219,26 @@ func TestIsAuthEndpoint(t *testing.T) { }) } } + +func TestIsRefreshEndpoint(t *testing.T) { + tests := []struct { + path string + expected bool + }{ + {"/api/v1/auth/refresh", true}, + {"/api/v1/auth/login", false}, + {"/api/v1/auth/register", false}, + {"/health", false}, + {"/api/v1/user/profile", false}, + {"/swagger/", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := IsRefreshEndpoint(tt.path) + if result != tt.expected { + t.Errorf("isRefreshEndpoint(%s) = %v, expected %v", tt.path, result, tt.expected) + } + }) + } +} diff --git a/server b/server index 4b9dfe8..a53fd9c 100755 Binary files a/server and b/server differ