Commit 264f20f5 by ccran

feat: add merged_all.json

parent 2bfc27eb
{ {
"python-envs.defaultEnvManager": "ms-python.python:conda", "python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda", "python-envs.defaultPackageManager": "ms-python.python:conda"
"python-envs.pythonProjects": []
} }
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
from typing import Any
def group_rows_by_review_item_with_dedup(
rows: list[dict[str, Any]],
) -> dict[str, list[dict[str, Any]]]:
grouped_rows: dict[str, list[dict[str, Any]]] = {}
seen_original_texts: dict[str, set[Any]] = {}
for row in rows:
review_item = row.get("review_item")
review_item_key = "<None>" if review_item is None else str(review_item).strip()
if review_item_key not in grouped_rows:
grouped_rows[review_item_key] = []
seen_original_texts[review_item_key] = set()
original_text = row.get("original_text")
original_text_key = original_text.strip() if isinstance(original_text, str) else original_text
if original_text_key in seen_original_texts[review_item_key]:
continue
seen_original_texts[review_item_key].add(original_text_key)
grouped_rows[review_item_key].append(row)
return dict(sorted(grouped_rows.items(), key=lambda item: item[0]))
if __name__ == "__main__":
rows = [
{
"review_item": "合同主体",
"original_text": "甲方:XXX公司",
},{
"review_item": "合同主体",
"original_text": "甲方:XXX公司"}
]
result = group_rows_by_review_item_with_dedup(rows)
print(result)
\ No newline at end of file
...@@ -118,7 +118,7 @@ def apply_review_item_mapping(rows: list[dict[str, Any]], mapping: dict[str, Any ...@@ -118,7 +118,7 @@ def apply_review_item_mapping(rows: list[dict[str, Any]], mapping: dict[str, Any
row["review_item"] = review_item_text row["review_item"] = review_item_text
mapped_value = mapping.get(review_item_text) mapped_value = mapping.get(review_item_text)
if mapped_value is not None: if mapped_value is not None:
row["prompt"] = mapped_value row["prompt"] = mapped_value.replace('&#10;', '\n')
else: else:
unmatched_review_items.add(review_item_text) unmatched_review_items.add(review_item_text)
...@@ -194,7 +194,7 @@ def filter_rows_by_llm_result( ...@@ -194,7 +194,7 @@ def filter_rows_by_llm_result(
max_retries: int = 10, max_retries: int = 10,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
kept_rows: list[dict[str, Any]] = [] kept_rows: list[dict[str, Any]] = []
for index, row in enumerate(rows, start=1): for index, row in enumerate(rows, start=1):
system_prompt = row.get("prompt") system_prompt = row.get("prompt")
user_text = row.get("original_text") user_text = row.get("original_text")
...@@ -248,7 +248,7 @@ def main() -> None: ...@@ -248,7 +248,7 @@ def main() -> None:
parser.add_argument("--json-dir", default="datasets/json", help="JSON目录路径") parser.add_argument("--json-dir", default="datasets/json", help="JSON目录路径")
parser.add_argument("--sheet", default=0, help="sheet名称或索引,默认0") parser.add_argument("--sheet", default=0, help="sheet名称或索引,默认0")
parser.add_argument("--model", default=model, help="OpenAI模型名称") parser.add_argument("--model", default=model, help="OpenAI模型名称")
parser.add_argument("--max-retries", type=int, default=10, help="每条数据最大重试次数") parser.add_argument("--max-retries", type=int, default=1, help="每条数据最大重试次数")
parser.add_argument("--openai-base-url", default=base_url, help="OpenAI接口基础URL") parser.add_argument("--openai-base-url", default=base_url, help="OpenAI接口基础URL")
parser.add_argument("--out", default="datasets/merged/merged_all.json", help="输出JSON文件路径") parser.add_argument("--out", default="datasets/merged/merged_all.json", help="输出JSON文件路径")
args = parser.parse_args() args = parser.parse_args()
...@@ -269,13 +269,16 @@ def main() -> None: ...@@ -269,13 +269,16 @@ def main() -> None:
client = OpenAI(api_key=api_key, base_url=args.openai_base_url) client = OpenAI(api_key=api_key, base_url=args.openai_base_url)
# rows = rows[:2] # rows = rows[:2]
rows = filter_rows_by_llm_result( rows = filter_rows_by_llm_result(
rows=rows, rows=rows,
client=client, client=client,
model=args.model, model=args.model,
max_retries=max(1, args.max_retries), max_retries=max(1, args.max_retries),
) )
grouped_rows = group_rows_by_review_item_with_dedup(rows) grouped_rows = group_rows_by_review_item_with_dedup(rows)
out_path = Path(args.out) out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
......
...@@ -43,7 +43,9 @@ def read_json_lists_from_dir(json_dir: str | Path) -> list[dict[str, Any]]: ...@@ -43,7 +43,9 @@ def read_json_lists_from_dir(json_dir: str | Path) -> list[dict[str, Any]]:
rows = read_json_list(file_path) rows = read_json_list(file_path)
for row in rows: for row in rows:
row['review_item'] = file_path.name.split('.')[0] row['review_item'] = file_path.name.split('.')[0]
row["__source_file__"] = file_path.name row['__source_file__'] = file_path.name
if 'ground_truth' in row and row['ground_truth'] == '不涉及':
row['ground_truth'] = '合格'
all_rows.extend(rows) all_rows.extend(rows)
return all_rows return all_rows
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment