diff --git a/Shell/eval_chair.sh b/Shell/eval_chair.sh index 70d53f0..586da9e 100755 --- a/Shell/eval_chair.sh +++ b/Shell/eval_chair.sh @@ -1,5 +1,46 @@ +#!/bin/bash +# CHAIR Evaluation Script for FarSight +# This script runs caption generation using FarSight decoding strategy + +# Configuration - Set these environment variables before running: +# MODEL_PATH: Path to the LLaVA model (default: liuhaotian/llava-v1.5-7b) +# IMAGE_FOLDER: Path to COCO val2014 images +# QUESTION_FILE: Path to CHAIR questions file (JSONL format) +# ANSWERS_FILE: Path to save generated captions (default: ./Answers/chair_captions.jsonl) + +MODEL_PATH="${MODEL_PATH:-liuhaotian/llava-v1.5-7b}" +IMAGE_FOLDER="${IMAGE_FOLDER:-}" +QUESTION_FILE="${QUESTION_FILE:-}" +ANSWERS_FILE="${ANSWERS_FILE:-./Answers/chair_captions.jsonl}" + +# Validate required paths +if [ -z "$IMAGE_FOLDER" ]; then + echo "Error: IMAGE_FOLDER environment variable is not set" + echo "Usage: IMAGE_FOLDER=/path/to/coco/val2014 QUESTION_FILE=/path/to/questions.jsonl bash eval_chair.sh" + exit 1 +fi + +if [ -z "$QUESTION_FILE" ]; then + echo "Error: QUESTION_FILE environment variable is not set" + echo "Usage: IMAGE_FOLDER=/path/to/coco/val2014 QUESTION_FILE=/path/to/questions.jsonl bash eval_chair.sh" + exit 1 +fi + +# Create output directory if needed +mkdir -p "$(dirname "$ANSWERS_FILE")" + +# Step 1: Generate captions using FarSight python eval_chair.py \ ---coco_path /root/autodl-tmp/annotations \ ---cache ./chair.pkl \ ---cap_file ./1.jsonl \ ---save_path ./Answers/eval-chair.json \ No newline at end of file + --model-path "$MODEL_PATH" \ + --image-folder "$IMAGE_FOLDER" \ + --question-file "$QUESTION_FILE" \ + --answers-file "$ANSWERS_FILE" \ + --farsight + +# Note: After generating captions, you can evaluate CHAIR metrics using: +# - External CHAIR evaluation tool (e.g., https://github.com/Maxlinn/CHAIR-metric-standalone) +# - Example command: +# python chair.py --cap_file ./Answers/chair_captions.jsonl \ +# --coco_path /path/to/coco/annotations \ +# --cache ./chair.pkl \ +# --save_path ./Answers/eval-chair.json \ No newline at end of file diff --git a/Shell/farsight_patch.py b/Shell/farsight_patch.py index 17244f1..4061056 100644 --- a/Shell/farsight_patch.py +++ b/Shell/farsight_patch.py @@ -4,11 +4,43 @@ import torch.nn.functional as F import types + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + """Apply rotary position embedding to query and key tensors.""" + # Handle different shapes of cos and sin from rotary embeddings + # Expected final shapes: cos, sin -> [bs, 1, seq_len, dim] + while cos.dim() > 2 and cos.shape[0] == 1: + cos = cos.squeeze(0) + while sin.dim() > 2 and sin.shape[0] == 1: + sin = sin.squeeze(0) + + # cos, sin should now be [seq_len, dim] or [1, seq_len, dim] + if cos.dim() == 2: + # [seq_len, dim] -> index by position_ids and add head dimension + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + elif cos.dim() == 3: + # [1, seq_len, dim] -> index by position_ids + cos = cos.squeeze(0)[position_ids].unsqueeze(1) + sin = sin.squeeze(0)[position_ids].unsqueeze(1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def farsight_attention_forward_v2( self, hidden_states, attention_mask=None, - position_ids=None, # 接住但不使用 + position_ids=None, past_key_value=None, # 兼容 KV 缓存接口 output_attentions=False, # 是否需要返回注意力矩阵 use_cache=False, # 兼容 generate 传参;本实现不支持 @@ -26,12 +58,29 @@ def farsight_attention_forward_v2( dtype = hidden_states.dtype device = hidden_states.device - # Q K V + # Q K V projections + # Handle both standard attention and grouped-query attention (GQA) + num_key_value_heads = getattr(self, 'num_key_value_heads', self.num_heads) + q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) - k = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) - v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) - - # 原始分数 + k = self.k_proj(hidden_states).view(B, L, num_key_value_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(hidden_states).view(B, L, num_key_value_heads, self.head_dim).transpose(1, 2) + + # Apply rotary position embeddings (critical for LLaMA!) + if hasattr(self, 'rotary_emb'): + kv_seq_len = k.shape[-2] + cos, sin = self.rotary_emb(k, seq_len=kv_seq_len) + if position_ids is None: + position_ids = torch.arange(L, device=device).unsqueeze(0) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + # Handle grouped query attention (GQA) - repeat KV heads if needed + if num_key_value_heads < self.num_heads: + num_key_value_groups = self.num_heads // num_key_value_heads + k = k.repeat_interleave(num_key_value_groups, dim=1) + v = v.repeat_interleave(num_key_value_groups, dim=1) + + # Compute attention scores attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [B,H,L,L] # 因果下三角 C(两次用) @@ -109,7 +158,37 @@ def _derive_valid_from_attention_mask(m, B, L, device, out_dtype): # ===== padding-safe 结束 ===== # 组装 W 与 softmax + # FarSight formula: W = (QK^T/√d) ⊙ C + P + # where C is causal mask (lower triangular 1s), P is FarSight upper triangular penalties + + # Create a large negative value for masking (using -1e9 instead of -inf to avoid NaN issues) + NEG_INF = -1e9 + + # For the lower triangular (causal): keep original attention scores + # For the upper triangular (future): apply FarSight penalties (allows limited look-ahead) + # The original code: attn_scores * C zeros out upper triangular completely, + # then adds P_total which has penalties only in upper triangular. + # This effectively gives: lower_tri = original_scores, upper_tri = penalties + W = attn_scores * C + P_total + + # Apply attention mask for padding tokens if provided + # The attention_mask from HuggingFace models is typically in additive format (0 for attend, -inf for mask) + if attention_mask is not None: + # Handle different attention mask formats from HuggingFace + if attention_mask.dim() == 4 and attention_mask.shape[-2] == L and attention_mask.shape[-1] == L: + # [B, 1, L, L] format - additive mask with 0 or -inf values + W = W + attention_mask.to(dtype) + elif attention_mask.dim() == 4 and attention_mask.shape[-2] == 1: + # [B, 1, 1, L] format - broadcast over query dimension + W = W + attention_mask.to(dtype) + elif attention_mask.dim() == 2: + # [B, L] format - binary mask (1 for attend, 0 for mask) + # Convert to additive mask format [B, 1, 1, L] + expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(dtype) + additive_mask = (1.0 - expanded_mask) * NEG_INF + W = W + additive_mask + attn_probs = torch.softmax(W, dim=-1) * C # 输出