初次提交

This commit is contained in:
2025-10-14 20:05:29 +08:00
commit 6e4e48fdd2
673 changed files with 437006 additions and 0 deletions

0
backend/app/__init__.py Normal file
View File

View File

View File

View File

View File

@@ -0,0 +1,217 @@
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
from app.core.database import get_db
from app.core.security import verify_token
from app.core.token_blacklist import token_blacklist
from app.services.user_service import UserService
from app.schemas.auth import (
UserRegister, UserLogin, UserResponse, LoginResponse,
TokenResponse, TokenRefresh, ApiResponse
)
from app.dependencies.auth import get_current_user_response
from app.models.user import User
from app.exceptions.auth import UsernameAlreadyExistsException
router = APIRouter()
@router.post("/register", status_code=status.HTTP_201_CREATED)
async def register(request: Request, user_data: UserRegister, db: Session = Depends(get_db)):
"""用户注册"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] INFO: === 注册接口 ===")
print(f"[{timestamp}] INFO: 用户名: {user_data.username}")
print(f"[{timestamp}] INFO: 邮箱: {user_data.email}")
print(f"[{timestamp}] INFO: 密码长度: {len(user_data.password)}字符")
print(f"[{timestamp}] INFO: 确认密码长度: {len(user_data.confirm_password)}字符")
print(f"[{timestamp}] DEBUG: Starting registration process...")
try:
user_service = UserService(db)
# 创建用户
print("[DEBUG] Creating user...")
user = user_service.create_user(user_data)
print(f"[DEBUG] User created successfully with ID: {user.id}")
# 创建令牌
print("[DEBUG] Creating tokens...")
tokens = user_service.create_user_tokens(user)
print("[DEBUG] Tokens created successfully")
# 转换为响应格式
print("[DEBUG] Converting to response format...")
user_response = user_service.to_user_response(user)
print("[DEBUG] Response conversion successful")
response_data = {
"user": user_response.dict(),
"tokens": tokens
}
print("[DEBUG] Response data created successfully")
return ApiResponse(
success=True,
message="注册成功",
data=response_data
)
except UsernameAlreadyExistsException as e:
print(f"[DEBUG] UsernameAlreadyExistsException caught: {e}")
raise e
except HTTPException as e:
print(f"[DEBUG] HTTPException caught: {e}")
raise e
except Exception as e:
# 打印异常信息以便调试
import traceback
print(f"[ERROR] Unexpected error in register: {e}")
print(f"[ERROR] Exception type: {type(e)}")
traceback.print_exc()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "REGISTRATION_FAILED",
"message": f"注册过程中发生错误: {str(e)}"
}
)
@router.post("/login", response_model=ApiResponse)
async def login(request: Request, login_data: UserLogin, db: Session = Depends(get_db)):
"""用户登录"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] INFO: === 登录接口 ===")
print(f"[{timestamp}] INFO: 用户名: {login_data.username}")
print(f"[{timestamp}] INFO: 密码长度: {len(login_data.password)}字符")
print(f"[{timestamp}] DEBUG: Starting authentication process...")
try:
user_service = UserService(db)
# 验证用户
user = user_service.authenticate_user(login_data)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"code": "INVALID_CREDENTIALS",
"message": "用户名或密码错误"
}
)
# 创建令牌
tokens = user_service.create_user_tokens(user)
# 转换为响应格式
user_response = user_service.to_user_response(user)
return ApiResponse(
success=True,
message="登录成功",
data={
"user": user_response.dict(),
"tokens": tokens
}
)
except HTTPException as e:
raise e
except Exception as e:
print("用户登录的异常:",e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "LOGIN_FAILED",
"message": "登录过程中发生错误"
}
)
@router.post("/refresh", response_model=ApiResponse)
async def refresh_token(token_data: TokenRefresh, db: Session = Depends(get_db)):
"""刷新访问令牌"""
try:
# 验证刷新令牌
payload = verify_token(token_data.refresh_token, "refresh")
if not payload:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"code": "INVALID_REFRESH_TOKEN",
"message": "无效的刷新令牌"
}
)
# 获取用户ID
user_id = int(payload.get("sub"))
user_service = UserService(db)
user = user_service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"code": "USER_NOT_FOUND",
"message": "用户不存在"
}
)
# 创建新的访问令牌
tokens = user_service.create_user_tokens(user)
return ApiResponse(
success=True,
message="令牌刷新成功",
data={
"tokens": tokens
}
)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "TOKEN_REFRESH_FAILED",
"message": "令牌刷新过程中发生错误"
}
)
@router.get("/me", response_model=ApiResponse)
async def get_current_user_info(current_user: UserResponse = Depends(get_current_user_response)):
"""获取当前用户信息"""
return ApiResponse(
success=True,
message="获取用户信息成功",
data={
"user": current_user.dict()
}
)
@router.post("/logout", response_model=ApiResponse)
async def logout(
request: Request,
current_user: UserResponse = Depends(get_current_user_response)
):
"""用户登出"""
try:
# 从请求头中获取Authorization令牌
authorization = request.headers.get("Authorization")
if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1]
# 将令牌加入黑名单
token_blacklist.add_token(token)
return ApiResponse(
success=True,
message="登出成功",
data={}
)
except Exception as e:
# 即使添加令牌到黑名单失败,也返回成功,因为登出操作主要目的是让客户端删除令牌
return ApiResponse(
success=True,
message="登出成功",
data={}
)

View File

@@ -0,0 +1,382 @@
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form, Query, Request
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from typing import Optional, List
from datetime import datetime
from app.core.database import get_db
from app.services.file_service import FileService
from app.schemas.file import (
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
FileResponse, FileListResponse, FileInfo, StorageInfo, ApiResponse,
UploadResponse, DeleteResponse, FileListRequest, FileIdRequest, StorageInfoRequest
)
from app.models.user import User
from app.dependencies.auth import get_current_user_from_headers
from app.exceptions.file import (
FileTooLargeException, StorageQuotaExceededException,
FileAlreadyExistsException, FileNotFoundException,
InvalidFileTypeException, FileUploadException, FileDeleteException
)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [MODULE] Files module loaded successfully")
router = APIRouter()
@router.post("/upload", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
async def upload_file(
file: UploadFile = File(...),
description: Optional[str] = Form(None),
tags: Optional[str] = Form(None),
is_public: bool = Form(False),
current_user: User = Depends(get_current_user_from_headers),
db: Session = Depends(get_db)
):
"""上传文件"""
try:
file_service = FileService(db)
# 创建上传请求对象
upload_request = FileUploadRequest(
description=description,
tags=tags,
is_public=is_public
)
# 上传文件
db_file = file_service.upload_file(file, current_user, upload_request)
# 转换为响应格式
file_response = FileResponse(
id=db_file.id,
user_id=db_file.user_id,
filename=db_file.filename,
original_filename=db_file.original_filename,
file_size=db_file.file_size,
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
is_public=db_file.is_public,
download_count=db_file.download_count,
description=db_file.description,
tags=db_file.tags,
created_at=db_file.created_at,
updated_at=db_file.updated_at,
last_accessed_at=db_file.last_accessed_at
)
return ApiResponse(
success=True,
message="文件上传成功",
data={
"file": file_response.model_dump()
}
)
except (FileTooLargeException, StorageQuotaExceededException,
FileAlreadyExistsException, InvalidFileTypeException) as e:
raise e
except Exception as e:
raise FileUploadException(str(e))
@router.get("/list", response_model=ApiResponse)
def get_user_files(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
current_user: User = Depends(get_current_user_from_headers),
db: Session = Depends(get_db)
):
"""获取用户文件列表"""
try:
file_service = FileService(db)
file_list = file_service.get_user_files(current_user.id, page, size)
return ApiResponse(
success=True,
message="获取文件列表成功",
data={
"files": [FileResponse(
id=file.id,
user_id=file.user_id,
filename=file.filename,
original_filename=file.original_filename,
file_size=file.file_size,
mime_type=file.mime_type,
file_hash=file.file_hash,
is_public=file.is_public,
download_count=file.download_count,
description=file.description,
tags=file.tags,
created_at=file.created_at,
updated_at=file.updated_at,
last_accessed_at=file.last_accessed_at
).model_dump() for file in file_list.files],
"pagination": {
"total": file_list.total,
"page": file_list.page,
"size": file_list.size,
"pages": file_list.pages
}
}
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "GET_FILES_FAILED",
"message": "获取文件列表失败"
}
)
@router.post("/info", response_model=ApiResponse)
def get_file_info(
request: FileIdRequest,
db: Session = Depends(get_db)
):
"""获取文件详细信息"""
try:
file_service = FileService(db)
file_info = file_service.get_file_info(request.file_id, request.user_id)
return ApiResponse(
success=True,
message="获取文件信息成功",
data=file_info.model_dump()
)
except FileNotFoundException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "GET_FILE_INFO_FAILED",
"message": "获取文件信息失败"
}
)
@router.post("/update", response_model=ApiResponse)
def update_file(
file_id_request: FileIdRequest,
update_request: FileUpdateRequest,
db: Session = Depends(get_db)
):
"""更新文件信息"""
try:
file_service = FileService(db)
db_file = file_service.update_file(file_id_request.file_id, file_id_request.user_id, update_request)
file_response = FileResponse(
id=db_file.id,
user_id=db_file.user_id,
filename=db_file.filename,
original_filename=db_file.original_filename,
file_size=db_file.file_size,
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
is_public=db_file.is_public,
download_count=db_file.download_count,
description=db_file.description,
tags=db_file.tags,
created_at=db_file.created_at,
updated_at=db_file.updated_at,
last_accessed_at=db_file.last_accessed_at
)
return ApiResponse(
success=True,
message="文件信息更新成功",
data={
"file": file_response.model_dump()
}
)
except FileNotFoundException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "UPDATE_FILE_FAILED",
"message": "更新文件信息失败"
}
)
@router.post("/delete", response_model=ApiResponse)
def delete_file(
request: FileIdRequest,
db: Session = Depends(get_db)
):
"""删除文件"""
try:
file_service = FileService(db)
success = file_service.delete_file(request.file_id, request.user_id)
if success:
return ApiResponse(
success=True,
message="文件删除成功",
data={}
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "DELETE_FILE_FAILED",
"message": "文件删除失败"
}
)
except FileNotFoundException as e:
raise e
except Exception as e:
raise FileDeleteException(str(e))
@router.post("/download")
def download_file(
request: FileIdRequest,
db: Session = Depends(get_db)
):
"""下载文件"""
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [IN] Processing download request: file_id={request.file_id}, user_id={request.user_id}")
try:
file_service = FileService(db)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: File service created")
db_file = file_service.get_file_by_id(request.file_id, request.user_id)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: File found in database: {db_file is not None}")
if not db_file:
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: File not found in database")
raise FileNotFoundException()
# 增加下载次数
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: Incrementing download count")
file_service.increment_download_count(request.file_id)
# 确保使用绝对路径
import os
absolute_path = os.path.abspath(db_file.file_path)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [PATH] File path: {absolute_path}")
# 验证文件存在
if not os.path.exists(absolute_path):
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: File not found on disk: {absolute_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"code": "FILE_NOT_FOUND_ON_DISK",
"message": "upload文件夹与数据库不匹配"
}
)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [OK] File exists on disk")
# 对于文本文件,直接读取内容返回
if db_file.mime_type and db_file.mime_type.startswith('text/'):
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [TEXT] Processing text file")
with open(absolute_path, 'r', encoding='utf-8') as f:
content = f.read()
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [READ] Read {len(content)} characters from file")
from fastapi.responses import Response
import urllib.parse
# 对文件名进行URL编码以支持中文
encoded_filename = urllib.parse.quote(db_file.original_filename.encode('utf-8'))
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [SEND] Sending file: {db_file.original_filename}")
return Response(
content=content,
media_type=db_file.mime_type,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
)
else:
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] INFO: [BINARY] Processing binary file with FileResponse")
# 对于二进制文件使用FileResponse
return FileResponse(
path=absolute_path,
filename=db_file.original_filename,
media_type=db_file.mime_type
)
except FileNotFoundException as e:
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: FileNotFoundException: {e}")
raise e
except Exception as e:
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: Download error: {e}")
import traceback
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ERROR: {traceback.format_exc()}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "DOWNLOAD_FILE_FAILED",
"message": "文件下载失败"
}
)
@router.post("/test-download")
def test_download_file():
"""测试下载功能 - 直接返回5KB文件内容"""
print("[TEST_DOWNLOAD] Test endpoint called", flush=True)
return {"message": "Test endpoint is working"}
@router.post("/simple-download")
def simple_download_file():
"""简单下载测试"""
try:
print("[SIMPLE_DOWNLOAD] Simple download endpoint called", flush=True)
# 直接读取我们创建的文件
file_path = "uploads/verified_5kb_download_test.txt"
import os
if not os.path.exists(file_path):
print(f"[SIMPLE_DOWNLOAD] File not found: {file_path}", flush=True)
return {"error": "File not found"}
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
print(f"[SIMPLE_DOWNLOAD] Read {len(content)} characters", flush=True)
from fastapi.responses import Response
return Response(
content=content,
media_type="text/plain",
headers={"Content-Disposition": "attachment; filename=verified_5kb.txt"}
)
except Exception as e:
print(f"[SIMPLE_DOWNLOAD] Exception: {e}", flush=True)
import traceback
traceback.print_exc()
return {"error": str(e)}
@router.post("/storage/info", response_model=ApiResponse)
def get_storage_info(
request: StorageInfoRequest,
db: Session = Depends(get_db)
):
"""获取用户存储信息"""
try:
file_service = FileService(db)
storage_info = file_service.get_storage_info(request.user_id)
return ApiResponse(
success=True,
message="获取存储信息成功",
data=storage_info.model_dump()
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "GET_STORAGE_INFO_FAILED",
"message": "获取存储信息失败"
}
)

View File

@@ -0,0 +1,81 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.config import settings
import redis
import time
router = APIRouter()
@router.get("/health")
async def health_check():
"""基础健康检查"""
return {
"success": True,
"data": {
"status": "healthy",
"service": "cloud-drive-api",
"environment": settings.ENVIRONMENT,
"timestamp": int(time.time())
},
"message": "API服务运行正常"
}
@router.get("/ready")
async def readiness_check(db: Session = Depends(get_db)):
"""就绪检查 - 检查数据库和Redis连接"""
checks = {}
# 检查数据库连接
try:
db.execute("SELECT 1")
checks["database"] = {
"status": "healthy",
"message": "数据库连接正常"
}
except Exception as e:
checks["database"] = {
"status": "unhealthy",
"message": f"数据库连接失败: {str(e)}"
}
raise HTTPException(status_code=503, detail="数据库连接失败")
# 检查Redis连接
try:
r = redis.from_url(settings.REDIS_URL)
r.ping()
checks["redis"] = {
"status": "healthy",
"message": "Redis连接正常"
}
except Exception as e:
checks["redis"] = {
"status": "unhealthy",
"message": f"Redis连接失败: {str(e)}"
}
raise HTTPException(status_code=503, detail="Redis连接失败")
# 检查文件存储
try:
import os
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
checks["storage"] = {
"status": "healthy",
"message": "文件存储正常"
}
except Exception as e:
checks["storage"] = {
"status": "unhealthy",
"message": f"文件存储失败: {str(e)}"
}
raise HTTPException(status_code=503, detail="文件存储失败")
return {
"success": True,
"data": {
"status": "ready",
"checks": checks,
"timestamp": int(time.time())
},
"message": "所有服务已就绪"
}

View File

View File

@@ -0,0 +1,54 @@
from pydantic_settings import BaseSettings
from typing import List
import os
class Settings(BaseSettings):
# 基础配置
ENVIRONMENT: str = "development"
DEBUG: bool = True
# 数据库配置
DATABASE_URL: str = "mysql+pymysql://mytest_db:mytest_db@101.126.85.76:3306/mytest_db"
# Redis配置
REDIS_URL: str = "redis://localhost:6379"
# JWT配置
JWT_SECRET_KEY: str = "your-super-secret-jwt-key-change-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_MINUTES: int = 30
JWT_REFRESH_EXPIRE_DAYS: int = 7
# CORS配置
ALLOWED_HOSTS: List[str] = ["*"] # 允许所有域名访问
# 文件上传配置
MAX_FILE_SIZE: int = 10 * 1024 * 1024 # 10MB
UPLOAD_DIR: str = "uploads"
ALLOWED_EXTENSIONS: List[str] = [
# 图片
".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg",
# 文档
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
".txt", ".rtf", ".csv",
# 压缩文件
".zip", ".rar", ".7z", ".tar", ".gz",
# 音频
".mp3", ".wav", ".flac", ".aac", ".ogg",
# 视频
".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv",
# 代码文件
".py", ".js", ".html", ".css", ".json", ".xml", ".yaml", ".yml",
".java", ".cpp", ".c", ".h", ".cs", ".php", ".rb", ".go",
".sql", ".sh", ".bat", ".ps1", ".md", ".log"
]
# 安全配置
BCRYPT_ROUNDS: int = 12
class Config:
env_file = ".env"
case_sensitive = True
extra = "allow" # 允许额外的环境变量
settings = Settings()

View File

@@ -0,0 +1,30 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
import pymysql
# 安装pymysql作为MySQLdb的替代
pymysql.install_as_MySQLdb()
# 创建数据库引擎
engine = create_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_recycle=300
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建基础模型类
Base = declarative_base()
# 数据库依赖
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -0,0 +1,151 @@
from datetime import datetime, timedelta
from typing import Optional, Union
from jose import JWTError, jwt
import bcrypt
from app.core.config import settings
from app.core.token_blacklist import token_blacklist
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建刷新令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def verify_token(token: str, token_type: str = "access") -> Optional[dict]:
"""验证令牌"""
try:
# 首先检查令牌是否在黑名单中
if token_blacklist.is_blacklisted(token):
return None
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# 检查令牌类型
if payload.get("type") != token_type:
return None
return payload
except JWTError:
return None
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
try:
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
except:
return False
def get_password_hash(password: str) -> str:
"""获取密码哈希"""
# bcrypt 限制密码长度为72字节如果超过则截断
if len(password.encode('utf-8')) > 72:
password = password.encode('utf-8')[:72].decode('utf-8', errors='ignore')
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode('utf-8'), salt).decode('utf-8')
def create_password_reset_token(email: str) -> str:
"""创建密码重置令牌"""
delta = timedelta(hours=1) # 1小时有效期
now = datetime.utcnow()
expires = now + delta
exp = expires.timestamp()
encoded_jwt = jwt.encode(
{"exp": exp, "nbf": now, "sub": email, "type": "password_reset"},
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM,
)
return encoded_jwt
def verify_password_reset_token(token: str) -> Optional[str]:
"""验证密码重置令牌"""
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# 检查令牌类型
if payload.get("type") != "password_reset":
return None
return payload["sub"]
except JWTError:
return None
# 密码强度验证
def validate_password_strength(password: str) -> dict:
"""验证密码强度"""
errors = []
if len(password) < 8:
errors.append("密码长度至少8位")
if len(password) > 128:
errors.append("密码长度不能超过128位")
if not any(c.islower() for c in password):
errors.append("密码必须包含至少一个小写字母")
if not any(c.isupper() for c in password):
errors.append("密码必须包含至少一个大写字母")
if not any(c.isdigit() for c in password):
errors.append("密码必须包含至少一个数字")
# 检查特殊字符
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
if not any(c in special_chars for c in password):
errors.append("密码必须包含至少一个特殊字符")
return {
"is_valid": len(errors) == 0,
"errors": errors,
"strength": calculate_password_strength(password)
}
def calculate_password_strength(password: str) -> str:
"""计算密码强度"""
score = 0
# 长度评分
if len(password) >= 8:
score += 1
if len(password) >= 12:
score += 1
if len(password) >= 16:
score += 1
# 字符类型评分
if any(c.islower() for c in password):
score += 1
if any(c.isupper() for c in password):
score += 1
if any(c.isdigit() for c in password):
score += 1
if any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
score += 1
# 根据评分返回强度等级
if score <= 2:
return ""
elif score <= 4:
return "中等"
elif score <= 6:
return ""
else:
return "非常强"

View File

@@ -0,0 +1,46 @@
from datetime import datetime, timedelta
from typing import Dict, Optional
import threading
class TokenBlacklist:
"""简单的令牌黑名单(内存存储)"""
def __init__(self):
self._blacklisted_tokens: Dict[str, datetime] = {}
self._lock = threading.Lock()
def add_token(self, token: str, expires_at: Optional[datetime] = None):
"""添加令牌到黑名单"""
with self._lock:
# 如果没有提供过期时间默认24小时后过期
if expires_at is None:
expires_at = datetime.utcnow() + timedelta(hours=24)
self._blacklisted_tokens[token] = expires_at
def is_blacklisted(self, token: str) -> bool:
"""检查令牌是否在黑名单中"""
with self._lock:
if token not in self._blacklisted_tokens:
return False
# 检查令牌是否已过期
if datetime.utcnow() > self._blacklisted_tokens[token]:
# 清理过期的令牌
del self._blacklisted_tokens[token]
return False
return True
def cleanup_expired_tokens(self):
"""清理过期的令牌"""
with self._lock:
current_time = datetime.utcnow()
expired_tokens = [
token for token, expires_at in self._blacklisted_tokens.items()
if current_time > expires_at
]
for token in expired_tokens:
del self._blacklisted_tokens[token]
# 全局黑名单实例
token_blacklist = TokenBlacklist()

View File

View File

@@ -0,0 +1,217 @@
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.security import verify_token
from app.services.user_service import UserService
from app.schemas.auth import UserResponse
from app.models.user import User
# Bearer token 认证方案
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> User:
"""获取当前认证用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"code": "INVALID_AUTHENTICATION",
"message": "无法验证凭据"
},
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 验证令牌
payload = verify_token(credentials.credentials, "access")
if payload is None:
raise credentials_exception
# 获取用户ID
user_id: str = payload.get("sub")
if user_id is None:
raise credentials_exception
user_id = int(user_id)
except (ValueError, TypeError):
raise credentials_exception
# 获取用户信息
user_service = UserService(db)
user = user_service.get_user_by_id(user_id)
if user is None:
raise credentials_exception
return user
async def get_current_user_from_headers(
request: Request,
db: Session = Depends(get_db)
) -> User:
"""从请求头中获取当前认证用户支持userId和token"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"code": "INVALID_AUTHENTICATION",
"message": "无法验证凭据"
},
)
try:
# 尝试从多种方式获取token
token = None
user_id = None
# 1. 从Authorization header获取
authorization = request.headers.get("Authorization")
if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1]
# 2. 从token header获取
if not token:
token = request.headers.get("token")
# 3. 从userId header获取用户ID
user_id_str = request.headers.get("userId")
if user_id_str:
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
pass
# 如果没有token认证失败
if not token:
raise credentials_exception
# 验证令牌
payload = verify_token(token, "access")
if payload is None:
raise credentials_exception
# 获取token中的用户ID
token_user_id: str = payload.get("sub")
if token_user_id is None:
raise credentials_exception
token_user_id = int(token_user_id)
# 如果header中有userId验证两个ID是否一致
if user_id is not None and user_id != token_user_id:
raise credentials_exception
# 使用token中的用户ID
final_user_id = token_user_id
except (ValueError, TypeError):
raise credentials_exception
# 获取用户信息
user_service = UserService(db)
user = user_service.get_user_by_id(final_user_id)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""获取当前活跃用户"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "INACTIVE_USER",
"message": "用户账户已被禁用"
}
)
return current_user
async def get_current_active_user_from_headers(
current_user: User = Depends(get_current_user_from_headers)
) -> User:
"""从请求头获取当前活跃用户"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "INACTIVE_USER",
"message": "用户账户已被禁用"
}
)
return current_user
async def get_current_verified_user(
current_user: User = Depends(get_current_active_user)
) -> User:
"""获取当前已验证用户"""
if not current_user.is_verified:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "UNVERIFIED_USER",
"message": "用户账户未验证"
}
)
return current_user
async def get_current_user_response(
current_user: User = Depends(get_current_active_user),
db: Session = Depends(get_db)
) -> UserResponse:
"""获取当前用户信息(响应格式)"""
user_service = UserService(db)
return user_service.to_user_response(current_user)
async def get_current_user_response_from_headers(
current_user: User = Depends(get_current_active_user_from_headers),
db: Session = Depends(get_db)
) -> UserResponse:
"""从请求头获取当前用户信息(响应格式)"""
user_service = UserService(db)
return user_service.to_user_response(current_user)
# 可选的认证依赖项(不强制要求认证)
async def get_optional_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> User | None:
"""获取可选的当前用户(认证失败时不抛出异常)"""
try:
return await get_current_user(credentials, db)
except HTTPException:
return None
# 权限检查函数
def check_user_permission(required_permission: str = None):
"""检查用户权限的装饰器工厂"""
async def permission_checker(
current_user: User = Depends(get_current_active_user)
) -> User:
# 这里可以根据需要实现更复杂的权限检查逻辑
# 目前只检查用户是否活跃
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": "PERMISSION_DENIED",
"message": "权限不足"
}
)
return current_user
return permission_checker
# 管理员权限检查(预留)
async def get_admin_user(
current_user: User = Depends(get_current_active_user)
) -> User:
"""获取管理员用户(预留功能)"""
# 这里可以添加管理员权限检查逻辑
# 例如检查用户是否有管理员角色
return current_user

View File

@@ -0,0 +1 @@
# Exception classes for the application

View File

@@ -0,0 +1,23 @@
from fastapi import HTTPException, status
class EmailAlreadyExistsException(HTTPException):
"""邮箱已存在异常"""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "EMAIL_EXISTS",
"message": "邮箱已被注册"
}
)
class UsernameAlreadyExistsException(HTTPException):
"""用户名已存在异常"""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "USERNAME_EXISTS",
"message": "用户名已存在"
}
)

View File

@@ -0,0 +1,92 @@
from fastapi import HTTPException, status
class FileTooLargeException(HTTPException):
"""文件过大异常"""
def __init__(self, file_size: int, max_size: int):
super().__init__(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail={
"code": "FILE_TOO_LARGE",
"message": f"文件大小 {file_size} 字节超过限制 {max_size} 字节",
"file_size": file_size,
"max_size": max_size
}
)
class StorageQuotaExceededException(HTTPException):
"""存储配额超限异常"""
def __init__(self, used_space: int, quota: int, required_space: int):
super().__init__(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail={
"code": "STORAGE_QUOTA_EXCEEDED",
"message": f"存储空间不足。已使用: {used_space} 字节,配额: {quota} 字节,需要: {required_space} 字节",
"used_space": used_space,
"quota": quota,
"required_space": required_space
}
)
class FileAlreadyExistsException(HTTPException):
"""文件已存在异常"""
def __init__(self, filename: str):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail={
"code": "FILE_ALREADY_EXISTS",
"message": f"文件 '{filename}' 已存在",
"filename": filename
}
)
class FileNotFoundException(HTTPException):
"""文件未找到异常"""
def __init__(self):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"code": "FILE_NOT_FOUND",
"message": "文件不存在"
}
)
class InvalidFileTypeException(HTTPException):
"""无效文件类型异常"""
def __init__(self, file_extension: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "INVALID_FILE_TYPE",
"message": f"不支持的文件类型: {file_extension}",
"file_extension": file_extension
}
)
class FileUploadException(HTTPException):
"""文件上传异常"""
def __init__(self, message: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "FILE_UPLOAD_FAILED",
"message": f"文件上传失败: {message}"
}
)
class FileDeleteException(HTTPException):
"""文件删除异常"""
def __init__(self, message: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "FILE_DELETE_FAILED",
"message": f"文件删除失败: {message}"
}
)

View File

@@ -0,0 +1,4 @@
from .user import User
from .file import File
__all__ = ["User", "File"]

View File

@@ -0,0 +1,86 @@
from sqlalchemy import Column, Integer, String, BigInteger, DateTime, ForeignKey, Boolean, Text
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.core.database import Base
class File(Base):
__tablename__ = "files"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
# 文件基本信息
filename = Column(String(255), nullable=False, index=True)
original_filename = Column(String(255), nullable=False) # 用户上传时的原始文件名
file_path = Column(String(500), nullable=False) # 服务器上的存储路径
file_size = Column(BigInteger, nullable=False) # 文件大小(字节)
mime_type = Column(String(100), nullable=False) # 文件MIME类型
file_hash = Column(String(64), nullable=False, index=True) # SHA-256哈希用于去重和完整性检查
# 文件状态
is_public = Column(Boolean, default=False) # 是否公开分享
download_count = Column(BigInteger, default=0) # 下载次数
# 文件元数据
description = Column(Text, nullable=True) # 文件描述
tags = Column(Text, nullable=True) # 标签,用逗号分隔
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
last_accessed_at = Column(DateTime(timezone=True), nullable=True)
# 关联关系
user = relationship("User", back_populates="files")
def __repr__(self):
return f"<File(id={self.id}, filename='{self.filename}', user_id={self.user_id})>"
def to_dict(self):
return {
"id": self.id,
"user_id": self.user_id,
"filename": self.filename,
"original_filename": self.original_filename,
"file_size": self.file_size,
"mime_type": self.mime_type,
"file_hash": self.file_hash,
"is_public": self.is_public,
"download_count": self.download_count,
"description": self.description,
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_accessed_at": self.last_accessed_at.isoformat() if self.last_accessed_at else None,
}
def get_file_extension(self) -> str:
"""获取文件扩展名"""
return self.filename.split('.')[-1].lower() if '.' in self.filename else ''
def is_image(self) -> bool:
"""判断是否为图片文件"""
return self.mime_type.startswith('image/')
def is_document(self) -> bool:
"""判断是否为文档文件"""
document_types = [
'application/pdf',
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-excel',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'text/plain',
'text/csv'
]
return self.mime_type in document_types
def get_size_formatted(self) -> str:
"""获取格式化的文件大小"""
for unit in ['B', 'KB', 'MB', 'GB']:
if self.file_size < 1024.0:
return f"{self.file_size:.1f} {unit}"
self.file_size /= 1024.0
return f"{self.file_size:.1f} TB"

View File

@@ -0,0 +1,59 @@
from sqlalchemy import Column, Integer, String, Boolean, DateTime, BigInteger, Text
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.core.database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, index=True, nullable=False)
email = Column(String(100), index=True, nullable=False)
password_hash = Column(String(255), nullable=False)
# 用户资料
avatar_url = Column(String(500), nullable=True)
# 存储配额
storage_quota = Column(BigInteger, default=104857600) # 100MB in bytes
storage_used = Column(BigInteger, default=0)
# 用户状态
is_active = Column(Boolean, default=True)
is_verified = Column(Boolean, default=False)
# 时间戳
last_login_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
# 关联关系
files = relationship("File", back_populates="user")
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
def to_dict(self):
return {
"id": self.id,
"username": self.username,
"email": self.email,
"avatar_url": self.avatar_url,
"storage_quota": self.storage_quota,
"storage_used": self.storage_used,
"is_active": self.is_active,
"is_verified": self.is_verified,
"last_login_at": self.last_login_at.isoformat() if self.last_login_at else None,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def is_storage_available(self, required_size: int) -> bool:
"""检查是否有足够的存储空间"""
return (self.storage_used + required_size) <= self.storage_quota
def get_storage_percentage(self) -> float:
"""获取已使用存储空间的百分比"""
if self.storage_quota == 0:
return 0.0
return (self.storage_used / self.storage_quota) * 100

View File

@@ -0,0 +1,2 @@
from .auth import *
from .file import *

165
backend/app/schemas/auth.py Normal file
View File

@@ -0,0 +1,165 @@
from pydantic import BaseModel, EmailStr, Field, validator
from typing import Optional
from datetime import datetime
import re
class UserRegister(BaseModel):
"""用户注册请求模型"""
username: str = Field(..., min_length=3, max_length=50, description="用户名")
email: EmailStr = Field(..., description="邮箱地址")
password: str = Field(..., min_length=6, max_length=128, description="密码")
confirm_password: str = Field(..., min_length=6, max_length=128, description="确认密码")
@validator('username')
def validate_username(cls, v):
"""验证用户名格式"""
if not re.match(r'^[a-zA-Z0-9_]+$', v):
raise ValueError('用户名只能包含字母、数字和下划线')
if v.startswith('_') or v.endswith('_'):
raise ValueError('用户名不能以下划线开头或结尾')
return v
@validator('confirm_password')
def passwords_match(cls, v, values):
"""验证密码确认"""
if 'password' in values and v != values['password']:
raise ValueError('两次输入的密码不一致')
return v
@validator('password')
def validate_password_length(cls, v):
"""验证密码长度"""
if len(v) <= 5:
raise ValueError('密码长度必须大于5个字符')
return v
class UserLogin(BaseModel):
"""用户登录请求模型"""
username: str = Field(..., description="用户名或邮箱")
password: str = Field(..., min_length=1, description="密码")
class TokenResponse(BaseModel):
"""令牌响应模型"""
access_token: str = Field(..., description="访问令牌")
refresh_token: str = Field(..., description="刷新令牌")
token_type: str = Field(default="bearer", description="令牌类型")
expires_in: int = Field(..., description="访问令牌过期时间(秒)")
class UserResponse(BaseModel):
"""用户信息响应模型"""
id: int
username: str
email: str
avatar_url: Optional[str] = None
storage_quota: int
storage_used: int
is_active: bool
is_verified: bool
last_login_at: Optional[datetime] = None
created_at: datetime
class Config:
from_attributes = True
class LoginResponse(BaseModel):
"""登录响应模型"""
user: UserResponse = Field(..., description="用户信息")
tokens: TokenResponse = Field(..., description="令牌信息")
class TokenRefresh(BaseModel):
"""令牌刷新请求模型"""
refresh_token: str = Field(..., description="刷新令牌")
class PasswordChange(BaseModel):
"""修改密码请求模型"""
current_password: str = Field(..., description="当前密码")
new_password: str = Field(..., min_length=8, max_length=128, description="新密码")
confirm_password: str = Field(..., min_length=8, max_length=128, description="确认新密码")
@validator('confirm_password')
def passwords_match(cls, v, values):
"""验证密码确认"""
if 'new_password' in values and v != values['new_password']:
raise ValueError('密码确认不匹配')
return v
@validator('new_password')
def validate_password_strength(cls, v):
"""验证密码强度"""
errors = []
if len(v) < 8:
errors.append("密码长度至少8位")
if not any(c.islower() for c in v):
errors.append("密码必须包含至少一个小写字母")
if not any(c.isupper() for c in v):
errors.append("密码必须包含至少一个大写字母")
if not any(c.isdigit() for c in v):
errors.append("密码必须包含至少一个数字")
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
if not any(c in special_chars for c in v):
errors.append("密码必须包含至少一个特殊字符")
if errors:
raise ValueError('; '.join(errors))
return v
class PasswordReset(BaseModel):
"""密码重置请求模型"""
email: EmailStr = Field(..., description="邮箱地址")
class PasswordResetConfirm(BaseModel):
"""密码重置确认模型"""
token: str = Field(..., description="重置令牌")
new_password: str = Field(..., min_length=8, max_length=128, description="新密码")
confirm_password: str = Field(..., min_length=8, max_length=128, description="确认新密码")
@validator('confirm_password')
def passwords_match(cls, v, values):
"""验证密码确认"""
if 'new_password' in values and v != values['new_password']:
raise ValueError('密码确认不匹配')
return v
@validator('new_password')
def validate_password_strength(cls, v):
"""验证密码强度"""
errors = []
if len(v) < 8:
errors.append("密码长度至少8位")
if not any(c.islower() for c in v):
errors.append("密码必须包含至少一个小写字母")
if not any(c.isupper() for c in v):
errors.append("密码必须包含至少一个大写字母")
if not any(c.isdigit() for c in v):
errors.append("密码必须包含至少一个数字")
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
if not any(c in special_chars for c in v):
errors.append("密码必须包含至少一个特殊字符")
if errors:
raise ValueError('; '.join(errors))
return v
class ApiResponse(BaseModel):
"""标准API响应模型"""
success: bool = Field(..., description="操作是否成功")
message: str = Field(..., description="响应消息")
data: Optional[dict] = Field(None, description="响应数据")
error: Optional[dict] = Field(None, description="错误信息")
class Config:
json_encoders = {
datetime: lambda v: v.isoformat() if v else None
}

148
backend/app/schemas/file.py Normal file
View File

@@ -0,0 +1,148 @@
from pydantic import BaseModel, Field, validator
from typing import Optional, List
from datetime import datetime
import re
class FileUploadRequest(BaseModel):
"""文件上传请求"""
description: Optional[str] = Field(None, max_length=500, description="文件描述")
tags: Optional[str] = Field(None, max_length=200, description="文件标签,用逗号分隔")
is_public: bool = Field(False, description="是否公开分享")
@validator('tags')
def validate_tags(cls, v):
if v:
# 验证标签格式,只允许字母、数字、中文、下划线、中划线
tags = v.split(',')
for tag in tags:
tag = tag.strip()
if not re.match(r'^[\w\u4e00-\u9fa5-]+$', tag):
raise ValueError(f"标签 '{tag}' 格式不正确,只允许字母、数字、中文、下划线、中划线")
if len(tag) > 20:
raise ValueError(f"标签 '{tag}' 长度不能超过20个字符")
return v
class FileResponse(BaseModel):
"""文件响应"""
id: int
user_id: int
filename: str
original_filename: str
file_size: int
mime_type: str
file_hash: str
is_public: bool
download_count: int
description: Optional[str] = None
tags: Optional[str] = None
created_at: datetime
updated_at: datetime
last_accessed_at: Optional[datetime] = None
class Config:
from_attributes = True
class FileListResponse(BaseModel):
"""文件列表响应"""
files: List[FileResponse]
total: int
page: int
size: int
pages: int
class Config:
from_attributes = True
class FileListRequest(BaseModel):
"""文件列表请求"""
user_id: int = Field(..., description="用户ID")
page: int = Field(1, ge=1, description="页码")
size: int = Field(20, ge=1, le=100, description="每页数量")
class FileIdRequest(BaseModel):
"""文件ID请求"""
user_id: int = Field(..., description="用户ID")
file_id: int = Field(..., description="文件ID")
class StorageInfoRequest(BaseModel):
"""存储信息请求"""
user_id: int = Field(..., description="用户ID")
class FileUpdateRequest(BaseModel):
"""文件更新请求"""
description: Optional[str] = Field(None, max_length=500, description="文件描述")
tags: Optional[str] = Field(None, max_length=200, description="文件标签,用逗号分隔")
is_public: Optional[bool] = Field(None, description="是否公开分享")
@validator('tags')
def validate_tags(cls, v):
if v:
tags = v.split(',')
for tag in tags:
tag = tag.strip()
if not re.match(r'^[\w\u4e00-\u9fa5-]+$', tag):
raise ValueError(f"标签 '{tag}' 格式不正确,只允许字母、数字、中文、下划线、中划线")
if len(tag) > 20:
raise ValueError(f"标签 '{tag}' 长度不能超过20个字符")
return v
class FileSearchRequest(BaseModel):
"""文件搜索请求"""
filename: Optional[str] = Field(None, description="文件名搜索")
tags: Optional[str] = Field(None, description="标签搜索,用逗号分隔")
mime_type: Optional[str] = Field(None, description="MIME类型过滤")
is_public: Optional[bool] = Field(None, description="是否公开文件")
start_date: Optional[datetime] = Field(None, description="开始日期")
end_date: Optional[datetime] = Field(None, description="结束日期")
min_size: Optional[int] = Field(None, ge=0, description="最小文件大小(字节)")
max_size: Optional[int] = Field(None, ge=0, description="最大文件大小(字节)")
class FileInfo(BaseModel):
"""文件信息"""
id: int
filename: str
original_filename: str
file_size: int
mime_type: str
file_hash: str
is_image: bool
is_document: bool
file_extension: str
size_formatted: str
is_public: bool = False
download_count: int = 0
description: Optional[str] = None
tags: Optional[str] = None
created_at: datetime
updated_at: datetime
last_accessed_at: Optional[datetime] = None
class Config:
from_attributes = True
class UploadResponse(BaseModel):
"""文件上传响应"""
file_info: FileResponse
message: str
success: bool
class DeleteResponse(BaseModel):
"""删除文件响应"""
message: str
success: bool
class StorageInfo(BaseModel):
"""存储信息"""
total_quota: int
used_space: int
available_space: int
usage_percentage: float
file_count: int
# 通用API响应格式
class ApiResponse(BaseModel):
"""API响应"""
success: bool
message: str
data: Optional[dict] = None
code: Optional[str] = None

View File

View File

@@ -0,0 +1,418 @@
import os
import hashlib
import uuid
import logging
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, desc
from fastapi import UploadFile, HTTPException, status
from datetime import datetime
from app.models.file import File
from app.models.user import User
from app.schemas.file import (
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
FileResponse, FileListResponse, StorageInfo, FileInfo
)
from app.core.config import settings
from app.exceptions.file import (
FileTooLargeException, StorageQuotaExceededException,
FileAlreadyExistsException, FileNotFoundException,
InvalidFileTypeException
)
class FileService:
def __init__(self, db: Session):
self.db = db
# 使用绝对路径确保文件保存正确
self.upload_dir = os.path.abspath(settings.UPLOAD_DIR)
self.max_file_size = settings.MAX_FILE_SIZE
self.allowed_extensions = settings.ALLOWED_EXTENSIONS
self.logger = logging.getLogger(__name__)
# 确保上传目录存在
os.makedirs(self.upload_dir, exist_ok=True)
self.logger.info(f"Upload directory (absolute): {self.upload_dir}")
self.logger.info(f"Upload directory exists: {os.path.exists(self.upload_dir)}")
self.logger.info(f"Current working directory: {os.getcwd()}")
def _calculate_file_hash(self, file_content: bytes) -> str:
"""计算文件的SHA-256哈希值"""
return hashlib.sha256(file_content).hexdigest()
def _generate_unique_filename(self, original_filename: str) -> str:
"""生成唯一的文件名"""
file_extension = os.path.splitext(original_filename)[1]
unique_id = str(uuid.uuid4())
return f"{unique_id}{file_extension}"
def _validate_file(self, file: UploadFile, user: User) -> None:
"""验证文件"""
# 检查文件大小
if hasattr(file, 'size') and file.size > self.max_file_size:
raise FileTooLargeException(file.size, self.max_file_size)
# 检查文件扩展名
if self.allowed_extensions:
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in self.allowed_extensions:
raise InvalidFileTypeException(file_extension)
# 检查存储配额
if hasattr(file, 'size') and not user.is_storage_available(file.size):
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file.size)
def _save_file_to_disk(self, file: UploadFile, unique_filename: str) -> Tuple[str, bytes]:
"""保存文件到磁盘"""
file_path = os.path.join(self.upload_dir, unique_filename)
# 读取文件内容
file_content = file.file.read()
# 强制输出调试信息
import sys
message = f"[CRITICAL] About to save {len(file_content)} bytes to {file_path}"
print(message, flush=True)
sys.stdout.flush()
message2 = f"[CRITICAL] Content preview: {file_content[:50] if file_content else 'EMPTY'}"
print(message2, flush=True)
sys.stdout.flush()
# 立即验证内容是否为空
if not file_content:
print("[CRITICAL] FILE CONTENT IS EMPTY!")
raise ValueError("File content is empty!")
# 使用临时文件方法确保写入成功
temp_path = file_path + '.tmp'
try:
# 先写入临时文件
with open(temp_path, "wb") as temp_file:
temp_file.write(file_content)
temp_file.flush()
os.fsync(temp_file.fileno())
# 验证临时文件
if os.path.exists(temp_path):
temp_size = os.path.getsize(temp_path)
if temp_size != len(file_content):
raise Exception(f"Temporary file size mismatch: {temp_size} != {len(file_content)}")
# 重命名为最终文件名(原子操作)
if os.name == 'nt': # Windows
if os.path.exists(file_path):
os.remove(file_path)
os.rename(temp_path, file_path)
# 最终验证
if not os.path.exists(file_path):
raise Exception("File was not created after rename")
final_size = os.path.getsize(file_path)
if final_size != len(file_content):
raise Exception(f"Final file size mismatch: {final_size} != {len(file_content)}")
except Exception as e:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
print(f"[ERROR] File save failed: {e}")
raise Exception(f"Failed to save file: {e}")
# 重置文件指针
file.file.seek(0)
return file_path, file_content
def upload_file(self, file: UploadFile, user: User, upload_request: FileUploadRequest) -> File:
"""上传文件"""
try:
# 验证文件
self._validate_file(file, user)
# 生成唯一文件名
unique_filename = self._generate_unique_filename(file.filename)
# 保存文件到磁盘
file_path, file_content = self._save_file_to_disk(file, unique_filename)
file_size = len(file_content)
# 再次检查文件大小如果没有size属性
if file_size > self.max_file_size:
# 删除已保存的文件
if os.path.exists(file_path):
os.remove(file_path)
raise FileTooLargeException(file_size, self.max_file_size)
# 检查存储配额
if not user.is_storage_available(file_size):
# 删除已保存的文件
if os.path.exists(file_path):
os.remove(file_path)
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file_size)
# 计算文件哈希
file_hash = self._calculate_file_hash(file_content)
# 检查文件是否已存在(基于哈希值)
existing_file = self.db.query(File).filter(
and_(
File.user_id == user.id,
File.file_hash == file_hash
)
).first()
if existing_file:
# 删除刚保存的文件,因为已存在相同内容的文件
if os.path.exists(file_path):
os.remove(file_path)
raise FileAlreadyExistsException(existing_file.original_filename)
# 创建文件记录
db_file = File(
user_id=user.id,
filename=unique_filename,
original_filename=file.filename,
file_path=file_path,
file_size=file_size,
mime_type=file.content_type or 'application/octet-stream',
file_hash=file_hash,
description=upload_request.description,
tags=upload_request.tags,
is_public=upload_request.is_public
)
self.db.add(db_file)
# 更新用户存储使用量
user.storage_used += file_size
self.db.commit()
self.db.refresh(db_file)
return db_file
except Exception as e:
self.db.rollback()
# 如果保存了文件但数据库操作失败,删除文件
if 'file_path' in locals() and os.path.exists(file_path):
os.remove(file_path)
raise e
def get_user_files(self, user_id: int, page: int = 1, size: int = 20) -> FileListResponse:
"""获取用户的文件列表"""
offset = (page - 1) * size
query = self.db.query(File).filter(File.user_id == user_id)
total = query.count()
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
pages = (total + size - 1) // size
return FileListResponse(
files=[FileResponse(
id=file.id,
user_id=file.user_id,
filename=file.filename,
original_filename=file.original_filename,
file_size=file.file_size,
mime_type=file.mime_type,
file_hash=file.file_hash,
is_public=file.is_public,
download_count=file.download_count,
description=file.description,
tags=file.tags,
created_at=file.created_at,
updated_at=file.updated_at,
last_accessed_at=file.last_accessed_at
) for file in files],
total=total,
page=page,
size=size,
pages=pages
)
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[File]:
"""根据ID获取文件"""
return self.db.query(File).filter(
and_(
File.id == file_id,
File.user_id == user_id
)
).first()
def update_file(self, file_id: int, user_id: int, update_request: FileUpdateRequest) -> Optional[File]:
"""更新文件信息"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
# 更新字段
if update_request.description is not None:
db_file.description = update_request.description
if update_request.tags is not None:
db_file.tags = update_request.tags
if update_request.is_public is not None:
db_file.is_public = update_request.is_public
db_file.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(db_file)
return db_file
def delete_file(self, file_id: int, user_id: int) -> bool:
"""删除文件"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
try:
# 删除磁盘上的文件
if os.path.exists(db_file.file_path):
os.remove(db_file.file_path)
# 更新用户存储使用量
user = self.db.query(User).filter(User.id == user_id).first()
if user:
user.storage_used = max(0, user.storage_used - db_file.file_size)
# 删除数据库记录
self.db.delete(db_file)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
raise e
def search_files(self, user_id: int, search_request: FileSearchRequest,
page: int = 1, size: int = 20) -> FileListResponse:
"""搜索文件"""
offset = (page - 1) * size
query = self.db.query(File).filter(File.user_id == user_id)
# 文件名搜索
if search_request.filename:
query = query.filter(
File.original_filename.ilike(f"%{search_request.filename}%")
)
# 标签搜索
if search_request.tags:
tag_list = [tag.strip() for tag in search_request.tags.split(',')]
tag_conditions = []
for tag in tag_list:
tag_conditions.append(File.tags.ilike(f"%{tag}%"))
if tag_conditions:
query = query.filter(or_(*tag_conditions))
# MIME类型过滤
if search_request.mime_type:
query = query.filter(File.mime_type == search_request.mime_type)
# 公开状态过滤
if search_request.is_public is not None:
query = query.filter(File.is_public == search_request.is_public)
# 日期范围过滤
if search_request.start_date:
query = query.filter(File.created_at >= search_request.start_date)
if search_request.end_date:
query = query.filter(File.created_at <= search_request.end_date)
# 文件大小范围过滤
if search_request.min_size:
query = query.filter(File.file_size >= search_request.min_size)
if search_request.max_size:
query = query.filter(File.file_size <= search_request.max_size)
total = query.count()
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
pages = (total + size - 1) // size
return FileListResponse(
files=[FileResponse(
id=file.id,
user_id=file.user_id,
filename=file.filename,
original_filename=file.original_filename,
file_size=file.file_size,
mime_type=file.mime_type,
file_hash=file.file_hash,
is_public=file.is_public,
download_count=file.download_count,
description=file.description,
tags=file.tags,
created_at=file.created_at,
updated_at=file.updated_at,
last_accessed_at=file.last_accessed_at
) for file in files],
total=total,
page=page,
size=size,
pages=pages
)
def get_file_info(self, file_id: int, user_id: int) -> FileInfo:
"""获取文件详细信息"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
# 更新最后访问时间
db_file.last_accessed_at = datetime.utcnow()
self.db.commit()
return FileInfo(
id=db_file.id,
filename=db_file.filename,
original_filename=db_file.original_filename,
file_size=db_file.file_size,
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
is_image=db_file.is_image(),
is_document=db_file.is_document(),
file_extension=db_file.get_file_extension(),
size_formatted=db_file.get_size_formatted(),
is_public=db_file.is_public,
download_count=db_file.download_count,
description=db_file.description,
tags=db_file.tags,
created_at=db_file.created_at,
updated_at=db_file.updated_at,
last_accessed_at=db_file.last_accessed_at
)
def get_storage_info(self, user_id: int) -> StorageInfo:
"""获取用户存储信息"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
file_count = self.db.query(File).filter(File.user_id == user_id).count()
available_space = user.storage_quota - user.storage_used
usage_percentage = user.get_storage_percentage()
return StorageInfo(
total_quota=user.storage_quota,
used_space=user.storage_used,
available_space=available_space,
usage_percentage=usage_percentage,
file_count=file_count
)
def increment_download_count(self, file_id: int) -> None:
"""增加文件下载次数"""
db_file = self.db.query(File).filter(File.id == file_id).first()
if db_file:
db_file.download_count += 1
db_file.last_accessed_at = datetime.utcnow()
self.db.commit()

View File

@@ -0,0 +1,252 @@
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, status
from typing import Optional
from datetime import datetime, timedelta
from app.models.user import User
from app.core.security import get_password_hash, verify_password, create_access_token, create_refresh_token
from app.schemas.auth import UserRegister, UserLogin, UserResponse
from app.core.config import settings
from app.exceptions.auth import UsernameAlreadyExistsException
class UserService:
"""用户服务类"""
def __init__(self, db: Session):
self.db = db
def create_user(self, user_data: UserRegister) -> User:
"""创建新用户"""
try:
# 检查用户名是否已存在
print(f"[DEBUG] Checking if username exists: {user_data.username}")
existing_user_by_username = self.get_user_by_username(user_data.username)
if existing_user_by_username:
print(f"[DEBUG] Username already exists: {existing_user_by_username}")
raise UsernameAlreadyExistsException()
print(f"[DEBUG] Username check passed")
# 邮箱允许重复,不再检查邮箱是否已存在
print(f"[DEBUG] Email uniqueness check skipped (emails can be duplicated)")
# 创建新用户
print(f"[DEBUG] About to hash password...")
try:
hashed_password = get_password_hash(user_data.password)
print(f"[DEBUG] Password hashed successfully")
except Exception as hash_error:
print(f"[ERROR] Password hashing failed: {hash_error}")
print(f"[ERROR] Error type: {type(hash_error)}")
import traceback
traceback.print_exc()
raise
print(f"[DEBUG] Creating User model...")
db_user = User(
username=user_data.username,
email=user_data.email,
password_hash=hashed_password,
is_active=True,
is_verified=False
)
print(f"[DEBUG] User model created")
self.db.add(db_user)
self.db.commit()
self.db.refresh(db_user)
return db_user
except IntegrityError as e:
self.db.rollback()
if "username" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "USERNAME_EXISTS", "message": "用户名已存在"}
)
elif "email" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "EMAIL_EXISTS", "message": "邮箱已被注册"}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "INTEGRITY_ERROR", "message": "数据完整性错误"}
)
except (HTTPException, UsernameAlreadyExistsException):
# 重新抛出HTTPException不要覆盖自定义错误信息
self.db.rollback()
raise
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "CREATION_FAILED", "message": "用户创建失败"}
)
def authenticate_user(self, login_data: UserLogin) -> Optional[User]:
"""验证用户登录"""
# 尝试通过用户名查找用户
user = self.get_user_by_username(login_data.username)
# 如果用户名找不到,尝试通过邮箱查找
if not user:
user = self.get_user_by_email(login_data.username)
# 如果找到用户且密码正确
if user and verify_password(login_data.password, user.password_hash):
# 更新最后登录时间
user.last_login_at = datetime.utcnow()
self.db.commit()
return user
return None
def get_user_by_id(self, user_id: int) -> Optional[User]:
"""根据ID获取用户"""
return self.db.query(User).filter(User.id == user_id, User.is_active == True).first()
def get_user_by_username(self, username: str) -> Optional[User]:
"""根据用户名获取用户"""
return self.db.query(User).filter(User.username == username, User.is_active == True).first()
def get_user_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
return self.db.query(User).filter(User.email == email, User.is_active == True).first()
def update_user(self, user_id: int, **kwargs) -> Optional[User]:
"""更新用户信息"""
user = self.get_user_by_id(user_id)
if not user:
return None
try:
for key, value in kwargs.items():
if hasattr(user, key) and value is not None:
setattr(user, key, value)
self.db.commit()
self.db.refresh(user)
return user
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "UPDATE_FAILED", "message": "用户信息更新失败"}
)
def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
"""修改用户密码"""
user = self.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"code": "USER_NOT_FOUND", "message": "用户不存在"}
)
# 验证当前密码
if not verify_password(current_password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "INVALID_PASSWORD", "message": "当前密码不正确"}
)
try:
# 更新密码
user.password_hash = get_password_hash(new_password)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "PASSWORD_CHANGE_FAILED", "message": "密码修改失败"}
)
def deactivate_user(self, user_id: int) -> bool:
"""停用用户"""
user = self.get_user_by_id(user_id)
if not user:
return False
try:
user.is_active = False
self.db.commit()
return True
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "DEACTIVATION_FAILED", "message": "用户停用失败"}
)
def update_storage_usage(self, user_id: int, size_change: int) -> bool:
"""更新用户存储使用量"""
user = self.get_user_by_id(user_id)
if not user:
return False
try:
new_usage = user.storage_used + size_change
# 检查存储配额
if new_usage > user.storage_quota:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"code": "STORAGE_EXCEEDED", "message": "存储空间不足"}
)
user.storage_used = max(0, new_usage) # 确保不小于0
self.db.commit()
return True
except HTTPException:
raise
except Exception as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "STORAGE_UPDATE_FAILED", "message": "存储空间更新失败"}
)
def create_user_tokens(self, user: User) -> dict:
"""为用户创建访问令牌和刷新令牌"""
access_token_expires = timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
access_token = create_access_token(
data={"sub": str(user.id), "username": user.username, "email": user.email},
expires_delta=access_token_expires
)
refresh_token = create_refresh_token(
data={"sub": str(user.id)},
expires_delta=refresh_token_expires
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": settings.JWT_EXPIRE_MINUTES * 60
}
def to_user_response(self, user: User) -> UserResponse:
"""将用户模型转换为响应模型"""
return UserResponse(
id=user.id,
username=user.username,
email=user.email,
avatar_url=user.avatar_url,
storage_quota=user.storage_quota,
storage_used=user.storage_used,
is_active=user.is_active,
is_verified=user.is_verified,
last_login_at=user.last_login_at,
created_at=user.created_at
)