From 17f2a40b5e72fa8230e0c262832303d2d8faba88 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Thu, 13 Nov 2025 11:55:10 +0800 Subject: [PATCH 1/6] Add GSPO-style REC variant --- examples/rec_gsm8k/README.md | 28 +++++++++++++++++ examples/rec_gsm8k/gsm8k.yaml | 5 +--- trinity/algorithm/algorithm.py | 4 +-- .../policy_loss_fn/gspo_policy_loss.py | 7 +++++ .../policy_loss_fn/rec_policy_loss.py | 30 ++++++++++++++++--- 5 files changed, 64 insertions(+), 10 deletions(-) diff --git a/examples/rec_gsm8k/README.md b/examples/rec_gsm8k/README.md index b3a5673463..8bf7eba948 100644 --- a/examples/rec_gsm8k/README.md +++ b/examples/rec_gsm8k/README.md @@ -107,6 +107,34 @@ algorithm: std_normalize: false ``` +**REC-GSPO-NoIS:** + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 3e-4 + epsilon_high: 4e-4 + clip_mode: "gspo-one-side" + weight: "none" + advantage_fn_args: + std_normalize: false +``` + +**REC-GSPO-IS:** + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 3e-4 + epsilon_high: 4e-4 + clip_mode: "gspo-one-side" + weight: "gspo_importance_sampling" + advantage_fn_args: + std_normalize: false +``` + **REC-TwoSide-IS:** ``` diff --git a/examples/rec_gsm8k/gsm8k.yaml b/examples/rec_gsm8k/gsm8k.yaml index c0dd6a3f07..0923d06b29 100644 --- a/examples/rec_gsm8k/gsm8k.yaml +++ b/examples/rec_gsm8k/gsm8k.yaml @@ -16,16 +16,13 @@ algorithm: epsilon_high: 0.2 clip_mode: "one-side" weight: "none" - temp: 1.0 - regularizer: "none" - regularizer_coef: 0.0 advantage_fn_args: std_normalize: false cluster: node_num: 1 gpu_per_node: 8 buffer: - total_steps: 100 + total_steps: 160 batch_size: 96 explorer_input: taskset: diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 4384da7b8e..a143dcac1a 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -358,8 +358,8 @@ def default_config(cls) -> Dict: "policy_loss_fn": "rec", "advantage_fn": "rec", "kl_penalty_fn": "none", - "kl_loss_fn": "none", - "entropy_loss_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", } diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 6ebd5af243..92f83cedc1 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -9,6 +9,9 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.utils import masked_loss, masked_mean +from trinity.utils.log import get_logger + +logger = get_logger(__name__) @POLICY_LOSS_FN.register_module("gspo") @@ -31,6 +34,10 @@ def __init__( if _clip_range_high is None: raise ValueError("Either clip_range or clip_range_high must be specified.") self.clip_range_high = _clip_range_high + + if (loss_agg_mode is not None) and (loss_agg_mode != "seq-mean-token-mean"): + logger.warning("GSPO requires loss_agg_mode == 'seq-mean-token-mean'.") + loss_agg_mode = "seq-mean-token-mean" self.loss_agg_mode = loss_agg_mode def __call__( # type: ignore diff --git a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py index ec0e623c9f..71135c085a 100644 --- a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py @@ -6,7 +6,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.utils import masked_mean +from trinity.algorithm.utils import masked_loss, masked_mean @POLICY_LOSS_FN.register_module("rec") @@ -42,6 +42,7 @@ def __init__( assert self.clip_mode in [ "none", "one-side", + "gspo-one-side", "two-side", "ring", ], f"Invalid clip_mode: {self.clip_mode}" @@ -49,6 +50,7 @@ def __init__( assert self.weight in [ "none", "importance_sampling", + "gspo_importance_sampling", "advantage", ], f"Invalid weight: {self.weight}" @@ -72,8 +74,8 @@ def __call__( # type: ignore **kwargs, ) -> Tuple[torch.Tensor, Dict]: """Calculate REC loss.""" - # token-wise - ratio = torch.exp(logprob - old_logprob).detach() + + ratio = torch.exp(logprob - old_logprob).detach() # token-wise prob ratio # clipping if self.clip_mode == "two-side": @@ -82,6 +84,16 @@ def __call__( # type: ignore is_in_range = (ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + ( advantages <= 0 ) * (ratio >= (1 - self.epsilon_low)) + elif self.clip_mode == "gspo-one-side": + mean_log_prob_diff = masked_mean( + logprob - old_logprob, action_mask, axis=-1 + ).detach() # [batch_size] + normalized_seq_ratio = torch.exp(mean_log_prob_diff).unsqueeze(-1) # [batch_size, 1] + is_in_range = (normalized_seq_ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + ( + normalized_seq_ratio >= (1 - self.epsilon_low) + ) * ( + advantages <= 0 + ) # [batch_size, seq_len] elif self.clip_mode == "ring": is_in_range = ( (ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high)) @@ -94,6 +106,8 @@ def __call__( # type: ignore if self.weight == "importance_sampling": advantages = advantages * ratio # importance sampling + elif self.weight == "gspo_importance_sampling": + advantages = advantages * normalized_seq_ratio elif self.weight == "advantage": weight = torch.exp(advantages / self.temp) advantages = advantages * weight # advantage weighting (unnormalized version) @@ -108,7 +122,15 @@ def __call__( # type: ignore regularizer_losses = self.regularizer_coef * (logprob - old_logprob).square() pg_losses = pg_losses + regularizer_losses - pg_loss = masked_mean(pg_losses, action_mask) + if self.clip_mode == "gspo-one-side": + # [EXPERIMENTAL] specialized for gspo-style rec variant for now + pg_loss = masked_loss( + values=pg_losses, + mask=action_mask, + loss_agg_mode="seq-mean-token-mean", + ) + else: + pg_loss = masked_mean(pg_losses, action_mask) pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask) metrics = { From 5a658f8f4e68efa10bc8d6e46faa456d496b3f88 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Thu, 13 Nov 2025 12:23:43 +0800 Subject: [PATCH 2/6] Set kl_coef = 0 explicitly for REC configs --- examples/rec_gsm8k/README.md | 26 ++++++++++++++++++++------ examples/rec_gsm8k/gsm8k.yaml | 2 ++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/examples/rec_gsm8k/README.md b/examples/rec_gsm8k/README.md index 8bf7eba948..af1e71b086 100644 --- a/examples/rec_gsm8k/README.md +++ b/examples/rec_gsm8k/README.md @@ -83,11 +83,10 @@ algorithm: epsilon_high: 0.2 clip_mode: "one-side" weight: "none" - temp: 1.0 - regularizer: "none" - regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-OneSide-IS:** @@ -100,11 +99,10 @@ algorithm: epsilon_high: 0.2 clip_mode: "one-side" weight: "importance_sampling" - temp: 1.0 - regularizer: "none" - regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-GSPO-NoIS:** @@ -119,6 +117,8 @@ algorithm: weight: "none" advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-GSPO-IS:** @@ -133,6 +133,8 @@ algorithm: weight: "gspo_importance_sampling" advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-TwoSide-IS:** @@ -150,6 +152,8 @@ algorithm: regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-Ring-NoIS:** @@ -169,6 +173,8 @@ algorithm: regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` ### REP family @@ -187,6 +193,8 @@ algorithm: regularizer_coef: 0.1 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` @@ -202,6 +210,8 @@ algorithm: regularizer_coef: 0.1 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` ### RED family @@ -219,6 +229,8 @@ algorithm: advantage_fn_args: std_normalize: false drop: "balance" + kl_loss_fn_args: + kl_coef: 0.0 ``` @@ -234,6 +246,8 @@ algorithm: temp: 1.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` ## Citation diff --git a/examples/rec_gsm8k/gsm8k.yaml b/examples/rec_gsm8k/gsm8k.yaml index 0923d06b29..209b0dfe05 100644 --- a/examples/rec_gsm8k/gsm8k.yaml +++ b/examples/rec_gsm8k/gsm8k.yaml @@ -18,6 +18,8 @@ algorithm: weight: "none" advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 cluster: node_num: 1 gpu_per_node: 8 From 2457e75f9b6976629f5421f54108923a134c5950 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Thu, 13 Nov 2025 13:32:33 +0800 Subject: [PATCH 3/6] Fix comment --- trinity/algorithm/policy_loss_fn/gspo_policy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 92f83cedc1..116da8e41a 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -35,7 +35,7 @@ def __init__( raise ValueError("Either clip_range or clip_range_high must be specified.") self.clip_range_high = _clip_range_high - if (loss_agg_mode is not None) and (loss_agg_mode != "seq-mean-token-mean"): + if loss_agg_mode != "seq-mean-token-mean": logger.warning("GSPO requires loss_agg_mode == 'seq-mean-token-mean'.") loss_agg_mode = "seq-mean-token-mean" self.loss_agg_mode = loss_agg_mode From 542e4ac31096371501fdfc5360249e98138951a9 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Fri, 26 Dec 2025 17:52:46 +0800 Subject: [PATCH 4/6] Fix pre-commit and rec --- trinity/algorithm/policy_loss_fn/gspo_policy_loss.py | 6 ++++-- trinity/algorithm/policy_loss_fn/rec_policy_loss.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 3e4e2b23c2..f65197387f 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -9,11 +9,11 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean - from trinity.utils.log import get_logger logger = get_logger(__name__) + class GSPOLossFn(PolicyLossFn): def __init__( self, @@ -35,7 +35,9 @@ def __init__( self.clip_range_high = _clip_range_high if loss_agg_mode != "seq-mean-token-mean": - logger.warning(f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'.") + logger.warning( + f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'." + ) # loss_agg_mode = "seq-mean-token-mean" self.loss_agg_mode = loss_agg_mode diff --git a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py index 680340bf5d..6f00ac86c1 100644 --- a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py @@ -6,7 +6,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn -from trinity.algorithm.utils import masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean class RECPolicyLossFn(PolicyLossFn): @@ -123,7 +123,7 @@ def __call__( # type: ignore if self.clip_mode == "gspo-one-side": # [EXPERIMENTAL] specialized for gspo-style rec variant for now - pg_loss = masked_loss( + pg_loss = aggregate_loss( values=pg_losses, mask=action_mask, loss_agg_mode="seq-mean-token-mean", From be4e78e30ec143c7d76cfe229c98886c2d949a88 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Mon, 29 Dec 2025 18:25:05 +0800 Subject: [PATCH 5/6] Fix logger issue --- trinity/algorithm/policy_loss_fn/gspo_policy_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index f65197387f..8bc620d4d6 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -11,8 +11,6 @@ from trinity.algorithm.utils import aggregate_loss, masked_mean from trinity.utils.log import get_logger -logger = get_logger(__name__) - class GSPOLossFn(PolicyLossFn): def __init__( @@ -35,6 +33,7 @@ def __init__( self.clip_range_high = _clip_range_high if loss_agg_mode != "seq-mean-token-mean": + logger = get_logger(__name__) logger.warning( f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'." ) From e2df33540046d6ebb8d7a1645009d14afa6773d6 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Mon, 29 Dec 2025 18:29:10 +0800 Subject: [PATCH 6/6] Move logger import --- trinity/algorithm/policy_loss_fn/gspo_policy_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 8bc620d4d6..7527603e03 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -9,7 +9,6 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean -from trinity.utils.log import get_logger class GSPOLossFn(PolicyLossFn): @@ -33,6 +32,8 @@ def __init__( self.clip_range_high = _clip_range_high if loss_agg_mode != "seq-mean-token-mean": + from trinity.utils.log import get_logger + logger = get_logger(__name__) logger.warning( f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'."