Commit 4441c13b by ccran

feat: 增加案例;修改审查提示词;

parent bd17ac4b
{ {
"python-envs.defaultEnvManager": "ms-python.python:conda", "python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda" "python-envs.defaultPackageManager": "ms-python.python:conda",
"python.defaultInterpreterPath": "/home/ccran/.conda/envs/lufa/bin/python",
"python.terminal.activateEnvironment": true
} }
\ No newline at end of file
...@@ -3,22 +3,6 @@ from dataclasses import dataclass ...@@ -3,22 +3,6 @@ from dataclasses import dataclass
# 可配置运行参数 # 可配置运行参数
use_docker = False use_docker = False
just_b_class = False
is_extract = True
use_original_text_verification = False
# @dataclass
# class LLMConfig:
# base_url: str = "http://172.21.107.45:9002/v1"
# api_key: str = "none"
# model: str = "Qwen3-32B"
#
#
# base_fastgpt_url = "http://172.21.107.45:3030"
# base_backend_url = "http://172.21.107.45:48080"
# ocr_url = 'http://172.21.107.45:8202/openapi/ocrUploadFile'
@dataclass @dataclass
class LLMConfig: class LLMConfig:
...@@ -27,32 +11,19 @@ class LLMConfig: ...@@ -27,32 +11,19 @@ class LLMConfig:
model: str = 'Qwen2-72B-Instruct' model: str = 'Qwen2-72B-Instruct'
# MAX_SINGLE_CHUNK_SIZE=100000 # MAX_SINGLE_CHUNK_SIZE=100000
# MAX_SINGLE_CHUNK_SIZE=5000 MERGE_RULE_PROMPT = False
MAX_SINGLE_CHUNK_SIZE=2000 MAX_SINGLE_CHUNK_SIZE=5000
META_KEY="META"
DEFAULT_RULESET_ID = "通用" DEFAULT_RULESET_ID = "通用"
ALL_RULESET_IDS = ["通用","借款","担保","测试","财务口","金盘","金盘简化"] ALL_RULESET_IDS = ["通用","借款","担保","财务口","金盘","金盘简化"]
FACT_DIMENSIONS = [ use_lufa = True
"当事人",
"标的",
"金额",
"支付",
"期限",
"交付",
"质量",
"知识产权",
"保密",
"违约责任",
"争议解决"
]
use_lufa = False
if use_lufa: if use_lufa:
outer_backend_url = "http://znkf.lgfzgroup.com:48081" outer_backend_url = "http://znkf.lgfzgroup.com:48081"
base_fastgpt_url = "http://192.168.252.71:18089" base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081" base_backend_url = "http://192.168.252.71:48081"
api_key = "fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8" api_key = "fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8"
else: else:
outer_backend_url = "http://218.77.58.8:8088" outer_backend_url = "http://218.77.58.8:48080"
base_fastgpt_url = "http://192.168.252.71:18088" base_fastgpt_url = "http://192.168.252.71:18088"
base_backend_url = "http://192.168.252.71:48080" base_backend_url = "http://192.168.252.71:48080"
api_key = "fastgpt-vLu2JHAfqwEq5FUQhvATFDK0yDS6fs804v7KwWBMyU4sRrHzh4UGl89Zpa" api_key = "fastgpt-vLu2JHAfqwEq5FUQhvATFDK0yDS6fs804v7KwWBMyU4sRrHzh4UGl89Zpa"
...@@ -78,40 +49,7 @@ LLM = { ...@@ -78,40 +49,7 @@ LLM = {
} }
doc_support_formats = [".docx", ".doc", ".wps"] doc_support_formats = [".docx", ".doc", ".wps"]
pdf_support_formats = [".txt", ".md", ".pdf"] pdf_support_formats = [".txt", ".md", ".pdf"]
# excel字段
field_mapping = {
"review": {
"id": "序号",
"key_points": "审查内容",
"original_text": "合同原文",
"details": "审查过程",
"result": "审查结果",
"suggest": "审查建议",
"chunk_id": "文件块序号",
"chunk_info": "文件块信息",
},
"category": {"text": "原文", "judge": "判断依据", "chunk_location": "原文位置"},
}
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5
# 销售类别判断
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思"
# 最大分片数量 # 最大分片数量
min_single_chunk_size = 2000 min_single_chunk_size = 2000
max_single_chunk_size = 20000 max_single_chunk_size = 20000
max_chunk_page = 10 max_chunk_page = 10
# 知识库读取sheet
if just_b_class:
know_sheet_name = "B类"
port = 9008
else:
if is_extract:
know_sheet_name = "提取审查"
port = 9016
else:
know_sheet_name = 0
port = 9006
# know_sheet_name = '发票审查'
reload = False
"""
{
"segment_id": "S12",
"topics": ["付款"],
"summary": {
"one_liner": "...",
"structured": ["...", "...", "..."]
},
"risks": [
{"flag":"...", "level":"M", "evidence":"...", "suggestion":"..."}
],
"dependencies": ["S11"],
"open_questions": ["..."],
"evidence_quotes": ["..."]
}
segment_id
含义:该记忆对应的合同分段唯一标识。
注意:用于前后段上下文关联,必须与分段阶段生成的 ID 一致。
topics
含义:本分段涉及的合同主题标签。
注意:从固定枚举中选择(如付款/违约/保密),用于后续记忆检索与一致性校验。
summary
summary.one_liner
含义:一句话概括本段条款的核心约定内容。
注意:只描述事实,不做风险判断。
summary.structured
含义:本段条款的结构化要点列表(3–5 条)。
注意:每条对应一个业务点,便于检索和上下文理解。
risks
含义:本分段确认存在的风险点列表。
字段说明:
flag:风险标签(简短描述问题)
level:风险等级(H / M / L)
evidence:支撑该风险判断的原文依据
suggestion:可直接落地的修改或谈判建议
注意:每条风险必须有明确证据和可执行建议。
dependencies
含义:本分段理解或适用所依赖的其他条款或附件。
注意:用于后续自动回读相关条款,防止断章取义。
open_questions
含义:本分段中未明确、需对方补充或澄清的信息。
注意:不确定内容不要写成风险结论,统一放在这里。
evidence_quotes
含义:从合同原文中摘录的关键短句,用于支撑摘要和风险判断。
注意:用于审查复核与反思重试,避免“凭空总结”。
"""
from __future__ import annotations from __future__ import annotations
import json import json
...@@ -59,7 +10,7 @@ from typing import Any, Dict, Iterable, List, Optional ...@@ -59,7 +10,7 @@ from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file from utils.http_util import upload_file
from utils.doc_util import DocBase from utils.doc_util import DocBase
from core.config import FACT_DIMENSIONS from core.config import META_KEY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -69,13 +20,14 @@ _ALLOWED_RISK_LEVELS = {"H", "M", "L",""} ...@@ -69,13 +20,14 @@ _ALLOWED_RISK_LEVELS = {"H", "M", "L",""}
@dataclass @dataclass
class RiskFinding: class Finding:
rule_title: str rule_title: str
segment_id: int segment_id: int
original_text: str original_text: str
issue: str issue: str
risk_level: str risk_level: str
suggestion: str suggestion: str
result: str = ""
def __post_init__(self) -> None: def __post_init__(self) -> None:
level = (self.risk_level or "").upper() level = (self.risk_level or "").upper()
...@@ -84,7 +36,7 @@ class RiskFinding: ...@@ -84,7 +36,7 @@ class RiskFinding:
self.risk_level = level self.risk_level = level
@classmethod @classmethod
def from_dict(cls, data: Dict) -> "RiskFinding": def from_dict(cls, data: Dict) -> "Finding":
data = data or {} data = data or {}
return cls( return cls(
rule_title=str(data.get("rule_title", "")), rule_title=str(data.get("rule_title", "")),
...@@ -93,10 +45,14 @@ class RiskFinding: ...@@ -93,10 +45,14 @@ class RiskFinding:
issue=str(data.get("issue", "")), issue=str(data.get("issue", "")),
risk_level=str(data.get("risk_level", "")), risk_level=str(data.get("risk_level", "")),
suggestion=str(data.get("suggestion", "")), suggestion=str(data.get("suggestion", "")),
result=str(data.get("result", "")),
) )
def __repr__(self): def __repr__(self):
return f"RiskFinding(rule_title={self.rule_title!r}, segment_id={self.segment_id}, issue={self.issue!r}, risk_level={self.risk_level!r})" return (
f"Finding(rule_title={self.rule_title!r}, segment_id={self.segment_id}, "
f"issue={self.issue!r}, risk_level={self.risk_level!r}, result={self.result!r})"
)
@dataclass @dataclass
...@@ -111,8 +67,8 @@ class MemoryStore: ...@@ -111,8 +67,8 @@ class MemoryStore:
self._storage_path.parent.mkdir(parents=True, exist_ok=True) self._storage_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = RLock() self._lock = RLock()
self.facts: List[Dict[str, Any]] = [] self.facts: List[Dict[str, Any]] = []
self.findings: List[RiskFinding] = [] self.findings: List[Finding] = []
self.final_findings: List[RiskFinding] = [] self.final_findings: List[Finding] = []
self._load() self._load()
# ---------------------- facts ---------------------- # ---------------------- facts ----------------------
...@@ -134,45 +90,55 @@ class MemoryStore: ...@@ -134,45 +90,55 @@ class MemoryStore:
def search_facts(self, keywords: List[str]) -> List[Any]: def search_facts(self, keywords: List[str]) -> List[Any]:
keys = [str(k).strip().lower() for k in (keywords or []) if str(k).strip()] keys = [str(k).strip().lower() for k in (keywords or []) if str(k).strip()]
allowed_keys = {str(k).strip().lower() for k in FACT_DIMENSIONS if str(k).strip()} if not keys:
requested_keys = {k for k in keys if k in allowed_keys}
with self._lock:
candidates = list(self.facts)
if not requested_keys:
return [] return []
def _key_match(name: Any) -> bool:
key_name = str(name).strip().lower()
return bool(key_name) and any(k in key_name or key_name in k for k in keys)
matched_values: List[Any] = [] matched_values: List[Any] = []
for item in candidates: with self._lock:
all_facts = list(self.facts)
for item in all_facts:
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
for key, value in item.items():
normalized_key = str(key).strip().lower() for top_key, top_value in item.items():
if normalized_key in requested_keys: if _key_match(top_key):
matched_values.append(value) matched_values.append({
top_key: top_value,
META_KEY: item.get(META_KEY, {}) # include metadata if exists
})
return matched_values return matched_values
# -------------------- findings --------------------- # -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding: def add_finding(self, finding: Finding) -> Finding:
return self._add_finding(self.findings, finding) return self._add_finding(self.findings, finding)
def add_finding_from_dict(self, data: Dict) -> RiskFinding: def add_finding_from_dict(self, data: Dict) -> Finding:
return self.add_finding(RiskFinding.from_dict(data)) return self.add_finding(Finding.from_dict(data))
def add_final_finding(self, finding: RiskFinding) -> RiskFinding: def add_final_finding(self, finding: Finding) -> Finding:
return self._add_finding(self.final_findings, finding) return self._add_finding(self.final_findings, finding)
def add_final_finding_from_dict(self, data: Dict) -> RiskFinding: def add_final_finding_from_dict(self, data: Dict) -> Finding:
return self.add_final_finding(RiskFinding.from_dict(data)) return self.add_final_finding(Finding.from_dict(data))
def list_findings(self) -> List[RiskFinding]: def list_findings(self) -> List[Finding]:
return self._list_findings(self.findings) return self._list_findings(self.findings)
def list_final_findings(self) -> List[RiskFinding]: def list_final_findings(self) -> List[Finding]:
return self._list_findings(self.final_findings) return self._list_findings(self.final_findings)
def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]: def get_findings_by_segment(self, segment_id: int) -> List[Finding]:
return self._get_findings_by_segment(self.findings, segment_id) return self._get_findings_by_segment(self.findings, segment_id)
def get_final_findings_by_segment(self, segment_id: int) -> List[RiskFinding]: def get_final_findings_by_segment(self, segment_id: int) -> List[Finding]:
return self._get_findings_by_segment(self.final_findings, segment_id) return self._get_findings_by_segment(self.final_findings, segment_id)
def delete_findings_by_segment(self, segment_id: int) -> int: def delete_findings_by_segment(self, segment_id: int) -> int:
...@@ -181,23 +147,23 @@ class MemoryStore: ...@@ -181,23 +147,23 @@ class MemoryStore:
def delete_final_findings_by_segment(self, segment_id: int) -> int: def delete_final_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("final_findings", segment_id) return self._delete_findings_by_segment("final_findings", segment_id)
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]: def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[Finding]:
return self._search_findings(self.findings, keyword, rule_title, risk_level) return self._search_findings(self.findings, keyword, rule_title, risk_level)
def search_final_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]: def search_final_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[Finding]:
return self._search_findings(self.final_findings, keyword, rule_title, risk_level) return self._search_findings(self.final_findings, keyword, rule_title, risk_level)
def _add_finding(self, target: List[RiskFinding], finding: RiskFinding) -> RiskFinding: def _add_finding(self, target: List[Finding], finding: Finding) -> Finding:
with self._lock: with self._lock:
target.append(finding) target.append(finding)
self._persist() self._persist()
return finding return finding
def _list_findings(self, target: List[RiskFinding]) -> List[RiskFinding]: def _list_findings(self, target: List[Finding]) -> List[Finding]:
with self._lock: with self._lock:
return list(target) return list(target)
def _get_findings_by_segment(self, target: List[RiskFinding], segment_id: int) -> List[RiskFinding]: def _get_findings_by_segment(self, target: List[Finding], segment_id: int) -> List[Finding]:
with self._lock: with self._lock:
return [f for f in target if f.segment_id == segment_id] return [f for f in target if f.segment_id == segment_id]
...@@ -214,11 +180,11 @@ class MemoryStore: ...@@ -214,11 +180,11 @@ class MemoryStore:
def _search_findings( def _search_findings(
self, self,
target: List[RiskFinding], target: List[Finding],
keyword: str, keyword: str,
rule_title: Optional[str] = None, rule_title: Optional[str] = None,
risk_level: Optional[str] = None, risk_level: Optional[str] = None,
) -> List[RiskFinding]: ) -> List[Finding]:
key = (keyword or "").strip().lower() key = (keyword or "").strip().lower()
with self._lock: with self._lock:
candidates = list(target) candidates = list(target)
...@@ -229,12 +195,13 @@ class MemoryStore: ...@@ -229,12 +195,13 @@ class MemoryStore:
candidates = [f for f in candidates if f.risk_level == lvl] candidates = [f for f in candidates if f.risk_level == lvl]
if not key: if not key:
return candidates return candidates
def _matches(f: RiskFinding) -> bool: def _matches(f: Finding) -> bool:
hay = " ".join([ hay = " ".join([
f.rule_title, f.rule_title,
f.original_text, f.original_text,
f.issue, f.issue,
f.suggestion, f.suggestion,
f.result,
]).lower() ]).lower()
return key in hay return key in hay
return [f for f in candidates if _matches(f)] return [f for f in candidates if _matches(f)]
...@@ -266,8 +233,8 @@ class MemoryStore: ...@@ -266,8 +233,8 @@ class MemoryStore:
data = json.loads(raw or "{}") data = json.loads(raw or "{}")
if isinstance(data, dict): if isinstance(data, dict):
self.facts = data.get("facts") or [] self.facts = data.get("facts") or []
self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []] self.findings = [Finding.from_dict(item) for item in data.get("findings", []) or []]
self.final_findings = [RiskFinding.from_dict(item) for item in data.get("final_findings", []) or []] self.final_findings = [Finding.from_dict(item) for item in data.get("final_findings", []) or []]
except Exception as exc: except Exception as exc:
logger.error("Failed to load memory store: %s", exc) logger.error("Failed to load memory store: %s", exc)
...@@ -285,7 +252,7 @@ class MemoryStore: ...@@ -285,7 +252,7 @@ class MemoryStore:
with self._lock: with self._lock:
wb = Workbook() wb = Workbook()
ws_final_findings = wb.active ws_final_findings = wb.active
ws_final_findings.title = "final_findings" ws_final_findings.title = "最终结果"
finding_headers = [ finding_headers = [
("rule_title", "规则标题"), ("rule_title", "规则标题"),
...@@ -293,6 +260,7 @@ class MemoryStore: ...@@ -293,6 +260,7 @@ class MemoryStore:
("original_text", "原文"), ("original_text", "原文"),
("issue", "问题描述"), ("issue", "问题描述"),
("risk_level", "风险等级"), ("risk_level", "风险等级"),
("result", "合格性"),
("suggestion", "建议"), ("suggestion", "建议"),
] ]
# add final findings # add final findings
...@@ -303,30 +271,24 @@ class MemoryStore: ...@@ -303,30 +271,24 @@ class MemoryStore:
]) ])
# add findings # add findings
ws_findings = wb.create_sheet("findings") ws_findings = wb.create_sheet("中间结果")
ws_findings.append([label for _, label in finding_headers]) ws_findings.append([label for _, label in finding_headers])
for f in self.findings: for f in self.findings:
ws_findings.append([ ws_findings.append([
getattr(f, key, "") for key, _ in finding_headers getattr(f, key, "") for key, _ in finding_headers
]) ])
ws_facts = wb.create_sheet("facts") ws_facts = wb.create_sheet("合同事实")
if self.facts: if self.facts:
fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()}) ws_facts.append(["元信息", "事实内容"])
if "段落" in fact_keys:
fact_keys = ["段落"] + [k for k in fact_keys if k != "段落"]
ws_facts.append(fact_keys)
for item in self.facts: for item in self.facts:
row = [] if not isinstance(item, dict):
for key in fact_keys: ws_facts.append(["事实", json.dumps(item, ensure_ascii=False)])
value = item.get(key) continue
if isinstance(value, (dict, list)): meta_info = item.pop(META_KEY, None)
row.append(json.dumps(value, ensure_ascii=False)) ws_facts.append([json.dumps(meta_info, ensure_ascii=False), json.dumps(item, ensure_ascii=False)])
else:
row.append(value)
ws_facts.append(row)
else: else:
ws_facts.append(["data"]) ws_facts.append(["元信息", "事实内容"])
wb.save(output_path) wb.save(output_path)
...@@ -368,7 +330,7 @@ class MemoryStore: ...@@ -368,7 +330,7 @@ class MemoryStore:
comments: List[Dict[str, Any]] = [] comments: List[Dict[str, Any]] = []
for idx, f in enumerate(target_findings, start=1): for idx, f in enumerate(target_findings, start=1):
segment_id = int(f.segment_id or 0) segment_id = int(f.segment_id or 0)
chunk_id = max(segment_id - 1, 0) chunk_id = max(segment_id, 0)
suggest_parts = [] suggest_parts = []
if f.risk_level: if f.risk_level:
suggest_parts.append(f"风险等级:{f.risk_level}") suggest_parts.append(f"风险等级:{f.risk_level}")
...@@ -384,7 +346,7 @@ class MemoryStore: ...@@ -384,7 +346,7 @@ class MemoryStore:
"original_text": f.original_text or "", "original_text": f.original_text or "",
"details": f.issue or "", "details": f.issue or "",
"chunk_id": chunk_id, "chunk_id": chunk_id,
"result": "不合格", "result": f.result or "不合格",
"suggest": suggest_text, "suggest": suggest_text,
} }
) )
...@@ -406,14 +368,16 @@ class MemoryStore: ...@@ -406,14 +368,16 @@ class MemoryStore:
def test_export_findings_to_doc_comments(doc_path: str) -> None: def test_export_findings_to_doc_comments(doc_path: str) -> None:
store = MemoryStore() store = MemoryStore()
finding = RiskFinding( finding = Finding(
rule_title="违约责任", rule_title="违约责任",
segment_id=1, segment_id=1,
original_text="湖南麓谷发展集团有限公司", original_text="湖南麓谷发展集团有限公司",
issue="未约定违约金上限,可能导致赔偿范围过大", issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H", risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’", suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
result="不合格",
) )
store.add_final_finding(finding) store.add_final_finding(finding)
"""测试:将 findings 作为批注写入文档并上传。""" """测试:将 findings 作为批注写入文档并上传。"""
...@@ -437,33 +401,30 @@ def test_export_findings_to_doc_comments(doc_path: str) -> None: ...@@ -437,33 +401,30 @@ def test_export_findings_to_doc_comments(doc_path: str) -> None:
def test_memory_and_export_excel(): def test_memory_and_export_excel():
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索 # 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore() store = MemoryStore()
store.set_facts([{ store.add_facts({
"公司": {"甲方": "A 公司", "乙方": "B 公司"}, "公司": {"甲方": "A 公司", "乙方": "B 公司"},
"支付": [ "支付": {"方式": "银行转账", "期限": "验收后30日内"},
{"方式": "银行转账", "期限": "验收后30日内"} META_KEY:{
], "segment_id":1
"段落":1 }
},{ })
"纠纷": {"解决方式": "诉讼", "地址": "原告方所在地法院"}, print( store.search_facts(['支付']))
"段落":2 # finding = Finding(
}]) # rule_title="违约责任",
# segment_id=1,
finding = RiskFinding( # original_text="违约方应赔偿全部损失",
rule_title="违约责任", # issue="未约定违约金上限,可能导致赔偿范围过大",
segment_id=1, # risk_level="H",
original_text="违约方应赔偿全部损失", # suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
issue="未约定违约金上限,可能导致赔偿范围过大", # )
risk_level="H", # store.add_finding(finding)
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
) # print("Facts:\n" + json.dumps(store.get_facts(), ensure_ascii=False, indent=2))
store.add_finding(finding) # hits = store.search_findings("赔偿", rule_title="违约责任")
# print("Findings search:")
print("Facts:\n" + json.dumps(store.get_facts(), ensure_ascii=False, indent=2)) # for f in hits:
hits = store.search_findings("赔偿", rule_title="违约责任") # print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print("Findings search:") # print(store.export_to_excel())
for f in hits:
print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print(store.export_to_excel())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional ...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
from dataclasses import asdict from dataclasses import asdict
from core.tool import ToolBase, tool, tool_func from core.tool import ToolBase, tool, tool_func
from core.memory import RiskFinding from core.memory import Finding
@tool("memory_write", "分段记忆写入") @tool("memory_write", "分段记忆写入")
...@@ -32,17 +32,19 @@ class MemoryWriteTool(ToolBase): ...@@ -32,17 +32,19 @@ class MemoryWriteTool(ToolBase):
issue = f.get("issue") or f.get("issue_description") or "" issue = f.get("issue") or f.get("issue_description") or ""
level = (f.get("level") or f.get("risk_level") or "M").upper() level = (f.get("level") or f.get("risk_level") or "M").upper()
suggestion = f.get("suggestion") or "" suggestion = f.get("suggestion") or ""
result = f.get("result") or ""
evs = list(f.get("evidence_quotes", []) or []) evs = list(f.get("evidence_quotes", []) or [])
original_text = evs[0] if evs else (f.get("original_text") or "") original_text = evs[0] if evs else (f.get("original_text") or "")
try: try:
finding_obj = RiskFinding( finding_obj = Finding(
rule_title=rule_title, rule_title=rule_title,
segment_id=int(segment_id) if str(segment_id).isdigit() else 0, segment_id=int(segment_id) if str(segment_id).isdigit() else 0,
original_text=original_text, original_text=original_text,
issue_description=issue, issue=issue,
risk_level=level, risk_level=level,
suggestion=suggestion, suggestion=suggestion,
result=result,
) )
store.add_finding(finding_obj) store.add_finding(finding_obj)
added.append(asdict(finding_obj)) added.append(asdict(finding_obj))
......
...@@ -59,7 +59,7 @@ REFLECT_USER_PROMPT = ''' ...@@ -59,7 +59,7 @@ REFLECT_USER_PROMPT = '''
【已有风险 findings】 【已有风险 findings】
{findings_json} {findings_json}
【合同事实记忆 facts】 【合同摘要事实记忆 facts】
{facts_json} {facts_json}
【合同立场】 【合同立场】
......
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Any from typing import Dict, List, Any
import re
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import ToolBase, tool, tool_func from core.tool import ToolBase, tool, tool_func
from core.cache import get_cached_memory from utils.excel_util import ExcelUtil
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
@tool("retrieve_reference", "审查参考检索") @tool("retrieve_reference", "审查参考检索")
class RetrieveReferenceTool(ToolBase): class RetrieveReferenceTool(ToolBase):
def __init__(self) -> None:
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
"case": "案例",
"summary":"摘要项"
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict[str, Any]]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func( @tool_func(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"question": {"type": "string"}, "ruleset_id": {"type": "string"},
"top_k": {"type": "int"}, "routed_rule_titles": {"type": "array", "items": {"type": "string"}},
}, },
"required": ["question"], "required": [],
} }
) )
def run(self, question: str, top_k: int = 5, conversation_id: str = "") -> Dict: def run(self, ruleset_id: str = "", routed_rule_titles: List[str] | None = None) -> Dict[str, Any]:
memory_refs = self._search_memory(question, conversation_id, top_k) target_ruleset_id = ruleset_id or self.default_ruleset_id
kb_refs = self._search_knowledge_base(question, top_k) full_rules = self.rulesets.get(target_ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
external_refs = self._search_external(question, top_k) if routed_rule_titles is None:
rules = full_rules
else:
title_set = {title for title in routed_rule_titles if isinstance(title, str)}
rules = [r for r in full_rules if r.get("title") in title_set]
return { return {
"memory_refs": memory_refs, "ruleset_id": target_ruleset_id,
"kb_refs": kb_refs, "rules": rules,
"external_refs": external_refs, "rule_titles": [r.get("title", "") for r in rules],
"total": len(rules),
} }
def _search_memory(self, question: str, conversation_id: str, top_k: int) -> List[Dict[str, Any]]: def summary_keywords(self, rules: List[Dict[str, Any]]) -> List[str]:
if not conversation_id: return [r.get("summary", "") for r in rules if r.get("summary")]
return []
store = get_cached_memory(conversation_id)
facts = store.get_facts()
results: List[Dict[str, Any]] = []
def _add(dim: str, payload: Any) -> None:
snippet = payload if isinstance(payload, str) else str(payload)
results.append({
"source": f"memory:{dim}",
"dimension": dim,
"snippet": snippet,
})
for dim in FACT_DIMENSIONS:
val = facts.get(dim)
if val is None:
continue
if dim in question:
_add(dim, val)
return results[:top_k]
def _search_knowledge_base(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement KB retrieval
return []
def _search_external(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement external retrieval (e.g., search engine)
return []
if __name__ == "__main__": if __name__ == "__main__":
tmp_memory = get_cached_memory("tmp")
tmp_memory.add_finding_from_dict({
"issue": "支付方式不明确",
"original_text": "买方应在收到货物后30天内支付全部货款。",
"risk_level": "H",
"rule_title": "支付条款审查",
})
tmp_memory.update_facts({
"支付":{
"支付方式": "银行转账",
"支付期限": "收到货物后30天内",
}
})
tool = RetrieveReferenceTool() tool = RetrieveReferenceTool()
result = tool.run( result = tool.run(ruleset_id="金盘", routed_rule_titles=None)
question="支付方式是什么?", for rule in result.get("rules", []):
top_k=3, print(f"Rule Title: {rule.get('title')}")
conversation_id="tmp", print(f"Case: {rule.get('case')}")
) print("-" * 20)
# print(result.get("total", 0))
print(result) \ No newline at end of file
\ No newline at end of file
...@@ -32,7 +32,8 @@ class LLMTool(ToolBase): ...@@ -32,7 +32,8 @@ class LLMTool(ToolBase):
def run_with_loop(self, coro): def run_with_loop(self, coro):
try: try:
return asyncio.run(coro) return asyncio.run(coro)
except RuntimeError: except RuntimeError as e:
print(f'RuntimeError in run_with_loop: {e}, trying to get event loop and run until complete.')
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(coro) return loop.run_until_complete(coro)
......
...@@ -2,59 +2,75 @@ from __future__ import annotations ...@@ -2,59 +2,75 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil
from core.tools.segment_llm import LLMTool from core.tools.segment_llm import LLMTool
import re import re
from core.config import DEFAULT_RULESET_ID, ALL_RULESET_IDS
REVIEW_SYSTEM_PROMPT = ''' REVIEW_SYSTEM_PROMPT = '''
你是一个专业的合同分段审查智能体(SegmentReview)。 你是一个专业的合同分段审查智能体(SegmentReview)。
你的任务是:基于给定审查规则,对“当前分段”进行审查,识别其中证据充分、对当前审查立场不利的风险或缺陷,并输出可执行的修改建议。 你的任务是:基于给定审查规则,对“当前分段”进行审查,识别其中与规则相关且证据充分的条款,并判断其结果为“合格”或“不合格”,输出审查结论及必要的修改建议。
【审查范围】 【审查范围】
你只能审查以下两类问题: 你只能审查当前分段自身已经明确体现的内容。
1. 当前分段自身已经明确体现的问题,例如:对我方不利、表述不清、逻辑冲突、责任失衡、触发条件不明确、关键限制缺失等; 你只能识别以下两类结果:
2. 当前分段与【上下文记忆】之间存在的、证据充分的前后不一致或条款冲突,例如:主体、定义、金额、期限、责任、流程节点不一致等。 1. 合格条款:当前分段中存在与审查规则相关的明确表述,且该表述符合规则要求;
2. 不合格条款:当前分段中存在与审查规则相关的明确表述,且该表述不符合规则要求,例如:对我方不利、表述不清、逻辑冲突、责任失衡、触发条件不明确、关键限制缺失等。
【审查原则】 【审查原则】
- 严格基于给定的审查规则进行审查,不得脱离规则自行扩展审查标准。 - 严格基于给定的审查规则进行审查,不得脱离规则自行扩展审查标准。
- 只有在证据充分时才生成 finding。弱猜测、无充分文本支持的问题不得输出。 - 只有在证据充分时才生成 finding。弱猜测、无充分文本支持的问题或合格判断不得输出。
- 上下文记忆仅作为辅助比对材料,不能替代当前分段原文,也不能在证据不足时强行得出结论。 - 只审查当前分段原文,不得使用上下文信息补充、修正或推断当前分段含义。
- 优先识别“确定存在”的问题,不输出模糊怀疑类表述。 - 优先识别“确定成立”的合格或不合格结论,不输出模糊怀疑类表述。
【结果判定规则】
- result 只能取以下两个值之一:
- "合格":当前分段存在与规则相关的明确内容,且符合该规则要求;
- "不合格":当前分段存在与规则相关的明确内容,且不符合该规则要求。
- 如果当前分段与某条审查规则无关,或虽疑似相关但证据不足,则不得生成 finding。
【证据要求】 【证据要求】
每个 finding 都必须包含 original_text,且必须是合同原文的直接引用。 每个 finding 都必须包含 original_text,且必须是合同原文的直接引用。
【证据粒度(关键句原则)】 【证据粒度(关键句原则)】
original_text 必须满足“最小充分证据原则”: original_text 必须满足“最小充分证据原则”:
- 只引用能够证明问题成立的最小文本片段 - 只引用能够证明该判断成立的最小文本片段;
- 优先引用单句或关键子句 - 优先引用单句或关键子句
- 不得复制整段条款 - 不得复制整段条款
引用长度限制: 引用长度限制:
- 推荐:20–80 字 - 推荐:20–80 字
- 最大:120 字 - 最大:120 字
- 若一句话即可证明问题,则只允许引用该句 - 若一句话即可证明该判断,则只允许引用该句
生成 finding 时必须执行: 生成 finding 时必须执行:
Step 1:定位能够证明问题成立的关键句 Step 1:定位能够证明结果成立的关键句
Step 2:仅提取该句作为 original_text Step 2:仅提取该句作为 original_text
Step 3:再分析 issue Step 3:再判断该句对应 result="合格" 还是 result="不合格"
Step 4:输出 issue 与 suggestion
禁止: 禁止:
- 复制整段条款 - 复制整段条款
- 引用超过 3 句文本 - 引用超过 3 句文本
- 引用与 issue 无关的上下文 - 引用与该判断无关的上下文
【issue 要求】
- issue 必须说明:该条款为什么合格或为什么不合格。
- 当 result="合格" 时,issue 应说明该表述满足了什么规则要求、为什么可认定为合格。
- 当 result="不合格" 时,issue 应说明该表述违反了什么规则要求、为什么构成风险或缺陷。
- issue 必须紧扣规则和原文,不得空泛评价。
【建议要求】 【建议要求】
- suggestion 必须具体、可执行。 - suggestion 必须具体、可执行。
- 若能在当前分段内直接修正,请给出可直接替换或新增的条款措辞。 - 当 result="不合格" 时:
- 若需联动其他条款,允许给出明确修改方向和应补充的关键要素,但不得只写“建议协商”“建议完善”等空泛表述。 - 若能在当前分段内直接修正,请给出可直接替换或新增的条款措辞;
- 若无法直接改写,请给出明确修改方向和应补充的关键要素;
- 不得只写“建议协商”“建议完善”等空泛表述。
- 当 result="合格" 时:
- suggestion 应简洁填写,可写“无需修改”;
- 不得为了凑内容而提出与审查结论无关的修改建议。
【规则适用性判断】 【规则适用性判断】
在执行任何审查规则之前,你必须先判断: 在执行任何审查规则之前,你必须先判断:
...@@ -62,20 +78,18 @@ Step 3:再分析 issue ...@@ -62,20 +78,18 @@ Step 3:再分析 issue
如果当前分段与该审查规则无关,则: 如果当前分段与该审查规则无关,则:
- 不得生成 finding - 不得生成 finding
- 不得引用原文 - 不得引用原文
- 直接返回 {"findings": []}。 - 继续检查下一条规则
如果所有规则均无关或均无证据充分的结论,则返回 {"findings": []}。
【输出约束】 【输出约束】
- 严格按照指定 JSON Schema 输出。 - 严格按照指定 JSON Schema 输出。
- 不得输出任何 JSON 之外的解释性文字。 - 不得输出任何 JSON 之外的解释性文字。
- 若未发现证据充分的问题,返回 {"findings": []}。 - 若未发现证据充分的合格或不合格条款,返回 {"findings": []}。
''' '''
REVIEW_USER_PROMPT = ''' REVIEW_USER_PROMPT = '''
【当前分段文本】 【当前分段文本】
{segment_text} {segment_text}
【上下文记忆(来自已审分段)】
{context_memories_json}
【合同立场】 【合同立场】
站在 {party_role} 的立场进行审查。 站在 {party_role} 的立场进行审查。
...@@ -83,14 +97,16 @@ REVIEW_USER_PROMPT = ''' ...@@ -83,14 +97,16 @@ REVIEW_USER_PROMPT = '''
{ruleset_text} {ruleset_text}
【任务】 【任务】
请基于审查规则,审查当前分段,识别证据充分的问题,并输出可执行修改建议 请基于审查规则,仅针对当前分段文本进行审查,提取证据充分的合格条款和不合格条款,并输出审查结果
【特别要求】 【特别要求】
- 仅输出证据充分的问题。 - 仅基于当前分段原文进行判断,不得参考任何上下文、摘要或记忆信息。
- 如果问题来自与上下文记忆的冲突,必须确保冲突是明确、可由文本直接支持的。 - findings 中每一项都必须包含 result 字段,且 result 只能为 "合格" 或 "不合格"。
- findings 中的 original_text 必须为合同原文直接引用。 - 只有当前分段中存在与规则相关且证据充分的内容时,才输出 finding。
- suggestion 应尽量提供可直接落地的修改文本;若无法安全地直接改写,请给出明确的修改方向和应补充的关键要素。 - findings 中的 original_text 必须为合同原文直接引用,且应为最小充分证据片段。
- 若无问题,返回 {{"findings": []}}。 - 当 result="合格" 时,suggestion 填写“无需修改”。
- 当 result="不合格" 时,suggestion 应尽量提供可直接落地的修改文本;若无法安全地直接改写,请给出明确的修改方向和应补充的关键要素。
- 若无相关或无证据充分的结论,返回 {{"findings": []}}。
【输出要求】 【输出要求】
- 仅输出 JSON。 - 仅输出 JSON。
...@@ -100,39 +116,18 @@ REVIEW_OUTPUT_SCHEMA = ''' ...@@ -100,39 +116,18 @@ REVIEW_OUTPUT_SCHEMA = '''
{ {
"findings": [ "findings": [
{ {
"issue": "详细的风险描述,为什么该问题构成风险,需基于规则和文本解释", "rule_title": "对应审查规则标题",
"result": "合格 或 不合格",
"issue": "对该条款为何合格或为何不合格的详细说明,需基于规则与案例文本解释",
"original_text": "合同原文片段的直接引用", "original_text": "合同原文片段的直接引用",
"suggestion": "可直接替换原文或新增的条款措辞" "suggestion": "若 result=合格 则填写“无需修改”;若 result=不合格 则填写可直接替换原文或新增的条款措辞"
} }
] ]
} }
```
''' '''
LEVEL_WEIGHT = {"H": 3, "M": 2, "L": 1} LEVEL_WEIGHT = {"H": 3, "M": 2, "L": 1}
ROOT_CAUSE_RULES = {
"ACC_UNCLEAR": ["验收", "标准", "文件", "通过", "确认"],
"PAY_TRIGGER": ["付款", "支付", "触发", "条件"],
"PAY_DELAY": ["逾期", "迟延", "未按期"],
"LIAB_CAP": ["责任", "上限", "限制", "赔偿"],
"BREACH": ["违约", "违约金"],
"TERM": ["解除", "终止", "解约"],
"CONF": ["保密"],
"IP": ["知识产权", "成果", "许可"],
}
GENERIC_SUGGESTIONS = [
"建议协商", "建议完善", "建议明确",
"进一步明确", "双方协商"
]
def _root_cause(issue: str) -> str:
t = _norm(issue)
for key, kws in ROOT_CAUSE_RULES.items():
if any(k in t for k in kws):
return key
return "OTHER"
def _norm(text: str) -> str: def _norm(text: str) -> str:
if not text: if not text:
...@@ -146,29 +141,10 @@ def _norm(text: str) -> str: ...@@ -146,29 +141,10 @@ def _norm(text: str) -> str:
def _has_evidence(f: Dict) -> bool: def _has_evidence(f: Dict) -> bool:
return bool(f.get("original_text")) return bool(f.get("original_text"))
def _is_generic_suggestion(text: str) -> bool:
t = _norm(text)
return any(g in t for g in GENERIC_SUGGESTIONS)
@tool("segment_review", "合同分段审查") @tool("segment_review", "合同分段审查")
class SegmentReviewTool(LLMTool): class SegmentReviewTool(LLMTool):
def __init__(self): def __init__(self):
super().__init__(REVIEW_SYSTEM_PROMPT) super().__init__(REVIEW_SYSTEM_PROMPT)
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func( @tool_func(
{ {
...@@ -176,23 +152,29 @@ class SegmentReviewTool(LLMTool): ...@@ -176,23 +152,29 @@ class SegmentReviewTool(LLMTool):
"properties": { "properties": {
"segment_id": {"type": "int"}, "segment_id": {"type": "int"},
"segment_text": {"type": "string"}, "segment_text": {"type": "string"},
"ruleset_id": {"type": "string"}, "rules": {"type": "array", "items": {"type": "object"}},
"routed_rule_titles": {"type": "array", "items": {"type": "string"}}, "merge_rules_prompt": {"type": "boolean"},
"party_role": {"type": "string"}, "party_role": {"type": "string"},
"context_summaries": {"type": "array"},
"context_memories": {"type": "array"}, "context_memories": {"type": "array"},
}, },
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"], "required": ["segment_id", "segment_text", "rules", "party_role"],
} }
) )
def run(self, segment_id: str, segment_text: str, ruleset_id: str, party_role: str, def run(self, segment_id: str, segment_text: str, rules: List[Dict], party_role: str,
routed_rule_titles: Optional[List[str]] = None, context_memories: Optional[List[Dict]] = None) -> Dict: context_summaries: Optional[List[Dict]] = None,
full_rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or [] context_memories: Optional[List[Dict]] = None,
if routed_rule_titles is not None: merge_rules_prompt: bool = True) -> Dict:
title_set = {title for title in routed_rule_titles if isinstance(title, str)} rules = rules or []
rules = [r for r in full_rules if r.get("title") in title_set] result = self._evaluate_rules(
else: party_role,
rules = full_rules segment_id,
result = self._evaluate_rules(party_role,segment_id,segment_text,rules, context_memories) segment_text,
rules,
context_summaries,
context_memories,
merge_rules_prompt=merge_rules_prompt,
)
overall = "revise" if (result["findings"] ) else "pass" overall = "revise" if (result["findings"] ) else "pass"
...@@ -204,25 +186,45 @@ class SegmentReviewTool(LLMTool): ...@@ -204,25 +186,45 @@ class SegmentReviewTool(LLMTool):
def _stringify_rule(self, rule:Dict) -> str: def _stringify_rule(self, rule:Dict) -> str:
res = '' res = ''
res += f"审查项: {rule.get('title','')}\n" res += f"## 审查项标题\n{rule.get('title','')}\n"
res += f"审查规则: {rule.get('rule','')}\n" res += f"## 审查规则\n{rule.get('rule','')}\n"
res += f"风险等级: {rule.get('level','')}\n" res += f"## 风险等级\n{rule.get('level','')}\n"
res += f"触发词: {rule.get('triggers','')}\n" res += f"## 建议模板\n{rule.get('suggestion_template','')}\n"
res += f"建议模板: {rule.get('suggestion_template','')}\n" res += f"## 参考案例\n{rule.get('case','')}\n"
return res return res
def _build_prompt(self, party_role: str, rule: Dict, segment_id: int, segment_text: str, context_memories: Optional[List[Dict]]) -> List[Dict[str, str]]: def _build_prompt(self, party_role: str, rule: Dict, segment_id: int, segment_text: str,
context_summaries: Optional[List[Dict]], context_memories: Optional[List[Dict]]) -> List[Dict[str, str]]:
user_content = REVIEW_USER_PROMPT.format( user_content = REVIEW_USER_PROMPT.format(
segment_id=segment_id, segment_id=segment_id,
segment_text=segment_text, segment_text=segment_text,
party_role=party_role, party_role=party_role,
context_memories_json=json.dumps(context_memories or [], ensure_ascii=False), # context_facts_json=json.dumps(context_summaries or [], ensure_ascii=False),
# context_memories_json=json.dumps(context_memories or [], ensure_ascii=False),
ruleset_text=self._stringify_rule(rule) ruleset_text=self._stringify_rule(rule)
) + REVIEW_OUTPUT_SCHEMA ) + REVIEW_OUTPUT_SCHEMA
return self.build_messages(user_content) return self.build_messages(user_content)
async def _evaluate_rules_async(self, party_role: str, segment_id: int, segment_text: str, rules: List[Dict], context_memories: Optional[List[Dict]]) -> Dict[str, List[Dict]]: def _build_prompt_with_rules(self, party_role: str, rules: List[Dict], segment_id: int, segment_text: str,
msgs = [self._build_prompt(party_role,rule, segment_id, segment_text, context_memories) for rule in rules] context_summaries: Optional[List[Dict]], context_memories: Optional[List[Dict]]) -> List[Dict[str, str]]:
ruleset_text = "\n\n".join([self._stringify_rule(rule) for rule in rules])
user_content = REVIEW_USER_PROMPT.format(
segment_id=segment_id,
segment_text=segment_text,
party_role=party_role,
context_facts_json=json.dumps(context_summaries or [], ensure_ascii=False),
context_memories_json=json.dumps(context_memories or [], ensure_ascii=False),
ruleset_text=ruleset_text,
) + REVIEW_OUTPUT_SCHEMA
return self.build_messages(user_content)
async def _evaluate_rules_async(self, party_role: str, segment_id: int, segment_text: str, rules: List[Dict],
context_summaries: Optional[List[Dict]], context_memories: Optional[List[Dict]],
merge_rules_prompt: bool = False) -> Dict[str, List[Dict]]:
if merge_rules_prompt:
msgs = [self._build_prompt_with_rules(party_role, rules, segment_id, segment_text, context_summaries, context_memories)]
else:
msgs = [self._build_prompt(party_role, rule, segment_id, segment_text, context_summaries, context_memories) for rule in rules]
if not msgs: if not msgs:
return {"findings": []} return {"findings": []}
...@@ -232,11 +234,22 @@ class SegmentReviewTool(LLMTool): ...@@ -232,11 +234,22 @@ class SegmentReviewTool(LLMTool):
return {"findings": []} return {"findings": []}
all_findings: List[Dict] = [] all_findings: List[Dict] = []
if merge_rules_prompt:
data = self.parse_first_json(responses[0])
findings = data.get("findings", []) or []
for f in findings:
if "rule_title" not in f:
f["rule_title"] = ""
if "risk_level" not in f:
f["risk_level"] = ""
all_findings.extend(findings)
else:
for idx,resp in enumerate(responses): for idx,resp in enumerate(responses):
data = self.parse_first_json(resp) data = self.parse_first_json(resp)
rule_title = rules[idx].get("title","") rule_title = rules[idx].get("title","")
rule_level = rules[idx].get("level","") rule_level = rules[idx].get("level","")
findings = data.get("findings", []) or [] findings = data.get("findings", []) or []
# 为每条 finding 添加对应的 rule_title 和 risk_level
for f in findings: for f in findings:
f["rule_title"] = rule_title f["rule_title"] = rule_title
f["risk_level"] = rule_level f["risk_level"] = rule_level
...@@ -246,101 +259,53 @@ class SegmentReviewTool(LLMTool): ...@@ -246,101 +259,53 @@ class SegmentReviewTool(LLMTool):
"findings": all_findings, "findings": all_findings,
} }
def _evaluate_rules(self, party_role: str, segment_id: int, segment_text: str, rules: List[Dict], context_memories: Optional[List[Dict]]) -> Dict[str, List[Dict]]: def _evaluate_rules(self, party_role: str, segment_id: int, segment_text: str, rules: List[Dict],
context_summaries: Optional[List[Dict]], context_memories: Optional[List[Dict]],
merge_rules_prompt: bool = False) -> Dict[str, List[Dict]]:
try: try:
return asyncio.run(self._evaluate_rules_async(party_role, segment_id, segment_text, rules, context_memories)) return asyncio.run(self._evaluate_rules_async(
party_role,
segment_id,
segment_text,
rules,
context_summaries,
context_memories,
merge_rules_prompt=merge_rules_prompt,
))
except RuntimeError: except RuntimeError:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(self._evaluate_rules_async(party_role, segment_id, segment_text, rules, context_memories)) return loop.run_until_complete(self._evaluate_rules_async(
party_role,
segment_id,
segment_text,
rules,
context_summaries,
context_memories,
merge_rules_prompt=merge_rules_prompt,
))
def filter(self,
segment_id: str,
review_result: Dict) -> Dict:
findings: List[Dict] = review_result.get("findings", [])
facts = review_result.get("facts", [])
missing_info: List[str] = []
# 1. 硬过滤:无证据直接丢
findings = [f for f in findings if _has_evidence(f)]
if not findings:
return {
"segment_id": segment_id,
"findings": [],
"facts": facts
}
# 2. 去重(issue 文本级)
dedup = {}
for f in findings:
k = _norm(f.get("issue", ""))
if k not in dedup:
dedup[k] = f
else:
old = dedup[k]
if LEVEL_WEIGHT.get(f["level"], 1) > LEVEL_WEIGHT.get(old["level"], 1):
dedup[k] = f
findings = list(dedup.values())
# 3. 按根因聚类
clusters = {}
for f in findings:
key = _root_cause(f.get("issue", ""))
clusters.setdefault(key, []).append(f)
# 4. 聚类合并
merged = []
for fs in clusters.values():
fs = sorted(
fs,
key=lambda x: (
LEVEL_WEIGHT.get(x.get("level", "L"), 1),
len(x.get("evidence_quotes", []))
),
reverse=True
)
best = fs[0]
merged.append({
"rule_titles": [f.get("rule_title") for f in fs],
"issue": best.get("issue"),
"level": best.get("level"),
"evidence_quotes": list(
{q for f in fs for q in f.get("evidence_quotes", [])}
),
"suggestion": next(
(f["suggestion"] for f in fs
if f.get("suggestion") and not _is_generic_suggestion(f["suggestion"])),
best.get("suggestion")
)
})
# 5. 排序 + 截断(最多 5 条)
merged.sort(
key=lambda x: (
LEVEL_WEIGHT.get(x["level"], 1),
len(x["evidence_quotes"]),
len(x["rule_titles"])
),
reverse=True
)
final_findings = merged[:5]
return {
"segment_id": segment_id,
"findings": final_findings,
"missing_info": missing_info
}
if __name__=="__main__": if __name__=="__main__":
tool = SegmentReviewTool() tool = SegmentReviewTool()
result = tool.run( result = tool.run(
segment_id=1, segment_id=1,
segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。", segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。",
ruleset_id="通用", rules=[
{
"title": "期限与续展条款清晰性",
"rule": "合同期限、续展触发与终止通知期限应明确,避免自动续展引发不确定性。",
"level": "M",
"suggestion_template": "明确续展条件、通知方式与生效时间。",
"case": "自动续展条款未约定通知送达标准导致争议。",
}
],
party_role="甲方", party_role="甲方",
context_summaries=[
{
"付款": {"方式": "银行转账", "期限": "验收后30日内"},
"META": {"segment_id": 0}
}
],
context_memories=[ context_memories=[
{ {
"segment_id": 0, "segment_id": 0,
......
...@@ -2,13 +2,10 @@ from __future__ import annotations ...@@ -2,13 +2,10 @@ from __future__ import annotations
import json import json
import re import re
from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool from core.tools.segment_llm import LLMTool
from utils.excel_util import ExcelUtil
ROUTER_SYSTEM_PROMPT = ''' ROUTER_SYSTEM_PROMPT = '''
...@@ -42,7 +39,7 @@ ROUTER_USER_PROMPT = ''' ...@@ -42,7 +39,7 @@ ROUTER_USER_PROMPT = '''
【合同立场】 【合同立场】
{party_role} {party_role}
【候选审查规则 【候选审查规则
{candidate_rules_json} {candidate_rules_json}
【任务】 【任务】
...@@ -68,20 +65,6 @@ ROUTER_OUTPUT_SCHEMA = ''' ...@@ -68,20 +65,6 @@ ROUTER_OUTPUT_SCHEMA = '''
class SegmentRuleRouterTool(LLMTool): class SegmentRuleRouterTool(LLMTool):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(ROUTER_SYSTEM_PROMPT) super().__init__(ROUTER_SYSTEM_PROMPT)
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func( @tool_func(
{ {
...@@ -89,22 +72,22 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -89,22 +72,22 @@ class SegmentRuleRouterTool(LLMTool):
"properties": { "properties": {
"segment_id": {"type": "int"}, "segment_id": {"type": "int"},
"segment_text": {"type": "string"}, "segment_text": {"type": "string"},
"ruleset_id": {"type": "string"}, "rules": {"type": "array", "items": {"type": "object"}},
"party_role": {"type": "string"}, "party_role": {"type": "string"},
"context_memories": {"type": "array"}, "context_memories": {"type": "array"},
}, },
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"], "required": ["segment_id", "segment_text", "rules", "party_role"],
} }
) )
def run( def run(
self, self,
segment_id: int, segment_id: int,
segment_text: str, segment_text: str,
ruleset_id: str, rules: List[Dict],
party_role: str, party_role: str,
context_memories: Optional[List[Dict]] = None, context_memories: Optional[List[Dict]] = None,
) -> Dict: ) -> Dict:
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or [] rules = rules or []
routed_rules = self._route_rules( routed_rules = self._route_rules(
segment_text=segment_text, segment_text=segment_text,
rules=rules, rules=rules,
...@@ -113,7 +96,6 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -113,7 +96,6 @@ class SegmentRuleRouterTool(LLMTool):
) )
return { return {
"segment_id": segment_id, "segment_id": segment_id,
"ruleset_id": ruleset_id,
"routed_rules": routed_rules, "routed_rules": routed_rules,
"routed_rule_titles": [r.get("title", "") for r in routed_rules], "routed_rule_titles": [r.get("title", "") for r in routed_rules],
} }
...@@ -121,10 +103,7 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -121,10 +103,7 @@ class SegmentRuleRouterTool(LLMTool):
def _build_candidate_rules(self, rules: List[Dict]) -> List[Dict]: def _build_candidate_rules(self, rules: List[Dict]) -> List[Dict]:
return [ return [
{ {
"title": r.get("title", ""), r.get("title", ""): r.get("rule", "")
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
} }
for r in rules for r in rules
if r.get("title") if r.get("title")
...@@ -223,6 +202,22 @@ class SegmentRuleRouterTool(LLMTool): ...@@ -223,6 +202,22 @@ class SegmentRuleRouterTool(LLMTool):
if __name__ == "__main__": if __name__ == "__main__":
tool = SegmentRuleRouterTool() tool = SegmentRuleRouterTool()
demo_rules = [
{
"id": "R1",
"title": "付款触发条件明确性",
"level": "H",
"rule": "付款应绑定明确触发条件和验收标准。",
"triggers": "支付,付款,验收",
},
{
"id": "R2",
"title": "违约责任对等性",
"level": "M",
"rule": "违约责任应当相对对等且违约金标准明确。",
"triggers": "违约,违约金",
},
]
demo_segment_text = ( demo_segment_text = (
"甲方应在合同签订后5个工作日内向乙方支付合同总价30%作为预付款," "甲方应在合同签订后5个工作日内向乙方支付合同总价30%作为预付款,"
"剩余70%在乙方完成交付并经甲方验收合格后30日内支付。" "剩余70%在乙方完成交付并经甲方验收合格后30日内支付。"
...@@ -232,7 +227,7 @@ if __name__ == "__main__": ...@@ -232,7 +227,7 @@ if __name__ == "__main__":
result = tool.run( result = tool.run(
segment_id=1, segment_id=1,
segment_text=demo_segment_text, segment_text=demo_segment_text,
ruleset_id="通用", rules=demo_rules,
party_role="甲方", party_role="甲方",
context_memories=[], context_memories=[],
) )
......
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool from core.tools.segment_llm import LLMTool
from core.config import FACT_DIMENSIONS from core.config import META_KEY
SUMMARY_SYSTEM_PROMPT = f''' SUMMARY_SYSTEM_PROMPT = f'''
你是合同事实提取智能体(SegmentSummary)。 你是合同事实提取智能体(SegmentSummary)。
你的任务是从当前合同分段中提取“客观事实”,并按指定维度结构化输出。 你的任务是:**基于给定的审查规则,从当前合同分段中提取“与该规则直接相关的客观事实”,并结构化输出。**
【核心原则】
你必须严格围绕“规则所需信息”进行提取。
---
【事实定义】 【事实定义】
...@@ -24,37 +27,51 @@ SUMMARY_SYSTEM_PROMPT = f''' ...@@ -24,37 +27,51 @@ SUMMARY_SYSTEM_PROMPT = f'''
3. 不得补充未出现的主体、条件或数值; 3. 不得补充未出现的主体、条件或数值;
4. 允许对原文做最小结构化拆分(例如金额、比例、期限)。 4. 允许对原文做最小结构化拆分(例如金额、比例、期限)。
---
【规则驱动提取要求(关键)】
- 仅提取“该审查规则执行所需要的信息字段”
- 不得提取与该规则无关的信息(即使这些信息在文本中存在)
- 若规则未涉及某类信息,则不得输出对应字段
- 若规则涉及某字段但文本未出现,需显式标记为 "未明确"
---
【输出结构】 【输出结构】
- 输出字段:facts - 输出字段:facts
- facts 是一个对象 - facts 是一个对象
- 键为以下预设维度: - 键必须来自【规则字段定义(rule_fields)】
- 不得使用预设通用维度(如“支付/违约责任”等)
---
{", ".join(FACT_DIMENSIONS)} 【字段填充规则】
- 每个维度值必须是对象或对象列表 - 每个字段值必须是对象或对象列表
- 未出现的维度可以省略 - 不得输出字符串作为字段值
- 字段内容必须为原文的最小结构化表达
- 不得改写原文含义
【结构规则】 ---
- 仅提取对合同履行或责任具有实际意义的事实 【缺失信息处理(非常重要)】
- 不得输出字符串作为维度值,必须使用对象
- 不得输出解释、总结或风险判断
【上下文事实使用规则】 - 若规则要求的字段在当前分段未出现:
→ 必须输出该字段,并标记为:
上下文事实仅用于: "未明确"
- 避免重复提取已存在的事实
- 保持字段命名一致
不得: (用于后续审查判断)
- 使用上下文事实补充当前分段没有出现的信息
- 修改当前分段原文事实 ---
【约束】 【约束】
- 严禁编造信息 - 严禁编造信息
- 严禁推断未出现的内容 - 严禁推断未出现的内容
- 不得输出风险判断或解释
- 严格输出 JSON - 严格输出 JSON
''' '''
...@@ -62,11 +79,17 @@ SUMMARY_USER_PROMPT = ''' ...@@ -62,11 +79,17 @@ SUMMARY_USER_PROMPT = '''
【分段原文】 【分段原文】
{segment_text} {segment_text}
【上下文事实】 【规则字段定义(仅提取这些字段)】
{context_facts} {rule_fields}
【任务】
请仅提取“当前分段中,与候选审查规则直接相关的客观事实”。
仅提取当前分段中明确出现的客观事实。 【特别要求】
不得从上下文事实中补充新的信息。 - facts 的顶层 key 必须是规则 title
- 每个规则下仅保留与该规则直接相关的信息
- 若某规则在当前分段未出现关键信息,输出该规则并标记为 "未明确"
- 不得提取与规则无关的信息
输出 JSON。 输出 JSON。
''' '''
...@@ -75,8 +98,8 @@ OUTPUT_EXAMPLE = ''' ...@@ -75,8 +98,8 @@ OUTPUT_EXAMPLE = '''
```json ```json
{ {
"facts": { "facts": {
"支付": {"方式": "银行转账", "时间": "验收后30日内"}, "支付审查": {"方式": "银行转账", "时间": "验收后30日内"},
"违约责任": {"违约金比例": "合同总金额的5%"} "违约责任审查": {"违约金比例": "合同总金额的5%"}
} }
} }
``` ```
...@@ -94,28 +117,54 @@ class SegmentSummaryTool(LLMTool): ...@@ -94,28 +117,54 @@ class SegmentSummaryTool(LLMTool):
"properties": { "properties": {
"segment_id": {"type": "int"}, "segment_id": {"type": "int"},
"segment_text": {"type": "string"}, "segment_text": {"type": "string"},
"rules": {"type": "array", "items": {"type": "object"}},
"party_role": {"type": "string"}, "party_role": {"type": "string"},
"context_facts": {"type": "object"}, "context_facts": {"type": "object"},
}, },
"required": ["segment_id", "segment_text"], "required": ["segment_id", "segment_text", "rules"],
} }
) )
def run( def run(
self, self,
segment_id: int, segment_id: int,
segment_text: str, segment_text: str,
rules: List[Dict],
party_role: str = "", party_role: str = "",
context_facts: Optional[Dict] = None, context_facts: Optional[Dict] = None,
) -> Dict: ) -> Dict:
rules = rules or []
try: try:
return self.run_with_loop(self._summarize_async(segment_id, segment_text, party_role, context_facts)) return self.run_with_loop(
self._summarize_async(segment_id, segment_text, rules, party_role, context_facts)
)
except Exception: except Exception:
return {} return {}
def _build_prompt(self, segment_text: str, context_facts: Optional[Dict], party_role: str) -> List[Dict[str, str]]: def _stringify_rule(self, rules: List[Dict]) -> str:
lines = []
for r in rules:
id = r.get("id", "")
rule_text = r.get("rule", "")
lines.append(f"规则ID: {id}\n审查规则: {rule_text}\n")
return "\n".join(lines)
def _build_prompt(
self,
segment_text: str,
rules: List[Dict],
context_facts: Optional[Dict],
party_role: str,
) -> List[Dict[str, str]]:
# 获取规则字段定义
rule_fields = [
r.get("summary") for r in rules
if r.get("summary")
]
user_content = SUMMARY_USER_PROMPT.format( user_content = SUMMARY_USER_PROMPT.format(
segment_text=segment_text, segment_text=segment_text,
context_facts=json.dumps(context_facts or {}, ensure_ascii=False), # party_role=party_role,
# rules_json=self._stringify_rule(rules),
rule_fields=json.dumps(rule_fields, ensure_ascii=False),
) + OUTPUT_EXAMPLE ) + OUTPUT_EXAMPLE
return self.build_messages(user_content) return self.build_messages(user_content)
...@@ -123,30 +172,30 @@ class SegmentSummaryTool(LLMTool): ...@@ -123,30 +172,30 @@ class SegmentSummaryTool(LLMTool):
self, self,
segment_id: int, segment_id: int,
segment_text: str, segment_text: str,
rules: List[Dict],
party_role: str, party_role: str,
context_facts: Optional[Dict], context_facts: Optional[Dict],
) -> Dict: ) -> Dict:
msgs = self._build_prompt(segment_text, context_facts, party_role) msgs = self._build_prompt(segment_text, rules, context_facts, party_role)
final_facts: Dict = {}
try: try:
resp = await self.chat_async(msgs) resp = await self.chat_async(msgs)
# print("segment summary response:", resp)
data = self.parse_first_json(resp) data = self.parse_first_json(resp)
facts = data.get("facts") or {} facts = data.get("facts") or {}
except Exception: except Exception as e:
print(f'Error in segment summary for segment {segment_id}: {e}')
facts = {} facts = {}
# print(f'SegmentSummaryTool facts: {facts}') facts[META_KEY] = {
if isinstance(facts,list): "segment_id": segment_id,
final_facts['内容'] = facts }
else: return facts
final_facts = facts
final_facts['segment_id'] = segment_id
return final_facts
if __name__=='__main__': if __name__=='__main__':
tool = SegmentSummaryTool() tool = SegmentSummaryTool()
res = tool.run( res = tool.run(
segment_id=1, segment_id=1,
segment_text="甲方应于合同签订之日起30日内向乙方支付合同总金额的50%,余款在货物验收合格后30日内付清.", segment_text="甲方应于合同签订之日起30日内向乙方支付合同总金额的50%,余款在货物验收合格后30日内付清.",
rules=[{"id": "R1", "title": "付款", "rule": "付款相关", "summary": "付款方式"}],
context_facts={}, context_facts={},
) )
print(res) print(res)
\ No newline at end of file
...@@ -14,7 +14,7 @@ from utils.http_util import upload_file, fastgpt_openai_chat, download_file ...@@ -14,7 +14,7 @@ from utils.http_util import upload_file, fastgpt_openai_chat, download_file
SUFFIX='_麓发改进' SUFFIX='_麓发改进'
batch_input_dir_path = 'jp-input' batch_input_dir_path = 'jp-input'
batch_output_dir_path = 'jp-output-lufa' batch_output_dir_path = 'jp-all'
batch_size = 5 batch_size = 5
# 麓发fastgpt接口 # 麓发fastgpt接口
# url = 'http://192.168.252.71:18089/api/v1/chat/completions' # url = 'http://192.168.252.71:18089/api/v1/chat/completions'
...@@ -31,9 +31,6 @@ token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyM ...@@ -31,9 +31,6 @@ token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyM
def extract_url(text): def extract_url(text):
# \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx)) # \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx))
# 麓发正则
# excel_p, doc_p = r'导出Excel结果\s*([^"]*xlsx)', r'导出Doc结果\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 金盘正则
excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))' excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 使用 re.search() 查找第一个匹配项 # 使用 re.search() 查找第一个匹配项
excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text) excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text)
......
...@@ -69,6 +69,8 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -69,6 +69,8 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
best_score = -1 best_score = -1
for idx, cand in enumerate(candidates): for idx, cand in enumerate(candidates):
ans_text = ans_text.strip() ans_text = ans_text.strip()
if cand is None or not isinstance(cand,str):
continue
cand = cand.strip() cand = cand.strip()
score = max( score = max(
fuzz.partial_ratio(ans_text, cand), fuzz.partial_ratio(ans_text, cand),
...@@ -99,6 +101,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -99,6 +101,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
unmatched_val_count = sum(len(v) for v in unmatched_val_by_item.values()) unmatched_val_count = sum(len(v) for v in unmatched_val_by_item.values())
unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values()) unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values())
file_false_positive_rate = (unmatched_val_count / val_total) if val_total != 0 else 0
# 累加到各“审查项”的全局统计 # 累加到各“审查项”的全局统计
for it, cnt in answer_counts.items(): for it, cnt in answer_counts.items():
...@@ -112,7 +115,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -112,7 +115,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
print('#' * 40) print('#' * 40)
print( print(
f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} " f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} "
f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | accuracy {matched_total / answer_total:.2%} | invalid_val {(unmatched_val_count / val_total) if val_total != 0 else 0:.2%}" f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | recall {matched_total / answer_total:.2%} | false_positive_rate {file_false_positive_rate:.2%}"
) )
for item in sorted(answer_counts): for item in sorted(answer_counts):
item_matches = matched_by_item.get(item, []) item_matches = matched_by_item.get(item, [])
...@@ -133,10 +136,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -133,10 +136,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
for t in uv: for t in uv:
print(f" val: {t}") print(f" val: {t}")
# break # only first file for demo # break # only first file for demo
accuracy = overall_matched / overall_answer if overall_answer else 0 recall = overall_matched / overall_answer if overall_answer else 0
invalid_val = (overall_val - overall_matched) / overall_val if overall_val else 0 overall_false_positive_rate = (overall_val - overall_matched) / overall_val if overall_val else 0
print( print(
f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | accuracy {accuracy:.2%} | invalid_val {invalid_val:.2%}" f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | recall {recall:.2%} | false_positive_rate {overall_false_positive_rate:.2%}"
) )
# 按“审查项”的 overall 结果 # 按“审查项”的 overall 结果
...@@ -151,7 +154,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -151,7 +154,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
u_ans = overall_item_unmatched_answer.get(it, 0) u_ans = overall_item_unmatched_answer.get(it, 0)
u_val = overall_item_unmatched_val.get(it, 0) u_val = overall_item_unmatched_val.get(it, 0)
acc = (mat / ans) if ans else 0 acc = (mat / ans) if ans else 0
invalid_val = u_val / (mat + u_val) if (mat + u_val) else 0 item_false_positive_rate = u_val / (mat + u_val) if (mat + u_val) else 0
rows_by_item.append({ rows_by_item.append({
"审查项": it, "审查项": it,
"大模型匹配上的不合格项": mat, "大模型匹配上的不合格项": mat,
...@@ -159,13 +162,13 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -159,13 +162,13 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"大模型其他不合格项": u_val, "大模型其他不合格项": u_val,
"大模型未匹配上的不合格项(C-B)": u_ans, "大模型未匹配上的不合格项(C-B)": u_ans,
"查全率(B/C)": acc, "查全率(B/C)": acc,
"无关审查率(D/B+D)": invalid_val, "误报率(D/B+D)": item_false_positive_rate,
}) })
print( print(
f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | accuracy {acc:.2%} | invalid_val {invalid_val:.2%}" f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | recall {acc:.2%} | false_positive_rate {item_false_positive_rate:.2%}"
) )
overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"]) overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "误报率(D/B+D)"])
unmatched_val_total = sum(overall_item_unmatched_val.values()) unmatched_val_total = sum(overall_item_unmatched_val.values())
unmatched_answer_total = sum(overall_item_unmatched_answer.values()) unmatched_answer_total = sum(overall_item_unmatched_answer.values())
overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0 overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0
...@@ -176,10 +179,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -176,10 +179,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"合同所有不合格项": overall_answer, "合同所有不合格项": overall_answer,
"大模型其他不合格项": unmatched_val_total, "大模型其他不合格项": unmatched_val_total,
"大模型未匹配上的不合格项(C-B)": unmatched_answer_total, "大模型未匹配上的不合格项(C-B)": unmatched_answer_total,
"查全率(B/C)": accuracy, "查全率(B/C)": recall,
"无关审查率(D/B+D)": overall_invalid_rate, "误报率(D/B+D)": overall_invalid_rate,
} }
], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"]) ], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "误报率(D/B+D)"])
combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True) combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True)
compare_dir_name = val_dir.name compare_dir_name = val_dir.name
......
...@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace: ...@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--datasets-dir", "--datasets-dir",
type=Path, type=Path,
default=base / "results" / "jp-output", default=base / "results" / "jp-all-merge-prompt",
help="Directory containing Word files with annotations.", help="Directory containing Word files with annotations.",
) )
parser.add_argument( parser.add_argument(
...@@ -133,13 +133,13 @@ def _parse_args() -> argparse.Namespace: ...@@ -133,13 +133,13 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--val-dir", "--val-dir",
type=Path, type=Path,
default=base / "results" / "jp-output-extracted", default=base / "results" / "jp-all-merge-prompt-extracted",
help="Directory to store extracted xlsx files for comparison.", help="Directory to store extracted xlsx files for comparison.",
) )
parser.add_argument( parser.add_argument(
"--strip-suffixes", "--strip-suffixes",
nargs="*", nargs="*",
default=['_人机交互'], default=['_麓发改进'],
help=( help=(
"Optional filename suffixes to strip from generated val xlsx stems before " "Optional filename suffixes to strip from generated val xlsx stems before "
"comparison, e.g. --strip-suffixes _v1 _审阅版" "comparison, e.g. --strip-suffixes _v1 _审阅版"
......
No preview for this file type
...@@ -10,15 +10,16 @@ import uvicorn ...@@ -10,15 +10,16 @@ import uvicorn
import traceback import traceback
from loguru import logger from loguru import logger
from utils.common_util import extract_url_file, format_now from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats, pdf_support_formats from core.config import doc_support_formats, pdf_support_formats,MERGE_RULE_PROMPT
from core.tools.segment_summary import SegmentSummaryTool from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool from core.tools.segment_review import SegmentReviewTool
from core.tools.segment_rule_router import SegmentRuleRouterTool from core.tools.segment_rule_router import SegmentRuleRouterTool
from core.tools.retrieve_reference import RetrieveReferenceTool
from core.tools.reflect_retry import ReflectRetryTool from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0") app = FastAPI(title="合同审查智能体", version="0.1.0")
TMP_DIR = Path(__file__).resolve().parent / "tmp" TMP_DIR = Path(__file__).resolve().parent / "tmp"
...@@ -26,6 +27,7 @@ TMP_DIR.mkdir(parents=True, exist_ok=True) ...@@ -26,6 +27,7 @@ TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool() summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool() review_tool = SegmentReviewTool()
rule_router_tool = SegmentRuleRouterTool() rule_router_tool = SegmentRuleRouterTool()
reference_tool = RetrieveReferenceTool()
reflect_tool = ReflectRetryTool() reflect_tool = ReflectRetryTool()
...@@ -47,7 +49,7 @@ class DocumentParseRequest(BaseModel): ...@@ -47,7 +49,7 @@ class DocumentParseRequest(BaseModel):
class DocumentParseResponse(BaseModel): class DocumentParseResponse(BaseModel):
conversation_id: str conversation_id: str
chunk_ids: List[int] segment_ids: List[int]
ruleset_items: List[str] ruleset_items: List[str]
text: Optional[str] = None text: Optional[str] = None
file_ext: Optional[str] = None file_ext: Optional[str] = None
...@@ -78,15 +80,20 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse ...@@ -78,15 +80,20 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse
# ocr # ocr
await doc_obj.get_from_ocr() await doc_obj.get_from_ocr()
text = doc_obj.get_all_text() text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list() segment_ids = doc_obj.get_chunk_id_list()
# TODO: FastGPT BUG segment_ids必须从1开始,0开始会缺少第一段文本,后续需要修复
segment_ids = [idx + 1 for idx in segment_ids]
# get ruleset items # get ruleset items
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
ruleset_items = review_tool.rulesets.get(ruleset_id) or [] ruleset_items = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
ruleset_review_items = [r.get('title') for r in ruleset_items] ruleset_review_items = [
t for t in (r.get("title") for r in ruleset_items)
if isinstance(t, str) and t.strip()
]
return DocumentParseResponse( return DocumentParseResponse(
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
text=text, text=text,
chunk_ids=chunk_ids, segment_ids=segment_ids,
ruleset_items=ruleset_review_items, ruleset_items=ruleset_review_items,
file_ext = file_ext file_ext = file_ext
) )
...@@ -97,6 +104,8 @@ class SegmentSummaryRequest(BaseModel): ...@@ -97,6 +104,8 @@ class SegmentSummaryRequest(BaseModel):
conversation_id: str conversation_id: str
segment_id: int segment_id: int
party_role: Optional[str] = "" party_role: Optional[str] = ""
ruleset_id: Optional[str] = "通用"
routed_rule_titles: Optional[List[str]] = None
file_ext: str file_ext: str
context_facts: Optional[Dict] = None context_facts: Optional[Dict] = None
...@@ -115,14 +124,21 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse: ...@@ -115,14 +124,21 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 # chunk_id 在 SpireWordDoc 中为 1-based segment_idx = payload.segment_id - 1
try: try:
segment_text = doc_obj.get_chunk_item(chunk_idx) segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.") raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(
ruleset_id=ruleset_id,
routed_rule_titles=payload.routed_rule_titles,
).get("rules", [])
result = summary_tool.run( result = summary_tool.run(
segment_id=payload.segment_id, segment_id=segment_idx,
segment_text=segment_text, segment_text=segment_text,
rules=rules,
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_facts=payload.context_facts, context_facts=payload.context_facts,
) )
...@@ -168,20 +184,31 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -168,20 +184,31 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 segment_idx = payload.segment_id - 1
try: try:
segment_text = doc_obj.get_chunk_item(chunk_idx) segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.") raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(
ruleset_id=ruleset_id,
routed_rule_titles=payload.routed_rule_titles,
).get("rules", [])
# 暂时不添加摘要看下结果
# summary_keywords = reference_tool.summary_keywords(rules)
# context_summaries = store.search_facts(summary_keywords)
result = review_tool.run( result = review_tool.run(
segment_id=payload.segment_id, segment_id=segment_idx,
segment_text=segment_text, segment_text=segment_text,
ruleset_id=payload.ruleset_id or "通用", rules=rules,
routed_rule_titles=payload.routed_rule_titles,
party_role=payload.party_role or "", party_role=payload.party_role or "",
# 暂时不添加摘要看下结果
# context_summaries=context_summaries,
# TODO 获取与当前审查相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆 # TODO 获取与当前审查相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
context_memories=payload.context_memories, context_memories=payload.context_memories,
merge_rules_prompt=MERGE_RULE_PROMPT, # TODO 是否合并规则到同一 prompt 进行审查,当前默认合并,后续可调整为不合并以提升审查的针对性
) )
# Persist findings to memory store # Persist findings to memory store
...@@ -189,10 +216,11 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -189,10 +216,11 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
try: try:
store.add_finding_from_dict({ store.add_finding_from_dict({
"rule_title": f.get("rule_title", ""), "rule_title": f.get("rule_title", ""),
"segment_id": payload.segment_id, "segment_id": segment_idx,
"original_text": f.get("original_text",''), "original_text": f.get("original_text",''),
"issue": f.get("issue", ""), "issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(), "risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"result": f.get("result", ""),
"suggestion": f.get("suggestion", ""), "suggestion": f.get("suggestion", ""),
}) })
except Exception as e: except Exception as e:
...@@ -213,17 +241,18 @@ def route_segment_rules(payload: SegmentReviewRequest) -> SegmentRuleRouterRespo ...@@ -213,17 +241,18 @@ def route_segment_rules(payload: SegmentReviewRequest) -> SegmentRuleRouterRespo
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 segment_idx = payload.segment_id - 1
try: try:
segment_text = doc_obj.get_chunk_item(chunk_idx) segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.") raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
result = rule_router_tool.run( result = rule_router_tool.run(
segment_id=payload.segment_id, segment_id=segment_idx,
segment_text=segment_text, segment_text=segment_text,
ruleset_id=ruleset_id, rules=rules,
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_memories=payload.context_memories, context_memories=payload.context_memories,
) )
...@@ -252,20 +281,19 @@ class ReflectReviewResponse(BaseModel): ...@@ -252,20 +281,19 @@ class ReflectReviewResponse(BaseModel):
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse) @app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse: def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
ruleset_items = review_tool.rulesets.get(ruleset_id) or [] ruleset_items = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None) rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule: if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}") raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
# TODO 获取与当前审查规则相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆 summary_keywords = reference_tool.summary_keywords([rule])
# facts = store.get_facts() context_summaries_facts = store.search_facts(summary_keywords)
facts = []
# 查找审查规则对应的 findings # 查找审查规则对应的 findings
findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)] findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)]
final_findings = reflect_tool.run( final_findings = reflect_tool.run(
party_role=payload.party_role, party_role=payload.party_role,
rule=rule, rule=rule,
facts=facts, facts=context_summaries_facts,
findings=findings, findings=findings,
) )
......
...@@ -396,7 +396,7 @@ class SpirePdfDoc(DocBase): ...@@ -396,7 +396,7 @@ class SpirePdfDoc(DocBase):
return -1, None, None return -1, None, None
def get_chunk_id_list(self, step=1): def get_chunk_id_list(self, step=1):
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)] return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_chunk_num(self): def get_chunk_num(self):
return len(self._chunk_list) return len(self._chunk_list)
......
...@@ -700,7 +700,7 @@ class SpireWordDoc(DocBase): ...@@ -700,7 +700,7 @@ class SpireWordDoc(DocBase):
def get_chunk_id_list(self, step=1): def get_chunk_id_list(self, step=1):
self._ensure_loaded() self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)] return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self): def get_all_text(self):
self._ensure_loaded() self._ensure_loaded()
......
...@@ -724,7 +724,7 @@ class SpireWordDoc(DocBase): ...@@ -724,7 +724,7 @@ class SpireWordDoc(DocBase):
def get_chunk_id_list(self, step=1): def get_chunk_id_list(self, step=1):
self._ensure_loaded() self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)] return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self): def get_all_text(self):
self._ensure_loaded() self._ensure_loaded()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment