Commit bd17ac4b by ccran

feat:添加批处理;添加benchmark;修改批注逻辑;

parent 49101664
tmp/ # Ignore everything by default
**
*.pyc # Keep directory structure visible so nested exceptions can work
**__pycache__** !*/
batch/input
batch/output # Keep Python source files
\ No newline at end of file !**/*.py
# Keep this file tracked
!.gitignore
\ No newline at end of file
...@@ -8,6 +8,7 @@ is_extract = True ...@@ -8,6 +8,7 @@ is_extract = True
use_original_text_verification = False use_original_text_verification = False
# @dataclass # @dataclass
# class LLMConfig: # class LLMConfig:
# base_url: str = "http://172.21.107.45:9002/v1" # base_url: str = "http://172.21.107.45:9002/v1"
...@@ -25,10 +26,36 @@ class LLMConfig: ...@@ -25,10 +26,36 @@ class LLMConfig:
api_key: str = "none" api_key: str = "none"
model: str = 'Qwen2-72B-Instruct' model: str = 'Qwen2-72B-Instruct'
# MAX_SINGLE_CHUNK_SIZE=100000
# MAX_SINGLE_CHUNK_SIZE=5000
MAX_SINGLE_CHUNK_SIZE=2000
outer_backend_url = "http://znkf.lgfzgroup.com:48081" DEFAULT_RULESET_ID = "通用"
base_fastgpt_url = "http://192.168.252.71:18089" ALL_RULESET_IDS = ["通用","借款","担保","测试","财务口","金盘","金盘简化"]
base_backend_url = "http://192.168.252.71:48081" FACT_DIMENSIONS = [
"当事人",
"标的",
"金额",
"支付",
"期限",
"交付",
"质量",
"知识产权",
"保密",
"违约责任",
"争议解决"
]
use_lufa = False
if use_lufa:
outer_backend_url = "http://znkf.lgfzgroup.com:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
api_key = "fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8"
else:
outer_backend_url = "http://218.77.58.8:8088"
base_fastgpt_url = "http://192.168.252.71:18088"
base_backend_url = "http://192.168.252.71:48080"
api_key = "fastgpt-vLu2JHAfqwEq5FUQhvATFDK0yDS6fs804v7KwWBMyU4sRrHzh4UGl89Zpa"
# 项目根目录 # 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent" root_path = r"E:\PycharmProject\contract_review_agent"
...@@ -46,7 +73,7 @@ LLM = { ...@@ -46,7 +73,7 @@ LLM = {
"base_tool_llm": LLMConfig(), "base_tool_llm": LLMConfig(),
"fastgpt_segment_review": LLMConfig( "fastgpt_segment_review": LLMConfig(
base_url=f"{base_fastgpt_url}/api/v1", base_url=f"{base_fastgpt_url}/api/v1",
api_key="fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8", api_key=api_key
) )
} }
doc_support_formats = [".docx", ".doc", ".wps"] doc_support_formats = [".docx", ".doc", ".wps"]
......
...@@ -59,12 +59,13 @@ from typing import Any, Dict, Iterable, List, Optional ...@@ -59,12 +59,13 @@ from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file from utils.http_util import upload_file
from utils.doc_util import DocBase from utils.doc_util import DocBase
from core.config import FACT_DIMENSIONS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ALLOWED_RISK_LEVELS = {"H", "M", "L"} _ALLOWED_RISK_LEVELS = {"H", "M", "L",""}
@dataclass @dataclass
...@@ -94,6 +95,9 @@ class RiskFinding: ...@@ -94,6 +95,9 @@ class RiskFinding:
suggestion=str(data.get("suggestion", "")), suggestion=str(data.get("suggestion", "")),
) )
def __repr__(self):
return f"RiskFinding(rule_title={self.rule_title!r}, segment_id={self.segment_id}, issue={self.issue!r}, risk_level={self.risk_level!r})"
@dataclass @dataclass
class MemoryStore: class MemoryStore:
...@@ -128,6 +132,24 @@ class MemoryStore: ...@@ -128,6 +132,24 @@ class MemoryStore:
with self._lock: with self._lock:
return self.facts # deep copy return self.facts # deep copy
def search_facts(self, keywords: List[str]) -> List[Any]:
keys = [str(k).strip().lower() for k in (keywords or []) if str(k).strip()]
allowed_keys = {str(k).strip().lower() for k in FACT_DIMENSIONS if str(k).strip()}
requested_keys = {k for k in keys if k in allowed_keys}
with self._lock:
candidates = list(self.facts)
if not requested_keys:
return []
matched_values: List[Any] = []
for item in candidates:
if not isinstance(item, dict):
continue
for key, value in item.items():
normalized_key = str(key).strip().lower()
if normalized_key in requested_keys:
matched_values.append(value)
return matched_values
# -------------------- findings --------------------- # -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding: def add_finding(self, finding: RiskFinding) -> RiskFinding:
return self._add_finding(self.findings, finding) return self._add_finding(self.findings, finding)
...@@ -263,7 +285,7 @@ class MemoryStore: ...@@ -263,7 +285,7 @@ class MemoryStore:
with self._lock: with self._lock:
wb = Workbook() wb = Workbook()
ws_final_findings = wb.active ws_final_findings = wb.active
ws_final_findings.title = "findings" ws_final_findings.title = "final_findings"
finding_headers = [ finding_headers = [
("rule_title", "规则标题"), ("rule_title", "规则标题"),
...@@ -273,12 +295,14 @@ class MemoryStore: ...@@ -273,12 +295,14 @@ class MemoryStore:
("risk_level", "风险等级"), ("risk_level", "风险等级"),
("suggestion", "建议"), ("suggestion", "建议"),
] ]
# add final findings
ws_final_findings.append([label for _, label in finding_headers]) ws_final_findings.append([label for _, label in finding_headers])
for f in self.final_findings: for f in self.final_findings:
ws_final_findings.append([ ws_final_findings.append([
getattr(f, key, "") for key, _ in finding_headers getattr(f, key, "") for key, _ in finding_headers
]) ])
# add findings
ws_findings = wb.create_sheet("findings") ws_findings = wb.create_sheet("findings")
ws_findings.append([label for _, label in finding_headers]) ws_findings.append([label for _, label in finding_headers])
for f in self.findings: for f in self.findings:
......
...@@ -4,7 +4,6 @@ import re ...@@ -4,7 +4,6 @@ import re
from typing import Dict, List, Optional,Tuple from typing import Dict, List, Optional,Tuple
from core.tool import ToolBase, tool, tool_func from core.tool import ToolBase, tool, tool_func
from tools.shared import store
from core.memory import MemoryStore from core.memory import MemoryStore
TOPIC_KEYWORDS = { TOPIC_KEYWORDS = {
......
...@@ -8,16 +8,54 @@ from core.tools.segment_llm import LLMTool ...@@ -8,16 +8,54 @@ from core.tools.segment_llm import LLMTool
REFLECT_SYSTEM_PROMPT = ''' REFLECT_SYSTEM_PROMPT = '''
你是一个合同审查反思智能体(ReviewReflection)。 你是一个合同审查反思智能体(ReviewReflection)。
你要基于 facts 与全文上下文,对已有 findings 进行校正后,输出【最终可交付的 findings 列表】。
要求: 你的任务不是重新发散式审查整份合同,而是基于已有 findings、facts 和全文上下文,对 findings 进行校验、去重、合并、修订,并输出最终可交付的 findings 列表。
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
- 最终 findings 中每条都必须证据充分,original_text 必须是合同原文直接引用 【反思范围】
- 不得引入新的审查维度,只能基于已有 findings 的范围做合并、修订、删除或系统性总结 你只能做以下事情:
1. 删除重复 findings;
2. 删除在全文上下文下不成立或证据不足的 findings;
3. 修订表述不准确、风险程度不合适、建议不够可执行的 findings;
4. 合并指向同一原文实质问题的重复 findings;
5. 仅在多个 findings 在同一审查规则或同一风险模式下形成稳定重复时,新增少量 system/global findings。
【边界约束】
- 不得引入新的审查维度,不得脱离已有 findings 的审查范围自行扩展新问题。
- facts 和全文上下文仅用于验证、修正、去重、合并已有 findings,或支持同一风险模式下的系统性总结。
- 不得仅依据 facts 摘要或抽象概括生成 finding;最终每条 finding 仍必须有合同原文直接引用作为证据。
- global finding 只能来自多个已有 findings 的归纳,不得凭空新增。
【判定原则】
- 若多个 findings 针对同一原文实质问题,仅保留一条更准确、更完整的表述。
- 若某 finding 在全文上下文下被其他条款、定义、附件或事实明确修正、限制或抵消,则删除。
- 若某 finding 方向正确但表述、严重性或建议不准确,则修订后保留。
- 若多个 findings 在同一规则或同一风险模式下重复出现,且有至少两处原文证据支持,可合并为一条 system/global finding。
【证据要求】
- final_findings 中每条都必须证据充分。
- original_texts 必须是合同原文的直接引用,不得改写、概括或臆造。
【输出约束】
- 严格按照 JSON Schema 输出,不得输出任何解释性文字。
- 若反思后无成立 findings,返回 {"final_findings": []}。
'''
# TODO 根据需要合并的 finding 的特点
MERGE_FINDINGS_PROMPT = '''
【问题归并规则】
如果多个 findings 属于同一根本问题(root cause),
只保留一条代表性 finding。
代表性 finding 应满足:
- 原文最清晰
- 最容易定位
- 修改建议最完整
其余重复 findings 应删除。
''' '''
REFLECT_USER_PROMPT = ''' REFLECT_USER_PROMPT = '''
【输入】 【反思规则】
{rule}
【已有风险 findings】 【已有风险 findings】
{findings_json} {findings_json}
...@@ -28,12 +66,19 @@ REFLECT_USER_PROMPT = ''' ...@@ -28,12 +66,19 @@ REFLECT_USER_PROMPT = '''
站在 {party_role} 的立场进行反思审查。 站在 {party_role} 的立场进行反思审查。
【任务】 【任务】
输出反思后的最终 findings 列表(可直接用于最终审查报告): 请输出反思后的最终 findings 列表,可直接用于最终审查报告。你需要:
- 删除在全文上下文中不成立的 findings - 删除重复 findings;
- 修订表述/严重性/建议不准确的 findings - 删除在全文上下文中不成立或证据不足的 findings;
- 如需合并重复 findings,请合并成一条(保留全部原文证据引用) - 修订表述、严重性或建议不准确的 findings;
- 如可由全文结构推导出系统性风险,可新增 1~3 条 global findings(仍需原文证据) - 合并重复 findings,并保留全部关键原文证据;
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字 - 仅在多个 findings 指向同一规则或同一风险模式下的系统性问题时,新增少量 global findings。
【特别要求】
- 不得引入新的审查维度;
- 不得仅依据 facts 摘要生成结论;
- 每条 final finding 都必须有合同原文直接引用;
- 若无成立 findings,返回 {{"final_findings": []}};
- 仅输出 JSON。
''' '''
OUTPUT_FORMAT_SCHEMA = ''' OUTPUT_FORMAT_SCHEMA = '''
```json ```json
...@@ -41,9 +86,9 @@ OUTPUT_FORMAT_SCHEMA = ''' ...@@ -41,9 +86,9 @@ OUTPUT_FORMAT_SCHEMA = '''
"final_findings": [ "final_findings": [
{ {
"segment_id":"合同原文片段所在的段落ID", "segment_id":"合同原文片段所在的段落ID",
"issue": "详细的风险描述", "issue": "详细且准确的风险描述,为什么该问题构成风险,需基于规则和文本解释",
"original_text": "合同原文片段的直接引用", "original_text": "合同原文片段的直接引用",
"suggestion": "可直接替换原文或新增的条款措辞" "suggestion": "可直接替换原文、新增条款措辞,或明确的修改方向"
} }
] ]
} }
...@@ -71,7 +116,10 @@ class ReflectRetryTool(LLMTool): ...@@ -71,7 +116,10 @@ class ReflectRetryTool(LLMTool):
) )
def run(self, party_role: str, rule: Dict, facts: Optional[List[Dict]] = None, findings: Optional[List[Dict]] = None) -> List[Dict]: def run(self, party_role: str, rule: Dict, facts: Optional[List[Dict]] = None, findings: Optional[List[Dict]] = None) -> List[Dict]:
base_findings = self._build_findings_with_ids(findings or []) base_findings = self._build_findings_with_ids(findings or [])
if len(base_findings) == 0:
return []
user_content = REFLECT_USER_PROMPT.format( user_content = REFLECT_USER_PROMPT.format(
rule=rule.get("rule",""),
findings_json=json.dumps(base_findings, ensure_ascii=False), findings_json=json.dumps(base_findings, ensure_ascii=False),
facts_json=json.dumps(facts or [], ensure_ascii=False), facts_json=json.dumps(facts or [], ensure_ascii=False),
party_role=party_role, party_role=party_role,
......
...@@ -9,22 +9,65 @@ from core.tool import tool, tool_func ...@@ -9,22 +9,65 @@ from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil from utils.excel_util import ExcelUtil
from core.tools.segment_llm import LLMTool from core.tools.segment_llm import LLMTool
import re import re
from core.config import DEFAULT_RULESET_ID, ALL_RULESET_IDS
DEFAULT_RULESET_ID = "通用"
ALL_RULESET_IDS = ["通用","借款","担保","测试"]
REVIEW_SYSTEM_PROMPT = ''' REVIEW_SYSTEM_PROMPT = '''
你是一个专业的合同分段审查智能体(SegmentReview)。 你是一个专业的合同分段审查智能体(SegmentReview)。
你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。
【工作职责】 你的任务是:基于给定审查规则,对“当前分段”进行审查,识别其中证据充分、对当前审查立场不利的风险或缺陷,并输出可执行的修改建议。
- 基于给定的“审查规则”,识别当前分段中确定存在的风险、逻辑矛盾或不合理之处。
- 必须通过“证据对碰”:将当前段落内容与【上下文记忆】进行比对,识别前后不一致。 【审查范围】
你只能审查以下两类问题:
1. 当前分段自身已经明确体现的问题,例如:对我方不利、表述不清、逻辑冲突、责任失衡、触发条件不明确、关键限制缺失等;
2. 当前分段与【上下文记忆】之间存在的、证据充分的前后不一致或条款冲突,例如:主体、定义、金额、期限、责任、流程节点不一致等。
【审查原则】
- 严格基于给定的审查规则进行审查,不得脱离规则自行扩展审查标准。
- 只有在证据充分时才生成 finding。弱猜测、无充分文本支持的问题不得输出。
- 上下文记忆仅作为辅助比对材料,不能替代当前分段原文,也不能在证据不足时强行得出结论。
- 优先识别“确定存在”的问题,不输出模糊怀疑类表述。
【证据要求】
每个 finding 都必须包含 original_text,且必须是合同原文的直接引用。
【证据粒度(关键句原则)】
original_text 必须满足“最小充分证据原则”:
- 只引用能够证明问题成立的最小文本片段
- 优先引用单句或关键子句
- 不得复制整段条款
引用长度限制:
- 推荐:20–80 字
- 最大:120 字
- 若一句话即可证明问题,则只允许引用该句
生成 finding 时必须执行:
Step 1:定位能够证明问题成立的关键句
Step 2:仅提取该句作为 original_text
Step 3:再分析 issue
禁止:
- 复制整段条款
- 引用超过 3 句文本
- 引用与 issue 无关的上下文
【建议要求】
- suggestion 必须具体、可执行。
- 若能在当前分段内直接修正,请给出可直接替换或新增的条款措辞。
- 若需联动其他条款,允许给出明确修改方向和应补充的关键要素,但不得只写“建议协商”“建议完善”等空泛表述。
【规则适用性判断】
在执行任何审查规则之前,你必须先判断:
当前分段是否包含与该审查规则相关的信息维度。
如果当前分段与该审查规则无关,则:
- 不得生成 finding
- 不得引用原文
- 直接返回 {"findings": []}。
【输出约束】 【输出约束】
- findings (风险发现):必须证据充分、可执行。无原文引用不生成 finding - 严格按照指定 JSON Schema 输出
- original_text (原文证据):必须是合同原文的直接引用,严禁改写、概括或臆造 - 不得输出任何 JSON 之外的解释性文字
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字 - 若未发现证据充分的问题,返回 {"findings": []}
''' '''
REVIEW_USER_PROMPT = ''' REVIEW_USER_PROMPT = '''
【当前分段文本】 【当前分段文本】
...@@ -39,19 +82,25 @@ REVIEW_USER_PROMPT = ''' ...@@ -39,19 +82,25 @@ REVIEW_USER_PROMPT = '''
【审查规则】 【审查规则】
{ruleset_text} {ruleset_text}
【指令】 【任务】
执行风险识别:基于规则,识别确定存在的风险,并给出直接落地的修改建议(不要使用“建议协商”等泛化词)。 请基于审查规则,审查当前分段,识别证据充分的问题,并输出可执行修改建议。
【特别要求】
- 仅输出证据充分的问题。
- 如果问题来自与上下文记忆的冲突,必须确保冲突是明确、可由文本直接支持的。
- findings 中的 original_text 必须为合同原文直接引用。
- suggestion 应尽量提供可直接落地的修改文本;若无法安全地直接改写,请给出明确的修改方向和应补充的关键要素。
- 若无问题,返回 {{"findings": []}}。
【输出要求】 【输出要求】
- 仅输出 JSON 格式。 - 仅输出 JSON。
- findings 字段中必须包含原文引用和具体的修改建议。
''' '''
REVIEW_OUTPUT_SCHEMA = ''' REVIEW_OUTPUT_SCHEMA = '''
```json ```json
{ {
"findings": [ "findings": [
{ {
"issue": "详细的风险描述", "issue": "详细的风险描述,为什么该问题构成风险,需基于规则和文本解释",
"original_text": "合同原文片段的直接引用", "original_text": "合同原文片段的直接引用",
"suggestion": "可直接替换原文或新增的条款措辞" "suggestion": "可直接替换原文或新增的条款措辞"
} }
...@@ -128,14 +177,21 @@ class SegmentReviewTool(LLMTool): ...@@ -128,14 +177,21 @@ class SegmentReviewTool(LLMTool):
"segment_id": {"type": "int"}, "segment_id": {"type": "int"},
"segment_text": {"type": "string"}, "segment_text": {"type": "string"},
"ruleset_id": {"type": "string"}, "ruleset_id": {"type": "string"},
"routed_rule_titles": {"type": "array", "items": {"type": "string"}},
"party_role": {"type": "string"}, "party_role": {"type": "string"},
"context_memories": {"type": "array"}, "context_memories": {"type": "array"},
}, },
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"], "required": ["segment_id", "segment_text", "ruleset_id", "party_role"],
} }
) )
def run(self, segment_id: str, segment_text: str, ruleset_id: str, party_role: str, context_memories: Optional[List[Dict]] = None) -> Dict: def run(self, segment_id: str, segment_text: str, ruleset_id: str, party_role: str,
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or [] routed_rule_titles: Optional[List[str]] = None, context_memories: Optional[List[Dict]] = None) -> Dict:
full_rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
if routed_rule_titles is not None:
title_set = {title for title in routed_rule_titles if isinstance(title, str)}
rules = [r for r in full_rules if r.get("title") in title_set]
else:
rules = full_rules
result = self._evaluate_rules(party_role,segment_id,segment_text,rules, context_memories) result = self._evaluate_rules(party_role,segment_id,segment_text,rules, context_memories)
overall = "revise" if (result["findings"] ) else "pass" overall = "revise" if (result["findings"] ) else "pass"
......
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Dict, List, Optional
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
from utils.excel_util import ExcelUtil
ROUTER_SYSTEM_PROMPT = '''
你是合同分段规则路由智能体(SegmentRuleRouter)。
你的任务是:基于“当前分段文本”,从候选审查规则中选出“应执行审查”的规则项。
【路由目标】
- 仅做规则适配判断,不输出风险结论、不输出审查建议。
- 高召回优先:只要当前分段与规则存在明确相关性,就应路由命中。
- 若候选规则明显无关,则不要命中。
【判断依据】
- 以当前分段文本为主。
- 可参考上下文记忆辅助理解术语,但不得脱离当前分段文本做臆断。
【输出约束】
- 严格输出 JSON。
- 每个命中规则需给出简短 reason,说明该分段为何与规则相关。
- 若确实没有任何相关规则,返回 {"selected_items": []}。
'''
ROUTER_USER_PROMPT = '''
【当前分段文本】
{segment_text}
【上下文记忆】
{context_memories_json}
【合同立场】
{party_role}
【候选审查规则
{candidate_rules_json}
【任务】
请从候选规则中选择当前分段应执行的审查项,并输出 selected_items。
'''
ROUTER_OUTPUT_SCHEMA = '''
```json
{
"selected_items": [
{
"title": "规则标题",
"reason": "命中原因(简短)"
}
]
}
```
'''
@tool("segment_rule_router", "分段规则路由")
class SegmentRuleRouterTool(LLMTool):
def __init__(self) -> None:
super().__init__(ROUTER_SYSTEM_PROMPT)
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"ruleset_id": {"type": "string"},
"party_role": {"type": "string"},
"context_memories": {"type": "array"},
},
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"],
}
)
def run(
self,
segment_id: int,
segment_text: str,
ruleset_id: str,
party_role: str,
context_memories: Optional[List[Dict]] = None,
) -> Dict:
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
routed_rules = self._route_rules(
segment_text=segment_text,
rules=rules,
party_role=party_role,
context_memories=context_memories,
)
return {
"segment_id": segment_id,
"ruleset_id": ruleset_id,
"routed_rules": routed_rules,
"routed_rule_titles": [r.get("title", "") for r in routed_rules],
}
def _build_candidate_rules(self, rules: List[Dict]) -> List[Dict]:
return [
{
"title": r.get("title", ""),
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
}
for r in rules
if r.get("title")
]
def _route_rules(
self,
segment_text: str,
rules: List[Dict],
party_role: str,
context_memories: Optional[List[Dict]],
) -> List[Dict]:
if not rules:
return []
candidates = self._build_candidate_rules(rules)
user_content = ROUTER_USER_PROMPT.format(
segment_text=segment_text,
context_memories_json=json.dumps(context_memories or [], ensure_ascii=False),
party_role=party_role,
candidate_rules_json=json.dumps(candidates, ensure_ascii=False),
) + ROUTER_OUTPUT_SCHEMA
llm_selected: List[Dict] = []
try:
resp = self.run_with_loop(self.chat_async(self.build_messages(user_content)))
data = self.parse_first_json(resp)
llm_selected = data.get("selected_items", []) or []
except Exception:
llm_selected = []
selected_titles = {str(item.get("title", "")).strip() for item in llm_selected if item.get("title")}
selected_reasons = {
str(item.get("title", "")).strip(): str(item.get("reason", "")).strip()
for item in llm_selected
if item.get("title")
}
if not selected_titles:
return self._fallback_route(segment_text=segment_text, rules=rules)
title_to_rule = {str(r.get("title", "")).strip(): r for r in rules if r.get("title")}
routed_rules: List[Dict] = []
for title in selected_titles:
rule = title_to_rule.get(title)
if not rule:
continue
routed_rules.append(
{
"id": rule.get("id", ""),
"title": rule.get("title", ""),
"level": rule.get("level", ""),
"rule": rule.get("rule", ""),
"triggers": rule.get("triggers", ""),
"reason": selected_reasons.get(title, ""),
}
)
return routed_rules or self._fallback_route(segment_text=segment_text, rules=rules)
def _fallback_route(self, segment_text: str, rules: List[Dict]) -> List[Dict]:
text = segment_text or ""
routed: List[Dict] = []
for r in rules:
triggers = self._parse_triggers(str(r.get("triggers", "")))
if triggers and any(t in text for t in triggers):
routed.append(
{
"id": r.get("id", ""),
"title": r.get("title", ""),
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
"reason": "fallback: trigger matched",
}
)
# 兜底策略:若触发词也未命中,返回全部规则,保证召回不漏审。
if not routed:
for r in rules:
routed.append(
{
"id": r.get("id", ""),
"title": r.get("title", ""),
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
"reason": "fallback: conservative full recall",
}
)
return routed
def _parse_triggers(self, trigger_text: str) -> List[str]:
parts = re.split(r"[,,、;;\s/|]+", trigger_text or "")
return [p.strip() for p in parts if p.strip()]
if __name__ == "__main__":
tool = SegmentRuleRouterTool()
demo_segment_text = (
"甲方应在合同签订后5个工作日内向乙方支付合同总价30%作为预付款,"
"剩余70%在乙方完成交付并经甲方验收合格后30日内支付。"
"若甲方逾期付款,每逾期一日按应付未付金额的0.05%支付违约金。"
)
result = tool.run(
segment_id=1,
segment_text=demo_segment_text,
ruleset_id="通用",
party_role="甲方",
context_memories=[],
)
print(json.dumps(result, ensure_ascii=False, indent=2))
...@@ -7,22 +7,55 @@ from typing import Dict, List, Optional ...@@ -7,22 +7,55 @@ from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool from core.tools.segment_llm import LLMTool
from core.config import FACT_DIMENSIONS
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
SUMMARY_SYSTEM_PROMPT = f''' SUMMARY_SYSTEM_PROMPT = f'''
你是合同事实提取智能体(SegmentSummary)。 你是合同事实提取智能体(SegmentSummary)。
仅输出“本分段的客观事实”,不做风险判断,不做主观推测。
你的任务是从当前合同分段中提取“客观事实”,并按指定维度结构化输出。
【事实定义】
事实必须满足:
1. 可以在当前分段原文中直接找到对应表述;
2. 不得对原文进行抽象、概括或推断;
3. 不得补充未出现的主体、条件或数值;
4. 允许对原文做最小结构化拆分(例如金额、比例、期限)。
【输出结构】 【输出结构】
- facts: 一个对象,键为预设维度,值为该分段出现的事实(未出现的维度可缺省或置空)。
- 维度列表:{", ".join(FACT_DIMENSIONS)}。 - 输出字段:facts
- 若原文包含多个事实,可使用列表或子对象表达,但保持紧凑、可读。 - facts 是一个对象
- 键为以下预设维度:
{", ".join(FACT_DIMENSIONS)}
- 每个维度值必须是对象或对象列表
- 未出现的维度可以省略
【结构规则】
- 仅提取对合同履行或责任具有实际意义的事实
- 不得输出字符串作为维度值,必须使用对象
- 不得输出解释、总结或风险判断
【上下文事实使用规则】
上下文事实仅用于:
- 避免重复提取已存在的事实
- 保持字段命名一致
不得:
- 使用上下文事实补充当前分段没有出现的信息
- 修改当前分段原文事实
【约束】 【约束】
- 严禁编造或改写原文未出现的信息。
- 不输出与事实无关的解释或额外文字。 - 严禁编造信息
- 严禁推断未出现的内容
- 严格输出 JSON
''' '''
SUMMARY_USER_PROMPT = ''' SUMMARY_USER_PROMPT = '''
...@@ -32,8 +65,11 @@ SUMMARY_USER_PROMPT = ''' ...@@ -32,8 +65,11 @@ SUMMARY_USER_PROMPT = '''
【上下文事实】 【上下文事实】
{context_facts} {context_facts}
请提取本段出现的客观事实,按照指定维度输出 JSON。未出现的维度可省略。 仅提取当前分段中明确出现的客观事实。
输出示例:''' 不得从上下文事实中补充新的信息。
输出 JSON。
'''
OUTPUT_EXAMPLE = ''' OUTPUT_EXAMPLE = '''
```json ```json
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import re import re
import sys import sys
sys.path.append('..') sys.path.append('../..')
import traceback import traceback
import concurrent.futures import concurrent.futures
...@@ -12,10 +12,29 @@ from loguru import logger ...@@ -12,10 +12,29 @@ from loguru import logger
from utils.common_util import random_str from utils.common_util import random_str
from utils.http_util import upload_file, fastgpt_openai_chat, download_file from utils.http_util import upload_file, fastgpt_openai_chat, download_file
SUFFIX='_麓发改进'
batch_input_dir_path = 'jp-input'
batch_output_dir_path = 'jp-output-lufa'
batch_size = 5
# 麓发fastgpt接口
# url = 'http://192.168.252.71:18089/api/v1/chat/completions'
# 金盘fastgpt接口
url = 'http://192.168.252.71:18088/api/v1/chat/completions'
# 麓发合同审查生产token
# token = 'fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz'
# 金盘迁移麓发合同审查测试token
token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyMK7'
# 人机交互测试(生产环境)
# token = 'fastgpt-ry4jIjgNwmNgufMr5jR0ncvJVmSS4GZl4bx2ItsNPoncdQzW9Na3IP1Xrankr'
# 提取后审查测试
# token = 'fastgpt-n74gGX5ZqLT6o1ysMBSGUTjIciswYOWDRfQ75krMkE5gDVDkpzsbz8u'
def extract_url(text): def extract_url(text):
# \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx)) # \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx))
excel_p, doc_p = r'导出Excel结果\s*([^"]*xlsx)', r'导出Doc结果\s*([^\" ]+?\.(?:docx?|pdf|wps))' # 麓发正则
# excel_p, doc_p = r'导出Excel结果\s*([^"]*xlsx)', r'导出Doc结果\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 金盘正则
excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 使用 re.search() 查找第一个匹配项 # 使用 re.search() 查找第一个匹配项
excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text) excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text)
if excel_m and doc_m: if excel_m and doc_m:
...@@ -39,20 +58,17 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count ...@@ -39,20 +58,17 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count
# 源目标处理 # 源目标处理
original_file = f'{batch_input_dir_path}/{file}' original_file = f'{batch_input_dir_path}/{file}'
des_check_file = f'{batch_output_dir_path}/{file_name}.md' des_check_file = f'{batch_output_dir_path}/{file_name}.md'
des_excel_file = f'{batch_output_dir_path}/{file_name}.xlsx' des_excel_file = f'{batch_output_dir_path}/{file_name}{SUFFIX}.xlsx'
des_doc_file = f'{batch_output_dir_path}/{file_name}{ext_name}' des_doc_file = f'{batch_output_dir_path}/{file_name}{SUFFIX}{ext_name}'
try: try:
# 处理原文件 # 处理原文件
file_url = upload_file(original_file, input_url_to_inner=True).replace('218.77.58.8', '192.168.252.71') file_url = upload_file(original_file, input_url_to_inner=True).replace('218.77.58.8', '192.168.252.71')
model = 'Qwen2-72B-Instruct' model = 'Qwen2-72B-Instruct'
# url = 'http://218.77.58.8:8088/api/v1/chat/completions'
url = 'http://192.168.252.71:18089/api/v1/chat/completions'
# 合同审核Excel工作流处理 # 合同审核Excel工作流处理
logger.info(' 第{}个文件,处理文件: {}'.format(counter, original_file)) logger.info(' 第{}个文件,处理文件: {}'.format(counter, original_file))
# # 合同审查测试token
token = 'fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz' result = fastgpt_openai_chat(url, token, model, random_str(), file_url, f'测试批处理任务-{file_name}', False)
result = fastgpt_openai_chat(url, token, model, random_str(), file_url, f'0304批处理任务-{file_name}', False)
excel_url, doc_url = extract_url(result) excel_url, doc_url = extract_url(result)
if excel_url and doc_url: if excel_url and doc_url:
download_file(excel_url.replace('218.77.58.8', '192.168.252.71'), des_excel_file) download_file(excel_url.replace('218.77.58.8', '192.168.252.71'), des_excel_file)
...@@ -64,10 +80,9 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count ...@@ -64,10 +80,9 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count
def execute_batch(max_workers: int = 4): def execute_batch(max_workers: int = 4):
batch_input_dir_path = 'input'
batch_output_dir_path = 'output'
start_file = 1 start_file = 1
dirs = os.listdir(batch_input_dir_path) dirs = os.listdir(batch_input_dir_path)
os.makedirs(batch_output_dir_path, exist_ok=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [ futures = [
...@@ -87,4 +102,4 @@ def execute_batch(max_workers: int = 4): ...@@ -87,4 +102,4 @@ def execute_batch(max_workers: int = 4):
if __name__ == '__main__': if __name__ == '__main__':
execute_batch(5) execute_batch(batch_size)
\ No newline at end of file \ No newline at end of file
from spire.doc import *
from spire.doc.common import *
# 创建一个 Document 类对象并加载一个 Word 文档
benchmark_path = '/home/ccran/contract_review_agent/benchmark'
datasets_path = f'{benchmark_path}/datasets'
clean_path = f'{benchmark_path}/clean'
items = os.listdir(datasets_path)
for item in items:
# 创建一个 Document 类的对象
doc = Document()
# 加载包含批注的 Word 文档
doc.LoadFromFile(f"{datasets_path}/{item}")
doc.Comments.Clear()
doc.SaveToFile(f"{clean_path}/{item}")
doc.Close()
import argparse
from pathlib import Path
import pandas as pd
from rapidfuzz import fuzz
from contextlib import redirect_stdout, redirect_stderr
fuzz_score_threshold = 80
def _normalize_cell(value: object) -> str:
if pd.isna(value):
return ""
return str(value).strip()
def _load_rows(path: Path) -> list[tuple[str, str]]:
df = pd.read_excel(path, dtype=str)
if df.empty:
return []
first_two = df.iloc[:, :2].copy()
# 避免 DataFrame.applymap 的 FutureWarning:按列使用 Series.map
for col in first_two.columns:
first_two[col] = first_two[col].map(_normalize_cell)
first_two = first_two.replace("", pd.NA).dropna(how="all")
return list(map(tuple, first_two.itertuples(index=False, name=None)))
def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
val_dir = val_dir.resolve()
answer_dir = answer_dir.resolve()
overall_val = overall_answer = overall_matched = 0
# 累计各“审查项”的全局统计
overall_item_answer: dict[str, int] = {}
overall_item_matched: dict[str, int] = {}
overall_item_unmatched_answer: dict[str, int] = {}
overall_item_unmatched_val: dict[str, int] = {}
for val_file in sorted(val_dir.glob("*.xlsx")):
answer_file = answer_dir / val_file.name
if not answer_file.exists():
print(f"Skip {val_file.name}: missing in answer")
continue
val_rows = _load_rows(val_file)
answer_rows = _load_rows(answer_file)
# Baseline: answer -> match val, consume val to keep 1-1, report leftover answers
answer_counts: dict[str, int] = {}
for item, _ in answer_rows:
answer_counts[item] = answer_counts.get(item, 0) + 1
val_buckets: dict[str, list[str]] = {}
for item, text in val_rows:
val_buckets.setdefault(item, []).append(text)
matched_total = 0
matched_by_item: dict[str, list[tuple[str, str, int]]] = {}
unmatched_answer_by_item: dict[str, list[str]] = {}
for item, ans_text in answer_rows:
candidates = val_buckets.get(item, [])
if not candidates:
unmatched_answer_by_item.setdefault(item, []).append(ans_text)
continue
best_idx = -1
best_score = -1
for idx, cand in enumerate(candidates):
ans_text = ans_text.strip()
cand = cand.strip()
score = max(
fuzz.partial_ratio(ans_text, cand),
fuzz.token_set_ratio(ans_text, cand)
)
if score > best_score:
best_score = score
best_idx = idx
if best_score >= fuzz_score_threshold:
matched_total += 1
matched_val = candidates.pop(best_idx)
matched_by_item.setdefault(item, []).append((ans_text, matched_val, best_score))
else:
unmatched_answer_by_item.setdefault(item, []).append(ans_text)
# remaining vals in buckets are unmatched
unmatched_val_by_item: dict[str, list[str]] = {
item: texts for item, texts in val_buckets.items() if texts
}
val_total = len(val_rows)
answer_total = len(answer_rows)
overall_val += val_total
overall_answer += answer_total
overall_matched += matched_total
unmatched_val_count = sum(len(v) for v in unmatched_val_by_item.values())
unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values())
# 累加到各“审查项”的全局统计
for it, cnt in answer_counts.items():
overall_item_answer[it] = overall_item_answer.get(it, 0) + cnt
for it, lst in matched_by_item.items():
overall_item_matched[it] = overall_item_matched.get(it, 0) + len(lst)
for it, lst in unmatched_answer_by_item.items():
overall_item_unmatched_answer[it] = overall_item_unmatched_answer.get(it, 0) + len(lst)
for it, lst in unmatched_val_by_item.items():
overall_item_unmatched_val[it] = overall_item_unmatched_val.get(it, 0) + len(lst)
print('#' * 40)
print(
f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} "
f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | accuracy {matched_total / answer_total:.2%} | invalid_val {(unmatched_val_count / val_total) if val_total != 0 else 0:.2%}"
)
for item in sorted(answer_counts):
item_matches = matched_by_item.get(item, [])
print(f" 审查项 {item}: matched {len(item_matches)} / {answer_counts[item]}")
# 匹配成功的结果
# for ans_text, val_text, score in item_matches:
# print(f" {score}% | answer: {ans_text} | val: {val_text}")
ua = unmatched_answer_by_item.get(item, [])
if ua:
print(f" 未匹配(answer 未被匹配){len(ua)} 条:")
for t in ua:
print(f" answer: {t}")
uv = unmatched_val_by_item.get(item, [])
if uv:
print(f" 未匹配(val 残留){len(uv)} 条:")
for t in uv:
print(f" val: {t}")
# break # only first file for demo
accuracy = overall_matched / overall_answer if overall_answer else 0
invalid_val = (overall_val - overall_matched) / overall_val if overall_val else 0
print(
f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | accuracy {accuracy:.2%} | invalid_val {invalid_val:.2%}"
)
# 按“审查项”的 overall 结果
if overall_item_answer:
print('#' * 40)
print("Overall by item:")
all_items = sorted(set(list(overall_item_answer.keys()) + list(overall_item_matched.keys()) + list(overall_item_unmatched_answer.keys()) + list(overall_item_unmatched_val.keys())))
rows_by_item = []
for it in all_items:
ans = overall_item_answer.get(it, 0)
mat = overall_item_matched.get(it, 0)
u_ans = overall_item_unmatched_answer.get(it, 0)
u_val = overall_item_unmatched_val.get(it, 0)
acc = (mat / ans) if ans else 0
invalid_val = u_val / (mat + u_val) if (mat + u_val) else 0
rows_by_item.append({
"审查项": it,
"大模型匹配上的不合格项": mat,
"合同所有不合格项": ans,
"大模型其他不合格项": u_val,
"大模型未匹配上的不合格项(C-B)": u_ans,
"查全率(B/C)": acc,
"无关审查率(D/B+D)": invalid_val,
})
print(
f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | accuracy {acc:.2%} | invalid_val {invalid_val:.2%}"
)
overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"])
unmatched_val_total = sum(overall_item_unmatched_val.values())
unmatched_answer_total = sum(overall_item_unmatched_answer.values())
overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0
overall_total_df = pd.DataFrame([
{
"审查项": "总体",
"大模型匹配上的不合格项": overall_matched,
"合同所有不合格项": overall_answer,
"大模型其他不合格项": unmatched_val_total,
"大模型未匹配上的不合格项(C-B)": unmatched_answer_total,
"查全率(B/C)": accuracy,
"无关审查率(D/B+D)": overall_invalid_rate,
}
], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"])
combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True)
compare_dir_name = val_dir.name
results_dir = Path(__file__).parent / "results"
results_dir.mkdir(parents=True, exist_ok=True)
output_excel = results_dir / f"合同审查结果_{compare_dir_name}.xlsx"
with pd.ExcelWriter(output_excel, engine="openpyxl") as writer:
combined_df.to_excel(writer, sheet_name="对比结果", index=False)
print(f"Excel written to {output_excel}")
def compare(val_dir: Path, answer_dir: Path) -> None:
_compare_impl(val_dir=val_dir, answer_dir=answer_dir)
def compare_with_log(val_dir: Path, answer_dir: Path, log_path: Path | None = None) -> Path:
val_dir = val_dir.resolve()
if log_path is None:
results_dir = Path(__file__).parent / "results"
results_dir.mkdir(parents=True, exist_ok=True)
log_path = results_dir / f"合同审查结果_{val_dir.name}.log"
else:
log_path = log_path.resolve()
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "w", encoding="utf-8") as f, redirect_stdout(f), redirect_stderr(f):
_compare_impl(val_dir=val_dir, answer_dir=answer_dir)
return log_path
def _parse_args() -> argparse.Namespace:
base = Path(__file__).parent
parser = argparse.ArgumentParser(description="Compare extracted annotations with answers.")
parser.add_argument(
"--val-dir",
type=Path,
default=base / "batch_output_0121_val",
help="Directory containing extracted val xlsx files.",
)
parser.add_argument(
"--answer-dir",
type=Path,
default=base / "审查答案",
help="Directory containing answer xlsx files.",
)
parser.add_argument(
"--log-path",
type=Path,
default=None,
help="Optional explicit log path. Defaults to results/合同审查结果_<val_dir_name>.log",
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
final_log_path = compare_with_log(
val_dir=args.val_dir,
answer_dir=args.answer_dir,
log_path=args.log_path,
)
print(f"Log written to {final_log_path}")
\ No newline at end of file
from __future__ import annotations
import argparse
import re
from pathlib import Path
from typing import Iterable
import pandas as pd
from spire.doc import Document
from compare_annotation import compare_with_log
# Map raw comment authors to unified review item names.
COMMENT_AUTHOR_MAPPING: dict[str, str] = {
"三方货款审查":"第三方审查",
"履行义务审查":"第三方审查",
"违约条款审查":"违约与延期审查",
"延期审查":"违约与延期审查"
}
def clean_illegal(value: object) -> object:
if isinstance(value, str):
return re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F]").sub("", value)
return value
def normalize_comment_author(author: str) -> str:
author = author.strip()
if not author:
return author
return COMMENT_AUTHOR_MAPPING.get(author, author)
def extract_annotaion(
datasets_dir: Path,
output_dir: Path,
strip_suffixes: list[str] | None = None,
) -> None:
"""Extract review comments from Word files to xlsx files."""
datasets_dir = datasets_dir.resolve()
output_dir = output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
strip_suffixes = strip_suffixes or []
for item in sorted(datasets_dir.iterdir()):
if item.suffix.lower() == ".xlsx" or not item.is_file():
continue
document = Document()
document.LoadFromFile(str(item))
comments: list[dict[str, str]] = []
for i in range(document.Comments.Count):
comment = document.Comments[i]
comment_text = ""
for j in range(comment.Body.Paragraphs.Count):
paragraph = comment.Body.Paragraphs[j]
comment_text += paragraph.Text + "\n"
comment_author = comment.Format.Author
# 通过|作为分隔符,只拿以后的进行比对
author_split_idx = comment_author.find("|")
comment_author = (
comment_author[author_split_idx + 1 :]
if author_split_idx != -1
else comment_author
)
comment_author = normalize_comment_author(comment_author)
comments.append(
{
"审查项": clean_illegal(comment_author),
"合同原文": clean_illegal(comment.OwnerParagraph.Text),
"建议": clean_illegal(comment_text),
}
)
df = pd.DataFrame(comments)
clean_stem = _strip_suffix_once(item.stem, strip_suffixes)
output_stem = clean_stem or item.stem
output_file = output_dir / f"{output_stem}.xlsx"
df.to_excel(output_file, index=False)
document.Close()
def compare_annotaion(val_dir: Path, answer_dir: Path) -> None:
"""Run benchmark comparison on extracted annotations."""
log_path = compare_with_log(val_dir=val_dir, answer_dir=answer_dir)
print(f"Compare log written to: {log_path}")
def _strip_suffix_once(stem: str, suffixes: Iterable[str]) -> str:
for suffix in suffixes:
if suffix and stem.endswith(suffix):
return stem[: -len(suffix)]
return stem
def eval(
datasets_dir: Path,
answer_dir: Path,
val_dir: Path,
strip_suffixes: list[str] | None = None,
) -> None:
"""Pipeline: extract annotations first, then compare against ground truth."""
strip_suffixes = strip_suffixes or []
extract_annotaion(
datasets_dir=datasets_dir,
output_dir=val_dir,
strip_suffixes=strip_suffixes,
)
compare_annotaion(val_dir=val_dir, answer_dir=answer_dir)
def _parse_args() -> argparse.Namespace:
base = Path(__file__).parent
parser = argparse.ArgumentParser(
description="Extract review comments from docs and evaluate against answers."
)
parser.add_argument(
"--datasets-dir",
type=Path,
default=base / "results" / "jp-output",
help="Directory containing Word files with annotations.",
)
parser.add_argument(
"--answer-dir",
type=Path,
default=base / "审查答案",
help="Directory containing labeled answer xlsx files.",
)
parser.add_argument(
"--val-dir",
type=Path,
default=base / "results" / "jp-output-extracted",
help="Directory to store extracted xlsx files for comparison.",
)
parser.add_argument(
"--strip-suffixes",
nargs="*",
default=['_人机交互'],
help=(
"Optional filename suffixes to strip from generated val xlsx stems before "
"comparison, e.g. --strip-suffixes _v1 _审阅版"
),
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
eval(
datasets_dir=args.datasets_dir,
answer_dir=args.answer_dir,
val_dir=args.val_dir,
strip_suffixes=args.strip_suffixes,
)
from __future__ import annotations
import argparse
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from utils.excel_util import ExcelUtil
from utils.spire_word_util import SpireWordDoc
EXCEL_SUFFIXES = {".xlsx", ".xlsm"}
WORD_SUFFIXES = {".doc", ".docx"}
HEADER_KEYWORDS = {
"key_points",
"chunk_id",
"original_text",
"details",
"suggest",
"建议",
"详情",
"审查项",
"原文",
"合同原文",
"文件块序号",
}
def _normalize_cell(value: object) -> str:
if value is None:
return ""
if isinstance(value, float) and value.is_integer():
return str(int(value))
return str(value).strip()
def _looks_like_header(row: list[object]) -> bool:
header_cells = {_normalize_cell(cell).lower() for cell in row[:6] if _normalize_cell(cell)}
return bool(header_cells & HEADER_KEYWORDS)
def _parse_chunk_id(raw_value: object, *, chunk_id_base: int, row_number: int, excel_path: Path) -> int:
raw_text = _normalize_cell(raw_value)
if not raw_text:
raise ValueError(f"{excel_path.name} 第 {row_number} 行缺少 chunk_id")
try:
chunk_id = int(float(raw_text))
except ValueError as exc:
raise ValueError(
f"{excel_path.name} 第 {row_number} 行 chunk_id 无法解析为整数: {raw_text}"
) from exc
normalized_chunk_id = chunk_id - chunk_id_base
if normalized_chunk_id < 0:
raise ValueError(
f"{excel_path.name} 第 {row_number} 行 chunk_id 越界: {chunk_id}"
)
return normalized_chunk_id
def load_comments_from_excel(
excel_path: Path,
*,
sheet_name: str | None = None,
chunk_id_base: int = 1,
skip_header: bool = True,
) -> list[dict[str, object]]:
rows = ExcelUtil.load_excel(excel_path, sheet_name=sheet_name, has_header=False)
comments: list[dict[str, object]] = []
for row_number, raw_row in enumerate(rows, start=1):
if not isinstance(raw_row, list):
continue
row = list(raw_row)
if not any(_normalize_cell(cell) for cell in row):
continue
if skip_header and row_number == 1 and _looks_like_header(row):
continue
padded = row + [None] * max(0, 6 - len(row))
key_points = _normalize_cell(padded[0])
original_text = _normalize_cell(padded[2])
details = _normalize_cell(padded[3])
suggest = _normalize_cell(padded[5])
if not key_points and not original_text and not suggest:
continue
chunk_id = _parse_chunk_id(
padded[1],
chunk_id_base=chunk_id_base,
row_number=row_number,
excel_path=excel_path,
)
comments.append(
{
"id": f"{excel_path.stem}-{row_number}",
"key_points": key_points,
"chunk_id": chunk_id,
"original_text": original_text,
"details": details,
"result": "不合格",
"suggest": suggest,
}
)
return comments
def _build_doc_index(doc_dir: Path) -> dict[str, Path]:
index: dict[str, Path] = {}
for path in sorted(doc_dir.iterdir()):
if not path.is_file() or path.suffix.lower() not in WORD_SUFFIXES:
continue
index.setdefault(path.stem, path)
return index
def _match_doc_path(excel_path: Path, doc_index: dict[str, Path]) -> Path | None:
return doc_index.get(excel_path.stem)
def _close_spire_doc(doc: SpireWordDoc) -> None:
internal_doc = getattr(doc, "_doc", None)
if internal_doc is None:
return
try:
internal_doc.Close()
except Exception:
pass
def clear_comments(doc: SpireWordDoc) -> None:
internal_doc = getattr(doc, "_doc", None)
if internal_doc is None:
return
try:
internal_doc.Comments.Clear()
except Exception:
pass
def annotate_doc_from_excel(
excel_path: Path,
doc_path: Path,
output_path: Path,
*,
sheet_name: str | None = None,
chunk_id_base: int = 1,
skip_header: bool = True,
) -> int:
comments = load_comments_from_excel(
excel_path,
sheet_name=sheet_name,
chunk_id_base=chunk_id_base,
skip_header=skip_header,
)
doc = SpireWordDoc()
try:
doc.load(str(doc_path))
clear_comments(doc)
grouped_comments: dict[int, list[dict[str, object]]] = {}
for comment in comments:
chunk_id = int(comment["chunk_id"])
if chunk_id >= doc.get_chunk_num():
raise ValueError(
f"{excel_path.name} 中的 chunk_id {chunk_id + chunk_id_base} 超出文档块数量 {doc.get_chunk_num()}"
)
grouped_comments.setdefault(chunk_id, []).append(comment)
for chunk_id, chunk_comments in grouped_comments.items():
doc.add_chunk_comment(chunk_id, chunk_comments)
output_path.parent.mkdir(parents=True, exist_ok=True)
doc.to_file(str(output_path), remove_prefix=True)
return len(comments)
finally:
_close_spire_doc(doc)
def batch_annotate_docs(
excel_dir: Path,
output_dir: Path,
*,
sheet_name: str | None = None,
chunk_id_base: int = 1,
skip_header: bool = True,
) -> None:
excel_dir = excel_dir.resolve()
output_dir = output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
doc_index = _build_doc_index(excel_dir)
processed_files = 0
skipped_files = 0
total_comments = 0
for excel_path in sorted(excel_dir.iterdir()):
if not excel_path.is_file() or excel_path.suffix.lower() not in EXCEL_SUFFIXES:
continue
doc_path = _match_doc_path(excel_path, doc_index)
if doc_path is None:
skipped_files += 1
print(f"Skip {excel_path.name}: 未找到同名 Word 文件")
continue
output_path = output_dir / doc_path.name
comment_count = annotate_doc_from_excel(
excel_path=excel_path,
doc_path=doc_path,
output_path=output_path,
sheet_name=sheet_name,
chunk_id_base=chunk_id_base,
skip_header=skip_header,
)
processed_files += 1
total_comments += comment_count
print(
f"Processed {excel_path.name} -> {output_path.name} | comments: {comment_count}"
)
print(
f"Done: processed {processed_files} files, skipped {skipped_files} files, total comments {total_comments}"
)
def _parse_args() -> argparse.Namespace:
base = Path(__file__).parent
parser = argparse.ArgumentParser(
description="Read Excel comments in bulk and write them into Word documents."
)
parser.add_argument(
"--excel-dir",
type=Path,
required=True,
help="Directory containing Excel files.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=base / "results" / "re_comment_output",
help="Directory to write annotated Word files.",
)
parser.add_argument(
"--chunk-id-base",
type=int,
choices=(0, 1),
default=0,
help="Whether chunk_id in Excel is 0-based or 1-based. Defaults to 0.",
)
parser.add_argument(
"--no-skip-header",
action="store_true",
help="Do not auto-skip the first row when it looks like a header.",
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
batch_annotate_docs(
excel_dir=args.excel_dir,
output_dir=args.output_dir,
chunk_id_base=args.chunk_id_base,
skip_header=not args.no_skip_header,
)
No preview for this file type
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import uuid4 from uuid import uuid4
import ast
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import uvicorn import uvicorn
import traceback import traceback
from loguru import logger
from utils.common_util import extract_url_file, format_now from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file from utils.http_util import download_file
...@@ -14,6 +16,7 @@ from core.cache import get_cached_doc_tool, get_cached_memory ...@@ -14,6 +16,7 @@ from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats, pdf_support_formats from core.config import doc_support_formats, pdf_support_formats
from core.tools.segment_summary import SegmentSummaryTool from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool from core.tools.segment_review import SegmentReviewTool
from core.tools.segment_rule_router import SegmentRuleRouterTool
from core.tools.reflect_retry import ReflectRetryTool from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding from core.memory import RiskFinding
...@@ -22,6 +25,7 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp" ...@@ -22,6 +25,7 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True) TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool() summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool() review_tool = SegmentReviewTool()
rule_router_tool = SegmentRuleRouterTool()
reflect_tool = ReflectRetryTool() reflect_tool = ReflectRetryTool()
...@@ -73,7 +77,7 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse ...@@ -73,7 +77,7 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse
doc_obj.load(file_path) doc_obj.load(file_path)
# ocr # ocr
await doc_obj.get_from_ocr() await doc_obj.get_from_ocr()
# text = doc_obj.get_all_text() text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list() chunk_ids = doc_obj.get_chunk_id_list()
# get ruleset items # get ruleset items
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id
...@@ -81,7 +85,7 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse ...@@ -81,7 +85,7 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse
ruleset_review_items = [r.get('title') for r in ruleset_items] ruleset_review_items = [r.get('title') for r in ruleset_items]
return DocumentParseResponse( return DocumentParseResponse(
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
# text=text, text=text,
chunk_ids=chunk_ids, chunk_ids=chunk_ids,
ruleset_items=ruleset_review_items, ruleset_items=ruleset_review_items,
file_ext = file_ext file_ext = file_ext
...@@ -116,12 +120,11 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse: ...@@ -116,12 +120,11 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
segment_text = doc_obj.get_chunk_item(chunk_idx) segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.") raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
result = summary_tool.run( result = summary_tool.run(
segment_id=payload.segment_id, segment_id=payload.segment_id,
segment_text=segment_text, segment_text=segment_text,
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_facts=payload.context_facts or store.get_facts(), context_facts=payload.context_facts,
) )
store.add_facts(result) store.add_facts(result)
...@@ -138,6 +141,7 @@ class SegmentReviewRequest(BaseModel): ...@@ -138,6 +141,7 @@ class SegmentReviewRequest(BaseModel):
segment_id: int segment_id: int
party_role: Optional[str] = "" party_role: Optional[str] = ""
ruleset_id: Optional[str] = "通用" ruleset_id: Optional[str] = "通用"
routed_rule_titles: Optional[List[str]] = None
file_ext: str file_ext: str
context_memories: Optional[List[Dict]] = None context_memories: Optional[List[Dict]] = None
...@@ -148,6 +152,14 @@ class SegmentReviewResponse(BaseModel): ...@@ -148,6 +152,14 @@ class SegmentReviewResponse(BaseModel):
overall_conclusion: str overall_conclusion: str
findings: List[Dict] findings: List[Dict]
class SegmentRuleRouterResponse(BaseModel):
conversation_id: str
segment_id: int
ruleset_id: str
routed_rule_titles: List[str]
routed_rules: List[Dict]
@app.post("/segments/review/findings", response_model=SegmentReviewResponse) @app.post("/segments/review/findings", response_model=SegmentReviewResponse)
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
...@@ -166,8 +178,10 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -166,8 +178,10 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
segment_id=payload.segment_id, segment_id=payload.segment_id,
segment_text=segment_text, segment_text=segment_text,
ruleset_id=payload.ruleset_id or "通用", ruleset_id=payload.ruleset_id or "通用",
routed_rule_titles=payload.routed_rule_titles,
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(), # TODO 获取与当前审查相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
context_memories=payload.context_memories,
) )
# Persist findings to memory store # Persist findings to memory store
...@@ -181,9 +195,9 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -181,9 +195,9 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(), "risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""), "suggestion": f.get("suggestion", ""),
}) })
except Exception: except Exception as e:
logger.error(e)
continue continue
return SegmentReviewResponse( return SegmentReviewResponse(
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
segment_id=payload.segment_id, segment_id=payload.segment_id,
...@@ -191,6 +205,37 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -191,6 +205,37 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
findings=result.get("findings", []), findings=result.get("findings", []),
) )
@app.post("/segments/review/rule-router", response_model=SegmentRuleRouterResponse)
def route_segment_rules(payload: SegmentReviewRequest) -> SegmentRuleRouterResponse:
try:
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id
result = rule_router_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
ruleset_id=ruleset_id,
party_role=payload.party_role or "",
context_memories=payload.context_memories,
)
return SegmentRuleRouterResponse(
conversation_id=payload.conversation_id,
segment_id=payload.segment_id,
ruleset_id=ruleset_id,
routed_rule_titles=result.get("routed_rule_titles", []),
routed_rules=result.get("routed_rules", []),
)
######################################################################################################################## ########################################################################################################################
class ReflectReviewRequest(BaseModel): class ReflectReviewRequest(BaseModel):
conversation_id: str conversation_id: str
...@@ -212,7 +257,10 @@ def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse: ...@@ -212,7 +257,10 @@ def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None) rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule: if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}") raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
facts = store.get_facts() # TODO 获取与当前审查规则相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
# facts = store.get_facts()
facts = []
# 查找审查规则对应的 findings
findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)] findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)]
final_findings = reflect_tool.run( final_findings = reflect_tool.run(
party_role=payload.party_role, party_role=payload.party_role,
...@@ -253,6 +301,35 @@ def new_conversation() -> ConversationResponse: ...@@ -253,6 +301,35 @@ def new_conversation() -> ConversationResponse:
######################################################################################################################## ########################################################################################################################
class FactsRetrieveRequest(BaseModel):
conversation_id: str
keywords: List[str] = Field(..., description="facts 检索关键字列表")
class FactsRetrieveResponse(BaseModel):
conversation_id: str
keywords: List[str]
facts: List[Any]
total: int
@app.post("/memory/facts/retrieve", response_model=FactsRetrieveResponse)
def retrieve_facts(payload: FactsRetrieveRequest) -> FactsRetrieveResponse:
keywords = [k.strip() for k in (payload.keywords or []) if isinstance(k, str) and k.strip()]
if not keywords:
raise HTTPException(status_code=400, detail="keywords cannot be empty")
store = get_cached_memory(payload.conversation_id)
matched_facts = store.search_facts(keywords)
return FactsRetrieveResponse(
conversation_id=payload.conversation_id,
keywords=keywords,
facts=matched_facts,
total=len(matched_facts),
)
########################################################################################################################
class MemoryExportRequest(BaseModel): class MemoryExportRequest(BaseModel):
conversation_id: str conversation_id: str
file_ext: str file_ext: str
...@@ -293,6 +370,11 @@ def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse: ...@@ -293,6 +370,11 @@ def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
) )
if __name__ == "__main__": if __name__ == "__main__":
from core.config import use_lufa
if use_lufa:
port = 18168
else:
port = 18169
uvicorn.run( uvicorn.run(
"main:app", host="0.0.0.0", port=18168, log_level="info", reload=False "main:app", host="0.0.0.0", port=port, log_level="info", reload=False
) )
\ No newline at end of file
from main import FactsRetrieveRequest, retrieve_facts
from core.cache import get_cached_memory
import json
def test_retrieve_facts_direct() -> None:
conversation_id = "fa86563cb6c649d59e32e7def16ea6b2"
payload = FactsRetrieveRequest(
conversation_id=conversation_id,
keywords=["当事人"],
)
res = retrieve_facts(payload)
print(json.dumps(res.facts,ensure_ascii=False, indent=4))
if __name__ == "__main__":
test_retrieve_facts_direct()
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from core.config import MAX_SINGLE_CHUNK_SIZE
# 文档基类 # 文档基类
...@@ -8,7 +9,7 @@ class DocBase(ABC): ...@@ -8,7 +9,7 @@ class DocBase(ABC):
self._doc_path = None self._doc_path = None
self._doc_name = None self._doc_name = None
self._kwargs = kwargs self._kwargs = kwargs
self._max_single_chunk_size = kwargs.get('max_single_chunk_size', 2000) self._max_single_chunk_size = kwargs.get('max_single_chunk_size', MAX_SINGLE_CHUNK_SIZE)
@abstractmethod @abstractmethod
def load(self, doc_path): def load(self, doc_path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment