import shutil from aerich import Command from fastapi import FastAPI from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from tortoise.expressions import Q from app.api import api_router from app.controllers.api import api_controller from app.controllers.user import UserCreate, user_controller from app.core.exceptions import ( DoesNotExist, DoesNotExistHandle, HTTPException, HttpExcHandle, IntegrityError, IntegrityHandle, RequestValidationError, RequestValidationHandle, ResponseValidationError, ResponseValidationHandle, ) from app.log import logger from app.models.admin import Api, Menu, Role from app.models.invoice import Invoice, PaymentReceipt from app.schemas.menus import MenuType from app.settings.config import settings from .middlewares import BackGroundTaskMiddleware, HttpAuditLogMiddleware def make_middlewares(): middleware = [ Middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS, allow_credentials=settings.CORS_ALLOW_CREDENTIALS, allow_methods=settings.CORS_ALLOW_METHODS, allow_headers=settings.CORS_ALLOW_HEADERS, ), Middleware(BackGroundTaskMiddleware), Middleware( HttpAuditLogMiddleware, methods=["GET", "POST", "PUT", "DELETE"], exclude_paths=[ "/api/v1/base/access_token", "/docs", "/openapi.json", "/static", # 排除静态文件路径 ], ), ] return middleware def register_exceptions(app: FastAPI): app.add_exception_handler(DoesNotExist, DoesNotExistHandle) app.add_exception_handler(HTTPException, HttpExcHandle) app.add_exception_handler(IntegrityError, IntegrityHandle) app.add_exception_handler(RequestValidationError, RequestValidationHandle) app.add_exception_handler(ResponseValidationError, ResponseValidationHandle) def register_routers(app: FastAPI, prefix: str = "/api"): app.include_router(api_router, prefix=prefix) async def init_superuser(): user = await user_controller.model.exists() if not user: await user_controller.create_user( UserCreate( username="admin", email="admin@admin.com", password="123456", is_active=True, is_superuser=True, ) ) async def init_menus(): menus = await Menu.exists() if not menus: parent_menu = await Menu.create( menu_type=MenuType.CATALOG, name="系统管理", path="/system", order=1, parent_id=0, icon="carbon:gui-management", is_hidden=False, component="Layout", keepalive=False, redirect="/system/user", ) children_menu = [ Menu( menu_type=MenuType.MENU, name="用户管理", path="user", order=1, parent_id=parent_menu.id, icon="material-symbols:person-outline-rounded", is_hidden=False, component="/system/user", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="角色管理", path="role", order=2, parent_id=parent_menu.id, icon="carbon:user-role", is_hidden=False, component="/system/role", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="菜单管理", path="menu", order=3, parent_id=parent_menu.id, icon="material-symbols:list-alt-outline", is_hidden=False, component="/system/menu", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="API管理", path="api", order=4, parent_id=parent_menu.id, icon="ant-design:api-outlined", is_hidden=False, component="/system/api", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="部门管理", path="dept", order=5, parent_id=parent_menu.id, icon="mingcute:department-line", is_hidden=False, component="/system/dept", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="审计日志", path="auditlog", order=6, parent_id=parent_menu.id, icon="ph:clipboard-text-bold", is_hidden=False, component="/system/auditlog", keepalive=False, ), ] await Menu.bulk_create(children_menu) # 创建系统数据管理菜单 data_menu = await Menu.create( menu_type=MenuType.CATALOG, name="系统数据", path="/data", order=2, parent_id=0, icon="carbon:data-base", is_hidden=False, component="Layout", keepalive=False, redirect="/data/industry", ) data_children_menu = [ Menu( menu_type=MenuType.MENU, name="行业修正", path="industry", order=1, parent_id=data_menu.id, icon="carbon:industry", is_hidden=False, component="/data/industry", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="政策匹配", path="policy", order=2, parent_id=data_menu.id, icon="carbon:policy", is_hidden=False, component="/data/policy", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="ESG关联", path="esg", order=3, parent_id=data_menu.id, icon="carbon:earth-southeast-asia", is_hidden=False, component="/data/esg", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="行业基准", path="index", order=4, parent_id=data_menu.id, icon="carbon:chart-line", is_hidden=False, component="/data/index", keepalive=False, ), ] await Menu.bulk_create(data_children_menu) await Menu.create( menu_type=MenuType.MENU, name="一级菜单", path="/top-menu", order=2, parent_id=0, icon="material-symbols:featured-play-list-outline", is_hidden=False, component="/top-menu", keepalive=False, redirect="", ) # 创建交易管理菜单 transaction_menu = await Menu.create( menu_type=MenuType.CATALOG, name="交易管理", path="/transaction", order=3, parent_id=0, icon="carbon:wallet", is_hidden=False, component="Layout", keepalive=False, redirect="/transaction/invoice", ) transaction_children = [ Menu( menu_type=MenuType.MENU, name="发票管理", path="invoice", order=1, parent_id=transaction_menu.id, icon="mdi:file-document-outline", is_hidden=False, component="/transaction/invoice", keepalive=False, ), Menu( menu_type=MenuType.MENU, name="交易记录", path="receipts", order=2, parent_id=transaction_menu.id, icon="mdi:receipt-text-outline", is_hidden=False, component="/transaction/receipts", keepalive=False, ), ] await Menu.bulk_create(transaction_children) async def init_apis(): await api_controller.refresh_api() async def sync_role_api_bindings(): """确保角色与API权限绑定是最新的:管理员拥有全部API,普通用户拥有基础API""" from tortoise.expressions import Q try: admin_role = await Role.filter(name="管理员").first() if admin_role: all_apis = await Api.all() current = await admin_role.apis.all() current_keys = {(a.method, a.path) for a in current} missing = [a for a in all_apis if (a.method, a.path) not in current_keys] if missing: await admin_role.apis.add(*missing) user_role = await Role.filter(name="普通用户").first() if user_role: basic_apis = await Api.filter(Q(method__in=["GET"]) | Q(tags="基础模块")) current_u = await user_role.apis.all() current_u_keys = {(a.method, a.path) for a in current_u} missing_u = [a for a in basic_apis if (a.method, a.path) not in current_u_keys] if missing_u: await user_role.apis.add(*missing_u) except Exception: pass async def _ensure_unique_index(): """确保 valuation_calculation_steps 表的唯一索引存在""" try: conn_alias = settings.TORTOISE_ORM["apps"]["models"]["default_connection"] from tortoise import connections conn = connections.get(conn_alias) # 检查表是否存在 result = await conn.execute_query( "SHOW TABLES LIKE 'valuation_calculation_steps'" ) if not result or len(result[1]) == 0: logger.info("Table valuation_calculation_steps does not exist, skipping index check") return # 检查唯一索引是否存在 # 查找包含 valuation_id 和 formula_code 的唯一索引 index_result = await conn.execute_query( "SHOW INDEX FROM `valuation_calculation_steps` WHERE Non_unique = 0 AND Column_name IN ('valuation_id', 'formula_code')" ) # 查找是否存在 (valuation_id, formula_code) 的唯一索引 # 对于复合索引,SHOW INDEX 会返回多行,每行对应一个列 # 需要检查是否有同一个 Key_name 包含两个列 has_unique_index = False if index_result and len(index_result) > 1: # 按 Key_name 分组 index_groups = {} for row in index_result[1]: if len(row) >= 5: key_name = row[2] if len(row) > 2 else "" non_unique = row[1] if len(row) > 1 else 1 column_name = row[4] if len(row) > 4 else "" seq_in_index = row[3] if len(row) > 3 else 0 if non_unique == 0 and column_name in ('valuation_id', 'formula_code'): if key_name not in index_groups: index_groups[key_name] = [] index_groups[key_name].append(column_name) # 检查是否有索引包含两个列 for key_name, columns in index_groups.items(): if 'valuation_id' in columns and 'formula_code' in columns: has_unique_index = True logger.debug(f"Found unique index: {key_name} on (valuation_id, formula_code)") break if not has_unique_index: logger.warning("Unique index on (valuation_id, formula_code) not found, attempting to create...") try: # 先删除可能存在的重复记录 await conn.execute_query(""" DELETE t1 FROM `valuation_calculation_steps` t1 INNER JOIN `valuation_calculation_steps` t2 WHERE t1.id > t2.id AND t1.valuation_id = t2.valuation_id AND t1.formula_code = t2.formula_code AND t1.formula_code IS NOT NULL """) logger.info("Cleaned up duplicate records") # 创建唯一索引 await conn.execute_query(""" CREATE UNIQUE INDEX `uidx_valuation_formula` ON `valuation_calculation_steps` (`valuation_id`, `formula_code`) """) logger.info("Created unique index on (valuation_id, formula_code)") except Exception as idx_err: error_str = str(idx_err).lower() if "duplicate key name" in error_str or "already exists" in error_str: logger.info("Unique index already exists (different name)") else: logger.warning(f"Failed to create unique index: {idx_err}") else: logger.debug("Unique index on (valuation_id, formula_code) already exists") except Exception as e: logger.warning(f"Failed to ensure unique index: {e}") async def init_db(): import os from pathlib import Path from tortoise import Tortoise from tortoise.exceptions import OperationalError command = Command(tortoise_config=settings.TORTOISE_ORM) try: await command.init_db(safe=True) except FileExistsError: pass await command.init() # 检查并清理可能冲突的迁移文件(避免交互式提示) # Aerich 在检测到迁移文件已存在时会交互式提示,我们提前删除冲突文件 migrations_dir = Path("migrations/models") if migrations_dir.exists(): # 查找包含 "update" 的迁移文件(通常是自动生成的冲突文件) for migration_file in migrations_dir.glob("*update*.py"): if migration_file.name != "__init__.py": logger.info(f"Removing conflicting migration file: {migration_file.name}") migration_file.unlink() # 尝试执行 migrate try: await command.migrate() except AttributeError: logger.warning("unable to retrieve model history from database, model history will be created from scratch") shutil.rmtree("migrations") await command.init_db(safe=True) except Exception as e: # 如果 migrate 失败,记录警告但继续执行 upgrade logger.warning(f"Migrate failed: {e}, continuing with upgrade...") # 在 upgrade 之前,先检查表是否存在,如果不存在则先创建表 try: await command.upgrade(run_in_transaction=True) # upgrade 成功后,验证并修复唯一索引 await _ensure_unique_index() except (OperationalError, Exception) as e: error_msg = str(e) # 如果是因为表不存在而失败,先让 Tortoise 生成表结构 if "doesn't exist" in error_msg.lower() or ("table" in error_msg.lower() and "valuation_calculation_steps" in error_msg): logger.warning(f"Table not found during upgrade: {error_msg}, generating schemas first...") # 确保 Tortoise 已初始化(Aerich 的 init 应该已经初始化了,但为了安全再检查) try: # 生成表结构(safe=True 表示如果表已存在则跳过) await Tortoise.generate_schemas(safe=True) logger.info("Tables generated successfully, retrying upgrade...") # 重新尝试 upgrade(这次应该会成功,因为表已经存在) try: await command.upgrade(run_in_transaction=True) except Exception as upgrade_err: # 如果 upgrade 仍然失败,可能是迁移文件的问题,记录警告但继续 logger.warning(f"Upgrade still failed after generating schemas: {upgrade_err}, continuing anyway...") except Exception as gen_err: logger.error(f"Failed to generate schemas: {gen_err}") raise # 如果是重复字段错误,说明迁移已经执行过,直接跳过并确保索引 elif "duplicate column name" in error_msg.lower(): logger.warning(f"Duplicate column detected during upgrade: {error_msg}, skipping migration step and ensuring schema integrity...") await _ensure_unique_index() # 如果是重复索引错误,删除表并重新创建(最简单可靠的方法) elif "duplicate key" in error_msg.lower() or "duplicate key name" in error_msg.lower(): logger.warning(f"Duplicate index detected: {error_msg}, dropping and recreating table...") try: # Aerich 的 command.init() 已经初始化了 Tortoise,直接使用连接 # 连接别名是 "mysql"(从配置中读取) conn_alias = settings.TORTOISE_ORM["apps"]["models"]["default_connection"] from tortoise import connections # 尝试获取连接,如果失败则重新初始化 try: conn = connections.get(conn_alias) except Exception: # 如果连接不存在,重新初始化 Tortoise await Tortoise.init(config=settings.TORTOISE_ORM) conn = connections.get(conn_alias) # 删除表 await conn.execute_query("DROP TABLE IF EXISTS `valuation_calculation_steps`") logger.info("Dropped valuation_calculation_steps table") # 重新生成表结构(包含正确的唯一索引) # 使用 safe=True 避免尝试创建已存在的其他表(如 user_role),只创建不存在的表 await Tortoise.generate_schemas(safe=True) logger.info("Table regenerated successfully with correct unique index") except Exception as recreate_err: logger.error(f"Failed to recreate table: {recreate_err}") raise else: raise async def init_roles(): roles = await Role.exists() if not roles: admin_role = await Role.create( name="管理员", desc="管理员角色", ) user_role = await Role.create( name="普通用户", desc="普通用户角色", ) # 分配所有API给管理员角色 all_apis = await Api.all() await admin_role.apis.add(*all_apis) # 分配所有菜单给管理员和普通用户 all_menus = await Menu.all() await admin_role.menus.add(*all_menus) await user_role.menus.add(*all_menus) # 为普通用户分配基本API basic_apis = await Api.filter(Q(method__in=["GET"]) | Q(tags="基础模块")) await user_role.apis.add(*basic_apis) async def init_demo_transactions(): """ 创建开发环境演示用的发票与交易记录(付款凭证)数据。 功能: - 在无现有付款凭证数据时,批量生成若干 `Invoice` 与关联的 `PaymentReceipt`。 - 仅在调试模式下执行,避免污染生产环境。 参数: 无 返回: `None`,异步执行插入操作。 """ if not settings.DEBUG: return has_receipt = await PaymentReceipt.exists() if has_receipt: return demo_invoices = [] demo_payloads = [ { "ticket_type": "electronic", "invoice_type": "normal", "phone": "13800000001", "email": "demo1@example.com", "company_name": "演示科技有限公司", "tax_number": "91310000MA1DEMO01", "register_address": "上海市浦东新区演示路 100 号", "register_phone": "021-88880001", "bank_name": "招商银行上海分行", "bank_account": "6214830000000001", "status": "pending", "wechat": "demo_wechat_01", }, { "ticket_type": "paper", "invoice_type": "special", "phone": "13800000002", "email": "demo2@example.com", "company_name": "示例信息技术股份有限公司", "tax_number": "91310000MA1DEMO02", "register_address": "北京市海淀区知春路 66 号", "register_phone": "010-66660002", "bank_name": "中国银行北京分行", "bank_account": "6216610000000002", "status": "invoiced", "wechat": "demo_wechat_02", }, { "ticket_type": "electronic", "invoice_type": "special", "phone": "13800000003", "email": "demo3@example.com", "company_name": "华夏制造有限公司", "tax_number": "91310000MA1DEMO03", "register_address": "广州市天河区高新大道 8 号", "register_phone": "020-77770003", "bank_name": "建设银行广州分行", "bank_account": "6227000000000003", "status": "rejected", "wechat": "demo_wechat_03", }, { "ticket_type": "paper", "invoice_type": "normal", "phone": "13800000004", "email": "demo4@example.com", "company_name": "泰岳网络科技有限公司", "tax_number": "91310000MA1DEMO04", "register_address": "杭州市滨江区科技大道 1 号", "register_phone": "0571-55550004", "bank_name": "农业银行杭州分行", "bank_account": "6228480000000004", "status": "refunded", "wechat": "demo_wechat_04", }, { "ticket_type": "electronic", "invoice_type": "normal", "phone": "13800000005", "email": "demo5@example.com", "company_name": "星云数据有限公司", "tax_number": "91310000MA1DEMO05", "register_address": "成都市高新区软件园 9 号楼", "register_phone": "028-33330005", "bank_name": "工商银行成都分行", "bank_account": "6222020000000005", "status": "pending", "wechat": "demo_wechat_05", }, ] for payload in demo_payloads: inv = await Invoice.create(**payload) demo_invoices.append(inv) for idx, inv in enumerate(demo_invoices, start=1): await PaymentReceipt.create( invoice=inv, url=f"https://example.com/demo-receipt-{idx}-a.png", note="DEMO 凭证 A", verified=(inv.status == "invoiced"), ) if idx % 2 == 0: await PaymentReceipt.create( invoice=inv, url=f"https://example.com/demo-receipt-{idx}-b.png", note="DEMO 凭证 B", verified=False, ) async def init_data(): await init_db() await init_superuser() await init_menus() await init_apis() await init_roles() await sync_role_api_bindings() await init_demo_transactions()