418 lines
15 KiB
Python
418 lines
15 KiB
Python
import os
|
||
import hashlib
|
||
import uuid
|
||
import logging
|
||
from typing import Optional, List, Tuple
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import and_, or_, desc
|
||
from fastapi import UploadFile, HTTPException, status
|
||
from datetime import datetime
|
||
|
||
from app.models.file import File
|
||
from app.models.user import User
|
||
from app.schemas.file import (
|
||
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
|
||
FileResponse, FileListResponse, StorageInfo, FileInfo
|
||
)
|
||
from app.core.config import settings
|
||
from app.exceptions.file import (
|
||
FileTooLargeException, StorageQuotaExceededException,
|
||
FileAlreadyExistsException, FileNotFoundException,
|
||
InvalidFileTypeException
|
||
)
|
||
|
||
class FileService:
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
# 使用绝对路径确保文件保存正确
|
||
self.upload_dir = os.path.abspath(settings.UPLOAD_DIR)
|
||
self.max_file_size = settings.MAX_FILE_SIZE
|
||
self.allowed_extensions = settings.ALLOWED_EXTENSIONS
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
# 确保上传目录存在
|
||
os.makedirs(self.upload_dir, exist_ok=True)
|
||
self.logger.info(f"Upload directory (absolute): {self.upload_dir}")
|
||
self.logger.info(f"Upload directory exists: {os.path.exists(self.upload_dir)}")
|
||
self.logger.info(f"Current working directory: {os.getcwd()}")
|
||
|
||
def _calculate_file_hash(self, file_content: bytes) -> str:
|
||
"""计算文件的SHA-256哈希值"""
|
||
return hashlib.sha256(file_content).hexdigest()
|
||
|
||
def _generate_unique_filename(self, original_filename: str) -> str:
|
||
"""生成唯一的文件名"""
|
||
file_extension = os.path.splitext(original_filename)[1]
|
||
unique_id = str(uuid.uuid4())
|
||
return f"{unique_id}{file_extension}"
|
||
|
||
def _validate_file(self, file: UploadFile, user: User) -> None:
|
||
"""验证文件"""
|
||
# 检查文件大小
|
||
if hasattr(file, 'size') and file.size > self.max_file_size:
|
||
raise FileTooLargeException(file.size, self.max_file_size)
|
||
|
||
# 检查文件扩展名
|
||
if self.allowed_extensions:
|
||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||
if file_extension not in self.allowed_extensions:
|
||
raise InvalidFileTypeException(file_extension)
|
||
|
||
# 检查存储配额
|
||
if hasattr(file, 'size') and not user.is_storage_available(file.size):
|
||
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file.size)
|
||
|
||
def _save_file_to_disk(self, file: UploadFile, unique_filename: str) -> Tuple[str, bytes]:
|
||
"""保存文件到磁盘"""
|
||
file_path = os.path.join(self.upload_dir, unique_filename)
|
||
|
||
# 读取文件内容
|
||
file_content = file.file.read()
|
||
|
||
# 强制输出调试信息
|
||
import sys
|
||
message = f"[CRITICAL] About to save {len(file_content)} bytes to {file_path}"
|
||
print(message, flush=True)
|
||
sys.stdout.flush()
|
||
message2 = f"[CRITICAL] Content preview: {file_content[:50] if file_content else 'EMPTY'}"
|
||
print(message2, flush=True)
|
||
sys.stdout.flush()
|
||
|
||
# 立即验证内容是否为空
|
||
if not file_content:
|
||
print("[CRITICAL] FILE CONTENT IS EMPTY!")
|
||
raise ValueError("File content is empty!")
|
||
|
||
# 使用临时文件方法确保写入成功
|
||
temp_path = file_path + '.tmp'
|
||
try:
|
||
# 先写入临时文件
|
||
with open(temp_path, "wb") as temp_file:
|
||
temp_file.write(file_content)
|
||
temp_file.flush()
|
||
os.fsync(temp_file.fileno())
|
||
|
||
# 验证临时文件
|
||
if os.path.exists(temp_path):
|
||
temp_size = os.path.getsize(temp_path)
|
||
if temp_size != len(file_content):
|
||
raise Exception(f"Temporary file size mismatch: {temp_size} != {len(file_content)}")
|
||
|
||
# 重命名为最终文件名(原子操作)
|
||
if os.name == 'nt': # Windows
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
os.rename(temp_path, file_path)
|
||
|
||
# 最终验证
|
||
if not os.path.exists(file_path):
|
||
raise Exception("File was not created after rename")
|
||
|
||
final_size = os.path.getsize(file_path)
|
||
if final_size != len(file_content):
|
||
raise Exception(f"Final file size mismatch: {final_size} != {len(file_content)}")
|
||
|
||
except Exception as e:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
print(f"[ERROR] File save failed: {e}")
|
||
raise Exception(f"Failed to save file: {e}")
|
||
|
||
# 重置文件指针
|
||
file.file.seek(0)
|
||
|
||
return file_path, file_content
|
||
|
||
def upload_file(self, file: UploadFile, user: User, upload_request: FileUploadRequest) -> File:
|
||
"""上传文件"""
|
||
try:
|
||
# 验证文件
|
||
self._validate_file(file, user)
|
||
|
||
# 生成唯一文件名
|
||
unique_filename = self._generate_unique_filename(file.filename)
|
||
|
||
# 保存文件到磁盘
|
||
file_path, file_content = self._save_file_to_disk(file, unique_filename)
|
||
file_size = len(file_content)
|
||
|
||
# 再次检查文件大小(如果没有size属性)
|
||
if file_size > self.max_file_size:
|
||
# 删除已保存的文件
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
raise FileTooLargeException(file_size, self.max_file_size)
|
||
|
||
# 检查存储配额
|
||
if not user.is_storage_available(file_size):
|
||
# 删除已保存的文件
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file_size)
|
||
|
||
# 计算文件哈希
|
||
file_hash = self._calculate_file_hash(file_content)
|
||
|
||
# 检查文件是否已存在(基于哈希值)
|
||
existing_file = self.db.query(File).filter(
|
||
and_(
|
||
File.user_id == user.id,
|
||
File.file_hash == file_hash
|
||
)
|
||
).first()
|
||
|
||
if existing_file:
|
||
# 删除刚保存的文件,因为已存在相同内容的文件
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
raise FileAlreadyExistsException(existing_file.original_filename)
|
||
|
||
# 创建文件记录
|
||
db_file = File(
|
||
user_id=user.id,
|
||
filename=unique_filename,
|
||
original_filename=file.filename,
|
||
file_path=file_path,
|
||
file_size=file_size,
|
||
mime_type=file.content_type or 'application/octet-stream',
|
||
file_hash=file_hash,
|
||
description=upload_request.description,
|
||
tags=upload_request.tags,
|
||
is_public=upload_request.is_public
|
||
)
|
||
|
||
self.db.add(db_file)
|
||
|
||
# 更新用户存储使用量
|
||
user.storage_used += file_size
|
||
|
||
self.db.commit()
|
||
self.db.refresh(db_file)
|
||
|
||
return db_file
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
# 如果保存了文件但数据库操作失败,删除文件
|
||
if 'file_path' in locals() and os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
raise e
|
||
|
||
def get_user_files(self, user_id: int, page: int = 1, size: int = 20) -> FileListResponse:
|
||
"""获取用户的文件列表"""
|
||
offset = (page - 1) * size
|
||
|
||
query = self.db.query(File).filter(File.user_id == user_id)
|
||
|
||
total = query.count()
|
||
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
|
||
|
||
pages = (total + size - 1) // size
|
||
|
||
return FileListResponse(
|
||
files=[FileResponse(
|
||
id=file.id,
|
||
user_id=file.user_id,
|
||
filename=file.filename,
|
||
original_filename=file.original_filename,
|
||
file_size=file.file_size,
|
||
mime_type=file.mime_type,
|
||
file_hash=file.file_hash,
|
||
is_public=file.is_public,
|
||
download_count=file.download_count,
|
||
description=file.description,
|
||
tags=file.tags,
|
||
created_at=file.created_at,
|
||
updated_at=file.updated_at,
|
||
last_accessed_at=file.last_accessed_at
|
||
) for file in files],
|
||
total=total,
|
||
page=page,
|
||
size=size,
|
||
pages=pages
|
||
)
|
||
|
||
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[File]:
|
||
"""根据ID获取文件"""
|
||
return self.db.query(File).filter(
|
||
and_(
|
||
File.id == file_id,
|
||
File.user_id == user_id
|
||
)
|
||
).first()
|
||
|
||
def update_file(self, file_id: int, user_id: int, update_request: FileUpdateRequest) -> Optional[File]:
|
||
"""更新文件信息"""
|
||
db_file = self.get_file_by_id(file_id, user_id)
|
||
if not db_file:
|
||
raise FileNotFoundException()
|
||
|
||
# 更新字段
|
||
if update_request.description is not None:
|
||
db_file.description = update_request.description
|
||
if update_request.tags is not None:
|
||
db_file.tags = update_request.tags
|
||
if update_request.is_public is not None:
|
||
db_file.is_public = update_request.is_public
|
||
|
||
db_file.updated_at = datetime.utcnow()
|
||
|
||
self.db.commit()
|
||
self.db.refresh(db_file)
|
||
|
||
return db_file
|
||
|
||
def delete_file(self, file_id: int, user_id: int) -> bool:
|
||
"""删除文件"""
|
||
db_file = self.get_file_by_id(file_id, user_id)
|
||
if not db_file:
|
||
raise FileNotFoundException()
|
||
|
||
try:
|
||
# 删除磁盘上的文件
|
||
if os.path.exists(db_file.file_path):
|
||
os.remove(db_file.file_path)
|
||
|
||
# 更新用户存储使用量
|
||
user = self.db.query(User).filter(User.id == user_id).first()
|
||
if user:
|
||
user.storage_used = max(0, user.storage_used - db_file.file_size)
|
||
|
||
# 删除数据库记录
|
||
self.db.delete(db_file)
|
||
self.db.commit()
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
raise e
|
||
|
||
def search_files(self, user_id: int, search_request: FileSearchRequest,
|
||
page: int = 1, size: int = 20) -> FileListResponse:
|
||
"""搜索文件"""
|
||
offset = (page - 1) * size
|
||
|
||
query = self.db.query(File).filter(File.user_id == user_id)
|
||
|
||
# 文件名搜索
|
||
if search_request.filename:
|
||
query = query.filter(
|
||
File.original_filename.ilike(f"%{search_request.filename}%")
|
||
)
|
||
|
||
# 标签搜索
|
||
if search_request.tags:
|
||
tag_list = [tag.strip() for tag in search_request.tags.split(',')]
|
||
tag_conditions = []
|
||
for tag in tag_list:
|
||
tag_conditions.append(File.tags.ilike(f"%{tag}%"))
|
||
if tag_conditions:
|
||
query = query.filter(or_(*tag_conditions))
|
||
|
||
# MIME类型过滤
|
||
if search_request.mime_type:
|
||
query = query.filter(File.mime_type == search_request.mime_type)
|
||
|
||
# 公开状态过滤
|
||
if search_request.is_public is not None:
|
||
query = query.filter(File.is_public == search_request.is_public)
|
||
|
||
# 日期范围过滤
|
||
if search_request.start_date:
|
||
query = query.filter(File.created_at >= search_request.start_date)
|
||
if search_request.end_date:
|
||
query = query.filter(File.created_at <= search_request.end_date)
|
||
|
||
# 文件大小范围过滤
|
||
if search_request.min_size:
|
||
query = query.filter(File.file_size >= search_request.min_size)
|
||
if search_request.max_size:
|
||
query = query.filter(File.file_size <= search_request.max_size)
|
||
|
||
total = query.count()
|
||
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
|
||
pages = (total + size - 1) // size
|
||
|
||
return FileListResponse(
|
||
files=[FileResponse(
|
||
id=file.id,
|
||
user_id=file.user_id,
|
||
filename=file.filename,
|
||
original_filename=file.original_filename,
|
||
file_size=file.file_size,
|
||
mime_type=file.mime_type,
|
||
file_hash=file.file_hash,
|
||
is_public=file.is_public,
|
||
download_count=file.download_count,
|
||
description=file.description,
|
||
tags=file.tags,
|
||
created_at=file.created_at,
|
||
updated_at=file.updated_at,
|
||
last_accessed_at=file.last_accessed_at
|
||
) for file in files],
|
||
total=total,
|
||
page=page,
|
||
size=size,
|
||
pages=pages
|
||
)
|
||
|
||
def get_file_info(self, file_id: int, user_id: int) -> FileInfo:
|
||
"""获取文件详细信息"""
|
||
db_file = self.get_file_by_id(file_id, user_id)
|
||
if not db_file:
|
||
raise FileNotFoundException()
|
||
|
||
# 更新最后访问时间
|
||
db_file.last_accessed_at = datetime.utcnow()
|
||
self.db.commit()
|
||
|
||
return FileInfo(
|
||
id=db_file.id,
|
||
filename=db_file.filename,
|
||
original_filename=db_file.original_filename,
|
||
file_size=db_file.file_size,
|
||
mime_type=db_file.mime_type,
|
||
file_hash=db_file.file_hash,
|
||
is_image=db_file.is_image(),
|
||
is_document=db_file.is_document(),
|
||
file_extension=db_file.get_file_extension(),
|
||
size_formatted=db_file.get_size_formatted(),
|
||
is_public=db_file.is_public,
|
||
download_count=db_file.download_count,
|
||
description=db_file.description,
|
||
tags=db_file.tags,
|
||
created_at=db_file.created_at,
|
||
updated_at=db_file.updated_at,
|
||
last_accessed_at=db_file.last_accessed_at
|
||
)
|
||
|
||
def get_storage_info(self, user_id: int) -> StorageInfo:
|
||
"""获取用户存储信息"""
|
||
user = self.db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="用户不存在"
|
||
)
|
||
|
||
file_count = self.db.query(File).filter(File.user_id == user_id).count()
|
||
available_space = user.storage_quota - user.storage_used
|
||
usage_percentage = user.get_storage_percentage()
|
||
|
||
return StorageInfo(
|
||
total_quota=user.storage_quota,
|
||
used_space=user.storage_used,
|
||
available_space=available_space,
|
||
usage_percentage=usage_percentage,
|
||
file_count=file_count
|
||
)
|
||
|
||
def increment_download_count(self, file_id: int) -> None:
|
||
"""增加文件下载次数"""
|
||
db_file = self.db.query(File).filter(File.id == file_id).first()
|
||
if db_file:
|
||
db_file.download_count += 1
|
||
db_file.last_accessed_at = datetime.utcnow()
|
||
self.db.commit() |