Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions Shell/eval_chair.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
#!/bin/bash
# CHAIR Evaluation Script for FarSight
# This script runs caption generation using FarSight decoding strategy

# Configuration - Set these environment variables before running:
# MODEL_PATH: Path to the LLaVA model (default: liuhaotian/llava-v1.5-7b)
# IMAGE_FOLDER: Path to COCO val2014 images
# QUESTION_FILE: Path to CHAIR questions file (JSONL format)
# ANSWERS_FILE: Path to save generated captions (default: ./Answers/chair_captions.jsonl)

MODEL_PATH="${MODEL_PATH:-liuhaotian/llava-v1.5-7b}"
IMAGE_FOLDER="${IMAGE_FOLDER:-}"
QUESTION_FILE="${QUESTION_FILE:-}"
ANSWERS_FILE="${ANSWERS_FILE:-./Answers/chair_captions.jsonl}"

# Validate required paths
if [ -z "$IMAGE_FOLDER" ]; then
echo "Error: IMAGE_FOLDER environment variable is not set"
echo "Usage: IMAGE_FOLDER=/path/to/coco/val2014 QUESTION_FILE=/path/to/questions.jsonl bash eval_chair.sh"
exit 1
fi

if [ -z "$QUESTION_FILE" ]; then
echo "Error: QUESTION_FILE environment variable is not set"
echo "Usage: IMAGE_FOLDER=/path/to/coco/val2014 QUESTION_FILE=/path/to/questions.jsonl bash eval_chair.sh"
exit 1
fi

# Create output directory if needed
mkdir -p "$(dirname "$ANSWERS_FILE")"

# Step 1: Generate captions using FarSight
python eval_chair.py \
--coco_path /root/autodl-tmp/annotations \
--cache ./chair.pkl \
--cap_file ./1.jsonl \
--save_path ./Answers/eval-chair.json
--model-path "$MODEL_PATH" \
--image-folder "$IMAGE_FOLDER" \
--question-file "$QUESTION_FILE" \
--answers-file "$ANSWERS_FILE" \
--farsight

# Note: After generating captions, you can evaluate CHAIR metrics using:
# - External CHAIR evaluation tool (e.g., https://github.com/Maxlinn/CHAIR-metric-standalone)
# - Example command:
# python chair.py --cap_file ./Answers/chair_captions.jsonl \
# --coco_path /path/to/coco/annotations \
# --cache ./chair.pkl \
# --save_path ./Answers/eval-chair.json
91 changes: 85 additions & 6 deletions Shell/farsight_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,43 @@
import torch.nn.functional as F
import types


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""Apply rotary position embedding to query and key tensors."""
# Handle different shapes of cos and sin from rotary embeddings
# Expected final shapes: cos, sin -> [bs, 1, seq_len, dim]
while cos.dim() > 2 and cos.shape[0] == 1:
cos = cos.squeeze(0)
while sin.dim() > 2 and sin.shape[0] == 1:
sin = sin.squeeze(0)

# cos, sin should now be [seq_len, dim] or [1, seq_len, dim]
if cos.dim() == 2:
# [seq_len, dim] -> index by position_ids and add head dimension
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
elif cos.dim() == 3:
# [1, seq_len, dim] -> index by position_ids
cos = cos.squeeze(0)[position_ids].unsqueeze(1)
sin = sin.squeeze(0)[position_ids].unsqueeze(1)

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def farsight_attention_forward_v2(
self,
hidden_states,
attention_mask=None,
position_ids=None, # 接住但不使用
position_ids=None,
past_key_value=None, # 兼容 KV 缓存接口
output_attentions=False, # 是否需要返回注意力矩阵
use_cache=False, # 兼容 generate 传参;本实现不支持
Expand All @@ -26,12 +58,29 @@ def farsight_attention_forward_v2(
dtype = hidden_states.dtype
device = hidden_states.device

# Q K V
# Q K V projections
# Handle both standard attention and grouped-query attention (GQA)
num_key_value_heads = getattr(self, 'num_key_value_heads', self.num_heads)

q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

# 原始分数
k = self.k_proj(hidden_states).view(B, L, num_key_value_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, L, num_key_value_heads, self.head_dim).transpose(1, 2)

# Apply rotary position embeddings (critical for LLaMA!)
if hasattr(self, 'rotary_emb'):
kv_seq_len = k.shape[-2]
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
if position_ids is None:
position_ids = torch.arange(L, device=device).unsqueeze(0)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)

# Handle grouped query attention (GQA) - repeat KV heads if needed
if num_key_value_heads < self.num_heads:
num_key_value_groups = self.num_heads // num_key_value_heads
k = k.repeat_interleave(num_key_value_groups, dim=1)
v = v.repeat_interleave(num_key_value_groups, dim=1)

# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [B,H,L,L]

# 因果下三角 C(两次用)
Expand Down Expand Up @@ -109,7 +158,37 @@ def _derive_valid_from_attention_mask(m, B, L, device, out_dtype):
# ===== padding-safe 结束 =====

# 组装 W 与 softmax
# FarSight formula: W = (QK^T/√d) ⊙ C + P
# where C is causal mask (lower triangular 1s), P is FarSight upper triangular penalties

# Create a large negative value for masking (using -1e9 instead of -inf to avoid NaN issues)
NEG_INF = -1e9

# For the lower triangular (causal): keep original attention scores
# For the upper triangular (future): apply FarSight penalties (allows limited look-ahead)
# The original code: attn_scores * C zeros out upper triangular completely,
# then adds P_total which has penalties only in upper triangular.
# This effectively gives: lower_tri = original_scores, upper_tri = penalties

W = attn_scores * C + P_total

# Apply attention mask for padding tokens if provided
# The attention_mask from HuggingFace models is typically in additive format (0 for attend, -inf for mask)
if attention_mask is not None:
# Handle different attention mask formats from HuggingFace
if attention_mask.dim() == 4 and attention_mask.shape[-2] == L and attention_mask.shape[-1] == L:
# [B, 1, L, L] format - additive mask with 0 or -inf values
W = W + attention_mask.to(dtype)
elif attention_mask.dim() == 4 and attention_mask.shape[-2] == 1:
# [B, 1, 1, L] format - broadcast over query dimension
W = W + attention_mask.to(dtype)
elif attention_mask.dim() == 2:
# [B, L] format - binary mask (1 for attend, 0 for mask)
# Convert to additive mask format [B, 1, 1, L]
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(dtype)
additive_mask = (1.0 - expanded_mask) * NEG_INF
W = W + additive_mask

attn_probs = torch.softmax(W, dim=-1) * C

# 输出
Expand Down