Commit 6b4d3476 by ccran

feat:添加merger合并同原文

parent 6b313617
......@@ -16,7 +16,7 @@ MAX_SINGLE_CHUNK_SIZE=5000
META_KEY="META"
DEFAULT_RULESET_ID = "通用"
ALL_RULESET_IDS = ["通用","借款","担保","财务口","金盘","金盘简化"]
use_lufa = False
use_lufa = True
if use_lufa:
outer_backend_url = "http://znkf.lgfzgroup.com:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
......
......@@ -2,11 +2,12 @@ from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from threading import RLock
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, List, Optional
from uuid import uuid4
from utils.http_util import upload_file
from utils.doc_util import DocBase
......@@ -17,6 +18,16 @@ logger = logging.getLogger(__name__)
_ALLOWED_RISK_LEVELS = {"H", "M", "L",""}
FINDING_KEY_REVIEW = "review"
FINDING_KEY_REFLECT = "reflect"
FINDING_KEY_MERGE = "merge"
_DEFAULT_REVIEW_KEY = FINDING_KEY_REVIEW
_FINDING_KEY_SHEET_NAMES = {
FINDING_KEY_REVIEW: "审查结果",
FINDING_KEY_REFLECT: "复核结果",
FINDING_KEY_MERGE: "合并结果",
}
@dataclass
......@@ -27,6 +38,7 @@ class Finding:
issue: str
risk_level: str
suggestion: str
id: str = ""
result: str = ""
def __post_init__(self) -> None:
......@@ -39,6 +51,7 @@ class Finding:
def from_dict(cls, data: Dict) -> "Finding":
data = data or {}
return cls(
id=str(data.get("id", "")),
rule_title=str(data.get("rule_title", "")),
segment_id=int(data.get("segment_id", 0) or 0),
original_text=str(data.get("original_text", "")),
......@@ -50,7 +63,7 @@ class Finding:
def __repr__(self):
return (
f"Finding(rule_title={self.rule_title!r}, segment_id={self.segment_id}, "
f"Finding(id={self.id!r}, rule_title={self.rule_title!r}, segment_id={self.segment_id}, "
f"issue={self.issue!r}, risk_level={self.risk_level!r}, result={self.result!r})"
)
......@@ -67,8 +80,7 @@ class MemoryStore:
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = RLock()
self.facts: List[Dict[str, Any]] = []
self.findings: List[Finding] = []
self.final_findings: List[Finding] = []
self.findings: Dict[str, List[Finding]] = {}
self._load()
# ---------------------- facts ----------------------
......@@ -117,48 +129,39 @@ class MemoryStore:
return matched_values
# -------------------- findings ---------------------
def add_finding(self, finding: Finding) -> Finding:
return self._add_finding(self.findings, finding)
def add_finding_from_dict(self, data: Dict) -> Finding:
return self.add_finding(Finding.from_dict(data))
def add_final_finding(self, finding: Finding) -> Finding:
return self._add_finding(self.final_findings, finding)
def add_final_finding_from_dict(self, data: Dict) -> Finding:
return self.add_final_finding(Finding.from_dict(data))
def list_findings(self) -> List[Finding]:
return self._list_findings(self.findings)
def list_final_findings(self) -> List[Finding]:
return self._list_findings(self.final_findings)
def get_findings_by_segment(self, segment_id: int) -> List[Finding]:
return self._get_findings_by_segment(self.findings, segment_id)
def add_finding(self, key: str, finding: Finding) -> Finding:
return self._add_finding(key, finding)
def get_final_findings_by_segment(self, segment_id: int) -> List[Finding]:
return self._get_findings_by_segment(self.final_findings, segment_id)
def list_findings(self, key: str) -> List[Finding]:
return self._list_findings(self._get_findings_bucket(key))
def delete_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("findings", segment_id)
def get_findings_by_segment(self, key: str, segment_id: int) -> List[Finding]:
return self._get_findings_by_segment(self._get_findings_bucket(key), segment_id)
def delete_final_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("final_findings", segment_id)
def delete_findings_by_segment(self, key: str, segment_id: int) -> int:
return self._delete_findings_by_segment(key, segment_id)
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[Finding]:
return self._search_findings(self.findings, keyword, rule_title, risk_level)
def search_findings(self, key: str, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[Finding]:
return self._search_findings(self._get_findings_bucket(key), keyword, rule_title, risk_level)
def search_final_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[Finding]:
return self._search_findings(self.final_findings, keyword, rule_title, risk_level)
def list_findings_grouped(self) -> Dict[str, List[Finding]]:
with self._lock:
return {k: list(v) for k, v in self.findings.items()}
def _add_finding(self, target: List[Finding], finding: Finding) -> Finding:
def _add_finding(self, key: str, finding: Finding) -> Finding:
with self._lock:
target.append(finding)
finding_key = self._normalize_finding_key(key)
if not finding.id:
finding.id = uuid4().hex
bucket = self.findings.setdefault(finding_key, [])
bucket.append(finding)
self._persist()
return finding
def _get_findings_bucket(self, key: str) -> List[Finding]:
finding_key = self._normalize_finding_key(key)
return self.findings.setdefault(finding_key, [])
def _list_findings(self, target: List[Finding]) -> List[Finding]:
with self._lock:
return list(target)
......@@ -167,12 +170,12 @@ class MemoryStore:
with self._lock:
return [f for f in target if f.segment_id == segment_id]
def _delete_findings_by_segment(self, attr_name: str, segment_id: int) -> int:
def _delete_findings_by_segment(self, key: str, segment_id: int) -> int:
with self._lock:
current = getattr(self, attr_name)
current = self._get_findings_bucket(key)
before = len(current)
updated = [f for f in current if f.segment_id != segment_id]
setattr(self, attr_name, updated)
self.findings[key] = updated
removed = before - len(updated)
if removed:
self._persist()
......@@ -211,14 +214,15 @@ class MemoryStore:
with self._lock:
self.facts.clear()
self.findings.clear()
self.final_findings.clear()
self._persist()
def _persist(self) -> None:
payload = {
"facts": self.facts,
"findings": [asdict(f) for f in self.findings],
"final_findings": [asdict(f) for f in self.final_findings],
"findings": {
key: [asdict(f) for f in values]
for key, values in self.findings.items()
},
}
try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
......@@ -233,8 +237,23 @@ class MemoryStore:
data = json.loads(raw or "{}")
if isinstance(data, dict):
self.facts = data.get("facts") or []
self.findings = [Finding.from_dict(item) for item in data.get("findings", []) or []]
self.final_findings = [Finding.from_dict(item) for item in data.get("final_findings", []) or []]
loaded_findings = data.get("findings", {})
findings_map: Dict[str, List[Finding]] = {}
if isinstance(loaded_findings, dict):
for key, items in loaded_findings.items():
normalized_key = self._normalize_finding_key(str(key))
findings_map[normalized_key] = [Finding.from_dict(item) for item in (items or [])]
self.findings = findings_map
needs_persist = False
for bucket in self.findings.values():
for finding in bucket:
if not finding.id:
finding.id = uuid4().hex
needs_persist = True
if needs_persist:
self._persist()
except Exception as exc:
logger.error("Failed to load memory store: %s", exc)
......@@ -251,10 +270,8 @@ class MemoryStore:
with self._lock:
wb = Workbook()
ws_final_findings = wb.active
ws_final_findings.title = "最终结果"
finding_headers = [
("id", "ID"),
("rule_title", "规则标题"),
("segment_id", "分段ID"),
("original_text", "原文"),
......@@ -263,20 +280,27 @@ class MemoryStore:
("result", "合格性"),
("suggestion", "建议"),
]
# add final findings
ws_final_findings.append([label for _, label in finding_headers])
for f in self.final_findings:
ws_final_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
# add findings
ws_findings = wb.create_sheet("中间结果")
ws_findings.append([label for _, label in finding_headers])
for f in self.findings:
ws_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
grouped_items = list(self.findings.items())
if grouped_items:
first_key, first_values = grouped_items[0]
ws_first = wb.active
first_sheet_name = _FINDING_KEY_SHEET_NAMES.get(self._normalize_finding_key(first_key), first_key)
ws_first.title = self._safe_sheet_name(first_sheet_name)
ws_first.append([label for _, label in finding_headers])
for f in first_values:
ws_first.append([getattr(f, key, "") for key, _ in finding_headers])
for key, values in grouped_items[1:]:
sheet_name = _FINDING_KEY_SHEET_NAMES.get(self._normalize_finding_key(key), key)
ws = wb.create_sheet(self._safe_sheet_name(sheet_name))
ws.append([label for _, label in finding_headers])
for f in values:
ws.append([getattr(f, item_key, "") for item_key, _ in finding_headers])
else:
ws_empty = wb.active
ws_empty.title = self._safe_sheet_name(_FINDING_KEY_SHEET_NAMES.get(_DEFAULT_REVIEW_KEY, _DEFAULT_REVIEW_KEY))
ws_empty.append([label for _, label in finding_headers])
ws_facts = wb.create_sheet("合同事实")
if self.facts:
......@@ -285,7 +309,7 @@ class MemoryStore:
if not isinstance(item, dict):
ws_facts.append(["事实", json.dumps(item, ensure_ascii=False)])
continue
meta_info = item.pop(META_KEY, None)
meta_info = item.get(META_KEY, None)
ws_facts.append([json.dumps(meta_info, ensure_ascii=False), json.dumps(item, ensure_ascii=False)])
else:
ws_facts.append(["元信息", "事实内容"])
......@@ -307,7 +331,7 @@ class MemoryStore:
doc_obj: DocBase,
file_name: Optional[str] = None,
remove_prefix: bool = False,
export_final: bool = False,
finding_key: str = _DEFAULT_REVIEW_KEY,
) -> Dict[str, Any]:
"""Add all findings as comments to a document, upload, then delete the local file."""
if doc_obj is None:
......@@ -321,12 +345,10 @@ class MemoryStore:
name = f"{name}{suffix}"
output_path = Path(__file__).resolve().parent.parent / "tmp" / name
if export_final:
target_findings = self.final_findings
else:
target_findings = self.findings
target_key = self._normalize_finding_key(finding_key)
with self._lock:
target_findings = list(self._get_findings_bucket(target_key))
comments: List[Dict[str, Any]] = []
for idx, f in enumerate(target_findings, start=1):
segment_id = int(f.segment_id or 0)
......@@ -366,6 +388,21 @@ class MemoryStore:
return res
@staticmethod
def _safe_sheet_name(name: str) -> str:
# Excel sheet names cannot exceed 31 chars or include certain symbols.
safe = (name or _DEFAULT_REVIEW_KEY).strip() or _DEFAULT_REVIEW_KEY
for ch in [":", "\\", "/", "?", "*", "[", "]"]:
safe = safe.replace(ch, "_")
return safe[:31]
@staticmethod
def _normalize_finding_key(key: str) -> str:
normalized = (key or "").strip().lower()
if not normalized:
return _DEFAULT_REVIEW_KEY
return normalized
def test_export_findings_to_doc_comments(doc_path: str) -> None:
......@@ -379,7 +416,7 @@ def test_export_findings_to_doc_comments(doc_path: str) -> None:
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
result="不合格",
)
store.add_final_finding(finding)
store.add_finding(FINDING_KEY_REFLECT, finding)
"""测试:将 findings 作为批注写入文档并上传。"""
if not doc_path:
print("doc_path 为空,跳过批注导出测试")
......@@ -408,23 +445,33 @@ def test_memory_and_export_excel():
"segment_id":1
}
})
print( store.search_facts(['支付']))
# finding = Finding(
# rule_title="违约责任",
# segment_id=1,
# original_text="违约方应赔偿全部损失",
# issue="未约定违约金上限,可能导致赔偿范围过大",
# risk_level="H",
# suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
# )
# store.add_finding(finding)
# print( store.search_facts(['支付']))
finding1 = Finding(
rule_title="违约责任",
segment_id=1,
original_text="违约方应赔偿全部损失",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
finding2 = Finding(
rule_title="违约责任",
segment_id=2,
original_text="违约方应赔偿全部损失",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
store.add_finding(FINDING_KEY_REVIEW, finding1)
store.add_finding(FINDING_KEY_REFLECT, finding2)
print(store.get_findings_by_segment(FINDING_KEY_REVIEW, 1))
# print("Facts:\n" + json.dumps(store.get_facts(), ensure_ascii=False, indent=2))
# hits = store.search_findings("赔偿", rule_title="违约责任")
# print("Findings search:")
# for f in hits:
# print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
# print(store.export_to_excel())
print(store.export_to_excel())
if __name__ == "__main__":
......
......@@ -28,6 +28,7 @@ class MemoryWriteTool(ToolBase):
added: List[Dict] = []
for f in findings:
finding_id = str(f.get("id") or "")
rule_title = f.get("rule_title") or ""
issue = f.get("issue") or f.get("issue_description") or ""
level = (f.get("level") or f.get("risk_level") or "M").upper()
......@@ -44,6 +45,7 @@ class MemoryWriteTool(ToolBase):
issue=issue,
risk_level=level,
suggestion=suggestion,
id=finding_id,
result=result,
)
store.add_finding(finding_obj)
......
from __future__ import annotations
import json
from typing import Any, Dict, List
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
from loguru import logger
MERGER_SYSTEM_PROMPT = '''
你是合同审查结果合并智能体(SegmentMerger)。
你的任务是:接收一组 findings,按 original_text 分组后合并。
【规则】
1. 以“文本重叠关系”分组:
- 完全相同:可合并。
- 存在公共子句(如一方是另一方子串,或两句共享连续核心片段):可合并。
- 无明显公共子句、语义独立:不可合并。
2. 每个分组最终只保留 1 条 finding。
3. 对于“公共子句合并”的分组,合并后的 original_text 取“并集文本”:
- 若 A 是 B 的子串,取 B。
- 若 A、B 各有新增片段且围绕同一事实,拼接为不重复、语义通顺的一条完整原文。
4. 完全不相关的句子必须保留为不同分组。
【分组示例】
- 句子1:"A:提交数据后15日内开票并付款。"
- 句子2:"B:提交数据后15日内开票并付款。另由子公司承担服务费。"
- 句子3:"甲方:xx公司"
应分为两组:
- 分组1(句子1+句子2):"提交数据后15日内开票并付款。另由子公司承担服务费。"
- 分组2(句子3):"甲方:xx公司"
【合并要求】
- 同步合并 issue 和 suggestion,兼顾组内要点
- 保留关键信息,不要机械拼接
【字段约束】
- 输出字段固定为:rule_title, segment_id, original_text, issue, risk_level, suggestion, result
- risk_level 仅允许 H/M/L/空字符串
- result 仅允许 合格/不合格/空字符串
- original_text 必须与该组合的原文一致
【输出要求】
- 严格输出 JSON
- 顶层结构必须是:{"findings": [...]}
'''
MERGER_USER_PROMPT = '''
【输入 findings】
{findings_json}
【任务】
请按“文本重叠关系(完全相同或存在公共子句)”分组后合并,并返回合并结果。
若同组文本可互补,请将 original_text 扩展为覆盖组内信息的并集文本;无关文本必须分到不同组。
输出 JSON。
'''
OUTPUT_EXAMPLE = '''
```json
{
"findings": [
{
"rule_title": "付款条款完整性",
"segment_id": 3,
"original_text": "甲方应于验收后30日内支付合同款项",
"issue": "付款期限明确,但未约定逾期违约责任,风险控制不足。",
"risk_level": "M",
"suggestion": "补充约定逾期付款违约金标准及计算方式。",
"result": "不合格"
}
]
}
```
'''
def _as_dict(item: Any) -> Dict[str, Any]:
if isinstance(item, dict):
return dict(item)
if hasattr(item, "__dict__"):
return dict(getattr(item, "__dict__"))
return {}
def _normalize_finding(raw: Dict[str, Any]) -> Dict[str, Any]:
risk_level = str(raw.get("risk_level", "") or "").upper()
if risk_level not in {"H", "M", "L", ""}:
risk_level = ""
result = str(raw.get("result", "") or "")
if result not in {"合格", "不合格", ""}:
result = ""
return {
"rule_title": str(raw.get("rule_title", "") or ""),
"segment_id": int(raw.get("segment_id", 0) or 0),
"original_text": str(raw.get("original_text", "") or ""),
"issue": str(raw.get("issue", "") or ""),
"risk_level": risk_level,
"suggestion": str(raw.get("suggestion", "") or ""),
"result": result,
}
@tool("segment_merger", "同证据 findings 合并")
class SegmentMergerTool(LLMTool):
def __init__(self) -> None:
super().__init__(MERGER_SYSTEM_PROMPT)
@tool_func(
{
"type": "object",
"properties": {
"findings": {"type": "array", "items": {"type": "object"}},
},
"required": ["findings"],
}
)
def run(self, findings: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
normalized_input = [_normalize_finding(_as_dict(item)) for item in (findings or [])]
if not normalized_input:
return {"findings": []}
msgs = self._build_prompt(normalized_input)
try:
resp = self.run_with_loop(self.chat_async(msgs))
data = self.parse_first_json(resp)
raw_findings = data.get("findings") or data.get("merged_findings") or []
if not isinstance(raw_findings, list):
raw_findings = []
normalized_output = [_normalize_finding(_as_dict(item)) for item in raw_findings]
return {"findings": normalized_output}
except Exception as e:
logger.error(f"SegmentMergerTool run error: {e}")
return {"findings": normalized_input}
def _build_prompt(self, findings: List[Dict[str, Any]]) -> List[Dict[str, str]]:
user_content = MERGER_USER_PROMPT.format(
findings_json=json.dumps(findings, ensure_ascii=False, indent=2)
) + OUTPUT_EXAMPLE
return self.build_messages(user_content)
if __name__ == "__main__":
tool = SegmentMergerTool()
sample = [
{
"rule_title": "支付时间审查",
"segment_id": 0,
"original_text": "本协议约定的服务内容全部履行完毕经甲方认可,在乙方提交检测数据后15个工作日内,乙方须向甲方提供相应数额的正规发票,甲方一次性支付合同总金额的100%给乙方。",
"issue": "付款条件缺乏实质把控。条款将付款绑定于提交数据,未明确'经甲方认可'的验收标准、期限及异议机制,且未设置质保金,存在验收流于形式即需全额付款的风险。",
"risk_level": "H",
"suggestion": "修改为:'乙方提交报告后,甲方在X个工作日内验收。验收合格且收到发票后15个工作日内支付95%;剩余5%作为质保金,满X个月无异议后无息支付。若验收不合格,甲方有权拒付并要求整改。'",
"result": "不合格"
},
{
"rule_title": "发票审查",
"segment_id": 0,
"original_text": "在乙方提交检测数据后15个工作日内,乙方须向甲方提供相应数额的正规发票,甲方一次性支付合同总金额的100%给乙方。甲方指定由其全资子公司长沙高新控股集团有限公司(简称高新控股)承担并支付本合同约定的检查服务费,乙方向高新控股开具相应金额的增值税专用发票。",
"issue": "缺失发票税率约定。条款明确了发票类型和开具时间,但未约定适用税率,违反审查规则,可能导致后续开票金额争议或税务合规风险。",
"risk_level": "H",
"suggestion": "补充税率约定。建议在'乙方向高新控股开具相应金额的增值税专用发票'后补充:'(税率:6%)'或根据实际业务类型补充具体税率数值。",
"result": "不合格"
},{
"rule_title": "主体审查",
"segment_id": 0,
"original_text": "委托方(甲方): 湖南麓谷发展集团有限公司... 甲方指定由其全资子公司长沙高新控股集团有限公司(简称高新控股)承担并支付本合同约定的检查服务费... 签章处:甲方:长沙高新控股集团有限公司",
"issue": "签约主体不一致。首部甲方为'湖南麓谷发展集团有限公司',但签章处及付款义务主体变更为'长沙高新控股集团有限公司',且未明确授权委托或变更确认条款,存在主体混同及履约风险。",
"risk_level": "H",
"suggestion": "统一合同主体名称。若确由子公司履约,应将首部及正文甲方统一修改为'长沙高新控股集团有限公司';若由母公司签约,应在签章处由母公司盖章,并补充'指定子公司代为履行付款义务'条款。",
"result": "不合格"
}]
print(json.dumps(tool.run(sample), ensure_ascii=False, indent=2))
......@@ -14,7 +14,7 @@ from utils.http_util import upload_file, fastgpt_openai_chat, download_file
SUFFIX='_麓发改进'
batch_input_dir_path = 'jp-input'
batch_output_dir_path = 'jp-all'
batch_output_dir_path = 'jp-output-lufa-simple'
batch_size = 5
# 麓发fastgpt接口
# url = 'http://192.168.252.71:18089/api/v1/chat/completions'
......
......@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument(
"--datasets-dir",
type=Path,
default=base / "results" / "jp-all-merge-prompt",
default=base / "results" / "jp-output-renji",
help="Directory containing Word files with annotations.",
)
parser.add_argument(
......@@ -133,13 +133,13 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument(
"--val-dir",
type=Path,
default=base / "results" / "jp-all-merge-prompt-extracted",
default=base / "results" / "jp-output-renji-extracted",
help="Directory to store extracted xlsx files for comparison.",
)
parser.add_argument(
"--strip-suffixes",
nargs="*",
default=['_麓发改进'],
default=['_麓发改进','_人机交互'],
help=(
"Optional filename suffixes to strip from generated val xlsx stems before "
"comparison, e.g. --strip-suffixes _v1 _审阅版"
......
......@@ -20,6 +20,9 @@ from core.tools.segment_review import SegmentReviewTool
from core.tools.segment_rule_router import SegmentRuleRouterTool
from core.tools.retrieve_reference import RetrieveReferenceTool
from core.tools.reflect_retry import ReflectRetryTool
from core.tools.segment_merger import SegmentMergerTool
from core.memory import Finding
from core.memory import FINDING_KEY_MERGE, FINDING_KEY_REFLECT, FINDING_KEY_REVIEW
app = FastAPI(title="合同审查智能体", version="0.1.0")
TMP_DIR = Path(__file__).resolve().parent / "tmp"
......@@ -29,6 +32,7 @@ review_tool = SegmentReviewTool()
rule_router_tool = SegmentRuleRouterTool()
reference_tool = RetrieveReferenceTool()
reflect_tool = ReflectRetryTool()
merger_tool = SegmentMergerTool()
@app.post("/sleep")
......@@ -214,7 +218,9 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
# Persist findings to memory store
for f in result.get("findings", []) or []:
try:
store.add_finding_from_dict({
store.add_finding(
FINDING_KEY_REVIEW,
Finding.from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": segment_idx,
"original_text": f.get("original_text",''),
......@@ -223,6 +229,7 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
"result": f.get("result", ""),
"suggestion": f.get("suggestion", ""),
})
)
except Exception as e:
logger.error(e)
continue
......@@ -278,6 +285,19 @@ class ReflectReviewResponse(BaseModel):
findings: List[Dict]
class MergerRequest(BaseModel):
conversation_id: str
segment_id: int
finding_key: Optional[str] = FINDING_KEY_REVIEW
class MergerResponse(BaseModel):
conversation_id: str
segment_id: int
finding_key: str
merged_findings: List[Dict]
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id)
......@@ -289,7 +309,7 @@ def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
summary_keywords = reference_tool.summary_keywords([rule])
context_summaries_facts = store.search_facts(summary_keywords)
# 查找审查规则对应的 findings
findings = [f.__dict__ for f in store.search_findings("", rule_title=payload.rule_title)]
findings = [f.__dict__ for f in store.search_findings(FINDING_KEY_REVIEW, "", rule_title=payload.rule_title)]
final_findings = reflect_tool.run(
party_role=payload.party_role,
rule=rule,
......@@ -299,7 +319,9 @@ def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
for f in final_findings or []:
try:
store.add_final_finding_from_dict({
store.add_finding(
FINDING_KEY_REFLECT,
Finding.from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": f.get("segment_id", 0),
"original_text": f.get("original_text", ""),
......@@ -308,7 +330,7 @@ def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
"suggestion": f.get("suggestion", ""),
"result": f.get("result", "")
})
# print(f'len(store) final_findings:{len(store.final_findings)}')
)
except Exception:
continue
return ReflectReviewResponse(
......@@ -317,6 +339,46 @@ def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
findings=final_findings or [],
)
@app.post("/segments/review/merger", response_model=MergerResponse)
def merge_segment_findings(payload: MergerRequest) -> MergerResponse:
store = get_cached_memory(payload.conversation_id)
source_key = payload.finding_key or FINDING_KEY_REVIEW
target_segment_id = payload.segment_id - 1
segment_findings = store.get_findings_by_segment(source_key, target_segment_id)
unqualified_findings = [f for f in segment_findings if (f.result or "").strip() == "不合格"]
merged_result = merger_tool.run([f.__dict__ for f in unqualified_findings])
merged_findings = merged_result.get("findings", []) or []
for f in merged_findings:
try:
store.add_finding(
FINDING_KEY_MERGE,
Finding.from_dict(
{
"id": f.get("id", ""),
"rule_title": f.get("rule_title", ""),
"segment_id": target_segment_id,
"original_text": f.get("original_text", ""),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
"result": f.get("result", ""),
}
)
)
except Exception as e:
logger.error(e)
continue
return MergerResponse(
conversation_id=payload.conversation_id,
segment_id=target_segment_id,
finding_key=source_key,
merged_findings=merged_findings,
)
########################################################################################################################
class ConversationResponse(BaseModel):
conversation_id: str
......@@ -363,7 +425,7 @@ class MemoryExportRequest(BaseModel):
conversation_id: str
file_ext: str
file_name: Optional[str] = None
export_final: Optional[bool] = False
finding_key: Optional[str] = FINDING_KEY_REVIEW
class MemoryExportResponse(BaseModel):
......@@ -387,7 +449,7 @@ def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
raise HTTPException(status_code=500, detail=f"Export failed: {exc}")
try:
doc_res = store.export_findings_to_doc_comments(doc_obj,export_final=payload.export_final or False)
doc_res = store.export_findings_to_doc_comments(doc_obj, finding_key=payload.finding_key or FINDING_KEY_REVIEW)
except Exception as exc:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Export doc comments failed: {exc}")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment