Commit bd17ac4b by ccran

feat:添加批处理;添加benchmark;修改批注逻辑;

parent 49101664
tmp/
# Ignore everything by default
**
*.pyc
**__pycache__**
batch/input
batch/output
\ No newline at end of file
# Keep directory structure visible so nested exceptions can work
!*/
# Keep Python source files
!**/*.py
# Keep this file tracked
!.gitignore
\ No newline at end of file
......@@ -8,6 +8,7 @@ is_extract = True
use_original_text_verification = False
# @dataclass
# class LLMConfig:
# base_url: str = "http://172.21.107.45:9002/v1"
......@@ -25,10 +26,36 @@ class LLMConfig:
api_key: str = "none"
model: str = 'Qwen2-72B-Instruct'
# MAX_SINGLE_CHUNK_SIZE=100000
# MAX_SINGLE_CHUNK_SIZE=5000
MAX_SINGLE_CHUNK_SIZE=2000
outer_backend_url = "http://znkf.lgfzgroup.com:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
DEFAULT_RULESET_ID = "通用"
ALL_RULESET_IDS = ["通用","借款","担保","测试","财务口","金盘","金盘简化"]
FACT_DIMENSIONS = [
"当事人",
"标的",
"金额",
"支付",
"期限",
"交付",
"质量",
"知识产权",
"保密",
"违约责任",
"争议解决"
]
use_lufa = False
if use_lufa:
outer_backend_url = "http://znkf.lgfzgroup.com:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
api_key = "fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8"
else:
outer_backend_url = "http://218.77.58.8:8088"
base_fastgpt_url = "http://192.168.252.71:18088"
base_backend_url = "http://192.168.252.71:48080"
api_key = "fastgpt-vLu2JHAfqwEq5FUQhvATFDK0yDS6fs804v7KwWBMyU4sRrHzh4UGl89Zpa"
# 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent"
......@@ -46,7 +73,7 @@ LLM = {
"base_tool_llm": LLMConfig(),
"fastgpt_segment_review": LLMConfig(
base_url=f"{base_fastgpt_url}/api/v1",
api_key="fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8",
api_key=api_key
)
}
doc_support_formats = [".docx", ".doc", ".wps"]
......
......@@ -59,12 +59,13 @@ from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file
from utils.doc_util import DocBase
from core.config import FACT_DIMENSIONS
logger = logging.getLogger(__name__)
_ALLOWED_RISK_LEVELS = {"H", "M", "L"}
_ALLOWED_RISK_LEVELS = {"H", "M", "L",""}
@dataclass
......@@ -93,6 +94,9 @@ class RiskFinding:
risk_level=str(data.get("risk_level", "")),
suggestion=str(data.get("suggestion", "")),
)
def __repr__(self):
return f"RiskFinding(rule_title={self.rule_title!r}, segment_id={self.segment_id}, issue={self.issue!r}, risk_level={self.risk_level!r})"
@dataclass
......@@ -128,6 +132,24 @@ class MemoryStore:
with self._lock:
return self.facts # deep copy
def search_facts(self, keywords: List[str]) -> List[Any]:
keys = [str(k).strip().lower() for k in (keywords or []) if str(k).strip()]
allowed_keys = {str(k).strip().lower() for k in FACT_DIMENSIONS if str(k).strip()}
requested_keys = {k for k in keys if k in allowed_keys}
with self._lock:
candidates = list(self.facts)
if not requested_keys:
return []
matched_values: List[Any] = []
for item in candidates:
if not isinstance(item, dict):
continue
for key, value in item.items():
normalized_key = str(key).strip().lower()
if normalized_key in requested_keys:
matched_values.append(value)
return matched_values
# -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding:
return self._add_finding(self.findings, finding)
......@@ -263,7 +285,7 @@ class MemoryStore:
with self._lock:
wb = Workbook()
ws_final_findings = wb.active
ws_final_findings.title = "findings"
ws_final_findings.title = "final_findings"
finding_headers = [
("rule_title", "规则标题"),
......@@ -273,12 +295,14 @@ class MemoryStore:
("risk_level", "风险等级"),
("suggestion", "建议"),
]
# add final findings
ws_final_findings.append([label for _, label in finding_headers])
for f in self.final_findings:
ws_final_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
# add findings
ws_findings = wb.create_sheet("findings")
ws_findings.append([label for _, label in finding_headers])
for f in self.findings:
......
......@@ -4,7 +4,6 @@ import re
from typing import Dict, List, Optional,Tuple
from core.tool import ToolBase, tool, tool_func
from tools.shared import store
from core.memory import MemoryStore
TOPIC_KEYWORDS = {
......
......@@ -8,16 +8,54 @@ from core.tools.segment_llm import LLMTool
REFLECT_SYSTEM_PROMPT = '''
你是一个合同审查反思智能体(ReviewReflection)。
你要基于 facts 与全文上下文,对已有 findings 进行校正后,输出【最终可交付的 findings 列表】。
要求:
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
- 最终 findings 中每条都必须证据充分,original_text 必须是合同原文直接引用
- 不得引入新的审查维度,只能基于已有 findings 的范围做合并、修订、删除或系统性总结
你的任务不是重新发散式审查整份合同,而是基于已有 findings、facts 和全文上下文,对 findings 进行校验、去重、合并、修订,并输出最终可交付的 findings 列表。
【反思范围】
你只能做以下事情:
1. 删除重复 findings;
2. 删除在全文上下文下不成立或证据不足的 findings;
3. 修订表述不准确、风险程度不合适、建议不够可执行的 findings;
4. 合并指向同一原文实质问题的重复 findings;
5. 仅在多个 findings 在同一审查规则或同一风险模式下形成稳定重复时,新增少量 system/global findings。
【边界约束】
- 不得引入新的审查维度,不得脱离已有 findings 的审查范围自行扩展新问题。
- facts 和全文上下文仅用于验证、修正、去重、合并已有 findings,或支持同一风险模式下的系统性总结。
- 不得仅依据 facts 摘要或抽象概括生成 finding;最终每条 finding 仍必须有合同原文直接引用作为证据。
- global finding 只能来自多个已有 findings 的归纳,不得凭空新增。
【判定原则】
- 若多个 findings 针对同一原文实质问题,仅保留一条更准确、更完整的表述。
- 若某 finding 在全文上下文下被其他条款、定义、附件或事实明确修正、限制或抵消,则删除。
- 若某 finding 方向正确但表述、严重性或建议不准确,则修订后保留。
- 若多个 findings 在同一规则或同一风险模式下重复出现,且有至少两处原文证据支持,可合并为一条 system/global finding。
【证据要求】
- final_findings 中每条都必须证据充分。
- original_texts 必须是合同原文的直接引用,不得改写、概括或臆造。
【输出约束】
- 严格按照 JSON Schema 输出,不得输出任何解释性文字。
- 若反思后无成立 findings,返回 {"final_findings": []}。
'''
# TODO 根据需要合并的 finding 的特点
MERGE_FINDINGS_PROMPT = '''
【问题归并规则】
如果多个 findings 属于同一根本问题(root cause),
只保留一条代表性 finding。
代表性 finding 应满足:
- 原文最清晰
- 最容易定位
- 修改建议最完整
其余重复 findings 应删除。
'''
REFLECT_USER_PROMPT = '''
【输入】
【反思规则】
{rule}
【已有风险 findings】
{findings_json}
......@@ -28,12 +66,19 @@ REFLECT_USER_PROMPT = '''
站在 {party_role} 的立场进行反思审查。
【任务】
输出反思后的最终 findings 列表(可直接用于最终审查报告):
- 删除在全文上下文中不成立的 findings
- 修订表述/严重性/建议不准确的 findings
- 如需合并重复 findings,请合并成一条(保留全部原文证据引用)
- 如可由全文结构推导出系统性风险,可新增 1~3 条 global findings(仍需原文证据)
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
请输出反思后的最终 findings 列表,可直接用于最终审查报告。你需要:
- 删除重复 findings;
- 删除在全文上下文中不成立或证据不足的 findings;
- 修订表述、严重性或建议不准确的 findings;
- 合并重复 findings,并保留全部关键原文证据;
- 仅在多个 findings 指向同一规则或同一风险模式下的系统性问题时,新增少量 global findings。
【特别要求】
- 不得引入新的审查维度;
- 不得仅依据 facts 摘要生成结论;
- 每条 final finding 都必须有合同原文直接引用;
- 若无成立 findings,返回 {{"final_findings": []}};
- 仅输出 JSON。
'''
OUTPUT_FORMAT_SCHEMA = '''
```json
......@@ -41,9 +86,9 @@ OUTPUT_FORMAT_SCHEMA = '''
"final_findings": [
{
"segment_id":"合同原文片段所在的段落ID",
"issue": "详细的风险描述",
"issue": "详细且准确的风险描述,为什么该问题构成风险,需基于规则和文本解释",
"original_text": "合同原文片段的直接引用",
"suggestion": "可直接替换原文或新增的条款措辞"
"suggestion": "可直接替换原文、新增条款措辞,或明确的修改方向"
}
]
}
......@@ -71,7 +116,10 @@ class ReflectRetryTool(LLMTool):
)
def run(self, party_role: str, rule: Dict, facts: Optional[List[Dict]] = None, findings: Optional[List[Dict]] = None) -> List[Dict]:
base_findings = self._build_findings_with_ids(findings or [])
if len(base_findings) == 0:
return []
user_content = REFLECT_USER_PROMPT.format(
rule=rule.get("rule",""),
findings_json=json.dumps(base_findings, ensure_ascii=False),
facts_json=json.dumps(facts or [], ensure_ascii=False),
party_role=party_role,
......
......@@ -9,22 +9,65 @@ from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil
from core.tools.segment_llm import LLMTool
import re
DEFAULT_RULESET_ID = "通用"
ALL_RULESET_IDS = ["通用","借款","担保","测试"]
from core.config import DEFAULT_RULESET_ID, ALL_RULESET_IDS
REVIEW_SYSTEM_PROMPT = '''
你是一个专业的合同分段审查智能体(SegmentReview)。
你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。
【工作职责】
- 基于给定的“审查规则”,识别当前分段中确定存在的风险、逻辑矛盾或不合理之处。
- 必须通过“证据对碰”:将当前段落内容与【上下文记忆】进行比对,识别前后不一致。
你的任务是:基于给定审查规则,对“当前分段”进行审查,识别其中证据充分、对当前审查立场不利的风险或缺陷,并输出可执行的修改建议。
【审查范围】
你只能审查以下两类问题:
1. 当前分段自身已经明确体现的问题,例如:对我方不利、表述不清、逻辑冲突、责任失衡、触发条件不明确、关键限制缺失等;
2. 当前分段与【上下文记忆】之间存在的、证据充分的前后不一致或条款冲突,例如:主体、定义、金额、期限、责任、流程节点不一致等。
【审查原则】
- 严格基于给定的审查规则进行审查,不得脱离规则自行扩展审查标准。
- 只有在证据充分时才生成 finding。弱猜测、无充分文本支持的问题不得输出。
- 上下文记忆仅作为辅助比对材料,不能替代当前分段原文,也不能在证据不足时强行得出结论。
- 优先识别“确定存在”的问题,不输出模糊怀疑类表述。
【证据要求】
每个 finding 都必须包含 original_text,且必须是合同原文的直接引用。
【证据粒度(关键句原则)】
original_text 必须满足“最小充分证据原则”:
- 只引用能够证明问题成立的最小文本片段
- 优先引用单句或关键子句
- 不得复制整段条款
引用长度限制:
- 推荐:20–80 字
- 最大:120 字
- 若一句话即可证明问题,则只允许引用该句
生成 finding 时必须执行:
Step 1:定位能够证明问题成立的关键句
Step 2:仅提取该句作为 original_text
Step 3:再分析 issue
禁止:
- 复制整段条款
- 引用超过 3 句文本
- 引用与 issue 无关的上下文
【建议要求】
- suggestion 必须具体、可执行。
- 若能在当前分段内直接修正,请给出可直接替换或新增的条款措辞。
- 若需联动其他条款,允许给出明确修改方向和应补充的关键要素,但不得只写“建议协商”“建议完善”等空泛表述。
【规则适用性判断】
在执行任何审查规则之前,你必须先判断:
当前分段是否包含与该审查规则相关的信息维度。
如果当前分段与该审查规则无关,则:
- 不得生成 finding
- 不得引用原文
- 直接返回 {"findings": []}。
【输出约束】
- findings (风险发现):必须证据充分、可执行。无原文引用不生成 finding
- original_text (原文证据):必须是合同原文的直接引用,严禁改写、概括或臆造
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
- 严格按照指定 JSON Schema 输出
- 不得输出任何 JSON 之外的解释性文字
- 若未发现证据充分的问题,返回 {"findings": []}
'''
REVIEW_USER_PROMPT = '''
【当前分段文本】
......@@ -39,19 +82,25 @@ REVIEW_USER_PROMPT = '''
【审查规则】
{ruleset_text}
【指令】
执行风险识别:基于规则,识别确定存在的风险,并给出直接落地的修改建议(不要使用“建议协商”等泛化词)。
【任务】
请基于审查规则,审查当前分段,识别证据充分的问题,并输出可执行修改建议。
【特别要求】
- 仅输出证据充分的问题。
- 如果问题来自与上下文记忆的冲突,必须确保冲突是明确、可由文本直接支持的。
- findings 中的 original_text 必须为合同原文直接引用。
- suggestion 应尽量提供可直接落地的修改文本;若无法安全地直接改写,请给出明确的修改方向和应补充的关键要素。
- 若无问题,返回 {{"findings": []}}。
【输出要求】
- 仅输出 JSON 格式。
- findings 字段中必须包含原文引用和具体的修改建议。
- 仅输出 JSON。
'''
REVIEW_OUTPUT_SCHEMA = '''
```json
{
"findings": [
{
"issue": "详细的风险描述",
"issue": "详细的风险描述,为什么该问题构成风险,需基于规则和文本解释",
"original_text": "合同原文片段的直接引用",
"suggestion": "可直接替换原文或新增的条款措辞"
}
......@@ -128,14 +177,21 @@ class SegmentReviewTool(LLMTool):
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"ruleset_id": {"type": "string"},
"routed_rule_titles": {"type": "array", "items": {"type": "string"}},
"party_role": {"type": "string"},
"context_memories": {"type": "array"},
},
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"],
}
)
def run(self, segment_id: str, segment_text: str, ruleset_id: str, party_role: str, context_memories: Optional[List[Dict]] = None) -> Dict:
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
def run(self, segment_id: str, segment_text: str, ruleset_id: str, party_role: str,
routed_rule_titles: Optional[List[str]] = None, context_memories: Optional[List[Dict]] = None) -> Dict:
full_rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
if routed_rule_titles is not None:
title_set = {title for title in routed_rule_titles if isinstance(title, str)}
rules = [r for r in full_rules if r.get("title") in title_set]
else:
rules = full_rules
result = self._evaluate_rules(party_role,segment_id,segment_text,rules, context_memories)
overall = "revise" if (result["findings"] ) else "pass"
......
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Dict, List, Optional
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
from utils.excel_util import ExcelUtil
ROUTER_SYSTEM_PROMPT = '''
你是合同分段规则路由智能体(SegmentRuleRouter)。
你的任务是:基于“当前分段文本”,从候选审查规则中选出“应执行审查”的规则项。
【路由目标】
- 仅做规则适配判断,不输出风险结论、不输出审查建议。
- 高召回优先:只要当前分段与规则存在明确相关性,就应路由命中。
- 若候选规则明显无关,则不要命中。
【判断依据】
- 以当前分段文本为主。
- 可参考上下文记忆辅助理解术语,但不得脱离当前分段文本做臆断。
【输出约束】
- 严格输出 JSON。
- 每个命中规则需给出简短 reason,说明该分段为何与规则相关。
- 若确实没有任何相关规则,返回 {"selected_items": []}。
'''
ROUTER_USER_PROMPT = '''
【当前分段文本】
{segment_text}
【上下文记忆】
{context_memories_json}
【合同立场】
{party_role}
【候选审查规则
{candidate_rules_json}
【任务】
请从候选规则中选择当前分段应执行的审查项,并输出 selected_items。
'''
ROUTER_OUTPUT_SCHEMA = '''
```json
{
"selected_items": [
{
"title": "规则标题",
"reason": "命中原因(简短)"
}
]
}
```
'''
@tool("segment_rule_router", "分段规则路由")
class SegmentRuleRouterTool(LLMTool):
def __init__(self) -> None:
super().__init__(ROUTER_SYSTEM_PROMPT)
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"ruleset_id": {"type": "string"},
"party_role": {"type": "string"},
"context_memories": {"type": "array"},
},
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"],
}
)
def run(
self,
segment_id: int,
segment_text: str,
ruleset_id: str,
party_role: str,
context_memories: Optional[List[Dict]] = None,
) -> Dict:
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
routed_rules = self._route_rules(
segment_text=segment_text,
rules=rules,
party_role=party_role,
context_memories=context_memories,
)
return {
"segment_id": segment_id,
"ruleset_id": ruleset_id,
"routed_rules": routed_rules,
"routed_rule_titles": [r.get("title", "") for r in routed_rules],
}
def _build_candidate_rules(self, rules: List[Dict]) -> List[Dict]:
return [
{
"title": r.get("title", ""),
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
}
for r in rules
if r.get("title")
]
def _route_rules(
self,
segment_text: str,
rules: List[Dict],
party_role: str,
context_memories: Optional[List[Dict]],
) -> List[Dict]:
if not rules:
return []
candidates = self._build_candidate_rules(rules)
user_content = ROUTER_USER_PROMPT.format(
segment_text=segment_text,
context_memories_json=json.dumps(context_memories or [], ensure_ascii=False),
party_role=party_role,
candidate_rules_json=json.dumps(candidates, ensure_ascii=False),
) + ROUTER_OUTPUT_SCHEMA
llm_selected: List[Dict] = []
try:
resp = self.run_with_loop(self.chat_async(self.build_messages(user_content)))
data = self.parse_first_json(resp)
llm_selected = data.get("selected_items", []) or []
except Exception:
llm_selected = []
selected_titles = {str(item.get("title", "")).strip() for item in llm_selected if item.get("title")}
selected_reasons = {
str(item.get("title", "")).strip(): str(item.get("reason", "")).strip()
for item in llm_selected
if item.get("title")
}
if not selected_titles:
return self._fallback_route(segment_text=segment_text, rules=rules)
title_to_rule = {str(r.get("title", "")).strip(): r for r in rules if r.get("title")}
routed_rules: List[Dict] = []
for title in selected_titles:
rule = title_to_rule.get(title)
if not rule:
continue
routed_rules.append(
{
"id": rule.get("id", ""),
"title": rule.get("title", ""),
"level": rule.get("level", ""),
"rule": rule.get("rule", ""),
"triggers": rule.get("triggers", ""),
"reason": selected_reasons.get(title, ""),
}
)
return routed_rules or self._fallback_route(segment_text=segment_text, rules=rules)
def _fallback_route(self, segment_text: str, rules: List[Dict]) -> List[Dict]:
text = segment_text or ""
routed: List[Dict] = []
for r in rules:
triggers = self._parse_triggers(str(r.get("triggers", "")))
if triggers and any(t in text for t in triggers):
routed.append(
{
"id": r.get("id", ""),
"title": r.get("title", ""),
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
"reason": "fallback: trigger matched",
}
)
# 兜底策略:若触发词也未命中,返回全部规则,保证召回不漏审。
if not routed:
for r in rules:
routed.append(
{
"id": r.get("id", ""),
"title": r.get("title", ""),
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
"reason": "fallback: conservative full recall",
}
)
return routed
def _parse_triggers(self, trigger_text: str) -> List[str]:
parts = re.split(r"[,,、;;\s/|]+", trigger_text or "")
return [p.strip() for p in parts if p.strip()]
if __name__ == "__main__":
tool = SegmentRuleRouterTool()
demo_segment_text = (
"甲方应在合同签订后5个工作日内向乙方支付合同总价30%作为预付款,"
"剩余70%在乙方完成交付并经甲方验收合格后30日内支付。"
"若甲方逾期付款,每逾期一日按应付未付金额的0.05%支付违约金。"
)
result = tool.run(
segment_id=1,
segment_text=demo_segment_text,
ruleset_id="通用",
party_role="甲方",
context_memories=[],
)
print(json.dumps(result, ensure_ascii=False, indent=2))
......@@ -7,22 +7,55 @@ from typing import Dict, List, Optional
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
from core.config import FACT_DIMENSIONS
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
SUMMARY_SYSTEM_PROMPT = f'''
你是合同事实提取智能体(SegmentSummary)。
仅输出“本分段的客观事实”,不做风险判断,不做主观推测。
你的任务是从当前合同分段中提取“客观事实”,并按指定维度结构化输出。
【事实定义】
事实必须满足:
1. 可以在当前分段原文中直接找到对应表述;
2. 不得对原文进行抽象、概括或推断;
3. 不得补充未出现的主体、条件或数值;
4. 允许对原文做最小结构化拆分(例如金额、比例、期限)。
【输出结构】
- facts: 一个对象,键为预设维度,值为该分段出现的事实(未出现的维度可缺省或置空)。
- 维度列表:{", ".join(FACT_DIMENSIONS)}。
- 若原文包含多个事实,可使用列表或子对象表达,但保持紧凑、可读。
- 输出字段:facts
- facts 是一个对象
- 键为以下预设维度:
{", ".join(FACT_DIMENSIONS)}
- 每个维度值必须是对象或对象列表
- 未出现的维度可以省略
【结构规则】
- 仅提取对合同履行或责任具有实际意义的事实
- 不得输出字符串作为维度值,必须使用对象
- 不得输出解释、总结或风险判断
【上下文事实使用规则】
上下文事实仅用于:
- 避免重复提取已存在的事实
- 保持字段命名一致
不得:
- 使用上下文事实补充当前分段没有出现的信息
- 修改当前分段原文事实
【约束】
- 严禁编造或改写原文未出现的信息。
- 不输出与事实无关的解释或额外文字。
- 严禁编造信息
- 严禁推断未出现的内容
- 严格输出 JSON
'''
SUMMARY_USER_PROMPT = '''
......@@ -32,8 +65,11 @@ SUMMARY_USER_PROMPT = '''
【上下文事实】
{context_facts}
请提取本段出现的客观事实,按照指定维度输出 JSON。未出现的维度可省略。
输出示例:'''
仅提取当前分段中明确出现的客观事实。
不得从上下文事实中补充新的信息。
输出 JSON。
'''
OUTPUT_EXAMPLE = '''
```json
......
......@@ -3,7 +3,7 @@ import os
import re
import sys
sys.path.append('..')
sys.path.append('../..')
import traceback
import concurrent.futures
......@@ -12,10 +12,29 @@ from loguru import logger
from utils.common_util import random_str
from utils.http_util import upload_file, fastgpt_openai_chat, download_file
SUFFIX='_麓发改进'
batch_input_dir_path = 'jp-input'
batch_output_dir_path = 'jp-output-lufa'
batch_size = 5
# 麓发fastgpt接口
# url = 'http://192.168.252.71:18089/api/v1/chat/completions'
# 金盘fastgpt接口
url = 'http://192.168.252.71:18088/api/v1/chat/completions'
# 麓发合同审查生产token
# token = 'fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz'
# 金盘迁移麓发合同审查测试token
token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyMK7'
# 人机交互测试(生产环境)
# token = 'fastgpt-ry4jIjgNwmNgufMr5jR0ncvJVmSS4GZl4bx2ItsNPoncdQzW9Na3IP1Xrankr'
# 提取后审查测试
# token = 'fastgpt-n74gGX5ZqLT6o1ysMBSGUTjIciswYOWDRfQ75krMkE5gDVDkpzsbz8u'
def extract_url(text):
# \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx))
excel_p, doc_p = r'导出Excel结果\s*([^"]*xlsx)', r'导出Doc结果\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx))
# 麓发正则
# excel_p, doc_p = r'导出Excel结果\s*([^"]*xlsx)', r'导出Doc结果\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 金盘正则
excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 使用 re.search() 查找第一个匹配项
excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text)
if excel_m and doc_m:
......@@ -39,20 +58,17 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count
# 源目标处理
original_file = f'{batch_input_dir_path}/{file}'
des_check_file = f'{batch_output_dir_path}/{file_name}.md'
des_excel_file = f'{batch_output_dir_path}/{file_name}.xlsx'
des_doc_file = f'{batch_output_dir_path}/{file_name}{ext_name}'
des_excel_file = f'{batch_output_dir_path}/{file_name}{SUFFIX}.xlsx'
des_doc_file = f'{batch_output_dir_path}/{file_name}{SUFFIX}{ext_name}'
try:
# 处理原文件
file_url = upload_file(original_file, input_url_to_inner=True).replace('218.77.58.8', '192.168.252.71')
model = 'Qwen2-72B-Instruct'
# url = 'http://218.77.58.8:8088/api/v1/chat/completions'
url = 'http://192.168.252.71:18089/api/v1/chat/completions'
# 合同审核Excel工作流处理
logger.info(' 第{}个文件,处理文件: {}'.format(counter, original_file))
# # 合同审查测试token
token = 'fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz'
result = fastgpt_openai_chat(url, token, model, random_str(), file_url, f'0304批处理任务-{file_name}', False)
result = fastgpt_openai_chat(url, token, model, random_str(), file_url, f'测试批处理任务-{file_name}', False)
excel_url, doc_url = extract_url(result)
if excel_url and doc_url:
download_file(excel_url.replace('218.77.58.8', '192.168.252.71'), des_excel_file)
......@@ -64,10 +80,9 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count
def execute_batch(max_workers: int = 4):
batch_input_dir_path = 'input'
batch_output_dir_path = 'output'
start_file = 1
dirs = os.listdir(batch_input_dir_path)
os.makedirs(batch_output_dir_path, exist_ok=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
......@@ -87,4 +102,4 @@ def execute_batch(max_workers: int = 4):
if __name__ == '__main__':
execute_batch(5)
\ No newline at end of file
execute_batch(batch_size)
\ No newline at end of file
from spire.doc import *
from spire.doc.common import *
# 创建一个 Document 类对象并加载一个 Word 文档
benchmark_path = '/home/ccran/contract_review_agent/benchmark'
datasets_path = f'{benchmark_path}/datasets'
clean_path = f'{benchmark_path}/clean'
items = os.listdir(datasets_path)
for item in items:
# 创建一个 Document 类的对象
doc = Document()
# 加载包含批注的 Word 文档
doc.LoadFromFile(f"{datasets_path}/{item}")
doc.Comments.Clear()
doc.SaveToFile(f"{clean_path}/{item}")
doc.Close()
import argparse
from pathlib import Path
import pandas as pd
from rapidfuzz import fuzz
from contextlib import redirect_stdout, redirect_stderr
fuzz_score_threshold = 80
def _normalize_cell(value: object) -> str:
if pd.isna(value):
return ""
return str(value).strip()
def _load_rows(path: Path) -> list[tuple[str, str]]:
df = pd.read_excel(path, dtype=str)
if df.empty:
return []
first_two = df.iloc[:, :2].copy()
# 避免 DataFrame.applymap 的 FutureWarning:按列使用 Series.map
for col in first_two.columns:
first_two[col] = first_two[col].map(_normalize_cell)
first_two = first_two.replace("", pd.NA).dropna(how="all")
return list(map(tuple, first_two.itertuples(index=False, name=None)))
def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
val_dir = val_dir.resolve()
answer_dir = answer_dir.resolve()
overall_val = overall_answer = overall_matched = 0
# 累计各“审查项”的全局统计
overall_item_answer: dict[str, int] = {}
overall_item_matched: dict[str, int] = {}
overall_item_unmatched_answer: dict[str, int] = {}
overall_item_unmatched_val: dict[str, int] = {}
for val_file in sorted(val_dir.glob("*.xlsx")):
answer_file = answer_dir / val_file.name
if not answer_file.exists():
print(f"Skip {val_file.name}: missing in answer")
continue
val_rows = _load_rows(val_file)
answer_rows = _load_rows(answer_file)
# Baseline: answer -> match val, consume val to keep 1-1, report leftover answers
answer_counts: dict[str, int] = {}
for item, _ in answer_rows:
answer_counts[item] = answer_counts.get(item, 0) + 1
val_buckets: dict[str, list[str]] = {}
for item, text in val_rows:
val_buckets.setdefault(item, []).append(text)
matched_total = 0
matched_by_item: dict[str, list[tuple[str, str, int]]] = {}
unmatched_answer_by_item: dict[str, list[str]] = {}
for item, ans_text in answer_rows:
candidates = val_buckets.get(item, [])
if not candidates:
unmatched_answer_by_item.setdefault(item, []).append(ans_text)
continue
best_idx = -1
best_score = -1
for idx, cand in enumerate(candidates):
ans_text = ans_text.strip()
cand = cand.strip()
score = max(
fuzz.partial_ratio(ans_text, cand),
fuzz.token_set_ratio(ans_text, cand)
)
if score > best_score:
best_score = score
best_idx = idx
if best_score >= fuzz_score_threshold:
matched_total += 1
matched_val = candidates.pop(best_idx)
matched_by_item.setdefault(item, []).append((ans_text, matched_val, best_score))
else:
unmatched_answer_by_item.setdefault(item, []).append(ans_text)
# remaining vals in buckets are unmatched
unmatched_val_by_item: dict[str, list[str]] = {
item: texts for item, texts in val_buckets.items() if texts
}
val_total = len(val_rows)
answer_total = len(answer_rows)
overall_val += val_total
overall_answer += answer_total
overall_matched += matched_total
unmatched_val_count = sum(len(v) for v in unmatched_val_by_item.values())
unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values())
# 累加到各“审查项”的全局统计
for it, cnt in answer_counts.items():
overall_item_answer[it] = overall_item_answer.get(it, 0) + cnt
for it, lst in matched_by_item.items():
overall_item_matched[it] = overall_item_matched.get(it, 0) + len(lst)
for it, lst in unmatched_answer_by_item.items():
overall_item_unmatched_answer[it] = overall_item_unmatched_answer.get(it, 0) + len(lst)
for it, lst in unmatched_val_by_item.items():
overall_item_unmatched_val[it] = overall_item_unmatched_val.get(it, 0) + len(lst)
print('#' * 40)
print(
f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} "
f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | accuracy {matched_total / answer_total:.2%} | invalid_val {(unmatched_val_count / val_total) if val_total != 0 else 0:.2%}"
)
for item in sorted(answer_counts):
item_matches = matched_by_item.get(item, [])
print(f" 审查项 {item}: matched {len(item_matches)} / {answer_counts[item]}")
# 匹配成功的结果
# for ans_text, val_text, score in item_matches:
# print(f" {score}% | answer: {ans_text} | val: {val_text}")
ua = unmatched_answer_by_item.get(item, [])
if ua:
print(f" 未匹配(answer 未被匹配){len(ua)} 条:")
for t in ua:
print(f" answer: {t}")
uv = unmatched_val_by_item.get(item, [])
if uv:
print(f" 未匹配(val 残留){len(uv)} 条:")
for t in uv:
print(f" val: {t}")
# break # only first file for demo
accuracy = overall_matched / overall_answer if overall_answer else 0
invalid_val = (overall_val - overall_matched) / overall_val if overall_val else 0
print(
f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | accuracy {accuracy:.2%} | invalid_val {invalid_val:.2%}"
)
# 按“审查项”的 overall 结果
if overall_item_answer:
print('#' * 40)
print("Overall by item:")
all_items = sorted(set(list(overall_item_answer.keys()) + list(overall_item_matched.keys()) + list(overall_item_unmatched_answer.keys()) + list(overall_item_unmatched_val.keys())))
rows_by_item = []
for it in all_items:
ans = overall_item_answer.get(it, 0)
mat = overall_item_matched.get(it, 0)
u_ans = overall_item_unmatched_answer.get(it, 0)
u_val = overall_item_unmatched_val.get(it, 0)
acc = (mat / ans) if ans else 0
invalid_val = u_val / (mat + u_val) if (mat + u_val) else 0
rows_by_item.append({
"审查项": it,
"大模型匹配上的不合格项": mat,
"合同所有不合格项": ans,
"大模型其他不合格项": u_val,
"大模型未匹配上的不合格项(C-B)": u_ans,
"查全率(B/C)": acc,
"无关审查率(D/B+D)": invalid_val,
})
print(
f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | accuracy {acc:.2%} | invalid_val {invalid_val:.2%}"
)
overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"])
unmatched_val_total = sum(overall_item_unmatched_val.values())
unmatched_answer_total = sum(overall_item_unmatched_answer.values())
overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0
overall_total_df = pd.DataFrame([
{
"审查项": "总体",
"大模型匹配上的不合格项": overall_matched,
"合同所有不合格项": overall_answer,
"大模型其他不合格项": unmatched_val_total,
"大模型未匹配上的不合格项(C-B)": unmatched_answer_total,
"查全率(B/C)": accuracy,
"无关审查率(D/B+D)": overall_invalid_rate,
}
], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"])
combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True)
compare_dir_name = val_dir.name
results_dir = Path(__file__).parent / "results"
results_dir.mkdir(parents=True, exist_ok=True)
output_excel = results_dir / f"合同审查结果_{compare_dir_name}.xlsx"
with pd.ExcelWriter(output_excel, engine="openpyxl") as writer:
combined_df.to_excel(writer, sheet_name="对比结果", index=False)
print(f"Excel written to {output_excel}")
def compare(val_dir: Path, answer_dir: Path) -> None:
_compare_impl(val_dir=val_dir, answer_dir=answer_dir)
def compare_with_log(val_dir: Path, answer_dir: Path, log_path: Path | None = None) -> Path:
val_dir = val_dir.resolve()
if log_path is None:
results_dir = Path(__file__).parent / "results"
results_dir.mkdir(parents=True, exist_ok=True)
log_path = results_dir / f"合同审查结果_{val_dir.name}.log"
else:
log_path = log_path.resolve()
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "w", encoding="utf-8") as f, redirect_stdout(f), redirect_stderr(f):
_compare_impl(val_dir=val_dir, answer_dir=answer_dir)
return log_path
def _parse_args() -> argparse.Namespace:
base = Path(__file__).parent
parser = argparse.ArgumentParser(description="Compare extracted annotations with answers.")
parser.add_argument(
"--val-dir",
type=Path,
default=base / "batch_output_0121_val",
help="Directory containing extracted val xlsx files.",
)
parser.add_argument(
"--answer-dir",
type=Path,
default=base / "审查答案",
help="Directory containing answer xlsx files.",
)
parser.add_argument(
"--log-path",
type=Path,
default=None,
help="Optional explicit log path. Defaults to results/合同审查结果_<val_dir_name>.log",
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
final_log_path = compare_with_log(
val_dir=args.val_dir,
answer_dir=args.answer_dir,
log_path=args.log_path,
)
print(f"Log written to {final_log_path}")
\ No newline at end of file
from __future__ import annotations
import argparse
import re
from pathlib import Path
from typing import Iterable
import pandas as pd
from spire.doc import Document
from compare_annotation import compare_with_log
# Map raw comment authors to unified review item names.
COMMENT_AUTHOR_MAPPING: dict[str, str] = {
"三方货款审查":"第三方审查",
"履行义务审查":"第三方审查",
"违约条款审查":"违约与延期审查",
"延期审查":"违约与延期审查"
}
def clean_illegal(value: object) -> object:
if isinstance(value, str):
return re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F]").sub("", value)
return value
def normalize_comment_author(author: str) -> str:
author = author.strip()
if not author:
return author
return COMMENT_AUTHOR_MAPPING.get(author, author)
def extract_annotaion(
datasets_dir: Path,
output_dir: Path,
strip_suffixes: list[str] | None = None,
) -> None:
"""Extract review comments from Word files to xlsx files."""
datasets_dir = datasets_dir.resolve()
output_dir = output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
strip_suffixes = strip_suffixes or []
for item in sorted(datasets_dir.iterdir()):
if item.suffix.lower() == ".xlsx" or not item.is_file():
continue
document = Document()
document.LoadFromFile(str(item))
comments: list[dict[str, str]] = []
for i in range(document.Comments.Count):
comment = document.Comments[i]
comment_text = ""
for j in range(comment.Body.Paragraphs.Count):
paragraph = comment.Body.Paragraphs[j]
comment_text += paragraph.Text + "\n"
comment_author = comment.Format.Author
# 通过|作为分隔符,只拿以后的进行比对
author_split_idx = comment_author.find("|")
comment_author = (
comment_author[author_split_idx + 1 :]
if author_split_idx != -1
else comment_author
)
comment_author = normalize_comment_author(comment_author)
comments.append(
{
"审查项": clean_illegal(comment_author),
"合同原文": clean_illegal(comment.OwnerParagraph.Text),
"建议": clean_illegal(comment_text),
}
)
df = pd.DataFrame(comments)
clean_stem = _strip_suffix_once(item.stem, strip_suffixes)
output_stem = clean_stem or item.stem
output_file = output_dir / f"{output_stem}.xlsx"
df.to_excel(output_file, index=False)
document.Close()
def compare_annotaion(val_dir: Path, answer_dir: Path) -> None:
"""Run benchmark comparison on extracted annotations."""
log_path = compare_with_log(val_dir=val_dir, answer_dir=answer_dir)
print(f"Compare log written to: {log_path}")
def _strip_suffix_once(stem: str, suffixes: Iterable[str]) -> str:
for suffix in suffixes:
if suffix and stem.endswith(suffix):
return stem[: -len(suffix)]
return stem
def eval(
datasets_dir: Path,
answer_dir: Path,
val_dir: Path,
strip_suffixes: list[str] | None = None,
) -> None:
"""Pipeline: extract annotations first, then compare against ground truth."""
strip_suffixes = strip_suffixes or []
extract_annotaion(
datasets_dir=datasets_dir,
output_dir=val_dir,
strip_suffixes=strip_suffixes,
)
compare_annotaion(val_dir=val_dir, answer_dir=answer_dir)
def _parse_args() -> argparse.Namespace:
base = Path(__file__).parent
parser = argparse.ArgumentParser(
description="Extract review comments from docs and evaluate against answers."
)
parser.add_argument(
"--datasets-dir",
type=Path,
default=base / "results" / "jp-output",
help="Directory containing Word files with annotations.",
)
parser.add_argument(
"--answer-dir",
type=Path,
default=base / "审查答案",
help="Directory containing labeled answer xlsx files.",
)
parser.add_argument(
"--val-dir",
type=Path,
default=base / "results" / "jp-output-extracted",
help="Directory to store extracted xlsx files for comparison.",
)
parser.add_argument(
"--strip-suffixes",
nargs="*",
default=['_人机交互'],
help=(
"Optional filename suffixes to strip from generated val xlsx stems before "
"comparison, e.g. --strip-suffixes _v1 _审阅版"
),
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
eval(
datasets_dir=args.datasets_dir,
answer_dir=args.answer_dir,
val_dir=args.val_dir,
strip_suffixes=args.strip_suffixes,
)
from __future__ import annotations
import argparse
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from utils.excel_util import ExcelUtil
from utils.spire_word_util import SpireWordDoc
EXCEL_SUFFIXES = {".xlsx", ".xlsm"}
WORD_SUFFIXES = {".doc", ".docx"}
HEADER_KEYWORDS = {
"key_points",
"chunk_id",
"original_text",
"details",
"suggest",
"建议",
"详情",
"审查项",
"原文",
"合同原文",
"文件块序号",
}
def _normalize_cell(value: object) -> str:
if value is None:
return ""
if isinstance(value, float) and value.is_integer():
return str(int(value))
return str(value).strip()
def _looks_like_header(row: list[object]) -> bool:
header_cells = {_normalize_cell(cell).lower() for cell in row[:6] if _normalize_cell(cell)}
return bool(header_cells & HEADER_KEYWORDS)
def _parse_chunk_id(raw_value: object, *, chunk_id_base: int, row_number: int, excel_path: Path) -> int:
raw_text = _normalize_cell(raw_value)
if not raw_text:
raise ValueError(f"{excel_path.name} 第 {row_number} 行缺少 chunk_id")
try:
chunk_id = int(float(raw_text))
except ValueError as exc:
raise ValueError(
f"{excel_path.name} 第 {row_number} 行 chunk_id 无法解析为整数: {raw_text}"
) from exc
normalized_chunk_id = chunk_id - chunk_id_base
if normalized_chunk_id < 0:
raise ValueError(
f"{excel_path.name} 第 {row_number} 行 chunk_id 越界: {chunk_id}"
)
return normalized_chunk_id
def load_comments_from_excel(
excel_path: Path,
*,
sheet_name: str | None = None,
chunk_id_base: int = 1,
skip_header: bool = True,
) -> list[dict[str, object]]:
rows = ExcelUtil.load_excel(excel_path, sheet_name=sheet_name, has_header=False)
comments: list[dict[str, object]] = []
for row_number, raw_row in enumerate(rows, start=1):
if not isinstance(raw_row, list):
continue
row = list(raw_row)
if not any(_normalize_cell(cell) for cell in row):
continue
if skip_header and row_number == 1 and _looks_like_header(row):
continue
padded = row + [None] * max(0, 6 - len(row))
key_points = _normalize_cell(padded[0])
original_text = _normalize_cell(padded[2])
details = _normalize_cell(padded[3])
suggest = _normalize_cell(padded[5])
if not key_points and not original_text and not suggest:
continue
chunk_id = _parse_chunk_id(
padded[1],
chunk_id_base=chunk_id_base,
row_number=row_number,
excel_path=excel_path,
)
comments.append(
{
"id": f"{excel_path.stem}-{row_number}",
"key_points": key_points,
"chunk_id": chunk_id,
"original_text": original_text,
"details": details,
"result": "不合格",
"suggest": suggest,
}
)
return comments
def _build_doc_index(doc_dir: Path) -> dict[str, Path]:
index: dict[str, Path] = {}
for path in sorted(doc_dir.iterdir()):
if not path.is_file() or path.suffix.lower() not in WORD_SUFFIXES:
continue
index.setdefault(path.stem, path)
return index
def _match_doc_path(excel_path: Path, doc_index: dict[str, Path]) -> Path | None:
return doc_index.get(excel_path.stem)
def _close_spire_doc(doc: SpireWordDoc) -> None:
internal_doc = getattr(doc, "_doc", None)
if internal_doc is None:
return
try:
internal_doc.Close()
except Exception:
pass
def clear_comments(doc: SpireWordDoc) -> None:
internal_doc = getattr(doc, "_doc", None)
if internal_doc is None:
return
try:
internal_doc.Comments.Clear()
except Exception:
pass
def annotate_doc_from_excel(
excel_path: Path,
doc_path: Path,
output_path: Path,
*,
sheet_name: str | None = None,
chunk_id_base: int = 1,
skip_header: bool = True,
) -> int:
comments = load_comments_from_excel(
excel_path,
sheet_name=sheet_name,
chunk_id_base=chunk_id_base,
skip_header=skip_header,
)
doc = SpireWordDoc()
try:
doc.load(str(doc_path))
clear_comments(doc)
grouped_comments: dict[int, list[dict[str, object]]] = {}
for comment in comments:
chunk_id = int(comment["chunk_id"])
if chunk_id >= doc.get_chunk_num():
raise ValueError(
f"{excel_path.name} 中的 chunk_id {chunk_id + chunk_id_base} 超出文档块数量 {doc.get_chunk_num()}"
)
grouped_comments.setdefault(chunk_id, []).append(comment)
for chunk_id, chunk_comments in grouped_comments.items():
doc.add_chunk_comment(chunk_id, chunk_comments)
output_path.parent.mkdir(parents=True, exist_ok=True)
doc.to_file(str(output_path), remove_prefix=True)
return len(comments)
finally:
_close_spire_doc(doc)
def batch_annotate_docs(
excel_dir: Path,
output_dir: Path,
*,
sheet_name: str | None = None,
chunk_id_base: int = 1,
skip_header: bool = True,
) -> None:
excel_dir = excel_dir.resolve()
output_dir = output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
doc_index = _build_doc_index(excel_dir)
processed_files = 0
skipped_files = 0
total_comments = 0
for excel_path in sorted(excel_dir.iterdir()):
if not excel_path.is_file() or excel_path.suffix.lower() not in EXCEL_SUFFIXES:
continue
doc_path = _match_doc_path(excel_path, doc_index)
if doc_path is None:
skipped_files += 1
print(f"Skip {excel_path.name}: 未找到同名 Word 文件")
continue
output_path = output_dir / doc_path.name
comment_count = annotate_doc_from_excel(
excel_path=excel_path,
doc_path=doc_path,
output_path=output_path,
sheet_name=sheet_name,
chunk_id_base=chunk_id_base,
skip_header=skip_header,
)
processed_files += 1
total_comments += comment_count
print(
f"Processed {excel_path.name} -> {output_path.name} | comments: {comment_count}"
)
print(
f"Done: processed {processed_files} files, skipped {skipped_files} files, total comments {total_comments}"
)
def _parse_args() -> argparse.Namespace:
base = Path(__file__).parent
parser = argparse.ArgumentParser(
description="Read Excel comments in bulk and write them into Word documents."
)
parser.add_argument(
"--excel-dir",
type=Path,
required=True,
help="Directory containing Excel files.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=base / "results" / "re_comment_output",
help="Directory to write annotated Word files.",
)
parser.add_argument(
"--chunk-id-base",
type=int,
choices=(0, 1),
default=0,
help="Whether chunk_id in Excel is 0-based or 1-based. Defaults to 0.",
)
parser.add_argument(
"--no-skip-header",
action="store_true",
help="Do not auto-skip the first row when it looks like a header.",
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
batch_annotate_docs(
excel_dir=args.excel_dir,
output_dir=args.output_dir,
chunk_id_base=args.chunk_id_base,
skip_header=not args.no_skip_header,
)
No preview for this file type
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from uuid import uuid4
import ast
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn
import traceback
from loguru import logger
from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file
......@@ -14,6 +16,7 @@ from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats, pdf_support_formats
from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool
from core.tools.segment_rule_router import SegmentRuleRouterTool
from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding
......@@ -22,6 +25,7 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool()
rule_router_tool = SegmentRuleRouterTool()
reflect_tool = ReflectRetryTool()
......@@ -73,7 +77,7 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse
doc_obj.load(file_path)
# ocr
await doc_obj.get_from_ocr()
# text = doc_obj.get_all_text()
text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list()
# get ruleset items
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id
......@@ -81,7 +85,7 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse
ruleset_review_items = [r.get('title') for r in ruleset_items]
return DocumentParseResponse(
conversation_id=payload.conversation_id,
# text=text,
text=text,
chunk_ids=chunk_ids,
ruleset_items=ruleset_review_items,
file_ext = file_ext
......@@ -116,12 +120,11 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
result = summary_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
party_role=payload.party_role or "",
context_facts=payload.context_facts or store.get_facts(),
context_facts=payload.context_facts,
)
store.add_facts(result)
......@@ -138,6 +141,7 @@ class SegmentReviewRequest(BaseModel):
segment_id: int
party_role: Optional[str] = ""
ruleset_id: Optional[str] = "通用"
routed_rule_titles: Optional[List[str]] = None
file_ext: str
context_memories: Optional[List[Dict]] = None
......@@ -148,6 +152,14 @@ class SegmentReviewResponse(BaseModel):
overall_conclusion: str
findings: List[Dict]
class SegmentRuleRouterResponse(BaseModel):
conversation_id: str
segment_id: int
ruleset_id: str
routed_rule_titles: List[str]
routed_rules: List[Dict]
@app.post("/segments/review/findings", response_model=SegmentReviewResponse)
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id)
......@@ -166,8 +178,10 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
segment_id=payload.segment_id,
segment_text=segment_text,
ruleset_id=payload.ruleset_id or "通用",
routed_rule_titles=payload.routed_rule_titles,
party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(),
# TODO 获取与当前审查相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
context_memories=payload.context_memories,
)
# Persist findings to memory store
......@@ -181,9 +195,9 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
})
except Exception:
except Exception as e:
logger.error(e)
continue
return SegmentReviewResponse(
conversation_id=payload.conversation_id,
segment_id=payload.segment_id,
......@@ -191,6 +205,37 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
findings=result.get("findings", []),
)
@app.post("/segments/review/rule-router", response_model=SegmentRuleRouterResponse)
def route_segment_rules(payload: SegmentReviewRequest) -> SegmentRuleRouterResponse:
try:
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id
result = rule_router_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
ruleset_id=ruleset_id,
party_role=payload.party_role or "",
context_memories=payload.context_memories,
)
return SegmentRuleRouterResponse(
conversation_id=payload.conversation_id,
segment_id=payload.segment_id,
ruleset_id=ruleset_id,
routed_rule_titles=result.get("routed_rule_titles", []),
routed_rules=result.get("routed_rules", []),
)
########################################################################################################################
class ReflectReviewRequest(BaseModel):
conversation_id: str
......@@ -212,7 +257,10 @@ def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
facts = store.get_facts()
# TODO 获取与当前审查规则相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
# facts = store.get_facts()
facts = []
# 查找审查规则对应的 findings
findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)]
final_findings = reflect_tool.run(
party_role=payload.party_role,
......@@ -253,6 +301,35 @@ def new_conversation() -> ConversationResponse:
########################################################################################################################
class FactsRetrieveRequest(BaseModel):
conversation_id: str
keywords: List[str] = Field(..., description="facts 检索关键字列表")
class FactsRetrieveResponse(BaseModel):
conversation_id: str
keywords: List[str]
facts: List[Any]
total: int
@app.post("/memory/facts/retrieve", response_model=FactsRetrieveResponse)
def retrieve_facts(payload: FactsRetrieveRequest) -> FactsRetrieveResponse:
keywords = [k.strip() for k in (payload.keywords or []) if isinstance(k, str) and k.strip()]
if not keywords:
raise HTTPException(status_code=400, detail="keywords cannot be empty")
store = get_cached_memory(payload.conversation_id)
matched_facts = store.search_facts(keywords)
return FactsRetrieveResponse(
conversation_id=payload.conversation_id,
keywords=keywords,
facts=matched_facts,
total=len(matched_facts),
)
########################################################################################################################
class MemoryExportRequest(BaseModel):
conversation_id: str
file_ext: str
......@@ -293,6 +370,11 @@ def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
)
if __name__ == "__main__":
from core.config import use_lufa
if use_lufa:
port = 18168
else:
port = 18169
uvicorn.run(
"main:app", host="0.0.0.0", port=18168, log_level="info", reload=False
"main:app", host="0.0.0.0", port=port, log_level="info", reload=False
)
\ No newline at end of file
from main import FactsRetrieveRequest, retrieve_facts
from core.cache import get_cached_memory
import json
def test_retrieve_facts_direct() -> None:
conversation_id = "fa86563cb6c649d59e32e7def16ea6b2"
payload = FactsRetrieveRequest(
conversation_id=conversation_id,
keywords=["当事人"],
)
res = retrieve_facts(payload)
print(json.dumps(res.facts,ensure_ascii=False, indent=4))
if __name__ == "__main__":
test_retrieve_facts_direct()
import os
from abc import ABC, abstractmethod
from core.config import MAX_SINGLE_CHUNK_SIZE
# 文档基类
......@@ -8,7 +9,7 @@ class DocBase(ABC):
self._doc_path = None
self._doc_name = None
self._kwargs = kwargs
self._max_single_chunk_size = kwargs.get('max_single_chunk_size', 2000)
self._max_single_chunk_size = kwargs.get('max_single_chunk_size', MAX_SINGLE_CHUNK_SIZE)
@abstractmethod
def load(self, doc_path):
......
from spire.doc import Document, Paragraph, Table, Comment, CommentMark, CommentMarkType
import json
from loguru import logger
import re
from thefuzz import fuzz
......@@ -150,147 +149,127 @@ def process_string(s):
return max(middle_parts, key=len, default="")
def build_mapping(original: str):
"""构造规范化文本和原文索引映射"""
normalized = []
mapping = []
for m in re.finditer(r"\S+", original):
word = m.group()
if normalized:
normalized.append(" ")
mapping.append(m.start()) # 空格映射
for j, ch in enumerate(word):
normalized.append(ch)
mapping.append(m.start() + j)
return "".join(normalized), mapping
def _score_target_against_query(target_text: str, query_text: str):
"""对单个候选文本与查询文本打分,并返回最适合落批注的匹配片段。"""
if not target_text or not query_text:
return None, 0
if query_text in target_text:
return query_text, 100
def extract_match(big_text: str, small_text: str, threshold=20):
"""
简化版文本匹配函数
核心逻辑:优先整个文本块匹配,次优子句匹配
"""
# 1. 精确匹配整个文本块
if small_text in big_text:
return small_text, 100
# 2. 整个文本块模糊匹配
full_score = fuzz.ratio(big_text, small_text)
if full_score >= threshold:
return big_text, full_score
# 3. 子句匹配(简单分割)
best_score = 0
best_clause = None
# 简单分割:按句号、分号、逗号分割
for clause in big_text.replace("。", ";").replace(",", ";").split(";"):
if not clause.strip():
continue
# partial_ratio 负责召回,ratio 负责精度;组合分用于排序
def _combined_score(text_a: str, text_b: str):
ratio_score = fuzz.ratio(text_a, text_b)
partial_score = fuzz.partial_ratio(text_a, text_b)
combined = int(round(0.4 * ratio_score + 0.6 * partial_score))
return combined
best_text = target_text
best_score = _combined_score(target_text, query_text)
clause_score = fuzz.ratio(clause, small_text)
for clause in target_text.replace("。", ";").replace(",", ";").split(";"):
clause = clause.strip()
if not clause:
continue
clause_score = _combined_score(clause, query_text)
if clause_score > best_score:
best_score = clause_score
best_clause = clause
# 4. 返回最佳匹配
if best_score >= threshold:
return best_clause, best_score
# 5. 无有效匹配
return None, max(full_score, best_score)
best_text = clause
return best_text, best_score
def find_best_match(sub_chunks, comment):
"""
在给定的文本块中查找与原始评论最匹配的文本
参数:
sub_chunks -- 包含Text属性的对象列表
comment -- 包含"original_text"的字典
def _build_narrowed_queries(text: str, min_len=12):
"""对文本做一步缩窄,生成下一轮候选。"""
if not text:
return []
返回:
best_match -- 匹配度最高的文本
best_score -- 最高匹配度
all_results -- 所有匹配结果列表(匹配文本, 相似度)
"""
all_results = [] # 存储所有(匹配文本, 相似度)的元组
best_match = None # 存储最佳匹配的结果
best_score = -1 # 存储最高相似度(初始化为-1)
text = text.strip()
if len(text) <= min_len:
return []
# print(f"开始处理评论: {comment['original_text'][:30]}...") # 显示简化的原始评论
next_queries = []
cut = max(1, len(text) // 8)
for obj in sub_chunks:
if isinstance(obj, Paragraph):
left_cut = text[cut:]
right_cut = text[:-cut]
center_cut = text[cut:-cut] if len(text) > 2 * cut else ""
target_text = obj.Text
original_text = comment["original_text"]
match_text, score = extract_match(target_text, original_text)
for item in (left_cut, right_cut, center_cut):
item = item.strip()
if len(item) >= min_len:
next_queries.append(item)
# 打印当前结果(保持原格式)
# print("匹配到:\n", match_text)
# print("相似度:", score)
simplified = process_string(text)
if simplified and len(simplified) >= min_len:
next_queries.append(simplified.strip())
# 存储所有结果
all_results.append((match_text, score))
parts = [p.strip() for p in re.split(r"[。;;,,\n]", text) if p.strip()]
if len(parts) > 1:
longest_part = max(parts, key=len)
if len(longest_part) >= min_len:
next_queries.append(longest_part)
# 更新最佳匹配 - 只更新分数更高的结果
if score > best_score:
best_match = match_text
best_score = score
if len(parts) > 2:
mid_join = "".join(parts[1:-1]).strip()
if len(mid_join) >= min_len:
next_queries.append(mid_join)
# 打印最终的最佳匹配结果
# print("\n" + "=" * 40)
# print("\n处理完成 - 最佳匹配结果:")
# print("匹配到:\n", best_match)
# print("相似度:", best_score)
# print("=" * 40 + "\n")
deduped = []
seen = set()
for item in next_queries:
if item not in seen:
seen.add(item)
deduped.append(item)
return deduped
return best_match, best_score
def _find_best_match_in_texts(target_texts, original_text):
"""在候选文本列表中查找与 original_text 最相近的一条(支持递进缩窄查询)。"""
if not target_texts or not original_text:
return None, -1
def table_contract(target_texts, comment):
"""
在给定的文本块中查找与原始评论最匹配的文本
best_match = None
best_score = -1
参数:
sub_chunks -- 待对比文本
comment -- 包含"original_text"的字典
beam_size = 5
max_rounds = 8
min_query_len = 12
返回:
best_match -- 匹配度最高的文本
best_score -- 最高匹配度
all_results -- 所有匹配结果列表(匹配文本, 相似度)
"""
all_results = [] # 存储所有(匹配文本, 相似度)的元组
best_match = None # 存储最佳匹配的结果
best_score = -1 # 存储最高相似度(初始化为-1)
active_queries = [original_text.strip()]
seen_queries = set(active_queries)
# print(f"开始处理评论: {comment['original_text'][:30]}...") # 显示简化的原始评论
for _ in range(max_rounds):
if not active_queries:
break
original_text = comment["original_text"]
for target_text in target_texts:
query_best_scores = []
match_text, score = extract_match(target_text, original_text)
for query in active_queries:
local_best = -1
for target_text in target_texts:
match_text, score = _score_target_against_query(target_text, query)
if score > best_score:
best_match = match_text
best_score = score
if score > local_best:
local_best = score
query_best_scores.append((query, local_best))
# 打印当前结果(保持原格式)
# print("匹配到:\n", match_text)
# print("相似度:", score)
if best_score >= 100:
break
# 存储所有结果
all_results.append((match_text, score))
query_best_scores.sort(key=lambda x: x[1], reverse=True)
top_queries = [q for q, _ in query_best_scores[:beam_size]]
# 更新最佳匹配 - 只更新分数更高的结果
if score > best_score:
best_match = match_text
best_score = score
next_queries = []
for query in top_queries:
for narrowed in _build_narrowed_queries(query, min_len=min_query_len):
if narrowed not in seen_queries:
seen_queries.add(narrowed)
next_queries.append(narrowed)
# 打印最终的最佳匹配结果
# print("\n" + "=" * 40)
# print("\n处理完成 - 最佳匹配结果:")
# print("匹配到:\n", best_match)
# print("相似度:", best_score)
# print("=" * 40 + "\n")
active_queries = next_queries
return best_match, best_score
......@@ -390,10 +369,12 @@ class SpireWordDoc(DocBase):
paragraph_text = cell.Paragraphs.get_Item(para_idx).Text
cell_content += paragraph_text
cell_list.append(cell_content)
table_data += "|" + "|".join(cell_list) + "|"
table_data += "\n"
# table_data += "|" + "|".join(cell_list) + "|"
# table_data += "\n"
table_data += ' '.join(cell_list) + '\n'
if i == 0:
table_data += "|" + "|".join(["--- " for _ in cell_list]) + "|\n"
# table_data += "|" + "|".join(["--- " for _ in cell_list]) + "|\n"
table_data= ' '.join(cell_list) + '\n'
return table_data
def get_chunk_info(self, chunk_id):
......@@ -435,6 +416,21 @@ class SpireWordDoc(DocBase):
def format_comment_author(self, comment):
return "{}|{}".format(str(comment["id"]), comment["key_points"])
def _decorate_author_with_match_type(self, author, match_type):
if match_type == "exact":
return f"(精确){author}"
if match_type == "fuzzy":
return f"(模糊){author}"
return author
def _normalize_author_prefix(self, author):
if not author:
return author
for prefix in ("(精确)", "(模糊)"):
if author.startswith(prefix):
return author[len(prefix) :]
return author
def remove_comment_prefix(
self,
):
......@@ -445,55 +441,130 @@ class SpireWordDoc(DocBase):
if len(split_author) == 2:
current_comment.Format.Author = comment_author.split("|")[1]
# 根据text_selection批注
def set_comment_by_text_selection(self, text_sel, author, comment_content):
if text_sel is None:
def _insert_comment_by_text_range(self, text_range, author, comment_content):
if text_range is None:
return False
# 将找到的文本作为文本范围,并获取其所属的段落
range = text_sel.GetAsOneRange()
paragraph = range.OwnerParagraph
paragraph = text_range.OwnerParagraph
if paragraph is None:
return False
# 创建一个评论对象并设置评论的内容和作者
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_content
comment.Format.Author = author
logger.info(author)
# 将评论添加到段落中
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range) + 1, comment
paragraph.ChildObjects.IndexOf(text_range) + 1, comment
)
# 创建评论起始标记和结束标记,并将它们设置为创建的评论的起始标记和结束标记
commentStart = CommentMark(self._doc, CommentMarkType.CommentStart)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在找到的文本之前和之后插入创建的评论起始和结束标记
comment_start = CommentMark(self._doc, CommentMarkType.CommentStart)
comment_end = CommentMark(self._doc, CommentMarkType.CommentEnd)
comment_start.CommentId = comment.Format.CommentId
comment_end.CommentId = comment.Format.CommentId
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range), commentStart
paragraph.ChildObjects.IndexOf(text_range), comment_start
)
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range) + 1, commentEnd
paragraph.ChildObjects.IndexOf(text_range) + 1, comment_end
)
return True
# 根据段落批注
def set_comment_by_paragraph(self, paragraph, author, comment_content):
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_content
# 设置注释的作者
comment.Format.Author = author
paragraph.ChildObjects.Add(comment)
# 创建注释开始标记和结束标记,并将它们设置为创建的注释的开始和结束标记
commentStart = CommentMark(self._doc, CommentMarkType.CommentStart)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在段落结尾插入注释开始标记和结束标记
# paragraph.ChildObjects.Add(commentStart)
paragraph.ChildObjects.Add(commentEnd)
# 也可以考虑在段落开始处插入标记
paragraph.ChildObjects.Insert(0, commentStart)
def _update_comment_content(self, comment_idx, suggest):
self._doc.Comments.get_Item(comment_idx).Body.Paragraphs.get_Item(0).Text = suggest
def _try_add_comment_in_paragraphs(self, paragraphs, target_text, author, suggest):
if not target_text:
return False
for paragraph in paragraphs:
text_sel = paragraph.Find(target_text, False, True)
if text_sel and self.set_comment_by_text_selection(text_sel, author, suggest):
return True
return False
def _try_add_comment_by_exact(self, sub_chunks, find_key, author, suggest):
for obj in sub_chunks:
if isinstance(obj, Paragraph):
try:
text_sel = obj.Find(find_key, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
return True
except Exception as e:
print(f"段落批注添加失败: {str(e)}")
elif isinstance(obj, Table):
try:
if self.add_table_comment(obj, find_key, suggest, author):
return True
except Exception as e:
print(f"表格批注添加失败: {str(e)}")
return False
def _try_add_comment_by_fuzzy(self, sub_chunks, comment, author, suggest):
original_text = comment.get("original_text", "")
candidates = []
# 段落与表格同权:统一加入候选池,按最高分排序后尝试落批注
for order, obj in enumerate(sub_chunks):
if isinstance(obj, Paragraph):
match_text, score = _find_best_match_in_texts([obj.Text], original_text)
candidates.append(
{
"kind": "paragraph",
"obj": obj,
"match_text": match_text,
"score": score,
"order": order,
}
)
elif isinstance(obj, Table):
table_data = extract_table_cells_text(obj)
match_text, score = _find_best_match_in_texts(table_data, original_text)
candidates.append(
{
"kind": "table",
"obj": obj,
"match_text": match_text,
"score": score,
"order": order,
}
)
candidates = [
item
for item in candidates
if item.get("match_text") and item.get("score", -1) >= 0
]
candidates.sort(key=lambda x: (-x["score"], x["order"]))
for item in candidates:
match_text = item["match_text"]
processed_text = process_string(match_text) if match_text else ""
if item["kind"] == "paragraph":
paragraph = item["obj"]
if self._try_add_comment_in_paragraphs(
[paragraph], match_text, author, suggest
):
return True
if self._try_add_comment_in_paragraphs(
[paragraph], processed_text, author, suggest
):
return True
else:
table = item["obj"]
if self.add_table_comment(table, match_text, suggest, author):
return True
if processed_text and self.add_table_comment(
table, processed_text, suggest, author
):
return True
return False
# 根据text_selection批注
def set_comment_by_text_selection(self, text_sel, author, comment_content):
if text_sel is None:
return False
text_range = text_sel.GetAsOneRange()
return self._insert_comment_by_text_range(text_range, author, comment_content)
# 设置chunk批注
def add_table_comment(
......@@ -518,47 +589,15 @@ class SpireWordDoc(DocBase):
# 在段落中查找目标文本
selection = para.Find(target_text, False, True)
if selection:
# 获取文本范围
text_range = selection.GetAsOneRange()
if text_range is None:
continue
# 获取所属段落
paragraph = text_range.OwnerParagraph
if paragraph is None:
continue
# 创建一个评论对象并设置评论的内容和作者
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_text
comment.Format.Author = author
# 将评论添加到段落中
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range) + 1, comment
)
# 创建评论起始标记和结束标记
commentStart = CommentMark(
self._doc, CommentMarkType.CommentStart
)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在找到的文本之前和之后插入创建的评论起始和结束标记
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range), commentStart
)
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range) + 1, commentEnd
)
added = True
# print(f"表格批注添加成功: '{target_text[:20]}...'")
# 添加成功后跳出内层循环
break
if self._insert_comment_by_text_range(
text_range, author, comment_text
):
added = True
# print(f"表格批注添加成功: '{target_text[:20]}...'")
# 添加成功后跳出内层循环
break
# 如果已经添加,跳出单元格循环
if added:
......@@ -574,16 +613,14 @@ class SpireWordDoc(DocBase):
"""
为chunk添加批注(保证每条评论只批注一次)
"""
if chunk_id is not None:
sub_chunks = self.get_sub_chunks(chunk_id)
for comment in comments:
if comment.get("result") != "不合格":
continue
# update chunk_id
chunk_id = comment.get("chunk_id", -1)
if chunk_id is not None and chunk_id != -1:
sub_chunks = self.get_sub_chunks(chunk_id)
comment_chunk_id = comment.get("chunk_id", -1)
# 优先使用comments里提供的chunk_id,如果没有或无效则使用外部传入的chunk_id,如果都没有则异常处理
sub_chunks = self.get_sub_chunks(comment_chunk_id) if comment_chunk_id != -1 \
and comment_chunk_id < self.get_chunk_num() else self.get_sub_chunks(chunk_id)
author = self.format_comment_author(comment)
suggest = comment.get("suggest", "")
find_key = comment["original_text"].strip() or comment["key_points"]
......@@ -592,79 +629,22 @@ class SpireWordDoc(DocBase):
existing_comment_idx = self.find_comment(author)
if existing_comment_idx is not None:
# 已存在批注,则更新内容
self._doc.Comments.get_Item(
existing_comment_idx
).Body.Paragraphs.get_Item(0).Text = suggest
self._update_comment_content(existing_comment_idx, suggest)
# print(f"批注已存在,更新内容: '{find_key[:20]}...'")
continue
matched = False
exact_author = self._decorate_author_with_match_type(author, "exact")
fuzzy_author = self._decorate_author_with_match_type(author, "fuzzy")
# ---------- 1. 精确匹配(段落 + 表格) ----------
for obj in sub_chunks:
if isinstance(obj, Paragraph):
try:
text_sel = obj.Find(find_key, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
# print(f"段落批注添加成功: '{find_key[:20]}...'")
matched = True
break
except Exception as e:
print(f"段落批注添加失败: {str(e)}")
elif isinstance(obj, Table):
try:
if self.add_table_comment(obj, find_key, suggest, author):
# print(f"表格批注添加成功: '{find_key[:20]}...'")
matched = True
break
except Exception as e:
print(f"表格批注添加失败: {str(e)}")
matched = self._try_add_comment_by_exact(
sub_chunks, find_key, exact_author, suggest
)
# ---------- 2. 模糊匹配 ----------
if not matched:
try:
paragraphs_only = [
obj for obj in sub_chunks if isinstance(obj, Paragraph)
]
match_text, _ = find_best_match(paragraphs_only, comment)
if match_text:
for obj in paragraphs_only:
text_sel = obj.Find(match_text, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
# print(f"模糊批注添加成功: '{match_text[:20]}...'")
matched = True
break
if not matched:
processed_text = process_string(match_text)
for obj in paragraphs_only:
text_sel = obj.Find(processed_text, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
# print(f"处理后批注添加成功: '{processed_text[:20]}...'")
matched = True
break
# 表格模糊匹配(仅段落模糊匹配失败才跑)
if not matched:
for obj in sub_chunks:
if isinstance(obj, Table):
table_data = extract_table_cells_text(obj)
best_table_match, _ = table_contract(
table_data, comment
)
if best_table_match and self.add_table_comment(
obj, best_table_match, suggest, author
):
# print(f"表格批注添加成功: '{best_table_match[:20]}...'")
matched = True
break
matched = self._try_add_comment_by_fuzzy(
sub_chunks, comment, fuzzy_author, suggest
)
except Exception as e:
print(f"模糊匹配失败: {str(e)}")
......@@ -674,10 +654,11 @@ class SpireWordDoc(DocBase):
# 根据作者名称查找批注
def find_comment(self, author):
normalized_author = self._normalize_author_prefix(author)
for i in range(self._doc.Comments.Count):
current_comment = self._doc.Comments.get_Item(i)
comment_author = current_comment.Format.Author
if comment_author == author:
comment_author = self._normalize_author_prefix(current_comment.Format.Author)
if comment_author == normalized_author:
return i
return None
......@@ -711,9 +692,7 @@ class SpireWordDoc(DocBase):
# 不合格,更新或新增
suggest = comment.get("suggest", "")
if existing_comment_idx is not None:
self._doc.Comments.get_Item(
existing_comment_idx
).Body.Paragraphs.get_Item(0).Text = suggest
self._update_comment_content(existing_comment_idx, suggest)
# print(f"更新已有批注: '{author}'")
else:
# chunk_id要从comment中获取
......@@ -747,22 +726,21 @@ class SpireWordDoc(DocBase):
if __name__ == "__main__":
doc = SpireWordDoc()
doc.load(
r"/home/ccran/lufa-contract/datasets/2-1.合同审核-近三年审核前的合同文件/20230101 麓谷发展视频合同书20230209(1).doc"
r"/home/ccran/lufa-contract/demo/今麦郎合同审核.docx"
)
print(doc._doc_name)
# print(doc.get_chunk_info(4))
doc.add_chunk_comment(
0,
[
{
"id": "1",
"key_points": "主体资格审查",
"original_text": "湖南麓谷发展集团有限公司",
"key_points": "日期审查",
"original_text": "承诺",
"details": "1111",
"chunk_id": 1,
"chunk_id": 0,
"result": "不合格",
"suggest": "这是测试建议",
}
],
)
doc.to_file("test.docx", True)
doc.to_file("/home/ccran/lufa-contract/demo/今麦郎合同审核_test.docx", True)
\ No newline at end of file
from spire.doc import Document, Paragraph, Table, Comment, CommentMark, CommentMarkType
import json
from loguru import logger
import re
from thefuzz import fuzz
from utils.doc_util import DocBase
from utils.common_util import adjust_single_chunk_size
import os
def extract_table_cells_text(table, joiner="\n"):
"""
从 Spire.Doc 的 Table 对象中提取每个单元格文本,并按行主序返回扁平列表:
["r0c0_text", "r0c1_text", "r1c0_text", ...]
joiner: 用于连接单元格内多段落或嵌套表行的分隔符(默认换行)
注意:不对文本做任何清洗或 strip,保持原有格式
"""
def _para_text(para):
# 优先使用 para.Text(保留原样),否则尝试从 para.ChildObjects 收集 Text-like 字段
try:
if hasattr(para, "Text"):
return para.Text if para.Text is not None else ""
except Exception:
pass
parts = []
try:
for idx in range(para.ChildObjects.Count):
obj = para.ChildObjects[idx]
if hasattr(obj, "Text"):
parts.append(obj.Text if obj.Text is not None else "")
except Exception:
pass
return "".join(parts)
def _extract_cell_text(cell):
parts = []
# 收集单元格内所有段落文本(保持原样,不做 strip)
try:
for p_idx in range(cell.Paragraphs.Count):
para = cell.Paragraphs[p_idx]
parts.append(_para_text(para))
except Exception:
pass
# 处理嵌套表格(若存在),把嵌套表每一行合并为一条字符串,并按行加入 parts
try:
if hasattr(cell, "Tables") and cell.Tables.Count > 0:
for t_idx in range(cell.Tables.Count):
nested = cell.Tables[t_idx]
nested_rows = []
for nr in range(nested.Rows.Count):
nested_row_cells = []
for nc in range(nested.Rows[nr].Cells.Count):
try:
# 取嵌套单元格的所有段落并用 joiner 连接(保留原样)
nc_parts = []
for np_idx in range(
nested.Rows[nr].Cells[nc].Paragraphs.Count
):
nc_parts.append(
_para_text(
nested.Rows[nr].Cells[nc].Paragraphs[np_idx]
)
)
nested_row_cells.append(joiner.join(nc_parts))
except Exception:
nested_row_cells.append("")
nested_rows.append(joiner.join(nested_row_cells))
parts.append(joiner.join(nested_rows))
else:
# 有时嵌套表格会放在 cell.ChildObjects 中,兼容处理
try:
for idx in range(cell.ChildObjects.Count):
ch = cell.ChildObjects[idx]
if hasattr(ch, "Rows") and getattr(ch, "Rows") is not None:
nested = ch
nested_rows = []
for nr in range(nested.Rows.Count):
nested_row_cells = []
for nc in range(nested.Rows[nr].Cells.Count):
try:
nc_parts = []
for np_idx in range(
nested.Rows[nr].Cells[nc].Paragraphs.Count
):
nc_parts.append(
_para_text(
nested.Rows[nr]
.Cells[nc]
.Paragraphs[np_idx]
)
)
nested_row_cells.append(joiner.join(nc_parts))
except Exception:
nested_row_cells.append("")
nested_rows.append(joiner.join(nested_row_cells))
parts.append(joiner.join(nested_rows))
except Exception:
pass
except Exception:
pass
# 把单元格内收集到的片段用 joiner 连接成最终字符串(不做任何 trim/clean)
return joiner.join(parts)
flat = []
for r in range(table.Rows.Count):
row = table.Rows[r]
for c in range(row.Cells.Count):
cell = row.Cells[c]
cell_text = _extract_cell_text(cell)
# 保持原样,空单元格返回空字符串
flat.append(cell_text)
return flat
def process_string(s):
# 统计换行符数量
newline_count = s.count("\n")
# 情况1:没有换行符
if newline_count == 0:
return s
# 情况2:只有一个换行符
elif newline_count == 1:
# 分割成两部分
parts = s.split("\n", 1)
# 比较前后部分长度
return parts[0] if len(parts[0]) >= len(parts[1]) else parts[1]
# 情况3:多个换行符
else:
# 分割所有部分
parts = s.split("\n")
# 找出中间部分(排除首尾)
middle_parts = parts[1:-1] if len(parts) > 2 else []
# 如果没有有效中间部分
if not middle_parts:
# 返回最长的一段(排除空字符串)
non_empty_parts = [p for p in parts if p]
return max(non_empty_parts, key=len) if non_empty_parts else ""
# 返回最长的中间部分
return max(middle_parts, key=len, default="")
def build_mapping(original: str):
"""构造规范化文本和原文索引映射"""
normalized = []
mapping = []
for m in re.finditer(r"\S+", original):
word = m.group()
if normalized:
normalized.append(" ")
mapping.append(m.start()) # 空格映射
for j, ch in enumerate(word):
normalized.append(ch)
mapping.append(m.start() + j)
return "".join(normalized), mapping
def extract_match(big_text: str, small_text: str, threshold=20):
"""
简化版文本匹配函数
核心逻辑:优先整个文本块匹配,次优子句匹配
"""
# 1. 精确匹配整个文本块
if small_text in big_text:
return small_text, 100
# 2. 整个文本块模糊匹配
full_score = fuzz.ratio(big_text, small_text)
if full_score >= threshold:
return big_text, full_score
# 3. 子句匹配(简单分割)
best_score = 0
best_clause = None
# 简单分割:按句号、分号、逗号分割
for clause in big_text.replace("。", ";").replace(",", ";").split(";"):
if not clause.strip():
continue
clause_score = fuzz.ratio(clause, small_text)
if clause_score > best_score:
best_score = clause_score
best_clause = clause
# 4. 返回最佳匹配
if best_score >= threshold:
return best_clause, best_score
# 5. 无有效匹配
return None, max(full_score, best_score)
def find_best_match(sub_chunks, comment):
"""
在给定的文本块中查找与原始评论最匹配的文本
参数:
sub_chunks -- 包含Text属性的对象列表
comment -- 包含"original_text"的字典
返回:
best_match -- 匹配度最高的文本
best_score -- 最高匹配度
all_results -- 所有匹配结果列表(匹配文本, 相似度)
"""
all_results = [] # 存储所有(匹配文本, 相似度)的元组
best_match = None # 存储最佳匹配的结果
best_score = -1 # 存储最高相似度(初始化为-1)
# print(f"开始处理评论: {comment['original_text'][:30]}...") # 显示简化的原始评论
for obj in sub_chunks:
if isinstance(obj, Paragraph):
target_text = obj.Text
original_text = comment["original_text"]
match_text, score = extract_match(target_text, original_text)
# 打印当前结果(保持原格式)
# print("匹配到:\n", match_text)
# print("相似度:", score)
# 存储所有结果
all_results.append((match_text, score))
# 更新最佳匹配 - 只更新分数更高的结果
if score > best_score:
best_match = match_text
best_score = score
# 打印最终的最佳匹配结果
# print("\n" + "=" * 40)
# print("\n处理完成 - 最佳匹配结果:")
# print("匹配到:\n", best_match)
# print("相似度:", best_score)
# print("=" * 40 + "\n")
return best_match, best_score
def table_contract(target_texts, comment):
"""
在给定的文本块中查找与原始评论最匹配的文本
参数:
sub_chunks -- 待对比文本
comment -- 包含"original_text"的字典
返回:
best_match -- 匹配度最高的文本
best_score -- 最高匹配度
all_results -- 所有匹配结果列表(匹配文本, 相似度)
"""
all_results = [] # 存储所有(匹配文本, 相似度)的元组
best_match = None # 存储最佳匹配的结果
best_score = -1 # 存储最高相似度(初始化为-1)
# print(f"开始处理评论: {comment['original_text'][:30]}...") # 显示简化的原始评论
original_text = comment["original_text"]
for target_text in target_texts:
match_text, score = extract_match(target_text, original_text)
# 打印当前结果(保持原格式)
# print("匹配到:\n", match_text)
# print("相似度:", score)
# 存储所有结果
all_results.append((match_text, score))
# 更新最佳匹配 - 只更新分数更高的结果
if score > best_score:
best_match = match_text
best_score = score
# 打印最终的最佳匹配结果
# print("\n" + "=" * 40)
# print("\n处理完成 - 最佳匹配结果:")
# print("匹配到:\n", best_match)
# print("相似度:", best_score)
# print("=" * 40 + "\n")
return best_match, best_score
# spire doc解析
class SpireWordDoc(DocBase):
def load(self, doc_path, **kwargs):
# License.SetLicenseFileFullPath(f"{root_path}/license.elic.python.xml")
self._doc_path = doc_path
self._doc_name = os.path.basename(doc_path)
self._doc = Document()
self._doc.LoadFromFile(doc_path)
self._chunk_list = self._resolve_doc_chunk()
return self
def _ensure_loaded(self):
if not self._doc:
raise RuntimeError("Document not loaded. Call load() first.")
def adjust_chunk_size(self):
self._ensure_loaded()
all_text_len = len(self.get_all_text())
self._max_single_chunk_size = adjust_single_chunk_size(all_text_len)
logger.info(
f"SpireWordDoc adjust _max_single_chunk_size to {self._max_single_chunk_size}"
)
self._chunk_list = self._resolve_doc_chunk()
return self._max_single_chunk_size
async def get_from_ocr(self):
pass
# 把文档分割成chunk
def _resolve_doc_chunk(self):
self._ensure_loaded()
chunk_list = []
# 单个chunk
single_chunk = ""
# 单个chunk的位置信息
single_chunk_location = []
# 遍历每个节
for section_idx in range(self._doc.Sections.Count):
current_section = self._doc.Sections.get_Item(section_idx)
# 遍历节里面每个子对象
for section_child_idx in range(current_section.Body.ChildObjects.Count):
# 获取子对象
child_obj = current_section.Body.ChildObjects.get_Item(
section_child_idx
)
# 段落处理
current_child_text = ""
if isinstance(child_obj, Paragraph):
paragraph = child_obj
current_child_text = paragraph.Text
# 表格处理
elif isinstance(child_obj, Table):
table = child_obj
current_child_text = self._resolve_table(table)
# 跳过其他非文本子对象
else:
continue
# 添加新对象
if (
len(single_chunk) + len(current_child_text)
> self._max_single_chunk_size
):
chunk_list.append(
{
"chunk_content": single_chunk,
"chunk_location": single_chunk_location,
}
)
single_chunk = ""
single_chunk_location = []
single_chunk += current_child_text + "\n"
single_chunk_location.append(
{"section_idx": section_idx, "section_child_idx": section_child_idx}
)
if len(single_chunk):
chunk_list.append(
{"chunk_content": single_chunk, "chunk_location": single_chunk_location}
)
return chunk_list
# 表格解析为markdown
def _resolve_table(self, table):
table_data = ""
for i in range(0, table.Rows.Count):
# 遍历行的单元格(cells)
cell_list = []
for j in range(0, table.Rows.get_Item(i).Cells.Count):
# 获取每一个单元格(cell)
cell = table.Rows.get_Item(i).Cells.get_Item(j)
cell_content = ""
for para_idx in range(cell.Paragraphs.Count):
paragraph_text = cell.Paragraphs.get_Item(para_idx).Text
cell_content += paragraph_text
cell_list.append(cell_content)
# table_data += "|" + "|".join(cell_list) + "|"
# table_data += "\n"
table_data += ' '.join(cell_list) + '\n'
if i == 0:
# table_data += "|" + "|".join(["--- " for _ in cell_list]) + "|\n"
table_data= ' '.join(cell_list) + '\n'
return table_data
def get_chunk_info(self, chunk_id):
chunk = self._chunk_list[chunk_id]
chunk_content = chunk["chunk_content"]
chunk_location = chunk["chunk_location"]
from_location = f"[第{chunk_location[0]['section_idx'] + 1}节的第{chunk_location[0]['section_child_idx'] + 1}段落]"
to_location = f"[第{chunk_location[-1]['section_idx'] + 1}节的第{chunk_location[-1]['section_child_idx'] + 1}段落]"
chunk_content_tips = (
"[" + chunk_content[:20] + "]...到...[" + chunk_content[-20:] + "]"
)
return f"文件块id: {chunk_id + 1}\n文件块位置: 从{from_location}到{to_location}\n文件块简述: {chunk_content_tips}\n"
def get_chunk_location(self, chunk_id):
return self.get_chunk_info(chunk_id)
def get_chunk_num(self):
self._ensure_loaded()
return len(self._chunk_list)
def get_chunk_item(self, chunk_id):
self._ensure_loaded()
return self._chunk_list[chunk_id]["chunk_content"]
# 根据locations获取数据
def get_sub_chunks(self, chunk_id):
if chunk_id >= len(self._chunk_list):
logger.error(f"get_sub_chunks_error:{chunk_id}")
return []
chunk = self._chunk_list[chunk_id]
chunk_locations = chunk["chunk_location"]
return [
self._doc.Sections.get_Item(loc["section_idx"]).Body.ChildObjects.get_Item(
loc["section_child_idx"]
)
for loc in chunk_locations
]
def format_comment_author(self, comment):
return "{}|{}".format(str(comment["id"]), comment["key_points"])
def remove_comment_prefix(
self,
):
for i in range(self._doc.Comments.Count):
current_comment = self._doc.Comments.get_Item(i)
comment_author = current_comment.Format.Author
split_author = comment_author.split("|")
if len(split_author) == 2:
current_comment.Format.Author = comment_author.split("|")[1]
# 根据text_selection批注
def set_comment_by_text_selection(self, text_sel, author, comment_content):
if text_sel is None:
return False
# 将找到的文本作为文本范围,并获取其所属的段落
range = text_sel.GetAsOneRange()
paragraph = range.OwnerParagraph
if paragraph is None:
return False
# 创建一个评论对象并设置评论的内容和作者
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_content
comment.Format.Author = author
# logger.info(author)
# 将评论添加到段落中
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range) + 1, comment
)
# 创建评论起始标记和结束标记,并将它们设置为创建的评论的起始标记和结束标记
commentStart = CommentMark(self._doc, CommentMarkType.CommentStart)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在找到的文本之前和之后插入创建的评论起始和结束标记
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range), commentStart
)
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range) + 1, commentEnd
)
return True
# 根据段落批注
def set_comment_by_paragraph(self, paragraph, author, comment_content):
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_content
# 设置注释的作者
comment.Format.Author = author
paragraph.ChildObjects.Add(comment)
# 创建注释开始标记和结束标记,并将它们设置为创建的注释的开始和结束标记
commentStart = CommentMark(self._doc, CommentMarkType.CommentStart)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在段落结尾插入注释开始标记和结束标记
# paragraph.ChildObjects.Add(commentStart)
paragraph.ChildObjects.Add(commentEnd)
# 也可以考虑在段落开始处插入标记
paragraph.ChildObjects.Insert(0, commentStart)
# 设置chunk批注
def add_table_comment(
self, table, target_text, comment_text, author="审阅助手", initials="AI"
):
"""
在表格中添加批注
返回是否成功添加
"""
added = False
# 遍历表格所有单元格
for i in range(table.Rows.Count):
row = table.Rows[i]
for j in range(row.Cells.Count):
cell = row.Cells[j]
# 遍历单元格中的段落
for k in range(cell.Paragraphs.Count):
para = cell.Paragraphs[k]
# 在段落中查找目标文本
selection = para.Find(target_text, False, True)
if selection:
# 获取文本范围
text_range = selection.GetAsOneRange()
if text_range is None:
continue
# 获取所属段落
paragraph = text_range.OwnerParagraph
if paragraph is None:
continue
# 创建一个评论对象并设置评论的内容和作者
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_text
comment.Format.Author = author
# 将评论添加到段落中
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range) + 1, comment
)
# 创建评论起始标记和结束标记
commentStart = CommentMark(
self._doc, CommentMarkType.CommentStart
)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在找到的文本之前和之后插入创建的评论起始和结束标记
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range), commentStart
)
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range) + 1, commentEnd
)
added = True
# print(f"表格批注添加成功: '{target_text[:20]}...'")
# 添加成功后跳出内层循环
break
# 如果已经添加,跳出单元格循环
if added:
break
# 如果已经添加,跳出行循环
if added:
break
return added
def add_chunk_comment(self, chunk_id, comments):
"""
为chunk添加批注(保证每条评论只批注一次)
"""
if chunk_id is not None:
sub_chunks = self.get_sub_chunks(chunk_id)
for comment in comments:
if comment.get("result") != "不合格":
continue
# update chunk_id
chunk_id = comment.get("chunk_id", -1)
if chunk_id is not None and chunk_id != -1:
sub_chunks = self.get_sub_chunks(chunk_id)
author = self.format_comment_author(comment)
suggest = comment.get("suggest", "")
find_key = comment["original_text"].strip() or comment["key_points"]
# 先检查是否已经有批注
existing_comment_idx = self.find_comment(author)
if existing_comment_idx is not None:
# 已存在批注,则更新内容
self._doc.Comments.get_Item(
existing_comment_idx
).Body.Paragraphs.get_Item(0).Text = suggest
# print(f"批注已存在,更新内容: '{find_key[:20]}...'")
continue
matched = False
# ---------- 1. 精确匹配(段落 + 表格) ----------
for obj in sub_chunks:
if isinstance(obj, Paragraph):
try:
text_sel = obj.Find(find_key, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
# print(f"段落批注添加成功: '{find_key[:20]}...'")
matched = True
# 第一个找到的作为标注对象
break
except Exception as e:
print(f"段落批注添加失败: {str(e)}")
elif isinstance(obj, Table):
try:
if self.add_table_comment(obj, find_key, suggest, author):
# 第一个找到的表格对象作为批注对象
matched = True
break
except Exception as e:
print(f"表格批注添加失败: {str(e)}")
# ---------- 2. 模糊匹配 ----------
if not matched:
try:
paragraphs_only = [
obj for obj in sub_chunks if isinstance(obj, Paragraph)
]
match_text, _ = find_best_match(paragraphs_only, comment)
if match_text:
for obj in paragraphs_only:
text_sel = obj.Find(match_text, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
# print(f"模糊批注添加成功: '{match_text[:20]}...'")
matched = True
break
if not matched:
processed_text = process_string(match_text)
for obj in paragraphs_only:
text_sel = obj.Find(processed_text, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
# print(f"处理后批注添加成功: '{processed_text[:20]}...'")
matched = True
break
# 表格模糊匹配(仅段落模糊匹配失败才跑)
if not matched:
for obj in sub_chunks:
if isinstance(obj, Table):
table_data = extract_table_cells_text(obj)
best_table_match, _ = table_contract(
table_data, comment
)
if best_table_match and self.add_table_comment(
obj, best_table_match, suggest, author
):
# print(f"表格批注添加成功: '{best_table_match[:20]}...'")
matched = True
break
except Exception as e:
print(f"模糊匹配失败: {str(e)}")
# ---------- 3. 匹配最终失败 ----------
if not matched:
logger.error(f"未找到可批注位置: '{find_key[:20]}...'")
# 根据作者名称查找批注
def find_comment(self, author):
for i in range(self._doc.Comments.Count):
current_comment = self._doc.Comments.get_Item(i)
comment_author = current_comment.Format.Author
if comment_author == author:
return i
return None
def delete_chunk_comment(self, comments):
"""
删除指定作者批注
"""
for comment in comments:
author = self.format_comment_author(comment)
author_comment_idx = self.find_comment(author)
if author_comment_idx is not None:
self._doc.Comments.RemoveAt(author_comment_idx)
print(f"删除批注: '{author}'")
def edit_chunk_comment(self, comments):
"""
编辑chunk批注:删除已合格的批注,修改存在的批注,不存在则新增
"""
for comment in comments:
author = self.format_comment_author(comment)
review_answer = comment["result"]
existing_comment_idx = self.find_comment(author)
if review_answer == "合格":
# 删除批注
if existing_comment_idx is not None:
self._doc.Comments.RemoveAt(existing_comment_idx)
# print(f"已删除合格批注: '{author}'")
else:
# 不合格,更新或新增
suggest = comment.get("suggest", "")
if existing_comment_idx is not None:
self._doc.Comments.get_Item(
existing_comment_idx
).Body.Paragraphs.get_Item(0).Text = suggest
# print(f"更新已有批注: '{author}'")
else:
# chunk_id要从comment中获取
self.add_chunk_comment(comment["chunk_id"] - 1, [comment])
def get_chunk_id_list(self, step=1):
self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self):
self._ensure_loaded()
return self._doc.GetText()
def to_file(self, path, remove_prefix=False):
self._ensure_loaded()
if remove_prefix:
self.remove_comment_prefix()
self._doc.SaveToFile(path)
def release(self):
# 关闭文件
if self._doc:
self._doc.Close()
super().release()
def __del__(self):
pass
# self.release()
if __name__ == "__main__":
doc = SpireWordDoc()
doc.load(
r"/home/ccran/lufa-contract/demo/今麦郎合同审核.docx"
)
print(doc._doc_name)
doc.add_chunk_comment(
0,
[
{
"id": "1",
"key_points": "日期审查",
"original_text": "承诺",
"details": "1111",
"chunk_id": 0,
"result": "不合格",
"suggest": "这是测试建议",
}
],
)
doc.to_file("/home/ccran/lufa-contract/demo/今麦郎合同审核_test.docx", True)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment