import argparse
import json
import os
from util import *

import textgrad as tg
from textgrad.tasks import load_task, Dataset
from textgrad.autograd.string_based_ops import StringBasedFunction

args = None


def single_column_fc(prediction: tg.Variable, ground_truth_answer: tg.Variable):
    json_res = extract_json(prediction.value)
    ground_truth = json.loads(ground_truth_answer.value)
    y = ground_truth['result']
    advice = ground_truth['detail']
    check_res = json_res[0]['result'] == y if json_res else False
    res_str = '正确' if check_res else '错误'
    return f'你的判断结果为[{res_str}],真实的审查结果result为[{y}],真实的审查详情detail为[{advice}]'


class ContractReviewDataset(Dataset):
    def __init__(self, path, x_col, y_col):
        with open(path, 'r', encoding='utf-8') as f:
            self.data_list = json.load(f)
        self.x_col = x_col
        self.y_cols = y_col.split(',')

    def __getitem__(self, index):
        row = self.data_list[index]
        all_y = {col: row[col] for col in self.y_cols}
        return row[self.x_col], json.dumps(all_y, ensure_ascii=False)

    def __len__(self):
        return len(self.data_list)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train_path",
        type=str,
        default="dataset/发票审查.json",
        help="train dataset path",
    )
    parser.add_argument(
        "--val_path",
        type=str,
        default="dataset/发票审查.json",
        help="val dataset path",
    )
    parser.add_argument("--prompt_path", type=str, default="prompt/发票审查.txt", help="prompts dir")
    parser.add_argument(
        "--output_dir", type=str, default="output_dir", help="Path to output dir"
    )
    parser.add_argument(
        "--x_col", type=str, default="original_text", help="dataset x column name"
    )
    parser.add_argument(
        "--y_cols", type=str, default="result,detail", help="dataset y column name"
    )
    parser.add_argument(
        "--batch_size", type=int, default=10, help="batch size"
    )
    parser.add_argument(
        "--epoch", type=int, default=10, help="epoch"
    )
    args = parser.parse_args()
    # create output dir
    output_dir = os.path.join(args.output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_path = os.path.join(output_dir, f'{get_now()}.jsonl')
    print(output_path)


    def load_contract_review_task():
        train_ds = ContractReviewDataset(args.train_path, args.x_col, args.y_cols)
        val_ds = ContractReviewDataset(args.val_path, args.x_col, args.y_cols)
        return train_ds, val_ds, None, StringBasedFunction(single_column_fc, function_purpose="比较审查结果是否正确")


    def save_results(results):
        output_dict = {}
        for k, v in results.items():
            output_dict[k] = v[-1]
        append_dict_to_jsonl(output_dict, output_path)


    # init engine,model,optimizer
    # os.environ['OLLAMA_BASE_URL'] = 'http://218.77.58.8:8088/qwen/v1/'
    os.environ['OLLAMA_BASE_URL'] = 'http://192.168.252.71:9002/v1/'
    llm_engine = tg.get_engine("ollama-Qwen2-72B-Instruct")
    tg.set_backward_engine("ollama-Qwen2-72B-Instruct")
    # init datasets prompt; init eval
    train_set, val_set, _, eval_fn = load_contract_review_task()
    train_loader = tg.tasks.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    with open(args.prompt_path, 'r', encoding='utf-8') as f:
        init_prompt = f.read()
    system_prompt = tg.Variable(init_prompt,
                                requires_grad=True,
                                role_description="用于合同审查的系统提示词")
    model = tg.BlackboxLLM(llm_engine, system_prompt=system_prompt)
    optimizer = tg.TGD(parameters=list(model.parameters()))
    results = {"train_acc": [], "prompt": [], "validation_acc": []}
    results["train_acc"].append(np.mean(eval_dataset(train_set, eval_fn, model)))
    results["validation_acc"].append(np.mean(eval_dataset(val_set, eval_fn, model)))
    results["prompt"].append(system_prompt.get_value())
    save_results(results)
    # backward
    for epoch in range(args.epoch):
        for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
            pbar.set_description(f"Training step {steps}. Epoch {epoch}")
            optimizer.zero_grad()
            losses = []
            for (x, y) in zip(batch_x, batch_y):
                x = tg.Variable(x, requires_grad=False, role_description="输入的合同")
                y = tg.Variable(y, requires_grad=False, role_description="审查得到的结果")
                response = model(x)
                eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
                losses.append(eval_output_variable)
            try:
                total_loss = tg.sum(losses)
                total_loss.backward()
                optimizer.step()
            except Exception as e:
                print('error in backward', e)
            run_validation_revert(system_prompt, results, model, eval_fn, val_set)
            # print("sys prompt: ", system_prompt)
            results["train_acc"].append(np.mean(eval_dataset(train_set, eval_fn, model)))
            results["validation_acc"].append(np.mean(eval_dataset(val_set, eval_fn, model)))
            results["prompt"].append(system_prompt.get_value())
            save_results(results)
