diff --git a/.env.example b/.env.example index 47ec962..a5733b2 100644 --- a/.env.example +++ b/.env.example @@ -5,4 +5,5 @@ MLIR_SHARED_LIBS= AST_DUMPER_BIN_PATH= PRE_VEC_BIN_PATH= VECTORIZER_BIN_PATH= -CONDA_ENV= \ No newline at end of file +CONDA_ENV= +CONFIG_FILE_PATH= \ No newline at end of file diff --git a/.gitignore b/.gitignore index c206e8f..156f1f0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ .neptune *__pycache__ tools/*/build + +playground/ \ No newline at end of file diff --git a/config/example.json b/config/example.json index f1c6ca7..26353f0 100644 --- a/config/example.json +++ b/config/example.json @@ -6,6 +6,7 @@ "vect_size_limit": 512, "order": [["I"], ["!", "I", "NT"], ["!", "I"], ["V", "NT"]], "interchange_mode": "enumerate", + "use_img2col": true, "exploration": ["entropy"], "init_epsilon": 0.5, "normalize_bounds": "max", diff --git a/demo.ipynb b/demo.ipynb index 894cf2b..c189698 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -2,10 +2,21 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "2e74d0c8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Setup environment\n", "# import os\n", @@ -82,7 +93,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mlir-rl", + "display_name": "llvm-build", "language": "python", "name": "python3" }, diff --git a/docs/ENV.md b/docs/ENV.md new file mode 100644 index 0000000..b36409e --- /dev/null +++ b/docs/ENV.md @@ -0,0 +1,7 @@ +export PATH=/scratch/kb5213/resources/llvm-project/build/bin:$PATH + +export PYTHONPATH=/scratch/kb5213/resources/llvm-project/build/tools/mlir/python_packages/mlir_core:$PYTHONPATH + +export MLIR_SHARED_LIBS=/scratch/kb5213/resources/llvm-project/build/lib/libomp.so,/scratch/kb5213/resources/llvm-project/build/lib/libmlir_c_runner_utils.so,/scratch/kb5213/resources/llvm-project/build/lib/libmlir_runner_utils.so + +"order": [["I"], ["!", "I", "NT"], ["!", "I"], ["V", "NT"]], \ No newline at end of file diff --git a/evaluate.py b/evaluate.py index c531750..ef2b826 100644 --- a/evaluate.py +++ b/evaluate.py @@ -71,6 +71,7 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: # Read the files in the evaluation directory eval_files = [f for f in os.listdir(eval_dir) if f.endswith('.pt')] +eval_files = [eval_files[-1]] # Only evaluate the last model # Order files eval_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) diff --git a/get_base.py b/get_base.py index 62a5e25..8dde829 100644 --- a/get_base.py +++ b/get_base.py @@ -11,8 +11,13 @@ if not os.path.isdir(path_to_folder): print(f"Error: {path_to_folder} is not a valid directory.") sys.exit(1) + +with open(f"{path_to_folder}/../benchmarks_split.json", 'r') as f: + benchmarks_split = json.load(f) + +train_output_data = {} +eval_output_data = {} -output_data = {} exec = Execution("") code_files = [f for f in os.listdir(path_to_folder) if f.endswith('.mlir')] @@ -28,7 +33,13 @@ except Exception as e: print(f"Failed to execute {bench_name}: {e}") et = -1 - output_data[bench_name] = et - - with open('base_exec_times.json', 'w') as f: - json.dump(output_data, f, indent=4) + + if bench_name in benchmarks_split['train']: + train_output_data[bench_name] = et + with open(f"{path_to_folder}/../execution_times_train.json", 'w') as f: + json.dump(train_output_data, f, indent=4) + + elif bench_name in benchmarks_split['eval']: + eval_output_data[bench_name] = et + with open(f"{path_to_folder}/../execution_times_eval.json", 'w') as f: + json.dump(eval_output_data, f, indent=4) diff --git a/iql/__init__.py b/iql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iql/agent.py b/iql/agent.py new file mode 100644 index 0000000..2d70f70 --- /dev/null +++ b/iql/agent.py @@ -0,0 +1,282 @@ +import copy +from typing import Dict, List, Optional, Type, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.config import Config +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory + +from iql.value_function import IQLValueModel +from iql.policy import IQLPolicyModel +from iql.q_functions import IQLTwinQ + +cfg = Config() + +class IQLAgent(nn.Module): + """ + IQL agent adapted to the PPO-aligned architecture and hierarchical action space. + - Uses Observation.get_parts(obs, *obs_parts) + - Shared 3x512 backbone across policy/value/Q + - Hierarchical heads (action + per-action params) + """ + def __init__(self, obs_parts=None, param_dims=None): + super().__init__() + self.obs_parts = obs_parts or [OpFeatures, ActionHistory] + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Use config hyperparameters + self.gamma = cfg.gamma + self.tau = cfg.tau + self.beta = cfg.beta + self.alpha = cfg.alpha + + # Networks + self.value_model = IQLValueModel(self.obs_parts, tau=self.tau).to(self.device) + self.policy_model = IQLPolicyModel(self.obs_parts).to(self.device) + self.q_model = IQLTwinQ(self.obs_parts).to(self.device) + + # Target Q + self.q_target = copy.deepcopy(self.q_model).to(self.device) + for p in self.q_target.parameters(): + p.requires_grad = False + + # Optimizers with cfg.lr dict (after models are on device) + self.value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=cfg.lr["value"]) + self.q_optimizer = torch.optim.Adam(self.q_model.parameters(), lr=cfg.lr["q"]) + self.policy_optimizer = torch.optim.Adam(self.policy_model.parameters(), lr=cfg.lr["policy"]) + """ + self.policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.policy_optimizer, + T_max=600000, + eta_min=1e-5 + ) + """ + + # --------- helpers to move inputs to device ---------- + def _to_device_tensor(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if x is None: + return None + return x.to(self.device, non_blocking=True) + + def _to_device_tensor_list( + self, + xs: Optional[List[Optional[torch.Tensor]]] + ) -> Optional[List[Optional[torch.Tensor]]]: + if xs is None: + return None + out: List[Optional[torch.Tensor]] = [] + for t in xs: + out.append(self._to_device_tensor(t) if isinstance(t, torch.Tensor) else None if t is None else t) + return out + + # ------------------------ + # Action selection (hierarchical) + # ------------------------ + @torch.no_grad() + def sample( + self, + obs: torch.Tensor, + greedy: bool = False, + eps: Optional[float] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample hierarchical action indices using the same API style as PPO. + Returns: + actions_index: packed hierarchical indices (ActionSpace format) + actions_log_p: log-prob of sampled action under current policy + entropies: per-head entropies (aggregated by ActionSpace) + """ + + # Build distributions from policy + dists = self.policy_model(obs) + eps_dists = ActionSpace.uniform_distributions(obs) + + # Hierarchical sample + use_uniform = (eps is not None) and (torch.rand((), device=self.device).item() < eps) + actions_index = ActionSpace.sample( + obs, + dists, + eps_dists, + uniform=use_uniform, + greedy=greedy, + ) + + # Stats for the sampled actions + actions_log_p, entropies = ActionSpace.distributions_stats( + dists, + actions_index, + eps_distributions=eps_dists if eps is not None else None, + eps=eps, + ) + return actions_index, actions_log_p, entropies + + # ------------------------ + # Value update (expectile regression using target twin-Q) + # ------------------------ + def update_value( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + *, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Updates V(s) by regressing towards min(Q1, Q2) from the *target* Q network. + """ + + with torch.no_grad(): + q1_t, q2_t = self.q_target(obs, action_idx) + q_min_t = torch.min(q1_t, q2_t) # [B] + + self.value_optimizer.zero_grad(set_to_none=True) + loss_v = self.value_model.loss(obs, q_min_t) + loss_v.backward() + self.value_optimizer.step() + return loss_v + + # ------------------------ + # Q update (TD with V(s')) + # ------------------------ + def update_q( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + rewards: torch.Tensor, + next_obs: torch.Tensor, + dones: torch.Tensor + ) -> torch.Tensor: + """ + Update twin Q networks with TD target: + target_q = r + gamma * (1 - done) * V_target(s') + If target_v is not provided, it is computed from the current value_model. + """ + + with torch.no_grad(): + target_v = self.value_model(next_obs).to(self.device) # [B] + + target_q = rewards + self.gamma * (1.0 - dones) * target_v # [B] + + self.q_optimizer.zero_grad(set_to_none=True) + + + loss_q = self.q_model.loss( + obs, + action_idx, + target_q + ) + loss_q.backward() + self.q_optimizer.step() + return loss_q + + # ------------------------ + # Policy update (advantage-weighted BC) + # ------------------------ + def update_policy( + self, + obs: torch.Tensor, + actions_index: torch.Tensor, # packed hierarchical indices (as stored by dataset) + *, + action_idx: Optional[torch.LongTensor] = None, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Update policy with advantage-weighted log-likelihood: + weights = exp(A / beta), A = min(Q1, Q2) - V(s) + - actions_index is used to compute log π(a|s) via ActionSpace.distributions_stats(...) + - Q needs decomposed (action_idx, param_indices/values). + """ + + # 1) log π(a|s) from hierarchical distributions + dists = self.policy_model(obs) + actions_log_p, _ = ActionSpace.distributions_stats(dists, actions_index) + + # 2) advantages = Q_min(s,a) - V(s) + assert action_idx is not None, "action_idx (top-level) is required for Q evaluation" + with torch.no_grad(): + q_min = self.q_model.q_values(obs, action_idx) # [B] + v = self.value_model(obs) # [B] + advantages = q_min - v # [B] + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + + # 3) loss (AWAC/IQL style) + + # 1. zero gradients + self.policy_optimizer.zero_grad(set_to_none=True) + # 2. compute loss + loss_pi = self.policy_model.loss( + actions_log_p=actions_log_p, + advantages=advantages, + beta=self.beta, + ) + + # 3. backpropagate + loss_pi.backward() + + # 4. clip gradients (to avoid instability) + torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), max_norm=5.0) + + + self.policy_optimizer.step() + # self.policy_lr_scheduler.step() + return loss_pi + + # ------------------------ + # Soft update of target Q + # ------------------------ + @torch.no_grad() + def soft_update_q_target(self): + """ + θ_target ← α θ + (1-α) θ_target + """ + for p, tp in zip(self.q_model.parameters(), self.q_target.parameters()): + tp.data.copy_(self.alpha * p.data + (1.0 - self.alpha) * tp.data) + + def update(self, batch: Tuple[torch.Tensor, ...]) -> Dict[str, float]: + """ + One full IQL update step: + 1. Update Q-functions + 2. Update value function + 3. Update policy (AWAC/IQL style) + 4. Soft update target Q + Returns dict of losses for logging. + """ + obs, actions_index, rewards, next_obs, dones = (t.to(self.device, non_blocking=True) for t in batch) + + + # ---- 1) Update Q ---- + loss_q = self.update_q( + obs=obs, + action_idx=actions_index, # top-level index + rewards=rewards, + next_obs=next_obs, + dones=dones, + ) + + # ---- 2) Update Value ---- + loss_v = self.update_value(obs, actions_index) + + + # ---- 3) Update Policy ---- + loss_pi = self.update_policy( + obs=obs, + actions_index=actions_index, + action_idx=actions_index, # required for Q evaluation + ) + + + + # ---- 4) Soft update Q target ---- + self.soft_update_q_target() + + return { + "q": float(loss_q.item()), + "policy": float(loss_pi.item()), + "value": float(loss_v.item()), + } \ No newline at end of file diff --git a/iql/config.py b/iql/config.py new file mode 100644 index 0000000..e4276e7 --- /dev/null +++ b/iql/config.py @@ -0,0 +1,159 @@ +from dotenv import load_dotenv +load_dotenv() + + +import os +from utils.singleton import Singleton +import json +from typing import Literal + + + + +class Config(metaclass=Singleton): + """Class to store and load global configuration""" + + ############## IQL specific parameters ############## + gamma : float + """Discount factor""" + tau : float + """expectile regression parameter""" + inverse_temperature : float + """Inverse temperature for advantage-weighted regression""" + alpha : float + """target smoothing coefficient""" + batch_size : int + """Batch size for training""" + learning_rate : dict[str, float] + """Learning rate for the optimizer""" + max_steps : int + """Maximum number of training steps""" + target_update_freq : int + """Frequency of target network updates""" + sparse_reward : bool + """Flag to enable sparse reward""" + + offline_data_directory : str + """The offline data directory""" + offline_data_file : str + """The offline data file""" + + + ############## Environment specific parameters ############## + max_num_stores_loads: int + """The maximum number of loads in the nested loops""" + max_num_loops: int + """The max number of nested loops""" + max_num_load_store_dim: int + """The max number of dimensions in load/store buffers""" + num_tile_sizes: int + """The number of tile sizes""" + vect_size_limit: int + """Vectorization size limit to prevent large sizes vectorization""" + order: list[list[str]] + """The order of actions that needs to bo followed""" + interchange_mode: Literal['enumerate', 'pointers', 'continuous'] + """The method used for interchange action""" + exploration: list[Literal['entropy', 'epsilon']] + """The exploration method""" + init_epsilon: float + """The initial epsilon value for epsilon greedy exploration""" + + normalize_bounds: Literal['none', 'max', 'log'] + """Flag to indicate if the upper bounds in the input should be normalized or not""" + + + split_ops: bool + """Flag to enable splitting operations into separate benchmarks""" + + activation: Literal["relu", "tanh"] + """The activation function to use in the network""" + + benchmarks_folder_path: str + """Path to the benchmarks folder. Can be empty if optimization mode is set to "last".""" + + bench_count: int + """Number of batches in a trajectory""" + + truncate: int + """Maximum number of steps in the schedule""" + json_file: str + """Path to the JSON file containing the benchmarks execution times.""" + eval_json_file: str + """Path to the JSON file containing the benchmarks execution times for evaluation.""" + + tags: list[str] + """List of tags to add to the neptune experiment""" + + debug: bool + """Flag to enable debug mode""" + + exec_data_file: str + """Path to the file containing the execution data""" + results_dir: str + """Path to the results directory""" + + loaded: bool + """Flag to check if the config was already loaded from JSON file or not""" + + def __init__(self): + """Initialize the default values""" + # IQL specific parameters + self.gamma = 0.99 + self.tau = 0.7 + self.inverse_temperature = 3.0 + self.alpha = 0.005 + self.batch_size = 256 + self.learning_rate = { + "value": 3e-4, + "q": 3e-4, + "policy": 3e-4 + } + self.max_steps = 1000000 + self.target_update_freq = 1 + self.sparse_reward = True + + self.offline_data_directory = "./data" + self.offline_data_file = "offline_data.npz" + + # Environment specific parameters + self.max_num_stores_loads = 2 + self.max_num_loops = 4 + self.max_num_load_store_dim = 2 + self.num_tile_sizes = 2 + self.vect_size_limit = 16 + self.order = [] + self.interchange_mode = 'continuous' + self.exploration = ['entropy', 'epsilon'] + self.init_epsilon = 1.0 + self.normalize_bounds = 'log' + self.split_ops = False + self.activation = "relu" + self.benchmarks_folder_path = "./benchmarks" + self.bench_count = 1 + self.truncate = 20 + self.json_file = "./config/exec_times.json" + self.eval_json_file = "./config/exec_times.json" + self.tags = [] + self.debug = False + self.exec_data_file = "./data/exec_data.npz" + self.results_dir = "./results" + self.loaded = False + + def load_from_json(self): + """Load the configuration from the JSON file.""" + # Open the JSON file + with open(os.getenv("OFFLINE_RL_CONFIG_FILE_PATH"), "r") as f: + config = json.load(f) + # Set the configuration values + for key, value in config.items(): + if hasattr(self, key): + setattr(self, key, value) + + def to_dict(self): + """Convert the configuration to a dictionary.""" + return self.__dict__ + + def __str__(self): + """Convert the configuration to a string.""" + return str(self.to_dict()) \ No newline at end of file diff --git a/iql/policy.py b/iql/policy.py new file mode 100644 index 0000000..5c27a43 --- /dev/null +++ b/iql/policy.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from torch.distributions import Distribution +from typing import Optional, List, Type +from rl_autoschedular import config as cfg +from rl_autoschedular.actions import ActionSpace, Interchange +from rl_autoschedular.observation import Observation, ObservationPart + +# Match PPO’s activation config +ACTIVATION = nn.ReLU if cfg.activation == "relu" else nn.Tanh + + +class IQLPolicyModel(nn.Module): + """ + IQL policy network, sharing architecture with PPO’s PolicyModel. + - Backbone: 3×512 MLP with ACTIVATION() + - Heads: one for action selection + one per action’s parameterization + - Output: list[Distribution], via ActionSpace.distributions + - Loss: BC loss with advantage-weighted log-likelihood (AWAC / IQL style) + """ + + def __init__(self, obs_parts: List[Type[ObservationPart]]): + super().__init__() + self.obs_parts = obs_parts + + + # Shared encoder + in_size = sum(part.size() for part in obs_parts) + self.backbone = nn.Sequential( + nn.Linear(in_size, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + ) + + # One head for action choice + one for each action’s params + output_sizes = [ActionSpace.size()] + [ + action.network_output_size() for action in ActionSpace.supported_actions + ] + self.heads_attributes = [f"head_{i}" for i in range(len(output_sizes))] + + for head_attr, output_size in zip(self.heads_attributes, output_sizes): + if not output_size: + setattr(self, head_attr, None) + continue + + head = nn.Linear(512, output_size) + if cfg.new_architecture: + head = nn.Sequential(nn.Linear(512, 512), ACTIVATION(), head) + setattr(self, head_attr, head) + + def forward(self, obs: torch.Tensor) -> List[Optional[Distribution]]: + """ + Forward pass: produce a Distribution object per action head. + """ + embedded = self.backbone(Observation.get_parts(obs, *self.obs_parts)) + heads: List[Optional[nn.Module]] = [getattr(self, attr) for attr in self.heads_attributes] + actions_logits = [head(embedded) if head else None for head in heads] + return ActionSpace.distributions(obs, *actions_logits) + + def loss( + self, + actions_log_p: torch.Tensor, + advantages: torch.Tensor, + beta: float = 1.0, + ) -> torch.Tensor: + """ + Advantage-weighted behavioral cloning (AWAC) / IQL policy loss. + Args: + obs: Observations [B, ...] + actions_log_p: log π(a|s) from this policy, evaluated at dataset actions + advantages: Advantage estimates A(s,a) from IQL (Q - V) + beta: Temperature scaling (larger beta = more deterministic) + Returns: + Scalar loss tensor. + """ + # Weights = exp(A / beta), clipped for stability + weights = torch.exp(advantages / beta).clamp(max=100.0) + loss = -(weights * actions_log_p).mean() + return loss \ No newline at end of file diff --git a/iql/q_functions.py b/iql/q_functions.py new file mode 100644 index 0000000..881713a --- /dev/null +++ b/iql/q_functions.py @@ -0,0 +1,281 @@ +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import torch.nn as nn +from typing import List, Optional, Tuple, Type + +# reuse same Observation/ActionSpace imports you have in repo +from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.state import OperationState +from rl_autoschedular import config as cfg + +# Keep same activation used elsewhere +ACTIVATION = nn.ReLU # replace with your actual ACTIVATION if different + + +class _DiscreteQHead(nn.Module): + """A head that outputs Q-values for each discrete option (flat logits -> Q-values).""" + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + # single linear mapping from embedding -> out_dim Q-values + self.net = nn.Linear(in_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # returns [B, out_dim] + return self.net(x) + + +class _TwinHiearchicalQNetwork(nn.Module): + """ + One Q network: + - backbone: embed observation -> [B, embed_dim] + - head_action: Q for selecting each action type [B, |A|] + - param_heads: for each action, a head producing Q-values for that action's network_output_size + NOTE: For multi-slot actions (e.g. Tiling), param_head outputs are flat: a concatenation + of per-slot Q-values. We'll reshape inside q_contribs when computing param-contribution. + """ + def __init__(self, obs_parts: List[Type[ObservationPart]], embed_dim: int = 512): + super().__init__() + self.obs_parts = obs_parts + in_size = sum(p.size() for p in obs_parts) + + # Backbone + self.backbone = nn.Sequential( + nn.Linear(in_size, embed_dim), + ACTIVATION(), + nn.Linear(embed_dim, embed_dim), + ACTIVATION(), + nn.Linear(embed_dim, embed_dim), + ACTIVATION(), + ) + + # Action selection Q head + self.head_action = _DiscreteQHead(embed_dim, ActionSpace.size()) + + # Parameter heads (one per supported action) + self.param_heads = nn.ModuleList() + # Save meta per action for convenient reshaping later + self._param_meta: List[Optional[dict]] = [] + for action_cls in ActionSpace.supported_actions: + out_dim = action_cls.network_output_size() + params_size = action_cls.params_size() + if out_dim and out_dim > 0: + head = _DiscreteQHead(embed_dim, out_dim) + self.param_heads.append(head) + # if multi-slot, compute classes_per_slot = out_dim // params_size (integer) + if params_size > 0: + classes_per_slot = None + if params_size > 0: + classes_per_slot = out_dim // params_size if params_size != 0 else None + self._param_meta.append({ + "params_size": params_size, # number of slots for this action + "out_dim": out_dim, + "classes_per_slot": classes_per_slot # number of classes per slot (None if single slot) + # Example : Interchange -> params_size=1 , out_dim= 7 , classes_per_slot= 7 , 7 loop choices for the current interchange + }) + else: + self._param_meta.append({"params_size": 0, "out_dim": out_dim, "classes_per_slot": None}) + else: + # no parameters -> placeholder (we'll treat as None) + # ( NT | V ) + self.param_heads.append(nn.Identity()) + self._param_meta.append(None) + + def forward( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None + ) -> torch.Tensor: + emb = self._embed(obs) + return self.q_contribs(emb, action_idx, param_indices) + + def _embed(self, obs: torch.Tensor) -> torch.Tensor: + parts = Observation.get_parts(obs, *self.obs_parts) # returns [B, in_size] + return self.backbone(parts) # [B, embed_dim] + + def q_contribs( + self, + emb: torch.Tensor, + action_idx: torch.LongTensor, # [B] + param_indices: Optional[List[Optional[torch.LongTensor]]] = None # list of length B + ) -> torch.Tensor: + """ + Compute Q(s, a, params) as: Q_action(s)[a] + Q_params(s, a, params). + - emb: [B, embed_dim] + - action_idx: [B] integers in [0..|A|-1] + - param_indices: list of length B (each either None or [params_size]) + Returns: + q_total: [B] - scalar Q for each sample + """ + B = emb.size(0) + device = emb.device + + # ---- top-level action contribution ---- + act_qs = self.head_action(emb) # [B, |A|] + act_q = act_qs.gather(1, action_idx.view(-1, 1)).squeeze(1) # [B] + + # ---- parameter contribution ---- + param_q = torch.zeros(B, device=device) + + # group samples by chosen action to do batched head computation + for k, head in enumerate(self.param_heads): + meta = self._param_meta[k] + if isinstance(head, nn.Identity) or (meta is None): + continue + + # mask = all samples where chosen action == k + mask = (action_idx == k) + if not mask.any(): + continue + + # get embeddings and their chosen param indices for this action + emb_masked = emb[mask] + head_out = head(emb_masked) # [N_mask, out_dim_k] + + psize = meta["params_size"] # + out_dim = meta["out_dim"] + cps = meta["classes_per_slot"] + + # collect just the param indices for the masked samples + masked_params = [param_indices[i] for i in range(B) if mask[i]] # [N_mask, Optional[LongTensor] of consistet size psize] + # they should all be not None if this action has params + assert all((p is not None) for p in masked_params) or psize == 0 + + if psize == 0: + continue + + if psize == 1: + # single-slot + idx_tensor = torch.stack([p.view(-1)[0] for p in masked_params]).view(-1, 1) # [N_mask, 1] + q_k = head_out.gather(1, idx_tensor).squeeze(1) # [N_mask] + param_q[mask] = q_k + else: + # multi-slot + assert cps is not None and cps > 0, "classes_per_slot unknown for multi-slot head" + + reshaped = head_out.view(-1, psize, cps) # [N_mask, psize, cps] + idx_tensor = torch.stack(masked_params).long() # [N_mask, psize] + idx_exp = idx_tensor.unsqueeze(-1) # [N_mask, psize, 1] + gathered = torch.gather(reshaped, dim=2, index=idx_exp).squeeze(-1) # [N_mask, psize] + q_k = gathered.sum(dim=1) # [N_mask] + param_q[mask] = q_k + + return act_q + param_q + + + + + +# ---------------- hierarchical double Q network ---------------- +class IQLTwinQ(nn.Module): + """ + Top-level twin Q network that matches style of PolicyModel: + - builds two Q-branches (Q1 and Q2), each using _TwinHiearchicalQNetwork + - provides helpers to split flat action-tensor into action_idx + param slices + """ + def __init__(self, obs_parts: List[Type[ObservationPart]], embed_dim: int = 512): + super().__init__() + self.obs_parts = obs_parts + # instantiate two Q heads (Q1, Q2) + self.q1 = _TwinHiearchicalQNetwork(obs_parts, embed_dim=embed_dim) + self.q2 = _TwinHiearchicalQNetwork(obs_parts, embed_dim=embed_dim) + + + def forward(self, obs: torch.Tensor, index: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute Q1(s, a, params) and Q2(s, a, params) for a batch. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + Returns: + (q1_vals, q2_vals) each shaped [B] + """ + action_idx, params_list = self._split_action_tensor(index) + q1_vals = self.q1(obs, action_idx, params_list) + q2_vals = self.q2(obs, action_idx, params_list) + return q1_vals, q2_vals + + def q_values(self, obs: torch.Tensor, index: torch.LongTensor) -> torch.Tensor: + """ + Compute min(Q1(s, a, params), Q2(s, a, params)) for a batch. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + Returns: + q_vals shaped [B] + """ + action_idx, params_list = self._split_action_tensor(index) + q1_vals = self.q1(obs, action_idx, params_list) + q2_vals = self.q2(obs, action_idx, params_list) + return torch.min(q1_vals, q2_vals) + + + def loss(self, obs: torch.Tensor, index: torch.LongTensor, target_q: torch.Tensor) -> torch.Tensor: + """ + Compute MSE loss between Q1, Q2 and target_q. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + - target_q: target Q-values [B] + Returns: + scalar loss + """ + q1_vals, q2_vals = self.forward(obs, index) # each [B] + loss_fn = nn.MSELoss() + loss1 = loss_fn(q1_vals, target_q) + loss2 = loss_fn(q2_vals, target_q) + return loss1 + loss2 + + @staticmethod + def _split_action_tensor(index: torch.LongTensor) -> Tuple[torch.LongTensor, List[Optional[torch.LongTensor]]]: + """ + Split the `index` tensor returned by ActionSpace.sample() into: + - action_idx: [B] + - params: list of length B, each either None (no params) or LongTensor [params_size] for that action + """ + B = index.size(0) + device = index.device + + action_idx = index[:, 0].long() # [B] + cum = ActionSpace.cumulative_params_sizes() + + params: List[Optional[torch.LongTensor]] = [] + + for i in range(B): + a_idx = action_idx[i].item() + action_type = ActionSpace.supported_actions[a_idx] + if action_type.params_size() == 0: + params.append(None) + else: + start, end = cum[a_idx], cum[a_idx + 1] + # extract just that sample's params for its chosen action + params.append(index[i, start:end].long()) + + return action_idx, params + + + +def main(): + model = IQLTwinQ([OpFeatures, ActionHistory]) + + _model = _TwinHiearchicalQNetwork([OpFeatures, ActionHistory]) + + x = torch.tensor([[2, 1, 5, 7, 0, 0, 0, 0, 3, 4, 2, 0, 0, 0, 0, 2]]).float() + + + action_idx , param_idx = model._split_action_tensor(x) + + + obs = torch.zeros([1, 2152]) + + + q = _model(obs, action_idx, param_idx) + + print("Q-values:", q) + + + +if __name__ == "__main__": + + main() \ No newline at end of file diff --git a/iql/singleton.py b/iql/singleton.py new file mode 100644 index 0000000..252d3aa --- /dev/null +++ b/iql/singleton.py @@ -0,0 +1,8 @@ +class Singleton(type): + """Meta class to create a singleton instance of a class""" + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] \ No newline at end of file diff --git a/iql/value_function.py b/iql/value_function.py new file mode 100644 index 0000000..ae524b0 --- /dev/null +++ b/iql/value_function.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from typing import List, Type +from rl_autoschedular import config as cfg +from rl_autoschedular.observation import Observation, ObservationPart + + +ACTIVATION = nn.ReLU + + +class IQLValueModel(nn.Module): + """ + IQL Value function with the SAME encoder/MLP layout as PPO's ValueModel: + Linear(sum(obs_parts)->512) -> ACT -> 512 -> ACT -> 512 -> ACT -> 1 + - Input: full Observation tensor, then sliced via Observation.get_parts to match PPO. + - Output: V(s) as shape [B], same squeeze(-1) behavior as PPO. + - Loss: Expectile regression with parameter tau (IQL). + """ + + def __init__( + self, + obs_parts: List[Type[ObservationPart]], + tau: float = 0.7, + ): + super().__init__() + self.obs_parts = obs_parts + self.tau = cfg.tau # consider wiring this from cfg (e.g., cfg.iql.tau) if you keep hyperparams in config + + in_size = sum(part.size() for part in obs_parts) + self.network = nn.Sequential( + nn.Linear(in_size, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 1), + ) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + """ + Args: + obs: full Observation tensor (like in PPO) + Returns: + V(s) as [B] + """ + x = Observation.get_parts(obs, *self.obs_parts) + return self.network(x).squeeze(-1) # [B] + + @torch.no_grad() + def v(self, obs: torch.Tensor) -> torch.Tensor: + """Convenience alias often used in IQL codepaths.""" + return self.forward(obs) + + def loss(self, obs: torch.Tensor, q_values: torch.Tensor) -> torch.Tensor: + """ + Expectile regression loss: minimize E[ w_tau(u) * u^2 ], u = Q(s,a) - V(s) + Args: + obs: full Observation tensor for states [B, ...] (same as PPO input) + q_values: [B] or [B,1] tensor with target Q(s,a) (DETACHED upstream in IQL) + """ + v = self.forward(obs) # [B] + q = q_values.squeeze(-1) # [B] + diff = q - v # u + + # weight = |tau - 1(u < 0)| + # same as: tau if u >= 0 else (1 - tau) + weight = torch.abs(self.tau - (diff < 0).float()) + return (weight * diff.pow(2)).mean() \ No newline at end of file diff --git a/neptune_sync.py b/neptune_sync.py index 99785ee..a97ad7c 100644 --- a/neptune_sync.py +++ b/neptune_sync.py @@ -1,6 +1,4 @@ -# Load environment variables -from dotenv import load_dotenv -load_dotenv(override=True) +from utils.keys import NEPTUNE_TOKEN, NEPTUNE_PROJECT import neptune from neptune import Run @@ -33,7 +31,8 @@ with open(os.path.join(run_path, 'tags'), 'r') as f: tags = f.read().splitlines() neptune_run = neptune.init_run( - project=os.getenv('NEPTUNE_PROJECT'), + project=NEPTUNE_PROJECT, + api_token=NEPTUNE_TOKEN, tags=tags, ) neptune_runs[run] = neptune_run diff --git a/rl_autoschedular/actions/__init__.py b/rl_autoschedular/actions/__init__.py index 208c1d3..5bea362 100644 --- a/rl_autoschedular/actions/__init__.py +++ b/rl_autoschedular/actions/__init__.py @@ -6,6 +6,7 @@ from .tiled_fusion import TiledFusion from .interchange import Interchange from .vectorization import Vectorization +from .img2col import Img2Col from rl_autoschedular.state import OperationState import torch from torch.distributions import Distribution, Categorical @@ -21,7 +22,8 @@ class ActionSpace: TiledParallelization, TiledFusion, Interchange, - Vectorization + Vectorization, + Img2Col ] @classmethod diff --git a/rl_autoschedular/actions/img2col.py b/rl_autoschedular/actions/img2col.py new file mode 100644 index 0000000..1775f43 --- /dev/null +++ b/rl_autoschedular/actions/img2col.py @@ -0,0 +1,33 @@ +from utils.config import Config +from .base import Action +from rl_autoschedular.transforms import transform_img2col +from rl_autoschedular.state import OperationFeatures, OperationState, OperationType +from typing import Callable, Optional + + +class Img2Col(Action): + """Class representing Img2Col action""" + + symbol = 'I2C' + parameters: None + + def __init__( + self, + state: Optional[OperationState] = None, + **extras + ): + super().__init__( + state, + **extras + ) + + def _apply_ready(self, code): + original_code = code + try: + return transform_img2col(code, self.operation_tag) + except Exception: + return original_code + + @classmethod + def is_allowed(cls, state: OperationState): + return state.operation_features.operation_type == OperationType.Conv diff --git a/rl_autoschedular/benchmarks.py b/rl_autoschedular/benchmarks.py index b19671a..2194d13 100644 --- a/rl_autoschedular/benchmarks.py +++ b/rl_autoschedular/benchmarks.py @@ -36,7 +36,7 @@ def __init__(self, is_training: bool = True): modified = False bench_code = benchmark_data.code for op_tag in benchmark_data.operation_tags: - if 'conv_2d' not in benchmark_data.operations[op_tag].operation_name: + if 'conv_2d' not in benchmark_data.operations[op_tag].operation_name or not cfg.use_img2col: continue bench_code = transform_img2col(bench_code, op_tag) modified = True diff --git a/rl_autoschedular/env.py b/rl_autoschedular/env.py index 019c9e2..150ae71 100644 --- a/rl_autoschedular/env.py +++ b/rl_autoschedular/env.py @@ -2,13 +2,14 @@ from rl_autoschedular.benchmarks import Benchmarks from typing import Optional from rl_autoschedular.execution import Execution -from rl_autoschedular.actions import Action, TiledFusion +from rl_autoschedular.actions import Action, TiledFusion, Img2Col from utils.log import print_error from utils.config import Config import random import math import traceback +from rl_autoschedular.state import extract_bench_features_from_code class Env: """RL Environment class""" @@ -48,6 +49,7 @@ def step(self, state: OperationState, action: Action) -> OperationState: bool: A flag indicating if the operation is done. Optional[float]: The speedup (if the operation is executed successfully) for logging purposes. """ + # Copy the current state to introduce the changes throughout the function next_state = state.copy() @@ -249,6 +251,21 @@ def __update_state_infos(self, state: OperationState, action: Action): # In case of fusion we need to update the producer features as well if isinstance(action, TiledFusion): action.update_producer_features(state, self.benchmark_data) + + # In case of Img2Col, we need to update the benchmark data as whole + if isinstance(action, Img2Col): + self.benchmark_data = extract_bench_features_from_code(self.benchmark_data.bench_name, action.apply(self.benchmark_data.code), self.benchmark_data.root_exec_time) + i2c_state = self.__init_op_state(-1) + + state.bench_idx = i2c_state.bench_idx + state.bench_name = i2c_state.bench_name + state.operation_tag = i2c_state.operation_tag + state.original_operation_features = i2c_state.original_operation_features + state.operation_features = i2c_state.operation_features + state.producer_tag = i2c_state.producer_tag + state.producer_operand_idx = i2c_state.producer_operand_idx + state.producer_features = i2c_state.producer_features + state.terminal = i2c_state.terminal # Get updated operation features state.operation_features = action.update_features(state.operation_features) diff --git a/rl_autoschedular/ppo.py b/rl_autoschedular/ppo.py index 9a73803..a46946f 100644 --- a/rl_autoschedular/ppo.py +++ b/rl_autoschedular/ppo.py @@ -19,7 +19,7 @@ from typing import Optional -def collect_trajectory(data: Benchmarks, model: Model, step: int): +def collect_trajectory(data: Benchmarks, model: Model, step: int) -> TrajectoryData: """Collect a trajectory using the model and the environment. Args: @@ -34,7 +34,7 @@ def collect_trajectory(data: Benchmarks, model: Model, step: int): dm = DaskManager() fl = FileLogger() exe = Execution() - cfg = Config() + cfg = Config() eps = None if 'epsilon' in cfg.exploration: @@ -362,7 +362,7 @@ def evaluate_benchmarks(model: Model, data: Benchmarks): def __execute_states(state: OperationState, exec_data_file: str, benchs: Benchmarks, main_exec_data: Optional[dict[str, dict[str, int]]]): - print(f"Handling bench: {state.bench_name}...", end=' ', file=sys.stderr) + print(f"Handling bench: {state.bench_name} with {[[action.__str__() for action in transformation] for transformation in state.transformation_history]}", end=' ', file=sys.stderr) worker_start = time() Execution(exec_data_file, main_exec_data) diff --git a/rl_autoschedular/state.py b/rl_autoschedular/state.py index f2ed3d6..dbfb04a 100644 --- a/rl_autoschedular/state.py +++ b/rl_autoschedular/state.py @@ -8,6 +8,8 @@ from utils.config import Config from utils.log import print_error +from utils.keys import AST_DUMPER_BIN_PATH + if TYPE_CHECKING: from rl_autoschedular.actions.base import Action @@ -44,6 +46,16 @@ class NestedLoopFeatures: def copy(self): """Copy the current NestedLoopFeatures object.""" return NestedLoopFeatures(self.arg, self.lower_bound, self.upper_bound, self.step, self.iterator_type) + + def to_dict(self): + """Convert the NestedLoopFeatures to a dictionary.""" + return { + "arg": self.arg, + "lower_bound": self.lower_bound, + "upper_bound": self.upper_bound, + "step": self.step, + "iterator_type": self.iterator_type.value + } @dataclass @@ -84,6 +96,23 @@ def copy(self): self.vectorizable, self.pre_actions.copy() ) + + def to_dict(self): + """Convert the OperationFeatures to a dictionary.""" + return { + "operation_name": self.operation_name, + "operation_type": self.operation_type.value, + "op_count": self.op_count, + "load_data": self.load_data, + "store_data": self.store_data, + "nested_loops": [ + loop.to_dict() for loop in self.nested_loops + ], + "producers": self.producers, + "consumers": self.consumers, + "vectorizable": self.vectorizable, + "pre_actions": [action.symbol for action in self.pre_actions] + } @dataclass @@ -109,6 +138,18 @@ def copy(self): {tag: op.copy() for tag, op in self.operations.items()}, self.root_exec_time ) + + def to_dict(self): + """Convert the BenchmarkFeatures to a dictionary.""" + return { + "bench_name": self.bench_name, + "code": self.code, + "operation_tags": self.operation_tags, + "operations": { + tag: op.to_dict() for tag, op in self.operations.items() + }, + "root_exec_time": self.root_exec_time + } @dataclass @@ -173,6 +214,23 @@ def copy(self): self.terminal ) + def to_dict(self): + """Convert the OperationState to a dictionary.""" + return { + "bench_idx": self.bench_idx, + "bench_name": self.bench_name, + "operation_tag": self.operation_tag, + "original_operation_features": self.original_operation_features.to_dict(), + "operation_features": self.operation_features.to_dict(), + "producer_tag": self.producer_tag, + "producer_operand_idx": self.producer_operand_idx, + "producer_features": self.producer_features.to_dict() if self.producer_features is not None else None, + "transformation_history": [ + [action.__repr__() for action in seq] for seq in self.transformation_history + ], + "terminal": self.terminal + } + def extract_bench_features_from_code(bench_name: str, code: str, root_execution_time: int): """Extract benchmark features from the given code. @@ -187,7 +245,7 @@ def extract_bench_features_from_code(bench_name: str, code: str, root_execution_ BenchmarkFeatures: the extracted benchmark features """ result = subprocess.run( - f'{os.getenv("AST_DUMPER_BIN_PATH")} -', + f'{AST_DUMPER_BIN_PATH} -', shell=True, input=code.encode('utf-8'), stdout=subprocess.PIPE, @@ -211,7 +269,7 @@ def extract_bench_features_from_file(bench_name: str, file_path: str, root_execu BenchmarkFeatures: the extracted benchmark features """ result = subprocess.run( - f'{os.getenv("AST_DUMPER_BIN_PATH")} {file_path}', + f'{AST_DUMPER_BIN_PATH} {file_path}', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE diff --git a/rl_autoschedular/trajectory.py b/rl_autoschedular/trajectory.py index 57a9033..303ae2c 100644 --- a/rl_autoschedular/trajectory.py +++ b/rl_autoschedular/trajectory.py @@ -255,6 +255,7 @@ def update_attributes(self, model: Model): """ start = time() + actions_old_log_p, values, _ = model(self.obs.to(device), self.actions_index.to(device)) self.actions_old_log_p, self.values = actions_old_log_p.cpu(), values.cpu() @@ -404,3 +405,12 @@ def reset(self): self.actions_bev_log_p.clear() self.rewards.clear() self.done.clear() + + def __len__(self) -> int: + """Get the length of the collected trajectory. + + Returns: + int: The length of the collected trajectory. + """ + assert len(self.num_loops) == len(self.actions_index) == len(self.obs) == len(self.next_obs) == len(self.actions_bev_log_p) == len(self.rewards) == len(self.done) + return len(self.obs) \ No newline at end of file diff --git a/rl_autoschedular/transforms.py b/rl_autoschedular/transforms.py index e59451c..89791cd 100644 --- a/rl_autoschedular/transforms.py +++ b/rl_autoschedular/transforms.py @@ -366,7 +366,6 @@ def transform_bufferize_and_lower_v(code: str): Args: code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. Returns: str: The code after applying the transformation. diff --git a/scripts/neptune-sync.sh b/scripts/neptune-sync.sh index e46d1ce..6cf75a5 100644 --- a/scripts/neptune-sync.sh +++ b/scripts/neptune-sync.sh @@ -2,23 +2,24 @@ # Define the resource requirements here using #SBATCH +# SBATCH -j neptune_sync #SBATCH -p compute #SBATCH --nodes=1 #SBATCH -c 4 #SBATCH --mem=16G #SBATCH -t 07-00 #SBATCH -o logs/neptune/%j.out -#SBATCH --mail-type=FAIL,TIME_LIMIT -#SBATCH --mail-user=mt5383@nyu.edu +#SBATCH --mail-type=ALL +#SBATCH --mail-user=kb5213@nyu.edu -# Resource requiremenmt commands end here +# Resource requirements end here -#Add the lines for running your code/application +# Add the lines for running your code/application module load miniconda-nobashrc eval "$(conda shell.bash hook)" # Activate any environments if required -conda activate testenv +conda activate llvm-build # Execute the code -python $SCRATCH/MLIR-RL/neptune_sync.py +python $SCRATCH/workspace/MLIR-RL/neptune_sync.py diff --git a/tests/cuda.py b/tests/cuda.py new file mode 100644 index 0000000..f4afc62 --- /dev/null +++ b/tests/cuda.py @@ -0,0 +1,48 @@ +import torch +import platform +import subprocess + +def check_cuda(): + # Check if CUDA is available + cuda_available = torch.cuda.is_available() + print(f"CUDA Available: {cuda_available}") + + if cuda_available: + # Number of GPUs available + num_gpus = torch.cuda.device_count() + print(f"Number of GPUs: {num_gpus}") + + # CUDA version from torch + print(f"PyTorch CUDA Version: {torch.version.cuda}") + + # Information about each GPU + for i in range(num_gpus): + print(f"\nDevice {i}:") + print(f" Name: {torch.cuda.get_device_name(i)}") + print(f" Total Memory: {torch.cuda.get_device_properties(i).total_memory / (1024 ** 3):.2f} GB") + print(f" Memory Allocated: {torch.cuda.memory_allocated(i) / (1024 ** 3):.2f} GB") + print(f" Memory Cached: {torch.cuda.memory_reserved(i) / (1024 ** 3):.2f} GB") + print(f" Compute Capability: {torch.cuda.get_device_capability(i)}") + print(f" Device Properties: {torch.cuda.get_device_properties(i)}") + + # Driver version + print(f"Driver Version: {torch.version.cuda}") + + # Run nvidia-smi command to get additional GPU details + try: + nvidia_smi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, text=True) + print(f"\nNVIDIA-SMI Output:\n{nvidia_smi_output.stdout}") + except FileNotFoundError: + print("nvidia-smi command not found. Ensure NVIDIA drivers are installed correctly.") + + else: + print("No CUDA-enabled GPU detected.") + + # Print system and Python environment information + print("\nSystem and Environment Information:") + print(f" Operating System: {platform.system()} {platform.release()} ({platform.version()})") + print(f" Python Version: {platform.python_version()}") + print(f" PyTorch Version: {torch.__version__}") + +if __name__ == "__main__": + check_cuda() \ No newline at end of file diff --git a/train.py b/train.py index 1c3b85f..c75be67 100644 --- a/train.py +++ b/train.py @@ -98,6 +98,16 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: # Collect trajectory using the model trajectory = collect_trajectory(train_data, model, step) + + # print the sizes of the trajectory components for debugging + print_info( + f"Trajectory collected: " + f"Obs size: {trajectory.obs.size()}, " + f"Next Obs size: {trajectory.next_obs.size()}, " + f"Actions size: {trajectory.actions_index.size()}, " + f"Rewards size: {trajectory.rewards.size()}, ", + flush=True + ) # Extend trajectory with previous trajectory if cfg.reuse_experience != 'none': @@ -117,7 +127,7 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: ppo_update(trajectory, model, optimizer) # Save the model - if (step + 1) % 5 == 0: + if (step + 1) % cfg.save_model_every == 0: torch.save( model.state_dict(), os.path.join( @@ -126,7 +136,7 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: ) ) - if (step + 1) % 100 == 0: + if (step + 1) % cfg.evaluate_every == 0: print_info('- Evaluating benchmarks -') evaluate_benchmarks(model, eval_data) @@ -138,6 +148,6 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: elapsed_dlt = timedelta(seconds=int(elapsed)) eta_dlt = timedelta(seconds=int(eta)) -if (step + 1) % 100 != 0: +if (step + 1) % cfg.evaluate_every != 0: print_info('- Evaluating benchmarks -') evaluate_benchmarks(model, eval_data) diff --git a/train_iql_offline.py b/train_iql_offline.py new file mode 100644 index 0000000..f1f6c38 --- /dev/null +++ b/train_iql_offline.py @@ -0,0 +1,208 @@ +import dotenv + +dotenv.load_dotenv() + +import os +import time +import torch +import numpy as np + +from rl_autoschedular import config as cfg, file_logger as fl +from utils.config import Config +from utils.file_logger import FileLogger +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.env import Env +from iql.agent import IQLAgent +from utils.data_collector import OfflineDataset +from rl_autoschedular.observation import Observation,OpFeatures, ActionHistory + +from tqdm import trange + +device = torch.device("cpu") + +cfg = Config() +fl = FileLogger() + +def load_dataset(): + """Load offline dataset from OfflineDataset singleton.""" + dataset = OfflineDataset( + save_dir=cfg.offline_data_save_dir, + fname=cfg.offline_data_file + ).load() + + if not dataset: + raise FileNotFoundError(f"Offline dataset not found: {cfg.offline_data_file}") + + states = torch.tensor(dataset["obs"], dtype=torch.float32) + actions = torch.tensor(dataset["actions"], dtype=torch.long) + rewards = torch.tensor(dataset["rewards"], dtype=torch.float32) + next_states = torch.tensor(dataset["next_obs"], dtype=torch.float32) + dones = torch.tensor(dataset["dones"], dtype=torch.float32) + + return states, actions, rewards, next_states, dones + + +@torch.no_grad() +def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): + """Evaluta a the model on the evaluation environment. + Args: + model (Model): The policy/value model. + env (Env): The environment. + step (int): Current training step. + Returns: + env_time (float): Time spent in environment steps. + """ + + + env_time = 0.0 # Time spent in environment steps + + eps = None + + + # store rewards and entropies to log average for the model accross the benchmarks later + all_speedups = [] + all_entropies = [] + + + for _ in trange(cfg.bench_count, desc='Trajectory'): + + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done = False + speedup = None + + # store rewards and entropies to log average for the current benchmark later + bench_rewards, bench_entropies = [], [] + + bench_name = state.bench_name + + + while not bench_done: + obs = Observation.from_state(state) + + # Sample action and log-prob from *current policy* + action_index, action_log_p, entropy = model.sample(obs.to(device)) + assert action_index.size(0) == 1 and action_log_p.size(0) == 1 + action = ActionSpace.action_by_index(action_index[0], state) + + # Step environment + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + next_obs = Observation.from_state(next_state) + + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + + # Accumulate metrics + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # === Per-benchmark logging === + mean_reward = float(np.mean(bench_rewards)) if bench_rewards else 0.0 + mean_entropy = float(np.mean(bench_entropies)) if bench_entropies else 0.0 + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + + bench_metrics = { + "mean_reward": mean_reward, + "mean_entropy": mean_entropy, + "final_speedup": speedup if speedup is not None else 0.0, + } + + fl.log_scalars(f"eval/{bench_name}", bench_metrics, step) + + print( + f"\033[92m\n- Eval Bench: {bench_name}\n" + f"- Mean Reward: {mean_reward:.4f}\n" + f"- Mean Entropy: {mean_entropy:.4f}\n" + f"- Final Speedup: {speedup if speedup is not None else 0.0:.4f}\033[0m" + ) + + + # === Global logging (across all benchmarks) === + if all_speedups: + fl.log_scalar("eval/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("eval/average_entropy", float(np.mean(all_entropies)), step) + + return env_time + +def train_iql(): + # Load offline dataset + print(f"Loading offline dataset from {cfg.offline_data_file} ...") + states, actions, rewards, next_states, dones = load_dataset() + dataset_size = states.shape[0] + print(f"Dataset loaded: {dataset_size} transitions") + + # Initialize IQL agent + agent = IQLAgent(cfg,device,obs_parts=[OpFeatures, ActionHistory]) + + eval_env = Env(is_training=False,run_name=cfg.run_name) + + + print("Starting IQL training ...") + start_time = time.time() + + step = 0 + iql_trange = trange(cfg.max_steps, desc="IQL Training",dynamic_ncols=True) + for step in iql_trange: + # Sample a random batch + idxs = np.random.randint(0, dataset_size, size=cfg.batch_size) + batch = ( + states[idxs].to(device), + actions[idxs].to(device), + rewards[idxs].to(device), + next_states[idxs].to(device), + dones[idxs].to(device), + ) + + losses = agent.update(batch) + + + # Only log occasionally to reduce disk I/O + if step % 50 == 0: + fl.log_scalars("train", losses, step) + + if (step +1) % 100 == 0: + elapsed = time.time() - start_time + iql_trange.set_postfix({ + "Value Loss": f"{losses['value']:.4f}", + "Q Loss": f"{losses['q']:.4f}", + "Policy Loss": f"{losses['policy']:.4f}", + "Elapsed": f"{elapsed:.2f}s" + }) + + # Evaluate the agent on benchmarks every 1000 steps + if (step + 1) % 1000 == 0: + print("Evaluating on benchmarks ...") + eval_start = time.time() + env_time = evaluate_benchmarks(agent, eval_env, step) + eval_time = time.time() - eval_start + print(f"Evaluation completed in {eval_time:.2f} seconds (env time: {env_time:.2f} seconds)") + fl.flush() + + + + if (step+1) % 2000 == 0 and step > 0: + ckpt_path = os.path.join(cfg.results_dir, f"iql_step_{step}.pt") + os.makedirs(cfg.results_dir, exist_ok=True) + torch.save(agent.state_dict(), ckpt_path) + print(f"Checkpoint saved: {ckpt_path}") + + + + total_time = time.time() - start_time + print(f"Training finished in {total_time:.2f} seconds.") + + +if __name__ == "__main__": + train_iql() \ No newline at end of file diff --git a/train_iql_online.py b/train_iql_online.py new file mode 100644 index 0000000..3e7a75c --- /dev/null +++ b/train_iql_online.py @@ -0,0 +1,340 @@ +import os +import time +import torch +import numpy as np +from tqdm import trange + +import dotenv +dotenv.load_dotenv() + +from rl_autoschedular import config as cfg, file_logger as fl +from rl_autoschedular.env import Env +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.observation import Observation, OpFeatures, ActionHistory +from iql.iql_agent import IQLAgent +from utils.data_collector import OfflineDataset + + +device = torch.device("cpu") + + +MAX_STEPS_OFFLINE_ONLINE_RATIO = 100_000 # steps over which to decay online-offline ratio + +UPDATE_ITERS = 3 # gradient updates per batch + +def load_offline_dataset(): + """Load offline dataset for warm-starting replay buffer.""" + dataset = OfflineDataset( + save_dir=cfg.offline_data_save_dir, + fname=cfg.offline_data_file + ).load() + + if not dataset: + raise FileNotFoundError(f"Offline dataset not found: {cfg.offline_data_file}") + + states = torch.tensor(dataset["obs"], dtype=torch.float32) + actions = torch.tensor(dataset["actions"], dtype=torch.long) + rewards = torch.tensor(dataset["rewards"], dtype=torch.float32) + next_states = torch.tensor(dataset["next_obs"], dtype=torch.float32) + dones = torch.tensor(dataset["dones"], dtype=torch.float32) + + return states, actions, rewards, next_states, dones + + +@torch.no_grad() +def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): + """Evaluate model performance across all benchmarks.""" + env_time = 0.0 + eps = None + all_speedups, all_entropies = [], [] + + for _ in trange(cfg.bench_count, desc="Eval Trajectory", leave=False): + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done, speedup = False, None + bench_rewards, bench_entropies = [], [] + bench_name = state.bench_name + + while not bench_done: + obs = Observation.from_state(state) + action_index, action_log_p, entropy = model.sample(obs.to(device), eps=eps) + action = ActionSpace.action_by_index(action_index[0], state) + + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # per-benchmark logs + fl.log_scalars(f"eval/{bench_name}", { + "mean_reward": float(np.mean(bench_rewards)) if bench_rewards else 0.0, + "mean_entropy": float(np.mean(bench_entropies)) if bench_entropies else 0.0, + "final_speedup": speedup if speedup is not None else 0.0, + }, step) + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + # global logs + if all_speedups: + fl.log_scalar("eval/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("eval/average_entropy", float(np.mean(all_entropies)), step) + + return env_time + + +def collect_warmup_data(agent, env, buffer, warmup_steps=5000): + """ + Collects a fixed number of online transitions before training begins. + This stabilizes early learning by pre-filling the online buffer. + """ + print(f"\n[ Warmup Phase ] Collecting {warmup_steps} online transitions before training...\n") + state = env.reset() + + progress = trange(warmup_steps, desc="Warmup Collection", dynamic_ncols=True) + + + for _ in progress: + obs = Observation.from_state(state) + action_index, _, _ = agent.sample(obs.to(device), eps=None) + action = ActionSpace.action_by_index(action_index[0], state) + + next_state, reward, op_done, _ = env.step(state, action) + next_obs = Observation.from_state(next_state) + + # Handle operation completion + if op_done: + next_state, done = env.get_next_op_state(next_state) + else: + done = False + + # Store in online buffer + buffer.add_online( + obs.to(device), + action_index, + torch.tensor(reward, dtype=torch.float32, device=device), + next_obs.to(device), + torch.tensor(done, dtype=torch.float32, device=device), + ) + + state = next_state if not done else env.reset() + + + print(f"[Warmup Complete] Collected {len(buffer.online_buffer)} online transitions.\n") + + +def get_epsilon(step, eps_start=0.3, eps_end=0.05, decay_steps=100_000): + """ + Linearly decays epsilon from eps_start → eps_end over decay_steps. + Used for epsilon-greedy exploration during online interaction. + """ + if step >= decay_steps: + return eps_end + decay = (eps_start - eps_end) * (1 - step / decay_steps) + return eps_end + decay + + +class ReplayBuffer: + """Simple replay buffer mixing offline + online data.""" + def __init__(self, max_size=100000): + self.states, self.actions, self.rewards, self.next_states, self.dones = [], [], [], [], [] + self.max_size = max_size + + def add(self, s, a, r, ns, d): + if len(self.states) >= self.max_size: + # drop oldest + self.states.pop(0) + self.actions.pop(0) + self.rewards.pop(0) + self.next_states.pop(0) + self.dones.pop(0) + + self.states.append(s.squeeze(0)) + self.actions.append(a.squeeze(0)) + self.rewards.append(r.squeeze(0)) + self.next_states.append(ns.squeeze(0)) + self.dones.append(d) + + def sample(self, batch_size): + idxs = np.random.randint(0, len(self.states), size=batch_size) + return ( + torch.stack([self.states[i] for i in idxs]), + torch.stack([self.actions[i] for i in idxs]), + torch.stack([self.rewards[i] for i in idxs]), + torch.stack([self.next_states[i] for i in idxs]), + torch.stack([self.dones[i] for i in idxs]), + ) + + def __len__(self): + return len(self.states) + +class DualReplayBuffer: + """Manages separate buffers for offline and online data, + with controlled sampling ratio that decays over time.""" + def __init__(self, offline_buffer_max=100_000, online_buffer_max=100_000): + self.offline_buffer = ReplayBuffer(max_size=offline_buffer_max) + self.online_buffer = ReplayBuffer(max_size=online_buffer_max) + + def add_offline(self, s, a, r, ns, d): + self.offline_buffer.add(s, a, r, ns, d) + + def add_online(self, s, a, r, ns, d): + self.online_buffer.add(s, a, r, ns, d) + + def sample(self, batch_size, step, max_steps, online_start_ratio=0.8, online_end_ratio=0.2): + """Sample from both buffers with ratio decaying linearly over training.""" + # Compute current online ratio + progress = min(step / max_steps, 1.0) + online_ratio = online_start_ratio - (online_start_ratio - online_end_ratio) * progress + online_batch = int(batch_size * online_ratio) + offline_batch = batch_size - online_batch + + # Safety check: if one buffer is too small, compensate from the other + online_batch = min(online_batch, len(self.online_buffer)) + offline_batch = batch_size - online_batch + if len(self.offline_buffer) < offline_batch: + offline_batch = len(self.offline_buffer) + online_batch = batch_size - offline_batch + + # Sample + if online_batch > 0: + online_samples = self.online_buffer.sample(online_batch) + else: + online_samples = None + if offline_batch > 0: + offline_samples = self.offline_buffer.sample(offline_batch) + else: + offline_samples = None + + # Merge samples + def merge_tensors(t1, t2): + if t1 is None: return t2 + if t2 is None: return t1 + return torch.cat([t1, t2], dim=0) + + return tuple( + merge_tensors(o, f) + for o, f in zip( + online_samples if online_samples else (None,)*5, + offline_samples if offline_samples else (None,)*5 + ) + ) + + def __len__(self): + return len(self.offline_buffer) + len(self.online_buffer) + + +def hybrid_finetune(): + # === Load pretrained agent === + agent = IQLAgent(cfg, device, obs_parts=[OpFeatures, ActionHistory]) + ckpt_path = "./offline_iql_results_1/iql_step_97999.pt" + if os.path.exists(ckpt_path): + agent.load_state_dict(torch.load(ckpt_path, map_location=device)) + print(f"Loaded pretrained checkpoint: {ckpt_path}") + else: + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + + # === Init Replay Buffer with offline data === + + buffer = DualReplayBuffer(offline_buffer_max=100_000, online_buffer_max=100_000) + + states, actions, rewards, next_states, dones = load_offline_dataset() + for s, a, r, ns, d in zip(states, actions, rewards, next_states, dones): + buffer.add_offline(s, a, r, ns, d) + print(f"Offline Replay buffer initialized with {len(buffer)} offline samples") + + # environments + train_env = Env(is_training=True, run_name=cfg.run_name) + eval_env = Env(is_training=False, run_name=cfg.run_name) + + print("Starting HYBRID fine-tuning (offline + online)...") + start_time = time.time() + state = train_env.reset() + + + # Warmup Phase (modular) + collect_warmup_data(agent, train_env, buffer, warmup_steps=2000) + + hybrid_trange = trange(cfg.max_steps, desc="Hybrid Fine-tuning", dynamic_ncols=True) + for step in hybrid_trange: + # reset benchmark + state = train_env.reset() + done = False + + while not done: + # current obs + obs = Observation.from_state(state) + + # agent picks action + eps = get_epsilon(step) + action_index, _, _ = agent.sample(obs.to(device), eps=eps) + action = ActionSpace.action_by_index(action_index[0], state) + + # env step + next_state, reward, op_done, _ = train_env.step(state, action) + + # build next_obs BEFORE advancing benchmark + next_obs = Observation.from_state(next_state) + + # if op finished, advance to next op or benchmark end + if op_done: + next_state, done = train_env.get_next_op_state(next_state) + + # push transition to replay buffer + buffer.add_offline( + obs.to(device), + action_index, + torch.tensor(reward, dtype=torch.float32, device=device), + next_obs.to(device), + torch.tensor(done, dtype=torch.float32, device=device), + ) + + # move forward + state = next_state + + # after benchmark, do 1 gradient update + batch = buffer.sample(cfg.batch_size,step, MAX_STEPS_OFFLINE_ONLINE_RATIO) + + for _ in range(UPDATE_ITERS): + losses = agent.update(batch) + + if (step + 1) % 100 == 0: + fl.log_scalars("hybrid_train", losses, step) + elapsed = time.time() - start_time + hybrid_trange.set_postfix({ + "Value Loss": f"{losses['value']:.4f}", + "Q Loss": f"{losses['q']:.4f}", + "Policy Loss": f"{losses['policy']:.4f}", + "Elapsed": f"{elapsed:.2f}s" + }) + + if (step + 1) % 5000 == 0: + print("Evaluating on benchmarks ...") + eval_start = time.time() + env_time = evaluate_benchmarks(agent, eval_env, step) + print(f"Evaluation done in {time.time() - eval_start:.2f}s (env time: {env_time:.2f}s)") + fl.flush() + + if (step + 1) % 5000 == 0: + os.makedirs(cfg.results_dir, exist_ok=True) + save_path = os.path.join(cfg.results_dir, f"iql_hybrid_step_{step}.pt") + torch.save(agent.state_dict(), save_path) + print(f"Checkpoint saved: {save_path}") + + state = next_state + + print(f"Hybrid fine-tuning finished in {time.time() - start_time:.2f} seconds.") + + +if __name__ == "__main__": + hybrid_finetune() diff --git a/train_ppo.py b/train_ppo.py new file mode 100644 index 0000000..017521b --- /dev/null +++ b/train_ppo.py @@ -0,0 +1,109 @@ +# Load environment variables +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import os +from typing import Optional +from utils.log import print_info, print_success + +# Import environment +from rl_autoschedular.env import Env + +# config, file_logger, device +from rl_autoschedular import config as cfg, file_logger as fl, device + +# Import RL components +from rl_autoschedular.model import HiearchyModel as Model +from rl_autoschedular.trajectory import TrajectoryData +from rl_autoschedular.ppo import ( + collect_trajectory, + ppo_update, + value_update, + evaluate_benchmarks +) + +import time +torch.set_grad_enabled(False) +torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "4"))) + +if cfg.debug: + torch.autograd.set_detect_anomaly(True) + +print_info(f"Config: {cfg}") +print_success(f'Logging to: {fl.run_dir}') + +# Set environments + +# run_name for /tmp/ path +env = Env(is_training=True,run_name="online_ppo") +eval_env = Env(is_training=False,run_name="online_ppo") +print_success(f"Environments initialized: {env.tmp_file}") + +# Set model +model = Model().to(device) +optimizer = torch.optim.Adam( + model.parameters(), + lr=3e-4 +) +print_success("Model initialized") + +train_start = time.perf_counter() +total_env_time = 0.0 +total_eval_time = 0.0 + +# Start training +for step in range(cfg.nb_iterations): + print_info(f"- Main Loop {step + 1}/{cfg.nb_iterations} ({100 * (step + 1) / cfg.nb_iterations:.2f}%)") + trajectory , env_time = collect_trajectory( + model, + env, + step, + ) + total_env_time += env_time + + + + # Fit value model to trajectory rewards + if cfg.value_epochs > 0: + value_update( + trajectory, + model, + optimizer, + step + ) + + ppo_update( + trajectory, + model, + optimizer, + step + ) + + if (step + 1) % 50 == 0: + torch.save( + model.state_dict(), + os.path.join( + env.tmp_file.replace('.mlir', ''), + f'model_{step}.pth' + ) + ) + + if (step + 1) % 50 == 0: + start_eval = time.perf_counter() + print_info('- Evaluating benchmark -') + eval_time = evaluate_benchmarks( + model, + eval_env, + step + ) + end_eval = time.perf_counter() + total_eval_time += end_eval - start_eval +train_end = time.perf_counter() +total_train_time = train_end - train_start - total_eval_time +print_success(f"- Training completed in {total_train_time:.2f} seconds") +print_success(f"- Evaluation completed in {total_eval_time:.2f} seconds") +print_success(f"- Total environment time: {total_env_time:.2f} seconds") +print_success(f"- Total eval env time: {total_eval_time:.2f} seconds") +print_success(f"- Percentage of time in environment: {100 * total_env_time / total_train_time:.2f}%") \ No newline at end of file diff --git a/utils/config.py b/utils/config.py index dc7113b..0612d28 100644 --- a/utils/config.py +++ b/utils/config.py @@ -4,6 +4,8 @@ import json import os +from utils.keys import CONFIG_FILE_PATH + class Config(metaclass=Singleton): """Class to store and load global configuration""" @@ -22,6 +24,8 @@ class Config(metaclass=Singleton): """The order of actions that needs to bo followed""" interchange_mode: Literal['enumerate', 'pointers', 'continuous'] """The method used for interchange action""" + use_img2col: bool + """Whether to use img2col transformation on conv2d operations""" exploration: list[Literal['entropy', 'epsilon']] """The exploration method""" init_epsilon: float @@ -70,13 +74,17 @@ class Config(metaclass=Singleton): """Path to the file containing the execution data""" results_dir: str """Path to the results directory""" + save_model_every: int + """Number of iterations between saving the model""" + evaluate_every: int + """Number of iterations between evaluations""" def __init__(self): """Load the configuration from the JSON file or get existing instance if any. """ # Open the JSON file - with open(os.getenv("CONFIG_FILE_PATH"), "r") as f: + with open(CONFIG_FILE_PATH, "r") as f: config_data: dict[str, Any] = json.load(f) for element, element_t in self.__annotations__.items(): diff --git a/utils/dask_manager.py b/utils/dask_manager.py index 5de6a68..f1daa8e 100644 --- a/utils/dask_manager.py +++ b/utils/dask_manager.py @@ -11,6 +11,8 @@ from .log import print_alert, print_error, print_info, print_success import os +from utils.keys import CONDA_ENV + if TYPE_CHECKING: from rl_autoschedular.benchmarks import Benchmarks from dask_jobqueue.slurm import SLURMJob @@ -44,7 +46,7 @@ def __init__(self): job_script_prologue=[ 'module load miniconda-nobashrc', 'eval "$(conda shell.bash hook)"', - f'conda activate {os.getenv("CONDA_ENV")}', + f'conda activate {CONDA_ENV}', 'export OMP_NUM_THREADS=12', ], scheduler_options={ diff --git a/utils/data_collector.py b/utils/data_collector.py new file mode 100644 index 0000000..d2ddc2c --- /dev/null +++ b/utils/data_collector.py @@ -0,0 +1,74 @@ +import os +import numpy as np +import torch +from utils.singleton import Singleton +from utils.log import print_success + +class OfflineDataset(metaclass=Singleton): + """Singleton class to collect and store trajectories for offline RL """ + + def __init__(self, save_dir: str = "offline_data", fname: str = "dataset.npz"): + """ + Args: + save_dir (str): Directory to store dataset. + fname (str): Dataset filename. + """ + self.save_dir = save_dir + self.fname = fname + os.makedirs(self.save_dir, exist_ok=True) + + self.buffer = [] # in-memory buffer for efficiency + self.file_path = os.path.join(self.save_dir, self.fname) + + def add_transition(self, obs, action, next_obs, reward, done): + """Add one transition to buffer.""" + self.buffer.append({ + "obs": obs.squeeze(0).cpu().numpy() if torch.is_tensor(obs) else obs, + "action": action.detach().cpu().numpy().squeeze(0) if torch.is_tensor(action) else np.array(action).squeeze(0), + "next_obs": next_obs.squeeze(0).cpu().numpy() if torch.is_tensor(next_obs) else next_obs, + "reward": float(reward), + "done": bool(done), + }) + + def add_trajectory(self, trajectory): + """Add a full trajectory (list of transitions).""" + self.buffer.extend(trajectory) + + def flush(self): + """Save buffer to disk as npz and clear it.""" + if not self.buffer: + return + + # Convert buffer to arrays + obs = np.array([t["obs"] for t in self.buffer], dtype=np.float32) + actions = np.array([t["action"] for t in self.buffer], dtype=np.int64) + next_obs = np.array([t["next_obs"] for t in self.buffer], dtype=np.float32) + rewards = np.array([t["reward"] for t in self.buffer], dtype=np.float32) + dones = np.array([t["done"] for t in self.buffer], dtype=np.bool_) + + if os.path.exists(self.file_path): + # If file exists, append to it + old = np.load(self.file_path) + obs = np.concatenate([old["obs"], obs], axis=0) + actions = np.concatenate([old["actions"], actions], axis=0) + next_obs = np.concatenate([old["next_obs"], next_obs], axis=0) + rewards = np.concatenate([old["rewards"], rewards], axis=0) + dones = np.concatenate([old["dones"], dones], axis=0) + + np.savez_compressed( + self.file_path, + obs=obs, + actions=actions, + next_obs=next_obs, + rewards=rewards, + dones=dones, + ) + + print_success(f"[OfflineDataset] Flushed {len(self.buffer)} transitions -> {self.file_path}") + self.buffer.clear() + + def load(self, mmap_mode=None): + """Load dataset from disk as dict of numpy arrays (optionally memory-mapped).""" + if not os.path.exists(self.file_path): + return {} + return np.load(self.file_path, mmap_mode=mmap_mode) \ No newline at end of file diff --git a/utils/keys.py b/utils/keys.py new file mode 100644 index 0000000..43980ce --- /dev/null +++ b/utils/keys.py @@ -0,0 +1,61 @@ +import os + +from dotenv import load_dotenv + +load_dotenv() + +VERBOSE = False + +CONDA_ENV = os.getenv("CONDA_ENV") +if CONDA_ENV is None: + raise ValueError("CONDA_ENV environment variable is not set.") +elif VERBOSE: + print(f"Using conda environment: {CONDA_ENV}") + +NEPTUNE_PROJECT = os.getenv("NEPTUNE_PROJECT") +if NEPTUNE_PROJECT is None: + raise ValueError("NEPTUNE_PROJECT environment variable is not set.") +elif VERBOSE: + print(f"Using Neptune project: {NEPTUNE_PROJECT}") + +NEPTUNE_TOKEN = os.getenv("NEPTUNE_TOKEN") +if NEPTUNE_TOKEN is None: + raise ValueError("NEPTUNE_TOKEN environment variable is not set.") +elif VERBOSE: + print("Neptune token is set.") + +LLVM_BUILD_PATH = os.getenv("LLVM_BUILD_PATH") +if LLVM_BUILD_PATH is None: + raise ValueError("LLVM_BUILD_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using LLVM build path: {LLVM_BUILD_PATH}") + +MLIR_SHARED_LIBS = os.getenv("MLIR_SHARED_LIBS") +if MLIR_SHARED_LIBS is None: + raise ValueError("MLIR_SHARED_LIBS environment variable is not set.") +elif VERBOSE: + print(f"Using MLIR shared libs: {MLIR_SHARED_LIBS}") + +AST_DUMPER_BIN_PATH = os.getenv("AST_DUMPER_BIN_PATH") +if AST_DUMPER_BIN_PATH is None: + raise ValueError("AST_DUMPER_BIN_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using AST dumper bin path: {AST_DUMPER_BIN_PATH}") + +VECTORIZER_BIN_PATH = os.getenv("VECTORIZER_BIN_PATH") +if VECTORIZER_BIN_PATH is None: + raise ValueError("VECTORIZER_BIN_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using vectorizer bin path: {VECTORIZER_BIN_PATH}") + +PRE_VEC_BIN_PATH = os.getenv("PRE_VEC_BIN_PATH") +if PRE_VEC_BIN_PATH is None: + raise ValueError("PRE_VEC_BIN_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using pre-vectorizer bin path: {PRE_VEC_BIN_PATH}") + +CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH") +if CONFIG_FILE_PATH is None: + raise ValueError("CONFIG_FILE_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using config file path: {CONFIG_FILE_PATH}") \ No newline at end of file