diff --git a/cmd/server/main.go b/cmd/server/main.go index 356ac3d..dce2696 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -46,7 +46,7 @@ func main() { // Initialize services and handlers authService := setupServices(db, cfg) - handlers := setupHandlers(authService, logger, db) + handlers := setupHandlers(authService, logger, db, cfg) // Setup routes and middleware handler := setupRoutes(handlers, logger, authService, cfg) @@ -87,7 +87,8 @@ func runMigrations(cfg *config.Config, logger *logger.Logger) { func setupServices(db *database.Database, cfg *config.Config) services.AuthService { userRepo := repositories.NewUserRepository(db.Pool()) - authService := services.NewAuthService(userRepo, &cfg.JWT) + refreshTokenRepo := repositories.NewRefreshTokenRepository(db.Pool()) + authService := services.NewAuthService(userRepo, refreshTokenRepo, &cfg.JWT) return authService } @@ -96,9 +97,9 @@ type Handlers struct { Health *httphandler.DetailedHealthHandler } -func setupHandlers(authService services.AuthService, logger *logger.Logger, db *database.Database) *Handlers { +func setupHandlers(authService services.AuthService, logger *logger.Logger, db *database.Database, cfg *config.Config) *Handlers { return &Handlers{ - Auth: httphandler.NewAuthHandlers(authService, logger), + Auth: httphandler.NewAuthHandlers(authService, logger, cfg), Health: httphandler.NewDetailedHealthHandler(logger, db.Pool()), } } diff --git a/env.example b/env.example index 0bdd2eb..bdb0b0d 100644 --- a/env.example +++ b/env.example @@ -34,14 +34,14 @@ RATE_LIMIT_BURST_SIZE=10 # CORS Configuration # Comma-separated list of allowed origins -CORS_ALLOWED_ORIGINS=http://localhost:3000,http://localhost:3001,http://localhost:4200 +CORS_ALLOWED_ORIGINS=http://localhost:3000,http://localhost:3001,http://localhost:4200,https://your-frontend-domain.com # Comma-separated list of allowed methods CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS # Comma-separated list of allowed headers CORS_ALLOWED_HEADERS=Accept,Authorization,Content-Type,X-Request-ID # Comma-separated list of exposed headers CORS_EXPOSED_HEADERS=X-Request-ID -# Allow credentials (true/false) +# Allow credentials (true/false) - REQUIRED for cross-domain cookies CORS_ALLOW_CREDENTIALS=true # Max age for preflight requests in seconds CORS_MAX_AGE=86400 @@ -50,6 +50,7 @@ CORS_MAX_AGE=86400 # Set to 'production' for HTTPS cookies, leave empty for development ENVIRONMENT= + # Security Headers Configuration # HSTS max age in seconds (1 year = 31536000) SECURITY_HSTS_MAX_AGE=31536000 diff --git a/internal/http/auth_handlers.go b/internal/http/auth_handlers.go index b74558d..6fe04ad 100644 --- a/internal/http/auth_handlers.go +++ b/internal/http/auth_handlers.go @@ -3,38 +3,36 @@ package http import ( "encoding/json" "net/http" - "os" + "github.com/aleksandr/strive-api/internal/config" "github.com/aleksandr/strive-api/internal/logger" "github.com/aleksandr/strive-api/internal/models" "github.com/aleksandr/strive-api/internal/services" "github.com/aleksandr/strive-api/internal/validation" ) -const ( - productionEnv = "production" -) - type AuthHandlers struct { authService services.AuthService logger *logger.Logger securityLogger *SecurityLogger + config *config.Config } -func NewAuthHandlers(authService services.AuthService, logger *logger.Logger) *AuthHandlers { +func NewAuthHandlers(authService services.AuthService, logger *logger.Logger, cfg *config.Config) *AuthHandlers { return &AuthHandlers{ authService: authService, logger: logger, securityLogger: NewSecurityLogger(logger), + config: cfg, } } -func getCookieSettings() bool { - return os.Getenv("ENVIRONMENT") == productionEnv +func (h *AuthHandlers) getCookieSettings() (secure bool, sameSite http.SameSite) { + return false, http.SameSiteNoneMode } -func setSecureCookie(w http.ResponseWriter, name, value string, maxAge int) { - secure := getCookieSettings() +func (h *AuthHandlers) setSecureCookie(w http.ResponseWriter, name, value string, maxAge int) { + secure, sameSite := h.getCookieSettings() cookie := &http.Cookie{ Name: name, @@ -42,7 +40,7 @@ func setSecureCookie(w http.ResponseWriter, name, value string, maxAge int) { Path: "/", Secure: secure, HttpOnly: true, - SameSite: http.SameSiteStrictMode, + SameSite: sameSite, MaxAge: maxAge, } @@ -64,11 +62,10 @@ type RefreshRequest struct { } type AuthResponse struct { - AccessToken string `json:"access_token,omitempty" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."` - RefreshToken string `json:"refresh_token,omitempty" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."` - ExpiresIn int `json:"expires_in" example:"900"` - TokenType string `json:"token_type" example:"Bearer"` - Message string `json:"message,omitempty" example:"Login successful"` + AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."` + ExpiresIn int `json:"expires_in" example:"900"` + TokenType string `json:"token_type" example:"Bearer"` + Message string `json:"message,omitempty" example:"Login successful"` } type ErrorResponse struct { @@ -204,13 +201,13 @@ func (h *AuthHandlers) Login(w http.ResponseWriter, r *http.Request) { h.logger.Info("User logged in successfully", "email", req.Email) - setSecureCookie(w, "access-token", accessToken, 900) - setSecureCookie(w, "refresh-token", refreshToken, 604800) + h.setSecureCookie(w, "refresh-token", refreshToken, 604800) response := AuthResponse{ - ExpiresIn: 900, - TokenType: "Bearer", - Message: "Login successful", + AccessToken: accessToken, + ExpiresIn: 900, + TokenType: "Bearer", + Message: "Login successful", } w.Header().Set("Content-Type", "application/json") @@ -254,13 +251,13 @@ func (h *AuthHandlers) Refresh(w http.ResponseWriter, r *http.Request) { h.logger.Info("Token refreshed successfully") - setSecureCookie(w, "access-token", accessToken, 900) - setSecureCookie(w, "refresh-token", refreshToken, 604800) + h.setSecureCookie(w, "refresh-token", refreshToken, 604800) response := AuthResponse{ - ExpiresIn: 900, - TokenType: "Bearer", - Message: "Token refreshed successfully", + AccessToken: accessToken, + ExpiresIn: 900, + TokenType: "Bearer", + Message: "Token refreshed successfully", } w.Header().Set("Content-Type", "application/json") @@ -277,8 +274,14 @@ func (h *AuthHandlers) Refresh(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} map[string]interface{} "Logout successful" // @Router /api/v1/auth/logout [post] func (h *AuthHandlers) Logout(w http.ResponseWriter, r *http.Request) { - setSecureCookie(w, "access-token", "", -1) - setSecureCookie(w, "refresh-token", "", -1) + refreshTokenCookie, err := r.Cookie("refresh-token") + if err == nil && refreshTokenCookie.Value != "" { + if err := h.authService.Logout(r.Context(), refreshTokenCookie.Value); err != nil { + h.logger.Error("Failed to logout user", "error", err) + } + } + + h.setSecureCookie(w, "refresh-token", "", -1) h.logger.Info("User logged out successfully") diff --git a/internal/http/auth_handlers_test.go b/internal/http/auth_handlers_test.go index 3011da7..ece3efe 100644 --- a/internal/http/auth_handlers_test.go +++ b/internal/http/auth_handlers_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "testing" + "github.com/aleksandr/strive-api/internal/config" "github.com/aleksandr/strive-api/internal/logger" "github.com/aleksandr/strive-api/internal/models" "github.com/google/uuid" @@ -71,7 +72,8 @@ func TestAuthHandlers_Register(t *testing.T) { mockService := new(MockAuthService) tt.mockSetup(mockService) - handlers := NewAuthHandlers(mockService, logger) + cfg := &config.Config{} + handlers := NewAuthHandlers(mockService, logger, cfg) body, _ := json.Marshal(tt.requestBody) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body)) @@ -152,7 +154,8 @@ func TestAuthHandlers_Login(t *testing.T) { mockService := new(MockAuthService) tt.mockSetup(mockService) - handlers := NewAuthHandlers(mockService, logger) + cfg := &config.Config{} + handlers := NewAuthHandlers(mockService, logger, cfg) body, _ := json.Marshal(tt.requestBody) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) @@ -169,20 +172,22 @@ func TestAuthHandlers_Login(t *testing.T) { assert.NoError(t, err) assert.Contains(t, response, "error") } else { - // Check that cookies are set instead of JSON tokens + // Check that access token is in JSON response and refresh token is in cookie + var response map[string]interface{} + err := json.Unmarshal(rr.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Contains(t, response, "access_token") + assert.Equal(t, "access_token", response["access_token"]) + cookies := rr.Result().Cookies() - var accessTokenCookie, refreshTokenCookie *http.Cookie + var refreshTokenCookie *http.Cookie for _, cookie := range cookies { - if cookie.Name == "access-token" { - accessTokenCookie = cookie - } else if cookie.Name == "refresh-token" { + if cookie.Name == "refresh-token" { refreshTokenCookie = cookie } } - assert.NotNil(t, accessTokenCookie, "access-token cookie should be set") assert.NotNil(t, refreshTokenCookie, "refresh-token cookie should be set") - assert.Equal(t, "access_token", accessTokenCookie.Value) assert.Equal(t, "refresh_token", refreshTokenCookie.Value) } @@ -215,7 +220,8 @@ func TestAuthHandlers_Me(t *testing.T) { mockService := &MockAuthService{} tt.mockSetup(mockService) - handlers := NewAuthHandlers(mockService, logger) + cfg := &config.Config{} + handlers := NewAuthHandlers(mockService, logger, cfg) req := httptest.NewRequest("GET", "/api/v1/auth/me", http.NoBody) rr := httptest.NewRecorder() diff --git a/internal/http/auth_middleware.go b/internal/http/auth_middleware.go index 2d8cb12..f4f0059 100644 --- a/internal/http/auth_middleware.go +++ b/internal/http/auth_middleware.go @@ -56,34 +56,20 @@ func logAuthFailure(log *logger.Logger, r *http.Request, reason string) { func AuthMiddleware(authService services.AuthService, log *logger.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var tokenString string - var authSource string - - // Try to get token from Authorization header first authHeader := r.Header.Get("Authorization") - if authHeader != "" { - parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" && parts[1] != "" { - tokenString = parts[1] - authSource = "header" - } - } - - // If no token in header, try to get from cookie - if tokenString == "" { - accessTokenCookie, err := r.Cookie("access-token") - if err == nil && accessTokenCookie.Value != "" { - tokenString = accessTokenCookie.Value - authSource = "cookie" - } + if authHeader == "" { + writeAuthError(w, log, r, "UNAUTHORIZED", "Authentication required", "missing_authorization_header") + return } - // If still no token, return error - if tokenString == "" { - writeAuthError(w, log, r, "UNAUTHORIZED", "Authentication required", "missing_token") + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" || parts[1] == "" { + writeAuthError(w, log, r, "UNAUTHORIZED", "Invalid authorization header format", "invalid_authorization_format") return } + tokenString := parts[1] + claims, err := authService.ValidateToken(tokenString) if err != nil { var code, message, reason string @@ -107,11 +93,9 @@ func AuthMiddleware(authService services.AuthService, log *logger.Logger) func(h return } - // Log successful authentication with source log.Debug("Authentication successful", "user_id", claims.UserID, - "email", claims.Email, - "source", authSource) + "email", claims.Email) ctx := context.WithValue(r.Context(), UserIDKey, claims.UserID.String()) ctx = context.WithValue(ctx, UserEmailKey, claims.Email) diff --git a/internal/http/auth_middleware_test.go b/internal/http/auth_middleware_test.go index 7824e09..912139a 100644 --- a/internal/http/auth_middleware_test.go +++ b/internal/http/auth_middleware_test.go @@ -46,6 +46,10 @@ func (m *mockAuthService) VerifyPassword(hashedPassword, password string) error return nil } +func (m *mockAuthService) Logout(ctx context.Context, refreshToken string) error { + return nil +} + func TestAuthMiddleware(t *testing.T) { log := logger.New("INFO", "json") diff --git a/internal/http/mocks_test.go b/internal/http/mocks_test.go index fad5234..4fa0215 100644 --- a/internal/http/mocks_test.go +++ b/internal/http/mocks_test.go @@ -47,3 +47,8 @@ func (m *MockAuthService) RefreshToken(ctx context.Context, refreshToken string) args := m.Called(ctx, refreshToken) return args.String(0), args.String(1), args.Error(2) } + +func (m *MockAuthService) Logout(ctx context.Context, refreshToken string) error { + args := m.Called(ctx, refreshToken) + return args.Error(0) +} diff --git a/internal/models/refresh_token.go b/internal/models/refresh_token.go new file mode 100644 index 0000000..cbdd6bf --- /dev/null +++ b/internal/models/refresh_token.go @@ -0,0 +1,16 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +type RefreshToken struct { + ID uuid.UUID `json:"id" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + Token string `json:"token" db:"token"` + ExpiresAt time.Time `json:"expires_at" db:"expires_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} diff --git a/internal/repositories/refresh_token_repository.go b/internal/repositories/refresh_token_repository.go new file mode 100644 index 0000000..35b4465 --- /dev/null +++ b/internal/repositories/refresh_token_repository.go @@ -0,0 +1,137 @@ +package repositories + +import ( + "context" + "fmt" + + "github.com/aleksandr/strive-api/internal/models" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" +) + +type RefreshTokenRepository interface { + Create(ctx context.Context, token *models.RefreshToken) error + GetByToken(ctx context.Context, token string) (*models.RefreshToken, error) + GetByUserID(ctx context.Context, userID uuid.UUID) ([]*models.RefreshToken, error) + Delete(ctx context.Context, token string) error + DeleteByUserID(ctx context.Context, userID uuid.UUID) error + DeleteExpired(ctx context.Context) error +} + +type refreshTokenRepository struct { + pool *pgxpool.Pool +} + +func NewRefreshTokenRepository(pool *pgxpool.Pool) RefreshTokenRepository { + return &refreshTokenRepository{ + pool: pool, + } +} + +func (r *refreshTokenRepository) Create(ctx context.Context, token *models.RefreshToken) error { + query := ` + INSERT INTO refresh_tokens (id, user_id, token, expires_at, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6) + ` + + _, err := r.pool.Exec(ctx, query, token.ID, token.UserID, token.Token, token.ExpiresAt, token.CreatedAt, token.UpdatedAt) + if err != nil { + return fmt.Errorf("failed to create refresh token: %w", err) + } + + return nil +} + +func (r *refreshTokenRepository) GetByToken(ctx context.Context, token string) (*models.RefreshToken, error) { + query := ` + SELECT id, user_id, token, expires_at, created_at, updated_at + FROM refresh_tokens + WHERE token = $1 AND expires_at > NOW() + ` + + refreshToken := &models.RefreshToken{} + err := r.pool.QueryRow(ctx, query, token).Scan( + &refreshToken.ID, + &refreshToken.UserID, + &refreshToken.Token, + &refreshToken.ExpiresAt, + &refreshToken.CreatedAt, + &refreshToken.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to get refresh token: %w", err) + } + + return refreshToken, nil +} + +func (r *refreshTokenRepository) GetByUserID(ctx context.Context, userID uuid.UUID) ([]*models.RefreshToken, error) { + query := ` + SELECT id, user_id, token, expires_at, created_at, updated_at + FROM refresh_tokens + WHERE user_id = $1 AND expires_at > NOW() + ORDER BY created_at DESC + ` + + rows, err := r.pool.Query(ctx, query, userID) + if err != nil { + return nil, fmt.Errorf("failed to get refresh tokens by user id: %w", err) + } + defer rows.Close() + + var tokens []*models.RefreshToken + for rows.Next() { + token := &models.RefreshToken{} + err := rows.Scan( + &token.ID, + &token.UserID, + &token.Token, + &token.ExpiresAt, + &token.CreatedAt, + &token.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan refresh token: %w", err) + } + tokens = append(tokens, token) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to iterate refresh tokens: %w", err) + } + + return tokens, nil +} + +func (r *refreshTokenRepository) Delete(ctx context.Context, token string) error { + query := `DELETE FROM refresh_tokens WHERE token = $1` + + _, err := r.pool.Exec(ctx, query, token) + if err != nil { + return fmt.Errorf("failed to delete refresh token: %w", err) + } + + return nil +} + +func (r *refreshTokenRepository) DeleteByUserID(ctx context.Context, userID uuid.UUID) error { + query := `DELETE FROM refresh_tokens WHERE user_id = $1` + + _, err := r.pool.Exec(ctx, query, userID) + if err != nil { + return fmt.Errorf("failed to delete refresh tokens by user id: %w", err) + } + + return nil +} + +func (r *refreshTokenRepository) DeleteExpired(ctx context.Context) error { + query := `DELETE FROM refresh_tokens WHERE expires_at <= NOW()` + + _, err := r.pool.Exec(ctx, query) + if err != nil { + return fmt.Errorf("failed to delete expired refresh tokens: %w", err) + } + + return nil +} diff --git a/internal/services/auth_service.go b/internal/services/auth_service.go index 433356b..aba76d7 100644 --- a/internal/services/auth_service.go +++ b/internal/services/auth_service.go @@ -2,6 +2,8 @@ package services import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "strings" @@ -30,6 +32,7 @@ type AuthService interface { ValidateToken(tokenString string) (*Claims, error) HashPassword(password string) (string, error) VerifyPassword(hashedPassword, password string) error + Logout(ctx context.Context, refreshToken string) error } type Claims struct { @@ -39,18 +42,24 @@ type Claims struct { } type authService struct { - userRepo repositories.UserRepository - config *config.JWTConfig - accessTTL time.Duration - refreshTTL time.Duration + userRepo repositories.UserRepository + refreshTokenRepo repositories.RefreshTokenRepository + config *config.JWTConfig + accessTTL time.Duration + refreshTTL time.Duration } -func NewAuthService(userRepo repositories.UserRepository, jwtConfig *config.JWTConfig) AuthService { +func NewAuthService( + userRepo repositories.UserRepository, + refreshTokenRepo repositories.RefreshTokenRepository, + jwtConfig *config.JWTConfig, +) AuthService { return &authService{ - userRepo: userRepo, - config: jwtConfig, - accessTTL: 15 * time.Minute, - refreshTTL: 7 * 24 * time.Hour, + userRepo: userRepo, + refreshTokenRepo: refreshTokenRepo, + config: jwtConfig, + accessTTL: 15 * time.Minute, + refreshTTL: 7 * 24 * time.Hour, } } @@ -103,21 +112,34 @@ func (s *authService) Login(ctx context.Context, email, password string) (string return "", "", fmt.Errorf("failed to generate access token: %w", err) } - refreshToken, err := s.generateToken(user, s.refreshTTL) + refreshToken, err := s.generateRefreshToken() if err != nil { return "", "", fmt.Errorf("failed to generate refresh token: %w", err) } + refreshTokenModel := &models.RefreshToken{ + ID: uuid.New(), + UserID: user.ID, + Token: refreshToken, + ExpiresAt: time.Now().Add(s.refreshTTL), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.refreshTokenRepo.Create(ctx, refreshTokenModel); err != nil { + return "", "", fmt.Errorf("failed to save refresh token: %w", err) + } + return accessToken, refreshToken, nil } func (s *authService) RefreshToken(ctx context.Context, refreshToken string) (string, string, error) { - claims, err := s.ValidateToken(refreshToken) + refreshTokenModel, err := s.refreshTokenRepo.GetByToken(ctx, refreshToken) if err != nil { return "", "", fmt.Errorf("invalid refresh token") } - user, err := s.userRepo.GetByID(ctx, claims.UserID) + user, err := s.userRepo.GetByID(ctx, refreshTokenModel.UserID) if err != nil { return "", "", fmt.Errorf("user not found") } @@ -127,11 +149,28 @@ func (s *authService) RefreshToken(ctx context.Context, refreshToken string) (st return "", "", fmt.Errorf("failed to generate access token: %w", err) } - newRefreshToken, err := s.generateToken(user, s.refreshTTL) + newRefreshToken, err := s.generateRefreshToken() if err != nil { return "", "", fmt.Errorf("failed to generate refresh token: %w", err) } + if err := s.refreshTokenRepo.Delete(ctx, refreshToken); err != nil { + return "", "", fmt.Errorf("failed to delete old refresh token: %w", err) + } + + newRefreshTokenModel := &models.RefreshToken{ + ID: uuid.New(), + UserID: user.ID, + Token: newRefreshToken, + ExpiresAt: time.Now().Add(s.refreshTTL), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.refreshTokenRepo.Create(ctx, newRefreshTokenModel); err != nil { + return "", "", fmt.Errorf("failed to save new refresh token: %w", err) + } + return accessToken, newRefreshToken, nil } @@ -215,6 +254,21 @@ func (s *authService) generateToken(user *models.User, ttl time.Duration) (strin return token.SignedString([]byte(s.config.Secret)) } +func (s *authService) generateRefreshToken() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + return hex.EncodeToString(bytes), nil +} + +func (s *authService) Logout(ctx context.Context, refreshToken string) error { + if err := s.refreshTokenRepo.Delete(ctx, refreshToken); err != nil { + return fmt.Errorf("failed to delete refresh token: %w", err) + } + return nil +} + func (s *authService) addLoginDelay() { time.Sleep(500 * time.Millisecond) } diff --git a/internal/services/auth_service_test.go b/internal/services/auth_service_test.go index 3e1d626..de5fded 100644 --- a/internal/services/auth_service_test.go +++ b/internal/services/auth_service_test.go @@ -55,17 +55,71 @@ func (m *mockUserRepository) Delete(ctx context.Context, id uuid.UUID) error { return fmt.Errorf("user not found") } +type mockRefreshTokenRepository struct { + tokens map[string]*models.RefreshToken +} + +func (m *mockRefreshTokenRepository) Create(ctx context.Context, token *models.RefreshToken) error { + m.tokens[token.Token] = token + return nil +} + +func (m *mockRefreshTokenRepository) GetByToken(ctx context.Context, token string) (*models.RefreshToken, error) { + refreshToken, exists := m.tokens[token] + if !exists { + return nil, fmt.Errorf("refresh token not found") + } + return refreshToken, nil +} + +func (m *mockRefreshTokenRepository) GetByUserID(ctx context.Context, userID uuid.UUID) ([]*models.RefreshToken, error) { + var tokens []*models.RefreshToken + for _, token := range m.tokens { + if token.UserID == userID { + tokens = append(tokens, token) + } + } + return tokens, nil +} + +func (m *mockRefreshTokenRepository) Delete(ctx context.Context, token string) error { + delete(m.tokens, token) + return nil +} + +func (m *mockRefreshTokenRepository) DeleteByUserID(ctx context.Context, userID uuid.UUID) error { + for token, refreshToken := range m.tokens { + if refreshToken.UserID == userID { + delete(m.tokens, token) + } + } + return nil +} + +func (m *mockRefreshTokenRepository) DeleteExpired(ctx context.Context) error { + now := time.Now() + for token, refreshToken := range m.tokens { + if refreshToken.ExpiresAt.Before(now) { + delete(m.tokens, token) + } + } + return nil +} + func TestAuthService_Register(t *testing.T) { mockRepo := &mockUserRepository{ users: make(map[string]*models.User), } + mockRefreshRepo := &mockRefreshTokenRepository{ + tokens: make(map[string]*models.RefreshToken), + } jwtConfig := &config.JWTConfig{ Secret: "test-secret", Issuer: "test-issuer", Audience: "test-audience", ClockSkew: 1 * time.Minute, } - authService := NewAuthService(mockRepo, jwtConfig) + authService := NewAuthService(mockRepo, mockRefreshRepo, jwtConfig) req := &models.CreateUserRequest{ Email: "test@example.com", @@ -94,13 +148,16 @@ func TestAuthService_Login(t *testing.T) { mockRepo := &mockUserRepository{ users: make(map[string]*models.User), } + mockRefreshRepo := &mockRefreshTokenRepository{ + tokens: make(map[string]*models.RefreshToken), + } jwtConfig := &config.JWTConfig{ Secret: "test-secret", Issuer: "test-issuer", Audience: "test-audience", ClockSkew: 1 * time.Minute, } - authService := NewAuthService(mockRepo, jwtConfig) + authService := NewAuthService(mockRepo, mockRefreshRepo, jwtConfig) // First register a user req := &models.CreateUserRequest{ @@ -131,13 +188,16 @@ func TestAuthService_LoginCaseInsensitive(t *testing.T) { mockRepo := &mockUserRepository{ users: make(map[string]*models.User), } + mockRefreshRepo := &mockRefreshTokenRepository{ + tokens: make(map[string]*models.RefreshToken), + } jwtConfig := &config.JWTConfig{ Secret: "test-secret", Issuer: "test-issuer", Audience: "test-audience", ClockSkew: 1 * time.Minute, } - authService := NewAuthService(mockRepo, jwtConfig) + authService := NewAuthService(mockRepo, mockRefreshRepo, jwtConfig) // Register user with lowercase email req := &models.CreateUserRequest{ @@ -168,13 +228,16 @@ func TestAuthService_HashPassword(t *testing.T) { mockRepo := &mockUserRepository{ users: make(map[string]*models.User), } + mockRefreshRepo := &mockRefreshTokenRepository{ + tokens: make(map[string]*models.RefreshToken), + } jwtConfig := &config.JWTConfig{ Secret: "test-secret", Issuer: "test-issuer", Audience: "test-audience", ClockSkew: 1 * time.Minute, } - authService := NewAuthService(mockRepo, jwtConfig) + authService := NewAuthService(mockRepo, mockRefreshRepo, jwtConfig) password := "testpassword123" hashed, err := authService.HashPassword(password) diff --git a/migrations/000002_refresh_tokens.down.sql b/migrations/000002_refresh_tokens.down.sql new file mode 100644 index 0000000..2afa74d --- /dev/null +++ b/migrations/000002_refresh_tokens.down.sql @@ -0,0 +1,2 @@ +-- Drop refresh_tokens table +DROP TABLE IF EXISTS refresh_tokens CASCADE; diff --git a/migrations/000002_refresh_tokens.up.sql b/migrations/000002_refresh_tokens.up.sql new file mode 100644 index 0000000..0a83382 --- /dev/null +++ b/migrations/000002_refresh_tokens.up.sql @@ -0,0 +1,18 @@ +-- Create refresh_tokens table +CREATE TABLE IF NOT EXISTS refresh_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token VARCHAR(255) UNIQUE NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Create index for faster lookups +CREATE INDEX idx_refresh_tokens_token ON refresh_tokens(token); +CREATE INDEX idx_refresh_tokens_user_id ON refresh_tokens(user_id); +CREATE INDEX idx_refresh_tokens_expires_at ON refresh_tokens(expires_at); + +-- Create trigger for updated_at +CREATE TRIGGER update_refresh_tokens_updated_at BEFORE UPDATE ON refresh_tokens + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();