class CgRagService:
def __init__(self, *, config: CgRagConfig, manager: Any, api_reranker: Any, args: Any) -> None:
self.config = config
self.manager = manager
self.api_reranker = api_reranker
self.args = args
self.retrieve_cache = RetrieveCache(CacheConfig(config.cache_max_entries, config.cache_ttl_seconds))
self.generation_session = requests.Session()
def health(self) -> dict[str, Any]:
manager_health = self.manager.health()
return _with_request_id({
"ok": True,
"service": "CG_RAG",
"http_prefix": self.config.http_prefix,
"mcp_path": self.config.mcp_path,
"available_scopes": manager_health.get("available_scopes", []),
"default_scope": manager_health.get("default_scope", ""),
"active_scope": manager_health.get("active_scope"),
"active_profile": _sanitize_profile_for_health(manager_health.get("active_profile")),
"profiles": _sanitize_profiles_for_health(manager_health.get("profiles")),
"config_issues": [asdict(issue) for issue in validate_cg_rag_config(self.config)],
"generation": {
"configured": self.config.generation_configured,
"endpoint_configured": bool(self.config.generation_endpoint),
"endpoint_host": endpoint_host(self.config.generation_endpoint),
"model": self.config.generation_model,
"response_format_json": self.config.generation_response_format_json,
"enable_thinking": self.config.generation_enable_thinking,
"timeout_seconds": self.config.timeout_seconds,
"max_context_docs": self.config.max_context_docs,
"default_max_items": self.config.default_max_items,
},
"rerank": {
"api_enabled": self._api_rerank_enabled(),
"model": str(getattr(self.args, "rerank_api_model", "") or ""),
},
})
def profiles(self) -> dict[str, Any]:
return dict(self.manager.health().get("profiles") or {})
def retrieve_rerank(self, request: CgRetrieveRerankRequest) -> dict[str, Any]:
topk = int(request.topk or getattr(self.args, "store_topk", 78))
normalized_scope = self._normalize_manager_scope(request.scope)
cached = self.retrieve_cache.get(normalized_scope, request.query, topk)
if cached is not None:
cached["cache"] = {"hit": True}
return _with_request_id(cached)
active = self.manager.get_retriever(normalized_scope)
retriever = active.retriever
started_at = time.perf_counter()
query_payload = {"dense": request.query, "bm25": request.query, "default": request.query}
setattr(retriever, "_last_rerank_time_ms", None)
if hasattr(retriever, "batch_search_single_call"):
docs_batch, scores_batch = retriever.batch_search_single_call([query_payload], num=topk, return_score=True)
docs = list(docs_batch[0])
scores = list(scores_batch[0])
else:
docs, scores = retriever.search(query_payload, num=topk, return_score=True)
docs = list(docs)
scores = list(scores)
local_search_time_ms = (time.perf_counter() - started_at) * 1000.0
rerank_time_ms = getattr(retriever, "_last_rerank_time_ms", None)
docs, scores = dedup_docs_with_scores(docs, scores)
api_rerank_enabled = self._api_rerank_enabled()
if api_rerank_enabled:
rerank_started_at = time.perf_counter()
docs, scores = self.api_reranker.rerank(request.query, docs, topk=topk)
rerank_time_ms = (time.perf_counter() - rerank_started_at) * 1000.0
docs = hydrate_retrieval_docs_from_corpus(trim_retrieval_docs(docs, scores, store_topk=topk))
search_total_time_ms = (time.perf_counter() - started_at) * 1000.0
retrieval_time_ms = local_search_time_ms
if not api_rerank_enabled and rerank_time_ms is not None:
retrieval_time_ms = max(search_total_time_ms - float(rerank_time_ms), 0.0)
payload = {
"query": request.query,
"retrieval_docs": docs,
"num_docs": len(docs),
"topk": topk,
"cache": {"hit": False},
"pipeline": {
"requested_scope": request.scope or "",
"active_scope": active.scope,
"scope_switched": active.scope_switched,
"profile": self._legacy_profile(active.profile),
"merge_method": str(getattr(retriever, "merge_method", "")),
"rerank_enabled": api_rerank_enabled or getattr(retriever, "reranker", None) is not None,
"api_rerank_enabled": api_rerank_enabled,
"rerank_backend": self._rerank_backend(retriever),
"rerank_model_path": self._rerank_model_path(retriever),
"embedding_api": self._embedding_api_status(retriever),
"api_embedding_encoder_applied": bool(getattr(retriever, "_embedding_api_encoder_applied", False)),
"retrieval_time_ms": retrieval_time_ms,
"rerank_time_ms": rerank_time_ms,
"search_total_time_ms": search_total_time_ms,
},
}
self.retrieve_cache.set(normalized_scope, request.query, topk, payload)
return _with_request_id(payload)
def warmup(self, scopes: list[str] | None = None) -> dict[str, Any]:
requested_scopes = self._warmup_requested_scopes(scopes)
warmed: list[dict[str, Any]] = []
for requested_scope in requested_scopes:
normalized_scope = self._normalize_warmup_scope(requested_scope)
active = self.manager.get_retriever(normalized_scope)
active_scope = str(getattr(active, "scope", "") or normalized_scope)
warmed.append(
{
"requested_scope": str(requested_scope or ""),
"normalized_scope": normalized_scope,
"active_scope": active_scope,
"warmed": True,
"resident": False,
"scope_switched": bool(getattr(active, "scope_switched", False)),
}
)
final_active_scope = warmed[-1]["active_scope"] if warmed else None
for item in warmed:
item["resident"] = item["active_scope"] == final_active_scope
return _with_request_id(
{
"ok": True,
"requested_scopes": [str(scope or "") for scope in requested_scopes],
"warmed_scopes": [item["active_scope"] for item in warmed],
"final_active_scope": final_active_scope,
"warmed": warmed,
}
)
def constrained_generate(self, request: CgConstrainedGenerateRequest) -> dict[str, Any]:
if not self.config.generation_configured:
return self._generation_error(
"generation_not_configured",
"CG_RAG 生成接口未配置,请设置 CG_RAG_GENERATION_ENDPOINT 和 CG_RAG_GENERATION_MODEL。",
)
context_docs = list(request.retrieval_docs[: self.config.max_context_docs])
messages = build_rag_messages(
request.query,
context_docs,
max_items=request.max_items,
context_topk=self.config.max_context_docs,
answer_mode="selected_indices",
)
payload: dict[str, Any] = {
"model": self.config.generation_model,
"messages": messages,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 0.0,
"max_tokens": CONSTRAINED_GENERATION_MAX_TOKENS,
}
if self.config.generation_response_format_json:
payload["response_format"] = {"type": "json_object"}
headers = {"Content-Type": "application/json"}
if self.config.api_key:
headers["Authorization"] = f"Bearer {self.config.api_key}"
try:
response = self.generation_session.post(
self.config.generation_endpoint,
headers=headers,
json=payload,
timeout=self.config.timeout_seconds,
)
response.raise_for_status()
except requests.Timeout:
return self._generation_error(
"generation_timeout",
"CG_RAG 生成请求超时,请稍后重试。",
)
except requests.RequestException:
return self._generation_error(
"generation_request_failed",
"CG_RAG 生成请求失败,请检查生成接口网络、鉴权或服务状态。",
)
try:
body = response.json()
message = self._extract_response_message(body)
raw_content = self._extract_message_content(message)
if not raw_content:
raise ValueError("empty generation content")
pred_indices = self._parse_generation_indices(raw_content, request.max_items, len(context_docs))
except (ValueError, json.JSONDecodeError):
return self._generation_error(
"generation_response_invalid",
"CG_RAG 生成接口返回格式无效,请检查 OpenAI-compatible 响应结构和 JSON 输出。",
)
if context_docs and not pred_indices:
return self._generation_error(
"no_stable_article",
"当前在线候选中未稳定选出明确条文。",
)
pred_items = [format_display_doc_citation(context_docs[index - 1]) for index in pred_indices]
result = {
"ok": True,
"generation_mode": "constrained",
"answer_text": self._answer_text(pred_items),
"pred_indices": pred_indices,
"pred_items": pred_items,
"pred_raw": raw_content,
"thinking": self._extract_reasoning_content(body),
"usage": self._extract_usage(body),
}
if request.include_debug:
result.update(
{
"prompt": messages,
"retrieval_result": context_docs,
"endpoint": self.config.generation_endpoint,
"model": self.config.generation_model,
}
)
return _with_request_id(result)
def rag(self, request: CgRagRequest) -> dict[str, Any]:
retrieval = self.retrieve_rerank(
CgRetrieveRerankRequest(query=request.query, scope=request.scope, topk=request.topk)
)
retrieval_docs = list(retrieval.get("retrieval_docs") or [])
if retrieval_docs:
generation = self.constrained_generate(
CgConstrainedGenerateRequest(
query=request.query,
retrieval_docs=retrieval_docs,
max_items=request.max_items,
include_debug=request.include_debug,
)
)
else:
generation = self._generation_error(
"no_stable_article",
"当前在线候选中未稳定选出明确条文。",
)
return _with_request_id({"query": request.query, "retrieval": retrieval, "generation": generation})
def _api_rerank_enabled(self) -> bool:
return bool(self.api_reranker is not None and getattr(self.api_reranker, "enabled", False))
def _warmup_requested_scopes(self, scopes: list[str] | None) -> list[str]:
if scopes is not None:
return list(scopes)
health = self.manager.health()
available_scopes = health.get("available_scopes") or []
if available_scopes:
return [str(scope) for scope in available_scopes]
default_scope = str(health.get("default_scope") or "")
return [default_scope] if default_scope else []
def _normalize_warmup_scope(self, scope: str | None) -> str:
return self._normalize_manager_scope(scope)
def _normalize_manager_scope(self, scope: str | None) -> str:
normalizer = getattr(self.manager, "normalize_scope", None)
if callable(normalizer):
normalized = normalizer(scope)
return str(normalized or "")
return normalize_scope(scope, self._manager_default_scope())
def _manager_default_scope(self) -> str:
try:
health = self.manager.health()
except Exception:
return "usual"
return normalize_scope(str(health.get("default_scope") or ""), "usual")
def _generation_error(self, code: str, message: str) -> dict[str, Any]:
return _with_request_id({
"ok": False,
"error": {"code": code, "message": message, "details": {}},
"pred_indices": [],
"pred_items": [],
"pred_raw": "",
})
def _parse_generation_indices(self, raw_content: str, max_items: int, max_index: int) -> list[int]:
try:
return parse_generation_indices(raw_content, max_items, max_index)
except (ValueError, json.JSONDecodeError):
pass
for match in reversed(list(re.finditer(r"\{[^{}]*\}", str(raw_content or ""), flags=re.S))):
try:
payload = json.loads(match.group(0))
except json.JSONDecodeError:
continue
if not isinstance(payload, dict) or not isinstance(payload.get("indices"), list):
continue
return self._filter_generation_indices(payload["indices"], max_items, max_index)
raise ValueError("模型输出中未找到有效 indices JSON 对象")
def _filter_generation_indices(self, indices: list[Any], max_items: int, max_index: int) -> list[int]:
selected: list[int] = []
seen: set[int] = set()
for item in indices:
value: int | None = None
if isinstance(item, int):
value = item
elif isinstance(item, str) and item.strip().isdigit():
value = int(item.strip())
elif isinstance(item, dict):
idx = item.get("index")
if isinstance(idx, int):
value = idx
elif isinstance(idx, str) and idx.strip().isdigit():
value = int(idx.strip())
if value is None or value < 1 or value > max_index or value in seen:
continue
seen.add(value)
selected.append(value)
if len(selected) >= max_items:
break
return selected
def _extract_response_message(self, body: Any) -> dict[str, Any]:
if not isinstance(body, dict):
raise ValueError("generation response body is not an object")
choices = body.get("choices")
if not isinstance(choices, list) or not choices:
raise ValueError("generation response choices is invalid")
choice = choices[0]
if not isinstance(choice, dict):
raise ValueError("generation response choice is invalid")
message = choice.get("message")
if not isinstance(message, dict):
raise ValueError("generation response message is invalid")
return message
def _extract_message_content(self, message: dict[str, Any]) -> str:
content = str(message.get("content") or "").strip()
return content or str(message.get("reasoning_content") or "").strip()
def _extract_reasoning_content(self, body: dict[str, Any]) -> str:
message = self._extract_response_message(body)
return str(message.get("reasoning_content") or "").strip()
def _extract_usage(self, body: dict[str, Any]) -> dict[str, int]:
usage = body.get("usage") or {}
if not isinstance(usage, dict):
usage = {}
return {
"prompt_tokens": self._coerce_token_count(usage.get("prompt_tokens")),
"completion_tokens": self._coerce_token_count(usage.get("completion_tokens")),
"total_tokens": self._coerce_token_count(usage.get("total_tokens")),
}
def _coerce_token_count(self, value: Any) -> int:
try:
return max(int(value), 0)
except (TypeError, ValueError):
return 0
def _answer_text(self, pred_items: list[str]) -> str:
if not pred_items:
return "当前候选中未稳定选出明确条文。"
return "我根据检索与候选筛选结果,整理出以下最相关条文:\n" + "\n".join(
f"- {item}" for item in pred_items
)
def _legacy_profile(self, profile: Any) -> dict[str, Any]:
return {
"label": getattr(profile, "label", ""),
"corpus_root": getattr(profile, "corpus_root", ""),
"index_root": getattr(profile, "index_root", ""),
"bm25_index_root": getattr(profile, "bm25_index_root", ""),
"retrieval_method": getattr(profile, "retrieval_method", ""),
"index_file_name": getattr(profile, "index_file_name", ""),
}
def _rerank_backend(self, retriever: Any | None) -> str | None:
if self._api_rerank_enabled():
return "api"
if retriever is None:
return None
reranker = getattr(retriever, "reranker", None)
if reranker is None:
return None
class_name = reranker.__class__.__name__.lower()
return "vllm" if "vllm" in class_name else "transformers"
def _rerank_model_path(self, retriever: Any | None) -> str:
if self._api_rerank_enabled():
return str(getattr(self.args, "rerank_api_model", "") or "")
if retriever is None:
return ""
reranker = getattr(retriever, "reranker", None)
if reranker is None:
return ""
return str(getattr(reranker, "reranker_model_path", "") or getattr(self.args, "rerank_model_path", "") or "")
def _embedding_api_status(self, retriever: Any | None) -> dict[str, Any]:
api_dimension = int(getattr(self.args, "embedding_api_dimension", 0) or 0) or None
return describe_embedding_api_status(
endpoint=getattr(self.args, "embedding_api_endpoint", ""),
model=getattr(self.args, "embedding_api_model", ""),
api_dimension=api_dimension,
index_dimension=self._dense_index_dimension(retriever),
)
def _dense_index_dimension(self, retriever: Any | None) -> int | None:
if retriever is None:
return None
retriever_list = list(getattr(retriever, "retriever_list", []) or [retriever])
for item in retriever_list:
dimension = getattr(getattr(item, "index", None), "d", None)
if dimension:
return int(dimension)
return None