From 73b2c096ace0dd7fe1c87aa8d1c048595878ee7f Mon Sep 17 00:00:00 2001 From: BrouthenKamel Date: Sun, 21 Dec 2025 16:27:36 +0400 Subject: [PATCH 1/3] Img2Col as an action; Offline RL (Ouail contribution); LLM Action space PoC --- .env.example | 3 +- .gitignore | 3 + config/example.json | 1 + demo.ipynb | 17 +- evaluate.py | 1 + get_base.py | 22 +- iql/__init__.py | 0 iql/agent.py | 282 ++++++++++++++ iql/config.py | 159 ++++++++ iql/policy.py | 82 ++++ iql/q_functions.py | 281 ++++++++++++++ iql/singleton.py | 8 + iql/value_function.py | 69 ++++ llm_action/.env.example | 6 + ...v_2d_nchw_fchw_128_32_7_7_256_1_1_7_7.mlir | 10 + .../data/matmul/matmul_128_256_128.mlir | 10 + llm_action/log/test/test_20251221_141044.txt | 224 +++++++++++ llm_action/src/__init__.py | 0 llm_action/src/agent.py | 178 +++++++++ llm_action/src/config.py | 10 + llm_action/src/keys.py | 9 + llm_action/src/llm.py | 12 + llm_action/src/prompt.py | 300 +++++++++++++++ llm_action/src/tools.py | 103 +++++ llm_action/src/utils/cache.py | 0 llm_action/src/utils/log.py | 10 + llm_action/src/utils/processing.py | 16 + llm_action/src/utils/transformation.py | 253 ++++++++++++ neptune_sync.py | 7 +- requirements.txt | 3 + rl_autoschedular/actions/__init__.py | 4 +- rl_autoschedular/actions/img2col.py | 33 ++ rl_autoschedular/benchmarks.py | 5 +- rl_autoschedular/env.py | 19 +- rl_autoschedular/ppo.py | 6 +- rl_autoschedular/state.py | 64 +++- rl_autoschedular/trajectory.py | 10 + rl_autoschedular/transforms.py | 1 - run_llm_playground.py | 7 + run_mlir_agent.py | 123 ++++++ scripts/neptune-sync.sh | 12 +- tests/cuda.py | 48 +++ train.py | 16 +- train_iql_offline.py | 208 ++++++++++ train_iql_online.py | 340 +++++++++++++++++ train_ppo.py | 109 ++++++ utils/config.py | 10 +- utils/dask_manager.py | 5 +- utils/data_collector.py | 74 ++++ utils/data_utils.py | 359 ++++++++++++++++++ utils/keys.py | 61 +++ 51 files changed, 3561 insertions(+), 32 deletions(-) create mode 100644 iql/__init__.py create mode 100644 iql/agent.py create mode 100644 iql/config.py create mode 100644 iql/policy.py create mode 100644 iql/q_functions.py create mode 100644 iql/singleton.py create mode 100644 iql/value_function.py create mode 100644 llm_action/.env.example create mode 100644 llm_action/data/conv2d/conv_2d_nchw_fchw_128_32_7_7_256_1_1_7_7.mlir create mode 100644 llm_action/data/matmul/matmul_128_256_128.mlir create mode 100644 llm_action/log/test/test_20251221_141044.txt create mode 100644 llm_action/src/__init__.py create mode 100644 llm_action/src/agent.py create mode 100644 llm_action/src/config.py create mode 100644 llm_action/src/keys.py create mode 100644 llm_action/src/llm.py create mode 100644 llm_action/src/prompt.py create mode 100644 llm_action/src/tools.py create mode 100644 llm_action/src/utils/cache.py create mode 100644 llm_action/src/utils/log.py create mode 100644 llm_action/src/utils/processing.py create mode 100644 llm_action/src/utils/transformation.py create mode 100644 rl_autoschedular/actions/img2col.py create mode 100644 run_llm_playground.py create mode 100644 run_mlir_agent.py create mode 100644 tests/cuda.py create mode 100644 train_iql_offline.py create mode 100644 train_iql_online.py create mode 100644 train_ppo.py create mode 100644 utils/data_collector.py create mode 100644 utils/data_utils.py create mode 100644 utils/keys.py diff --git a/.env.example b/.env.example index 47ec962..a5733b2 100644 --- a/.env.example +++ b/.env.example @@ -5,4 +5,5 @@ MLIR_SHARED_LIBS= AST_DUMPER_BIN_PATH= PRE_VEC_BIN_PATH= VECTORIZER_BIN_PATH= -CONDA_ENV= \ No newline at end of file +CONDA_ENV= +CONFIG_FILE_PATH= \ No newline at end of file diff --git a/.gitignore b/.gitignore index c206e8f..70742a0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ .neptune *__pycache__ tools/*/build + +docs/ +playground/ \ No newline at end of file diff --git a/config/example.json b/config/example.json index f1c6ca7..26353f0 100644 --- a/config/example.json +++ b/config/example.json @@ -6,6 +6,7 @@ "vect_size_limit": 512, "order": [["I"], ["!", "I", "NT"], ["!", "I"], ["V", "NT"]], "interchange_mode": "enumerate", + "use_img2col": true, "exploration": ["entropy"], "init_epsilon": 0.5, "normalize_bounds": "max", diff --git a/demo.ipynb b/demo.ipynb index 894cf2b..c189698 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -2,10 +2,21 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "2e74d0c8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Setup environment\n", "# import os\n", @@ -82,7 +93,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mlir-rl", + "display_name": "llvm-build", "language": "python", "name": "python3" }, diff --git a/evaluate.py b/evaluate.py index c531750..06886ac 100644 --- a/evaluate.py +++ b/evaluate.py @@ -71,6 +71,7 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: # Read the files in the evaluation directory eval_files = [f for f in os.listdir(eval_dir) if f.endswith('.pt')] +# eval_files = [eval_files[-1]] # Only evaluate the last model # Order files eval_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) diff --git a/get_base.py b/get_base.py index 62a5e25..215b790 100644 --- a/get_base.py +++ b/get_base.py @@ -11,8 +11,13 @@ if not os.path.isdir(path_to_folder): print(f"Error: {path_to_folder} is not a valid directory.") sys.exit(1) + +with open(f"{path_to_folder}/../benchmarks_split.json", 'r') as f: + benchmarks_split = json.load(f) + +train_output_data = {} +eval_output_data = {} -output_data = {} exec = Execution("") code_files = [f for f in os.listdir(path_to_folder) if f.endswith('.mlir')] @@ -28,7 +33,14 @@ except Exception as e: print(f"Failed to execute {bench_name}: {e}") et = -1 - output_data[bench_name] = et - - with open('base_exec_times.json', 'w') as f: - json.dump(output_data, f, indent=4) + + if bench_name in benchmarks_split['train']: + + train_output_data[bench_name] = et + with open(f"{path_to_folder}/../execution_times_train.json", 'w') as f: + json.dump(train_output_data, f, indent=4) + + elif bench_name in benchmarks_split['eval']: + eval_output_data[bench_name] = et + with open(f"{path_to_folder}/../execution_times_eval.json", 'w') as f: + json.dump(eval_output_data, f, indent=4) diff --git a/iql/__init__.py b/iql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iql/agent.py b/iql/agent.py new file mode 100644 index 0000000..2d70f70 --- /dev/null +++ b/iql/agent.py @@ -0,0 +1,282 @@ +import copy +from typing import Dict, List, Optional, Type, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.config import Config +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory + +from iql.value_function import IQLValueModel +from iql.policy import IQLPolicyModel +from iql.q_functions import IQLTwinQ + +cfg = Config() + +class IQLAgent(nn.Module): + """ + IQL agent adapted to the PPO-aligned architecture and hierarchical action space. + - Uses Observation.get_parts(obs, *obs_parts) + - Shared 3x512 backbone across policy/value/Q + - Hierarchical heads (action + per-action params) + """ + def __init__(self, obs_parts=None, param_dims=None): + super().__init__() + self.obs_parts = obs_parts or [OpFeatures, ActionHistory] + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Use config hyperparameters + self.gamma = cfg.gamma + self.tau = cfg.tau + self.beta = cfg.beta + self.alpha = cfg.alpha + + # Networks + self.value_model = IQLValueModel(self.obs_parts, tau=self.tau).to(self.device) + self.policy_model = IQLPolicyModel(self.obs_parts).to(self.device) + self.q_model = IQLTwinQ(self.obs_parts).to(self.device) + + # Target Q + self.q_target = copy.deepcopy(self.q_model).to(self.device) + for p in self.q_target.parameters(): + p.requires_grad = False + + # Optimizers with cfg.lr dict (after models are on device) + self.value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=cfg.lr["value"]) + self.q_optimizer = torch.optim.Adam(self.q_model.parameters(), lr=cfg.lr["q"]) + self.policy_optimizer = torch.optim.Adam(self.policy_model.parameters(), lr=cfg.lr["policy"]) + """ + self.policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.policy_optimizer, + T_max=600000, + eta_min=1e-5 + ) + """ + + # --------- helpers to move inputs to device ---------- + def _to_device_tensor(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if x is None: + return None + return x.to(self.device, non_blocking=True) + + def _to_device_tensor_list( + self, + xs: Optional[List[Optional[torch.Tensor]]] + ) -> Optional[List[Optional[torch.Tensor]]]: + if xs is None: + return None + out: List[Optional[torch.Tensor]] = [] + for t in xs: + out.append(self._to_device_tensor(t) if isinstance(t, torch.Tensor) else None if t is None else t) + return out + + # ------------------------ + # Action selection (hierarchical) + # ------------------------ + @torch.no_grad() + def sample( + self, + obs: torch.Tensor, + greedy: bool = False, + eps: Optional[float] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample hierarchical action indices using the same API style as PPO. + Returns: + actions_index: packed hierarchical indices (ActionSpace format) + actions_log_p: log-prob of sampled action under current policy + entropies: per-head entropies (aggregated by ActionSpace) + """ + + # Build distributions from policy + dists = self.policy_model(obs) + eps_dists = ActionSpace.uniform_distributions(obs) + + # Hierarchical sample + use_uniform = (eps is not None) and (torch.rand((), device=self.device).item() < eps) + actions_index = ActionSpace.sample( + obs, + dists, + eps_dists, + uniform=use_uniform, + greedy=greedy, + ) + + # Stats for the sampled actions + actions_log_p, entropies = ActionSpace.distributions_stats( + dists, + actions_index, + eps_distributions=eps_dists if eps is not None else None, + eps=eps, + ) + return actions_index, actions_log_p, entropies + + # ------------------------ + # Value update (expectile regression using target twin-Q) + # ------------------------ + def update_value( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + *, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Updates V(s) by regressing towards min(Q1, Q2) from the *target* Q network. + """ + + with torch.no_grad(): + q1_t, q2_t = self.q_target(obs, action_idx) + q_min_t = torch.min(q1_t, q2_t) # [B] + + self.value_optimizer.zero_grad(set_to_none=True) + loss_v = self.value_model.loss(obs, q_min_t) + loss_v.backward() + self.value_optimizer.step() + return loss_v + + # ------------------------ + # Q update (TD with V(s')) + # ------------------------ + def update_q( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + rewards: torch.Tensor, + next_obs: torch.Tensor, + dones: torch.Tensor + ) -> torch.Tensor: + """ + Update twin Q networks with TD target: + target_q = r + gamma * (1 - done) * V_target(s') + If target_v is not provided, it is computed from the current value_model. + """ + + with torch.no_grad(): + target_v = self.value_model(next_obs).to(self.device) # [B] + + target_q = rewards + self.gamma * (1.0 - dones) * target_v # [B] + + self.q_optimizer.zero_grad(set_to_none=True) + + + loss_q = self.q_model.loss( + obs, + action_idx, + target_q + ) + loss_q.backward() + self.q_optimizer.step() + return loss_q + + # ------------------------ + # Policy update (advantage-weighted BC) + # ------------------------ + def update_policy( + self, + obs: torch.Tensor, + actions_index: torch.Tensor, # packed hierarchical indices (as stored by dataset) + *, + action_idx: Optional[torch.LongTensor] = None, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Update policy with advantage-weighted log-likelihood: + weights = exp(A / beta), A = min(Q1, Q2) - V(s) + - actions_index is used to compute log π(a|s) via ActionSpace.distributions_stats(...) + - Q needs decomposed (action_idx, param_indices/values). + """ + + # 1) log π(a|s) from hierarchical distributions + dists = self.policy_model(obs) + actions_log_p, _ = ActionSpace.distributions_stats(dists, actions_index) + + # 2) advantages = Q_min(s,a) - V(s) + assert action_idx is not None, "action_idx (top-level) is required for Q evaluation" + with torch.no_grad(): + q_min = self.q_model.q_values(obs, action_idx) # [B] + v = self.value_model(obs) # [B] + advantages = q_min - v # [B] + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + + # 3) loss (AWAC/IQL style) + + # 1. zero gradients + self.policy_optimizer.zero_grad(set_to_none=True) + # 2. compute loss + loss_pi = self.policy_model.loss( + actions_log_p=actions_log_p, + advantages=advantages, + beta=self.beta, + ) + + # 3. backpropagate + loss_pi.backward() + + # 4. clip gradients (to avoid instability) + torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), max_norm=5.0) + + + self.policy_optimizer.step() + # self.policy_lr_scheduler.step() + return loss_pi + + # ------------------------ + # Soft update of target Q + # ------------------------ + @torch.no_grad() + def soft_update_q_target(self): + """ + θ_target ← α θ + (1-α) θ_target + """ + for p, tp in zip(self.q_model.parameters(), self.q_target.parameters()): + tp.data.copy_(self.alpha * p.data + (1.0 - self.alpha) * tp.data) + + def update(self, batch: Tuple[torch.Tensor, ...]) -> Dict[str, float]: + """ + One full IQL update step: + 1. Update Q-functions + 2. Update value function + 3. Update policy (AWAC/IQL style) + 4. Soft update target Q + Returns dict of losses for logging. + """ + obs, actions_index, rewards, next_obs, dones = (t.to(self.device, non_blocking=True) for t in batch) + + + # ---- 1) Update Q ---- + loss_q = self.update_q( + obs=obs, + action_idx=actions_index, # top-level index + rewards=rewards, + next_obs=next_obs, + dones=dones, + ) + + # ---- 2) Update Value ---- + loss_v = self.update_value(obs, actions_index) + + + # ---- 3) Update Policy ---- + loss_pi = self.update_policy( + obs=obs, + actions_index=actions_index, + action_idx=actions_index, # required for Q evaluation + ) + + + + # ---- 4) Soft update Q target ---- + self.soft_update_q_target() + + return { + "q": float(loss_q.item()), + "policy": float(loss_pi.item()), + "value": float(loss_v.item()), + } \ No newline at end of file diff --git a/iql/config.py b/iql/config.py new file mode 100644 index 0000000..e4276e7 --- /dev/null +++ b/iql/config.py @@ -0,0 +1,159 @@ +from dotenv import load_dotenv +load_dotenv() + + +import os +from utils.singleton import Singleton +import json +from typing import Literal + + + + +class Config(metaclass=Singleton): + """Class to store and load global configuration""" + + ############## IQL specific parameters ############## + gamma : float + """Discount factor""" + tau : float + """expectile regression parameter""" + inverse_temperature : float + """Inverse temperature for advantage-weighted regression""" + alpha : float + """target smoothing coefficient""" + batch_size : int + """Batch size for training""" + learning_rate : dict[str, float] + """Learning rate for the optimizer""" + max_steps : int + """Maximum number of training steps""" + target_update_freq : int + """Frequency of target network updates""" + sparse_reward : bool + """Flag to enable sparse reward""" + + offline_data_directory : str + """The offline data directory""" + offline_data_file : str + """The offline data file""" + + + ############## Environment specific parameters ############## + max_num_stores_loads: int + """The maximum number of loads in the nested loops""" + max_num_loops: int + """The max number of nested loops""" + max_num_load_store_dim: int + """The max number of dimensions in load/store buffers""" + num_tile_sizes: int + """The number of tile sizes""" + vect_size_limit: int + """Vectorization size limit to prevent large sizes vectorization""" + order: list[list[str]] + """The order of actions that needs to bo followed""" + interchange_mode: Literal['enumerate', 'pointers', 'continuous'] + """The method used for interchange action""" + exploration: list[Literal['entropy', 'epsilon']] + """The exploration method""" + init_epsilon: float + """The initial epsilon value for epsilon greedy exploration""" + + normalize_bounds: Literal['none', 'max', 'log'] + """Flag to indicate if the upper bounds in the input should be normalized or not""" + + + split_ops: bool + """Flag to enable splitting operations into separate benchmarks""" + + activation: Literal["relu", "tanh"] + """The activation function to use in the network""" + + benchmarks_folder_path: str + """Path to the benchmarks folder. Can be empty if optimization mode is set to "last".""" + + bench_count: int + """Number of batches in a trajectory""" + + truncate: int + """Maximum number of steps in the schedule""" + json_file: str + """Path to the JSON file containing the benchmarks execution times.""" + eval_json_file: str + """Path to the JSON file containing the benchmarks execution times for evaluation.""" + + tags: list[str] + """List of tags to add to the neptune experiment""" + + debug: bool + """Flag to enable debug mode""" + + exec_data_file: str + """Path to the file containing the execution data""" + results_dir: str + """Path to the results directory""" + + loaded: bool + """Flag to check if the config was already loaded from JSON file or not""" + + def __init__(self): + """Initialize the default values""" + # IQL specific parameters + self.gamma = 0.99 + self.tau = 0.7 + self.inverse_temperature = 3.0 + self.alpha = 0.005 + self.batch_size = 256 + self.learning_rate = { + "value": 3e-4, + "q": 3e-4, + "policy": 3e-4 + } + self.max_steps = 1000000 + self.target_update_freq = 1 + self.sparse_reward = True + + self.offline_data_directory = "./data" + self.offline_data_file = "offline_data.npz" + + # Environment specific parameters + self.max_num_stores_loads = 2 + self.max_num_loops = 4 + self.max_num_load_store_dim = 2 + self.num_tile_sizes = 2 + self.vect_size_limit = 16 + self.order = [] + self.interchange_mode = 'continuous' + self.exploration = ['entropy', 'epsilon'] + self.init_epsilon = 1.0 + self.normalize_bounds = 'log' + self.split_ops = False + self.activation = "relu" + self.benchmarks_folder_path = "./benchmarks" + self.bench_count = 1 + self.truncate = 20 + self.json_file = "./config/exec_times.json" + self.eval_json_file = "./config/exec_times.json" + self.tags = [] + self.debug = False + self.exec_data_file = "./data/exec_data.npz" + self.results_dir = "./results" + self.loaded = False + + def load_from_json(self): + """Load the configuration from the JSON file.""" + # Open the JSON file + with open(os.getenv("OFFLINE_RL_CONFIG_FILE_PATH"), "r") as f: + config = json.load(f) + # Set the configuration values + for key, value in config.items(): + if hasattr(self, key): + setattr(self, key, value) + + def to_dict(self): + """Convert the configuration to a dictionary.""" + return self.__dict__ + + def __str__(self): + """Convert the configuration to a string.""" + return str(self.to_dict()) \ No newline at end of file diff --git a/iql/policy.py b/iql/policy.py new file mode 100644 index 0000000..5c27a43 --- /dev/null +++ b/iql/policy.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from torch.distributions import Distribution +from typing import Optional, List, Type +from rl_autoschedular import config as cfg +from rl_autoschedular.actions import ActionSpace, Interchange +from rl_autoschedular.observation import Observation, ObservationPart + +# Match PPO’s activation config +ACTIVATION = nn.ReLU if cfg.activation == "relu" else nn.Tanh + + +class IQLPolicyModel(nn.Module): + """ + IQL policy network, sharing architecture with PPO’s PolicyModel. + - Backbone: 3×512 MLP with ACTIVATION() + - Heads: one for action selection + one per action’s parameterization + - Output: list[Distribution], via ActionSpace.distributions + - Loss: BC loss with advantage-weighted log-likelihood (AWAC / IQL style) + """ + + def __init__(self, obs_parts: List[Type[ObservationPart]]): + super().__init__() + self.obs_parts = obs_parts + + + # Shared encoder + in_size = sum(part.size() for part in obs_parts) + self.backbone = nn.Sequential( + nn.Linear(in_size, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + ) + + # One head for action choice + one for each action’s params + output_sizes = [ActionSpace.size()] + [ + action.network_output_size() for action in ActionSpace.supported_actions + ] + self.heads_attributes = [f"head_{i}" for i in range(len(output_sizes))] + + for head_attr, output_size in zip(self.heads_attributes, output_sizes): + if not output_size: + setattr(self, head_attr, None) + continue + + head = nn.Linear(512, output_size) + if cfg.new_architecture: + head = nn.Sequential(nn.Linear(512, 512), ACTIVATION(), head) + setattr(self, head_attr, head) + + def forward(self, obs: torch.Tensor) -> List[Optional[Distribution]]: + """ + Forward pass: produce a Distribution object per action head. + """ + embedded = self.backbone(Observation.get_parts(obs, *self.obs_parts)) + heads: List[Optional[nn.Module]] = [getattr(self, attr) for attr in self.heads_attributes] + actions_logits = [head(embedded) if head else None for head in heads] + return ActionSpace.distributions(obs, *actions_logits) + + def loss( + self, + actions_log_p: torch.Tensor, + advantages: torch.Tensor, + beta: float = 1.0, + ) -> torch.Tensor: + """ + Advantage-weighted behavioral cloning (AWAC) / IQL policy loss. + Args: + obs: Observations [B, ...] + actions_log_p: log π(a|s) from this policy, evaluated at dataset actions + advantages: Advantage estimates A(s,a) from IQL (Q - V) + beta: Temperature scaling (larger beta = more deterministic) + Returns: + Scalar loss tensor. + """ + # Weights = exp(A / beta), clipped for stability + weights = torch.exp(advantages / beta).clamp(max=100.0) + loss = -(weights * actions_log_p).mean() + return loss \ No newline at end of file diff --git a/iql/q_functions.py b/iql/q_functions.py new file mode 100644 index 0000000..881713a --- /dev/null +++ b/iql/q_functions.py @@ -0,0 +1,281 @@ +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import torch.nn as nn +from typing import List, Optional, Tuple, Type + +# reuse same Observation/ActionSpace imports you have in repo +from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.state import OperationState +from rl_autoschedular import config as cfg + +# Keep same activation used elsewhere +ACTIVATION = nn.ReLU # replace with your actual ACTIVATION if different + + +class _DiscreteQHead(nn.Module): + """A head that outputs Q-values for each discrete option (flat logits -> Q-values).""" + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + # single linear mapping from embedding -> out_dim Q-values + self.net = nn.Linear(in_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # returns [B, out_dim] + return self.net(x) + + +class _TwinHiearchicalQNetwork(nn.Module): + """ + One Q network: + - backbone: embed observation -> [B, embed_dim] + - head_action: Q for selecting each action type [B, |A|] + - param_heads: for each action, a head producing Q-values for that action's network_output_size + NOTE: For multi-slot actions (e.g. Tiling), param_head outputs are flat: a concatenation + of per-slot Q-values. We'll reshape inside q_contribs when computing param-contribution. + """ + def __init__(self, obs_parts: List[Type[ObservationPart]], embed_dim: int = 512): + super().__init__() + self.obs_parts = obs_parts + in_size = sum(p.size() for p in obs_parts) + + # Backbone + self.backbone = nn.Sequential( + nn.Linear(in_size, embed_dim), + ACTIVATION(), + nn.Linear(embed_dim, embed_dim), + ACTIVATION(), + nn.Linear(embed_dim, embed_dim), + ACTIVATION(), + ) + + # Action selection Q head + self.head_action = _DiscreteQHead(embed_dim, ActionSpace.size()) + + # Parameter heads (one per supported action) + self.param_heads = nn.ModuleList() + # Save meta per action for convenient reshaping later + self._param_meta: List[Optional[dict]] = [] + for action_cls in ActionSpace.supported_actions: + out_dim = action_cls.network_output_size() + params_size = action_cls.params_size() + if out_dim and out_dim > 0: + head = _DiscreteQHead(embed_dim, out_dim) + self.param_heads.append(head) + # if multi-slot, compute classes_per_slot = out_dim // params_size (integer) + if params_size > 0: + classes_per_slot = None + if params_size > 0: + classes_per_slot = out_dim // params_size if params_size != 0 else None + self._param_meta.append({ + "params_size": params_size, # number of slots for this action + "out_dim": out_dim, + "classes_per_slot": classes_per_slot # number of classes per slot (None if single slot) + # Example : Interchange -> params_size=1 , out_dim= 7 , classes_per_slot= 7 , 7 loop choices for the current interchange + }) + else: + self._param_meta.append({"params_size": 0, "out_dim": out_dim, "classes_per_slot": None}) + else: + # no parameters -> placeholder (we'll treat as None) + # ( NT | V ) + self.param_heads.append(nn.Identity()) + self._param_meta.append(None) + + def forward( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None + ) -> torch.Tensor: + emb = self._embed(obs) + return self.q_contribs(emb, action_idx, param_indices) + + def _embed(self, obs: torch.Tensor) -> torch.Tensor: + parts = Observation.get_parts(obs, *self.obs_parts) # returns [B, in_size] + return self.backbone(parts) # [B, embed_dim] + + def q_contribs( + self, + emb: torch.Tensor, + action_idx: torch.LongTensor, # [B] + param_indices: Optional[List[Optional[torch.LongTensor]]] = None # list of length B + ) -> torch.Tensor: + """ + Compute Q(s, a, params) as: Q_action(s)[a] + Q_params(s, a, params). + - emb: [B, embed_dim] + - action_idx: [B] integers in [0..|A|-1] + - param_indices: list of length B (each either None or [params_size]) + Returns: + q_total: [B] - scalar Q for each sample + """ + B = emb.size(0) + device = emb.device + + # ---- top-level action contribution ---- + act_qs = self.head_action(emb) # [B, |A|] + act_q = act_qs.gather(1, action_idx.view(-1, 1)).squeeze(1) # [B] + + # ---- parameter contribution ---- + param_q = torch.zeros(B, device=device) + + # group samples by chosen action to do batched head computation + for k, head in enumerate(self.param_heads): + meta = self._param_meta[k] + if isinstance(head, nn.Identity) or (meta is None): + continue + + # mask = all samples where chosen action == k + mask = (action_idx == k) + if not mask.any(): + continue + + # get embeddings and their chosen param indices for this action + emb_masked = emb[mask] + head_out = head(emb_masked) # [N_mask, out_dim_k] + + psize = meta["params_size"] # + out_dim = meta["out_dim"] + cps = meta["classes_per_slot"] + + # collect just the param indices for the masked samples + masked_params = [param_indices[i] for i in range(B) if mask[i]] # [N_mask, Optional[LongTensor] of consistet size psize] + # they should all be not None if this action has params + assert all((p is not None) for p in masked_params) or psize == 0 + + if psize == 0: + continue + + if psize == 1: + # single-slot + idx_tensor = torch.stack([p.view(-1)[0] for p in masked_params]).view(-1, 1) # [N_mask, 1] + q_k = head_out.gather(1, idx_tensor).squeeze(1) # [N_mask] + param_q[mask] = q_k + else: + # multi-slot + assert cps is not None and cps > 0, "classes_per_slot unknown for multi-slot head" + + reshaped = head_out.view(-1, psize, cps) # [N_mask, psize, cps] + idx_tensor = torch.stack(masked_params).long() # [N_mask, psize] + idx_exp = idx_tensor.unsqueeze(-1) # [N_mask, psize, 1] + gathered = torch.gather(reshaped, dim=2, index=idx_exp).squeeze(-1) # [N_mask, psize] + q_k = gathered.sum(dim=1) # [N_mask] + param_q[mask] = q_k + + return act_q + param_q + + + + + +# ---------------- hierarchical double Q network ---------------- +class IQLTwinQ(nn.Module): + """ + Top-level twin Q network that matches style of PolicyModel: + - builds two Q-branches (Q1 and Q2), each using _TwinHiearchicalQNetwork + - provides helpers to split flat action-tensor into action_idx + param slices + """ + def __init__(self, obs_parts: List[Type[ObservationPart]], embed_dim: int = 512): + super().__init__() + self.obs_parts = obs_parts + # instantiate two Q heads (Q1, Q2) + self.q1 = _TwinHiearchicalQNetwork(obs_parts, embed_dim=embed_dim) + self.q2 = _TwinHiearchicalQNetwork(obs_parts, embed_dim=embed_dim) + + + def forward(self, obs: torch.Tensor, index: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute Q1(s, a, params) and Q2(s, a, params) for a batch. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + Returns: + (q1_vals, q2_vals) each shaped [B] + """ + action_idx, params_list = self._split_action_tensor(index) + q1_vals = self.q1(obs, action_idx, params_list) + q2_vals = self.q2(obs, action_idx, params_list) + return q1_vals, q2_vals + + def q_values(self, obs: torch.Tensor, index: torch.LongTensor) -> torch.Tensor: + """ + Compute min(Q1(s, a, params), Q2(s, a, params)) for a batch. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + Returns: + q_vals shaped [B] + """ + action_idx, params_list = self._split_action_tensor(index) + q1_vals = self.q1(obs, action_idx, params_list) + q2_vals = self.q2(obs, action_idx, params_list) + return torch.min(q1_vals, q2_vals) + + + def loss(self, obs: torch.Tensor, index: torch.LongTensor, target_q: torch.Tensor) -> torch.Tensor: + """ + Compute MSE loss between Q1, Q2 and target_q. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + - target_q: target Q-values [B] + Returns: + scalar loss + """ + q1_vals, q2_vals = self.forward(obs, index) # each [B] + loss_fn = nn.MSELoss() + loss1 = loss_fn(q1_vals, target_q) + loss2 = loss_fn(q2_vals, target_q) + return loss1 + loss2 + + @staticmethod + def _split_action_tensor(index: torch.LongTensor) -> Tuple[torch.LongTensor, List[Optional[torch.LongTensor]]]: + """ + Split the `index` tensor returned by ActionSpace.sample() into: + - action_idx: [B] + - params: list of length B, each either None (no params) or LongTensor [params_size] for that action + """ + B = index.size(0) + device = index.device + + action_idx = index[:, 0].long() # [B] + cum = ActionSpace.cumulative_params_sizes() + + params: List[Optional[torch.LongTensor]] = [] + + for i in range(B): + a_idx = action_idx[i].item() + action_type = ActionSpace.supported_actions[a_idx] + if action_type.params_size() == 0: + params.append(None) + else: + start, end = cum[a_idx], cum[a_idx + 1] + # extract just that sample's params for its chosen action + params.append(index[i, start:end].long()) + + return action_idx, params + + + +def main(): + model = IQLTwinQ([OpFeatures, ActionHistory]) + + _model = _TwinHiearchicalQNetwork([OpFeatures, ActionHistory]) + + x = torch.tensor([[2, 1, 5, 7, 0, 0, 0, 0, 3, 4, 2, 0, 0, 0, 0, 2]]).float() + + + action_idx , param_idx = model._split_action_tensor(x) + + + obs = torch.zeros([1, 2152]) + + + q = _model(obs, action_idx, param_idx) + + print("Q-values:", q) + + + +if __name__ == "__main__": + + main() \ No newline at end of file diff --git a/iql/singleton.py b/iql/singleton.py new file mode 100644 index 0000000..252d3aa --- /dev/null +++ b/iql/singleton.py @@ -0,0 +1,8 @@ +class Singleton(type): + """Meta class to create a singleton instance of a class""" + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] \ No newline at end of file diff --git a/iql/value_function.py b/iql/value_function.py new file mode 100644 index 0000000..ae524b0 --- /dev/null +++ b/iql/value_function.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from typing import List, Type +from rl_autoschedular import config as cfg +from rl_autoschedular.observation import Observation, ObservationPart + + +ACTIVATION = nn.ReLU + + +class IQLValueModel(nn.Module): + """ + IQL Value function with the SAME encoder/MLP layout as PPO's ValueModel: + Linear(sum(obs_parts)->512) -> ACT -> 512 -> ACT -> 512 -> ACT -> 1 + - Input: full Observation tensor, then sliced via Observation.get_parts to match PPO. + - Output: V(s) as shape [B], same squeeze(-1) behavior as PPO. + - Loss: Expectile regression with parameter tau (IQL). + """ + + def __init__( + self, + obs_parts: List[Type[ObservationPart]], + tau: float = 0.7, + ): + super().__init__() + self.obs_parts = obs_parts + self.tau = cfg.tau # consider wiring this from cfg (e.g., cfg.iql.tau) if you keep hyperparams in config + + in_size = sum(part.size() for part in obs_parts) + self.network = nn.Sequential( + nn.Linear(in_size, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 1), + ) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + """ + Args: + obs: full Observation tensor (like in PPO) + Returns: + V(s) as [B] + """ + x = Observation.get_parts(obs, *self.obs_parts) + return self.network(x).squeeze(-1) # [B] + + @torch.no_grad() + def v(self, obs: torch.Tensor) -> torch.Tensor: + """Convenience alias often used in IQL codepaths.""" + return self.forward(obs) + + def loss(self, obs: torch.Tensor, q_values: torch.Tensor) -> torch.Tensor: + """ + Expectile regression loss: minimize E[ w_tau(u) * u^2 ], u = Q(s,a) - V(s) + Args: + obs: full Observation tensor for states [B, ...] (same as PPO input) + q_values: [B] or [B,1] tensor with target Q(s,a) (DETACHED upstream in IQL) + """ + v = self.forward(obs) # [B] + q = q_values.squeeze(-1) # [B] + diff = q - v # u + + # weight = |tau - 1(u < 0)| + # same as: tau if u >= 0 else (1 - tau) + weight = torch.abs(self.tau - (diff < 0).float()) + return (weight * diff.pow(2)).mean() \ No newline at end of file diff --git a/llm_action/.env.example b/llm_action/.env.example new file mode 100644 index 0000000..6e429ca --- /dev/null +++ b/llm_action/.env.example @@ -0,0 +1,6 @@ +# LLM +ANTHROPIC_API_KEY = "sk-ant-api03-..." + +# MLIR +MLIR_SHARED_LIBS=/path/to/llvm-project/build/lib/libomp.so,/path/to/llvm-project/build/lib/libmlir_c_runner_utils.so,/path/to/llvm-project/build/lib/libmlir_runner_utils.so +AST_DUMPER_BIN_PATH=/path/to/MLIR-RL/tools/ast_dumper/build/bin/AstDumper diff --git a/llm_action/data/conv2d/conv_2d_nchw_fchw_128_32_7_7_256_1_1_7_7.mlir b/llm_action/data/conv2d/conv_2d_nchw_fchw_128_32_7_7_256_1_1_7_7.mlir new file mode 100644 index 0000000..ed926d0 --- /dev/null +++ b/llm_action/data/conv2d/conv_2d_nchw_fchw_128_32_7_7_256_1_1_7_7.mlir @@ -0,0 +1,10 @@ +module { + func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface} + func.func @main(%arg0: tensor<128x32x7x7xf64>, %arg1: tensor<256x32x1x1xf64>, %arg2: tensor<128x256x7x7xf64>) -> (tensor<128x256x7x7xf64>, i64) attributes {llvm.emit_c_interface} { + %0 = call @nanoTime() : () -> i64 + %1 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<128x32x7x7xf64>, tensor<256x32x1x1xf64>) outs(%arg2 : tensor<128x256x7x7xf64>) -> tensor<128x256x7x7xf64> + %2 = call @nanoTime() : () -> i64 + %3 = arith.subi %2, %0 : i64 + return %1, %3 : tensor<128x256x7x7xf64>, i64 + } +} diff --git a/llm_action/data/matmul/matmul_128_256_128.mlir b/llm_action/data/matmul/matmul_128_256_128.mlir new file mode 100644 index 0000000..1db3597 --- /dev/null +++ b/llm_action/data/matmul/matmul_128_256_128.mlir @@ -0,0 +1,10 @@ +module { + func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface} + func.func @main(%arg0: tensor<128x256xf64>, %arg1: tensor<256x128xf64>, %arg2: tensor<128x128xf64>) -> (tensor<128x128xf64>, i64) attributes {llvm.emit_c_interface} { + %0 = call @nanoTime() : () -> i64 + %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x256xf64>, tensor<256x128xf64>) outs(%arg2 : tensor<128x128xf64>) -> tensor<128x128xf64> + %2 = call @nanoTime() : () -> i64 + %3 = arith.subi %2, %0 : i64 + return %1, %3 : tensor<128x128xf64>, i64 + } +} diff --git a/llm_action/log/test/test_20251221_141044.txt b/llm_action/log/test/test_20251221_141044.txt new file mode 100644 index 0000000..e98355b --- /dev/null +++ b/llm_action/log/test/test_20251221_141044.txt @@ -0,0 +1,224 @@ +I'll explore the RL action space by starting with a baseline benchmark, then systematically discovering transformations. Let me begin with Phase 1: establishing the base performance. + +## Phase 1: Baseline Benchmark + +First, let me execute the base code without any transformations:Great! The base code executes correctly (assertion passed) with **5,566,048 nanoseconds** baseline execution time. + +--- + +## Phase 2: Testing Action 1 - `tile_using_for` + +Now let me test the first transformation: **tiling with sequential for loops** using tile sizes [32, 32]:Excellent! The transformation succeeded. Now let me execute this transformed code:Perfect! The tiled code executes correctly. Now let me measure the speedup:--- + +## Phase 3: Testing Action 2 - `generalize` + `interchange` + +Now let me test a composition: first generalize the matmul to generic form, then try interchange:Great! The generalize transformation worked and converted the named matmul to a generic linalg operation. Let me execute it:Good! Now let me try combining generalize with tile_using_for to see how they compose. Let me re-tag after generalize and then apply tiling:Perfect! The composition works. Let me execute this to verify correctness:Excellent! Now let me measure this speedup:--- + +## Phase 4: Testing Action 3 - Different Tiling Configurations + +Let me test different tile sizes to understand the parameter space:--- + +# RL Action Space Discovery Summary + +## Baseline Performance +- **Base Code Execution Time**: 5,566,048 ns +- **Assertion Status**: ✅ PASSED (all transformations produce correct results) + +--- + +## Discovered Actions + +### ACTION 1: `tile_using_for` +**Status**: ✅ VALID ACTION + +**Description**: Sequential tiling using `scf.for` loops. Divides the iteration space into tiles and creates nested loops to iterate over tiles, then dispatches smaller tiled operations. + +**Parameters**: +- `tile_sizes`: List[int], length = number of input dimensions + - Each element: [0, dimension_size] + - 0 = skip this dimension (no tiling) + - Non-zero = tile size for that dimension + - Does NOT need to divide evenly (last tile may be smaller) + - Example: [32, 32, 0] tiles first two dimensions by 32 + +**Valid Parameter Ranges**: +| Parameter | Min | Max | Notes | +|-----------|-----|-----|-------| +| tile_size_0 | 0 | 128 | 0=no tiling, 1-128 valid | +| tile_size_1 | 0 | 128 | 0=no tiling, 1-128 valid | +| tile_size_2 | 0 | 256 | K dimension (reduction), often 0 | + +**Preconditions**: +- Operation must be a structured linalg operation (linalg.matmul, linalg.generic, etc.) +- Operation must have matching tag: `{tag = "operation_0"}` + +**Postconditions**: +- Returns tuple: `(tiled_op, loop_1, loop_2, ...)` + - Number of loops = number of non-zero tile sizes +- Original tag is **LOST** on tiled_op +- Introduces `scf.for` loops with `iter_args` for functional control flow +- Creates `tensor.extract_slice` and `tensor.insert_slice` operations +- Transformation is **idempotent** with same parameters + +**Performance Impact** (on matmul 128x256x128): +| Tile Size | Execution Time | Speedup | +|-----------|---|---| +| [32, 32, 0] | 7,329,471 ns | 0.76x ⚠️ | +| [64, 64, 0] | 5,647,546 ns | 0.99x ≈ | +| Base (no tile) | 5,566,048 ns | 1.0x baseline | + +**Composability**: +- ✅ Can follow: `generalize`, initial state +- ✅ Can precede: another `tile_using_for` (nested tiling), `generalize` again +- ⚠️ Requires re-tagging before next `structured.match` (but works without re-tag when chaining directly) +- ✅ Can compose with: `generalize → tile_using_for` in sequence + +**Example Code**: +```mlir +%op = transform.structured.match attributes{tag = "operation_0"} in %arg0 : (!transform.any_op) -> !transform.any_op +%tiled, %loops:2 = transform.structured.tile_using_for %op tile_sizes [32, 32, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +``` + +**RL Action Space Entry**: +```json +{ + "name": "tile_using_for", + "parameters": { + "tile_size_dim0": {"type": "int", "range": [0, 128], "default": 32}, + "tile_size_dim1": {"type": "int", "range": [0, 128], "default": 32}, + "tile_size_dim2": {"type": "int", "range": [0, 256], "default": 0} + }, + "preconditions": ["has_structured_op", "has_operation_tag"], + "returns": "tuple(tiled_op, loops[])", + "postcondition_note": "loses tag on tiled_op", + "composability": "can_chain_with_tile, generalize" +} +``` + +--- + +### ACTION 2: `generalize` +**Status**: ✅ VALID ACTION + +**Description**: Converts a named structured operation (e.g., `linalg.matmul`, `linalg.conv`) into a generic `linalg.generic` form with explicit indexing maps and blocks. + +**Parameters**: +- None (no tunable parameters) + +**Preconditions**: +- Operation must be a named structured operation +- Supports: `linalg.matmul`, `linalg.add`, `linalg.generic` (already generic), etc. + +**Postconditions**: +- Converts named op to `linalg.generic` with: + - Explicit `indexing_maps` (affine maps showing tensor access patterns) + - `iterator_types` array (["parallel", "parallel", "reduction"]) + - Explicit `^bb0` block with computation +- **Tag is LOST** after generalize +- Enables operations that require generic form (e.g., `interchange`) + +**Performance Impact** (on matmul 128x256x128): +| Scenario | Execution Time | Speedup | +|----------|---|---| +| Generalize only | 7,363,424 ns | 0.76x ⚠️ | +| Generalize + Tile [32,32,0] | 5,629,696 ns | 0.99x ≈ | + +**Composability**: +- ✅ Can follow: Initial state (before any transforms) +- ✅ Can precede: `tile_using_for`, `interchange` (requires generalize) +- ⚠️ Composition: `generalize → tile_using_for` works perfectly +- ❌ Not compatible: `interchange` on named ops (must generalize first) + +**Example Code**: +```mlir +%op = transform.structured.match attributes{tag = "operation_0"} in %arg0 : (!transform.any_op) -> !transform.any_op +%generalized = transform.structured.generalize %op : (!transform.any_op) -> !transform.any_op +``` + +**RL Action Space Entry**: +```json +{ + "name": "generalize", + "parameters": {}, + "preconditions": ["has_structured_op"], + "returns": "generic_op", + "postcondition_note": "loses_tag, converts_to_generic_form", + "required_for": ["interchange"], + "composability": "enables_subsequent_generic_transforms" +} +``` + +--- + +## Composition Patterns Discovered + +### ✅ Pattern 1: Direct Tiling +``` +[Initial Code] → tile_using_for[32,32,0] → [4×4 tiles] +Status: VALID +Performance: 0.76x (slightly slower, bufferization overhead) +Use case: Learning tiling effect, preparation for other transforms +``` + +### ✅ Pattern 2: Generalize then Tile +``` +[Initial Code] → generalize → tile_using_for[32,32,0] → [Tiled generic ops] +Status: VALID +Performance: 0.99x (near-baseline) +Use case: Foundation for more complex optimization chains +``` + +### ✅ Pattern 3: Multi-size Tiling Parameter Study +``` +tile_sizes: [32,32,0] vs [64,64,0] +Impact: + [32,32,0]: 0.76x (more overhead from extra loop iterations) + [64,64,0]: 0.99x (fewer loop iterations, less overhead) +Lesson: Larger tiles can reduce loop overhead +``` + +--- + +## Constraints & Failure Modes Discovered + +| Constraint | Description | Impact | +|-----------|-----------|--------| +| **Tag Loss on Transform** | After any transform, original tag is lost from output. Chaining requires re-annotation. | Affects RL state tracking | +| **Direct Chaining** | Can apply tile twice without explicit re-tagging (transform infrastructure handles it) | Not documented in spec | +| **Generalize Prerequisite** | `interchange` requires generalize first; fails with named ops | Hard ordering constraint | +| **Dimension Ordering** | 3D iteration space (2 parallel + 1 reduction). Reduction dim (k) often 0. | Parameter constraint | +| **K-Dimension Tiling** | Setting tile_sizes[2] > 0 is valid but creates extra loops on reduction axis | Valid but affects semantics | + +--- + +## Unexplored Actions (TODOs for Full Space) + +1. **tile_using_forall** - Parallel tiling variant +2. **interchange** - Loop reordering (blocked by generalize precondition) +3. **pad** - Add padding for alignment +4. **vectorize** - Convert to vector operations +5. **decompose** - Break ops into simpler forms +6. **split** / **split_reduction** - Other tiling variants +7. **fuse_into_containing_op** - Operation fusion + +--- + +## Key Metrics + +| Metric | Value | +|--------|-------| +| Total Actions Validated | 2 | +| Successful Transforms | 3/3 | +| Composition Patterns | 3 | +| Parameter Configurations Tested | 4 | +| Assertion Pass Rate | 100% | + +--- + +## Recommendations for RL Integration + +1. **State Representation**: Must track whether operation is "tagged" vs "lost tag" +2. **Action Ordering**: `generalize` enables `interchange`; build dependency graph +3. **Parameter Tuning**: Tile sizes [32-64] seem reasonable; explore [8-128] range systematically +4. **Reward Signal**: Currently tilng adds overhead; focus on compositions that improve performance +5. **Next Phase**: Explore `interchange` + composition chains once generalize is in pipeline \ No newline at end of file diff --git a/llm_action/src/__init__.py b/llm_action/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_action/src/agent.py b/llm_action/src/agent.py new file mode 100644 index 0000000..08555bb --- /dev/null +++ b/llm_action/src/agent.py @@ -0,0 +1,178 @@ +import json +from typing import AsyncGenerator, Optional + +from agno.agent import Agent, RunResponseEvent +from agno.playground import Playground, PlaygroundSettings + +from llm_action.src.llm import get_claude_llm +from llm_action.src.tools import transform_code, execute_code, measure_speedup +from llm_action.src.prompt import SYSTEM_INSTRUCTIONS +from llm_action.src.config import NUM_HISTORY_RUNS + +from llm_action.src.utils.log import logger + +class MLIR_LLM_Agent: + def __init__(self): + self.name = "MLIR LLM Agent" + self.description = "An agent that assists with MLIR automatic code optimization." + self.model = get_claude_llm() + self.agent = Agent( + name=self.name, + description=self.description, + model=self.model, + instructions=SYSTEM_INSTRUCTIONS, + tools=[measure_speedup, transform_code, execute_code], + memory=None, + storage=None, + add_history_to_messages=True, + num_history_runs=NUM_HISTORY_RUNS, + show_tool_calls=True, + markdown=True, + ) + + self.tool_execution = False + +class AgentWrapper: + def __init__(self): + self.mlir_llm_agent = MLIR_LLM_Agent() + # logger.info("[Agent] MLIR LLM Agent initialized") + + async def run_stream(self, benchmark_code: str) -> AsyncGenerator[str, None]: + """ + Run the Agno agents for a benchmark_code + """ + + response_stream = await self.mlir_llm_agent.agent.arun( + message=benchmark_code, + stream=True, + stream_intermediate_steps=True, + ) + + agent_response = "" + + async for event in response_stream: + output = self._format_event(event) + if output: + if output.get('type') == 'content': + agent_response += output['content'] + yield output + + # logger.info(f"[Agent]: Response: {agent_response}") + + def _format_event(self, event: RunResponseEvent) -> Optional[str]: + """ + Converts an agent event into a string for streaming/yielding. + """ + match event.event: + + # Run Events + case "RunStarted": + return { + "type": "run", + "status": "started" + } + case "RunResponseContent": + if not self.mlir_llm_agent.tool_execution: + return { + "type": "content", + "content": event.content + } + case "RunCompleted": + return { + "type": "run", + "status": "completed" + } + case "RunError": + return { + "type": "error", + "error": f"{event.content}" + } + case "RunCanceled": + return { + "type": "run", + "status": f"canceled" + } + case "RunPaused": + return { + "type": "run", + "status": "paused" + } + case "RunContinued": + return { + "type": "run", + "status": "continued" + } + + # Tool Events + case "ToolCallStarted": + + self.mlir_llm_agent.tool_execution = True + return { + "type": "tool", + "status": "started", + "name": event.tool.tool_name, + "arguments": event.tool.tool_args, + } + case "ToolCallCompleted": + self.mlir_llm_agent.tool_execution = False + + try: + tool_result = json.loads(event.tool.result) + except json.JSONDecodeError: + tool_result = event.tool.result + + return { + "type": "tool", + "status": "completed", + "name": event.tool.tool_name, + "result": tool_result, + } + + # Reasoning Events + case "ReasoningStarted": + return { + "type": "reasoning", + "status": "started", + } + case "ReasoningStep": + return { + "type": "reasoning", + "status": "step", + "content": event.content + } + case "ReasoningCompleted": + return { + "type": "reasoning", + "status": "completed", + "content": event.content + } + + # Memory Events + case "MemoryUpdateStarted": + return { + "type": "memory_update", + "status": "started", + } + case "MemoryUpdateCompleted": + return { + "type": "memory_update", + "status": "completed", + "content": event.content + } + + # Default case + case _: + return f"Unhandled event: {event.event}" + + def run_playground(self) -> None: + """ + Run the agent playground server. + """ + + playground = Playground( + agents=[self.mlir_llm_agent.agent], + settings=PlaygroundSettings(env="dev") + ) + app = playground.get_app() + + playground.serve(app) diff --git a/llm_action/src/config.py b/llm_action/src/config.py new file mode 100644 index 0000000..a2b9df5 --- /dev/null +++ b/llm_action/src/config.py @@ -0,0 +1,10 @@ +# LLM +CLAUDE_LLM_MODEL = "claude-haiku-4-5" +CLAUDE_LLM_TEMPERATURE = 1.0 + +# Execution +CODE_TRANSFORM_TIMEOUT = 10 # seconds +CODE_EXECUTION_TIMEOUT = 10 # seconds + +# History +NUM_HISTORY_RUNS = 10 \ No newline at end of file diff --git a/llm_action/src/keys.py b/llm_action/src/keys.py new file mode 100644 index 0000000..bad3f28 --- /dev/null +++ b/llm_action/src/keys.py @@ -0,0 +1,9 @@ +import os +import dotenv + +dotenv.load_dotenv() + +ANTHROPIC_API_KEY = os.getenv('ANTHROPIC_API_KEY') + +MLIR_SHARED_LIBS = os.getenv("MLIR_SHARED_LIBS") +AST_DUMPER_BIN_PATH = os.getenv("AST_DUMPER_BIN_PATH") diff --git a/llm_action/src/llm.py b/llm_action/src/llm.py new file mode 100644 index 0000000..c49fc17 --- /dev/null +++ b/llm_action/src/llm.py @@ -0,0 +1,12 @@ +from agno.models.anthropic import Claude + +from llm_action.src.keys import ANTHROPIC_API_KEY +from llm_action.src.config import CLAUDE_LLM_MODEL, CLAUDE_LLM_TEMPERATURE + +def get_claude_llm(): + llm = Claude( + id=CLAUDE_LLM_MODEL, + temperature=CLAUDE_LLM_TEMPERATURE, + api_key=ANTHROPIC_API_KEY, + ) + return llm diff --git a/llm_action/src/prompt.py b/llm_action/src/prompt.py new file mode 100644 index 0000000..6cfedc0 --- /dev/null +++ b/llm_action/src/prompt.py @@ -0,0 +1,300 @@ + +SYSTEM_INSTRUCTIONS = """ +You are an MLIR transformation explorer agent. Your mission is to systematically discover, implement, and parametrize MLIR transformations to build a comprehensive action space for reinforcement learning. + +## Your Purpose + +You are NOT trying to optimize code. You are **exploring what transformations are possible** in MLIR and **how they can be parametrized**. Think of yourself as a cartographer mapping uncharted territory - every valid transformation you discover expands the action space that an RL agent can later use for optimization. + +## Your Capabilities + +You have access to two tools: + +1. **transform_code**: Apply MLIR transform dialect sequences to modify code structure +2. **execute_code**: Compile and run code to verify transformation validity (correctness matters, speed doesn't yet) +3. **measure_speedup**: Measure speedup ratio between base and transformed code execution times + +## Your Mission: Build the RL Action Space + +For each transformation you explore, document: + +1. **Action Name**: What is this transformation called? +2. **Parameters**: What knobs can be tuned? (tile sizes, axis indices, boolean flags, etc.) +3. **Parameter Ranges**: What values are valid? What are the constraints? +4. **Preconditions**: When can this action be applied? (what operation types, what state must the IR be in?) +5. **Postconditions**: What does the IR look like after? (what new operations appear, what tags need tracking?) +6. **Composability**: Can this action chain with others? What must come before/after? + +## Transformation Categories to Explore + +### Category 1: Tiling Variants +- `tile_using_for` - Sequential tiling +- `tile_using_forall` - Parallel tiling +- Multi-level tiling (nested applications) +- Partial tiling (some dimensions only) + +**Parameters to discover:** +- Tile sizes (per dimension) +- Which dimensions to tile +- Number of tiling levels + +### Category 2: Loop Transformations +- `interchange` - Reorder loop dimensions +- `peel` - Handle loop remainders +- Loop unrolling (if available) + +**Parameters to discover:** +- Permutation orderings +- Peel factors +- Unroll factors + +### Category 3: Data Layout +- `pad` - Add padding for alignment +- `pack` / `unpack` - Data layout transformations +- `hoist_pad` - Move padding operations + +**Parameters to discover:** +- Padding amounts +- Pack tile sizes +- Hoist levels + +### Category 4: Operation Transformations +- `generalize` - Convert named ops to generic form +- `decompose` - Break complex ops into simpler ones +- `vectorize` - Enable SIMD operations +- `lower_to_loops` - Convert to explicit loops + +**Parameters to discover:** +- Vector widths +- Decomposition strategies + +### Category 5: Fusion & Composition +- `fuse_into_containing_op` - Combine operations +- `fuse` - Fuse producer/consumer + +**Parameters to discover:** +- Which operations to fuse +- Fusion ordering + +### Category 6: Lowering & Conversion +- Various lowering passes +- Dialect conversions + +## CRITICAL: MLIR Transform Dialect Syntax + +### Working with Tagged Operations + +**IMPORTANT**: Input code has operations tagged with `{tag = "operation_0"}`. Use these tags to match operations. + +**Basic Pattern:** +```mlir +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %op = transform.structured.match attributes{tag = "operation_0"} in %arg0 : (!transform.any_op) -> !transform.any_op + + // Apply transformation here + + transform.yield + } +} +``` + +### Re-tagging Pattern (Essential for Chaining) + +When chaining transformations, re-tag after each step: + +```mlir +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // Match and transform + %op = transform.structured.match attributes{tag = "operation_0"} in %arg0 : (!transform.any_op) -> !transform.any_op + %tiled, %loops:2 = transform.structured.tile_using_for %op tile_sizes [32, 32, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + // Re-tag for next transformation + %tag = transform.param.constant "operation_0" -> !transform.any_param + transform.annotate %tiled "tag" = %tag : !transform.any_op, !transform.any_param + + // Now can match again + %op2 = transform.structured.match attributes{tag = "operation_0"} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %op2 : !transform.any_op + + transform.yield + } +} +``` + +**SYNTAX RULES:** +- Match using `attributes{tag = "operation_N"}` +- Re-annotate after transformations for chaining +- `tile_using_for` with N non-zero sizes returns N+1 results +- Bind all results: `%tiled, %loops:2` +- Don't forget re-tagging when chaining + +## Exploration Strategy + +### Phase 1: Enumerate Individual Actions +For each transformation type: +1. Try it in isolation +2. Record if it succeeds or fails +3. Document parameter constraints discovered +4. Note preconditions (what IR state is needed) + +### Phase 2: Parameter Space Mapping +For each working transformation: +1. Try different parameter values +2. Find valid ranges (what causes failures?) +3. Identify discrete vs continuous parameters +4. Note parameter dependencies + +### Phase 3: Composability Testing +1. Try pairs of transformations +2. Document which orderings are valid +3. Find required intermediate steps +4. Map the composition graph + +### Phase 4: Edge Cases & Constraints +1. What happens at boundary conditions? +2. What operation types support which transforms? +3. What are the failure modes? + +## Output Format: Action Space Documentation + +Structure your exploration as action discovery: + +``` +=== ACTION DISCOVERY LOG === + +--- Action: tile_using_for --- +Status: VALID ACTION + +Parameters: + - tile_sizes: List[int], length = num_dimensions + - Valid range: 1 to dimension_size (or 0 to skip dimension) + - Must divide evenly OR peeling handles remainder + +Preconditions: + - Target must be a structured (linalg) operation + - Operation must have the matched tag + +Postconditions: + - Returns: (tiled_op, loop1, loop2, ...) - one loop per non-zero tile + - Original tag is LOST - must re-annotate tiled_op + - Creates nested scf.for loops around tiled operation + +Composability: + - Can follow: (initial state), pad, generalize + - Can precede: vectorize, another tile, interchange + - Requires re-tagging before next structured.match + +Example (verified working): +```mlir +%op = transform.structured.match attributes{tag = "operation_0"} in %arg0 : (!transform.any_op) -> !transform.any_op +%tiled, %loops:2 = transform.structured.tile_using_for %op tile_sizes [32, 32, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +``` + +RL Action Space Entry: +{ + "name": "tile_using_for", + "parameters": { + "tile_size_0": {"type": "int", "range": [1, 256], "default": 32}, + "tile_size_1": {"type": "int", "range": [1, 256], "default": 32}, + "tile_size_2": {"type": "int", "range": [0, 256], "default": 0} + }, + "preconditions": ["is_structured_op", "has_tag"], + "requires_retag": true +} + +--- + +--- Action: interchange --- +Status: VALID ACTION (requires generalize first) + +Parameters: + - iterator_interchange: List[int], permutation of [0, 1, ..., n-1] + - Must be valid permutation + - Length = number of iterator dimensions + +Preconditions: + - Must apply `generalize` first (named ops like matmul don't support directly) + - Operation must be generic linalg + +[... continue for each action discovered ...] + +--- + +=== FAILED TRANSFORMATIONS (Constraints Discovered) === + +--- Attempted: interchange on linalg.matmul directly --- +Status: FAILED +Error: "expects a GenericOp" +Lesson: Must generalize before interchange +Constraint: interchange.precondition += "is_generic_op" + +--- + +=== COMPOSITION PATTERNS DISCOVERED === + +Pattern: generalize → interchange → tile → vectorize +Status: Valid composition +Notes: Interchange changes loop order, then tile the reordered loops + +Pattern: tile → tile (nested) +Status: Valid - creates hierarchical tiling +Notes: Each level needs re-tagging + +Pattern: vectorize → tile +Status: Invalid ordering +Notes: Vectorize should come after tiling + +--- + +=== ACTION SPACE SUMMARY === + +Confirmed Actions: +1. tile_using_for (params: tile_sizes[]) +2. tile_using_forall (params: tile_sizes[]) +3. generalize (params: none) +4. interchange (params: permutation[], requires: generalize) +5. vectorize (params: none, requires: usually after tiling) +[...] + +Discovered Constraints: +- interchange requires generalize as precondition +- vectorize typically needs tiling first for effectiveness +- All structured transforms lose tags → re-annotation required +[...] + +Unexplored (TODO): +- pad: not yet tested +- decompose: not yet tested +- fusion operations: not yet tested +[...] +``` + +## Key Questions to Answer + +For building the RL action space: + +1. **What actions exist?** (enumerate all transform dialect operations) +2. **What are their parameters?** (continuous, discrete, categorical?) +3. **What are valid parameter ranges?** (min, max, constraints) +4. **What preconditions exist?** (required IR state, prior transforms) +5. **What postconditions result?** (how does IR change, what needs tracking) +6. **How do actions compose?** (valid orderings, required intermediates) +7. **What are the failure modes?** (invalid parameters, wrong preconditions) + +## Mindset + +- **Breadth over depth**: Try many different transformations rather than perfecting one +- **Failure is data**: A failed transformation teaches you constraints +- **Parameters are key**: An action without known parameters can't be used by RL +- **Composability matters**: RL will chain actions, so document what chains work +- **Correctness over speed**: Verify transforms produce valid IR (execution returns True), don't worry about performance yet + +## Remember + +You're building the foundation for RL-driven optimization. Every action you discover and parametrize becomes a tool the RL agent can use. Be systematic, document everything, and explore widely! + +# Initial PoC +Just try 2 to 3 actions only, just to debug the process cost-effectively. Start first benchmarking the base code without any transformations. (Avoid vectorization for now) +""" diff --git a/llm_action/src/tools.py b/llm_action/src/tools.py new file mode 100644 index 0000000..724aae2 --- /dev/null +++ b/llm_action/src/tools.py @@ -0,0 +1,103 @@ +from llm_action.src.utils.transformation import transform_bufferize_and_lower_v, execute_bufferized_code, run_transform_code + +from agno.tools import tool + +@tool( + name="measure_speedup", + description=""" + Measures the speedup achieved by MLIR transformations. + + This tool compares the execution time of base code against transformed code + to calculate the performance improvement factor. + + Use this when you need to: + - Quantify optimization effectiveness + - Compare performance before and after transformations + - Calculate speedup ratios + - Evaluate transformation impact + + Args: + base_execution_time: Execution time of original code in nanoseconds + execution_time: Execution time of transformed code in nanoseconds + + Returns: + The speedup ratio as a float (base_time / transformed_time) + """, + show_result=True, + stop_after_tool_call=False +) +def measure_speedup(base_execution_time: float, execution_time: float) -> float: + return base_execution_time / execution_time + +@tool( + name="transform_code", + description=""" + Applies MLIR transformations to the given code using custom transformation scripts. + + This tool takes base MLIR code and applies user-defined transformations to it, + allowing for optimization passes, dialect conversions, or other code modifications. + + Use this when you need to: + - Apply specific MLIR transformation passes to code + - Test different optimization strategies + - Convert between MLIR dialects + - Modify MLIR operations programmatically + + Args: + code: The base MLIR code to transform + transformation_code: The transformation script/pass to apply + + Returns: + The transformed MLIR code as a string + """, + show_result=True, + stop_after_tool_call=False +) +def transform_code(code: str, transformation_code: str) -> str: + return run_transform_code(code, transformation_code) + +@tool( + name="execute_code", + description=""" + Executes MLIR code and measures its performance with assertion validation. + + This tool compiles and runs MLIR code through a bufferization and lowering pipeline, + then executes it to measure real execution time and verify correctness through assertions. + + Use this when you need to: + - Benchmark MLIR code performance + - Verify that transformations maintain correctness + - Measure execution time in nanoseconds + - Validate code functionality through assertions + + The code goes through: + 1. Bufferization (converts tensor operations to memref) + 2. Lowering (converts high-level dialects to lower-level representations) + 3. Execution with timing and assertion checking + + Args: + code: The MLIR code to execute + + Returns: + tuple[int, bool]: (execution time in nanoseconds, assertion success/failure) + - First element: Real execution time measured in nanoseconds + - Second element: True if all assertions passed, False otherwise + """, + show_result=True, + stop_after_tool_call=False +) +def execute_code(code: str) -> tuple[int, bool]: + """Evaluates the given MLIR code with a timeout. + + Args: + state (OperationState): The operation state to evaluate. + tmp_exec_data_file (str): The path to the temporary execution data file. + + Returns: + tuple[int, bool]: (execution time in nanoseconds, assertion result) + """ + + bufferized_code = transform_bufferize_and_lower_v(code) + real_exec_time, success = execute_bufferized_code(bufferized_code) + return real_exec_time, success + diff --git a/llm_action/src/utils/cache.py b/llm_action/src/utils/cache.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_action/src/utils/log.py b/llm_action/src/utils/log.py new file mode 100644 index 0000000..a38f3fe --- /dev/null +++ b/llm_action/src/utils/log.py @@ -0,0 +1,10 @@ +import logging + +logger = logging.getLogger("MLIR-LLM") + +logger.setLevel(logging.INFO) + +handler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) diff --git a/llm_action/src/utils/processing.py b/llm_action/src/utils/processing.py new file mode 100644 index 0000000..80523a7 --- /dev/null +++ b/llm_action/src/utils/processing.py @@ -0,0 +1,16 @@ +from rl_autoschedular.state import extract_bench_features_from_code + +def preprocess_code(code: str) -> str: + """Preprocess the given MLIR code to extract benchmark features. + + Args: + code (str): The MLIR code as a string. + + Returns: + str: The preprocessed MLIR code. + """ + # Extract benchmark features (this may include transformations) + bench_features = extract_bench_features_from_code("", code, 0) + + # Return the (possibly modified) code + return bench_features.code diff --git a/llm_action/src/utils/transformation.py b/llm_action/src/utils/transformation.py new file mode 100644 index 0000000..4c72e15 --- /dev/null +++ b/llm_action/src/utils/transformation.py @@ -0,0 +1,253 @@ +from typing import Optional +import numpy as np +import ctypes.util + +from mlir.ir import Context, Module, MemRefType, IntegerType, F64Type, F32Type +from mlir.passmanager import PassManager +from mlir.execution_engine import ExecutionEngine +from mlir.dialects.func import FuncOp +from mlir.runtime import get_ranked_memref_descriptor, make_nd_memref_descriptor, as_ctype, ranked_memref_to_numpy + +from mlir.ir import Context, Module +from mlir.dialects.transform import interpreter +from utils.bindings_process import BindingsProcess + +from llm_action.src.keys import MLIR_SHARED_LIBS +from llm_action.src.config import CODE_TRANSFORM_TIMEOUT, CODE_EXECUTION_TIMEOUT + +def free_pointer(ptr: ctypes.c_void_p): + # Find the C standard library + libc_path = ctypes.util.find_library('c') + if not libc_path: + raise RuntimeError("C standard library not found.") + libc = ctypes.CDLL(libc_path) + + # Define the signature for free + free = libc.free + free.argtypes = [ctypes.c_void_p] + free.restype = None + + # Call free + free(ptr) + +def convert_to_args(inputs: list[np.ndarray], outputs_structure: ctypes.Structure): + args: list[ctypes._Pointer[ctypes._Pointer[ctypes.Structure]]] = [] + args.append(ctypes.pointer(ctypes.pointer(outputs_structure))) + for in_arr in inputs: + args.append(ctypes.pointer(ctypes.pointer( + get_ranked_memref_descriptor(in_arr) + ))) + return args + +def create_params(module: Module): + def __get_dtype(memref_type: MemRefType): + et = memref_type.element_type + match et: + case F32Type(): + np_dtype = np.float32 + case F64Type(): + np_dtype = np.float64 + case IntegerType(): + match et.width: + case 32: + np_dtype = np.int32 + case 64: + np_dtype = np.int64 + case _: + raise Exception(f'unexpected element type {et}') + case _: + raise Exception(f'unexpected element type {et}') + return np_dtype + + # Get the main function + main_func = next(op for op in module.body.operations if isinstance(op, FuncOp) and (op.name.value == 'main')) + + # Create input params + inputs: list[np.ndarray] = [] + for input_type in main_func.type.inputs: + assert isinstance(input_type, MemRefType), f'unexpected input type {input_type}' + in_arr = np.zeros(input_type.shape, dtype=__get_dtype(input_type)) + inputs.append(in_arr) + + # Create results arg + res_types = main_func.type.results + + exec_time_type = res_types[-1] + if not (isinstance(exec_time_type, IntegerType) and exec_time_type.width == 64): + raise Exception(f'unexpected exec time type {exec_time_type}') + + out_fields: list[tuple[str, type[ctypes.Structure]]] = [] + for i, out_type in enumerate(res_types[:-1]): + assert isinstance(out_type, MemRefType), f'unexpected output type {out_type}' + descriptor_type = make_nd_memref_descriptor(out_type.rank, as_ctype(__get_dtype(out_type))) + out_fields.append((f'out_{i}', descriptor_type)) + + class OutputsStructure(ctypes.Structure): + _fields_ = [ + *out_fields, + ("delta", ctypes.c_int64) + ] + delta: int + + def get_results(self): + res: list[np.ndarray] = [] + for field_name, _ in out_fields: + out_array = ranked_memref_to_numpy([getattr(self, field_name)]) + res.append(out_array.copy()) + return res + + def free_outputs(self): + for field_name, mem_desc_T in out_fields: + memref_descriptor: ctypes.Structure = getattr(self, field_name) + allocated_ptr: Optional[ctypes.c_longlong] = getattr(memref_descriptor, 'allocated', None) + + if allocated_ptr: + address = ctypes.cast(allocated_ptr, ctypes.c_void_p) + if address.value: + free_pointer(address) + setattr(self, field_name, mem_desc_T()) + + outputs_structure = OutputsStructure() + for i, (field_name, field_type) in enumerate(out_fields): + out_arg = field_type() + setattr(outputs_structure, field_name, out_arg) + + return inputs, outputs_structure + +def run_transform_code(code: str, transform_code: str, timeout: int = CODE_TRANSFORM_TIMEOUT) -> str: + """Applies an MLIR transform sequence to the given code. + + Args: + code (str): The MLIR code to transform. + transform_code (str): The MLIR transform dialect code to apply. + timeout (int, optional): Maximum time for transformation in seconds. Defaults to CODE_TRANSFORM_TIMEOUT. + + Returns: + str: The transformed MLIR code as a string. + """ + + def transform_bind_call(): + with Context(): + module = Module.parse(code) + t_module = Module.parse(transform_code) + interpreter.apply_named_sequence(module, t_module.body.operations[0], t_module) + + return str(module) + + return BindingsProcess.call(transform_bind_call, timeout=timeout) + +def transform_bufferize_and_lower_v(code: str): + """Apply the vectorization transformation with vectorizer to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + + Returns: + str: The code after applying the transformation. + """ + transform_code = """ + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) { + %all_loops = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops : !transform.any_op + + transform.structured.eliminate_empty_tensors %arg0 : !transform.any_op + %empty = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!transform.any_op) -> !transform.op<"tensor.empty"> + transform.bufferization.empty_tensor_to_alloc_tensor %empty : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> + + %f0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f0 { + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.reduction_to_contract + } : !transform.any_op + transform.apply_patterns to %f0 { + transform.apply_patterns.canonicalization + transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers + } : !transform.any_op + + %arg1 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg0 {bufferize_function_boundaries = true} : (!transform.any_op) -> !transform.any_op + + %f1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f1 { + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" + transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 + transform.apply_patterns.vector.lower_shape_cast + transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.yield + } + }""" + + return run_transform_code(code, transform_code) + +def execute_bufferized_code(code: str, timeout: int = CODE_EXECUTION_TIMEOUT) -> tuple[int, bool]: + """Lowers and runs the given MLIR code using Python bindings, then returns the execution time and assertion + result (if the executed code returns the correct result). + + Args: + code (str): The MLIR code to run. + timeout (int): The maximum time to allow for code execution in seconds. + + Returns: + int: the execution time in seconds. + bool: the assertion result. + """ + + def execute_bind_call(): + pass_pipeline = """builtin.module( + canonicalize, + buffer-deallocation-pipeline, + convert-bufferization-to-memref, + convert-linalg-to-loops, + scf-forall-to-parallel, + convert-scf-to-openmp, + expand-strided-metadata, + finalize-memref-to-llvm, + convert-scf-to-cf, + lower-affine, + + convert-openmp-to-llvm, + convert-vector-to-llvm, + convert-math-to-llvm, + convert-math-to-libm, + finalize-memref-to-llvm, + convert-func-to-llvm, + convert-index-to-llvm, + convert-arith-to-llvm, + convert-cf-to-llvm, + + reconcile-unrealized-casts, + canonicalize, + cse + )""" + + with Context(): + module = Module.parse(code) + pm = PassManager.parse(pass_pipeline) + + inputs, outs_struct = create_params(module) + args = convert_to_args(inputs, outs_struct) + + pm.run(module.operation) + execution_engine = ExecutionEngine( + module, + opt_level=3, + shared_libs=MLIR_SHARED_LIBS.split(","), + ) + + try: + for _ in range(2): + execution_engine.invoke("main", *args) + # If output tensors are needed call `get_results` before `free_outputs` + outs_struct.free_outputs() + finally: + outs_struct.free_outputs() + + return outs_struct.delta, True + + return BindingsProcess.call(execute_bind_call, timeout=timeout) diff --git a/neptune_sync.py b/neptune_sync.py index 99785ee..a97ad7c 100644 --- a/neptune_sync.py +++ b/neptune_sync.py @@ -1,6 +1,4 @@ -# Load environment variables -from dotenv import load_dotenv -load_dotenv(override=True) +from utils.keys import NEPTUNE_TOKEN, NEPTUNE_PROJECT import neptune from neptune import Run @@ -33,7 +31,8 @@ with open(os.path.join(run_path, 'tags'), 'r') as f: tags = f.read().splitlines() neptune_run = neptune.init_run( - project=os.getenv('NEPTUNE_PROJECT'), + project=NEPTUNE_PROJECT, + api_token=NEPTUNE_TOKEN, tags=tags, ) neptune_runs[run] = neptune_run diff --git a/requirements.txt b/requirements.txt index 15a4965..5409d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,6 @@ neptune tqdm dask-jobqueue typeguard +anthropic +agno==1.7.11 +fastapi \ No newline at end of file diff --git a/rl_autoschedular/actions/__init__.py b/rl_autoschedular/actions/__init__.py index 208c1d3..5bea362 100644 --- a/rl_autoschedular/actions/__init__.py +++ b/rl_autoschedular/actions/__init__.py @@ -6,6 +6,7 @@ from .tiled_fusion import TiledFusion from .interchange import Interchange from .vectorization import Vectorization +from .img2col import Img2Col from rl_autoschedular.state import OperationState import torch from torch.distributions import Distribution, Categorical @@ -21,7 +22,8 @@ class ActionSpace: TiledParallelization, TiledFusion, Interchange, - Vectorization + Vectorization, + Img2Col ] @classmethod diff --git a/rl_autoschedular/actions/img2col.py b/rl_autoschedular/actions/img2col.py new file mode 100644 index 0000000..1775f43 --- /dev/null +++ b/rl_autoschedular/actions/img2col.py @@ -0,0 +1,33 @@ +from utils.config import Config +from .base import Action +from rl_autoschedular.transforms import transform_img2col +from rl_autoschedular.state import OperationFeatures, OperationState, OperationType +from typing import Callable, Optional + + +class Img2Col(Action): + """Class representing Img2Col action""" + + symbol = 'I2C' + parameters: None + + def __init__( + self, + state: Optional[OperationState] = None, + **extras + ): + super().__init__( + state, + **extras + ) + + def _apply_ready(self, code): + original_code = code + try: + return transform_img2col(code, self.operation_tag) + except Exception: + return original_code + + @classmethod + def is_allowed(cls, state: OperationState): + return state.operation_features.operation_type == OperationType.Conv diff --git a/rl_autoschedular/benchmarks.py b/rl_autoschedular/benchmarks.py index b19671a..c9827f4 100644 --- a/rl_autoschedular/benchmarks.py +++ b/rl_autoschedular/benchmarks.py @@ -35,10 +35,13 @@ def __init__(self, is_training: bool = True): benchmark_data = extract_bench_features_from_file(bench_name, bench_file, root_exec_time) modified = False bench_code = benchmark_data.code + print("BEFORE: ", bench_code) for op_tag in benchmark_data.operation_tags: - if 'conv_2d' not in benchmark_data.operations[op_tag].operation_name: + print(op_tag) + if 'conv_2d' not in benchmark_data.operations[op_tag].operation_name or not cfg.use_img2col: continue bench_code = transform_img2col(bench_code, op_tag) + print("AFTER: ", bench_code) modified = True if modified: benchmark_data = extract_bench_features_from_code(bench_name, bench_code, root_exec_time) diff --git a/rl_autoschedular/env.py b/rl_autoschedular/env.py index 019c9e2..150ae71 100644 --- a/rl_autoschedular/env.py +++ b/rl_autoschedular/env.py @@ -2,13 +2,14 @@ from rl_autoschedular.benchmarks import Benchmarks from typing import Optional from rl_autoschedular.execution import Execution -from rl_autoschedular.actions import Action, TiledFusion +from rl_autoschedular.actions import Action, TiledFusion, Img2Col from utils.log import print_error from utils.config import Config import random import math import traceback +from rl_autoschedular.state import extract_bench_features_from_code class Env: """RL Environment class""" @@ -48,6 +49,7 @@ def step(self, state: OperationState, action: Action) -> OperationState: bool: A flag indicating if the operation is done. Optional[float]: The speedup (if the operation is executed successfully) for logging purposes. """ + # Copy the current state to introduce the changes throughout the function next_state = state.copy() @@ -249,6 +251,21 @@ def __update_state_infos(self, state: OperationState, action: Action): # In case of fusion we need to update the producer features as well if isinstance(action, TiledFusion): action.update_producer_features(state, self.benchmark_data) + + # In case of Img2Col, we need to update the benchmark data as whole + if isinstance(action, Img2Col): + self.benchmark_data = extract_bench_features_from_code(self.benchmark_data.bench_name, action.apply(self.benchmark_data.code), self.benchmark_data.root_exec_time) + i2c_state = self.__init_op_state(-1) + + state.bench_idx = i2c_state.bench_idx + state.bench_name = i2c_state.bench_name + state.operation_tag = i2c_state.operation_tag + state.original_operation_features = i2c_state.original_operation_features + state.operation_features = i2c_state.operation_features + state.producer_tag = i2c_state.producer_tag + state.producer_operand_idx = i2c_state.producer_operand_idx + state.producer_features = i2c_state.producer_features + state.terminal = i2c_state.terminal # Get updated operation features state.operation_features = action.update_features(state.operation_features) diff --git a/rl_autoschedular/ppo.py b/rl_autoschedular/ppo.py index 9a73803..a46946f 100644 --- a/rl_autoschedular/ppo.py +++ b/rl_autoschedular/ppo.py @@ -19,7 +19,7 @@ from typing import Optional -def collect_trajectory(data: Benchmarks, model: Model, step: int): +def collect_trajectory(data: Benchmarks, model: Model, step: int) -> TrajectoryData: """Collect a trajectory using the model and the environment. Args: @@ -34,7 +34,7 @@ def collect_trajectory(data: Benchmarks, model: Model, step: int): dm = DaskManager() fl = FileLogger() exe = Execution() - cfg = Config() + cfg = Config() eps = None if 'epsilon' in cfg.exploration: @@ -362,7 +362,7 @@ def evaluate_benchmarks(model: Model, data: Benchmarks): def __execute_states(state: OperationState, exec_data_file: str, benchs: Benchmarks, main_exec_data: Optional[dict[str, dict[str, int]]]): - print(f"Handling bench: {state.bench_name}...", end=' ', file=sys.stderr) + print(f"Handling bench: {state.bench_name} with {[[action.__str__() for action in transformation] for transformation in state.transformation_history]}", end=' ', file=sys.stderr) worker_start = time() Execution(exec_data_file, main_exec_data) diff --git a/rl_autoschedular/state.py b/rl_autoschedular/state.py index f2ed3d6..c99ec2a 100644 --- a/rl_autoschedular/state.py +++ b/rl_autoschedular/state.py @@ -8,6 +8,8 @@ from utils.config import Config from utils.log import print_error +from utils.keys import AST_DUMPER_BIN_PATH + if TYPE_CHECKING: from rl_autoschedular.actions.base import Action @@ -44,6 +46,16 @@ class NestedLoopFeatures: def copy(self): """Copy the current NestedLoopFeatures object.""" return NestedLoopFeatures(self.arg, self.lower_bound, self.upper_bound, self.step, self.iterator_type) + + def to_dict(self): + """Convert the NestedLoopFeatures to a dictionary.""" + return { + "arg": self.arg, + "lower_bound": self.lower_bound, + "upper_bound": self.upper_bound, + "step": self.step, + "iterator_type": self.iterator_type.value + } @dataclass @@ -84,6 +96,23 @@ def copy(self): self.vectorizable, self.pre_actions.copy() ) + + def to_dict(self): + """Convert the OperationFeatures to a dictionary.""" + return { + "operation_name": self.operation_name, + "operation_type": self.operation_type.value, + "op_count": self.op_count, + "load_data": self.load_data, + "store_data": self.store_data, + "nested_loops": [ + loop.to_dict() for loop in self.nested_loops + ], + "producers": self.producers, + "consumers": self.consumers, + "vectorizable": self.vectorizable, + "pre_actions": [action.symbol for action in self.pre_actions] + } @dataclass @@ -109,6 +138,18 @@ def copy(self): {tag: op.copy() for tag, op in self.operations.items()}, self.root_exec_time ) + + def to_dict(self): + """Convert the BenchmarkFeatures to a dictionary.""" + return { + "bench_name": self.bench_name, + "code": self.code, + "operation_tags": self.operation_tags, + "operations": { + tag: op.to_dict() for tag, op in self.operations.items() + }, + "root_exec_time": self.root_exec_time + } @dataclass @@ -173,6 +214,23 @@ def copy(self): self.terminal ) + def to_dict(self): + """Convert the OperationState to a dictionary.""" + return { + "bench_idx": self.bench_idx, + "bench_name": self.bench_name, + "operation_tag": self.operation_tag, + "original_operation_features": self.original_operation_features.to_dict(), + "operation_features": self.operation_features.to_dict(), + "producer_tag": self.producer_tag, + "producer_operand_idx": self.producer_operand_idx, + "producer_features": self.producer_features.to_dict() if self.producer_features is not None else None, + "transformation_history": [ + [action.__repr__() for action in seq] for seq in self.transformation_history + ], + "terminal": self.terminal + } + def extract_bench_features_from_code(bench_name: str, code: str, root_execution_time: int): """Extract benchmark features from the given code. @@ -187,7 +245,7 @@ def extract_bench_features_from_code(bench_name: str, code: str, root_execution_ BenchmarkFeatures: the extracted benchmark features """ result = subprocess.run( - f'{os.getenv("AST_DUMPER_BIN_PATH")} -', + f'{AST_DUMPER_BIN_PATH} -', shell=True, input=code.encode('utf-8'), stdout=subprocess.PIPE, @@ -211,7 +269,7 @@ def extract_bench_features_from_file(bench_name: str, file_path: str, root_execu BenchmarkFeatures: the extracted benchmark features """ result = subprocess.run( - f'{os.getenv("AST_DUMPER_BIN_PATH")} {file_path}', + f'{AST_DUMPER_BIN_PATH} {file_path}', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE @@ -348,5 +406,7 @@ def __get_operation_type(operation_name: str): """ for operation_type in OperationType: if operation_type.value and operation_type.value in operation_name: + if operation_type.value == "conv" and ( "op0" in operation_name or "op1" in operation_name or "i2c" in operation_name): + return OperationType.unknown return operation_type return OperationType.unknown diff --git a/rl_autoschedular/trajectory.py b/rl_autoschedular/trajectory.py index 57a9033..303ae2c 100644 --- a/rl_autoschedular/trajectory.py +++ b/rl_autoschedular/trajectory.py @@ -255,6 +255,7 @@ def update_attributes(self, model: Model): """ start = time() + actions_old_log_p, values, _ = model(self.obs.to(device), self.actions_index.to(device)) self.actions_old_log_p, self.values = actions_old_log_p.cpu(), values.cpu() @@ -404,3 +405,12 @@ def reset(self): self.actions_bev_log_p.clear() self.rewards.clear() self.done.clear() + + def __len__(self) -> int: + """Get the length of the collected trajectory. + + Returns: + int: The length of the collected trajectory. + """ + assert len(self.num_loops) == len(self.actions_index) == len(self.obs) == len(self.next_obs) == len(self.actions_bev_log_p) == len(self.rewards) == len(self.done) + return len(self.obs) \ No newline at end of file diff --git a/rl_autoschedular/transforms.py b/rl_autoschedular/transforms.py index e59451c..89791cd 100644 --- a/rl_autoschedular/transforms.py +++ b/rl_autoschedular/transforms.py @@ -366,7 +366,6 @@ def transform_bufferize_and_lower_v(code: str): Args: code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. Returns: str: The code after applying the transformation. diff --git a/run_llm_playground.py b/run_llm_playground.py new file mode 100644 index 0000000..7b446a3 --- /dev/null +++ b/run_llm_playground.py @@ -0,0 +1,7 @@ +import asyncio + +from llm_action.src.agent import AgentWrapper + +if __name__ == "__main__": + agent = AgentWrapper() + agent.run_playground() diff --git a/run_mlir_agent.py b/run_mlir_agent.py new file mode 100644 index 0000000..323579a --- /dev/null +++ b/run_mlir_agent.py @@ -0,0 +1,123 @@ +import asyncio +from datetime import datetime + +from llm_action.src.utils.processing import preprocess_code +from llm_action.src.agent import AgentWrapper + +async def test_agent(): + """Test the MLIR LLM Agent with a matmul benchmark.""" + + # Load the MLIR code + with open("llm_action/data/matmul/matmul_128_256_128.mlir", "r") as f: + mlir_code = f.read() + mlir_code = preprocess_code(mlir_code) + + print("=" * 80) + print("MLIR LLM AGENT TEST") + print("=" * 80) + print(f"\n📄 Input Code:") + print("-" * 80) + print(mlir_code) + print("-" * 80) + print("\n🤖 Agent Starting...\n") + print("=" * 80) + + # Initialize agent + agent = AgentWrapper() + + # Track metrics + transformations = [] + executions = [] + agent_text = [] + + # Run the agent and stream output + async for event in agent.run_stream(mlir_code): + if event is None: + continue + + event_type = event.get('type') + + if event_type == 'run': + status = event.get('status') + if status == 'started': + print("\n🚀 RUN STARTED\n") + elif status == 'completed': + print("\n✅ RUN COMPLETED\n") + elif status == 'error': + print(f"\n❌ ERROR: {event.get('error')}\n") + + elif event_type == 'content': + content = event.get('content', '') + agent_text.append(content) + print(content, end='', flush=True) + + elif event_type == 'tool': + status = event.get('status') + tool_name = event.get('name') + + if status == 'started': + print(f"\n\n🔧 TOOL CALL: {tool_name}") + print("-" * 80) + + args = event.get('arguments', {}) + if tool_name == 'transform_code': + print("📝 Transformation Code:") + print(args.get('transformation_code', '')) + elif tool_name == 'execute_code': + code = args.get('code', '') + print(f"⚙️ Executing code ({len(code)} chars)") + print("-" * 80) + + elif status == 'completed': + result = event.get('result') + print(f"\n✓ Tool Result: {result}") + + if tool_name == 'transform_code': + transformations.append({ + 'result': result + }) + elif tool_name == 'execute_code': + if isinstance(result, (list, tuple)) and len(result) == 2: + exec_time, success = result + executions.append({ + 'time_ns': exec_time, + 'success': success + }) + print(f" ⏱️ Execution Time: {exec_time:,} ns") + print(f" {'✅' if success else '❌'} Correctness: {success}") + print("-" * 80) + + # Summary + print("\n" + "=" * 80) + print("📊 SUMMARY") + print("=" * 80) + print(f"\n🔧 Total Transformations: {len(transformations)}") + print(f"⚙️ Total Executions: {len(executions)}") + + if executions: + print(f"\n⏱️ Execution Times:") + for i, exec_data in enumerate(executions, 1): + time_ns = exec_data['time_ns'] + success = exec_data['success'] + status_icon = '✅' if success else '❌' + print(f" {i}. {time_ns:,} ns {status_icon}") + + # Calculate speedup if we have baseline + if len(executions) > 1: + baseline = executions[0]['time_ns'] + best = min(e['time_ns'] for e in executions if e['success']) + speedup = baseline / best if best > 0 else 0 + print(f"\n🚀 Best Speedup: {speedup:.2f}x") + + print("\n" + "=" * 80) + print("Full agent response saved to output.") + print("=" * 80) + + # Optionally save full output + date_format = datetime.now().strftime("%Y%m%d_%H%M%S") + with open(f"llm_action/log/test/test_{date_format}.txt", "w") as f: + f.write("".join(agent_text)) + print(f"\n💾 Full output saved to: llm_action/log/test/test_{date_format}.txt\n") + +if __name__ == "__main__": + asyncio.run(test_agent()) \ No newline at end of file diff --git a/scripts/neptune-sync.sh b/scripts/neptune-sync.sh index e46d1ce..30de37f 100644 --- a/scripts/neptune-sync.sh +++ b/scripts/neptune-sync.sh @@ -2,23 +2,25 @@ # Define the resource requirements here using #SBATCH +# SBATCH -j neptune_sync #SBATCH -p compute +#SBATCH --reservation=c2 #SBATCH --nodes=1 #SBATCH -c 4 #SBATCH --mem=16G #SBATCH -t 07-00 #SBATCH -o logs/neptune/%j.out -#SBATCH --mail-type=FAIL,TIME_LIMIT -#SBATCH --mail-user=mt5383@nyu.edu +#SBATCH --mail-type=ALL +#SBATCH --mail-user=kb5213@nyu.edu -# Resource requiremenmt commands end here +# Resource requirements end here -#Add the lines for running your code/application +# Add the lines for running your code/application module load miniconda-nobashrc eval "$(conda shell.bash hook)" # Activate any environments if required -conda activate testenv +conda activate llvm-build # Execute the code python $SCRATCH/MLIR-RL/neptune_sync.py diff --git a/tests/cuda.py b/tests/cuda.py new file mode 100644 index 0000000..f4afc62 --- /dev/null +++ b/tests/cuda.py @@ -0,0 +1,48 @@ +import torch +import platform +import subprocess + +def check_cuda(): + # Check if CUDA is available + cuda_available = torch.cuda.is_available() + print(f"CUDA Available: {cuda_available}") + + if cuda_available: + # Number of GPUs available + num_gpus = torch.cuda.device_count() + print(f"Number of GPUs: {num_gpus}") + + # CUDA version from torch + print(f"PyTorch CUDA Version: {torch.version.cuda}") + + # Information about each GPU + for i in range(num_gpus): + print(f"\nDevice {i}:") + print(f" Name: {torch.cuda.get_device_name(i)}") + print(f" Total Memory: {torch.cuda.get_device_properties(i).total_memory / (1024 ** 3):.2f} GB") + print(f" Memory Allocated: {torch.cuda.memory_allocated(i) / (1024 ** 3):.2f} GB") + print(f" Memory Cached: {torch.cuda.memory_reserved(i) / (1024 ** 3):.2f} GB") + print(f" Compute Capability: {torch.cuda.get_device_capability(i)}") + print(f" Device Properties: {torch.cuda.get_device_properties(i)}") + + # Driver version + print(f"Driver Version: {torch.version.cuda}") + + # Run nvidia-smi command to get additional GPU details + try: + nvidia_smi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, text=True) + print(f"\nNVIDIA-SMI Output:\n{nvidia_smi_output.stdout}") + except FileNotFoundError: + print("nvidia-smi command not found. Ensure NVIDIA drivers are installed correctly.") + + else: + print("No CUDA-enabled GPU detected.") + + # Print system and Python environment information + print("\nSystem and Environment Information:") + print(f" Operating System: {platform.system()} {platform.release()} ({platform.version()})") + print(f" Python Version: {platform.python_version()}") + print(f" PyTorch Version: {torch.__version__}") + +if __name__ == "__main__": + check_cuda() \ No newline at end of file diff --git a/train.py b/train.py index 1c3b85f..d79f10c 100644 --- a/train.py +++ b/train.py @@ -98,6 +98,16 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: # Collect trajectory using the model trajectory = collect_trajectory(train_data, model, step) + + # print the sizes of the trajectory components for debugging + # print_info( + # f"Trajectory collected: " + # f"Obs size: {trajectory.obs.size()}, " + # f"Next Obs size: {trajectory.next_obs.size()}, " + # f"Actions size: {trajectory.actions_index.size()}, " + # f"Rewards size: {trajectory.rewards.size()}, ", + # flush=True + # ) # Extend trajectory with previous trajectory if cfg.reuse_experience != 'none': @@ -117,7 +127,7 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: ppo_update(trajectory, model, optimizer) # Save the model - if (step + 1) % 5 == 0: + if (step + 1) % cfg.save_model_every == 0: torch.save( model.state_dict(), os.path.join( @@ -126,7 +136,7 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: ) ) - if (step + 1) % 100 == 0: + if (step + 1) % cfg.evaluate_every == 0: print_info('- Evaluating benchmarks -') evaluate_benchmarks(model, eval_data) @@ -138,6 +148,6 @@ def load_main_exec_data() -> Optional[dict[str, dict[str, int]]]: elapsed_dlt = timedelta(seconds=int(elapsed)) eta_dlt = timedelta(seconds=int(eta)) -if (step + 1) % 100 != 0: +if (step + 1) % cfg.evaluate_every != 0: print_info('- Evaluating benchmarks -') evaluate_benchmarks(model, eval_data) diff --git a/train_iql_offline.py b/train_iql_offline.py new file mode 100644 index 0000000..f1f6c38 --- /dev/null +++ b/train_iql_offline.py @@ -0,0 +1,208 @@ +import dotenv + +dotenv.load_dotenv() + +import os +import time +import torch +import numpy as np + +from rl_autoschedular import config as cfg, file_logger as fl +from utils.config import Config +from utils.file_logger import FileLogger +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.env import Env +from iql.agent import IQLAgent +from utils.data_collector import OfflineDataset +from rl_autoschedular.observation import Observation,OpFeatures, ActionHistory + +from tqdm import trange + +device = torch.device("cpu") + +cfg = Config() +fl = FileLogger() + +def load_dataset(): + """Load offline dataset from OfflineDataset singleton.""" + dataset = OfflineDataset( + save_dir=cfg.offline_data_save_dir, + fname=cfg.offline_data_file + ).load() + + if not dataset: + raise FileNotFoundError(f"Offline dataset not found: {cfg.offline_data_file}") + + states = torch.tensor(dataset["obs"], dtype=torch.float32) + actions = torch.tensor(dataset["actions"], dtype=torch.long) + rewards = torch.tensor(dataset["rewards"], dtype=torch.float32) + next_states = torch.tensor(dataset["next_obs"], dtype=torch.float32) + dones = torch.tensor(dataset["dones"], dtype=torch.float32) + + return states, actions, rewards, next_states, dones + + +@torch.no_grad() +def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): + """Evaluta a the model on the evaluation environment. + Args: + model (Model): The policy/value model. + env (Env): The environment. + step (int): Current training step. + Returns: + env_time (float): Time spent in environment steps. + """ + + + env_time = 0.0 # Time spent in environment steps + + eps = None + + + # store rewards and entropies to log average for the model accross the benchmarks later + all_speedups = [] + all_entropies = [] + + + for _ in trange(cfg.bench_count, desc='Trajectory'): + + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done = False + speedup = None + + # store rewards and entropies to log average for the current benchmark later + bench_rewards, bench_entropies = [], [] + + bench_name = state.bench_name + + + while not bench_done: + obs = Observation.from_state(state) + + # Sample action and log-prob from *current policy* + action_index, action_log_p, entropy = model.sample(obs.to(device)) + assert action_index.size(0) == 1 and action_log_p.size(0) == 1 + action = ActionSpace.action_by_index(action_index[0], state) + + # Step environment + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + next_obs = Observation.from_state(next_state) + + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + + # Accumulate metrics + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # === Per-benchmark logging === + mean_reward = float(np.mean(bench_rewards)) if bench_rewards else 0.0 + mean_entropy = float(np.mean(bench_entropies)) if bench_entropies else 0.0 + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + + bench_metrics = { + "mean_reward": mean_reward, + "mean_entropy": mean_entropy, + "final_speedup": speedup if speedup is not None else 0.0, + } + + fl.log_scalars(f"eval/{bench_name}", bench_metrics, step) + + print( + f"\033[92m\n- Eval Bench: {bench_name}\n" + f"- Mean Reward: {mean_reward:.4f}\n" + f"- Mean Entropy: {mean_entropy:.4f}\n" + f"- Final Speedup: {speedup if speedup is not None else 0.0:.4f}\033[0m" + ) + + + # === Global logging (across all benchmarks) === + if all_speedups: + fl.log_scalar("eval/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("eval/average_entropy", float(np.mean(all_entropies)), step) + + return env_time + +def train_iql(): + # Load offline dataset + print(f"Loading offline dataset from {cfg.offline_data_file} ...") + states, actions, rewards, next_states, dones = load_dataset() + dataset_size = states.shape[0] + print(f"Dataset loaded: {dataset_size} transitions") + + # Initialize IQL agent + agent = IQLAgent(cfg,device,obs_parts=[OpFeatures, ActionHistory]) + + eval_env = Env(is_training=False,run_name=cfg.run_name) + + + print("Starting IQL training ...") + start_time = time.time() + + step = 0 + iql_trange = trange(cfg.max_steps, desc="IQL Training",dynamic_ncols=True) + for step in iql_trange: + # Sample a random batch + idxs = np.random.randint(0, dataset_size, size=cfg.batch_size) + batch = ( + states[idxs].to(device), + actions[idxs].to(device), + rewards[idxs].to(device), + next_states[idxs].to(device), + dones[idxs].to(device), + ) + + losses = agent.update(batch) + + + # Only log occasionally to reduce disk I/O + if step % 50 == 0: + fl.log_scalars("train", losses, step) + + if (step +1) % 100 == 0: + elapsed = time.time() - start_time + iql_trange.set_postfix({ + "Value Loss": f"{losses['value']:.4f}", + "Q Loss": f"{losses['q']:.4f}", + "Policy Loss": f"{losses['policy']:.4f}", + "Elapsed": f"{elapsed:.2f}s" + }) + + # Evaluate the agent on benchmarks every 1000 steps + if (step + 1) % 1000 == 0: + print("Evaluating on benchmarks ...") + eval_start = time.time() + env_time = evaluate_benchmarks(agent, eval_env, step) + eval_time = time.time() - eval_start + print(f"Evaluation completed in {eval_time:.2f} seconds (env time: {env_time:.2f} seconds)") + fl.flush() + + + + if (step+1) % 2000 == 0 and step > 0: + ckpt_path = os.path.join(cfg.results_dir, f"iql_step_{step}.pt") + os.makedirs(cfg.results_dir, exist_ok=True) + torch.save(agent.state_dict(), ckpt_path) + print(f"Checkpoint saved: {ckpt_path}") + + + + total_time = time.time() - start_time + print(f"Training finished in {total_time:.2f} seconds.") + + +if __name__ == "__main__": + train_iql() \ No newline at end of file diff --git a/train_iql_online.py b/train_iql_online.py new file mode 100644 index 0000000..3e7a75c --- /dev/null +++ b/train_iql_online.py @@ -0,0 +1,340 @@ +import os +import time +import torch +import numpy as np +from tqdm import trange + +import dotenv +dotenv.load_dotenv() + +from rl_autoschedular import config as cfg, file_logger as fl +from rl_autoschedular.env import Env +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.observation import Observation, OpFeatures, ActionHistory +from iql.iql_agent import IQLAgent +from utils.data_collector import OfflineDataset + + +device = torch.device("cpu") + + +MAX_STEPS_OFFLINE_ONLINE_RATIO = 100_000 # steps over which to decay online-offline ratio + +UPDATE_ITERS = 3 # gradient updates per batch + +def load_offline_dataset(): + """Load offline dataset for warm-starting replay buffer.""" + dataset = OfflineDataset( + save_dir=cfg.offline_data_save_dir, + fname=cfg.offline_data_file + ).load() + + if not dataset: + raise FileNotFoundError(f"Offline dataset not found: {cfg.offline_data_file}") + + states = torch.tensor(dataset["obs"], dtype=torch.float32) + actions = torch.tensor(dataset["actions"], dtype=torch.long) + rewards = torch.tensor(dataset["rewards"], dtype=torch.float32) + next_states = torch.tensor(dataset["next_obs"], dtype=torch.float32) + dones = torch.tensor(dataset["dones"], dtype=torch.float32) + + return states, actions, rewards, next_states, dones + + +@torch.no_grad() +def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): + """Evaluate model performance across all benchmarks.""" + env_time = 0.0 + eps = None + all_speedups, all_entropies = [], [] + + for _ in trange(cfg.bench_count, desc="Eval Trajectory", leave=False): + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done, speedup = False, None + bench_rewards, bench_entropies = [], [] + bench_name = state.bench_name + + while not bench_done: + obs = Observation.from_state(state) + action_index, action_log_p, entropy = model.sample(obs.to(device), eps=eps) + action = ActionSpace.action_by_index(action_index[0], state) + + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # per-benchmark logs + fl.log_scalars(f"eval/{bench_name}", { + "mean_reward": float(np.mean(bench_rewards)) if bench_rewards else 0.0, + "mean_entropy": float(np.mean(bench_entropies)) if bench_entropies else 0.0, + "final_speedup": speedup if speedup is not None else 0.0, + }, step) + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + # global logs + if all_speedups: + fl.log_scalar("eval/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("eval/average_entropy", float(np.mean(all_entropies)), step) + + return env_time + + +def collect_warmup_data(agent, env, buffer, warmup_steps=5000): + """ + Collects a fixed number of online transitions before training begins. + This stabilizes early learning by pre-filling the online buffer. + """ + print(f"\n[ Warmup Phase ] Collecting {warmup_steps} online transitions before training...\n") + state = env.reset() + + progress = trange(warmup_steps, desc="Warmup Collection", dynamic_ncols=True) + + + for _ in progress: + obs = Observation.from_state(state) + action_index, _, _ = agent.sample(obs.to(device), eps=None) + action = ActionSpace.action_by_index(action_index[0], state) + + next_state, reward, op_done, _ = env.step(state, action) + next_obs = Observation.from_state(next_state) + + # Handle operation completion + if op_done: + next_state, done = env.get_next_op_state(next_state) + else: + done = False + + # Store in online buffer + buffer.add_online( + obs.to(device), + action_index, + torch.tensor(reward, dtype=torch.float32, device=device), + next_obs.to(device), + torch.tensor(done, dtype=torch.float32, device=device), + ) + + state = next_state if not done else env.reset() + + + print(f"[Warmup Complete] Collected {len(buffer.online_buffer)} online transitions.\n") + + +def get_epsilon(step, eps_start=0.3, eps_end=0.05, decay_steps=100_000): + """ + Linearly decays epsilon from eps_start → eps_end over decay_steps. + Used for epsilon-greedy exploration during online interaction. + """ + if step >= decay_steps: + return eps_end + decay = (eps_start - eps_end) * (1 - step / decay_steps) + return eps_end + decay + + +class ReplayBuffer: + """Simple replay buffer mixing offline + online data.""" + def __init__(self, max_size=100000): + self.states, self.actions, self.rewards, self.next_states, self.dones = [], [], [], [], [] + self.max_size = max_size + + def add(self, s, a, r, ns, d): + if len(self.states) >= self.max_size: + # drop oldest + self.states.pop(0) + self.actions.pop(0) + self.rewards.pop(0) + self.next_states.pop(0) + self.dones.pop(0) + + self.states.append(s.squeeze(0)) + self.actions.append(a.squeeze(0)) + self.rewards.append(r.squeeze(0)) + self.next_states.append(ns.squeeze(0)) + self.dones.append(d) + + def sample(self, batch_size): + idxs = np.random.randint(0, len(self.states), size=batch_size) + return ( + torch.stack([self.states[i] for i in idxs]), + torch.stack([self.actions[i] for i in idxs]), + torch.stack([self.rewards[i] for i in idxs]), + torch.stack([self.next_states[i] for i in idxs]), + torch.stack([self.dones[i] for i in idxs]), + ) + + def __len__(self): + return len(self.states) + +class DualReplayBuffer: + """Manages separate buffers for offline and online data, + with controlled sampling ratio that decays over time.""" + def __init__(self, offline_buffer_max=100_000, online_buffer_max=100_000): + self.offline_buffer = ReplayBuffer(max_size=offline_buffer_max) + self.online_buffer = ReplayBuffer(max_size=online_buffer_max) + + def add_offline(self, s, a, r, ns, d): + self.offline_buffer.add(s, a, r, ns, d) + + def add_online(self, s, a, r, ns, d): + self.online_buffer.add(s, a, r, ns, d) + + def sample(self, batch_size, step, max_steps, online_start_ratio=0.8, online_end_ratio=0.2): + """Sample from both buffers with ratio decaying linearly over training.""" + # Compute current online ratio + progress = min(step / max_steps, 1.0) + online_ratio = online_start_ratio - (online_start_ratio - online_end_ratio) * progress + online_batch = int(batch_size * online_ratio) + offline_batch = batch_size - online_batch + + # Safety check: if one buffer is too small, compensate from the other + online_batch = min(online_batch, len(self.online_buffer)) + offline_batch = batch_size - online_batch + if len(self.offline_buffer) < offline_batch: + offline_batch = len(self.offline_buffer) + online_batch = batch_size - offline_batch + + # Sample + if online_batch > 0: + online_samples = self.online_buffer.sample(online_batch) + else: + online_samples = None + if offline_batch > 0: + offline_samples = self.offline_buffer.sample(offline_batch) + else: + offline_samples = None + + # Merge samples + def merge_tensors(t1, t2): + if t1 is None: return t2 + if t2 is None: return t1 + return torch.cat([t1, t2], dim=0) + + return tuple( + merge_tensors(o, f) + for o, f in zip( + online_samples if online_samples else (None,)*5, + offline_samples if offline_samples else (None,)*5 + ) + ) + + def __len__(self): + return len(self.offline_buffer) + len(self.online_buffer) + + +def hybrid_finetune(): + # === Load pretrained agent === + agent = IQLAgent(cfg, device, obs_parts=[OpFeatures, ActionHistory]) + ckpt_path = "./offline_iql_results_1/iql_step_97999.pt" + if os.path.exists(ckpt_path): + agent.load_state_dict(torch.load(ckpt_path, map_location=device)) + print(f"Loaded pretrained checkpoint: {ckpt_path}") + else: + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + + # === Init Replay Buffer with offline data === + + buffer = DualReplayBuffer(offline_buffer_max=100_000, online_buffer_max=100_000) + + states, actions, rewards, next_states, dones = load_offline_dataset() + for s, a, r, ns, d in zip(states, actions, rewards, next_states, dones): + buffer.add_offline(s, a, r, ns, d) + print(f"Offline Replay buffer initialized with {len(buffer)} offline samples") + + # environments + train_env = Env(is_training=True, run_name=cfg.run_name) + eval_env = Env(is_training=False, run_name=cfg.run_name) + + print("Starting HYBRID fine-tuning (offline + online)...") + start_time = time.time() + state = train_env.reset() + + + # Warmup Phase (modular) + collect_warmup_data(agent, train_env, buffer, warmup_steps=2000) + + hybrid_trange = trange(cfg.max_steps, desc="Hybrid Fine-tuning", dynamic_ncols=True) + for step in hybrid_trange: + # reset benchmark + state = train_env.reset() + done = False + + while not done: + # current obs + obs = Observation.from_state(state) + + # agent picks action + eps = get_epsilon(step) + action_index, _, _ = agent.sample(obs.to(device), eps=eps) + action = ActionSpace.action_by_index(action_index[0], state) + + # env step + next_state, reward, op_done, _ = train_env.step(state, action) + + # build next_obs BEFORE advancing benchmark + next_obs = Observation.from_state(next_state) + + # if op finished, advance to next op or benchmark end + if op_done: + next_state, done = train_env.get_next_op_state(next_state) + + # push transition to replay buffer + buffer.add_offline( + obs.to(device), + action_index, + torch.tensor(reward, dtype=torch.float32, device=device), + next_obs.to(device), + torch.tensor(done, dtype=torch.float32, device=device), + ) + + # move forward + state = next_state + + # after benchmark, do 1 gradient update + batch = buffer.sample(cfg.batch_size,step, MAX_STEPS_OFFLINE_ONLINE_RATIO) + + for _ in range(UPDATE_ITERS): + losses = agent.update(batch) + + if (step + 1) % 100 == 0: + fl.log_scalars("hybrid_train", losses, step) + elapsed = time.time() - start_time + hybrid_trange.set_postfix({ + "Value Loss": f"{losses['value']:.4f}", + "Q Loss": f"{losses['q']:.4f}", + "Policy Loss": f"{losses['policy']:.4f}", + "Elapsed": f"{elapsed:.2f}s" + }) + + if (step + 1) % 5000 == 0: + print("Evaluating on benchmarks ...") + eval_start = time.time() + env_time = evaluate_benchmarks(agent, eval_env, step) + print(f"Evaluation done in {time.time() - eval_start:.2f}s (env time: {env_time:.2f}s)") + fl.flush() + + if (step + 1) % 5000 == 0: + os.makedirs(cfg.results_dir, exist_ok=True) + save_path = os.path.join(cfg.results_dir, f"iql_hybrid_step_{step}.pt") + torch.save(agent.state_dict(), save_path) + print(f"Checkpoint saved: {save_path}") + + state = next_state + + print(f"Hybrid fine-tuning finished in {time.time() - start_time:.2f} seconds.") + + +if __name__ == "__main__": + hybrid_finetune() diff --git a/train_ppo.py b/train_ppo.py new file mode 100644 index 0000000..017521b --- /dev/null +++ b/train_ppo.py @@ -0,0 +1,109 @@ +# Load environment variables +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import os +from typing import Optional +from utils.log import print_info, print_success + +# Import environment +from rl_autoschedular.env import Env + +# config, file_logger, device +from rl_autoschedular import config as cfg, file_logger as fl, device + +# Import RL components +from rl_autoschedular.model import HiearchyModel as Model +from rl_autoschedular.trajectory import TrajectoryData +from rl_autoschedular.ppo import ( + collect_trajectory, + ppo_update, + value_update, + evaluate_benchmarks +) + +import time +torch.set_grad_enabled(False) +torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "4"))) + +if cfg.debug: + torch.autograd.set_detect_anomaly(True) + +print_info(f"Config: {cfg}") +print_success(f'Logging to: {fl.run_dir}') + +# Set environments + +# run_name for /tmp/ path +env = Env(is_training=True,run_name="online_ppo") +eval_env = Env(is_training=False,run_name="online_ppo") +print_success(f"Environments initialized: {env.tmp_file}") + +# Set model +model = Model().to(device) +optimizer = torch.optim.Adam( + model.parameters(), + lr=3e-4 +) +print_success("Model initialized") + +train_start = time.perf_counter() +total_env_time = 0.0 +total_eval_time = 0.0 + +# Start training +for step in range(cfg.nb_iterations): + print_info(f"- Main Loop {step + 1}/{cfg.nb_iterations} ({100 * (step + 1) / cfg.nb_iterations:.2f}%)") + trajectory , env_time = collect_trajectory( + model, + env, + step, + ) + total_env_time += env_time + + + + # Fit value model to trajectory rewards + if cfg.value_epochs > 0: + value_update( + trajectory, + model, + optimizer, + step + ) + + ppo_update( + trajectory, + model, + optimizer, + step + ) + + if (step + 1) % 50 == 0: + torch.save( + model.state_dict(), + os.path.join( + env.tmp_file.replace('.mlir', ''), + f'model_{step}.pth' + ) + ) + + if (step + 1) % 50 == 0: + start_eval = time.perf_counter() + print_info('- Evaluating benchmark -') + eval_time = evaluate_benchmarks( + model, + eval_env, + step + ) + end_eval = time.perf_counter() + total_eval_time += end_eval - start_eval +train_end = time.perf_counter() +total_train_time = train_end - train_start - total_eval_time +print_success(f"- Training completed in {total_train_time:.2f} seconds") +print_success(f"- Evaluation completed in {total_eval_time:.2f} seconds") +print_success(f"- Total environment time: {total_env_time:.2f} seconds") +print_success(f"- Total eval env time: {total_eval_time:.2f} seconds") +print_success(f"- Percentage of time in environment: {100 * total_env_time / total_train_time:.2f}%") \ No newline at end of file diff --git a/utils/config.py b/utils/config.py index dc7113b..0612d28 100644 --- a/utils/config.py +++ b/utils/config.py @@ -4,6 +4,8 @@ import json import os +from utils.keys import CONFIG_FILE_PATH + class Config(metaclass=Singleton): """Class to store and load global configuration""" @@ -22,6 +24,8 @@ class Config(metaclass=Singleton): """The order of actions that needs to bo followed""" interchange_mode: Literal['enumerate', 'pointers', 'continuous'] """The method used for interchange action""" + use_img2col: bool + """Whether to use img2col transformation on conv2d operations""" exploration: list[Literal['entropy', 'epsilon']] """The exploration method""" init_epsilon: float @@ -70,13 +74,17 @@ class Config(metaclass=Singleton): """Path to the file containing the execution data""" results_dir: str """Path to the results directory""" + save_model_every: int + """Number of iterations between saving the model""" + evaluate_every: int + """Number of iterations between evaluations""" def __init__(self): """Load the configuration from the JSON file or get existing instance if any. """ # Open the JSON file - with open(os.getenv("CONFIG_FILE_PATH"), "r") as f: + with open(CONFIG_FILE_PATH, "r") as f: config_data: dict[str, Any] = json.load(f) for element, element_t in self.__annotations__.items(): diff --git a/utils/dask_manager.py b/utils/dask_manager.py index 5de6a68..43825f0 100644 --- a/utils/dask_manager.py +++ b/utils/dask_manager.py @@ -11,6 +11,8 @@ from .log import print_alert, print_error, print_info, print_success import os +from utils.keys import CONDA_ENV + if TYPE_CHECKING: from rl_autoschedular.benchmarks import Benchmarks from dask_jobqueue.slurm import SLURMJob @@ -36,6 +38,7 @@ def __init__(self): memory='100GB', walltime='7-00', job_extra_directives=[ + '--reservation=c2', '--nodes=1', '--exclusive', ], @@ -44,7 +47,7 @@ def __init__(self): job_script_prologue=[ 'module load miniconda-nobashrc', 'eval "$(conda shell.bash hook)"', - f'conda activate {os.getenv("CONDA_ENV")}', + f'conda activate {CONDA_ENV}', 'export OMP_NUM_THREADS=12', ], scheduler_options={ diff --git a/utils/data_collector.py b/utils/data_collector.py new file mode 100644 index 0000000..d2ddc2c --- /dev/null +++ b/utils/data_collector.py @@ -0,0 +1,74 @@ +import os +import numpy as np +import torch +from utils.singleton import Singleton +from utils.log import print_success + +class OfflineDataset(metaclass=Singleton): + """Singleton class to collect and store trajectories for offline RL """ + + def __init__(self, save_dir: str = "offline_data", fname: str = "dataset.npz"): + """ + Args: + save_dir (str): Directory to store dataset. + fname (str): Dataset filename. + """ + self.save_dir = save_dir + self.fname = fname + os.makedirs(self.save_dir, exist_ok=True) + + self.buffer = [] # in-memory buffer for efficiency + self.file_path = os.path.join(self.save_dir, self.fname) + + def add_transition(self, obs, action, next_obs, reward, done): + """Add one transition to buffer.""" + self.buffer.append({ + "obs": obs.squeeze(0).cpu().numpy() if torch.is_tensor(obs) else obs, + "action": action.detach().cpu().numpy().squeeze(0) if torch.is_tensor(action) else np.array(action).squeeze(0), + "next_obs": next_obs.squeeze(0).cpu().numpy() if torch.is_tensor(next_obs) else next_obs, + "reward": float(reward), + "done": bool(done), + }) + + def add_trajectory(self, trajectory): + """Add a full trajectory (list of transitions).""" + self.buffer.extend(trajectory) + + def flush(self): + """Save buffer to disk as npz and clear it.""" + if not self.buffer: + return + + # Convert buffer to arrays + obs = np.array([t["obs"] for t in self.buffer], dtype=np.float32) + actions = np.array([t["action"] for t in self.buffer], dtype=np.int64) + next_obs = np.array([t["next_obs"] for t in self.buffer], dtype=np.float32) + rewards = np.array([t["reward"] for t in self.buffer], dtype=np.float32) + dones = np.array([t["done"] for t in self.buffer], dtype=np.bool_) + + if os.path.exists(self.file_path): + # If file exists, append to it + old = np.load(self.file_path) + obs = np.concatenate([old["obs"], obs], axis=0) + actions = np.concatenate([old["actions"], actions], axis=0) + next_obs = np.concatenate([old["next_obs"], next_obs], axis=0) + rewards = np.concatenate([old["rewards"], rewards], axis=0) + dones = np.concatenate([old["dones"], dones], axis=0) + + np.savez_compressed( + self.file_path, + obs=obs, + actions=actions, + next_obs=next_obs, + rewards=rewards, + dones=dones, + ) + + print_success(f"[OfflineDataset] Flushed {len(self.buffer)} transitions -> {self.file_path}") + self.buffer.clear() + + def load(self, mmap_mode=None): + """Load dataset from disk as dict of numpy arrays (optionally memory-mapped).""" + if not os.path.exists(self.file_path): + return {} + return np.load(self.file_path, mmap_mode=mmap_mode) \ No newline at end of file diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000..dcd6e7e --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,359 @@ +from tqdm import tqdm +import sys +import os +import json + +import re +from typing import List + +from rl_autoschedular.state import extract_bench_features_from_file, extract_bench_features_from_code +from rl_autoschedular.transforms import transform_img2col + +def split_convolution_operations(img2col_conv_code: str) -> List[str]: + """ + Deterministic splitter that: + - preserves original SSA names when safe, + - detects how many #map lines belong to op0 by scanning op0 body, + - keeps tensor.collapse_shape lines in op1 (but avoids SSA redefinition by + mapping conflicting function-arg names to the collapse source), + - includes tensor.expand_shape (if present) and makes op1 return the expanded SSA/type, + - removes empty lines when parsing. + + Returns [mlir_op0_text, mlir_op1_text]. + """ + # remove empty lines immediately + lines = [ln for ln in img2col_conv_code.splitlines() if ln.strip()] + + # 1) collect contiguous maps + maps: List[str] = [] + idx = 0 + while idx < len(lines) and lines[idx].strip().startswith("#map"): + maps.append(lines[idx]) + idx += 1 + + # store map names in order for detection + map_names: List[str] = [] + for mline in maps: + mm = re.match(r'^\s*(#map\d*|\#map)\b', mline) + map_names.append(mm.group(1) if mm else "") + + # 2) find the linalg.generic lines that have the tags + op0_line_idx = next(i for i, ln in enumerate(lines) if 'tag = "operation_0"' in ln) + op1_line_idx = next(i for i, ln in enumerate(lines) if 'tag = "operation_1"' in ln) + + # 3) helper: extract balanced brace block starting idx + def extract_block_from(start_idx: int) -> List[str]: + depth = 0 + block = [] + for ln in lines[start_idx:]: + depth += ln.count("{") + depth -= ln.count("}") + block.append(ln) + if depth == 0: + break + return block + + op0_block = extract_block_from(op0_line_idx) + op1_block = extract_block_from(op1_line_idx) + + op0_text = "\n".join(op0_block) + op1_text = "\n".join(op1_block) + + # 4) determine which maps are referenced by op0 (dynamic) + referenced_maps_in_op0 = set(re.findall(r"#map\d*", op0_text)) + if "#map" in op0_text: + referenced_maps_in_op0.add("#map") + + last_ref_idx = -1 + for i, name in enumerate(map_names): + if name and name in referenced_maps_in_op0: + last_ref_idx = i + + if last_ref_idx >= 0: + maps_op0 = maps[: last_ref_idx + 1] + maps_op1 = maps[last_ref_idx + 1 :] + else: + fallback = min(8, len(maps)) + maps_op0 = maps[:fallback] + maps_op1 = maps[fallback:] + + # 5) find any tensor.collapse_shape lines (they appear before op0/op1) + # Parse them into collapse_infos so we can reason about 'result' and 'src'. + collapse_infos: List[dict] = [] + for j in range(idx, op1_line_idx): + ln = lines[j] + if "tensor.collapse_shape" in ln: + # full parse (preferred) + m = re.match( + r'\s*(%[A-Za-z0-9_]+)\s*=\s*tensor\.collapse_shape\s+(%[A-Za-z0-9_]+).*:\s*(tensor<[^>]+>)\s+into\s+(tensor<[^>]+>)', + ln) + if m: + collapse_infos.append({ + 'result': m.group(1), + 'src': m.group(2), + 'src_type': m.group(3), + 'out_type': m.group(4), + 'line': ln.rstrip() + }) + else: + # looser parse: try only names + mm = re.match(r'\s*(%[A-Za-z0-9_]+)\s*=\s*tensor\.collapse_shape\s+(%[A-Za-z0-9_]+).*', ln) + if mm: + collapse_infos.append({ + 'result': mm.group(1), + 'src': mm.group(2), + 'src_type': None, + 'out_type': None, + 'line': ln.rstrip() + }) + + # 6) find any tensor.expand_shape lines after op1 that consume op1 result + expand_info = None + + def find_linalg_result_ssa(block_text: str) -> str: + for ln in block_text.splitlines(): + m = re.match(r"\s*(%[0-9]+)\s*=\s*linalg\.generic", ln) + if m: + return m.group(1) + return None + + linalg_ssa = find_linalg_result_ssa(op1_text) + for j in range(op1_line_idx + len(op1_block), len(lines)): + ln = lines[j] + if "tensor.expand_shape" in ln and linalg_ssa and linalg_ssa in ln: + m = re.match( + r'\s*(%[A-Za-z0-9_]+)\s*=\s*tensor\.expand_shape\s+(' + re.escape(linalg_ssa) + + r').*:\s*(tensor<[^>]+>)\s+into\s+(tensor<[^>]+>)', ln) + if m: + expand_info = { + 'result': m.group(1), + 'src': m.group(2), + 'src_type': m.group(3), + 'out_type': m.group(4), + 'line': ln.rstrip() + } + else: + mm = re.match(r'\s*(%[A-Za-z0-9_]+)\s*=\s*tensor\.expand_shape\s+([%A-Za-z0-9_]+).*', ln) + if mm: + expand_info = {'result': mm.group(1), 'line': ln.rstrip()} + break + + # 7) parse original function args (preserve order and names) + func_line = next((ln for ln in lines if "func.func @main" in ln), "") + arg_match = re.search(r"@main\((.*?)\)\s*->", func_line) + original_arg_list: List[str] = [] + if arg_match: + original_args = arg_match.group(1).strip() + original_arg_list = [a.strip() for a in original_args.split(",") if a.strip()] + + # 8) Build op0 argument list (prefer original %arg0) + op0_args = [a for a in original_arg_list if a.split(":")[0].strip() == "%arg0"] + if not op0_args: + m_tex = re.search(r"tensor<[^>]+>", op0_text) + img_type = m_tex.group(0) if m_tex else "tensor" + op0_args = [f"%arg0: {img_type}"] + + # op0 body: include any tensor.empty() lines nearest before op0 (keeps naming) + tensor_empty_line = None + for j in range(op0_line_idx - 1, -1, -1): + if "tensor.empty" in lines[j]: + tensor_empty_line = lines[j].rstrip() + break + op0_setup_lines = [tensor_empty_line] if tensor_empty_line else [] + op0_text_block = "\n".join(op0_setup_lines + op0_block) + + # 9) Build op1 argument list carefully while avoiding SSA redefinition: + # - if original args include a collapse-result name, *replace* that arg with + # the collapse *source* (preserve source type when available). This allows us + # to keep the collapse line in op1 and avoid redefinition while keeping naming traceable. + # - preserve original arg order as much as possible. + # parse op1 ins(...) line to learn im2col inputs order and types + op1_ins_line = next((ln for ln in op1_block if "ins(" in ln), "") + in_names: List[str] = [] + in_types: List[str] = [] + m_ins = re.search(r"ins\((.*?)\)", op1_ins_line) + if m_ins: + inside = m_ins.group(1) + if ":" in inside: + names_part, types_part = inside.split(":", 1) + in_names = re.findall(r"%[A-Za-z0-9_]+", names_part) + in_types = re.findall(r"tensor<[^>]+>", types_part) + else: + in_names = re.findall(r"%[A-Za-z0-9_]+", inside) + + # Map collapse result -> src for quick lookup + collapse_result_to_src = {ci['result']: ci for ci in collapse_infos} + + # Build op1_arg_list starting from original_arg_list but remap any arg that equals a collapse result + op1_arg_list: List[str] = [] + seen_args = set() + for orig in original_arg_list: + name = orig.split(":")[0].strip() + # If this original argument *is* a collapse result, replace it with the collapse source + if name in collapse_result_to_src: + ci = collapse_result_to_src[name] + src_name = ci['src'] + # try to preserve original arg type for the src if it's in original_arg_list + src_orig_entry = next((a for a in original_arg_list if a.split(":")[0].strip() == src_name), None) + if src_orig_entry: + entry = src_orig_entry + else: + # fallback to parsed src_type if present + if ci.get('src_type'): + entry = f"{src_name}: {ci['src_type']}" + else: + entry = f"{src_name}: tensor" + # append only if not already added + if entry not in op1_arg_list: + op1_arg_list.append(entry) + seen_args.add(entry.split(":")[0].strip()) + else: + # normal case: keep the original arg if it's relevant (we'll filter later) + if orig not in op1_arg_list: + op1_arg_list.append(orig) + seen_args.add(name) + + # Now ensure the im2col inputs from ins(...) are present in op1_arg_list: + for idx_name, nm in enumerate(in_names): + # if nm is already in op1_arg_list (by SSA name), skip + if any(nm == (a.split(":")[0].strip()) for a in op1_arg_list): + continue + t = in_types[idx_name] if idx_name < len(in_types) else "tensor" + op1_arg_list.append(f"{nm}: {t}") + + # If we still ended up empty (very unlikely), fallback to original args containing %arg* + if not op1_arg_list: + op1_arg_list = [a for a in original_arg_list if a.startswith("%arg")] + + # Deduplicate preserving order + final_op1_arg_list: List[str] = [] + seen = set() + for a in op1_arg_list: + name = a.split(":")[0].strip() + if name not in seen: + final_op1_arg_list.append(a) + seen.add(name) + op1_arg_list = final_op1_arg_list + + # 10) Build op1 body text: include collapse lines first (preserve exact text), + # then the op1 linalg block, then the expand line (if any). + # Because we replaced collapse-arg names with their sources in signature above, + # including collapse lines verbatim is safe (no redefinition). + op1_body_parts = [] + for ci in collapse_infos: + # If, for some reason, collapse result name is still present as an argument name, + # we must avoid redefining the arg. But our remap step above should prevent that. + res_name = ci['result'] + arg_names_set = {a.split(":")[0].strip() for a in op1_arg_list} + if res_name in arg_names_set: + # As a safety net: rename the collapse result inside the line to a fresh SSA and + # update any subsequent references. (This should rarely occur due to remapping.) + base = res_name.lstrip("%") + new_name = f"%{base}_from_collapse" + collapse_line = ci['line'].replace(res_name, new_name, 1) + op1_body_parts.append(collapse_line) + # also replace uses of res_name inside op1_text with new_name + op1_text = re.sub(rf'(? str: + joined = "\n".join(block_lines) + m = re.search(r"->\s*(tensor<[^>]+>)", joined) + return m.group(1) if m else "tensor" + + if expand_info: + op1_result_ssa = expand_info['result'] + op1_result_type = expand_info.get('out_type') or extract_linalg_result_type(op1_block) + else: + op1_result_ssa = find_linalg_result_ssa(op1_text) or "%result" + op1_result_type = extract_linalg_result_type(op1_block) + + # for op0, extract its result SSA and type + op0_result_ssa = find_linalg_result_ssa(op0_text) or "%result" + op0_result_type = extract_linalg_result_type(op0_block) + + # 12) Build final MLIR module texts (op0 and op1) preserving naming + def build_module(maps: List[str], arg_list: List[str], result_type: str, body_text: str, result_ssa: str) -> str: + arg_sig = "(" + ", ".join(arg_list) + ")" + ret_sig = "(" + result_type + ", i64)" + return ( + "\n".join(maps) + "\n\n" + "module {\n" + " func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface}\n\n" + f" func.func @main{arg_sig}\n" + f" -> {ret_sig} attributes {{llvm.emit_c_interface}} {{\n\n" + " %0 = call @nanoTime() : () -> i64\n" + f"{body_text}\n" + " %4 = call @nanoTime() : () -> i64\n" + " %5 = arith.subi %4, %0 : i64\n\n" + f" return {result_ssa}, %5 : {result_type}, i64\n" + " }\n" + "}\n" + ) + + mlir_op0 = build_module(maps_op0, op0_args, op0_result_type, op0_text_block, op0_result_ssa) + mlir_op1 = build_module(maps_op1, op1_arg_list, op1_result_type, op1_body_text, op1_result_ssa) + + return [mlir_op0, mlir_op1] + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python utils/data_utils.py ") + sys.exit(1) + path_to_folder = sys.argv[1] + if not os.path.isdir(path_to_folder): + print(f"Error: {path_to_folder} is not a valid directory.") + sys.exit(1) + + with open(f"{path_to_folder}/../benchmarks_split.json", 'r') as f: + benchmarks_split = json.load(f) + + code_files = [f for f in os.listdir(path_to_folder) if f.endswith('.mlir')] + files_tqdm = tqdm(code_files, unit='file') + for code_file in files_tqdm: + bench_name = code_file.replace('.mlir', '') + files_tqdm.set_postfix_str(bench_name) + full_path = os.path.join(path_to_folder, code_file) + with open(full_path, 'r') as f: + bench_features = extract_bench_features_from_file(bench_name, full_path, 0) + # print(f"EXTRACT FEATUERES FROM FILE: \n", bench_features.code) + code_i2c = transform_img2col(bench_features.code, "operation_0") + # print(f"I2C Code: \n", code_i2c) + bench_features_i2c = extract_bench_features_from_code(bench_name, code_i2c, 0) + # print(f"EXTRACT FEATUERES FROM CODE: \n", bench_features_i2c.code) + with open(f"{path_to_folder}/{bench_name}_i2c.mlir", 'w') as f: + f.write(bench_features_i2c.code) + + try: + op0, op1 = split_convolution_operations(bench_features_i2c.code) + + with open(f"{path_to_folder}/{bench_name}_op0.mlir", 'w') as f: + f.write(op0) + with open(f"{path_to_folder}/{bench_name}_op1.mlir", 'w') as f: + f.write(op1) + + if bench_name in benchmarks_split['train']: + benchmarks_split['train'].append(f"{bench_name}_i2c") + benchmarks_split['train'].append(f"{bench_name}_op0") + benchmarks_split['train'].append(f"{bench_name}_op1") + else: + benchmarks_split['eval'].append(f"{bench_name}_i2c") + benchmarks_split['eval'].append(f"{bench_name}_op0") + benchmarks_split['eval'].append(f"{bench_name}_op1") + + except Exception as e: + print(f"Failed to split convolution code `{bench_name}`: {e}") + + with open(f"{path_to_folder}/../benchmarks_split.json", 'w') as f: + json.dump(benchmarks_split, f, indent=4) + \ No newline at end of file diff --git a/utils/keys.py b/utils/keys.py new file mode 100644 index 0000000..43980ce --- /dev/null +++ b/utils/keys.py @@ -0,0 +1,61 @@ +import os + +from dotenv import load_dotenv + +load_dotenv() + +VERBOSE = False + +CONDA_ENV = os.getenv("CONDA_ENV") +if CONDA_ENV is None: + raise ValueError("CONDA_ENV environment variable is not set.") +elif VERBOSE: + print(f"Using conda environment: {CONDA_ENV}") + +NEPTUNE_PROJECT = os.getenv("NEPTUNE_PROJECT") +if NEPTUNE_PROJECT is None: + raise ValueError("NEPTUNE_PROJECT environment variable is not set.") +elif VERBOSE: + print(f"Using Neptune project: {NEPTUNE_PROJECT}") + +NEPTUNE_TOKEN = os.getenv("NEPTUNE_TOKEN") +if NEPTUNE_TOKEN is None: + raise ValueError("NEPTUNE_TOKEN environment variable is not set.") +elif VERBOSE: + print("Neptune token is set.") + +LLVM_BUILD_PATH = os.getenv("LLVM_BUILD_PATH") +if LLVM_BUILD_PATH is None: + raise ValueError("LLVM_BUILD_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using LLVM build path: {LLVM_BUILD_PATH}") + +MLIR_SHARED_LIBS = os.getenv("MLIR_SHARED_LIBS") +if MLIR_SHARED_LIBS is None: + raise ValueError("MLIR_SHARED_LIBS environment variable is not set.") +elif VERBOSE: + print(f"Using MLIR shared libs: {MLIR_SHARED_LIBS}") + +AST_DUMPER_BIN_PATH = os.getenv("AST_DUMPER_BIN_PATH") +if AST_DUMPER_BIN_PATH is None: + raise ValueError("AST_DUMPER_BIN_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using AST dumper bin path: {AST_DUMPER_BIN_PATH}") + +VECTORIZER_BIN_PATH = os.getenv("VECTORIZER_BIN_PATH") +if VECTORIZER_BIN_PATH is None: + raise ValueError("VECTORIZER_BIN_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using vectorizer bin path: {VECTORIZER_BIN_PATH}") + +PRE_VEC_BIN_PATH = os.getenv("PRE_VEC_BIN_PATH") +if PRE_VEC_BIN_PATH is None: + raise ValueError("PRE_VEC_BIN_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using pre-vectorizer bin path: {PRE_VEC_BIN_PATH}") + +CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH") +if CONFIG_FILE_PATH is None: + raise ValueError("CONFIG_FILE_PATH environment variable is not set.") +elif VERBOSE: + print(f"Using config file path: {CONFIG_FILE_PATH}") \ No newline at end of file From 9bedd55f49de2aee344f44b8fb0ba8d380c9b9a3 Mon Sep 17 00:00:00 2001 From: Mohammed Tirichine <48598170+mohph197@users.noreply.github.com> Date: Sun, 21 Dec 2025 13:48:48 +0100 Subject: [PATCH 2/3] Delete scripts/neptune-sync.sh --- scripts/neptune-sync.sh | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 scripts/neptune-sync.sh diff --git a/scripts/neptune-sync.sh b/scripts/neptune-sync.sh deleted file mode 100644 index 30de37f..0000000 --- a/scripts/neptune-sync.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -# Define the resource requirements here using #SBATCH - -# SBATCH -j neptune_sync -#SBATCH -p compute -#SBATCH --reservation=c2 -#SBATCH --nodes=1 -#SBATCH -c 4 -#SBATCH --mem=16G -#SBATCH -t 07-00 -#SBATCH -o logs/neptune/%j.out -#SBATCH --mail-type=ALL -#SBATCH --mail-user=kb5213@nyu.edu - -# Resource requirements end here - -# Add the lines for running your code/application -module load miniconda-nobashrc -eval "$(conda shell.bash hook)" - -# Activate any environments if required -conda activate llvm-build - -# Execute the code -python $SCRATCH/MLIR-RL/neptune_sync.py From 4df0bd6d4890b48a16d78fe93365441699bd14ae Mon Sep 17 00:00:00 2001 From: Mohammed Tirichine <48598170+mohph197@users.noreply.github.com> Date: Sun, 21 Dec 2025 13:50:06 +0100 Subject: [PATCH 3/3] Delete tests/cuda.py --- tests/cuda.py | 48 ------------------------------------------------ 1 file changed, 48 deletions(-) delete mode 100644 tests/cuda.py diff --git a/tests/cuda.py b/tests/cuda.py deleted file mode 100644 index f4afc62..0000000 --- a/tests/cuda.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import platform -import subprocess - -def check_cuda(): - # Check if CUDA is available - cuda_available = torch.cuda.is_available() - print(f"CUDA Available: {cuda_available}") - - if cuda_available: - # Number of GPUs available - num_gpus = torch.cuda.device_count() - print(f"Number of GPUs: {num_gpus}") - - # CUDA version from torch - print(f"PyTorch CUDA Version: {torch.version.cuda}") - - # Information about each GPU - for i in range(num_gpus): - print(f"\nDevice {i}:") - print(f" Name: {torch.cuda.get_device_name(i)}") - print(f" Total Memory: {torch.cuda.get_device_properties(i).total_memory / (1024 ** 3):.2f} GB") - print(f" Memory Allocated: {torch.cuda.memory_allocated(i) / (1024 ** 3):.2f} GB") - print(f" Memory Cached: {torch.cuda.memory_reserved(i) / (1024 ** 3):.2f} GB") - print(f" Compute Capability: {torch.cuda.get_device_capability(i)}") - print(f" Device Properties: {torch.cuda.get_device_properties(i)}") - - # Driver version - print(f"Driver Version: {torch.version.cuda}") - - # Run nvidia-smi command to get additional GPU details - try: - nvidia_smi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, text=True) - print(f"\nNVIDIA-SMI Output:\n{nvidia_smi_output.stdout}") - except FileNotFoundError: - print("nvidia-smi command not found. Ensure NVIDIA drivers are installed correctly.") - - else: - print("No CUDA-enabled GPU detected.") - - # Print system and Python environment information - print("\nSystem and Environment Information:") - print(f" Operating System: {platform.system()} {platform.release()} ({platform.version()})") - print(f" Python Version: {platform.python_version()}") - print(f" PyTorch Version: {torch.__version__}") - -if __name__ == "__main__": - check_cuda() \ No newline at end of file