Commit 37faa701 by ccran

Initial commit

parents
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="ccran@117.157.192.95:33333 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="天水aidemo">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="天水coast">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="天水练手verl">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="70">
<item index="0" class="java.lang.String" itemvalue="azure_storage" />
<item index="1" class="java.lang.String" itemvalue="onnxruntime" />
<item index="2" class="java.lang.String" itemvalue="torch" />
<item index="3" class="java.lang.String" itemvalue="openai-whisper" />
<item index="4" class="java.lang.String" itemvalue="torchaudio" />
<item index="5" class="java.lang.String" itemvalue="kaldialign" />
<item index="6" class="java.lang.String" itemvalue="tiktoken" />
<item index="7" class="java.lang.String" itemvalue="whisperspeech" />
<item index="8" class="java.lang.String" itemvalue="faster-whisper" />
<item index="9" class="java.lang.String" itemvalue="braceexpand" />
<item index="10" class="java.lang.String" itemvalue="chromadb" />
<item index="11" class="java.lang.String" itemvalue="httpx" />
<item index="12" class="java.lang.String" itemvalue="alembic" />
<item index="13" class="java.lang.String" itemvalue="rebyte-langchain" />
<item index="14" class="java.lang.String" itemvalue="emoji" />
<item index="15" class="java.lang.String" itemvalue="pgvector" />
<item index="16" class="java.lang.String" itemvalue="SQLAlchemy" />
<item index="17" class="java.lang.String" itemvalue="psycopg2-binary" />
<item index="18" class="java.lang.String" itemvalue="python-dotenv" />
<item index="19" class="java.lang.String" itemvalue="firebase_admin" />
<item index="20" class="java.lang.String" itemvalue="numpy" />
<item index="21" class="java.lang.String" itemvalue="edge-tts" />
<item index="22" class="java.lang.String" itemvalue="aioconsole" />
<item index="23" class="java.lang.String" itemvalue="llama_index" />
<item index="24" class="java.lang.String" itemvalue="langchain" />
<item index="25" class="java.lang.String" itemvalue="starlette" />
<item index="26" class="java.lang.String" itemvalue="anthropic" />
<item index="27" class="java.lang.String" itemvalue="google-cloud-speech" />
<item index="28" class="java.lang.String" itemvalue="beautifulsoup4" />
<item index="29" class="java.lang.String" itemvalue="SpeechRecognition" />
<item index="30" class="java.lang.String" itemvalue="pydantic" />
<item index="31" class="java.lang.String" itemvalue="faster_whisper" />
<item index="32" class="java.lang.String" itemvalue="pytest" />
<item index="33" class="java.lang.String" itemvalue="readerwriterlock" />
<item index="34" class="java.lang.String" itemvalue="pypdf" />
<item index="35" class="java.lang.String" itemvalue="pyaudio" />
<item index="36" class="java.lang.String" itemvalue="openai" />
<item index="37" class="java.lang.String" itemvalue="fastapi" />
<item index="38" class="java.lang.String" itemvalue="twilio" />
<item index="39" class="java.lang.String" itemvalue="transformers" />
<item index="40" class="java.lang.String" itemvalue="chonkie" />
<item index="41" class="java.lang.String" itemvalue="fitz" />
<item index="42" class="java.lang.String" itemvalue="tenacity" />
<item index="43" class="java.lang.String" itemvalue="pymupdf" />
<item index="44" class="java.lang.String" itemvalue="streamlit" />
<item index="45" class="java.lang.String" itemvalue="loguru" />
<item index="46" class="java.lang.String" itemvalue="Requests" />
<item index="47" class="java.lang.String" itemvalue="requests_toolbelt" />
<item index="48" class="java.lang.String" itemvalue="pandas" />
<item index="49" class="java.lang.String" itemvalue="pdf2docx" />
<item index="50" class="java.lang.String" itemvalue="python_docx" />
<item index="51" class="java.lang.String" itemvalue="cn2an" />
<item index="52" class="java.lang.String" itemvalue="pdfminer.six" />
<item index="53" class="java.lang.String" itemvalue="qwen_agent" />
<item index="54" class="java.lang.String" itemvalue="aiohttp" />
<item index="55" class="java.lang.String" itemvalue="uvicorn" />
<item index="56" class="java.lang.String" itemvalue="openpyxl" />
<item index="57" class="java.lang.String" itemvalue="torchdata" />
<item index="58" class="java.lang.String" itemvalue="pre-commit" />
<item index="59" class="java.lang.String" itemvalue="flash-attn" />
<item index="60" class="java.lang.String" itemvalue="ray" />
<item index="61" class="java.lang.String" itemvalue="pybind11" />
<item index="62" class="java.lang.String" itemvalue="hydra-core" />
<item index="63" class="java.lang.String" itemvalue="liger-kernel" />
<item index="64" class="java.lang.String" itemvalue="peft" />
<item index="65" class="java.lang.String" itemvalue="wandb" />
<item index="66" class="java.lang.String" itemvalue="tensordict" />
<item index="67" class="java.lang.String" itemvalue="codetiming" />
<item index="68" class="java.lang.String" itemvalue="pylatexenc" />
<item index="69" class="java.lang.String" itemvalue="thefuzz" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="D:\Anaconda" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="D:\Anaconda" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/auto-prompt.iml" filepath="$PROJECT_DIR$/.idea/auto-prompt.iml" />
</modules>
</component>
</project>
\ No newline at end of file
[
{
"original_text": "8.1.2.6 向甲方支付违约基数30%的违约金。",
"details": "句子主语为乙方",
"result": "不涉及",
"suggest": ""
},
{
"original_text": "8.1.2.2 要求乙方全额退还甲方已支付的预付款项。",
"details": "句子主语为乙方",
"result": "不涉及",
"suggest": ""
},
{
"original_text": "8.1.2.2 2.13.3.5 甲方有权取消该笔订单/合同,乙方需支付违约基数30%的违约金。",
"details": "",
"result": "不合格",
"suggest": ""
}
]
\ No newline at end of file
[
{
"original_text": "8.1.2.6 向甲方支付违约基数30%的违约金。",
"details": "句子主语为乙方",
"result": "不涉及",
"suggest": ""
},
{
"original_text": "8.1.2.2 要求乙方全额退还甲方已支付的预付款项。",
"details": "句子主语为乙方",
"result": "不涉及",
"suggest": ""
},
{
"original_text": "8.1.2.2 2.13.3.5 甲方有权取消该笔订单/合同,乙方需支付违约基数30%的违约金。",
"details": "",
"result": "不合格",
"suggest": ""
}
]
\ No newline at end of file
import json
json_list = [
{'original_text': '8.1.2.6 向甲方支付违约基数30%的违约金。', 'details': '句子主语为乙方', 'result': '不涉及',
'suggest': ''},
{'original_text': '8.1.2.2 要求乙方全额退还甲方已支付的预付款项。', 'details': '句子主语为乙方',
'result': '不涉及', 'suggest': ''},
{'original_text': '8.1.2.2 2.13.3.5 甲方有权取消该笔订单/合同,乙方需支付违约基数30%的违约金。',
'details': '',
'result': '不合格', 'suggest': ''}
]
print(json.dumps(json_list, ensure_ascii=False, indent=4))
This source diff could not be displayed because it is too large. You can view the blob instead.
import os
from util import *
os.environ['OLLAMA_BASE_URL'] = 'http://192.168.252.71:9002/v1'
import textgrad as tg
from textgrad.tasks import load_task, Dataset
from textgrad.autograd.string_based_ops import StringBasedFunction
class ContractReviewDataset(Dataset):
def __init__(self):
self.datas = [
{'original_text': '8.1.2.6 向甲方支付违约基数30%的违约金。', 'details': '句子主语为乙方', 'result': '不涉及',
'suggest': ''},
{'original_text': '8.1.2.2 要求乙方全额退还甲方已支付的预付款项。', 'details': '句子主语为乙方',
'result': '不涉及', 'suggest': ''},
{'original_text': '8.1.2.2 2.13.3.5 甲方有权取消该笔订单/合同,乙方需支付违约基数30%的违约金。',
'details': '',
'result': '不合格', 'suggest': ''}
]
def __getitem__(self, index):
row = self.datas[index]
return row["original_text"], row["result"]
def __len__(self):
return len(self.datas)
def load_contract_review_task():
ds = ContractReviewDataset()
return ds, ds, ds, StringBasedFunction(string_based_equality_fn,
function_purpose="The runtime of string-based function that checks if the prediction is correct.")
# init engine
llm_engine = tg.get_engine("ollama-Qwen2-72B-Instruct")
tg.set_backward_engine("ollama-Qwen2-72B-Instruct")
# init datasets
train_set, val_set, _, eval_fn = load_contract_review_task()
train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)
# prediction
init_prompt = '''
你是乙方(供方、卖方)法律部门的合同审查助手
# 审查要点
1)提取涉及到句子的主体为甲方/买方/需方,句子内容为“合同变更/取消”、“退货”相关的句子,没有则返回不涉及
2)句子明确提及了“双方协商”,审查合格
3)句子没有明确提及“合同变更/取消”、“中途退货”所需要承担的责任,审查不合格
3)“合同变更/取消”相关的句子,没有提及违约金额,审查不合格
4)“退货”相关的句子,违约金的比例低于80%,审查不合格
# 不合格建议
1、提醒用户不合规的变更取消责任
# 审查约束
- 输出包括审查的原文、详情、结果、建议
- 审查结果为合格/不合格/不涉及,合格/不涉及的审查结果无需输出建议
- 审查原文严格提取关键、无省略、无篡改的原文内容
- 结果以JSON数组的格式返回,例如```json [{"original_text":"xx","details":"xx","result":"xx","suggest":"xx"}]```
依据审查要点,遵循约束,完成合同审查,提供审查建议,一步步仔细思考。
'''
system_prompt = tg.Variable(init_prompt,
requires_grad=True,
role_description="system prompt to guide the LLM's reasoning strategy for accurate responses")
model = tg.BlackboxLLM(llm_engine, system_prompt=system_prompt)
optimizer = tg.TGD(parameters=list(model.parameters()))
results = {"train_acc": [], "prompt": [], "validation_acc": []}
results["train_acc"].append(eval_dataset(train_set, eval_fn, model))
results["validation_acc"].append(eval_dataset(val_set, eval_fn, model))
results["prompt"].append(system_prompt.get_value())
# 反向传播
for epoch in range(3):
for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
pbar.set_description(f"Training step {steps}. Epoch {epoch}")
optimizer.zero_grad()
losses = []
for (x, y) in zip(batch_x, batch_y):
x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
response = model(x)
eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
losses.append(eval_output_variable)
total_loss = tg.sum(losses)
total_loss.backward()
optimizer.step()
run_validation_revert(system_prompt, results, model, eval_fn, val_set)
print("sys prompt: ", system_prompt)
results["validation_acc"].append(eval_dataset(val_set, eval_fn, model))
results["prompt"].append(system_prompt.get_value())
import argparse
import os
from util import *
import textgrad as tg
from textgrad.tasks import load_task, Dataset
from textgrad.autograd.string_based_ops import StringBasedFunction
class ContractReviewDataset(Dataset):
def __init__(self, path, x_col, y_col):
with open(path, 'r', encoding='utf-8') as f:
self.data_list = json.load(f)
self.x_col = x_col
self.y_col = y_col
def __getitem__(self, index):
row = self.data_list[index]
return row[self.x_col], row[self.y_col]
def __len__(self):
return len(self.data_list)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_path",
type=str,
default="dataset/train.json",
help="train dataset path",
)
parser.add_argument(
"--val_path",
type=str,
default="dataset/val.json",
help="val dataset path",
)
parser.add_argument("--prompt_path", type=str, default="prompt/init_prompt.txt", help="prompts dir")
parser.add_argument(
"--output_dir", type=str, default="output_dir", help="Path to output dir"
)
parser.add_argument(
"--x_col", type=str, default="original_text", help="dataset x column name"
)
parser.add_argument(
"--y_col", type=str, default="result", help="dataset y column name"
)
parser.add_argument(
"--batch_size", type=int, default=10, help="batch size"
)
parser.add_argument(
"--epoch", type=int, default=3, help="epoch"
)
args = parser.parse_args()
# create output dir
output_dir = os.path.join(args.output_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_path = os.path.join(output_dir, f'{get_now()}.jsonl')
print(output_path)
def load_contract_review_task():
train_ds = ContractReviewDataset(args.train_path, args.x_col, args.y_col)
val_ds = ContractReviewDataset(args.val_path, args.x_col, args.y_col)
return train_ds, val_ds, None, StringBasedFunction(string_based_equality_fn,
function_purpose="The runtime of string-based function that checks if the prediction is correct.")
def save_results(results):
output_dict = {}
for k, v in results.items():
output_dict[k] = v[-1]
append_dict_to_jsonl(output_dict, output_path)
# init engine,model,optimizer
os.environ['OLLAMA_BASE_URL'] = 'http://192.168.252.71:9002/v1'
llm_engine = tg.get_engine("ollama-Qwen2-72B-Instruct")
tg.set_backward_engine("ollama-Qwen2-72B-Instruct")
# init datasets prompt; init eval
train_set, val_set, _, eval_fn = load_contract_review_task()
train_loader = tg.tasks.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
with open(args.prompt_path, 'r', encoding='utf-8') as f:
init_prompt = f.read()
system_prompt = tg.Variable(init_prompt,
requires_grad=True,
role_description="system prompt to guide the LLM's reasoning strategy for accurate responses")
model = tg.BlackboxLLM(llm_engine, system_prompt=system_prompt)
optimizer = tg.TGD(parameters=list(model.parameters()))
results = {"train_acc": [], "prompt": [], "validation_acc": []}
results["train_acc"].append(np.mean(eval_dataset(train_set, eval_fn, model)))
results["validation_acc"].append(np.mean(eval_dataset(val_set, eval_fn, model)))
results["prompt"].append(system_prompt.get_value())
save_results(results)
# backward
for epoch in range(args.epoch):
for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
pbar.set_description(f"Training step {steps}. Epoch {epoch}")
optimizer.zero_grad()
losses = []
for (x, y) in zip(batch_x, batch_y):
x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
response = model(x)
eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
losses.append(eval_output_variable)
total_loss = tg.sum(losses)
total_loss.backward()
optimizer.step()
run_validation_revert(system_prompt, results, model, eval_fn, val_set)
# print("sys prompt: ", system_prompt)
results["train_acc"].append(np.mean(eval_dataset(train_set, eval_fn, model)))
results["validation_acc"].append(np.mean(eval_dataset(val_set, eval_fn, model)))
results["prompt"].append(system_prompt.get_value())
save_results(results)
{"train_acc": 0.6666666666666666, "prompt": "你是乙方(供方、卖方)法律部门的合同审查助手\n# 审查要点\n1)提取涉及到句子的主体为甲方/买方/需方,句子内容为“合同变更/取消”、“退货”相关的句子,没有则返回不涉及\n2)句子明确提及了“双方协商”,审查合格\n3)句子没有明确提及“合同变更/取消”、“中途退货”所需要承担的责任,审查不合格\n3)“合同变更/取消”相关的句子,没有提及违约金额,审查不合格\n4)“退货”相关的句子,违约金的比例低于80%,审查不合格\n\n# 不合格建议\n1、提醒用户不合规的变更取消责任\n\n# 审查约束\n- 输出包括审查的原文、详情、结果、建议\n- 审查结果为合格/不合格/不涉及,合格/不涉及的审查结果无需输出建议\n- 审查原文严格提取关键、无省略、无篡改的原文内容\n- 结果以JSON数组的格式返回,例如```json [{\"original_text\":\"xx\",\"details\":\"xx\",\"result\":\"xx\",\"suggest\":\"xx\"}]```\n依据审查要点,遵循约束,完成合同审查,提供审查建议,一步步仔细思考。", "validation_acc": 0.6666666666666666}
你是乙方(供方、卖方)法律部门的合同审查助手
# 审查要点
1)提取涉及到句子的主体为甲方/买方/需方,句子内容为“合同变更/取消”、“退货”相关的句子,没有则返回不涉及
2)句子明确提及了“双方协商”,审查合格
3)句子没有明确提及“合同变更/取消”、“中途退货”所需要承担的责任,审查不合格
3)“合同变更/取消”相关的句子,没有提及违约金额,审查不合格
4)“退货”相关的句子,违约金的比例低于80%,审查不合格
# 不合格建议
1、提醒用户不合规的变更取消责任
# 审查约束
- 输出包括审查的原文、详情、结果、建议
- 审查结果为合格/不合格/不涉及,合格/不涉及的审查结果无需输出建议
- 审查原文严格提取关键、无省略、无篡改的原文内容
- 结果以JSON数组的格式返回,例如```json [{"original_text":"xx","details":"xx","result":"xx","suggest":"xx"}]```
依据审查要点,遵循约束,完成合同审查,提供审查建议,一步步仔细思考。
\ No newline at end of file
import textgrad as tg
import re
import json
import numpy as np
import concurrent
from tqdm import tqdm
from datetime import datetime
def eval_sample(item, eval_fn, model):
"""
This function allows us to evaluate if an answer to a question in the prompt is a good answer.
"""
x, y = item
x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
response = model(x)
try:
eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
return 1 if eval_output_variable.value == '正确' else 0
except:
return 0
def run_validation_revert(system_prompt: tg.Variable, results, model, eval_fn, val_set):
val_performance = np.mean(eval_dataset(val_set, eval_fn, model))
previous_performance = np.mean(results["validation_acc"][-1])
print("val_performance: ", val_performance)
print("previous_performance: ", previous_performance)
previous_prompt = results["prompt"][-1]
if val_performance < previous_performance:
# print(f"rejected prompt: {system_prompt.value}")
system_prompt.set_value(previous_prompt)
val_performance = previous_performance
results["validation_acc"].append(val_performance)
def eval_dataset(test_set, eval_fn, model, max_samples: int = None):
if max_samples is None:
max_samples = len(test_set)
accuracy_list = []
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for _, sample in enumerate(test_set):
future = executor.submit(eval_sample, sample, eval_fn, model)
futures.append(future)
if len(futures) >= max_samples:
break
tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)
for future in tqdm_loader:
acc_item = future.result()
accuracy_list.append(acc_item)
tqdm_loader.set_description(f"Accuracy: {np.mean(accuracy_list)}")
return accuracy_list
def extract_json(json_str):
json_pattern = r'```json([\s\S]*?)```'
matches = re.findall(json_pattern, json_str, re.DOTALL)
json_list = []
for match in matches:
# 去除可能存在的前后空白字符
clean_json_str = match.strip()
try:
json_obj = json.loads(clean_json_str)
if isinstance(json_obj, list):
json_list += json_obj
else:
json_list.append(json_obj)
except json.JSONDecodeError as e:
print(f"发现了一个无法解析的JSON字符串: {clean_json_str} {e}")
return json_list
def string_based_equality_fn(prediction: tg.Variable, ground_truth_answer: tg.Variable):
json_res = extract_json(prediction.value)
check_res = json_res[0]['result'] == ground_truth_answer.value if json_res else False
return '正确' if check_res else '错误'
def append_dict_to_jsonl(dictionary, file_path):
with open(file_path, 'a', encoding='utf-8') as f:
json.dump(dictionary, f, ensure_ascii=False)
f.write('\n')
f.flush()
def get_now():
now = datetime.now()
formatted_time = now.strftime("%Y%m%d-%H%M%S")
return formatted_time
if __name__ == '__main__':
# append_dict_to_jsonl({'a': 'zz'}, 'test.jsonl')
# append_dict_to_jsonl({'a': 'ff'}, 'test.jsonl')
print(get_now())
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment