import textgrad as tg
import re
import json
import numpy as np
import concurrent
from tqdm import tqdm
from datetime import datetime


def eval_sample(item, eval_fn, model):
    """
    This function allows us to evaluate if an answer to a question in the prompt is a good answer.

    """
    x, y = item
    x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
    y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
    try:
        response = model(x)
        eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
        return 1 if '你的判断结果为[正确]' in eval_output_variable.value else 0
    except Exception as e:
        print('error in eval_sample: ', e)
        return 0


def run_validation_revert(system_prompt: tg.Variable, results, model, eval_fn, val_set):
    val_performance = np.mean(eval_dataset(val_set, eval_fn, model))
    previous_performance = np.mean(results["validation_acc"][-1])
    print("val_performance: ", val_performance)
    print("previous_performance: ", previous_performance)
    previous_prompt = results["prompt"][-1]

    if val_performance < previous_performance:
        # print(f"rejected prompt: {system_prompt.value}")
        system_prompt.set_value(previous_prompt)
        val_performance = previous_performance

    results["validation_acc"].append(val_performance)


def eval_dataset(test_set, eval_fn, model, max_samples: int = None):
    if max_samples is None:
        max_samples = len(test_set)
    accuracy_list = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        futures = []
        for _, sample in enumerate(test_set):
            future = executor.submit(eval_sample, sample, eval_fn, model)
            futures.append(future)
            if len(futures) >= max_samples:
                break
        tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)
        for future in tqdm_loader:
            acc_item = future.result()
            accuracy_list.append(acc_item)
            tqdm_loader.set_description(f"Accuracy: {np.mean(accuracy_list)}")
    return accuracy_list


def extract_json(json_str):
    json_pattern = r'```json([\s\S]*?)```'
    matches = re.findall(json_pattern, json_str, re.DOTALL)
    json_list = []
    for match in matches:
        # 去除可能存在的前后空白字符
        clean_json_str = match.strip()
        try:
            json_obj = json.loads(clean_json_str)
            if isinstance(json_obj, list):
                json_list += json_obj
            else:
                json_list.append(json_obj)
        except json.JSONDecodeError as e:
            print(f"发现了一个无法解析的JSON字符串: {clean_json_str} {e}")
    return json_list


def append_dict_to_jsonl(dictionary, file_path):
    with open(file_path, 'a', encoding='utf-8') as f:
        json.dump(dictionary, f, ensure_ascii=False)
        f.write('\n')
        f.flush()


def get_now():
    now = datetime.now()
    formatted_time = now.strftime("%Y%m%d-%H%M%S")
    return formatted_time


if __name__ == '__main__':
    # append_dict_to_jsonl({'a': 'zz'}, 'test.jsonl')
    # append_dict_to_jsonl({'a': 'ff'}, 'test.jsonl')
    print(get_now())
    pass
