Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions examples/rec_gsm8k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,10 @@ algorithm:
epsilon_high: 0.2
clip_mode: "one-side"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

**REC-OneSide-IS:**
Expand All @@ -100,11 +99,42 @@ algorithm:
epsilon_high: 0.2
clip_mode: "one-side"
weight: "importance_sampling"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

**REC-GSPO-NoIS:**

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 3e-4
epsilon_high: 4e-4
clip_mode: "gspo-one-side"
weight: "none"
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

**REC-GSPO-IS:**

```
algorithm:
algorithm_type: rec
policy_loss_fn_args:
epsilon_low: 3e-4
epsilon_high: 4e-4
clip_mode: "gspo-one-side"
weight: "gspo_importance_sampling"
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

**REC-TwoSide-IS:**
Expand All @@ -122,6 +152,8 @@ algorithm:
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

**REC-Ring-NoIS:**
Expand All @@ -141,6 +173,8 @@ algorithm:
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

### REP family
Expand All @@ -159,6 +193,8 @@ algorithm:
regularizer_coef: 0.1
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```


Expand All @@ -174,6 +210,8 @@ algorithm:
regularizer_coef: 0.1
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

### RED family
Expand All @@ -191,6 +229,8 @@ algorithm:
advantage_fn_args:
std_normalize: false
drop: "balance"
kl_loss_fn_args:
kl_coef: 0.0
```


Expand All @@ -206,6 +246,8 @@ algorithm:
temp: 1.0
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
```

## Citation
Expand Down
7 changes: 3 additions & 4 deletions examples/rec_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@ algorithm:
epsilon_high: 0.2
clip_mode: "one-side"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
advantage_fn_args:
std_normalize: false
kl_loss_fn_args:
kl_coef: 0.0
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_steps: 100
total_steps: 160
batch_size: 96
explorer_input:
taskset:
Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ def default_config(cls) -> Dict:
"policy_loss_fn": "rec",
"advantage_fn": "rec",
"kl_penalty_fn": "none",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}


Expand Down
9 changes: 9 additions & 0 deletions trinity/algorithm/policy_loss_fn/gspo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ def __init__(
if _clip_range_high is None:
raise ValueError("Either clip_range or clip_range_high must be specified.")
self.clip_range_high = _clip_range_high

if loss_agg_mode != "seq-mean-token-mean":
from trinity.utils.log import get_logger

logger = get_logger(__name__)
logger.warning(
f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'."
)
# loss_agg_mode = "seq-mean-token-mean"
self.loss_agg_mode = loss_agg_mode

def __call__( # type: ignore
Expand Down
30 changes: 26 additions & 4 deletions trinity/algorithm/policy_loss_fn/rec_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn
from trinity.algorithm.utils import masked_mean
from trinity.algorithm.utils import aggregate_loss, masked_mean


class RECPolicyLossFn(PolicyLossFn):
Expand Down Expand Up @@ -41,13 +41,15 @@ def __init__(
assert self.clip_mode in [
"none",
"one-side",
"gspo-one-side",
"two-side",
"ring",
], f"Invalid clip_mode: {self.clip_mode}"
self.weight = weight
assert self.weight in [
"none",
"importance_sampling",
"gspo_importance_sampling",
"advantage",
], f"Invalid weight: {self.weight}"

Expand All @@ -71,8 +73,8 @@ def __call__( # type: ignore
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""Calculate REC loss."""
# token-wise
ratio = torch.exp(logprob - old_logprob).detach()

ratio = torch.exp(logprob - old_logprob).detach() # token-wise prob ratio

# clipping
if self.clip_mode == "two-side":
Expand All @@ -81,6 +83,16 @@ def __call__( # type: ignore
is_in_range = (ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + (
advantages <= 0
) * (ratio >= (1 - self.epsilon_low))
elif self.clip_mode == "gspo-one-side":
mean_log_prob_diff = masked_mean(
logprob - old_logprob, action_mask, axis=-1
).detach() # [batch_size]
normalized_seq_ratio = torch.exp(mean_log_prob_diff).unsqueeze(-1) # [batch_size, 1]
is_in_range = (normalized_seq_ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + (
normalized_seq_ratio >= (1 - self.epsilon_low)
) * (
advantages <= 0
) # [batch_size, seq_len]
elif self.clip_mode == "ring":
is_in_range = (
(ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high))
Expand All @@ -93,6 +105,8 @@ def __call__( # type: ignore

if self.weight == "importance_sampling":
advantages = advantages * ratio # importance sampling
elif self.weight == "gspo_importance_sampling":
advantages = advantages * normalized_seq_ratio
elif self.weight == "advantage":
weight = torch.exp(advantages / self.temp)
advantages = advantages * weight # advantage weighting (unnormalized version)
Expand All @@ -107,7 +121,15 @@ def __call__( # type: ignore
regularizer_losses = self.regularizer_coef * (logprob - old_logprob).square()
pg_losses = pg_losses + regularizer_losses

pg_loss = masked_mean(pg_losses, action_mask)
if self.clip_mode == "gspo-one-side":
# [EXPERIMENTAL] specialized for gspo-style rec variant for now
pg_loss = aggregate_loss(
values=pg_losses,
mask=action_mask,
loss_agg_mode="seq-mean-token-mean",
)
else:
pg_loss = masked_mean(pg_losses, action_mask)

pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask)
metrics = {
Expand Down