From 33a6b27867b4617576b15ef994e06f6baa64fe56 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Sun, 10 May 2026 03:49:19 -0700 Subject: [PATCH] Support Tinker LoRA config flags --- skyrl/backends/backend.py | 9 +- skyrl/backends/jax.py | 24 +++- skyrl/backends/skyrl_train_backend.py | 80 +++++++++-- skyrl/backends/skyrl_train_lora.py | 131 +++++++++++++++++++ skyrl/tinker/api.py | 47 ++++++- skyrl/tinker/engine.py | 7 +- skyrl/tinker/types.py | 1 + skyrl/train/config/config.py | 4 +- tests/tinker/test_api_validation.py | 92 +++++++++++++ tests/tinker/test_skyrl_train_lora_config.py | 77 +++++++++++ 10 files changed, 446 insertions(+), 26 deletions(-) create mode 100644 skyrl/backends/skyrl_train_lora.py create mode 100644 tests/tinker/test_skyrl_train_lora_config.py diff --git a/skyrl/backends/backend.py b/skyrl/backends/backend.py index d4cbd11a1a..ddadd587f7 100644 --- a/skyrl/backends/backend.py +++ b/skyrl/backends/backend.py @@ -43,7 +43,13 @@ def __init__(self, base_model: str, config: BaseModel): pass @abstractmethod - def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: + def create_model( + self, + model_id: str, + lora_config: types.LoraConfig, + model_role: str = "policy", + seed_was_provided: bool = True, + ) -> None: """Create a new model in the backend. Creates optimizer and configures LoRA adapter. @@ -52,6 +58,7 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: model_id: The model identifier lora_config: LoRA configuration with rank and alpha model_role: Logical role for the model (e.g. policy or critic) + seed_was_provided: Whether the client explicitly set the seed. """ pass diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index 787f25e8b0..6b04f67b8a 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -546,7 +546,13 @@ def has_model(self, model_id: str) -> bool: """Check if a model is registered with the backend.""" return model_id in self.models - def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: + def create_model( + self, + model_id: str, + lora_config: types.LoraConfig, + model_role: str = "policy", + seed_was_provided: bool = True, + ) -> None: """Create a new model in the backend. Creates optimizer and configures LoRA adapter. Allocates adapter_index internally. @@ -1109,8 +1115,20 @@ def serialize(k, v): ) return getattr(super(), method)(**kwargs) - def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: - self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config, model_role=model_role) + def create_model( + self, + model_id: str, + lora_config: types.LoraConfig, + model_role: str = "policy", + seed_was_provided: bool = True, + ) -> None: + self._broadcast_and_call( + "create_model", + model_id=model_id, + lora_config=lora_config, + model_role=model_role, + seed_was_provided=seed_was_provided, + ) def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch) diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index ebf612f575..775b3be2fd 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -28,6 +28,10 @@ ) from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch +from skyrl.backends.skyrl_train_lora import ( + resolve_skyrl_train_lora_config, + skyrl_train_lora_signature, +) from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S from skyrl.tinker import types from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str @@ -98,8 +102,18 @@ def _build_skyrl_train_config( # Apply LoRA configuration if lora_config is not None and lora_config.rank > 0: + cfg.trainer.seed = int(lora_config.seed) + lora_type = cfg.trainer.policy.megatron_config.lora_config.lora_type + resolved_lora = resolve_skyrl_train_lora_config( + lora_config, + strategy=cfg.trainer.strategy, + lora_type=lora_type, + pipeline_parallel_size=cfg.trainer.policy.megatron_config.pipeline_model_parallel_size, + ) cfg.trainer.policy.model.lora.rank = lora_config.rank cfg.trainer.policy.model.lora.alpha = int(lora_config.alpha) + cfg.trainer.policy.model.lora.target_modules = resolved_lora.target_modules + cfg.trainer.policy.model.lora.exclude_modules = resolved_lora.exclude_modules logger.info("SkyRL-Train config:\n%s", get_config_as_yaml_str(cfg)) return cfg @@ -133,6 +147,7 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides): # Captured at first LoRA create_model; subsequent create_models must # match this signature exactly. None when no LoRA model is registered. self._base_lora_signature: tuple | None = None + self._base_lora_seed: int | None = None # New inference infrastructure self._server_groups: list = [] @@ -287,8 +302,18 @@ def _build_critic(self, CriticWorker, lora_config: types.LoraConfig) -> None: num_policy_gpus == num_critic_gpus ), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model" - cfg.trainer.critic.model.lora.rank = lora_config.rank - cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha) + if lora_config is not None and lora_config.rank > 0: + cfg.trainer.seed = int(lora_config.seed) + resolved_lora = resolve_skyrl_train_lora_config( + lora_config, + strategy=cfg.trainer.strategy, + lora_type=cfg.trainer.policy.megatron_config.lora_config.lora_type, + pipeline_parallel_size=cfg.trainer.policy.megatron_config.pipeline_model_parallel_size, + ) + cfg.trainer.critic.model.lora.rank = lora_config.rank + cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha) + cfg.trainer.critic.model.lora.target_modules = resolved_lora.target_modules + cfg.trainer.critic.model.lora.exclude_modules = resolved_lora.exclude_modules critic_model = PPORayActorGroup( cfg.trainer, cfg.trainer.placement.critic_num_nodes, @@ -356,14 +381,28 @@ def _ensure_inference_engines(self): self._inference_engines_initialized = True def _lora_signature_from(self, lora_config: types.LoraConfig) -> tuple: - # Tinker's public LoraConfig only exposes rank + alpha (plus - # seed/train_attn/train_mlp/train_unembed) - pending support https://github.com/NovaSky-AI/SkyRL/issues/1632. - # Equality across adapters therefore reduces to (rank, alpha); the worker-side - # AdapterStore additionally verifies parallel-state equality via - # its own LoraSignature. - return (int(lora_config.rank), int(lora_config.alpha)) - - def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: + if self._cfg is not None: + strategy = self._cfg.trainer.strategy + lora_type = self._cfg.trainer.policy.megatron_config.lora_config.lora_type + pipeline_parallel_size = self._cfg.trainer.policy.megatron_config.pipeline_model_parallel_size + else: + strategy = self.config.strategy + lora_type = "lora" + pipeline_parallel_size = 1 + return skyrl_train_lora_signature( + lora_config, + strategy=strategy, + lora_type=lora_type, + pipeline_parallel_size=pipeline_parallel_size, + ) + + def create_model( + self, + model_id: str, + lora_config: types.LoraConfig, + model_role: str = "policy", + seed_was_provided: bool = True, + ) -> None: if model_id in self._model_ids_to_role: raise ValueError(f"Model '{model_id}' already exists") @@ -389,10 +428,17 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: if new_signature != self._base_lora_signature: raise ValueError( f"LoRA signature mismatch for model '{model_id}': " - f"got (rank, alpha)={new_signature}, " + f"got {new_signature}, " f"first adapter registered with {self._base_lora_signature}. " - "Multi-LoRA with the SkyRLTrainBackend requires identical (rank, alpha) across all " - "adapters; target_modules is fixed server-side." + "Multi-LoRA with the SkyRLTrainBackend requires identical " + "(rank, alpha, target_modules, exclude_modules, lora_type) across all adapters." + ) + if seed_was_provided and self._base_lora_seed is not None and int(lora_config.seed) != self._base_lora_seed: + raise ValueError( + f"LoRA seed mismatch for model '{model_id}': got seed={lora_config.seed}, " + f"first adapter registered with seed={self._base_lora_seed}. " + "SkyRLTrainBackend additional adapters are initialized from the first pristine adapter, " + "so explicit seeds must match." ) self._dispatch.register_adapter("policy", model_id) self._model_ids_to_role[model_id] = model_role @@ -402,7 +448,11 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: # First-time setup OR critic creation (existing path). if model_role == "policy": - self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config) + self._cfg = _build_skyrl_train_config( + self.base_model, + self.config, + lora_config, + ) if not ray.is_initialized(): logger.info("Initializing Ray with runtime environment") @@ -425,6 +475,7 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: self._build_policy(PolicyWorker, model_id=model_id) if is_lora: self._base_lora_signature = self._lora_signature_from(lora_config) + self._base_lora_seed = int(lora_config.seed) elif model_role == "critic": if model_role in self._model_ids_to_role.values(): raise ValueError(f"SkyRLTrainBackend already has a '{model_role}' model") @@ -496,6 +547,7 @@ def delete_model(self, model_id: str) -> None: self._renderer = None self._colocate_pg = None self._base_lora_signature = None + self._base_lora_seed = None logger.info(f"Successfully deleted model {model_id}") def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch, role: str) -> TrainingInputBatch: diff --git a/skyrl/backends/skyrl_train_lora.py b/skyrl/backends/skyrl_train_lora.py new file mode 100644 index 0000000000..f9e203210b --- /dev/null +++ b/skyrl/backends/skyrl_train_lora.py @@ -0,0 +1,131 @@ +"""Translate Tinker LoRA options into SkyRL-Train LoRA target modules.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable + +from skyrl.tinker import types + +FSDP_ATTN_TARGET_MODULES = ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "query_key_value", + "attn.c_attn", + "attn.c_proj", +) +FSDP_MLP_TARGET_MODULES = ( + "gate_proj", + "up_proj", + "down_proj", + "fc1", + "fc2", + "c_fc", + "mlp.c_proj", +) +FSDP_UNEMBED_TARGET_MODULES = ( + "lm_head", + "embed_out", + "output_projection", +) + +MEGATRON_LORA_ATTN_TARGET_MODULES = ("linear_qkv", "linear_proj") +MEGATRON_LORA_MLP_TARGET_MODULES = ("linear_fc1", "linear_fc2") +MEGATRON_CANONICAL_LORA_ATTN_TARGET_MODULES = ("linear_q", "linear_k", "linear_v", "linear_proj") +MEGATRON_CANONICAL_LORA_MLP_TARGET_MODULES = ("linear_fc1_up", "linear_fc1_gate", "linear_fc2") +MEGATRON_UNEMBED_TARGET_MODULES = ("output_layer",) + + +@dataclass(frozen=True) +class ResolvedSkyRLTrainLoraConfig: + target_modules: str | list[str] + exclude_modules: list[str] | None = None + + +def normalize_lora_targets(target_modules: str | Iterable[str]) -> tuple[str, ...]: + if isinstance(target_modules, str): + return (target_modules,) + return tuple(target_modules) + + +def _dedupe_targets(target_modules: Iterable[str]) -> list[str]: + return list(dict.fromkeys(target_modules)) + + +def _validate_train_targets(lora_config: types.LoraConfig) -> None: + if lora_config.rank > 0 and not (lora_config.train_attn or lora_config.train_mlp or lora_config.train_unembed): + raise ValueError("At least one of train_attn, train_mlp, or train_unembed must be true for LoRA rank > 0") + + +def resolve_skyrl_train_lora_config( + lora_config: types.LoraConfig, + strategy: str, + lora_type: str = "lora", + pipeline_parallel_size: int = 1, +) -> ResolvedSkyRLTrainLoraConfig: + """Resolve Tinker LoRA train flags to the target module surface SkyRL-Train expects.""" + + _validate_train_targets(lora_config) + if lora_config.rank <= 0: + return ResolvedSkyRLTrainLoraConfig(target_modules="all-linear") + if lora_config.train_attn and lora_config.train_mlp and not lora_config.train_unembed: + return ResolvedSkyRLTrainLoraConfig(target_modules="all-linear") + + if strategy in ("fsdp", "fsdp2"): + target_modules: list[str] = [] + if lora_config.train_attn: + target_modules.extend(FSDP_ATTN_TARGET_MODULES) + if lora_config.train_mlp: + target_modules.extend(FSDP_MLP_TARGET_MODULES) + if lora_config.train_unembed: + target_modules.extend(FSDP_UNEMBED_TARGET_MODULES) + return ResolvedSkyRLTrainLoraConfig(target_modules=_dedupe_targets(target_modules)) + + if strategy == "megatron": + if lora_config.train_unembed and pipeline_parallel_size > 1: + raise ValueError( + "train_unembed=True is not supported for the Megatron SkyRL-Train backend when " + "pipeline_model_parallel_size > 1 because output_layer only exists on the final pipeline stage" + ) + if lora_type == "canonical_lora": + attn_targets = MEGATRON_CANONICAL_LORA_ATTN_TARGET_MODULES + mlp_targets = MEGATRON_CANONICAL_LORA_MLP_TARGET_MODULES + elif lora_type == "lora": + attn_targets = MEGATRON_LORA_ATTN_TARGET_MODULES + mlp_targets = MEGATRON_LORA_MLP_TARGET_MODULES + else: + raise ValueError(f"Unsupported Megatron LoRA type: {lora_type!r}") + + target_modules = [] + if lora_config.train_attn: + target_modules.extend(attn_targets) + if lora_config.train_mlp: + target_modules.extend(mlp_targets) + if lora_config.train_unembed: + target_modules.extend(MEGATRON_UNEMBED_TARGET_MODULES) + return ResolvedSkyRLTrainLoraConfig(target_modules=_dedupe_targets(target_modules)) + + raise ValueError(f"Unsupported SkyRL-Train strategy for Tinker LoRA config: {strategy!r}") + + +def skyrl_train_lora_signature( + lora_config: types.LoraConfig, + strategy: str, + lora_type: str = "lora", + pipeline_parallel_size: int = 1, +) -> tuple: + resolved = resolve_skyrl_train_lora_config( + lora_config, + strategy=strategy, + lora_type=lora_type, + pipeline_parallel_size=pipeline_parallel_size, + ) + return ( + int(lora_config.rank), + int(lora_config.alpha), + normalize_lora_targets(resolved.target_modules), + tuple(resolved.exclude_modules or ()), + lora_type if strategy == "megatron" else strategy, + ) diff --git a/skyrl/tinker/api.py b/skyrl/tinker/api.py index 884368773e..df9c93f859 100644 --- a/skyrl/tinker/api.py +++ b/skyrl/tinker/api.py @@ -244,6 +244,15 @@ class LoRAConfig(BaseModel): seed: int | None = Field( default=None, description="Seed for LoRA weight initialization. If None, a random seed is used." ) + train_unembed: bool = True + train_mlp: bool = True + train_attn: bool = True + + @model_validator(mode="after") + def validate_train_targets(self) -> "LoRAConfig": + if self.rank > 0 and not (self.train_attn or self.train_mlp or self.train_unembed): + raise ValueError("At least one of train_attn, train_mlp, or train_unembed must be true for LoRA rank > 0") + return self class CreateModelRequest(BaseModel): @@ -283,6 +292,16 @@ class ModelInfoResponse(BaseModel): model_data: ModelData +def _to_api_lora_config(lora_config: types.LoraConfig) -> LoRAConfig: + return LoRAConfig( + rank=lora_config.rank, + seed=lora_config.seed, + train_unembed=lora_config.train_unembed, + train_mlp=lora_config.train_mlp, + train_attn=lora_config.train_attn, + ) + + class Checkpoint(BaseModel): checkpoint_id: str checkpoint_type: Literal["training", "sampler"] @@ -677,6 +696,9 @@ class WeightsInfoResponse(BaseModel): base_model: str is_lora: bool lora_rank: int | None = None + train_unembed: bool | None = None + train_mlp: bool | None = None + train_attn: bool | None = None class ClientConfigResponse(BaseModel): @@ -757,13 +779,25 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe # alpha = 32 seems to be the tinker default (see https://thinkingmachines.ai/blog/lora/) # Generate a random seed if not provided - seed = request.lora_config.seed if request.lora_config.seed is not None else random.randint(0, 2**31 - 1) - lora_config = types.LoraConfig(rank=request.lora_config.rank, alpha=32.0, seed=seed) + seed_was_provided = request.lora_config.seed is not None + seed = request.lora_config.seed if seed_was_provided else random.randint(0, 2**31 - 1) + lora_config = types.LoraConfig( + rank=request.lora_config.rank, + alpha=32.0, + seed=seed, + train_unembed=request.lora_config.train_unembed, + train_mlp=request.lora_config.train_mlp, + train_attn=request.lora_config.train_attn, + ) request_id = await create_future( session=session, request_type=types.RequestType.CREATE_MODEL, model_id=model_id, - request_data=types.CreateModelInput(lora_config=lora_config, model_role=request.model_role), + request_data=types.CreateModelInput( + lora_config=lora_config, + model_role=request.model_role, + seed_was_provided=seed_was_provided, + ), ) model_db = ModelDB( @@ -781,7 +815,7 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe return CreateModelResponse( model_id=model_id, base_model=request.base_model, - lora_config=request.lora_config, + lora_config=_to_api_lora_config(lora_config), status="created", request_id=str(request_id), ) @@ -823,7 +857,7 @@ async def get_model_info(request: GetInfoRequest, session: AsyncSession = Depend lora_config = types.LoraConfig.model_validate(model.lora_config) model_data = ModelData( - base_model=model.base_model, lora_config=LoRAConfig(rank=lora_config.rank), model_name=model.base_model + base_model=model.base_model, lora_config=_to_api_lora_config(lora_config), model_name=model.base_model ) return ModelInfoResponse(model_id=model.model_id, status=model.status, model_data=model_data) @@ -1295,6 +1329,9 @@ async def get_weights_info(request: WeightsInfoRequest, req: Request, session: A base_model=model.base_model, is_lora=is_lora, lora_rank=lora_config.rank, + train_unembed=lora_config.train_unembed if is_lora else None, + train_mlp=lora_config.train_mlp if is_lora else None, + train_attn=lora_config.train_attn if is_lora else None, ) diff --git a/skyrl/tinker/engine.py b/skyrl/tinker/engine.py index 72058c2056..ed8793c9ab 100644 --- a/skyrl/tinker/engine.py +++ b/skyrl/tinker/engine.py @@ -445,7 +445,12 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput: """Create and initialize a model.""" # Create model in backend (allocates adapter_index, creates optimizer, and configures adapter) - self.backend.create_model(model_id, request_data.lora_config, model_role=request_data.model_role) + self.backend.create_model( + model_id, + request_data.lora_config, + model_role=request_data.model_role, + seed_was_provided=request_data.seed_was_provided, + ) logger.info(f"Created LoRA model {model_id}") diff --git a/skyrl/tinker/types.py b/skyrl/tinker/types.py index 0d7c334c22..a07d92166c 100644 --- a/skyrl/tinker/types.py +++ b/skyrl/tinker/types.py @@ -75,6 +75,7 @@ class LoraConfig(BaseModel): class CreateModelInput(BaseModel): lora_config: LoraConfig model_role: str = "policy" + seed_was_provided: bool = True class CreateModelOutput(BaseModel): diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index f28d997ba4..a07bd96a04 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -56,8 +56,8 @@ class SkyRLLoraConfig(BaseConfig): alpha: int = 16 dropout: float = 0.0 lora_sync_path: str = "/tmp/skyrl_lora_sync" - target_modules: str = "all-linear" - exclude_modules: Optional[str] = None + target_modules: str | List[str] = "all-linear" + exclude_modules: Optional[str | List[str]] = None init_method: str = "kaiming" """For FSDP, corresponds to ``init_lora_weights`` in PEFT. For Megatron, used for ``lora_A_init_method``; supports "xavier", "normal", "kaiming", "zero".""" diff --git a/tests/tinker/test_api_validation.py b/tests/tinker/test_api_validation.py index a436519601..c1ab10362d 100644 --- a/tests/tinker/test_api_validation.py +++ b/tests/tinker/test_api_validation.py @@ -2,8 +2,12 @@ import pytest from pydantic import TypeAdapter, ValidationError +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import SQLModel, select +from sqlmodel.ext.asyncio.session import AsyncSession from skyrl.tinker import api, types +from skyrl.tinker.db_models import FutureDB, ModelDB, SessionDB _B64_PNG = base64.b64encode(b"\x89PNG").decode() @@ -27,6 +31,94 @@ def test_forward_backward_input_accepts_ppo_threshold_keys(): assert req.loss_fn_config == {"clip_low_threshold": 0.9, "clip_high_threshold": 1.1} +def test_lora_config_accepts_tinker_train_flags(): + cfg = api.LoRAConfig(rank=8, seed=123, train_attn=False, train_mlp=True, train_unembed=True) + + assert cfg.seed == 123 + assert cfg.train_attn is False + assert cfg.train_mlp is True + assert cfg.train_unembed is True + + +def test_api_lora_config_defaults_match_tinker_sdk(): + api_cfg = api.LoRAConfig(rank=8) + + assert api_cfg.train_attn is True + assert api_cfg.train_mlp is True + assert api_cfg.train_unembed is True + + +def test_lora_config_rejects_rank_with_no_train_targets(): + with pytest.raises(ValidationError, match="At least one"): + api.LoRAConfig(rank=8, train_attn=False, train_mlp=False, train_unembed=False) + + +def test_weights_info_response_accepts_lora_train_flags(): + response = api.WeightsInfoResponse( + base_model="base", + is_lora=True, + lora_rank=8, + train_attn=False, + train_mlp=True, + train_unembed=True, + ) + + assert response.train_attn is False + assert response.train_mlp is True + assert response.train_unembed is True + + +@pytest.mark.asyncio +async def test_create_model_persists_lora_flags_and_seed_provenance(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + try: + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + async with AsyncSession(engine) as session: + session.add(SessionDB(session_id="session", sdk_version="test")) + await session.commit() + + response = await api.create_model( + api.CreateModelRequest( + session_id="session", + base_model="base", + lora_config=api.LoRAConfig( + rank=8, + train_attn=False, + train_mlp=True, + train_unembed=True, + ), + ), + session, + ) + + assert response.lora_config is not None + response_lora_config = response.lora_config + assert response_lora_config == api.LoRAConfig( + rank=8, + seed=response_lora_config.seed, + train_attn=False, + train_mlp=True, + train_unembed=True, + ) + assert response_lora_config.seed is not None + + future = (await session.exec(select(FutureDB))).one() + queued_input = types.CreateModelInput.model_validate(future.request_data) + assert queued_input.seed_was_provided is False + assert queued_input.lora_config.seed == response_lora_config.seed + assert queued_input.lora_config.train_attn is False + assert queued_input.lora_config.train_mlp is True + assert queued_input.lora_config.train_unembed is True + + model = (await session.exec(select(ModelDB))).one() + stored_lora = types.LoraConfig.model_validate(model.lora_config) + assert stored_lora == queued_input.lora_config + finally: + await engine.dispose() + + def test_forward_backward_input_accepts_ppo_value_clip(): req = api.ForwardBackwardInput( data=[_make_datum()], diff --git a/tests/tinker/test_skyrl_train_lora_config.py b/tests/tinker/test_skyrl_train_lora_config.py new file mode 100644 index 0000000000..42e76eaffe --- /dev/null +++ b/tests/tinker/test_skyrl_train_lora_config.py @@ -0,0 +1,77 @@ +import pytest + +from skyrl.backends.skyrl_train_lora import ( + resolve_skyrl_train_lora_config, + skyrl_train_lora_signature, +) +from skyrl.tinker import types + + +def test_resolves_fsdp_lora_train_flags_to_target_modules(): + cfg = types.LoraConfig( + rank=8, + alpha=32.0, + seed=123, + train_attn=True, + train_mlp=False, + train_unembed=True, + ) + + resolved = resolve_skyrl_train_lora_config(cfg, strategy="fsdp2") + + assert resolved.target_modules == [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "query_key_value", + "attn.c_attn", + "attn.c_proj", + "lm_head", + "embed_out", + "output_projection", + ] + + +def test_resolves_legacy_attn_mlp_surface_to_all_linear(): + cfg = types.LoraConfig(rank=8, alpha=32.0, seed=123, train_attn=True, train_mlp=True, train_unembed=False) + + resolved = resolve_skyrl_train_lora_config(cfg, strategy="fsdp2") + + assert resolved.target_modules == "all-linear" + + +def test_resolves_megatron_canonical_lora_train_flags_to_target_modules(): + cfg = types.LoraConfig( + rank=8, + alpha=32.0, + seed=123, + train_attn=False, + train_mlp=True, + train_unembed=True, + ) + + resolved = resolve_skyrl_train_lora_config(cfg, strategy="megatron", lora_type="canonical_lora") + + assert resolved.target_modules == [ + "linear_fc1_up", + "linear_fc1_gate", + "linear_fc2", + "output_layer", + ] + + +def test_skyrl_train_lora_signature_includes_trainable_surface(): + cfg_a = types.LoraConfig(rank=8, alpha=32.0, seed=1, train_attn=True, train_mlp=True, train_unembed=False) + cfg_b = types.LoraConfig(rank=8, alpha=32.0, seed=1, train_attn=True, train_mlp=True, train_unembed=True) + + assert skyrl_train_lora_signature(cfg_a, strategy="fsdp2") != skyrl_train_lora_signature( + cfg_b, strategy="fsdp2" + ) + + +def test_megatron_rejects_train_unembed_with_pipeline_parallelism(): + cfg = types.LoraConfig(rank=8, alpha=32.0, seed=123, train_unembed=True) + + with pytest.raises(ValueError, match="pipeline_model_parallel_size"): + resolve_skyrl_train_lora_config(cfg, strategy="megatron", pipeline_parallel_size=2)