Commit 9fa3fd96 by ccran

feat: 增加PDF处理;

parent 53a63f97
tmp/ tmp/
*.pyc
**__pycache__**
\ No newline at end of file
from utils.spire_word_util import SpireWordDoc from utils.spire_word_util import SpireWordDoc
from utils.spire_pdf_util import SpirePdfDoc
from utils.doc_util import DocBase from utils.doc_util import DocBase
from functools import lru_cache from functools import lru_cache
from typing import Optional from typing import Optional, Tuple
from core.memory import MemoryStore from core.memory import MemoryStore
from core.config import pdf_support_formats
MAX_CACHE = 128 MAX_CACHE = 128
def _normalize_file_ext(file_ext: str) -> str:
if not file_ext:
raise ValueError("file_ext is required")
ext = file_ext.strip().lower()
if not ext.startswith("."):
ext = f".{ext}"
return ext
@lru_cache(maxsize=MAX_CACHE) @lru_cache(maxsize=MAX_CACHE)
def get_cached_doc_tool(conversation_id: str) -> Optional[DocBase]: def get_cached_doc_tool(conversation_id: str, file_ext: str) -> Tuple[DocBase, str]:
return SpireWordDoc() ext = _normalize_file_ext(file_ext)
if ext in pdf_support_formats:
return SpirePdfDoc(), ext
return SpireWordDoc(), ext
@lru_cache(maxsize=MAX_CACHE) @lru_cache(maxsize=MAX_CACHE)
def get_cached_memory(conversation_id: str) -> MemoryStore: def get_cached_memory(conversation_id: str) -> MemoryStore:
......
...@@ -29,7 +29,6 @@ class LLMConfig: ...@@ -29,7 +29,6 @@ class LLMConfig:
outer_backend_url = "http://218.77.58.8:48081" outer_backend_url = "http://218.77.58.8:48081"
base_fastgpt_url = "http://192.168.252.71:18089" base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081" base_backend_url = "http://192.168.252.71:48081"
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
# 项目根目录 # 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent" root_path = r"E:\PycharmProject\contract_review_agent"
...@@ -70,7 +69,6 @@ field_mapping = { ...@@ -70,7 +69,6 @@ field_mapping = {
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]} excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5 max_review_group = 5
# 销售类别判断 # 销售类别判断
ocr_thresholds = 1000
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"] all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思" reflection_sheet = "反思"
# 最大分片数量 # 最大分片数量
......
...@@ -58,6 +58,7 @@ from threading import RLock ...@@ -58,6 +58,7 @@ from threading import RLock
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file from utils.http_util import upload_file
from utils.doc_util import DocBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -107,6 +108,7 @@ class MemoryStore: ...@@ -107,6 +108,7 @@ class MemoryStore:
self._lock = RLock() self._lock = RLock()
self.facts: List[Dict[str, Any]] = [] self.facts: List[Dict[str, Any]] = []
self.findings: List[RiskFinding] = [] self.findings: List[RiskFinding] = []
self.final_findings: List[RiskFinding] = []
self._load() self._load()
# ---------------------- facts ---------------------- # ---------------------- facts ----------------------
...@@ -128,35 +130,76 @@ class MemoryStore: ...@@ -128,35 +130,76 @@ class MemoryStore:
# -------------------- findings --------------------- # -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding: def add_finding(self, finding: RiskFinding) -> RiskFinding:
with self._lock: return self._add_finding(self.findings, finding)
self.findings.append(finding)
self._persist()
return finding
def add_finding_from_dict(self, data: Dict) -> RiskFinding: def add_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_finding(RiskFinding.from_dict(data)) return self.add_finding(RiskFinding.from_dict(data))
def add_final_finding(self, finding: RiskFinding) -> RiskFinding:
return self._add_finding(self.final_findings, finding)
def add_final_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_final_finding(RiskFinding.from_dict(data))
def list_findings(self) -> List[RiskFinding]: def list_findings(self) -> List[RiskFinding]:
with self._lock: return self._list_findings(self.findings)
return list(self.findings)
def list_final_findings(self) -> List[RiskFinding]:
return self._list_findings(self.final_findings)
def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]: def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
with self._lock: return self._get_findings_by_segment(self.findings, segment_id)
return [f for f in self.findings if f.segment_id == segment_id]
def get_final_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
return self._get_findings_by_segment(self.final_findings, segment_id)
def delete_findings_by_segment(self, segment_id: int) -> int: def delete_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("findings", segment_id)
def delete_final_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("final_findings", segment_id)
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.findings, keyword, rule_title, risk_level)
def search_final_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.final_findings, keyword, rule_title, risk_level)
def _add_finding(self, target: List[RiskFinding], finding: RiskFinding) -> RiskFinding:
with self._lock:
target.append(finding)
self._persist()
return finding
def _list_findings(self, target: List[RiskFinding]) -> List[RiskFinding]:
with self._lock:
return list(target)
def _get_findings_by_segment(self, target: List[RiskFinding], segment_id: int) -> List[RiskFinding]:
with self._lock: with self._lock:
before = len(self.findings) return [f for f in target if f.segment_id == segment_id]
self.findings = [f for f in self.findings if f.segment_id != segment_id]
removed = before - len(self.findings) def _delete_findings_by_segment(self, attr_name: str, segment_id: int) -> int:
with self._lock:
current = getattr(self, attr_name)
before = len(current)
updated = [f for f in current if f.segment_id != segment_id]
setattr(self, attr_name, updated)
removed = before - len(updated)
if removed: if removed:
self._persist() self._persist()
return removed return removed
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]: def _search_findings(
self,
target: List[RiskFinding],
keyword: str,
rule_title: Optional[str] = None,
risk_level: Optional[str] = None,
) -> List[RiskFinding]:
key = (keyword or "").strip().lower() key = (keyword or "").strip().lower()
with self._lock: with self._lock:
candidates = list(self.findings) candidates = list(target)
if rule_title: if rule_title:
candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()] candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()]
if risk_level: if risk_level:
...@@ -179,12 +222,14 @@ class MemoryStore: ...@@ -179,12 +222,14 @@ class MemoryStore:
with self._lock: with self._lock:
self.facts.clear() self.facts.clear()
self.findings.clear() self.findings.clear()
self.final_findings.clear()
self._persist() self._persist()
def _persist(self) -> None: def _persist(self) -> None:
payload = { payload = {
"facts": self.facts, "facts": self.facts,
"findings": [asdict(f) for f in self.findings], "findings": [asdict(f) for f in self.findings],
"final_findings": [asdict(f) for f in self.final_findings],
} }
try: try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
...@@ -200,6 +245,7 @@ class MemoryStore: ...@@ -200,6 +245,7 @@ class MemoryStore:
if isinstance(data, dict): if isinstance(data, dict):
self.facts = data.get("facts") or [] self.facts = data.get("facts") or []
self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []] self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []]
self.final_findings = [RiskFinding.from_dict(item) for item in data.get("final_findings", []) or []]
except Exception as exc: except Exception as exc:
logger.error("Failed to load memory store: %s", exc) logger.error("Failed to load memory store: %s", exc)
...@@ -233,6 +279,13 @@ class MemoryStore: ...@@ -233,6 +279,13 @@ class MemoryStore:
getattr(f, key, "") for key, _ in finding_headers getattr(f, key, "") for key, _ in finding_headers
]) ])
ws_final_findings = wb.create_sheet("final_findings")
ws_final_findings.append([label for _, label in finding_headers])
for f in self.final_findings:
ws_final_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
ws_facts = wb.create_sheet("facts") ws_facts = wb.create_sheet("facts")
if self.facts: if self.facts:
fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()}) fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()})
...@@ -263,8 +316,97 @@ class MemoryStore: ...@@ -263,8 +316,97 @@ class MemoryStore:
return res return res
def export_findings_to_doc_comments(
self,
doc_obj: DocBase,
file_name: Optional[str] = None,
remove_prefix: bool = False,
) -> Dict[str, Any]:
"""Add all findings as comments to a document, upload, then delete the local file."""
if doc_obj is None:
raise ValueError("doc_obj is required")
if __name__ == "__main__": ts = datetime.now().strftime("%Y%m%d_%H%M%S")
doc_name = getattr(doc_obj, "_doc_name", "") or ""
suffix = Path(doc_name).suffix or ".docx"
name = file_name or f"findings_{ts}{suffix}"
if not Path(name).suffix:
name = f"{name}{suffix}"
output_path = Path(__file__).resolve().parent.parent / "tmp" / name
with self._lock:
comments: List[Dict[str, Any]] = []
for idx, f in enumerate(self.findings, start=1):
# 导出final_findings
# for idx, f in enumerate(self.final_findings, start=1):
segment_id = int(f.segment_id or 0)
chunk_id = max(segment_id - 1, 0)
suggest_parts = []
if f.risk_level:
suggest_parts.append(f"风险等级:{f.risk_level}")
if f.issue:
suggest_parts.append(f"问题:{f.issue}")
if f.suggestion:
suggest_parts.append(f"建议:{f.suggestion}")
suggest_text = "\n".join(suggest_parts).strip()
comments.append(
{
"id": str(idx),
"key_points": f.rule_title or "风险提示",
"original_text": f.original_text or "",
"details": f.issue or "",
"chunk_id": chunk_id,
"result": "不合格",
"suggest": suggest_text,
}
)
if comments:
doc_obj.add_chunk_comment(0, comments)
doc_obj.to_file(str(output_path), remove_prefix=remove_prefix)
try:
res = upload_file(str(output_path))
finally:
try:
output_path.unlink()
except Exception:
logger.warning("Failed to delete temp doc: %s", output_path)
return res
def test_export_findings_to_doc_comments(doc_path: str) -> None:
store = MemoryStore()
finding = RiskFinding(
rule_title="违约责任",
segment_id=1,
original_text="湖南麓谷发展集团有限公司",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
store.add_final_finding(finding)
"""测试:将 findings 作为批注写入文档并上传。"""
if not doc_path:
print("doc_path 为空,跳过批注导出测试")
return
if not Path(doc_path).exists():
print(f"文件不存在,跳过批注导出测试: {doc_path}")
return
try:
from utils.spire_word_util import SpireWordDoc
except Exception as exc:
print(f"加载 SpireWordDoc 失败,跳过批注导出测试: {exc}")
return
doc = SpireWordDoc()
doc.load(doc_path)
res = store.export_findings_to_doc_comments(doc)
print("Export doc comments:")
print(json.dumps(res, ensure_ascii=False, indent=2))
def test_memory_and_export_excel():
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索 # 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore() store = MemoryStore()
store.set_facts([{ store.set_facts([{
...@@ -295,3 +437,8 @@ if __name__ == "__main__": ...@@ -295,3 +437,8 @@ if __name__ == "__main__":
print(json.dumps(asdict(f), ensure_ascii=False, indent=2)) print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print(store.export_to_excel()) print(store.export_to_excel())
if __name__ == "__main__":
# test_export_findings_to_doc_comments("/home/ccran/lufa-contract/tmp/股份转让协议.docx")
test_memory_and_export_excel()
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional import json
from typing import Dict, List, Optional, Any
from core.tool import ToolBase, tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
REFLECT_SYSTEM_PROMPT = '''
你是一个合同审查反思智能体(ReviewReflection)。
你的任务是:基于 facts 与全文上下文,
对已有 findings 输出修改操作。
你只能对 findings 执行以下四种操作:
- keep:确认该风险结论成立且无需修改
- update:修改一个已有风险
- add:新增一个由全文结构推导出的风险
- remove:删除一个不成立的风险
【输出约束】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
REFLECT_USER_PROMPT = '''
【输入】
【已有风险 findings】
{findings_json}
【合同事实记忆 facts】
{facts_json}
【合同立场】
站在 {party_role} 的立场进行反思审查。
【反思原则】
- 不得引入新的审查维度
- 所有判断必须有 facts 或合同原文证据支持
- 仅对已有 findings 进行增add、改update、删remove、保留keep操作
- 若风险在全文上下文中不成立,必须 remove
- 若风险成立但表述、严重性或建议不准确,必须 update
- 若由多个 findings 或全文结构推导出新的系统性风险,可 add
【输出要求】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
OUTPUT_FORMAT_SCHEMA = '''
```json
{
"operations": [
{
"op": "keep | add | update | remove",
"id": "string | null",
"findings": "object | null"
}
]
}
```
'''
@tool("reflect_retry", "反思重试质量闸")
class ReflectRetryTool(LLMTool):
def __init__(self) -> None:
super().__init__(REFLECT_SYSTEM_PROMPT)
@tool("reflect_retry", "反思重试质量闸(MVP)")
class ReflectRetryTool(ToolBase):
@tool_func( @tool_func(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"segment_id": {"type": "string"}, "party_role": {"type": "string"},
"ruleset_id": {"type": "string"}, "rule": {"type": "object"},
"context_memories": {"type": "array"}, "facts": {"type": "array"},
"review_output": {"type": "object"}, "findings": {"type": "array"},
}, },
"required": ["segment_id", "ruleset_id", "review_output"], "required": ["party_role", "rule", "facts", "findings"],
} }
) )
def run(self, segment_id: str, ruleset_id: str, context_memories: Optional[List[Dict]] = None, review_output: Dict = None) -> Dict: def run(self, party_role: str, rule: Dict, facts: Optional[List[Dict]] = None, findings: Optional[List[Dict]] = None) -> List[Dict]:
problems: List[str] = [] base_findings = self._build_findings_with_ids(findings or [])
actions: List[Dict] = [] user_content = REFLECT_USER_PROMPT.format(
findings = (review_output or {}).get("findings", []) findings_json=json.dumps(base_findings, ensure_ascii=False),
missings = (review_output or {}).get("missing_info", []) facts_json=json.dumps(facts or [], ensure_ascii=False),
party_role=party_role,
for f in findings: ) + OUTPUT_FORMAT_SCHEMA
evs = f.get("evidence_quotes", []) or [] messages = self.build_messages(user_content)
sugg = (f.get("suggestion") or "").strip()
if not evs: try:
problems.append("缺少对关键结论的原文引用") resp = self.run_with_loop(self.chat_async(messages))
if not sugg: data = self.parse_first_json(resp)
problems.append("建议条款措辞仍偏泛,未给可替换文本") except Exception:
data = {}
needs_evidence = any(m.get("type") == "evidence" for m in missings)
needs_knowledge = any(m.get("type") == "knowledge" for m in missings) operations = data.get("operations", []) or []
status = "pass" final_findings = self._apply_operations(base_findings, operations)
if needs_evidence: return final_findings
actions.append({"action": "FullRead", "target": segment_id, "need_focus": [m.get("need") for m in missings if m.get("type") == "evidence"]})
if needs_knowledge: def _build_findings_with_ids(self, findings: List[Dict]) -> List[Dict[str, Any]]:
actions.append({"action": "RetrieveReference", "topic": [m.get("topic") for m in missings if m.get("type") == "knowledge"]}) res: List[Dict[str, Any]] = []
if problems or actions: for idx, f in enumerate(findings):
status = "revise" fid = f.get("id") or f.get("finding_id") or f.get("_id") or f"f_{idx+1}"
item = dict(f)
retry_budget_left = int((review_output or {}).get("retry_budget_left", 1)) item["id"] = str(fid)
res.append(item)
return { return res
"segment_id": segment_id,
"status": status, def _apply_operations(self, base_findings: List[Dict[str, Any]], operations: List[Dict]) -> List[Dict[str, Any]]:
"problems": list(dict.fromkeys(problems)), by_id = {str(f.get("id")): dict(f) for f in base_findings if f.get("id") is not None}
"revise_instructions": actions, added: List[Dict[str, Any]] = []
"retry_budget_left": max(0, retry_budget_left - (1 if status == "revise" else 0)),
} for op in operations:
action = (op.get("op") or "").strip().lower()
target_id = op.get("id")
payload = op.get("findings")
if action == "keep":
if target_id is not None and str(target_id) in by_id:
continue
elif action == "remove":
if target_id is not None:
by_id.pop(str(target_id), None)
elif action == "update":
if target_id is not None and isinstance(payload, dict):
payload = dict(payload)
payload["id"] = str(target_id)
by_id[str(target_id)] = payload
elif action == "add":
if isinstance(payload, dict):
added.append(dict(payload))
merged = list(by_id.values()) + added
return merged
if __name__ == "__main__":
tool = ReflectRetryTool()
res = tool.run(
party_role="甲方",
rule={"title":"付款与验收", "segment_id": 3},
facts=[
{"segment_id": 3, "支付": {"比例": "30%预付款", "期限": "验收后30日内"}},
{"segment_id": 3, "验收": {"标准": "双方书面确认", "时限": "7个工作日"}},
],
findings=[
{
"rule_title": "付款条款审查",
"segment_id": 3,
"original_text": "甲方在验收合格后30日内支付剩余价款。",
"issue": "未约定逾期付款违约金,约束不足",
"risk_level": "M",
"suggestion": "增加逾期付款违约金条款,如逾期每日按未付款的0.05%计收"
},
{
"rule_title": "验收条款审查",
"segment_id": 3,
"original_text": "乙方交付后甲方7个工作日内完成验收。",
"issue": "未明确逾期未验收的后果",
"risk_level": "L",
"suggestion": "补充逾期未验收视为通过或约定明确责任"
},
]
)
print(res)
\ No newline at end of file
...@@ -9,8 +9,8 @@ from utils.openai_util import OpenAITool ...@@ -9,8 +9,8 @@ from utils.openai_util import OpenAITool
from core.config import LLM, MAX_WORKERS from core.config import LLM, MAX_WORKERS
class SegmentLLMBase(ToolBase): class LLMTool(ToolBase):
"""LLM-backed segment processor: builds prompts, calls LLM, parses JSON.""" """LLM-backed processor: builds prompts, calls LLM, parses JSON."""
def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None: def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None:
super().__init__() super().__init__()
......
...@@ -7,10 +7,10 @@ from typing import Dict, List, Optional ...@@ -7,10 +7,10 @@ from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil from utils.excel_util import ExcelUtil
from core.tools.segment_llm import SegmentLLMBase from core.tools.segment_llm import LLMTool
import re import re
REVIEW_SYSTEM_PROMPT = f''' REVIEW_SYSTEM_PROMPT = '''
你是一个专业的合同分段审查智能体(SegmentReview)。 你是一个专业的合同分段审查智能体(SegmentReview)。
你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。 你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。
...@@ -100,10 +100,10 @@ def _is_generic_suggestion(text: str) -> bool: ...@@ -100,10 +100,10 @@ def _is_generic_suggestion(text: str) -> bool:
return any(g in t for g in GENERIC_SUGGESTIONS) return any(g in t for g in GENERIC_SUGGESTIONS)
@tool("segment_review", "合同分段审查") @tool("segment_review", "合同分段审查")
class SegmentReviewTool(SegmentLLMBase): class SegmentReviewTool(LLMTool):
def __init__(self): def __init__(self):
super().__init__(REVIEW_SYSTEM_PROMPT) super().__init__(REVIEW_SYSTEM_PROMPT)
self.rule_version = "rule-v1" self.rule_version = "通用"
self.column_map = { self.column_map = {
"id": "ID", "id": "ID",
"title": "审查项", "title": "审查项",
...@@ -278,7 +278,7 @@ if __name__=="__main__": ...@@ -278,7 +278,7 @@ if __name__=="__main__":
result = tool.run( result = tool.run(
segment_id=1, segment_id=1,
segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。", segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。",
ruleset_id="rule-v1", ruleset_id="通用",
party_role="甲方", party_role="甲方",
context_memories=[ context_memories=[
{ {
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import SegmentLLMBase from core.tools.segment_llm import LLMTool
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"] FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
...@@ -48,7 +48,7 @@ OUTPUT_EXAMPLE = ''' ...@@ -48,7 +48,7 @@ OUTPUT_EXAMPLE = '''
@tool("segment_summary", "分段事实提取") @tool("segment_summary", "分段事实提取")
class SegmentSummaryTool(SegmentLLMBase): class SegmentSummaryTool(LLMTool):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(SUMMARY_SYSTEM_PROMPT) super().__init__(SUMMARY_SYSTEM_PROMPT)
......
No preview for this file type
...@@ -6,13 +6,15 @@ from uuid import uuid4 ...@@ -6,13 +6,15 @@ from uuid import uuid4
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import uvicorn import uvicorn
import traceback
from utils.common_util import extract_url_file, format_now from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats from core.config import doc_support_formats, pdf_support_formats
from core.tools.segment_summary import SegmentSummaryTool from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool from core.tools.segment_review import SegmentReviewTool
from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0") app = FastAPI(title="合同审查智能体", version="0.1.0")
...@@ -20,25 +22,31 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp" ...@@ -20,25 +22,31 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True) TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool() summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool() review_tool = SegmentReviewTool()
reflect_tool = ReflectRetryTool()
######################################################################################################################## ########################################################################################################################
class DocumentParseRequest(BaseModel): class DocumentParseRequest(BaseModel):
conversation_id: str conversation_id: str
urls: List[str] = Field(..., description="File download url") urls: List[str] = Field(..., description="File download url")
file_ext: Optional[str] = None
ruleset_id: Optional[str] = "通用"
class DocumentParseResponse(BaseModel): class DocumentParseResponse(BaseModel):
conversation_id: str conversation_id: str
text: str
chunk_ids: List[int] chunk_ids: List[int]
ruleset_items: List[str]
text: Optional[str] = None
file_ext: Optional[str] = None
@app.post("/documents/parse", response_model=DocumentParseResponse) @app.post("/documents/parse", response_model=DocumentParseResponse)
def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse: async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
if not payload.urls: if not payload.urls:
raise HTTPException(status_code=400, detail="No URLs provided") raise HTTPException(status_code=400, detail="No URLs provided")
try: try:
filename = extract_url_file(payload.urls[0], doc_support_formats) support_formats = list(dict.fromkeys(doc_support_formats + pdf_support_formats))
filename = extract_url_file(payload.urls[0], support_formats)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}") raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}")
...@@ -47,13 +55,27 @@ def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse: ...@@ -47,13 +55,27 @@ def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
download_file(payload.urls[0], file_path) download_file(payload.urls[0], file_path)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=500, detail=f"Download failed: {exc}") raise HTTPException(status_code=500, detail=f"Download failed: {exc}")
# get doc tool
doc_obj = get_cached_doc_tool(payload.conversation_id) file_ext = payload.file_ext or Path(filename).suffix
try:
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
doc_obj.load(file_path) doc_obj.load(file_path)
text = doc_obj.get_all_text() # ocr
await doc_obj.get_from_ocr()
# text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list() chunk_ids = doc_obj.get_chunk_id_list()
# get ruleset items
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
ruleset_review_items = [r.get('title') for r in ruleset_items]
return DocumentParseResponse( return DocumentParseResponse(
conversation_id=payload.conversation_id, text=text,chunk_ids=chunk_ids conversation_id=payload.conversation_id,
# text=text,
chunk_ids=chunk_ids,
ruleset_items=ruleset_review_items,
file_ext = file_ext
) )
######################################################################################################################## ########################################################################################################################
...@@ -62,6 +84,7 @@ class SegmentSummaryRequest(BaseModel): ...@@ -62,6 +84,7 @@ class SegmentSummaryRequest(BaseModel):
conversation_id: str conversation_id: str
segment_id: int segment_id: int
party_role: Optional[str] = "" party_role: Optional[str] = ""
file_ext: str
context_facts: Optional[Dict] = None context_facts: Optional[Dict] = None
...@@ -75,7 +98,7 @@ class SegmentSummaryResponse(BaseModel): ...@@ -75,7 +98,7 @@ class SegmentSummaryResponse(BaseModel):
def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse: def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
try: try:
doc_obj = get_cached_doc_tool(payload.conversation_id) doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
...@@ -105,7 +128,8 @@ class SegmentReviewRequest(BaseModel): ...@@ -105,7 +128,8 @@ class SegmentReviewRequest(BaseModel):
conversation_id: str conversation_id: str
segment_id: int segment_id: int
party_role: Optional[str] = "" party_role: Optional[str] = ""
ruleset_id: Optional[str] = "rule-v1" ruleset_id: Optional[str] = "通用"
file_ext: str
context_memories: Optional[List[Dict]] = None context_memories: Optional[List[Dict]] = None
...@@ -119,7 +143,7 @@ class SegmentReviewResponse(BaseModel): ...@@ -119,7 +143,7 @@ class SegmentReviewResponse(BaseModel):
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
try: try:
doc_obj = get_cached_doc_tool(payload.conversation_id) doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
...@@ -132,7 +156,7 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -132,7 +156,7 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
result = review_tool.run( result = review_tool.run(
segment_id=payload.segment_id, segment_id=payload.segment_id,
segment_text=segment_text, segment_text=segment_text,
ruleset_id=payload.ruleset_id or "rule-v1", ruleset_id=payload.ruleset_id or "通用",
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(), context_memories=payload.context_memories or store.get_facts(),
) )
...@@ -159,6 +183,57 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -159,6 +183,57 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
) )
######################################################################################################################## ########################################################################################################################
class ReflectReviewRequest(BaseModel):
conversation_id: str
party_role: str
ruleset_id: Optional[str] = "通用"
rule_title: str
class ReflectReviewResponse(BaseModel):
conversation_id: str
rule_title: str
findings: List[Dict]
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id)
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
facts = store.get_facts()
findings = [f.__dict__ for f in store.list_findings()]
final_findings = reflect_tool.run(
party_role=payload.party_role,
rule=rule,
facts=facts,
findings=findings,
)
for f in final_findings or []:
try:
store.add_final_finding_from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": f.get("segment_id", 0),
"original_text": f.get("original_text", ""),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
})
except Exception:
continue
return ReflectReviewResponse(
conversation_id=payload.conversation_id,
rule_title=payload.rule_title,
findings=final_findings or [],
)
########################################################################################################################
class ConversationResponse(BaseModel): class ConversationResponse(BaseModel):
conversation_id: str conversation_id: str
created_at: str created_at: str
...@@ -173,44 +248,41 @@ def new_conversation() -> ConversationResponse: ...@@ -173,44 +248,41 @@ def new_conversation() -> ConversationResponse:
class MemoryExportRequest(BaseModel): class MemoryExportRequest(BaseModel):
conversation_id: str conversation_id: str
file_ext: str
file_name: Optional[str] = None file_name: Optional[str] = None
class MemoryExportResponse(BaseModel): class MemoryExportResponse(BaseModel):
conversation_id: str conversation_id: str
url: str excel_url: str
data: Dict doc_url: str
@app.post("/memory/export", response_model=MemoryExportResponse) @app.post("/memory/export", response_model=MemoryExportResponse)
def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse: def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
try: try:
res = store.export_to_excel(file_name=payload.file_name) doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
try:
excel_res = store.export_to_excel(file_name=payload.file_name)
except ImportError as exc: except ImportError as exc:
raise HTTPException(status_code=500, detail=str(exc)) raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=500, detail=f"Export failed: {exc}") raise HTTPException(status_code=500, detail=f"Export failed: {exc}")
url = "" try:
if isinstance(res, str): doc_res = store.export_findings_to_doc_comments(doc_obj)
url = res except Exception as exc:
elif isinstance(res, dict): traceback.print_exc()
for key in [ raise HTTPException(status_code=500, detail=f"Export doc comments failed: {exc}")
"url",
"file_url", return MemoryExportResponse(
"fileUrl", conversation_id=payload.conversation_id,
"link", excel_url=excel_res,
"downloadUrl", doc_url=doc_res,
"path", )
"filePath",
]:
val = res.get(key)
if isinstance(val, str) and val:
url = val
break
return MemoryExportResponse(conversation_id=payload.conversation_id, url=url, data=res if isinstance(res, dict) else {"result": res})
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
......
import json import json_repair
import random import random
import re import re
from datetime import datetime from datetime import datetime
from typing import Dict, List from typing import Dict, List
from loguru import logger from loguru import logger
from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size
...@@ -51,14 +50,14 @@ def extract_json(json_str:str) -> List[Dict]: ...@@ -51,14 +50,14 @@ def extract_json(json_str:str) -> List[Dict]:
# 清理控制字符 # 清理控制字符
s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s) s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s)
try: try:
obj = json.loads(s, strict=False) obj = json_repair.loads(s, strict=False)
if isinstance(obj, list): if isinstance(obj, list):
out_list.extend(obj) out_list.extend(obj)
else: else:
out_list.append(obj) out_list.append(obj)
return True return True
except Exception as e: except Exception as e:
logger.error(f"JSON解析失败: {e}") logger.error(f"JSON解析失败: {e} | 内容片段: {s}")
return False return False
results = [] results = []
......
...@@ -53,7 +53,7 @@ def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True) ...@@ -53,7 +53,7 @@ def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True)
return rsp return rsp
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False): def upload_file(path, input_url_to_inner=True, output_url_to_inner=False) -> str:
# 登录获取token # 登录获取token
login_data = {"username": "admin", "password": "admin@jpai.com"} login_data = {"username": "admin", "password": "admin@jpai.com"}
login_url = f"{base_backend_url}/admin-api/system/auth/login" login_url = f"{base_backend_url}/admin-api/system/auth/login"
...@@ -109,5 +109,5 @@ def url_replace_fastgpt(origin: str): ...@@ -109,5 +109,5 @@ def url_replace_fastgpt(origin: str):
if __name__ == "__main__": if __name__ == "__main__":
# d = '/home/ccran/file.docx' # d = '/home/ccran/file.docx'
d = "file.docx" d = "/home/ccran/lufa-contract/tmp/default.json"
print(os.path.basename(d)) print(upload_file(d))
import fitz
from urllib import parse
from urllib.parse import urlparse
import os
import asyncio
import aiohttp
from aiohttp import ClientSession
from utils.http_util import url_replace_fastgpt, download_file
from utils.common_util import random_str
from loguru import logger
import json
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
class OCRUtil:
def __init__(self):
pass
# ocr异步接口
async def ocr_requests_async(self, session, file_path):
# 文件上传通常使用multipart/form-data编码格式
# file_path = 'file/1.webp'
async with session.post(ocr_url, data={'file': open(file_path, 'rb')}) as response:
# logger.info(f'开始请求:{file_path}')
rsp = await response.text()
# logger.info(f'{file_path}结束请求:{json.loads(rsp)["msg"]}')
return rsp, file_path
# 异步请求ocr image接口
async def ocr_image_async(self, path_list):
timeout = aiohttp.ClientTimeout(total=600)
connector = aiohttp.TCPConnector(limit=1)
async with ClientSession(connector=connector, timeout=timeout) as session:
tasks = [self.ocr_requests_async(session, file_path) for file_path in path_list]
responses = await asyncio.gather(*tasks)
res_dict = {}
for response_text, file_path in responses:
rsp_json = json.loads(response_text)
if 'data' not in rsp_json:
logger.error(f'ocr_image_async {file_path} error:{rsp_json["msg"]}')
continue
else:
content = rsp_json['data']['strRes']
# logger.info(f'ocr_image_async append {file_path} success. content:{content[:100]}')
# add to dict
page_num = int(self.get_pdf_2_img_page_num(file_path))
res_dict[page_num] = rsp_json['data']['strRes']
# 根据页数排序
logger.info(f'ocr_image_async finish. all pages:{len(res_dict)}')
sorted_values = [res_dict[key] for key in sorted(res_dict)]
return sorted_values
def set_pdf_2_img_page(self, path, page_idx):
return f'{path}_{page_idx + 1}.png'
def get_pdf_2_img_page_num(self, path):
split_path = path.split('_')
# for example : 'xx_10.png' ==> get 10
return split_path[-1][:-4]
def pdf_2_img(self, path, zoom_x=1, zoom_y=1):
# 打开PDF文件
pdf = fitz.open(path)
pdf_list = []
# 逐页读取PDF
for pg in range(0, pdf.page_count):
page = pdf[pg]
# 设置缩放和旋转系数
trans = fitz.Matrix(zoom_x, zoom_y)
pm = page.get_pixmap(matrix=trans, alpha=False)
# img save
dest_png = self.set_pdf_2_img_page(path, pg)
pm.save(dest_png)
pdf_list.append(dest_png)
pdf.close()
return pdf_list
def ocr_download_path(self, url):
logger.info(f'ocr url:{url}')
# url替换
url = url_replace_fastgpt(url)
# url分析
url_parsed = urlparse(url)
query_dict = parse.parse_qs(url_parsed.query)
if 'filename' in query_dict:
filename = query_dict.get('filename')[0]
else:
filename = f'{random_str()}.pdf'
# 获取文件内容
dest_path = f'ocr/{filename}'
download_file(url, dest_path)
return dest_path
async def ocr_result_pdf(self, dest_path):
# 解析返回结果
pdf_list = self.pdf_2_img(dest_path)
# logger.info(f'pdf 生成完毕:{pdf_list}')
result = await self.ocr_image_async(pdf_list)
# 删除文件
for pdf in pdf_list:
if os.path.exists(pdf):
os.remove(pdf)
return result
if __name__ == '__main__':
ocr_util = OCRUtil()
result = asyncio.run(ocr_util.ocr_result_pdf('/home/ccran/lufa-contract/tmp/财03 2023年国科京东方评估合同.pdf'))
print(f'len(result):{len(result)}')
print(result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment