252 lines
9.3 KiB
Python
252 lines
9.3 KiB
Python
from sqlalchemy.orm import Session
|
||
from sqlalchemy.exc import IntegrityError
|
||
from fastapi import HTTPException, status
|
||
from typing import Optional
|
||
from datetime import datetime, timedelta
|
||
|
||
from app.models.user import User
|
||
from app.core.security import get_password_hash, verify_password, create_access_token, create_refresh_token
|
||
from app.schemas.auth import UserRegister, UserLogin, UserResponse
|
||
from app.core.config import settings
|
||
from app.exceptions.auth import UsernameAlreadyExistsException
|
||
|
||
class UserService:
|
||
"""用户服务类"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def create_user(self, user_data: UserRegister) -> User:
|
||
"""创建新用户"""
|
||
try:
|
||
# 检查用户名是否已存在
|
||
print(f"[DEBUG] Checking if username exists: {user_data.username}")
|
||
existing_user_by_username = self.get_user_by_username(user_data.username)
|
||
if existing_user_by_username:
|
||
print(f"[DEBUG] Username already exists: {existing_user_by_username}")
|
||
raise UsernameAlreadyExistsException()
|
||
print(f"[DEBUG] Username check passed")
|
||
|
||
# 邮箱允许重复,不再检查邮箱是否已存在
|
||
print(f"[DEBUG] Email uniqueness check skipped (emails can be duplicated)")
|
||
|
||
# 创建新用户
|
||
print(f"[DEBUG] About to hash password...")
|
||
try:
|
||
hashed_password = get_password_hash(user_data.password)
|
||
print(f"[DEBUG] Password hashed successfully")
|
||
except Exception as hash_error:
|
||
print(f"[ERROR] Password hashing failed: {hash_error}")
|
||
print(f"[ERROR] Error type: {type(hash_error)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise
|
||
|
||
print(f"[DEBUG] Creating User model...")
|
||
db_user = User(
|
||
username=user_data.username,
|
||
email=user_data.email,
|
||
password_hash=hashed_password,
|
||
is_active=True,
|
||
is_verified=False
|
||
)
|
||
print(f"[DEBUG] User model created")
|
||
|
||
self.db.add(db_user)
|
||
self.db.commit()
|
||
self.db.refresh(db_user)
|
||
|
||
return db_user
|
||
|
||
except IntegrityError as e:
|
||
self.db.rollback()
|
||
if "username" in str(e.orig):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={"code": "USERNAME_EXISTS", "message": "用户名已存在"}
|
||
)
|
||
elif "email" in str(e.orig):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={"code": "EMAIL_EXISTS", "message": "邮箱已被注册"}
|
||
)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={"code": "INTEGRITY_ERROR", "message": "数据完整性错误"}
|
||
)
|
||
except (HTTPException, UsernameAlreadyExistsException):
|
||
# 重新抛出HTTPException(不要覆盖自定义错误信息)
|
||
self.db.rollback()
|
||
raise
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "CREATION_FAILED", "message": "用户创建失败"}
|
||
)
|
||
|
||
def authenticate_user(self, login_data: UserLogin) -> Optional[User]:
|
||
"""验证用户登录"""
|
||
# 尝试通过用户名查找用户
|
||
user = self.get_user_by_username(login_data.username)
|
||
|
||
# 如果用户名找不到,尝试通过邮箱查找
|
||
if not user:
|
||
user = self.get_user_by_email(login_data.username)
|
||
|
||
# 如果找到用户且密码正确
|
||
if user and verify_password(login_data.password, user.password_hash):
|
||
# 更新最后登录时间
|
||
user.last_login_at = datetime.utcnow()
|
||
self.db.commit()
|
||
return user
|
||
|
||
return None
|
||
|
||
def get_user_by_id(self, user_id: int) -> Optional[User]:
|
||
"""根据ID获取用户"""
|
||
return self.db.query(User).filter(User.id == user_id, User.is_active == True).first()
|
||
|
||
def get_user_by_username(self, username: str) -> Optional[User]:
|
||
"""根据用户名获取用户"""
|
||
return self.db.query(User).filter(User.username == username, User.is_active == True).first()
|
||
|
||
def get_user_by_email(self, email: str) -> Optional[User]:
|
||
"""根据邮箱获取用户"""
|
||
return self.db.query(User).filter(User.email == email, User.is_active == True).first()
|
||
|
||
def update_user(self, user_id: int, **kwargs) -> Optional[User]:
|
||
"""更新用户信息"""
|
||
user = self.get_user_by_id(user_id)
|
||
if not user:
|
||
return None
|
||
|
||
try:
|
||
for key, value in kwargs.items():
|
||
if hasattr(user, key) and value is not None:
|
||
setattr(user, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(user)
|
||
return user
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "UPDATE_FAILED", "message": "用户信息更新失败"}
|
||
)
|
||
|
||
def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
|
||
"""修改用户密码"""
|
||
user = self.get_user_by_id(user_id)
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail={"code": "USER_NOT_FOUND", "message": "用户不存在"}
|
||
)
|
||
|
||
# 验证当前密码
|
||
if not verify_password(current_password, user.password_hash):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={"code": "INVALID_PASSWORD", "message": "当前密码不正确"}
|
||
)
|
||
|
||
try:
|
||
# 更新密码
|
||
user.password_hash = get_password_hash(new_password)
|
||
self.db.commit()
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "PASSWORD_CHANGE_FAILED", "message": "密码修改失败"}
|
||
)
|
||
|
||
def deactivate_user(self, user_id: int) -> bool:
|
||
"""停用用户"""
|
||
user = self.get_user_by_id(user_id)
|
||
if not user:
|
||
return False
|
||
|
||
try:
|
||
user.is_active = False
|
||
self.db.commit()
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "DEACTIVATION_FAILED", "message": "用户停用失败"}
|
||
)
|
||
|
||
def update_storage_usage(self, user_id: int, size_change: int) -> bool:
|
||
"""更新用户存储使用量"""
|
||
user = self.get_user_by_id(user_id)
|
||
if not user:
|
||
return False
|
||
|
||
try:
|
||
new_usage = user.storage_used + size_change
|
||
|
||
# 检查存储配额
|
||
if new_usage > user.storage_quota:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={"code": "STORAGE_EXCEEDED", "message": "存储空间不足"}
|
||
)
|
||
|
||
user.storage_used = max(0, new_usage) # 确保不小于0
|
||
self.db.commit()
|
||
return True
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "STORAGE_UPDATE_FAILED", "message": "存储空间更新失败"}
|
||
)
|
||
|
||
def create_user_tokens(self, user: User) -> dict:
|
||
"""为用户创建访问令牌和刷新令牌"""
|
||
access_token_expires = timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
|
||
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
|
||
|
||
access_token = create_access_token(
|
||
data={"sub": str(user.id), "username": user.username, "email": user.email},
|
||
expires_delta=access_token_expires
|
||
)
|
||
|
||
refresh_token = create_refresh_token(
|
||
data={"sub": str(user.id)},
|
||
expires_delta=refresh_token_expires
|
||
)
|
||
|
||
return {
|
||
"access_token": access_token,
|
||
"refresh_token": refresh_token,
|
||
"token_type": "bearer",
|
||
"expires_in": settings.JWT_EXPIRE_MINUTES * 60
|
||
}
|
||
|
||
def to_user_response(self, user: User) -> UserResponse:
|
||
"""将用户模型转换为响应模型"""
|
||
return UserResponse(
|
||
id=user.id,
|
||
username=user.username,
|
||
email=user.email,
|
||
avatar_url=user.avatar_url,
|
||
storage_quota=user.storage_quota,
|
||
storage_used=user.storage_used,
|
||
is_active=user.is_active,
|
||
is_verified=user.is_verified,
|
||
last_login_at=user.last_login_at,
|
||
created_at=user.created_at
|
||
) |