from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError from fastapi import HTTPException, status from typing import Optional from datetime import datetime, timedelta from app.models.user import User from app.core.security import get_password_hash, verify_password, create_access_token, create_refresh_token from app.schemas.auth import UserRegister, UserLogin, UserResponse from app.core.config import settings from app.exceptions.auth import UsernameAlreadyExistsException class UserService: """用户服务类""" def __init__(self, db: Session): self.db = db def create_user(self, user_data: UserRegister) -> User: """创建新用户""" try: # 检查用户名是否已存在 print(f"[DEBUG] Checking if username exists: {user_data.username}") existing_user_by_username = self.get_user_by_username(user_data.username) if existing_user_by_username: print(f"[DEBUG] Username already exists: {existing_user_by_username}") raise UsernameAlreadyExistsException() print(f"[DEBUG] Username check passed") # 邮箱允许重复,不再检查邮箱是否已存在 print(f"[DEBUG] Email uniqueness check skipped (emails can be duplicated)") # 创建新用户 print(f"[DEBUG] About to hash password...") try: hashed_password = get_password_hash(user_data.password) print(f"[DEBUG] Password hashed successfully") except Exception as hash_error: print(f"[ERROR] Password hashing failed: {hash_error}") print(f"[ERROR] Error type: {type(hash_error)}") import traceback traceback.print_exc() raise print(f"[DEBUG] Creating User model...") db_user = User( username=user_data.username, email=user_data.email, password_hash=hashed_password, is_active=True, is_verified=False ) print(f"[DEBUG] User model created") self.db.add(db_user) self.db.commit() self.db.refresh(db_user) return db_user except IntegrityError as e: self.db.rollback() if "username" in str(e.orig): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"code": "USERNAME_EXISTS", "message": "用户名已存在"} ) elif "email" in str(e.orig): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"code": "EMAIL_EXISTS", "message": "邮箱已被注册"} ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"code": "INTEGRITY_ERROR", "message": "数据完整性错误"} ) except (HTTPException, UsernameAlreadyExistsException): # 重新抛出HTTPException(不要覆盖自定义错误信息) self.db.rollback() raise except Exception as e: self.db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "CREATION_FAILED", "message": "用户创建失败"} ) def authenticate_user(self, login_data: UserLogin) -> Optional[User]: """验证用户登录""" # 尝试通过用户名查找用户 user = self.get_user_by_username(login_data.username) # 如果用户名找不到,尝试通过邮箱查找 if not user: user = self.get_user_by_email(login_data.username) # 如果找到用户且密码正确 if user and verify_password(login_data.password, user.password_hash): # 更新最后登录时间 user.last_login_at = datetime.utcnow() self.db.commit() return user return None def get_user_by_id(self, user_id: int) -> Optional[User]: """根据ID获取用户""" return self.db.query(User).filter(User.id == user_id, User.is_active == True).first() def get_user_by_username(self, username: str) -> Optional[User]: """根据用户名获取用户""" return self.db.query(User).filter(User.username == username, User.is_active == True).first() def get_user_by_email(self, email: str) -> Optional[User]: """根据邮箱获取用户""" return self.db.query(User).filter(User.email == email, User.is_active == True).first() def update_user(self, user_id: int, **kwargs) -> Optional[User]: """更新用户信息""" user = self.get_user_by_id(user_id) if not user: return None try: for key, value in kwargs.items(): if hasattr(user, key) and value is not None: setattr(user, key, value) self.db.commit() self.db.refresh(user) return user except Exception as e: self.db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "UPDATE_FAILED", "message": "用户信息更新失败"} ) def change_password(self, user_id: int, current_password: str, new_password: str) -> bool: """修改用户密码""" user = self.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={"code": "USER_NOT_FOUND", "message": "用户不存在"} ) # 验证当前密码 if not verify_password(current_password, user.password_hash): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"code": "INVALID_PASSWORD", "message": "当前密码不正确"} ) try: # 更新密码 user.password_hash = get_password_hash(new_password) self.db.commit() return True except Exception as e: self.db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "PASSWORD_CHANGE_FAILED", "message": "密码修改失败"} ) def deactivate_user(self, user_id: int) -> bool: """停用用户""" user = self.get_user_by_id(user_id) if not user: return False try: user.is_active = False self.db.commit() return True except Exception as e: self.db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "DEACTIVATION_FAILED", "message": "用户停用失败"} ) def update_storage_usage(self, user_id: int, size_change: int) -> bool: """更新用户存储使用量""" user = self.get_user_by_id(user_id) if not user: return False try: new_usage = user.storage_used + size_change # 检查存储配额 if new_usage > user.storage_quota: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"code": "STORAGE_EXCEEDED", "message": "存储空间不足"} ) user.storage_used = max(0, new_usage) # 确保不小于0 self.db.commit() return True except HTTPException: raise except Exception as e: self.db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "STORAGE_UPDATE_FAILED", "message": "存储空间更新失败"} ) def create_user_tokens(self, user: User) -> dict: """为用户创建访问令牌和刷新令牌""" access_token_expires = timedelta(minutes=settings.JWT_EXPIRE_MINUTES) refresh_token_expires = timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS) access_token = create_access_token( data={"sub": str(user.id), "username": user.username, "email": user.email}, expires_delta=access_token_expires ) refresh_token = create_refresh_token( data={"sub": str(user.id)}, expires_delta=refresh_token_expires ) return { "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer", "expires_in": settings.JWT_EXPIRE_MINUTES * 60 } def to_user_response(self, user: User) -> UserResponse: """将用户模型转换为响应模型""" return UserResponse( id=user.id, username=user.username, email=user.email, avatar_url=user.avatar_url, storage_quota=user.storage_quota, storage_used=user.storage_used, is_active=user.is_active, is_verified=user.is_verified, last_login_at=user.last_login_at, created_at=user.created_at )