Commit 4441c13b by ccran

feat: 增加案例;修改审查提示词;

parent bd17ac4b
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda"
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python.defaultInterpreterPath": "/home/ccran/.conda/envs/lufa/bin/python",
"python.terminal.activateEnvironment": true
}
\ No newline at end of file
......@@ -3,22 +3,6 @@ from dataclasses import dataclass
# 可配置运行参数
use_docker = False
just_b_class = False
is_extract = True
use_original_text_verification = False
# @dataclass
# class LLMConfig:
# base_url: str = "http://172.21.107.45:9002/v1"
# api_key: str = "none"
# model: str = "Qwen3-32B"
#
#
# base_fastgpt_url = "http://172.21.107.45:3030"
# base_backend_url = "http://172.21.107.45:48080"
# ocr_url = 'http://172.21.107.45:8202/openapi/ocrUploadFile'
@dataclass
class LLMConfig:
......@@ -27,32 +11,19 @@ class LLMConfig:
model: str = 'Qwen2-72B-Instruct'
# MAX_SINGLE_CHUNK_SIZE=100000
# MAX_SINGLE_CHUNK_SIZE=5000
MAX_SINGLE_CHUNK_SIZE=2000
MERGE_RULE_PROMPT = False
MAX_SINGLE_CHUNK_SIZE=5000
META_KEY="META"
DEFAULT_RULESET_ID = "通用"
ALL_RULESET_IDS = ["通用","借款","担保","测试","财务口","金盘","金盘简化"]
FACT_DIMENSIONS = [
"当事人",
"标的",
"金额",
"支付",
"期限",
"交付",
"质量",
"知识产权",
"保密",
"违约责任",
"争议解决"
]
use_lufa = False
ALL_RULESET_IDS = ["通用","借款","担保","财务口","金盘","金盘简化"]
use_lufa = True
if use_lufa:
outer_backend_url = "http://znkf.lgfzgroup.com:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
api_key = "fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8"
else:
outer_backend_url = "http://218.77.58.8:8088"
outer_backend_url = "http://218.77.58.8:48080"
base_fastgpt_url = "http://192.168.252.71:18088"
base_backend_url = "http://192.168.252.71:48080"
api_key = "fastgpt-vLu2JHAfqwEq5FUQhvATFDK0yDS6fs804v7KwWBMyU4sRrHzh4UGl89Zpa"
......@@ -78,40 +49,7 @@ LLM = {
}
doc_support_formats = [".docx", ".doc", ".wps"]
pdf_support_formats = [".txt", ".md", ".pdf"]
# excel字段
field_mapping = {
"review": {
"id": "序号",
"key_points": "审查内容",
"original_text": "合同原文",
"details": "审查过程",
"result": "审查结果",
"suggest": "审查建议",
"chunk_id": "文件块序号",
"chunk_info": "文件块信息",
},
"category": {"text": "原文", "judge": "判断依据", "chunk_location": "原文位置"},
}
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5
# 销售类别判断
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思"
# 最大分片数量
min_single_chunk_size = 2000
max_single_chunk_size = 20000
max_chunk_page = 10
# 知识库读取sheet
if just_b_class:
know_sheet_name = "B类"
port = 9008
else:
if is_extract:
know_sheet_name = "提取审查"
port = 9016
else:
know_sheet_name = 0
port = 9006
# know_sheet_name = '发票审查'
reload = False
......@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
from dataclasses import asdict
from core.tool import ToolBase, tool, tool_func
from core.memory import RiskFinding
from core.memory import Finding
@tool("memory_write", "分段记忆写入")
......@@ -32,17 +32,19 @@ class MemoryWriteTool(ToolBase):
issue = f.get("issue") or f.get("issue_description") or ""
level = (f.get("level") or f.get("risk_level") or "M").upper()
suggestion = f.get("suggestion") or ""
result = f.get("result") or ""
evs = list(f.get("evidence_quotes", []) or [])
original_text = evs[0] if evs else (f.get("original_text") or "")
try:
finding_obj = RiskFinding(
finding_obj = Finding(
rule_title=rule_title,
segment_id=int(segment_id) if str(segment_id).isdigit() else 0,
original_text=original_text,
issue_description=issue,
issue=issue,
risk_level=level,
suggestion=suggestion,
result=result,
)
store.add_finding(finding_obj)
added.append(asdict(finding_obj))
......
......@@ -59,7 +59,7 @@ REFLECT_USER_PROMPT = '''
【已有风险 findings】
{findings_json}
【合同事实记忆 facts】
【合同摘要事实记忆 facts】
{facts_json}
【合同立场】
......
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Any
import re
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import ToolBase, tool, tool_func
from core.cache import get_cached_memory
from utils.excel_util import ExcelUtil
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
@tool("retrieve_reference", "审查参考检索")
class RetrieveReferenceTool(ToolBase):
def __init__(self) -> None:
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
"case": "案例",
"summary":"摘要项"
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict[str, Any]]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func(
{
"type": "object",
"properties": {
"question": {"type": "string"},
"top_k": {"type": "int"},
"ruleset_id": {"type": "string"},
"routed_rule_titles": {"type": "array", "items": {"type": "string"}},
},
"required": ["question"],
"required": [],
}
)
def run(self, question: str, top_k: int = 5, conversation_id: str = "") -> Dict:
memory_refs = self._search_memory(question, conversation_id, top_k)
kb_refs = self._search_knowledge_base(question, top_k)
external_refs = self._search_external(question, top_k)
def run(self, ruleset_id: str = "", routed_rule_titles: List[str] | None = None) -> Dict[str, Any]:
target_ruleset_id = ruleset_id or self.default_ruleset_id
full_rules = self.rulesets.get(target_ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
if routed_rule_titles is None:
rules = full_rules
else:
title_set = {title for title in routed_rule_titles if isinstance(title, str)}
rules = [r for r in full_rules if r.get("title") in title_set]
return {
"memory_refs": memory_refs,
"kb_refs": kb_refs,
"external_refs": external_refs,
"ruleset_id": target_ruleset_id,
"rules": rules,
"rule_titles": [r.get("title", "") for r in rules],
"total": len(rules),
}
def _search_memory(self, question: str, conversation_id: str, top_k: int) -> List[Dict[str, Any]]:
if not conversation_id:
return []
store = get_cached_memory(conversation_id)
facts = store.get_facts()
results: List[Dict[str, Any]] = []
def _add(dim: str, payload: Any) -> None:
snippet = payload if isinstance(payload, str) else str(payload)
results.append({
"source": f"memory:{dim}",
"dimension": dim,
"snippet": snippet,
})
for dim in FACT_DIMENSIONS:
val = facts.get(dim)
if val is None:
continue
if dim in question:
_add(dim, val)
return results[:top_k]
def _search_knowledge_base(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement KB retrieval
return []
def _search_external(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement external retrieval (e.g., search engine)
return []
def summary_keywords(self, rules: List[Dict[str, Any]]) -> List[str]:
return [r.get("summary", "") for r in rules if r.get("summary")]
if __name__ == "__main__":
tmp_memory = get_cached_memory("tmp")
tmp_memory.add_finding_from_dict({
"issue": "支付方式不明确",
"original_text": "买方应在收到货物后30天内支付全部货款。",
"risk_level": "H",
"rule_title": "支付条款审查",
})
tmp_memory.update_facts({
"支付":{
"支付方式": "银行转账",
"支付期限": "收到货物后30天内",
}
})
tool = RetrieveReferenceTool()
result = tool.run(
question="支付方式是什么?",
top_k=3,
conversation_id="tmp",
)
print(result)
\ No newline at end of file
result = tool.run(ruleset_id="金盘", routed_rule_titles=None)
for rule in result.get("rules", []):
print(f"Rule Title: {rule.get('title')}")
print(f"Case: {rule.get('case')}")
print("-" * 20)
# print(result.get("total", 0))
\ No newline at end of file
......@@ -32,7 +32,8 @@ class LLMTool(ToolBase):
def run_with_loop(self, coro):
try:
return asyncio.run(coro)
except RuntimeError:
except RuntimeError as e:
print(f'RuntimeError in run_with_loop: {e}, trying to get event loop and run until complete.')
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)
......
......@@ -2,13 +2,10 @@ from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Dict, List, Optional
from core.config import ALL_RULESET_IDS, DEFAULT_RULESET_ID
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
from utils.excel_util import ExcelUtil
ROUTER_SYSTEM_PROMPT = '''
......@@ -42,7 +39,7 @@ ROUTER_USER_PROMPT = '''
【合同立场】
{party_role}
【候选审查规则
【候选审查规则
{candidate_rules_json}
【任务】
......@@ -68,20 +65,6 @@ ROUTER_OUTPUT_SCHEMA = '''
class SegmentRuleRouterTool(LLMTool):
def __init__(self) -> None:
super().__init__(ROUTER_SYSTEM_PROMPT)
self.default_ruleset_id = DEFAULT_RULESET_ID
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
self.rulesets: Dict[str, List[Dict]] = {}
for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map)
self.rulesets[rs_id] = rules
@tool_func(
{
......@@ -89,22 +72,22 @@ class SegmentRuleRouterTool(LLMTool):
"properties": {
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"ruleset_id": {"type": "string"},
"rules": {"type": "array", "items": {"type": "object"}},
"party_role": {"type": "string"},
"context_memories": {"type": "array"},
},
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"],
"required": ["segment_id", "segment_text", "rules", "party_role"],
}
)
def run(
self,
segment_id: int,
segment_text: str,
ruleset_id: str,
rules: List[Dict],
party_role: str,
context_memories: Optional[List[Dict]] = None,
) -> Dict:
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or []
rules = rules or []
routed_rules = self._route_rules(
segment_text=segment_text,
rules=rules,
......@@ -113,7 +96,6 @@ class SegmentRuleRouterTool(LLMTool):
)
return {
"segment_id": segment_id,
"ruleset_id": ruleset_id,
"routed_rules": routed_rules,
"routed_rule_titles": [r.get("title", "") for r in routed_rules],
}
......@@ -121,10 +103,7 @@ class SegmentRuleRouterTool(LLMTool):
def _build_candidate_rules(self, rules: List[Dict]) -> List[Dict]:
return [
{
"title": r.get("title", ""),
"level": r.get("level", ""),
"rule": r.get("rule", ""),
"triggers": r.get("triggers", ""),
r.get("title", ""): r.get("rule", "")
}
for r in rules
if r.get("title")
......@@ -223,6 +202,22 @@ class SegmentRuleRouterTool(LLMTool):
if __name__ == "__main__":
tool = SegmentRuleRouterTool()
demo_rules = [
{
"id": "R1",
"title": "付款触发条件明确性",
"level": "H",
"rule": "付款应绑定明确触发条件和验收标准。",
"triggers": "支付,付款,验收",
},
{
"id": "R2",
"title": "违约责任对等性",
"level": "M",
"rule": "违约责任应当相对对等且违约金标准明确。",
"triggers": "违约,违约金",
},
]
demo_segment_text = (
"甲方应在合同签订后5个工作日内向乙方支付合同总价30%作为预付款,"
"剩余70%在乙方完成交付并经甲方验收合格后30日内支付。"
......@@ -232,7 +227,7 @@ if __name__ == "__main__":
result = tool.run(
segment_id=1,
segment_text=demo_segment_text,
ruleset_id="通用",
rules=demo_rules,
party_role="甲方",
context_memories=[],
)
......
from __future__ import annotations
import asyncio
import json
from typing import Dict, List, Optional
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
from core.config import FACT_DIMENSIONS
from core.config import META_KEY
SUMMARY_SYSTEM_PROMPT = f'''
你是合同事实提取智能体(SegmentSummary)。
你的任务是从当前合同分段中提取“客观事实”,并按指定维度结构化输出。
你的任务是:**基于给定的审查规则,从当前合同分段中提取“与该规则直接相关的客观事实”,并结构化输出。**
【核心原则】
你必须严格围绕“规则所需信息”进行提取。
---
【事实定义】
......@@ -24,37 +27,51 @@ SUMMARY_SYSTEM_PROMPT = f'''
3. 不得补充未出现的主体、条件或数值;
4. 允许对原文做最小结构化拆分(例如金额、比例、期限)。
---
【规则驱动提取要求(关键)】
- 仅提取“该审查规则执行所需要的信息字段”
- 不得提取与该规则无关的信息(即使这些信息在文本中存在)
- 若规则未涉及某类信息,则不得输出对应字段
- 若规则涉及某字段但文本未出现,需显式标记为 "未明确"
---
【输出结构】
- 输出字段:facts
- facts 是一个对象
- 键为以下预设维度:
- 键必须来自【规则字段定义(rule_fields)】
- 不得使用预设通用维度(如“支付/违约责任”等)
---
{", ".join(FACT_DIMENSIONS)}
【字段填充规则】
- 每个维度值必须是对象或对象列表
- 未出现的维度可以省略
- 每个字段值必须是对象或对象列表
- 不得输出字符串作为字段值
- 字段内容必须为原文的最小结构化表达
- 不得改写原文含义
【结构规则】
---
- 仅提取对合同履行或责任具有实际意义的事实
- 不得输出字符串作为维度值,必须使用对象
- 不得输出解释、总结或风险判断
【缺失信息处理(非常重要)】
【上下文事实使用规则】
- 若规则要求的字段在当前分段未出现:
→ 必须输出该字段,并标记为:
上下文事实仅用于:
- 避免重复提取已存在的事实
- 保持字段命名一致
"未明确"
不得:
- 使用上下文事实补充当前分段没有出现的信息
- 修改当前分段原文事实
(用于后续审查判断)
---
【约束】
- 严禁编造信息
- 严禁推断未出现的内容
- 不得输出风险判断或解释
- 严格输出 JSON
'''
......@@ -62,11 +79,17 @@ SUMMARY_USER_PROMPT = '''
【分段原文】
{segment_text}
【上下文事实】
{context_facts}
【规则字段定义(仅提取这些字段)】
{rule_fields}
【任务】
请仅提取“当前分段中,与候选审查规则直接相关的客观事实”。
仅提取当前分段中明确出现的客观事实。
不得从上下文事实中补充新的信息。
【特别要求】
- facts 的顶层 key 必须是规则 title
- 每个规则下仅保留与该规则直接相关的信息
- 若某规则在当前分段未出现关键信息,输出该规则并标记为 "未明确"
- 不得提取与规则无关的信息
输出 JSON。
'''
......@@ -75,8 +98,8 @@ OUTPUT_EXAMPLE = '''
```json
{
"facts": {
"支付": {"方式": "银行转账", "时间": "验收后30日内"},
"违约责任": {"违约金比例": "合同总金额的5%"}
"支付审查": {"方式": "银行转账", "时间": "验收后30日内"},
"违约责任审查": {"违约金比例": "合同总金额的5%"}
}
}
```
......@@ -94,28 +117,54 @@ class SegmentSummaryTool(LLMTool):
"properties": {
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"rules": {"type": "array", "items": {"type": "object"}},
"party_role": {"type": "string"},
"context_facts": {"type": "object"},
},
"required": ["segment_id", "segment_text"],
"required": ["segment_id", "segment_text", "rules"],
}
)
def run(
self,
segment_id: int,
segment_text: str,
rules: List[Dict],
party_role: str = "",
context_facts: Optional[Dict] = None,
) -> Dict:
rules = rules or []
try:
return self.run_with_loop(self._summarize_async(segment_id, segment_text, party_role, context_facts))
return self.run_with_loop(
self._summarize_async(segment_id, segment_text, rules, party_role, context_facts)
)
except Exception:
return {}
def _build_prompt(self, segment_text: str, context_facts: Optional[Dict], party_role: str) -> List[Dict[str, str]]:
def _stringify_rule(self, rules: List[Dict]) -> str:
lines = []
for r in rules:
id = r.get("id", "")
rule_text = r.get("rule", "")
lines.append(f"规则ID: {id}\n审查规则: {rule_text}\n")
return "\n".join(lines)
def _build_prompt(
self,
segment_text: str,
rules: List[Dict],
context_facts: Optional[Dict],
party_role: str,
) -> List[Dict[str, str]]:
# 获取规则字段定义
rule_fields = [
r.get("summary") for r in rules
if r.get("summary")
]
user_content = SUMMARY_USER_PROMPT.format(
segment_text=segment_text,
context_facts=json.dumps(context_facts or {}, ensure_ascii=False),
# party_role=party_role,
# rules_json=self._stringify_rule(rules),
rule_fields=json.dumps(rule_fields, ensure_ascii=False),
) + OUTPUT_EXAMPLE
return self.build_messages(user_content)
......@@ -123,30 +172,30 @@ class SegmentSummaryTool(LLMTool):
self,
segment_id: int,
segment_text: str,
rules: List[Dict],
party_role: str,
context_facts: Optional[Dict],
) -> Dict:
msgs = self._build_prompt(segment_text, context_facts, party_role)
final_facts: Dict = {}
msgs = self._build_prompt(segment_text, rules, context_facts, party_role)
try:
resp = await self.chat_async(msgs)
# print("segment summary response:", resp)
data = self.parse_first_json(resp)
facts = data.get("facts") or {}
except Exception:
except Exception as e:
print(f'Error in segment summary for segment {segment_id}: {e}')
facts = {}
# print(f'SegmentSummaryTool facts: {facts}')
if isinstance(facts,list):
final_facts['内容'] = facts
else:
final_facts = facts
final_facts['segment_id'] = segment_id
return final_facts
facts[META_KEY] = {
"segment_id": segment_id,
}
return facts
if __name__=='__main__':
tool = SegmentSummaryTool()
res = tool.run(
segment_id=1,
segment_text="甲方应于合同签订之日起30日内向乙方支付合同总金额的50%,余款在货物验收合格后30日内付清.",
rules=[{"id": "R1", "title": "付款", "rule": "付款相关", "summary": "付款方式"}],
context_facts={},
)
print(res)
\ No newline at end of file
......@@ -14,7 +14,7 @@ from utils.http_util import upload_file, fastgpt_openai_chat, download_file
SUFFIX='_麓发改进'
batch_input_dir_path = 'jp-input'
batch_output_dir_path = 'jp-output-lufa'
batch_output_dir_path = 'jp-all'
batch_size = 5
# 麓发fastgpt接口
# url = 'http://192.168.252.71:18089/api/v1/chat/completions'
......@@ -31,9 +31,6 @@ token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyM
def extract_url(text):
# \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx))
# 麓发正则
# excel_p, doc_p = r'导出Excel结果\s*([^"]*xlsx)', r'导出Doc结果\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 金盘正则
excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))'
# 使用 re.search() 查找第一个匹配项
excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text)
......
......@@ -69,6 +69,8 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
best_score = -1
for idx, cand in enumerate(candidates):
ans_text = ans_text.strip()
if cand is None or not isinstance(cand,str):
continue
cand = cand.strip()
score = max(
fuzz.partial_ratio(ans_text, cand),
......@@ -99,6 +101,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
unmatched_val_count = sum(len(v) for v in unmatched_val_by_item.values())
unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values())
file_false_positive_rate = (unmatched_val_count / val_total) if val_total != 0 else 0
# 累加到各“审查项”的全局统计
for it, cnt in answer_counts.items():
......@@ -112,7 +115,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
print('#' * 40)
print(
f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} "
f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | accuracy {matched_total / answer_total:.2%} | invalid_val {(unmatched_val_count / val_total) if val_total != 0 else 0:.2%}"
f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | recall {matched_total / answer_total:.2%} | false_positive_rate {file_false_positive_rate:.2%}"
)
for item in sorted(answer_counts):
item_matches = matched_by_item.get(item, [])
......@@ -133,10 +136,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
for t in uv:
print(f" val: {t}")
# break # only first file for demo
accuracy = overall_matched / overall_answer if overall_answer else 0
invalid_val = (overall_val - overall_matched) / overall_val if overall_val else 0
recall = overall_matched / overall_answer if overall_answer else 0
overall_false_positive_rate = (overall_val - overall_matched) / overall_val if overall_val else 0
print(
f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | accuracy {accuracy:.2%} | invalid_val {invalid_val:.2%}"
f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | recall {recall:.2%} | false_positive_rate {overall_false_positive_rate:.2%}"
)
# 按“审查项”的 overall 结果
......@@ -151,7 +154,7 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
u_ans = overall_item_unmatched_answer.get(it, 0)
u_val = overall_item_unmatched_val.get(it, 0)
acc = (mat / ans) if ans else 0
invalid_val = u_val / (mat + u_val) if (mat + u_val) else 0
item_false_positive_rate = u_val / (mat + u_val) if (mat + u_val) else 0
rows_by_item.append({
"审查项": it,
"大模型匹配上的不合格项": mat,
......@@ -159,13 +162,13 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"大模型其他不合格项": u_val,
"大模型未匹配上的不合格项(C-B)": u_ans,
"查全率(B/C)": acc,
"无关审查率(D/B+D)": invalid_val,
"误报率(D/B+D)": item_false_positive_rate,
})
print(
f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | accuracy {acc:.2%} | invalid_val {invalid_val:.2%}"
f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | recall {acc:.2%} | false_positive_rate {item_false_positive_rate:.2%}"
)
overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"])
overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "误报率(D/B+D)"])
unmatched_val_total = sum(overall_item_unmatched_val.values())
unmatched_answer_total = sum(overall_item_unmatched_answer.values())
overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0
......@@ -176,10 +179,10 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"合同所有不合格项": overall_answer,
"大模型其他不合格项": unmatched_val_total,
"大模型未匹配上的不合格项(C-B)": unmatched_answer_total,
"查全率(B/C)": accuracy,
"无关审查率(D/B+D)": overall_invalid_rate,
"查全率(B/C)": recall,
"误报率(D/B+D)": overall_invalid_rate,
}
], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "无关审查率(D/B+D)"])
], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查全率(B/C)", "误报率(D/B+D)"])
combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True)
compare_dir_name = val_dir.name
......
......@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument(
"--datasets-dir",
type=Path,
default=base / "results" / "jp-output",
default=base / "results" / "jp-all-merge-prompt",
help="Directory containing Word files with annotations.",
)
parser.add_argument(
......@@ -133,13 +133,13 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument(
"--val-dir",
type=Path,
default=base / "results" / "jp-output-extracted",
default=base / "results" / "jp-all-merge-prompt-extracted",
help="Directory to store extracted xlsx files for comparison.",
)
parser.add_argument(
"--strip-suffixes",
nargs="*",
default=['_人机交互'],
default=['_麓发改进'],
help=(
"Optional filename suffixes to strip from generated val xlsx stems before "
"comparison, e.g. --strip-suffixes _v1 _审阅版"
......
No preview for this file type
......@@ -10,15 +10,16 @@ import uvicorn
import traceback
from loguru import logger
from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats, pdf_support_formats
from core.config import doc_support_formats, pdf_support_formats,MERGE_RULE_PROMPT
from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool
from core.tools.segment_rule_router import SegmentRuleRouterTool
from core.tools.retrieve_reference import RetrieveReferenceTool
from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0")
TMP_DIR = Path(__file__).resolve().parent / "tmp"
......@@ -26,6 +27,7 @@ TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool()
rule_router_tool = SegmentRuleRouterTool()
reference_tool = RetrieveReferenceTool()
reflect_tool = ReflectRetryTool()
......@@ -47,7 +49,7 @@ class DocumentParseRequest(BaseModel):
class DocumentParseResponse(BaseModel):
conversation_id: str
chunk_ids: List[int]
segment_ids: List[int]
ruleset_items: List[str]
text: Optional[str] = None
file_ext: Optional[str] = None
......@@ -78,15 +80,20 @@ async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse
# ocr
await doc_obj.get_from_ocr()
text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list()
segment_ids = doc_obj.get_chunk_id_list()
# TODO: FastGPT BUG segment_ids必须从1开始,0开始会缺少第一段文本,后续需要修复
segment_ids = [idx + 1 for idx in segment_ids]
# get ruleset items
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
ruleset_review_items = [r.get('title') for r in ruleset_items]
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
ruleset_items = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
ruleset_review_items = [
t for t in (r.get("title") for r in ruleset_items)
if isinstance(t, str) and t.strip()
]
return DocumentParseResponse(
conversation_id=payload.conversation_id,
text=text,
chunk_ids=chunk_ids,
segment_ids=segment_ids,
ruleset_items=ruleset_review_items,
file_ext = file_ext
)
......@@ -97,6 +104,8 @@ class SegmentSummaryRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
ruleset_id: Optional[str] = "通用"
routed_rule_titles: Optional[List[str]] = None
file_ext: str
context_facts: Optional[Dict] = None
......@@ -115,14 +124,21 @@ def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 # chunk_id 在 SpireWordDoc 中为 1-based
segment_idx = payload.segment_id - 1
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(
ruleset_id=ruleset_id,
routed_rule_titles=payload.routed_rule_titles,
).get("rules", [])
result = summary_tool.run(
segment_id=payload.segment_id,
segment_id=segment_idx,
segment_text=segment_text,
rules=rules,
party_role=payload.party_role or "",
context_facts=payload.context_facts,
)
......@@ -168,20 +184,31 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1
segment_idx = payload.segment_id - 1
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(
ruleset_id=ruleset_id,
routed_rule_titles=payload.routed_rule_titles,
).get("rules", [])
# 暂时不添加摘要看下结果
# summary_keywords = reference_tool.summary_keywords(rules)
# context_summaries = store.search_facts(summary_keywords)
result = review_tool.run(
segment_id=payload.segment_id,
segment_id=segment_idx,
segment_text=segment_text,
ruleset_id=payload.ruleset_id or "通用",
routed_rule_titles=payload.routed_rule_titles,
rules=rules,
party_role=payload.party_role or "",
# 暂时不添加摘要看下结果
# context_summaries=context_summaries,
# TODO 获取与当前审查相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
context_memories=payload.context_memories,
merge_rules_prompt=MERGE_RULE_PROMPT, # TODO 是否合并规则到同一 prompt 进行审查,当前默认合并,后续可调整为不合并以提升审查的针对性
)
# Persist findings to memory store
......@@ -189,10 +216,11 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
try:
store.add_finding_from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": payload.segment_id,
"segment_id": segment_idx,
"original_text": f.get("original_text",''),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"result": f.get("result", ""),
"suggestion": f.get("suggestion", ""),
})
except Exception as e:
......@@ -213,17 +241,18 @@ def route_segment_rules(payload: SegmentReviewRequest) -> SegmentRuleRouterRespo
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1
segment_idx = payload.segment_id - 1
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
segment_text = doc_obj.get_chunk_item(segment_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
rules = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
result = rule_router_tool.run(
segment_id=payload.segment_id,
segment_id=segment_idx,
segment_text=segment_text,
ruleset_id=ruleset_id,
rules=rules,
party_role=payload.party_role or "",
context_memories=payload.context_memories,
)
......@@ -252,20 +281,19 @@ class ReflectReviewResponse(BaseModel):
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id)
ruleset_id = payload.ruleset_id or review_tool.default_ruleset_id
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
ruleset_id = payload.ruleset_id or reference_tool.default_ruleset_id
ruleset_items = reference_tool.run(ruleset_id=ruleset_id).get("rules", [])
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
# TODO 获取与当前审查规则相关的上下文记忆(如之前的审查结果、总结事实等),而非全部记忆
# facts = store.get_facts()
facts = []
summary_keywords = reference_tool.summary_keywords([rule])
context_summaries_facts = store.search_facts(summary_keywords)
# 查找审查规则对应的 findings
findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)]
final_findings = reflect_tool.run(
party_role=payload.party_role,
rule=rule,
facts=facts,
facts=context_summaries_facts,
findings=findings,
)
......
......@@ -396,7 +396,7 @@ class SpirePdfDoc(DocBase):
return -1, None, None
def get_chunk_id_list(self, step=1):
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)]
return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_chunk_num(self):
return len(self._chunk_list)
......
......@@ -700,7 +700,7 @@ class SpireWordDoc(DocBase):
def get_chunk_id_list(self, step=1):
self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)]
return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self):
self._ensure_loaded()
......
......@@ -724,7 +724,7 @@ class SpireWordDoc(DocBase):
def get_chunk_id_list(self, step=1):
self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)]
return [idx for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self):
self._ensure_loaded()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment