from __future__ import annotations from datetime import datetime from typing import Any, Dict, Iterable, Optional, Type from app.models.company import ( BaseCompanyModel, BossCompany, CompanyCleaningQueue, QcwyCompany, ZhilianCompany, ) COMPANY_SOURCES = {"boss", "qcwy", "zhilian"} QUEUE_TERMINAL_STATUSES = {"done", "failed"} def normalize_company_id(source: str, company_id: str) -> str: value = str(company_id or "").strip() if source == "qcwy" and value.lower().startswith("co") and value[2:].isdigit(): return value[2:] return value def _pick_first(data: dict[str, Any], *keys: str) -> Optional[Any]: for key in keys: value = data.get(key) if value not in (None, ""): return value return None def _nested_get(data: dict[str, Any], *path: str) -> Any: current: Any = data for key in path: if not isinstance(current, dict): return None current = current.get(key) return current def _clean_text(value: Any) -> Optional[str]: if value is None: return None text = str(value).strip() return text or None def _model_for_source(source: str) -> Type[BaseCompanyModel]: mapping: dict[str, Type[BaseCompanyModel]] = { "boss": BossCompany, "qcwy": QcwyCompany, "zhilian": ZhilianCompany, } if source not in mapping: raise ValueError(f"unsupported source: {source}") return mapping[source] def _extract_boss_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]: payload = raw.get("zpData") if isinstance(raw.get("zpData"), dict) else raw brand = payload.get("brandComInfoVO") or {} company_full = payload.get("companyFullInfoVO") or {} return { "source_company_id": normalize_company_id("boss", company_id or _pick_first(brand, "encryptBrandId", "brandId")), "company_name": _clean_text( _pick_first(payload, "name") or _pick_first(company_full, "name", "brandName") or _pick_first(brand, "brandName") ) or "", "company_type": _clean_text(_pick_first(company_full, "typeName") or _pick_first(brand, "brandIndustry")), "industry": _clean_text(_pick_first(brand, "industryName") or _pick_first(company_full, "industry")), "company_size": _clean_text(_pick_first(brand, "scaleName") or _pick_first(company_full, "scaleName")), "financing_stage": _clean_text(_pick_first(brand, "stageName") or _pick_first(company_full, "stageName")), "city": _clean_text(_pick_first(company_full, "cityName", "city")), "address": _clean_text(_pick_first(company_full, "address", "addressInfo")), "website": _clean_text(_pick_first(company_full, "website")), "logo_url": _clean_text(_pick_first(company_full, "logo", "brandLogo") or _pick_first(brand, "logo", "brandLogo")), "description": _clean_text( _pick_first(company_full, "introduce", "introduction", "companyDesc") or _pick_first(brand, "introduce") ), } def _extract_qcwy_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]: financing = raw.get("financingStage") or {} coinfo = raw.get("coinfo") if isinstance(raw.get("coinfo"), dict) else {} return { "source_company_id": normalize_company_id( "qcwy", company_id or _pick_first(raw, "companyId", "coId") or _nested_get(raw, "coinfo", "coid"), ), "company_name": _clean_text( _pick_first(raw, "companyName", "fullCompanyName", "companyNameEn") or _pick_first(coinfo, "coname", "brandName") ) or "", "company_type": _clean_text(_pick_first(raw, "companyTypeString", "orgTypeName") or _pick_first(coinfo, "cotype")), "industry": _clean_text( _pick_first(raw, "industryName", "companyIndustryType1Str") or _pick_first(coinfo, "indtype1", "indtype2", "coIndustryText") ), "company_size": _clean_text( _pick_first(raw, "companySizeString", "companySize", "orgSizeName") or _pick_first(coinfo, "cosize") ), "financing_stage": _clean_text(_pick_first(financing, "name") or _pick_first(raw, "financingStageName")), "city": _clean_text(_pick_first(raw, "cityName", "jobAreaString", "workCity") or _pick_first(coinfo, "areaString")), "address": _clean_text( _pick_first(raw, "address", "location") or _nested_get(raw, "workLocation", "workAddress") or _pick_first(coinfo, "caddr") ), "website": _clean_text(_pick_first(raw, "companyUrl", "companyHref") or _pick_first(coinfo, "webUrl")), "logo_url": _clean_text(_pick_first(raw, "companyLogo") or _pick_first(coinfo, "logourl")), "description": _clean_text( _pick_first(raw, "companyDesc", "company_desc", "description") or _nested_get(raw, "campusRootOrgInfo", "description") or _pick_first(coinfo, "coinfo") ), } def _extract_zhilian_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]: data = raw.get("data") if isinstance(raw.get("data"), dict) else raw company_base = data.get("companyBase") or {} detailed_company = data.get("detailedCompany") or {} return { "source_company_id": normalize_company_id( "zhilian", company_id or _pick_first(company_base, "companyNumber", "number") or _pick_first(detailed_company, "companyNumber", "number"), ), "company_name": _clean_text(_pick_first(company_base, "companyName") or _pick_first(data, "companyName")) or "", "company_type": _clean_text( _pick_first(company_base, "companyTypeName", "companyType") or _pick_first(detailed_company, "companyTypeName") ), "industry": _clean_text(_pick_first(company_base, "industryName") or _pick_first(detailed_company, "industryName")), "company_size": _clean_text( _pick_first(company_base, "companySize", "companySizeString") or _pick_first(detailed_company, "companySize") ), "financing_stage": _clean_text( _pick_first(company_base, "financingStageName") or _nested_get(company_base, "financingStage", "name") or _nested_get(detailed_company, "financingStage", "name") ), "city": _clean_text(_pick_first(company_base, "cityName") or _pick_first(detailed_company, "cityName")), "address": _clean_text(_pick_first(company_base, "address") or _pick_first(detailed_company, "address")), "website": _clean_text(_pick_first(company_base, "companyUrl", "website")), "logo_url": _clean_text(_pick_first(company_base, "logoUrl", "companyLogo")), "description": _clean_text( _pick_first(company_base, "companyDescWithHtml", "companyDesc") or _pick_first(detailed_company, "companyDescription", "companyDesc") ), } def extract_company_fields(source: str, raw: dict[str, Any], company_id: str) -> dict[str, Any]: if source == "boss": return _extract_boss_fields(raw, company_id) if source == "qcwy": return _extract_qcwy_fields(raw, company_id) if source == "zhilian": return _extract_zhilian_fields(raw, company_id) raise ValueError(f"unsupported source: {source}") class CompanyStorageService: @staticmethod def company_model(source: str) -> Type[BaseCompanyModel]: return _model_for_source(source) async def get_existing_company_ids(self, source: str, company_ids: Iterable[str]) -> set[str]: normalized_ids = [normalize_company_id(source, item) for item in company_ids if item] if not normalized_ids: return set() model = self.company_model(source) rows = await model.filter(source_company_id__in=normalized_ids).values_list("source_company_id", flat=True) return set(rows) async def get_all_company_ids(self, source: str) -> set[str]: """获取该平台所有已入库的公司 ID(用于 ClickHouse 查询排除)""" model = self.company_model(source) rows = await model.all().values_list("source_company_id", flat=True) return set(rows) async def get_existing_queue_ids(self, source: str, company_ids: Iterable[str]) -> set[str]: normalized_ids = [normalize_company_id(source, item) for item in company_ids if item] if not normalized_ids: return set() rows = await CompanyCleaningQueue.filter(source=source, company_id__in=normalized_ids).values_list("company_id", flat=True) return set(rows) async def enqueue_company(self, source: str, company_id: str, company_name: str = "") -> tuple[CompanyCleaningQueue, bool]: normalized_id = normalize_company_id(source, company_id) defaults = { "company_name": company_name or "", "status": "pending", "error_msg": "", "retry_count": 0, "started_at": None, "finished_at": None, "jobs_fetched": 0, "jobs_stored": 0, "jobs_duplicate": 0, "jobs_failed": 0, "jobs_error_msg": "", } queue, created = await CompanyCleaningQueue.get_or_create( source=source, company_id=normalized_id, defaults=defaults, ) if not created and company_name and queue.company_name != company_name: queue.company_name = company_name await queue.save(update_fields=["company_name", "updated_at"]) return queue, created async def enqueue_companies(self, source: str, companies: Iterable[dict[str, str]]) -> int: created_count = 0 for item in companies: _, created = await self.enqueue_company( source=source, company_id=item.get("company_id", ""), company_name=item.get("company_name", "") or "", ) if created: created_count += 1 return created_count async def get_company_record(self, source: str, company_id: str) -> Optional[BaseCompanyModel]: normalized_id = normalize_company_id(source, company_id) model = self.company_model(source) return await model.get_or_none(source_company_id=normalized_id) async def upsert_company( self, source: str, raw_data: dict[str, Any], *, company_id: Optional[str] = None, ) -> dict[str, Any]: normalized_id = normalize_company_id(source, company_id or "") fields = extract_company_fields(source, raw_data, normalized_id) normalized_id = fields["source_company_id"] if not normalized_id: raise ValueError(f"missing normalized company id for source={source}") if not fields["company_name"]: raise ValueError(f"missing company name for source={source} company_id={normalized_id}") model = self.company_model(source) record = await model.get_or_none(source_company_id=normalized_id) now = datetime.now() payload = { **fields, "raw_json": raw_data, "last_crawled_at": now, } if record: for key, value in payload.items(): setattr(record, key, value) await record.save() created = False else: record = await model.create( **payload, first_crawled_at=now, ) created = True return { "success": True, "created": created, "company_id": normalized_id, "company_name": record.company_name, "data_summary": { "source": source, "company_id": normalized_id, "company_name": record.company_name, "created": created, }, "record": record, } async def mark_queue_processing(self, queue: CompanyCleaningQueue) -> None: queue.status = "processing" queue.error_msg = "" queue.started_at = datetime.now() queue.finished_at = None queue.jobs_fetched = 0 queue.jobs_stored = 0 queue.jobs_duplicate = 0 queue.jobs_failed = 0 queue.jobs_error_msg = "" await queue.save( update_fields=[ "status", "error_msg", "started_at", "finished_at", "jobs_fetched", "jobs_stored", "jobs_duplicate", "jobs_failed", "jobs_error_msg", "updated_at", ] ) async def mark_queue_result( self, queue: CompanyCleaningQueue, *, status: str, error_msg: str = "", increment_retry: bool = False, jobs_summary: Optional[dict[str, Any]] = None, ) -> None: queue.status = status queue.error_msg = error_msg or "" queue.finished_at = datetime.now() if jobs_summary: queue.jobs_fetched = int(jobs_summary.get("jobs_fetched") or 0) queue.jobs_stored = int(jobs_summary.get("stored_success") or 0) queue.jobs_duplicate = int(jobs_summary.get("duplicate") or 0) queue.jobs_failed = int(jobs_summary.get("failed") or 0) queue.jobs_error_msg = jobs_summary.get("error") or "" if increment_retry: queue.retry_count += 1 await queue.save( update_fields=[ "company_name", "status", "error_msg", "retry_count", "finished_at", "jobs_fetched", "jobs_stored", "jobs_duplicate", "jobs_failed", "jobs_error_msg", "updated_at", ] ) company_storage = CompanyStorageService()