from fastapi import Depends, HTTPException, status, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.orm import Session from app.core.database import get_db from app.core.security import verify_token from app.services.user_service import UserService from app.schemas.auth import UserResponse from app.models.user import User # Bearer token 认证方案 security = HTTPBearer() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db) ) -> User: """获取当前认证用户""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={ "code": "INVALID_AUTHENTICATION", "message": "无法验证凭据" }, headers={"WWW-Authenticate": "Bearer"}, ) try: # 验证令牌 payload = verify_token(credentials.credentials, "access") if payload is None: raise credentials_exception # 获取用户ID user_id: str = payload.get("sub") if user_id is None: raise credentials_exception user_id = int(user_id) except (ValueError, TypeError): raise credentials_exception # 获取用户信息 user_service = UserService(db) user = user_service.get_user_by_id(user_id) if user is None: raise credentials_exception return user async def get_current_user_from_headers( request: Request, db: Session = Depends(get_db) ) -> User: """从请求头中获取当前认证用户(支持userId和token)""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={ "code": "INVALID_AUTHENTICATION", "message": "无法验证凭据" }, ) try: # 尝试从多种方式获取token token = None user_id = None # 1. 从Authorization header获取 authorization = request.headers.get("Authorization") if authorization and authorization.startswith("Bearer "): token = authorization.split(" ")[1] # 2. 从token header获取 if not token: token = request.headers.get("token") # 3. 从userId header获取用户ID user_id_str = request.headers.get("userId") if user_id_str: try: user_id = int(user_id_str) except (ValueError, TypeError): pass # 如果没有token,认证失败 if not token: raise credentials_exception # 验证令牌 payload = verify_token(token, "access") if payload is None: raise credentials_exception # 获取token中的用户ID token_user_id: str = payload.get("sub") if token_user_id is None: raise credentials_exception token_user_id = int(token_user_id) # 如果header中有userId,验证两个ID是否一致 if user_id is not None and user_id != token_user_id: raise credentials_exception # 使用token中的用户ID final_user_id = token_user_id except (ValueError, TypeError): raise credentials_exception # 获取用户信息 user_service = UserService(db) user = user_service.get_user_by_id(final_user_id) if user is None: raise credentials_exception return user async def get_current_active_user( current_user: User = Depends(get_current_user) ) -> User: """获取当前活跃用户""" if not current_user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={ "code": "INACTIVE_USER", "message": "用户账户已被禁用" } ) return current_user async def get_current_active_user_from_headers( current_user: User = Depends(get_current_user_from_headers) ) -> User: """从请求头获取当前活跃用户""" if not current_user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={ "code": "INACTIVE_USER", "message": "用户账户已被禁用" } ) return current_user async def get_current_verified_user( current_user: User = Depends(get_current_active_user) ) -> User: """获取当前已验证用户""" if not current_user.is_verified: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={ "code": "UNVERIFIED_USER", "message": "用户账户未验证" } ) return current_user async def get_current_user_response( current_user: User = Depends(get_current_active_user), db: Session = Depends(get_db) ) -> UserResponse: """获取当前用户信息(响应格式)""" user_service = UserService(db) return user_service.to_user_response(current_user) async def get_current_user_response_from_headers( current_user: User = Depends(get_current_active_user_from_headers), db: Session = Depends(get_db) ) -> UserResponse: """从请求头获取当前用户信息(响应格式)""" user_service = UserService(db) return user_service.to_user_response(current_user) # 可选的认证依赖项(不强制要求认证) async def get_optional_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db) ) -> User | None: """获取可选的当前用户(认证失败时不抛出异常)""" try: return await get_current_user(credentials, db) except HTTPException: return None # 权限检查函数 def check_user_permission(required_permission: str = None): """检查用户权限的装饰器工厂""" async def permission_checker( current_user: User = Depends(get_current_active_user) ) -> User: # 这里可以根据需要实现更复杂的权限检查逻辑 # 目前只检查用户是否活跃 if not current_user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail={ "code": "PERMISSION_DENIED", "message": "权限不足" } ) return current_user return permission_checker # 管理员权限检查(预留) async def get_admin_user( current_user: User = Depends(get_current_active_user) ) -> User: """获取管理员用户(预留功能)""" # 这里可以添加管理员权限检查逻辑 # 例如检查用户是否有管理员角色 return current_user