import os import hashlib import uuid import logging from typing import Optional, List, Tuple from sqlalchemy.orm import Session from sqlalchemy import and_, or_, desc from fastapi import UploadFile, HTTPException, status from datetime import datetime from app.models.file import File from app.models.user import User from app.schemas.file import ( FileUploadRequest, FileUpdateRequest, FileSearchRequest, FileResponse, FileListResponse, StorageInfo, FileInfo ) from app.core.config import settings from app.exceptions.file import ( FileTooLargeException, StorageQuotaExceededException, FileAlreadyExistsException, FileNotFoundException, InvalidFileTypeException ) class FileService: def __init__(self, db: Session): self.db = db # 使用绝对路径确保文件保存正确 self.upload_dir = os.path.abspath(settings.UPLOAD_DIR) self.max_file_size = settings.MAX_FILE_SIZE self.allowed_extensions = settings.ALLOWED_EXTENSIONS self.logger = logging.getLogger(__name__) # 确保上传目录存在 os.makedirs(self.upload_dir, exist_ok=True) self.logger.info(f"Upload directory (absolute): {self.upload_dir}") self.logger.info(f"Upload directory exists: {os.path.exists(self.upload_dir)}") self.logger.info(f"Current working directory: {os.getcwd()}") def _calculate_file_hash(self, file_content: bytes) -> str: """计算文件的SHA-256哈希值""" return hashlib.sha256(file_content).hexdigest() def _generate_unique_filename(self, original_filename: str) -> str: """生成唯一的文件名""" file_extension = os.path.splitext(original_filename)[1] unique_id = str(uuid.uuid4()) return f"{unique_id}{file_extension}" def _validate_file(self, file: UploadFile, user: User) -> None: """验证文件""" # 检查文件大小 if hasattr(file, 'size') and file.size > self.max_file_size: raise FileTooLargeException(file.size, self.max_file_size) # 检查文件扩展名 if self.allowed_extensions: file_extension = os.path.splitext(file.filename)[1].lower() if file_extension not in self.allowed_extensions: raise InvalidFileTypeException(file_extension) # 检查存储配额 if hasattr(file, 'size') and not user.is_storage_available(file.size): raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file.size) def _save_file_to_disk(self, file: UploadFile, unique_filename: str) -> Tuple[str, bytes]: """保存文件到磁盘""" file_path = os.path.join(self.upload_dir, unique_filename) # 读取文件内容 file_content = file.file.read() # 强制输出调试信息 import sys message = f"[CRITICAL] About to save {len(file_content)} bytes to {file_path}" print(message, flush=True) sys.stdout.flush() message2 = f"[CRITICAL] Content preview: {file_content[:50] if file_content else 'EMPTY'}" print(message2, flush=True) sys.stdout.flush() # 立即验证内容是否为空 if not file_content: print("[CRITICAL] FILE CONTENT IS EMPTY!") raise ValueError("File content is empty!") # 使用临时文件方法确保写入成功 temp_path = file_path + '.tmp' try: # 先写入临时文件 with open(temp_path, "wb") as temp_file: temp_file.write(file_content) temp_file.flush() os.fsync(temp_file.fileno()) # 验证临时文件 if os.path.exists(temp_path): temp_size = os.path.getsize(temp_path) if temp_size != len(file_content): raise Exception(f"Temporary file size mismatch: {temp_size} != {len(file_content)}") # 重命名为最终文件名(原子操作) if os.name == 'nt': # Windows if os.path.exists(file_path): os.remove(file_path) os.rename(temp_path, file_path) # 最终验证 if not os.path.exists(file_path): raise Exception("File was not created after rename") final_size = os.path.getsize(file_path) if final_size != len(file_content): raise Exception(f"Final file size mismatch: {final_size} != {len(file_content)}") except Exception as e: # 清理临时文件 if os.path.exists(temp_path): os.remove(temp_path) print(f"[ERROR] File save failed: {e}") raise Exception(f"Failed to save file: {e}") # 重置文件指针 file.file.seek(0) return file_path, file_content def upload_file(self, file: UploadFile, user: User, upload_request: FileUploadRequest) -> File: """上传文件""" try: # 验证文件 self._validate_file(file, user) # 生成唯一文件名 unique_filename = self._generate_unique_filename(file.filename) # 保存文件到磁盘 file_path, file_content = self._save_file_to_disk(file, unique_filename) file_size = len(file_content) # 再次检查文件大小(如果没有size属性) if file_size > self.max_file_size: # 删除已保存的文件 if os.path.exists(file_path): os.remove(file_path) raise FileTooLargeException(file_size, self.max_file_size) # 检查存储配额 if not user.is_storage_available(file_size): # 删除已保存的文件 if os.path.exists(file_path): os.remove(file_path) raise StorageQuotaExceededException(user.storage_used, user.storage_quota, file_size) # 计算文件哈希 file_hash = self._calculate_file_hash(file_content) # 检查文件是否已存在(基于哈希值) existing_file = self.db.query(File).filter( and_( File.user_id == user.id, File.file_hash == file_hash ) ).first() if existing_file: # 删除刚保存的文件,因为已存在相同内容的文件 if os.path.exists(file_path): os.remove(file_path) raise FileAlreadyExistsException(existing_file.original_filename) # 创建文件记录 db_file = File( user_id=user.id, filename=unique_filename, original_filename=file.filename, file_path=file_path, file_size=file_size, mime_type=file.content_type or 'application/octet-stream', file_hash=file_hash, description=upload_request.description, tags=upload_request.tags, is_public=upload_request.is_public ) self.db.add(db_file) # 更新用户存储使用量 user.storage_used += file_size self.db.commit() self.db.refresh(db_file) return db_file except Exception as e: self.db.rollback() # 如果保存了文件但数据库操作失败,删除文件 if 'file_path' in locals() and os.path.exists(file_path): os.remove(file_path) raise e def get_user_files(self, user_id: int, page: int = 1, size: int = 20) -> FileListResponse: """获取用户的文件列表""" offset = (page - 1) * size query = self.db.query(File).filter(File.user_id == user_id) total = query.count() files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all() pages = (total + size - 1) // size return FileListResponse( files=[FileResponse( id=file.id, user_id=file.user_id, filename=file.filename, original_filename=file.original_filename, file_size=file.file_size, mime_type=file.mime_type, file_hash=file.file_hash, is_public=file.is_public, download_count=file.download_count, description=file.description, tags=file.tags, created_at=file.created_at, updated_at=file.updated_at, last_accessed_at=file.last_accessed_at ) for file in files], total=total, page=page, size=size, pages=pages ) def get_file_by_id(self, file_id: int, user_id: int) -> Optional[File]: """根据ID获取文件""" return self.db.query(File).filter( and_( File.id == file_id, File.user_id == user_id ) ).first() def update_file(self, file_id: int, user_id: int, update_request: FileUpdateRequest) -> Optional[File]: """更新文件信息""" db_file = self.get_file_by_id(file_id, user_id) if not db_file: raise FileNotFoundException() # 更新字段 if update_request.description is not None: db_file.description = update_request.description if update_request.tags is not None: db_file.tags = update_request.tags if update_request.is_public is not None: db_file.is_public = update_request.is_public db_file.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(db_file) return db_file def delete_file(self, file_id: int, user_id: int) -> bool: """删除文件""" db_file = self.get_file_by_id(file_id, user_id) if not db_file: raise FileNotFoundException() try: # 删除磁盘上的文件 if os.path.exists(db_file.file_path): os.remove(db_file.file_path) # 更新用户存储使用量 user = self.db.query(User).filter(User.id == user_id).first() if user: user.storage_used = max(0, user.storage_used - db_file.file_size) # 删除数据库记录 self.db.delete(db_file) self.db.commit() return True except Exception as e: self.db.rollback() raise e def search_files(self, user_id: int, search_request: FileSearchRequest, page: int = 1, size: int = 20) -> FileListResponse: """搜索文件""" offset = (page - 1) * size query = self.db.query(File).filter(File.user_id == user_id) # 文件名搜索 if search_request.filename: query = query.filter( File.original_filename.ilike(f"%{search_request.filename}%") ) # 标签搜索 if search_request.tags: tag_list = [tag.strip() for tag in search_request.tags.split(',')] tag_conditions = [] for tag in tag_list: tag_conditions.append(File.tags.ilike(f"%{tag}%")) if tag_conditions: query = query.filter(or_(*tag_conditions)) # MIME类型过滤 if search_request.mime_type: query = query.filter(File.mime_type == search_request.mime_type) # 公开状态过滤 if search_request.is_public is not None: query = query.filter(File.is_public == search_request.is_public) # 日期范围过滤 if search_request.start_date: query = query.filter(File.created_at >= search_request.start_date) if search_request.end_date: query = query.filter(File.created_at <= search_request.end_date) # 文件大小范围过滤 if search_request.min_size: query = query.filter(File.file_size >= search_request.min_size) if search_request.max_size: query = query.filter(File.file_size <= search_request.max_size) total = query.count() files = query.order_by(desc(File.created_at)).offset(offset).limit(size).all() pages = (total + size - 1) // size return FileListResponse( files=[FileResponse( id=file.id, user_id=file.user_id, filename=file.filename, original_filename=file.original_filename, file_size=file.file_size, mime_type=file.mime_type, file_hash=file.file_hash, is_public=file.is_public, download_count=file.download_count, description=file.description, tags=file.tags, created_at=file.created_at, updated_at=file.updated_at, last_accessed_at=file.last_accessed_at ) for file in files], total=total, page=page, size=size, pages=pages ) def get_file_info(self, file_id: int, user_id: int) -> FileInfo: """获取文件详细信息""" db_file = self.get_file_by_id(file_id, user_id) if not db_file: raise FileNotFoundException() # 更新最后访问时间 db_file.last_accessed_at = datetime.utcnow() self.db.commit() return FileInfo( id=db_file.id, filename=db_file.filename, original_filename=db_file.original_filename, file_size=db_file.file_size, mime_type=db_file.mime_type, file_hash=db_file.file_hash, is_image=db_file.is_image(), is_document=db_file.is_document(), file_extension=db_file.get_file_extension(), size_formatted=db_file.get_size_formatted(), is_public=db_file.is_public, download_count=db_file.download_count, description=db_file.description, tags=db_file.tags, created_at=db_file.created_at, updated_at=db_file.updated_at, last_accessed_at=db_file.last_accessed_at ) def get_storage_info(self, user_id: int) -> StorageInfo: """获取用户存储信息""" user = self.db.query(User).filter(User.id == user_id).first() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在" ) file_count = self.db.query(File).filter(File.user_id == user_id).count() available_space = user.storage_quota - user.storage_used usage_percentage = user.get_storage_percentage() return StorageInfo( total_quota=user.storage_quota, used_space=user.storage_used, available_space=available_space, usage_percentage=usage_percentage, file_count=file_count ) def increment_download_count(self, file_id: int) -> None: """增加文件下载次数""" db_file = self.db.query(File).filter(File.id == file_id).first() if db_file: db_file.download_count += 1 db_file.last_accessed_at = datetime.utcnow() self.db.commit()