初次提交
This commit is contained in:
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
54
backend/app/core/config.py
Normal file
54
backend/app/core/config.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 基础配置
|
||||
ENVIRONMENT: str = "development"
|
||||
DEBUG: bool = True
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL: str = "mysql+pymysql://mytest_db:mytest_db@101.126.85.76:3306/mytest_db"
|
||||
|
||||
# Redis配置
|
||||
REDIS_URL: str = "redis://localhost:6379"
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY: str = "your-super-secret-jwt-key-change-in-production"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_EXPIRE_MINUTES: int = 30
|
||||
JWT_REFRESH_EXPIRE_DAYS: int = 7
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_HOSTS: List[str] = ["*"] # 允许所有域名访问
|
||||
|
||||
# 文件上传配置
|
||||
MAX_FILE_SIZE: int = 10 * 1024 * 1024 # 10MB
|
||||
UPLOAD_DIR: str = "uploads"
|
||||
ALLOWED_EXTENSIONS: List[str] = [
|
||||
# 图片
|
||||
".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg",
|
||||
# 文档
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
|
||||
".txt", ".rtf", ".csv",
|
||||
# 压缩文件
|
||||
".zip", ".rar", ".7z", ".tar", ".gz",
|
||||
# 音频
|
||||
".mp3", ".wav", ".flac", ".aac", ".ogg",
|
||||
# 视频
|
||||
".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv",
|
||||
# 代码文件
|
||||
".py", ".js", ".html", ".css", ".json", ".xml", ".yaml", ".yml",
|
||||
".java", ".cpp", ".c", ".h", ".cs", ".php", ".rb", ".go",
|
||||
".sql", ".sh", ".bat", ".ps1", ".md", ".log"
|
||||
]
|
||||
|
||||
# 安全配置
|
||||
BCRYPT_ROUNDS: int = 12
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
extra = "allow" # 允许额外的环境变量
|
||||
|
||||
settings = Settings()
|
||||
30
backend/app/core/database.py
Normal file
30
backend/app/core/database.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.config import settings
|
||||
import pymysql
|
||||
|
||||
# 安装pymysql作为MySQLdb的替代
|
||||
pymysql.install_as_MySQLdb()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=300
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 创建基础模型类
|
||||
Base = declarative_base()
|
||||
|
||||
# 数据库依赖
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
151
backend/app/core/security.py
Normal file
151
backend/app/core/security.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union
|
||||
from jose import JWTError, jwt
|
||||
import bcrypt
|
||||
from app.core.config import settings
|
||||
from app.core.token_blacklist import token_blacklist
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建访问令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建刷新令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(token: str, token_type: str = "access") -> Optional[dict]:
|
||||
"""验证令牌"""
|
||||
try:
|
||||
# 首先检查令牌是否在黑名单中
|
||||
if token_blacklist.is_blacklisted(token):
|
||||
return None
|
||||
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
|
||||
# 检查令牌类型
|
||||
if payload.get("type") != token_type:
|
||||
return None
|
||||
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
try:
|
||||
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""获取密码哈希"""
|
||||
# bcrypt 限制密码长度为72字节,如果超过则截断
|
||||
if len(password.encode('utf-8')) > 72:
|
||||
password = password.encode('utf-8')[:72].decode('utf-8', errors='ignore')
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode('utf-8'), salt).decode('utf-8')
|
||||
|
||||
def create_password_reset_token(email: str) -> str:
|
||||
"""创建密码重置令牌"""
|
||||
delta = timedelta(hours=1) # 1小时有效期
|
||||
now = datetime.utcnow()
|
||||
expires = now + delta
|
||||
exp = expires.timestamp()
|
||||
encoded_jwt = jwt.encode(
|
||||
{"exp": exp, "nbf": now, "sub": email, "type": "password_reset"},
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
"""验证密码重置令牌"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
|
||||
# 检查令牌类型
|
||||
if payload.get("type") != "password_reset":
|
||||
return None
|
||||
|
||||
return payload["sub"]
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
# 密码强度验证
|
||||
def validate_password_strength(password: str) -> dict:
|
||||
"""验证密码强度"""
|
||||
errors = []
|
||||
|
||||
if len(password) < 8:
|
||||
errors.append("密码长度至少8位")
|
||||
|
||||
if len(password) > 128:
|
||||
errors.append("密码长度不能超过128位")
|
||||
|
||||
if not any(c.islower() for c in password):
|
||||
errors.append("密码必须包含至少一个小写字母")
|
||||
|
||||
if not any(c.isupper() for c in password):
|
||||
errors.append("密码必须包含至少一个大写字母")
|
||||
|
||||
if not any(c.isdigit() for c in password):
|
||||
errors.append("密码必须包含至少一个数字")
|
||||
|
||||
# 检查特殊字符
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
if not any(c in special_chars for c in password):
|
||||
errors.append("密码必须包含至少一个特殊字符")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"strength": calculate_password_strength(password)
|
||||
}
|
||||
|
||||
def calculate_password_strength(password: str) -> str:
|
||||
"""计算密码强度"""
|
||||
score = 0
|
||||
|
||||
# 长度评分
|
||||
if len(password) >= 8:
|
||||
score += 1
|
||||
if len(password) >= 12:
|
||||
score += 1
|
||||
if len(password) >= 16:
|
||||
score += 1
|
||||
|
||||
# 字符类型评分
|
||||
if any(c.islower() for c in password):
|
||||
score += 1
|
||||
if any(c.isupper() for c in password):
|
||||
score += 1
|
||||
if any(c.isdigit() for c in password):
|
||||
score += 1
|
||||
if any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
|
||||
score += 1
|
||||
|
||||
# 根据评分返回强度等级
|
||||
if score <= 2:
|
||||
return "弱"
|
||||
elif score <= 4:
|
||||
return "中等"
|
||||
elif score <= 6:
|
||||
return "强"
|
||||
else:
|
||||
return "非常强"
|
||||
46
backend/app/core/token_blacklist.py
Normal file
46
backend/app/core/token_blacklist.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
import threading
|
||||
|
||||
class TokenBlacklist:
|
||||
"""简单的令牌黑名单(内存存储)"""
|
||||
|
||||
def __init__(self):
|
||||
self._blacklisted_tokens: Dict[str, datetime] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_token(self, token: str, expires_at: Optional[datetime] = None):
|
||||
"""添加令牌到黑名单"""
|
||||
with self._lock:
|
||||
# 如果没有提供过期时间,默认24小时后过期
|
||||
if expires_at is None:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
||||
self._blacklisted_tokens[token] = expires_at
|
||||
|
||||
def is_blacklisted(self, token: str) -> bool:
|
||||
"""检查令牌是否在黑名单中"""
|
||||
with self._lock:
|
||||
if token not in self._blacklisted_tokens:
|
||||
return False
|
||||
|
||||
# 检查令牌是否已过期
|
||||
if datetime.utcnow() > self._blacklisted_tokens[token]:
|
||||
# 清理过期的令牌
|
||||
del self._blacklisted_tokens[token]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
with self._lock:
|
||||
current_time = datetime.utcnow()
|
||||
expired_tokens = [
|
||||
token for token, expires_at in self._blacklisted_tokens.items()
|
||||
if current_time > expires_at
|
||||
]
|
||||
for token in expired_tokens:
|
||||
del self._blacklisted_tokens[token]
|
||||
|
||||
# 全局黑名单实例
|
||||
token_blacklist = TokenBlacklist()
|
||||
Reference in New Issue
Block a user