Commit 9fa3fd96 by ccran

feat: 增加PDF处理;

parent 53a63f97
tmp/ tmp/
*.pyc
**__pycache__**
\ No newline at end of file
from utils.spire_word_util import SpireWordDoc from utils.spire_word_util import SpireWordDoc
from utils.spire_pdf_util import SpirePdfDoc
from utils.doc_util import DocBase from utils.doc_util import DocBase
from functools import lru_cache from functools import lru_cache
from typing import Optional from typing import Optional, Tuple
from core.memory import MemoryStore from core.memory import MemoryStore
from core.config import pdf_support_formats
MAX_CACHE = 128 MAX_CACHE = 128
def _normalize_file_ext(file_ext: str) -> str:
if not file_ext:
raise ValueError("file_ext is required")
ext = file_ext.strip().lower()
if not ext.startswith("."):
ext = f".{ext}"
return ext
@lru_cache(maxsize=MAX_CACHE) @lru_cache(maxsize=MAX_CACHE)
def get_cached_doc_tool(conversation_id: str) -> Optional[DocBase]: def get_cached_doc_tool(conversation_id: str, file_ext: str) -> Tuple[DocBase, str]:
return SpireWordDoc() ext = _normalize_file_ext(file_ext)
if ext in pdf_support_formats:
return SpirePdfDoc(), ext
return SpireWordDoc(), ext
@lru_cache(maxsize=MAX_CACHE) @lru_cache(maxsize=MAX_CACHE)
def get_cached_memory(conversation_id: str) -> MemoryStore: def get_cached_memory(conversation_id: str) -> MemoryStore:
......
...@@ -29,7 +29,6 @@ class LLMConfig: ...@@ -29,7 +29,6 @@ class LLMConfig:
outer_backend_url = "http://218.77.58.8:48081" outer_backend_url = "http://218.77.58.8:48081"
base_fastgpt_url = "http://192.168.252.71:18089" base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081" base_backend_url = "http://192.168.252.71:48081"
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
# 项目根目录 # 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent" root_path = r"E:\PycharmProject\contract_review_agent"
...@@ -70,7 +69,6 @@ field_mapping = { ...@@ -70,7 +69,6 @@ field_mapping = {
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]} excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5 max_review_group = 5
# 销售类别判断 # 销售类别判断
ocr_thresholds = 1000
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"] all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思" reflection_sheet = "反思"
# 最大分片数量 # 最大分片数量
......
...@@ -58,6 +58,7 @@ from threading import RLock ...@@ -58,6 +58,7 @@ from threading import RLock
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file from utils.http_util import upload_file
from utils.doc_util import DocBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -107,6 +108,7 @@ class MemoryStore: ...@@ -107,6 +108,7 @@ class MemoryStore:
self._lock = RLock() self._lock = RLock()
self.facts: List[Dict[str, Any]] = [] self.facts: List[Dict[str, Any]] = []
self.findings: List[RiskFinding] = [] self.findings: List[RiskFinding] = []
self.final_findings: List[RiskFinding] = []
self._load() self._load()
# ---------------------- facts ---------------------- # ---------------------- facts ----------------------
...@@ -128,35 +130,76 @@ class MemoryStore: ...@@ -128,35 +130,76 @@ class MemoryStore:
# -------------------- findings --------------------- # -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding: def add_finding(self, finding: RiskFinding) -> RiskFinding:
with self._lock: return self._add_finding(self.findings, finding)
self.findings.append(finding)
self._persist()
return finding
def add_finding_from_dict(self, data: Dict) -> RiskFinding: def add_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_finding(RiskFinding.from_dict(data)) return self.add_finding(RiskFinding.from_dict(data))
def add_final_finding(self, finding: RiskFinding) -> RiskFinding:
return self._add_finding(self.final_findings, finding)
def add_final_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_final_finding(RiskFinding.from_dict(data))
def list_findings(self) -> List[RiskFinding]: def list_findings(self) -> List[RiskFinding]:
with self._lock: return self._list_findings(self.findings)
return list(self.findings)
def list_final_findings(self) -> List[RiskFinding]:
return self._list_findings(self.final_findings)
def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]: def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
with self._lock: return self._get_findings_by_segment(self.findings, segment_id)
return [f for f in self.findings if f.segment_id == segment_id]
def get_final_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
return self._get_findings_by_segment(self.final_findings, segment_id)
def delete_findings_by_segment(self, segment_id: int) -> int: def delete_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("findings", segment_id)
def delete_final_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("final_findings", segment_id)
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.findings, keyword, rule_title, risk_level)
def search_final_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.final_findings, keyword, rule_title, risk_level)
def _add_finding(self, target: List[RiskFinding], finding: RiskFinding) -> RiskFinding:
with self._lock:
target.append(finding)
self._persist()
return finding
def _list_findings(self, target: List[RiskFinding]) -> List[RiskFinding]:
with self._lock:
return list(target)
def _get_findings_by_segment(self, target: List[RiskFinding], segment_id: int) -> List[RiskFinding]:
with self._lock: with self._lock:
before = len(self.findings) return [f for f in target if f.segment_id == segment_id]
self.findings = [f for f in self.findings if f.segment_id != segment_id]
removed = before - len(self.findings) def _delete_findings_by_segment(self, attr_name: str, segment_id: int) -> int:
with self._lock:
current = getattr(self, attr_name)
before = len(current)
updated = [f for f in current if f.segment_id != segment_id]
setattr(self, attr_name, updated)
removed = before - len(updated)
if removed: if removed:
self._persist() self._persist()
return removed return removed
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]: def _search_findings(
self,
target: List[RiskFinding],
keyword: str,
rule_title: Optional[str] = None,
risk_level: Optional[str] = None,
) -> List[RiskFinding]:
key = (keyword or "").strip().lower() key = (keyword or "").strip().lower()
with self._lock: with self._lock:
candidates = list(self.findings) candidates = list(target)
if rule_title: if rule_title:
candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()] candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()]
if risk_level: if risk_level:
...@@ -179,12 +222,14 @@ class MemoryStore: ...@@ -179,12 +222,14 @@ class MemoryStore:
with self._lock: with self._lock:
self.facts.clear() self.facts.clear()
self.findings.clear() self.findings.clear()
self.final_findings.clear()
self._persist() self._persist()
def _persist(self) -> None: def _persist(self) -> None:
payload = { payload = {
"facts": self.facts, "facts": self.facts,
"findings": [asdict(f) for f in self.findings], "findings": [asdict(f) for f in self.findings],
"final_findings": [asdict(f) for f in self.final_findings],
} }
try: try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
...@@ -200,6 +245,7 @@ class MemoryStore: ...@@ -200,6 +245,7 @@ class MemoryStore:
if isinstance(data, dict): if isinstance(data, dict):
self.facts = data.get("facts") or [] self.facts = data.get("facts") or []
self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []] self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []]
self.final_findings = [RiskFinding.from_dict(item) for item in data.get("final_findings", []) or []]
except Exception as exc: except Exception as exc:
logger.error("Failed to load memory store: %s", exc) logger.error("Failed to load memory store: %s", exc)
...@@ -233,6 +279,13 @@ class MemoryStore: ...@@ -233,6 +279,13 @@ class MemoryStore:
getattr(f, key, "") for key, _ in finding_headers getattr(f, key, "") for key, _ in finding_headers
]) ])
ws_final_findings = wb.create_sheet("final_findings")
ws_final_findings.append([label for _, label in finding_headers])
for f in self.final_findings:
ws_final_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
ws_facts = wb.create_sheet("facts") ws_facts = wb.create_sheet("facts")
if self.facts: if self.facts:
fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()}) fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()})
...@@ -263,8 +316,97 @@ class MemoryStore: ...@@ -263,8 +316,97 @@ class MemoryStore:
return res return res
def export_findings_to_doc_comments(
self,
doc_obj: DocBase,
file_name: Optional[str] = None,
remove_prefix: bool = False,
) -> Dict[str, Any]:
"""Add all findings as comments to a document, upload, then delete the local file."""
if doc_obj is None:
raise ValueError("doc_obj is required")
if __name__ == "__main__": ts = datetime.now().strftime("%Y%m%d_%H%M%S")
doc_name = getattr(doc_obj, "_doc_name", "") or ""
suffix = Path(doc_name).suffix or ".docx"
name = file_name or f"findings_{ts}{suffix}"
if not Path(name).suffix:
name = f"{name}{suffix}"
output_path = Path(__file__).resolve().parent.parent / "tmp" / name
with self._lock:
comments: List[Dict[str, Any]] = []
for idx, f in enumerate(self.findings, start=1):
# 导出final_findings
# for idx, f in enumerate(self.final_findings, start=1):
segment_id = int(f.segment_id or 0)
chunk_id = max(segment_id - 1, 0)
suggest_parts = []
if f.risk_level:
suggest_parts.append(f"风险等级:{f.risk_level}")
if f.issue:
suggest_parts.append(f"问题:{f.issue}")
if f.suggestion:
suggest_parts.append(f"建议:{f.suggestion}")
suggest_text = "\n".join(suggest_parts).strip()
comments.append(
{
"id": str(idx),
"key_points": f.rule_title or "风险提示",
"original_text": f.original_text or "",
"details": f.issue or "",
"chunk_id": chunk_id,
"result": "不合格",
"suggest": suggest_text,
}
)
if comments:
doc_obj.add_chunk_comment(0, comments)
doc_obj.to_file(str(output_path), remove_prefix=remove_prefix)
try:
res = upload_file(str(output_path))
finally:
try:
output_path.unlink()
except Exception:
logger.warning("Failed to delete temp doc: %s", output_path)
return res
def test_export_findings_to_doc_comments(doc_path: str) -> None:
store = MemoryStore()
finding = RiskFinding(
rule_title="违约责任",
segment_id=1,
original_text="湖南麓谷发展集团有限公司",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
store.add_final_finding(finding)
"""测试:将 findings 作为批注写入文档并上传。"""
if not doc_path:
print("doc_path 为空,跳过批注导出测试")
return
if not Path(doc_path).exists():
print(f"文件不存在,跳过批注导出测试: {doc_path}")
return
try:
from utils.spire_word_util import SpireWordDoc
except Exception as exc:
print(f"加载 SpireWordDoc 失败,跳过批注导出测试: {exc}")
return
doc = SpireWordDoc()
doc.load(doc_path)
res = store.export_findings_to_doc_comments(doc)
print("Export doc comments:")
print(json.dumps(res, ensure_ascii=False, indent=2))
def test_memory_and_export_excel():
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索 # 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore() store = MemoryStore()
store.set_facts([{ store.set_facts([{
...@@ -295,3 +437,8 @@ if __name__ == "__main__": ...@@ -295,3 +437,8 @@ if __name__ == "__main__":
print(json.dumps(asdict(f), ensure_ascii=False, indent=2)) print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print(store.export_to_excel()) print(store.export_to_excel())
if __name__ == "__main__":
# test_export_findings_to_doc_comments("/home/ccran/lufa-contract/tmp/股份转让协议.docx")
test_memory_and_export_excel()
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional import json
from typing import Dict, List, Optional, Any
from core.tool import ToolBase, tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
REFLECT_SYSTEM_PROMPT = '''
你是一个合同审查反思智能体(ReviewReflection)。
你的任务是:基于 facts 与全文上下文,
对已有 findings 输出修改操作。
你只能对 findings 执行以下四种操作:
- keep:确认该风险结论成立且无需修改
- update:修改一个已有风险
- add:新增一个由全文结构推导出的风险
- remove:删除一个不成立的风险
【输出约束】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
REFLECT_USER_PROMPT = '''
【输入】
【已有风险 findings】
{findings_json}
【合同事实记忆 facts】
{facts_json}
【合同立场】
站在 {party_role} 的立场进行反思审查。
【反思原则】
- 不得引入新的审查维度
- 所有判断必须有 facts 或合同原文证据支持
- 仅对已有 findings 进行增add、改update、删remove、保留keep操作
- 若风险在全文上下文中不成立,必须 remove
- 若风险成立但表述、严重性或建议不准确,必须 update
- 若由多个 findings 或全文结构推导出新的系统性风险,可 add
【输出要求】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
OUTPUT_FORMAT_SCHEMA = '''
```json
{
"operations": [
{
"op": "keep | add | update | remove",
"id": "string | null",
"findings": "object | null"
}
]
}
```
'''
@tool("reflect_retry", "反思重试质量闸")
class ReflectRetryTool(LLMTool):
def __init__(self) -> None:
super().__init__(REFLECT_SYSTEM_PROMPT)
@tool("reflect_retry", "反思重试质量闸(MVP)")
class ReflectRetryTool(ToolBase):
@tool_func( @tool_func(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"segment_id": {"type": "string"}, "party_role": {"type": "string"},
"ruleset_id": {"type": "string"}, "rule": {"type": "object"},
"context_memories": {"type": "array"}, "facts": {"type": "array"},
"review_output": {"type": "object"}, "findings": {"type": "array"},
}, },
"required": ["segment_id", "ruleset_id", "review_output"], "required": ["party_role", "rule", "facts", "findings"],
} }
) )
def run(self, segment_id: str, ruleset_id: str, context_memories: Optional[List[Dict]] = None, review_output: Dict = None) -> Dict: def run(self, party_role: str, rule: Dict, facts: Optional[List[Dict]] = None, findings: Optional[List[Dict]] = None) -> List[Dict]:
problems: List[str] = [] base_findings = self._build_findings_with_ids(findings or [])
actions: List[Dict] = [] user_content = REFLECT_USER_PROMPT.format(
findings = (review_output or {}).get("findings", []) findings_json=json.dumps(base_findings, ensure_ascii=False),
missings = (review_output or {}).get("missing_info", []) facts_json=json.dumps(facts or [], ensure_ascii=False),
party_role=party_role,
for f in findings: ) + OUTPUT_FORMAT_SCHEMA
evs = f.get("evidence_quotes", []) or [] messages = self.build_messages(user_content)
sugg = (f.get("suggestion") or "").strip()
if not evs: try:
problems.append("缺少对关键结论的原文引用") resp = self.run_with_loop(self.chat_async(messages))
if not sugg: data = self.parse_first_json(resp)
problems.append("建议条款措辞仍偏泛,未给可替换文本") except Exception:
data = {}
needs_evidence = any(m.get("type") == "evidence" for m in missings)
needs_knowledge = any(m.get("type") == "knowledge" for m in missings) operations = data.get("operations", []) or []
status = "pass" final_findings = self._apply_operations(base_findings, operations)
if needs_evidence: return final_findings
actions.append({"action": "FullRead", "target": segment_id, "need_focus": [m.get("need") for m in missings if m.get("type") == "evidence"]})
if needs_knowledge: def _build_findings_with_ids(self, findings: List[Dict]) -> List[Dict[str, Any]]:
actions.append({"action": "RetrieveReference", "topic": [m.get("topic") for m in missings if m.get("type") == "knowledge"]}) res: List[Dict[str, Any]] = []
if problems or actions: for idx, f in enumerate(findings):
status = "revise" fid = f.get("id") or f.get("finding_id") or f.get("_id") or f"f_{idx+1}"
item = dict(f)
retry_budget_left = int((review_output or {}).get("retry_budget_left", 1)) item["id"] = str(fid)
res.append(item)
return { return res
"segment_id": segment_id,
"status": status, def _apply_operations(self, base_findings: List[Dict[str, Any]], operations: List[Dict]) -> List[Dict[str, Any]]:
"problems": list(dict.fromkeys(problems)), by_id = {str(f.get("id")): dict(f) for f in base_findings if f.get("id") is not None}
"revise_instructions": actions, added: List[Dict[str, Any]] = []
"retry_budget_left": max(0, retry_budget_left - (1 if status == "revise" else 0)),
} for op in operations:
action = (op.get("op") or "").strip().lower()
target_id = op.get("id")
payload = op.get("findings")
if action == "keep":
if target_id is not None and str(target_id) in by_id:
continue
elif action == "remove":
if target_id is not None:
by_id.pop(str(target_id), None)
elif action == "update":
if target_id is not None and isinstance(payload, dict):
payload = dict(payload)
payload["id"] = str(target_id)
by_id[str(target_id)] = payload
elif action == "add":
if isinstance(payload, dict):
added.append(dict(payload))
merged = list(by_id.values()) + added
return merged
if __name__ == "__main__":
tool = ReflectRetryTool()
res = tool.run(
party_role="甲方",
rule={"title":"付款与验收", "segment_id": 3},
facts=[
{"segment_id": 3, "支付": {"比例": "30%预付款", "期限": "验收后30日内"}},
{"segment_id": 3, "验收": {"标准": "双方书面确认", "时限": "7个工作日"}},
],
findings=[
{
"rule_title": "付款条款审查",
"segment_id": 3,
"original_text": "甲方在验收合格后30日内支付剩余价款。",
"issue": "未约定逾期付款违约金,约束不足",
"risk_level": "M",
"suggestion": "增加逾期付款违约金条款,如逾期每日按未付款的0.05%计收"
},
{
"rule_title": "验收条款审查",
"segment_id": 3,
"original_text": "乙方交付后甲方7个工作日内完成验收。",
"issue": "未明确逾期未验收的后果",
"risk_level": "L",
"suggestion": "补充逾期未验收视为通过或约定明确责任"
},
]
)
print(res)
\ No newline at end of file
...@@ -9,8 +9,8 @@ from utils.openai_util import OpenAITool ...@@ -9,8 +9,8 @@ from utils.openai_util import OpenAITool
from core.config import LLM, MAX_WORKERS from core.config import LLM, MAX_WORKERS
class SegmentLLMBase(ToolBase): class LLMTool(ToolBase):
"""LLM-backed segment processor: builds prompts, calls LLM, parses JSON.""" """LLM-backed processor: builds prompts, calls LLM, parses JSON."""
def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None: def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None:
super().__init__() super().__init__()
......
...@@ -7,10 +7,10 @@ from typing import Dict, List, Optional ...@@ -7,10 +7,10 @@ from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil from utils.excel_util import ExcelUtil
from core.tools.segment_llm import SegmentLLMBase from core.tools.segment_llm import LLMTool
import re import re
REVIEW_SYSTEM_PROMPT = f''' REVIEW_SYSTEM_PROMPT = '''
你是一个专业的合同分段审查智能体(SegmentReview)。 你是一个专业的合同分段审查智能体(SegmentReview)。
你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。 你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。
...@@ -100,10 +100,10 @@ def _is_generic_suggestion(text: str) -> bool: ...@@ -100,10 +100,10 @@ def _is_generic_suggestion(text: str) -> bool:
return any(g in t for g in GENERIC_SUGGESTIONS) return any(g in t for g in GENERIC_SUGGESTIONS)
@tool("segment_review", "合同分段审查") @tool("segment_review", "合同分段审查")
class SegmentReviewTool(SegmentLLMBase): class SegmentReviewTool(LLMTool):
def __init__(self): def __init__(self):
super().__init__(REVIEW_SYSTEM_PROMPT) super().__init__(REVIEW_SYSTEM_PROMPT)
self.rule_version = "rule-v1" self.rule_version = "通用"
self.column_map = { self.column_map = {
"id": "ID", "id": "ID",
"title": "审查项", "title": "审查项",
...@@ -278,7 +278,7 @@ if __name__=="__main__": ...@@ -278,7 +278,7 @@ if __name__=="__main__":
result = tool.run( result = tool.run(
segment_id=1, segment_id=1,
segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。", segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。",
ruleset_id="rule-v1", ruleset_id="通用",
party_role="甲方", party_role="甲方",
context_memories=[ context_memories=[
{ {
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
from core.tool import tool, tool_func from core.tool import tool, tool_func
from core.tools.segment_llm import SegmentLLMBase from core.tools.segment_llm import LLMTool
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"] FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
...@@ -48,7 +48,7 @@ OUTPUT_EXAMPLE = ''' ...@@ -48,7 +48,7 @@ OUTPUT_EXAMPLE = '''
@tool("segment_summary", "分段事实提取") @tool("segment_summary", "分段事实提取")
class SegmentSummaryTool(SegmentLLMBase): class SegmentSummaryTool(LLMTool):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(SUMMARY_SYSTEM_PROMPT) super().__init__(SUMMARY_SYSTEM_PROMPT)
......
No preview for this file type
...@@ -6,13 +6,15 @@ from uuid import uuid4 ...@@ -6,13 +6,15 @@ from uuid import uuid4
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import uvicorn import uvicorn
import traceback
from utils.common_util import extract_url_file, format_now from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats from core.config import doc_support_formats, pdf_support_formats
from core.tools.segment_summary import SegmentSummaryTool from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool from core.tools.segment_review import SegmentReviewTool
from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0") app = FastAPI(title="合同审查智能体", version="0.1.0")
...@@ -20,25 +22,31 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp" ...@@ -20,25 +22,31 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True) TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool() summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool() review_tool = SegmentReviewTool()
reflect_tool = ReflectRetryTool()
######################################################################################################################## ########################################################################################################################
class DocumentParseRequest(BaseModel): class DocumentParseRequest(BaseModel):
conversation_id: str conversation_id: str
urls: List[str] = Field(..., description="File download url") urls: List[str] = Field(..., description="File download url")
file_ext: Optional[str] = None
ruleset_id: Optional[str] = "通用"
class DocumentParseResponse(BaseModel): class DocumentParseResponse(BaseModel):
conversation_id: str conversation_id: str
text: str
chunk_ids: List[int] chunk_ids: List[int]
ruleset_items: List[str]
text: Optional[str] = None
file_ext: Optional[str] = None
@app.post("/documents/parse", response_model=DocumentParseResponse) @app.post("/documents/parse", response_model=DocumentParseResponse)
def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse: async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
if not payload.urls: if not payload.urls:
raise HTTPException(status_code=400, detail="No URLs provided") raise HTTPException(status_code=400, detail="No URLs provided")
try: try:
filename = extract_url_file(payload.urls[0], doc_support_formats) support_formats = list(dict.fromkeys(doc_support_formats + pdf_support_formats))
filename = extract_url_file(payload.urls[0], support_formats)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}") raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}")
...@@ -47,13 +55,27 @@ def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse: ...@@ -47,13 +55,27 @@ def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
download_file(payload.urls[0], file_path) download_file(payload.urls[0], file_path)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=500, detail=f"Download failed: {exc}") raise HTTPException(status_code=500, detail=f"Download failed: {exc}")
# get doc tool
doc_obj = get_cached_doc_tool(payload.conversation_id) file_ext = payload.file_ext or Path(filename).suffix
try:
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
doc_obj.load(file_path) doc_obj.load(file_path)
text = doc_obj.get_all_text() # ocr
await doc_obj.get_from_ocr()
# text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list() chunk_ids = doc_obj.get_chunk_id_list()
# get ruleset items
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
ruleset_review_items = [r.get('title') for r in ruleset_items]
return DocumentParseResponse( return DocumentParseResponse(
conversation_id=payload.conversation_id, text=text,chunk_ids=chunk_ids conversation_id=payload.conversation_id,
# text=text,
chunk_ids=chunk_ids,
ruleset_items=ruleset_review_items,
file_ext = file_ext
) )
######################################################################################################################## ########################################################################################################################
...@@ -62,6 +84,7 @@ class SegmentSummaryRequest(BaseModel): ...@@ -62,6 +84,7 @@ class SegmentSummaryRequest(BaseModel):
conversation_id: str conversation_id: str
segment_id: int segment_id: int
party_role: Optional[str] = "" party_role: Optional[str] = ""
file_ext: str
context_facts: Optional[Dict] = None context_facts: Optional[Dict] = None
...@@ -75,7 +98,7 @@ class SegmentSummaryResponse(BaseModel): ...@@ -75,7 +98,7 @@ class SegmentSummaryResponse(BaseModel):
def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse: def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
try: try:
doc_obj = get_cached_doc_tool(payload.conversation_id) doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
...@@ -105,7 +128,8 @@ class SegmentReviewRequest(BaseModel): ...@@ -105,7 +128,8 @@ class SegmentReviewRequest(BaseModel):
conversation_id: str conversation_id: str
segment_id: int segment_id: int
party_role: Optional[str] = "" party_role: Optional[str] = ""
ruleset_id: Optional[str] = "rule-v1" ruleset_id: Optional[str] = "通用"
file_ext: str
context_memories: Optional[List[Dict]] = None context_memories: Optional[List[Dict]] = None
...@@ -119,7 +143,7 @@ class SegmentReviewResponse(BaseModel): ...@@ -119,7 +143,7 @@ class SegmentReviewResponse(BaseModel):
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
try: try:
doc_obj = get_cached_doc_tool(payload.conversation_id) doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}") raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
...@@ -132,7 +156,7 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -132,7 +156,7 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
result = review_tool.run( result = review_tool.run(
segment_id=payload.segment_id, segment_id=payload.segment_id,
segment_text=segment_text, segment_text=segment_text,
ruleset_id=payload.ruleset_id or "rule-v1", ruleset_id=payload.ruleset_id or "通用",
party_role=payload.party_role or "", party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(), context_memories=payload.context_memories or store.get_facts(),
) )
...@@ -159,6 +183,57 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse: ...@@ -159,6 +183,57 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
) )
######################################################################################################################## ########################################################################################################################
class ReflectReviewRequest(BaseModel):
conversation_id: str
party_role: str
ruleset_id: Optional[str] = "通用"
rule_title: str
class ReflectReviewResponse(BaseModel):
conversation_id: str
rule_title: str
findings: List[Dict]
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id)
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
facts = store.get_facts()
findings = [f.__dict__ for f in store.list_findings()]
final_findings = reflect_tool.run(
party_role=payload.party_role,
rule=rule,
facts=facts,
findings=findings,
)
for f in final_findings or []:
try:
store.add_final_finding_from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": f.get("segment_id", 0),
"original_text": f.get("original_text", ""),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
})
except Exception:
continue
return ReflectReviewResponse(
conversation_id=payload.conversation_id,
rule_title=payload.rule_title,
findings=final_findings or [],
)
########################################################################################################################
class ConversationResponse(BaseModel): class ConversationResponse(BaseModel):
conversation_id: str conversation_id: str
created_at: str created_at: str
...@@ -173,44 +248,41 @@ def new_conversation() -> ConversationResponse: ...@@ -173,44 +248,41 @@ def new_conversation() -> ConversationResponse:
class MemoryExportRequest(BaseModel): class MemoryExportRequest(BaseModel):
conversation_id: str conversation_id: str
file_ext: str
file_name: Optional[str] = None file_name: Optional[str] = None
class MemoryExportResponse(BaseModel): class MemoryExportResponse(BaseModel):
conversation_id: str conversation_id: str
url: str excel_url: str
data: Dict doc_url: str
@app.post("/memory/export", response_model=MemoryExportResponse) @app.post("/memory/export", response_model=MemoryExportResponse)
def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse: def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
store = get_cached_memory(payload.conversation_id) store = get_cached_memory(payload.conversation_id)
try: try:
res = store.export_to_excel(file_name=payload.file_name) doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
try:
excel_res = store.export_to_excel(file_name=payload.file_name)
except ImportError as exc: except ImportError as exc:
raise HTTPException(status_code=500, detail=str(exc)) raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=500, detail=f"Export failed: {exc}") raise HTTPException(status_code=500, detail=f"Export failed: {exc}")
url = "" try:
if isinstance(res, str): doc_res = store.export_findings_to_doc_comments(doc_obj)
url = res except Exception as exc:
elif isinstance(res, dict): traceback.print_exc()
for key in [ raise HTTPException(status_code=500, detail=f"Export doc comments failed: {exc}")
"url",
"file_url", return MemoryExportResponse(
"fileUrl", conversation_id=payload.conversation_id,
"link", excel_url=excel_res,
"downloadUrl", doc_url=doc_res,
"path", )
"filePath",
]:
val = res.get(key)
if isinstance(val, str) and val:
url = val
break
return MemoryExportResponse(conversation_id=payload.conversation_id, url=url, data=res if isinstance(res, dict) else {"result": res})
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
......
import json import json_repair
import random import random
import re import re
from datetime import datetime from datetime import datetime
from typing import Dict, List from typing import Dict, List
from loguru import logger from loguru import logger
from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size
...@@ -51,14 +50,14 @@ def extract_json(json_str:str) -> List[Dict]: ...@@ -51,14 +50,14 @@ def extract_json(json_str:str) -> List[Dict]:
# 清理控制字符 # 清理控制字符
s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s) s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s)
try: try:
obj = json.loads(s, strict=False) obj = json_repair.loads(s, strict=False)
if isinstance(obj, list): if isinstance(obj, list):
out_list.extend(obj) out_list.extend(obj)
else: else:
out_list.append(obj) out_list.append(obj)
return True return True
except Exception as e: except Exception as e:
logger.error(f"JSON解析失败: {e}") logger.error(f"JSON解析失败: {e} | 内容片段: {s}")
return False return False
results = [] results = []
......
...@@ -53,7 +53,7 @@ def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True) ...@@ -53,7 +53,7 @@ def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True)
return rsp return rsp
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False): def upload_file(path, input_url_to_inner=True, output_url_to_inner=False) -> str:
# 登录获取token # 登录获取token
login_data = {"username": "admin", "password": "admin@jpai.com"} login_data = {"username": "admin", "password": "admin@jpai.com"}
login_url = f"{base_backend_url}/admin-api/system/auth/login" login_url = f"{base_backend_url}/admin-api/system/auth/login"
...@@ -109,5 +109,5 @@ def url_replace_fastgpt(origin: str): ...@@ -109,5 +109,5 @@ def url_replace_fastgpt(origin: str):
if __name__ == "__main__": if __name__ == "__main__":
# d = '/home/ccran/file.docx' # d = '/home/ccran/file.docx'
d = "file.docx" d = "/home/ccran/lufa-contract/tmp/default.json"
print(os.path.basename(d)) print(upload_file(d))
import fitz
from urllib import parse
from urllib.parse import urlparse
import os
import asyncio
import aiohttp
from aiohttp import ClientSession
from utils.http_util import url_replace_fastgpt, download_file
from utils.common_util import random_str
from loguru import logger
import json
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
class OCRUtil:
def __init__(self):
pass
# ocr异步接口
async def ocr_requests_async(self, session, file_path):
# 文件上传通常使用multipart/form-data编码格式
# file_path = 'file/1.webp'
async with session.post(ocr_url, data={'file': open(file_path, 'rb')}) as response:
# logger.info(f'开始请求:{file_path}')
rsp = await response.text()
# logger.info(f'{file_path}结束请求:{json.loads(rsp)["msg"]}')
return rsp, file_path
# 异步请求ocr image接口
async def ocr_image_async(self, path_list):
timeout = aiohttp.ClientTimeout(total=600)
connector = aiohttp.TCPConnector(limit=1)
async with ClientSession(connector=connector, timeout=timeout) as session:
tasks = [self.ocr_requests_async(session, file_path) for file_path in path_list]
responses = await asyncio.gather(*tasks)
res_dict = {}
for response_text, file_path in responses:
rsp_json = json.loads(response_text)
if 'data' not in rsp_json:
logger.error(f'ocr_image_async {file_path} error:{rsp_json["msg"]}')
continue
else:
content = rsp_json['data']['strRes']
# logger.info(f'ocr_image_async append {file_path} success. content:{content[:100]}')
# add to dict
page_num = int(self.get_pdf_2_img_page_num(file_path))
res_dict[page_num] = rsp_json['data']['strRes']
# 根据页数排序
logger.info(f'ocr_image_async finish. all pages:{len(res_dict)}')
sorted_values = [res_dict[key] for key in sorted(res_dict)]
return sorted_values
def set_pdf_2_img_page(self, path, page_idx):
return f'{path}_{page_idx + 1}.png'
def get_pdf_2_img_page_num(self, path):
split_path = path.split('_')
# for example : 'xx_10.png' ==> get 10
return split_path[-1][:-4]
def pdf_2_img(self, path, zoom_x=1, zoom_y=1):
# 打开PDF文件
pdf = fitz.open(path)
pdf_list = []
# 逐页读取PDF
for pg in range(0, pdf.page_count):
page = pdf[pg]
# 设置缩放和旋转系数
trans = fitz.Matrix(zoom_x, zoom_y)
pm = page.get_pixmap(matrix=trans, alpha=False)
# img save
dest_png = self.set_pdf_2_img_page(path, pg)
pm.save(dest_png)
pdf_list.append(dest_png)
pdf.close()
return pdf_list
def ocr_download_path(self, url):
logger.info(f'ocr url:{url}')
# url替换
url = url_replace_fastgpt(url)
# url分析
url_parsed = urlparse(url)
query_dict = parse.parse_qs(url_parsed.query)
if 'filename' in query_dict:
filename = query_dict.get('filename')[0]
else:
filename = f'{random_str()}.pdf'
# 获取文件内容
dest_path = f'ocr/{filename}'
download_file(url, dest_path)
return dest_path
async def ocr_result_pdf(self, dest_path):
# 解析返回结果
pdf_list = self.pdf_2_img(dest_path)
# logger.info(f'pdf 生成完毕:{pdf_list}')
result = await self.ocr_image_async(pdf_list)
# 删除文件
for pdf in pdf_list:
if os.path.exists(pdf):
os.remove(pdf)
return result
if __name__ == '__main__':
ocr_util = OCRUtil()
result = asyncio.run(ocr_util.ocr_result_pdf('/home/ccran/lufa-contract/tmp/财03 2023年国科京东方评估合同.pdf'))
print(f'len(result):{len(result)}')
print(result)
import asyncio
from core.config import root_path
from utils.doc_util import DocBase
from spire.pdf import PdfDocument, PdfTextExtractOptions, PdfTextExtractor, License, PdfTextFinder, TextFindParameter, \
PdfTextMarkupAnnotation, PdfRGBColor, Color, PdfPolyLineAnnotation, PointF, RectangleF, PdfPopupAnnotation, \
PdfPopupIcon
from loguru import logger
from rapidfuzz import process
from utils.ocr_util import OCRUtil
import copy
from utils.common_util import adjust_single_chunk_size
import re
import os
ocr_thresholds = 1000
def is_messy_text(text: str,
min_chars=40,
chinese_ratio_thresh=0.20,
printable_ratio_thresh=0.70,
symbol_ratio_thresh=0.30,
longest_non_word_run_thresh=10,
english_word_density_thresh=0.03):
"""
判断单页或一段 text 是否 '乱码'。如果返回 True 表示该 text 被判为乱码(需要 OCR)。
参数可调整:
- min_chars: 少于这个字符数直接认为质量差(比如抽出来是碎行)
- chinese_ratio_thresh: 中文字符占比阈值(若低于且其它判据也成立则可能乱码)
- printable_ratio_thresh: 可打印字符占比(太低说明很多不可见/非标准字符)
- symbol_ratio_thresh: 标点/特殊符号占比阈值(高于则更像乱码)
- longest_non_word_run_thresh: 最长连续特殊符号串长度阈值(如 "~~##@@@... ")
- english_word_density_thresh: 英文单词密度(单词数/行数或单词数/字符数近似);若很低且无中文,则文本不可读
"""
if not text:
return True
text_len = len(text)
if text_len < min_chars:
# 太短的片段通常是拆分后的乱码或空白
return True
# 统计中文、字母/数字、可打印、空白
chinese_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
alnum_count = sum(1 for c in text if c.isalnum())
printable_count = sum(1 for c in text if c.isprintable())
space_count = sum(1 for c in text if c.isspace())
# 标点/特殊符号(既不是中文也不是字母数字也不是空白,视为符号)
symbol_count = sum(1 for c in text if not (('\u4e00' <= c <= '\u9fff') or c.isalnum() or c.isspace()))
chinese_ratio = chinese_count / text_len
printable_ratio = printable_count / text_len
symbol_ratio = symbol_count / text_len
# 连续的非字/非中文/非数字串(比如 "~~~##@@@")
non_word_runs = re.findall(r'[^0-9A-Za-z\u4e00-\u9fff\s]+', text)
longest_non_word_run = max((len(s) for s in non_word_runs), default=0)
# 英文单词密度(粗略):单词数 / 总字符数,单词定义为至少两个字母
english_words = re.findall(r'\b[a-zA-Z]{2,}\b', text)
english_word_density = len(english_words) / max(1, text_len)
# 启发式判断:多条规则投票(其中任意一条强指标触发即可)
# 强指标:非常高的符号占比、非常长的连续符号、几乎不可打印
if printable_ratio < printable_ratio_thresh * 0.6:
return True
if symbol_ratio > max(0.5, symbol_ratio_thresh):
return True
if longest_non_word_run >= longest_non_word_run_thresh * 1.5:
return True
# 组合判断:如果中文占比很低且英文单词密度也很低,并且符号占比高 -> 乱码
if chinese_ratio < chinese_ratio_thresh and english_word_density < english_word_density_thresh and symbol_ratio > symbol_ratio_thresh:
return True
# 如果中文占比极低且打印字符不多,也判为乱码
if chinese_ratio < (chinese_ratio_thresh * 0.5) and printable_ratio < printable_ratio_thresh:
return True
# 其它情况视为可读
return False
class SpirePdfDoc(DocBase):
def get_chunk_location(self, chunk_idx):
chunk_item = self._chunk_list[chunk_idx]
return '、'.join(str(page + 1) for page in chunk_item['page_idx'])
# 调整块大小
def adjust_chunk_size(self):
all_text_len = len(self.get_all_text())
self._max_single_chunk_size = adjust_single_chunk_size(all_text_len)
logger.info(f'SpirePdfDoc adjust _max_single_chunk_size to {self._max_single_chunk_size}')
self.reset()
return self._max_single_chunk_size
async def get_from_ocr(self):
"""
新策略:
- 先用 Spire 的提取结果(self._chunk_list 已由 reset/init_chunks 填充)
- 对每一页单独用 is_messy_text 判断
- 如果被判为乱码的页占比超过 page_messy_ratio_thresh(比如 0.4),则触发 OCR(按页OCR)
- 若被判为乱码的页数量少(但某些页极差),也可触发 OCR(按需要)
"""
# 先从当前 chunk_list 拼出每页文本(init_chunks 保证 chunk_list 基于 PdfTextExtractor)
# 注意 chunk_list 中 page 是每个 chunk 的 page 文本列表,我们要逐页统计
all_pages = []
for chunk in self._chunk_list:
# chunk['page'] 是字符串列表,对应 chunk['page_idx'] 顺序一致
all_pages.extend(chunk.get('page', []))
# 如果没有任何页面文本(极端情况),强制 OCR
if len(all_pages) == 0:
logger.info('Spire extract returned 0 pages text, forcing OCR.')
do_ocr = True
else:
# 针对每页判断是否为乱码
page_messy_flags = []
for i, page_text in enumerate(all_pages):
messy = is_messy_text(page_text)
page_messy_flags.append(messy)
logger.debug(f'page {i + 1}: len={len(page_text)} messy={messy}')
messy_count = sum(1 for f in page_messy_flags if f)
total_pages = len(all_pages)
messy_ratio = messy_count / total_pages
logger.info(f'spire extract pages: {total_pages}, messy pages: {messy_count}, ratio: {messy_ratio:.2f}')
# 触发 OCR 的条件(可调):
# 1) 若多数页为乱码 -> OCR
# 2) 若至少一页非常差(例如 is_messy_text 在内部检测到极端情况) -> OCR
page_messy_ratio_thresh = 0.40 # 如果超过 40% 页被判为乱码则 OCR
# 若某页文本长度非常短(如 <20)且很多页属于这种短页 -> OCR
many_short_pages = sum(1 for p in all_pages if len(p) < 20) / total_pages > 0.3
do_ocr = (messy_ratio > page_messy_ratio_thresh) or many_short_pages
# 作为后备:如果全文字符数低于 ocr_thresholds(原逻辑),同样 OCR
all_text = '\n'.join(all_pages)
if len(all_text) < ocr_thresholds:
do_ocr = True
if do_ocr:
page_result_list = await self.ocr_util.ocr_result_pdf(self._doc_path)
logger.info(f'ocr success: len[{len(page_result_list)}]')
# 用 OCR 结果重新构建 chunk_list(与原逻辑一致)
chunk_list = []
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
for i in range(0, len(page_result_list)):
text = page_result_list[i]
if cur_text_len + len(text) > self._max_single_chunk_size and len(chunk_item['page']):
chunk_list.append(copy.deepcopy(chunk_item))
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
chunk_item['page'].append(text)
chunk_item['page_idx'].append(i)
cur_text_len += len(text)
if len(chunk_item['page']):
chunk_list.append(copy.deepcopy(chunk_item))
self._chunk_list = chunk_list
else:
logger.info('Spire extract judged readable — skip OCR.')
return ''
def __init__(self, **kwargs):
super(SpirePdfDoc, self).__init__(**kwargs)
self.ocr_util = OCRUtil()
def load(self,doc_path):
self._doc_path = doc_path
self._doc_name = os.path.basename(doc_path)
self._doc = PdfDocument()
self._doc.LoadFromFile(self._doc_path)
self._chunk_list = self.init_chunks()
def init_chunks(self):
chunk_list = []
extract_options = PdfTextExtractOptions()
extract_options.IsExtractAllText = True
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
for i in range(0, self._doc.Pages.Count):
page = self._doc.Pages[i]
text = PdfTextExtractor(page).ExtractText(extract_options)
if cur_text_len + len(text) > self._max_single_chunk_size and len(chunk_item['page_objs']):
chunk_list.append(chunk_item)
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
chunk_item['page_objs'].append(page)
chunk_item['page'].append(text)
chunk_item['page_idx'].append(i)
cur_text_len += len(text)
if len(chunk_item['page_objs']):
chunk_list.append(chunk_item)
return chunk_list
def get_chunk_item(self, chunk_id):
chunk = self._chunk_list[chunk_id]
return '\n'.join([page for page in chunk['page']])
def get_chunk_info(self, chunk_id):
chunk = self._chunk_list[chunk_id]
from_location = f'[页面{chunk["page_idx"][0] + 1}]'
to_location = f'[页面{chunk["page_idx"][-1] + 1}]'
chunk_content_tips = '[' + chunk['page'][0][:20] + ']...到...[' + chunk['page'][-1][-20:] + ']'
return f'文件块id: {chunk_id + 1}\n文件块位置: 从{from_location}到{to_location}\n文件块简述: {chunk_content_tips}\n'
def format_comment_author(self, comment):
return '{}|{}'.format(str(comment['id']), comment['key_points'])
def add_text_markup_anno(self, page_obj, text_fragment, author, review_suggest):
# 遍历文本边界
for j in range(len(text_fragment.Bounds)):
# 获取边界
rect = text_fragment.Bounds[j]
# 创建文本标记注释
annotation = PdfTextMarkupAnnotation(author, review_suggest, rect)
annotation.Name = author
annotation.TextMarkupColor = PdfRGBColor(Color.get_Green())
# 将注释添加到注释集合中
page_obj.AnnotationsWidget.Add(annotation)
# TODO 好像注释删不掉
def add_poly_line_anno(self, page_obj, text_fragment, author, review_suggest):
# 获取文本边界
rect = text_fragment.Bounds[0]
# 获取文本边界的坐标以添加注释
left = rect.Left
top = rect.Top
right = rect.Right
bottom = rect.Bottom
polyLineAnnotation = PdfPolyLineAnnotation(page_obj,
[PointF(left, top), PointF(right, top), PointF(right, bottom),
PointF(left, bottom), PointF(left, top)])
polyLineAnnotation.Name = author
# 自定义注释内容
polyLineAnnotation.Text = review_suggest
# 将注释添加到文档的注释集合中
page_obj.AnnotationsWidget.Add(polyLineAnnotation)
# TODO 批注好像删不掉
def add_popup_anno(self, page_obj, text_fragment, author, review_suggest):
# 获取文本边界
rect = text_fragment.Bounds[0]
# 获取文本边界的坐标以添加注释
right = rect.Right
top = rect.Top
# 创建注释
rectangle = RectangleF(right, top - 15., 15., 15.)
popupAnnotation = PdfPopupAnnotation(rectangle)
# 自定义注释内容
popupAnnotation.Name = author
popupAnnotation.Text = review_suggest
# 设置注释的图标和颜色
popupAnnotation.Icon = PdfPopupIcon.Comment
popupAnnotation.Color = PdfRGBColor(Color.get_Red())
# 将注释添加到文档的注释集合中
page_obj.AnnotationsWidget.Add(popupAnnotation)
def add_anno(self, chunk, find_key, author, review_suggest):
accurate_find = False
for i in range(0, len(chunk['page_objs'])):
page_obj = chunk['page_objs'][i]
finder = PdfTextFinder(page_obj)
finder.Options.Parameter = TextFindParameter.WholeWord
fragments = finder.Find(find_key)
# 未精确匹配
if fragments is None or len(fragments) == 0:
pass
else:
textFragment = fragments[0]
# logger.info(f'[add_anno markup] {find_key} {author} {review_suggest}')
# textFragment.HighLight(Color.get_Yellow())
# self.add_text_markup_anno(page_obj, textFragment, author, review_suggest)
# self.add_poly_line_anno(page_obj, textFragment, author, review_suggest)
self.add_popup_anno(page_obj, textFragment, author, review_suggest)
accurate_find = True
return accurate_find
def add_chunk_comment(self, chunk_id, comments):
# logger.info(f'add_chunk_comment: {chunk_id} {comments}')
chunk = self._chunk_list[chunk_id]
for comment in comments:
if comment['result'] != '不合格':
continue
review_suggest = comment.get('suggest', '')
author = self.format_comment_author(comment)
find_key = comment['original_text'].strip() if comment['original_text'].strip() else comment['key_points']
# 先进行精确查找
accurate_find = self.add_anno(chunk, find_key, author, review_suggest)
# 模糊查找
if not accurate_find:
sub_chunk_texts = []
for page in chunk['page']:
sub_chunk_texts.extend([item.strip() for item in page.split('\n') if len(item.strip())])
top_match = process.extract(find_key, sub_chunk_texts, limit=1)
sub_chunk_text, score, sub_chunk_idx = top_match[0]
# logger.info(f'add_chunk_comment find_key:{find_key} fuzz match: {sub_chunk_text}')
# 在进行模糊查找
self.add_anno(chunk, sub_chunk_text, author, review_suggest)
def print_all_anno(self):
for j in range(0, self._doc.Pages.Count):
page_obj = self._doc.Pages.get_Item(j)
for i in range(page_obj.AnnotationsWidget.Count):
anno = page_obj.AnnotationsWidget.get_Item(i)
name = ''
try:
name = anno.Name
except:
name = ''
logger.info(f'anno批注:{j + 1}页 第{i + 1}个 [{name}]:[{anno.Text}]')
# page_obj.AnnotationsWidget.Clear()
def find_anno_idx(self, page_obj, author):
for i in range(page_obj.AnnotationsWidget.Count):
anno = page_obj.AnnotationsWidget.get_Item(i)
anno_name = ''
try:
anno_name = anno.Name
except Exception as e:
pass
# logger.info(f'{type(anno)} {anno_name}')
if anno_name == author:
# logger.info(f'find_anno_idx:{author} at {i}')
return i
return -1
def edit_chunk_comment(self, comments):
for comment in comments:
review_answer = comment['result']
if review_answer == '不合格':
idx, page, delete_comment = self.delete_single_chunk_comment_once(comment)
# 不存在存在删除项,新增
if idx == -1:
self.add_chunk_comment(comment['chunk_id'], [comment])
# 存在删除项,编辑
else:
review_suggest = comment.get('suggest', '')
all_pages = []
all_delete_comment = []
all_idx = []
while idx != -1:
delete_comment.Text = review_suggest
all_pages.append(page)
all_delete_comment.append(delete_comment)
all_idx.append(idx)
idx, page, delete_comment = self.delete_single_chunk_comment_once(comment)
for page, delete_comment, idx in zip(all_pages, all_delete_comment, all_idx):
page.AnnotationsWidget.Insert(idx, delete_comment)
# 合格则删除
else:
self.delete_single_chunk_comment(comment)
def delete_chunk_comment(self, comments):
for comment in comments:
self.delete_single_chunk_comment(comment)
def delete_single_chunk_comment(self, comment):
while True:
idx, _, _ = self.delete_single_chunk_comment_once(comment)
if idx == -1:
break
# 删除一次批注
def delete_single_chunk_comment_once(self, comment):
author = self.format_comment_author(comment)
chunk_id = comment['chunk_id']
chunk = self._chunk_list[chunk_id]
for page_obj in chunk['page_objs']:
idx = self.find_anno_idx(page_obj, author)
if idx != -1:
# logger.info(f'delete_single_chunk_comment_once:{author} {comment}')
delete_comment = page_obj.AnnotationsWidget.get_Item(idx)
page_obj.AnnotationsWidget.RemoveAt(idx)
# logger.info(page_obj.AnnotationsWidget.Contains(delete_comment))
return idx, page_obj, delete_comment
return -1, None, None
def get_chunk_id_list(self, step=1):
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)]
def get_chunk_num(self):
return len(self._chunk_list)
def get_all_text(self):
return '\n'.join([page for chunk in self._chunk_list for page in chunk['page']])
def to_file(self, path, **kwargs):
self._doc.SaveToFile(path)
def re_comment_by_memory(self, comments):
self._doc.Close()
self.reset()
for comment in comments:
self.add_chunk_comment(comment['chunk_id'], [comment])
def release(self):
self._doc.Close()
super().release()
if __name__ == '__main__':
tool = SpirePdfDoc()
tool.load('/home/ccran/lufa-contract/tmp/财03 2023年国科京东方评估合同.pdf')
print('', tool.get_all_text())
print(asyncio.run(tool.get_from_ocr()))
print(tool.get_chunk_id_list())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment