Files
full-stack-doc/backend/app/services/file_service.py
2025-10-14 20:05:29 +08:00

418 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import hashlib
import uuid
import logging
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, desc
from fastapi import UploadFile, HTTPException, status
from datetime import datetime
from app.models.file import File
from app.models.user import User
from app.schemas.file import (
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
FileResponse, FileListResponse, StorageInfo, FileInfo
)
from app.core.config import settings
from app.exceptions.file import (
FileTooLargeException, StorageQuotaExceededException,
FileAlreadyExistsException, FileNotFoundException,
InvalidFileTypeException
)
class FileService:
def __init__(self, db: Session):
self.db = db
# 使用绝对路径确保文件保存正确
self.upload_dir = os.path.abspath(settings.UPLOAD_DIR)
self.max_file_size = settings.MAX_FILE_SIZE
self.allowed_extensions = settings.ALLOWED_EXTENSIONS
self.logger = logging.getLogger(__name__)
# 确保上传目录存在
os.makedirs(self.upload_dir, exist_ok=True)
self.logger.info(f"Upload directory (absolute): {self.upload_dir}")
self.logger.info(f"Upload directory exists: {os.path.exists(self.upload_dir)}")
self.logger.info(f"Current working directory: {os.getcwd()}")
def _calculate_file_hash(self, file_content: bytes) -> str:
"""计算文件的SHA-256哈希值"""
return hashlib.sha256(file_content).hexdigest()
def _generate_unique_filename(self, original_filename: str) -> str:
"""生成唯一的文件名"""
file_extension = os.path.splitext(original_filename)[1]
unique_id = str(uuid.uuid4())
return f"{unique_id}{file_extension}"
def _validate_file(self, file: UploadFile, user: User) -> None:
"""验证文件"""
# 检查文件大小
if hasattr(file, 'size') and file.size > self.max_file_size:
raise FileTooLargeException(file.size, self.max_file_size)
# 检查文件扩展名
if self.allowed_extensions:
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in self.allowed_extensions:
raise InvalidFileTypeException(file_extension)
# 检查存储配额
if hasattr(file, 'size') and not user.is_storage_available(file.size):
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file.size)
def _save_file_to_disk(self, file: UploadFile, unique_filename: str) -> Tuple[str, bytes]:
"""保存文件到磁盘"""
file_path = os.path.join(self.upload_dir, unique_filename)
# 读取文件内容
file_content = file.file.read()
# 强制输出调试信息
import sys
message = f"[CRITICAL] About to save {len(file_content)} bytes to {file_path}"
print(message, flush=True)
sys.stdout.flush()
message2 = f"[CRITICAL] Content preview: {file_content[:50] if file_content else 'EMPTY'}"
print(message2, flush=True)
sys.stdout.flush()
# 立即验证内容是否为空
if not file_content:
print("[CRITICAL] FILE CONTENT IS EMPTY!")
raise ValueError("File content is empty!")
# 使用临时文件方法确保写入成功
temp_path = file_path + '.tmp'
try:
# 先写入临时文件
with open(temp_path, "wb") as temp_file:
temp_file.write(file_content)
temp_file.flush()
os.fsync(temp_file.fileno())
# 验证临时文件
if os.path.exists(temp_path):
temp_size = os.path.getsize(temp_path)
if temp_size != len(file_content):
raise Exception(f"Temporary file size mismatch: {temp_size} != {len(file_content)}")
# 重命名为最终文件名(原子操作)
if os.name == 'nt': # Windows
if os.path.exists(file_path):
os.remove(file_path)
os.rename(temp_path, file_path)
# 最终验证
if not os.path.exists(file_path):
raise Exception("File was not created after rename")
final_size = os.path.getsize(file_path)
if final_size != len(file_content):
raise Exception(f"Final file size mismatch: {final_size} != {len(file_content)}")
except Exception as e:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
print(f"[ERROR] File save failed: {e}")
raise Exception(f"Failed to save file: {e}")
# 重置文件指针
file.file.seek(0)
return file_path, file_content
def upload_file(self, file: UploadFile, user: User, upload_request: FileUploadRequest) -> File:
"""上传文件"""
try:
# 验证文件
self._validate_file(file, user)
# 生成唯一文件名
unique_filename = self._generate_unique_filename(file.filename)
# 保存文件到磁盘
file_path, file_content = self._save_file_to_disk(file, unique_filename)
file_size = len(file_content)
# 再次检查文件大小如果没有size属性
if file_size > self.max_file_size:
# 删除已保存的文件
if os.path.exists(file_path):
os.remove(file_path)
raise FileTooLargeException(file_size, self.max_file_size)
# 检查存储配额
if not user.is_storage_available(file_size):
# 删除已保存的文件
if os.path.exists(file_path):
os.remove(file_path)
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file_size)
# 计算文件哈希
file_hash = self._calculate_file_hash(file_content)
# 检查文件是否已存在(基于哈希值)
existing_file = self.db.query(File).filter(
and_(
File.user_id == user.id,
File.file_hash == file_hash
)
).first()
if existing_file:
# 删除刚保存的文件,因为已存在相同内容的文件
if os.path.exists(file_path):
os.remove(file_path)
raise FileAlreadyExistsException(existing_file.original_filename)
# 创建文件记录
db_file = File(
user_id=user.id,
filename=unique_filename,
original_filename=file.filename,
file_path=file_path,
file_size=file_size,
mime_type=file.content_type or 'application/octet-stream',
file_hash=file_hash,
description=upload_request.description,
tags=upload_request.tags,
is_public=upload_request.is_public
)
self.db.add(db_file)
# 更新用户存储使用量
user.storage_used += file_size
self.db.commit()
self.db.refresh(db_file)
return db_file
except Exception as e:
self.db.rollback()
# 如果保存了文件但数据库操作失败,删除文件
if 'file_path' in locals() and os.path.exists(file_path):
os.remove(file_path)
raise e
def get_user_files(self, user_id: int, page: int = 1, size: int = 20) -> FileListResponse:
"""获取用户的文件列表"""
offset = (page - 1) * size
query = self.db.query(File).filter(File.user_id == user_id)
total = query.count()
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
pages = (total + size - 1) // size
return FileListResponse(
files=[FileResponse(
id=file.id,
user_id=file.user_id,
filename=file.filename,
original_filename=file.original_filename,
file_size=file.file_size,
mime_type=file.mime_type,
file_hash=file.file_hash,
is_public=file.is_public,
download_count=file.download_count,
description=file.description,
tags=file.tags,
created_at=file.created_at,
updated_at=file.updated_at,
last_accessed_at=file.last_accessed_at
) for file in files],
total=total,
page=page,
size=size,
pages=pages
)
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[File]:
"""根据ID获取文件"""
return self.db.query(File).filter(
and_(
File.id == file_id,
File.user_id == user_id
)
).first()
def update_file(self, file_id: int, user_id: int, update_request: FileUpdateRequest) -> Optional[File]:
"""更新文件信息"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
# 更新字段
if update_request.description is not None:
db_file.description = update_request.description
if update_request.tags is not None:
db_file.tags = update_request.tags
if update_request.is_public is not None:
db_file.is_public = update_request.is_public
db_file.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(db_file)
return db_file
def delete_file(self, file_id: int, user_id: int) -> bool:
"""删除文件"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
try:
# 删除磁盘上的文件
if os.path.exists(db_file.file_path):
os.remove(db_file.file_path)
# 更新用户存储使用量
user = self.db.query(User).filter(User.id == user_id).first()
if user:
user.storage_used = max(0, user.storage_used - db_file.file_size)
# 删除数据库记录
self.db.delete(db_file)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
raise e
def search_files(self, user_id: int, search_request: FileSearchRequest,
page: int = 1, size: int = 20) -> FileListResponse:
"""搜索文件"""
offset = (page - 1) * size
query = self.db.query(File).filter(File.user_id == user_id)
# 文件名搜索
if search_request.filename:
query = query.filter(
File.original_filename.ilike(f"%{search_request.filename}%")
)
# 标签搜索
if search_request.tags:
tag_list = [tag.strip() for tag in search_request.tags.split(',')]
tag_conditions = []
for tag in tag_list:
tag_conditions.append(File.tags.ilike(f"%{tag}%"))
if tag_conditions:
query = query.filter(or_(*tag_conditions))
# MIME类型过滤
if search_request.mime_type:
query = query.filter(File.mime_type == search_request.mime_type)
# 公开状态过滤
if search_request.is_public is not None:
query = query.filter(File.is_public == search_request.is_public)
# 日期范围过滤
if search_request.start_date:
query = query.filter(File.created_at >= search_request.start_date)
if search_request.end_date:
query = query.filter(File.created_at <= search_request.end_date)
# 文件大小范围过滤
if search_request.min_size:
query = query.filter(File.file_size >= search_request.min_size)
if search_request.max_size:
query = query.filter(File.file_size <= search_request.max_size)
total = query.count()
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
pages = (total + size - 1) // size
return FileListResponse(
files=[FileResponse(
id=file.id,
user_id=file.user_id,
filename=file.filename,
original_filename=file.original_filename,
file_size=file.file_size,
mime_type=file.mime_type,
file_hash=file.file_hash,
is_public=file.is_public,
download_count=file.download_count,
description=file.description,
tags=file.tags,
created_at=file.created_at,
updated_at=file.updated_at,
last_accessed_at=file.last_accessed_at
) for file in files],
total=total,
page=page,
size=size,
pages=pages
)
def get_file_info(self, file_id: int, user_id: int) -> FileInfo:
"""获取文件详细信息"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
# 更新最后访问时间
db_file.last_accessed_at = datetime.utcnow()
self.db.commit()
return FileInfo(
id=db_file.id,
filename=db_file.filename,
original_filename=db_file.original_filename,
file_size=db_file.file_size,
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
is_image=db_file.is_image(),
is_document=db_file.is_document(),
file_extension=db_file.get_file_extension(),
size_formatted=db_file.get_size_formatted(),
is_public=db_file.is_public,
download_count=db_file.download_count,
description=db_file.description,
tags=db_file.tags,
created_at=db_file.created_at,
updated_at=db_file.updated_at,
last_accessed_at=db_file.last_accessed_at
)
def get_storage_info(self, user_id: int) -> StorageInfo:
"""获取用户存储信息"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
file_count = self.db.query(File).filter(File.user_id == user_id).count()
available_space = user.storage_quota - user.storage_used
usage_percentage = user.get_storage_percentage()
return StorageInfo(
total_quota=user.storage_quota,
used_space=user.storage_used,
available_space=available_space,
usage_percentage=usage_percentage,
file_count=file_count
)
def increment_download_count(self, file_id: int) -> None:
"""增加文件下载次数"""
db_file = self.db.query(File).filter(File.id == file_id).first()
if db_file:
db_file.download_count += 1
db_file.last_accessed_at = datetime.utcnow()
self.db.commit()