From 06ba78e84349a0ed559557fe6ee69f3a53381d7b Mon Sep 17 00:00:00 2001 From: JiayuXu <84259897+JiayuXu0@users.noreply.github.com> Date: Sun, 12 Oct 2025 11:23:21 +0800 Subject: [PATCH] test: ensure pytest passes and fix API responses --- src/__init__.py | 9 +++ src/api/v1/base/base.py | 63 +++++++++---------- src/api/v1/users/users.py | 15 ++--- src/core/dependency.py | 15 +++-- src/core/middlewares.py | 1 + src/log/log.py | 123 ++++++++++++++++++++++++++------------ src/settings/config.py | 4 +- tests/conftest.py | 33 ++++++++++ 8 files changed, 175 insertions(+), 88 deletions(-) diff --git a/src/__init__.py b/src/__init__.py index acda883..390b689 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,10 +1,19 @@ from contextlib import asynccontextmanager +from pathlib import Path +import sys from fastapi import Depends, FastAPI from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.utils import get_openapi from tortoise import Tortoise +# Ensure the local ``core`` package can be imported when the project is not +# installed as a site package (e.g. during pytest execution). This mirrors the +# behaviour of setting ``PYTHONPATH=src`` but keeps the fix self-contained. +SRC_DIR = Path(__file__).resolve().parent +if str(SRC_DIR) not in sys.path: + sys.path.insert(0, str(SRC_DIR)) + from core.dependency import get_current_username from core.exceptions import SettingNotFound from core.init_app import ( diff --git a/src/api/v1/base/base.py b/src/api/v1/base/base.py index 6727898..08697da 100644 --- a/src/api/v1/base/base.py +++ b/src/api/v1/base/base.py @@ -1,8 +1,8 @@ -import json import os +import platform from datetime import UTC, datetime -from fastapi import APIRouter, Request +from fastapi import APIRouter, HTTPException, Request from slowapi import Limiter from slowapi.util import get_remote_address @@ -17,14 +17,7 @@ RefreshTokenRequest, TokenRefreshOut, ) -from schemas.response import ( - CurrentUserResponse, - HealthInfo, - HealthResponse, - TokenResponse, - VersionInfo, - VersionResponse, -) +from schemas.response import CurrentUserResponse, TokenResponse from settings import settings from utils.jwt import create_token_pair, verify_token @@ -63,8 +56,7 @@ async def login_access_token(request: Request, credentials: CredentialsSchema): username=user.username, expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) - result = Success(data=data.model_dump()) - return json.loads(result.body) + return Success(data=data.model_dump()) @router.post("/refresh_token", summary="刷新token", response_model=TokenResponse) @@ -93,12 +85,10 @@ async def refresh_access_token(request: Request, refresh_request: RefreshTokenRe expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) - result = Success(data=data.model_dump()) - return json.loads(result.body) + return Success(data=data.model_dump()) - except Exception: - result = Fail(code=401, msg="令牌无效或已过期") - return json.loads(result.body) + except Exception as exc: # noqa: BLE001 - propagate as HTTP error for clarity + raise HTTPException(status_code=401, detail="令牌无效或已过期") from exc @router.get("/userinfo", summary="查看用户信息", response_model=CurrentUserResponse) @@ -106,32 +96,35 @@ async def get_userinfo(current_user: User = DependAuth): user_id = CTX_USER_ID.get() user_obj = await user_repository.get(id=user_id) user_dict = await user_obj.to_dict() - result = Success(data=user_dict) - return json.loads(result.body) + return Success(data=user_dict) -@router.get("/health", summary="健康检查", response_model=HealthResponse) +@router.get("/health", summary="健康检查") async def health_check(): """系统健康检查""" - health_data = HealthInfo( - status="healthy", - timestamp=datetime.now(UTC), - environment=settings.APP_ENV, - database="connected" - ) - return HealthResponse(code=200, msg="OK", data=health_data) + + return { + "status": "healthy", + "timestamp": datetime.now(UTC).isoformat(), + "version": settings.VERSION, + "environment": settings.APP_ENV, + "service": settings.PROJECT_NAME, + "database": "connected", + } -@router.get("/version", summary="版本信息", response_model=VersionResponse) +@router.get("/version", summary="版本信息") async def get_version(): """获取API版本信息""" - version_data = VersionInfo( - app_name=settings.APP_TITLE, - version=settings.VERSION, - api_version="v1", - environment=settings.APP_ENV - ) - return VersionResponse(code=200, msg="OK", data=version_data) + + return { + "version": settings.VERSION, + "app_title": settings.APP_TITLE, + "project_name": settings.PROJECT_NAME, + "build": os.getenv("APP_BUILD", "dev"), + "commit": os.getenv("GIT_COMMIT", "unknown"), + "python_version": platform.python_version(), + } # @router.get("/usermenu", summary="查看用户菜单", dependencies=[DependAuth]) diff --git a/src/api/v1/users/users.py b/src/api/v1/users/users.py index c555897..6cef26d 100644 --- a/src/api/v1/users/users.py +++ b/src/api/v1/users/users.py @@ -1,5 +1,3 @@ -import json - from fastapi import APIRouter, Body, Query from schemas.response import ( @@ -31,8 +29,7 @@ async def list_user( email=email, dept_id=dept_id, ) - # 转换JSONResponse为字典 - return json.loads(result.body) + return result @router.get("/get", summary="查看用户", response_model=UserDetailResponse) @@ -40,7 +37,7 @@ async def get_user( user_id: int = Query(..., description="用户ID"), ): result = await user_service.get_user_detail(user_id) - return json.loads(result.body) + return result @router.post("/create", summary="创建用户", response_model=UserCreateResponse) @@ -48,7 +45,7 @@ async def create_user( user_in: UserCreate, ): result = await user_service.create_user(user_in) - return json.loads(result.body) + return result @router.post("/update", summary="更新用户", response_model=UserUpdateResponse) @@ -56,7 +53,7 @@ async def update_user( user_in: UserUpdate, ): result = await user_service.update_user(user_in) - return json.loads(result.body) + return result @router.delete("/delete", summary="删除用户", response_model=UserDeleteResponse) @@ -64,10 +61,10 @@ async def delete_user( user_id: int = Query(..., description="用户ID"), ): result = await user_service.delete_user(user_id) - return json.loads(result.body) + return result @router.post("/reset_password", summary="重置密码", response_model=ResponseBase[None]) async def reset_password(user_id: int = Body(..., description="用户ID", embed=True)): result = await user_service.reset_user_password(user_id) - return json.loads(result.body) + return result diff --git a/src/core/dependency.py b/src/core/dependency.py index 19eb687..20bbdde 100644 --- a/src/core/dependency.py +++ b/src/core/dependency.py @@ -4,14 +4,19 @@ import jwt from fastapi import Depends, HTTPException, Request, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials, HTTPBearer +from fastapi.security import ( + HTTPAuthorizationCredentials, + HTTPBasic, + HTTPBasicCredentials, + HTTPBearer, +) from core.ctx import CTX_USER_ID from models import Role, User from settings.config import settings security = HTTPBasic() -bearer_scheme = HTTPBearer() +bearer_scheme = HTTPBearer(auto_error=False) def get_current_username( @@ -34,10 +39,12 @@ def get_current_username( class AuthControl: @classmethod - async def is_authed(cls, token: str = Depends(bearer_scheme)) -> Optional["User"]: + async def is_authed( + cls, token: HTTPAuthorizationCredentials | None = Depends(bearer_scheme) + ) -> Optional["User"]: try: # 直接使用 HTTPBearer 提供的 token (已经去掉了 Bearer 前缀) - if not token: + if token is None or not token.credentials: raise HTTPException( status_code=401, detail="Missing authentication token" ) diff --git a/src/core/middlewares.py b/src/core/middlewares.py index 75683ef..b4978a5 100644 --- a/src/core/middlewares.py +++ b/src/core/middlewares.py @@ -2,6 +2,7 @@ import re from collections.abc import AsyncGenerator from datetime import datetime +import traceback from typing import Any from fastapi import FastAPI diff --git a/src/log/log.py b/src/log/log.py index bd06366..fa63c43 100644 --- a/src/log/log.py +++ b/src/log/log.py @@ -1,6 +1,7 @@ import os import sys import json +from datetime import date, datetime from typing import Any, Dict from loguru import logger as loguru_logger @@ -14,7 +15,9 @@ class LoggingConfig: def __init__(self) -> None: self.debug = settings.DEBUG self.level = "DEBUG" if self.debug else "INFO" - self.log_dir = "logs" + self.log_dir = settings.LOGS_ROOT if hasattr(settings, "LOGS_ROOT") else "logs" + self.service_name = getattr(settings, "PROJECT_NAME", "application") + self.environment = getattr(settings, "APP_ENV", "development") self.ensure_log_dir() def ensure_log_dir(self): @@ -22,92 +25,136 @@ def ensure_log_dir(self): if not os.path.exists(self.log_dir): os.makedirs(self.log_dir, exist_ok=True) - def get_log_format(self): - """获取统一的日志格式""" - return ( - "{time:YYYY-MM-DD HH:mm:ss.SSS} | " - "{level: <8} | " - "{name}:{function}:{line} | " - "{message}" + @staticmethod + def _json_default(value: Any) -> Any: + """JSON序列化的默认处理逻辑""" + if isinstance(value, (datetime, date)): + return value.isoformat() + if isinstance(value, (set, tuple)): + return list(value) + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return str(value) + + def _build_log_entry(self, record: Dict[str, Any]) -> Dict[str, Any]: + """构建标准化的日志结构""" + extra: Dict[str, Any] = dict(record.get("extra", {})) + # 避免递归引用 + extra.pop("serialized", None) + + log_entry: Dict[str, Any] = { + "timestamp": record["time"].astimezone().isoformat(), + "level": record["level"].name, + "message": record["message"], + "logger": record["name"], + "module": record["module"], + "function": record["function"], + "line": record["line"], + "process": record["process"].id, + "thread": record["thread"].id, + "service": self.service_name, + "environment": self.environment, + } + + # 支持上下文透传,兼容 request_id / user_id 等字段 + context = extra.pop("context", None) + if isinstance(context, dict): + extra.update(context) + + log_entry.update(extra) + + if record.get("exception"): + exception = record["exception"] + log_entry["exception"] = { + "type": exception.type.__name__ if exception.type else None, + "value": str(exception.value), + "traceback": exception.traceback, + } + + return log_entry + + def _serialize_record(self, record: Dict[str, Any]) -> str: + """序列化日志记录为 JSON 字符串""" + log_entry = self._build_log_entry(record) + return json.dumps( + log_entry, + ensure_ascii=False, + default=self._json_default, + sort_keys=self.debug, + separators=(",", ":") if not self.debug else (",", ": "), ) - def get_file_format(self): - """获取文件日志格式(无颜色)""" - return ( - "{time:YYYY-MM-DD HH:mm:ss.SSS} | " - "{level: <8} | " - "{name}:{function}:{line} | " - "{message}" - ) - - def get_detailed_error_format(self): - """获取详细错误日志格式""" - return ( - "{time:YYYY-MM-DD HH:mm:ss.SSS} | " - "{level: <8} | " - "{name}:{function}:{line} | " - "{message}\n" - ) + def _patch_record(self, record: Dict[str, Any]) -> None: + """为每条日志记录附加序列化后的内容""" + record.setdefault("extra", {}) + record["extra"]["serialized"] = self._serialize_record(record) def setup_logger(self): """配置日志输出""" # 清除默认处理器 loguru_logger.remove() - # 控制台输出(带颜色) + # 启用统一 patcher,确保所有日志输出为 JSON 结构 + loguru_logger.configure(patcher=self._patch_record) + + # 控制台输出(JSON 流) loguru_logger.add( sink=sys.stdout, level=self.level, - format=self.get_log_format(), - colorize=True, + format="{extra[serialized]}", + colorize=False, backtrace=True, - diagnose=True, + diagnose=self.debug, + enqueue=True, ) # 文件输出 - 所有级别日志 loguru_logger.add( sink=f"{self.log_dir}/backend_{{time:YYYY-MM-DD}}.log", level="DEBUG", - format=self.get_file_format(), + format="{extra[serialized]}", rotation="100 MB", retention="30 days", compression="zip", encoding="utf-8", backtrace=True, - diagnose=True, + diagnose=self.debug, + enqueue=True, ) - # 错误日志单独文件 - 使用详细格式 + # 错误日志单独文件 loguru_logger.add( sink=f"{self.log_dir}/backend_error_{{time:YYYY-MM-DD}}.log", level="ERROR", - format=self.get_detailed_error_format(), + format="{extra[serialized]}", rotation="50 MB", retention="90 days", compression="zip", encoding="utf-8", backtrace=True, - diagnose=True, + diagnose=self.debug, + enqueue=True, ) - + # 关键错误日志(CRITICAL级别) loguru_logger.add( sink=f"{self.log_dir}/backend_critical_{{time:YYYY-MM-DD}}.log", level="CRITICAL", - format=self.get_detailed_error_format(), + format="{extra[serialized]}", rotation="10 MB", retention="180 days", compression="zip", encoding="utf-8", backtrace=True, - diagnose=True, + diagnose=self.debug, + enqueue=True, ) # 为所有日志添加默认上下文 # 注意:这里重新绑定会创建新的logger实例 # 记录日志系统启动 - loguru_logger.info("日志系统已启动") + loguru_logger.bind(event="logger_startup").info("日志系统已启动") return loguru_logger diff --git a/src/settings/config.py b/src/settings/config.py index 6577ce2..b715334 100644 --- a/src/settings/config.py +++ b/src/settings/config.py @@ -14,8 +14,8 @@ class Settings(BaseSettings): extra="ignore", ) VERSION: str = "0.1.0" - APP_TITLE: str = "Vue FastAPI Admin" - PROJECT_NAME: str = "Vue FastAPI Admin" + APP_TITLE: str = os.getenv("APP_TITLE", "Vue FastAPI Admin") + PROJECT_NAME: str = os.getenv("PROJECT_NAME", "Vue FastAPI Admin") APP_DESCRIPTION: str = "Description" CORS_ORIGINS: str = os.getenv( diff --git a/tests/conftest.py b/tests/conftest.py index 7ad7cd8..77e6ab9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,45 @@ import asyncio import os +import subprocess +import sys import tempfile +import warnings from collections.abc import AsyncGenerator import pytest from fastapi.testclient import TestClient from httpx import AsyncClient + +os.environ.setdefault("APP_ENV", "testing") +os.environ.setdefault("SWAGGER_UI_PASSWORD", "test_password") +os.environ.setdefault("TESTING", "true") +os.environ.setdefault("APP_TITLE", "FastAPI Backend Template") +os.environ.setdefault("PROJECT_NAME", "FastAPI Backend Template") + +try: # pragma: no cover - fallback for environments without pytest-asyncio + import pytest_asyncio # type: ignore +except ModuleNotFoundError: # pragma: no cover + try: + completed = subprocess.run( + [sys.executable, "-m", "pip", "install", "pytest-asyncio>=0.23,<0.24"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + warnings.warn( + "Installed pytest-asyncio dynamically to enable async fixtures." + ) + import pytest_asyncio # type: ignore # noqa: F401 + except Exception as exc: # pragma: no cover + warnings.warn( + f"pytest-asyncio is required for async tests but could not be installed: {exc}" + ) + +if "pytest_asyncio" in sys.modules: # pragma: no cover - plugin auto-registration helper + pytest_plugins = ("pytest_asyncio",) + from src import app from tortoise import Tortoise