Commit 4441c13b by ccran

feat: 增加案例;修改审查提示词;

parent bd17ac4b
{ {
"python-envs.defaultEnvManager": "ms-python.python:conda", "python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda" "python-envs.defaultPackageManager": "ms-python.python:conda",
"python.defaultInterpreterPath": "/home/ccran/.conda/envs/lufa/bin/python",
"python.terminal.activateEnvironment": true
} }
\ No newline at end of file
...@@ -3,22 +3,6 @@ from dataclasses import dataclass ...@@ -3,22 +3,6 @@ from dataclasses import dataclass
# 可配置运行参数 # 可配置运行参数
use_docker = False use_docker = False
just_b_class = False
is_extract = True
use_original_text_verification = False
# @dataclass
# class LLMConfig:
# base_url: str = "http://172.21.107.45:9002/v1"
# api_key: str = "none"
# model: str = "Qwen3-32B"
#
#
# base_fastgpt_url = "http://172.21.107.45:3030"
# base_backend_url = "http://172.21.107.45:48080"
# ocr_url = 'http://172.21.107.45:8202/openapi/ocrUploadFile'
@dataclass @dataclass
class LLMConfig: class LLMConfig:
...@@ -27,32 +11,19 @@ class LLMConfig: ...@@ -27,32 +11,19 @@ class LLMConfig:
model: str = 'Qwen2-72B-Instruct' model: str = 'Qwen2-72B-Instruct'
# MAX_SINGLE_CHUNK_SIZE=100000 # MAX_SINGLE_CHUNK_SIZE=100000
# MAX_SINGLE_CHUNK_SIZE=5000 MERGE_RULE_PROMPT = False
MAX_SINGLE_CHUNK_SIZE=2000 MAX_SINGLE_CHUNK_SIZE=5000
META_KEY="META"
DEFAULT_RULESET_ID = "通用" DEFAULT_RULESET_ID = "通用"
ALL_RULESET_IDS = ["通用","借款","担保","测试","财务口","金盘","金盘简化"] ALL_RULESET_IDS = ["通用","借款","担保","财务口","金盘","金盘简化"]
FACT_DIMENSIONS = [ use_lufa = True
"当事人",
"标的",
"金额",
"支付",
"期限",
"交付",
"质量",
"知识产权",
"保密",
"违约责任",
"争议解决"
]
use_lufa = False
if use_lufa: if use_lufa:
outer_backend_url = "http://znkf.lgfzgroup.com:48081" outer_backend_url = "http://znkf.lgfzgroup.com:48081"
base_fastgpt_url = "http://192.168.252.71:18089" base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081" base_backend_url = "http://192.168.252.71:48081"
api_key = "fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8" api_key = "fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8"
else: else:
outer_backend_url = "http://218.77.58.8:8088" outer_backend_url = "http://218.77.58.8:48080"
base_fastgpt_url = "http://192.168.252.71:18088" base_fastgpt_url = "http://192.168.252.71:18088"
base_backend_url = "http://192.168.252.71:48080" base_backend_url = "http://192.168.252.71:48080"
api_key = "fastgpt-vLu2JHAfqwEq5FUQhvATFDK0yDS6fs804v7KwWBMyU4sRrHzh4UGl89Zpa" api_key = "fastgpt-vLu2JHAfqwEq5FUQhvATFDK0yDS6fs804v7KwWBMyU4sRrHzh4UGl89Zpa"
...@@ -78,40 +49,7 @@ LLM = { ...@@ -78,40 +49,7 @@ LLM = {
} }
doc_support_formats = [".docx", ".doc", ".wps"] doc_support_formats = [".docx", ".doc", ".wps"]
pdf_support_formats = [".txt", ".md", ".pdf"] pdf_support_formats = [".txt", ".md", ".pdf"]
# excel字段
field_mapping = {
"review": {
"id": "序号",
"key_points": "审查内容",
"original_text": "合同原文",
"details": "审查过程",
"result": "审查结果",
"suggest": "审查建议",
"chunk_id": "文件块序号",
"chunk_info": "文件块信息",
},
"category": {"text": "原文", "judge": "判断依据", "chunk_location": "原文位置"},
}
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5
# 销售类别判断
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思"
# 最大分片数量 # 最大分片数量
min_single_chunk_size = 2000 min_single_chunk_size = 2000
max_single_chunk_size = 20000 max_single_chunk_size = 20000
max_chunk_page = 10 max_chunk_page = 10
# 知识库读取sheet
if just_b_class:
know_sheet_name = "B类"
port = 9008
else:
if is_extract:
know_sheet_name = "提取审查"
port = 9016
else:
know_sheet_name = 0
port = 9006
# know_sheet_name = '发票审查'
reload = False
...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional ...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
from dataclasses import asdict from dataclasses import asdict
from core.tool import ToolBase, tool, tool_func from core.tool import ToolBase, tool, tool_func
from core.memory import RiskFinding from core.memory import Finding
@tool("memory_write", "分段记忆写入") @tool("memory_write", "分段记忆写入")
...@@ -32,17 +32,19 @@ class MemoryWriteTool(ToolBase): ...@@ -32,17 +32,19 @@ class MemoryWriteTool(ToolBase):
issue = f.get("issue") or f.get("issue_description") or "" issue = f.get("issue") or f.get("issue_description") or ""
level = (f.get("level") or f.get("risk_level") or "M").upper() level = (f.get("level") or f.get("risk_level") or "M").upper()
suggestion = f.get("suggestion") or "" suggestion = f.get("suggestion") or ""
result = f.get("result") or ""
evs = list(f.get("evidence_quotes", []) or []) evs = list(f.get("evidence_quotes", []) or [])
original_text = evs[0] if evs else (f.get("original_text") or "") original_text = evs[0] if evs else (f.get("original_text") or "")
try: try:
finding_obj = RiskFinding( finding_obj = Finding(
rule_title=rule_title, rule_title=rule_title,
segment_id=int(segment_id) if str(segment_id).isdigit() else 0, segment_id=int(segment_id) if str(segment_id).isdigit() else 0,
original_text=original_text, original_text=original_text,
issue_description=issue, issue=issue,
risk_level=level, risk_level=level,
suggestion=suggestion, suggestion=suggestion,
result=result,
) )
store.add_finding(finding_obj) store.add_finding(finding_obj)
added.append(asdict(finding_obj)) added.append(asdict(finding_obj))
......
...@@ -59,7 +59,7 @@ REFLECT_USER_PROMPT = ''' ...@@ -59,7 +59,7 @@ REFLECT_USER_PROMPT = '''
【已有风险 findings】 【已有风险 findings】
{findings_json} {findings_json}
【合同事实记忆 facts】 【合同摘要事实记忆 facts】
{facts_json} {facts_json}
【合同立场】 【合同立场】
......
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Any from typing import Dict, List, Any
import re
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import ToolBase, tool, tool_func from core.tool import ToolBase, tool, tool_func
from core.cache import get_cached_memory from utils.excel_util import ExcelUtil
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
@tool("retrieve_reference", "审查参考检索") @tool("retrieve_reference", "审查参考检索")
class RetrieveReferenceTool(ToolBase): class RetrieveReferenceTool(ToolBase):
def __init__(self) -> None:
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
"case": "案例",
"summary":"摘要项"
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict[str, Any]]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func( @tool_func(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"question": {"type": "string"}, "ruleset_id": {"type": "string"},
"top_k": {"type": "int"}, "routed_rule_titles": {"type": "array", "items": {"type": "string"}},
}, },
"required": ["question"], "required": [],
} }
) )
def run(self, question: str, top_k: int = 5, conversation_id: str = "") -> Dict: def run(self, ruleset_id: str = "", routed_rule_titles: List[str] | None = None) -> Dict[str, Any]:
memory_refs = self._search_memory(question, conversation_id, top_k) target_ruleset_id = ruleset_id or self.default_ruleset_id
kb_refs = self._search_knowledge_base(question, top_k) full_rules = self.rulesets.get(target_ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
external_refs = self._search_external(question, top_k) if routed_rule_titles is None:
rules = full_rules
else:
title_set = {title for title in routed_rule_titles if isinstance(title, str)}
rules = [r for r in full_rules if r.get("title") in title_set]
return { return {
"memory_refs": memory_refs, "ruleset_id": target_ruleset_id,
"kb_refs": kb_refs, "rules": rules,
"external_refs": external_refs, "rule_titles": [r.get("title", "") for r in rules],
"total": len(rules),
} }
def _search_memory(self, question: str, conversation_id: str, top_k: int) -> List[Dict[str, Any]]: def summary_keywords(self, rules: List[Dict[str, Any]]) -> List[str]:
if not conversation_id: return [r.get("summary", "") for r in rules if r.get("summary")]
return []
store = get_cached_memory(conversation_id)
facts = store.get_facts()
results: List[Dict[str, Any]] = []
def _add(dim: str, payload: Any) -> None:
snippet = payload if isinstance(payload, str) else str(payload)
results.append({
"source": f"memory:{dim}",
"dimension": dim,
"snippet": snippet,
})
for dim in FACT_DIMENSIONS:
val = facts.get(dim)
if val is None:
continue
if dim in question:
_add(dim, val)
return results[:top_k]
def _search_knowledge_base(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement KB retrieval
return []
def _search_external(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement external retrieval (e.g., search engine)
return []
if __name__ == "__main__": if __name__ == "__main__":
tmp_memory = get_cached_memory("tmp")
tmp_memory.add_finding_from_dict({
"issue": "支付方式不明确",
"original_text": "买方应在收到货物后30天内支付全部货款。",
"risk_level": "H",
"rule_title": "支付条款审查",
})
tmp_memory.update_facts({
"支付":{
"支付方式": "银行转账",
"支付期限": "收到货物后30天内",
}
})
tool = RetrieveReferenceTool() tool = RetrieveReferenceTool()
result = tool.run( result = tool.run(ruleset_id="金盘", routed_rule_titles=None)
question="支付方式是什么?", for rule in result.get("rules", []):
top_k=3, print(f"Rule Title: {rule.get('title')}")
conversation_id="tmp", print(f"Case: {rule.get('case')}")
) print("-" * 20)
# print(result.get("total", 0))
print(result) \ No newline at end of file
\ No newline at end of file
...@@ -32,7 +32,8 @@ class LLMTool(ToolBase): ...@@ -32,7 +32,8 @@ class LLMTool(ToolBase):
def run_with_loop(self, coro): def run_with_loop(self, coro):
try: try:
return asyncio.run(coro) return asyncio.run(coro)
except RuntimeError: except RuntimeError as e:
print(f'RuntimeError in run_with_loop: {e}, trying to get event loop and run until complete.')
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(coro) return loop.run_until_complete(coro)
......
...@@ -2,13 +2,10 @@ from __future__ import annotations ...@@ -2,13 +2,10 @@ from __future__ import annotations
import json import json
import re import re
from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool from core.tools.segment_llm import LLMTool
from utils.excel_util import ExcelUtil
ROUTER_SYSTEM_PROMPT = ''' ROUTER_SYSTEM_PROMPT = '''
...@@ -42,7 +39,7 @@ ROUTER_USER_PROMPT = ''' ...@@ -42,7 +39,7 @@ ROUTER_USER_PROMPT = '''
【合同立场】 【合同立场】
{party_role} {party_role}
【候选审查规则 【候选审查规则
{candidate_rules_json} {candidate_rules_json}
【任务】 【任务】
...@@ -68,20 +65,6 @@ ROUTER_OUTPUT_SCHEMA = ''' ...@@ -68,20 +65,6 @@ ROUTER_OUTPUT_SCHEMA = '''
class SegmentRuleRouterTool(LLMTool): class SegmentRuleRouterTool(LLMTool):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(ROUTER_SYSTEM_PROMPT) super().__init__(ROUTER_SYSTEM_PROMPT)
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func( @tool_func(
{ {
...@@ -89,22 +72,22 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -89,22 +72,22 @@ class SegmentRuleRouterTool(LLMTool):
"properties": { "properties": {
"segment_id": {"type": "int"}, "segment_id": {"type": "int"},
"segment_text": {"type": "string"}, "segment_text": {"type": "string"},
"ruleset_id": {"type": "string"}, "rules": {"type": "array", "items": {"type": "object"}},
"party_role": {"type": "string"}, "party_role": {"type": "string"},
"context_memories": {"type": "array"}, "context_memories": {"type": "array"},
}, },
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"], "required": ["segment_id", "segment_text", "rules", "party_role"],
} }
) )
def run( def run(
self, self,
segment_id: int, segment_id: int,
segment_text: str, segment_text: str,
ruleset_id: str, rules: List[Dict],
party_role: str, party_role: str,
context_memories: Optional[List[Dict]] = None, context_memories: Optional[List[Dict]] = None,
) -> Dict: ) -> Dict:
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or [] rules = rules or []
routed_rules = self._route_rules( routed_rules = self._route_rules(
segment_text=segment_text, segment_text=segment_text,
rules=rules, rules=rules,
...@@ -113,7 +96,6 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -113,7 +96,6 @@ class SegmentRuleRouterTool(LLMTool):
) )
return { return {
"segment_id": segment_id, "segment_id": segment_id,
"ruleset_id": ruleset_id,
"routed_rules": routed_rules, "routed_rules": routed_rules,
"routed_rule_titles": [r.get("title", "") for r in routed_rules], "routed_rule_titles": [r.get("title", "") for r in routed_rules],
} }
...@@ -121,10 +103,7 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -121,10 +103,7 @@ class SegmentRuleRouterTool(LLMTool):
def _build_candidate_rules(self, rules: List[Dict]) -> List[Dict]: def _build_candidate_rules(self, rules: List[Dict]) -> List[Dict]:
return [ return [
{ {
"title": r.get("title", ""), r.get("title", ""): r.get("rule", "")
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
} }
for r in rules for r in rules
if r.get("title") if r.get("title")
...@@ -223,6 +202,22 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -223,6 +202,22 @@ class SegmentRuleRouterTool(LLMTool):
if __name__ == "__main__": if __name__ == "__main__":
tool = SegmentRuleRouterTool() tool = SegmentRuleRouterTool()
demo_rules = [
{
"id": "R1",
"title": "付款触发条件明确性",
"level": "H",
"rule": "付款应绑定明确触发条件和验收标准。",
"triggers": "支付,付款,验收",
},
{
"id": "R2",
"title": "违约责任对等性",
"level": "M",
"rule": "违约责任应当相对对等且违约金标准明确。",
"triggers": "违约,违约金",
},
]
demo_segment_text = ( demo_segment_text = (
"甲方应在合同签订后5个工作日内向乙方支付合同总价30%作为预付款," "甲方应在合同签订后5个工作日内向乙方支付合同总价30%作为预付款,"
"剩余70%在乙方完成交付并经甲方验收合格后30日内支付。" "剩余70%在乙方完成交付并经甲方验收合格后30日内支付。"
...@@ -232,7 +227,7 @@ if __name__ == "__main__": ...@@ -232,7 +227,7 @@ if __name__ == "__main__":
result = tool.run( result = tool.run(
segment_id=1, segment_id=1,
segment_text=demo_segment_text, segment_text=demo_segment_text,
ruleset_id="通用", rules=demo_rules,
party_role="甲方", party_role="甲方",
context_memories=[], context_memories=[],
) )
......
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool from core.tools.segment_llm import LLMTool
from core.config import FACT_DIMENSIONS from core.config import META_KEY
SUMMARY_SYSTEM_PROMPT = f''' SUMMARY_SYSTEM_PROMPT = f'''
你是合同事实提取智能体(SegmentSummary)。 你是合同事实提取智能体(SegmentSummary)。
你的任务是从当前合同分段中提取“客观事实”,并按指定维度结构化输出。 你的任务是:**基于给定的审查规则,从当前合同分段中提取“与该规则直接相关的客观事实”,并结构化输出。**
【核心原则】
你必须严格围绕“规则所需信息”进行提取。
---
【事实定义】 【事实定义】
...@@ -24,37 +27,51 @@ SUMMARY_SYSTEM_PROMPT = f''' ...@@ -24,37 +27,51 @@ SUMMARY_SYSTEM_PROMPT = f'''
3. 不得补充未出现的主体、条件或数值; 3. 不得补充未出现的主体、条件或数值;
4. 允许对原文做最小结构化拆分(例如金额、比例、期限)。 4. 允许对原文做最小结构化拆分(例如金额、比例、期限)。
---
【规则驱动提取要求(关键)】
- 仅提取“该审查规则执行所需要的信息字段”
- 不得提取与该规则无关的信息(即使这些信息在文本中存在)
- 若规则未涉及某类信息,则不得输出对应字段
- 若规则涉及某字段但文本未出现,需显式标记为 "未明确"
---
【输出结构】 【输出结构】
- 输出字段:facts - 输出字段:facts
- facts 是一个对象 - facts 是一个对象
- 键为以下预设维度: - 键必须来自【规则字段定义(rule_fields)】
- 不得使用预设通用维度(如“支付/违约责任”等)
---
{", ".join(FACT_DIMENSIONS)} 【字段填充规则】
- 每个维度值必须是对象或对象列表 - 每个字段值必须是对象或对象列表
- 未出现的维度可以省略 - 不得输出字符串作为字段值
- 字段内容必须为原文的最小结构化表达
- 不得改写原文含义
【结构规则】 ---
- 仅提取对合同履行或责任具有实际意义的事实 【缺失信息处理(非常重要)】
- 不得输出字符串作为维度值,必须使用对象
- 不得输出解释、总结或风险判断
【上下文事实使用规则】 - 若规则要求的字段在当前分段未出现:
→ 必须输出该字段,并标记为:
上下文事实仅用于: "未明确"
- 避免重复提取已存在的事实
- 保持字段命名一致
不得: (用于后续审查判断)
- 使用上下文事实补充当前分段没有出现的信息
- 修改当前分段原文事实 ---
【约束】 【约束】
- 严禁编造信息 - 严禁编造信息
- 严禁推断未出现的内容 - 严禁推断未出现的内容
- 不得输出风险判断或解释
- 严格输出 JSON - 严格输出 JSON
''' '''
...@@ -62,11 +79,17 @@ SUMMARY_USER_PROMPT = ''' ...@@ -62,11 +79,17 @@ SUMMARY_USER_PROMPT = '''
【分段原文】 【分段原文】
{segment_text} {segment_text}
【上下文事实】 【规则字段定义(仅提取这些字段)】
{context_facts} {rule_fields}
【任务】
请仅提取“当前分段中,与候选审查规则直接相关的客观事实”。
仅提取当前分段中明确出现的客观事实。 【特别要求】
不得从上下文事实中补充新的信息。 - facts 的顶层 key 必须是规则 title
- 每个规则下仅保留与该规则直接相关的信息
- 若某规则在当前分段未出现关键信息,输出该规则并标记为 "未明确"
- 不得提取与规则无关的信息
输出 JSON。 输出 JSON。
''' '''
...@@ -75,8 +98,8 @@ OUTPUT_EXAMPLE = ''' ...@@ -75,8 +98,8 @@ OUTPUT_EXAMPLE = '''
```json ```json
{ {
"facts": { "facts": {
"支付": {"方式": "银行转账", "时间": "验收后30日内"}, "支付审查": {"方式": "银行转账", "时间": "验收后30日内"},
"违约责任": {"违约金比例": "合同总金额的5%"} "违约责任审查": {"违约金比例": "合同总金额的5%"}
} }
} }
``` ```
...@@ -94,28 +117,54 @@ class SegmentSummaryTool(LLMTool): ...@@ -94,28 +117,54 @@ class SegmentSummaryTool(LLMTool):
"properties": { "properties": {
"segment_id": {"type": "int"}, "segment_id": {"type": "int"},
"segment_text": {"type": "string"}, "segment_text": {"type": "string"},
"rules": {"type": "array", "items": {"type": "object"}},
"party_role": {"type": "string"}, "party_role": {"type": "string"},
"context_facts": {"type": "object"}, "context_facts": {"type": "object"},
}, },
"required": ["segment_id", "segment_text"], "required": ["segment_id", "segment_text", "rules"],
} }
) )
def run( def run(
self, self,
segment_id: int, segment_id: int,
segment_text: str, segment_text: str,
rules: List[Dict],
party_role: str = "", party_role: str = "",
context_facts: Optional[Dict] = None, context_facts: Optional[Dict] = None,
) -> Dict: ) -> Dict:
rules = rules or []
try: try:
return self.run_with_loop(self._summarize_async(segment_id, segment_text, party_role, context_facts)) return self.run_with_loop(
self._summarize_async(segment_id, segment_text, rules, party_role, context_facts)
)
except Exception: except Exception:
return {} return {}
def _build_prompt(self, segment_text: str, context_facts: Optional[Dict], party_role: str) -> List[Dict[str, str]]: def _stringify_rule(self, rules: List[Dict]) -> str:
lines = []
for r in rules:
id = r.get("id", "")
rule_text = r.get("rule", "")
lines.append(f"规则ID: {id}\n审查规则: {rule_text}\n")
return "\n".join(lines)
def _build_prompt(
self,
segment_text: str,
rules: List[Dict],
context_facts: Optional[Dict],
party_role: str,
) -> List[Dict[str, str]]:
# 获取规则字段定义
rule_fields = [
r.get("summary") for r in rules
if r.get("summary")
]
user_content = SUMMARY_USER_PROMPT.format( user_content = SUMMARY_USER_PROMPT.format(
segment_text=segment_text, segment_text=segment_text,
context_facts=json.dumps(context_facts or {}, ensure_ascii=False), # party_role=party_role,
# rules_json=self._stringify_rule(rules),
rule_fields=json.dumps(rule_fields, ensure_ascii=False),
) + OUTPUT_EXAMPLE ) + OUTPUT_EXAMPLE
return self.build_messages(user_content) return self.build_messages(user_content)
...@@ -123,30 +172,30 @@ class SegmentSummaryTool(LLMTool): ...@@ -123,30 +172,30 @@ class SegmentSummaryTool(LLMTool):
self, self,
segment_id: int, segment_id: int,
segment_text: str, segment_text: str,
rules: List[Dict],
party_role: str, party_role: str,
context_facts: Optional[Dict], context_facts: Optional[Dict],
) -> Dict: ) -> Dict:
msgs = self._build_prompt(segment_text, context_facts, party_role) msgs = self._build_prompt(segment_text, rules, context_facts, party_role)
final_facts: Dict = {}
try: try:
resp = await self.chat_async(msgs) resp = await self.chat_async(msgs)
# print("segment summary response:", resp)
data = self.parse_first_json(resp) data = self.parse_first_json(resp)
facts = data.get("facts") or {} facts = data.get("facts") or {}
except Exception: except Exception as e:
print(f'Error in segment summary for segment {segment_id}: {e}')
facts = {} facts = {}
# print(f'SegmentSummaryTool facts: {facts}') facts[META_KEY] = {
if isinstance(facts,list): "segment_id": segment_id,
final_facts['内容'] = facts }
else: return facts
final_facts = facts
final_facts['segment_id'] = segment_id
return final_facts
if __name__=='__main__': if __name__=='__main__':
tool = SegmentSummaryTool() tool = SegmentSummaryTool()
res = tool.run( res = tool.run(
segment_id=1, segment_id=1,
segment_text="甲方应于合同签订之日起30日内向乙方支付合同总金额的50%,余款在货物验收合格后30日内付清.", segment_text="甲方应于合同签订之日起30日内向乙方支付合同总金额的50%,余款在货物验收合格后30日内付清.",
rules=[{"id": "R1", "title": "付款", "rule": "付款相关", "summary": "付款方式"}],
context_facts={}, context_facts={},
) )
print(res) print(res)
\ No newline at end of file
...@@ -14,7 +14,7 @@ from utils.http_util import upload_file, fastgpt_openai_chat, download_file ...@@ -14,7 +14,7 @@ from utils.http_util import upload_file, fastgpt_openai_chat, download_file
SUFFIX='_麓发改进' SUFFIX='_麓发改进'
batch_input_dir_path = 'jp-input' batch_input_dir_path = 'jp-input'
batch_output_dir_path = 'jp-output-lufa' batch_output_dir_path = 'jp-all'
batch_size = 5 batch_size = 5
# 麓发fastgpt接口 # 麓发fastgpt接口
# url = 'http://192.168.252.71:18089/api/v1/chat/completions' # url = 'http://192.168.252.71:18089/api/v1/chat/completions'
...@@ -31,9 +31,6 @@ token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyM ...@@ -31,9 +31,6 @@ token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyM
def extract_url(text): def extract_url(text):
# \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx)) # \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx))
# 麓发正则
# excel_p, doc_p = r'导出Excel结果\s*([^"]*xlsx)', r'导出Doc结果\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 金盘正则
excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))' excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 使用 re.search() 查找第一个匹配项 # 使用 re.search() 查找第一个匹配项
excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text) excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text)
......
...@@ -69,6 +69,8 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -69,6 +69,8 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
best_score = -1 best_score = -1
for idx, cand in enumerate(candidates): for idx, cand in enumerate(candidates):
ans_text = ans_text.strip() ans_text = ans_text.strip()
if cand is None or not isinstance(cand,str):
continue
cand = cand.strip() cand = cand.strip()
score = max( score = max(
fuzz.partial_ratio(ans_text, cand), fuzz.partial_ratio(ans_text, cand),
...@@ -99,6 +101,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -99,6 +101,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
unmatched_val_count = sum(len(v) for v in unmatched_val_by_item.values()) unmatched_val_count = sum(len(v) for v in unmatched_val_by_item.values())
unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values()) unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values())
file_false_positive_rate = (unmatched_val_count / val_total) if val_total != 0 else 0
# 累加到各“审查项”的全局统计 # 累加到各“审查项”的全局统计
for it, cnt in answer_counts.items(): for it, cnt in answer_counts.items():
...@@ -112,7 +115,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -112,7 +115,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
print('#' * 40) print('#' * 40)
print( print(
f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} " f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} "
f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | accuracy {matched_total / answer_total:.2%} | invalid_val {(unmatched_val_count / val_total) if val_total != 0 else 0:.2%}" f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | recall {matched_total / answer_total:.2%} | false_positive_rate {file_false_positive_rate:.2%}"
) )
for item in sorted(answer_counts): for item in sorted(answer_counts):
item_matches = matched_by_item.get(item, []) item_matches = matched_by_item.get(item, [])
...@@ -133,10 +136,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -133,10 +136,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
for t in uv: for t in uv:
print(f" val: {t}") print(f" val: {t}")
# break # only first file for demo # break # only first file for demo
accuracy = overall_matched / overall_answer if overall_answer else 0 recall = overall_matched / overall_answer if overall_answer else 0
invalid_val = (overall_val - overall_matched) / overall_val if overall_val else 0 overall_false_positive_rate = (overall_val - overall_matched) / overall_val if overall_val else 0
print( print(
f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | accuracy {accuracy:.2%} | invalid_val {invalid_val:.2%}" f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | recall {recall:.2%} | false_positive_rate {overall_false_positive_rate:.2%}"
) )
# 按“审查项”的 overall 结果 # 按“审查项”的 overall 结果
...@@ -151,7 +154,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -151,7 +154,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
u_ans = overall_item_unmatched_answer.get(it, 0) u_ans = overall_item_unmatched_answer.get(it, 0)
u_val = overall_item_unmatched_val.get(it, 0) u_val = overall_item_unmatched_val.get(it, 0)
acc = (mat / ans) if ans else 0 acc = (mat / ans) if ans else 0
invalid_val = u_val / (mat + u_val) if (mat + u_val) else 0 item_false_positive_rate = u_val / (mat + u_val) if (mat + u_val) else 0
rows_by_item.append({ rows_by_item.append({
"审查项": it, "审查项": it,
"大模型匹配上的不合格项": mat, "大模型匹配上的不合格项": mat,
...@@ -159,13 +162,13 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -159,13 +162,13 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"大模型其他不合格项": u_val, "大模型其他不合格项": u_val,
"大模型未匹配上的不合格项(C-B)": u_ans, "大模型未匹配上的不合格项(C-B)": u_ans,
"查全率(B/C)": acc, "查全率(B/C)": acc,
"无关审查率(D/B+D)": invalid_val, "误报率(D/B+D)": item_false_positive_rate,
}) })
print( print(
f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | accuracy {acc:.2%} | invalid_val {invalid_val:.2%}" f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | recall {acc:.2%} | false_positive_rate {item_false_positive_rate:.2%}"
) )
overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"]) overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "误报率(D/B+D)"])
unmatched_val_total = sum(overall_item_unmatched_val.values()) unmatched_val_total = sum(overall_item_unmatched_val.values())
unmatched_answer_total = sum(overall_item_unmatched_answer.values()) unmatched_answer_total = sum(overall_item_unmatched_answer.values())
overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0 overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0
...@@ -176,10 +179,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -176,10 +179,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"合同所有不合格项": overall_answer, "合同所有不合格项": overall_answer,
"大模型其他不合格项": unmatched_val_total, "大模型其他不合格项": unmatched_val_total,
"大模型未匹配上的不合格项(C-B)": unmatched_answer_total, "大模型未匹配上的不合格项(C-B)": unmatched_answer_total,
"查全率(B/C)": accuracy, "查全率(B/C)": recall,
"无关审查率(D/B+D)": overall_invalid_rate, "误报率(D/B+D)": overall_invalid_rate,
} }
], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"]) ], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "误报率(D/B+D)"])
combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True) combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True)
compare_dir_name = val_dir.name compare_dir_name = val_dir.name
......
...@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace: ...@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--datasets-dir", "--datasets-dir",
type=Path, type=Path,
default=base / "results" / "jp-output", default=base / "results" / "jp-all-merge-prompt",
help="Directory containing Word files with annotations.", help="Directory containing Word files with annotations.",
) )
parser.add_argument( parser.add_argument(
...@@ -133,13 +133,13 @@ def _parse_args() -> argparse.Namespace: ...@@ -133,13 +133,13 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--val-dir", "--val-dir",
type=Path, type=Path,
default=base / "results" / "jp-output-extracted", default=base / "results" / "jp-all-merge-prompt-extracted",
help="Directory to store extracted xlsx files for comparison.", help="Directory to store extracted xlsx files for comparison.",
) )
parser.add_argument( parser.add_argument(
"--strip-suffixes", "--strip-suffixes",
nargs="*", nargs="*",
default=['_人机交互'], default=['_麓发改进'],
help=( help=(
"Optional filename suffixes to strip from generated val xlsx stems before " "Optional filename suffixes to strip from generated val xlsx stems before "
"comparison, e.g. --strip-suffixes _v1 _审阅版" "comparison, e.g. --strip-suffixes _v1 _审阅版"
......
No preview for this file type
...@@ -10,15 +10,16 @@ import uvicorn ...@@ -10,15 +10,16 @@ import uvicorn
import traceback import traceback
from loguru import logger from loguru import logger
from utils.common_util import extract_url_file, format_now from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats, pdf_support_formats from core.config import doc_support_formats, pdf_support_formats,MERGE_RULE_PROMPT
from core.tools.segment_summary import SegmentSummaryTool from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool from core.tools.segment_review import SegmentReviewTool
from core.tools.segment_rule_router import SegmentRuleRouterTool from core.tools.segment_rule_router import SegmentRuleRouterTool
from core.tools.retrieve_reference import RetrieveReferenceTool
from core.tools.reflect_retry import ReflectRetryTool from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0") app = FastAPI(title="合同审查智能体", version="0.1.0")
TMP_DIR = Path(__file__).resolve().parent / "tmp" TMP_DIR = Path(__file__).resolve().parent / "tmp"
...@@ -26,6 +27,7 @@ TMP_DIR.mkdir(parents=True, exist_ok=True) ...@@ -26,6 +27,7 @@ TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool() summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool() review_tool = SegmentReviewTool()
rule_router_tool = SegmentRuleRouterTool() rule_router_tool = SegmentRuleRouterTool()
reference_tool = RetrieveReferenceTool()
reflect_tool = ReflectRetryTool() reflect_tool = ReflectRetryTool()
...@@ -47,7 +49,7 @@ class DocumentParseRequest(BaseModel): ...@@ -47,7 +49,7 @@ class DocumentParseRequest(BaseModel):
class DocumentParseResponse(BaseModel): class DocumentParseResponse(BaseModel):
conversation_id: str conversation_id: str
chunk_ids: List[int] segment_ids: List[int]
ruleset_items: List[str] ruleset_items: List[str]
text: Optional[str] = None text: Optional[str] = None
file_ext: Optional[str] = None file_ext: Optional[str] = None
...@@ -78,15 +80,20 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse ...@@ -78,15 +80,20 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse
# ocr # ocr
await doc_obj.get_from_ocr() await doc_obj.get_from_ocr()
text = doc_obj.get_all_text() text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list() segment_ids = doc_obj.get_chunk_id_list()
# TODO: FastGPT BUG segment_ids必须从1开始,0开始会缺少第一段文本,后续需要修复
segment_ids = [idx + 1 for idx in segment_ids]
# get ruleset items # get ruleset items
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
ruleset_items = review_tool.rulesets.get(ruleset_id) or [] ruleset_items = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
ruleset_review_items = [r.get('title') for r in ruleset_items] ruleset_review_items = [
t for t in (r.get("title") for r in ruleset_items)
if isinstance(t, str) and t.strip()
]
return DocumentParseResponse( return DocumentParseResponse(
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
text=text, text=text,
chunk_ids=chunk_ids, segment_ids=segment_ids,
ruleset_items=ruleset_review_items, ruleset_items=ruleset_review_items,
file_ext = file_ext file_ext = file_ext
) )
...@@ -97,6 +104,8 @@ class SegmentSummaryRequest(BaseModel): ...@@ -97,6 +104,8 @@ class SegmentSummaryRequest(BaseModel):
conversation_id: str conversation_id: str
segment_id: int segment_id: int
party_role: Optional[str] = "" party_role: Optional[str] = ""
ruleset_id: Optional[str] = "通用"
routed_rule_titles: Optional[List[str]] = None
file_ext: str file_ext: str
context_facts: Optional[Dict] = None context_facts: Optional[Dict] = None
...@@ -115,14 +124,21 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse: ...@@ -115,14 +124,21 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 # chunk_id 在 SpireWordDoc 中为 1-based segment_idx = payload.segment_id - 1
try: try:
segment_text = doc_obj.get_chunk_item(chunk_idx) segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.") raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(
ruleset_id=ruleset_id,
routed_rule_titles=payload.routed_rule_titles,
).get("rules", [])
result = summary_tool.run( result = summary_tool.run(
segment_id=payload.segment_id, segment_id=segment_idx,
segment_text=segment_text, segment_text=segment_text,
rules=rules,
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_facts=payload.context_facts, context_facts=payload.context_facts,
) )
...@@ -168,20 +184,31 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -168,20 +184,31 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 segment_idx = payload.segment_id - 1
try: try:
segment_text = doc_obj.get_chunk_item(chunk_idx) segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.") raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(
ruleset_id=ruleset_id,
routed_rule_titles=payload.routed_rule_titles,
).get("rules", [])
# 暂时不添加摘要看下结果
# summary_keywords = reference_tool.summary_keywords(rules)
# context_summaries = store.search_facts(summary_keywords)
result = review_tool.run( result = review_tool.run(
segment_id=payload.segment_id, segment_id=segment_idx,
segment_text=segment_text, segment_text=segment_text,
ruleset_id=payload.ruleset_id or "通用", rules=rules,
routed_rule_titles=payload.routed_rule_titles,
party_role=payload.party_role or "", party_role=payload.party_role or "",
# 暂时不添加摘要看下结果
# context_summaries=context_summaries,
# TODO 获取与当前审查相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆 # TODO 获取与当前审查相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
context_memories=payload.context_memories, context_memories=payload.context_memories,
merge_rules_prompt=MERGE_RULE_PROMPT, # TODO 是否合并规则到同一 prompt 进行审查,当前默认合并,后续可调整为不合并以提升审查的针对性
) )
# Persist findings to memory store # Persist findings to memory store
...@@ -189,10 +216,11 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -189,10 +216,11 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
try: try:
store.add_finding_from_dict({ store.add_finding_from_dict({
"rule_title": f.get("rule_title", ""), "rule_title": f.get("rule_title", ""),
"segment_id": payload.segment_id, "segment_id": segment_idx,
"original_text": f.get("original_text",''), "original_text": f.get("original_text",''),
"issue": f.get("issue", ""), "issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(), "risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"result": f.get("result", ""),
"suggestion": f.get("suggestion", ""), "suggestion": f.get("suggestion", ""),
}) })
except Exception as e: except Exception as e:
...@@ -213,17 +241,18 @@ def route_segment_rules(payload: SegmentReviewRequest) -> SegmentRuleRouterRespo ...@@ -213,17 +241,18 @@ def route_segment_rules(payload: SegmentReviewRequest) -> SegmentRuleRouterRespo
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 segment_idx = payload.segment_id - 1
try: try:
segment_text = doc_obj.get_chunk_item(chunk_idx) segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.") raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
result = rule_router_tool.run( result = rule_router_tool.run(
segment_id=payload.segment_id, segment_id=segment_idx,
segment_text=segment_text, segment_text=segment_text,
ruleset_id=ruleset_id, rules=rules,
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_memories=payload.context_memories, context_memories=payload.context_memories,
) )
...@@ -252,20 +281,19 @@ class ReflectReviewResponse(BaseModel): ...@@ -252,20 +281,19 @@ class ReflectReviewResponse(BaseModel):
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse) @app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse: def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
ruleset_items = review_tool.rulesets.get(ruleset_id) or [] ruleset_items = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None) rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule: if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}") raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
# TODO 获取与当前审查规则相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆 summary_keywords = reference_tool.summary_keywords([rule])
# facts = store.get_facts() context_summaries_facts = store.search_facts(summary_keywords)
facts = []
# 查找审查规则对应的 findings # 查找审查规则对应的 findings
findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)] findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)]
final_findings = reflect_tool.run( final_findings = reflect_tool.run(
party_role=payload.party_role, party_role=payload.party_role,
rule=rule, rule=rule,
facts=facts, facts=context_summaries_facts,
findings=findings, findings=findings,
) )
......
...@@ -396,7 +396,7 @@ class SpirePdfDoc(DocBase): ...@@ -396,7 +396,7 @@ class SpirePdfDoc(DocBase):
return -1, None, None return -1, None, None
def get_chunk_id_list(self, step=1): def get_chunk_id_list(self, step=1):
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)] return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_chunk_num(self): def get_chunk_num(self):
return len(self._chunk_list) return len(self._chunk_list)
......
...@@ -700,7 +700,7 @@ class SpireWordDoc(DocBase): ...@@ -700,7 +700,7 @@ class SpireWordDoc(DocBase):
def get_chunk_id_list(self, step=1): def get_chunk_id_list(self, step=1):
self._ensure_loaded() self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)] return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self): def get_all_text(self):
self._ensure_loaded() self._ensure_loaded()
......
...@@ -724,7 +724,7 @@ class SpireWordDoc(DocBase): ...@@ -724,7 +724,7 @@ class SpireWordDoc(DocBase):
def get_chunk_id_list(self, step=1): def get_chunk_id_list(self, step=1):
self._ensure_loaded() self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)] return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self): def get_all_text(self):
self._ensure_loaded() self._ensure_loaded()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment