初次提交
This commit is contained in:
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/v1/__init__.py
Normal file
0
backend/app/api/v1/__init__.py
Normal file
0
backend/app/api/v1/endpoints/__init__.py
Normal file
0
backend/app/api/v1/endpoints/__init__.py
Normal file
217
backend/app/api/v1/endpoints/auth.py
Normal file
217
backend/app/api/v1/endpoints/auth.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import verify_token
|
||||
from app.core.token_blacklist import token_blacklist
|
||||
from app.services.user_service import UserService
|
||||
from app.schemas.auth import (
|
||||
UserRegister, UserLogin, UserResponse, LoginResponse,
|
||||
TokenResponse, TokenRefresh, ApiResponse
|
||||
)
|
||||
from app.dependencies.auth import get_current_user_response
|
||||
from app.models.user import User
|
||||
from app.exceptions.auth import UsernameAlreadyExistsException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/register", status_code=status.HTTP_201_CREATED)
|
||||
async def register(request: Request, user_data: UserRegister, db: Session = Depends(get_db)):
|
||||
"""用户注册"""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
print(f"[{timestamp}] INFO: === 注册接口 ===")
|
||||
print(f"[{timestamp}] INFO: 用户名: {user_data.username}")
|
||||
print(f"[{timestamp}] INFO: 邮箱: {user_data.email}")
|
||||
print(f"[{timestamp}] INFO: 密码长度: {len(user_data.password)}字符")
|
||||
print(f"[{timestamp}] INFO: 确认密码长度: {len(user_data.confirm_password)}字符")
|
||||
print(f"[{timestamp}] DEBUG: Starting registration process...")
|
||||
|
||||
try:
|
||||
user_service = UserService(db)
|
||||
|
||||
# 创建用户
|
||||
print("[DEBUG] Creating user...")
|
||||
user = user_service.create_user(user_data)
|
||||
print(f"[DEBUG] User created successfully with ID: {user.id}")
|
||||
|
||||
# 创建令牌
|
||||
print("[DEBUG] Creating tokens...")
|
||||
tokens = user_service.create_user_tokens(user)
|
||||
print("[DEBUG] Tokens created successfully")
|
||||
|
||||
# 转换为响应格式
|
||||
print("[DEBUG] Converting to response format...")
|
||||
user_response = user_service.to_user_response(user)
|
||||
print("[DEBUG] Response conversion successful")
|
||||
|
||||
response_data = {
|
||||
"user": user_response.dict(),
|
||||
"tokens": tokens
|
||||
}
|
||||
print("[DEBUG] Response data created successfully")
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="注册成功",
|
||||
data=response_data
|
||||
)
|
||||
|
||||
except UsernameAlreadyExistsException as e:
|
||||
print(f"[DEBUG] UsernameAlreadyExistsException caught: {e}")
|
||||
raise e
|
||||
except HTTPException as e:
|
||||
print(f"[DEBUG] HTTPException caught: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
# 打印异常信息以便调试
|
||||
import traceback
|
||||
print(f"[ERROR] Unexpected error in register: {e}")
|
||||
print(f"[ERROR] Exception type: {type(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "REGISTRATION_FAILED",
|
||||
"message": f"注册过程中发生错误: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/login", response_model=ApiResponse)
|
||||
async def login(request: Request, login_data: UserLogin, db: Session = Depends(get_db)):
|
||||
"""用户登录"""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
print(f"[{timestamp}] INFO: === 登录接口 ===")
|
||||
print(f"[{timestamp}] INFO: 用户名: {login_data.username}")
|
||||
print(f"[{timestamp}] INFO: 密码长度: {len(login_data.password)}字符")
|
||||
print(f"[{timestamp}] DEBUG: Starting authentication process...")
|
||||
|
||||
try:
|
||||
user_service = UserService(db)
|
||||
# 验证用户
|
||||
user = user_service.authenticate_user(login_data)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
"code": "INVALID_CREDENTIALS",
|
||||
"message": "用户名或密码错误"
|
||||
}
|
||||
)
|
||||
|
||||
# 创建令牌
|
||||
tokens = user_service.create_user_tokens(user)
|
||||
|
||||
# 转换为响应格式
|
||||
user_response = user_service.to_user_response(user)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
data={
|
||||
"user": user_response.dict(),
|
||||
"tokens": tokens
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
print("用户登录的异常:",e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "LOGIN_FAILED",
|
||||
"message": "登录过程中发生错误"
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(token_data: TokenRefresh, db: Session = Depends(get_db)):
|
||||
"""刷新访问令牌"""
|
||||
try:
|
||||
# 验证刷新令牌
|
||||
payload = verify_token(token_data.refresh_token, "refresh")
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
"code": "INVALID_REFRESH_TOKEN",
|
||||
"message": "无效的刷新令牌"
|
||||
}
|
||||
)
|
||||
|
||||
# 获取用户ID
|
||||
user_id = int(payload.get("sub"))
|
||||
user_service = UserService(db)
|
||||
user = user_service.get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
"code": "USER_NOT_FOUND",
|
||||
"message": "用户不存在"
|
||||
}
|
||||
)
|
||||
|
||||
# 创建新的访问令牌
|
||||
tokens = user_service.create_user_tokens(user)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="令牌刷新成功",
|
||||
data={
|
||||
"tokens": tokens
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "TOKEN_REFRESH_FAILED",
|
||||
"message": "令牌刷新过程中发生错误"
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/me", response_model=ApiResponse)
|
||||
async def get_current_user_info(current_user: UserResponse = Depends(get_current_user_response)):
|
||||
"""获取当前用户信息"""
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="获取用户信息成功",
|
||||
data={
|
||||
"user": current_user.dict()
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/logout", response_model=ApiResponse)
|
||||
async def logout(
|
||||
request: Request,
|
||||
current_user: UserResponse = Depends(get_current_user_response)
|
||||
):
|
||||
"""用户登出"""
|
||||
try:
|
||||
# 从请求头中获取Authorization令牌
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.split(" ")[1]
|
||||
# 将令牌加入黑名单
|
||||
token_blacklist.add_token(token)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="登出成功",
|
||||
data={}
|
||||
)
|
||||
except Exception as e:
|
||||
# 即使添加令牌到黑名单失败,也返回成功,因为登出操作主要目的是让客户端删除令牌
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="登出成功",
|
||||
data={}
|
||||
)
|
||||
382
backend/app/api/v1/endpoints/files.py
Normal file
382
backend/app/api/v1/endpoints/files.py
Normal file
@@ -0,0 +1,382 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form, Query, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.services.file_service import FileService
|
||||
from app.schemas.file import (
|
||||
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
|
||||
FileResponse, FileListResponse, FileInfo, StorageInfo, ApiResponse,
|
||||
UploadResponse, DeleteResponse, FileListRequest, FileIdRequest, StorageInfoRequest
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.dependencies.auth import get_current_user_from_headers
|
||||
from app.exceptions.file import (
|
||||
FileTooLargeException, StorageQuotaExceededException,
|
||||
FileAlreadyExistsException, FileNotFoundException,
|
||||
InvalidFileTypeException, FileUploadException, FileDeleteException
|
||||
)
|
||||
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [MODULE] Files module loaded successfully")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/upload", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
description: Optional[str] = Form(None),
|
||||
tags: Optional[str] = Form(None),
|
||||
is_public: bool = Form(False),
|
||||
current_user: User = Depends(get_current_user_from_headers),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""上传文件"""
|
||||
try:
|
||||
file_service = FileService(db)
|
||||
|
||||
# 创建上传请求对象
|
||||
upload_request = FileUploadRequest(
|
||||
description=description,
|
||||
tags=tags,
|
||||
is_public=is_public
|
||||
)
|
||||
|
||||
# 上传文件
|
||||
db_file = file_service.upload_file(file, current_user, upload_request)
|
||||
|
||||
# 转换为响应格式
|
||||
file_response = FileResponse(
|
||||
id=db_file.id,
|
||||
user_id=db_file.user_id,
|
||||
filename=db_file.filename,
|
||||
original_filename=db_file.original_filename,
|
||||
file_size=db_file.file_size,
|
||||
mime_type=db_file.mime_type,
|
||||
file_hash=db_file.file_hash,
|
||||
is_public=db_file.is_public,
|
||||
download_count=db_file.download_count,
|
||||
description=db_file.description,
|
||||
tags=db_file.tags,
|
||||
created_at=db_file.created_at,
|
||||
updated_at=db_file.updated_at,
|
||||
last_accessed_at=db_file.last_accessed_at
|
||||
)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="文件上传成功",
|
||||
data={
|
||||
"file": file_response.model_dump()
|
||||
}
|
||||
)
|
||||
|
||||
except (FileTooLargeException, StorageQuotaExceededException,
|
||||
FileAlreadyExistsException, InvalidFileTypeException) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise FileUploadException(str(e))
|
||||
|
||||
@router.get("/list", response_model=ApiResponse)
|
||||
def get_user_files(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(10, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user_from_headers),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取用户文件列表"""
|
||||
try:
|
||||
file_service = FileService(db)
|
||||
file_list = file_service.get_user_files(current_user.id, page, size)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="获取文件列表成功",
|
||||
data={
|
||||
"files": [FileResponse(
|
||||
id=file.id,
|
||||
user_id=file.user_id,
|
||||
filename=file.filename,
|
||||
original_filename=file.original_filename,
|
||||
file_size=file.file_size,
|
||||
mime_type=file.mime_type,
|
||||
file_hash=file.file_hash,
|
||||
is_public=file.is_public,
|
||||
download_count=file.download_count,
|
||||
description=file.description,
|
||||
tags=file.tags,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
last_accessed_at=file.last_accessed_at
|
||||
).model_dump() for file in file_list.files],
|
||||
"pagination": {
|
||||
"total": file_list.total,
|
||||
"page": file_list.page,
|
||||
"size": file_list.size,
|
||||
"pages": file_list.pages
|
||||
}
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "GET_FILES_FAILED",
|
||||
"message": "获取文件列表失败"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/info", response_model=ApiResponse)
|
||||
def get_file_info(
|
||||
request: FileIdRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取文件详细信息"""
|
||||
try:
|
||||
file_service = FileService(db)
|
||||
file_info = file_service.get_file_info(request.file_id, request.user_id)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="获取文件信息成功",
|
||||
data=file_info.model_dump()
|
||||
)
|
||||
except FileNotFoundException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "GET_FILE_INFO_FAILED",
|
||||
"message": "获取文件信息失败"
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/update", response_model=ApiResponse)
|
||||
def update_file(
|
||||
file_id_request: FileIdRequest,
|
||||
update_request: FileUpdateRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新文件信息"""
|
||||
try:
|
||||
file_service = FileService(db)
|
||||
db_file = file_service.update_file(file_id_request.file_id, file_id_request.user_id, update_request)
|
||||
|
||||
file_response = FileResponse(
|
||||
id=db_file.id,
|
||||
user_id=db_file.user_id,
|
||||
filename=db_file.filename,
|
||||
original_filename=db_file.original_filename,
|
||||
file_size=db_file.file_size,
|
||||
mime_type=db_file.mime_type,
|
||||
file_hash=db_file.file_hash,
|
||||
is_public=db_file.is_public,
|
||||
download_count=db_file.download_count,
|
||||
description=db_file.description,
|
||||
tags=db_file.tags,
|
||||
created_at=db_file.created_at,
|
||||
updated_at=db_file.updated_at,
|
||||
last_accessed_at=db_file.last_accessed_at
|
||||
)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="文件信息更新成功",
|
||||
data={
|
||||
"file": file_response.model_dump()
|
||||
}
|
||||
)
|
||||
except FileNotFoundException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "UPDATE_FILE_FAILED",
|
||||
"message": "更新文件信息失败"
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/delete", response_model=ApiResponse)
|
||||
def delete_file(
|
||||
request: FileIdRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除文件"""
|
||||
try:
|
||||
file_service = FileService(db)
|
||||
success = file_service.delete_file(request.file_id, request.user_id)
|
||||
|
||||
if success:
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="文件删除成功",
|
||||
data={}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "DELETE_FILE_FAILED",
|
||||
"message": "文件删除失败"
|
||||
}
|
||||
)
|
||||
except FileNotFoundException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise FileDeleteException(str(e))
|
||||
|
||||
@router.post("/download")
|
||||
def download_file(
|
||||
request: FileIdRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""下载文件"""
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [IN] Processing download request: file_id={request.file_id}, user_id={request.user_id}")
|
||||
|
||||
try:
|
||||
file_service = FileService(db)
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: File service created")
|
||||
|
||||
db_file = file_service.get_file_by_id(request.file_id, request.user_id)
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: File found in database: {db_file is not None}")
|
||||
|
||||
if not db_file:
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: File not found in database")
|
||||
raise FileNotFoundException()
|
||||
|
||||
# 增加下载次数
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: Incrementing download count")
|
||||
file_service.increment_download_count(request.file_id)
|
||||
|
||||
# 确保使用绝对路径
|
||||
import os
|
||||
absolute_path = os.path.abspath(db_file.file_path)
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [PATH] File path: {absolute_path}")
|
||||
|
||||
# 验证文件存在
|
||||
if not os.path.exists(absolute_path):
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: File not found on disk: {absolute_path}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"code": "FILE_NOT_FOUND_ON_DISK",
|
||||
"message": "upload文件夹与数据库不匹配"
|
||||
}
|
||||
)
|
||||
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [OK] File exists on disk")
|
||||
|
||||
# 对于文本文件,直接读取内容返回
|
||||
if db_file.mime_type and db_file.mime_type.startswith('text/'):
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [TEXT] Processing text file")
|
||||
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [READ] Read {len(content)} characters from file")
|
||||
|
||||
from fastapi.responses import Response
|
||||
import urllib.parse
|
||||
# 对文件名进行URL编码以支持中文
|
||||
encoded_filename = urllib.parse.quote(db_file.original_filename.encode('utf-8'))
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [SEND] Sending file: {db_file.original_filename}")
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=db_file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
}
|
||||
)
|
||||
else:
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [BINARY] Processing binary file with FileResponse")
|
||||
|
||||
# 对于二进制文件,使用FileResponse
|
||||
return FileResponse(
|
||||
path=absolute_path,
|
||||
filename=db_file.original_filename,
|
||||
media_type=db_file.mime_type
|
||||
)
|
||||
|
||||
except FileNotFoundException as e:
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: FileNotFoundException: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: Download error: {e}")
|
||||
import traceback
|
||||
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: {traceback.format_exc()}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "DOWNLOAD_FILE_FAILED",
|
||||
"message": "文件下载失败"
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/test-download")
|
||||
def test_download_file():
|
||||
"""测试下载功能 - 直接返回5KB文件内容"""
|
||||
print("[TEST_DOWNLOAD] Test endpoint called", flush=True)
|
||||
return {"message": "Test endpoint is working"}
|
||||
|
||||
@router.post("/simple-download")
|
||||
def simple_download_file():
|
||||
"""简单下载测试"""
|
||||
try:
|
||||
print("[SIMPLE_DOWNLOAD] Simple download endpoint called", flush=True)
|
||||
|
||||
# 直接读取我们创建的文件
|
||||
file_path = "uploads/verified_5kb_download_test.txt"
|
||||
|
||||
import os
|
||||
if not os.path.exists(file_path):
|
||||
print(f"[SIMPLE_DOWNLOAD] File not found: {file_path}", flush=True)
|
||||
return {"error": "File not found"}
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
print(f"[SIMPLE_DOWNLOAD] Read {len(content)} characters", flush=True)
|
||||
|
||||
from fastapi.responses import Response
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="text/plain",
|
||||
headers={"Content-Disposition": "attachment; filename=verified_5kb.txt"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[SIMPLE_DOWNLOAD] Exception: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {"error": str(e)}
|
||||
|
||||
@router.post("/storage/info", response_model=ApiResponse)
|
||||
def get_storage_info(
|
||||
request: StorageInfoRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取用户存储信息"""
|
||||
try:
|
||||
file_service = FileService(db)
|
||||
storage_info = file_service.get_storage_info(request.user_id)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
message="获取存储信息成功",
|
||||
data=storage_info.model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "GET_STORAGE_INFO_FAILED",
|
||||
"message": "获取存储信息失败"
|
||||
}
|
||||
)
|
||||
|
||||
81
backend/app/api/v1/endpoints/health.py
Normal file
81
backend/app/api/v1/endpoints/health.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.database import get_db
|
||||
from app.core.config import settings
|
||||
import redis
|
||||
import time
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""基础健康检查"""
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"status": "healthy",
|
||||
"service": "cloud-drive-api",
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"timestamp": int(time.time())
|
||||
},
|
||||
"message": "API服务运行正常"
|
||||
}
|
||||
|
||||
@router.get("/ready")
|
||||
async def readiness_check(db: Session = Depends(get_db)):
|
||||
"""就绪检查 - 检查数据库和Redis连接"""
|
||||
checks = {}
|
||||
|
||||
# 检查数据库连接
|
||||
try:
|
||||
db.execute("SELECT 1")
|
||||
checks["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "数据库连接正常"
|
||||
}
|
||||
except Exception as e:
|
||||
checks["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"数据库连接失败: {str(e)}"
|
||||
}
|
||||
raise HTTPException(status_code=503, detail="数据库连接失败")
|
||||
|
||||
# 检查Redis连接
|
||||
try:
|
||||
r = redis.from_url(settings.REDIS_URL)
|
||||
r.ping()
|
||||
checks["redis"] = {
|
||||
"status": "healthy",
|
||||
"message": "Redis连接正常"
|
||||
}
|
||||
except Exception as e:
|
||||
checks["redis"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Redis连接失败: {str(e)}"
|
||||
}
|
||||
raise HTTPException(status_code=503, detail="Redis连接失败")
|
||||
|
||||
# 检查文件存储
|
||||
try:
|
||||
import os
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
checks["storage"] = {
|
||||
"status": "healthy",
|
||||
"message": "文件存储正常"
|
||||
}
|
||||
except Exception as e:
|
||||
checks["storage"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"文件存储失败: {str(e)}"
|
||||
}
|
||||
raise HTTPException(status_code=503, detail="文件存储失败")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"status": "ready",
|
||||
"checks": checks,
|
||||
"timestamp": int(time.time())
|
||||
},
|
||||
"message": "所有服务已就绪"
|
||||
}
|
||||
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
54
backend/app/core/config.py
Normal file
54
backend/app/core/config.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 基础配置
|
||||
ENVIRONMENT: str = "development"
|
||||
DEBUG: bool = True
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL: str = "mysql+pymysql://mytest_db:mytest_db@101.126.85.76:3306/mytest_db"
|
||||
|
||||
# Redis配置
|
||||
REDIS_URL: str = "redis://localhost:6379"
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY: str = "your-super-secret-jwt-key-change-in-production"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_EXPIRE_MINUTES: int = 30
|
||||
JWT_REFRESH_EXPIRE_DAYS: int = 7
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_HOSTS: List[str] = ["*"] # 允许所有域名访问
|
||||
|
||||
# 文件上传配置
|
||||
MAX_FILE_SIZE: int = 10 * 1024 * 1024 # 10MB
|
||||
UPLOAD_DIR: str = "uploads"
|
||||
ALLOWED_EXTENSIONS: List[str] = [
|
||||
# 图片
|
||||
".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg",
|
||||
# 文档
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
|
||||
".txt", ".rtf", ".csv",
|
||||
# 压缩文件
|
||||
".zip", ".rar", ".7z", ".tar", ".gz",
|
||||
# 音频
|
||||
".mp3", ".wav", ".flac", ".aac", ".ogg",
|
||||
# 视频
|
||||
".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv",
|
||||
# 代码文件
|
||||
".py", ".js", ".html", ".css", ".json", ".xml", ".yaml", ".yml",
|
||||
".java", ".cpp", ".c", ".h", ".cs", ".php", ".rb", ".go",
|
||||
".sql", ".sh", ".bat", ".ps1", ".md", ".log"
|
||||
]
|
||||
|
||||
# 安全配置
|
||||
BCRYPT_ROUNDS: int = 12
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
extra = "allow" # 允许额外的环境变量
|
||||
|
||||
settings = Settings()
|
||||
30
backend/app/core/database.py
Normal file
30
backend/app/core/database.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.config import settings
|
||||
import pymysql
|
||||
|
||||
# 安装pymysql作为MySQLdb的替代
|
||||
pymysql.install_as_MySQLdb()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=300
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 创建基础模型类
|
||||
Base = declarative_base()
|
||||
|
||||
# 数据库依赖
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
151
backend/app/core/security.py
Normal file
151
backend/app/core/security.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union
|
||||
from jose import JWTError, jwt
|
||||
import bcrypt
|
||||
from app.core.config import settings
|
||||
from app.core.token_blacklist import token_blacklist
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建访问令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建刷新令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(token: str, token_type: str = "access") -> Optional[dict]:
|
||||
"""验证令牌"""
|
||||
try:
|
||||
# 首先检查令牌是否在黑名单中
|
||||
if token_blacklist.is_blacklisted(token):
|
||||
return None
|
||||
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
|
||||
# 检查令牌类型
|
||||
if payload.get("type") != token_type:
|
||||
return None
|
||||
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
try:
|
||||
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""获取密码哈希"""
|
||||
# bcrypt 限制密码长度为72字节,如果超过则截断
|
||||
if len(password.encode('utf-8')) > 72:
|
||||
password = password.encode('utf-8')[:72].decode('utf-8', errors='ignore')
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode('utf-8'), salt).decode('utf-8')
|
||||
|
||||
def create_password_reset_token(email: str) -> str:
|
||||
"""创建密码重置令牌"""
|
||||
delta = timedelta(hours=1) # 1小时有效期
|
||||
now = datetime.utcnow()
|
||||
expires = now + delta
|
||||
exp = expires.timestamp()
|
||||
encoded_jwt = jwt.encode(
|
||||
{"exp": exp, "nbf": now, "sub": email, "type": "password_reset"},
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
"""验证密码重置令牌"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
|
||||
# 检查令牌类型
|
||||
if payload.get("type") != "password_reset":
|
||||
return None
|
||||
|
||||
return payload["sub"]
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
# 密码强度验证
|
||||
def validate_password_strength(password: str) -> dict:
|
||||
"""验证密码强度"""
|
||||
errors = []
|
||||
|
||||
if len(password) < 8:
|
||||
errors.append("密码长度至少8位")
|
||||
|
||||
if len(password) > 128:
|
||||
errors.append("密码长度不能超过128位")
|
||||
|
||||
if not any(c.islower() for c in password):
|
||||
errors.append("密码必须包含至少一个小写字母")
|
||||
|
||||
if not any(c.isupper() for c in password):
|
||||
errors.append("密码必须包含至少一个大写字母")
|
||||
|
||||
if not any(c.isdigit() for c in password):
|
||||
errors.append("密码必须包含至少一个数字")
|
||||
|
||||
# 检查特殊字符
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
if not any(c in special_chars for c in password):
|
||||
errors.append("密码必须包含至少一个特殊字符")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"strength": calculate_password_strength(password)
|
||||
}
|
||||
|
||||
def calculate_password_strength(password: str) -> str:
|
||||
"""计算密码强度"""
|
||||
score = 0
|
||||
|
||||
# 长度评分
|
||||
if len(password) >= 8:
|
||||
score += 1
|
||||
if len(password) >= 12:
|
||||
score += 1
|
||||
if len(password) >= 16:
|
||||
score += 1
|
||||
|
||||
# 字符类型评分
|
||||
if any(c.islower() for c in password):
|
||||
score += 1
|
||||
if any(c.isupper() for c in password):
|
||||
score += 1
|
||||
if any(c.isdigit() for c in password):
|
||||
score += 1
|
||||
if any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
|
||||
score += 1
|
||||
|
||||
# 根据评分返回强度等级
|
||||
if score <= 2:
|
||||
return "弱"
|
||||
elif score <= 4:
|
||||
return "中等"
|
||||
elif score <= 6:
|
||||
return "强"
|
||||
else:
|
||||
return "非常强"
|
||||
46
backend/app/core/token_blacklist.py
Normal file
46
backend/app/core/token_blacklist.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
import threading
|
||||
|
||||
class TokenBlacklist:
|
||||
"""简单的令牌黑名单(内存存储)"""
|
||||
|
||||
def __init__(self):
|
||||
self._blacklisted_tokens: Dict[str, datetime] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_token(self, token: str, expires_at: Optional[datetime] = None):
|
||||
"""添加令牌到黑名单"""
|
||||
with self._lock:
|
||||
# 如果没有提供过期时间,默认24小时后过期
|
||||
if expires_at is None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
||||
self._blacklisted_tokens[token] = expires_at
|
||||
|
||||
def is_blacklisted(self, token: str) -> bool:
|
||||
"""检查令牌是否在黑名单中"""
|
||||
with self._lock:
|
||||
if token not in self._blacklisted_tokens:
|
||||
return False
|
||||
|
||||
# 检查令牌是否已过期
|
||||
if datetime.utcnow() > self._blacklisted_tokens[token]:
|
||||
# 清理过期的令牌
|
||||
del self._blacklisted_tokens[token]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
with self._lock:
|
||||
current_time = datetime.utcnow()
|
||||
expired_tokens = [
|
||||
token for token, expires_at in self._blacklisted_tokens.items()
|
||||
if current_time > expires_at
|
||||
]
|
||||
for token in expired_tokens:
|
||||
del self._blacklisted_tokens[token]
|
||||
|
||||
# 全局黑名单实例
|
||||
token_blacklist = TokenBlacklist()
|
||||
0
backend/app/dependencies/__init__.py
Normal file
0
backend/app/dependencies/__init__.py
Normal file
217
backend/app/dependencies/auth.py
Normal file
217
backend/app/dependencies/auth.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import verify_token
|
||||
from app.services.user_service import UserService
|
||||
from app.schemas.auth import UserResponse
|
||||
from app.models.user import User
|
||||
|
||||
# Bearer token 认证方案
|
||||
security = HTTPBearer()
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""获取当前认证用户"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
"code": "INVALID_AUTHENTICATION",
|
||||
"message": "无法验证凭据"
|
||||
},
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证令牌
|
||||
payload = verify_token(credentials.credentials, "access")
|
||||
if payload is None:
|
||||
raise credentials_exception
|
||||
|
||||
# 获取用户ID
|
||||
user_id: str = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
|
||||
user_id = int(user_id)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
raise credentials_exception
|
||||
|
||||
# 获取用户信息
|
||||
user_service = UserService(db)
|
||||
user = user_service.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
async def get_current_user_from_headers(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""从请求头中获取当前认证用户(支持userId和token)"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
"code": "INVALID_AUTHENTICATION",
|
||||
"message": "无法验证凭据"
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# 尝试从多种方式获取token
|
||||
token = None
|
||||
user_id = None
|
||||
|
||||
# 1. 从Authorization header获取
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.split(" ")[1]
|
||||
|
||||
# 2. 从token header获取
|
||||
if not token:
|
||||
token = request.headers.get("token")
|
||||
|
||||
# 3. 从userId header获取用户ID
|
||||
user_id_str = request.headers.get("userId")
|
||||
if user_id_str:
|
||||
try:
|
||||
user_id = int(user_id_str)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 如果没有token,认证失败
|
||||
if not token:
|
||||
raise credentials_exception
|
||||
|
||||
# 验证令牌
|
||||
payload = verify_token(token, "access")
|
||||
if payload is None:
|
||||
raise credentials_exception
|
||||
|
||||
# 获取token中的用户ID
|
||||
token_user_id: str = payload.get("sub")
|
||||
if token_user_id is None:
|
||||
raise credentials_exception
|
||||
|
||||
token_user_id = int(token_user_id)
|
||||
|
||||
# 如果header中有userId,验证两个ID是否一致
|
||||
if user_id is not None and user_id != token_user_id:
|
||||
raise credentials_exception
|
||||
|
||||
# 使用token中的用户ID
|
||||
final_user_id = token_user_id
|
||||
|
||||
except (ValueError, TypeError):
|
||||
raise credentials_exception
|
||||
|
||||
# 获取用户信息
|
||||
user_service = UserService(db)
|
||||
user = user_service.get_user_by_id(final_user_id)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""获取当前活跃用户"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "INACTIVE_USER",
|
||||
"message": "用户账户已被禁用"
|
||||
}
|
||||
)
|
||||
return current_user
|
||||
|
||||
async def get_current_active_user_from_headers(
|
||||
current_user: User = Depends(get_current_user_from_headers)
|
||||
) -> User:
|
||||
"""从请求头获取当前活跃用户"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "INACTIVE_USER",
|
||||
"message": "用户账户已被禁用"
|
||||
}
|
||||
)
|
||||
return current_user
|
||||
|
||||
async def get_current_verified_user(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> User:
|
||||
"""获取当前已验证用户"""
|
||||
if not current_user.is_verified:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "UNVERIFIED_USER",
|
||||
"message": "用户账户未验证"
|
||||
}
|
||||
)
|
||||
return current_user
|
||||
|
||||
async def get_current_user_response(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> UserResponse:
|
||||
"""获取当前用户信息(响应格式)"""
|
||||
user_service = UserService(db)
|
||||
return user_service.to_user_response(current_user)
|
||||
|
||||
async def get_current_user_response_from_headers(
|
||||
current_user: User = Depends(get_current_active_user_from_headers),
|
||||
db: Session = Depends(get_db)
|
||||
) -> UserResponse:
|
||||
"""从请求头获取当前用户信息(响应格式)"""
|
||||
user_service = UserService(db)
|
||||
return user_service.to_user_response(current_user)
|
||||
|
||||
# 可选的认证依赖项(不强制要求认证)
|
||||
async def get_optional_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: Session = Depends(get_db)
|
||||
) -> User | None:
|
||||
"""获取可选的当前用户(认证失败时不抛出异常)"""
|
||||
try:
|
||||
return await get_current_user(credentials, db)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
# 权限检查函数
|
||||
def check_user_permission(required_permission: str = None):
|
||||
"""检查用户权限的装饰器工厂"""
|
||||
async def permission_checker(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> User:
|
||||
# 这里可以根据需要实现更复杂的权限检查逻辑
|
||||
# 目前只检查用户是否活跃
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": "PERMISSION_DENIED",
|
||||
"message": "权限不足"
|
||||
}
|
||||
)
|
||||
return current_user
|
||||
|
||||
return permission_checker
|
||||
|
||||
# 管理员权限检查(预留)
|
||||
async def get_admin_user(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> User:
|
||||
"""获取管理员用户(预留功能)"""
|
||||
# 这里可以添加管理员权限检查逻辑
|
||||
# 例如检查用户是否有管理员角色
|
||||
return current_user
|
||||
1
backend/app/exceptions/__init__.py
Normal file
1
backend/app/exceptions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Exception classes for the application
|
||||
23
backend/app/exceptions/auth.py
Normal file
23
backend/app/exceptions/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
class EmailAlreadyExistsException(HTTPException):
|
||||
"""邮箱已存在异常"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "EMAIL_EXISTS",
|
||||
"message": "邮箱已被注册"
|
||||
}
|
||||
)
|
||||
|
||||
class UsernameAlreadyExistsException(HTTPException):
|
||||
"""用户名已存在异常"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "USERNAME_EXISTS",
|
||||
"message": "用户名已存在"
|
||||
}
|
||||
)
|
||||
92
backend/app/exceptions/file.py
Normal file
92
backend/app/exceptions/file.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class FileTooLargeException(HTTPException):
|
||||
"""文件过大异常"""
|
||||
def __init__(self, file_size: int, max_size: int):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail={
|
||||
"code": "FILE_TOO_LARGE",
|
||||
"message": f"文件大小 {file_size} 字节超过限制 {max_size} 字节",
|
||||
"file_size": file_size,
|
||||
"max_size": max_size
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class StorageQuotaExceededException(HTTPException):
|
||||
"""存储配额超限异常"""
|
||||
def __init__(self, used_space: int, quota: int, required_space: int):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail={
|
||||
"code": "STORAGE_QUOTA_EXCEEDED",
|
||||
"message": f"存储空间不足。已使用: {used_space} 字节,配额: {quota} 字节,需要: {required_space} 字节",
|
||||
"used_space": used_space,
|
||||
"quota": quota,
|
||||
"required_space": required_space
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class FileAlreadyExistsException(HTTPException):
|
||||
"""文件已存在异常"""
|
||||
def __init__(self, filename: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"code": "FILE_ALREADY_EXISTS",
|
||||
"message": f"文件 '{filename}' 已存在",
|
||||
"filename": filename
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class FileNotFoundException(HTTPException):
|
||||
"""文件未找到异常"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"code": "FILE_NOT_FOUND",
|
||||
"message": "文件不存在"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class InvalidFileTypeException(HTTPException):
|
||||
"""无效文件类型异常"""
|
||||
def __init__(self, file_extension: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "INVALID_FILE_TYPE",
|
||||
"message": f"不支持的文件类型: {file_extension}",
|
||||
"file_extension": file_extension
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class FileUploadException(HTTPException):
|
||||
"""文件上传异常"""
|
||||
def __init__(self, message: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "FILE_UPLOAD_FAILED",
|
||||
"message": f"文件上传失败: {message}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class FileDeleteException(HTTPException):
|
||||
"""文件删除异常"""
|
||||
def __init__(self, message: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "FILE_DELETE_FAILED",
|
||||
"message": f"文件删除失败: {message}"
|
||||
}
|
||||
)
|
||||
4
backend/app/models/__init__.py
Normal file
4
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .user import User
|
||||
from .file import File
|
||||
|
||||
__all__ = ["User", "File"]
|
||||
86
backend/app/models/file.py
Normal file
86
backend/app/models/file.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from sqlalchemy import Column, Integer, String, BigInteger, DateTime, ForeignKey, Boolean, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = "files"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
|
||||
# 文件基本信息
|
||||
filename = Column(String(255), nullable=False, index=True)
|
||||
original_filename = Column(String(255), nullable=False) # 用户上传时的原始文件名
|
||||
file_path = Column(String(500), nullable=False) # 服务器上的存储路径
|
||||
file_size = Column(BigInteger, nullable=False) # 文件大小(字节)
|
||||
mime_type = Column(String(100), nullable=False) # 文件MIME类型
|
||||
file_hash = Column(String(64), nullable=False, index=True) # SHA-256哈希,用于去重和完整性检查
|
||||
|
||||
# 文件状态
|
||||
is_public = Column(Boolean, default=False) # 是否公开分享
|
||||
download_count = Column(BigInteger, default=0) # 下载次数
|
||||
|
||||
# 文件元数据
|
||||
description = Column(Text, nullable=True) # 文件描述
|
||||
tags = Column(Text, nullable=True) # 标签,用逗号分隔
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
last_accessed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 关联关系
|
||||
user = relationship("User", back_populates="files")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<File(id={self.id}, filename='{self.filename}', user_id={self.user_id})>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"user_id": self.user_id,
|
||||
"filename": self.filename,
|
||||
"original_filename": self.original_filename,
|
||||
"file_size": self.file_size,
|
||||
"mime_type": self.mime_type,
|
||||
"file_hash": self.file_hash,
|
||||
"is_public": self.is_public,
|
||||
"download_count": self.download_count,
|
||||
"description": self.description,
|
||||
"tags": self.tags,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_accessed_at": self.last_accessed_at.isoformat() if self.last_accessed_at else None,
|
||||
}
|
||||
|
||||
def get_file_extension(self) -> str:
|
||||
"""获取文件扩展名"""
|
||||
return self.filename.split('.')[-1].lower() if '.' in self.filename else ''
|
||||
|
||||
def is_image(self) -> bool:
|
||||
"""判断是否为图片文件"""
|
||||
return self.mime_type.startswith('image/')
|
||||
|
||||
def is_document(self) -> bool:
|
||||
"""判断是否为文档文件"""
|
||||
document_types = [
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'text/plain',
|
||||
'text/csv'
|
||||
]
|
||||
return self.mime_type in document_types
|
||||
|
||||
def get_size_formatted(self) -> str:
|
||||
"""获取格式化的文件大小"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if self.file_size < 1024.0:
|
||||
return f"{self.file_size:.1f} {unit}"
|
||||
self.file_size /= 1024.0
|
||||
return f"{self.file_size:.1f} TB"
|
||||
59
backend/app/models/user.py
Normal file
59
backend/app/models/user.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, BigInteger, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(50), unique=True, index=True, nullable=False)
|
||||
email = Column(String(100), index=True, nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
|
||||
# 用户资料
|
||||
avatar_url = Column(String(500), nullable=True)
|
||||
|
||||
# 存储配额
|
||||
storage_quota = Column(BigInteger, default=104857600) # 100MB in bytes
|
||||
storage_used = Column(BigInteger, default=0)
|
||||
|
||||
# 用户状态
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_verified = Column(Boolean, default=False)
|
||||
|
||||
# 时间戳
|
||||
last_login_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
# 关联关系
|
||||
files = relationship("File", back_populates="user")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"username": self.username,
|
||||
"email": self.email,
|
||||
"avatar_url": self.avatar_url,
|
||||
"storage_quota": self.storage_quota,
|
||||
"storage_used": self.storage_used,
|
||||
"is_active": self.is_active,
|
||||
"is_verified": self.is_verified,
|
||||
"last_login_at": self.last_login_at.isoformat() if self.last_login_at else None,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def is_storage_available(self, required_size: int) -> bool:
|
||||
"""检查是否有足够的存储空间"""
|
||||
return (self.storage_used + required_size) <= self.storage_quota
|
||||
|
||||
def get_storage_percentage(self) -> float:
|
||||
"""获取已使用存储空间的百分比"""
|
||||
if self.storage_quota == 0:
|
||||
return 0.0
|
||||
return (self.storage_used / self.storage_quota) * 100
|
||||
2
backend/app/schemas/__init__.py
Normal file
2
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .auth import *
|
||||
from .file import *
|
||||
165
backend/app/schemas/auth.py
Normal file
165
backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from pydantic import BaseModel, EmailStr, Field, validator
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
class UserRegister(BaseModel):
|
||||
"""用户注册请求模型"""
|
||||
username: str = Field(..., min_length=3, max_length=50, description="用户名")
|
||||
email: EmailStr = Field(..., description="邮箱地址")
|
||||
password: str = Field(..., min_length=6, max_length=128, description="密码")
|
||||
confirm_password: str = Field(..., min_length=6, max_length=128, description="确认密码")
|
||||
|
||||
@validator('username')
|
||||
def validate_username(cls, v):
|
||||
"""验证用户名格式"""
|
||||
if not re.match(r'^[a-zA-Z0-9_]+$', v):
|
||||
raise ValueError('用户名只能包含字母、数字和下划线')
|
||||
if v.startswith('_') or v.endswith('_'):
|
||||
raise ValueError('用户名不能以下划线开头或结尾')
|
||||
return v
|
||||
|
||||
@validator('confirm_password')
|
||||
def passwords_match(cls, v, values):
|
||||
"""验证密码确认"""
|
||||
if 'password' in values and v != values['password']:
|
||||
raise ValueError('两次输入的密码不一致')
|
||||
return v
|
||||
|
||||
@validator('password')
|
||||
def validate_password_length(cls, v):
|
||||
"""验证密码长度"""
|
||||
if len(v) <= 5:
|
||||
raise ValueError('密码长度必须大于5个字符')
|
||||
return v
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
"""用户登录请求模型"""
|
||||
username: str = Field(..., description="用户名或邮箱")
|
||||
password: str = Field(..., min_length=1, description="密码")
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""令牌响应模型"""
|
||||
access_token: str = Field(..., description="访问令牌")
|
||||
refresh_token: str = Field(..., description="刷新令牌")
|
||||
token_type: str = Field(default="bearer", description="令牌类型")
|
||||
expires_in: int = Field(..., description="访问令牌过期时间(秒)")
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""用户信息响应模型"""
|
||||
id: int
|
||||
username: str
|
||||
email: str
|
||||
avatar_url: Optional[str] = None
|
||||
storage_quota: int
|
||||
storage_used: int
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
last_login_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""登录响应模型"""
|
||||
user: UserResponse = Field(..., description="用户信息")
|
||||
tokens: TokenResponse = Field(..., description="令牌信息")
|
||||
|
||||
class TokenRefresh(BaseModel):
|
||||
"""令牌刷新请求模型"""
|
||||
refresh_token: str = Field(..., description="刷新令牌")
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""修改密码请求模型"""
|
||||
current_password: str = Field(..., description="当前密码")
|
||||
new_password: str = Field(..., min_length=8, max_length=128, description="新密码")
|
||||
confirm_password: str = Field(..., min_length=8, max_length=128, description="确认新密码")
|
||||
|
||||
@validator('confirm_password')
|
||||
def passwords_match(cls, v, values):
|
||||
"""验证密码确认"""
|
||||
if 'new_password' in values and v != values['new_password']:
|
||||
raise ValueError('密码确认不匹配')
|
||||
return v
|
||||
|
||||
@validator('new_password')
|
||||
def validate_password_strength(cls, v):
|
||||
"""验证密码强度"""
|
||||
errors = []
|
||||
|
||||
if len(v) < 8:
|
||||
errors.append("密码长度至少8位")
|
||||
|
||||
if not any(c.islower() for c in v):
|
||||
errors.append("密码必须包含至少一个小写字母")
|
||||
|
||||
if not any(c.isupper() for c in v):
|
||||
errors.append("密码必须包含至少一个大写字母")
|
||||
|
||||
if not any(c.isdigit() for c in v):
|
||||
errors.append("密码必须包含至少一个数字")
|
||||
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
if not any(c in special_chars for c in v):
|
||||
errors.append("密码必须包含至少一个特殊字符")
|
||||
|
||||
if errors:
|
||||
raise ValueError('; '.join(errors))
|
||||
|
||||
return v
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""密码重置请求模型"""
|
||||
email: EmailStr = Field(..., description="邮箱地址")
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""密码重置确认模型"""
|
||||
token: str = Field(..., description="重置令牌")
|
||||
new_password: str = Field(..., min_length=8, max_length=128, description="新密码")
|
||||
confirm_password: str = Field(..., min_length=8, max_length=128, description="确认新密码")
|
||||
|
||||
@validator('confirm_password')
|
||||
def passwords_match(cls, v, values):
|
||||
"""验证密码确认"""
|
||||
if 'new_password' in values and v != values['new_password']:
|
||||
raise ValueError('密码确认不匹配')
|
||||
return v
|
||||
|
||||
@validator('new_password')
|
||||
def validate_password_strength(cls, v):
|
||||
"""验证密码强度"""
|
||||
errors = []
|
||||
|
||||
if len(v) < 8:
|
||||
errors.append("密码长度至少8位")
|
||||
|
||||
if not any(c.islower() for c in v):
|
||||
errors.append("密码必须包含至少一个小写字母")
|
||||
|
||||
if not any(c.isupper() for c in v):
|
||||
errors.append("密码必须包含至少一个大写字母")
|
||||
|
||||
if not any(c.isdigit() for c in v):
|
||||
errors.append("密码必须包含至少一个数字")
|
||||
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
if not any(c in special_chars for c in v):
|
||||
errors.append("密码必须包含至少一个特殊字符")
|
||||
|
||||
if errors:
|
||||
raise ValueError('; '.join(errors))
|
||||
|
||||
return v
|
||||
|
||||
class ApiResponse(BaseModel):
|
||||
"""标准API响应模型"""
|
||||
success: bool = Field(..., description="操作是否成功")
|
||||
message: str = Field(..., description="响应消息")
|
||||
data: Optional[dict] = Field(None, description="响应数据")
|
||||
error: Optional[dict] = Field(None, description="错误信息")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
148
backend/app/schemas/file.py
Normal file
148
backend/app/schemas/file.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
class FileUploadRequest(BaseModel):
|
||||
"""文件上传请求"""
|
||||
description: Optional[str] = Field(None, max_length=500, description="文件描述")
|
||||
tags: Optional[str] = Field(None, max_length=200, description="文件标签,用逗号分隔")
|
||||
is_public: bool = Field(False, description="是否公开分享")
|
||||
|
||||
@validator('tags')
|
||||
def validate_tags(cls, v):
|
||||
if v:
|
||||
# 验证标签格式,只允许字母、数字、中文、下划线、中划线
|
||||
tags = v.split(',')
|
||||
for tag in tags:
|
||||
tag = tag.strip()
|
||||
if not re.match(r'^[\w\u4e00-\u9fa5-]+$', tag):
|
||||
raise ValueError(f"标签 '{tag}' 格式不正确,只允许字母、数字、中文、下划线、中划线")
|
||||
if len(tag) > 20:
|
||||
raise ValueError(f"标签 '{tag}' 长度不能超过20个字符")
|
||||
return v
|
||||
|
||||
class FileResponse(BaseModel):
|
||||
"""文件响应"""
|
||||
id: int
|
||||
user_id: int
|
||||
filename: str
|
||||
original_filename: str
|
||||
file_size: int
|
||||
mime_type: str
|
||||
file_hash: str
|
||||
is_public: bool
|
||||
download_count: int
|
||||
description: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_accessed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class FileListResponse(BaseModel):
|
||||
"""文件列表响应"""
|
||||
files: List[FileResponse]
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
pages: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class FileListRequest(BaseModel):
|
||||
"""文件列表请求"""
|
||||
user_id: int = Field(..., description="用户ID")
|
||||
page: int = Field(1, ge=1, description="页码")
|
||||
size: int = Field(20, ge=1, le=100, description="每页数量")
|
||||
|
||||
class FileIdRequest(BaseModel):
|
||||
"""文件ID请求"""
|
||||
user_id: int = Field(..., description="用户ID")
|
||||
file_id: int = Field(..., description="文件ID")
|
||||
|
||||
class StorageInfoRequest(BaseModel):
|
||||
"""存储信息请求"""
|
||||
user_id: int = Field(..., description="用户ID")
|
||||
|
||||
class FileUpdateRequest(BaseModel):
|
||||
"""文件更新请求"""
|
||||
description: Optional[str] = Field(None, max_length=500, description="文件描述")
|
||||
tags: Optional[str] = Field(None, max_length=200, description="文件标签,用逗号分隔")
|
||||
is_public: Optional[bool] = Field(None, description="是否公开分享")
|
||||
|
||||
@validator('tags')
|
||||
def validate_tags(cls, v):
|
||||
if v:
|
||||
tags = v.split(',')
|
||||
for tag in tags:
|
||||
tag = tag.strip()
|
||||
if not re.match(r'^[\w\u4e00-\u9fa5-]+$', tag):
|
||||
raise ValueError(f"标签 '{tag}' 格式不正确,只允许字母、数字、中文、下划线、中划线")
|
||||
if len(tag) > 20:
|
||||
raise ValueError(f"标签 '{tag}' 长度不能超过20个字符")
|
||||
return v
|
||||
|
||||
class FileSearchRequest(BaseModel):
|
||||
"""文件搜索请求"""
|
||||
filename: Optional[str] = Field(None, description="文件名搜索")
|
||||
tags: Optional[str] = Field(None, description="标签搜索,用逗号分隔")
|
||||
mime_type: Optional[str] = Field(None, description="MIME类型过滤")
|
||||
is_public: Optional[bool] = Field(None, description="是否公开文件")
|
||||
start_date: Optional[datetime] = Field(None, description="开始日期")
|
||||
end_date: Optional[datetime] = Field(None, description="结束日期")
|
||||
min_size: Optional[int] = Field(None, ge=0, description="最小文件大小(字节)")
|
||||
max_size: Optional[int] = Field(None, ge=0, description="最大文件大小(字节)")
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""文件信息"""
|
||||
id: int
|
||||
filename: str
|
||||
original_filename: str
|
||||
file_size: int
|
||||
mime_type: str
|
||||
file_hash: str
|
||||
is_image: bool
|
||||
is_document: bool
|
||||
file_extension: str
|
||||
size_formatted: str
|
||||
is_public: bool = False
|
||||
download_count: int = 0
|
||||
description: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_accessed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
"""文件上传响应"""
|
||||
file_info: FileResponse
|
||||
message: str
|
||||
success: bool
|
||||
|
||||
class DeleteResponse(BaseModel):
|
||||
"""删除文件响应"""
|
||||
message: str
|
||||
success: bool
|
||||
|
||||
class StorageInfo(BaseModel):
|
||||
"""存储信息"""
|
||||
total_quota: int
|
||||
used_space: int
|
||||
available_space: int
|
||||
usage_percentage: float
|
||||
file_count: int
|
||||
|
||||
# 通用API响应格式
|
||||
class ApiResponse(BaseModel):
|
||||
"""API响应"""
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[dict] = None
|
||||
code: Optional[str] = None
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
418
backend/app/services/file_service.py
Normal file
418
backend/app/services/file_service.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import os
|
||||
import hashlib
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, desc
|
||||
from fastapi import UploadFile, HTTPException, status
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.file import File
|
||||
from app.models.user import User
|
||||
from app.schemas.file import (
|
||||
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
|
||||
FileResponse, FileListResponse, StorageInfo, FileInfo
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.exceptions.file import (
|
||||
FileTooLargeException, StorageQuotaExceededException,
|
||||
FileAlreadyExistsException, FileNotFoundException,
|
||||
InvalidFileTypeException
|
||||
)
|
||||
|
||||
class FileService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
# 使用绝对路径确保文件保存正确
|
||||
self.upload_dir = os.path.abspath(settings.UPLOAD_DIR)
|
||||
self.max_file_size = settings.MAX_FILE_SIZE
|
||||
self.allowed_extensions = settings.ALLOWED_EXTENSIONS
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# 确保上传目录存在
|
||||
os.makedirs(self.upload_dir, exist_ok=True)
|
||||
self.logger.info(f"Upload directory (absolute): {self.upload_dir}")
|
||||
self.logger.info(f"Upload directory exists: {os.path.exists(self.upload_dir)}")
|
||||
self.logger.info(f"Current working directory: {os.getcwd()}")
|
||||
|
||||
def _calculate_file_hash(self, file_content: bytes) -> str:
|
||||
"""计算文件的SHA-256哈希值"""
|
||||
return hashlib.sha256(file_content).hexdigest()
|
||||
|
||||
def _generate_unique_filename(self, original_filename: str) -> str:
|
||||
"""生成唯一的文件名"""
|
||||
file_extension = os.path.splitext(original_filename)[1]
|
||||
unique_id = str(uuid.uuid4())
|
||||
return f"{unique_id}{file_extension}"
|
||||
|
||||
def _validate_file(self, file: UploadFile, user: User) -> None:
|
||||
"""验证文件"""
|
||||
# 检查文件大小
|
||||
if hasattr(file, 'size') and file.size > self.max_file_size:
|
||||
raise FileTooLargeException(file.size, self.max_file_size)
|
||||
|
||||
# 检查文件扩展名
|
||||
if self.allowed_extensions:
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
if file_extension not in self.allowed_extensions:
|
||||
raise InvalidFileTypeException(file_extension)
|
||||
|
||||
# 检查存储配额
|
||||
if hasattr(file, 'size') and not user.is_storage_available(file.size):
|
||||
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file.size)
|
||||
|
||||
def _save_file_to_disk(self, file: UploadFile, unique_filename: str) -> Tuple[str, bytes]:
|
||||
"""保存文件到磁盘"""
|
||||
file_path = os.path.join(self.upload_dir, unique_filename)
|
||||
|
||||
# 读取文件内容
|
||||
file_content = file.file.read()
|
||||
|
||||
# 强制输出调试信息
|
||||
import sys
|
||||
message = f"[CRITICAL] About to save {len(file_content)} bytes to {file_path}"
|
||||
print(message, flush=True)
|
||||
sys.stdout.flush()
|
||||
message2 = f"[CRITICAL] Content preview: {file_content[:50] if file_content else 'EMPTY'}"
|
||||
print(message2, flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# 立即验证内容是否为空
|
||||
if not file_content:
|
||||
print("[CRITICAL] FILE CONTENT IS EMPTY!")
|
||||
raise ValueError("File content is empty!")
|
||||
|
||||
# 使用临时文件方法确保写入成功
|
||||
temp_path = file_path + '.tmp'
|
||||
try:
|
||||
# 先写入临时文件
|
||||
with open(temp_path, "wb") as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
os.fsync(temp_file.fileno())
|
||||
|
||||
# 验证临时文件
|
||||
if os.path.exists(temp_path):
|
||||
temp_size = os.path.getsize(temp_path)
|
||||
if temp_size != len(file_content):
|
||||
raise Exception(f"Temporary file size mismatch: {temp_size} != {len(file_content)}")
|
||||
|
||||
# 重命名为最终文件名(原子操作)
|
||||
if os.name == 'nt': # Windows
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
os.rename(temp_path, file_path)
|
||||
|
||||
# 最终验证
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception("File was not created after rename")
|
||||
|
||||
final_size = os.path.getsize(file_path)
|
||||
if final_size != len(file_content):
|
||||
raise Exception(f"Final file size mismatch: {final_size} != {len(file_content)}")
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
print(f"[ERROR] File save failed: {e}")
|
||||
raise Exception(f"Failed to save file: {e}")
|
||||
|
||||
# 重置文件指针
|
||||
file.file.seek(0)
|
||||
|
||||
return file_path, file_content
|
||||
|
||||
def upload_file(self, file: UploadFile, user: User, upload_request: FileUploadRequest) -> File:
|
||||
"""上传文件"""
|
||||
try:
|
||||
# 验证文件
|
||||
self._validate_file(file, user)
|
||||
|
||||
# 生成唯一文件名
|
||||
unique_filename = self._generate_unique_filename(file.filename)
|
||||
|
||||
# 保存文件到磁盘
|
||||
file_path, file_content = self._save_file_to_disk(file, unique_filename)
|
||||
file_size = len(file_content)
|
||||
|
||||
# 再次检查文件大小(如果没有size属性)
|
||||
if file_size > self.max_file_size:
|
||||
# 删除已保存的文件
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise FileTooLargeException(file_size, self.max_file_size)
|
||||
|
||||
# 检查存储配额
|
||||
if not user.is_storage_available(file_size):
|
||||
# 删除已保存的文件
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file_size)
|
||||
|
||||
# 计算文件哈希
|
||||
file_hash = self._calculate_file_hash(file_content)
|
||||
|
||||
# 检查文件是否已存在(基于哈希值)
|
||||
existing_file = self.db.query(File).filter(
|
||||
and_(
|
||||
File.user_id == user.id,
|
||||
File.file_hash == file_hash
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_file:
|
||||
# 删除刚保存的文件,因为已存在相同内容的文件
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise FileAlreadyExistsException(existing_file.original_filename)
|
||||
|
||||
# 创建文件记录
|
||||
db_file = File(
|
||||
user_id=user.id,
|
||||
filename=unique_filename,
|
||||
original_filename=file.filename,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
mime_type=file.content_type or 'application/octet-stream',
|
||||
file_hash=file_hash,
|
||||
description=upload_request.description,
|
||||
tags=upload_request.tags,
|
||||
is_public=upload_request.is_public
|
||||
)
|
||||
|
||||
self.db.add(db_file)
|
||||
|
||||
# 更新用户存储使用量
|
||||
user.storage_used += file_size
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(db_file)
|
||||
|
||||
return db_file
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
# 如果保存了文件但数据库操作失败,删除文件
|
||||
if 'file_path' in locals() and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
raise e
|
||||
|
||||
def get_user_files(self, user_id: int, page: int = 1, size: int = 20) -> FileListResponse:
|
||||
"""获取用户的文件列表"""
|
||||
offset = (page - 1) * size
|
||||
|
||||
query = self.db.query(File).filter(File.user_id == user_id)
|
||||
|
||||
total = query.count()
|
||||
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
|
||||
|
||||
pages = (total + size - 1) // size
|
||||
|
||||
return FileListResponse(
|
||||
files=[FileResponse(
|
||||
id=file.id,
|
||||
user_id=file.user_id,
|
||||
filename=file.filename,
|
||||
original_filename=file.original_filename,
|
||||
file_size=file.file_size,
|
||||
mime_type=file.mime_type,
|
||||
file_hash=file.file_hash,
|
||||
is_public=file.is_public,
|
||||
download_count=file.download_count,
|
||||
description=file.description,
|
||||
tags=file.tags,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
last_accessed_at=file.last_accessed_at
|
||||
) for file in files],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size,
|
||||
pages=pages
|
||||
)
|
||||
|
||||
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[File]:
|
||||
"""根据ID获取文件"""
|
||||
return self.db.query(File).filter(
|
||||
and_(
|
||||
File.id == file_id,
|
||||
File.user_id == user_id
|
||||
)
|
||||
).first()
|
||||
|
||||
def update_file(self, file_id: int, user_id: int, update_request: FileUpdateRequest) -> Optional[File]:
|
||||
"""更新文件信息"""
|
||||
db_file = self.get_file_by_id(file_id, user_id)
|
||||
if not db_file:
|
||||
raise FileNotFoundException()
|
||||
|
||||
# 更新字段
|
||||
if update_request.description is not None:
|
||||
db_file.description = update_request.description
|
||||
if update_request.tags is not None:
|
||||
db_file.tags = update_request.tags
|
||||
if update_request.is_public is not None:
|
||||
db_file.is_public = update_request.is_public
|
||||
|
||||
db_file.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(db_file)
|
||||
|
||||
return db_file
|
||||
|
||||
def delete_file(self, file_id: int, user_id: int) -> bool:
|
||||
"""删除文件"""
|
||||
db_file = self.get_file_by_id(file_id, user_id)
|
||||
if not db_file:
|
||||
raise FileNotFoundException()
|
||||
|
||||
try:
|
||||
# 删除磁盘上的文件
|
||||
if os.path.exists(db_file.file_path):
|
||||
os.remove(db_file.file_path)
|
||||
|
||||
# 更新用户存储使用量
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if user:
|
||||
user.storage_used = max(0, user.storage_used - db_file.file_size)
|
||||
|
||||
# 删除数据库记录
|
||||
self.db.delete(db_file)
|
||||
self.db.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise e
|
||||
|
||||
def search_files(self, user_id: int, search_request: FileSearchRequest,
|
||||
page: int = 1, size: int = 20) -> FileListResponse:
|
||||
"""搜索文件"""
|
||||
offset = (page - 1) * size
|
||||
|
||||
query = self.db.query(File).filter(File.user_id == user_id)
|
||||
|
||||
# 文件名搜索
|
||||
if search_request.filename:
|
||||
query = query.filter(
|
||||
File.original_filename.ilike(f"%{search_request.filename}%")
|
||||
)
|
||||
|
||||
# 标签搜索
|
||||
if search_request.tags:
|
||||
tag_list = [tag.strip() for tag in search_request.tags.split(',')]
|
||||
tag_conditions = []
|
||||
for tag in tag_list:
|
||||
tag_conditions.append(File.tags.ilike(f"%{tag}%"))
|
||||
if tag_conditions:
|
||||
query = query.filter(or_(*tag_conditions))
|
||||
|
||||
# MIME类型过滤
|
||||
if search_request.mime_type:
|
||||
query = query.filter(File.mime_type == search_request.mime_type)
|
||||
|
||||
# 公开状态过滤
|
||||
if search_request.is_public is not None:
|
||||
query = query.filter(File.is_public == search_request.is_public)
|
||||
|
||||
# 日期范围过滤
|
||||
if search_request.start_date:
|
||||
query = query.filter(File.created_at >= search_request.start_date)
|
||||
if search_request.end_date:
|
||||
query = query.filter(File.created_at <= search_request.end_date)
|
||||
|
||||
# 文件大小范围过滤
|
||||
if search_request.min_size:
|
||||
query = query.filter(File.file_size >= search_request.min_size)
|
||||
if search_request.max_size:
|
||||
query = query.filter(File.file_size <= search_request.max_size)
|
||||
|
||||
total = query.count()
|
||||
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
|
||||
pages = (total + size - 1) // size
|
||||
|
||||
return FileListResponse(
|
||||
files=[FileResponse(
|
||||
id=file.id,
|
||||
user_id=file.user_id,
|
||||
filename=file.filename,
|
||||
original_filename=file.original_filename,
|
||||
file_size=file.file_size,
|
||||
mime_type=file.mime_type,
|
||||
file_hash=file.file_hash,
|
||||
is_public=file.is_public,
|
||||
download_count=file.download_count,
|
||||
description=file.description,
|
||||
tags=file.tags,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
last_accessed_at=file.last_accessed_at
|
||||
) for file in files],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size,
|
||||
pages=pages
|
||||
)
|
||||
|
||||
def get_file_info(self, file_id: int, user_id: int) -> FileInfo:
|
||||
"""获取文件详细信息"""
|
||||
db_file = self.get_file_by_id(file_id, user_id)
|
||||
if not db_file:
|
||||
raise FileNotFoundException()
|
||||
|
||||
# 更新最后访问时间
|
||||
db_file.last_accessed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
return FileInfo(
|
||||
id=db_file.id,
|
||||
filename=db_file.filename,
|
||||
original_filename=db_file.original_filename,
|
||||
file_size=db_file.file_size,
|
||||
mime_type=db_file.mime_type,
|
||||
file_hash=db_file.file_hash,
|
||||
is_image=db_file.is_image(),
|
||||
is_document=db_file.is_document(),
|
||||
file_extension=db_file.get_file_extension(),
|
||||
size_formatted=db_file.get_size_formatted(),
|
||||
is_public=db_file.is_public,
|
||||
download_count=db_file.download_count,
|
||||
description=db_file.description,
|
||||
tags=db_file.tags,
|
||||
created_at=db_file.created_at,
|
||||
updated_at=db_file.updated_at,
|
||||
last_accessed_at=db_file.last_accessed_at
|
||||
)
|
||||
|
||||
def get_storage_info(self, user_id: int) -> StorageInfo:
|
||||
"""获取用户存储信息"""
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
file_count = self.db.query(File).filter(File.user_id == user_id).count()
|
||||
available_space = user.storage_quota - user.storage_used
|
||||
usage_percentage = user.get_storage_percentage()
|
||||
|
||||
return StorageInfo(
|
||||
total_quota=user.storage_quota,
|
||||
used_space=user.storage_used,
|
||||
available_space=available_space,
|
||||
usage_percentage=usage_percentage,
|
||||
file_count=file_count
|
||||
)
|
||||
|
||||
def increment_download_count(self, file_id: int) -> None:
|
||||
"""增加文件下载次数"""
|
||||
db_file = self.db.query(File).filter(File.id == file_id).first()
|
||||
if db_file:
|
||||
db_file.download_count += 1
|
||||
db_file.last_accessed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
252
backend/app/services/user_service.py
Normal file
252
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import HTTPException, status
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.models.user import User
|
||||
from app.core.security import get_password_hash, verify_password, create_access_token, create_refresh_token
|
||||
from app.schemas.auth import UserRegister, UserLogin, UserResponse
|
||||
from app.core.config import settings
|
||||
from app.exceptions.auth import UsernameAlreadyExistsException
|
||||
|
||||
class UserService:
|
||||
"""用户服务类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_user(self, user_data: UserRegister) -> User:
|
||||
"""创建新用户"""
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
print(f"[DEBUG] Checking if username exists: {user_data.username}")
|
||||
existing_user_by_username = self.get_user_by_username(user_data.username)
|
||||
if existing_user_by_username:
|
||||
print(f"[DEBUG] Username already exists: {existing_user_by_username}")
|
||||
raise UsernameAlreadyExistsException()
|
||||
print(f"[DEBUG] Username check passed")
|
||||
|
||||
# 邮箱允许重复,不再检查邮箱是否已存在
|
||||
print(f"[DEBUG] Email uniqueness check skipped (emails can be duplicated)")
|
||||
|
||||
# 创建新用户
|
||||
print(f"[DEBUG] About to hash password...")
|
||||
try:
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
print(f"[DEBUG] Password hashed successfully")
|
||||
except Exception as hash_error:
|
||||
print(f"[ERROR] Password hashing failed: {hash_error}")
|
||||
print(f"[ERROR] Error type: {type(hash_error)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
print(f"[DEBUG] Creating User model...")
|
||||
db_user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
is_active=True,
|
||||
is_verified=False
|
||||
)
|
||||
print(f"[DEBUG] User model created")
|
||||
|
||||
self.db.add(db_user)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_user)
|
||||
|
||||
return db_user
|
||||
|
||||
except IntegrityError as e:
|
||||
self.db.rollback()
|
||||
if "username" in str(e.orig):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "USERNAME_EXISTS", "message": "用户名已存在"}
|
||||
)
|
||||
elif "email" in str(e.orig):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "EMAIL_EXISTS", "message": "邮箱已被注册"}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "INTEGRITY_ERROR", "message": "数据完整性错误"}
|
||||
)
|
||||
except (HTTPException, UsernameAlreadyExistsException):
|
||||
# 重新抛出HTTPException(不要覆盖自定义错误信息)
|
||||
self.db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "CREATION_FAILED", "message": "用户创建失败"}
|
||||
)
|
||||
|
||||
def authenticate_user(self, login_data: UserLogin) -> Optional[User]:
|
||||
"""验证用户登录"""
|
||||
# 尝试通过用户名查找用户
|
||||
user = self.get_user_by_username(login_data.username)
|
||||
|
||||
# 如果用户名找不到,尝试通过邮箱查找
|
||||
if not user:
|
||||
user = self.get_user_by_email(login_data.username)
|
||||
|
||||
# 如果找到用户且密码正确
|
||||
if user and verify_password(login_data.password, user.password_hash):
|
||||
# 更新最后登录时间
|
||||
user.last_login_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[User]:
|
||||
"""根据ID获取用户"""
|
||||
return self.db.query(User).filter(User.id == user_id, User.is_active == True).first()
|
||||
|
||||
def get_user_by_username(self, username: str) -> Optional[User]:
|
||||
"""根据用户名获取用户"""
|
||||
return self.db.query(User).filter(User.username == username, User.is_active == True).first()
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
"""根据邮箱获取用户"""
|
||||
return self.db.query(User).filter(User.email == email, User.is_active == True).first()
|
||||
|
||||
def update_user(self, user_id: int, **kwargs) -> Optional[User]:
|
||||
"""更新用户信息"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(user, key) and value is not None:
|
||||
setattr(user, key, value)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "UPDATE_FAILED", "message": "用户信息更新失败"}
|
||||
)
|
||||
|
||||
def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
|
||||
"""修改用户密码"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"code": "USER_NOT_FOUND", "message": "用户不存在"}
|
||||
)
|
||||
|
||||
# 验证当前密码
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "INVALID_PASSWORD", "message": "当前密码不正确"}
|
||||
)
|
||||
|
||||
try:
|
||||
# 更新密码
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "PASSWORD_CHANGE_FAILED", "message": "密码修改失败"}
|
||||
)
|
||||
|
||||
def deactivate_user(self, user_id: int) -> bool:
|
||||
"""停用用户"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
try:
|
||||
user.is_active = False
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "DEACTIVATION_FAILED", "message": "用户停用失败"}
|
||||
)
|
||||
|
||||
def update_storage_usage(self, user_id: int, size_change: int) -> bool:
|
||||
"""更新用户存储使用量"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
try:
|
||||
new_usage = user.storage_used + size_change
|
||||
|
||||
# 检查存储配额
|
||||
if new_usage > user.storage_quota:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"code": "STORAGE_EXCEEDED", "message": "存储空间不足"}
|
||||
)
|
||||
|
||||
user.storage_used = max(0, new_usage) # 确保不小于0
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"code": "STORAGE_UPDATE_FAILED", "message": "存储空间更新失败"}
|
||||
)
|
||||
|
||||
def create_user_tokens(self, user: User) -> dict:
|
||||
"""为用户创建访问令牌和刷新令牌"""
|
||||
access_token_expires = timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
|
||||
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user.id), "username": user.username, "email": user.email},
|
||||
expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
data={"sub": str(user.id)},
|
||||
expires_delta=refresh_token_expires
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.JWT_EXPIRE_MINUTES * 60
|
||||
}
|
||||
|
||||
def to_user_response(self, user: User) -> UserResponse:
|
||||
"""将用户模型转换为响应模型"""
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
avatar_url=user.avatar_url,
|
||||
storage_quota=user.storage_quota,
|
||||
storage_used=user.storage_used,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
last_login_at=user.last_login_at,
|
||||
created_at=user.created_at
|
||||
)
|
||||
Reference in New Issue
Block a user