一个FastAPI 项目 conftest.py 的案例,主要包含了配置文件定义、测试数据库的初始化、数据库session创建、HTTP客户端创建等fixture。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
@pytest_asyncio.fixture(scope="session", autouse=True)
async def application():
"""应用app."""
from app.main import app

yield app


@pytest_asyncio.fixture(scope="session", autouse=True)
async def app_settings():
"""应用配置信息."""
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
load_dotenv(f"{base_path}/.env", override=True) # TODO 暂时用.env
from app.config import settings

yield settings


@pytest_asyncio.fixture(scope="session", autouse=True)
async def project_root_path(request):
"""pytest根目录."""
yield request.config.rootpath


@pytest_asyncio.fixture(scope="session", autouse=True)
async def api_prefix(app_settings):
"""API 前缀."""
yield app_settings.API_PREFIX


@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_testing_db(app_settings):
"""创建测试数据库并在测试结束时删除.

测试数据库名字为 pytest_ + 原库名 TODO 因sqlalchemy_utils 不支持异步操作,这里暂且使用同步方式.
"""
app_settings.DB_DATABASE = "pytest_" + app_settings.DB_DATABASE
sqlalchemy_database_uri = app_settings.MARIADB_DATABASE_URI.unicode_string()

if database_exists(sqlalchemy_database_uri):
drop_database(sqlalchemy_database_uri)

create_database(sqlalchemy_database_uri)

yield

drop_database(sqlalchemy_database_uri)


@pytest_asyncio.fixture(scope="session")
async def async_engine(app_settings):
"""创建数据库引擎."""
return create_async_engine(
app_settings.AIO_MARIADB_DATABASE_URI.unicode_string(),
echo=app_settings.DB_ENABLE_ECHO,
pool_size=app_settings.DB_POOL_SIZE,
max_overflow=app_settings.DB_POOL_OVERFLOW,
future=True,
)


@pytest_asyncio.fixture(scope="function")
async def async_session(async_engine) -> AsyncGenerator[Session, None, None]:
"""创建数据库访问session, 并在每次调用前后创建、删除所有数据表."""
session = sessionmaker(async_engine, class_=Session, expire_on_commit=False)

async with session() as s:
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

yield s

async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)

await async_engine.dispose()


@pytest_asyncio.fixture(scope="function")
async def async_client(async_session, app_settings):
"""创建HTTP访问客户端."""
from app.main import app
from app.models import get_session

app.dependency_overrides[get_session] = lambda: async_session
async with AsyncClient(
transport=ASGITransport(app=app),
base_url=f"http://localhost:{app_settings.PORT}",
) as client:
yield client


@pytest_asyncio.fixture(scope="function")
async def setup_redis_cache(app_settings):
"""构建redis cache."""
from ..models.redis import build_redis_cache

async with build_redis_cache(app_settings):
yield