Files
full-stack-doc/backend/app/dependencies/auth.py
2025-10-14 20:05:29 +08:00

217 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.security import verify_token
from app.services.user_service import UserService
from app.schemas.auth import UserResponse
from app.models.user import User
# Bearer token 认证方案
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> User:
"""获取当前认证用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"code": "INVALID_AUTHENTICATION",
"message": "无法验证凭据"
},
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 验证令牌
payload = verify_token(credentials.credentials, "access")
if payload is None:
raise credentials_exception
# 获取用户ID
user_id: str = payload.get("sub")
if user_id is None:
raise credentials_exception
user_id = int(user_id)
except (ValueError, TypeError):
raise credentials_exception
# 获取用户信息
user_service = UserService(db)
user = user_service.get_user_by_id(user_id)
if user is None:
raise credentials_exception
return user
async def get_current_user_from_headers(
request: Request,
db: Session = Depends(get_db)
) -> User:
"""从请求头中获取当前认证用户支持userId和token"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"code": "INVALID_AUTHENTICATION",
"message": "无法验证凭据"
},
)
try:
# 尝试从多种方式获取token
token = None
user_id = None
# 1. 从Authorization header获取
authorization = request.headers.get("Authorization")
if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1]
# 2. 从token header获取
if not token:
token = request.headers.get("token")
# 3. 从userId header获取用户ID
user_id_str = request.headers.get("userId")
if user_id_str:
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
pass
# 如果没有token认证失败
if not token:
raise credentials_exception
# 验证令牌
payload = verify_token(token, "access")
if payload is None:
raise credentials_exception
# 获取token中的用户ID
token_user_id: str = payload.get("sub")
if token_user_id is None:
raise credentials_exception
token_user_id = int(token_user_id)
# 如果header中有userId验证两个ID是否一致
if user_id is not None and user_id != token_user_id:
raise credentials_exception
# 使用token中的用户ID
final_user_id = token_user_id
except (ValueError, TypeError):
raise credentials_exception
# 获取用户信息
user_service = UserService(db)
user = user_service.get_user_by_id(final_user_id)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""获取当前活跃用户"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "INACTIVE_USER",
"message": "用户账户已被禁用"
}
)
return current_user
async def get_current_active_user_from_headers(
current_user: User = Depends(get_current_user_from_headers)
) -> User:
"""从请求头获取当前活跃用户"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "INACTIVE_USER",
"message": "用户账户已被禁用"
}
)
return current_user
async def get_current_verified_user(
current_user: User = Depends(get_current_active_user)
) -> User:
"""获取当前已验证用户"""
if not current_user.is_verified:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "UNVERIFIED_USER",
"message": "用户账户未验证"
}
)
return current_user
async def get_current_user_response(
current_user: User = Depends(get_current_active_user),
db: Session = Depends(get_db)
) -> UserResponse:
"""获取当前用户信息(响应格式)"""
user_service = UserService(db)
return user_service.to_user_response(current_user)
async def get_current_user_response_from_headers(
current_user: User = Depends(get_current_active_user_from_headers),
db: Session = Depends(get_db)
) -> UserResponse:
"""从请求头获取当前用户信息(响应格式)"""
user_service = UserService(db)
return user_service.to_user_response(current_user)
# 可选的认证依赖项(不强制要求认证)
async def get_optional_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> User | None:
"""获取可选的当前用户(认证失败时不抛出异常)"""
try:
return await get_current_user(credentials, db)
except HTTPException:
return None
# 权限检查函数
def check_user_permission(required_permission: str = None):
"""检查用户权限的装饰器工厂"""
async def permission_checker(
current_user: User = Depends(get_current_active_user)
) -> User:
# 这里可以根据需要实现更复杂的权限检查逻辑
# 目前只检查用户是否活跃
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": "PERMISSION_DENIED",
"message": "权限不足"
}
)
return current_user
return permission_checker
# 管理员权限检查(预留)
async def get_admin_user(
current_user: User = Depends(get_current_active_user)
) -> User:
"""获取管理员用户(预留功能)"""
# 这里可以添加管理员权限检查逻辑
# 例如检查用户是否有管理员角色
return current_user