296 lines
10 KiB
Python
296 lines
10 KiB
Python
import pytest
|
|
import json
|
|
from httpx import AsyncClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from app.main import app
|
|
from app.core.database import get_db, Base
|
|
from app.core.config import settings
|
|
from app.models.user import User
|
|
|
|
# 测试数据库配置
|
|
SQLALCHEMY_DATABASE_URL = "sqlite:///./test_auth.db"
|
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_session():
|
|
"""创建测试数据库会话"""
|
|
Base.metadata.create_all(bind=engine)
|
|
session = TestingSessionLocal()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client(db_session):
|
|
"""创建测试客户端"""
|
|
def override_get_db():
|
|
try:
|
|
yield db_session
|
|
finally:
|
|
pass
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
with AsyncClient(app=app, base_url="http://test") as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
@pytest.fixture
|
|
def test_user_data():
|
|
"""测试用户数据"""
|
|
return {
|
|
"username": "testuser123",
|
|
"email": "test123@example.com",
|
|
"password": "TestPass123!",
|
|
"confirm_password": "TestPass123!"
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_success(client: AsyncClient, test_user_data):
|
|
"""测试用户注册成功"""
|
|
response = await client.post("/api/v1/auth/register", json=test_user_data)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert "user" in data["data"]
|
|
assert "tokens" in data["data"]
|
|
assert data["data"]["user"]["username"] == test_user_data["username"]
|
|
assert data["data"]["user"]["email"] == test_user_data["email"]
|
|
assert "access_token" in data["data"]["tokens"]
|
|
assert "refresh_token" in data["data"]["tokens"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_username_exists(client: AsyncClient, test_user_data):
|
|
"""测试用户名已存在的情况"""
|
|
# 先注册一个用户
|
|
await client.post("/api/v1/auth/register", json=test_user_data)
|
|
|
|
# 再次使用相同用户名注册
|
|
response = await client.post("/api/v1/auth/register", json=test_user_data)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert data["detail"]["code"] == "USERNAME_EXISTS"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_email_exists(client: AsyncClient, test_user_data):
|
|
"""测试邮箱已存在的情况"""
|
|
# 先注册一个用户
|
|
await client.post("/api/v1/auth/register", json=test_user_data)
|
|
|
|
# 使用相同邮箱但不同用户名注册
|
|
duplicate_email_data = test_user_data.copy()
|
|
duplicate_email_data["username"] = "differentuser"
|
|
|
|
response = await client.post("/api/v1/auth/register", json=duplicate_email_data)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert data["detail"]["code"] == "EMAIL_EXISTS"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_invalid_email(client: AsyncClient, test_user_data):
|
|
"""测试无效邮箱"""
|
|
invalid_data = test_user_data.copy()
|
|
invalid_data["email"] = "invalid-email"
|
|
|
|
response = await client.post("/api/v1/auth/register", json=invalid_data)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_weak_password(client: AsyncClient, test_user_data):
|
|
"""测试弱密码"""
|
|
invalid_data = test_user_data.copy()
|
|
invalid_data["password"] = "123"
|
|
invalid_data["confirm_password"] = "123"
|
|
|
|
response = await client.post("/api/v1/auth/register", json=invalid_data)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_password_mismatch(client: AsyncClient, test_user_data):
|
|
"""测试密码确认不匹配"""
|
|
invalid_data = test_user_data.copy()
|
|
invalid_data["confirm_password"] = "DifferentPass123!"
|
|
|
|
response = await client.post("/api/v1/auth/register", json=invalid_data)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_user_success(client: AsyncClient, test_user_data):
|
|
"""测试用户登录成功"""
|
|
# 先注册用户
|
|
await client.post("/api/v1/auth/register", json=test_user_data)
|
|
|
|
# 登录
|
|
login_data = {
|
|
"username": test_user_data["username"],
|
|
"password": test_user_data["password"]
|
|
}
|
|
|
|
response = await client.post("/api/v1/auth/login", json=login_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert "user" in data["data"]
|
|
assert "tokens" in data["data"]
|
|
assert data["data"]["user"]["username"] == test_user_data["username"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_user_with_email(client: AsyncClient, test_user_data):
|
|
"""测试使用邮箱登录"""
|
|
# 先注册用户
|
|
await client.post("/api/v1/auth/register", json=test_user_data)
|
|
|
|
# 使用邮箱登录
|
|
login_data = {
|
|
"username": test_user_data["email"],
|
|
"password": test_user_data["password"]
|
|
}
|
|
|
|
response = await client.post("/api/v1/auth/login", json=login_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_user_invalid_credentials(client: AsyncClient, test_user_data):
|
|
"""测试无效凭据登录"""
|
|
# 先注册用户
|
|
await client.post("/api/v1/auth/register", json=test_user_data)
|
|
|
|
# 使用错误密码登录
|
|
login_data = {
|
|
"username": test_user_data["username"],
|
|
"password": "WrongPassword123!"
|
|
}
|
|
|
|
response = await client.post("/api/v1/auth/login", json=login_data)
|
|
|
|
assert response.status_code == 401
|
|
data = response.json()
|
|
assert data["detail"]["code"] == "INVALID_CREDENTIALS"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_user_not_exists(client: AsyncClient):
|
|
"""测试登录不存在的用户"""
|
|
login_data = {
|
|
"username": "nonexistentuser",
|
|
"password": "SomePassword123!"
|
|
}
|
|
|
|
response = await client.post("/api/v1/auth/login", json=login_data)
|
|
|
|
assert response.status_code == 401
|
|
data = response.json()
|
|
assert data["detail"]["code"] == "INVALID_CREDENTIALS"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_current_user_info(client: AsyncClient, test_user_data):
|
|
"""测试获取当前用户信息"""
|
|
# 先注册用户
|
|
register_response = await client.post("/api/v1/auth/register", json=test_user_data)
|
|
access_token = register_response.json()["data"]["tokens"]["access_token"]
|
|
|
|
# 获取用户信息
|
|
headers = {"Authorization": f"Bearer {access_token}"}
|
|
response = await client.get("/api/v1/auth/me", headers=headers)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"]["user"]["username"] == test_user_data["username"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_current_user_info_invalid_token(client: AsyncClient):
|
|
"""测试使用无效令牌获取用户信息"""
|
|
headers = {"Authorization": "Bearer invalid_token"}
|
|
response = await client.get("/api/v1/auth/me", headers=headers)
|
|
|
|
assert response.status_code == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_current_user_info_no_token(client: AsyncClient):
|
|
"""测试不提供令牌获取用户信息"""
|
|
response = await client.get("/api/v1/auth/me")
|
|
|
|
assert response.status_code == 403 # HTTPBearer returns 403 for missing credentials
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_token_success(client: AsyncClient, test_user_data):
|
|
"""测试令牌刷新成功"""
|
|
# 先注册用户
|
|
register_response = await client.post("/api/v1/auth/register", json=test_user_data)
|
|
refresh_token = register_response.json()["data"]["tokens"]["refresh_token"]
|
|
|
|
# 刷新令牌
|
|
refresh_data = {"refresh_token": refresh_token}
|
|
response = await client.post("/api/v1/auth/refresh", json=refresh_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert "tokens" in data["data"]
|
|
assert "access_token" in data["data"]["tokens"]
|
|
assert "refresh_token" in data["data"]["tokens"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_token_invalid(client: AsyncClient):
|
|
"""测试无效刷新令牌"""
|
|
refresh_data = {"refresh_token": "invalid_refresh_token"}
|
|
response = await client.post("/api/v1/auth/refresh", json=refresh_data)
|
|
|
|
assert response.status_code == 401
|
|
data = response.json()
|
|
assert data["detail"]["code"] == "INVALID_REFRESH_TOKEN"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logout_success(client: AsyncClient, test_user_data):
|
|
"""测试登出成功"""
|
|
# 先注册用户
|
|
register_response = await client.post("/api/v1/auth/register", json=test_user_data)
|
|
access_token = register_response.json()["data"]["tokens"]["access_token"]
|
|
|
|
# 登出
|
|
headers = {"Authorization": f"Bearer {access_token}"}
|
|
response = await client.post("/api/v1/auth/logout", headers=headers)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["message"] == "登出成功"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_invalid_username(client: AsyncClient, test_user_data):
|
|
"""测试无效用户名"""
|
|
invalid_data = test_user_data.copy()
|
|
invalid_data["username"] = "us" # 太短
|
|
|
|
response = await client.post("/api/v1/auth/register", json=invalid_data)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_user_username_with_special_chars(client: AsyncClient, test_user_data):
|
|
"""测试用户名包含特殊字符"""
|
|
invalid_data = test_user_data.copy()
|
|
invalid_data["username"] = "user@name" # 包含特殊字符
|
|
|
|
response = await client.post("/api/v1/auth/register", json=invalid_data)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"]) |