Commit 53a63f97 by ccran

Initial commit

parents
tmp/
\ No newline at end of file
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: 当前文件",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
}
]
}
\ No newline at end of file
from utils.spire_word_util import SpireWordDoc
from utils.doc_util import DocBase
from functools import lru_cache
from typing import Optional
from core.memory import MemoryStore
MAX_CACHE = 128
@lru_cache(maxsize=MAX_CACHE)
def get_cached_doc_tool(conversation_id: str) -> Optional[DocBase]:
return SpireWordDoc()
@lru_cache(maxsize=MAX_CACHE)
def get_cached_memory(conversation_id: str) -> MemoryStore:
return MemoryStore(f'memory_store_{conversation_id}.json')
\ No newline at end of file
import platform
from dataclasses import dataclass
# 可配置运行参数
use_docker = False
just_b_class = False
is_extract = True
use_original_text_verification = False
# @dataclass
# class LLMConfig:
# base_url: str = "http://172.21.107.45:9002/v1"
# api_key: str = "none"
# model: str = "Qwen3-32B"
#
#
# base_fastgpt_url = "http://172.21.107.45:3030"
# base_backend_url = "http://172.21.107.45:48080"
# ocr_url = 'http://172.21.107.45:8202/openapi/ocrUploadFile'
@dataclass
class LLMConfig:
base_url: str = "http://192.168.252.71:9002/v1"
api_key: str = "none"
model: str = 'Qwen2-72B-Instruct'
outer_backend_url = "http://218.77.58.8:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
# 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent"
system = platform.system()
if system == "Linux":
# root_path = "/data/home/ccran/contract_review_agent"
root_path = '/home/ccran/contract_review_agent'
elif system == "Darwin":
root_path = "/Users/chenran/PycharmProjects/contract_review_agent"
# docker设置
if use_docker:
root_path = '/app'
MAX_WORKERS = 20
LLM = {
"base_tool_llm": LLMConfig(),
"fastgpt_segment_review": LLMConfig(
base_url=f"{base_fastgpt_url}/api/v1",
api_key="fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8",
)
}
doc_support_formats = [".docx", ".doc", ".wps"]
pdf_support_formats = [".txt", ".md", ".pdf"]
# excel字段
field_mapping = {
"review": {
"id": "序号",
"key_points": "审查内容",
"original_text": "合同原文",
"details": "审查过程",
"result": "审查结果",
"suggest": "审查建议",
"chunk_id": "文件块序号",
"chunk_info": "文件块信息",
},
"category": {"text": "原文", "judge": "判断依据", "chunk_location": "原文位置"},
}
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5
# 销售类别判断
ocr_thresholds = 1000
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思"
# 最大分片数量
min_single_chunk_size = 2000
max_single_chunk_size = 20000
max_chunk_page = 10
# 知识库读取sheet
if just_b_class:
know_sheet_name = "B类"
port = 9008
else:
if is_extract:
know_sheet_name = "提取审查"
port = 9016
else:
know_sheet_name = 0
port = 9006
# know_sheet_name = '发票审查'
reload = False
"""
{
"segment_id": "S12",
"topics": ["付款"],
"summary": {
"one_liner": "...",
"structured": ["...", "...", "..."]
},
"risks": [
{"flag":"...", "level":"M", "evidence":"...", "suggestion":"..."}
],
"dependencies": ["S11"],
"open_questions": ["..."],
"evidence_quotes": ["..."]
}
segment_id
含义:该记忆对应的合同分段唯一标识。
注意:用于前后段上下文关联,必须与分段阶段生成的 ID 一致。
topics
含义:本分段涉及的合同主题标签。
注意:从固定枚举中选择(如付款/违约/保密),用于后续记忆检索与一致性校验。
summary
summary.one_liner
含义:一句话概括本段条款的核心约定内容。
注意:只描述事实,不做风险判断。
summary.structured
含义:本段条款的结构化要点列表(3–5 条)。
注意:每条对应一个业务点,便于检索和上下文理解。
risks
含义:本分段确认存在的风险点列表。
字段说明:
flag:风险标签(简短描述问题)
level:风险等级(H / M / L)
evidence:支撑该风险判断的原文依据
suggestion:可直接落地的修改或谈判建议
注意:每条风险必须有明确证据和可执行建议。
dependencies
含义:本分段理解或适用所依赖的其他条款或附件。
注意:用于后续自动回读相关条款,防止断章取义。
open_questions
含义:本分段中未明确、需对方补充或澄清的信息。
注意:不确定内容不要写成风险结论,统一放在这里。
evidence_quotes
含义:从合同原文中摘录的关键短句,用于支撑摘要和风险判断。
注意:用于审查复核与反思重试,避免“凭空总结”。
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from threading import RLock
from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file
logger = logging.getLogger(__name__)
_ALLOWED_RISK_LEVELS = {"H", "M", "L"}
@dataclass
class RiskFinding:
rule_title: str
segment_id: int
original_text: str
issue: str
risk_level: str
suggestion: str
def __post_init__(self) -> None:
level = (self.risk_level or "").upper()
if level not in _ALLOWED_RISK_LEVELS:
raise ValueError(f"risk_level must be one of {_ALLOWED_RISK_LEVELS}, got {self.risk_level}")
self.risk_level = level
@classmethod
def from_dict(cls, data: Dict) -> "RiskFinding":
data = data or {}
return cls(
rule_title=str(data.get("rule_title", "")),
segment_id=int(data.get("segment_id", 0) or 0),
original_text=str(data.get("original_text", "")),
issue=str(data.get("issue", "")),
risk_level=str(data.get("risk_level", "")),
suggestion=str(data.get("suggestion", "")),
)
@dataclass
class MemoryStore:
"""简化的记忆存储:合同事实 facts 与问题 findings。线程安全并支持 JSON 持久化。"""
storage_name: Optional[Path] = 'default.json'
def __init__(self,storage_name:str = 'default.json') -> None:
self._storage_path = Path(__file__).resolve().parent.parent / "tmp" / storage_name # type: ignore[arg-type]
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = RLock()
self.facts: List[Dict[str, Any]] = []
self.findings: List[RiskFinding] = []
self._load()
# ---------------------- facts ----------------------
def set_facts(self, facts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
with self._lock:
self.facts = facts or []
self._persist()
return self.facts
def add_facts(self, partial: Dict[str, Any]) -> List[Dict[str, Any]]:
with self._lock:
self.facts.append(partial)
self._persist()
return self.facts
def get_facts(self) -> List[Dict[str, Any]]:
with self._lock:
return self.facts # deep copy
# -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding:
with self._lock:
self.findings.append(finding)
self._persist()
return finding
def add_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_finding(RiskFinding.from_dict(data))
def list_findings(self) -> List[RiskFinding]:
with self._lock:
return list(self.findings)
def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
with self._lock:
return [f for f in self.findings if f.segment_id == segment_id]
def delete_findings_by_segment(self, segment_id: int) -> int:
with self._lock:
before = len(self.findings)
self.findings = [f for f in self.findings if f.segment_id != segment_id]
removed = before - len(self.findings)
if removed:
self._persist()
return removed
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
key = (keyword or "").strip().lower()
with self._lock:
candidates = list(self.findings)
if rule_title:
candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()]
if risk_level:
lvl = risk_level.strip().upper()
candidates = [f for f in candidates if f.risk_level == lvl]
if not key:
return candidates
def _matches(f: RiskFinding) -> bool:
hay = " ".join([
f.rule_title,
f.original_text,
f.issue,
f.suggestion,
]).lower()
return key in hay
return [f for f in candidates if _matches(f)]
# ------------------- housekeeping ------------------
def clear(self) -> None:
with self._lock:
self.facts.clear()
self.findings.clear()
self._persist()
def _persist(self) -> None:
payload = {
"facts": self.facts,
"findings": [asdict(f) for f in self.findings],
}
try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
except Exception as exc:
logger.error("Failed to persist memory store: %s", exc)
def _load(self) -> None:
try:
if not self._storage_path.exists():
return
raw = self._storage_path.read_text(encoding="utf-8")
data = json.loads(raw or "{}")
if isinstance(data, dict):
self.facts = data.get("facts") or []
self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []]
except Exception as exc:
logger.error("Failed to load memory store: %s", exc)
def export_to_excel(self, file_name: Optional[str] = None) -> Dict[str, Any]:
"""Export findings and facts to Excel, upload, then delete the local file."""
try:
from openpyxl import Workbook # type: ignore
except ImportError as exc:
raise ImportError("openpyxl is required for export_to_excel; install via 'pip install openpyxl'") from exc
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
name = file_name or f"memory_export_{ts}.xlsx"
output_path = Path(__file__).resolve().parent.parent / "tmp" / name
with self._lock:
wb = Workbook()
ws_findings = wb.active
ws_findings.title = "findings"
finding_headers = [
("rule_title", "规则标题"),
("segment_id", "分段ID"),
("original_text", "原文"),
("issue", "问题描述"),
("risk_level", "风险等级"),
("suggestion", "建议"),
]
ws_findings.append([label for _, label in finding_headers])
for f in self.findings:
ws_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
ws_facts = wb.create_sheet("facts")
if self.facts:
fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()})
if "段落" in fact_keys:
fact_keys = ["段落"] + [k for k in fact_keys if k != "段落"]
ws_facts.append(fact_keys)
for item in self.facts:
row = []
for key in fact_keys:
value = item.get(key)
if isinstance(value, (dict, list)):
row.append(json.dumps(value, ensure_ascii=False))
else:
row.append(value)
ws_facts.append(row)
else:
ws_facts.append(["data"])
wb.save(output_path)
try:
res = upload_file(str(output_path))
finally:
try:
output_path.unlink()
except Exception:
logger.warning("Failed to delete temp excel: %s", output_path)
return res
if __name__ == "__main__":
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore()
store.set_facts([{
"公司": {"甲方": "A 公司", "乙方": "B 公司"},
"支付": [
{"方式": "银行转账", "期限": "验收后30日内"}
],
"段落":1
},{
"纠纷": {"解决方式": "诉讼", "地址": "原告方所在地法院"},
"段落":2
}])
finding = RiskFinding(
rule_title="违约责任",
segment_id=1,
original_text="违约方应赔偿全部损失",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
store.add_finding(finding)
print("Facts:\n" + json.dumps(store.get_facts(), ensure_ascii=False, indent=2))
hits = store.search_findings("赔偿", rule_title="违约责任")
print("Findings search:")
for f in hits:
print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print(store.export_to_excel())
from abc import ABC
from functools import wraps
import inspect
# 工具类装饰
def tool(name, description):
def decorator(cls):
cls._name = name
cls._description = description
cls._is_tool = True
return cls
return decorator
# 工具方法装饰
def tool_func(param_description):
def decorator(func):
func._param_description = param_description
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
return result
return wrapper
return decorator
# 工具基类
class ToolBase(ABC):
_name = 'tool_base'
_description = 'tool_description'
_is_tool = False
def execute(self, *args, **kwargs):
methods = inspect.getmembers(self, predicate=inspect.ismethod)
for name, method in methods:
func = method.__func__ if hasattr(method, '__func__') else method
if hasattr(func, '_param_description'):
result = func(self, *args, **kwargs)
return result
raise Exception('没有找到@tool_func装饰的方法')
from core.tool import ToolBase, tool, tool_func
# 加法工具
@tool('add', 'add tool')
class AddTool(ToolBase):
@tool_func({
"type": "object",
"properties": {
"a": {
"type": "int",
"description": "operator 1"
},
"b": {
"type": "int",
"description": "operator 2"
}
},
"required": [
"a",
"b"
]
})
def add(self, a, b):
return a + b
if __name__ == '__main__':
tool = AddTool()
print(AddTool._name)
print(tool.execute(1, 2))
\ No newline at end of file
from __future__ import annotations
from typing import Dict, List, Optional
from core.tool import ToolBase, tool, tool_func
from tools.shared import store, limit_list
@tool("full_read", "分段原文读取/摘录")
class FullReadTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"target": {"type": "object"},
"need_focus": {"type": "array"},
"return_snippets": {"type": "boolean"},
},
"required": ["target"],
}
)
def run(self, target: Dict, need_focus: Optional[List[str]] = None, return_snippets: bool = True) -> Dict:
seg_id = (target or {}).get("id") or (target or {}).get("segment_id") or ""
mem = store.get(seg_id) if seg_id else None
full_text = " ".join(mem.evidence_quotes) if mem and mem.evidence_quotes else ""
if not full_text:
full_text = "(原文未存储,建议在调用端提供全文或将原文写入 MemoryStore)"
focus = need_focus or []
snippets: List[str] = []
if return_snippets and mem:
for kw in focus:
for q in mem.evidence_quotes:
if kw and kw in q:
snippets.append(q)
if not snippets:
for kw in focus:
for s in mem.summary.structured:
if kw and kw in s:
snippets.append(s)
snippets = limit_list(list(dict.fromkeys(snippets)), 5)
return {
"target": {"type": "segment", "id": seg_id},
"full_text": full_text,
"highlighted_snippets": snippets,
}
from __future__ import annotations
import re
from typing import Dict, List, Optional,Tuple
from core.tool import ToolBase, tool, tool_func
from tools.shared import store
from core.memory import MemoryStore
TOPIC_KEYWORDS = {
"付款": ["付款", "支付", "价款", "结算", "发票", "对账", "款项", "账期"],
"验收": ["验收", "确认", "签收", "测试", "交付确认", "验收单", "验收报告"],
"交付": ["交付", "交货", "交付物", "交付时间", "里程碑", "上线", "交接"],
"违约": ["违约", "违约金", "赔偿", "损失", "承担责任", "罚金", "扣款"],
"解除": ["解除", "终止", "解约", "撤销"],
"责任限制": ["责任上限", "间接损失", "免责", "赔偿上限", "责任限制", "不可抗力"],
"保密": ["保密", "保密信息", "披露", "泄露", "保密期限", "例外"],
"知识产权": ["知识产权", "著作权", "专利", "商标", "成果", "源代码", "许可"],
"争议解决": ["争议", "仲裁", "诉讼", "管辖", "适用法律", "法院"],
}
DEFAULT_TOP_K = 6
@tool("memory_read", "分段记忆读取")
class MemoryReadTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_ids": {"type": "array"},
"topics": {"type": "array"},
"query": {"type": "string"},
"top_k": {"type": "int"},
"mode": {"type": "string"},
},
"required": ["top_k"],
}
)
def parse_neighbor_ids(self, segment_id: int) -> List[int]:
neighbor_ids = []
if segment_id >= 1:
neighbor_ids.append(segment_id - 1)
if segment_id >= 2:
neighbor_ids.append(segment_id - 2)
return neighbor_ids
def normalize_text(self, s: str) -> str:
return re.sub(r"\s+", " ", s.strip())
def derive_topics_rule_based(self, segment_text: str, max_topics: int = 3) -> List[str]:
text = self.normalize_text(segment_text)
scores: List[Tuple[str, int]] = []
for topic, kws in TOPIC_KEYWORDS.items():
hit = sum(1 for kw in kws if kw in text)
if hit > 0:
scores.append((topic, hit))
scores.sort(key=lambda x: x[1], reverse=True)
return [t for t, _ in scores[:max_topics]]
def run(self, store: MemoryStore, segment_id: int, segment_text: str,
top_k: int = DEFAULT_TOP_K) -> Dict:
neighbor_ids = self.parse_neighbor_ids(segment_id)
topics = self.derive_topics_rule_based(segment_text)
if not topics:
topics = ["其他"]
candidates = store.list()
ids = set(segment_ids or [])
if ids:
candidates = [c for c in candidates if c.segment_id in ids]
if topics:
candidates = [c for c in candidates if set(c.topics).intersection(set(topics))]
key = (query or "").strip().lower()
scored: List[Dict] = []
for c in candidates:
hay = " ".join([
c.summary.one_liner,
" ".join(c.summary.structured),
" ".join(c.topics),
" ".join(c.dependencies),
" ".join(c.evidence_quotes),
" ".join([r.flag + " " + r.evidence + " " + r.suggestion for r in c.risks]),
]).lower()
score = 0.0
if key:
tokens = [t for t in re.split(r"\s+", key) if t]
score = sum(1.0 for t in tokens if t in hay) / max(1, len(tokens))
rec = c.to_dict()
rec["relevance"] = round(score, 2)
rec["reason"] = "关键词匹配" if key else "主题/ID过滤"
scored.append(rec)
scored.sort(key=lambda x: x.get("relevance", 0.0), reverse=True)
return {"memories": scored[:top_k]}
from __future__ import annotations
from typing import Dict, List, Optional
from dataclasses import asdict
from core.tool import ToolBase, tool, tool_func
from core.memory import RiskFinding
@tool("memory_write", "分段记忆写入")
class MemoryWriteTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "string"},
"topics": {"type": "array"},
"review_output": {"type": "object"},
"dependencies": {"type": "array"},
"write_policy": {"type": "object"},
},
"required": ["segment_id", "topics", "review_output"],
}
)
def run(self, segment_id: str, topics: List[str], review_output: Dict, dependencies: Optional[List[str]] = None, write_policy: Optional[Dict] = None) -> Dict:
# 将审查输出转换为简化结构的 findings 写入 MemoryStore
findings = review_output.get("findings", [])
added: List[Dict] = []
for f in findings:
rule_title = f.get("rule_title") or ""
issue = f.get("issue") or f.get("issue_description") or ""
level = (f.get("level") or f.get("risk_level") or "M").upper()
suggestion = f.get("suggestion") or ""
evs = list(f.get("evidence_quotes", []) or [])
original_text = evs[0] if evs else (f.get("original_text") or "")
try:
finding_obj = RiskFinding(
rule_title=rule_title,
segment_id=int(segment_id) if str(segment_id).isdigit() else 0,
original_text=original_text,
issue_description=issue,
risk_level=level,
suggestion=suggestion,
)
store.add_finding(finding_obj)
added.append(asdict(finding_obj))
except Exception:
# 保守处理:跳过格式异常项
continue
memory_id = "m_" + segment_id
return {
"memory_id": memory_id,
"findings": added,
}
from __future__ import annotations
from typing import Dict, List, Optional
from core.tool import ToolBase, tool, tool_func
@tool("reflect_retry", "反思重试质量闸(MVP)")
class ReflectRetryTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "string"},
"ruleset_id": {"type": "string"},
"context_memories": {"type": "array"},
"review_output": {"type": "object"},
},
"required": ["segment_id", "ruleset_id", "review_output"],
}
)
def run(self, segment_id: str, ruleset_id: str, context_memories: Optional[List[Dict]] = None, review_output: Dict = None) -> Dict:
problems: List[str] = []
actions: List[Dict] = []
findings = (review_output or {}).get("findings", [])
missings = (review_output or {}).get("missing_info", [])
for f in findings:
evs = f.get("evidence_quotes", []) or []
sugg = (f.get("suggestion") or "").strip()
if not evs:
problems.append("缺少对关键结论的原文引用")
if not sugg:
problems.append("建议条款措辞仍偏泛,未给可替换文本")
needs_evidence = any(m.get("type") == "evidence" for m in missings)
needs_knowledge = any(m.get("type") == "knowledge" for m in missings)
status = "pass"
if needs_evidence:
actions.append({"action": "FullRead", "target": segment_id, "need_focus": [m.get("need") for m in missings if m.get("type") == "evidence"]})
if needs_knowledge:
actions.append({"action": "RetrieveReference", "topic": [m.get("topic") for m in missings if m.get("type") == "knowledge"]})
if problems or actions:
status = "revise"
retry_budget_left = int((review_output or {}).get("retry_budget_left", 1))
return {
"segment_id": segment_id,
"status": status,
"problems": list(dict.fromkeys(problems)),
"revise_instructions": actions,
"retry_budget_left": max(0, retry_budget_left - (1 if status == "revise" else 0)),
}
import importlib
import inspect
import pkgutil
from pathlib import Path
from typing import Dict, Optional, Type
from core.tool import ToolBase
# Registry that maps tool names to their classes
_tool_registry: Dict[str, Type[ToolBase]] = {}
def _iter_tool_classes():
package_dir = Path(__file__).parent
for module_info in pkgutil.walk_packages([str(package_dir)], prefix="tools."):
if module_info.name.endswith(".registry") or module_info.name.endswith(".contract_tools"):
continue
module = importlib.import_module(module_info.name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if getattr(obj, "_is_tool", False) and issubclass(obj, ToolBase):
yield obj
def build_registry(force_refresh: bool = False) -> Dict[str, Type[ToolBase]]:
if _tool_registry and not force_refresh:
return _tool_registry
_tool_registry.clear()
for cls in _iter_tool_classes():
_tool_registry[cls._name] = cls
return _tool_registry
def get_tool(name: str) -> Optional[Type[ToolBase]]:
return _tool_registry.get(name)
def available_tools() -> Dict[str, str]:
return {name: cls._description for name, cls in _tool_registry.items()}
# Build registry on import
build_registry()
if __name__ == "__main__":
for tool_name, tool_desc in available_tools().items():
print(f"Tool: {tool_name}, Description: {tool_desc}")
\ No newline at end of file
from __future__ import annotations
from typing import Dict, List, Any
from core.tool import ToolBase, tool, tool_func
from core.cache import get_cached_memory
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
@tool("retrieve_reference", "审查参考检索")
class RetrieveReferenceTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"question": {"type": "string"},
"top_k": {"type": "int"},
},
"required": ["question"],
}
)
def run(self, question: str, top_k: int = 5, conversation_id: str = "") -> Dict:
memory_refs = self._search_memory(question, conversation_id, top_k)
kb_refs = self._search_knowledge_base(question, top_k)
external_refs = self._search_external(question, top_k)
return {
"memory_refs": memory_refs,
"kb_refs": kb_refs,
"external_refs": external_refs,
}
def _search_memory(self, question: str, conversation_id: str, top_k: int) -> List[Dict[str, Any]]:
if not conversation_id:
return []
store = get_cached_memory(conversation_id)
facts = store.get_facts()
results: List[Dict[str, Any]] = []
def _add(dim: str, payload: Any) -> None:
snippet = payload if isinstance(payload, str) else str(payload)
results.append({
"source": f"memory:{dim}",
"dimension": dim,
"snippet": snippet,
})
for dim in FACT_DIMENSIONS:
val = facts.get(dim)
if val is None:
continue
if dim in question:
_add(dim, val)
return results[:top_k]
def _search_knowledge_base(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement KB retrieval
return []
def _search_external(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement external retrieval (e.g., search engine)
return []
if __name__ == "__main__":
tmp_memory = get_cached_memory("tmp")
tmp_memory.add_finding_from_dict({
"issue": "支付方式不明确",
"original_text": "买方应在收到货物后30天内支付全部货款。",
"risk_level": "H",
"rule_title": "支付条款审查",
})
tmp_memory.update_facts({
"支付":{
"支付方式": "银行转账",
"支付期限": "收到货物后30天内",
}
})
tool = RetrieveReferenceTool()
result = tool.run(
question="支付方式是什么?",
top_k=3,
conversation_id="tmp",
)
print(result)
\ No newline at end of file
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from core.tool import ToolBase
from utils.common_util import extract_json
from utils.openai_util import OpenAITool
from core.config import LLM, MAX_WORKERS
class SegmentLLMBase(ToolBase):
"""LLM-backed segment processor: builds prompts, calls LLM, parses JSON."""
def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None:
super().__init__()
self.system_prompt = system_prompt
self.llm = OpenAITool(LLM[llm_key], max_workers=MAX_WORKERS)
def build_messages(self, user_content: str) -> List[Dict[str, str]]:
return [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_content},
]
async def chat_async(self, messages: List[Dict[str, str]]):
return await self.llm.chat(messages)
async def chat_batch_async(self, messages_list: List[List[Dict[str, str]]]):
return await self.llm.mul_chat(messages_list)
def run_with_loop(self, coro):
try:
return asyncio.run(coro)
except RuntimeError:
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)
def parse_first_json(self, resp: str) -> Dict[str, Any]:
try:
data = extract_json(resp)
return data[0] if data else {}
except Exception:
return {}
from __future__ import annotations
import asyncio
import json
from pathlib import Path
from typing import Dict, List, Optional
from core.tool import tool, tool_func
from utils.excel_util import ExcelUtil
from core.tools.segment_llm import SegmentLLMBase
import re
REVIEW_SYSTEM_PROMPT = f'''
你是一个专业的合同分段审查智能体(SegmentReview)。
你的核心任务是对“当前分段”进行【法律风险识别】,并给出可落地的修改建议。
【工作职责】
- 基于给定的“审查规则”,识别当前分段中确定存在的风险、逻辑矛盾或不合理之处。
- 必须通过“证据对碰”:将当前段落内容与【上下文记忆】进行比对,识别前后不一致。
【输出约束】
- findings (风险发现):必须证据充分、可执行。无原文引用不生成 finding。
- original_text (原文证据):必须是合同原文的直接引用,严禁改写、概括或臆造。
- 严格按照输出 JSON Schema 返回结果,不得输出任何解释性文字。
'''
REVIEW_USER_PROMPT = '''
【当前分段文本】
{segment_text}
【上下文记忆(来自已审分段)】
{context_memories_json}
【合同立场】
站在 {party_role} 的立场进行审查。
【审查规则】
{ruleset_text}
【指令】
执行风险识别:基于规则,识别确定存在的风险,并给出直接落地的修改建议(不要使用“建议协商”等泛化词)。
【输出要求】
- 仅输出 JSON 格式。
- findings 字段中必须包含原文引用和具体的修改建议。
'''
REVIEW_OUTPUT_SCHEMA = '''
```json
{
"findings": [
{
"issue": "详细的风险描述",
"original_text": "合同原文片段的直接引用",
"suggestion": "可直接替换原文或新增的条款措辞"
}
]
}
```
'''
LEVEL_WEIGHT = {"H": 3, "M": 2, "L": 1}
ROOT_CAUSE_RULES = {
"ACC_UNCLEAR": ["验收", "标准", "文件", "通过", "确认"],
"PAY_TRIGGER": ["付款", "支付", "触发", "条件"],
"PAY_DELAY": ["逾期", "迟延", "未按期"],
"LIAB_CAP": ["责任", "上限", "限制", "赔偿"],
"BREACH": ["违约", "违约金"],
"TERM": ["解除", "终止", "解约"],
"CONF": ["保密"],
"IP": ["知识产权", "成果", "许可"],
}
GENERIC_SUGGESTIONS = [
"建议协商", "建议完善", "建议明确",
"进一步明确", "双方协商"
]
def _root_cause(issue: str) -> str:
t = _norm(issue)
for key, kws in ROOT_CAUSE_RULES.items():
if any(k in t for k in kws):
return key
return "OTHER"
def _norm(text: str) -> str:
if not text:
return ""
text = text.lower()
text = re.sub(r"\s+", "", text)
text = re.sub(r"[,。;:、()()“”\"']", "", text)
return text
def _has_evidence(f: Dict) -> bool:
return bool(f.get("original_text"))
def _is_generic_suggestion(text: str) -> bool:
t = _norm(text)
return any(g in t for g in GENERIC_SUGGESTIONS)
@tool("segment_review", "合同分段审查")
class SegmentReviewTool(SegmentLLMBase):
def __init__(self):
super().__init__(REVIEW_SYSTEM_PROMPT)
self.rule_version = "rule-v1"
self.column_map = {
"id": "ID",
"title": "审查项",
"rule": "审查规则",
"level": "风险等级",
"triggers": "触发词",
"suggestion_template": "建议模板",
}
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=self.rule_version, column_map=self.column_map)
self.rulesets: Dict[str, List[Dict]] = {self.rule_version: rules} if rules else {}
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"ruleset_id": {"type": "string"},
"party_role": {"type": "string"},
"context_memories": {"type": "array"},
},
"required": ["segment_id", "segment_text", "ruleset_id", "party_role"],
}
)
def run(self, segment_id: str, segment_text: str, ruleset_id: str, party_role: str, context_memories: Optional[List[Dict]] = None) -> Dict:
rules = self.rulesets.get(ruleset_id) or self.rulesets.get(self.rule_version, []) or []
result = self._evaluate_rules(party_role,segment_id,segment_text,rules, context_memories)
overall = "revise" if (result["findings"] ) else "pass"
return {
"segment_id": segment_id,
"overall_conclusion": overall,
"findings": result["findings"],
}
def _stringify_rule(self, rule:Dict) -> str:
res = ''
res += f"审查项: {rule.get('title','')}\n"
res += f"审查规则: {rule.get('rule','')}\n"
res += f"风险等级: {rule.get('level','')}\n"
res += f"触发词: {rule.get('triggers','')}\n"
res += f"建议模板: {rule.get('suggestion_template','')}\n"
return res
def _build_prompt(self, party_role: str, rule: Dict, segment_id: int, segment_text: str, context_memories: Optional[List[Dict]]) -> List[Dict[str, str]]:
user_content = REVIEW_USER_PROMPT.format(
segment_id=segment_id,
segment_text=segment_text,
party_role=party_role,
context_memories_json=json.dumps(context_memories or [], ensure_ascii=False),
ruleset_text=self._stringify_rule(rule)
) + REVIEW_OUTPUT_SCHEMA
return self.build_messages(user_content)
async def _evaluate_rules_async(self, party_role: str, segment_id: int, segment_text: str, rules: List[Dict], context_memories: Optional[List[Dict]]) -> Dict[str, List[Dict]]:
msgs = [self._build_prompt(party_role,rule, segment_id, segment_text, context_memories) for rule in rules]
if not msgs:
return {"findings": []}
try:
responses = await self.chat_batch_async(msgs)
except Exception:
return {"findings": []}
all_findings: List[Dict] = []
for idx,resp in enumerate(responses):
data = self.parse_first_json(resp)
rule_title = rules[idx].get("title","")
rule_level = rules[idx].get("level","")
findings = data.get("findings", []) or []
for f in findings:
f["rule_title"] = rule_title
f["risk_level"] = rule_level
all_findings.extend(findings)
return {
"findings": all_findings,
}
def _evaluate_rules(self, party_role: str, segment_id: int, segment_text: str, rules: List[Dict], context_memories: Optional[List[Dict]]) -> Dict[str, List[Dict]]:
try:
return asyncio.run(self._evaluate_rules_async(party_role, segment_id, segment_text, rules, context_memories))
except RuntimeError:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self._evaluate_rules_async(party_role, segment_id, segment_text, rules, context_memories))
def filter(self,
segment_id: str,
review_result: Dict) -> Dict:
findings: List[Dict] = review_result.get("findings", [])
facts = review_result.get("facts", [])
missing_info: List[str] = []
# 1. 硬过滤:无证据直接丢
findings = [f for f in findings if _has_evidence(f)]
if not findings:
return {
"segment_id": segment_id,
"findings": [],
"facts": facts
}
# 2. 去重(issue 文本级)
dedup = {}
for f in findings:
k = _norm(f.get("issue", ""))
if k not in dedup:
dedup[k] = f
else:
old = dedup[k]
if LEVEL_WEIGHT.get(f["level"], 1) > LEVEL_WEIGHT.get(old["level"], 1):
dedup[k] = f
findings = list(dedup.values())
# 3. 按根因聚类
clusters = {}
for f in findings:
key = _root_cause(f.get("issue", ""))
clusters.setdefault(key, []).append(f)
# 4. 聚类合并
merged = []
for fs in clusters.values():
fs = sorted(
fs,
key=lambda x: (
LEVEL_WEIGHT.get(x.get("level", "L"), 1),
len(x.get("evidence_quotes", []))
),
reverse=True
)
best = fs[0]
merged.append({
"rule_titles": [f.get("rule_title") for f in fs],
"issue": best.get("issue"),
"level": best.get("level"),
"evidence_quotes": list(
{q for f in fs for q in f.get("evidence_quotes", [])}
),
"suggestion": next(
(f["suggestion"] for f in fs
if f.get("suggestion") and not _is_generic_suggestion(f["suggestion"])),
best.get("suggestion")
)
})
# 5. 排序 + 截断(最多 5 条)
merged.sort(
key=lambda x: (
LEVEL_WEIGHT.get(x["level"], 1),
len(x["evidence_quotes"]),
len(x["rule_titles"])
),
reverse=True
)
final_findings = merged[:5]
return {
"segment_id": segment_id,
"findings": final_findings,
"missing_info": missing_info
}
if __name__=="__main__":
tool = SegmentReviewTool()
result = tool.run(
segment_id=1,
segment_text="本合同自双方签字盖章之日起生效,有效期为两年,期满后自动续展一年,除非一方提前30天书面通知对方终止。",
ruleset_id="rule-v1",
party_role="甲方",
context_memories=[
{
"segment_id": 0,
"overall_conclusion": "pass",
"findings": [],
"missing_info": []
}
]
)
print(json.dumps(result, ensure_ascii=False, indent=2))
# filter_result = tool.filter(1,result)
# print(json.dumps(filter_result, ensure_ascii=False, indent=2))
\ No newline at end of file
from __future__ import annotations
import asyncio
import json
from typing import Dict, List, Optional
from core.tool import tool, tool_func
from core.tools.segment_llm import SegmentLLMBase
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
SUMMARY_SYSTEM_PROMPT = f'''
你是合同事实提取智能体(SegmentSummary)。
仅输出“本分段的客观事实”,不做风险判断,不做主观推测。
【输出结构】
- facts: 一个对象,键为预设维度,值为该分段出现的事实(未出现的维度可缺省或置空)。
- 维度列表:{", ".join(FACT_DIMENSIONS)}。
- 若原文包含多个事实,可使用列表或子对象表达,但保持紧凑、可读。
【约束】
- 严禁编造或改写原文未出现的信息。
- 不输出与事实无关的解释或额外文字。
'''
SUMMARY_USER_PROMPT = '''
【分段原文】
{segment_text}
【上下文事实】
{context_facts}
请提取本段出现的客观事实,按照指定维度输出 JSON。未出现的维度可省略。
输出示例:'''
OUTPUT_EXAMPLE = '''
```json
{
"facts": {
"支付": {"方式": "银行转账", "时间": "验收后30日内"},
"违约责任": {"违约金比例": "合同总金额的5%"}
}
}
```
'''
@tool("segment_summary", "分段事实提取")
class SegmentSummaryTool(SegmentLLMBase):
def __init__(self) -> None:
super().__init__(SUMMARY_SYSTEM_PROMPT)
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"party_role": {"type": "string"},
"context_facts": {"type": "object"},
},
"required": ["segment_id", "segment_text"],
}
)
def run(
self,
segment_id: int,
segment_text: str,
party_role: str = "",
context_facts: Optional[Dict] = None,
) -> Dict:
try:
return self.run_with_loop(self._summarize_async(segment_id, segment_text, party_role, context_facts))
except Exception:
return {}
def _build_prompt(self, segment_text: str, context_facts: Optional[Dict], party_role: str) -> List[Dict[str, str]]:
user_content = SUMMARY_USER_PROMPT.format(
segment_text=segment_text,
context_facts=json.dumps(context_facts or {}, ensure_ascii=False),
) + OUTPUT_EXAMPLE
return self.build_messages(user_content)
async def _summarize_async(
self,
segment_id: int,
segment_text: str,
party_role: str,
context_facts: Optional[Dict],
) -> Dict:
msgs = self._build_prompt(segment_text, context_facts, party_role)
final_facts: Dict = {}
try:
resp = await self.chat_async(msgs)
data = self.parse_first_json(resp)
facts = data.get("facts") or {}
except Exception:
facts = {}
# print(f'SegmentSummaryTool facts: {facts}')
if isinstance(facts,list):
final_facts['内容'] = facts
else:
final_facts = facts
final_facts['segment_id'] = segment_id
return final_facts
if __name__=='__main__':
tool = SegmentSummaryTool()
res = tool.run(
segment_id=1,
segment_text="甲方应于合同签订之日起30日内向乙方支付合同总金额的50%,余款在货物验收合格后30日内付清.",
context_facts={},
)
print(res)
\ No newline at end of file
File added
import numpy as np
def calculate_grpo_advantages(rewards, epsilon=1e-8):
"""
计算 GRPO 的组优势值
:param rewards: 列表或数组,包含同一组样本的奖励值
:param epsilon: 稳定性系数,防止除以 0
:return: 归一化后的优势值数组
"""
rewards = np.array(rewards)
# 1. 计算当前组的平均值
mean = np.mean(rewards)
# 2. 计算当前组的标准差
std = np.std(rewards)
# 3. 归一化计算优势
# 减去均值除以标准差,使得该组优势值满足均值为 0,标准差为 1
advantages = (rewards - mean) / (std + epsilon)
return advantages
# 示例数据
your_rewards = [1.1,1.1,1.1,1.1]
advantages = calculate_grpo_advantages(your_rewards)
print(f"原始奖励值: {your_rewards}")
print(f"GRPO 优势值: {advantages.round(4)}")
\ No newline at end of file
from spire.doc import *
from spire.doc.common import *
from pathlib import Path
# 删除原有的单文件读取逻辑,改为遍历文件夹并统计平均值
folder = Path("/home/ccran/lufa-contract/datasets/2-1.合同审核-近三年审核前的合同文件") # 如需自定义目录,修改此处
lengths = []
for p in folder.iterdir():
if not p.is_file():
continue
try:
doc = Document()
doc.LoadFromFile(str(p))
text = doc.GetText()
length = len(text)
print(f"{p.name}: {length}")
lengths.append(length)
except Exception as e:
print(f"跳过文件 {p.name}: {e}")
if lengths:
avg = sum(lengths) / len(lengths)
print(f"平均长度: {avg:.2f}(共 {len(lengths)} 个文件)")
else:
print("没有可处理的文件。")
\ No newline at end of file
import json
d = {
"urls": "[\"/api/common/file/read/大模型角色扮演综述总结.docx?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJidWNrZXROYW1lIjoiY2hhdCIsInRlYW1JZCI6IjM2MzU2MjY1NjYzMDYyMzQzMDM2NjQzNCIsInVpZCI6IjY0NjQ2MTM2MzU2MzM5MzUzNzMwMzU2MiIsImZpbGVJZCI6IjY5NjhhOGU4M2YyMDEzMWI1MjdjMDdjZSIsImV4cCI6MTc2OTA3MTQ2NCwiaWF0IjoxNzY4NDY2NjY0fQ.B_W_CyzpaW__Hb8mrN-Xn5M2FOg73zFN4KoKFnSEQzs\"]",
"conversation_id": "38188df54cbe4884b2b6eb0b30dec898"
}
print(d['urls'])
\ No newline at end of file
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn
from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats
from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool
from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0")
TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool()
########################################################################################################################
class DocumentParseRequest(BaseModel):
conversation_id: str
urls: List[str] = Field(..., description="File download url")
class DocumentParseResponse(BaseModel):
conversation_id: str
text: str
chunk_ids: List[int]
@app.post("/documents/parse", response_model=DocumentParseResponse)
def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
if not payload.urls:
raise HTTPException(status_code=400, detail="No URLs provided")
try:
filename = extract_url_file(payload.urls[0], doc_support_formats)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}")
file_path = str(TMP_DIR / filename)
try:
download_file(payload.urls[0], file_path)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Download failed: {exc}")
doc_obj = get_cached_doc_tool(payload.conversation_id)
doc_obj.load(file_path)
text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list()
return DocumentParseResponse(
conversation_id=payload.conversation_id, text=text,chunk_ids=chunk_ids
)
########################################################################################################################
class SegmentSummaryRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
context_facts: Optional[Dict] = None
class SegmentSummaryResponse(BaseModel):
conversation_id: str
segment_id: int
facts: Dict
@app.post("/segments/summary/facts", response_model=SegmentSummaryResponse)
def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 # chunk_id 在 SpireWordDoc 中为 1-based
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
result = summary_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
party_role=payload.party_role or "",
context_facts=payload.context_facts or store.get_facts(),
)
store.add_facts(result)
return SegmentSummaryResponse(
conversation_id=payload.conversation_id,
segment_id=payload.segment_id,
facts=result,
)
########################################################################################################################
class SegmentReviewRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
ruleset_id: Optional[str] = "rule-v1"
context_memories: Optional[List[Dict]] = None
class SegmentReviewResponse(BaseModel):
conversation_id: str
segment_id: int
overall_conclusion: str
findings: List[Dict]
@app.post("/segments/review/findings", response_model=SegmentReviewResponse)
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
result = review_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
ruleset_id=payload.ruleset_id or "rule-v1",
party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(),
)
# Persist findings to memory store
for f in result.get("findings", []) or []:
try:
store.add_finding_from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": payload.segment_id,
"original_text": f.get("original_text",''),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
})
except Exception:
continue
return SegmentReviewResponse(
conversation_id=payload.conversation_id,
segment_id=payload.segment_id,
overall_conclusion=result.get("overall_conclusion", ""),
findings=result.get("findings", []),
)
########################################################################################################################
class ConversationResponse(BaseModel):
conversation_id: str
created_at: str
@app.post("/conversations/new", response_model=ConversationResponse)
def new_conversation() -> ConversationResponse:
conversation_id = uuid4().hex
created_at = format_now()
return ConversationResponse(conversation_id=conversation_id, created_at=created_at)
########################################################################################################################
class MemoryExportRequest(BaseModel):
conversation_id: str
file_name: Optional[str] = None
class MemoryExportResponse(BaseModel):
conversation_id: str
url: str
data: Dict
@app.post("/memory/export", response_model=MemoryExportResponse)
def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
store = get_cached_memory(payload.conversation_id)
try:
res = store.export_to_excel(file_name=payload.file_name)
except ImportError as exc:
raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Export failed: {exc}")
url = ""
if isinstance(res, str):
url = res
elif isinstance(res, dict):
for key in [
"url",
"file_url",
"fileUrl",
"link",
"downloadUrl",
"path",
"filePath",
]:
val = res.get(key)
if isinstance(val, str) and val:
url = val
break
return MemoryExportResponse(conversation_id=payload.conversation_id, url=url, data=res if isinstance(res, dict) else {"result": res})
if __name__ == "__main__":
uvicorn.run(
"main:app", host="0.0.0.0", port=18168, log_level="info", reload=False
)
\ No newline at end of file
import json
import random
import re
from datetime import datetime
from typing import Dict, List
from loguru import logger
from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size
def random_str(l=5):
return ''.join(random.sample('abcdefghijklmnopqrstuvwxyz', l))
def format_now():
"""Return current time string formatted as YYYY-MM-DD HH:MM:SS."""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 从url中提取文件名称
def extract_url_file(url, support_formats):
pattern = '|'.join([r'[\u4e00-\u9fa5()()0-9\w-]+' + format for format in support_formats])
search_result = re.search(pattern, url)
if search_result:
return search_result.group()
else:
raise Exception(f'{support_formats} not found in url:{url}')
# 调整单个页面数量
def adjust_single_chunk_size(all_text_len):
desired_chunk_size = all_text_len // max_chunk_page
return max(min_single_chunk_size, min(desired_chunk_size, max_single_chunk_size))
# 从JSON字符串提取JSON对象
def extract_json(json_str:str) -> List[Dict]:
"""从字符串中提取 JSON 对象列表。
优先提取 ```json ... ``` 代码块;若不存在,尝试:
- 解析全文为 JSON
- 从普通 ``` ... ``` 代码块解析
- 从任意包含花括号/方括号的片段尝试解析
返回解析成功的 JSON 对象列表(数组会被展开)。
"""
def _try_parse_to_list(candidate: str, out_list: list) -> bool:
s = (candidate or '').strip()
if not s:
return False
# 清理控制字符
s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s)
try:
obj = json.loads(s, strict=False)
if isinstance(obj, list):
out_list.extend(obj)
else:
out_list.append(obj)
return True
except Exception as e:
logger.error(f"JSON解析失败: {e}")
return False
results = []
# 1. 提取 ```json ... ``` 代码块
fenced_json_pattern = r'```json([\s\S]*?)```'
for match in re.findall(fenced_json_pattern, json_str or '', re.DOTALL):
_try_parse_to_list(match, results)
if results:
return results
# 2. 尝试将全文解析为 JSON
if _try_parse_to_list(json_str or '', results):
return results
# 3. 提取普通 ``` ... ``` 代码块,尝试解析
fenced_any_pattern = r'```([\s\S]*?)```'
for match in re.findall(fenced_any_pattern, json_str or '', re.DOTALL):
if _try_parse_to_list(match, results):
return results
# 4. 从包含花括号/方括号的片段尝试解析(启发式,尽力而为)
bracket_pattern = r'(\{[\s\S]*?\}|\[[\s\S]*?\])'
for match in re.findall(bracket_pattern, json_str or '', re.DOTALL):
_try_parse_to_list(match, results)
return results
def remove_duplicates_by_key(data_list, key):
# 先按字符串长度从长到短排序
sorted_list = sorted(
data_list,
key=lambda x: len(x.get(key, "")),
reverse=True
)
result = []
seen_strings = []
for item in sorted_list:
value = item.get(key, "")
if not any(value in s for s in seen_strings):
seen_strings.append(value)
result.append(item)
return result
def extract_drop_json_part(json_str):
json_pattern = r'```json([\s\S]*?)```'
non_json_content = re.sub(json_pattern, '', json_str, re.DOTALL)
return non_json_content.strip()
def group_chunk_by_len(chunk_list: List[Dict], key: str, chunk_len: int) -> List[List[Dict]]:
ret_chunk_list = []
sub_chunk_list = []
current_acc_len = 0 # 用于记录当前 sub_chunk 的累积长度
for chunk in chunk_list:
# 获取当前字典中指定 key 的内容的长度
# 使用 .get(key, "") 防止 key 不存在导致报错
content_len = len(chunk.get(key, ""))
# 判断:如果当前累积长度 + 新内容的长度 > 限制长度
# 且 sub_chunk_list 不为空(确保即使单个元素超长也能被添加)
if current_acc_len + content_len > chunk_len and sub_chunk_list:
# 将当前组加入结果集
ret_chunk_list.append(sub_chunk_list)
# 重置当前组和计数器
sub_chunk_list = []
current_acc_len = 0
# 将当前 chunk 加入子列表
sub_chunk_list.append(chunk)
current_acc_len += content_len
# 循环结束后,如果 sub_chunk_list 还有剩余内容,需要加入结果集
if sub_chunk_list:
ret_chunk_list.append(sub_chunk_list)
return ret_chunk_list
if __name__ == '__main__':
json_str = '```json{"segment_id": "seg-001"}```'
print(extract_json(json_str))
pass
import os
from abc import ABC, abstractmethod
# 文档基类
class DocBase(ABC):
def __init__(self, **kwargs):
self._doc_path = None
self._doc_name = None
self._kwargs = kwargs
self._max_single_chunk_size = kwargs.get('max_single_chunk_size', 2000)
@abstractmethod
def load(self, doc_path):
"""加载文件内容,初始化内部状态。"""
pass
@abstractmethod
def adjust_chunk_size(self):
pass
@abstractmethod
async def get_from_ocr(self):
pass
# 根据chunk_id获取文档片段
@abstractmethod
def get_chunk_item(self, chunk_id):
pass
# 获取文档片段信息
@abstractmethod
def get_chunk_info(self, chunk_id):
pass
@abstractmethod
def get_chunk_location(self, chunk_id):
pass
# 新增片段批注
@abstractmethod
def add_chunk_comment(self, chunk_id, comments):
pass
# 编辑片段批注
@abstractmethod
def edit_chunk_comment(self, comments):
pass
@abstractmethod
def delete_chunk_comment(self, comments):
pass
@abstractmethod
def get_chunk_id_list(self, step=1):
pass
# 获取文档片段数量
@abstractmethod
def get_chunk_num(self):
pass
# 获取文档片段大小
@abstractmethod
def get_all_text(self):
pass
# 保存文档到路径
def to_file(self, path, **kwargs):
pass
def release(self):
if self._doc_path and os.path.exists(self._doc_path):
os.remove(self._doc_path)
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional, Union
class ExcelLoadError(Exception):
"""Raised when Excel loading fails."""
class ExcelUtil:
"""Helper for reading Excel sheets (xlsx/xlsm) using openpyxl."""
def __init__(self, file_path: Union[str, Path]):
self.file_path = Path(file_path)
@staticmethod
def _import_openpyxl():
try:
import openpyxl # type: ignore
except ImportError as exc:
raise ExcelLoadError("openpyxl is required. Install via 'pip install openpyxl'") from exc
return openpyxl
def _ensure_exists(self) -> None:
if not self.file_path.exists():
raise ExcelLoadError(f"File not found: {self.file_path}")
def load(
self,
sheet_name: Optional[str] = None,
has_header: bool = True,
) -> List[Union[Dict[str, object], List[object]]]:
"""
Load data from the Excel file into a list.
Args:
sheet_name: Sheet to read; defaults to the first sheet.
has_header: When True, return a list of dict rows keyed by header row; otherwise return list of lists.
Raises:
ExcelLoadError: When the file is missing, not readable, or openpyxl is unavailable.
"""
self._ensure_exists()
openpyxl = self._import_openpyxl()
try:
wb = openpyxl.load_workbook(self.file_path, data_only=True, read_only=True)
except Exception as exc:
raise ExcelLoadError(f"Failed to open Excel file: {exc}") from exc
ws = wb[sheet_name] if sheet_name else wb.active
rows = list(ws.iter_rows(values_only=True))
if not rows:
return []
if not has_header:
return [list(row) for row in rows]
headers = [str(h).strip() if h is not None else "" for h in rows[0]]
data_rows = rows[1:]
result: List[Dict[str, object]] = []
for row in data_rows:
row_dict = {headers[i] if i < len(headers) else f"col{i}": row[i] for i in range(len(row))}
result.append(row_dict)
return result
def list_sheets(self) -> List[str]:
"""Return available sheet names for the current file."""
self._ensure_exists()
openpyxl = self._import_openpyxl()
try:
wb = openpyxl.load_workbook(self.file_path, read_only=True)
return wb.sheetnames
except Exception as exc:
raise ExcelLoadError(f"Failed to read sheet names: {exc}") from exc
def find_value_by_column(
self,
key_column: str,
key_value: object,
value_column: str,
sheet_name: Optional[str] = None,
) -> Optional[object]:
"""
Return the first value in column `value_column` where column `key_column` equals `key_value`.
Both column arguments should match header names when `has_header=True`.
"""
rows = self.load(sheet_name=sheet_name, has_header=True)
for row in rows:
if row.get(key_column) == key_value:
return row.get(value_column)
return None
def map_rows(self, sheet_name: Optional[str], column_map: Dict[str, str]) -> List[Dict[str, object]]:
"""
Load rows as dicts (header->value) and remap keys using column_map.
column_map: {new_key: header_name_in_excel}
Unmapped headers are ignored; missing headers yield None.
"""
rows = self.load(sheet_name=sheet_name, has_header=True)
mapped: List[Dict[str, object]] = []
for row in rows:
mapped_row = {new_key: row.get(header) for new_key, header in column_map.items()}
mapped.append(mapped_row)
return mapped
@classmethod
def load_excel(
cls,
file_path: Union[str, Path],
sheet_name: Optional[str] = None,
has_header: bool = True,
) -> List[Union[Dict[str, object], List[object]]]:
"""Convenience classmethod wrapper for one-off loads."""
return cls(file_path).load(sheet_name=sheet_name, has_header=has_header)
@classmethod
def list_excel_sheets(cls, file_path: Union[str, Path]) -> List[str]:
"""Convenience classmethod wrapper to list sheet names for a path."""
return cls(file_path).list_sheets()
@classmethod
def find_value_by_column_excel(
cls,
file_path: Union[str, Path],
key_column: str,
key_value: object,
value_column: str,
sheet_name: Optional[str] = None,
) -> Optional[object]:
"""Convenience wrapper for find_value_by_column on a file path."""
return cls(file_path).find_value_by_column(
key_column=key_column,
key_value=key_value,
value_column=value_column,
sheet_name=sheet_name,
)
@classmethod
def load_mapped_excel(
cls,
file_path: Union[str, Path],
sheet_name: Optional[str],
column_map: Dict[str, str],
) -> List[Dict[str, object]]:
"""Convenience wrapper to load rows and remap headers to new keys."""
return cls(file_path).map_rows(sheet_name=sheet_name, column_map=column_map)
if __name__ == "__main__":
# 一次性调用
rows = ExcelUtil.load_excel("data/rules.xlsx", sheet_name=None, has_header=False)
print(rows)
import json
import os
import requests
from loguru import logger
from requests_toolbelt import MultipartEncoder
from core.config import base_fastgpt_url, base_backend_url, outer_backend_url
def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True):
data = {
"chatId": chat_id,
"messages": [
{
"role": "user",
"content": [
{"type": "file_url", "name": "文件", "url": file_url},
{
"type": "text",
"text": text,
},
],
}
],
"model": model,
"stream": stream,
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
response = requests.post(
url=url,
headers=headers,
data=json.dumps(data),
stream=True,
timeout=(60000, 60000),
)
rsp = ""
if stream:
for chunk in response.iter_content(8192):
try:
chunk = chunk.decode("utf-8")
if "data: [DONE]" in chunk:
break
# logger.info(chunk)
stream_rsp = json.loads(chunk[6:])
rsp += stream_rsp.get("choices")[0]["delta"]["content"]
except Exception as e:
logger.error(f"error decode:{e} chunk:{chunk}")
else:
# print(response.text)
json_rsp = json.loads(response.text)
rsp = json_rsp.get("choices")[0]["message"]["content"]
return rsp
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False):
# 登录获取token
login_data = {"username": "admin", "password": "admin@jpai.com"}
login_url = f"{base_backend_url}/admin-api/system/auth/login"
response = requests.post(
url=login_url,
headers={"Content-Type": "application/json"},
data=json.dumps(login_data),
)
try:
token = json.loads(response.text).get("data").get("accessToken")
except Exception as e:
logger.error(f"后端登录异常:{e}")
# logger.info(f"上传Token: {token}")
# 上传文件
upload_url = f"{base_backend_url}/admin-api/infra/file/upload"
m = MultipartEncoder(fields={"file": (os.path.basename(path), open(path, "rb"))})
response = requests.post(
url=upload_url,
headers={"Content-Type": m.content_type, "Authorization": token},
data=m,
)
res = json.loads(response.text).get("data")
if res:
return res
else:
raise Exception(f"上传{path}失败 Response text: {response.text}")
# 下载url到本地path
def download_file(url, path, input_url_to_inner=True):
if not url.startswith("http:"):
url = base_fastgpt_url + url
url = url.replace(outer_backend_url, base_backend_url)
logger.info(f"url准备下载:{url}")
# 发送一个HTTP请求到URL
response = requests.get(url)
# 确保请求成功
if response.status_code == 200:
# 打开本地文件,准备写入数据
with open(path, "wb") as f:
# 写入响应的内容
f.write(response.content)
logger.info(f"{url}文件下载成功,保存到{path}")
else:
logger.error(f"{url}文件下载失败. HTTP Status Code: {response.status_code}")
def url_replace_fastgpt(origin: str):
if not origin.startswith("http:"):
origin = base_fastgpt_url + origin
return origin
if __name__ == "__main__":
# d = '/home/ccran/file.docx'
d = "file.docx"
print(os.path.basename(d))
from loguru import logger
from openai import AsyncOpenAI
from dataclasses import dataclass
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed
import asyncio
@dataclass
class LLMConfig:
base_url: str
api_key: str
model: str
class OpenAITool:
def __init__(self, llm_config: LLMConfig, max_workers: int = 5):
self.max_workers = max_workers
self.llm_config = llm_config
self.client = AsyncOpenAI(
base_url=llm_config.base_url, api_key=llm_config.api_key
)
@retry(stop=stop_after_delay(600) | stop_after_attempt(1), wait=wait_fixed(1))
async def chat(self, msg, tools=None):
if tools is None:
extra_body = None
if msg[0]['role'] == 'system':
extra_body = {
'variables': {
'system': msg[0]['content']
}
}
msg = msg[1:]
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=msg,
extra_body=extra_body
)
content = response.choices[0].message.content
reasoning_content = response.choices[0].message.model_extra.get(
"reasoning_content", ""
)
return content
else:
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=msg,
tools=tools,
tool_choice="auto",
)
return response.choices[0].message.tool_calls
async def mul_chat(self, msgs, tools=None):
sem = asyncio.Semaphore(self.max_workers)
async def _wrapped(m):
async with sem:
return await self.chat(m, tools)
return await asyncio.gather(*[_wrapped(m) for m in msgs])
\ No newline at end of file
from spire.doc import Document, Paragraph, Table, Comment, CommentMark, CommentMarkType
import json
from loguru import logger
import re
from thefuzz import fuzz
from utils.doc_util import DocBase
from utils.common_util import adjust_single_chunk_size
import os
def extract_table_cells_text(table, joiner="\n"):
"""
从 Spire.Doc 的 Table 对象中提取每个单元格文本,并按行主序返回扁平列表:
["r0c0_text", "r0c1_text", "r1c0_text", ...]
joiner: 用于连接单元格内多段落或嵌套表行的分隔符(默认换行)
注意:不对文本做任何清洗或 strip,保持原有格式
"""
def _para_text(para):
# 优先使用 para.Text(保留原样),否则尝试从 para.ChildObjects 收集 Text-like 字段
try:
if hasattr(para, "Text"):
return para.Text if para.Text is not None else ""
except Exception:
pass
parts = []
try:
for idx in range(para.ChildObjects.Count):
obj = para.ChildObjects[idx]
if hasattr(obj, "Text"):
parts.append(obj.Text if obj.Text is not None else "")
except Exception:
pass
return "".join(parts)
def _extract_cell_text(cell):
parts = []
# 收集单元格内所有段落文本(保持原样,不做 strip)
try:
for p_idx in range(cell.Paragraphs.Count):
para = cell.Paragraphs[p_idx]
parts.append(_para_text(para))
except Exception:
pass
# 处理嵌套表格(若存在),把嵌套表每一行合并为一条字符串,并按行加入 parts
try:
if hasattr(cell, "Tables") and cell.Tables.Count > 0:
for t_idx in range(cell.Tables.Count):
nested = cell.Tables[t_idx]
nested_rows = []
for nr in range(nested.Rows.Count):
nested_row_cells = []
for nc in range(nested.Rows[nr].Cells.Count):
try:
# 取嵌套单元格的所有段落并用 joiner 连接(保留原样)
nc_parts = []
for np_idx in range(
nested.Rows[nr].Cells[nc].Paragraphs.Count
):
nc_parts.append(
_para_text(
nested.Rows[nr].Cells[nc].Paragraphs[np_idx]
)
)
nested_row_cells.append(joiner.join(nc_parts))
except Exception:
nested_row_cells.append("")
nested_rows.append(joiner.join(nested_row_cells))
parts.append(joiner.join(nested_rows))
else:
# 有时嵌套表格会放在 cell.ChildObjects 中,兼容处理
try:
for idx in range(cell.ChildObjects.Count):
ch = cell.ChildObjects[idx]
if hasattr(ch, "Rows") and getattr(ch, "Rows") is not None:
nested = ch
nested_rows = []
for nr in range(nested.Rows.Count):
nested_row_cells = []
for nc in range(nested.Rows[nr].Cells.Count):
try:
nc_parts = []
for np_idx in range(
nested.Rows[nr].Cells[nc].Paragraphs.Count
):
nc_parts.append(
_para_text(
nested.Rows[nr]
.Cells[nc]
.Paragraphs[np_idx]
)
)
nested_row_cells.append(joiner.join(nc_parts))
except Exception:
nested_row_cells.append("")
nested_rows.append(joiner.join(nested_row_cells))
parts.append(joiner.join(nested_rows))
except Exception:
pass
except Exception:
pass
# 把单元格内收集到的片段用 joiner 连接成最终字符串(不做任何 trim/clean)
return joiner.join(parts)
flat = []
for r in range(table.Rows.Count):
row = table.Rows[r]
for c in range(row.Cells.Count):
cell = row.Cells[c]
cell_text = _extract_cell_text(cell)
# 保持原样,空单元格返回空字符串
flat.append(cell_text)
return flat
def process_string(s):
# 统计换行符数量
newline_count = s.count("\n")
# 情况1:没有换行符
if newline_count == 0:
return s
# 情况2:只有一个换行符
elif newline_count == 1:
# 分割成两部分
parts = s.split("\n", 1)
# 比较前后部分长度
return parts[0] if len(parts[0]) >= len(parts[1]) else parts[1]
# 情况3:多个换行符
else:
# 分割所有部分
parts = s.split("\n")
# 找出中间部分(排除首尾)
middle_parts = parts[1:-1] if len(parts) > 2 else []
# 如果没有有效中间部分
if not middle_parts:
# 返回最长的一段(排除空字符串)
non_empty_parts = [p for p in parts if p]
return max(non_empty_parts, key=len) if non_empty_parts else ""
# 返回最长的中间部分
return max(middle_parts, key=len, default="")
def build_mapping(original: str):
"""构造规范化文本和原文索引映射"""
normalized = []
mapping = []
for m in re.finditer(r"\S+", original):
word = m.group()
if normalized:
normalized.append(" ")
mapping.append(m.start()) # 空格映射
for j, ch in enumerate(word):
normalized.append(ch)
mapping.append(m.start() + j)
return "".join(normalized), mapping
def extract_match(big_text: str, small_text: str, threshold=20):
"""
简化版文本匹配函数
核心逻辑:优先整个文本块匹配,次优子句匹配
"""
# 1. 精确匹配整个文本块
if small_text in big_text:
return small_text, 100
# 2. 整个文本块模糊匹配
full_score = fuzz.ratio(big_text, small_text)
if full_score >= threshold:
return big_text, full_score
# 3. 子句匹配(简单分割)
best_score = 0
best_clause = None
# 简单分割:按句号、分号、逗号分割
for clause in big_text.replace("。", ";").replace(",", ";").split(";"):
if not clause.strip():
continue
clause_score = fuzz.ratio(clause, small_text)
if clause_score > best_score:
best_score = clause_score
best_clause = clause
# 4. 返回最佳匹配
if best_score >= threshold:
return best_clause, best_score
# 5. 无有效匹配
return None, max(full_score, best_score)
def find_best_match(sub_chunks, comment):
"""
在给定的文本块中查找与原始评论最匹配的文本
参数:
sub_chunks -- 包含Text属性的对象列表
comment -- 包含"original_text"的字典
返回:
best_match -- 匹配度最高的文本
best_score -- 最高匹配度
all_results -- 所有匹配结果列表(匹配文本, 相似度)
"""
all_results = [] # 存储所有(匹配文本, 相似度)的元组
best_match = None # 存储最佳匹配的结果
best_score = -1 # 存储最高相似度(初始化为-1)
print(f"开始处理评论: {comment['original_text'][:30]}...") # 显示简化的原始评论
for obj in sub_chunks:
if isinstance(obj, Paragraph):
target_text = obj.Text
original_text = comment["original_text"]
match_text, score = extract_match(target_text, original_text)
# 打印当前结果(保持原格式)
# print("匹配到:\n", match_text)
# print("相似度:", score)
# 存储所有结果
all_results.append((match_text, score))
# 更新最佳匹配 - 只更新分数更高的结果
if score > best_score:
best_match = match_text
best_score = score
# 打印最终的最佳匹配结果
# print("\n" + "=" * 40)
# print("\n处理完成 - 最佳匹配结果:")
# print("匹配到:\n", best_match)
# print("相似度:", best_score)
# print("=" * 40 + "\n")
return best_match, best_score
def table_contract(target_texts, comment):
"""
在给定的文本块中查找与原始评论最匹配的文本
参数:
sub_chunks -- 待对比文本
comment -- 包含"original_text"的字典
返回:
best_match -- 匹配度最高的文本
best_score -- 最高匹配度
all_results -- 所有匹配结果列表(匹配文本, 相似度)
"""
all_results = [] # 存储所有(匹配文本, 相似度)的元组
best_match = None # 存储最佳匹配的结果
best_score = -1 # 存储最高相似度(初始化为-1)
print(f"开始处理评论: {comment['original_text'][:30]}...") # 显示简化的原始评论
original_text = comment["original_text"]
for target_text in target_texts:
match_text, score = extract_match(target_text, original_text)
# 打印当前结果(保持原格式)
# print("匹配到:\n", match_text)
# print("相似度:", score)
# 存储所有结果
all_results.append((match_text, score))
# 更新最佳匹配 - 只更新分数更高的结果
if score > best_score:
best_match = match_text
best_score = score
# 打印最终的最佳匹配结果
# print("\n" + "=" * 40)
# print("\n处理完成 - 最佳匹配结果:")
# print("匹配到:\n", best_match)
# print("相似度:", best_score)
# print("=" * 40 + "\n")
return best_match, best_score
# spire doc解析
class SpireWordDoc(DocBase):
def load(self, doc_path, **kwargs):
# License.SetLicenseFileFullPath(f"{root_path}/license.elic.python.xml")
self._doc_path = doc_path
self._doc_name = os.path.basename(doc_path)
self._doc = Document()
self._doc.LoadFromFile(doc_path)
self._chunk_list = self._resolve_doc_chunk()
return self
def _ensure_loaded(self):
if not self._doc:
raise RuntimeError("Document not loaded. Call load() first.")
def adjust_chunk_size(self):
self._ensure_loaded()
all_text_len = len(self.get_all_text())
self._max_single_chunk_size = adjust_single_chunk_size(all_text_len)
logger.info(
f"SpireWordDoc adjust _max_single_chunk_size to {self._max_single_chunk_size}"
)
self._chunk_list = self._resolve_doc_chunk()
return self._max_single_chunk_size
async def get_from_ocr(self):
pass
# 把文档分割成chunk
def _resolve_doc_chunk(self):
self._ensure_loaded()
chunk_list = []
# 单个chunk
single_chunk = ""
# 单个chunk的位置信息
single_chunk_location = []
# 遍历每个节
for section_idx in range(self._doc.Sections.Count):
current_section = self._doc.Sections.get_Item(section_idx)
# 遍历节里面每个子对象
for section_child_idx in range(current_section.Body.ChildObjects.Count):
# 获取子对象
child_obj = current_section.Body.ChildObjects.get_Item(
section_child_idx
)
# 段落处理
current_child_text = ""
if isinstance(child_obj, Paragraph):
paragraph = child_obj
current_child_text = paragraph.Text
# 表格处理
elif isinstance(child_obj, Table):
table = child_obj
current_child_text = self._resolve_table(table)
# 跳过其他非文本子对象
else:
continue
# 添加新对象
if (
len(single_chunk) + len(current_child_text)
> self._max_single_chunk_size
):
chunk_list.append(
{
"chunk_content": single_chunk,
"chunk_location": single_chunk_location,
}
)
single_chunk = ""
single_chunk_location = []
single_chunk += current_child_text + "\n"
single_chunk_location.append(
{"section_idx": section_idx, "section_child_idx": section_child_idx}
)
if len(single_chunk):
chunk_list.append(
{"chunk_content": single_chunk, "chunk_location": single_chunk_location}
)
return chunk_list
# 表格解析为markdown
def _resolve_table(self, table):
table_data = ""
for i in range(0, table.Rows.Count):
# 遍历行的单元格(cells)
cell_list = []
for j in range(0, table.Rows.get_Item(i).Cells.Count):
# 获取每一个单元格(cell)
cell = table.Rows.get_Item(i).Cells.get_Item(j)
cell_content = ""
for para_idx in range(cell.Paragraphs.Count):
paragraph_text = cell.Paragraphs.get_Item(para_idx).Text
cell_content += paragraph_text
cell_list.append(cell_content)
table_data += "|" + "|".join(cell_list) + "|"
table_data += "\n"
if i == 0:
table_data += "|" + "|".join(["--- " for _ in cell_list]) + "|\n"
return table_data
def get_chunk_info(self, chunk_id):
chunk = self._chunk_list[chunk_id]
chunk_content = chunk["chunk_content"]
chunk_location = chunk["chunk_location"]
from_location = f"[第{chunk_location[0]['section_idx'] + 1}节的第{chunk_location[0]['section_child_idx'] + 1}段落]"
to_location = f"[第{chunk_location[-1]['section_idx'] + 1}节的第{chunk_location[-1]['section_child_idx'] + 1}段落]"
chunk_content_tips = (
"[" + chunk_content[:20] + "]...到...[" + chunk_content[-20:] + "]"
)
return f"文件块id: {chunk_id + 1}\n文件块位置: 从{from_location}到{to_location}\n文件块简述: {chunk_content_tips}\n"
def get_chunk_location(self, chunk_id):
return self.get_chunk_info(chunk_id)
def get_chunk_num(self):
self._ensure_loaded()
return len(self._chunk_list)
def get_chunk_item(self, chunk_id):
self._ensure_loaded()
return self._chunk_list[chunk_id]["chunk_content"]
# 根据locations获取数据
def get_sub_chunks(self, chunk_id):
if chunk_id >= len(self._chunk_list):
logger.error(f"get_sub_chunks_error:{chunk_id}")
return []
chunk = self._chunk_list[chunk_id]
chunk_locations = chunk["chunk_location"]
return [
self._doc.Sections.get_Item(loc["section_idx"]).Body.ChildObjects.get_Item(
loc["section_child_idx"]
)
for loc in chunk_locations
]
def format_comment_author(self, comment):
return "{}|{}".format(str(comment["id"]), comment["key_points"])
def remove_comment_prefix(
self,
):
for i in range(self._doc.Comments.Count):
current_comment = self._doc.Comments.get_Item(i)
comment_author = current_comment.Format.Author
split_author = comment_author.split("|")
if len(split_author) == 2:
current_comment.Format.Author = comment_author.split("|")[1]
# 根据text_selection批注
def set_comment_by_text_selection(self, text_sel, author, comment_content):
if text_sel is None:
return False
# 将找到的文本作为文本范围,并获取其所属的段落
range = text_sel.GetAsOneRange()
paragraph = range.OwnerParagraph
if paragraph is None:
return False
# 创建一个评论对象并设置评论的内容和作者
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_content
comment.Format.Author = author
logger.info(author)
# 将评论添加到段落中
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range) + 1, comment
)
# 创建评论起始标记和结束标记,并将它们设置为创建的评论的起始标记和结束标记
commentStart = CommentMark(self._doc, CommentMarkType.CommentStart)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在找到的文本之前和之后插入创建的评论起始和结束标记
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range), commentStart
)
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(range) + 1, commentEnd
)
return True
# 根据段落批注
def set_comment_by_paragraph(self, paragraph, author, comment_content):
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_content
# 设置注释的作者
comment.Format.Author = author
paragraph.ChildObjects.Add(comment)
# 创建注释开始标记和结束标记,并将它们设置为创建的注释的开始和结束标记
commentStart = CommentMark(self._doc, CommentMarkType.CommentStart)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在段落结尾插入注释开始标记和结束标记
# paragraph.ChildObjects.Add(commentStart)
paragraph.ChildObjects.Add(commentEnd)
# 也可以考虑在段落开始处插入标记
paragraph.ChildObjects.Insert(0, commentStart)
# 设置chunk批注
def add_table_comment(
self, table, target_text, comment_text, author="审阅助手", initials="AI"
):
"""
在表格中添加批注
返回是否成功添加
"""
added = False
# 遍历表格所有单元格
for i in range(table.Rows.Count):
row = table.Rows[i]
for j in range(row.Cells.Count):
cell = row.Cells[j]
# 遍历单元格中的段落
for k in range(cell.Paragraphs.Count):
para = cell.Paragraphs[k]
# 在段落中查找目标文本
selection = para.Find(target_text, False, True)
if selection:
# 获取文本范围
text_range = selection.GetAsOneRange()
if text_range is None:
continue
# 获取所属段落
paragraph = text_range.OwnerParagraph
if paragraph is None:
continue
# 创建一个评论对象并设置评论的内容和作者
comment = Comment(self._doc)
comment.Body.AddParagraph().Text = comment_text
comment.Format.Author = author
# 将评论添加到段落中
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range) + 1, comment
)
# 创建评论起始标记和结束标记
commentStart = CommentMark(
self._doc, CommentMarkType.CommentStart
)
commentEnd = CommentMark(self._doc, CommentMarkType.CommentEnd)
commentStart.CommentId = comment.Format.CommentId
commentEnd.CommentId = comment.Format.CommentId
# 在找到的文本之前和之后插入创建的评论起始和结束标记
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range), commentStart
)
paragraph.ChildObjects.Insert(
paragraph.ChildObjects.IndexOf(text_range) + 1, commentEnd
)
added = True
print(f"表格批注添加成功: '{target_text[:20]}...'")
# 添加成功后跳出内层循环
break
# 如果已经添加,跳出单元格循环
if added:
break
# 如果已经添加,跳出行循环
if added:
break
return added
def add_chunk_comment(self, chunk_id, comments):
"""
为chunk添加批注(保证每条评论只批注一次)
"""
if chunk_id is not None:
sub_chunks = self.get_sub_chunks(chunk_id)
for comment in comments:
if comment.get("result") != "不合格":
continue
# update chunk_id
chunk_id = comment.get("chunk_id", -1)
if chunk_id is not None and chunk_id != -1:
sub_chunks = self.get_sub_chunks(chunk_id)
author = self.format_comment_author(comment)
suggest = comment.get("suggest", "")
find_key = comment["original_text"].strip() or comment["key_points"]
# 先检查是否已经有批注
existing_comment_idx = self.find_comment(author)
if existing_comment_idx is not None:
# 已存在批注,则更新内容
self._doc.Comments.get_Item(
existing_comment_idx
).Body.Paragraphs.get_Item(0).Text = suggest
print(f"批注已存在,更新内容: '{find_key[:20]}...'")
continue
matched = False
# ---------- 1. 精确匹配(段落 + 表格) ----------
for obj in sub_chunks:
if isinstance(obj, Paragraph):
try:
text_sel = obj.Find(find_key, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
# print(f"段落批注添加成功: '{find_key[:20]}...'")
matched = True
break
except Exception as e:
print(f"段落批注添加失败: {str(e)}")
elif isinstance(obj, Table):
try:
if self.add_table_comment(obj, find_key, suggest, author):
print(f"表格批注添加成功: '{find_key[:20]}...'")
matched = True
break
except Exception as e:
print(f"表格批注添加失败: {str(e)}")
# ---------- 2. 模糊匹配 ----------
if not matched:
try:
paragraphs_only = [
obj for obj in sub_chunks if isinstance(obj, Paragraph)
]
match_text, _ = find_best_match(paragraphs_only, comment)
if match_text:
for obj in paragraphs_only:
text_sel = obj.Find(match_text, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
print(f"模糊批注添加成功: '{match_text[:20]}...'")
matched = True
break
if not matched:
processed_text = process_string(match_text)
for obj in paragraphs_only:
text_sel = obj.Find(processed_text, False, True)
if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
print(
f"处理后批注添加成功: '{processed_text[:20]}...'"
)
matched = True
break
# 表格模糊匹配(仅段落模糊匹配失败才跑)
if not matched:
for obj in sub_chunks:
if isinstance(obj, Table):
table_data = extract_table_cells_text(obj)
best_table_match, _ = table_contract(
table_data, comment
)
if best_table_match and self.add_table_comment(
obj, best_table_match, suggest, author
):
print(
f"表格批注添加成功: '{best_table_match[:20]}...'"
)
matched = True
break
except Exception as e:
print(f"模糊匹配失败: {str(e)}")
# ---------- 3. 匹配最终失败 ----------
if not matched:
logger.error(f"未找到可批注位置: '{find_key[:20]}...'")
# 根据作者名称查找批注
def find_comment(self, author):
for i in range(self._doc.Comments.Count):
current_comment = self._doc.Comments.get_Item(i)
comment_author = current_comment.Format.Author
if comment_author == author:
return i
return None
def delete_chunk_comment(self, comments):
"""
删除指定作者批注
"""
for comment in comments:
author = self.format_comment_author(comment)
author_comment_idx = self.find_comment(author)
if author_comment_idx is not None:
self._doc.Comments.RemoveAt(author_comment_idx)
print(f"删除批注: '{author}'")
def edit_chunk_comment(self, comments):
"""
编辑chunk批注:删除已合格的批注,修改存在的批注,不存在则新增
"""
for comment in comments:
author = self.format_comment_author(comment)
review_answer = comment["result"]
existing_comment_idx = self.find_comment(author)
if review_answer == "合格":
# 删除批注
if existing_comment_idx is not None:
self._doc.Comments.RemoveAt(existing_comment_idx)
print(f"已删除合格批注: '{author}'")
else:
# 不合格,更新或新增
suggest = comment.get("suggest", "")
if existing_comment_idx is not None:
self._doc.Comments.get_Item(
existing_comment_idx
).Body.Paragraphs.get_Item(0).Text = suggest
print(f"更新已有批注: '{author}'")
else:
# chunk_id要从comment中获取
self.add_chunk_comment(comment["chunk_id"] - 1, [comment])
def get_chunk_id_list(self, step=1):
self._ensure_loaded()
return [idx + 1 for idx in range(0, self.get_chunk_num(), step)]
def get_all_text(self):
self._ensure_loaded()
return self._doc.GetText()
def to_file(self, path, remove_prefix=False):
self._ensure_loaded()
if remove_prefix:
self.remove_comment_prefix()
self._doc.SaveToFile(path)
def release(self):
# 关闭文件
if self._doc:
self._doc.Close()
super().release()
def __del__(self):
pass
# self.release()
if __name__ == "__main__":
doc = SpireWordDoc()
doc.load(
r"/home/ccran/lufa-contract/datasets/2-1.合同审核-近三年审核前的合同文件/20230101 麓谷发展视频合同书20230209(1).doc"
)
print(doc._doc_name)
# print(doc.get_chunk_info(4))
doc.add_chunk_comment(
0,
[
{
"id": "1",
"key_points": "主体资格审查",
"original_text": "湖南麓谷发展集团有限公司",
"details": "1111",
"chunk_id": 1,
"result": "不合格",
"suggest": "这是测试建议",
}
],
)
doc.to_file("test.docx", True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment