Commit 53a63f97 by ccran

Initial commit

parents
tmp/
\ No newline at end of file
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: 当前文件",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
}
]
}
\ No newline at end of file
from utils.spire_word_util import SpireWordDoc
from utils.doc_util import DocBase
from functools import lru_cache
from typing import Optional
from core.memory import MemoryStore
MAX_CACHE = 128
@lru_cache(maxsize=MAX_CACHE)
def get_cached_doc_tool(conversation_id: str) -> Optional[DocBase]:
return SpireWordDoc()
@lru_cache(maxsize=MAX_CACHE)
def get_cached_memory(conversation_id: str) -> MemoryStore:
return MemoryStore(f'memory_store_{conversation_id}.json')
\ No newline at end of file
import platform
from dataclasses import dataclass
# 可配置运行参数
use_docker = False
just_b_class = False
is_extract = True
use_original_text_verification = False
# @dataclass
# class LLMConfig:
# base_url: str = "http://172.21.107.45:9002/v1"
# api_key: str = "none"
# model: str = "Qwen3-32B"
#
#
# base_fastgpt_url = "http://172.21.107.45:3030"
# base_backend_url = "http://172.21.107.45:48080"
# ocr_url = 'http://172.21.107.45:8202/openapi/ocrUploadFile'
@dataclass
class LLMConfig:
base_url: str = "http://192.168.252.71:9002/v1"
api_key: str = "none"
model: str = 'Qwen2-72B-Instruct'
outer_backend_url = "http://218.77.58.8:48081"
base_fastgpt_url = "http://192.168.252.71:18089"
base_backend_url = "http://192.168.252.71:48081"
ocr_url = 'http://192.168.252.71:8202/openapi/ocrUploadFile'
# 项目根目录
root_path = r"E:\PycharmProject\contract_review_agent"
system = platform.system()
if system == "Linux":
# root_path = "/data/home/ccran/contract_review_agent"
root_path = '/home/ccran/contract_review_agent'
elif system == "Darwin":
root_path = "/Users/chenran/PycharmProjects/contract_review_agent"
# docker设置
if use_docker:
root_path = '/app'
MAX_WORKERS = 20
LLM = {
"base_tool_llm": LLMConfig(),
"fastgpt_segment_review": LLMConfig(
base_url=f"{base_fastgpt_url}/api/v1",
api_key="fastgpt-zMavJKKgqA9jRNHLXxzXCVZx1JXxfuNkH1p2qfLhtPfMp41UvdSQvt8",
)
}
doc_support_formats = [".docx", ".doc", ".wps"]
pdf_support_formats = [".txt", ".md", ".pdf"]
# excel字段
field_mapping = {
"review": {
"id": "序号",
"key_points": "审查内容",
"original_text": "合同原文",
"details": "审查过程",
"result": "审查结果",
"suggest": "审查建议",
"chunk_id": "文件块序号",
"chunk_info": "文件块信息",
},
"category": {"text": "原文", "judge": "判断依据", "chunk_location": "原文位置"},
}
excel_widths = {"review": [5, 20, 80, 80, 20, 80, 5, 60], "category": [80, 80, 30]}
max_review_group = 5
# 销售类别判断
ocr_thresholds = 1000
all_rule_sheet = ["内销或出口", "内销", "出口", "反思"]
reflection_sheet = "反思"
# 最大分片数量
min_single_chunk_size = 2000
max_single_chunk_size = 20000
max_chunk_page = 10
# 知识库读取sheet
if just_b_class:
know_sheet_name = "B类"
port = 9008
else:
if is_extract:
know_sheet_name = "提取审查"
port = 9016
else:
know_sheet_name = 0
port = 9006
# know_sheet_name = '发票审查'
reload = False
"""
{
"segment_id": "S12",
"topics": ["付款"],
"summary": {
"one_liner": "...",
"structured": ["...", "...", "..."]
},
"risks": [
{"flag":"...", "level":"M", "evidence":"...", "suggestion":"..."}
],
"dependencies": ["S11"],
"open_questions": ["..."],
"evidence_quotes": ["..."]
}
segment_id
含义:该记忆对应的合同分段唯一标识。
注意:用于前后段上下文关联,必须与分段阶段生成的 ID 一致。
topics
含义:本分段涉及的合同主题标签。
注意:从固定枚举中选择(如付款/违约/保密),用于后续记忆检索与一致性校验。
summary
summary.one_liner
含义:一句话概括本段条款的核心约定内容。
注意:只描述事实,不做风险判断。
summary.structured
含义:本段条款的结构化要点列表(3–5 条)。
注意:每条对应一个业务点,便于检索和上下文理解。
risks
含义:本分段确认存在的风险点列表。
字段说明:
flag:风险标签(简短描述问题)
level:风险等级(H / M / L)
evidence:支撑该风险判断的原文依据
suggestion:可直接落地的修改或谈判建议
注意:每条风险必须有明确证据和可执行建议。
dependencies
含义:本分段理解或适用所依赖的其他条款或附件。
注意:用于后续自动回读相关条款,防止断章取义。
open_questions
含义:本分段中未明确、需对方补充或澄清的信息。
注意:不确定内容不要写成风险结论,统一放在这里。
evidence_quotes
含义:从合同原文中摘录的关键短句,用于支撑摘要和风险判断。
注意:用于审查复核与反思重试,避免“凭空总结”。
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from threading import RLock
from typing import Any, Dict, Iterable, List, Optional
from utils.http_util import upload_file
logger = logging.getLogger(__name__)
_ALLOWED_RISK_LEVELS = {"H", "M", "L"}
@dataclass
class RiskFinding:
rule_title: str
segment_id: int
original_text: str
issue: str
risk_level: str
suggestion: str
def __post_init__(self) -> None:
level = (self.risk_level or "").upper()
if level not in _ALLOWED_RISK_LEVELS:
raise ValueError(f"risk_level must be one of {_ALLOWED_RISK_LEVELS}, got {self.risk_level}")
self.risk_level = level
@classmethod
def from_dict(cls, data: Dict) -> "RiskFinding":
data = data or {}
return cls(
rule_title=str(data.get("rule_title", "")),
segment_id=int(data.get("segment_id", 0) or 0),
original_text=str(data.get("original_text", "")),
issue=str(data.get("issue", "")),
risk_level=str(data.get("risk_level", "")),
suggestion=str(data.get("suggestion", "")),
)
@dataclass
class MemoryStore:
"""简化的记忆存储:合同事实 facts 与问题 findings。线程安全并支持 JSON 持久化。"""
storage_name: Optional[Path] = 'default.json'
def __init__(self,storage_name:str = 'default.json') -> None:
self._storage_path = Path(__file__).resolve().parent.parent / "tmp" / storage_name # type: ignore[arg-type]
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = RLock()
self.facts: List[Dict[str, Any]] = []
self.findings: List[RiskFinding] = []
self._load()
# ---------------------- facts ----------------------
def set_facts(self, facts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
with self._lock:
self.facts = facts or []
self._persist()
return self.facts
def add_facts(self, partial: Dict[str, Any]) -> List[Dict[str, Any]]:
with self._lock:
self.facts.append(partial)
self._persist()
return self.facts
def get_facts(self) -> List[Dict[str, Any]]:
with self._lock:
return self.facts # deep copy
# -------------------- findings ---------------------
def add_finding(self, finding: RiskFinding) -> RiskFinding:
with self._lock:
self.findings.append(finding)
self._persist()
return finding
def add_finding_from_dict(self, data: Dict) -> RiskFinding:
return self.add_finding(RiskFinding.from_dict(data))
def list_findings(self) -> List[RiskFinding]:
with self._lock:
return list(self.findings)
def get_findings_by_segment(self, segment_id: int) -> List[RiskFinding]:
with self._lock:
return [f for f in self.findings if f.segment_id == segment_id]
def delete_findings_by_segment(self, segment_id: int) -> int:
with self._lock:
before = len(self.findings)
self.findings = [f for f in self.findings if f.segment_id != segment_id]
removed = before - len(self.findings)
if removed:
self._persist()
return removed
def search_findings(self, keyword: str, rule_title: Optional[str] = None, risk_level: Optional[str] = None) -> List[RiskFinding]:
key = (keyword or "").strip().lower()
with self._lock:
candidates = list(self.findings)
if rule_title:
candidates = [f for f in candidates if (f.rule_title or "").lower() == rule_title.strip().lower()]
if risk_level:
lvl = risk_level.strip().upper()
candidates = [f for f in candidates if f.risk_level == lvl]
if not key:
return candidates
def _matches(f: RiskFinding) -> bool:
hay = " ".join([
f.rule_title,
f.original_text,
f.issue,
f.suggestion,
]).lower()
return key in hay
return [f for f in candidates if _matches(f)]
# ------------------- housekeeping ------------------
def clear(self) -> None:
with self._lock:
self.facts.clear()
self.findings.clear()
self._persist()
def _persist(self) -> None:
payload = {
"facts": self.facts,
"findings": [asdict(f) for f in self.findings],
}
try:
self._storage_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
except Exception as exc:
logger.error("Failed to persist memory store: %s", exc)
def _load(self) -> None:
try:
if not self._storage_path.exists():
return
raw = self._storage_path.read_text(encoding="utf-8")
data = json.loads(raw or "{}")
if isinstance(data, dict):
self.facts = data.get("facts") or []
self.findings = [RiskFinding.from_dict(item) for item in data.get("findings", []) or []]
except Exception as exc:
logger.error("Failed to load memory store: %s", exc)
def export_to_excel(self, file_name: Optional[str] = None) -> Dict[str, Any]:
"""Export findings and facts to Excel, upload, then delete the local file."""
try:
from openpyxl import Workbook # type: ignore
except ImportError as exc:
raise ImportError("openpyxl is required for export_to_excel; install via 'pip install openpyxl'") from exc
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
name = file_name or f"memory_export_{ts}.xlsx"
output_path = Path(__file__).resolve().parent.parent / "tmp" / name
with self._lock:
wb = Workbook()
ws_findings = wb.active
ws_findings.title = "findings"
finding_headers = [
("rule_title", "规则标题"),
("segment_id", "分段ID"),
("original_text", "原文"),
("issue", "问题描述"),
("risk_level", "风险等级"),
("suggestion", "建议"),
]
ws_findings.append([label for _, label in finding_headers])
for f in self.findings:
ws_findings.append([
getattr(f, key, "") for key, _ in finding_headers
])
ws_facts = wb.create_sheet("facts")
if self.facts:
fact_keys: List[str] = sorted({k for item in self.facts for k in item.keys()})
if "段落" in fact_keys:
fact_keys = ["段落"] + [k for k in fact_keys if k != "段落"]
ws_facts.append(fact_keys)
for item in self.facts:
row = []
for key in fact_keys:
value = item.get(key)
if isinstance(value, (dict, list)):
row.append(json.dumps(value, ensure_ascii=False))
else:
row.append(value)
ws_facts.append(row)
else:
ws_facts.append(["data"])
wb.save(output_path)
try:
res = upload_file(str(output_path))
finally:
try:
output_path.unlink()
except Exception:
logger.warning("Failed to delete temp excel: %s", output_path)
return res
if __name__ == "__main__":
# 简单示例:设置事实 -> 写入问题 -> 读取/搜索
store = MemoryStore()
store.set_facts([{
"公司": {"甲方": "A 公司", "乙方": "B 公司"},
"支付": [
{"方式": "银行转账", "期限": "验收后30日内"}
],
"段落":1
},{
"纠纷": {"解决方式": "诉讼", "地址": "原告方所在地法院"},
"段落":2
}])
finding = RiskFinding(
rule_title="违约责任",
segment_id=1,
original_text="违约方应赔偿全部损失",
issue="未约定违约金上限,可能导致赔偿范围过大",
risk_level="H",
suggestion="建议增加‘赔偿金额不超过合同总额的30%’",
)
store.add_finding(finding)
print("Facts:\n" + json.dumps(store.get_facts(), ensure_ascii=False, indent=2))
hits = store.search_findings("赔偿", rule_title="违约责任")
print("Findings search:")
for f in hits:
print(json.dumps(asdict(f), ensure_ascii=False, indent=2))
print(store.export_to_excel())
from abc import ABC
from functools import wraps
import inspect
# 工具类装饰
def tool(name, description):
def decorator(cls):
cls._name = name
cls._description = description
cls._is_tool = True
return cls
return decorator
# 工具方法装饰
def tool_func(param_description):
def decorator(func):
func._param_description = param_description
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
return result
return wrapper
return decorator
# 工具基类
class ToolBase(ABC):
_name = 'tool_base'
_description = 'tool_description'
_is_tool = False
def execute(self, *args, **kwargs):
methods = inspect.getmembers(self, predicate=inspect.ismethod)
for name, method in methods:
func = method.__func__ if hasattr(method, '__func__') else method
if hasattr(func, '_param_description'):
result = func(self, *args, **kwargs)
return result
raise Exception('没有找到@tool_func装饰的方法')
from core.tool import ToolBase, tool, tool_func
# 加法工具
@tool('add', 'add tool')
class AddTool(ToolBase):
@tool_func({
"type": "object",
"properties": {
"a": {
"type": "int",
"description": "operator 1"
},
"b": {
"type": "int",
"description": "operator 2"
}
},
"required": [
"a",
"b"
]
})
def add(self, a, b):
return a + b
if __name__ == '__main__':
tool = AddTool()
print(AddTool._name)
print(tool.execute(1, 2))
\ No newline at end of file
from __future__ import annotations
from typing import Dict, List, Optional
from core.tool import ToolBase, tool, tool_func
from tools.shared import store, limit_list
@tool("full_read", "分段原文读取/摘录")
class FullReadTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"target": {"type": "object"},
"need_focus": {"type": "array"},
"return_snippets": {"type": "boolean"},
},
"required": ["target"],
}
)
def run(self, target: Dict, need_focus: Optional[List[str]] = None, return_snippets: bool = True) -> Dict:
seg_id = (target or {}).get("id") or (target or {}).get("segment_id") or ""
mem = store.get(seg_id) if seg_id else None
full_text = " ".join(mem.evidence_quotes) if mem and mem.evidence_quotes else ""
if not full_text:
full_text = "(原文未存储,建议在调用端提供全文或将原文写入 MemoryStore)"
focus = need_focus or []
snippets: List[str] = []
if return_snippets and mem:
for kw in focus:
for q in mem.evidence_quotes:
if kw and kw in q:
snippets.append(q)
if not snippets:
for kw in focus:
for s in mem.summary.structured:
if kw and kw in s:
snippets.append(s)
snippets = limit_list(list(dict.fromkeys(snippets)), 5)
return {
"target": {"type": "segment", "id": seg_id},
"full_text": full_text,
"highlighted_snippets": snippets,
}
from __future__ import annotations
import re
from typing import Dict, List, Optional,Tuple
from core.tool import ToolBase, tool, tool_func
from tools.shared import store
from core.memory import MemoryStore
TOPIC_KEYWORDS = {
"付款": ["付款", "支付", "价款", "结算", "发票", "对账", "款项", "账期"],
"验收": ["验收", "确认", "签收", "测试", "交付确认", "验收单", "验收报告"],
"交付": ["交付", "交货", "交付物", "交付时间", "里程碑", "上线", "交接"],
"违约": ["违约", "违约金", "赔偿", "损失", "承担责任", "罚金", "扣款"],
"解除": ["解除", "终止", "解约", "撤销"],
"责任限制": ["责任上限", "间接损失", "免责", "赔偿上限", "责任限制", "不可抗力"],
"保密": ["保密", "保密信息", "披露", "泄露", "保密期限", "例外"],
"知识产权": ["知识产权", "著作权", "专利", "商标", "成果", "源代码", "许可"],
"争议解决": ["争议", "仲裁", "诉讼", "管辖", "适用法律", "法院"],
}
DEFAULT_TOP_K = 6
@tool("memory_read", "分段记忆读取")
class MemoryReadTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_ids": {"type": "array"},
"topics": {"type": "array"},
"query": {"type": "string"},
"top_k": {"type": "int"},
"mode": {"type": "string"},
},
"required": ["top_k"],
}
)
def parse_neighbor_ids(self, segment_id: int) -> List[int]:
neighbor_ids = []
if segment_id >= 1:
neighbor_ids.append(segment_id - 1)
if segment_id >= 2:
neighbor_ids.append(segment_id - 2)
return neighbor_ids
def normalize_text(self, s: str) -> str:
return re.sub(r"\s+", " ", s.strip())
def derive_topics_rule_based(self, segment_text: str, max_topics: int = 3) -> List[str]:
text = self.normalize_text(segment_text)
scores: List[Tuple[str, int]] = []
for topic, kws in TOPIC_KEYWORDS.items():
hit = sum(1 for kw in kws if kw in text)
if hit > 0:
scores.append((topic, hit))
scores.sort(key=lambda x: x[1], reverse=True)
return [t for t, _ in scores[:max_topics]]
def run(self, store: MemoryStore, segment_id: int, segment_text: str,
top_k: int = DEFAULT_TOP_K) -> Dict:
neighbor_ids = self.parse_neighbor_ids(segment_id)
topics = self.derive_topics_rule_based(segment_text)
if not topics:
topics = ["其他"]
candidates = store.list()
ids = set(segment_ids or [])
if ids:
candidates = [c for c in candidates if c.segment_id in ids]
if topics:
candidates = [c for c in candidates if set(c.topics).intersection(set(topics))]
key = (query or "").strip().lower()
scored: List[Dict] = []
for c in candidates:
hay = " ".join([
c.summary.one_liner,
" ".join(c.summary.structured),
" ".join(c.topics),
" ".join(c.dependencies),
" ".join(c.evidence_quotes),
" ".join([r.flag + " " + r.evidence + " " + r.suggestion for r in c.risks]),
]).lower()
score = 0.0
if key:
tokens = [t for t in re.split(r"\s+", key) if t]
score = sum(1.0 for t in tokens if t in hay) / max(1, len(tokens))
rec = c.to_dict()
rec["relevance"] = round(score, 2)
rec["reason"] = "关键词匹配" if key else "主题/ID过滤"
scored.append(rec)
scored.sort(key=lambda x: x.get("relevance", 0.0), reverse=True)
return {"memories": scored[:top_k]}
from __future__ import annotations
from typing import Dict, List, Optional
from dataclasses import asdict
from core.tool import ToolBase, tool, tool_func
from core.memory import RiskFinding
@tool("memory_write", "分段记忆写入")
class MemoryWriteTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "string"},
"topics": {"type": "array"},
"review_output": {"type": "object"},
"dependencies": {"type": "array"},
"write_policy": {"type": "object"},
},
"required": ["segment_id", "topics", "review_output"],
}
)
def run(self, segment_id: str, topics: List[str], review_output: Dict, dependencies: Optional[List[str]] = None, write_policy: Optional[Dict] = None) -> Dict:
# 将审查输出转换为简化结构的 findings 写入 MemoryStore
findings = review_output.get("findings", [])
added: List[Dict] = []
for f in findings:
rule_title = f.get("rule_title") or ""
issue = f.get("issue") or f.get("issue_description") or ""
level = (f.get("level") or f.get("risk_level") or "M").upper()
suggestion = f.get("suggestion") or ""
evs = list(f.get("evidence_quotes", []) or [])
original_text = evs[0] if evs else (f.get("original_text") or "")
try:
finding_obj = RiskFinding(
rule_title=rule_title,
segment_id=int(segment_id) if str(segment_id).isdigit() else 0,
original_text=original_text,
issue_description=issue,
risk_level=level,
suggestion=suggestion,
)
store.add_finding(finding_obj)
added.append(asdict(finding_obj))
except Exception:
# 保守处理:跳过格式异常项
continue
memory_id = "m_" + segment_id
return {
"memory_id": memory_id,
"findings": added,
}
from __future__ import annotations
from typing import Dict, List, Optional
from core.tool import ToolBase, tool, tool_func
@tool("reflect_retry", "反思重试质量闸(MVP)")
class ReflectRetryTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "string"},
"ruleset_id": {"type": "string"},
"context_memories": {"type": "array"},
"review_output": {"type": "object"},
},
"required": ["segment_id", "ruleset_id", "review_output"],
}
)
def run(self, segment_id: str, ruleset_id: str, context_memories: Optional[List[Dict]] = None, review_output: Dict = None) -> Dict:
problems: List[str] = []
actions: List[Dict] = []
findings = (review_output or {}).get("findings", [])
missings = (review_output or {}).get("missing_info", [])
for f in findings:
evs = f.get("evidence_quotes", []) or []
sugg = (f.get("suggestion") or "").strip()
if not evs:
problems.append("缺少对关键结论的原文引用")
if not sugg:
problems.append("建议条款措辞仍偏泛,未给可替换文本")
needs_evidence = any(m.get("type") == "evidence" for m in missings)
needs_knowledge = any(m.get("type") == "knowledge" for m in missings)
status = "pass"
if needs_evidence:
actions.append({"action": "FullRead", "target": segment_id, "need_focus": [m.get("need") for m in missings if m.get("type") == "evidence"]})
if needs_knowledge:
actions.append({"action": "RetrieveReference", "topic": [m.get("topic") for m in missings if m.get("type") == "knowledge"]})
if problems or actions:
status = "revise"
retry_budget_left = int((review_output or {}).get("retry_budget_left", 1))
return {
"segment_id": segment_id,
"status": status,
"problems": list(dict.fromkeys(problems)),
"revise_instructions": actions,
"retry_budget_left": max(0, retry_budget_left - (1 if status == "revise" else 0)),
}
import importlib
import inspect
import pkgutil
from pathlib import Path
from typing import Dict, Optional, Type
from core.tool import ToolBase
# Registry that maps tool names to their classes
_tool_registry: Dict[str, Type[ToolBase]] = {}
def _iter_tool_classes():
package_dir = Path(__file__).parent
for module_info in pkgutil.walk_packages([str(package_dir)], prefix="tools."):
if module_info.name.endswith(".registry") or module_info.name.endswith(".contract_tools"):
continue
module = importlib.import_module(module_info.name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if getattr(obj, "_is_tool", False) and issubclass(obj, ToolBase):
yield obj
def build_registry(force_refresh: bool = False) -> Dict[str, Type[ToolBase]]:
if _tool_registry and not force_refresh:
return _tool_registry
_tool_registry.clear()
for cls in _iter_tool_classes():
_tool_registry[cls._name] = cls
return _tool_registry
def get_tool(name: str) -> Optional[Type[ToolBase]]:
return _tool_registry.get(name)
def available_tools() -> Dict[str, str]:
return {name: cls._description for name, cls in _tool_registry.items()}
# Build registry on import
build_registry()
if __name__ == "__main__":
for tool_name, tool_desc in available_tools().items():
print(f"Tool: {tool_name}, Description: {tool_desc}")
\ No newline at end of file
from __future__ import annotations
from typing import Dict, List, Any
from core.tool import ToolBase, tool, tool_func
from core.cache import get_cached_memory
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
@tool("retrieve_reference", "审查参考检索")
class RetrieveReferenceTool(ToolBase):
@tool_func(
{
"type": "object",
"properties": {
"question": {"type": "string"},
"top_k": {"type": "int"},
},
"required": ["question"],
}
)
def run(self, question: str, top_k: int = 5, conversation_id: str = "") -> Dict:
memory_refs = self._search_memory(question, conversation_id, top_k)
kb_refs = self._search_knowledge_base(question, top_k)
external_refs = self._search_external(question, top_k)
return {
"memory_refs": memory_refs,
"kb_refs": kb_refs,
"external_refs": external_refs,
}
def _search_memory(self, question: str, conversation_id: str, top_k: int) -> List[Dict[str, Any]]:
if not conversation_id:
return []
store = get_cached_memory(conversation_id)
facts = store.get_facts()
results: List[Dict[str, Any]] = []
def _add(dim: str, payload: Any) -> None:
snippet = payload if isinstance(payload, str) else str(payload)
results.append({
"source": f"memory:{dim}",
"dimension": dim,
"snippet": snippet,
})
for dim in FACT_DIMENSIONS:
val = facts.get(dim)
if val is None:
continue
if dim in question:
_add(dim, val)
return results[:top_k]
def _search_knowledge_base(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement KB retrieval
return []
def _search_external(self, question: str, top_k: int) -> List[Dict[str, Any]]:
# TODO: implement external retrieval (e.g., search engine)
return []
if __name__ == "__main__":
tmp_memory = get_cached_memory("tmp")
tmp_memory.add_finding_from_dict({
"issue": "支付方式不明确",
"original_text": "买方应在收到货物后30天内支付全部货款。",
"risk_level": "H",
"rule_title": "支付条款审查",
})
tmp_memory.update_facts({
"支付":{
"支付方式": "银行转账",
"支付期限": "收到货物后30天内",
}
})
tool = RetrieveReferenceTool()
result = tool.run(
question="支付方式是什么?",
top_k=3,
conversation_id="tmp",
)
print(result)
\ No newline at end of file
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from core.tool import ToolBase
from utils.common_util import extract_json
from utils.openai_util import OpenAITool
from core.config import LLM, MAX_WORKERS
class SegmentLLMBase(ToolBase):
"""LLM-backed segment processor: builds prompts, calls LLM, parses JSON."""
def __init__(self, system_prompt: str, llm_key: str = "fastgpt_segment_review") -> None:
super().__init__()
self.system_prompt = system_prompt
self.llm = OpenAITool(LLM[llm_key], max_workers=MAX_WORKERS)
def build_messages(self, user_content: str) -> List[Dict[str, str]]:
return [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_content},
]
async def chat_async(self, messages: List[Dict[str, str]]):
return await self.llm.chat(messages)
async def chat_batch_async(self, messages_list: List[List[Dict[str, str]]]):
return await self.llm.mul_chat(messages_list)
def run_with_loop(self, coro):
try:
return asyncio.run(coro)
except RuntimeError:
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)
def parse_first_json(self, resp: str) -> Dict[str, Any]:
try:
data = extract_json(resp)
return data[0] if data else {}
except Exception:
return {}
from __future__ import annotations
import asyncio
import json
from typing import Dict, List, Optional
from core.tool import tool, tool_func
from core.tools.segment_llm import SegmentLLMBase
FACT_DIMENSIONS: List[str] = ["当事人", "标的", "金额", "支付", "交付", "质量", "知识产权", "保密", "违约责任", "争议解决"]
SUMMARY_SYSTEM_PROMPT = f'''
你是合同事实提取智能体(SegmentSummary)。
仅输出“本分段的客观事实”,不做风险判断,不做主观推测。
【输出结构】
- facts: 一个对象,键为预设维度,值为该分段出现的事实(未出现的维度可缺省或置空)。
- 维度列表:{", ".join(FACT_DIMENSIONS)}。
- 若原文包含多个事实,可使用列表或子对象表达,但保持紧凑、可读。
【约束】
- 严禁编造或改写原文未出现的信息。
- 不输出与事实无关的解释或额外文字。
'''
SUMMARY_USER_PROMPT = '''
【分段原文】
{segment_text}
【上下文事实】
{context_facts}
请提取本段出现的客观事实,按照指定维度输出 JSON。未出现的维度可省略。
输出示例:'''
OUTPUT_EXAMPLE = '''
```json
{
"facts": {
"支付": {"方式": "银行转账", "时间": "验收后30日内"},
"违约责任": {"违约金比例": "合同总金额的5%"}
}
}
```
'''
@tool("segment_summary", "分段事实提取")
class SegmentSummaryTool(SegmentLLMBase):
def __init__(self) -> None:
super().__init__(SUMMARY_SYSTEM_PROMPT)
@tool_func(
{
"type": "object",
"properties": {
"segment_id": {"type": "int"},
"segment_text": {"type": "string"},
"party_role": {"type": "string"},
"context_facts": {"type": "object"},
},
"required": ["segment_id", "segment_text"],
}
)
def run(
self,
segment_id: int,
segment_text: str,
party_role: str = "",
context_facts: Optional[Dict] = None,
) -> Dict:
try:
return self.run_with_loop(self._summarize_async(segment_id, segment_text, party_role, context_facts))
except Exception:
return {}
def _build_prompt(self, segment_text: str, context_facts: Optional[Dict], party_role: str) -> List[Dict[str, str]]:
user_content = SUMMARY_USER_PROMPT.format(
segment_text=segment_text,
context_facts=json.dumps(context_facts or {}, ensure_ascii=False),
) + OUTPUT_EXAMPLE
return self.build_messages(user_content)
async def _summarize_async(
self,
segment_id: int,
segment_text: str,
party_role: str,
context_facts: Optional[Dict],
) -> Dict:
msgs = self._build_prompt(segment_text, context_facts, party_role)
final_facts: Dict = {}
try:
resp = await self.chat_async(msgs)
data = self.parse_first_json(resp)
facts = data.get("facts") or {}
except Exception:
facts = {}
# print(f'SegmentSummaryTool facts: {facts}')
if isinstance(facts,list):
final_facts['内容'] = facts
else:
final_facts = facts
final_facts['segment_id'] = segment_id
return final_facts
if __name__=='__main__':
tool = SegmentSummaryTool()
res = tool.run(
segment_id=1,
segment_text="甲方应于合同签订之日起30日内向乙方支付合同总金额的50%,余款在货物验收合格后30日内付清.",
context_facts={},
)
print(res)
\ No newline at end of file
File added
import numpy as np
def calculate_grpo_advantages(rewards, epsilon=1e-8):
"""
计算 GRPO 的组优势值
:param rewards: 列表或数组,包含同一组样本的奖励值
:param epsilon: 稳定性系数,防止除以 0
:return: 归一化后的优势值数组
"""
rewards = np.array(rewards)
# 1. 计算当前组的平均值
mean = np.mean(rewards)
# 2. 计算当前组的标准差
std = np.std(rewards)
# 3. 归一化计算优势
# 减去均值除以标准差,使得该组优势值满足均值为 0,标准差为 1
advantages = (rewards - mean) / (std + epsilon)
return advantages
# 示例数据
your_rewards = [1.1,1.1,1.1,1.1]
advantages = calculate_grpo_advantages(your_rewards)
print(f"原始奖励值: {your_rewards}")
print(f"GRPO 优势值: {advantages.round(4)}")
\ No newline at end of file
from spire.doc import *
from spire.doc.common import *
from pathlib import Path
# 删除原有的单文件读取逻辑,改为遍历文件夹并统计平均值
folder = Path("/home/ccran/lufa-contract/datasets/2-1.合同审核-近三年审核前的合同文件") # 如需自定义目录,修改此处
lengths = []
for p in folder.iterdir():
if not p.is_file():
continue
try:
doc = Document()
doc.LoadFromFile(str(p))
text = doc.GetText()
length = len(text)
print(f"{p.name}: {length}")
lengths.append(length)
except Exception as e:
print(f"跳过文件 {p.name}: {e}")
if lengths:
avg = sum(lengths) / len(lengths)
print(f"平均长度: {avg:.2f}(共 {len(lengths)} 个文件)")
else:
print("没有可处理的文件。")
\ No newline at end of file
import json
d = {
"urls": "[\"/api/common/file/read/大模型角色扮演综述总结.docx?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJidWNrZXROYW1lIjoiY2hhdCIsInRlYW1JZCI6IjM2MzU2MjY1NjYzMDYyMzQzMDM2NjQzNCIsInVpZCI6IjY0NjQ2MTM2MzU2MzM5MzUzNzMwMzU2MiIsImZpbGVJZCI6IjY5NjhhOGU4M2YyMDEzMWI1MjdjMDdjZSIsImV4cCI6MTc2OTA3MTQ2NCwiaWF0IjoxNzY4NDY2NjY0fQ.B_W_CyzpaW__Hb8mrN-Xn5M2FOg73zFN4KoKFnSEQzs\"]",
"conversation_id": "38188df54cbe4884b2b6eb0b30dec898"
}
print(d['urls'])
\ No newline at end of file
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn
from utils.common_util import extract_url_file, format_now
from utils.http_util import download_file
from core.cache import get_cached_doc_tool, get_cached_memory
from core.config import doc_support_formats
from core.tools.segment_summary import SegmentSummaryTool
from core.tools.segment_review import SegmentReviewTool
from core.memory import RiskFinding
app = FastAPI(title="合同审查智能体", version="0.1.0")
TMP_DIR = Path(__file__).resolve().parent / "tmp"
TMP_DIR.mkdir(parents=True, exist_ok=True)
summary_tool = SegmentSummaryTool()
review_tool = SegmentReviewTool()
########################################################################################################################
class DocumentParseRequest(BaseModel):
conversation_id: str
urls: List[str] = Field(..., description="File download url")
class DocumentParseResponse(BaseModel):
conversation_id: str
text: str
chunk_ids: List[int]
@app.post("/documents/parse", response_model=DocumentParseResponse)
def parse_document(payload: DocumentParseRequest) -> DocumentParseResponse:
if not payload.urls:
raise HTTPException(status_code=400, detail="No URLs provided")
try:
filename = extract_url_file(payload.urls[0], doc_support_formats)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Failed to parse url: {exc}")
file_path = str(TMP_DIR / filename)
try:
download_file(payload.urls[0], file_path)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Download failed: {exc}")
doc_obj = get_cached_doc_tool(payload.conversation_id)
doc_obj.load(file_path)
text = doc_obj.get_all_text()
chunk_ids = doc_obj.get_chunk_id_list()
return DocumentParseResponse(
conversation_id=payload.conversation_id, text=text,chunk_ids=chunk_ids
)
########################################################################################################################
class SegmentSummaryRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
context_facts: Optional[Dict] = None
class SegmentSummaryResponse(BaseModel):
conversation_id: str
segment_id: int
facts: Dict
@app.post("/segments/summary/facts", response_model=SegmentSummaryResponse)
def summarize_facts(payload: SegmentSummaryRequest) -> SegmentSummaryResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1 # chunk_id 在 SpireWordDoc 中为 1-based
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
result = summary_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
party_role=payload.party_role or "",
context_facts=payload.context_facts or store.get_facts(),
)
store.add_facts(result)
return SegmentSummaryResponse(
conversation_id=payload.conversation_id,
segment_id=payload.segment_id,
facts=result,
)
########################################################################################################################
class SegmentReviewRequest(BaseModel):
conversation_id: str
segment_id: int
party_role: Optional[str] = ""
ruleset_id: Optional[str] = "rule-v1"
context_memories: Optional[List[Dict]] = None
class SegmentReviewResponse(BaseModel):
conversation_id: str
segment_id: int
overall_conclusion: str
findings: List[Dict]
@app.post("/segments/review/findings", response_model=SegmentReviewResponse)
def review_segment(payload: SegmentReviewRequest) -> SegmentReviewResponse:
store = get_cached_memory(payload.conversation_id)
try:
doc_obj = get_cached_doc_tool(payload.conversation_id)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Document tool not available: {exc}")
chunk_idx = payload.segment_id - 1
try:
segment_text = doc_obj.get_chunk_item(chunk_idx)
except Exception as exc:
raise HTTPException(status_code=404, detail=f"Segment text not found for id {payload.segment_id}: {exc}. Please parse document first.")
result = review_tool.run(
segment_id=payload.segment_id,
segment_text=segment_text,
ruleset_id=payload.ruleset_id or "rule-v1",
party_role=payload.party_role or "",
context_memories=payload.context_memories or store.get_facts(),
)
# Persist findings to memory store
for f in result.get("findings", []) or []:
try:
store.add_finding_from_dict({
"rule_title": f.get("rule_title", ""),
"segment_id": payload.segment_id,
"original_text": f.get("original_text",''),
"issue": f.get("issue", ""),
"risk_level": (f.get("risk_level") or f.get("level") or "").upper(),
"suggestion": f.get("suggestion", ""),
})
except Exception:
continue
return SegmentReviewResponse(
conversation_id=payload.conversation_id,
segment_id=payload.segment_id,
overall_conclusion=result.get("overall_conclusion", ""),
findings=result.get("findings", []),
)
########################################################################################################################
class ConversationResponse(BaseModel):
conversation_id: str
created_at: str
@app.post("/conversations/new", response_model=ConversationResponse)
def new_conversation() -> ConversationResponse:
conversation_id = uuid4().hex
created_at = format_now()
return ConversationResponse(conversation_id=conversation_id, created_at=created_at)
########################################################################################################################
class MemoryExportRequest(BaseModel):
conversation_id: str
file_name: Optional[str] = None
class MemoryExportResponse(BaseModel):
conversation_id: str
url: str
data: Dict
@app.post("/memory/export", response_model=MemoryExportResponse)
def export_memory(payload: MemoryExportRequest) -> MemoryExportResponse:
store = get_cached_memory(payload.conversation_id)
try:
res = store.export_to_excel(file_name=payload.file_name)
except ImportError as exc:
raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Export failed: {exc}")
url = ""
if isinstance(res, str):
url = res
elif isinstance(res, dict):
for key in [
"url",
"file_url",
"fileUrl",
"link",
"downloadUrl",
"path",
"filePath",
]:
val = res.get(key)
if isinstance(val, str) and val:
url = val
break
return MemoryExportResponse(conversation_id=payload.conversation_id, url=url, data=res if isinstance(res, dict) else {"result": res})
if __name__ == "__main__":
uvicorn.run(
"main:app", host="0.0.0.0", port=18168, log_level="info", reload=False
)
\ No newline at end of file
import json
import random
import re
from datetime import datetime
from typing import Dict, List
from loguru import logger
from core.config import max_chunk_page, min_single_chunk_size, max_single_chunk_size
def random_str(l=5):
return ''.join(random.sample('abcdefghijklmnopqrstuvwxyz', l))
def format_now():
"""Return current time string formatted as YYYY-MM-DD HH:MM:SS."""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 从url中提取文件名称
def extract_url_file(url, support_formats):
pattern = '|'.join([r'[\u4e00-\u9fa5()()0-9\w-]+' + format for format in support_formats])
search_result = re.search(pattern, url)
if search_result:
return search_result.group()
else:
raise Exception(f'{support_formats} not found in url:{url}')
# 调整单个页面数量
def adjust_single_chunk_size(all_text_len):
desired_chunk_size = all_text_len // max_chunk_page
return max(min_single_chunk_size, min(desired_chunk_size, max_single_chunk_size))
# 从JSON字符串提取JSON对象
def extract_json(json_str:str) -> List[Dict]:
"""从字符串中提取 JSON 对象列表。
优先提取 ```json ... ``` 代码块;若不存在,尝试:
- 解析全文为 JSON
- 从普通 ``` ... ``` 代码块解析
- 从任意包含花括号/方括号的片段尝试解析
返回解析成功的 JSON 对象列表(数组会被展开)。
"""
def _try_parse_to_list(candidate: str, out_list: list) -> bool:
s = (candidate or '').strip()
if not s:
return False
# 清理控制字符
s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s)
try:
obj = json.loads(s, strict=False)
if isinstance(obj, list):
out_list.extend(obj)
else:
out_list.append(obj)
return True
except Exception as e:
logger.error(f"JSON解析失败: {e}")
return False
results = []
# 1. 提取 ```json ... ``` 代码块
fenced_json_pattern = r'```json([\s\S]*?)```'
for match in re.findall(fenced_json_pattern, json_str or '', re.DOTALL):
_try_parse_to_list(match, results)
if results:
return results
# 2. 尝试将全文解析为 JSON
if _try_parse_to_list(json_str or '', results):
return results
# 3. 提取普通 ``` ... ``` 代码块,尝试解析
fenced_any_pattern = r'```([\s\S]*?)```'
for match in re.findall(fenced_any_pattern, json_str or '', re.DOTALL):
if _try_parse_to_list(match, results):
return results
# 4. 从包含花括号/方括号的片段尝试解析(启发式,尽力而为)
bracket_pattern = r'(\{[\s\S]*?\}|\[[\s\S]*?\])'
for match in re.findall(bracket_pattern, json_str or '', re.DOTALL):
_try_parse_to_list(match, results)
return results
def remove_duplicates_by_key(data_list, key):
# 先按字符串长度从长到短排序
sorted_list = sorted(
data_list,
key=lambda x: len(x.get(key, "")),
reverse=True
)
result = []
seen_strings = []
for item in sorted_list:
value = item.get(key, "")
if not any(value in s for s in seen_strings):
seen_strings.append(value)
result.append(item)
return result
def extract_drop_json_part(json_str):
json_pattern = r'```json([\s\S]*?)```'
non_json_content = re.sub(json_pattern, '', json_str, re.DOTALL)
return non_json_content.strip()
def group_chunk_by_len(chunk_list: List[Dict], key: str, chunk_len: int) -> List[List[Dict]]:
ret_chunk_list = []
sub_chunk_list = []
current_acc_len = 0 # 用于记录当前 sub_chunk 的累积长度
for chunk in chunk_list:
# 获取当前字典中指定 key 的内容的长度
# 使用 .get(key, "") 防止 key 不存在导致报错
content_len = len(chunk.get(key, ""))
# 判断:如果当前累积长度 + 新内容的长度 > 限制长度
# 且 sub_chunk_list 不为空(确保即使单个元素超长也能被添加)
if current_acc_len + content_len > chunk_len and sub_chunk_list:
# 将当前组加入结果集
ret_chunk_list.append(sub_chunk_list)
# 重置当前组和计数器
sub_chunk_list = []
current_acc_len = 0
# 将当前 chunk 加入子列表
sub_chunk_list.append(chunk)
current_acc_len += content_len
# 循环结束后,如果 sub_chunk_list 还有剩余内容,需要加入结果集
if sub_chunk_list:
ret_chunk_list.append(sub_chunk_list)
return ret_chunk_list
if __name__ == '__main__':
json_str = '```json{"segment_id": "seg-001"}```'
print(extract_json(json_str))
pass
import os
from abc import ABC, abstractmethod
# 文档基类
class DocBase(ABC):
def __init__(self, **kwargs):
self._doc_path = None
self._doc_name = None
self._kwargs = kwargs
self._max_single_chunk_size = kwargs.get('max_single_chunk_size', 2000)
@abstractmethod
def load(self, doc_path):
"""加载文件内容,初始化内部状态。"""
pass
@abstractmethod
def adjust_chunk_size(self):
pass
@abstractmethod
async def get_from_ocr(self):
pass
# 根据chunk_id获取文档片段
@abstractmethod
def get_chunk_item(self, chunk_id):
pass
# 获取文档片段信息
@abstractmethod
def get_chunk_info(self, chunk_id):
pass
@abstractmethod
def get_chunk_location(self, chunk_id):
pass
# 新增片段批注
@abstractmethod
def add_chunk_comment(self, chunk_id, comments):
pass
# 编辑片段批注
@abstractmethod
def edit_chunk_comment(self, comments):
pass
@abstractmethod
def delete_chunk_comment(self, comments):
pass
@abstractmethod
def get_chunk_id_list(self, step=1):
pass
# 获取文档片段数量
@abstractmethod
def get_chunk_num(self):
pass
# 获取文档片段大小
@abstractmethod
def get_all_text(self):
pass
# 保存文档到路径
def to_file(self, path, **kwargs):
pass
def release(self):
if self._doc_path and os.path.exists(self._doc_path):
os.remove(self._doc_path)
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional, Union
class ExcelLoadError(Exception):
"""Raised when Excel loading fails."""
class ExcelUtil:
"""Helper for reading Excel sheets (xlsx/xlsm) using openpyxl."""
def __init__(self, file_path: Union[str, Path]):
self.file_path = Path(file_path)
@staticmethod
def _import_openpyxl():
try:
import openpyxl # type: ignore
except ImportError as exc:
raise ExcelLoadError("openpyxl is required. Install via 'pip install openpyxl'") from exc
return openpyxl
def _ensure_exists(self) -> None:
if not self.file_path.exists():
raise ExcelLoadError(f"File not found: {self.file_path}")
def load(
self,
sheet_name: Optional[str] = None,
has_header: bool = True,
) -> List[Union[Dict[str, object], List[object]]]:
"""
Load data from the Excel file into a list.
Args:
sheet_name: Sheet to read; defaults to the first sheet.
has_header: When True, return a list of dict rows keyed by header row; otherwise return list of lists.
Raises:
ExcelLoadError: When the file is missing, not readable, or openpyxl is unavailable.
"""
self._ensure_exists()
openpyxl = self._import_openpyxl()
try:
wb = openpyxl.load_workbook(self.file_path, data_only=True, read_only=True)
except Exception as exc:
raise ExcelLoadError(f"Failed to open Excel file: {exc}") from exc
ws = wb[sheet_name] if sheet_name else wb.active
rows = list(ws.iter_rows(values_only=True))
if not rows:
return []
if not has_header:
return [list(row) for row in rows]
headers = [str(h).strip() if h is not None else "" for h in rows[0]]
data_rows = rows[1:]
result: List[Dict[str, object]] = []
for row in data_rows:
row_dict = {headers[i] if i < len(headers) else f"col{i}": row[i] for i in range(len(row))}
result.append(row_dict)
return result
def list_sheets(self) -> List[str]:
"""Return available sheet names for the current file."""
self._ensure_exists()
openpyxl = self._import_openpyxl()
try:
wb = openpyxl.load_workbook(self.file_path, read_only=True)
return wb.sheetnames
except Exception as exc:
raise ExcelLoadError(f"Failed to read sheet names: {exc}") from exc
def find_value_by_column(
self,
key_column: str,
key_value: object,
value_column: str,
sheet_name: Optional[str] = None,
) -> Optional[object]:
"""
Return the first value in column `value_column` where column `key_column` equals `key_value`.
Both column arguments should match header names when `has_header=True`.
"""
rows = self.load(sheet_name=sheet_name, has_header=True)
for row in rows:
if row.get(key_column) == key_value:
return row.get(value_column)
return None
def map_rows(self, sheet_name: Optional[str], column_map: Dict[str, str]) -> List[Dict[str, object]]:
"""
Load rows as dicts (header->value) and remap keys using column_map.
column_map: {new_key: header_name_in_excel}
Unmapped headers are ignored; missing headers yield None.
"""
rows = self.load(sheet_name=sheet_name, has_header=True)
mapped: List[Dict[str, object]] = []
for row in rows:
mapped_row = {new_key: row.get(header) for new_key, header in column_map.items()}
mapped.append(mapped_row)
return mapped
@classmethod
def load_excel(
cls,
file_path: Union[str, Path],
sheet_name: Optional[str] = None,
has_header: bool = True,
) -> List[Union[Dict[str, object], List[object]]]:
"""Convenience classmethod wrapper for one-off loads."""
return cls(file_path).load(sheet_name=sheet_name, has_header=has_header)
@classmethod
def list_excel_sheets(cls, file_path: Union[str, Path]) -> List[str]:
"""Convenience classmethod wrapper to list sheet names for a path."""
return cls(file_path).list_sheets()
@classmethod
def find_value_by_column_excel(
cls,
file_path: Union[str, Path],
key_column: str,
key_value: object,
value_column: str,
sheet_name: Optional[str] = None,
) -> Optional[object]:
"""Convenience wrapper for find_value_by_column on a file path."""
return cls(file_path).find_value_by_column(
key_column=key_column,
key_value=key_value,
value_column=value_column,
sheet_name=sheet_name,
)
@classmethod
def load_mapped_excel(
cls,
file_path: Union[str, Path],
sheet_name: Optional[str],
column_map: Dict[str, str],
) -> List[Dict[str, object]]:
"""Convenience wrapper to load rows and remap headers to new keys."""
return cls(file_path).map_rows(sheet_name=sheet_name, column_map=column_map)
if __name__ == "__main__":
# 一次性调用
rows = ExcelUtil.load_excel("data/rules.xlsx", sheet_name=None, has_header=False)
print(rows)
import json
import os
import requests
from loguru import logger
from requests_toolbelt import MultipartEncoder
from core.config import base_fastgpt_url, base_backend_url, outer_backend_url
def fastgpt_openai_chat(url, token, model, chat_id, file_url, text, stream=True):
data = {
"chatId": chat_id,
"messages": [
{
"role": "user",
"content": [
{"type": "file_url", "name": "文件", "url": file_url},
{
"type": "text",
"text": text,
},
],
}
],
"model": model,
"stream": stream,
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
response = requests.post(
url=url,
headers=headers,
data=json.dumps(data),
stream=True,
timeout=(60000, 60000),
)
rsp = ""
if stream:
for chunk in response.iter_content(8192):
try:
chunk = chunk.decode("utf-8")
if "data: [DONE]" in chunk:
break
# logger.info(chunk)
stream_rsp = json.loads(chunk[6:])
rsp += stream_rsp.get("choices")[0]["delta"]["content"]
except Exception as e:
logger.error(f"error decode:{e} chunk:{chunk}")
else:
# print(response.text)
json_rsp = json.loads(response.text)
rsp = json_rsp.get("choices")[0]["message"]["content"]
return rsp
def upload_file(path, input_url_to_inner=True, output_url_to_inner=False):
# 登录获取token
login_data = {"username": "admin", "password": "admin@jpai.com"}
login_url = f"{base_backend_url}/admin-api/system/auth/login"
response = requests.post(
url=login_url,
headers={"Content-Type": "application/json"},
data=json.dumps(login_data),
)
try:
token = json.loads(response.text).get("data").get("accessToken")
except Exception as e:
logger.error(f"后端登录异常:{e}")
# logger.info(f"上传Token: {token}")
# 上传文件
upload_url = f"{base_backend_url}/admin-api/infra/file/upload"
m = MultipartEncoder(fields={"file": (os.path.basename(path), open(path, "rb"))})
response = requests.post(
url=upload_url,
headers={"Content-Type": m.content_type, "Authorization": token},
data=m,
)
res = json.loads(response.text).get("data")
if res:
return res
else:
raise Exception(f"上传{path}失败 Response text: {response.text}")
# 下载url到本地path
def download_file(url, path, input_url_to_inner=True):
if not url.startswith("http:"):
url = base_fastgpt_url + url
url = url.replace(outer_backend_url, base_backend_url)
logger.info(f"url准备下载:{url}")
# 发送一个HTTP请求到URL
response = requests.get(url)
# 确保请求成功
if response.status_code == 200:
# 打开本地文件,准备写入数据
with open(path, "wb") as f:
# 写入响应的内容
f.write(response.content)
logger.info(f"{url}文件下载成功,保存到{path}")
else:
logger.error(f"{url}文件下载失败. HTTP Status Code: {response.status_code}")
def url_replace_fastgpt(origin: str):
if not origin.startswith("http:"):
origin = base_fastgpt_url + origin
return origin
if __name__ == "__main__":
# d = '/home/ccran/file.docx'
d = "file.docx"
print(os.path.basename(d))
from loguru import logger
from openai import AsyncOpenAI
from dataclasses import dataclass
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed
import asyncio
@dataclass
class LLMConfig:
base_url: str
api_key: str
model: str
class OpenAITool:
def __init__(self, llm_config: LLMConfig, max_workers: int = 5):
self.max_workers = max_workers
self.llm_config = llm_config
self.client = AsyncOpenAI(
base_url=llm_config.base_url, api_key=llm_config.api_key
)
@retry(stop=stop_after_delay(600) | stop_after_attempt(1), wait=wait_fixed(1))
async def chat(self, msg, tools=None):
if tools is None:
extra_body = None
if msg[0]['role'] == 'system':
extra_body = {
'variables': {
'system': msg[0]['content']
}
}
msg = msg[1:]
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=msg,
extra_body=extra_body
)
content = response.choices[0].message.content
reasoning_content = response.choices[0].message.model_extra.get(
"reasoning_content", ""
)
return content
else:
response = await self.client.chat.completions.create(
model=self.llm_config.model,
messages=msg,
tools=tools,
tool_choice="auto",
)
return response.choices[0].message.tool_calls
async def mul_chat(self, msgs, tools=None):
sem = asyncio.Semaphore(self.max_workers)
async def _wrapped(m):
async with sem:
return await self.chat(m, tools)
return await asyncio.gather(*[_wrapped(m) for m in msgs])
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment