guzhi/app/core/middlewares.py
邹方成 48b93fdddb feat(valuation): 扩展非遗资产评估模型并完善相关功能
- 在valuation模型中新增多个评估字段,包括稀缺等级、市场活动时间等
- 完善用户端输出模型,确保所有字段正确序列化
- 修复文件上传返回URL缺少BASE_URL的问题
- 更新Docker镜像版本至v1.2
- 添加静态文件路径到中间件排除列表
- 优化估值评估创建接口,自动关联当前用户ID
2025-10-10 08:55:17 +08:00

224 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import re
from datetime import datetime
from typing import Any, AsyncGenerator
from fastapi import FastAPI
from fastapi.responses import Response
from fastapi.routing import APIRoute
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
from app.core.dependency import AuthControl
from app.models.admin import AuditLog, User
from .bgtask import BgTasks
class SimpleBaseMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive=receive)
response = await self.before_request(request) or self.app
await response(request.scope, request.receive, send)
await self.after_request(request)
async def before_request(self, request: Request):
return self.app
async def after_request(self, request: Request):
return None
class BackGroundTaskMiddleware(SimpleBaseMiddleware):
async def before_request(self, request):
await BgTasks.init_bg_tasks_obj()
async def after_request(self, request):
await BgTasks.execute_tasks()
class HttpAuditLogMiddleware(BaseHTTPMiddleware):
def __init__(self, app, methods: list[str], exclude_paths: list[str]):
super().__init__(app)
self.methods = methods
self.exclude_paths = exclude_paths
self.audit_log_paths = ["/api/v1/auditlog/list"]
self.max_body_size = 1024 * 1024 # 1MB 响应体大小限制
async def get_request_args(self, request: Request) -> dict:
args = {}
# 获取查询参数
for key, value in request.query_params.items():
args[key] = value
# 获取请求体
if request.method in ["POST", "PUT", "PATCH"]:
# 检查Content-Type来决定如何解析请求体
content_type = request.headers.get("content-type", "")
if "multipart/form-data" in content_type or "application/x-www-form-urlencoded" in content_type:
# 处理表单数据(包括文件上传)
try:
# 对于文件上传不要在中间件中消费request.form()
# 因为这会导致FastAPI无法再次读取请求体
pass
except Exception:
pass
elif "application/json" in content_type:
# 处理JSON数据
try:
body = await request.json()
args.update(body)
except (json.JSONDecodeError, UnicodeDecodeError):
pass
else:
# 尝试解析为JSON如果失败则跳过
try:
body = await request.json()
args.update(body)
except (json.JSONDecodeError, UnicodeDecodeError):
pass
return args
async def get_response_body(self, request: Request, response: Response) -> Any:
# 检查Content-Length
content_length = response.headers.get("content-length")
if content_length and int(content_length) > self.max_body_size:
return {"code": 0, "msg": "Response too large to log", "data": None}
if hasattr(response, "body"):
body = response.body
else:
body_chunks = []
async for chunk in response.body_iterator:
if not isinstance(chunk, bytes):
chunk = chunk.encode(response.charset)
body_chunks.append(chunk)
response.body_iterator = self._async_iter(body_chunks)
body = b"".join(body_chunks)
if any(request.url.path.startswith(path) for path in self.audit_log_paths):
try:
data = self.lenient_json(body)
# 只保留基本信息,去除详细的响应内容
if isinstance(data, dict):
data.pop("response_body", None)
if "data" in data and isinstance(data["data"], list):
for item in data["data"]:
item.pop("response_body", None)
return data
except Exception:
return None
return self.lenient_json(body)
def lenient_json(self, v: Any) -> Any:
if isinstance(v, (str, bytes)):
try:
return json.loads(v)
except (ValueError, TypeError):
pass
return v
def normalize_json_field(self, value: Any) -> Any:
"""确保写入 JSONField 的值合法。"""
if value is None:
return None
if isinstance(value, (bytes, bytearray)):
try:
value = value.decode("utf-8")
except Exception:
value = value.decode("utf-8", errors="ignore")
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return None
try:
return json.loads(stripped)
except (ValueError, TypeError):
return stripped
if isinstance(value, (dict, list, int, float, bool)):
return value
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def _async_iter(self, items: list[bytes]) -> AsyncGenerator[bytes, None]:
for item in items:
yield item
async def get_request_log(self, request: Request, response: Response) -> dict:
"""
根据request和response对象获取对应的日志记录数据
"""
data: dict = {"path": request.url.path, "status": response.status_code, "method": request.method}
# 路由信息
app: FastAPI = request.app
for route in app.routes:
if (
isinstance(route, APIRoute)
and route.path_regex.match(request.url.path)
and request.method in route.methods
):
data["module"] = ",".join(route.tags)
data["summary"] = route.summary
# 获取用户信息
try:
token = request.headers.get("token")
user_obj = None
if token:
user_obj: User = await AuthControl.is_authed(token)
data["user_id"] = user_obj.id if user_obj else 0
data["username"] = user_obj.username if user_obj else ""
except Exception:
data["user_id"] = 0
data["username"] = ""
return data
async def before_request(self, request: Request):
request_args = await self.get_request_args(request)
request.state.request_args = request_args
async def after_request(self, request: Request, response: Response, process_time: int):
if request.method in self.methods:
for path in self.exclude_paths:
if re.search(path, request.url.path, re.I) is not None:
return
data: dict = await self.get_request_log(request=request, response=response)
data["response_time"] = process_time
request_args = getattr(request.state, "request_args", None)
response_body = await self.get_response_body(request, response)
data["request_args"] = self.normalize_json_field(request_args)
data["response_body"] = self.normalize_json_field(response_body)
await AuditLog.create(**data)
return response
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
start_time: datetime = datetime.now()
await self.before_request(request)
response = await call_next(request)
end_time: datetime = datetime.now()
process_time = int((end_time.timestamp() - start_time.timestamp()) * 1000)
await self.after_request(request, response, process_time)
return response