Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
fd584ea
adapt tau2 bench
Mar 16, 2026
0dc261c
Merge branch 'master_center' into tau2_dev
Mar 19, 2026
3b19d79
tau2 pass@k
Mar 19, 2026
5ede58d
tau2 summarizer
Mar 19, 2026
d607066
tau2 summarizer
Mar 19, 2026
7e5e0b9
tau2 summarizer
Mar 20, 2026
50c4472
ignore json load exception
Mar 23, 2026
f5c8569
fix json load
Mar 23, 2026
5066f44
add requirement
Mar 24, 2026
c1d7aaf
Apply suggestion from @gemini-code-assist[bot]
SJTUyh Mar 24, 2026
d1fee46
Apply suggestion from @gemini-code-assist[bot]
SJTUyh Mar 24, 2026
7c98424
review fix
Mar 24, 2026
fa420e2
Merge branch 'tau2_dev' of https://github.com/SJTUyh/benchmark into t…
Mar 24, 2026
7ccb966
review fix
Mar 24, 2026
314ce7b
tau2 add tag limit
Mar 24, 2026
ed190fb
patch the user input
Mar 24, 2026
4f3e057
summarizer add total count
Mar 24, 2026
fbd59b3
summarizer add total count
Mar 24, 2026
93b4c13
summarizer add total count
Mar 24, 2026
100bbe3
summarizer add total count
Mar 24, 2026
e1062c2
fix weight
Mar 25, 2026
eba8f67
fix weight
Mar 25, 2026
97a7cb5
fix weight
Mar 25, 2026
72d83eb
tau2 fix
Apr 7, 2026
fb22a36
tau2 fix
Apr 9, 2026
d38a3fe
hat fix
Apr 9, 2026
2c0979c
add en docs
Apr 9, 2026
c5e7efe
add en docs
Apr 9, 2026
0510b46
merge fix
Apr 13, 2026
46bc121
merge fix
Apr 13, 2026
5f85793
merge fix
Apr 13, 2026
599f7b0
merge fix
Apr 13, 2026
a3661dd
merge fix
Apr 13, 2026
d88859f
merge fix
Apr 13, 2026
8f87172
merge fix
Apr 13, 2026
a4fdba1
add UT for tau2 bench
Apr 13, 2026
b5d420e
add UT for tau2 bench
Apr 13, 2026
f955871
add UT for tau2 bench
Apr 13, 2026
abbb924
add UT for tau2 bench
Apr 13, 2026
4dc4150
add UT for tau2 bench
Apr 13, 2026
0f4fd1f
add UT for tau2 bench
Apr 13, 2026
e645453
tau2 bench UT mock dependencies
Apr 13, 2026
38b4423
tau2 bench UT mock dependencies
Apr 13, 2026
b101d2e
tau2 bench UT mock dependencies
Apr 13, 2026
c1830fc
tau2 bench UT mock dependencies
Apr 13, 2026
4931123
tau2 bench UT mock dependencies
Apr 13, 2026
b3cab39
add new tau2 bench UT
Apr 13, 2026
7ad463e
add new tau2 bench UT
Apr 13, 2026
547e1ea
add new tau2 bench UT
Apr 13, 2026
d3a061d
add new tau2 bench UT
Apr 13, 2026
b2e78eb
add new tau2 bench UT
Apr 13, 2026
8624231
add new tau2 bench UT
Apr 13, 2026
06502d7
add new tau2 bench UT
Apr 13, 2026
1c8c061
add new tau2 bench UT
Apr 13, 2026
0f1b4e5
add new tau2 bench UT
Apr 13, 2026
9631cf5
add new tau2 bench UT
Apr 13, 2026
c777a76
add new tau2 bench UT
Apr 13, 2026
a38cba1
add new tau2 bench UT
Apr 13, 2026
7ae9a49
add new tau2 bench UT
Apr 13, 2026
f7ff051
delete unused dep
Apr 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ais_bench/benchmark/cli/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need, fill_test_range_use_num_prompts, recur_convert_config_type

class CustomConfigChecker:
MODEL_REQUIRED_FIELDS = ['type', 'abbr', 'attr']
DATASET_REQUIRED_FIELDS = ['type', 'abbr']
MODEL_REQUIRED_FIELDS = ['abbr']
DATASET_REQUIRED_FIELDS = ['abbr']
SUMMARIZER_REQUIRED_FIELDS = ['attr']

def __init__(self, config, file_path):
Expand Down Expand Up @@ -106,6 +106,8 @@ def load_config(self, workflow):

def _fill_dataset_configs(self):
for dataset_cfg in self.cfg["datasets"]:
if dataset_cfg.get("infer_cfg", None) is None:
continue
fill_test_range_use_num_prompts(self.cfg["cli_args"].get("num_prompts"), dataset_cfg)
fill_model_path_if_datasets_need(self.cfg["models"][0], dataset_cfg)
retriever_cfg = dataset_cfg["infer_cfg"]["retriever"]
Expand Down
2 changes: 2 additions & 0 deletions ais_bench/benchmark/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@


def get_config_type(obj) -> str:
if obj is None:
return None
Comment on lines +21 to +22
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While adding a None check is a good improvement for robustness, the function's return type hint -> str on line 18 is now incorrect because the function can return None. Please update the signature to -> Optional[str] to accurately reflect its behavior. You will also need to add from typing import Optional at the top of the file.

if isinstance(obj, str):
return obj
return f"{obj.__module__}.{obj.__name__}"
Expand Down
47 changes: 40 additions & 7 deletions ais_bench/benchmark/cli/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ais_bench.benchmark.partitioners import NaivePartitioner
from ais_bench.benchmark.runners import LocalRunner
from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask
from ais_bench.benchmark.tasks.base import EmptyTask
from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer
from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator
from ais_bench.benchmark.cli.utils import clear_repeat_tasks
Expand All @@ -26,6 +27,7 @@
class BaseWorker(ABC):
def __init__(self, args) -> None:
self.args = args
self.skip = False

@abstractmethod
def update_cfg(self, cfg: ConfigDict) -> None:
Expand All @@ -39,13 +41,21 @@ def do_work(self, cfg: ConfigDict):


class Infer(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
def get_task_type() -> str:
if cfg["models"][0]["attr"] == "service":
return OpenICLApiInferTask
else:
return OpenICLInferTask

custom_infer = cfg.get("infer")
custom_task = None
if custom_infer:
custom_task = custom_infer.get("runner", {}).get("task", {}).get("type")
if custom_task == EmptyTask:
self.skip = True
return cfg

def update_new_infer_cfg(new_cfg: ConfigDict) -> None:
runner_cfg = new_cfg['infer']['runner']
runner_cfg['max_num_workers'] = self.args.max_num_workers
Expand All @@ -54,12 +64,16 @@ def update_new_infer_cfg(new_cfg: ConfigDict) -> None:

if cfg.get('infer'):
new_cfg = dict(infer=cfg.infer)
if not new_cfg["infer"].get("partitioner"):
new_cfg["infer"]["partitioner"] = dict(type=NaivePartitioner)
if new_cfg["infer"].get("runner") and new_cfg["infer"]["runner"].get("type") is None:
new_cfg["infer"]["runner"]["type"] = LocalRunner
else:
new_cfg = dict(
infer=dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
task=dict(type=get_task_type()),
task=dict(type=custom_task if custom_task else get_task_type()),
type=LocalRunner,
),
),
Expand All @@ -70,6 +84,9 @@ def update_new_infer_cfg(new_cfg: ConfigDict) -> None:
return cfg

def do_work(self, cfg: ConfigDict):
if self.skip:
logger.info("EmptyTask is selected, skip inference.")
return
partitioner = PARTITIONERS.build(cfg.infer.partitioner)
logger.info("Starting inference tasks...")
tasks = partitioner(cfg)
Expand Down Expand Up @@ -123,7 +140,7 @@ def __init__(self, args) -> None:
super().__init__(args)
self.judge_model_type = None

def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
for dataset_cfg in cfg["datasets"]:
judge_infer_cfg = dataset_cfg.get("judge_infer_cfg")
if judge_infer_cfg:
Expand Down Expand Up @@ -280,7 +297,15 @@ def _result_post_process(self, tasks, cfg: ConfigDict):


class Eval(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
custom_eval = cfg.get("eval")
custom_task = None
if custom_eval:
custom_task = custom_eval.get("runner", {}).get("task", {}).get("type")
if custom_task == EmptyTask:
self.skip = True
return cfg

def update_eval_cfg(new_cfg: ConfigDict) -> None:
runner_cfg = new_cfg['eval']['runner']
runner_cfg['max_num_workers'] = self.args.max_num_workers
Expand All @@ -291,22 +316,30 @@ def update_eval_cfg(new_cfg: ConfigDict) -> None:

if cfg.get('eval'):
new_cfg = dict(eval=cfg.eval)
if not new_cfg["eval"].get("partitioner"):
new_cfg["eval"]["partitioner"] = dict(type=NaivePartitioner)
if new_cfg["eval"].get("runner") and new_cfg["eval"]["runner"].get("type") is None:
new_cfg["eval"]["runner"]["type"] = LocalRunner
else:
new_cfg = dict(
eval=dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
task=dict(type=OpenICLEvalTask),
),
))
task=dict(type=custom_task if custom_task else OpenICLEvalTask),
),
)
)

update_eval_cfg(new_cfg)
cfg.merge_from_dict(new_cfg)
cfg.eval.partitioner["out_dir"] = osp.join(cfg["work_dir"], "results/")
return cfg

def do_work(self, cfg: ConfigDict):
if self.skip:
logger.info("EmptyTask is selected, skip evaluation.")
return
partitioner = PARTITIONERS.build(cfg.eval.partitioner)
logger.info("Starting evaluation tasks...")
self._cfg_pre_process(cfg)
Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/partitioners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def _check_task_cfg(self, tasks):
filtered_tasks = []
for task in tasks:
mode = task.get("cli_args", {}).get("mode")
dataset_type = task["datasets"][0][0]["type"]
model_type = task["models"][0]["type"]
dataset_type = task["datasets"][0][0].get("type", None)
model_type = task["models"][0].get("type", None)
if mode not in ["perf", "perf_viz"] and dataset_type in ONLY_PERF_DATASETS:
self.logger.warning(
f"'{dataset_type}' can only be used for performance evaluation, "
Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/partitioners/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class NaivePartitioner(BasePartitioner):
"""

def __init__(self,
out_dir: str,
out_dir: str = '',
n: int = 1,
keep_keys: Optional[List[str]] = None):
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
Expand All @@ -33,7 +33,7 @@ def partition(self,
model_dataset_combinations: List[Dict[str,
List[ConfigDict]]],
work_dir: str,
out_dir: str,
out_dir: str = '',
add_cfg: Dict = {}) -> List[Dict]:
"""Partition model-dataset pairs into tasks. Each task is defined as a
dict and will run independently as a unit. Its structure is as
Expand Down
2 changes: 1 addition & 1 deletion ais_bench/benchmark/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def register_module(

PARTITIONERS = Registry('partitioner', locations=get_locations('partitioners'))
RUNNERS = Registry('runner', locations=get_locations('runners'))
TASKS = Registry('task', locations=get_locations('tasks'))
TASKS = Registry('task', locations=get_locations('tasks') + get_locations('tasks.custom_tasks'))
MODELS = Registry('model', locations=get_locations('models'))
# TODO: LOAD_DATASET -> DATASETS
LOAD_DATASET = Registry('load_dataset', locations=get_locations('datasets'))
Expand Down
30 changes: 26 additions & 4 deletions ais_bench/benchmark/summarizers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,29 +286,51 @@ def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode, requ
elif isinstance(item, (list, tuple)):
summarizer_dataset_abbrs.append((item[0], item[1]))

has_total_count = False
for dataset_abbr in dataset_metrics:
if 'total_count' in dataset_metrics[dataset_abbr]:
has_total_count = True
break

table = []
header = ['dataset', 'version', 'metric', 'mode'] + self.model_abbrs
if has_total_count:
header = ['dataset', 'version', 'metric', 'mode', 'total_count'] + self.model_abbrs
table.append(header)
for dataset_abbr, metric in summarizer_dataset_abbrs:
if dataset_abbr not in dataset_metrics:
if not skip_all_slash:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
if has_total_count:
table.append([dataset_abbr, '-', '-', '-', '-'] + ['-'] * len(self.model_abbrs))
else:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
continue
if metric is None:
metric = dataset_metrics[dataset_abbr][0]
elif metric in dataset_metrics[dataset_abbr]:
pass
else:
if not skip_all_slash:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
if has_total_count:
table.append([dataset_abbr, '-', '-', '-', '-'] + ['-'] * len(self.model_abbrs))
else:
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
continue

total_count_value = '/'
if 'total_count' in dataset_metrics[dataset_abbr]:
first_model_abbr = self.model_abbrs[0]
if dataset_abbr in parsed_results[first_model_abbr] and 'total_count' in parsed_results[first_model_abbr][dataset_abbr]:
total_count_value = str(int(parsed_results[first_model_abbr][dataset_abbr]['total_count']))

row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric, dataset_eval_mode.get(dataset_abbr, '-')]
if has_total_count:
row.append(total_count_value)
for model_abbr in self.model_abbrs:
if dataset_abbr in parsed_results[model_abbr]:
row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][metric]))
correct_count = parsed_results[model_abbr][dataset_abbr].pop('correct_count', None)
total_count = parsed_results[model_abbr][dataset_abbr].pop('total_count', None)
correct_count = parsed_results[model_abbr][dataset_abbr].get('correct_count', None)
total_count = parsed_results[model_abbr][dataset_abbr].get('total_count', None)
if correct_count is not None and total_count is not None:
row[-1] = str(row[-1]) + f' ({correct_count}/{total_count})'
else:
Expand Down
8 changes: 8 additions & 0 deletions ais_bench/benchmark/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def get_output_paths(self, file_extension: str = "json") -> List[str]:
return output_paths


class EmptyTask(BaseTask):
def run(self):
pass

def get_command(self, cfg_path, template) -> str:
return ""


class TaskStateManager:
def __init__(self, tmp_path: str, task_name: str, is_debug: bool, refresh_interval: int = 0.5):
self.logger = AISLogger()
Expand Down
Empty file.
Loading
Loading