Files
full-stack-doc/backend/app/services/user_service.py
2025-10-14 20:05:29 +08:00

252 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, status
from typing import Optional
from datetime import datetime, timedelta
from app.models.user import User
from app.core.security import get_password_hash, verify_password, create_access_token, create_refresh_token
from app.schemas.auth import UserRegister, UserLogin, UserResponse
from app.core.config import settings
from app.exceptions.auth import UsernameAlreadyExistsException
class UserService:
"""用户服务类"""
def __init__(self, db: Session):
self.db = db
def create_user(self, user_data: UserRegister) -> User:
"""创建新用户"""
try:
# 检查用户名是否已存在
print(f"[DEBUG] Checking if username exists: {user_data.username}")
existing_user_by_username = self.get_user_by_username(user_data.username)
if existing_user_by_username:
print(f"[DEBUG] Username already exists: {existing_user_by_username}")
raise UsernameAlreadyExistsException()
print(f"[DEBUG] Username check passed")
# 邮箱允许重复,不再检查邮箱是否已存在
print(f"[DEBUG] Email uniqueness check skipped (emails can be duplicated)")
# 创建新用户
print(f"[DEBUG] About to hash password...")
try:
hashed_password = get_password_hash(user_data.password)
print(f"[DEBUG] Password hashed successfully")
except Exception as hash_error:
print(f"[ERROR] Password hashing failed: {hash_error}")
print(f"[ERROR] Error type: {type(hash_error)}")
import traceback
traceback.print_exc()
raise
print(f"[DEBUG] Creating User model...")
db_user = User(
username=user_data.username,
email=user_data.email,
password_hash=hashed_password,
is_active=True,
is_verified=False
)
print(f"[DEBUG] User model created")
self.db.add(db_user)
self.db.commit()
self.db.refresh(db_user)
return db_user
except IntegrityError as e:
self.db.rollback()
if "username" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "USERNAME_EXISTS", "message": "用户名已存在"}
)
elif "email" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "EMAIL_EXISTS", "message": "邮箱已被注册"}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "INTEGRITY_ERROR", "message": "数据完整性错误"}
)
except (HTTPException, UsernameAlreadyExistsException):
# 重新抛出HTTPException不要覆盖自定义错误信息
self.db.rollback()
raise
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "CREATION_FAILED", "message": "用户创建失败"}
)
def authenticate_user(self, login_data: UserLogin) -> Optional[User]:
"""验证用户登录"""
# 尝试通过用户名查找用户
user = self.get_user_by_username(login_data.username)
# 如果用户名找不到,尝试通过邮箱查找
if not user:
user = self.get_user_by_email(login_data.username)
# 如果找到用户且密码正确
if user and verify_password(login_data.password, user.password_hash):
# 更新最后登录时间
user.last_login_at = datetime.utcnow()
self.db.commit()
return user
return None
def get_user_by_id(self, user_id: int) -> Optional[User]:
"""根据ID获取用户"""
return self.db.query(User).filter(User.id == user_id, User.is_active == True).first()
def get_user_by_username(self, username: str) -> Optional[User]:
"""根据用户名获取用户"""
return self.db.query(User).filter(User.username == username, User.is_active == True).first()
def get_user_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
return self.db.query(User).filter(User.email == email, User.is_active == True).first()
def update_user(self, user_id: int, **kwargs) -> Optional[User]:
"""更新用户信息"""
user = self.get_user_by_id(user_id)
if not user:
return None
try:
for key, value in kwargs.items():
if hasattr(user, key) and value is not None:
setattr(user, key, value)
self.db.commit()
self.db.refresh(user)
return user
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "UPDATE_FAILED", "message": "用户信息更新失败"}
)
def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
"""修改用户密码"""
user = self.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"code": "USER_NOT_FOUND", "message": "用户不存在"}
)
# 验证当前密码
if not verify_password(current_password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "INVALID_PASSWORD", "message": "当前密码不正确"}
)
try:
# 更新密码
user.password_hash = get_password_hash(new_password)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "PASSWORD_CHANGE_FAILED", "message": "密码修改失败"}
)
def deactivate_user(self, user_id: int) -> bool:
"""停用用户"""
user = self.get_user_by_id(user_id)
if not user:
return False
try:
user.is_active = False
self.db.commit()
return True
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "DEACTIVATION_FAILED", "message": "用户停用失败"}
)
def update_storage_usage(self, user_id: int, size_change: int) -> bool:
"""更新用户存储使用量"""
user = self.get_user_by_id(user_id)
if not user:
return False
try:
new_usage = user.storage_used + size_change
# 检查存储配额
if new_usage > user.storage_quota:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "STORAGE_EXCEEDED", "message": "存储空间不足"}
)
user.storage_used = max(0, new_usage) # 确保不小于0
self.db.commit()
return True
except HTTPException:
raise
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "STORAGE_UPDATE_FAILED", "message": "存储空间更新失败"}
)
def create_user_tokens(self, user: User) -> dict:
"""为用户创建访问令牌和刷新令牌"""
access_token_expires = timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
access_token = create_access_token(
data={"sub": str(user.id), "username": user.username, "email": user.email},
expires_delta=access_token_expires
)
refresh_token = create_refresh_token(
data={"sub": str(user.id)},
expires_delta=refresh_token_expires
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": settings.JWT_EXPIRE_MINUTES * 60
}
def to_user_response(self, user: User) -> UserResponse:
"""将用户模型转换为响应模型"""
return UserResponse(
id=user.id,
username=user.username,
email=user.email,
avatar_url=user.avatar_url,
storage_quota=user.storage_quota,
storage_used=user.storage_used,
is_active=user.is_active,
is_verified=user.is_verified,
last_login_at=user.last_login_at,
created_at=user.created_at
)