Commit 9fa3fd96 by ccran

feat: 增加PDF处理;

parent 53a63f97
tmp/
\ No newline at end of file
tmp/
*.pyc
**__pycache__**
\ No newline at end of file
from utils.spire_word_util import SpireWordDoc
from utils.spire_pdf_util import SpirePdfDoc
from utils.doc_util import DocBase
from functools import lru_cache
from typing import Optional
from typing import Optional, Tuple
from core.memory import MemoryStore
from core.config import pdf_support_formats
MAX_CACHE = 128
def _normalize_file_ext(file_ext: str) -> str:
if not file_ext:
raise ValueError("file_ext is required")
ext = file_ext.strip().lower()
if not ext.startswith("."):
ext = f".{ext}"
return ext
@lru_cache(maxsize=MAX_CACHE)
def get_cached_doc_tool(conversation_id: str) -> Optional[DocBase]:
return SpireWordDoc()
def get_cached_doc_tool(conversation_id: str, file_ext: str) -> Tuple[DocBase, str]:
ext = _normalize_file_ext(file_ext)
if ext in pdf_support_formats:
return SpirePdfDoc(), ext
return SpireWordDoc(), ext
@lru_cache(maxsize=MAX_CACHE)
def get_cached_memory(conversation_id: str) -> MemoryStore:
......
......@@ -29,7 +29,6 @@ class LLMConfig:
outer_backend_url = "http://218.77.58.8:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
# 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent"
......@@ -70,7 +69,6 @@ field_mapping = {
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5
# 销售类别判断
ocr_thresholds = 1000
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思"
# 最大分片数量
......
......@@ -58,6 +58,7 @@ from threading import RLock
from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file
from utils.doc_util import DocBase
logger = logging.getLogger(__name__)
......@@ -107,6 +108,7 @@ class MemoryStore:
self._lock = RLock()
self.facts: List[Dict[str, Any]] = []
self.findings: List[RiskFinding] = []
self.final_findings: List[RiskFinding] = []
self._load()
# ---------------------- facts ----------------------
......@@ -128,35 +130,76 @@ class MemoryStore:
# -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding:
with self._lock:
self.findings.append(finding)
self._persist()
return finding
return self._add_finding(self.findings, finding)
def add_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_finding(RiskFinding.from_dict(data))
def add_final_finding(self, finding: RiskFinding) -> RiskFinding:
return self._add_finding(self.final_findings, finding)
def add_final_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_final_finding(RiskFinding.from_dict(data))
def list_findings(self) -> List[RiskFinding]:
with self._lock:
return list(self.findings)
return self._list_findings(self.findings)
def list_final_findings(self) -> List[RiskFinding]:
return self._list_findings(self.final_findings)
def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
with self._lock:
return [f for f in self.findings if f.segment_id == segment_id]
return self._get_findings_by_segment(self.findings, segment_id)
def get_final_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
return self._get_findings_by_segment(self.final_findings, segment_id)
def delete_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("findings", segment_id)
def delete_final_findings_by_segment(self, segment_id: int) -> int:
return self._delete_findings_by_segment("final_findings", segment_id)
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.findings, keyword, rule_title, risk_level)
def search_final_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
return self._search_findings(self.final_findings, keyword, rule_title, risk_level)
def _add_finding(self, target: List[RiskFinding], finding: RiskFinding) -> RiskFinding:
with self._lock:
target.append(finding)
self._persist()
return finding
def _list_findings(self, target: List[RiskFinding]) -> List[RiskFinding]:
with self._lock:
before = len(self.findings)
self.findings = [f for f in self.findings if f.segment_id != segment_id]
removed = before - len(self.findings)
return list(target)
def _get_findings_by_segment(self, target: List[RiskFinding], segment_id: int) -> List[RiskFinding]:
with self._lock:
return [f for f in target if f.segment_id == segment_id]
def _delete_findings_by_segment(self, attr_name: str, segment_id: int) -> int:
with self._lock:
current = getattr(self, attr_name)
before = len(current)
updated = [f for f in current if f.segment_id != segment_id]
setattr(self, attr_name, updated)
removed = before - len(updated)
if removed:
self._persist()
return removed
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
def _search_findings(
self,
target: List[RiskFinding],
keyword: str,
rule_title: Optional[str] = None,
risk_level: Optional[str] = None,
) -> List[RiskFinding]:
key = (keyword or "").strip().lower()
with self._lock:
candidates = list(self.findings)
candidates = list(target)
if rule_title:
candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()]
if risk_level:
......@@ -179,12 +222,14 @@ class MemoryStore:
with self._lock:
self.facts.clear()
self.findings.clear()
self.final_findings.clear()
self._persist()
def _persist(self) -> None:
payload = {
"facts": self.facts,
"findings": [asdict(f) for f in self.findings],
"final_findings": [asdict(f) for f in self.final_findings],
}
try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
......@@ -200,6 +245,7 @@ class MemoryStore:
if isinstance(data, dict):
self.facts = data.get("facts") or []
self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []]
self.final_findings = [RiskFinding.from_dict(item) for item in data.get("final_findings", []) or []]
except Exception as exc:
logger.error("Failed to load memory store: %s", exc)
......@@ -233,6 +279,13 @@ class MemoryStore:
getattr(f, key, "") for key, _ in finding_headers
])
ws_final_findings = wb.create_sheet("final_findings")
ws_final_findings.append([label for _, label in finding_headers])
for f in self.final_findings:
ws_final_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
ws_facts = wb.create_sheet("facts")
if self.facts:
fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()})
......@@ -263,8 +316,97 @@ class MemoryStore:
return res
def export_findings_to_doc_comments(
self,
doc_obj: DocBase,
file_name: Optional[str] = None,
remove_prefix: bool = False,
) -> Dict[str, Any]:
"""Add all findings as comments to a document, upload, then delete the local file."""
if doc_obj is None:
raise ValueError("doc_obj is required")
if __name__ == "__main__":
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
doc_name = getattr(doc_obj, "_doc_name", "") or ""
suffix = Path(doc_name).suffix or ".docx"
name = file_name or f"findings_{ts}{suffix}"
if not Path(name).suffix:
name = f"{name}{suffix}"
output_path = Path(__file__).resolve().parent.parent / "tmp" / name
with self._lock:
comments: List[Dict[str, Any]] = []
for idx, f in enumerate(self.findings, start=1):
# 导出final_findings
# for idx, f in enumerate(self.final_findings, start=1):
segment_id = int(f.segment_id or 0)
chunk_id = max(segment_id - 1, 0)
suggest_parts = []
if f.risk_level:
suggest_parts.append(f"风险等级:{f.risk_level}")
if f.issue:
suggest_parts.append(f"问题:{f.issue}")
if f.suggestion:
suggest_parts.append(f"建议:{f.suggestion}")
suggest_text = "\n".join(suggest_parts).strip()
comments.append(
{
"id": str(idx),
"key_points": f.rule_title or "风险提示",
"original_text": f.original_text or "",
"details": f.issue or "",
"chunk_id": chunk_id,
"result": "不合格",
"suggest": suggest_text,
}
)
if comments:
doc_obj.add_chunk_comment(0, comments)
doc_obj.to_file(str(output_path), remove_prefix=remove_prefix)
try:
res = upload_file(str(output_path))
finally:
try:
output_path.unlink()
except Exception:
logger.warning("Failed to delete temp doc: %s", output_path)
return res
def test_export_findings_to_doc_comments(doc_path: str) -> None:
store = MemoryStore()
finding = RiskFinding(
rule_title="违约责任",
segment_id=1,
original_text="湖南麓谷发展集团有限公司",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
store.add_final_finding(finding)
"""测试:将 findings 作为批注写入文档并上传。"""
if not doc_path:
print("doc_path 为空,跳过批注导出测试")
return
if not Path(doc_path).exists():
print(f"文件不存在,跳过批注导出测试: {doc_path}")
return
try:
from utils.spire_word_util import SpireWordDoc
except Exception as exc:
print(f"加载 SpireWordDoc 失败,跳过批注导出测试: {exc}")
return
doc = SpireWordDoc()
doc.load(doc_path)
res = store.export_findings_to_doc_comments(doc)
print("Export doc comments:")
print(json.dumps(res, ensure_ascii=False, indent=2))
def test_memory_and_export_excel():
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore()
store.set_facts([{
......@@ -295,3 +437,8 @@ if __name__ == "__main__":
print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print(store.export_to_excel())
if __name__ == "__main__":
# test_export_findings_to_doc_comments("/home/ccran/lufa-contract/tmp/股份转让协议.docx")
test_memory_and_export_excel()
from __future__ import annotations
from typing import Dict, List, Optional
from core.tool import ToolBase, tool, tool_func
@tool("reflect_retry", "反思重试质量闸(MVP)")
class ReflectRetryTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "string"},
"ruleset_id": {"type": "string"},
"context_memories": {"type": "array"},
"review_output": {"type": "object"},
},
"required": ["segment_id", "ruleset_id", "review_output"],
}
)
def run(self, segment_id: str, ruleset_id: str, context_memories: Optional[List[Dict]] = None, review_output: Dict = None) -> Dict:
problems: List[str] = []
actions: List[Dict] = []
findings = (review_output or {}).get("findings", [])
missings = (review_output or {}).get("missing_info", [])
for f in findings:
evs = f.get("evidence_quotes", []) or []
sugg = (f.get("suggestion") or "").strip()
if not evs:
problems.append("缺少对关键结论的原文引用")
if not sugg:
problems.append("建议条款措辞仍偏泛,未给可替换文本")
needs_evidence = any(m.get("type") == "evidence" for m in missings)
needs_knowledge = any(m.get("type") == "knowledge" for m in missings)
status = "pass"
if needs_evidence:
actions.append({"action": "FullRead", "target": segment_id, "need_focus": [m.get("need") for m in missings if m.get("type") == "evidence"]})
if needs_knowledge:
actions.append({"action": "RetrieveReference", "topic": [m.get("topic") for m in missings if m.get("type") == "knowledge"]})
if problems or actions:
status = "revise"
retry_budget_left = int((review_output or {}).get("retry_budget_left", 1))
return {
"segment_id": segment_id,
"status": status,
"problems": list(dict.fromkeys(problems)),
"revise_instructions": actions,
"retry_budget_left": max(0, retry_budget_left - (1 if status == "revise" else 0)),
}
import json
from typing import Dict, List, Optional, Any
from core.tool import tool, tool_func
from core.tools.segment_llm import LLMTool
REFLECT_SYSTEM_PROMPT = '''
你是一个合同审查反思智能体(ReviewReflection)。
你的任务是:基于 facts 与全文上下文,
对已有 findings 输出修改操作。
你只能对 findings 执行以下四种操作:
- keep:确认该风险结论成立且无需修改
- update:修改一个已有风险
- add:新增一个由全文结构推导出的风险
- remove:删除一个不成立的风险
【输出约束】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
REFLECT_USER_PROMPT = '''
【输入】
【已有风险 findings】
{findings_json}
【合同事实记忆 facts】
{facts_json}
【合同立场】
站在 {party_role} 的立场进行反思审查。
【反思原则】
- 不得引入新的审查维度
- 所有判断必须有 facts 或合同原文证据支持
- 仅对已有 findings 进行增add、改update、删remove、保留keep操作
- 若风险在全文上下文中不成立,必须 remove
- 若风险成立但表述、严重性或建议不准确,必须 update
- 若由多个 findings 或全文结构推导出新的系统性风险,可 add
【输出要求】
- 输出必须是 JSON
- 每条操作必须仅包含字段:op、id、findings
- findings 在 add / update 时必须是完整 finding
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字
'''
OUTPUT_FORMAT_SCHEMA = '''
```json
{
"operations": [
{
"op": "keep | add | update | remove",
"id": "string | null",
"findings": "object | null"
}
]
}
```
'''
@tool("reflect_retry", "反思重试质量闸")
class ReflectRetryTool(LLMTool):
def __init__(self) -> None:
super().__init__(REFLECT_SYSTEM_PROMPT)
@tool_func(
{
"type": "object",
"properties": {
"party_role": {"type": "string"},
"rule": {"type": "object"},
"facts": {"type": "array"},
"findings": {"type": "array"},
},
"required": ["party_role", "rule", "facts", "findings"],
}
)
def run(self, party_role: str, rule: Dict, facts: Optional[List[Dict]] = None, findings: Optional[List[Dict]] = None) -> List[Dict]:
base_findings = self._build_findings_with_ids(findings or [])
user_content = REFLECT_USER_PROMPT.format(
findings_json=json.dumps(base_findings, ensure_ascii=False),
facts_json=json.dumps(facts or [], ensure_ascii=False),
party_role=party_role,
) + OUTPUT_FORMAT_SCHEMA
messages = self.build_messages(user_content)
try:
resp = self.run_with_loop(self.chat_async(messages))
data = self.parse_first_json(resp)
except Exception:
data = {}
operations = data.get("operations", []) or []
final_findings = self._apply_operations(base_findings, operations)
return final_findings
def _build_findings_with_ids(self, findings: List[Dict]) -> List[Dict[str, Any]]:
res: List[Dict[str, Any]] = []
for idx, f in enumerate(findings):
fid = f.get("id") or f.get("finding_id") or f.get("_id") or f"f_{idx+1}"
item = dict(f)
item["id"] = str(fid)
res.append(item)
return res
def _apply_operations(self, base_findings: List[Dict[str, Any]], operations: List[Dict]) -> List[Dict[str, Any]]:
by_id = {str(f.get("id")): dict(f) for f in base_findings if f.get("id") is not None}
added: List[Dict[str, Any]] = []
for op in operations:
action = (op.get("op") or "").strip().lower()
target_id = op.get("id")
payload = op.get("findings")
if action == "keep":
if target_id is not None and str(target_id) in by_id:
continue
elif action == "remove":
if target_id is not None:
by_id.pop(str(target_id), None)
elif action == "update":
if target_id is not None and isinstance(payload, dict):
payload = dict(payload)
payload["id"] = str(target_id)
by_id[str(target_id)] = payload
elif action == "add":
if isinstance(payload, dict):
added.append(dict(payload))
merged = list(by_id.values()) + added
return merged
if __name__ == "__main__":
tool = ReflectRetryTool()
res = tool.run(
party_role="甲方",
rule={"title":"付款与验收", "segment_id": 3},
facts=[
{"segment_id": 3, "支付": {"比例": "30%预付款", "期限": "验收后30日内"}},
{"segment_id": 3, "验收": {"标准": "双方书面确认", "时限": "7个工作日"}},
],
findings=[
{
"rule_title": "付款条款审查",
"segment_id": 3,
"original_text": "甲方在验收合格后30日内支付剩余价款。",
"issue": "未约定逾期付款违约金,约束不足",
"risk_level": "M",
"suggestion": "增加逾期付款违约金条款,如逾期每日按未付款的0.05%计收"
},
{
"rule_title": "验收条款审查",
"segment_id": 3,
"original_text": "乙方交付后甲方7个工作日内完成验收。",
"issue": "未明确逾期未验收的后果",
"risk_level": "L",
"suggestion": "补充逾期未验收视为通过或约定明确责任"
},
]
)
print(res)
\ No newline at end of file
......@@ -9,8 +9,8 @@ from utils.openai_util import OpenAITool
from core.config import LLM, MAX_WORKERS
class SegmentLLMBase(ToolBase):
"""LLM-backed segment processor: builds prompts, calls LLM, parses JSON."""
class LLMTool(ToolBase):
"""LLM-backed processor: builds prompts, calls LLM, parses JSON."""
def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None:
super().__init__()
......
......@@ -7,10 +7,10 @@ from typing import Dict, List, Optional
from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil
from core.tools.segment_llm import SegmentLLMBase
from core.tools.segment_llm import LLMTool
import re
REVIEW_SYSTEM_PROMPT = f'''
REVIEW_SYSTEM_PROMPT = '''
你是一个专业的合同分段审查智能体(SegmentReview)。
你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。
......@@ -100,10 +100,10 @@ def _is_generic_suggestion(text: str) -> bool:
return any(g in t for g in GENERIC_SUGGESTIONS)
@tool("segment_review", "合同分段审查")
class SegmentReviewTool(SegmentLLMBase):
class SegmentReviewTool(LLMTool):
def __init__(self):
super().__init__(REVIEW_SYSTEM_PROMPT)
self.rule_version = "rule-v1"
self.rule_version = "通用"
self.column_map = {
"id": "ID",
"title": "审查项",
......@@ -278,7 +278,7 @@ if __name__=="__main__":
result = tool.run(
segment_id=1,
segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。",
ruleset_id="rule-v1",
ruleset_id="通用",
party_role="甲方",
context_memories=[
{
......
......@@ -6,7 +6,7 @@ import json
from typing import Dict, List, Optional
from core.tool import tool, tool_func
from core.tools.segment_llm import SegmentLLMBase
from core.tools.segment_llm import LLMTool
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
......@@ -48,7 +48,7 @@ OUTPUT_EXAMPLE = '''
@tool("segment_summary", "分段事实提取")
class SegmentSummaryTool(SegmentLLMBase):
class SegmentSummaryTool(LLMTool):
def __init__(self) -> None:
super().__init__(SUMMARY_SYSTEM_PROMPT)
......
No preview for this file type
......@@ -6,13 +6,15 @@ from uuid import uuid4
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn
import traceback
from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats
from core.config import doc_support_formats, pdf_support_formats
from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool
from core.tools.reflect_retry import ReflectRetryTool
from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0")
......@@ -20,25 +22,31 @@ TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool()
reflect_tool = ReflectRetryTool()
########################################################################################################################
class DocumentParseRequest(BaseModel):
conversation_id: str
urls: List[str] = Field(..., description="File download url")
file_ext: Optional[str] = None
ruleset_id: Optional[str] = "通用"
class DocumentParseResponse(BaseModel):
conversation_id: str
text: str
chunk_ids: List[int]
ruleset_items: List[str]
text: Optional[str] = None
file_ext: Optional[str] = None
@app.post("/documents/parse", response_model=DocumentParseResponse)
def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
async def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
if not payload.urls:
raise HTTPException(status_code=400, detail="No URLs provided")
try:
filename = extract_url_file(payload.urls[0], doc_support_formats)
support_formats = list(dict.fromkeys(doc_support_formats + pdf_support_formats))
filename = extract_url_file(payload.urls[0], support_formats)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}")
......@@ -47,13 +55,27 @@ def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
download_file(payload.urls[0], file_path)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Download failed: {exc}")
doc_obj = get_cached_doc_tool(payload.conversation_id)
# get doc tool
file_ext = payload.file_ext or Path(filename).suffix
try:
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
doc_obj.load(file_path)
text = doc_obj.get_all_text()
# ocr
await doc_obj.get_from_ocr()
# text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list()
# get ruleset items
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
ruleset_review_items = [r.get('title') for r in ruleset_items]
return DocumentParseResponse(
conversation_id=payload.conversation_id, text=text,chunk_ids=chunk_ids
conversation_id=payload.conversation_id,
# text=text,
chunk_ids=chunk_ids,
ruleset_items=ruleset_review_items,
file_ext = file_ext
)
########################################################################################################################
......@@ -62,6 +84,7 @@ class SegmentSummaryRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
file_ext: str
context_facts: Optional[Dict] = None
......@@ -75,7 +98,7 @@ class SegmentSummaryResponse(BaseModel):
def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
......@@ -105,7 +128,8 @@ class SegmentReviewRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
ruleset_id: Optional[str] = "rule-v1"
ruleset_id: Optional[str] = "通用"
file_ext: str
context_memories: Optional[List[Dict]] = None
......@@ -119,7 +143,7 @@ class SegmentReviewResponse(BaseModel):
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
......@@ -132,7 +156,7 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
result = review_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
ruleset_id=payload.ruleset_id or "rule-v1",
ruleset_id=payload.ruleset_id or "通用",
party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(),
)
......@@ -159,6 +183,57 @@ def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
)
########################################################################################################################
class ReflectReviewRequest(BaseModel):
conversation_id: str
party_role: str
ruleset_id: Optional[str] = "通用"
rule_title: str
class ReflectReviewResponse(BaseModel):
conversation_id: str
rule_title: str
findings: List[Dict]
@app.post("/segments/review/reflect", response_model=ReflectReviewResponse)
def reflect_review(payload: ReflectReviewRequest) -> ReflectReviewResponse:
store = get_cached_memory(payload.conversation_id)
ruleset_id = payload.ruleset_id or review_tool.rule_version
ruleset_items = review_tool.rulesets.get(ruleset_id) or []
rule = next((r for r in ruleset_items if r.get("title") == payload.rule_title), None)
if not rule:
raise HTTPException(status_code=404, detail=f"Rule not found: {payload.rule_title}")
facts = store.get_facts()
findings = [f.__dict__ for f in store.list_findings()]
final_findings = reflect_tool.run(
party_role=payload.party_role,
rule=rule,
facts=facts,
findings=findings,
)
for f in final_findings or []:
try:
store.add_final_finding_from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": f.get("segment_id", 0),
"original_text": f.get("original_text", ""),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
})
except Exception:
continue
return ReflectReviewResponse(
conversation_id=payload.conversation_id,
rule_title=payload.rule_title,
findings=final_findings or [],
)
########################################################################################################################
class ConversationResponse(BaseModel):
conversation_id: str
created_at: str
......@@ -173,44 +248,41 @@ def new_conversation() -> ConversationResponse:
class MemoryExportRequest(BaseModel):
conversation_id: str
file_ext: str
file_name: Optional[str] = None
class MemoryExportResponse(BaseModel):
conversation_id: str
url: str
data: Dict
excel_url: str
doc_url: str
@app.post("/memory/export", response_model=MemoryExportResponse)
def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
store = get_cached_memory(payload.conversation_id)
try:
res = store.export_to_excel(file_name=payload.file_name)
doc_obj, _ = get_cached_doc_tool(payload.conversation_id, payload.file_ext)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
try:
excel_res = store.export_to_excel(file_name=payload.file_name)
except ImportError as exc:
raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Export failed: {exc}")
url = ""
if isinstance(res, str):
url = res
elif isinstance(res, dict):
for key in [
"url",
"file_url",
"fileUrl",
"link",
"downloadUrl",
"path",
"filePath",
]:
val = res.get(key)
if isinstance(val, str) and val:
url = val
break
return MemoryExportResponse(conversation_id=payload.conversation_id, url=url, data=res if isinstance(res, dict) else {"result": res})
try:
doc_res = store.export_findings_to_doc_comments(doc_obj)
except Exception as exc:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Export doc comments failed: {exc}")
return MemoryExportResponse(
conversation_id=payload.conversation_id,
excel_url=excel_res,
doc_url=doc_res,
)
if __name__ == "__main__":
uvicorn.run(
......
import json
import json_repair
import random
import re
from datetime import datetime
from typing import Dict, List
from loguru import logger
from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size
......@@ -51,14 +50,14 @@ def extract_json(json_str:str) -> List[Dict]:
# 清理控制字符
s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s)
try:
obj = json.loads(s, strict=False)
obj = json_repair.loads(s, strict=False)
if isinstance(obj, list):
out_list.extend(obj)
else:
out_list.append(obj)
return True
except Exception as e:
logger.error(f"JSON解析失败: {e}")
logger.error(f"JSON解析失败: {e} | 内容片段: {s}")
return False
results = []
......
......@@ -53,7 +53,7 @@ def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True)
return rsp
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False):
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False) -> str:
# 登录获取token
login_data = {"username": "admin", "password": "admin@jpai.com"}
login_url = f"{base_backend_url}/admin-api/system/auth/login"
......@@ -109,5 +109,5 @@ def url_replace_fastgpt(origin: str):
if __name__ == "__main__":
# d = '/home/ccran/file.docx'
d = "file.docx"
print(os.path.basename(d))
d = "/home/ccran/lufa-contract/tmp/default.json"
print(upload_file(d))
import fitz
from urllib import parse
from urllib.parse import urlparse
import os
import asyncio
import aiohttp
from aiohttp import ClientSession
from utils.http_util import url_replace_fastgpt, download_file
from utils.common_util import random_str
from loguru import logger
import json
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
class OCRUtil:
def __init__(self):
pass
# ocr异步接口
async def ocr_requests_async(self, session, file_path):
# 文件上传通常使用multipart/form-data编码格式
# file_path = 'file/1.webp'
async with session.post(ocr_url, data={'file': open(file_path, 'rb')}) as response:
# logger.info(f'开始请求:{file_path}')
rsp = await response.text()
# logger.info(f'{file_path}结束请求:{json.loads(rsp)["msg"]}')
return rsp, file_path
# 异步请求ocr image接口
async def ocr_image_async(self, path_list):
timeout = aiohttp.ClientTimeout(total=600)
connector = aiohttp.TCPConnector(limit=1)
async with ClientSession(connector=connector, timeout=timeout) as session:
tasks = [self.ocr_requests_async(session, file_path) for file_path in path_list]
responses = await asyncio.gather(*tasks)
res_dict = {}
for response_text, file_path in responses:
rsp_json = json.loads(response_text)
if 'data' not in rsp_json:
logger.error(f'ocr_image_async {file_path} error:{rsp_json["msg"]}')
continue
else:
content = rsp_json['data']['strRes']
# logger.info(f'ocr_image_async append {file_path} success. content:{content[:100]}')
# add to dict
page_num = int(self.get_pdf_2_img_page_num(file_path))
res_dict[page_num] = rsp_json['data']['strRes']
# 根据页数排序
logger.info(f'ocr_image_async finish. all pages:{len(res_dict)}')
sorted_values = [res_dict[key] for key in sorted(res_dict)]
return sorted_values
def set_pdf_2_img_page(self, path, page_idx):
return f'{path}_{page_idx + 1}.png'
def get_pdf_2_img_page_num(self, path):
split_path = path.split('_')
# for example : 'xx_10.png' ==> get 10
return split_path[-1][:-4]
def pdf_2_img(self, path, zoom_x=1, zoom_y=1):
# 打开PDF文件
pdf = fitz.open(path)
pdf_list = []
# 逐页读取PDF
for pg in range(0, pdf.page_count):
page = pdf[pg]
# 设置缩放和旋转系数
trans = fitz.Matrix(zoom_x, zoom_y)
pm = page.get_pixmap(matrix=trans, alpha=False)
# img save
dest_png = self.set_pdf_2_img_page(path, pg)
pm.save(dest_png)
pdf_list.append(dest_png)
pdf.close()
return pdf_list
def ocr_download_path(self, url):
logger.info(f'ocr url:{url}')
# url替换
url = url_replace_fastgpt(url)
# url分析
url_parsed = urlparse(url)
query_dict = parse.parse_qs(url_parsed.query)
if 'filename' in query_dict:
filename = query_dict.get('filename')[0]
else:
filename = f'{random_str()}.pdf'
# 获取文件内容
dest_path = f'ocr/{filename}'
download_file(url, dest_path)
return dest_path
async def ocr_result_pdf(self, dest_path):
# 解析返回结果
pdf_list = self.pdf_2_img(dest_path)
# logger.info(f'pdf 生成完毕:{pdf_list}')
result = await self.ocr_image_async(pdf_list)
# 删除文件
for pdf in pdf_list:
if os.path.exists(pdf):
os.remove(pdf)
return result
if __name__ == '__main__':
ocr_util = OCRUtil()
result = asyncio.run(ocr_util.ocr_result_pdf('/home/ccran/lufa-contract/tmp/财03 2023年国科京东方评估合同.pdf'))
print(f'len(result):{len(result)}')
print(result)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment