初次提交
This commit is contained in:
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
418
backend/app/services/file_service.py
Normal file
418
backend/app/services/file_service.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import os
|
||||
import hashlib
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, desc
|
||||
from fastapi import UploadFile, HTTPException, status
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.file import File
|
||||
from app.models.user import User
|
||||
from app.schemas.file import (
|
||||
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
|
||||
FileResponse, FileListResponse, StorageInfo, FileInfo
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.exceptions.file import (
|
||||
FileTooLargeException, StorageQuotaExceededException,
|
||||
FileAlreadyExistsException, FileNotFoundException,
|
||||
InvalidFileTypeException
|
||||
)
|
||||
|
||||
class FileService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
# 使用绝对路径确保文件保存正确
|
||||
self.upload_dir = os.path.abspath(settings.UPLOAD_DIR)
|
||||
self.max_file_size = settings.MAX_FILE_SIZE
|
||||
self.allowed_extensions = settings.ALLOWED_EXTENSIONS
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# 确保上传目录存在
|
||||
os.makedirs(self.upload_dir, exist_ok=True)
|
||||
self.logger.info(f"Upload directory (absolute): {self.upload_dir}")
|
||||
self.logger.info(f"Upload directory exists: {os.path.exists(self.upload_dir)}")
|
||||
self.logger.info(f"Current working directory: {os.getcwd()}")
|
||||
|
||||
def _calculate_file_hash(self, file_content: bytes) -> str:
|
||||
"""计算文件的SHA-256哈希值"""
|
||||
return hashlib.sha256(file_content).hexdigest()
|
||||
|
||||
def _generate_unique_filename(self, original_filename: str) -> str:
|
||||
"""生成唯一的文件名"""
|
||||
file_extension = os.path.splitext(original_filename)[1]
|
||||
unique_id = str(uuid.uuid4())
|
||||
return f"{unique_id}{file_extension}"
|
||||
|
||||
def _validate_file(self, file: UploadFile, user: User) -> None:
|
||||
"""验证文件"""
|
||||
# 检查文件大小
|
||||
if hasattr(file, 'size') and file.size > self.max_file_size:
|
||||
raise FileTooLargeException(file.size, self.max_file_size)
|
||||
|
||||
# 检查文件扩展名
|
||||
if self.allowed_extensions:
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
if file_extension not in self.allowed_extensions:
|
||||
raise InvalidFileTypeException(file_extension)
|
||||
|
||||
# 检查存储配额
|
||||
if hasattr(file, 'size') and not user.is_storage_available(file.size):
|
||||
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file.size)
|
||||
|
||||
def _save_file_to_disk(self, file: UploadFile, unique_filename: str) -> Tuple[str, bytes]:
|
||||
"""保存文件到磁盘"""
|
||||
file_path = os.path.join(self.upload_dir, unique_filename)
|
||||
|
||||
# 读取文件内容
|
||||
file_content = file.file.read()
|
||||
|
||||
# 强制输出调试信息
|
||||
import sys
|
||||
message = f"[CRITICAL] About to save {len(file_content)} bytes to {file_path}"
|
||||
print(message, flush=True)
|
||||
sys.stdout.flush()
|
||||
message2 = f"[CRITICAL] Content preview: {file_content[:50] if file_content else 'EMPTY'}"
|
||||
print(message2, flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# 立即验证内容是否为空
|
||||
if not file_content:
|
||||
print("[CRITICAL] FILE CONTENT IS EMPTY!")
|
||||
raise ValueError("File content is empty!")
|
||||
|
||||
# 使用临时文件方法确保写入成功
|
||||
temp_path = file_path + '.tmp'
|
||||
try:
|
||||
# 先写入临时文件
|
||||
with open(temp_path, "wb") as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
os.fsync(temp_file.fileno())
|
||||
|
||||
# 验证临时文件
|
||||
if os.path.exists(temp_path):
|
||||
temp_size = os.path.getsize(temp_path)
|
||||
if temp_size != len(file_content):
|
||||
raise Exception(f"Temporary file size mismatch: {temp_size} != {len(file_content)}")
|
||||
|
||||
# 重命名为最终文件名(原子操作)
|
||||
if os.name == 'nt': # Windows
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
os.rename(temp_path, file_path)
|
||||
|
||||
# 最终验证
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception("File was not created after rename")
|
||||
|
||||
final_size = os.path.getsize(file_path)
|
||||
if final_size != len(file_content):
|
||||
raise Exception(f"Final file size mismatch: {final_size} != {len(file_content)}")
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
print(f"[ERROR] File save failed: {e}")
|
||||
raise Exception(f"Failed to save file: {e}")
|
||||
|
||||
# 重置文件指针
|
||||
file.file.seek(0)
|
||||
|
||||
return file_path, file_content
|
||||
|
||||
def upload_file(self, file: UploadFile, user: User, upload_request: FileUploadRequest) -> File:
|
||||
"""上传文件"""
|
||||
try:
|
||||
# 验证文件
|
||||
self._validate_file(file, user)
|
||||
|
||||
# 生成唯一文件名
|
||||
unique_filename = self._generate_unique_filename(file.filename)
|
||||
|
||||
# 保存文件到磁盘
|
||||
file_path, file_content = self._save_file_to_disk(file, unique_filename)
|
||||
file_size = len(file_content)
|
||||
|
||||
# 再次检查文件大小(如果没有size属性)
|
||||
if file_size > self.max_file_size:
|
||||
# 删除已保存的文件
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise FileTooLargeException(file_size, self.max_file_size)
|
||||
|
||||
# 检查存储配额
|
||||
if not user.is_storage_available(file_size):
|
||||
# 删除已保存的文件
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file_size)
|
||||
|
||||
# 计算文件哈希
|
||||
file_hash = self._calculate_file_hash(file_content)
|
||||
|
||||
# 检查文件是否已存在(基于哈希值)
|
||||
existing_file = self.db.query(File).filter(
|
||||
and_(
|
||||
File.user_id == user.id,
|
||||
File.file_hash == file_hash
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_file:
|
||||
# 删除刚保存的文件,因为已存在相同内容的文件
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise FileAlreadyExistsException(existing_file.original_filename)
|
||||
|
||||
# 创建文件记录
|
||||
db_file = File(
|
||||
user_id=user.id,
|
||||
filename=unique_filename,
|
||||
original_filename=file.filename,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
mime_type=file.content_type or 'application/octet-stream',
|
||||
file_hash=file_hash,
|
||||
description=upload_request.description,
|
||||
tags=upload_request.tags,
|
||||
is_public=upload_request.is_public
|
||||
)
|
||||
|
||||
self.db.add(db_file)
|
||||
|
||||
# 更新用户存储使用量
|
||||
user.storage_used += file_size
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(db_file)
|
||||
|
||||
return db_file
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
# 如果保存了文件但数据库操作失败,删除文件
|
||||
if 'file_path' in locals() and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise e
|
||||
|
||||
def get_user_files(self, user_id: int, page: int = 1, size: int = 20) -> FileListResponse:
|
||||
"""获取用户的文件列表"""
|
||||
offset = (page - 1) * size
|
||||
|
||||
query = self.db.query(File).filter(File.user_id == user_id)
|
||||
|
||||
total = query.count()
|
||||
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
|
||||
|
||||
pages = (total + size - 1) // size
|
||||
|
||||
return FileListResponse(
|
||||
files=[FileResponse(
|
||||
id=file.id,
|
||||
user_id=file.user_id,
|
||||
filename=file.filename,
|
||||
original_filename=file.original_filename,
|
||||
file_size=file.file_size,
|
||||
mime_type=file.mime_type,
|
||||
file_hash=file.file_hash,
|
||||
is_public=file.is_public,
|
||||
download_count=file.download_count,
|
||||
description=file.description,
|
||||
tags=file.tags,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
last_accessed_at=file.last_accessed_at
|
||||
) for file in files],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size,
|
||||
pages=pages
|
||||
)
|
||||
|
||||
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[File]:
|
||||
"""根据ID获取文件"""
|
||||
return self.db.query(File).filter(
|
||||
and_(
|
||||
File.id == file_id,
|
||||
File.user_id == user_id
|
||||
)
|
||||
).first()
|
||||
|
||||
def update_file(self, file_id: int, user_id: int, update_request: FileUpdateRequest) -> Optional[File]:
|
||||
"""更新文件信息"""
|
||||
db_file = self.get_file_by_id(file_id, user_id)
|
||||
if not db_file:
|
||||
raise FileNotFoundException()
|
||||
|
||||
# 更新字段
|
||||
if update_request.description is not None:
|
||||
db_file.description = update_request.description
|
||||
if update_request.tags is not None:
|
||||
db_file.tags = update_request.tags
|
||||
if update_request.is_public is not None:
|
||||
db_file.is_public = update_request.is_public
|
||||
|
||||
db_file.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(db_file)
|
||||
|
||||
return db_file
|
||||
|
||||
def delete_file(self, file_id: int, user_id: int) -> bool:
|
||||
"""删除文件"""
|
||||
db_file = self.get_file_by_id(file_id, user_id)
|
||||
if not db_file:
|
||||
raise FileNotFoundException()
|
||||
|
||||
try:
|
||||
# 删除磁盘上的文件
|
||||
if os.path.exists(db_file.file_path):
|
||||
os.remove(db_file.file_path)
|
||||
|
||||
# 更新用户存储使用量
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if user:
|
||||
user.storage_used = max(0, user.storage_used - db_file.file_size)
|
||||
|
||||
# 删除数据库记录
|
||||
self.db.delete(db_file)
|
||||
self.db.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise e
|
||||
|
||||
def search_files(self, user_id: int, search_request: FileSearchRequest,
|
||||
page: int = 1, size: int = 20) -> FileListResponse:
|
||||
"""搜索文件"""
|
||||
offset = (page - 1) * size
|
||||
|
||||
query = self.db.query(File).filter(File.user_id == user_id)
|
||||
|
||||
# 文件名搜索
|
||||
if search_request.filename:
|
||||
query = query.filter(
|
||||
File.original_filename.ilike(f"%{search_request.filename}%")
|
||||
)
|
||||
|
||||
# 标签搜索
|
||||
if search_request.tags:
|
||||
tag_list = [tag.strip() for tag in search_request.tags.split(',')]
|
||||
tag_conditions = []
|
||||
for tag in tag_list:
|
||||
tag_conditions.append(File.tags.ilike(f"%{tag}%"))
|
||||
if tag_conditions:
|
||||
query = query.filter(or_(*tag_conditions))
|
||||
|
||||
# MIME类型过滤
|
||||
if search_request.mime_type:
|
||||
query = query.filter(File.mime_type == search_request.mime_type)
|
||||
|
||||
# 公开状态过滤
|
||||
if search_request.is_public is not None:
|
||||
query = query.filter(File.is_public == search_request.is_public)
|
||||
|
||||
# 日期范围过滤
|
||||
if search_request.start_date:
|
||||
query = query.filter(File.created_at >= search_request.start_date)
|
||||
if search_request.end_date:
|
||||
query = query.filter(File.created_at <= search_request.end_date)
|
||||
|
||||
# 文件大小范围过滤
|
||||
if search_request.min_size:
|
||||
query = query.filter(File.file_size >= search_request.min_size)
|
||||
if search_request.max_size:
|
||||
query = query.filter(File.file_size <= search_request.max_size)
|
||||
|
||||
total = query.count()
|
||||
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
|
||||
pages = (total + size - 1) // size
|
||||
|
||||
return FileListResponse(
|
||||
files=[FileResponse(
|
||||
id=file.id,
|
||||
user_id=file.user_id,
|
||||
filename=file.filename,
|
||||
original_filename=file.original_filename,
|
||||
file_size=file.file_size,
|
||||
mime_type=file.mime_type,
|
||||
file_hash=file.file_hash,
|
||||
is_public=file.is_public,
|
||||
download_count=file.download_count,
|
||||
description=file.description,
|
||||
tags=file.tags,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
last_accessed_at=file.last_accessed_at
|
||||
) for file in files],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size,
|
||||
pages=pages
|
||||
)
|
||||
|
||||
def get_file_info(self, file_id: int, user_id: int) -> FileInfo:
|
||||
"""获取文件详细信息"""
|
||||
db_file = self.get_file_by_id(file_id, user_id)
|
||||
if not db_file:
|
||||
raise FileNotFoundException()
|
||||
|
||||
# 更新最后访问时间
|
||||
db_file.last_accessed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
return FileInfo(
|
||||
id=db_file.id,
|
||||
filename=db_file.filename,
|
||||
original_filename=db_file.original_filename,
|
||||
file_size=db_file.file_size,
|
||||
mime_type=db_file.mime_type,
|
||||
file_hash=db_file.file_hash,
|
||||
is_image=db_file.is_image(),
|
||||
is_document=db_file.is_document(),
|
||||
file_extension=db_file.get_file_extension(),
|
||||
size_formatted=db_file.get_size_formatted(),
|
||||
is_public=db_file.is_public,
|
||||
download_count=db_file.download_count,
|
||||
description=db_file.description,
|
||||
tags=db_file.tags,
|
||||
created_at=db_file.created_at,
|
||||
updated_at=db_file.updated_at,
|
||||
last_accessed_at=db_file.last_accessed_at
|
||||
)
|
||||
|
||||
def get_storage_info(self, user_id: int) -> StorageInfo:
|
||||
"""获取用户存储信息"""
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
file_count = self.db.query(File).filter(File.user_id == user_id).count()
|
||||
available_space = user.storage_quota - user.storage_used
|
||||
usage_percentage = user.get_storage_percentage()
|
||||
|
||||
return StorageInfo(
|
||||
total_quota=user.storage_quota,
|
||||
used_space=user.storage_used,
|
||||
available_space=available_space,
|
||||
usage_percentage=usage_percentage,
|
||||
file_count=file_count
|
||||
)
|
||||
|
||||
def increment_download_count(self, file_id: int) -> None:
|
||||
"""增加文件下载次数"""
|
||||
db_file = self.db.query(File).filter(File.id == file_id).first()
|
||||
if db_file:
|
||||
db_file.download_count += 1
|
||||
db_file.last_accessed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
252
backend/app/services/user_service.py
Normal file
252
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import HTTPException, status
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.models.user import User
|
||||
from app.core.security import get_password_hash, verify_password, create_access_token, create_refresh_token
|
||||
from app.schemas.auth import UserRegister, UserLogin, UserResponse
|
||||
from app.core.config import settings
|
||||
from app.exceptions.auth import UsernameAlreadyExistsException
|
||||
|
||||
class UserService:
|
||||
"""用户服务类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_user(self, user_data: UserRegister) -> User:
|
||||
"""创建新用户"""
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
print(f"[DEBUG] Checking if username exists: {user_data.username}")
|
||||
existing_user_by_username = self.get_user_by_username(user_data.username)
|
||||
if existing_user_by_username:
|
||||
print(f"[DEBUG] Username already exists: {existing_user_by_username}")
|
||||
raise UsernameAlreadyExistsException()
|
||||
print(f"[DEBUG] Username check passed")
|
||||
|
||||
# 邮箱允许重复,不再检查邮箱是否已存在
|
||||
print(f"[DEBUG] Email uniqueness check skipped (emails can be duplicated)")
|
||||
|
||||
# 创建新用户
|
||||
print(f"[DEBUG] About to hash password...")
|
||||
try:
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
print(f"[DEBUG] Password hashed successfully")
|
||||
except Exception as hash_error:
|
||||
print(f"[ERROR] Password hashing failed: {hash_error}")
|
||||
print(f"[ERROR] Error type: {type(hash_error)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
print(f"[DEBUG] Creating User model...")
|
||||
db_user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
is_active=True,
|
||||
is_verified=False
|
||||
)
|
||||
print(f"[DEBUG] User model created")
|
||||
|
||||
self.db.add(db_user)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_user)
|
||||
|
||||
return db_user
|
||||
|
||||
except IntegrityError as e:
|
||||
self.db.rollback()
|
||||
if "username" in str(e.orig):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "USERNAME_EXISTS", "message": "用户名已存在"}
|
||||
)
|
||||
elif "email" in str(e.orig):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "EMAIL_EXISTS", "message": "邮箱已被注册"}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "INTEGRITY_ERROR", "message": "数据完整性错误"}
|
||||
)
|
||||
except (HTTPException, UsernameAlreadyExistsException):
|
||||
# 重新抛出HTTPException(不要覆盖自定义错误信息)
|
||||
self.db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "CREATION_FAILED", "message": "用户创建失败"}
|
||||
)
|
||||
|
||||
def authenticate_user(self, login_data: UserLogin) -> Optional[User]:
|
||||
"""验证用户登录"""
|
||||
# 尝试通过用户名查找用户
|
||||
user = self.get_user_by_username(login_data.username)
|
||||
|
||||
# 如果用户名找不到,尝试通过邮箱查找
|
||||
if not user:
|
||||
user = self.get_user_by_email(login_data.username)
|
||||
|
||||
# 如果找到用户且密码正确
|
||||
if user and verify_password(login_data.password, user.password_hash):
|
||||
# 更新最后登录时间
|
||||
user.last_login_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[User]:
|
||||
"""根据ID获取用户"""
|
||||
return self.db.query(User).filter(User.id == user_id, User.is_active == True).first()
|
||||
|
||||
def get_user_by_username(self, username: str) -> Optional[User]:
|
||||
"""根据用户名获取用户"""
|
||||
return self.db.query(User).filter(User.username == username, User.is_active == True).first()
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
"""根据邮箱获取用户"""
|
||||
return self.db.query(User).filter(User.email == email, User.is_active == True).first()
|
||||
|
||||
def update_user(self, user_id: int, **kwargs) -> Optional[User]:
|
||||
"""更新用户信息"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(user, key) and value is not None:
|
||||
setattr(user, key, value)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "UPDATE_FAILED", "message": "用户信息更新失败"}
|
||||
)
|
||||
|
||||
def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
|
||||
"""修改用户密码"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"code": "USER_NOT_FOUND", "message": "用户不存在"}
|
||||
)
|
||||
|
||||
# 验证当前密码
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "INVALID_PASSWORD", "message": "当前密码不正确"}
|
||||
)
|
||||
|
||||
try:
|
||||
# 更新密码
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "PASSWORD_CHANGE_FAILED", "message": "密码修改失败"}
|
||||
)
|
||||
|
||||
def deactivate_user(self, user_id: int) -> bool:
|
||||
"""停用用户"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
try:
|
||||
user.is_active = False
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "DEACTIVATION_FAILED", "message": "用户停用失败"}
|
||||
)
|
||||
|
||||
def update_storage_usage(self, user_id: int, size_change: int) -> bool:
|
||||
"""更新用户存储使用量"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
try:
|
||||
new_usage = user.storage_used + size_change
|
||||
|
||||
# 检查存储配额
|
||||
if new_usage > user.storage_quota:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "STORAGE_EXCEEDED", "message": "存储空间不足"}
|
||||
)
|
||||
|
||||
user.storage_used = max(0, new_usage) # 确保不小于0
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "STORAGE_UPDATE_FAILED", "message": "存储空间更新失败"}
|
||||
)
|
||||
|
||||
def create_user_tokens(self, user: User) -> dict:
|
||||
"""为用户创建访问令牌和刷新令牌"""
|
||||
access_token_expires = timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
|
||||
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user.id), "username": user.username, "email": user.email},
|
||||
expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
data={"sub": str(user.id)},
|
||||
expires_delta=refresh_token_expires
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.JWT_EXPIRE_MINUTES * 60
|
||||
}
|
||||
|
||||
def to_user_response(self, user: User) -> UserResponse:
|
||||
"""将用户模型转换为响应模型"""
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
avatar_url=user.avatar_url,
|
||||
storage_quota=user.storage_quota,
|
||||
storage_used=user.storage_used,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
last_login_at=user.last_login_at,
|
||||
created_at=user.created_at
|
||||
)
|
||||
Reference in New Issue
Block a user