初次提交

This commit is contained in:
2025-10-14 20:05:29 +08:00
commit 6e4e48fdd2
673 changed files with 437006 additions and 0 deletions

View File

View File

@@ -0,0 +1,54 @@
from pydantic_settings import BaseSettings
from typing import List
import os
class Settings(BaseSettings):
# 基础配置
ENVIRONMENT: str = "development"
DEBUG: bool = True
# 数据库配置
DATABASE_URL: str = "mysql+pymysql://mytest_db:mytest_db@101.126.85.76:3306/mytest_db"
# Redis配置
REDIS_URL: str = "redis://localhost:6379"
# JWT配置
JWT_SECRET_KEY: str = "your-super-secret-jwt-key-change-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_MINUTES: int = 30
JWT_REFRESH_EXPIRE_DAYS: int = 7
# CORS配置
ALLOWED_HOSTS: List[str] = ["*"] # 允许所有域名访问
# 文件上传配置
MAX_FILE_SIZE: int = 10 * 1024 * 1024 # 10MB
UPLOAD_DIR: str = "uploads"
ALLOWED_EXTENSIONS: List[str] = [
# 图片
".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg",
# 文档
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
".txt", ".rtf", ".csv",
# 压缩文件
".zip", ".rar", ".7z", ".tar", ".gz",
# 音频
".mp3", ".wav", ".flac", ".aac", ".ogg",
# 视频
".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv",
# 代码文件
".py", ".js", ".html", ".css", ".json", ".xml", ".yaml", ".yml",
".java", ".cpp", ".c", ".h", ".cs", ".php", ".rb", ".go",
".sql", ".sh", ".bat", ".ps1", ".md", ".log"
]
# 安全配置
BCRYPT_ROUNDS: int = 12
class Config:
env_file = ".env"
case_sensitive = True
extra = "allow" # 允许额外的环境变量
settings = Settings()

View File

@@ -0,0 +1,30 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
import pymysql
# 安装pymysql作为MySQLdb的替代
pymysql.install_as_MySQLdb()
# 创建数据库引擎
engine = create_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_recycle=300
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建基础模型类
Base = declarative_base()
# 数据库依赖
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -0,0 +1,151 @@
from datetime import datetime, timedelta
from typing import Optional, Union
from jose import JWTError, jwt
import bcrypt
from app.core.config import settings
from app.core.token_blacklist import token_blacklist
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建刷新令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def verify_token(token: str, token_type: str = "access") -> Optional[dict]:
"""验证令牌"""
try:
# 首先检查令牌是否在黑名单中
if token_blacklist.is_blacklisted(token):
return None
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# 检查令牌类型
if payload.get("type") != token_type:
return None
return payload
except JWTError:
return None
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
try:
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
except:
return False
def get_password_hash(password: str) -> str:
"""获取密码哈希"""
# bcrypt 限制密码长度为72字节如果超过则截断
if len(password.encode('utf-8')) > 72:
password = password.encode('utf-8')[:72].decode('utf-8', errors='ignore')
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode('utf-8'), salt).decode('utf-8')
def create_password_reset_token(email: str) -> str:
"""创建密码重置令牌"""
delta = timedelta(hours=1) # 1小时有效期
now = datetime.utcnow()
expires = now + delta
exp = expires.timestamp()
encoded_jwt = jwt.encode(
{"exp": exp, "nbf": now, "sub": email, "type": "password_reset"},
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM,
)
return encoded_jwt
def verify_password_reset_token(token: str) -> Optional[str]:
"""验证密码重置令牌"""
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# 检查令牌类型
if payload.get("type") != "password_reset":
return None
return payload["sub"]
except JWTError:
return None
# 密码强度验证
def validate_password_strength(password: str) -> dict:
"""验证密码强度"""
errors = []
if len(password) < 8:
errors.append("密码长度至少8位")
if len(password) > 128:
errors.append("密码长度不能超过128位")
if not any(c.islower() for c in password):
errors.append("密码必须包含至少一个小写字母")
if not any(c.isupper() for c in password):
errors.append("密码必须包含至少一个大写字母")
if not any(c.isdigit() for c in password):
errors.append("密码必须包含至少一个数字")
# 检查特殊字符
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
if not any(c in special_chars for c in password):
errors.append("密码必须包含至少一个特殊字符")
return {
"is_valid": len(errors) == 0,
"errors": errors,
"strength": calculate_password_strength(password)
}
def calculate_password_strength(password: str) -> str:
"""计算密码强度"""
score = 0
# 长度评分
if len(password) >= 8:
score += 1
if len(password) >= 12:
score += 1
if len(password) >= 16:
score += 1
# 字符类型评分
if any(c.islower() for c in password):
score += 1
if any(c.isupper() for c in password):
score += 1
if any(c.isdigit() for c in password):
score += 1
if any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
score += 1
# 根据评分返回强度等级
if score <= 2:
return ""
elif score <= 4:
return "中等"
elif score <= 6:
return ""
else:
return "非常强"

View File

@@ -0,0 +1,46 @@
from datetime import datetime, timedelta
from typing import Dict, Optional
import threading
class TokenBlacklist:
"""简单的令牌黑名单(内存存储)"""
def __init__(self):
self._blacklisted_tokens: Dict[str, datetime] = {}
self._lock = threading.Lock()
def add_token(self, token: str, expires_at: Optional[datetime] = None):
"""添加令牌到黑名单"""
with self._lock:
# 如果没有提供过期时间默认24小时后过期
if expires_at is None:
expires_at = datetime.utcnow() + timedelta(hours=24)
self._blacklisted_tokens[token] = expires_at
def is_blacklisted(self, token: str) -> bool:
"""检查令牌是否在黑名单中"""
with self._lock:
if token not in self._blacklisted_tokens:
return False
# 检查令牌是否已过期
if datetime.utcnow() > self._blacklisted_tokens[token]:
# 清理过期的令牌
del self._blacklisted_tokens[token]
return False
return True
def cleanup_expired_tokens(self):
"""清理过期的令牌"""
with self._lock:
current_time = datetime.utcnow()
expired_tokens = [
token for token, expires_at in self._blacklisted_tokens.items()
if current_time > expires_at
]
for token in expired_tokens:
del self._blacklisted_tokens[token]
# 全局黑名单实例
token_blacklist = TokenBlacklist()