初次提交

This commit is contained in:
2025-10-14 20:05:29 +08:00
commit 6e4e48fdd2
673 changed files with 437006 additions and 0 deletions

View File

@@ -0,0 +1,418 @@
import os
import hashlib
import uuid
import logging
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, desc
from fastapi import UploadFile, HTTPException, status
from datetime import datetime
from app.models.file import File
from app.models.user import User
from app.schemas.file import (
FileUploadRequest, FileUpdateRequest, FileSearchRequest,
FileResponse, FileListResponse, StorageInfo, FileInfo
)
from app.core.config import settings
from app.exceptions.file import (
FileTooLargeException, StorageQuotaExceededException,
FileAlreadyExistsException, FileNotFoundException,
InvalidFileTypeException
)
class FileService:
def __init__(self, db: Session):
self.db = db
# 使用绝对路径确保文件保存正确
self.upload_dir = os.path.abspath(settings.UPLOAD_DIR)
self.max_file_size = settings.MAX_FILE_SIZE
self.allowed_extensions = settings.ALLOWED_EXTENSIONS
self.logger = logging.getLogger(__name__)
# 确保上传目录存在
os.makedirs(self.upload_dir, exist_ok=True)
self.logger.info(f"Upload directory (absolute): {self.upload_dir}")
self.logger.info(f"Upload directory exists: {os.path.exists(self.upload_dir)}")
self.logger.info(f"Current working directory: {os.getcwd()}")
def _calculate_file_hash(self, file_content: bytes) -> str:
"""计算文件的SHA-256哈希值"""
return hashlib.sha256(file_content).hexdigest()
def _generate_unique_filename(self, original_filename: str) -> str:
"""生成唯一的文件名"""
file_extension = os.path.splitext(original_filename)[1]
unique_id = str(uuid.uuid4())
return f"{unique_id}{file_extension}"
def _validate_file(self, file: UploadFile, user: User) -> None:
"""验证文件"""
# 检查文件大小
if hasattr(file, 'size') and file.size > self.max_file_size:
raise FileTooLargeException(file.size, self.max_file_size)
# 检查文件扩展名
if self.allowed_extensions:
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in self.allowed_extensions:
raise InvalidFileTypeException(file_extension)
# 检查存储配额
if hasattr(file, 'size') and not user.is_storage_available(file.size):
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file.size)
def _save_file_to_disk(self, file: UploadFile, unique_filename: str) -> Tuple[str, bytes]:
"""保存文件到磁盘"""
file_path = os.path.join(self.upload_dir, unique_filename)
# 读取文件内容
file_content = file.file.read()
# 强制输出调试信息
import sys
message = f"[CRITICAL] About to save {len(file_content)} bytes to {file_path}"
print(message, flush=True)
sys.stdout.flush()
message2 = f"[CRITICAL] Content preview: {file_content[:50] if file_content else 'EMPTY'}"
print(message2, flush=True)
sys.stdout.flush()
# 立即验证内容是否为空
if not file_content:
print("[CRITICAL] FILE CONTENT IS EMPTY!")
raise ValueError("File content is empty!")
# 使用临时文件方法确保写入成功
temp_path = file_path + '.tmp'
try:
# 先写入临时文件
with open(temp_path, "wb") as temp_file:
temp_file.write(file_content)
temp_file.flush()
os.fsync(temp_file.fileno())
# 验证临时文件
if os.path.exists(temp_path):
temp_size = os.path.getsize(temp_path)
if temp_size != len(file_content):
raise Exception(f"Temporary file size mismatch: {temp_size} != {len(file_content)}")
# 重命名为最终文件名(原子操作)
if os.name == 'nt': # Windows
if os.path.exists(file_path):
os.remove(file_path)
os.rename(temp_path, file_path)
# 最终验证
if not os.path.exists(file_path):
raise Exception("File was not created after rename")
final_size = os.path.getsize(file_path)
if final_size != len(file_content):
raise Exception(f"Final file size mismatch: {final_size} != {len(file_content)}")
except Exception as e:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
print(f"[ERROR] File save failed: {e}")
raise Exception(f"Failed to save file: {e}")
# 重置文件指针
file.file.seek(0)
return file_path, file_content
def upload_file(self, file: UploadFile, user: User, upload_request: FileUploadRequest) -> File:
"""上传文件"""
try:
# 验证文件
self._validate_file(file, user)
# 生成唯一文件名
unique_filename = self._generate_unique_filename(file.filename)
# 保存文件到磁盘
file_path, file_content = self._save_file_to_disk(file, unique_filename)
file_size = len(file_content)
# 再次检查文件大小如果没有size属性
if file_size > self.max_file_size:
# 删除已保存的文件
if os.path.exists(file_path):
os.remove(file_path)
raise FileTooLargeException(file_size, self.max_file_size)
# 检查存储配额
if not user.is_storage_available(file_size):
# 删除已保存的文件
if os.path.exists(file_path):
os.remove(file_path)
raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file_size)
# 计算文件哈希
file_hash = self._calculate_file_hash(file_content)
# 检查文件是否已存在(基于哈希值)
existing_file = self.db.query(File).filter(
and_(
File.user_id == user.id,
File.file_hash == file_hash
)
).first()
if existing_file:
# 删除刚保存的文件,因为已存在相同内容的文件
if os.path.exists(file_path):
os.remove(file_path)
raise FileAlreadyExistsException(existing_file.original_filename)
# 创建文件记录
db_file = File(
user_id=user.id,
filename=unique_filename,
original_filename=file.filename,
file_path=file_path,
file_size=file_size,
mime_type=file.content_type or 'application/octet-stream',
file_hash=file_hash,
description=upload_request.description,
tags=upload_request.tags,
is_public=upload_request.is_public
)
self.db.add(db_file)
# 更新用户存储使用量
user.storage_used += file_size
self.db.commit()
self.db.refresh(db_file)
return db_file
except Exception as e:
self.db.rollback()
# 如果保存了文件但数据库操作失败,删除文件
if 'file_path' in locals() and os.path.exists(file_path):
os.remove(file_path)
raise e
def get_user_files(self, user_id: int, page: int = 1, size: int = 20) -> FileListResponse:
"""获取用户的文件列表"""
offset = (page - 1) * size
query = self.db.query(File).filter(File.user_id == user_id)
total = query.count()
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
pages = (total + size - 1) // size
return FileListResponse(
files=[FileResponse(
id=file.id,
user_id=file.user_id,
filename=file.filename,
original_filename=file.original_filename,
file_size=file.file_size,
mime_type=file.mime_type,
file_hash=file.file_hash,
is_public=file.is_public,
download_count=file.download_count,
description=file.description,
tags=file.tags,
created_at=file.created_at,
updated_at=file.updated_at,
last_accessed_at=file.last_accessed_at
) for file in files],
total=total,
page=page,
size=size,
pages=pages
)
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[File]:
"""根据ID获取文件"""
return self.db.query(File).filter(
and_(
File.id == file_id,
File.user_id == user_id
)
).first()
def update_file(self, file_id: int, user_id: int, update_request: FileUpdateRequest) -> Optional[File]:
"""更新文件信息"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
# 更新字段
if update_request.description is not None:
db_file.description = update_request.description
if update_request.tags is not None:
db_file.tags = update_request.tags
if update_request.is_public is not None:
db_file.is_public = update_request.is_public
db_file.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(db_file)
return db_file
def delete_file(self, file_id: int, user_id: int) -> bool:
"""删除文件"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
try:
# 删除磁盘上的文件
if os.path.exists(db_file.file_path):
os.remove(db_file.file_path)
# 更新用户存储使用量
user = self.db.query(User).filter(User.id == user_id).first()
if user:
user.storage_used = max(0, user.storage_used - db_file.file_size)
# 删除数据库记录
self.db.delete(db_file)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
raise e
def search_files(self, user_id: int, search_request: FileSearchRequest,
page: int = 1, size: int = 20) -> FileListResponse:
"""搜索文件"""
offset = (page - 1) * size
query = self.db.query(File).filter(File.user_id == user_id)
# 文件名搜索
if search_request.filename:
query = query.filter(
File.original_filename.ilike(f"%{search_request.filename}%")
)
# 标签搜索
if search_request.tags:
tag_list = [tag.strip() for tag in search_request.tags.split(',')]
tag_conditions = []
for tag in tag_list:
tag_conditions.append(File.tags.ilike(f"%{tag}%"))
if tag_conditions:
query = query.filter(or_(*tag_conditions))
# MIME类型过滤
if search_request.mime_type:
query = query.filter(File.mime_type == search_request.mime_type)
# 公开状态过滤
if search_request.is_public is not None:
query = query.filter(File.is_public == search_request.is_public)
# 日期范围过滤
if search_request.start_date:
query = query.filter(File.created_at >= search_request.start_date)
if search_request.end_date:
query = query.filter(File.created_at <= search_request.end_date)
# 文件大小范围过滤
if search_request.min_size:
query = query.filter(File.file_size >= search_request.min_size)
if search_request.max_size:
query = query.filter(File.file_size <= search_request.max_size)
total = query.count()
files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all()
pages = (total + size - 1) // size
return FileListResponse(
files=[FileResponse(
id=file.id,
user_id=file.user_id,
filename=file.filename,
original_filename=file.original_filename,
file_size=file.file_size,
mime_type=file.mime_type,
file_hash=file.file_hash,
is_public=file.is_public,
download_count=file.download_count,
description=file.description,
tags=file.tags,
created_at=file.created_at,
updated_at=file.updated_at,
last_accessed_at=file.last_accessed_at
) for file in files],
total=total,
page=page,
size=size,
pages=pages
)
def get_file_info(self, file_id: int, user_id: int) -> FileInfo:
"""获取文件详细信息"""
db_file = self.get_file_by_id(file_id, user_id)
if not db_file:
raise FileNotFoundException()
# 更新最后访问时间
db_file.last_accessed_at = datetime.utcnow()
self.db.commit()
return FileInfo(
id=db_file.id,
filename=db_file.filename,
original_filename=db_file.original_filename,
file_size=db_file.file_size,
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
is_image=db_file.is_image(),
is_document=db_file.is_document(),
file_extension=db_file.get_file_extension(),
size_formatted=db_file.get_size_formatted(),
is_public=db_file.is_public,
download_count=db_file.download_count,
description=db_file.description,
tags=db_file.tags,
created_at=db_file.created_at,
updated_at=db_file.updated_at,
last_accessed_at=db_file.last_accessed_at
)
def get_storage_info(self, user_id: int) -> StorageInfo:
"""获取用户存储信息"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
file_count = self.db.query(File).filter(File.user_id == user_id).count()
available_space = user.storage_quota - user.storage_used
usage_percentage = user.get_storage_percentage()
return StorageInfo(
total_quota=user.storage_quota,
used_space=user.storage_used,
available_space=available_space,
usage_percentage=usage_percentage,
file_count=file_count
)
def increment_download_count(self, file_id: int) -> None:
"""增加文件下载次数"""
db_file = self.db.query(File).filter(File.id == file_id).first()
if db_file:
db_file.download_count += 1
db_file.last_accessed_at = datetime.utcnow()
self.db.commit()