Commit 9fa3fd96 by ccran

feat: 增加PDF处理;

parent 53a63f97
tmp/
\ No newline at end of file
tmp/
*.pyc
**__pycache__**
\ No newline at end of file
from utils.spire_word_util import SpireWordDoc
from utils.spire_pdf_util import SpirePdfDoc
from utils.doc_util import DocBase
from functools import lru_cache
from typing import Optional
from typing import Optional, Tuple
from core.memory import MemoryStore
from core.config import pdf_support_formats
MAX_CACHE = 128
def _normalize_file_ext(file_ext: str) -> str:
if not file_ext:
raise ValueError("file_ext is required")
ext = file_ext.strip().lower()
if not ext.startswith("."):
ext = f".{ext}"
return ext
@lru_cache(maxsize=MAX_CACHE)
def get_cached_doc_tool(conversation_id: str) -> Optional[DocBase]:
return SpireWordDoc()
def get_cached_doc_tool(conversation_id: str, file_ext: str) -> Tuple[DocBase, str]:
ext = _normalize_file_ext(file_ext)
if ext in pdf_support_formats:
return SpirePdfDoc(), ext
return SpireWordDoc(), ext
@lru_cache(maxsize=MAX_CACHE)
def get_cached_memory(conversation_id: str) -> MemoryStore:
......
......@@ -29,7 +29,6 @@ class LLMConfig:
outer_backend_url = "http://218.77.58.8:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
# 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent"
......@@ -70,7 +69,6 @@ field_mapping = {
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5
# 销售类别判断
ocr_thresholds = 1000
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思"
# 最大分片数量
......
......@@ -58,6 +58,7 @@ from threading import RLock
from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file
from utils.doc_util import DocBase
logger = logging.getLogger(__name__)
......@@ -107,6 +108,7 @@ class MemoryStore:
self._lock = RLock()
self.facts: List[Dict[str, Any]] = []
self.findings: List[RiskFinding] = []
self.final_findings: List[RiskFinding] = []
self._load()
# ---------------------- facts ----------------------
......@@ -128,35 +130,76 @@ class MemoryStore:
# -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding:
with self._lock:
self.findings.append(finding)
self._persist()
return finding
return self._add_finding(self.findings, finding)
def add_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_finding(RiskFinding.from_dict(data))
def add_final_finding(self, finding: RiskFinding) -> RiskFinding:
return self._add_finding(self.final_findings, finding)
def add_final_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_final_finding(RiskFinding.from_dict(data))
def list_findings(self) -> List[RiskFinding]:
with self._lock:
return list(self.findings)
return self._list_findings(self.findings)
def list_final_findings(self) -> List[RiskFinding]:
return self._list_findings(self.final_findings)
def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
with self._lock:
return [f for f in self.findings if f.segment_id == segment_id]
return self._get_findings_by_segment(self.findings, segment_id)
def get_final_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
return self._get_findings_by_segment(self.final_findings, segment_id)
def delete_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("findings", segment_id)
def delete_final_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("final_findings", segment_id)
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.findings, keyword, rule_title, risk_level)
def search_final_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.final_findings, keyword, rule_title, risk_level)
def _add_finding(self, target: List[RiskFinding], finding: RiskFinding) -> RiskFinding:
with self._lock:
target.append(finding)
self._persist()
return finding
def _list_findings(self, target: List[RiskFinding]) -> List[RiskFinding]:
with self._lock:
before = len(self.findings)
self.findings = [f for f in self.findings if f.segment_id != segment_id]
removed = before - len(self.findings)
return list(target)
def _get_findings_by_segment(self, target: List[RiskFinding], segment_id: int) -> List[RiskFinding]:
with self._lock:
return [f for f in target if f.segment_id == segment_id]
def _delete_findings_by_segment(self, attr_name: str, segment_id: int) -> int:
with self._lock:
current = getattr(self, attr_name)
before = len(current)
updated = [f for f in current if f.segment_id != segment_id]
setattr(self, attr_name, updated)
removed = before - len(updated)
if removed:
self._persist()
return removed
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
def _search_findings(
self,
target: List[RiskFinding],
keyword: str,
rule_title: Optional[str] = None,
risk_level: Optional[str] = None,
) -> List[RiskFinding]:
key = (keyword or "").strip().lower()
with self._lock:
candidates = list(self.findings)
candidates = list(target)
if rule_title:
candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()]
if risk_level:
......@@ -179,12 +222,14 @@ class MemoryStore:
with self._lock:
self.facts.clear()
self.findings.clear()
self.final_findings.clear()
self._persist()
def _persist(self) -> None:
payload = {
"facts": self.facts,
"findings": [asdict(f) for f in self.findings],
"final_findings": [asdict(f) for f in self.final_findings],
}
try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
......@@ -200,6 +245,7 @@ class MemoryStore:
if isinstance(data, dict):
self.facts = data.get("facts") or []
self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []]
self.final_findings = [RiskFinding.from_dict(item) for item in data.get("final_findings", []) or []]
except Exception as exc:
logger.error("Failed to load memory store: %s", exc)
......@@ -233,6 +279,13 @@ class MemoryStore:
getattr(f, key, "") for key, _ in finding_headers
])
ws_final_findings = wb.create_sheet("final_findings")
ws_final_findings.append([label for _, label in finding_headers])
for f in self.final_findings:
ws_final_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
ws_facts = wb.create_sheet("facts")
if self.facts:
fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()})
......@@ -263,8 +316,97 @@ class MemoryStore:
return res
def export_findings_to_doc_comments(
self,
doc_obj: DocBase,
file_name: Optional[str] = None,
remove_prefix: bool = False,
) -> Dict[str, Any]:
"""Add all findings as comments to a document, upload, then delete the local file."""
if doc_obj is None:
raise ValueError("doc_obj is required")
if __name__ == "__main__":
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
doc_name = getattr(doc_obj, "_doc_name", "") or ""
suffix = Path(doc_name).suffix or ".docx"
name = file_name or f"findings_{ts}{suffix}"
if not Path(name).suffix:
name = f"{name}{suffix}"
output_path = Path(__file__).resolve().parent.parent / "tmp" / name
with self._lock:
comments: List[Dict[str, Any]] = []
for idx, f in enumerate(self.findings, start=1):
# 导出final_findings
# for idx, f in enumerate(self.final_findings, start=1):
segment_id = int(f.segment_id or 0)
chunk_id = max(segment_id - 1, 0)
suggest_parts = []
if f.risk_level:
suggest_parts.append(f"风险等级:{f.risk_level}")
if f.issue:
suggest_parts.append(f"问题:{f.issue}")
if f.suggestion:
suggest_parts.append(f"建议:{f.suggestion}")
suggest_text = "\n".join(suggest_parts).strip()
comments.append(
{
"id": str(idx),
"key_points": f.rule_title or "风险提示",
"original_text": f.original_text or "",
"details": f.issue or "",
"chunk_id": chunk_id,
"result": "不合格",
"suggest": suggest_text,
}
)
if comments:
doc_obj.add_chunk_comment(0, comments)
doc_obj.to_file(str(output_path), remove_prefix=remove_prefix)
try:
res = upload_file(str(output_path))
finally:
try:
output_path.unlink()
except Exception:
logger.warning("Failed to delete temp doc: %s", output_path)
return res
def test_export_findings_to_doc_comments(doc_path: str) -> None:
store = MemoryStore()
finding = RiskFinding(
rule_title="违约责任",
segment_id=1,
original_text="湖南麓谷发展集团有限公司",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
store.add_final_finding(finding)
"""测试:将 findings 作为批注写入文档并上传。"""
if not doc_path:
print("doc_path 为空,跳过批注导出测试")
return
if not Path(doc_path).exists():
print(f"文件不存在,跳过批注导出测试: {doc_path}")
return
try:
from utils.spire_word_util import SpireWordDoc
except Exception as exc:
print(f"加载 SpireWordDoc 失败,跳过批注导出测试: {exc}")
return
doc = SpireWordDoc()
doc.load(doc_path)
res = store.export_findings_to_doc_comments(doc)
print("Export doc comments:")
print(json.dumps(res, ensure_ascii=False, indent=2))
def test_memory_and_export_excel():
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore()
store.set_facts([{
......@@ -295,3 +437,8 @@ if __name__ == "__main__":
print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print(store.export_to_excel())
if __name__ == "__main__":
# test_export_findings_to_doc_comments("/home/ccran/lufa-contract/tmp/股份转让协议.docx")
test_memory_and_export_excel()
from __future__ import annotations
from typing import Dict, List, Optional
from core.tool import ToolBase, tool, tool_func
@tool("reflect_retry", "反思重试质量闸(MVP)")
class ReflectRetryTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "string"},
"ruleset_id": {"type": "string"},
"context_memories": {"type": "array"},
"review_output": {"type": "object"},
},
"required": ["segment_id", "ruleset_id", "review_output"],
}
)
def run(self, segment_id: str, ruleset_id: str, context_memories: Optional[List[Dict]] = None, review_output: Dict = None) -> Dict:
problems: List[str] = []
actions: List[Dict] = []
findings = (review_output or {}).get("findings", [])
missings = (review_output or {}).get("missing_info", [])
for f in findings:
evs = f.get("evidence_quotes", []) or []
sugg = (f.get("suggestion") or "").strip()
if not evs:
problems.append("缺少对关键结论的原文引用")
if not sugg:
problems.append("建议条款措辞仍偏泛,未给可替换文本")
needs_evidence = any(m.get("type") == "evidence" for m in missings)
needs_knowledge = any(m.get("type") == "knowledge" for m in missings)
status = "pass"
if needs_evidence:
actions.append({"action": "FullRead", "target": segment_id, "need_focus": [m.get("need") for m in missings if m.get("type") == "evidence"]})
if needs_knowledge:
actions.append({"action": "RetrieveReference", "topic": [m.get("topic") for m in missings if m.get("type") == "knowledge"]})
if problems or actions:
status = "revise"
retry_budget_left = int((review_output or {}).get("retry_budget_left", 1))
return {
"segment_id": segment_id,
"status": status,
"problems": list(dict.fromkeys(problems)),
"revise_instructions": actions,
"retry_budget_left": max(0, retry_budget_left - (1 if status == "revise" else 0)),
}
import json
from typing import Dict, List, Optional, Any
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
REFLECT_SYSTEM_PROMPT = '''
你是一个合同审查反思智能体(ReviewReflection)。
你的任务是:基于 facts 与全文上下文,
对已有 findings 输出修改操作。
你只能对 findings 执行以下四种操作:
- keep:确认该风险结论成立且无需修改
- update:修改一个已有风险
- add:新增一个由全文结构推导出的风险
- remove:删除一个不成立的风险
【输出约束】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
REFLECT_USER_PROMPT = '''
【输入】
【已有风险 findings】
{findings_json}
【合同事实记忆 facts】
{facts_json}
【合同立场】
站在 {party_role} 的立场进行反思审查。
【反思原则】
- 不得引入新的审查维度
- 所有判断必须有 facts 或合同原文证据支持
- 仅对已有 findings 进行增add、改update、删remove、保留keep操作
- 若风险在全文上下文中不成立,必须 remove
- 若风险成立但表述、严重性或建议不准确,必须 update
- 若由多个 findings 或全文结构推导出新的系统性风险,可 add
【输出要求】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
OUTPUT_FORMAT_SCHEMA = '''
```json
{
"operations": [
{
"op": "keep | add | update | remove",
"id": "string | null",
"findings": "object | null"
}
]
}
```
'''
@tool("reflect_retry", "反思重试质量闸")
class ReflectRetryTool(LLMTool):
def __init__(self) -> None:
super().__init__(REFLECT_SYSTEM_PROMPT)
@tool_func(
{
"type": "object",
"properties": {
"party_role": {"type": "string"},
"rule": {"type": "object"},
"facts": {"type": "array"},
"findings": {"type": "array"},
},
"required": ["party_role", "rule", "facts", "findings"],
}
)
def run(self, party_role: str, rule: Dict, facts: Optional[List[Dict]] = None, findings: Optional[List[Dict]] = None) -> List[Dict]:
base_findings = self._build_findings_with_ids(findings or [])
user_content = REFLECT_USER_PROMPT.format(
findings_json=json.dumps(base_findings, ensure_ascii=False),
facts_json=json.dumps(facts or [], ensure_ascii=False),
party_role=party_role,
) + OUTPUT_FORMAT_SCHEMA
messages = self.build_messages(user_content)
try:
resp = self.run_with_loop(self.chat_async(messages))
data = self.parse_first_json(resp)
except Exception:
data = {}
operations = data.get("operations", []) or []
final_findings = self._apply_operations(base_findings, operations)
return final_findings
def _build_findings_with_ids(self, findings: List[Dict]) -> List[Dict[str, Any]]:
res: List[Dict[str, Any]] = []
for idx, f in enumerate(findings):
fid = f.get("id") or f.get("finding_id") or f.get("_id") or f"f_{idx+1}"
item = dict(f)
item["id"] = str(fid)
res.append(item)
return res
def _apply_operations(self, base_findings: List[Dict[str, Any]], operations: List[Dict]) -> List[Dict[str, Any]]:
by_id = {str(f.get("id")): dict(f) for f in base_findings if f.get("id") is not None}
added: List[Dict[str, Any]] = []
for op in operations:
action = (op.get("op") or "").strip().lower()
target_id = op.get("id")
payload = op.get("findings")
if action == "keep":
if target_id is not None and str(target_id) in by_id:
continue
elif action == "remove":
if target_id is not None:
by_id.pop(str(target_id), None)
elif action == "update":
if target_id is not None and isinstance(payload, dict):
payload = dict(payload)
payload["id"] = str(target_id)
by_id[str(target_id)] = payload
elif action == "add":
if isinstance(payload, dict):
added.append(dict(payload))
merged = list(by_id.values()) + added
return merged
if __name__ == "__main__":
tool = ReflectRetryTool()
res = tool.run(
party_role="甲方",
rule={"title":"付款与验收", "segment_id": 3},
facts=[
{"segment_id": 3, "支付": {"比例": "30%预付款", "期限": "验收后30日内"}},
{"segment_id": 3, "验收": {"标准": "双方书面确认", "时限": "7个工作日"}},
],
findings=[
{
"rule_title": "付款条款审查",
"segment_id": 3,
"original_text": "甲方在验收合格后30日内支付剩余价款。",
"issue": "未约定逾期付款违约金,约束不足",
"risk_level": "M",
"suggestion": "增加逾期付款违约金条款,如逾期每日按未付款的0.05%计收"
},
{
"rule_title": "验收条款审查",
"segment_id": 3,
"original_text": "乙方交付后甲方7个工作日内完成验收。",
"issue": "未明确逾期未验收的后果",
"risk_level": "L",
"suggestion": "补充逾期未验收视为通过或约定明确责任"
},
]
)
print(res)
\ No newline at end of file
......@@ -9,8 +9,8 @@ from utils.openai_util import OpenAITool
from core.config import LLM, MAX_WORKERS
class SegmentLLMBase(ToolBase):
"""LLM-backed segment processor: builds prompts, calls LLM, parses JSON."""
class LLMTool(ToolBase):
"""LLM-backed processor: builds prompts, calls LLM, parses JSON."""
def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None:
super().__init__()
......
......@@ -7,10 +7,10 @@ from typing import Dict, List, Optional
from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil
from core.tools.segment_llm import SegmentLLMBase
from core.tools.segment_llm import LLMTool
import re
REVIEW_SYSTEM_PROMPT = f'''
REVIEW_SYSTEM_PROMPT = '''
你是一个专业的合同分段审查智能体(SegmentReview)。
你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。
......@@ -100,10 +100,10 @@ def _is_generic_suggestion(text: str) -> bool:
return any(g in t for g in GENERIC_SUGGESTIONS)
@tool("segment_review", "合同分段审查")
class SegmentReviewTool(SegmentLLMBase):
class SegmentReviewTool(LLMTool):
def __init__(self):
super().__init__(REVIEW_SYSTEM_PROMPT)
self.rule_version = "rule-v1"
self.rule_version = "通用"
self.column_map = {
"id": "ID",
"title": "审查项",
......@@ -278,7 +278,7 @@ if __name__=="__main__":
result = tool.run(
segment_id=1,
segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。",
ruleset_id="rule-v1",
ruleset_id="通用",
party_role="甲方",
context_memories=[
{
......
......@@ -6,7 +6,7 @@ import json
from typing import Dict, List, Optional
from core.tool import tool, tool_func
from core.tools.segment_llm import SegmentLLMBase
from core.tools.segment_llm import LLMTool
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
......@@ -48,7 +48,7 @@ OUTPUT_EXAMPLE = '''
@tool("segment_summary", "分段事实提取")
class SegmentSummaryTool(SegmentLLMBase):
class SegmentSummaryTool(LLMTool):
def __init__(self) -> None:
super().__init__(SUMMARY_SYSTEM_PROMPT)
......
No preview for this file type
......@@ -6,13 +6,15 @@ from uuid import uuid4
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn
import traceback
from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats
from core.config import doc_support_formats, pdf_support_formats
from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool
from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0")
......@@ -20,25 +22,31 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool()
reflect_tool = ReflectRetryTool()
########################################################################################################################
class DocumentParseRequest(BaseModel):
conversation_id: str
urls: List[str] = Field(..., description="File download url")
file_ext: Optional[str] = None
ruleset_id: Optional[str] = "通用"
class DocumentParseResponse(BaseModel):
conversation_id: str
text: str
chunk_ids: List[int]
ruleset_items: List[str]
text: Optional[str] = None
file_ext: Optional[str] = None
@app.post("/documents/parse", response_model=DocumentParseResponse)
def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
if not payload.urls:
raise HTTPException(status_code=400, detail="No URLs provided")
try:
filename = extract_url_file(payload.urls[0], doc_support_formats)
support_formats = list(dict.fromkeys(doc_support_formats + pdf_support_formats))
filename = extract_url_file(payload.urls[0], support_formats)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}")
......@@ -47,13 +55,27 @@ def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
download_file(payload.urls[0], file_path)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Download failed: {exc}")
doc_obj = get_cached_doc_tool(payload.conversation_id)
# get doc tool
file_ext = payload.file_ext or Path(filename).suffix
try:
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
doc_obj.load(file_path)
text = doc_obj.get_all_text()
# ocr
await doc_obj.get_from_ocr()
# text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list()
# get ruleset items
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
ruleset_review_items = [r.get('title') for r in ruleset_items]
return DocumentParseResponse(
conversation_id=payload.conversation_id, text=text,chunk_ids=chunk_ids
conversation_id=payload.conversation_id,
# text=text,
chunk_ids=chunk_ids,
ruleset_items=ruleset_review_items,
file_ext = file_ext
)
########################################################################################################################
......@@ -62,6 +84,7 @@ class SegmentSummaryRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
file_ext: str
context_facts: Optional[Dict] = None
......@@ -75,7 +98,7 @@ class SegmentSummaryResponse(BaseModel):
def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
......@@ -105,7 +128,8 @@ class SegmentReviewRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
ruleset_id: Optional[str] = "rule-v1"
ruleset_id: Optional[str] = "通用"
file_ext: str
context_memories: Optional[List[Dict]] = None
......@@ -119,7 +143,7 @@ class SegmentReviewResponse(BaseModel):
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
......@@ -132,7 +156,7 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
result = review_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
ruleset_id=payload.ruleset_id or "rule-v1",
ruleset_id=payload.ruleset_id or "通用",
party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(),
)
......@@ -159,6 +183,57 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
)
########################################################################################################################
class ReflectReviewRequest(BaseModel):
conversation_id: str
party_role: str
ruleset_id: Optional[str] = "通用"
rule_title: str
class ReflectReviewResponse(BaseModel):
conversation_id: str
rule_title: str
findings: List[Dict]
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id)
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
facts = store.get_facts()
findings = [f.__dict__ for f in store.list_findings()]
final_findings = reflect_tool.run(
party_role=payload.party_role,
rule=rule,
facts=facts,
findings=findings,
)
for f in final_findings or []:
try:
store.add_final_finding_from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": f.get("segment_id", 0),
"original_text": f.get("original_text", ""),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
})
except Exception:
continue
return ReflectReviewResponse(
conversation_id=payload.conversation_id,
rule_title=payload.rule_title,
findings=final_findings or [],
)
########################################################################################################################
class ConversationResponse(BaseModel):
conversation_id: str
created_at: str
......@@ -173,44 +248,41 @@ def new_conversation() -> ConversationResponse:
class MemoryExportRequest(BaseModel):
conversation_id: str
file_ext: str
file_name: Optional[str] = None
class MemoryExportResponse(BaseModel):
conversation_id: str
url: str
data: Dict
excel_url: str
doc_url: str
@app.post("/memory/export", response_model=MemoryExportResponse)
def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
store = get_cached_memory(payload.conversation_id)
try:
res = store.export_to_excel(file_name=payload.file_name)
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
try:
excel_res = store.export_to_excel(file_name=payload.file_name)
except ImportError as exc:
raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Export failed: {exc}")
url = ""
if isinstance(res, str):
url = res
elif isinstance(res, dict):
for key in [
"url",
"file_url",
"fileUrl",
"link",
"downloadUrl",
"path",
"filePath",
]:
val = res.get(key)
if isinstance(val, str) and val:
url = val
break
return MemoryExportResponse(conversation_id=payload.conversation_id, url=url, data=res if isinstance(res, dict) else {"result": res})
try:
doc_res = store.export_findings_to_doc_comments(doc_obj)
except Exception as exc:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Export doc comments failed: {exc}")
return MemoryExportResponse(
conversation_id=payload.conversation_id,
excel_url=excel_res,
doc_url=doc_res,
)
if __name__ == "__main__":
uvicorn.run(
......
import json
import json_repair
import random
import re
from datetime import datetime
from typing import Dict, List
from loguru import logger
from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size
......@@ -51,14 +50,14 @@ def extract_json(json_str:str) -> List[Dict]:
# 清理控制字符
s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s)
try:
obj = json.loads(s, strict=False)
obj = json_repair.loads(s, strict=False)
if isinstance(obj, list):
out_list.extend(obj)
else:
out_list.append(obj)
return True
except Exception as e:
logger.error(f"JSON解析失败: {e}")
logger.error(f"JSON解析失败: {e} | 内容片段: {s}")
return False
results = []
......
......@@ -53,7 +53,7 @@ def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True)
return rsp
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False):
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False) -> str:
# 登录获取token
login_data = {"username": "admin", "password": "admin@jpai.com"}
login_url = f"{base_backend_url}/admin-api/system/auth/login"
......@@ -109,5 +109,5 @@ def url_replace_fastgpt(origin: str):
if __name__ == "__main__":
# d = '/home/ccran/file.docx'
d = "file.docx"
print(os.path.basename(d))
d = "/home/ccran/lufa-contract/tmp/default.json"
print(upload_file(d))
import fitz
from urllib import parse
from urllib.parse import urlparse
import os
import asyncio
import aiohttp
from aiohttp import ClientSession
from utils.http_util import url_replace_fastgpt, download_file
from utils.common_util import random_str
from loguru import logger
import json
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
class OCRUtil:
def __init__(self):
pass
# ocr异步接口
async def ocr_requests_async(self, session, file_path):
# 文件上传通常使用multipart/form-data编码格式
# file_path = 'file/1.webp'
async with session.post(ocr_url, data={'file': open(file_path, 'rb')}) as response:
# logger.info(f'开始请求:{file_path}')
rsp = await response.text()
# logger.info(f'{file_path}结束请求:{json.loads(rsp)["msg"]}')
return rsp, file_path
# 异步请求ocr image接口
async def ocr_image_async(self, path_list):
timeout = aiohttp.ClientTimeout(total=600)
connector = aiohttp.TCPConnector(limit=1)
async with ClientSession(connector=connector, timeout=timeout) as session:
tasks = [self.ocr_requests_async(session, file_path) for file_path in path_list]
responses = await asyncio.gather(*tasks)
res_dict = {}
for response_text, file_path in responses:
rsp_json = json.loads(response_text)
if 'data' not in rsp_json:
logger.error(f'ocr_image_async {file_path} error:{rsp_json["msg"]}')
continue
else:
content = rsp_json['data']['strRes']
# logger.info(f'ocr_image_async append {file_path} success. content:{content[:100]}')
# add to dict
page_num = int(self.get_pdf_2_img_page_num(file_path))
res_dict[page_num] = rsp_json['data']['strRes']
# 根据页数排序
logger.info(f'ocr_image_async finish. all pages:{len(res_dict)}')
sorted_values = [res_dict[key] for key in sorted(res_dict)]
return sorted_values
def set_pdf_2_img_page(self, path, page_idx):
return f'{path}_{page_idx + 1}.png'
def get_pdf_2_img_page_num(self, path):
split_path = path.split('_')
# for example : 'xx_10.png' ==> get 10
return split_path[-1][:-4]
def pdf_2_img(self, path, zoom_x=1, zoom_y=1):
# 打开PDF文件
pdf = fitz.open(path)
pdf_list = []
# 逐页读取PDF
for pg in range(0, pdf.page_count):
page = pdf[pg]
# 设置缩放和旋转系数
trans = fitz.Matrix(zoom_x, zoom_y)
pm = page.get_pixmap(matrix=trans, alpha=False)
# img save
dest_png = self.set_pdf_2_img_page(path, pg)
pm.save(dest_png)
pdf_list.append(dest_png)
pdf.close()
return pdf_list
def ocr_download_path(self, url):
logger.info(f'ocr url:{url}')
# url替换
url = url_replace_fastgpt(url)
# url分析
url_parsed = urlparse(url)
query_dict = parse.parse_qs(url_parsed.query)
if 'filename' in query_dict:
filename = query_dict.get('filename')[0]
else:
filename = f'{random_str()}.pdf'
# 获取文件内容
dest_path = f'ocr/{filename}'
download_file(url, dest_path)
return dest_path
async def ocr_result_pdf(self, dest_path):
# 解析返回结果
pdf_list = self.pdf_2_img(dest_path)
# logger.info(f'pdf 生成完毕:{pdf_list}')
result = await self.ocr_image_async(pdf_list)
# 删除文件
for pdf in pdf_list:
if os.path.exists(pdf):
os.remove(pdf)
return result
if __name__ == '__main__':
ocr_util = OCRUtil()
result = asyncio.run(ocr_util.ocr_result_pdf('/home/ccran/lufa-contract/tmp/财03 2023年国科京东方评估合同.pdf'))
print(f'len(result):{len(result)}')
print(result)
import asyncio
from core.config import root_path
from utils.doc_util import DocBase
from spire.pdf import PdfDocument, PdfTextExtractOptions, PdfTextExtractor, License, PdfTextFinder, TextFindParameter, \
PdfTextMarkupAnnotation, PdfRGBColor, Color, PdfPolyLineAnnotation, PointF, RectangleF, PdfPopupAnnotation, \
PdfPopupIcon
from loguru import logger
from rapidfuzz import process
from utils.ocr_util import OCRUtil
import copy
from utils.common_util import adjust_single_chunk_size
import re
import os
ocr_thresholds = 1000
def is_messy_text(text: str,
min_chars=40,
chinese_ratio_thresh=0.20,
printable_ratio_thresh=0.70,
symbol_ratio_thresh=0.30,
longest_non_word_run_thresh=10,
english_word_density_thresh=0.03):
"""
判断单页或一段 text 是否 '乱码'。如果返回 True 表示该 text 被判为乱码(需要 OCR)。
参数可调整:
- min_chars: 少于这个字符数直接认为质量差(比如抽出来是碎行)
- chinese_ratio_thresh: 中文字符占比阈值(若低于且其它判据也成立则可能乱码)
- printable_ratio_thresh: 可打印字符占比(太低说明很多不可见/非标准字符)
- symbol_ratio_thresh: 标点/特殊符号占比阈值(高于则更像乱码)
- longest_non_word_run_thresh: 最长连续特殊符号串长度阈值(如 "~~##@@@... ")
- english_word_density_thresh: 英文单词密度(单词数/行数或单词数/字符数近似);若很低且无中文,则文本不可读
"""
if not text:
return True
text_len = len(text)
if text_len < min_chars:
# 太短的片段通常是拆分后的乱码或空白
return True
# 统计中文、字母/数字、可打印、空白
chinese_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
alnum_count = sum(1 for c in text if c.isalnum())
printable_count = sum(1 for c in text if c.isprintable())
space_count = sum(1 for c in text if c.isspace())
# 标点/特殊符号(既不是中文也不是字母数字也不是空白,视为符号)
symbol_count = sum(1 for c in text if not (('\u4e00' <= c <= '\u9fff') or c.isalnum() or c.isspace()))
chinese_ratio = chinese_count / text_len
printable_ratio = printable_count / text_len
symbol_ratio = symbol_count / text_len
# 连续的非字/非中文/非数字串(比如 "~~~##@@@")
non_word_runs = re.findall(r'[^0-9A-Za-z\u4e00-\u9fff\s]+', text)
longest_non_word_run = max((len(s) for s in non_word_runs), default=0)
# 英文单词密度(粗略):单词数 / 总字符数,单词定义为至少两个字母
english_words = re.findall(r'\b[a-zA-Z]{2,}\b', text)
english_word_density = len(english_words) / max(1, text_len)
# 启发式判断:多条规则投票(其中任意一条强指标触发即可)
# 强指标:非常高的符号占比、非常长的连续符号、几乎不可打印
if printable_ratio < printable_ratio_thresh * 0.6:
return True
if symbol_ratio > max(0.5, symbol_ratio_thresh):
return True
if longest_non_word_run >= longest_non_word_run_thresh * 1.5:
return True
# 组合判断:如果中文占比很低且英文单词密度也很低,并且符号占比高 -> 乱码
if chinese_ratio < chinese_ratio_thresh and english_word_density < english_word_density_thresh and symbol_ratio > symbol_ratio_thresh:
return True
# 如果中文占比极低且打印字符不多,也判为乱码
if chinese_ratio < (chinese_ratio_thresh * 0.5) and printable_ratio < printable_ratio_thresh:
return True
# 其它情况视为可读
return False
class SpirePdfDoc(DocBase):
def get_chunk_location(self, chunk_idx):
chunk_item = self._chunk_list[chunk_idx]
return '、'.join(str(page + 1) for page in chunk_item['page_idx'])
# 调整块大小
def adjust_chunk_size(self):
all_text_len = len(self.get_all_text())
self._max_single_chunk_size = adjust_single_chunk_size(all_text_len)
logger.info(f'SpirePdfDoc adjust _max_single_chunk_size to {self._max_single_chunk_size}')
self.reset()
return self._max_single_chunk_size
async def get_from_ocr(self):
"""
新策略:
- 先用 Spire 的提取结果(self._chunk_list 已由 reset/init_chunks 填充)
- 对每一页单独用 is_messy_text 判断
- 如果被判为乱码的页占比超过 page_messy_ratio_thresh(比如 0.4),则触发 OCR(按页OCR)
- 若被判为乱码的页数量少(但某些页极差),也可触发 OCR(按需要)
"""
# 先从当前 chunk_list 拼出每页文本(init_chunks 保证 chunk_list 基于 PdfTextExtractor)
# 注意 chunk_list 中 page 是每个 chunk 的 page 文本列表,我们要逐页统计
all_pages = []
for chunk in self._chunk_list:
# chunk['page'] 是字符串列表,对应 chunk['page_idx'] 顺序一致
all_pages.extend(chunk.get('page', []))
# 如果没有任何页面文本(极端情况),强制 OCR
if len(all_pages) == 0:
logger.info('Spire extract returned 0 pages text, forcing OCR.')
do_ocr = True
else:
# 针对每页判断是否为乱码
page_messy_flags = []
for i, page_text in enumerate(all_pages):
messy = is_messy_text(page_text)
page_messy_flags.append(messy)
logger.debug(f'page {i + 1}: len={len(page_text)} messy={messy}')
messy_count = sum(1 for f in page_messy_flags if f)
total_pages = len(all_pages)
messy_ratio = messy_count / total_pages
logger.info(f'spire extract pages: {total_pages}, messy pages: {messy_count}, ratio: {messy_ratio:.2f}')
# 触发 OCR 的条件(可调):
# 1) 若多数页为乱码 -> OCR
# 2) 若至少一页非常差(例如 is_messy_text 在内部检测到极端情况) -> OCR
page_messy_ratio_thresh = 0.40 # 如果超过 40% 页被判为乱码则 OCR
# 若某页文本长度非常短(如 <20)且很多页属于这种短页 -> OCR
many_short_pages = sum(1 for p in all_pages if len(p) < 20) / total_pages > 0.3
do_ocr = (messy_ratio > page_messy_ratio_thresh) or many_short_pages
# 作为后备:如果全文字符数低于 ocr_thresholds(原逻辑),同样 OCR
all_text = '\n'.join(all_pages)
if len(all_text) < ocr_thresholds:
do_ocr = True
if do_ocr:
page_result_list = await self.ocr_util.ocr_result_pdf(self._doc_path)
logger.info(f'ocr success: len[{len(page_result_list)}]')
# 用 OCR 结果重新构建 chunk_list(与原逻辑一致)
chunk_list = []
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
for i in range(0, len(page_result_list)):
text = page_result_list[i]
if cur_text_len + len(text) > self._max_single_chunk_size and len(chunk_item['page']):
chunk_list.append(copy.deepcopy(chunk_item))
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
chunk_item['page'].append(text)
chunk_item['page_idx'].append(i)
cur_text_len += len(text)
if len(chunk_item['page']):
chunk_list.append(copy.deepcopy(chunk_item))
self._chunk_list = chunk_list
else:
logger.info('Spire extract judged readable — skip OCR.')
return ''
def __init__(self, **kwargs):
super(SpirePdfDoc, self).__init__(**kwargs)
self.ocr_util = OCRUtil()
def load(self,doc_path):
self._doc_path = doc_path
self._doc_name = os.path.basename(doc_path)
self._doc = PdfDocument()
self._doc.LoadFromFile(self._doc_path)
self._chunk_list = self.init_chunks()
def init_chunks(self):
chunk_list = []
extract_options = PdfTextExtractOptions()
extract_options.IsExtractAllText = True
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
for i in range(0, self._doc.Pages.Count):
page = self._doc.Pages[i]
text = PdfTextExtractor(page).ExtractText(extract_options)
if cur_text_len + len(text) > self._max_single_chunk_size and len(chunk_item['page_objs']):
chunk_list.append(chunk_item)
chunk_item = {
'page_objs': [],
'page': [],
'page_idx': []
}
cur_text_len = 0
chunk_item['page_objs'].append(page)
chunk_item['page'].append(text)
chunk_item['page_idx'].append(i)
cur_text_len += len(text)
if len(chunk_item['page_objs']):
chunk_list.append(chunk_item)
return chunk_list
def get_chunk_item(self, chunk_id):
chunk = self._chunk_list[chunk_id]
return '\n'.join([page for page in chunk['page']])
def get_chunk_info(self, chunk_id):
chunk = self._chunk_list[chunk_id]
from_location = f'[页面{chunk["page_idx"][0] + 1}]'
to_location = f'[页面{chunk["page_idx"][-1] + 1}]'
chunk_content_tips = '[' + chunk['page'][0][:20] + ']...到...[' + chunk['page'][-1][-20:] + ']'
return f'文件块id: {chunk_id + 1}\n文件块位置: 从{from_location}到{to_location}\n文件块简述: {chunk_content_tips}\n'
def format_comment_author(self, comment):
return '{}|{}'.format(str(comment['id']), comment['key_points'])
def add_text_markup_anno(self, page_obj, text_fragment, author, review_suggest):
# 遍历文本边界
for j in range(len(text_fragment.Bounds)):
# 获取边界
rect = text_fragment.Bounds[j]
# 创建文本标记注释
annotation = PdfTextMarkupAnnotation(author, review_suggest, rect)
annotation.Name = author
annotation.TextMarkupColor = PdfRGBColor(Color.get_Green())
# 将注释添加到注释集合中
page_obj.AnnotationsWidget.Add(annotation)
# TODO 好像注释删不掉
def add_poly_line_anno(self, page_obj, text_fragment, author, review_suggest):
# 获取文本边界
rect = text_fragment.Bounds[0]
# 获取文本边界的坐标以添加注释
left = rect.Left
top = rect.Top
right = rect.Right
bottom = rect.Bottom
polyLineAnnotation = PdfPolyLineAnnotation(page_obj,
[PointF(left, top), PointF(right, top), PointF(right, bottom),
PointF(left, bottom), PointF(left, top)])
polyLineAnnotation.Name = author
# 自定义注释内容
polyLineAnnotation.Text = review_suggest
# 将注释添加到文档的注释集合中
page_obj.AnnotationsWidget.Add(polyLineAnnotation)
# TODO 批注好像删不掉
def add_popup_anno(self, page_obj, text_fragment, author, review_suggest):
# 获取文本边界
rect = text_fragment.Bounds[0]
# 获取文本边界的坐标以添加注释
right = rect.Right
top = rect.Top
# 创建注释
rectangle = RectangleF(right, top - 15., 15., 15.)
popupAnnotation = PdfPopupAnnotation(rectangle)
# 自定义注释内容
popupAnnotation.Name = author
popupAnnotation.Text = review_suggest
# 设置注释的图标和颜色
popupAnnotation.Icon = PdfPopupIcon.Comment
popupAnnotation.Color = PdfRGBColor(Color.get_Red())
# 将注释添加到文档的注释集合中
page_obj.AnnotationsWidget.Add(popupAnnotation)
def add_anno(self, chunk, find_key, author, review_suggest):
accurate_find = False
for i in range(0, len(chunk['page_objs'])):
page_obj = chunk['page_objs'][i]
finder = PdfTextFinder(page_obj)
finder.Options.Parameter = TextFindParameter.WholeWord
fragments = finder.Find(find_key)
# 未精确匹配
if fragments is None or len(fragments) == 0:
pass
else:
textFragment = fragments[0]
# logger.info(f'[add_anno markup] {find_key} {author} {review_suggest}')
# textFragment.HighLight(Color.get_Yellow())
# self.add_text_markup_anno(page_obj, textFragment, author, review_suggest)
# self.add_poly_line_anno(page_obj, textFragment, author, review_suggest)
self.add_popup_anno(page_obj, textFragment, author, review_suggest)
accurate_find = True
return accurate_find
def add_chunk_comment(self, chunk_id, comments):
# logger.info(f'add_chunk_comment: {chunk_id} {comments}')
chunk = self._chunk_list[chunk_id]
for comment in comments:
if comment['result'] != '不合格':
continue
review_suggest = comment.get('suggest', '')
author = self.format_comment_author(comment)
find_key = comment['original_text'].strip() if comment['original_text'].strip() else comment['key_points']
# 先进行精确查找
accurate_find = self.add_anno(chunk, find_key, author, review_suggest)
# 模糊查找
if not accurate_find:
sub_chunk_texts = []
for page in chunk['page']:
sub_chunk_texts.extend([item.strip() for item in page.split('\n') if len(item.strip())])
top_match = process.extract(find_key, sub_chunk_texts, limit=1)
sub_chunk_text, score, sub_chunk_idx = top_match[0]
# logger.info(f'add_chunk_comment find_key:{find_key} fuzz match: {sub_chunk_text}')
# 在进行模糊查找
self.add_anno(chunk, sub_chunk_text, author, review_suggest)
def print_all_anno(self):
for j in range(0, self._doc.Pages.Count):
page_obj = self._doc.Pages.get_Item(j)
for i in range(page_obj.AnnotationsWidget.Count):
anno = page_obj.AnnotationsWidget.get_Item(i)
name = ''
try:
name = anno.Name
except:
name = ''
logger.info(f'anno批注:{j + 1}页 第{i + 1}个 [{name}]:[{anno.Text}]')
# page_obj.AnnotationsWidget.Clear()
def find_anno_idx(self, page_obj, author):
for i in range(page_obj.AnnotationsWidget.Count):
anno = page_obj.AnnotationsWidget.get_Item(i)
anno_name = ''
try:
anno_name = anno.Name
except Exception as e:
pass
# logger.info(f'{type(anno)} {anno_name}')
if anno_name == author:
# logger.info(f'find_anno_idx:{author} at {i}')
return i
return -1
def edit_chunk_comment(self, comments):
for comment in comments:
review_answer = comment['result']
if review_answer == '不合格':
idx, page, delete_comment = self.delete_single_chunk_comment_once(comment)
# 不存在存在删除项,新增
if idx == -1:
self.add_chunk_comment(comment['chunk_id'], [comment])
# 存在删除项,编辑
else:
review_suggest = comment.get('suggest', '')
all_pages = []
all_delete_comment = []
all_idx = []
while idx != -1:
delete_comment.Text = review_suggest
all_pages.append(page)
all_delete_comment.append(delete_comment)
all_idx.append(idx)
idx, page, delete_comment = self.delete_single_chunk_comment_once(comment)
for page, delete_comment, idx in zip(all_pages, all_delete_comment, all_idx):
page.AnnotationsWidget.Insert(idx, delete_comment)
# 合格则删除
else:
self.delete_single_chunk_comment(comment)
def delete_chunk_comment(self, comments):
for comment in comments:
self.delete_single_chunk_comment(comment)
def delete_single_chunk_comment(self, comment):
while True:
idx, _, _ = self.delete_single_chunk_comment_once(comment)
if idx == -1:
break
# 删除一次批注
def delete_single_chunk_comment_once(self, comment):
author = self.format_comment_author(comment)
chunk_id = comment['chunk_id']
chunk = self._chunk_list[chunk_id]
for page_obj in chunk['page_objs']:
idx = self.find_anno_idx(page_obj, author)
if idx != -1:
# logger.info(f'delete_single_chunk_comment_once:{author} {comment}')
delete_comment = page_obj.AnnotationsWidget.get_Item(idx)
page_obj.AnnotationsWidget.RemoveAt(idx)
# logger.info(page_obj.AnnotationsWidget.Contains(delete_comment))
return idx, page_obj, delete_comment
return -1, None, None
def get_chunk_id_list(self, step=1):
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)]
def get_chunk_num(self):
return len(self._chunk_list)
def get_all_text(self):
return '\n'.join([page for chunk in self._chunk_list for page in chunk['page']])
def to_file(self, path, **kwargs):
self._doc.SaveToFile(path)
def re_comment_by_memory(self, comments):
self._doc.Close()
self.reset()
for comment in comments:
self.add_chunk_comment(comment['chunk_id'], [comment])
def release(self):
self._doc.Close()
super().release()
if __name__ == '__main__':
tool = SpirePdfDoc()
tool.load('/home/ccran/lufa-contract/tmp/财03 2023年国科京东方评估合同.pdf')
print('', tool.get_all_text())
print(asyncio.run(tool.get_from_ocr()))
print(tool.get_chunk_id_list())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment