225 lines
8.2 KiB
Python
225 lines
8.2 KiB
Python
import json
|
||
import re
|
||
from datetime import datetime
|
||
from typing import Any, AsyncGenerator
|
||
|
||
from fastapi import FastAPI
|
||
from fastapi.responses import Response
|
||
from fastapi.routing import APIRoute
|
||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||
from starlette.requests import Request
|
||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||
|
||
from app.core.dependency import AuthControl
|
||
from app.models.admin import AuditLog, User
|
||
|
||
from .bgtask import BgTasks
|
||
|
||
|
||
class SimpleBaseMiddleware:
|
||
def __init__(self, app: ASGIApp) -> None:
|
||
self.app = app
|
||
|
||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
if scope["type"] != "http":
|
||
await self.app(scope, receive, send)
|
||
return
|
||
|
||
request = Request(scope, receive=receive)
|
||
|
||
response = await self.before_request(request) or self.app
|
||
await response(request.scope, request.receive, send)
|
||
await self.after_request(request)
|
||
|
||
async def before_request(self, request: Request):
|
||
return self.app
|
||
|
||
async def after_request(self, request: Request):
|
||
return None
|
||
|
||
|
||
class BackGroundTaskMiddleware(SimpleBaseMiddleware):
|
||
async def before_request(self, request):
|
||
await BgTasks.init_bg_tasks_obj()
|
||
|
||
async def after_request(self, request):
|
||
await BgTasks.execute_tasks()
|
||
|
||
|
||
class HttpAuditLogMiddleware(BaseHTTPMiddleware):
|
||
def __init__(self, app, methods: list[str], exclude_paths: list[str]):
|
||
super().__init__(app)
|
||
self.methods = methods
|
||
self.exclude_paths = exclude_paths
|
||
self.audit_log_paths = ["/api/v1/auditlog/list"]
|
||
self.max_body_size = 1024 * 1024 # 1MB 响应体大小限制
|
||
|
||
async def get_request_args(self, request: Request) -> dict:
|
||
args = {}
|
||
# 获取查询参数
|
||
for key, value in request.query_params.items():
|
||
args[key] = value
|
||
|
||
# 获取请求体
|
||
if request.method in ["POST", "PUT", "PATCH"]:
|
||
# 检查Content-Type来决定如何解析请求体
|
||
content_type = request.headers.get("content-type", "")
|
||
|
||
if "multipart/form-data" in content_type or "application/x-www-form-urlencoded" in content_type:
|
||
# 处理表单数据(包括文件上传)
|
||
try:
|
||
# 对于文件上传,不要在中间件中消费request.form()
|
||
# 因为这会导致FastAPI无法再次读取请求体
|
||
pass
|
||
except Exception:
|
||
pass
|
||
elif "application/json" in content_type:
|
||
# 处理JSON数据
|
||
try:
|
||
body = await request.json()
|
||
args.update(body)
|
||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||
pass
|
||
else:
|
||
# 尝试解析为JSON,如果失败则跳过
|
||
try:
|
||
body = await request.json()
|
||
args.update(body)
|
||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||
pass
|
||
|
||
return args
|
||
|
||
async def get_response_body(self, request: Request, response: Response) -> Any:
|
||
# 检查Content-Length
|
||
content_length = response.headers.get("content-length")
|
||
if content_length and int(content_length) > self.max_body_size:
|
||
return {"code": 0, "msg": "Response too large to log", "data": None}
|
||
|
||
if hasattr(response, "body"):
|
||
body = response.body
|
||
else:
|
||
body_chunks = []
|
||
async for chunk in response.body_iterator:
|
||
if not isinstance(chunk, bytes):
|
||
chunk = chunk.encode(response.charset)
|
||
body_chunks.append(chunk)
|
||
|
||
response.body_iterator = self._async_iter(body_chunks)
|
||
body = b"".join(body_chunks)
|
||
|
||
if any(request.url.path.startswith(path) for path in self.audit_log_paths):
|
||
try:
|
||
data = self.lenient_json(body)
|
||
# 只保留基本信息,去除详细的响应内容
|
||
if isinstance(data, dict):
|
||
data.pop("response_body", None)
|
||
if "data" in data and isinstance(data["data"], list):
|
||
for item in data["data"]:
|
||
item.pop("response_body", None)
|
||
return data
|
||
except Exception:
|
||
return None
|
||
|
||
return self.lenient_json(body)
|
||
|
||
def lenient_json(self, v: Any) -> Any:
|
||
if isinstance(v, (str, bytes)):
|
||
try:
|
||
return json.loads(v)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
return v
|
||
|
||
def normalize_json_field(self, value: Any) -> Any:
|
||
"""确保写入 JSONField 的值合法。"""
|
||
if value is None:
|
||
return None
|
||
|
||
if isinstance(value, (bytes, bytearray)):
|
||
try:
|
||
value = value.decode("utf-8")
|
||
except Exception:
|
||
value = value.decode("utf-8", errors="ignore")
|
||
|
||
if isinstance(value, str):
|
||
stripped = value.strip()
|
||
if not stripped:
|
||
return None
|
||
try:
|
||
return json.loads(stripped)
|
||
except (ValueError, TypeError):
|
||
# 将非 JSON 字符串包装为字典,以便 JSONField 能够正确存储
|
||
return {"text": stripped}
|
||
|
||
if isinstance(value, (dict, list, int, float, bool)):
|
||
return value
|
||
|
||
try:
|
||
json.dumps(value)
|
||
return value
|
||
except (TypeError, ValueError):
|
||
return str(value)
|
||
|
||
async def _async_iter(self, items: list[bytes]) -> AsyncGenerator[bytes, None]:
|
||
for item in items:
|
||
yield item
|
||
|
||
async def get_request_log(self, request: Request, response: Response) -> dict:
|
||
"""
|
||
根据request和response对象获取对应的日志记录数据
|
||
"""
|
||
data: dict = {"path": request.url.path, "status": response.status_code, "method": request.method}
|
||
# 路由信息
|
||
app: FastAPI = request.app
|
||
for route in app.routes:
|
||
if (
|
||
isinstance(route, APIRoute)
|
||
and route.path_regex.match(request.url.path)
|
||
and request.method in route.methods
|
||
):
|
||
data["module"] = ",".join(route.tags)
|
||
data["summary"] = route.summary
|
||
# 获取用户信息
|
||
try:
|
||
token = request.headers.get("token")
|
||
user_obj = None
|
||
if token:
|
||
user_obj: User = await AuthControl.is_authed(token)
|
||
data["user_id"] = user_obj.id if user_obj else 0
|
||
data["username"] = user_obj.username if user_obj else ""
|
||
except Exception:
|
||
data["user_id"] = 0
|
||
data["username"] = ""
|
||
return data
|
||
|
||
async def before_request(self, request: Request):
|
||
request_args = await self.get_request_args(request)
|
||
request.state.request_args = request_args
|
||
|
||
async def after_request(self, request: Request, response: Response, process_time: int):
|
||
if request.method in self.methods:
|
||
for path in self.exclude_paths:
|
||
if re.search(path, request.url.path, re.I) is not None:
|
||
return
|
||
data: dict = await self.get_request_log(request=request, response=response)
|
||
data["response_time"] = process_time
|
||
|
||
request_args = getattr(request.state, "request_args", None)
|
||
response_body = await self.get_response_body(request, response)
|
||
|
||
data["request_args"] = self.normalize_json_field(request_args)
|
||
data["response_body"] = self.normalize_json_field(response_body)
|
||
await AuditLog.create(**data)
|
||
|
||
return response
|
||
|
||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||
start_time: datetime = datetime.now()
|
||
await self.before_request(request)
|
||
response = await call_next(request)
|
||
end_time: datetime = datetime.now()
|
||
process_time = int((end_time.timestamp() - start_time.timestamp()) * 1000)
|
||
await self.after_request(request, response, process_time)
|
||
return response
|