FastAPI中间件包装响应报文

作者:zhangyunlong 发布时间: 2025-10-24 阅读量:6 评论数:0

在Java中用习惯了Spring的ControllerAdvice对响应报文进行统一包装, 转到Python的FastAPI框架也可以做到类似的包装:

1. 定义错误码枚举类

# error_code_config.py
from enum import Enum

class ErrorCode(Enum):
    USER_NOT_FOUND = (1001, "用户不存在")
    INVALID_TOKEN = (1002, "无效的令牌")
    PERMISSION_DENIED = (1003, "没有访问权限")
    PARAMS_MISSING = (1004, "缺少必要参数")
    SYSTEM_ERROR = (9999, "系统内部错误")

    def __init__(self, code: int, msg: str):
        self._code = code
        self._msg = msg

    @property
    def code(self) -> int:
        return self._code

    @property
    def msg(self) -> str:
        return self._msg

2. 定义响应报文包装中间件

# wrap.py
import json
import base64
import logging
from typing import Callable, Iterable

from fastapi import Request
from fastapi.responses import Response, JSONResponse, StreamingResponse
from config.error_code_config import ErrorCode

logger = logging.getLogger("uvicorn.error")


class BusinessException(Exception):
    """业务异常类"""

    def __init__(self, error_code: ErrorCode):
        """
        业务异常:必须使用 ErrorCode 枚举类进行构造
        """
        if not isinstance(error_code, ErrorCode):
            raise ValueError("BusinessException 必须使用 ErrorCode 枚举")
        self.code = error_code.code
        self.msg = error_code.msg
        super().__init__(self.msg)


def native_response(func: Callable):
    """
    装饰器:标记该 endpoint 的响应报文不被封装
    用法:
      @app.get("/raw")
      @native_response
      async def raw(): ...
    """
    setattr(func, "__native_response__", True)
    return func


def create_wrap_middleware(exclude_paths: Iterable[str] = None):
    """
    返回一个可用于 app.middleware("http") 的函数式中间件
    """
    exclude = set(exclude_paths or ["/docs", "/redoc", "/openapi.json", "/swagger", "/health"])

    async def wrap_middleware(request: Request, call_next):
        print("开始进行报文包装")
        # 如果请求路径在排除列表,直接走下游(不做封装)
        if any(request.url.path.startswith(p) for p in exclude):
            print("静态资源, 无需报文包装")
            return await call_next(request)

        # 先调用下游,routing 会填充 request.scope["endpoint"]
        response = await call_next(request)

        # 如果路由函数标记了 native_response,直接返回原始 response(不做任何读取/封装)
        endpoint = request.scope.get("endpoint")
        if endpoint and getattr(endpoint, "__native_response__", False):
            print("标记为原始响应, 无需报文包装")
            return response

        # 流式响应或 SSE 不做封装(保持原样)
        content_type = (response.headers.get("content-type") or "").lower()
        if isinstance(response, StreamingResponse) or "text/event-stream" in content_type:
            print("流式响应, 无需报文包装")
            return response

        # 读取原始 body(多数 Response 提供 body_iterator)
        body_bytes = b""
        try:
            async for chunk in response.body_iterator:  # type: ignore
                body_bytes += chunk
        except Exception:
            # 兜底处理:一些 Response 没有 body_iterator,可以直接取 .body
            try:
                raw = getattr(response, "body", None)
                if raw is None:
                    body_bytes = b""
                elif isinstance(raw, bytes):
                    body_bytes = raw
                elif isinstance(raw, str):
                    body_bytes = raw.encode(getattr(response, "charset", "utf-8") or "utf-8")
                else:
                    body_bytes = str(raw).encode("utf-8")
            except Exception:
                body_bytes = b""

        # 解析 text/json
        body_text = None
        body_obj = None
        if body_bytes:
            try:
                body_text = body_bytes.decode(response.charset or "utf-8")
            except Exception:
                body_text = None

            if body_text is not None:
                try:
                    body_obj = json.loads(body_text)
                except Exception:
                    body_obj = None

        # 如果已经是统一结构 {code,msg,...} 则直接返回(不二次封装)
        if isinstance(body_obj, dict) and {"code", "msg"}.issubset(body_obj.keys()):
            print("已是包装结构, 无需再次包装")
            # 复制 headers 并移除可能不匹配的 header,避免 h11 的 Content-Length 错误
            new_headers = dict(response.headers)
            for h in list(new_headers.keys()):
                if h.lower() in ("content-length", "transfer-encoding"):
                    new_headers.pop(h, None)
            new_headers["content-type"] = "application/json; charset=utf-8"

            return Response(
                content=json.dumps(body_obj, ensure_ascii=False),
                status_code=response.status_code,
                headers=new_headers,
                media_type="application/json"
            )

        # 构造统一响应体
        def success_body(data):
            return {
                "code": 200,
                "msg": "success",
                "data": data
            }

        if body_obj is not None:
            wrapped = success_body(body_obj)
        elif body_text is not None:
            wrapped = success_body(body_text)
        else:
            wrapped = success_body({"__base64_bytes": base64.b64encode(body_bytes).decode()})
        content = json.dumps(wrapped, ensure_ascii=False)

        # 复制并清理 headers:必须移除 Content-Length/Transfer-Encoding 等,避免与封装后的新 body 长度冲突
        new_headers = dict(response.headers)
        for h in list(new_headers.keys()):
            if h.lower() in ("content-length", "transfer-encoding"):
                new_headers.pop(h, None)

        new_headers["content-type"] = "application/json; charset=utf-8"
        print("包装完成")
        return Response(content=content, status_code=response.status_code, headers=new_headers,
                        media_type="application/json")

    # 返回构建好的函数式中间件
    return wrap_middleware


def register_exception_handlers(app):
    """
    异常处理器, 处理业务异常和系统异常, 包装为统一格式
    :param app: 应用
    :return: {"code":错误码, "msg":错误信息, "data":None}
    """

    @app.exception_handler(BusinessException)
    async def business_exc_handler(request: Request, exc: BusinessException):
        """
        业务异常处理
        :param request: http请求
        :param exc: 异常信息
        :return: {"code":错误码, "msg":错误信息, "data":None}
        """
        logger.error(f"发生业务异常, 异常信息: {exc.msg}", exc_info=True)
        body = {
            "code": exc.code,
            "msg": exc.msg,
            "data": None
        }
        return JSONResponse(content=body, status_code=200)

    @app.exception_handler(Exception)
    async def general_exc_handler(request: Request, exc: Exception):
        """
        兜底处理其他异常
        :param request: http请求
        :param exc: 异常信息
        :return: {"code":1000, "msg":系统异常, "data":None}
        """
        logger.exception(f"发生了系统异常, 异常信息: {exc}")
        body = {
            "code": 1000,
            "msg": "系统异常",
            "data": None
        }
        return JSONResponse(content=body, status_code=200)

3. 定义中间件配置类

# middleware_config.py
from middleware.wrap import create_wrap_middleware, register_exception_handlers
from fastapi.middleware.cors import CORSMiddleware


def setup_middleware(app):
    """
    配置各种中间件, 包括跨域, 响应报文封装, 鉴权等
    :param app: 应用
    :return: None
    """
    # 配置跨域资源共享
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"]
    )
    # 注册全局异常处理
    register_exception_handlers(app)

    # 配置应用程序中间件, 按添加顺序倒序执行
    app.middleware("http")(create_wrap_middleware(exclude_paths=["/docs", "/redoc", "/openapi.json", "/swagger", "/health"]))
    # 下面继续添加其他中间件

4. 包装报文使用示例

from fastapi import APIRouter

from config.error_code_config import ErrorCode
from middleware.wrap import BusinessException, native_response

router = APIRouter(tags=["业务1相关接口"], prefix="/demo1")


@router.get("/test1")
async def test():
    # 正常业务, 会被包装
    print("业务1收到请求")
    return "ok"


@router.get("test2")
async def test2():
    # 抛出异常, 会被包装
    raise BusinessException(ErrorCode.INVALID_TOKEN)


@native_response
@router.get("test3")
async def test3():
    # 使用了装饰器标识, 不会被包装
    return "test2"

5. 在主程序中添加中间件配置类

# main.py
import uvicorn
from fastapi import FastAPI
from router import demo1_router
from config.middleware_config import setup_middleware

app = FastAPI()

# 配置中间件
setup_middleware(app)

# 添加路由
app.include_router(demo1_router.router)

@app.get("/health")
async def root():
    return {"ok"}


if __name__ == "__main__":
    uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)

评论