From 34cece66704503cdde879c38388ea7e58b6276ee Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 6 Jan 2026 18:00:19 +0800 Subject: [PATCH 1/2] 1. Remove `base_model` in docs. 2. Add `compute_score` for `learn_to_ask`. --- .../source/tutorial/example_tinker_backend.md | 3 - .../tutorial/example_tinker_backend.md | 3 - examples/learn_to_ask/README.md | 1 + .../data_prepare/3_rollout_then_evaluate.py | 75 ++++++++++++++++--- examples/tinker/README.md | 3 - examples/tinker/tinker.yaml | 1 - 6 files changed, 66 insertions(+), 20 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_tinker_backend.md b/docs/sphinx_doc/source/tutorial/example_tinker_backend.md index a1a3db5061..4e9dc0e7bc 100644 --- a/docs/sphinx_doc/source/tutorial/example_tinker_backend.md +++ b/docs/sphinx_doc/source/tutorial/example_tinker_backend.md @@ -23,7 +23,6 @@ Configure the Tinker backend in your YAML configuration file by setting the `mod model: tinker: enable: true - base_model: null rank: 32 seed: null train_mlp: true @@ -35,7 +34,6 @@ model: - **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored. - **`enable`**: Whether to activate the Tinker backend. Default: `false` - - **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config - **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32` - **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set - **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true` @@ -94,7 +92,6 @@ model: custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" tinker: enable: true - base_model: meta-llama/Llama-3.2-3B cluster: node_num: 1 gpu_per_node: 8 diff --git a/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md b/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md index a56d6eb671..cc1ea5966c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md @@ -23,7 +23,6 @@ ray start --head model: tinker: enable: true - base_model: null rank: 32 seed: null train_mlp: true @@ -35,7 +34,6 @@ model: - **`tinker`**:Tinker 专用配置部分。**注意**:启用 Tinker 后,所有 LoRA 配置(`model.lora_configs`)将被忽略。 - **`enable`**:是否启用 Tinker 后端。默认值:`false` - - **`base_model`**:Tinker 的基础模型路径。如果未指定(`null`),则默认为配置中其他位置的 `model_path` - **`rank`**:LoRA 的秩,控制适应矩阵的大小。默认值:`32` - **`seed`**:Tinker 操作的随机种子。未指定(`null`)时不设定特定种子 - **`train_mlp`**:是否训练 MLP(前馈)层。默认值:`true` @@ -93,7 +91,6 @@ model: custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" tinker: enable: true - base_model: meta-llama/Llama-3.2-3B cluster: node_num: 1 gpu_per_node: 8 diff --git a/examples/learn_to_ask/README.md b/examples/learn_to_ask/README.md index 1b300fe228..7b72b774aa 100644 --- a/examples/learn_to_ask/README.md +++ b/examples/learn_to_ask/README.md @@ -1,6 +1,7 @@ # Learn2Ask: Getting Started This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). + **Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction. All relevant files are located under `examples/learn_to_ask/`: diff --git a/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py index 60d7755d8c..f2e5538191 100644 --- a/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py +++ b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py @@ -6,6 +6,7 @@ import copy import gc import json +import math import os import re import time @@ -14,9 +15,6 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams -from trinity.common.constants import PLUGIN_DIRS_ENV_VAR -from trinity.utils.plugin_loader import load_plugins - def init_llm(model_path): tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -37,9 +35,15 @@ def init_llm(model_path): def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3): - from examples.learn_to_ask.workflow.prompt_learn2ask import ( - rollout_prompt_med as rollout_prompt, + import importlib + + spec = importlib.util.spec_from_file_location( + "prompt_learn2ask", + os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"), ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + rollout_prompt = module.rollout_prompt_med with open(input_file_path, "r") as lines: sample_list = [json.loads(line.strip()) for line in lines] @@ -70,9 +74,15 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path): - from examples.learn_to_ask.workflow.prompt_learn2ask import ( - reward_prompt_med as grader_prompt, + import importlib + + spec = importlib.util.spec_from_file_location( + "prompt_learn2ask", + os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"), ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + grader_prompt = module.reward_prompt_med print(f"input_file_path: {input_file_path}") print(f"output_file_path: {output_file_path}") @@ -156,6 +166,53 @@ def msg2str(msg_list): print("\n======================\n") +def compute_score(input_file_path): + with open(input_file_path, "r") as lines: + sample_list = [json.loads(line.strip()) for line in lines] + continue_count, continue_content_score, continue_content_full = 0, 0, 0 + continue_decision_score = 0 + stop_count, stop_decision_score = 0, 0 + total_reward, total_format = 0, 0 + continue_count_correct, continue_content_score_correct, continue_content_full_correct = 0, 0, 0 + for sample in sample_list: + for rollout, grade in zip(sample["rollouts"], sample["grades"]): + if math.isnan(grade["content_score"]) or math.isnan(grade["format_score"]): + continue + if sample["decision_truth"] == "continue": + continue_count += 1 + continue_content_score += grade["content_score"] + continue_content_full += 1 if grade["content_score"] == 1 else 0 + continue_decision_score += grade["action_score"] + if "" not in rollout: + continue_count_correct += 1 + continue_content_score_correct += grade["content_score"] + continue_content_full_correct += 1 if grade["content_score"] == 1 else 0 + + else: + stop_count += 1 + stop_decision_score += grade["action_score"] + total_reward += ( + grade["action_score"] * (1 + 2 * grade["content_score"]) + grade["format_score"] + ) + total_format += grade["format_score"] + + result = { + "ave_continue_content": continue_content_score / continue_count, + "win_continue_content": continue_content_full / continue_count, + "ave_continue_content if correct": continue_content_score_correct / continue_count_correct, + "win_continue_content if correct": continue_content_full_correct / continue_count_correct, + "ave_continue_decision": continue_decision_score / continue_count, + "ave_stop_decision": stop_decision_score / stop_count, + "ave_total_decision": (continue_decision_score + stop_decision_score) + / (continue_count + stop_count), + "ave_total_format": total_format / (continue_count + stop_count), + "ave_total_reward": total_reward / (continue_count + stop_count), + } + + print(f"total count: {continue_count + stop_count}") + print(json.dumps(result, ensure_ascii=False, indent=4)) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--rollout_repeat", type=int, default=3) @@ -177,9 +234,6 @@ def msg2str(msg_list): args = parser.parse_args() - os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join(os.path.dirname(__file__), "..", "workflow") - load_plugins() - # rollout stage llm, tokenizer, sampling_params = init_llm(args.eval_model_path) rollout( @@ -197,3 +251,4 @@ def msg2str(msg_list): # eval stage llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path) eval_sample(llm2, tokenizer2, sampling_params2, args.rollout_file_path, args.eval_file_path) + compute_score(args.eval_file_path) diff --git a/examples/tinker/README.md b/examples/tinker/README.md index 2c55fdf294..395cc57328 100644 --- a/examples/tinker/README.md +++ b/examples/tinker/README.md @@ -19,7 +19,6 @@ Configure the Tinker backend in your YAML configuration file by setting the `mod model: tinker: enable: true - base_model: null rank: 32 seed: null train_mlp: true @@ -31,7 +30,6 @@ model: - **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored. - **`enable`**: Whether to activate the Tinker backend. Default: `false` - - **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config - **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32` - **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set - **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true` @@ -90,7 +88,6 @@ model: custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" tinker: enable: true - base_model: meta-llama/Llama-3.2-3B cluster: node_num: 1 gpu_per_node: 8 diff --git a/examples/tinker/tinker.yaml b/examples/tinker/tinker.yaml index 744357e745..93812e9f79 100644 --- a/examples/tinker/tinker.yaml +++ b/examples/tinker/tinker.yaml @@ -19,7 +19,6 @@ model: max_response_tokens: 2048 tinker: enable: true - base_model: Qwen/Qwen3-4B-Instruct-2507 buffer: batch_size: 96 total_epochs: 1 From ca7457941d4f181fdc5c3265cde9542168722445 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 6 Jan 2026 20:18:23 +0800 Subject: [PATCH 2/2] apply suggestions from gemini --- .../data_prepare/3_rollout_then_evaluate.py | 63 ++++++++++--------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py index f2e5538191..cfbb650379 100644 --- a/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py +++ b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py @@ -5,6 +5,7 @@ import argparse import copy import gc +import importlib import json import math import os @@ -15,6 +16,13 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams +spec = importlib.util.spec_from_file_location( + "prompt_learn2ask", + os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"), +) +prompt_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(prompt_module) + def init_llm(model_path): tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -35,15 +43,7 @@ def init_llm(model_path): def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3): - import importlib - - spec = importlib.util.spec_from_file_location( - "prompt_learn2ask", - os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"), - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - rollout_prompt = module.rollout_prompt_med + rollout_prompt = prompt_module.rollout_prompt_med with open(input_file_path, "r") as lines: sample_list = [json.loads(line.strip()) for line in lines] @@ -74,15 +74,7 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path): - import importlib - - spec = importlib.util.spec_from_file_location( - "prompt_learn2ask", - os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"), - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - grader_prompt = module.reward_prompt_med + grader_prompt = prompt_module.reward_prompt_med print(f"input_file_path: {input_file_path}") print(f"output_file_path: {output_file_path}") @@ -196,20 +188,31 @@ def compute_score(input_file_path): ) total_format += grade["format_score"] + total_count = continue_count + stop_count result = { - "ave_continue_content": continue_content_score / continue_count, - "win_continue_content": continue_content_full / continue_count, - "ave_continue_content if correct": continue_content_score_correct / continue_count_correct, - "win_continue_content if correct": continue_content_full_correct / continue_count_correct, - "ave_continue_decision": continue_decision_score / continue_count, - "ave_stop_decision": stop_decision_score / stop_count, - "ave_total_decision": (continue_decision_score + stop_decision_score) - / (continue_count + stop_count), - "ave_total_format": total_format / (continue_count + stop_count), - "ave_total_reward": total_reward / (continue_count + stop_count), + "ave_continue_content": continue_content_score / continue_count if continue_count else 0.0, + "win_continue_content": continue_content_full / continue_count if continue_count else 0.0, + "ave_continue_content if correct": ( + continue_content_score_correct / continue_count_correct + if continue_count_correct + else 0.0 + ), + "win_continue_content if correct": ( + continue_content_full_correct / continue_count_correct + if continue_count_correct + else 0.0 + ), + "ave_continue_decision": ( + continue_decision_score / continue_count if continue_count else 0.0 + ), + "ave_stop_decision": stop_decision_score / stop_count if stop_count else 0.0, + "ave_total_decision": ( + (continue_decision_score + stop_decision_score) / total_count if total_count else 0.0 + ), + "ave_total_format": total_format / total_count if total_count else 0.0, + "ave_total_reward": total_reward / total_count if total_count else 0.0, } - - print(f"total count: {continue_count + stop_count}") + print(f"total count: {total_count}") print(json.dumps(result, ensure_ascii=False, indent=4))