import os
from util import *

os.environ['OLLAMA_BASE_URL'] = 'http://192.168.252.71:9002/v1'

import textgrad as tg
from textgrad.tasks import load_task, Dataset
from textgrad.autograd.string_based_ops import StringBasedFunction


class ContractReviewDataset(Dataset):
    def __init__(self):
        self.datas = [
            {'original_text': '8.1.2.6 向甲方支付违约基数30%的违约金。', 'details': '句子主语为乙方', 'result': '不涉及',
             'suggest': ''},
            {'original_text': '8.1.2.2 要求乙方全额退还甲方已支付的预付款项。', 'details': '句子主语为乙方',
             'result': '不涉及', 'suggest': ''},
            {'original_text': '8.1.2.2 2.13.3.5 甲方有权取消该笔订单/合同，乙方需支付违约基数30%的违约金。',
             'details': '',
             'result': '不合格', 'suggest': ''}
        ]

    def __getitem__(self, index):
        row = self.datas[index]
        return row["original_text"], row["result"]

    def __len__(self):
        return len(self.datas)


def load_contract_review_task():
    ds = ContractReviewDataset()
    return ds, ds, ds, StringBasedFunction(string_based_equality_fn,
                                           function_purpose="The runtime of string-based function that checks if the prediction is correct.")


# init engine
llm_engine = tg.get_engine("ollama-Qwen2-72B-Instruct")
tg.set_backward_engine("ollama-Qwen2-72B-Instruct")

# init datasets
train_set, val_set, _, eval_fn = load_contract_review_task()
train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)

# prediction
init_prompt = '''
你是乙方（供方、卖方）法律部门的合同审查助手
# 审查要点
1）提取涉及到句子的主体为甲方/买方/需方，句子内容为“合同变更/取消”、“退货”相关的句子，没有则返回不涉及
2）句子明确提及了“双方协商”，审查合格
3）句子没有明确提及“合同变更/取消”、“中途退货”所需要承担的责任，审查不合格
3）“合同变更/取消”相关的句子，没有提及违约金额，审查不合格
4）“退货”相关的句子，违约金的比例低于80%，审查不合格

# 不合格建议
1、提醒用户不合规的变更取消责任

# 审查约束
- 输出包括审查的原文、详情、结果、建议
- 审查结果为合格/不合格/不涉及，合格/不涉及的审查结果无需输出建议
- 审查原文严格提取关键、无省略、无篡改的原文内容
- 结果以JSON数组的格式返回,例如```json [{"original_text":"xx","details":"xx","result":"xx","suggest":"xx"}]```
依据审查要点，遵循约束，完成合同审查，提供审查建议，一步步仔细思考。
'''
system_prompt = tg.Variable(init_prompt,
                            requires_grad=True,
                            role_description="system prompt to guide the LLM's reasoning strategy for accurate responses")

model = tg.BlackboxLLM(llm_engine, system_prompt=system_prompt)
optimizer = tg.TGD(parameters=list(model.parameters()))
results = {"train_acc": [], "prompt": [], "validation_acc": []}
results["train_acc"].append(eval_dataset(train_set, eval_fn, model))
results["validation_acc"].append(eval_dataset(val_set, eval_fn, model))
results["prompt"].append(system_prompt.get_value())

# 反向传播
for epoch in range(3):
    for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
        pbar.set_description(f"Training step {steps}. Epoch {epoch}")
        optimizer.zero_grad()
        losses = []
        for (x, y) in zip(batch_x, batch_y):
            x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
            y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
            response = model(x)
            eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
            losses.append(eval_output_variable)
        total_loss = tg.sum(losses)
        total_loss.backward()
        optimizer.step()
        run_validation_revert(system_prompt, results, model, eval_fn, val_set)
        print("sys prompt: ", system_prompt)
        results["validation_acc"].append(eval_dataset(val_set, eval_fn, model))
        results["prompt"].append(system_prompt.get_value())
