Commit bd83f8bf by ccran

feat: fix forward backward exception

parent a2be6637
.idea
__pycache__
logs
output_dir
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -7,6 +7,12 @@ from textgrad.tasks import load_task, Dataset
from textgrad.autograd.string_based_ops import StringBasedFunction
def single_column_fc(prediction: tg.Variable, ground_truth_answer: tg.Variable):
json_res = extract_json(prediction.value)
check_res = json_res[0]['result'] == ground_truth_answer.value if json_res else False
return '正确' if check_res else '错误'
class ContractReviewDataset(Dataset):
def __init__(self, path, x_col, y_col):
with open(path, 'r', encoding='utf-8') as f:
......@@ -27,16 +33,16 @@ if __name__ == "__main__":
parser.add_argument(
"--train_path",
type=str,
default="dataset/train.json",
default="dataset/变更取消责任审查.json",
help="train dataset path",
)
parser.add_argument(
"--val_path",
type=str,
default="dataset/val.json",
default="dataset/变更取消责任审查.json",
help="val dataset path",
)
parser.add_argument("--prompt_path", type=str, default="prompt/init_prompt.txt", help="prompts dir")
parser.add_argument("--prompt_path", type=str, default="prompt/变更取消责任审查.txt", help="prompts dir")
parser.add_argument(
"--output_dir", type=str, default="output_dir", help="Path to output dir"
)
......@@ -44,7 +50,7 @@ if __name__ == "__main__":
"--x_col", type=str, default="original_text", help="dataset x column name"
)
parser.add_argument(
"--y_col", type=str, default="result", help="dataset y column name"
"--y_col", type=str, default="ground_truth", help="dataset y column name"
)
parser.add_argument(
"--batch_size", type=int, default=10, help="batch size"
......@@ -64,8 +70,7 @@ if __name__ == "__main__":
def load_contract_review_task():
train_ds = ContractReviewDataset(args.train_path, args.x_col, args.y_col)
val_ds = ContractReviewDataset(args.val_path, args.x_col, args.y_col)
return train_ds, val_ds, None, StringBasedFunction(string_based_equality_fn,
function_purpose="The runtime of string-based function that checks if the prediction is correct.")
return train_ds, val_ds, None, StringBasedFunction(single_column_fc, function_purpose="比较审查结果是否正确")
def save_results(results):
......@@ -76,7 +81,7 @@ if __name__ == "__main__":
# init engine,model,optimizer
os.environ['OLLAMA_BASE_URL'] = 'http://192.168.252.71:9002/v1'
os.environ['OLLAMA_BASE_URL'] = 'http://218.77.58.8:8088/qwen/v1/'
llm_engine = tg.get_engine("ollama-Qwen2-72B-Instruct")
tg.set_backward_engine("ollama-Qwen2-72B-Instruct")
# init datasets prompt; init eval
......@@ -86,7 +91,7 @@ if __name__ == "__main__":
init_prompt = f.read()
system_prompt = tg.Variable(init_prompt,
requires_grad=True,
role_description="system prompt to guide the LLM's reasoning strategy for accurate responses")
role_description="用于合同审查的系统提示词")
model = tg.BlackboxLLM(llm_engine, system_prompt=system_prompt)
optimizer = tg.TGD(parameters=list(model.parameters()))
results = {"train_acc": [], "prompt": [], "validation_acc": []}
......@@ -101,14 +106,17 @@ if __name__ == "__main__":
optimizer.zero_grad()
losses = []
for (x, y) in zip(batch_x, batch_y):
x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
x = tg.Variable(x, requires_grad=False, role_description="输入的合同")
y = tg.Variable(y, requires_grad=False, role_description="审查得到的结果")
response = model(x)
eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
losses.append(eval_output_variable)
try:
total_loss = tg.sum(losses)
total_loss.backward()
optimizer.step()
except Exception as e:
print('error in backward', e)
run_validation_revert(system_prompt, results, model, eval_fn, val_set)
# print("sys prompt: ", system_prompt)
results["train_acc"].append(np.mean(eval_dataset(train_set, eval_fn, model)))
......
{"train_acc": 0.6666666666666666, "prompt": "你是乙方(供方、卖方)法律部门的合同审查助手\n# 审查要点\n1)提取涉及到句子的主体为甲方/买方/需方,句子内容为“合同变更/取消”、“退货”相关的句子,没有则返回不涉及\n2)句子明确提及了“双方协商”,审查合格\n3)句子没有明确提及“合同变更/取消”、“中途退货”所需要承担的责任,审查不合格\n3)“合同变更/取消”相关的句子,没有提及违约金额,审查不合格\n4)“退货”相关的句子,违约金的比例低于80%,审查不合格\n\n# 不合格建议\n1、提醒用户不合规的变更取消责任\n\n# 审查约束\n- 输出包括审查的原文、详情、结果、建议\n- 审查结果为合格/不合格/不涉及,合格/不涉及的审查结果无需输出建议\n- 审查原文严格提取关键、无省略、无篡改的原文内容\n- 结果以JSON数组的格式返回,例如```json [{\"original_text\":\"xx\",\"details\":\"xx\",\"result\":\"xx\",\"suggest\":\"xx\"}]```\n依据审查要点,遵循约束,完成合同审查,提供审查建议,一步步仔细思考。", "validation_acc": 0.6666666666666666}
{"train_acc": 0.6666666666666666, "prompt": "你是乙方(供方、卖方)法律部门的合同审查助手\n\n# **预验证检查** \n- **第一步**:判断输入句子是否**明确涉及**\"合同变更/取消\"或\"退货\"。 \n - 若**不涉及**,输出 `{\"result\": \"不涉及\", \"suggest\": \"不涉及\"}` **且不生成其他字段**。 \n - 若**涉及**,继续后续审查步骤。 \n- **法律术语精度**: \n - \"双方协商\"仅在涉及合同变更/取消或退货条款时作为合格条件,其他场景无关。 \n\n# **修订审查要点** \n1. **适用性检查**: \n - 若句子不涉及\"合同变更/取消\"或\"退货\",直接返回 `{\"result\": \"不涉及\", \"suggest\": \"不涉及\"}`。 \n2. **主体与内容匹配**: \n - 提取甲方/买方/需方为行为主体,且内容涉及\"合同变更/取消\"或\"退货\"的句子。 \n3. **合格条件**: \n - 若句子明确包含\"双方协商\",标记为**合格**(`\"suggest\": \"\"`)。 \n4. **不合格条件**: \n - 若未提及\"合同变更/取消\"或\"中途退货\"所需承担责任,标记为**不合格**。 \n - 若\"合同变更/取消\"相关句子未约定违约金额,标记为**不合格**。 \n - 若\"退货\"相关句子违约金比例低于80%,标记为**不合格**。 \n\n# **不合格建议** \n1. 提醒用户补充\"双方协商\"条款(仅限涉及变更/取消或退货场景)。 \n2. 若审查结果为\"不涉及\",需在`suggest`字段中明确标注\"不涉及\"。 \n\n# **修订审查约束** \n- 输出必须包含 `original_text`、`details`、`result`、`suggest` 四个字段。 \n- `result` 为 \"合格\"、\"不合格\"、\"不涉及\": \n - \"合格\" 的 `suggest` 字段为空字符串(`\"\"`)。 \n - \"不涉及\" 的 `suggest` 字段必须为 `\"不涉及\"`。 \n - \"不合格\" 的 `suggest` 字段需提供具体修改建议。 \n- `original_text` 严格提取原文,无省略、无篡改。 \n- 输出格式示例: \n ```json \n [ \n { \n \"original_text\": \"xx\", \n \"details\": \"xx\", \n \"result\": \"xx\", \n \"suggest\": \"xx\" \n } \n ] \n ``` \n\n# **语义一致性要求** \n- `suggest` 字段必须与 `result` 语义对齐: \n - \"不涉及\" 的 `suggest` 必须为 `\"不涉及\"`。 \n - \"合格\" 的 `suggest` 可为空字符串。 \n- 禁止使用空字符串表示非\"合格\"结果。 \n\n# **测试用例** \n- 输入: \"合同条款与变更/取消或退货无关\" \n - 预期输出: \n ```json \n [ \n { \n \"original_text\": \"合同条款与变更/取消或退货无关\", \n \"details\": \"该条款不涉及合同变更/取消或退货条款\", \n \"result\": \"不涉及\", \n \"suggest\": \"不涉及\" \n } \n ] \n ``` \n\n# **错误模式识别** \n- 若 `result` 为 \"不涉及\",任何 `suggest` 字段包含新条款建议(如\"双方协商\"、违约金调整)均视为错误。 \n- 严格禁止在 \"不涉及\" 场景中引入法律复杂性。 \n\n依据审查要点,遵循约束,完成合同审查,提供审查建议,**先验证适用性,再逐步推理**。", "validation_acc": 0.6666666666666666}
......@@ -15,11 +15,12 @@ def eval_sample(item, eval_fn, model):
x, y = item
x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
response = model(x)
try:
response = model(x)
eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
return 1 if eval_output_variable.value == '正确' else 0
except:
except Exception as e:
print('error in eval_sample: ', e)
return 0
......@@ -75,12 +76,6 @@ def extract_json(json_str):
return json_list
def string_based_equality_fn(prediction: tg.Variable, ground_truth_answer: tg.Variable):
json_res = extract_json(prediction.value)
check_res = json_res[0]['result'] == ground_truth_answer.value if json_res else False
return '正确' if check_res else '错误'
def append_dict_to_jsonl(dictionary, file_path):
with open(file_path, 'a', encoding='utf-8') as f:
json.dump(dictionary, f, ensure_ascii=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment