diff --git a/examples/rec_gsm8k/README.md b/examples/rec_gsm8k/README.md index b3a5673463..af1e71b086 100644 --- a/examples/rec_gsm8k/README.md +++ b/examples/rec_gsm8k/README.md @@ -83,11 +83,10 @@ algorithm: epsilon_high: 0.2 clip_mode: "one-side" weight: "none" - temp: 1.0 - regularizer: "none" - regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-OneSide-IS:** @@ -100,11 +99,42 @@ algorithm: epsilon_high: 0.2 clip_mode: "one-side" weight: "importance_sampling" - temp: 1.0 - regularizer: "none" - regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 +``` + +**REC-GSPO-NoIS:** + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 3e-4 + epsilon_high: 4e-4 + clip_mode: "gspo-one-side" + weight: "none" + advantage_fn_args: + std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 +``` + +**REC-GSPO-IS:** + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 3e-4 + epsilon_high: 4e-4 + clip_mode: "gspo-one-side" + weight: "gspo_importance_sampling" + advantage_fn_args: + std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-TwoSide-IS:** @@ -122,6 +152,8 @@ algorithm: regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` **REC-Ring-NoIS:** @@ -141,6 +173,8 @@ algorithm: regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` ### REP family @@ -159,6 +193,8 @@ algorithm: regularizer_coef: 0.1 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` @@ -174,6 +210,8 @@ algorithm: regularizer_coef: 0.1 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` ### RED family @@ -191,6 +229,8 @@ algorithm: advantage_fn_args: std_normalize: false drop: "balance" + kl_loss_fn_args: + kl_coef: 0.0 ``` @@ -206,6 +246,8 @@ algorithm: temp: 1.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 ``` ## Citation diff --git a/examples/rec_gsm8k/gsm8k.yaml b/examples/rec_gsm8k/gsm8k.yaml index c0dd6a3f07..209b0dfe05 100644 --- a/examples/rec_gsm8k/gsm8k.yaml +++ b/examples/rec_gsm8k/gsm8k.yaml @@ -16,16 +16,15 @@ algorithm: epsilon_high: 0.2 clip_mode: "one-side" weight: "none" - temp: 1.0 - regularizer: "none" - regularizer_coef: 0.0 advantage_fn_args: std_normalize: false + kl_loss_fn_args: + kl_coef: 0.0 cluster: node_num: 1 gpu_per_node: 8 buffer: - total_steps: 100 + total_steps: 160 batch_size: 96 explorer_input: taskset: diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index cd8fce5f35..ebe5978a07 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -434,8 +434,8 @@ def default_config(cls) -> Dict: "policy_loss_fn": "rec", "advantage_fn": "rec", "kl_penalty_fn": "none", - "kl_loss_fn": "none", - "entropy_loss_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", } diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 5b30a6a6af..7527603e03 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -30,6 +30,15 @@ def __init__( if _clip_range_high is None: raise ValueError("Either clip_range or clip_range_high must be specified.") self.clip_range_high = _clip_range_high + + if loss_agg_mode != "seq-mean-token-mean": + from trinity.utils.log import get_logger + + logger = get_logger(__name__) + logger.warning( + f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'." + ) + # loss_agg_mode = "seq-mean-token-mean" self.loss_agg_mode = loss_agg_mode def __call__( # type: ignore diff --git a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py index 5d7c547e62..6f00ac86c1 100644 --- a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py @@ -6,7 +6,7 @@ import torch from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn -from trinity.algorithm.utils import masked_mean +from trinity.algorithm.utils import aggregate_loss, masked_mean class RECPolicyLossFn(PolicyLossFn): @@ -41,6 +41,7 @@ def __init__( assert self.clip_mode in [ "none", "one-side", + "gspo-one-side", "two-side", "ring", ], f"Invalid clip_mode: {self.clip_mode}" @@ -48,6 +49,7 @@ def __init__( assert self.weight in [ "none", "importance_sampling", + "gspo_importance_sampling", "advantage", ], f"Invalid weight: {self.weight}" @@ -71,8 +73,8 @@ def __call__( # type: ignore **kwargs, ) -> Tuple[torch.Tensor, Dict]: """Calculate REC loss.""" - # token-wise - ratio = torch.exp(logprob - old_logprob).detach() + + ratio = torch.exp(logprob - old_logprob).detach() # token-wise prob ratio # clipping if self.clip_mode == "two-side": @@ -81,6 +83,16 @@ def __call__( # type: ignore is_in_range = (ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + ( advantages <= 0 ) * (ratio >= (1 - self.epsilon_low)) + elif self.clip_mode == "gspo-one-side": + mean_log_prob_diff = masked_mean( + logprob - old_logprob, action_mask, axis=-1 + ).detach() # [batch_size] + normalized_seq_ratio = torch.exp(mean_log_prob_diff).unsqueeze(-1) # [batch_size, 1] + is_in_range = (normalized_seq_ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + ( + normalized_seq_ratio >= (1 - self.epsilon_low) + ) * ( + advantages <= 0 + ) # [batch_size, seq_len] elif self.clip_mode == "ring": is_in_range = ( (ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high)) @@ -93,6 +105,8 @@ def __call__( # type: ignore if self.weight == "importance_sampling": advantages = advantages * ratio # importance sampling + elif self.weight == "gspo_importance_sampling": + advantages = advantages * normalized_seq_ratio elif self.weight == "advantage": weight = torch.exp(advantages / self.temp) advantages = advantages * weight # advantage weighting (unnormalized version) @@ -107,7 +121,15 @@ def __call__( # type: ignore regularizer_losses = self.regularizer_coef * (logprob - old_logprob).square() pg_losses = pg_losses + regularizer_losses - pg_loss = masked_mean(pg_losses, action_mask) + if self.clip_mode == "gspo-one-side": + # [EXPERIMENTAL] specialized for gspo-style rec variant for now + pg_loss = aggregate_loss( + values=pg_losses, + mask=action_mask, + loss_agg_mode="seq-mean-token-mean", + ) + else: + pg_loss = masked_mean(pg_losses, action_mask) pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask) metrics = {