Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func main() {

// Initialize services and handlers
authService := setupServices(db, cfg)
handlers := setupHandlers(authService, logger, db)
handlers := setupHandlers(authService, logger, db, cfg)

// Setup routes and middleware
handler := setupRoutes(handlers, logger, authService, cfg)
Expand Down Expand Up @@ -87,7 +87,8 @@ func runMigrations(cfg *config.Config, logger *logger.Logger) {

func setupServices(db *database.Database, cfg *config.Config) services.AuthService {
userRepo := repositories.NewUserRepository(db.Pool())
authService := services.NewAuthService(userRepo, &cfg.JWT)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db.Pool())
authService := services.NewAuthService(userRepo, refreshTokenRepo, &cfg.JWT)
return authService
}

Expand All @@ -96,9 +97,9 @@ type Handlers struct {
Health *httphandler.DetailedHealthHandler
}

func setupHandlers(authService services.AuthService, logger *logger.Logger, db *database.Database) *Handlers {
func setupHandlers(authService services.AuthService, logger *logger.Logger, db *database.Database, cfg *config.Config) *Handlers {
return &Handlers{
Auth: httphandler.NewAuthHandlers(authService, logger),
Auth: httphandler.NewAuthHandlers(authService, logger, cfg),
Health: httphandler.NewDetailedHealthHandler(logger, db.Pool()),
}
}
Expand Down
5 changes: 3 additions & 2 deletions env.example
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ RATE_LIMIT_BURST_SIZE=10

# CORS Configuration
# Comma-separated list of allowed origins
CORS_ALLOWED_ORIGINS=http://localhost:3000,http://localhost:3001,http://localhost:4200
CORS_ALLOWED_ORIGINS=http://localhost:3000,http://localhost:3001,http://localhost:4200,https://your-frontend-domain.com
# Comma-separated list of allowed methods
CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
# Comma-separated list of allowed headers
CORS_ALLOWED_HEADERS=Accept,Authorization,Content-Type,X-Request-ID
# Comma-separated list of exposed headers
CORS_EXPOSED_HEADERS=X-Request-ID
# Allow credentials (true/false)
# Allow credentials (true/false) - REQUIRED for cross-domain cookies
CORS_ALLOW_CREDENTIALS=true
# Max age for preflight requests in seconds
CORS_MAX_AGE=86400
Expand All @@ -50,6 +50,7 @@ CORS_MAX_AGE=86400
# Set to 'production' for HTTPS cookies, leave empty for development
ENVIRONMENT=


# Security Headers Configuration
# HSTS max age in seconds (1 year = 31536000)
SECURITY_HSTS_MAX_AGE=31536000
Expand Down
59 changes: 31 additions & 28 deletions internal/http/auth_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,44 @@ package http
import (
"encoding/json"
"net/http"
"os"

"github.com/aleksandr/strive-api/internal/config"
"github.com/aleksandr/strive-api/internal/logger"
"github.com/aleksandr/strive-api/internal/models"
"github.com/aleksandr/strive-api/internal/services"
"github.com/aleksandr/strive-api/internal/validation"
)

const (
productionEnv = "production"
)

type AuthHandlers struct {
authService services.AuthService
logger *logger.Logger
securityLogger *SecurityLogger
config *config.Config
}

func NewAuthHandlers(authService services.AuthService, logger *logger.Logger) *AuthHandlers {
func NewAuthHandlers(authService services.AuthService, logger *logger.Logger, cfg *config.Config) *AuthHandlers {
return &AuthHandlers{
authService: authService,
logger: logger,
securityLogger: NewSecurityLogger(logger),
config: cfg,
}
}

func getCookieSettings() bool {
return os.Getenv("ENVIRONMENT") == productionEnv
func (h *AuthHandlers) getCookieSettings() (secure bool, sameSite http.SameSite) {
return false, http.SameSiteNoneMode
}

func setSecureCookie(w http.ResponseWriter, name, value string, maxAge int) {
secure := getCookieSettings()
func (h *AuthHandlers) setSecureCookie(w http.ResponseWriter, name, value string, maxAge int) {
secure, sameSite := h.getCookieSettings()

cookie := &http.Cookie{
Name: name,
Value: value,
Path: "/",
Secure: secure,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
SameSite: sameSite,
MaxAge: maxAge,
}

Expand All @@ -64,11 +62,10 @@ type RefreshRequest struct {
}

type AuthResponse struct {
AccessToken string `json:"access_token,omitempty" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token,omitempty" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
ExpiresIn int `json:"expires_in" example:"900"`
TokenType string `json:"token_type" example:"Bearer"`
Message string `json:"message,omitempty" example:"Login successful"`
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
ExpiresIn int `json:"expires_in" example:"900"`
TokenType string `json:"token_type" example:"Bearer"`
Message string `json:"message,omitempty" example:"Login successful"`
}

type ErrorResponse struct {
Expand Down Expand Up @@ -204,13 +201,13 @@ func (h *AuthHandlers) Login(w http.ResponseWriter, r *http.Request) {

h.logger.Info("User logged in successfully", "email", req.Email)

setSecureCookie(w, "access-token", accessToken, 900)
setSecureCookie(w, "refresh-token", refreshToken, 604800)
h.setSecureCookie(w, "refresh-token", refreshToken, 604800)

response := AuthResponse{
ExpiresIn: 900,
TokenType: "Bearer",
Message: "Login successful",
AccessToken: accessToken,
ExpiresIn: 900,
TokenType: "Bearer",
Message: "Login successful",
}

w.Header().Set("Content-Type", "application/json")
Expand Down Expand Up @@ -254,13 +251,13 @@ func (h *AuthHandlers) Refresh(w http.ResponseWriter, r *http.Request) {

h.logger.Info("Token refreshed successfully")

setSecureCookie(w, "access-token", accessToken, 900)
setSecureCookie(w, "refresh-token", refreshToken, 604800)
h.setSecureCookie(w, "refresh-token", refreshToken, 604800)

response := AuthResponse{
ExpiresIn: 900,
TokenType: "Bearer",
Message: "Token refreshed successfully",
AccessToken: accessToken,
ExpiresIn: 900,
TokenType: "Bearer",
Message: "Token refreshed successfully",
}

w.Header().Set("Content-Type", "application/json")
Expand All @@ -277,8 +274,14 @@ func (h *AuthHandlers) Refresh(w http.ResponseWriter, r *http.Request) {
// @Success 200 {object} map[string]interface{} "Logout successful"
// @Router /api/v1/auth/logout [post]
func (h *AuthHandlers) Logout(w http.ResponseWriter, r *http.Request) {
setSecureCookie(w, "access-token", "", -1)
setSecureCookie(w, "refresh-token", "", -1)
refreshTokenCookie, err := r.Cookie("refresh-token")
if err == nil && refreshTokenCookie.Value != "" {
if err := h.authService.Logout(r.Context(), refreshTokenCookie.Value); err != nil {
h.logger.Error("Failed to logout user", "error", err)
}
}

h.setSecureCookie(w, "refresh-token", "", -1)

h.logger.Info("User logged out successfully")

Expand Down
26 changes: 16 additions & 10 deletions internal/http/auth_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http/httptest"
"testing"

"github.com/aleksandr/strive-api/internal/config"
"github.com/aleksandr/strive-api/internal/logger"
"github.com/aleksandr/strive-api/internal/models"
"github.com/google/uuid"
Expand Down Expand Up @@ -71,7 +72,8 @@ func TestAuthHandlers_Register(t *testing.T) {
mockService := new(MockAuthService)
tt.mockSetup(mockService)

handlers := NewAuthHandlers(mockService, logger)
cfg := &config.Config{}
handlers := NewAuthHandlers(mockService, logger, cfg)

body, _ := json.Marshal(tt.requestBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
Expand Down Expand Up @@ -152,7 +154,8 @@ func TestAuthHandlers_Login(t *testing.T) {
mockService := new(MockAuthService)
tt.mockSetup(mockService)

handlers := NewAuthHandlers(mockService, logger)
cfg := &config.Config{}
handlers := NewAuthHandlers(mockService, logger, cfg)

body, _ := json.Marshal(tt.requestBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
Expand All @@ -169,20 +172,22 @@ func TestAuthHandlers_Login(t *testing.T) {
assert.NoError(t, err)
assert.Contains(t, response, "error")
} else {
// Check that cookies are set instead of JSON tokens
// Check that access token is in JSON response and refresh token is in cookie
var response map[string]interface{}
err := json.Unmarshal(rr.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Contains(t, response, "access_token")
assert.Equal(t, "access_token", response["access_token"])

cookies := rr.Result().Cookies()
var accessTokenCookie, refreshTokenCookie *http.Cookie
var refreshTokenCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "access-token" {
accessTokenCookie = cookie
} else if cookie.Name == "refresh-token" {
if cookie.Name == "refresh-token" {
refreshTokenCookie = cookie
}
}

assert.NotNil(t, accessTokenCookie, "access-token cookie should be set")
assert.NotNil(t, refreshTokenCookie, "refresh-token cookie should be set")
assert.Equal(t, "access_token", accessTokenCookie.Value)
assert.Equal(t, "refresh_token", refreshTokenCookie.Value)
}

Expand Down Expand Up @@ -215,7 +220,8 @@ func TestAuthHandlers_Me(t *testing.T) {
mockService := &MockAuthService{}
tt.mockSetup(mockService)

handlers := NewAuthHandlers(mockService, logger)
cfg := &config.Config{}
handlers := NewAuthHandlers(mockService, logger, cfg)

req := httptest.NewRequest("GET", "/api/v1/auth/me", http.NoBody)
rr := httptest.NewRecorder()
Expand Down
34 changes: 9 additions & 25 deletions internal/http/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,20 @@ func logAuthFailure(log *logger.Logger, r *http.Request, reason string) {
func AuthMiddleware(authService services.AuthService, log *logger.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var tokenString string
var authSource string

// Try to get token from Authorization header first
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" && parts[1] != "" {
tokenString = parts[1]
authSource = "header"
}
}

// If no token in header, try to get from cookie
if tokenString == "" {
accessTokenCookie, err := r.Cookie("access-token")
if err == nil && accessTokenCookie.Value != "" {
tokenString = accessTokenCookie.Value
authSource = "cookie"
}
if authHeader == "" {
writeAuthError(w, log, r, "UNAUTHORIZED", "Authentication required", "missing_authorization_header")
return
}

// If still no token, return error
if tokenString == "" {
writeAuthError(w, log, r, "UNAUTHORIZED", "Authentication required", "missing_token")
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" || parts[1] == "" {
writeAuthError(w, log, r, "UNAUTHORIZED", "Invalid authorization header format", "invalid_authorization_format")
return
}

tokenString := parts[1]

claims, err := authService.ValidateToken(tokenString)
if err != nil {
var code, message, reason string
Expand All @@ -107,11 +93,9 @@ func AuthMiddleware(authService services.AuthService, log *logger.Logger) func(h
return
}

// Log successful authentication with source
log.Debug("Authentication successful",
"user_id", claims.UserID,
"email", claims.Email,
"source", authSource)
"email", claims.Email)

ctx := context.WithValue(r.Context(), UserIDKey, claims.UserID.String())
ctx = context.WithValue(ctx, UserEmailKey, claims.Email)
Expand Down
4 changes: 4 additions & 0 deletions internal/http/auth_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func (m *mockAuthService) VerifyPassword(hashedPassword, password string) error
return nil
}

func (m *mockAuthService) Logout(ctx context.Context, refreshToken string) error {
return nil
}

func TestAuthMiddleware(t *testing.T) {
log := logger.New("INFO", "json")

Expand Down
5 changes: 5 additions & 0 deletions internal/http/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ func (m *MockAuthService) RefreshToken(ctx context.Context, refreshToken string)
args := m.Called(ctx, refreshToken)
return args.String(0), args.String(1), args.Error(2)
}

func (m *MockAuthService) Logout(ctx context.Context, refreshToken string) error {
args := m.Called(ctx, refreshToken)
return args.Error(0)
}
16 changes: 16 additions & 0 deletions internal/models/refresh_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package models

import (
"time"

"github.com/google/uuid"
)

type RefreshToken struct {
ID uuid.UUID `json:"id" db:"id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
Token string `json:"token" db:"token"`
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
Loading