import argparse
import os
from util import *

import textgrad as tg
from textgrad.tasks import load_task, Dataset
from textgrad.autograd.string_based_ops import StringBasedFunction


class ContractReviewDataset(Dataset):
    def __init__(self, path, x_col, y_col):
        with open(path, 'r', encoding='utf-8') as f:
            self.data_list = json.load(f)
        self.x_col = x_col
        self.y_col = y_col

    def __getitem__(self, index):
        row = self.data_list[index]
        return row[self.x_col], row[self.y_col]

    def __len__(self):
        return len(self.data_list)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train_path",
        type=str,
        default="dataset/train.json",
        help="train dataset path",
    )
    parser.add_argument(
        "--val_path",
        type=str,
        default="dataset/val.json",
        help="val dataset path",
    )
    parser.add_argument("--prompt_path", type=str, default="prompt/init_prompt.txt", help="prompts dir")
    parser.add_argument(
        "--output_dir", type=str, default="output_dir", help="Path to output dir"
    )
    parser.add_argument(
        "--x_col", type=str, default="original_text", help="dataset x column name"
    )
    parser.add_argument(
        "--y_col", type=str, default="result", help="dataset y column name"
    )
    parser.add_argument(
        "--batch_size", type=int, default=10, help="batch size"
    )
    parser.add_argument(
        "--epoch", type=int, default=3, help="epoch"
    )
    args = parser.parse_args()
    # create output dir
    output_dir = os.path.join(args.output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_path = os.path.join(output_dir, f'{get_now()}.jsonl')
    print(output_path)


    def load_contract_review_task():
        train_ds = ContractReviewDataset(args.train_path, args.x_col, args.y_col)
        val_ds = ContractReviewDataset(args.val_path, args.x_col, args.y_col)
        return train_ds, val_ds, None, StringBasedFunction(string_based_equality_fn,
                                                           function_purpose="The runtime of string-based function that checks if the prediction is correct.")


    def save_results(results):
        output_dict = {}
        for k, v in results.items():
            output_dict[k] = v[-1]
        append_dict_to_jsonl(output_dict, output_path)


    # init engine,model,optimizer
    os.environ['OLLAMA_BASE_URL'] = 'http://192.168.252.71:9002/v1'
    llm_engine = tg.get_engine("ollama-Qwen2-72B-Instruct")
    tg.set_backward_engine("ollama-Qwen2-72B-Instruct")
    # init datasets prompt; init eval
    train_set, val_set, _, eval_fn = load_contract_review_task()
    train_loader = tg.tasks.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    with open(args.prompt_path, 'r', encoding='utf-8') as f:
        init_prompt = f.read()
    system_prompt = tg.Variable(init_prompt,
                                requires_grad=True,
                                role_description="system prompt to guide the LLM's reasoning strategy for accurate responses")
    model = tg.BlackboxLLM(llm_engine, system_prompt=system_prompt)
    optimizer = tg.TGD(parameters=list(model.parameters()))
    results = {"train_acc": [], "prompt": [], "validation_acc": []}
    results["train_acc"].append(np.mean(eval_dataset(train_set, eval_fn, model)))
    results["validation_acc"].append(np.mean(eval_dataset(val_set, eval_fn, model)))
    results["prompt"].append(system_prompt.get_value())
    save_results(results)
    # backward
    for epoch in range(args.epoch):
        for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
            pbar.set_description(f"Training step {steps}. Epoch {epoch}")
            optimizer.zero_grad()
            losses = []
            for (x, y) in zip(batch_x, batch_y):
                x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
                y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
                response = model(x)
                eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
                losses.append(eval_output_variable)
            total_loss = tg.sum(losses)
            total_loss.backward()
            optimizer.step()
            run_validation_revert(system_prompt, results, model, eval_fn, val_set)
            # print("sys prompt: ", system_prompt)
            results["train_acc"].append(np.mean(eval_dataset(train_set, eval_fn, model)))
            results["validation_acc"].append(np.mean(eval_dataset(val_set, eval_fn, model)))
            results["prompt"].append(system_prompt.get_value())
            save_results(results)
