diff --git a/generation.py b/generation.py index d71db6e..533c79a 100644 --- a/generation.py +++ b/generation.py @@ -57,6 +57,14 @@ def _apply_decay_mask(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> t probabilities = probabilities / probabilities.sum(dim=-1) return probabilities + def _apply_decay_mask_logits(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> torch.Tensor: + """Applies decay to a tensor of logits in the log space""" + decay_mask = torch.exp(- decay_mask * self.decay_constant) + decay_mask = torch.max(decay_mask, torch.tensor([self.epsilon], device=decay_mask.device)) + log_decay_mask = torch.log(decay_mask) + logits += log_decay_mask + return logits + def _generate_decay_mask(self, logits_regular: torch.FloatTensor, logits_biased_list: List[torch.FloatTensor]) -> torch.Tensor: """Computes the alpha values (see paper) for each token and stores them in a mask tensor""" p_regular = logits_regular.softmax(dim=-1)