Commit d26c53e1 by ccran

feat: add readme.md

parent 5f18aa67
...@@ -7,5 +7,7 @@ ...@@ -7,5 +7,7 @@
# Keep Python source files # Keep Python source files
!**/*.py !**/*.py
!README.md
# Keep this file tracked # Keep this file tracked
!.gitignore !.gitignore
\ No newline at end of file
# 合同审查智能体 (Contract Review Agent)
一个基于 FastAPI 和大型语言模型 (LLM) 的智能合同审查系统,能够自动分析合同条款、识别风险并提供审查建议。
## 📋 项目概述
本项目是一个智能合同审查代理,通过以下流程实现合同自动化审查:
1. **文档解析** - 支持多种格式的合同文档解析
2. **分段处理** - 将合同按规则智能分段
3. **事实提取** - 从每个分段中提取与审查规则相关的客观事实
4. **规则审查** - 基于预设规则对提取的事实进行审查
5. **风险复核** - 对审查结果进行反思和复核
6. **结果合并** - 合并所有分段审查结果生成最终报告
## 🏗️ 项目结构
```
lufa-contract/
├── main.py # FastAPI 主应用入口
├── test.py # 测试脚本
├── core/ # 核心业务逻辑
│ ├── cache.py # 缓存管理
│ ├── config.py # 配置管理
│ ├── memory.py # 记忆/状态管理
│ ├── tool.py # 工具基类
│ └── tools/ # 具体工具实现
│ ├── segment_summary.py # 分段事实提取
│ ├── segment_review.py # 分段规则审查
│ ├── segment_rule_router.py # 规则路由
│ ├── retrieve_reference.py # 参考检索
│ ├── reflect_retry.py # 反思重试
│ └── segment_merger.py # 结果合并
├── data/ # 数据文件
│ ├── rules.xlsx # 审查规则表
│ ├── batch/ # 批量处理数据
│ └── benchmark/ # 基准测试数据
├── utils/ # 工具函数
│ ├── common_util.py # 通用工具
│ ├── http_util.py # HTTP 工具
│ └── doc_util.py # 文档工具
├── demo/ # 演示文件
├── tmp/ # 临时文件
└── .vscode/ # VSCode 配置
```
## 🔧 技术栈
- **后端框架**: FastAPI
- **LLM 服务**: Qwen2-72B-Instruct (可配置)
- **文档处理**: 支持 PDF、Word 等多种格式
- **日志**: Loguru
- **数据验证**: Pydantic
## 📦 核心功能
### 1. 分段事实提取 (SegmentSummary)
基于审查规则从合同分段中提取客观事实,确保:
- 事实可在原文中直接找到
- 不做抽象、概括或推断
- 不补充未出现的主体、条件或数值
### 2. 分段规则审查 (SegmentReview)
对提取的事实进行规则匹配和风险分析,输出:
- 风险等级 (H/M/L)
- 审查结论
- 修改建议
### 3. 反思重试 (ReflectRetry)
对审查结果进行自我反思,识别潜在问题并重试
### 4. 结果合并 (SegmentMerger)
合并所有分段的审查结果,生成完整的审查报告
## ⚙️ 配置说明
`core/config.py` 中可配置:
```python
# LLM 配置
LLMConfig:
base_url: "http://192.168.252.71:9002/v1"
model: "Qwen2-72B-Instruct"
# 审查规则集
ALL_RULESET_IDS = ["通用", "借款", "担保", "财务口", "金盘", "金盘简化"]
# 分段大小控制
MAX_SINGLE_CHUNK_SIZE = 5000
```
## 🚀 快速开始
### 1. 安装依赖
```bash
pip install fastapi uvicorn pydantic loguru
```
### 2. 启动服务
```bash
python main.py
```
服务将在 `http://localhost:8000` 启动
### 3. API 端点
- `POST /sleep` - 测试端点
- `POST /document/parse` - 解析合同文档
- `POST /contract/review` - 执行合同审查
- `GET /contract/{conversation_id}/result` - 获取审查结果
## 📝 使用示例
### 提交合同审查请求
```python
import requests
# 上传合同文档
response = requests.post(
"http://localhost:8000/document/parse",
json={
"conversation_id": "unique-conversation-id",
"file_url": "http://example.com/contract.pdf",
"ruleset_id": "通用"
}
)
# 获取审查结果
result = requests.get(
f"http://localhost:8000/contract/{response.json()['conversation_id']}/result"
)
```
## 🔐 安全说明
- API Key 配置在 `core/config.py`
- 支持内外网环境切换 (`use_lufa` 参数)
- 临时文件自动清理
## 📊 数据格式
### 审查结果结构
```json
{
"conversation_id": "xxx",
"findings": [
{
"segment_id": "seg_001",
"rule_id": "rule_001",
"risk_level": "H",
"fact": "提取的事实",
"conclusion": "审查结论",
"suggestion": "修改建议"
}
]
}
```
## 🛠️ 开发指南
### 添加新的审查规则
1.`data/rules.xlsx` 中添加新规则
2. 更新 `core/config.py` 中的规则集配置
3. 重启服务
### 自定义 LLM 模型
修改 `core/config.py` 中的 `LLMConfig`:
```python
LLMConfig:
base_url: "你的 LLM 服务地址"
model: "你的模型名称"
```
## 📄 许可证
内部使用,保留所有权利。
## 👥 维护者
- 开发团队
## 📞 联系方式
如有问题,请联系项目维护团队。
...@@ -17,7 +17,7 @@ from core.config import META_KEY ...@@ -17,7 +17,7 @@ from core.config import META_KEY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ALLOWED_RISK_LEVELS = {"H", "M", "L",""} _ALLOWED_RISK_LEVELS = {"H", "M", "L", ""}
FINDING_KEY_REVIEW = "review" FINDING_KEY_REVIEW = "review"
FINDING_KEY_REFLECT = "reflect" FINDING_KEY_REFLECT = "reflect"
FINDING_KEY_MERGE = "merge" FINDING_KEY_MERGE = "merge"
...@@ -44,7 +44,9 @@ class Finding: ...@@ -44,7 +44,9 @@ class Finding:
def __post_init__(self) -> None: def __post_init__(self) -> None:
level = (self.risk_level or "").upper() level = (self.risk_level or "").upper()
if level not in _ALLOWED_RISK_LEVELS: if level not in _ALLOWED_RISK_LEVELS:
raise ValueError(f"risk_level must be one of {_ALLOWED_RISK_LEVELS}, got {self.risk_level}") raise ValueError(
f"risk_level must be one of {_ALLOWED_RISK_LEVELS}, got {self.risk_level}"
)
self.risk_level = level self.risk_level = level
@classmethod @classmethod
...@@ -72,10 +74,9 @@ class Finding: ...@@ -72,10 +74,9 @@ class Finding:
class MemoryStore: class MemoryStore:
"""简化的记忆存储:合同事实 facts 与问题 findings。线程安全并支持 JSON 持久化。""" """简化的记忆存储:合同事实 facts 与问题 findings。线程安全并支持 JSON 持久化。"""
storage_name: Optional[Path] = 'default.json' storage_name: Optional[Path] = "default.json"
def __init__(self,storage_name:str = 'default.json') -> None: def __init__(self, storage_name: str = "default.json") -> None:
self._storage_path = Path(__file__).resolve().parent.parent / "tmp" / storage_name # type: ignore[arg-type] self._storage_path = Path(__file__).resolve().parent.parent / "tmp" / storage_name # type: ignore[arg-type]
self._storage_path.parent.mkdir(parents=True, exist_ok=True) self._storage_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = RLock() self._lock = RLock()
...@@ -119,12 +120,14 @@ class MemoryStore: ...@@ -119,12 +120,14 @@ class MemoryStore:
for top_key, top_value in item.items(): for top_key, top_value in item.items():
if _key_match(top_key): if _key_match(top_key):
matched_values.append({ matched_values.append(
{
top_key: top_value, top_key: top_value,
META_KEY: item.get(META_KEY, {}) # include metadata if exists META_KEY: item.get(
}) META_KEY, {}
), # include metadata if exists
}
)
return matched_values return matched_values
...@@ -141,8 +144,16 @@ class MemoryStore: ...@@ -141,8 +144,16 @@ class MemoryStore:
def delete_findings_by_segment(self, key: str, segment_id: int) -> int: def delete_findings_by_segment(self, key: str, segment_id: int) -> int:
return self._delete_findings_by_segment(key, segment_id) return self._delete_findings_by_segment(key, segment_id)
def search_findings(self, key: str, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[Finding]: def search_findings(
return self._search_findings(self._get_findings_bucket(key), keyword, rule_title, risk_level) self,
key: str,
keyword: str,
rule_title: Optional[str] = None,
risk_level: Optional[str] = None,
) -> List[Finding]:
return self._search_findings(
self._get_findings_bucket(key), keyword, rule_title, risk_level
)
def list_findings_grouped(self) -> Dict[str, List[Finding]]: def list_findings_grouped(self) -> Dict[str, List[Finding]]:
with self._lock: with self._lock:
...@@ -166,7 +177,9 @@ class MemoryStore: ...@@ -166,7 +177,9 @@ class MemoryStore:
with self._lock: with self._lock:
return list(target) return list(target)
def _get_findings_by_segment(self, target: List[Finding], segment_id: int) -> List[Finding]: def _get_findings_by_segment(
self, target: List[Finding], segment_id: int
) -> List[Finding]:
with self._lock: with self._lock:
return [f for f in target if f.segment_id == segment_id] return [f for f in target if f.segment_id == segment_id]
...@@ -192,21 +205,29 @@ class MemoryStore: ...@@ -192,21 +205,29 @@ class MemoryStore:
with self._lock: with self._lock:
candidates = list(target) candidates = list(target)
if rule_title: if rule_title:
candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()] candidates = [
f
for f in candidates
if (f.rule_title or "").lower() == rule_title.strip().lower()
]
if risk_level: if risk_level:
lvl = risk_level.strip().upper() lvl = risk_level.strip().upper()
candidates = [f for f in candidates if f.risk_level == lvl] candidates = [f for f in candidates if f.risk_level == lvl]
if not key: if not key:
return candidates return candidates
def _matches(f: Finding) -> bool: def _matches(f: Finding) -> bool:
hay = " ".join([ hay = " ".join(
[
f.rule_title, f.rule_title,
f.original_text, f.original_text,
f.issue, f.issue,
f.suggestion, f.suggestion,
f.result, f.result,
]).lower() ]
).lower()
return key in hay return key in hay
return [f for f in candidates if _matches(f)] return [f for f in candidates if _matches(f)]
# ------------------- housekeeping ------------------ # ------------------- housekeeping ------------------
...@@ -225,7 +246,9 @@ class MemoryStore: ...@@ -225,7 +246,9 @@ class MemoryStore:
}, },
} }
try: try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") self._storage_path.write_text(
json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8"
)
except Exception as exc: except Exception as exc:
logger.error("Failed to persist memory store: %s", exc) logger.error("Failed to persist memory store: %s", exc)
...@@ -243,7 +266,9 @@ class MemoryStore: ...@@ -243,7 +266,9 @@ class MemoryStore:
if isinstance(loaded_findings, dict): if isinstance(loaded_findings, dict):
for key, items in loaded_findings.items(): for key, items in loaded_findings.items():
normalized_key = self._normalize_finding_key(str(key)) normalized_key = self._normalize_finding_key(str(key))
findings_map[normalized_key] = [Finding.from_dict(item) for item in (items or [])] findings_map[normalized_key] = [
Finding.from_dict(item) for item in (items or [])
]
self.findings = findings_map self.findings = findings_map
needs_persist = False needs_persist = False
...@@ -262,7 +287,9 @@ class MemoryStore: ...@@ -262,7 +287,9 @@ class MemoryStore:
try: try:
from openpyxl import Workbook # type: ignore from openpyxl import Workbook # type: ignore
except ImportError as exc: except ImportError as exc:
raise ImportError("openpyxl is required for export_to_excel; install via 'pip install openpyxl'") from exc raise ImportError(
"openpyxl is required for export_to_excel; install via 'pip install openpyxl'"
) from exc
ts = datetime.now().strftime("%Y%m%d_%H%M%S") ts = datetime.now().strftime("%Y%m%d_%H%M%S")
name = file_name or f"memory_export_{ts}.xlsx" name = file_name or f"memory_export_{ts}.xlsx"
...@@ -285,21 +312,34 @@ class MemoryStore: ...@@ -285,21 +312,34 @@ class MemoryStore:
if grouped_items: if grouped_items:
first_key, first_values = grouped_items[0] first_key, first_values = grouped_items[0]
ws_first = wb.active ws_first = wb.active
first_sheet_name = _FINDING_KEY_SHEET_NAMES.get(self._normalize_finding_key(first_key), first_key) first_sheet_name = _FINDING_KEY_SHEET_NAMES.get(
self._normalize_finding_key(first_key), first_key
)
ws_first.title = self._safe_sheet_name(first_sheet_name) ws_first.title = self._safe_sheet_name(first_sheet_name)
ws_first.append([label for _, label in finding_headers]) ws_first.append([label for _, label in finding_headers])
for f in first_values: for f in first_values:
ws_first.append([getattr(f, key, "") for key, _ in finding_headers]) ws_first.append([getattr(f, key, "") for key, _ in finding_headers])
for key, values in grouped_items[1:]: for key, values in grouped_items[1:]:
sheet_name = _FINDING_KEY_SHEET_NAMES.get(self._normalize_finding_key(key), key) sheet_name = _FINDING_KEY_SHEET_NAMES.get(
self._normalize_finding_key(key), key
)
ws = wb.create_sheet(self._safe_sheet_name(sheet_name)) ws = wb.create_sheet(self._safe_sheet_name(sheet_name))
ws.append([label for _, label in finding_headers]) ws.append([label for _, label in finding_headers])
for f in values: for f in values:
ws.append([getattr(f, item_key, "") for item_key, _ in finding_headers]) ws.append(
[
getattr(f, item_key, "")
for item_key, _ in finding_headers
]
)
else: else:
ws_empty = wb.active ws_empty = wb.active
ws_empty.title = self._safe_sheet_name(_FINDING_KEY_SHEET_NAMES.get(_DEFAULT_REVIEW_KEY, _DEFAULT_REVIEW_KEY)) ws_empty.title = self._safe_sheet_name(
_FINDING_KEY_SHEET_NAMES.get(
_DEFAULT_REVIEW_KEY, _DEFAULT_REVIEW_KEY
)
)
ws_empty.append([label for _, label in finding_headers]) ws_empty.append([label for _, label in finding_headers])
ws_facts = wb.create_sheet("合同事实") ws_facts = wb.create_sheet("合同事实")
...@@ -310,7 +350,12 @@ class MemoryStore: ...@@ -310,7 +350,12 @@ class MemoryStore:
ws_facts.append(["事实", json.dumps(item, ensure_ascii=False)]) ws_facts.append(["事实", json.dumps(item, ensure_ascii=False)])
continue continue
meta_info = item.get(META_KEY, None) meta_info = item.get(META_KEY, None)
ws_facts.append([json.dumps(meta_info, ensure_ascii=False), json.dumps(item, ensure_ascii=False)]) ws_facts.append(
[
json.dumps(meta_info, ensure_ascii=False),
json.dumps(item, ensure_ascii=False),
]
)
else: else:
ws_facts.append(["元信息", "事实内容"]) ws_facts.append(["元信息", "事实内容"])
...@@ -435,16 +480,17 @@ def test_export_findings_to_doc_comments(doc_path: str) -> None: ...@@ -435,16 +480,17 @@ def test_export_findings_to_doc_comments(doc_path: str) -> None:
print("Export doc comments:") print("Export doc comments:")
print(json.dumps(res, ensure_ascii=False, indent=2)) print(json.dumps(res, ensure_ascii=False, indent=2))
def test_memory_and_export_excel(): def test_memory_and_export_excel():
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索 # 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore() store = MemoryStore()
store.add_facts({ store.add_facts(
{
"公司": {"甲方": "A 公司", "乙方": "B 公司"}, "公司": {"甲方": "A 公司", "乙方": "B 公司"},
"支付": {"方式": "银行转账", "期限": "验收后30日内"}, "支付": {"方式": "银行转账", "期限": "验收后30日内"},
META_KEY:{ META_KEY: {"segment_id": 1},
"segment_id":1
} }
}) )
# print( store.search_facts(['支付'])) # print( store.search_facts(['支付']))
finding1 = Finding( finding1 = Finding(
rule_title="违约责任", rule_title="违约责任",
...@@ -477,4 +523,3 @@ def test_memory_and_export_excel(): ...@@ -477,4 +523,3 @@ def test_memory_and_export_excel():
if __name__ == "__main__": if __name__ == "__main__":
# test_export_findings_to_doc_comments("/home/ccran/lufa-contract/tmp/股份转让协议.docx") # test_export_findings_to_doc_comments("/home/ccran/lufa-contract/tmp/股份转让协议.docx")
test_memory_and_export_excel() test_memory_and_export_excel()
...@@ -9,7 +9,6 @@ from core.tool import ToolBase, tool, tool_func ...@@ -9,7 +9,6 @@ from core.tool import ToolBase, tool, tool_func
from utils.excel_util import ExcelUtil from utils.excel_util import ExcelUtil
@tool("retrieve_reference", "审查参考检索") @tool("retrieve_reference", "审查参考检索")
class RetrieveReferenceTool(ToolBase): class RetrieveReferenceTool(ToolBase):
def __init__(self) -> None: def __init__(self) -> None:
...@@ -22,12 +21,16 @@ class RetrieveReferenceTool(ToolBase): ...@@ -22,12 +21,16 @@ class RetrieveReferenceTool(ToolBase):
"triggers": "触发词", "triggers": "触发词",
"suggestion_template": "建议模板", "suggestion_template": "建议模板",
"case": "案例", "case": "案例",
"summary":"摘要项" "summary": "摘要项",
} }
rules_path = Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx" rules_path = (
Path(__file__).resolve().parent.parent.parent / "data" / "rules.xlsx"
)
self.rulesets: Dict[str, List[Dict[str, Any]]] = {} self.rulesets: Dict[str, List[Dict[str, Any]]] = {}
for rs_id in ALL_RULESET_IDS: for rs_id in ALL_RULESET_IDS:
rules = ExcelUtil.load_mapped_excel(rules_path, sheet_name=rs_id, column_map=self.column_map) rules = ExcelUtil.load_mapped_excel(
rules_path, sheet_name=rs_id, column_map=self.column_map
)
self.rulesets[rs_id] = rules self.rulesets[rs_id] = rules
@tool_func( @tool_func(
...@@ -40,13 +43,21 @@ class RetrieveReferenceTool(ToolBase): ...@@ -40,13 +43,21 @@ class RetrieveReferenceTool(ToolBase):
"required": [], "required": [],
} }
) )
def run(self, ruleset_id: str = "", routed_rule_titles: List[str] | None = None) -> Dict[str, Any]: def run(
self, ruleset_id: str = "", routed_rule_titles: List[str] | None = None
) -> Dict[str, Any]:
target_ruleset_id = ruleset_id or self.default_ruleset_id target_ruleset_id = ruleset_id or self.default_ruleset_id
full_rules = self.rulesets.get(target_ruleset_id) or self.rulesets.get(self.default_ruleset_id, []) or [] full_rules = (
self.rulesets.get(target_ruleset_id)
or self.rulesets.get(self.default_ruleset_id, [])
or []
)
if routed_rule_titles is None: if routed_rule_titles is None:
rules = full_rules rules = full_rules
else: else:
title_set = {title for title in routed_rule_titles if isinstance(title, str)} title_set = {
title for title in routed_rule_titles if isinstance(title, str)
}
rules = [r for r in full_rules if r.get("title") in title_set] rules = [r for r in full_rules if r.get("title") in title_set]
return { return {
...@@ -59,6 +70,7 @@ class RetrieveReferenceTool(ToolBase): ...@@ -59,6 +70,7 @@ class RetrieveReferenceTool(ToolBase):
def summary_keywords(self, rules: List[Dict[str, Any]]) -> List[str]: def summary_keywords(self, rules: List[Dict[str, Any]]) -> List[str]:
return [r.get("summary", "") for r in rules if r.get("summary")] return [r.get("summary", "") for r in rules if r.get("summary")]
if __name__ == "__main__": if __name__ == "__main__":
tool = RetrieveReferenceTool() tool = RetrieveReferenceTool()
result = tool.run(ruleset_id="金盘", routed_rule_titles=None) result = tool.run(ruleset_id="金盘", routed_rule_titles=None)
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import re import re
import sys import sys
sys.path.append('../..') sys.path.append("../..")
import traceback import traceback
import concurrent.futures import concurrent.futures
...@@ -12,21 +12,21 @@ from loguru import logger ...@@ -12,21 +12,21 @@ from loguru import logger
from utils.common_util import random_str from utils.common_util import random_str
from utils.http_util import upload_file, fastgpt_openai_chat, download_file from utils.http_util import upload_file, fastgpt_openai_chat, download_file
# SUFFIX='_麓发迁移' SUFFIX = "_麓发迁移"
# batch_input_dir_path = 'jp-input' batch_input_dir_path = "jp-input"
# batch_output_dir_path = 'jp-output-lufa-new' batch_output_dir_path = "jp-output-lufa-new"
SUFFIX='_麓发' # SUFFIX = "_麓发"
batch_input_dir_path = 'lufa-input' # batch_input_dir_path = "lufa-input"
batch_output_dir_path = 'lufa-output' # batch_output_dir_path = "lufa-output"
batch_size = 5 batch_size = 5
# 麓发fastgpt接口 # 麓发fastgpt接口
url = 'http://192.168.252.71:18089/api/v1/chat/completions' # url = "http://192.168.252.71:18089/api/v1/chat/completions"
# 金盘fastgpt接口 # 金盘fastgpt接口
# url = 'http://192.168.252.71:18088/api/v1/chat/completions' url = "http://192.168.252.71:18088/api/v1/chat/completions"
# 麓发合同审查生产token # 麓发合同审查生产token
token = 'fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz' # token = "fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz"
# 金盘迁移麓发合同审查测试token # 金盘迁移麓发合同审查测试token
# token = 'fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyMK7' token = "fastgpt-vykT6qs07g7hR4tL2MNJE6DdNCIxaQjEu3Cxw9nuTBFg8MAG3CkByvnXKxSNEyMK7"
# 人机交互测试(测试环境) # 人机交互测试(测试环境)
# token = 'fastgpt-p189K5zoTX5wjp0dBybFCwsbWm3juIwlJxt2wTGyiaOWOANI5Y10pKEZzyt' # token = 'fastgpt-p189K5zoTX5wjp0dBybFCwsbWm3juIwlJxt2wTGyiaOWOANI5Y10pKEZzyt'
# 人机交互测试(生产环境) # 人机交互测试(生产环境)
...@@ -34,9 +34,13 @@ token = 'fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz' ...@@ -34,9 +34,13 @@ token = 'fastgpt-ek3Z6PxI6sXgYc0jxzZ5bVGqrxwM6aVyfSmA6JVErJYBMr2KmYxrHwEUOIMSYz'
# 提取后审查测试 # 提取后审查测试
# token = 'fastgpt-n74gGX5ZqLT6o1ysMBSGUTjIciswYOWDRfQ75krMkE5gDVDkpzsbz8u' # token = 'fastgpt-n74gGX5ZqLT6o1ysMBSGUTjIciswYOWDRfQ75krMkE5gDVDkpzsbz8u'
def extract_url(text): def extract_url(text):
# \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx)) # \s * ([ ^ "\s]+?\.(?:docx?|pdf|xlsx))
excel_p, doc_p = r'最终审查Excel\s*([^"]*xlsx)', r'最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))' excel_p, doc_p = (
r'最终审查Excel\s*([^"]*xlsx)',
r"最终审查批注\s*([^\" ]+?\.(?:docx?|pdf|wps))",
)
# 使用 re.search() 查找第一个匹配项 # 使用 re.search() 查找第一个匹配项
excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text) excel_m, doc_m = re.search(excel_p, text), re.search(doc_p, text)
if excel_m and doc_m: if excel_m and doc_m:
...@@ -46,7 +50,9 @@ def extract_url(text): ...@@ -46,7 +50,9 @@ def extract_url(text):
return None, None return None, None
def process_single_file(file, batch_input_dir_path, batch_output_dir_path, counter, start_file): def process_single_file(
file, batch_input_dir_path, batch_output_dir_path, counter, start_file
):
""" """
单文件处理逻辑,可被线程池并发调用 单文件处理逻辑,可被线程池并发调用
""" """
...@@ -55,29 +61,45 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count ...@@ -55,29 +61,45 @@ def process_single_file(file, batch_input_dir_path, batch_output_dir_path, count
return return
# 提取文件前缀 # 提取文件前缀
file_name = file[:file.rfind('.')] file_name = file[: file.rfind(".")]
ext_name = file[file.rfind('.'):] ext_name = file[file.rfind(".") :]
# 源目标处理 # 源目标处理
original_file = f'{batch_input_dir_path}/{file}' original_file = f"{batch_input_dir_path}/{file}"
des_check_file = f'{batch_output_dir_path}/{file_name}.md' des_check_file = f"{batch_output_dir_path}/{file_name}.md"
des_excel_file = f'{batch_output_dir_path}/{file_name}{SUFFIX}.xlsx' des_excel_file = f"{batch_output_dir_path}/{file_name}{SUFFIX}.xlsx"
des_doc_file = f'{batch_output_dir_path}/{file_name}{SUFFIX}{ext_name}' des_doc_file = f"{batch_output_dir_path}/{file_name}{SUFFIX}{ext_name}"
try: try:
# 处理原文件 # 处理原文件
file_url = upload_file(original_file, input_url_to_inner=True).replace('218.77.58.8', '192.168.252.71') file_url = upload_file(original_file, input_url_to_inner=True).replace(
model = 'Qwen2-72B-Instruct' "218.77.58.8", "192.168.252.71"
)
model = "Qwen2-72B-Instruct"
# 合同审核Excel工作流处理 # 合同审核Excel工作流处理
logger.info(' 第{}个文件,处理文件: {}'.format(counter, original_file)) logger.info(" 第{}个文件,处理文件: {}".format(counter, original_file))
result = fastgpt_openai_chat(url, token, model, random_str(), file_url, f'测试批处理任务-{file_name}', False) result = fastgpt_openai_chat(
url,
token,
model,
random_str(),
file_url,
f"测试批处理任务-{file_name}",
False,
)
excel_url, doc_url = extract_url(result) excel_url, doc_url = extract_url(result)
if excel_url and doc_url: if excel_url and doc_url:
download_file(excel_url.replace('218.77.58.8', '192.168.252.71'), des_excel_file) download_file(
download_file(doc_url.replace('218.77.58.8', '192.168.252.71'), des_doc_file) excel_url.replace("218.77.58.8", "192.168.252.71"), des_excel_file
logger.info(f'第{counter}个文件下载:{excel_url}到{des_excel_file} {des_doc_file}') )
download_file(
doc_url.replace("218.77.58.8", "192.168.252.71"), des_doc_file
)
logger.info(
f"第{counter}个文件下载:{excel_url}到{des_excel_file} {des_doc_file}"
)
except Exception as e: except Exception as e:
logger.error(f'{original_file} 处理异常 第{counter}个文件: {e}') logger.error(f"{original_file} 处理异常 第{counter}个文件: {e}")
logger.error(traceback.print_exc()) logger.error(traceback.print_exc())
...@@ -103,5 +125,5 @@ def execute_batch(max_workers: int = 4): ...@@ -103,5 +125,5 @@ def execute_batch(max_workers: int = 4):
f.result() f.result()
if __name__ == '__main__': if __name__ == "__main__":
execute_batch(batch_size) execute_batch(batch_size)
...@@ -6,6 +6,7 @@ from contextlib import redirect_stdout, redirect_stderr ...@@ -6,6 +6,7 @@ from contextlib import redirect_stdout, redirect_stderr
fuzz_score_threshold = 80 fuzz_score_threshold = 80
def _normalize_cell(value: object) -> str: def _normalize_cell(value: object) -> str:
if pd.isna(value): if pd.isna(value):
return "" return ""
...@@ -69,12 +70,12 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -69,12 +70,12 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
best_score = -1 best_score = -1
for idx, cand in enumerate(candidates): for idx, cand in enumerate(candidates):
ans_text = ans_text.strip() ans_text = ans_text.strip()
if cand is None or not isinstance(cand,str): if cand is None or not isinstance(cand, str):
continue continue
cand = cand.strip() cand = cand.strip()
score = max( score = max(
fuzz.partial_ratio(ans_text, cand), fuzz.partial_ratio(ans_text, cand),
fuzz.token_set_ratio(ans_text, cand) fuzz.token_set_ratio(ans_text, cand),
) )
if score > best_score: if score > best_score:
best_score = score best_score = score
...@@ -83,7 +84,9 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -83,7 +84,9 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
if best_score >= fuzz_score_threshold: if best_score >= fuzz_score_threshold:
matched_total += 1 matched_total += 1
matched_val = candidates.pop(best_idx) matched_val = candidates.pop(best_idx)
matched_by_item.setdefault(item, []).append((ans_text, matched_val, best_score)) matched_by_item.setdefault(item, []).append(
(ans_text, matched_val, best_score)
)
else: else:
unmatched_answer_by_item.setdefault(item, []).append(ans_text) unmatched_answer_by_item.setdefault(item, []).append(ans_text)
...@@ -103,8 +106,14 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -103,8 +106,14 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values()) unmatched_answer_count = sum(len(v) for v in unmatched_answer_by_item.values())
file_precision = (matched_total / val_total) if val_total != 0 else 0 file_precision = (matched_total / val_total) if val_total != 0 else 0
file_recall = (matched_total / answer_total) if answer_total != 0 else 0 file_recall = (matched_total / answer_total) if answer_total != 0 else 0
file_f1 = (2 * file_precision * file_recall / (file_precision + file_recall)) if (file_precision + file_recall) else 0 file_f1 = (
file_false_positive_rate = (unmatched_val_count / val_total) if val_total != 0 else 0 (2 * file_precision * file_recall / (file_precision + file_recall))
if (file_precision + file_recall)
else 0
)
file_false_positive_rate = (
(unmatched_val_count / val_total) if val_total != 0 else 0
)
# 累加到各“审查项”的全局统计 # 累加到各“审查项”的全局统计
for it, cnt in answer_counts.items(): for it, cnt in answer_counts.items():
...@@ -112,19 +121,28 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -112,19 +121,28 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
for it, lst in matched_by_item.items(): for it, lst in matched_by_item.items():
overall_item_matched[it] = overall_item_matched.get(it, 0) + len(lst) overall_item_matched[it] = overall_item_matched.get(it, 0) + len(lst)
for it, lst in unmatched_answer_by_item.items(): for it, lst in unmatched_answer_by_item.items():
overall_item_unmatched_answer[it] = overall_item_unmatched_answer.get(it, 0) + len(lst) overall_item_unmatched_answer[it] = overall_item_unmatched_answer.get(
it, 0
) + len(lst)
for it, lst in unmatched_val_by_item.items(): for it, lst in unmatched_val_by_item.items():
overall_item_unmatched_val[it] = overall_item_unmatched_val.get(it, 0) + len(lst) overall_item_unmatched_val[it] = overall_item_unmatched_val.get(
print('#' * 40) it, 0
) + len(lst)
print("#" * 40)
print( print(
f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} " f"{val_file.name}: matched {matched_total} | val {val_total} | answer {answer_total} "
f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | precision {file_precision:.2%} | recall {file_recall:.2%} | f1 {file_f1:.2%} | false_positive_rate {file_false_positive_rate:.2%}" f"| unmatched val {unmatched_val_count} | unmatched answer {unmatched_answer_count} | precision {file_precision:.2%} | recall {file_recall:.2%} | f1 {file_f1:.2%} | false_positive_rate {file_false_positive_rate:.2%}"
) )
import json import json
print(f'unmatched_val_by_item: {json.dumps(unmatched_val_by_item, ensure_ascii=False, indent=2)}')
print(
f"unmatched_val_by_item: {json.dumps(unmatched_val_by_item, ensure_ascii=False, indent=2)}"
)
for item in sorted(answer_counts): for item in sorted(answer_counts):
item_matches = matched_by_item.get(item, []) item_matches = matched_by_item.get(item, [])
print(f" 审查项 {item}: matched {len(item_matches)} / {answer_counts[item]}") print(
f" 审查项 {item}: matched {len(item_matches)} / {answer_counts[item]}"
)
# 匹配成功的结果 # 匹配成功的结果
# for ans_text, val_text, score in item_matches: # for ans_text, val_text, score in item_matches:
# print(f" {score}% | answer: {ans_text} | val: {val_text}") # print(f" {score}% | answer: {ans_text} | val: {val_text}")
...@@ -144,16 +162,25 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -144,16 +162,25 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
precision = overall_matched / overall_val if overall_val else 0 precision = overall_matched / overall_val if overall_val else 0
recall = overall_matched / overall_answer if overall_answer else 0 recall = overall_matched / overall_answer if overall_answer else 0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0 f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0
overall_false_positive_rate = (overall_val - overall_matched) / overall_val if overall_val else 0 overall_false_positive_rate = (
(overall_val - overall_matched) / overall_val if overall_val else 0
)
print( print(
f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | precision {precision:.2%} | recall {recall:.2%} | f1 {f1:.2%}" f"Overall: matched {overall_matched} | val {overall_val} | answer {overall_answer} | precision {precision:.2%} | recall {recall:.2%} | f1 {f1:.2%}"
) )
# 按“审查项”的 overall 结果 # 按“审查项”的 overall 结果
if overall_item_answer: if overall_item_answer:
print('#' * 40) print("#" * 40)
print("Overall by item:") print("Overall by item:")
all_items = sorted(set(list(overall_item_answer.keys()) + list(overall_item_matched.keys()) + list(overall_item_unmatched_answer.keys()) + list(overall_item_unmatched_val.keys()))) all_items = sorted(
set(
list(overall_item_answer.keys())
+ list(overall_item_matched.keys())
+ list(overall_item_unmatched_answer.keys())
+ list(overall_item_unmatched_val.keys())
)
)
rows_by_item = [] rows_by_item = []
for it in all_items: for it in all_items:
ans = overall_item_answer.get(it, 0) ans = overall_item_answer.get(it, 0)
...@@ -162,9 +189,14 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -162,9 +189,14 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
u_val = overall_item_unmatched_val.get(it, 0) u_val = overall_item_unmatched_val.get(it, 0)
item_precision = (mat / (mat + u_val)) if (mat + u_val) else 0 item_precision = (mat / (mat + u_val)) if (mat + u_val) else 0
acc = (mat / ans) if ans else 0 acc = (mat / ans) if ans else 0
item_f1 = (2 * item_precision * acc / (item_precision + acc)) if (item_precision + acc) else 0 item_f1 = (
(2 * item_precision * acc / (item_precision + acc))
if (item_precision + acc)
else 0
)
item_false_positive_rate = u_val / (mat + u_val) if (mat + u_val) else 0 item_false_positive_rate = u_val / (mat + u_val) if (mat + u_val) else 0
rows_by_item.append({ rows_by_item.append(
{
"审查项": it, "审查项": it,
"大模型匹配上的不合格项": mat, "大模型匹配上的不合格项": mat,
"合同所有不合格项": ans, "合同所有不合格项": ans,
...@@ -174,18 +206,45 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -174,18 +206,45 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"查全率(B/C)": acc, "查全率(B/C)": acc,
"F1": item_f1, "F1": item_f1,
"误报率(D/B+D)": item_false_positive_rate, "误报率(D/B+D)": item_false_positive_rate,
}) }
)
print( print(
f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | precision {item_precision:.2%} | recall {acc:.2%} | f1 {item_f1:.2%}" f" 审查项 {it}: matched {mat} / answer {ans} | unmatched val {u_val} | unmatched answer {u_ans} | precision {item_precision:.2%} | recall {acc:.2%} | f1 {item_f1:.2%}"
) )
overall_by_item_df = pd.DataFrame(rows_by_item, columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查准率(B/B+D)", "查全率(B/C)", "F1", "误报率(D/B+D)"]) overall_by_item_df = pd.DataFrame(
rows_by_item,
columns=[
"审查项",
"大模型匹配上的不合格项",
"合同所有不合格项",
"大模型其他不合格项",
"大模型未匹配上的不合格项(C-B)",
"查准率(B/B+D)",
"查全率(B/C)",
"F1",
"误报率(D/B+D)",
],
)
unmatched_val_total = sum(overall_item_unmatched_val.values()) unmatched_val_total = sum(overall_item_unmatched_val.values())
unmatched_answer_total = sum(overall_item_unmatched_answer.values()) unmatched_answer_total = sum(overall_item_unmatched_answer.values())
overall_precision = overall_matched / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0 overall_precision = (
overall_f1 = (2 * overall_precision * recall / (overall_precision + recall)) if (overall_precision + recall) else 0 overall_matched / (overall_matched + unmatched_val_total)
overall_invalid_rate = unmatched_val_total / (overall_matched + unmatched_val_total) if (overall_matched + unmatched_val_total) else 0 if (overall_matched + unmatched_val_total)
overall_total_df = pd.DataFrame([ else 0
)
overall_f1 = (
(2 * overall_precision * recall / (overall_precision + recall))
if (overall_precision + recall)
else 0
)
overall_invalid_rate = (
unmatched_val_total / (overall_matched + unmatched_val_total)
if (overall_matched + unmatched_val_total)
else 0
)
overall_total_df = pd.DataFrame(
[
{ {
"审查项": "总体", "审查项": "总体",
"大模型匹配上的不合格项": overall_matched, "大模型匹配上的不合格项": overall_matched,
...@@ -197,8 +256,22 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None: ...@@ -197,8 +256,22 @@ def _compare_impl(val_dir: Path, answer_dir: Path) -> None:
"F1": overall_f1, "F1": overall_f1,
"误报率(D/B+D)": overall_invalid_rate, "误报率(D/B+D)": overall_invalid_rate,
} }
], columns=["审查项", "大模型匹配上的不合格项", "合同所有不合格项", "大模型其他不合格项", "大模型未匹配上的不合格项(C-B)", "查准率(B/B+D)", "查全率(B/C)", "F1", "误报率(D/B+D)"]) ],
combined_df = pd.concat([overall_by_item_df, overall_total_df], ignore_index=True) columns=[
"审查项",
"大模型匹配上的不合格项",
"合同所有不合格项",
"大模型其他不合格项",
"大模型未匹配上的不合格项(C-B)",
"查准率(B/B+D)",
"查全率(B/C)",
"F1",
"误报率(D/B+D)",
],
)
combined_df = pd.concat(
[overall_by_item_df, overall_total_df], ignore_index=True
)
compare_dir_name = val_dir.name compare_dir_name = val_dir.name
results_dir = Path(__file__).parent / "results" results_dir = Path(__file__).parent / "results"
...@@ -213,7 +286,9 @@ def compare(val_dir: Path, answer_dir: Path) -> None: ...@@ -213,7 +286,9 @@ def compare(val_dir: Path, answer_dir: Path) -> None:
_compare_impl(val_dir=val_dir, answer_dir=answer_dir) _compare_impl(val_dir=val_dir, answer_dir=answer_dir)
def compare_with_log(val_dir: Path, answer_dir: Path, log_path: Path | None = None) -> Path: def compare_with_log(
val_dir: Path, answer_dir: Path, log_path: Path | None = None
) -> Path:
val_dir = val_dir.resolve() val_dir = val_dir.resolve()
if log_path is None: if log_path is None:
results_dir = Path(__file__).parent / "results" results_dir = Path(__file__).parent / "results"
...@@ -223,7 +298,9 @@ def compare_with_log(val_dir: Path, answer_dir: Path, log_path: Path | None = No ...@@ -223,7 +298,9 @@ def compare_with_log(val_dir: Path, answer_dir: Path, log_path: Path | None = No
log_path = log_path.resolve() log_path = log_path.resolve()
log_path.parent.mkdir(parents=True, exist_ok=True) log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "w", encoding="utf-8") as f, redirect_stdout(f), redirect_stderr(f): with open(log_path, "w", encoding="utf-8") as f, redirect_stdout(
f
), redirect_stderr(f):
_compare_impl(val_dir=val_dir, answer_dir=answer_dir) _compare_impl(val_dir=val_dir, answer_dir=answer_dir)
return log_path return log_path
...@@ -231,7 +308,9 @@ def compare_with_log(val_dir: Path, answer_dir: Path, log_path: Path | None = No ...@@ -231,7 +308,9 @@ def compare_with_log(val_dir: Path, answer_dir: Path, log_path: Path | None = No
def _parse_args() -> argparse.Namespace: def _parse_args() -> argparse.Namespace:
base = Path(__file__).parent base = Path(__file__).parent
parser = argparse.ArgumentParser(description="Compare extracted annotations with answers.") parser = argparse.ArgumentParser(
description="Compare extracted annotations with answers."
)
parser.add_argument( parser.add_argument(
"--val-dir", "--val-dir",
type=Path, type=Path,
...@@ -252,6 +331,7 @@ def _parse_args() -> argparse.Namespace: ...@@ -252,6 +331,7 @@ def _parse_args() -> argparse.Namespace:
) )
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
args = _parse_args() args = _parse_args()
final_log_path = compare_with_log( final_log_path = compare_with_log(
......
...@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace: ...@@ -121,7 +121,7 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--datasets-dir", "--datasets-dir",
type=Path, type=Path,
default=base / "results" / "jp-output-renji", default=base / "results" / "jp-output-lufa",
help="Directory containing Word files with annotations.", help="Directory containing Word files with annotations.",
) )
parser.add_argument( parser.add_argument(
...@@ -133,7 +133,7 @@ def _parse_args() -> argparse.Namespace: ...@@ -133,7 +133,7 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--val-dir", "--val-dir",
type=Path, type=Path,
default=base / "results" / "jp-output-renji-extracted", default=base / "results" / "jp-output-lufa-extracted",
help="Directory to store extracted xlsx files for comparison.", help="Directory to store extracted xlsx files for comparison.",
) )
parser.add_argument( parser.add_argument(
......
No preview for this file type
...@@ -9,7 +9,9 @@ class DocBase(ABC): ...@@ -9,7 +9,9 @@ class DocBase(ABC):
self._doc_path = None self._doc_path = None
self._doc_name = None self._doc_name = None
self._kwargs = kwargs self._kwargs = kwargs
self._max_single_chunk_size = kwargs.get('max_single_chunk_size', MAX_SINGLE_CHUNK_SIZE) self._max_single_chunk_size = kwargs.get(
"max_single_chunk_size", MAX_SINGLE_CHUNK_SIZE
)
@abstractmethod @abstractmethod
def load(self, doc_path): def load(self, doc_path):
......
...@@ -509,10 +509,10 @@ class SpireWordDoc(DocBase): ...@@ -509,10 +509,10 @@ class SpireWordDoc(DocBase):
cell_list.append(cell_content) cell_list.append(cell_content)
# table_data += "|" + "|".join(cell_list) + "|" # table_data += "|" + "|".join(cell_list) + "|"
# table_data += "\n" # table_data += "\n"
table_data += ' '.join(cell_list) + '\n' table_data += " ".join(cell_list) + "\n"
if i == 0: if i == 0:
# table_data += "|" + "|".join(["--- " for _ in cell_list]) + "|\n" # table_data += "|" + "|".join(["--- " for _ in cell_list]) + "|\n"
table_data= ' '.join(cell_list) + '\n' table_data = " ".join(cell_list) + "\n"
return table_data return table_data
def get_chunk_info(self, chunk_id): def get_chunk_info(self, chunk_id):
...@@ -608,14 +608,18 @@ class SpireWordDoc(DocBase): ...@@ -608,14 +608,18 @@ class SpireWordDoc(DocBase):
return True return True
def _update_comment_content(self, comment_idx, suggest): def _update_comment_content(self, comment_idx, suggest):
self._doc.Comments.get_Item(comment_idx).Body.Paragraphs.get_Item(0).Text = suggest self._doc.Comments.get_Item(comment_idx).Body.Paragraphs.get_Item(
0
).Text = suggest
def _try_add_comment_in_paragraphs(self, paragraphs, target_text, author, suggest): def _try_add_comment_in_paragraphs(self, paragraphs, target_text, author, suggest):
if not target_text: if not target_text:
return False return False
for paragraph in paragraphs: for paragraph in paragraphs:
text_sel = paragraph.Find(target_text, False, True) text_sel = paragraph.Find(target_text, False, True)
if text_sel and self.set_comment_by_text_selection(text_sel, author, suggest): if text_sel and self.set_comment_by_text_selection(
text_sel, author, suggest
):
return True return True
return False return False
...@@ -767,8 +771,11 @@ class SpireWordDoc(DocBase): ...@@ -767,8 +771,11 @@ class SpireWordDoc(DocBase):
# update chunk_id # update chunk_id
comment_chunk_id = comment.get("chunk_id", -1) comment_chunk_id = comment.get("chunk_id", -1)
# 优先使用comments里提供的chunk_id,如果没有或无效则使用外部传入的chunk_id,如果都没有则异常处理 # 优先使用comments里提供的chunk_id,如果没有或无效则使用外部传入的chunk_id,如果都没有则异常处理
sub_chunks = self.get_sub_chunks(comment_chunk_id) if comment_chunk_id != -1 \ sub_chunks = (
and comment_chunk_id < self.get_chunk_num() else self.get_sub_chunks(chunk_id) self.get_sub_chunks(comment_chunk_id)
if comment_chunk_id != -1 and comment_chunk_id < self.get_chunk_num()
else self.get_sub_chunks(chunk_id)
)
author = self.format_comment_author(comment) author = self.format_comment_author(comment)
suggest = comment.get("suggest", "") suggest = comment.get("suggest", "")
find_key = comment["original_text"].strip() or comment["key_points"] find_key = comment["original_text"].strip() or comment["key_points"]
...@@ -808,7 +815,9 @@ class SpireWordDoc(DocBase): ...@@ -808,7 +815,9 @@ class SpireWordDoc(DocBase):
normalized_author = self._normalize_author_prefix(author) normalized_author = self._normalize_author_prefix(author)
for i in range(self._doc.Comments.Count): for i in range(self._doc.Comments.Count):
current_comment = self._doc.Comments.get_Item(i) current_comment = self._doc.Comments.get_Item(i)
comment_author = self._normalize_author_prefix(current_comment.Format.Author) comment_author = self._normalize_author_prefix(
current_comment.Format.Author
)
if comment_author == normalized_author: if comment_author == normalized_author:
return i return i
return None return None
...@@ -876,9 +885,7 @@ class SpireWordDoc(DocBase): ...@@ -876,9 +885,7 @@ class SpireWordDoc(DocBase):
if __name__ == "__main__": if __name__ == "__main__":
doc = SpireWordDoc() doc = SpireWordDoc()
doc.load( doc.load(r"/home/ccran/lufa-contract/demo/今麦郎合同审核.docx")
r"/home/ccran/lufa-contract/demo/今麦郎合同审核.docx"
)
print(doc._doc_name) print(doc._doc_name)
print("附件2《技术协议》" in doc.get_all_text()) print("附件2《技术协议》" in doc.get_all_text())
# doc.add_chunk_comment( # doc.add_chunk_comment(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment