diff --git a/cmd/server/main.go b/cmd/server/main.go index cd55271..073160e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -101,32 +101,40 @@ func runMigrations(cfg *config.Config, logger *logger.Logger) { } type Services struct { - Auth services.AuthService - User services.UserService + Auth services.AuthService + User services.UserService + Calorie services.CalorieService } func setupServices(db *database.Database, cfg *config.Config) *Services { userRepo := repositories.NewUserRepository(db.Pool()) refreshTokenRepo := repositories.NewRefreshTokenRepository(db.Pool()) + calorieRepo := repositories.NewCalorieRepository(db.Pool()) + authService := services.NewAuthService(userRepo, refreshTokenRepo, &cfg.JWT) userService := services.NewUserService(userRepo) + calorieService := services.NewCalorieService(calorieRepo) + return &Services{ - Auth: authService, - User: userService, + Auth: authService, + User: userService, + Calorie: calorieService, } } type Handlers struct { - Auth *httphandler.AuthHandlers - User *httphandler.UserHandlers - Health *httphandler.DetailedHealthHandler + Auth *httphandler.AuthHandlers + User *httphandler.UserHandlers + Calorie *httphandler.CalorieHandlers + Health *httphandler.DetailedHealthHandler } func setupHandlers(services *Services, logger *logger.Logger, db *database.Database, cfg *config.Config) *Handlers { return &Handlers{ - Auth: httphandler.NewAuthHandlers(services.Auth, logger, cfg), - User: httphandler.NewUserHandlers(services.User, logger), - Health: httphandler.NewDetailedHealthHandler(logger, db.Pool()), + Auth: httphandler.NewAuthHandlers(services.Auth, logger, cfg), + User: httphandler.NewUserHandlers(services.User, logger), + Calorie: httphandler.NewCalorieHandlers(services.Calorie, logger), + Health: httphandler.NewDetailedHealthHandler(logger, db.Pool()), } } @@ -166,6 +174,13 @@ func setupProtectedRoutes(mux *http.ServeMux, authService services.AuthService, userProtectedMux.HandleFunc("/theme", handlers.User.UpdateTheme) userProtectedHandler := httphandler.AuthMiddleware(authService, logger)(userProtectedMux) mux.Handle("/api/v1/user/", http.StripPrefix("/api/v1/user", userProtectedHandler)) + + // Calorie calculator protected routes + calorieProtectedMux := http.NewServeMux() + calorieProtectedMux.HandleFunc("/calculate", handlers.Calorie.CalculateCalories) + calorieProtectedMux.HandleFunc("/last", handlers.Calorie.GetLastCalculation) + calorieProtectedHandler := httphandler.AuthMiddleware(authService, logger)(calorieProtectedMux) + mux.Handle("/api/v1/calorie/", http.StripPrefix("/api/v1/calorie", calorieProtectedHandler)) } func applyMiddleware(mux *http.ServeMux, logger *logger.Logger, cfg *config.Config) http.Handler { diff --git a/docs/calorie-diary-mvp-plan.md b/docs/calorie-diary-mvp-plan.md new file mode 100644 index 0000000..20f5b5c --- /dev/null +++ b/docs/calorie-diary-mvp-plan.md @@ -0,0 +1,311 @@ +# План MVP дневника калорий для Strive API + +## Обзор + +Данный документ содержит план реализации MVP версии дневника калорий с интеграцией Open Food Facts API для сканирования штрихкодов продуктов. + +**Связанные планы:** +- [План разработки API дневника тренировок](./plan.md) - основная архитектура +- [Будущие улучшения](./stages/09-future.md) - расширенная функциональность +- [План улучшения безопасности](./security-improvement-plan.md) - меры безопасности + +## 🎯 Цели MVP + +- Отслеживание потребления продуктов через сканирование штрихкодов +- Интеграция с Open Food Facts API для получения данных о продуктах +- Ведение дневника питания с расчетом калорий и БЖУ +- Базовая аналитика потребления + +## 🏗️ Архитектурные решения + +**Расширение существующей архитектуры:** +- Использование текущей JWT аутентификации +- Расширение схемы БД новыми таблицами +- Добавление новых сервисов в существующую структуру +- Сохранение принципов Clean Architecture + +## 📊 Схема базы данных + +```sql +-- Таблица продуктов +CREATE TABLE products ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + barcode VARCHAR(20) UNIQUE NOT NULL, + name VARCHAR(255) NOT NULL, + brand VARCHAR(255), + category VARCHAR(100), + image_url TEXT, + nutrition_per_100g JSONB NOT NULL, -- калории, белки, жиры, углеводы + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Таблица дневника питания +CREATE TABLE food_diary_entries ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + product_id UUID NOT NULL REFERENCES products(id) ON DELETE CASCADE, + quantity_grams DECIMAL(10,2) NOT NULL, + consumed_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Индексы для производительности +CREATE INDEX idx_food_diary_user_date ON food_diary_entries(user_id, consumed_at); +CREATE INDEX idx_products_barcode ON products(barcode); +``` + +## 🔌 API Endpoints + +**Новые эндпоинты для дневника калорий:** + +``` +# Поиск продуктов +GET /api/v1/products/search?query=название +GET /api/v1/products/barcode/{barcode} + +# Дневник питания +GET /api/v1/diary/entries?date=2024-01-01 +POST /api/v1/diary/entries +PUT /api/v1/diary/entries/{id} +DELETE /api/v1/diary/entries/{id} + +# Аналитика +GET /api/v1/diary/summary?date=2024-01-01 +GET /api/v1/diary/summary?from=2024-01-01&to=2024-01-07 +``` + +## 🏛️ Структура кода + +**Новые компоненты в существующей архитектуре:** + +``` +internal/ +├── models/ +│ ├── product.go # Модели продуктов +│ ├── food_diary.go # Модели дневника питания +│ └── nutrition.go # Модели питательных веществ +├── services/ +│ ├── product_service.go # Бизнес-логика продуктов +│ ├── food_diary_service.go # Бизнес-логика дневника +│ └── nutrition_service.go # Расчеты БЖУ +├── repositories/ +│ ├── product_repository.go +│ └── food_diary_repository.go +├── http/ +│ ├── product_handlers.go +│ └── food_diary_handlers.go +└── external/ + └── openfoodfacts_client.go # Клиент для Open Food Facts API +``` + +## 🔌 Интеграция с Open Food Facts API + +**Основные функции:** +- Поиск продукта по штрихкоду +- Получение данных о питательности +- Кэширование результатов для оптимизации +- Обработка ошибок API + +```go +type OpenFoodFactsClient interface { + GetProductByBarcode(ctx context.Context, barcode string) (*Product, error) + SearchProducts(ctx context.Context, query string) ([]*Product, error) +} +``` + +## 📋 Модели данных + +```go +// Продукт +type Product struct { + ID uuid.UUID `json:"id" db:"id"` + Barcode string `json:"barcode" db:"barcode"` + Name string `json:"name" db:"name"` + Brand string `json:"brand" db:"brand"` + Category string `json:"category" db:"category"` + ImageURL string `json:"image_url" db:"image_url"` + Nutrition *Nutrition `json:"nutrition" db:"nutrition_per_100g"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +// Питательные вещества +type Nutrition struct { + Calories float64 `json:"calories"` + Protein float64 `json:"protein"` + Fat float64 `json:"fat"` + Carbs float64 `json:"carbs"` + Fiber float64 `json:"fiber"` + Sugar float64 `json:"sugar"` +} + +// Запись в дневнике +type FoodDiaryEntry struct { + ID uuid.UUID `json:"id" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + ProductID uuid.UUID `json:"product_id" db:"product_id"` + Product *Product `json:"product,omitempty"` + QuantityGrams float64 `json:"quantity_grams" db:"quantity_grams"` + ConsumedAt time.Time `json:"consumed_at" db:"consumed_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} +``` + +## 🚀 Этапы реализации + +### **Этап 1: Подготовка инфраструктуры (1-2 дня)** +1. **Миграции БД** + - Создать таблицы `products` и `food_diary_entries` + - Добавить индексы для производительности + - Создать миграции up/down + +2. **Модели данных** + - `Product`, `FoodDiaryEntry`, `Nutrition` + - Request/Response модели + - Валидация данных + +### **Этап 2: Интеграция с Open Food Facts (2-3 дня)** +1. **HTTP клиент** + - Реализация `OpenFoodFactsClient` + - Обработка ошибок API + - Rate limiting и retry логика + +2. **Сервис продуктов** + - `ProductService` с методами поиска + - Кэширование результатов + - Обработка отсутствующих продуктов + +### **Этап 3: Репозитории и сервисы (2-3 дня)** +1. **Репозитории** + - `ProductRepository` - CRUD операции + - `FoodDiaryRepository` - управление дневником + +2. **Бизнес-логика** + - `FoodDiaryService` - расчеты калорий + - `NutritionService` - расчеты БЖУ + - Валидация данных + +### **Этап 4: HTTP слой (2-3 дня)** +1. **Handlers** + - `ProductHandlers` - поиск продуктов + - `FoodDiaryHandlers` - управление дневником + +2. **Middleware** + - Валидация входных данных + - Обработка ошибок + +### **Этап 5: Тестирование (2-3 дня)** +1. **Unit тесты** + - Тесты сервисов и репозиториев + - Моки для внешних API + - Покрытие >70% + +2. **Integration тесты** + - Тесты API endpoints + - Тесты с реальной БД + +### **Этап 6: Документация и развертывание (1-2 дня)** +1. **Swagger документация** + - Описание новых endpoints + - Примеры запросов/ответов + +2. **Конфигурация** + - ENV переменные для Open Food Facts + - Настройки кэширования + +## ⚙️ Технические детали + +### **Конфигурация** +```bash +# Open Food Facts API +OPENFOODFACTS_API_URL=https://world.openfoodfacts.org/api/v0 +OPENFOODFACTS_USER_AGENT=StriveAPI/1.0 +OPENFOODFACTS_TIMEOUT=10s +OPENFOODFACTS_RETRY_ATTEMPTS=3 + +# Кэширование продуктов +PRODUCT_CACHE_TTL=24h +PRODUCT_CACHE_SIZE=1000 +``` + +### **Обработка ошибок** +- Продукт не найден в Open Food Facts +- Неверный формат штрихкода +- Ошибки сети при обращении к API +- Валидация данных пользователя + +### **Производительность** +- Кэширование продуктов в памяти +- Индексы БД для быстрого поиска +- Пагинация для больших списков +- Оптимизация SQL запросов + +## 📊 Аналитика и отчеты + +### **Сводка по дням** +```json +{ + "date": "2024-01-01", + "total_calories": 2500, + "total_protein": 120.5, + "total_fat": 85.2, + "total_carbs": 300.8, + "entries_count": 5, + "entries": [...] +} +``` + +### **Статистика по периодам** +- Дневная сводка +- Недельная статистика +- Месячные отчеты +- Тренды потребления + +## ✅ Критерии готовности MVP + +- ✅ Пользователь может найти продукт по штрихкоду +- ✅ Пользователь может добавить продукт в дневник +- ✅ Система рассчитывает калории и БЖУ +- ✅ Пользователь видит сводку за день +- ✅ Все операции покрыты тестами +- ✅ API документирован в Swagger + +## 🚀 Следующие шаги + +1. **Создать ветку для фичи**: `git checkout -b feature/calorie-diary` +2. **Начать с миграций БД** - создать схему таблиц +3. **Реализовать модели данных** - Product, FoodDiaryEntry +4. **Интегрировать Open Food Facts API** - HTTP клиент +5. **Добавить бизнес-логику** - сервисы и репозитории +6. **Создать API endpoints** - handlers и middleware +7. **Написать тесты** - unit и integration +8. **Обновить документацию** - Swagger и README + +## 📈 Время выполнения + +**Общее время**: 10-15 дней +- Этап 1: 1-2 дня +- Этап 2: 2-3 дня +- Этап 3: 2-3 дня +- Этап 4: 2-3 дня +- Этап 5: 2-3 дня +- Этап 6: 1-2 дня + +## 🎯 Приоритеты + +1. **Высокий**: Базовая функциональность дневника, интеграция с Open Food Facts +2. **Средний**: Аналитика и отчеты, кэширование +3. **Низкий**: Расширенная аналитика, оптимизация производительности + +## 🔗 Связанные задачи + +### Дополнительные возможности (будущие этапы): +- **Расширенная аналитика** - тренды, рекомендации по питанию +- **Экспорт данных** - выгрузка дневника в CSV/JSON +- **Уведомления** - напоминания о приеме пищи +- **Социальные функции** - обмен рецептами, достижения + +*Эти задачи планируются в рамках будущих этапов развития* diff --git a/docs/security-improvement-plan.md b/docs/security-improvement-plan.md index 0a62089..334f29f 100644 --- a/docs/security-improvement-plan.md +++ b/docs/security-improvement-plan.md @@ -2,17 +2,13 @@ ## Обзор -Данный документ содержит план по устранению оставшихся проблем безопасности в проекте Strive API. Все критические проблемы уже решены. +Данный документ содержит план по устранению оставшихся проблем безопасности в проекте Strive API. -## 🔴 Критические проблемы (Приоритет 1) +**Связанные планы:** +- [План будущих улучшений](../stages/09-future.md) - включает 2FA, аудит безопасности +- [Архитектура проекта](../../.cursor/rules/project-architecture.mdc) - обзор реализованных мер безопасности -**Статус**: ✅ **ВСЕ ВЫПОЛНЕНЫ** -- JWT secret валидация -- Rate limiting защита -- HTTP Security Headers -- CORS конфигурация - -## 🟠 Серьезные проблемы (Приоритет 2) +## 🟠 Серьезные проблемы (Приоритет 1) ### 1. Отсутствие защиты от CSRF @@ -23,22 +19,11 @@ **Решение**: - [ ] Добавить CSRF middleware - [ ] Использовать Double Submit Cookie pattern -- [ ] Настроить SameSite атрибуты для cookies - [ ] Добавить проверку Referer header -## 🟡 Средние проблемы (Приоритет 3) - -### 2. Недостаточная защита от timing атак - -**Проблема**: Нет constant-time сравнения паролей +## 🟡 Средние проблемы (Приоритет 2) -**Риск**: Enumeration атаки - -**Решение**: -- [ ] Использовать constant-time сравнение для паролей -- [ ] Добавить jitter для времени ответа - -### 3. Отсутствие валидации размера запросов +### 1. Отсутствие валидации размера запросов **Проблема**: Нет ограничений на размер тела запроса @@ -55,18 +40,23 @@ 1. Реализация CSRF защиты ### Приоритет 2 (Средние проблемы) -1. Constant-time сравнение паролей -2. Валидация размера запросов +1. Валидация размера запросов ## 📊 Статистика -- **Критические проблемы**: 4/4 (100%) ✅ -- **Серьезные проблемы**: 2/3 (67%) ⚠️ -- **Средние проблемы**: 1/3 (33%) ⚠️ -- **Общий прогресс**: 7/10 (70%) ✅ +- **Серьезные проблемы**: 0/1 (0%) ⚠️ +- **Средние проблемы**: 0/1 (0%) ⚠️ +- **Общий прогресс**: 0/2 (0%) ⚠️ ## 🎯 Рекомендации 1. **Приоритет 1**: Реализовать CSRF защиту -2. **Приоритет 2**: Добавить constant-time сравнение паролей -3. **Приоритет 3**: Добавить валидацию размера запросов +2. **Приоритет 2**: Добавить валидацию размера запросов + +## 🔗 Связанные задачи из плана будущих улучшений + +### Дополнительные меры безопасности (низкий приоритет): +- **2FA (Двухфакторная аутентификация)** - TOTP, SMS коды, backup codes +- **Аудит безопасности** - логирование событий безопасности, мониторинг подозрительной активности + +*Эти задачи планируются в рамках [Этапа 9: Будущие улучшения](../stages/09-future.md)* diff --git a/docs/stages/09-future.md b/docs/stages/09-future.md index da9a764..233c8b2 100644 --- a/docs/stages/09-future.md +++ b/docs/stages/09-future.md @@ -3,32 +3,9 @@ ## Цель этапа Дополнительные возможности для улучшения производительности, безопасности и функциональности API. -## ✅ Уже реализованные функции - -### Безопасность -- ✅ **Rate Limiting** - защита от DDoS атак с настраиваемыми лимитами -- ✅ **HTTP Security Headers** - HSTS, CSP, X-Frame-Options, XSS Protection -- ✅ **CORS конфигурация** - настраиваемая через ENV переменные -- ✅ **JWT аутентификация** - с расширенной валидацией и refresh токенами -- ✅ **bcrypt хеширование** - безопасное хранение паролей -- ✅ **SQL Injection защита** - через pgx драйвер - -### Мониторинг и логирование -- ✅ **Структурированное логирование** - JSON/текст форматы -- ✅ **Request ID трассировка** - для отслеживания запросов -- ✅ **Security логирование** - события безопасности -- ✅ **Health checks** - базовые проверки состояния системы - -### DevOps и развертывание -- ✅ **CI/CD Pipeline** - GitHub Actions с автоматическими тестами -- ✅ **Docker контейнеризация** - с multi-stage build -- ✅ **Автоматическое тестирование** - с покрытием кода -- ✅ **Линтинг и форматирование** - golangci-lint, gofumpt, goimports - ## 9.1 Производительность и масштабируемость - [ ] Кэширование часто запрашиваемых данных - [ ] Метрики производительности (Prometheus) -- [ ] Оптимизация запросов к базе данных ### Кэширование - Кэшировать список пользователей (TTL 5 минут) @@ -71,12 +48,7 @@ - [ ] GET `/api/v1/stats/sessions` - статистика сессий ## 9.3 Безопасность -- [ ] HTTPS в продакшене - [ ] Двухфакторная аутентификация (2FA) -- [ ] Аудит безопасности - -### HTTPS и TLS -- Настройка TLS сертификатов ### 2FA - TOTP (Time-based One-Time Password) @@ -84,12 +56,11 @@ - Backup codes ## 9.4 Мониторинг и алертинг -- [ ] Расширенные health checks - [ ] Алерты на критические ошибки - [ ] Dashboard для мониторинга - [ ] Distributed tracing -### Расширенные health checks +### Дополнительные health checks - [ ] GET `/health/cache` - статус кэша - [ ] GET `/health/external` - статус внешних сервисов - [ ] GET `/health/memory` - использование памяти @@ -107,7 +78,6 @@ - [ ] Helm charts ### CI/CD Pipeline -- [ ] Безопасность сканирование - [ ] Автоматическое развертывание ### Kubernetes @@ -121,7 +91,6 @@ - [ ] Метрики собираются и доступны - [ ] Экспорт/импорт данных работает - [ ] Мониторинг показывает корректные данные -- [ ] HTTPS настроен в продакшене - [ ] 2FA работает корректно ## Время выполнения @@ -129,11 +98,9 @@ ## Приоритеты 1. **Высокий**: Метрики, мониторинг, кэширование -2. **Средний**: Экспорт/импорт, расширенные health checks, HTTPS +2. **Средний**: Экспорт/импорт, дополнительные health checks 3. **Низкий**: 2FA, Kubernetes, advanced DevOps -## Предыдущий этап -[Этап 8: Дополнительные возможности](./08-additional.md) ## Завершение проекта После завершения всех этапов проект будет готов к enterprise-уровню использования с полным набором возможностей для мониторинга, безопасности и масштабируемости. diff --git a/internal/http/auth_handlers.go b/internal/http/auth_handlers.go index 89e2690..17f6dc6 100644 --- a/internal/http/auth_handlers.go +++ b/internal/http/auth_handlers.go @@ -306,9 +306,9 @@ func (h *AuthHandlers) Logout(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} AuthError "Internal server error" // @Router /api/v1/auth/me [get] func (h *AuthHandlers) Me(w http.ResponseWriter, r *http.Request) { - userID, ok := GetUserIDFromContext(r.Context()) - if !ok { - h.logger.Error("User ID not found in context") + userID, err := GetUserIDFromContext(r.Context()) + if err != nil { + h.logger.Error("User ID not found in context", "error", err) http.Error(w, `{"error":{"code":"INTERNAL_ERROR","message":"User ID not found in context"}}`, http.StatusInternalServerError) return } diff --git a/internal/http/auth_handlers_test.go b/internal/http/auth_handlers_test.go index ece3efe..14e31a2 100644 --- a/internal/http/auth_handlers_test.go +++ b/internal/http/auth_handlers_test.go @@ -227,7 +227,7 @@ func TestAuthHandlers_Me(t *testing.T) { rr := httptest.NewRecorder() // Add user context to request (simulating auth middleware) - ctx := context.WithValue(req.Context(), UserIDKey, "test-user-id") + ctx := context.WithValue(req.Context(), UserIDKey, "5891f008-f598-4aad-86bd-cae765c22fed") ctx = context.WithValue(ctx, UserEmailKey, "test@example.com") req = req.WithContext(ctx) @@ -247,7 +247,7 @@ func TestAuthHandlers_Me(t *testing.T) { assert.Contains(t, response, "user_id") assert.Contains(t, response, "email") assert.Contains(t, response, "message") - assert.Equal(t, "test-user-id", response["user_id"]) + assert.Equal(t, "5891f008-f598-4aad-86bd-cae765c22fed", response["user_id"]) assert.Equal(t, "test@example.com", response["email"]) } diff --git a/internal/http/auth_middleware.go b/internal/http/auth_middleware.go index f4f0059..0d84e95 100644 --- a/internal/http/auth_middleware.go +++ b/internal/http/auth_middleware.go @@ -11,13 +11,6 @@ import ( "github.com/aleksandr/strive-api/internal/services" ) -type contextKey string - -const ( - UserIDKey contextKey = "user_id" - UserEmailKey contextKey = "user_email" -) - type AuthError struct { Error struct { Code string `json:"code"` @@ -104,13 +97,3 @@ func AuthMiddleware(authService services.AuthService, log *logger.Logger) func(h }) } } - -func GetUserIDFromContext(ctx context.Context) (string, bool) { - userID, ok := ctx.Value(UserIDKey).(string) - return userID, ok -} - -func GetUserEmailFromContext(ctx context.Context) (string, bool) { - email, ok := ctx.Value(UserEmailKey).(string) - return email, ok -} diff --git a/internal/http/auth_middleware_test.go b/internal/http/auth_middleware_test.go index 912139a..c92225b 100644 --- a/internal/http/auth_middleware_test.go +++ b/internal/http/auth_middleware_test.go @@ -214,12 +214,12 @@ func TestAuthMiddleware(t *testing.T) { middleware := AuthMiddleware(mockAuth, log) handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - contextUserID, ok := GetUserIDFromContext(r.Context()) - assert.True(t, ok) - assert.Equal(t, userID.String(), contextUserID) + contextUserID, err := GetUserIDFromContext(r.Context()) + assert.NoError(t, err) + assert.Equal(t, userID, contextUserID) - contextEmail, ok := GetUserEmailFromContext(r.Context()) - assert.True(t, ok) + contextEmail, err := GetUserEmailFromContext(r.Context()) + assert.NoError(t, err) assert.Equal(t, email, contextEmail) w.WriteHeader(http.StatusOK) @@ -251,12 +251,12 @@ func TestAuthMiddleware(t *testing.T) { middleware := AuthMiddleware(mockAuth, log) handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - contextUserID, ok := GetUserIDFromContext(r.Context()) - assert.True(t, ok) - assert.Equal(t, testUser.ID.String(), contextUserID) + contextUserID, err := GetUserIDFromContext(r.Context()) + assert.NoError(t, err) + assert.Equal(t, testUser.ID, contextUserID) - contextEmail, ok := GetUserEmailFromContext(r.Context()) - assert.True(t, ok) + contextEmail, err := GetUserEmailFromContext(r.Context()) + assert.NoError(t, err) assert.Equal(t, testUser.Email, contextEmail) w.WriteHeader(http.StatusOK) diff --git a/internal/http/calorie_handlers.go b/internal/http/calorie_handlers.go new file mode 100644 index 0000000..e6efd95 --- /dev/null +++ b/internal/http/calorie_handlers.go @@ -0,0 +1,92 @@ +package http + +import ( + "encoding/json" + "net/http" + + "github.com/aleksandr/strive-api/internal/logger" + "github.com/aleksandr/strive-api/internal/models" + "github.com/aleksandr/strive-api/internal/services" + "github.com/aleksandr/strive-api/internal/validation" +) + +type CalorieHandlers struct { + calorieService services.CalorieService + logger *logger.Logger + validator *validation.Validator +} + +func NewCalorieHandlers(calorieService services.CalorieService, logger *logger.Logger) *CalorieHandlers { + return &CalorieHandlers{ + calorieService: calorieService, + logger: logger, + validator: &validation.Validator{}, + } +} + +func (h *CalorieHandlers) CalculateCalories(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var data models.CalorieCalculationData + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + h.logger.Error("Failed to decode request body", "error", err) + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + if err := h.validator.Validate(data); err != nil { + h.logger.Error("Validation failed", "error", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + userUUID, err := GetUserIDFromContext(r.Context()) + if err != nil { + h.logger.Error("User ID not found in context", "error", err) + http.Error(w, "Invalid user ID", http.StatusInternalServerError) + return + } + + results, err := h.calorieService.CalculateCalories(r.Context(), userUUID, &data) + if err != nil { + h.logger.Error("Failed to calculate calories", "error", err) + http.Error(w, "Failed to calculate calories", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(results); err != nil { + h.logger.Error("Failed to encode response", "error", err) + } +} + +func (h *CalorieHandlers) GetLastCalculation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + userUUID, err := GetUserIDFromContext(r.Context()) + if err != nil { + h.logger.Error("User ID not found in context", "error", err) + http.Error(w, "Invalid user ID", http.StatusInternalServerError) + return + } + + response, err := h.calorieService.GetLastCalculation(r.Context(), userUUID) + if err != nil { + h.logger.Error("Failed to get last calculation", "error", err) + http.Error(w, "No calculation found", http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Error("Failed to encode response", "error", err) + } +} diff --git a/internal/http/calorie_handlers_test.go b/internal/http/calorie_handlers_test.go new file mode 100644 index 0000000..54ff778 --- /dev/null +++ b/internal/http/calorie_handlers_test.go @@ -0,0 +1,325 @@ +package http + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/aleksandr/strive-api/internal/logger" + "github.com/aleksandr/strive-api/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockCalorieService struct { + mock.Mock +} + +func (m *mockCalorieService) CalculateCalories( + ctx context.Context, + userID uuid.UUID, + data *models.CalorieCalculationData, +) (*models.CalorieResults, error) { + args := m.Called(ctx, userID, data) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.CalorieResults), args.Error(1) +} + +func (m *mockCalorieService) GetLastCalculation(ctx context.Context, userID uuid.UUID) (*models.CalorieCalculationResponse, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.CalorieCalculationResponse), args.Error(1) +} + +func TestCalorieHandlers_CalculateCalories_Success(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + userID := uuid.New() + expectedResults := &models.CalorieResults{ + BMR: 1650, + TDEE: 2558, + TargetCalories: 2558, + Macros: models.Macronutrients{ + ProteinGrams: 196, + ProteinPercentage: 30.6, + FatGrams: 85, + FatPercentage: 30.0, + CarbsGrams: 256, + CarbsPercentage: 39.4, + }, + } + + mockService.On("CalculateCalories", mock.Anything, userID, mock.AnythingOfType("*models.CalorieCalculationData")). + Return(expectedResults, nil) + + requestData := models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + } + + jsonData, _ := json.Marshal(requestData) + req := httptest.NewRequest("POST", "/api/v1/calorie/calculate", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(context.WithValue(req.Context(), UserIDKey, userID.String())) + + w := httptest.NewRecorder() + handlers.CalculateCalories(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + + var response models.CalorieResults + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, expectedResults.BMR, response.BMR) + assert.Equal(t, expectedResults.TDEE, response.TDEE) + assert.Equal(t, expectedResults.TargetCalories, response.TargetCalories) + assert.Equal(t, expectedResults.Macros.ProteinGrams, response.Macros.ProteinGrams) + assert.Equal(t, expectedResults.Macros.FatGrams, response.Macros.FatGrams) + assert.Equal(t, expectedResults.Macros.CarbsGrams, response.Macros.CarbsGrams) + + mockService.AssertExpectations(t) +} + +func TestCalorieHandlers_CalculateCalories_InvalidMethod(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + req := httptest.NewRequest("GET", "/api/v1/calorie/calculate", http.NoBody) + w := httptest.NewRecorder() + handlers.CalculateCalories(w, req) + + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) +} + +func TestCalorieHandlers_CalculateCalories_InvalidJSON(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + req := httptest.NewRequest("POST", "/api/v1/calorie/calculate", bytes.NewBufferString("invalid json")) + req.Header.Set("Content-Type", "application/json") + userID := uuid.New() + req = req.WithContext(context.WithValue(req.Context(), UserIDKey, userID.String())) + + w := httptest.NewRecorder() + handlers.CalculateCalories(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestCalorieHandlers_CalculateCalories_ValidationError(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + // Invalid data - age too low + requestData := models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 10, // Too young + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + } + + jsonData, _ := json.Marshal(requestData) + req := httptest.NewRequest("POST", "/api/v1/calorie/calculate", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + userID := uuid.New() + req = req.WithContext(context.WithValue(req.Context(), UserIDKey, userID.String())) + + w := httptest.NewRecorder() + handlers.CalculateCalories(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestCalorieHandlers_CalculateCalories_ServiceError(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + userID := uuid.New() + mockService.On("CalculateCalories", mock.Anything, userID, mock.AnythingOfType("*models.CalorieCalculationData")). + Return(nil, assert.AnError) + + requestData := models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + } + + jsonData, _ := json.Marshal(requestData) + req := httptest.NewRequest("POST", "/api/v1/calorie/calculate", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(context.WithValue(req.Context(), UserIDKey, userID.String())) + + w := httptest.NewRecorder() + handlers.CalculateCalories(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) +} + +func TestCalorieHandlers_CalculateCalories_NoUserID(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + requestData := models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + } + + jsonData, _ := json.Marshal(requestData) + req := httptest.NewRequest("POST", "/api/v1/calorie/calculate", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + // No user ID in context + + w := httptest.NewRecorder() + handlers.CalculateCalories(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) +} + +func TestCalorieHandlers_GetLastCalculation_Success(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + userID := uuid.New() + expectedResponse := &models.CalorieCalculationResponse{ + Data: models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + }, + Results: models.CalorieResults{ + BMR: 1650, + TDEE: 2558, + TargetCalories: 2558, + Formula: models.FormulaMifflin, + Macros: models.Macronutrients{ + ProteinGrams: 196, + ProteinPercentage: 30.6, + FatGrams: 85, + FatPercentage: 30.0, + CarbsGrams: 256, + CarbsPercentage: 39.4, + }, + }, + Timestamp: time.Now(), + } + + mockService.On("GetLastCalculation", mock.Anything, userID).Return(expectedResponse, nil) + + req := httptest.NewRequest("GET", "/api/v1/calorie/last", http.NoBody) + req = req.WithContext(context.WithValue(req.Context(), UserIDKey, userID.String())) + + w := httptest.NewRecorder() + handlers.GetLastCalculation(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + + var response models.CalorieCalculationResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, expectedResponse.Data.Gender, response.Data.Gender) + assert.Equal(t, expectedResponse.Data.Age, response.Data.Age) + assert.Equal(t, expectedResponse.Data.Height, response.Data.Height) + assert.Equal(t, expectedResponse.Data.Weight, response.Data.Weight) + assert.Equal(t, expectedResponse.Data.ActivityLevel, response.Data.ActivityLevel) + assert.Equal(t, expectedResponse.Data.Goal, response.Data.Goal) + assert.Equal(t, expectedResponse.Results.BMR, response.Results.BMR) + assert.Equal(t, expectedResponse.Results.TDEE, response.Results.TDEE) + assert.Equal(t, expectedResponse.Results.TargetCalories, response.Results.TargetCalories) + assert.Equal(t, expectedResponse.Results.Formula, response.Results.Formula) + assert.Equal(t, expectedResponse.Results.Macros.ProteinGrams, response.Results.Macros.ProteinGrams) + assert.Equal(t, expectedResponse.Results.Macros.ProteinPercentage, response.Results.Macros.ProteinPercentage) + assert.Equal(t, expectedResponse.Results.Macros.FatGrams, response.Results.Macros.FatGrams) + assert.Equal(t, expectedResponse.Results.Macros.FatPercentage, response.Results.Macros.FatPercentage) + assert.Equal(t, expectedResponse.Results.Macros.CarbsGrams, response.Results.Macros.CarbsGrams) + assert.Equal(t, expectedResponse.Results.Macros.CarbsPercentage, response.Results.Macros.CarbsPercentage) + + mockService.AssertExpectations(t) +} + +func TestCalorieHandlers_GetLastCalculation_InvalidMethod(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + req := httptest.NewRequest("POST", "/api/v1/calorie/last", http.NoBody) + w := httptest.NewRecorder() + handlers.GetLastCalculation(w, req) + + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) +} + +func TestCalorieHandlers_GetLastCalculation_ServiceError(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + userID := uuid.New() + mockService.On("GetLastCalculation", mock.Anything, userID).Return(nil, assert.AnError) + + req := httptest.NewRequest("GET", "/api/v1/calorie/last", http.NoBody) + req = req.WithContext(context.WithValue(req.Context(), UserIDKey, userID.String())) + + w := httptest.NewRecorder() + handlers.GetLastCalculation(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestCalorieHandlers_GetLastCalculation_NoUserID(t *testing.T) { + mockService := new(mockCalorieService) + logger := logger.New("INFO", "json") + + handlers := NewCalorieHandlers(mockService, logger) + + req := httptest.NewRequest("GET", "/api/v1/calorie/last", http.NoBody) + // No user ID in context + + w := httptest.NewRecorder() + handlers.GetLastCalculation(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) +} diff --git a/internal/http/helpers.go b/internal/http/helpers.go new file mode 100644 index 0000000..ab5d473 --- /dev/null +++ b/internal/http/helpers.go @@ -0,0 +1,36 @@ +package http + +import ( + "context" + "fmt" + + "github.com/google/uuid" +) + +type contextKey string + +const ( + UserIDKey contextKey = "user_id" + UserEmailKey contextKey = "user_email" +) + +func GetUserIDFromContext(ctx context.Context) (uuid.UUID, error) { + userID, ok := ctx.Value(UserIDKey).(string) + if !ok { + return uuid.Nil, fmt.Errorf("user ID not found in context") + } + + userUUID, err := uuid.Parse(userID) + if err != nil { + return uuid.Nil, fmt.Errorf("invalid user ID format: %w", err) + } + return userUUID, nil +} + +func GetUserEmailFromContext(ctx context.Context) (string, error) { + email, ok := ctx.Value(UserEmailKey).(string) + if !ok { + return "", fmt.Errorf("user email not found in context") + } + return email, nil +} diff --git a/internal/http/user_handlers.go b/internal/http/user_handlers.go index bb27dba..3470953 100644 --- a/internal/http/user_handlers.go +++ b/internal/http/user_handlers.go @@ -7,7 +7,6 @@ import ( "github.com/aleksandr/strive-api/internal/logger" "github.com/aleksandr/strive-api/internal/services" "github.com/aleksandr/strive-api/internal/validation" - "github.com/google/uuid" ) type UserHandlers struct { @@ -40,28 +39,21 @@ type UpdateUserThemeRequest struct { // @Failure 500 {object} ErrorResponse "Internal server error" // @Router /api/v1/user/me [get] func (h *UserHandlers) Me(w http.ResponseWriter, r *http.Request) { - userID, ok := GetUserIDFromContext(r.Context()) - if !ok { - h.logger.Error("User ID not found in context") - http.Error(w, `{"error":{"code":"INTERNAL_ERROR","message":"User ID not found in context"}}`, http.StatusInternalServerError) - return - } - - userUUID, err := uuid.Parse(userID) + userUUID, err := GetUserIDFromContext(r.Context()) if err != nil { - h.logger.Error("Invalid user ID format", "error", err, "user_id", userID) - http.Error(w, `{"error":{"code":"INVALID_USER_ID","message":"Invalid user ID format"}}`, http.StatusInternalServerError) + h.logger.Error("User ID not found in context", "error", err) + http.Error(w, `{"error":{"code":"INTERNAL_ERROR","message":"User ID not found in context"}}`, http.StatusInternalServerError) return } user, err := h.userService.GetUserProfile(r.Context(), userUUID) if err != nil { - h.logger.Error("Failed to get user profile", "error", err, "user_id", userID) + h.logger.Error("Failed to get user profile", "error", err, "user_id", userUUID) http.Error(w, `{"error":{"code":"USER_NOT_FOUND","message":"User not found"}}`, http.StatusNotFound) return } - h.logger.Info("User profile requested", "user_id", userID, "email", user.Email, "theme", user.Theme) + h.logger.Info("User profile requested", "user_id", userUUID, "email", user.Email, "theme", user.Theme) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -82,9 +74,9 @@ func (h *UserHandlers) Me(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} ErrorResponse "Internal server error" // @Router /api/v1/user/theme [put] func (h *UserHandlers) UpdateTheme(w http.ResponseWriter, r *http.Request) { - userID, ok := GetUserIDFromContext(r.Context()) - if !ok { - h.logger.Error("User ID not found in context") + userUUID, err := GetUserIDFromContext(r.Context()) + if err != nil { + h.logger.Error("User ID not found in context", "error", err) http.Error(w, `{"error":{"code":"INTERNAL_ERROR","message":"User ID not found in context"}}`, http.StatusInternalServerError) return } @@ -117,15 +109,8 @@ func (h *UserHandlers) UpdateTheme(w http.ResponseWriter, r *http.Request) { return } - userUUID, err := uuid.Parse(userID) - if err != nil { - h.logger.Error("Invalid user ID format", "error", err, "user_id", userID) - http.Error(w, `{"error":{"code":"INVALID_USER_ID","message":"Invalid user ID format"}}`, http.StatusInternalServerError) - return - } - if err := h.userService.UpdateUserTheme(r.Context(), userUUID, req.Theme); err != nil { - h.logger.Error("Failed to update user theme", "error", err, "user_id", userID, "theme", req.Theme) + h.logger.Error("Failed to update user theme", "error", err, "user_id", userUUID, "theme", req.Theme) if err == services.ErrInvalidTheme { http.Error(w, `{"error":{"code":"INVALID_THEME","message":"Invalid theme value"}}`, http.StatusBadRequest) } else { @@ -134,7 +119,7 @@ func (h *UserHandlers) UpdateTheme(w http.ResponseWriter, r *http.Request) { return } - h.logger.Info("User theme updated successfully", "user_id", userID, "theme", req.Theme) + h.logger.Info("User theme updated successfully", "user_id", userUUID, "theme", req.Theme) response := map[string]interface{}{ "message": "Theme updated successfully", diff --git a/internal/models/calorie.go b/internal/models/calorie.go new file mode 100644 index 0000000..c851007 --- /dev/null +++ b/internal/models/calorie.go @@ -0,0 +1,63 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +type CalorieCalculationData struct { + Gender string `json:"gender" validate:"required,oneof=male female"` + Age int `json:"age" validate:"required,min=15,max=120"` + Height float64 `json:"height" validate:"required,min=100,max=250"` + Weight float64 `json:"weight" validate:"required,min=30,max=300"` + ActivityLevel string `json:"activityLevel" validate:"required"` + Goal string `json:"goal" validate:"required,oneof=lose_weight maintain_weight gain_weight"` + BodyFatPercentage *float64 `json:"bodyFatPercentage,omitempty" validate:"omitempty,min=5,max=50"` +} + +type Macronutrients struct { + ProteinGrams int `json:"proteinGrams"` + ProteinPercentage float64 `json:"proteinPercentage"` + FatGrams int `json:"fatGrams"` + FatPercentage float64 `json:"fatPercentage"` + CarbsGrams int `json:"carbsGrams"` + CarbsPercentage float64 `json:"carbsPercentage"` +} + +type CalorieResults struct { + BMR int `json:"bmr"` + TDEE int `json:"tdee"` + TargetCalories int `json:"targetCalories"` + Formula string `json:"formula"` + Macros Macronutrients `json:"macros"` +} + +type CalorieCalculation struct { + ID uuid.UUID `json:"id" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + Gender string `json:"gender" db:"gender"` + Age int `json:"age" db:"age"` + Height float64 `json:"height" db:"height"` + Weight float64 `json:"weight" db:"weight"` + ActivityLevel string `json:"activityLevel" db:"activity_level"` + Goal string `json:"goal" db:"goal"` + BMR int `json:"bmr" db:"bmr"` + TDEE int `json:"tdee" db:"tdee"` + TargetCalories int `json:"targetCalories" db:"target_calories"` + Formula string `json:"formula" db:"formula"` + ProteinGrams int `json:"proteinGrams" db:"protein_grams"` + ProteinPercentage float64 `json:"proteinPercentage" db:"protein_percentage"` + FatGrams int `json:"fatGrams" db:"fat_grams"` + FatPercentage float64 `json:"fatPercentage" db:"fat_percentage"` + CarbsGrams int `json:"carbsGrams" db:"carbs_grams"` + CarbsPercentage float64 `json:"carbsPercentage" db:"carbs_percentage"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type CalorieCalculationResponse struct { + Data CalorieCalculationData `json:"data"` + Results CalorieResults `json:"results"` + Timestamp time.Time `json:"timestamp"` +} diff --git a/internal/models/calorie_constants.go b/internal/models/calorie_constants.go new file mode 100644 index 0000000..8e1df02 --- /dev/null +++ b/internal/models/calorie_constants.go @@ -0,0 +1,109 @@ +package models + +const ( + // Gender constants + GenderMale = "male" + GenderFemale = "female" + + // Activity level constants + ActivitySedentary = "sedentary" + ActivityLightlyActive = "lightly_active" + ActivityModeratelyActive = "moderately_active" + ActivityVeryActive = "very_active" + ActivityExtremelyActive = "extremely_active" + + // Goal constants + GoalLoseWeight = "lose_weight" + GoalMaintainWeight = "maintain_weight" + GoalGainWeight = "gain_weight" + + // Formula constants + FormulaMifflin = "mifflin" + + // Activity level multipliers for TDEE calculation + ActivityMultiplierSedentary = 1.2 + ActivityMultiplierLightlyActive = 1.375 + ActivityMultiplierModeratelyActive = 1.55 + ActivityMultiplierVeryActive = 1.725 + ActivityMultiplierExtremelyActive = 1.9 + + // Goal modifiers for target calories + GoalModifierLoseWeight = -0.20 + GoalModifierMaintainWeight = 0.0 + GoalModifierGainWeight = 0.15 + + // Validation constants + ActivityLevelValidation = "required,oneof=sedentary lightly_active moderately_active very_active extremely_active" + GoalValidation = "required,oneof=lose_weight maintain_weight gain_weight" + GenderValidation = "required,oneof=male female" + AgeValidation = "required,min=15,max=120" + HeightValidation = "required,min=100,max=250" + WeightValidation = "required,min=30,max=300" + BodyFatValidation = "omitempty,min=5,max=50" +) + +// Base protein per kg by activity level (на сухую массу тела) +var BaseProteinByActivity = map[string]float64{ + ActivitySedentary: 1.2, // Здоровый активный человек + ActivityLightlyActive: 1.4, // Здоровый активный человек + ActivityModeratelyActive: 1.6, // Здоровый активный человек + ActivityVeryActive: 1.8, // Силовые тренировки + ActivityExtremelyActive: 2.0, // Силовые тренировки +} + +// Protein adjustments by goal +var ProteinAdjustmentByGoal = map[string]float64{ + GoalLoseWeight: 0.4, // При похудении: 1.6 + 0.4 = 2.0г/кг (сохранить мышцы) + GoalMaintainWeight: 0.0, // Без изменений + GoalGainWeight: 0.2, // При наборе: 1.6 + 0.2 = 1.8г/кг (набор мышц) +} + +// Fat percentages by goal +var FatPercentageByGoal = map[string]float64{ + GoalLoseWeight: 25.0, + GoalMaintainWeight: 30.0, + GoalGainWeight: 35.0, +} + +// Genders returns slice of gender values +func Genders() []string { + return []string{GenderMale, GenderFemale} +} + +// ActivityLevels returns slice of activity level values +func ActivityLevels() []string { + return []string{ + ActivitySedentary, + ActivityLightlyActive, + ActivityModeratelyActive, + ActivityVeryActive, + ActivityExtremelyActive, + } +} + +// Goals returns slice of goal values +func Goals() []string { + return []string{GoalLoseWeight, GoalMaintainWeight, GoalGainWeight} +} + +// GetActivityMultiplier returns the multiplier for given activity level +func GetActivityMultiplier(activityLevel string) float64 { + multipliers := map[string]float64{ + ActivitySedentary: ActivityMultiplierSedentary, + ActivityLightlyActive: ActivityMultiplierLightlyActive, + ActivityModeratelyActive: ActivityMultiplierModeratelyActive, + ActivityVeryActive: ActivityMultiplierVeryActive, + ActivityExtremelyActive: ActivityMultiplierExtremelyActive, + } + return multipliers[activityLevel] +} + +// GetGoalModifier returns the modifier for given goal +func GetGoalModifier(goal string) float64 { + modifiers := map[string]float64{ + GoalLoseWeight: GoalModifierLoseWeight, + GoalMaintainWeight: GoalModifierMaintainWeight, + GoalGainWeight: GoalModifierGainWeight, + } + return modifiers[goal] +} diff --git a/internal/repositories/calorie_repository.go b/internal/repositories/calorie_repository.go new file mode 100644 index 0000000..37551e2 --- /dev/null +++ b/internal/repositories/calorie_repository.go @@ -0,0 +1,137 @@ +package repositories + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/aleksandr/strive-api/internal/models" +) + +type CalorieRepository interface { + SaveOrUpdate(ctx context.Context, calculation *models.CalorieCalculation) error + GetByUserID(ctx context.Context, userID uuid.UUID) (*models.CalorieCalculation, error) +} + +type calorieRepository struct { + pool *pgxpool.Pool +} + +func NewCalorieRepository(pool *pgxpool.Pool) CalorieRepository { + return &calorieRepository{ + pool: pool, + } +} + +func (r *calorieRepository) SaveOrUpdate(ctx context.Context, calculation *models.CalorieCalculation) error { + query := ` + INSERT INTO calorie_calculations ( + id, user_id, gender, age, height, weight, activity_level, goal, + bmr, tdee, target_calories, formula, + protein_grams, protein_percentage, fat_grams, fat_percentage, + carbs_grams, carbs_percentage, created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, + $9, $10, $11, $12, + $13, $14, $15, $16, + $17, $18, $19, $20 + ) + ON CONFLICT (user_id) DO UPDATE SET + gender = EXCLUDED.gender, + age = EXCLUDED.age, + height = EXCLUDED.height, + weight = EXCLUDED.weight, + activity_level = EXCLUDED.activity_level, + goal = EXCLUDED.goal, + bmr = EXCLUDED.bmr, + tdee = EXCLUDED.tdee, + target_calories = EXCLUDED.target_calories, + formula = EXCLUDED.formula, + protein_grams = EXCLUDED.protein_grams, + protein_percentage = EXCLUDED.protein_percentage, + fat_grams = EXCLUDED.fat_grams, + fat_percentage = EXCLUDED.fat_percentage, + carbs_grams = EXCLUDED.carbs_grams, + carbs_percentage = EXCLUDED.carbs_percentage, + updated_at = EXCLUDED.updated_at + ` + + now := time.Now() + if calculation.CreatedAt.IsZero() { + calculation.CreatedAt = now + } + calculation.UpdatedAt = now + + _, err := r.pool.Exec(ctx, query, + calculation.ID, + calculation.UserID, + calculation.Gender, + calculation.Age, + calculation.Height, + calculation.Weight, + calculation.ActivityLevel, + calculation.Goal, + calculation.BMR, + calculation.TDEE, + calculation.TargetCalories, + calculation.Formula, + calculation.ProteinGrams, + calculation.ProteinPercentage, + calculation.FatGrams, + calculation.FatPercentage, + calculation.CarbsGrams, + calculation.CarbsPercentage, + calculation.CreatedAt, + calculation.UpdatedAt, + ) + if err != nil { + return fmt.Errorf("failed to save calorie calculation: %w", err) + } + + return nil +} + +func (r *calorieRepository) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.CalorieCalculation, error) { + query := ` + SELECT id, user_id, gender, age, height, weight, activity_level, goal, + bmr, tdee, target_calories, formula, + protein_grams, protein_percentage, fat_grams, fat_percentage, + carbs_grams, carbs_percentage, created_at, updated_at + FROM calorie_calculations + WHERE user_id = $1 + ORDER BY created_at DESC + LIMIT 1 + ` + + var calculation models.CalorieCalculation + err := r.pool.QueryRow(ctx, query, userID).Scan( + &calculation.ID, + &calculation.UserID, + &calculation.Gender, + &calculation.Age, + &calculation.Height, + &calculation.Weight, + &calculation.ActivityLevel, + &calculation.Goal, + &calculation.BMR, + &calculation.TDEE, + &calculation.TargetCalories, + &calculation.Formula, + &calculation.ProteinGrams, + &calculation.ProteinPercentage, + &calculation.FatGrams, + &calculation.FatPercentage, + &calculation.CarbsGrams, + &calculation.CarbsPercentage, + &calculation.CreatedAt, + &calculation.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to get calorie calculation: %w", err) + } + + return &calculation, nil +} diff --git a/internal/repositories/calorie_repository_test.go b/internal/repositories/calorie_repository_test.go new file mode 100644 index 0000000..422ab6e --- /dev/null +++ b/internal/repositories/calorie_repository_test.go @@ -0,0 +1,227 @@ +package repositories + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/aleksandr/strive-api/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockCalorieRepository struct { + mock.Mock +} + +func (m *mockCalorieRepository) SaveOrUpdate(ctx context.Context, calculation *models.CalorieCalculation) error { + args := m.Called(ctx, calculation) + return args.Error(0) +} + +func (m *mockCalorieRepository) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.CalorieCalculation, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.CalorieCalculation), args.Error(1) +} + +func TestCalorieRepository_SaveOrUpdate(t *testing.T) { + repo := &mockCalorieRepository{} + userID := uuid.New() + calculation := &models.CalorieCalculation{ + ID: uuid.New(), + UserID: userID, + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + BMR: 1650, + TDEE: 2558, + TargetCalories: 2558, + Formula: models.FormulaMifflin, + ProteinGrams: 196, + ProteinPercentage: 30.6, + FatGrams: 85, + FatPercentage: 30.0, + CarbsGrams: 256, + CarbsPercentage: 39.4, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + repo.On("SaveOrUpdate", mock.Anything, calculation).Return(nil) + + err := repo.SaveOrUpdate(context.Background(), calculation) + assert.NoError(t, err) + repo.AssertExpectations(t) +} + +func TestCalorieRepository_SaveOrUpdate_Error(t *testing.T) { + repo := &mockCalorieRepository{} + userID := uuid.New() + calculation := &models.CalorieCalculation{ + ID: uuid.New(), + UserID: userID, + } + + expectedError := errors.New("database error") + repo.On("SaveOrUpdate", mock.Anything, calculation).Return(expectedError) + + err := repo.SaveOrUpdate(context.Background(), calculation) + assert.Error(t, err) + assert.Equal(t, expectedError, err) + repo.AssertExpectations(t) +} + +func TestCalorieRepository_GetByUserID(t *testing.T) { + repo := &mockCalorieRepository{} + userID := uuid.New() + expectedCalculation := &models.CalorieCalculation{ + ID: uuid.New(), + UserID: userID, + Gender: models.GenderFemale, + Age: 28, + Height: 165.0, + Weight: 65.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalLoseWeight, + BMR: 1385, + TDEE: 2147, + TargetCalories: 1718, + Formula: models.FormulaMifflin, + ProteinGrams: 120, + ProteinPercentage: 28.0, + FatGrams: 48, + FatPercentage: 25.0, + CarbsGrams: 172, + CarbsPercentage: 40.0, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + repo.On("GetByUserID", mock.Anything, userID).Return(expectedCalculation, nil) + + result, err := repo.GetByUserID(context.Background(), userID) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, expectedCalculation.ID, result.ID) + assert.Equal(t, expectedCalculation.UserID, result.UserID) + assert.Equal(t, expectedCalculation.Gender, result.Gender) + assert.Equal(t, expectedCalculation.Age, result.Age) + assert.Equal(t, expectedCalculation.Height, result.Height) + assert.Equal(t, expectedCalculation.Weight, result.Weight) + assert.Equal(t, expectedCalculation.ActivityLevel, result.ActivityLevel) + assert.Equal(t, expectedCalculation.Goal, result.Goal) + assert.Equal(t, expectedCalculation.BMR, result.BMR) + assert.Equal(t, expectedCalculation.TDEE, result.TDEE) + assert.Equal(t, expectedCalculation.TargetCalories, result.TargetCalories) + assert.Equal(t, expectedCalculation.Formula, result.Formula) + assert.Equal(t, expectedCalculation.ProteinGrams, result.ProteinGrams) + assert.Equal(t, expectedCalculation.ProteinPercentage, result.ProteinPercentage) + assert.Equal(t, expectedCalculation.FatGrams, result.FatGrams) + assert.Equal(t, expectedCalculation.FatPercentage, result.FatPercentage) + assert.Equal(t, expectedCalculation.CarbsGrams, result.CarbsGrams) + assert.Equal(t, expectedCalculation.CarbsPercentage, result.CarbsPercentage) + repo.AssertExpectations(t) +} + +func TestCalorieRepository_GetByUserID_NotFound(t *testing.T) { + repo := &mockCalorieRepository{} + nonExistentUserID := uuid.New() + + expectedError := errors.New("calculation not found") + repo.On("GetByUserID", mock.Anything, nonExistentUserID).Return(nil, expectedError) + + result, err := repo.GetByUserID(context.Background(), nonExistentUserID) + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedError, err) + repo.AssertExpectations(t) +} + +func TestCalorieRepository_UpdateExisting(t *testing.T) { + repo := &mockCalorieRepository{} + userID := uuid.New() + + initialCalculation := &models.CalorieCalculation{ + ID: uuid.New(), + UserID: userID, + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + BMR: 1650, + TDEE: 2558, + TargetCalories: 2558, + Formula: models.FormulaMifflin, + ProteinGrams: 196, + ProteinPercentage: 30.6, + FatGrams: 85, + FatPercentage: 30.0, + CarbsGrams: 256, + CarbsPercentage: 39.4, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + updatedCalculation := &models.CalorieCalculation{ + ID: initialCalculation.ID, + UserID: userID, + Gender: models.GenderMale, + Age: 26, + Height: 176.0, + Weight: 72.0, + ActivityLevel: models.ActivityVeryActive, + Goal: models.GoalGainWeight, + BMR: 1680, + TDEE: 3024, + TargetCalories: 3326, + Formula: models.FormulaMifflin, + ProteinGrams: 208, + ProteinPercentage: 25.0, + FatGrams: 119, + FatPercentage: 30.0, + CarbsGrams: 357, + CarbsPercentage: 40.0, + CreatedAt: initialCalculation.CreatedAt, + UpdatedAt: time.Now(), + } + + repo.On("SaveOrUpdate", mock.Anything, initialCalculation).Return(nil) + repo.On("SaveOrUpdate", mock.Anything, updatedCalculation).Return(nil) + repo.On("GetByUserID", mock.Anything, userID).Return(updatedCalculation, nil) + + err := repo.SaveOrUpdate(context.Background(), initialCalculation) + assert.NoError(t, err) + + err = repo.SaveOrUpdate(context.Background(), updatedCalculation) + assert.NoError(t, err) + + result, err := repo.GetByUserID(context.Background(), userID) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, updatedCalculation.Age, result.Age) + assert.Equal(t, updatedCalculation.Height, result.Height) + assert.Equal(t, updatedCalculation.Weight, result.Weight) + assert.Equal(t, updatedCalculation.ActivityLevel, result.ActivityLevel) + assert.Equal(t, updatedCalculation.Goal, result.Goal) + assert.Equal(t, updatedCalculation.BMR, result.BMR) + assert.Equal(t, updatedCalculation.TDEE, result.TDEE) + assert.Equal(t, updatedCalculation.TargetCalories, result.TargetCalories) + assert.Equal(t, updatedCalculation.ProteinGrams, result.ProteinGrams) + assert.Equal(t, updatedCalculation.ProteinPercentage, result.ProteinPercentage) + assert.Equal(t, updatedCalculation.FatGrams, result.FatGrams) + assert.Equal(t, updatedCalculation.FatPercentage, result.FatPercentage) + assert.Equal(t, updatedCalculation.CarbsGrams, result.CarbsGrams) + assert.Equal(t, updatedCalculation.CarbsPercentage, result.CarbsPercentage) + assert.True(t, result.UpdatedAt.After(result.CreatedAt)) + repo.AssertExpectations(t) +} diff --git a/internal/services/calorie_service.go b/internal/services/calorie_service.go new file mode 100644 index 0000000..8466d19 --- /dev/null +++ b/internal/services/calorie_service.go @@ -0,0 +1,237 @@ +package services + +import ( + "context" + "fmt" + "math" + "time" + + "github.com/google/uuid" + + "github.com/aleksandr/strive-api/internal/models" + "github.com/aleksandr/strive-api/internal/repositories" +) + +type CalorieService interface { + CalculateCalories(ctx context.Context, userID uuid.UUID, data *models.CalorieCalculationData) (*models.CalorieResults, error) + GetLastCalculation(ctx context.Context, userID uuid.UUID) (*models.CalorieCalculationResponse, error) +} + +type calorieService struct { + calorieRepo repositories.CalorieRepository +} + +func NewCalorieService(calorieRepo repositories.CalorieRepository) CalorieService { + return &calorieService{ + calorieRepo: calorieRepo, + } +} + +func (s *calorieService) CalculateCalories( + ctx context.Context, + userID uuid.UUID, + data *models.CalorieCalculationData, +) (*models.CalorieResults, error) { + bmr := s.calculateBMRMifflin(data) + tdee := s.calculateTDEE(bmr, data.ActivityLevel) + targetCalories := s.calculateTargetCalories(tdee, data.Goal) + macros := s.calculateMacronutrients( + targetCalories, + data.ActivityLevel, + data.Goal, + data.Weight, + data.Gender, + data.Age, + data.Height, + data.BodyFatPercentage, + ) + + results := &models.CalorieResults{ + BMR: int(math.Round(bmr)), + TDEE: tdee, + TargetCalories: targetCalories, + Formula: models.FormulaMifflin, + Macros: macros, + } + + calculation := &models.CalorieCalculation{ + ID: uuid.New(), + UserID: userID, + Gender: data.Gender, + Age: data.Age, + Height: data.Height, + Weight: data.Weight, + ActivityLevel: data.ActivityLevel, + Goal: data.Goal, + BMR: results.BMR, + TDEE: results.TDEE, + TargetCalories: results.TargetCalories, + Formula: results.Formula, + ProteinGrams: macros.ProteinGrams, + ProteinPercentage: macros.ProteinPercentage, + FatGrams: macros.FatGrams, + FatPercentage: macros.FatPercentage, + CarbsGrams: macros.CarbsGrams, + CarbsPercentage: macros.CarbsPercentage, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.calorieRepo.SaveOrUpdate(ctx, calculation); err != nil { + return nil, fmt.Errorf("failed to save calculation: %w", err) + } + + return results, nil +} + +func (s *calorieService) GetLastCalculation(ctx context.Context, userID uuid.UUID) (*models.CalorieCalculationResponse, error) { + calculation, err := s.calorieRepo.GetByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to get calculation: %w", err) + } + + if calculation == nil { + return nil, fmt.Errorf("no calculation found for user") + } + + data := models.CalorieCalculationData{ + Gender: calculation.Gender, + Age: calculation.Age, + Height: calculation.Height, + Weight: calculation.Weight, + ActivityLevel: calculation.ActivityLevel, + Goal: calculation.Goal, + } + + results := models.CalorieResults{ + BMR: calculation.BMR, + TDEE: calculation.TDEE, + TargetCalories: calculation.TargetCalories, + Formula: calculation.Formula, + Macros: models.Macronutrients{ + ProteinGrams: calculation.ProteinGrams, + ProteinPercentage: calculation.ProteinPercentage, + FatGrams: calculation.FatGrams, + FatPercentage: calculation.FatPercentage, + CarbsGrams: calculation.CarbsGrams, + CarbsPercentage: calculation.CarbsPercentage, + }, + } + + return &models.CalorieCalculationResponse{ + Data: data, + Results: results, + Timestamp: calculation.CreatedAt, + }, nil +} + +func (s *calorieService) calculateBMRMifflin(data *models.CalorieCalculationData) float64 { + const ( + weightMultiplier = 10.0 + heightMultiplier = 6.25 + ageMultiplier = 5.0 + maleOffset = 5.0 + femaleOffset = -161.0 + ) + + base := weightMultiplier*data.Weight + heightMultiplier*data.Height - ageMultiplier*float64(data.Age) + + if data.Gender == models.GenderMale { + return base + maleOffset + } + return base + femaleOffset +} + +func (s *calorieService) calculateTDEE(bmr float64, activityLevel string) int { + return int(math.Round(bmr * models.GetActivityMultiplier(activityLevel))) +} + +func (s *calorieService) calculateTargetCalories(tdee int, goal string) int { + modifier := models.GetGoalModifier(goal) + return int(math.Round(float64(tdee) * (1 + modifier))) +} + +func (s *calorieService) calculateBodyFatPercentage(gender string, age int, weight, height float64) float64 { + bmi := weight / ((height / 100) * (height / 100)) + genderValue := 1.0 + if gender == "female" { + genderValue = 0.0 + } + + bodyFatPercentage := 1.20*bmi + 0.23*float64(age) - 10.8*genderValue - 5.4 + + // Более реалистичные ограничения для процента жира + if gender == "male" { + if bodyFatPercentage < 8.0 { + bodyFatPercentage = 8.0 + } + if bodyFatPercentage > 25.0 { + bodyFatPercentage = 25.0 + } + } else { + if bodyFatPercentage < 12.0 { + bodyFatPercentage = 12.0 + } + if bodyFatPercentage > 35.0 { + bodyFatPercentage = 35.0 + } + } + + return bodyFatPercentage +} + +func (s *calorieService) getBodyFatPercentage(provided *float64, gender string, age int, weight, height float64) float64 { + if provided != nil { + return *provided + } + return s.calculateBodyFatPercentage(gender, age, weight, height) +} + +func (s *calorieService) calculateMacronutrients( + targetCalories int, + activityLevel, goal string, + weight float64, + gender string, + age int, + height float64, + bodyFatPercentage *float64, +) models.Macronutrients { + actualBodyFatPercentage := s.getBodyFatPercentage(bodyFatPercentage, gender, age, weight, height) + leanBodyMass := weight * (1 - actualBodyFatPercentage/100) + + baseProteinPerKg := models.BaseProteinByActivity[activityLevel] + proteinAdjustment := models.ProteinAdjustmentByGoal[goal] + + proteinPerKg := baseProteinPerKg + proteinAdjustment + + proteinGrams := int(math.Round(proteinPerKg * leanBodyMass)) + proteinCalories := proteinGrams * 4 + + fatPercentage := models.FatPercentageByGoal[goal] + fatCalories := int(math.Round(float64(targetCalories) * fatPercentage / 100)) + fatGrams := int(math.Round(float64(fatCalories) / 9)) + + carbsCalories := targetCalories - proteinCalories - fatCalories + carbsGrams := int(math.Round(float64(carbsCalories) / 4)) + + if carbsGrams < 0 { + carbsGrams = 0 + carbsCalories = 0 + adjustedFatCalories := targetCalories - proteinCalories + fatGrams = int(math.Round(float64(adjustedFatCalories) / 9)) + fatCalories = fatGrams * 9 + } + + actualProteinPercentage := float64(proteinCalories) / float64(targetCalories) * 100 + actualFatPercentage := float64(fatCalories) / float64(targetCalories) * 100 + actualCarbsPercentage := float64(carbsCalories) / float64(targetCalories) * 100 + + return models.Macronutrients{ + ProteinGrams: proteinGrams, + ProteinPercentage: math.Round(actualProteinPercentage*100) / 100, + FatGrams: fatGrams, + FatPercentage: math.Round(actualFatPercentage*100) / 100, + CarbsGrams: carbsGrams, + CarbsPercentage: math.Round(actualCarbsPercentage*100) / 100, + } +} diff --git a/internal/services/calorie_service_test.go b/internal/services/calorie_service_test.go new file mode 100644 index 0000000..cd2eb1c --- /dev/null +++ b/internal/services/calorie_service_test.go @@ -0,0 +1,362 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/aleksandr/strive-api/internal/models" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockCalorieRepository struct { + mock.Mock +} + +func (m *mockCalorieRepository) SaveOrUpdate(ctx context.Context, calculation *models.CalorieCalculation) error { + args := m.Called(ctx, calculation) + return args.Error(0) +} + +func (m *mockCalorieRepository) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.CalorieCalculation, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*models.CalorieCalculation), args.Error(1) +} + +func TestCalorieService_CalculateCalories(t *testing.T) { + tests := []struct { + name string + userID uuid.UUID + data *models.CalorieCalculationData + expectedBMR int + expectedTDEE int + expectedTarget int + }{ + { + name: "Male, 25 years, 175cm, 70kg, moderately active, maintain weight", + userID: uuid.New(), + data: &models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + }, + expectedBMR: 1674, + expectedTDEE: 2594, + expectedTarget: 2594, + }, + { + name: "Female, 28 years, 165cm, 65kg, moderately active, lose weight", + userID: uuid.New(), + data: &models.CalorieCalculationData{ + Gender: models.GenderFemale, + Age: 28, + Height: 165.0, + Weight: 65.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalLoseWeight, + }, + expectedBMR: 1380, + expectedTDEE: 2139, + expectedTarget: 1711, + }, + { + name: "Male, 30 years, 180cm, 80kg, very active, gain weight", + userID: uuid.New(), + data: &models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 30, + Height: 180.0, + Weight: 80.0, + ActivityLevel: models.ActivityVeryActive, + Goal: models.GoalGainWeight, + }, + expectedBMR: 1780, + expectedTDEE: 3071, + expectedTarget: 3532, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := new(mockCalorieRepository) + mockRepo.On("SaveOrUpdate", mock.Anything, mock.AnythingOfType("*models.CalorieCalculation")).Return(nil) + + service := NewCalorieService(mockRepo) + result, err := service.CalculateCalories(context.Background(), tt.userID, tt.data) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, tt.expectedBMR, result.BMR) + assert.Equal(t, tt.expectedTDEE, result.TDEE) + assert.Equal(t, tt.expectedTarget, result.TargetCalories) + assert.NotNil(t, result.Macros) + assert.Greater(t, result.Macros.ProteinGrams, 0) + assert.Greater(t, result.Macros.FatGrams, 0) + assert.Greater(t, result.Macros.CarbsGrams, 0) + + mockRepo.AssertExpectations(t) + }) + } +} + +func TestCalorieService_GetLastCalculation(t *testing.T) { + userID := uuid.New() + expectedCalculation := &models.CalorieCalculation{ + ID: uuid.New(), + UserID: userID, + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + ActivityLevel: models.ActivityModeratelyActive, + Goal: models.GoalMaintainWeight, + BMR: 1650, + TDEE: 2558, + TargetCalories: 2558, + Formula: models.FormulaMifflin, + ProteinGrams: 196, + ProteinPercentage: 30.6, + FatGrams: 85, + FatPercentage: 30.0, + CarbsGrams: 256, + CarbsPercentage: 39.4, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + mockRepo := new(mockCalorieRepository) + mockRepo.On("GetByUserID", mock.Anything, userID).Return(expectedCalculation, nil) + + service := NewCalorieService(mockRepo) + result, err := service.GetLastCalculation(context.Background(), userID) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, expectedCalculation.Gender, result.Data.Gender) + assert.Equal(t, expectedCalculation.Gender, result.Data.Gender) + assert.Equal(t, expectedCalculation.Age, result.Data.Age) + assert.Equal(t, expectedCalculation.Height, result.Data.Height) + assert.Equal(t, expectedCalculation.Weight, result.Data.Weight) + assert.Equal(t, expectedCalculation.ActivityLevel, result.Data.ActivityLevel) + assert.Equal(t, expectedCalculation.Goal, result.Data.Goal) + assert.Equal(t, expectedCalculation.BMR, result.Results.BMR) + assert.Equal(t, expectedCalculation.TDEE, result.Results.TDEE) + assert.Equal(t, expectedCalculation.TargetCalories, result.Results.TargetCalories) + assert.Equal(t, expectedCalculation.Formula, result.Results.Formula) + assert.Equal(t, expectedCalculation.ProteinGrams, result.Results.Macros.ProteinGrams) + assert.Equal(t, expectedCalculation.ProteinPercentage, result.Results.Macros.ProteinPercentage) + assert.Equal(t, expectedCalculation.FatGrams, result.Results.Macros.FatGrams) + assert.Equal(t, expectedCalculation.FatPercentage, result.Results.Macros.FatPercentage) + assert.Equal(t, expectedCalculation.CarbsGrams, result.Results.Macros.CarbsGrams) + assert.Equal(t, expectedCalculation.CarbsPercentage, result.Results.Macros.CarbsPercentage) + assert.Equal(t, expectedCalculation.CreatedAt, result.Timestamp) + + mockRepo.AssertExpectations(t) +} + +func TestCalorieService_GetLastCalculation_NotFound(t *testing.T) { + userID := uuid.New() + + mockRepo := new(mockCalorieRepository) + mockRepo.On("GetByUserID", mock.Anything, userID).Return(nil, assert.AnError) + + service := NewCalorieService(mockRepo) + result, err := service.GetLastCalculation(context.Background(), userID) + + assert.Error(t, err) + assert.Nil(t, result) + mockRepo.AssertExpectations(t) +} + +func TestCalorieService_CalculateBMRMifflin(t *testing.T) { + service := &calorieService{} + + tests := []struct { + name string + data models.CalorieCalculationData + expected int + }{ + { + name: "Male, 25 years, 175cm, 70kg", + data: models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 25, + Height: 175.0, + Weight: 70.0, + }, + expected: 1673, + }, + { + name: "Female, 28 years, 165cm, 65kg", + data: models.CalorieCalculationData{ + Gender: models.GenderFemale, + Age: 28, + Height: 165.0, + Weight: 65.0, + }, + expected: 1380, + }, + { + name: "Male, 30 years, 180cm, 80kg", + data: models.CalorieCalculationData{ + Gender: models.GenderMale, + Age: 30, + Height: 180.0, + Weight: 80.0, + }, + expected: 1780, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := service.calculateBMRMifflin(&tt.data) + assert.Equal(t, tt.expected, int(result)) + }) + } +} + +func TestCalorieService_CalculateTDEE(t *testing.T) { + service := &calorieService{} + + tests := []struct { + name string + bmr float64 + activityLevel string + expected int + }{ + { + name: "Sedentary", + bmr: 1650.0, + activityLevel: models.ActivitySedentary, + expected: 1980, + }, + { + name: "Lightly Active", + bmr: 1650.0, + activityLevel: models.ActivityLightlyActive, + expected: 2269, + }, + { + name: "Moderately Active", + bmr: 1650.0, + activityLevel: models.ActivityModeratelyActive, + expected: 2558, + }, + { + name: "Very Active", + bmr: 1650.0, + activityLevel: models.ActivityVeryActive, + expected: 2846, + }, + { + name: "Extremely Active", + bmr: 1650.0, + activityLevel: models.ActivityExtremelyActive, + expected: 3135, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := service.calculateTDEE(tt.bmr, tt.activityLevel) + assert.Equal(t, tt.expected, int(result)) + }) + } +} + +func TestCalorieService_CalculateTargetCalories(t *testing.T) { + service := &calorieService{} + + tests := []struct { + name string + tdee float64 + goal string + expected int + }{ + { + name: "Lose Weight", + tdee: 2000.0, + goal: models.GoalLoseWeight, + expected: 1600, + }, + { + name: "Maintain Weight", + tdee: 2000.0, + goal: models.GoalMaintainWeight, + expected: 2000, + }, + { + name: "Gain Weight", + tdee: 2000.0, + goal: models.GoalGainWeight, + expected: 2300, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := service.calculateTargetCalories(int(tt.tdee), tt.goal) + assert.Equal(t, tt.expected, int(result)) + }) + } +} + +func TestCalorieService_CalculateMacronutrients(t *testing.T) { + service := &calorieService{} + + tests := []struct { + name string + targetCalories int + activityLevel string + goal string + weight float64 + expectedProtein int + expectedFat int + expectedCarbs int + }{ + { + name: "Moderately Active, Maintain Weight", + targetCalories: 2000, + activityLevel: models.ActivityModeratelyActive, + goal: models.GoalMaintainWeight, + weight: 80.0, + expectedProtein: 140, + expectedFat: 67, + expectedCarbs: 200, + }, + { + name: "Very Active, Lose Weight", + targetCalories: 1800, + activityLevel: models.ActivityVeryActive, + goal: models.GoalLoseWeight, + weight: 70.0, + expectedProtein: 144, + expectedFat: 50, + expectedCarbs: 180, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := service.calculateMacronutrients(tt.targetCalories, tt.activityLevel, tt.goal, tt.weight, "male", 30, 180.0, nil) + + assert.NotNil(t, result) + assert.Greater(t, result.ProteinGrams, 0) + assert.Greater(t, result.FatGrams, 0) + assert.Greater(t, result.CarbsGrams, 0) + + // Check that percentages add up to approximately 100% + totalPercentage := result.ProteinPercentage + result.FatPercentage + result.CarbsPercentage + assert.InDelta(t, 100.0, totalPercentage, 1.0) + }) + } +} diff --git a/internal/validation/validator.go b/internal/validation/validator.go index f2207b2..e1d0eb9 100644 --- a/internal/validation/validator.go +++ b/internal/validation/validator.go @@ -5,6 +5,8 @@ import ( "regexp" "strings" "unicode" + + "github.com/aleksandr/strive-api/internal/models" ) type ValidationError struct { @@ -14,6 +16,38 @@ type ValidationError struct { type ValidationErrors []ValidationError +type Validator struct{} + +func (v *Validator) Validate(data interface{}) error { + var errors ValidationErrors + + if d, ok := data.(models.CalorieCalculationData); ok { + if err := ValidateGender(d.Gender); err != nil { + errors = append(errors, ValidationError{Field: "gender", Message: err.Error()}) + } + if err := ValidateAge(d.Age); err != nil { + errors = append(errors, ValidationError{Field: "age", Message: err.Error()}) + } + if err := ValidateHeight(d.Height); err != nil { + errors = append(errors, ValidationError{Field: "height", Message: err.Error()}) + } + if err := ValidateWeight(d.Weight); err != nil { + errors = append(errors, ValidationError{Field: "weight", Message: err.Error()}) + } + if err := ValidateActivityLevel(d.ActivityLevel); err != nil { + errors = append(errors, ValidationError{Field: "activityLevel", Message: err.Error()}) + } + if err := ValidateGoal(d.Goal); err != nil { + errors = append(errors, ValidationError{Field: "goal", Message: err.Error()}) + } + } + + if len(errors) > 0 { + return errors + } + return nil +} + func (ve ValidationErrors) Error() string { var messages []string for _, err := range ve { @@ -106,3 +140,63 @@ func ValidateString(value, fieldName string, minLen, maxLen int) error { } return nil } + +func ValidateGender(gender string) error { + if gender == "" { + return fmt.Errorf("gender is required") + } + for _, validGender := range models.Genders() { + if gender == validGender { + return nil + } + } + return fmt.Errorf("gender must be 'male' or 'female'") +} + +func ValidateAge(age int) error { + if age < 15 { + return fmt.Errorf("age must be at least 15 years") + } + if age > 120 { + return fmt.Errorf("age must be at most 120 years") + } + return nil +} + +func ValidateHeight(height float64) error { + if height < 100 { + return fmt.Errorf("height must be at least 100 cm") + } + if height > 250 { + return fmt.Errorf("height must be at most 250 cm") + } + return nil +} + +func ValidateWeight(weight float64) error { + if weight < 30 { + return fmt.Errorf("weight must be at least 30 kg") + } + if weight > 300 { + return fmt.Errorf("weight must be at most 300 kg") + } + return nil +} + +func ValidateActivityLevel(activityLevel string) error { + for _, level := range models.ActivityLevels() { + if activityLevel == level { + return nil + } + } + return fmt.Errorf("activity level must be one of: sedentary, lightly_active, moderately_active, very_active, extremely_active") +} + +func ValidateGoal(goal string) error { + for _, validGoal := range models.Goals() { + if goal == validGoal { + return nil + } + } + return fmt.Errorf("goal must be one of: lose_weight, maintain_weight, gain_weight") +} diff --git a/migrations/000004_calorie_calculations.down.sql b/migrations/000004_calorie_calculations.down.sql new file mode 100644 index 0000000..e62badf --- /dev/null +++ b/migrations/000004_calorie_calculations.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS calorie_calculations; diff --git a/migrations/000004_calorie_calculations.up.sql b/migrations/000004_calorie_calculations.up.sql new file mode 100644 index 0000000..de16207 --- /dev/null +++ b/migrations/000004_calorie_calculations.up.sql @@ -0,0 +1,27 @@ +CREATE TABLE calorie_calculations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL, + gender VARCHAR(10) NOT NULL CHECK (gender IN ('male', 'female')), + age INTEGER NOT NULL CHECK (age >= 15 AND age <= 120), + height DECIMAL(5,2) NOT NULL CHECK (height >= 100 AND height <= 250), + weight DECIMAL(5,2) NOT NULL CHECK (weight >= 30 AND weight <= 300), + activity_level VARCHAR(20) NOT NULL CHECK (activity_level IN ('sedentary', 'lightly_active', 'moderately_active', 'very_active', 'extremely_active')), + goal VARCHAR(20) NOT NULL CHECK (goal IN ('lose_weight', 'maintain_weight', 'gain_weight')), + bmr INTEGER NOT NULL, + tdee INTEGER NOT NULL, + target_calories INTEGER NOT NULL, + formula VARCHAR(10) NOT NULL DEFAULT 'mifflin', + protein_grams INTEGER NOT NULL, + protein_percentage DECIMAL(5,2) NOT NULL, + fat_grams INTEGER NOT NULL, + fat_percentage DECIMAL(5,2) NOT NULL, + carbs_grams INTEGER NOT NULL, + carbs_percentage DECIMAL(5,2) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE(user_id) +); + +CREATE INDEX idx_calorie_calculations_user_id ON calorie_calculations(user_id); +CREATE INDEX idx_calorie_calculations_created_at ON calorie_calculations(created_at); diff --git a/server b/server index 8365d64..4b9dfe8 100755 Binary files a/server and b/server differ