
一个真实场景:刚接手一个FastAPI项目,打开代码库,UserRepository、ProductRepository、OrderRepository……每个文件都在重复同样的save、get、update、delete逻辑。复制粘贴了8次之后,我开始怀疑人生——我们真的需要为每个数据表写一遍相同的代码吗?
如果你也有同样的困惑,今天这篇文章会给你一个答案。我将带你用Python泛型和SQLAlchemy,实现一个类型安全、可扩展、可复用的通用仓库模式,让你从此告别重复的CRUD代码。
在大多数FastAPI或SQLAlchemy项目中,仓库层(Repository)长这样:
classUserRepository:
def__init__(self, session: AsyncSession):
self._session = session
asyncdefsave(self, user: User) -> User:
model = UserModel(name=user.name, email=user.email)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return self._to_entity(model)
asyncdefget(self, user_id: UUID) -> User | None:
result = await self._session.scalar(
select(UserModel).where(UserModel.id == user_id)
)
return self._to_entity(result) if result elseNone
# ... 更多方法
然后你创建ProductRepository——复制粘贴。OrderRepository——再次复制粘贴。
每个仓库都包含:
唯一变化的只有三样东西:
User)UserModel)⚠️ 注意:这种重复代码是“复制粘贴综合症”的典型表现,90%的团队在这里踩坑——当业务逻辑需要修改时,你要在8个仓库里改8遍,漏改一个就是Bug。
一个设计良好的通用仓库应该做到:
下面是一份生产级的实现代码。
首先,需要一个所有领域实体共享的基类,保证统一的结构:
from dataclasses import dataclass, field
from datetime import datetime, timezone
from uuid import UUID
@dataclass(kw_only=True)
classEntityBase:
id: UUID | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
classDatabaseException(Exception):
"""数据库操作异常的统一包装"""
pass
from enum import StrEnum
classOrdering(StrEnum):
"""排序方向,类型安全"""
asc = "asc"
desc = "desc"
这是整个模式的核心。我把它拆成两部分讲解,但你可以直接复制使用。
from abc import ABC, abstractmethod
from typing import Any, Generic, List, TypeVar
import sqlalchemy
from sqlalchemy import asc, desc, func, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
# 假设你的Base类在这里定义
from .... import Base
from domain.value_objects.ordering import Ordering
from domain.entities.base import EntityBase
from domain.exceptions.common import DatabaseException
Entity = TypeVar("Entity", bound=EntityBase)
SqlAlchemyModel = TypeVar("SqlAlchemyModel", bound=Base)
classSqlAlchemyAbstractRepository(ABC, Generic[Entity, SqlAlchemyModel]):
# 子类必须指定具体的ORM模型类
model: type[SqlAlchemyModel]
def__init__(self, session: AsyncSession) -> None:
self._session = session
asyncdefsave(self, entity: Entity) -> Entity:
"""保存实体,返回包含数据库生成字段(如ID)的完整实体"""
model = self._entity_to_model(entity)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return self._model_to_entity(model)
asyncdefupdate(
self,
fields_to_update: dict[str, Any],
**filters,
) -> int:
"""根据过滤条件更新字段,返回受影响的行数"""
try:
filter_conditions = self._get_filters(**filters)
query = (
sqlalchemy.update(self.model)
.where(*filter_conditions)
.values(fields_to_update)
)
result = await self._session.execute(query)
await self._session.flush()
return result.rowcount # type: ignore[attr-defined]
except IntegrityError as exception:
await self._session.rollback()
raise exception
except SQLAlchemyError as exception:
await self._session.rollback()
raise DatabaseException from exception
asyncdeflist_all(
self,
page: int = 1,
limit: int = 10,
order_by: str = "created_at",
ordering: Ordering = Ordering.asc,
**filters,
) -> List[Entity]:
"""分页列表查询,支持排序和过滤"""
query = select(self.model)
filter_conditions = self._get_filters(**filters)
query = query.where(*filter_conditions)
# 排序
query = query.order_by(
self._get_order_expression(order_by=order_by, ordering=ordering)
)
# 分页
offset = (page - 1) * limit
query = query.offset(offset).limit(limit)
result = await self._session.execute(query)
models = result.scalars().all()
return [self._model_to_entity(model) for model in models]
asyncdefget(
self,
**filters,
) -> Entity | None:
"""根据过滤条件获取单个实体"""
query = select(self.model)
filter_conditions = self._get_filters(**filters)
query = query.where(*filter_conditions)
model = await self._session.scalar(query)
return self._model_to_entity(model) if model elseNone
asyncdefexists(
self,
**filters,
) -> bool:
"""检查是否存在满足条件的记录"""
query = select(self.model)
filter_conditions = self._get_filters(**filters)
query = query.where(*filter_conditions)
result = await self._session.scalar(query)
return result isnotNone
asyncdefdelete(
self,
**filters,
) -> int:
"""根据过滤条件删除记录,返回删除的行数"""
try:
query = sqlalchemy.delete(self.model)
filter_conditions = self._get_filters(**filters)
query = query.where(*filter_conditions)
result = await self._session.execute(query)
await self._session.flush()
return result.rowcount # type: ignore[attr-defined]
except SQLAlchemyError as e:
await self._session.rollback()
raise DatabaseException from e
asyncdefcount(
self,
**filters,
) -> int:
"""统计满足条件的记录数"""
filter_conditions = self._get_filters(**filters)
return (
await self._session.scalar(
select(func.count()).select_from(self.model).where(*filter_conditions)
)
or0
)
@staticmethod
@abstractmethod
def_model_to_entity(model: SqlAlchemyModel) -> Entity:
"""将ORM模型转换为领域实体——子类必须实现"""
raise NotImplementedError("Subclasses must implement _model_to_entity")
@staticmethod
@abstractmethod
def_entity_to_model(entity: Entity) -> SqlAlchemyModel:
"""将领域实体转换为ORM模型——子类必须实现"""
raise NotImplementedError("Subclasses must implement _entity_to_model")
@abstractmethod
def_get_filters(self, **filters) -> List[Any]:
"""将业务层过滤条件转换为SQLAlchemy查询条件——子类可重写"""
return []
@staticmethod
def_get_order_expression(
order_by: str, ordering: Ordering
) -> sqlalchemy.UnaryExpression[str]:
"""生成排序表达式"""
if ordering == Ordering.asc:
return asc(order_by)
return desc(order_by)
如果上面这段代码让你有点晕,我用一个类比帮你理清:
泛型就像订餐平台的模板:
Entity = TypeVar("Entity", bound=EntityBase) —— 这就像“我要一份饭”,但具体是盖浇饭还是炒饭,后面再定Model = TypeVar("Model", bound=Base) —— 这就像“我要一个餐具”,具体是碗还是盘子,也后面再定SqlAlchemyAbstractRepository[Entity, Model] —— 这个组合就像“我要一份(某种饭)搭配(某种餐具)的套餐”当你创建具体仓库时:
classUserRepository(SqlAlchemyAbstractRepository[User, UserModel]):
...
就相当于说:“我要一份User饭装在UserModel餐具里。”
IDE现在就能准确知道:
save() 接收User,返回User_model_to_entity() 必须把UserModel映射成UserUser有效的字段⚠️ 关键点:Python虽然是动态语言,但通过类型提示和泛型,你可以获得编译时类型检查的能力。这在多人协作时,能避免无数“不小心传错参数”的Bug。
现在创建一个用户仓库,你会发现只需要写三件事:
model类classSqlAlchemyUserRepository(
SqlAlchemyAbstractRepository[User, UserModel],
):
model = UserModel
def_entity_to_model(self, entity: User) -> UserModel:
model = UserModel(
name=entity.name,
email=entity.email,
role=entity.role,
)
# 如果实体已有ID(更新场景),保持ID
if entity.id:
model.id = entity.id
return model
def_model_to_entity(self, model: UserModel) -> User:
return User(
id=model.id,
name=model.name,
email=model.email,
role=model.role,
created_at=model.created_at,
updated_at=model.updated_at,
)
def_get_filters(self, **filters):
"""支持三种过滤条件:id、email、role"""
conditions = []
if"id_filter"in filters:
conditions.append(UserModel.id == filters["id_filter"])
if"email_filter"in filters:
conditions.append(UserModel.email == filters["email_filter"])
if"role_filter"in filters:
conditions.append(UserModel.role == filters["role_filter"])
return conditions
看到没? 整个仓库就这么点代码。
你的仓库只需要关注领域特有的逻辑。
_get_filters这么重要?它让你的查询API既干净又灵活:
# 查询管理员
admins = await user_repo.list_all(
role_filter="admin",
page=1,
limit=20
)
# 按邮箱查找单个用户
user = await user_repo.get(email_filter="john@example.com")
# 检查用户是否存在
exists = await user_repo.exists(email_filter="john@example.com")
不需要为每个查询写单独的SQL,所有过滤条件统一通过_get_filters转换为查询条件。
需要处理特定业务的数据库错误?只需覆盖方法:
classSqlAlchemyUserRepository(...):
# ... 前面的代码
asyncdefsave(self, entity: User) -> User:
try:
returnawait super().save(entity)
except IntegrityError as e:
await self._session.rollback()
# 检查是否是邮箱重复
if"ix_users_email"in str(e):
raise UserAlreadyExistsError(entity.email)
raise
⚠️ 注意:这里的关键是
await self._session.rollback()——忘记回滚会让session处于异常状态,后续操作都会失败。这是90%的人踩过的坑。
通用仓库不代表不能添加特定查询:
classSqlAlchemyUserRepository(...):
# ... 前面的代码
asyncdefget_by_email(self, email: str) -> User | None:
"""按邮箱获取用户(业务常用)"""
returnawait self.get(email_filter=email)
asyncdefget_active_admins(self) -> List[User]:
"""获取活跃管理员(业务特定)"""
returnawait self.list_all(
role_filter="admin",
status_filter="active"
)
通用 ≠ 限制,而是从强大的基础上开始。
在重构一个中等规模的FastAPI项目后,数据是这样的:
核心洞察:这种模式不仅减少了代码量,更重要的是——逻辑集中在一处,修改一次生效全局,Bug率显著下降。
1. DRY原则落地
写一次,修一次,处处生效。
2. 一致性保障
所有仓库行为统一,新人上手零学习成本。
3. 类型安全
告别Any和随意传递的字典,IDE能给你准确的代码补全。
4. 可测试性
测试一次基类,所有仓库都得到测试覆盖。
5. 可维护性
想加软删除?在基类改一次,所有仓库自动支持。
6. 灵活性
需要特殊行为?覆盖方法即可,基类不限制你。
从复制粘贴8个仓库,到用泛型基类一行行抽象出来,这个过程让我意识到一件事:
好的抽象不是炫技,而是当你需要修改代码时,发现只需要改一个地方。
通用仓库模式在Python生态中并不算新,但它结合async、SQLAlchemy和泛型后,能给你的代码质量带来质的飞跃。下次你再新建一个实体时,不用再写那300行CRUD,只需30行映射和过滤逻辑。
如果你正在维护一个数据访问层臃肿的项目,建议逐个仓库迁移,而不是一次性全量替换。先迁移一个非核心的仓库,验证无误后再逐步推进。
_get_filters统一入口你在项目中有没有遇到过类似的重复代码困扰?如果让你设计一个通用仓库,你会加入哪些额外的通用方法(比如批量操作、乐观锁)?欢迎在评论区分享你的思路。

长按👇关注- 数据STUDIO -设为星标,干货速递
