import json import re from datetime import datetime from typing import Any, AsyncGenerator from fastapi import FastAPI from fastapi.responses import Response from fastapi.routing import APIRoute from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send from app.core.dependency import AuthControl from app.models.admin import AuditLog, User from .bgtask import BgTasks class SimpleBaseMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return request = Request(scope, receive=receive) response = await self.before_request(request) or self.app await response(request.scope, request.receive, send) await self.after_request(request) async def before_request(self, request: Request): return self.app async def after_request(self, request: Request): return None class BackGroundTaskMiddleware(SimpleBaseMiddleware): async def before_request(self, request): await BgTasks.init_bg_tasks_obj() async def after_request(self, request): await BgTasks.execute_tasks() class HttpAuditLogMiddleware(BaseHTTPMiddleware): def __init__(self, app, methods: list[str], exclude_paths: list[str]): super().__init__(app) self.methods = methods self.exclude_paths = exclude_paths self.audit_log_paths = ["/api/v1/auditlog/list"] self.max_body_size = 1024 * 1024 # 1MB 响应体大小限制 async def get_request_args(self, request: Request) -> dict: args = {} # 获取查询参数 for key, value in request.query_params.items(): args[key] = value # 获取请求体 if request.method in ["POST", "PUT", "PATCH"]: # 检查Content-Type来决定如何解析请求体 content_type = request.headers.get("content-type", "") if "multipart/form-data" in content_type or "application/x-www-form-urlencoded" in content_type: # 处理表单数据(包括文件上传) try: # 对于文件上传,不要在中间件中消费request.form() # 因为这会导致FastAPI无法再次读取请求体 pass except Exception: pass elif "application/json" in content_type: # 处理JSON数据 try: body = await request.json() args.update(body) except (json.JSONDecodeError, UnicodeDecodeError): pass else: # 尝试解析为JSON,如果失败则跳过 try: body = await request.json() args.update(body) except (json.JSONDecodeError, UnicodeDecodeError): pass return args async def get_response_body(self, request: Request, response: Response) -> Any: # 检查Content-Length content_length = response.headers.get("content-length") if content_length and int(content_length) > self.max_body_size: return {"code": 0, "msg": "Response too large to log", "data": None} if hasattr(response, "body"): body = response.body else: body_chunks = [] async for chunk in response.body_iterator: if not isinstance(chunk, bytes): chunk = chunk.encode(response.charset) body_chunks.append(chunk) response.body_iterator = self._async_iter(body_chunks) body = b"".join(body_chunks) if any(request.url.path.startswith(path) for path in self.audit_log_paths): try: data = self.lenient_json(body) # 只保留基本信息,去除详细的响应内容 if isinstance(data, dict): data.pop("response_body", None) if "data" in data and isinstance(data["data"], list): for item in data["data"]: item.pop("response_body", None) return data except Exception: return None return self.lenient_json(body) def lenient_json(self, v: Any) -> Any: if isinstance(v, (str, bytes)): try: return json.loads(v) except (ValueError, TypeError): pass return v def normalize_json_field(self, value: Any) -> Any: """确保写入 JSONField 的值合法。""" if value is None: return None if isinstance(value, (bytes, bytearray)): try: value = value.decode("utf-8") except Exception: value = value.decode("utf-8", errors="ignore") if isinstance(value, str): stripped = value.strip() if not stripped: return None try: return json.loads(stripped) except (ValueError, TypeError): # 将非 JSON 字符串包装为字典,以便 JSONField 能够正确存储 return {"text": stripped} if isinstance(value, (dict, list, int, float, bool)): return value try: json.dumps(value) return value except (TypeError, ValueError): return str(value) async def _async_iter(self, items: list[bytes]) -> AsyncGenerator[bytes, None]: for item in items: yield item async def get_request_log(self, request: Request, response: Response) -> dict: """ 根据request和response对象获取对应的日志记录数据 """ data: dict = {"path": request.url.path, "status": response.status_code, "method": request.method} # 路由信息 app: FastAPI = request.app for route in app.routes: if ( isinstance(route, APIRoute) and route.path_regex.match(request.url.path) and request.method in route.methods ): data["module"] = ",".join(route.tags) data["summary"] = route.summary # 获取用户信息 try: token = request.headers.get("token") user_obj = None if token: user_obj: User = await AuthControl.is_authed(token) data["user_id"] = user_obj.id if user_obj else 0 data["username"] = user_obj.username if user_obj else "" except Exception: data["user_id"] = 0 data["username"] = "" return data async def before_request(self, request: Request): request_args = await self.get_request_args(request) request.state.request_args = request_args async def after_request(self, request: Request, response: Response, process_time: int): if request.method in self.methods: for path in self.exclude_paths: if re.search(path, request.url.path, re.I) is not None: return data: dict = await self.get_request_log(request=request, response=response) data["response_time"] = process_time request_args = getattr(request.state, "request_args", None) response_body = await self.get_response_body(request, response) data["request_args"] = self.normalize_json_field(request_args) data["response_body"] = self.normalize_json_field(response_body) await AuditLog.create(**data) return response async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: start_time: datetime = datetime.now() await self.before_request(request) response = await call_next(request) end_time: datetime = datetime.now() process_time = int((end_time.timestamp() - start_time.timestamp()) * 1000) await self.after_request(request, response, process_time) return response