From fdb09fef19ace39d259df895ce442f549f816fb7 Mon Sep 17 00:00:00 2001 From: Oucherif Ouail Date: Sat, 11 Oct 2025 12:54:39 +0100 Subject: [PATCH 1/5] added offline data collector ,iql offline pretraining implementation, iql online finetuning --- .gitignore | 15 + README.md | 178 +-- config/.gitignore | 8 +- config/example.json | 37 - create_dataset.py | 48 + dask-logs/.gitignore | 2 - data.py | 71 ++ data/all/.gitignore | 2 - data/debug/.gitignore | 2 - data/features/.gitignore | 2 - data/lqcd/bench/.gitignore | 4 +- data/lqcd/control/execution_times.json | 4 +- data/lqcd/control/test_AB_1.mlir | 90 +- data/lqcd/control/test_AB_1.mlir.npy | Bin data/lqcd/control/test_AB_1.mlir.npz | Bin data/multi/.gitignore | 2 - data/nn-old/.gitignore | 2 - data/nn/.gitignore | 3 - data/nn/gen/.gitignore | 3 - data/nn/gen/data_generation_random.py | 1100 ----------------- data/nn/gen/example.yaml | 63 - data/polybench/.gitignore | 3 - data/polybench/gen/.gitignore | 1 - data/polybench/gen/2mm.mlir.bench | 45 - data/polybench/gen/2mm_gen.py | 120 -- data/polybench/gen/3mm.mlir.bench | 56 - data/polybench/gen/3mm_gen.py | 121 -- data/polybench/gen/fdtd.mlir.bench | 68 - data/polybench/gen/fdtd_gen.py | 129 -- data/polybench/gen/floyd.mlir.bench | 17 - data/polybench/gen/floyd_gen.py | 114 -- data/polybench/gen/gemm.mlir.bench | 27 - data/polybench/gen/gemm_gen.py | 120 -- data/polybench/gen/jacobi.mlir.bench | 45 - data/polybench/gen/jacobi_gen.py | 122 -- data/polybench/gen/seidel.mlir.bench | 58 - data/polybench/gen/seidel_gen.py | 124 -- demo.ipynb | 38 +- demo.py | 66 +- eval.py | 47 + evaluate.py | 134 +- filelog_clean.py | 46 +- fill_db.py | 166 +-- gen.py | 686 +++++----- init_env.py | 42 + iql/__init__.py | 0 iql/iql_agent.py | 286 +++++ iql/iql_agent_device.py | 282 +++++ iql/iql_config.py | 159 +++ iql/policy.py | 84 ++ iql/q_functions.py | 282 +++++ iql/singleton.py | 8 + iql/value_function.py | 71 ++ iql_online.py | 230 ++++ models/.gitignore | 6 +- neptune_sync.py | 144 +-- requirements.txt | 12 +- results/.gitignore | 2 - rl_autoschedular/__init__.py | 29 +- rl_autoschedular/actions/__init__.py | 486 ++++---- rl_autoschedular/actions/base.py | 471 ++++--- rl_autoschedular/actions/interchange.py | 516 ++++---- rl_autoschedular/actions/no_transformation.py | 35 +- rl_autoschedular/actions/tiled_fusion.py | 54 +- .../actions/tiled_parallelization.py | 62 +- rl_autoschedular/actions/tiling.py | 258 ++-- rl_autoschedular/actions/vectorization.py | 171 +-- rl_autoschedular/benchmarks.py | 54 - rl_autoschedular/env.py | 679 +++++----- rl_autoschedular/evaluation.py | 348 ++++++ rl_autoschedular/execution.py | 2 +- rl_autoschedular/model.py | 442 ++++--- rl_autoschedular/observation.py | 472 ++++--- rl_autoschedular/ppo.py | 752 +++++------ rl_autoschedular/state.py | 629 +++++----- rl_autoschedular/trajectory.py | 590 +++------ rl_autoschedular/transforms.py | 1021 ++++++++------- scripts/.gitignore | 6 - scripts/neptune-sync.sh | 23 - scripts/train_example.sh | 28 - setup.sh | 104 ++ test.py | 11 + tests/__init__.py | 0 tests/benchmarks/matmul.mlir | 8 + tests/inference.py | 75 ++ tests/test_action.py | 132 ++ tests/test_execution.py | 63 + tests/test_model.py | 69 ++ tests/test_state.py | 101 ++ tests/tmp.mlir | 15 + tests/transform_test.py | 101 ++ tmp-debug/.gitignore | 5 + {logs => tmp-debug/exec}/.gitignore | 6 +- train.py | 214 ++-- train_iql.py | 209 ++++ utils/config.py | 387 ++++-- utils/dask_manager.py | 158 --- utils/data_collector.py | 74 ++ utils/file_logger.py | 105 +- utils/log.py | 69 +- utils/singleton.py | 16 +- viz.py | 24 + viz_online.py | 62 + 103 files changed, 7586 insertions(+), 7147 deletions(-) mode change 100644 => 100755 README.md mode change 100644 => 100755 config/.gitignore delete mode 100644 config/example.json create mode 100755 create_dataset.py delete mode 100644 dask-logs/.gitignore create mode 100644 data.py delete mode 100644 data/all/.gitignore delete mode 100644 data/debug/.gitignore delete mode 100644 data/features/.gitignore mode change 100644 => 100755 data/lqcd/bench/.gitignore mode change 100644 => 100755 data/lqcd/control/execution_times.json mode change 100644 => 100755 data/lqcd/control/test_AB_1.mlir mode change 100644 => 100755 data/lqcd/control/test_AB_1.mlir.npy mode change 100644 => 100755 data/lqcd/control/test_AB_1.mlir.npz delete mode 100644 data/multi/.gitignore delete mode 100644 data/nn-old/.gitignore delete mode 100644 data/nn/.gitignore delete mode 100644 data/nn/gen/.gitignore delete mode 100644 data/nn/gen/data_generation_random.py delete mode 100644 data/nn/gen/example.yaml delete mode 100644 data/polybench/.gitignore delete mode 100644 data/polybench/gen/.gitignore delete mode 100644 data/polybench/gen/2mm.mlir.bench delete mode 100644 data/polybench/gen/2mm_gen.py delete mode 100644 data/polybench/gen/3mm.mlir.bench delete mode 100644 data/polybench/gen/3mm_gen.py delete mode 100644 data/polybench/gen/fdtd.mlir.bench delete mode 100644 data/polybench/gen/fdtd_gen.py delete mode 100644 data/polybench/gen/floyd.mlir.bench delete mode 100644 data/polybench/gen/floyd_gen.py delete mode 100644 data/polybench/gen/gemm.mlir.bench delete mode 100644 data/polybench/gen/gemm_gen.py delete mode 100644 data/polybench/gen/jacobi.mlir.bench delete mode 100644 data/polybench/gen/jacobi_gen.py delete mode 100644 data/polybench/gen/seidel.mlir.bench delete mode 100644 data/polybench/gen/seidel_gen.py mode change 100644 => 100755 demo.ipynb mode change 100644 => 100755 demo.py create mode 100644 eval.py mode change 100644 => 100755 evaluate.py mode change 100644 => 100755 filelog_clean.py mode change 100644 => 100755 fill_db.py mode change 100644 => 100755 gen.py create mode 100644 init_env.py create mode 100644 iql/__init__.py create mode 100755 iql/iql_agent.py create mode 100644 iql/iql_agent_device.py create mode 100755 iql/iql_config.py create mode 100755 iql/policy.py create mode 100755 iql/q_functions.py create mode 100755 iql/singleton.py create mode 100755 iql/value_function.py create mode 100644 iql_online.py mode change 100644 => 100755 models/.gitignore mode change 100644 => 100755 neptune_sync.py mode change 100644 => 100755 requirements.txt delete mode 100644 results/.gitignore mode change 100644 => 100755 rl_autoschedular/__init__.py mode change 100644 => 100755 rl_autoschedular/actions/__init__.py mode change 100644 => 100755 rl_autoschedular/actions/base.py mode change 100644 => 100755 rl_autoschedular/actions/interchange.py mode change 100644 => 100755 rl_autoschedular/actions/no_transformation.py mode change 100644 => 100755 rl_autoschedular/actions/tiled_fusion.py mode change 100644 => 100755 rl_autoschedular/actions/tiled_parallelization.py mode change 100644 => 100755 rl_autoschedular/actions/tiling.py mode change 100644 => 100755 rl_autoschedular/actions/vectorization.py delete mode 100644 rl_autoschedular/benchmarks.py mode change 100644 => 100755 rl_autoschedular/env.py create mode 100755 rl_autoschedular/evaluation.py mode change 100644 => 100755 rl_autoschedular/model.py mode change 100644 => 100755 rl_autoschedular/observation.py mode change 100644 => 100755 rl_autoschedular/ppo.py mode change 100644 => 100755 rl_autoschedular/state.py mode change 100644 => 100755 rl_autoschedular/trajectory.py mode change 100644 => 100755 rl_autoschedular/transforms.py delete mode 100644 scripts/.gitignore delete mode 100644 scripts/neptune-sync.sh delete mode 100644 scripts/train_example.sh create mode 100755 setup.sh create mode 100644 test.py create mode 100644 tests/__init__.py create mode 100644 tests/benchmarks/matmul.mlir create mode 100644 tests/inference.py create mode 100644 tests/test_action.py create mode 100644 tests/test_execution.py create mode 100644 tests/test_model.py create mode 100644 tests/test_state.py create mode 100644 tests/tmp.mlir create mode 100644 tests/transform_test.py create mode 100755 tmp-debug/.gitignore rename {logs => tmp-debug/exec}/.gitignore (95%) mode change 100644 => 100755 create mode 100644 train_iql.py mode change 100644 => 100755 utils/config.py delete mode 100644 utils/dask_manager.py create mode 100644 utils/data_collector.py mode change 100644 => 100755 utils/file_logger.py mode change 100644 => 100755 utils/log.py mode change 100644 => 100755 utils/singleton.py create mode 100644 viz.py create mode 100644 viz_online.py diff --git a/.gitignore b/.gitignore index c206e8f..3566c9e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,18 @@ .neptune *__pycache__ tools/*/build +.environment +*.csv +*.png +mlir-venv/ +iql_results/ +llvm-project/ +*.json +*.log +checkpoints/ +cache/ +data/ +results/ +offline_iql_adv_norm_gradclip_cosine_scheduler/ +offline_dataset/ +tmp/* \ No newline at end of file diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 701f1a1..310321e --- a/README.md +++ b/README.md @@ -1,90 +1,90 @@ -## Getting Started -This is an example of how you may give instructions on setting up your project locally. -To get a local copy up and running follow these simple example steps. -### Prerequisites: -###### Required -1) [CMake](https://cmake.org/): version 3.20 or greater. -2) [Ninja](https://ninja-build.org/). -3) [Gcc](https://gcc.gnu.org/) : version 13.2. -4) [Gxx]: version 13.2. -5) [LLD](https://lld.llvm.org/). -6) [Python](https://www.python.org/downloads/): version 3.11 or greater. -### Setup -#### 1. Building MLIR : -```sh -git clone --depth 1 -b release/19.x https://github.com/llvm/llvm-project.git -mkdir llvm-project/build -cd llvm-project/build -cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ --DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=X86 -DLLVM_ENABLE_ASSERTIONS=ON \ --DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DLLVM_ENABLE_LLD=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON - -cmake --build . --target check-mlir -``` -#### 2. Install python requirements : -```sh -pip install -r requirements.txt -``` -#### 3. Setup environment variables : -Change llvm related variables according to your llvm-project folder path. -```env -NEPTUNE_PROJECT= -NEPTUNE_TOKEN= -LLVM_BUILD_PATH=llvm-project/build -MLIR_SHARED_LIBS=llvm-project/build/lib/libomp.so,llvm-project/build/lib/libmlir_c_runner_utils.so,llvm-project/build/lib/libmlir_runner_utils.so -AST_DUMPER_BIN_PATH=tools/ast_dumper/build/bin/AstDumper -VECTORIZER_BIN_PATH=tools/vectorizer/build/bin/Vectorizer -``` -### Documentation -#### 1. Jobs -For running jobs using slurm script examples are provided in the `scripts/` folder. -#### 2. Configuration -Configuring the model on a specific case can be done by setting a JSON config file containing all required settings. Configuration JSON file examples are provided in the `config/` folder. -The following JSON content is an example of a config file: -```json -{ - "max_num_stores_loads": 7, - "max_num_loops": 7, - "max_num_load_store_dim": 7, - "num_tile_sizes": 7, - "num_transformations": 6, - "vect_size_limit": 2048, - "use_bindings": false, - "use_vectorizer": false, - "data_format": "json", - "optimization_mode": "last", - "benchmarks_folder_path": "", - "len_trajectory": 64, - "ppo_batch_size": 64, - "nb_iterations": 10000, - "ppo_epochs": 4, - "entropy_coef": 0.01, - "lr": 0.001, - "truncate": 5, - "json_file": "data/nn/train_operations.json", - "tags": ["nn"], - "logging": true -} -``` -The following list describes every required setting in a configuration file. -- `max_num_stores_loads (int)`: The maximum number of loads in the nested loops. -- `max_num_loops (int)`: The max number of nested loops. -- `max_num_load_store_dim (int)`: The max number of dimensions in load/store buffers. -- `num_tile_sizes (int)`: The number of possible tile sizes for a loop. -- `num_transformations (int)`: The number of transformations. -- `vect_size_limit (int)`: Vectorization size limit to prevent large sizes vectorization. -- `use_bindings (bool)`: Flag to enable using python bindings for execution, if False, the execution will be done using the command line. Default is False. -- `use_vectorizer (bool)`: Flag to enable using the vectorizer C++ program for vectorization, if False, vectorization is done using transform dialect directly. Default is False. -- `data_format (Literal["json", "mlir"])`: The format of the data, can be either "json" or "mlir". "json" mode reads json files containing benchmark features, "mlir" mode reads mlir code files directly and extract features from it using AST dumper. Default is "json". -- `optimization_mode (Literal["last", "all"])`: The optimization mode to use, "last" will optimize only the last operation, "all" will optimize all operations in the code. Default is "last". -- `benchmarks_folder_path (str)`: Path to the benchmarks folder. Can be empty if data format is set to "json". -- `len_trajectory (int)`: Length of the trajectory used for PPO. -- `ppo_batch_size (int)`: Batch size for PPO. -- `nb_iterations (int)`: Number of training iterations. -- `ppo_epochs (int)`: Number of epochs for PPO. -- `entropy_coef (float)`: Entropy coefficient. -- `lr (float)`: Learning rate. -- `truncate (int)`: Maximum number of steps of a schedule for an operation. -- `json_file (str)`: Path to the JSON file containing the benchmarks code and features if data format is set to "json". Otherwise, it should contain original execution times for every benchmark in the benchmark folder. -- `tags (list[str])`: List of tags to add to the neptune experiment. +## Getting Started +This is an example of how you may give instructions on setting up your project locally. +To get a local copy up and running follow these simple example steps. +### Prerequisites: +###### Required +1) [CMake](https://cmake.org/): version 3.20 or greater. +2) [Ninja](https://ninja-build.org/). +3) [Gcc](https://gcc.gnu.org/) : version 13.2. +4) [Gxx]: version 13.2. +5) [LLD](https://lld.llvm.org/). +6) [Python](https://www.python.org/downloads/): version 3.11 or greater. +### Setup +#### 1. Building MLIR : +```sh +git clone --depth 1 -b release/19.x https://github.com/llvm/llvm-project.git +mkdir llvm-project/build +cd llvm-project/build +cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ +-DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=X86 -DLLVM_ENABLE_ASSERTIONS=ON \ +-DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DLLVM_ENABLE_LLD=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON + +cmake --build . --target check-mlir +``` +#### 2. Install python requirements : +```sh +pip install -r requirements.txt +``` +#### 3. Setup environment variables : +Change llvm related variables according to your llvm-project folder path. +```env +NEPTUNE_PROJECT= +NEPTUNE_TOKEN= +LLVM_BUILD_PATH=llvm-project/build +MLIR_SHARED_LIBS=llvm-project/build/lib/libomp.so,llvm-project/build/lib/libmlir_c_runner_utils.so,llvm-project/build/lib/libmlir_runner_utils.so +AST_DUMPER_BIN_PATH=tools/ast_dumper/build/bin/AstDumper +VECTORIZER_BIN_PATH=tools/vectorizer/build/bin/Vectorizer +``` +### Documentation +#### 1. Jobs +For running jobs using slurm script examples are provided in the `scripts/` folder. +#### 2. Configuration +Configuring the model on a specific case can be done by setting a JSON config file containing all required settings. Configuration JSON file examples are provided in the `config/` folder. +The following JSON content is an example of a config file: +```json +{ + "max_num_stores_loads": 7, + "max_num_loops": 7, + "max_num_load_store_dim": 7, + "num_tile_sizes": 7, + "num_transformations": 6, + "vect_size_limit": 2048, + "use_bindings": false, + "use_vectorizer": false, + "data_format": "json", + "optimization_mode": "last", + "benchmarks_folder_path": "", + "len_trajectory": 64, + "ppo_batch_size": 64, + "nb_iterations": 10000, + "ppo_epochs": 4, + "entropy_coef": 0.01, + "lr": 0.001, + "truncate": 5, + "json_file": "data/nn/train_operations.json", + "tags": ["nn"], + "logging": true +} +``` +The following list describes every required setting in a configuration file. +- `max_num_stores_loads (int)`: The maximum number of loads in the nested loops. +- `max_num_loops (int)`: The max number of nested loops. +- `max_num_load_store_dim (int)`: The max number of dimensions in load/store buffers. +- `num_tile_sizes (int)`: The number of possible tile sizes for a loop. +- `num_transformations (int)`: The number of transformations. +- `vect_size_limit (int)`: Vectorization size limit to prevent large sizes vectorization. +- `use_bindings (bool)`: Flag to enable using python bindings for execution, if False, the execution will be done using the command line. Default is False. +- `use_vectorizer (bool)`: Flag to enable using the vectorizer C++ program for vectorization, if False, vectorization is done using transform dialect directly. Default is False. +- `data_format (Literal["json", "mlir"])`: The format of the data, can be either "json" or "mlir". "json" mode reads json files containing benchmark features, "mlir" mode reads mlir code files directly and extract features from it using AST dumper. Default is "json". +- `optimization_mode (Literal["last", "all"])`: The optimization mode to use, "last" will optimize only the last operation, "all" will optimize all operations in the code. Default is "last". +- `benchmarks_folder_path (str)`: Path to the benchmarks folder. Can be empty if data format is set to "json". +- `len_trajectory (int)`: Length of the trajectory used for PPO. +- `ppo_batch_size (int)`: Batch size for PPO. +- `nb_iterations (int)`: Number of training iterations. +- `ppo_epochs (int)`: Number of epochs for PPO. +- `entropy_coef (float)`: Entropy coefficient. +- `lr (float)`: Learning rate. +- `truncate (int)`: Maximum number of steps of a schedule for an operation. +- `json_file (str)`: Path to the JSON file containing the benchmarks code and features if data format is set to "json". Otherwise, it should contain original execution times for every benchmark in the benchmark folder. +- `tags (list[str])`: List of tags to add to the neptune experiment. - `logging (bool)`: Flag to enable logging to neptune. \ No newline at end of file diff --git a/config/.gitignore b/config/.gitignore old mode 100644 new mode 100755 index 80ce6c7..3619be8 --- a/config/.gitignore +++ b/config/.gitignore @@ -1,5 +1,5 @@ -# Ignore everything in this directory -* -# Except these files -!.gitignore +# Ignore everything in this directory +* +# Except these files +!.gitignore !example.json \ No newline at end of file diff --git a/config/example.json b/config/example.json deleted file mode 100644 index d5053ce..0000000 --- a/config/example.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "max_num_stores_loads": 7, - "max_num_loops": 7, - "max_num_load_store_dim": 7, - "num_tile_sizes": 7, - "vect_size_limit": 2048, - "order": [["TP"], ["V", "NT"]], - "interchange_mode": "enumerate", - "exploration": ["entropy"], - "init_epsilon": 0.1, - "new_architecture": false, - "normalize_bounds": "max", - "normalize_adv": "standard", - "sparse_reward": true, - "split_ops": false, - "reuse_experience": "none", - "activation": "relu", - "benchmarks_folder_path": "", - "bench_count": 20, - "replay_count": 10, - "nb_iterations": 10000, - "ppo_epochs": 4, - "ppo_batch_size": 4, - "value_epochs": 32, - "value_batch_size": 32, - "value_coef": 0.5, - "value_clip": false, - "entropy_coef": 0.01, - "lr": 0.001, - "truncate": 5, - "json_file": "data/nn/train_operations.json", - "eval_json_file": "data/nn/eval_operations.json", - "tags": ["nn"], - "debug": false, - "main_exec_data_file": "", - "results_dir": "results" -} \ No newline at end of file diff --git a/create_dataset.py b/create_dataset.py new file mode 100755 index 0000000..1cf05cc --- /dev/null +++ b/create_dataset.py @@ -0,0 +1,48 @@ +import os +from dotenv import load_dotenv +import json +import pathlib +from rl_autoschedular.execution import Execution +from utils.config import Config + +load_dotenv(override=True) + +config = Config() +cache_file = "cache/execution.json" +exec = Execution(exec_data_file=cache_file) + +train_operations = {} +# eval_operations = {} + +for benchmark in os.listdir(config.benchmarks_folder_path): + benchmark_name = benchmark.split('.')[0] + mlir_code_path = f"data/matmul/online_data/{benchmark}" + mlir_code = pathlib.Path(mlir_code_path).read_text() + time_ns, success, cache_miss = exec.execute_code(mlir_code, benchmark_name, seq=[]) + + train_operations[benchmark_name] = time_ns + # eval_operations[benchmark_name] = time_ns + + print(f"Benchmark: {benchmark_name}") + print(f"Execution time: {time_ns} ns") + print(f"Success: {success}, Cache miss: {cache_miss}") + print("-" * 40) + +# --- helper function to append safely --- +def append_json(file_path, new_data): + if os.path.exists(file_path): + with open(file_path, 'r') as f: + try: + data = json.load(f) + except json.JSONDecodeError: + data = {} + else: + data = {} + # update old with new + data.update(new_data) + with open(file_path, 'w') as f: + json.dump(data, f, indent=4) + +# append instead of overwrite +append_json(config.json_file, train_operations) +# append_json(config.eval_json_file, eval_operations) diff --git a/dask-logs/.gitignore b/dask-logs/.gitignore deleted file mode 100644 index c96a04f..0000000 --- a/dask-logs/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore \ No newline at end of file diff --git a/data.py b/data.py new file mode 100644 index 0000000..e93440b --- /dev/null +++ b/data.py @@ -0,0 +1,71 @@ +import os + +# full list from D:\data_matmul (given by user) +all_files = [ +"matmul_1024_1024_128.mlir","matmul_1024_1024_256.mlir","matmul_1024_128_1024.mlir","matmul_1024_128_128.mlir", +"matmul_1024_128_2048.mlir","matmul_1024_128_256.mlir","matmul_1024_128_512.mlir","matmul_1024_128_768.mlir", +"matmul_1024_1536_128.mlir","matmul_1024_2048_128.mlir","matmul_1024_256_1024.mlir","matmul_1024_256_1536.mlir", +"matmul_1024_256_256.mlir","matmul_1024_256_512.mlir","matmul_1024_256_768.mlir","matmul_1024_3072_128.mlir", +"matmul_1024_512_128.mlir","matmul_1024_512_256.mlir","matmul_1024_512_512.mlir","matmul_1024_768_256.mlir", +"matmul_128_1024_1024.mlir","matmul_128_1024_128.mlir","matmul_128_1024_1536.mlir","matmul_128_1024_256.mlir", +"matmul_128_1024_512.mlir","matmul_128_1024_768.mlir","matmul_128_128_1024.mlir","matmul_128_128_128.mlir", +"matmul_128_128_1536.mlir","matmul_128_128_2048.mlir","matmul_128_128_3072.mlir","matmul_128_128_512.mlir", +"matmul_128_128_768.mlir","matmul_128_1536_1024.mlir","matmul_128_1536_128.mlir","matmul_128_1536_256.mlir", +"matmul_128_1536_512.mlir","matmul_128_1536_768.mlir","matmul_128_2048_1024.mlir","matmul_128_2048_128.mlir", +"matmul_128_2048_1536.mlir","matmul_128_2048_256.mlir","matmul_128_2048_512.mlir","matmul_128_2048_768.mlir", +"matmul_128_256_1024.mlir","matmul_128_256_128.mlir","matmul_128_256_2048.mlir","matmul_128_256_3072.mlir", +"matmul_128_256_768.mlir","matmul_128_3072_128.mlir","matmul_128_3072_256.mlir","matmul_128_3072_512.mlir", +"matmul_128_3072_768.mlir","matmul_128_512_1024.mlir","matmul_128_512_128.mlir","matmul_128_512_1536.mlir", +"matmul_128_512_2048.mlir","matmul_128_512_256.mlir","matmul_128_512_3072.mlir","matmul_128_512_512.mlir", +"matmul_128_768_1024.mlir","matmul_128_768_128.mlir","matmul_128_768_1536.mlir","matmul_128_768_256.mlir", +"matmul_128_768_3072.mlir","matmul_128_768_512.mlir","matmul_128_768_768.mlir","matmul_1536_1024_128.mlir", +"matmul_1536_128_128.mlir","matmul_1536_128_1536.mlir","matmul_1536_128_512.mlir","matmul_1536_128_768.mlir", +"matmul_1536_1536_128.mlir","matmul_1536_256_1024.mlir","matmul_1536_256_128.mlir","matmul_1536_256_256.mlir", +"matmul_1536_256_512.mlir","matmul_1536_256_768.mlir","matmul_1536_512_128.mlir","matmul_1536_512_256.mlir", +"matmul_1536_768_256.mlir","matmul_2048_128_1024.mlir","matmul_2048_128_128.mlir","matmul_2048_128_256.mlir", +"matmul_2048_128_512.mlir","matmul_2048_128_768.mlir","matmul_2048_256_128.mlir","matmul_2048_256_512.mlir", +"matmul_2048_256_768.mlir","matmul_2048_512_128.mlir","matmul_2048_512_256.mlir","matmul_2048_768_128.mlir", +"matmul_256_1024_1024.mlir","matmul_256_1024_128.mlir","matmul_256_1024_1536.mlir","matmul_256_1024_256.mlir", +"matmul_256_1024_512.mlir","matmul_256_1024_768.mlir","matmul_256_1280_1000.mlir","matmul_256_128_128.mlir", +"matmul_256_128_1536.mlir","matmul_256_128_2048.mlir","matmul_256_128_256.mlir","matmul_256_128_3072.mlir", +"matmul_256_128_512.mlir","matmul_256_128_768.mlir","matmul_256_1408_1000.mlir","matmul_256_1536_1000.mlir", +"matmul_256_1536_128.mlir","matmul_256_1536_256.mlir","matmul_256_1536_4096.mlir","matmul_256_1536_512.mlir", +"matmul_256_1536_768.mlir","matmul_256_2048_1000.mlir","matmul_256_2048_128.mlir","matmul_256_2048_2048.mlir", +"matmul_256_2048_256.mlir","matmul_256_2048_512.mlir","matmul_256_256_1024.mlir","matmul_256_256_128.mlir", +"matmul_256_256_1536.mlir","matmul_256_256_256.mlir","matmul_256_256_512.mlir","matmul_256_256_768.mlir", +"matmul_256_3072_128.mlir","matmul_256_4096_1024.mlir","matmul_256_512_1024.mlir","matmul_256_512_128.mlir", +"matmul_256_512_1536.mlir","matmul_256_512_2048.mlir","matmul_256_512_256.mlir","matmul_256_512_3072.mlir", +"matmul_256_512_512.mlir","matmul_256_512_768.mlir","matmul_256_768_1024.mlir","matmul_256_768_128.mlir", +"matmul_256_768_1536.mlir","matmul_256_768_2.mlir","matmul_256_768_256.mlir","matmul_256_768_3072.mlir", +"matmul_256_768_512.mlir","matmul_256_768_768.mlir","matmul_3072_128_128.mlir","matmul_3072_128_256.mlir", +"matmul_3072_128_512.mlir","matmul_3072_256_128.mlir","matmul_3072_256_256.mlir","matmul_3072_512_128.mlir", +"matmul_3072_512_256.mlir","matmul_3072_768_128.mlir","matmul_512_1024_128.mlir","matmul_512_1024_256.mlir", +"matmul_512_1024_512.mlir","matmul_512_128_1024.mlir","matmul_512_128_128.mlir","matmul_512_128_1536.mlir", +"matmul_512_128_2048.mlir","matmul_512_128_256.mlir","matmul_512_128_3072.mlir","matmul_512_128_512.mlir", +"matmul_512_128_768.mlir","matmul_512_1536_128.mlir","matmul_512_1536_256.mlir","matmul_512_2048_128.mlir", +"matmul_512_256_1024.mlir","matmul_512_256_128.mlir","matmul_512_256_1536.mlir","matmul_512_256_2048.mlir", +"matmul_512_256_256.mlir","matmul_512_256_512.mlir","matmul_512_256_768.mlir","matmul_512_512_1024.mlir", +"matmul_512_512_128.mlir","matmul_512_512_256.mlir","matmul_512_512_512.mlir","matmul_512_512_768.mlir", +"matmul_512_768_128.mlir","matmul_512_768_256.mlir","matmul_512_768_512.mlir","matmul_512_768_768.mlir", +"matmul_768_1024_128.mlir","matmul_768_128_1536.mlir","matmul_768_128_256.mlir","matmul_768_128_3072.mlir", +"matmul_768_128_512.mlir","matmul_768_128_768.mlir","matmul_768_1536_128.mlir","matmul_768_2048_128.mlir", +"matmul_768_2048_256.mlir","matmul_768_256_1024.mlir","matmul_768_256_128.mlir","matmul_768_256_1536.mlir", +"matmul_768_256_2048.mlir","matmul_768_256_256.mlir","matmul_768_256_768.mlir","matmul_768_3072_128.mlir", +"matmul_768_512_128.mlir","matmul_768_512_256.mlir","matmul_768_512_768.mlir","matmul_768_768_128.mlir", +"matmul_768_768_256.mlir","matmul_768_768_512.mlir" +] + +# used files subset +used_files = [ +"matmul_256_1024_1024.mlir","matmul_256_1280_1000.mlir","matmul_256_1408_1000.mlir","matmul_256_1536_1000.mlir", +"matmul_256_1536_4096.mlir","matmul_256_2048_1000.mlir","matmul_256_2048_2048.mlir","matmul_256_256_128.mlir", +"matmul_256_256_512.mlir","matmul_256_4096_1024.mlir","matmul_256_512_1024.mlir","matmul_256_768_2.mlir", +"matmul_256_768_3072.mlir","matmul_256_768_768.mlir" +] + +unused_files = (set(all_files) - set(used_files)) +unused_files = [f for f in unused_files if f.endswith('.mlir') and f.startswith('matmul_')] + +unused_files = sorted(unused_files) + +print(unused_files[:10]) \ No newline at end of file diff --git a/data/all/.gitignore b/data/all/.gitignore deleted file mode 100644 index d6b7ef3..0000000 --- a/data/all/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore diff --git a/data/debug/.gitignore b/data/debug/.gitignore deleted file mode 100644 index c96a04f..0000000 --- a/data/debug/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore \ No newline at end of file diff --git a/data/features/.gitignore b/data/features/.gitignore deleted file mode 100644 index d6b7ef3..0000000 --- a/data/features/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore diff --git a/data/lqcd/bench/.gitignore b/data/lqcd/bench/.gitignore old mode 100644 new mode 100755 index d6b7ef3..005717e --- a/data/lqcd/bench/.gitignore +++ b/data/lqcd/bench/.gitignore @@ -1,2 +1,2 @@ -* -!.gitignore +* +!.gitignore diff --git a/data/lqcd/control/execution_times.json b/data/lqcd/control/execution_times.json old mode 100644 new mode 100755 index f7806b5..ffd7c84 --- a/data/lqcd/control/execution_times.json +++ b/data/lqcd/control/execution_times.json @@ -1,3 +1,3 @@ -{ - "test_AB_1": 2620914018 +{ + "test_AB_1": 2620914018 } \ No newline at end of file diff --git a/data/lqcd/control/test_AB_1.mlir b/data/lqcd/control/test_AB_1.mlir old mode 100644 new mode 100755 index 39f4758..3c95ed3 --- a/data/lqcd/control/test_AB_1.mlir +++ b/data/lqcd/control/test_AB_1.mlir @@ -1,45 +1,45 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%B_28: memref<1024x1024xf64>, %A_30: memref<1024x1024xf64>, %output_24: memref<1024x1024xf64>) -> i64 attributes { llvm.emit_c_interface } { - %t0 = func.call @nanoTime() : () -> i64 - %7 = memref.alloc() : memref<1024x1024x1xf64> - linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%7: memref<1024x1024x1xf64>) { - ^bb0(%8: f64): - %1 = arith.constant 0.0 : f64 - linalg.yield %1 : f64 - } - %9 = memref.alloc() : memref<1024xf64> - linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%9, %A_30, %B_28, %A_30, %B_28: memref<1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>) outs(%7: memref<1024x1024x1xf64>) { - ^bb0(%10: f64, %13: f64, %37: f64, %30: f64, %39: f64, %11: f64): - %4 = linalg.index 0 : index - %5 = linalg.index 1 : index - %6 = linalg.index 2 : index - %12 = linalg.index 3 : index - %27 = arith.constant 1 : index - %32 = arith.minsi %6, %27 : index - %17 = arith.constant 0 : index - %33 = arith.maxsi %32, %17 : index - %26 = arith.mulf %13, %37 fastmath : f64 - %22 = arith.constant 0.0 : f64 - %18 = arith.subf %26, %22 fastmath : f64 - %29 = arith.constant 0.0 : f64 - %24 = arith.mulf %30, %29 fastmath : f64 - %15 = arith.constant 0.0 : f64 - %14 = arith.mulf %15, %39 fastmath : f64 - %19 = arith.addf %24, %14 fastmath : f64 - %36 = arith.constant 0 : index - %21 = arith.cmpi eq, %33, %36 : index - %38 = arith.select %21, %18, %19 : f64 - %25 = arith.addf %11, %38 fastmath : f64 - linalg.yield %25 : f64 - } - %41 = memref.collapse_shape %7 [[0], [1, 2]] : memref<1024x1024x1xf64> into memref<1024x1024xf64> - linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>, affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>], iterator_types = ["parallel", "parallel"]} ins(%41: memref<1024x1024xf64>) outs(%output_24: memref<1024x1024xf64>) { - ^bb0(%43: f64, %42: f64): - %2 = linalg.index 0 : index - %3 = linalg.index 1 : index - linalg.yield %43 : f64 - } - %t1 = func.call @nanoTime() : () -> (i64) - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} +func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } +func.func @main(%B_28: memref<1024x1024xf64>, %A_30: memref<1024x1024xf64>, %output_24: memref<1024x1024xf64>) -> i64 attributes { llvm.emit_c_interface } { + %t0 = func.call @nanoTime() : () -> i64 + %7 = memref.alloc() : memref<1024x1024x1xf64> + linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%7: memref<1024x1024x1xf64>) { + ^bb0(%8: f64): + %1 = arith.constant 0.0 : f64 + linalg.yield %1 : f64 + } + %9 = memref.alloc() : memref<1024xf64> + linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, sum_36_0)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(sum_36_0, o_1_22)>, affine_map<(o_0_20, o_1_22, complex_26, sum_36_0)->(o_0_20, o_1_22, complex_26)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%9, %A_30, %B_28, %A_30, %B_28: memref<1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>, memref<1024x1024xf64>) outs(%7: memref<1024x1024x1xf64>) { + ^bb0(%10: f64, %13: f64, %37: f64, %30: f64, %39: f64, %11: f64): + %4 = linalg.index 0 : index + %5 = linalg.index 1 : index + %6 = linalg.index 2 : index + %12 = linalg.index 3 : index + %27 = arith.constant 1 : index + %32 = arith.minsi %6, %27 : index + %17 = arith.constant 0 : index + %33 = arith.maxsi %32, %17 : index + %26 = arith.mulf %13, %37 fastmath : f64 + %22 = arith.constant 0.0 : f64 + %18 = arith.subf %26, %22 fastmath : f64 + %29 = arith.constant 0.0 : f64 + %24 = arith.mulf %30, %29 fastmath : f64 + %15 = arith.constant 0.0 : f64 + %14 = arith.mulf %15, %39 fastmath : f64 + %19 = arith.addf %24, %14 fastmath : f64 + %36 = arith.constant 0 : index + %21 = arith.cmpi eq, %33, %36 : index + %38 = arith.select %21, %18, %19 : f64 + %25 = arith.addf %11, %38 fastmath : f64 + linalg.yield %25 : f64 + } + %41 = memref.collapse_shape %7 [[0], [1, 2]] : memref<1024x1024x1xf64> into memref<1024x1024xf64> + linalg.generic {indexing_maps = [affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>, affine_map<(o_0_20, o_1_22)->(o_0_20, o_1_22)>], iterator_types = ["parallel", "parallel"]} ins(%41: memref<1024x1024xf64>) outs(%output_24: memref<1024x1024xf64>) { + ^bb0(%43: f64, %42: f64): + %2 = linalg.index 0 : index + %3 = linalg.index 1 : index + linalg.yield %43 : f64 + } + %t1 = func.call @nanoTime() : () -> (i64) + %t2 = arith.subi %t1, %t0 : i64 + return %t2 : i64 +} diff --git a/data/lqcd/control/test_AB_1.mlir.npy b/data/lqcd/control/test_AB_1.mlir.npy old mode 100644 new mode 100755 diff --git a/data/lqcd/control/test_AB_1.mlir.npz b/data/lqcd/control/test_AB_1.mlir.npz old mode 100644 new mode 100755 diff --git a/data/multi/.gitignore b/data/multi/.gitignore deleted file mode 100644 index d6b7ef3..0000000 --- a/data/multi/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore diff --git a/data/nn-old/.gitignore b/data/nn-old/.gitignore deleted file mode 100644 index c96a04f..0000000 --- a/data/nn-old/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore \ No newline at end of file diff --git a/data/nn/.gitignore b/data/nn/.gitignore deleted file mode 100644 index ca76327..0000000 --- a/data/nn/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -* -!gen -!.gitignore diff --git a/data/nn/gen/.gitignore b/data/nn/gen/.gitignore deleted file mode 100644 index 4e7cefe..0000000 --- a/data/nn/gen/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -!data_generation_random.py -!example.yaml -!.gitignore diff --git a/data/nn/gen/data_generation_random.py b/data/nn/gen/data_generation_random.py deleted file mode 100644 index 68190da..0000000 --- a/data/nn/gen/data_generation_random.py +++ /dev/null @@ -1,1100 +0,0 @@ -from random import randint, choice, shuffle, random -from tqdm import trange -import json -import numpy as np -import yaml -import argparse -import re -import os -import sys -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -from typing import Callable -from stopit import ThreadingTimeout, TimeoutException - - -def remove_duplicate_args(args, shapes): - args_shapes = list(zip(args, shapes)) - seen = set() - result = [] - for item in args_shapes: - if item not in seen: - seen.add(item) - result.append(item) - - args = [x for (x, _) in result] - shapes = [x for (_, x) in result] - return args, shapes - - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - convert-bufferization-to-memref, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -tensor_pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - eliminate-empty-tensors, - empty-tensor-to-alloc-tensor, - one-shot-bufferize{ - bufferize-function-boundaries - function-boundary-type-conversion=identity-layout-map - }, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - convert-bufferization-to-memref, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - finalize-memref-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -tmp_file = 'temp.mlir' - -BATCH_SIZES = [] -SIZES = [] -HEIGHTS = [] -CHANNELS = [] -KERNELS = [] -DILATIONS = [] -STRIDES = [] - - -def choice_topped(choices, max_value): - trials_left = 50 - n = choice(choices) - while not (n <= max_value) and trials_left != 0: - n = choice(choices) - trials_left -= 1 - - if trials_left == 0: - return None - return n - - -def add(): - # SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 3))]) - matmul_size = 10 ** 9 - while matmul_size >= (10 ** 9): - shape_int = [choice(HEIGHTS) for _ in range(4)] - matmul_size = shape_int[0] * shape_int[1] * shape_int[2] * shape_int[3] - SHAPE = "x".join(map(str, shape_int)) - shape_str = SHAPE.replace('x', '_') - bench_name = f"add_{shape_str}" - return f"linalg.add ins(%arg0, %arg1: memref<{SHAPE}xf64>, memref<{SHAPE}xf64>) outs(%arg2: memref<{SHAPE}xf64>)", bench_name - - -def add_nn(): - B = choice(BATCH_SIZES) - N = choice(HEIGHTS) - operation = f""" - linalg.generic {{indexing_maps = [#map2, #map4, #map2], iterator_types = ["parallel", "parallel"]}} ins(%44, %10 : tensor<{B}x{N}xf32>, tensor<{N}xf32>) outs(%42 : tensor<{B}x{N}xf32>) {{ - ^bb0(%in: f32, %in_1: f32, %out: f32): - %46 = arith.addf %in, %in_1 : f32 - linalg.yield %46 : f32 - }} - """.strip() - return operation - - -def sub(): - SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 4))]) - return f"linalg.sub ins(%arg0, %arg1: tensor<{SHAPE}xf32>, tensor<{SHAPE}xf32>) outs(%arg2: tensor<{SHAPE}xf32>) -> tensor<{SHAPE}xf32>" - - -def max(): - SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 4))]) - return f"linalg.max ins(%arg0, %arg1: tensor<{SHAPE}xf32>, tensor<{SHAPE}xf32>) outs(%arg2: tensor<{SHAPE}xf32>) -> tensor<{SHAPE}xf32>" - - -def mul(): - SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 4))]) - return f"linalg.mul ins(%arg0, %arg1: tensor<{SHAPE}xf32>, tensor<{SHAPE}xf32>) outs(%arg2: tensor<{SHAPE}xf32>) -> tensor<{SHAPE}xf32>" - - -def abs(): - SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 4))]) - return f"linalg.abs ins(%arg0: tensor<{SHAPE}xf32>) outs(%arg2: tensor<{SHAPE}xf32>) -> tensor<{SHAPE}xf32>" - - -def ceil(): - SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 4))]) - return f"linalg.ceil ins(%arg0 : tensor<{SHAPE}xf32>) outs(%arg1: tensor<{SHAPE}xf32>) -> tensor<{SHAPE}xf32>" - - -def copy_(): - SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 4))]) - return f"linalg.copy ins(%arg0 : tensor<{SHAPE}xf32>) outs(%arg1: tensor<{SHAPE}xf32>) -> tensor<{SHAPE}xf32>" - - -def fill(): - SHAPE = "x".join([str(choice(HEIGHTS)) for _ in range(randint(1, 4))]) - return f"linalg.fill ins(%arg0 : f32) outs(%arg1: tensor<{SHAPE}xf32>) -> tensor<{SHAPE}xf32>" - - -def transpose(): - L = randint(1, 5) - - permutation = list(range(L)) - shuffle(permutation) - - SHAPE1 = [choice(HEIGHTS) for _ in range(L)] - - SHAPE2 = [] - for i in range(L): - SHAPE2.append(SHAPE1[permutation[i]]) - - SHAPE1 = "x".join(map(str, SHAPE1)) - SHAPE2 = "x".join(map(str, SHAPE2)) - - return f"linalg.transpose ins(%input:tensor<{SHAPE1}xf32>) outs(%init:tensor<{SHAPE2}xf32>) permutation = {permutation}" - - -def batch_matmul(): - B = choice(BATCH_SIZES) - N = choice(HEIGHTS) - K = choice(HEIGHTS) - M = choice(HEIGHTS) - return f"linalg.batch_matmul ins(%arg0, %arg1 : tensor<{B}x{N}x{K}xf32>, tensor<{B}x{K}x{M}xf32>) outs(%arg2 : tensor<{B}x{N}x{M}xf32>) -> tensor<{B}x{N}x{M}xf32>" - - -def batch_matmul_transpose_a(): - B = choice(BATCH_SIZES) - N = choice(HEIGHTS) - K = choice(HEIGHTS) - M = choice(HEIGHTS) - return f"linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<{B}x{K}x{N}xf32>, tensor<{B}x{K}x{M}xf32>) outs(%arg2: tensor<{B}x{N}x{M}xf32>) -> tensor<{B}x{N}x{M}xf32>" - - -def batch_matmul_transpose_b(): - B = choice(BATCH_SIZES) - N = choice(HEIGHTS) - K = choice(HEIGHTS) - M = choice(HEIGHTS) - return f"linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : tensor<{B}x{N}x{K}xf32>, tensor<{B}x{M}x{K}xf32>) outs(%arg2: tensor<{B}x{N}x{M}xf32>) -> tensor<{B}x{N}x{M}xf32>" - - -def batch_reduce_matmul(): - B = choice(BATCH_SIZES) - N = choice(HEIGHTS) - K = choice(HEIGHTS) - M = choice(HEIGHTS) - return f"linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<{B}x{N}x{K}xf32>, tensor<{B}x{K}x{M}xf32>) outs(%arg2: tensor<{N}x{M}xf32>) -> tensor<{N}x{M}xf32>" - - -def matmul(): - matmul_size = 10 ** 9 - while matmul_size >= (10 ** 9): - N = choice(SIZES) - K = choice(SIZES) - M = choice(SIZES) - matmul_size = N * K * M - bench_name = f"matmul_{N}_{K}_{M}" - return f"linalg.matmul ins(%arg0, %arg1 : memref<{N}x{K}xf64>, memref<{K}x{M}xf64>) outs(%arg2 : memref<{N}x{M}xf64>)", bench_name - - -def matmul_transpose_a(): - N = choice(HEIGHTS) - K = choice(HEIGHTS) - M = choice(HEIGHTS) - return f"linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<{K}x{N}xf32>, tensor<{K}x{M}xf32>) outs(%arg2: tensor<{N}x{M}xf32>) -> tensor<{N}x{M}xf32>" - - -def matmul_transpose_b(): - N = choice(HEIGHTS) - K = choice(HEIGHTS) - M = choice(HEIGHTS) - return f"linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{N}x{K}xf32>, tensor<{M}x{K}xf32>) outs(%arg2: tensor<{N}x{M}xf32>) -> tensor<{N}x{M}xf32>" - - -def conv_1d(): - N = choice(HEIGHTS) - F = choice_topped(KERNELS, N) - N_ = N - F + 1 - return f"linalg.conv_1d ins(%input, %filter : tensor<{N}xf32>, tensor<{F}xf32>) outs(%output : tensor<{N_}xf32>) -> tensor<{N_}xf32>" - - -def conv_1d_ncw_fcw(): - # INPUT: NCW1 - # KERNL: FCW2 - # OUTPUT: (N, F, W1-W2+1) - - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W1 = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - F = choice(CHANNELS) - W2 = choice_topped(KERNELS, (W1 + 2 * padding - 1) // dilation - 1) - - W3 = ((W1 + 2 * padding - dilation * (W2 - 1) - 1) // stride) + 1 - - return f"linalg.conv_1d_ncw_fcw {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{C}x{W1}xf32>, tensor<{F}x{C}x{W2}xf32>) outs (%init: tensor<{N}x{F}x{W3}xf32>) -> tensor<{N}x{F}x{W3}xf32>" - - -def conv_1d_nwc_wcf(): - # INPUT: NWC - # KERNL: WCF - # OUTPUT: (N, W1-W2+1, F) - - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W1 = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - F = choice(CHANNELS) - W2 = choice_topped(KERNELS, (W1 + 2 * padding - 1) // dilation - 1) - - W3 = ((W1 + 2 * padding - dilation * (W2 - 1) - 1) // stride) + 1 - - return f"linalg.conv_1d_nwc_wcf {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{W1}x{C}xf32>, tensor<{W2}x{C}x{F}xf32>) outs (%init: tensor<{N}x{W3}x{F}xf32>) -> tensor<{N}x{W3}x{F}xf32>" - - -def conv_2d(): - H, W = choice(HEIGHTS), choice(HEIGHTS) - - F1 = F2 = choice_topped(KERNELS, min(H - 2, W - 2)) - - H_ = H - F1 + 1 - W_ = W - F2 + 1 - - return f"linalg.conv_2d ins(%input, %filter: tensor<{H}x{W}xi32>, tensor<{F1}x{F2}xi32>) outs(%output: tensor<{H_}x{W_}xi32>) -> tensor<{H_}x{W_}xi32>" - - -def conv_2d_nchw_fchw(): - # INPUT: NCHW - # KERNL: FCHW - # OUTPUT: (N, F, H', W') - - matmul_size = 10 ** 9 - while matmul_size >= (10 ** 9): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - # W = choice(HEIGHTS) - W = H - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - F = choice(CHANNELS) - KH = KW = choice_topped(KERNELS, (min(H, W) + 2 * padding - 1) // dilation - 1) - - H_ = ((H + 2 * padding - dilation * (KH - 1) - 1) // stride) + 1 - W_ = ((W + 2 * padding - dilation * (KW - 1) - 1) // stride) + 1 - - matmul_size = N * F * H_ * W_ * C * KH * KW - - bench_name = f"conv_2d_nchw_fchw_{N}_{C}_{H}_{W}_{F}_{KH}_{KW}_{H_}_{W_}" - - return f"linalg.conv_2d_nchw_fchw {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{C}x{H}x{W}xf64>, tensor<{F}x{C}x{KH}x{KW}xf64>) outs (%init: tensor<{N}x{F}x{H_}x{W_}xf64>) -> tensor<{N}x{F}x{H_}x{W_}xf64>", bench_name - - -def conv_2d_ngchw_fgchw(): - # INPUT: NCHW - # KERNL: FCHW - # OUTPUT: (N, F, H', W') - - N = choice(BATCH_SIZES) - G = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - F = choice(CHANNELS) - KH = KW = choice_topped(KERNELS, (min(H, W) + 2 * padding - 1) // dilation - 1) - - W_ = ((W + 2 * padding - dilation * (KW - 1) - 1) // stride) + 1 - H_ = ((H + 2 * padding - dilation * (KH - 1) - 1) // stride) + 1 - - return f"linalg.conv_2d_ngchw_fgchw {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{G}x{C}x{H}x{W}xf32>, tensor<{G}x{F}x{C}x{KH}x{KW}xf32>) outs (%init: tensor<{N}x{G}x{F}x{H_}x{W_}xf32>) -> tensor<{N}x{G}x{F}x{H_}x{W_}xf32>" - - -def conv_2d_nhwc_fhwc(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - F = choice(CHANNELS) - KH = KW = choice_topped(KERNELS, (min(H, W) + 2 * padding - 1) // dilation - 1) - - W_ = ((W + 2 * padding - dilation * (KW - 1) - 1) // stride) + 1 - H_ = ((H + 2 * padding - dilation * (KH - 1) - 1) // stride) + 1 - - return f"linalg.conv_2d_nhwc_fhwc {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{H}x{W}x{C}xf32>, tensor<{F}x{KH}x{KW}x{C}xf32>) outs (%init: tensor<{N}x{H_}x{W_}x{F}xf32>) -> tensor<{N}x{H_}x{W_}x{F}xf32>" - - -def conv_2d_nhwc_hwcf(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - F = choice(CHANNELS) - KH = KW = choice_topped(KERNELS, (min(H, W) + 2 * padding - 1) // dilation - 1) - - W_ = ((W + 2 * padding - dilation * (KW - 1) - 1) // stride) + 1 - H_ = ((H + 2 * padding - dilation * (KH - 1) - 1) // stride) + 1 - - return f"linalg.conv_2d_nhwc_hwcf {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{H}x{W}x{C}xf32>, tensor<{KH}x{KW}x{C}x{F}xf32>) outs (%init: tensor<{N}x{H_}x{W_}x{F}xf32>) -> tensor<{N}x{H_}x{W_}x{F}xf32>" - - -def conv_3d(): - H, W, D = choice(HEIGHTS), choice(HEIGHTS), choice(HEIGHTS) - - F = choice_topped(KERNELS, min(H, W, D) - 2) - - H_ = H - F + 1 - W_ = W - F + 1 - D_ = D - F + 1 - - return f"linalg.conv_3d ins(%input, %filter: tensor<{H}x{W}x{D}xf32>, tensor<{F}x{F}x{F}xf32>) outs(%output: tensor<{H_}x{W_}x{D_}xf32>) -> tensor<{H_}x{W_}x{D_}xf32>" - - -def conv_3d_ncdhw_fcdhw(): - # INPUT: NCHW - # KERNL: FCHW - # OUTPUT: (N, F, H', W') - - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - D = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - F = choice(CHANNELS) - KH = KW = KD = choice_topped( - KERNELS, (min(H, W, D) + 2 * padding - 1) // dilation - 1 - ) - - W_ = ((W + 2 * padding - dilation * (KW - 1) - 1) // stride) + 1 - H_ = ((H + 2 * padding - dilation * (KH - 1) - 1) // stride) + 1 - D_ = ((D + 2 * padding - dilation * (KD - 1) - 1) // stride) + 1 - - return f"linalg.conv_3d_ncdhw_fcdhw {{dilations = dense<{dilation}> : tensor<3xi64>, strides = dense<{stride}> : tensor<3xi64>}} ins (%input, %filter: tensor<{N}x{C}x{H}x{W}x{D}xf32>, tensor<{F}x{C}x{KH}x{KW}x{KD}xf32>) outs (%init: tensor<{N}x{F}x{H_}x{W_}x{D_}xf32>) -> tensor<{N}x{F}x{H_}x{W_}x{D_}xf32>" - - -def depthwise_conv_1d_ncw_cw(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (W + 2 * padding - 1) // dilation - 1) - - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_1d_ncw_cw {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{C}x{W}xf32>, tensor<{C}x{K}xf32>) outs (%init: tensor<{N}x{C}x{W_}xf32>) -> tensor<{N}x{C}x{W_}xf32>" - - -def depthwise_conv_1d_nwc_wc(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (W + 2 * padding - 1) // dilation - 1) - - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_1d_nwc_wc {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{W}x{C}xf32>, tensor<{K}x{C}xf32>) outs (%init: tensor<{N}x{W_}x{C}xf32>) -> tensor<{N}x{W_}x{C}xf32>" - - -def depthwise_conv_1d_nwc_wcm(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - M = choice(CHANNELS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (W + 2 * padding - 1) // dilation - 1) - - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_1d_nwc_wcm {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{W}x{C}xf32>, tensor<{K}x{C}x{M}xf32>) outs (%init: tensor<{N}x{W_}x{C}x{M}xf32>) -> tensor<{N}x{W_}x{C}x{M}xf32>" - - -def depthwise_conv_2d_nchw_chw(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (min(H, W) + 2 * padding - 1) // dilation - 1) - - H_ = ((H + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_2d_nchw_chw {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{C}x{H}x{W}xf32>, tensor<{C}x{K}x{K}xf32>) outs (%init: tensor<{N}x{C}x{H_}x{W_}xf32>) -> tensor<{N}x{C}x{H_}x{W_}xf32>" - - -def depthwise_conv_2d_nhwc_hwc(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (min(H, W) + 2 * padding - 1) // dilation - 1) - - H_ = ((H + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_2d_nhwc_hwc {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{H}x{W}x{C}xf32>, tensor<{K}x{K}x{C}xf32>) outs (%init: tensor<{N}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{H_}x{W_}x{C}xf32>" - - -def depthwise_conv_2d_nhwc_hwcm(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - M = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (min(H, W) + 2 * padding - 1) // dilation - 1) - - H_ = ((H + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_2d_nhwc_hwcm {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{H}x{W}x{C}xf32>, tensor<{K}x{K}x{C}x{M}xf32>) outs (%init: tensor<{N}x{H_}x{W_}x{C}x{M}xf32>) -> tensor<{N}x{H_}x{W_}x{C}x{M}xf32>" - - -def depthwise_conv_3d_ncdhw_cdhw(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - D = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (min(H, W, D) + 2 * padding - 1) // dilation - 1) - - H_ = ((H + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - D_ = ((D + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_3d_ncdhw_cdhw {{dilations = dense<{dilation}> : tensor<3xi64>, strides = dense<{stride}> : tensor<3xi64>}} ins (%input, %filter: tensor<{N}x{C}x{D}x{H}x{W}xf32>, tensor<{C}x{K}x{K}x{K}xf32>) outs (%init: tensor<{N}x{C}x{D_}x{H_}x{W_}xf32>) -> tensor<{N}x{C}x{D_}x{H_}x{W_}xf32>" - - -def depthwise_conv_3d_ndhwc_dhwc(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - D = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (min(H, W, D) + 2 * padding - 1) // dilation - 1) - - H_ = ((H + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - D_ = ((D + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_3d_ndhwc_dhwc {{dilations = dense<{dilation}> : tensor<3xi64>, strides = dense<{stride}> : tensor<3xi64>}} ins (%input, %filter: tensor<{N}x{D}x{H}x{W}x{C}xf32>, tensor<{K}x{K}x{K}x{C}xf32>) outs (%init: tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>" - - -def depthwise_conv_3d_ndhwc_dhwcm(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - M = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - D = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - padding = 0 - - K = choice_topped(KERNELS, (min(H, W, D) + 2 * padding - 1) // dilation - 1) - - H_ = ((H + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - W_ = ((W + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - D_ = ((D + 2 * padding - dilation * (K - 1) - 1) // stride) + 1 - - return f"linalg.depthwise_conv_3d_ndhwc_dhwcm {{dilations = dense<{dilation}> : tensor<3xi64>, strides = dense<{stride}> : tensor<3xi64>}} ins (%input, %filter: tensor<{N}x{D}x{H}x{W}x{C}xf32>, tensor<{K}x{K}x{K}x{C}x{M}xf32>) outs (%init: tensor<{N}x{D_}x{H_}x{W_}x{C}x{M}xf32>) -> tensor<{N}x{D_}x{H_}x{W_}x{C}x{M}xf32>" - - -def pooling_nchw_max(): - matmul_size = 10 ** 9 - - while matmul_size >= (10 ** 9): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = H - - dilation = choice(DILATIONS) - stride = 2 - - K = choice_topped(KERNELS, (min(H, W) - 1) // dilation - 1) - - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - matmul_size = N * C * H_ * W_ * K * K - - bench_name = f"pooling_nchw_max_{N}_{C}_{H}_{W}_{K}_{H_}_{W_}" - - return f"linalg.pooling_nchw_max {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{C}x{H}x{W}xf64>, tensor<{K}x{K}xf64>) outs (%init: tensor<{N}x{C}x{H_}x{W_}xf64>) -> tensor<{N}x{C}x{H_}x{W_}xf64>", bench_name - - -def pooling_nchw_sum(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (min(H, W) - 1) // dilation - 1) - - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_nchw_sum {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{C}x{H}x{W}xf32>, tensor<{K}x{K}xf32>) outs (%init: tensor<{N}x{C}x{H_}x{W_}xf32>) -> tensor<{N}x{C}x{H_}x{W_}xf32>" - - -def pooling_ncw_max(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (W - 1) // dilation - 1) - - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_ncw_max {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{C}x{W}xf32>, tensor<{K}xf32>) outs (%init: tensor<{N}x{C}x{W_}xf32>) -> tensor<{N}x{C}x{W_}xf32>" - - -def pooling_ncw_sum(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (W - 1) // dilation - 1) - - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_ncw_sum {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{C}x{W}xf32>, tensor<{K}xf32>) outs (%init: tensor<{N}x{C}x{W_}xf32>) -> tensor<{N}x{C}x{W_}xf32>" - - -def pooling_ndhwc_max(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - D = choice(HEIGHTS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (min(H, W, D) - 1) // dilation - 1) - - D_ = (D - dilation * (K - 1) - 1) // stride + 1 - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_ndhwc_max {{dilations = dense<{dilation}> : tensor<3xi64>, strides = dense<{stride}> : tensor<3xi64>}} ins (%input, %filter: tensor<{N}x{D}x{H}x{W}x{C}xf32>, tensor<{K}x{K}x{K}xf32>) outs (%init: tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>" - - -def pooling_ndhwc_min(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - D = choice(HEIGHTS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (min(H, W, D) - 1) // dilation - 1) - - D_ = (D - dilation * (K - 1) - 1) // stride + 1 - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_ndhwc_min {{dilations = dense<{dilation}> : tensor<3xi64>, strides = dense<{stride}> : tensor<3xi64>}} ins (%input, %filter: tensor<{N}x{D}x{H}x{W}x{C}xf32>, tensor<{K}x{K}x{K}xf32>) outs (%init: tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>" - - -def pooling_ndhwc_sum(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - D = choice(HEIGHTS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (min(H, W, D) - 1) // dilation - 1) - - D_ = (D - dilation * (K - 1) - 1) // stride + 1 - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_ndhwc_sum {{dilations = dense<{dilation}> : tensor<3xi64>, strides = dense<{stride}> : tensor<3xi64>}} ins (%input, %filter: tensor<{N}x{D}x{H}x{W}x{C}xf32>, tensor<{K}x{K}x{K}xf32>) outs (%init: tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{D_}x{H_}x{W_}x{C}xf32>" - - -def pooling_nhwc_max(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (min(H, W) - 1) // dilation - 1) - - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_nhwc_max {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{H}x{W}x{C}xf32>, tensor<{K}x{K}xf32>) outs (%init: tensor<{N}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{H_}x{W_}x{C}xf32>" - - -def pooling_nhwc_min(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (min(H, W) - 1) // dilation - 1) - - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_nhwc_min {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{H}x{W}x{C}xf32>, tensor<{K}x{K}xf32>) outs (%init: tensor<{N}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{H_}x{W_}x{C}xf32>" - - -def pooling_nhwc_sum(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - H = choice(HEIGHTS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (min(H, W) - 1) // dilation - 1) - - H_ = (H - dilation * (K - 1) - 1) // stride + 1 - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_nhwc_sum {{dilations = dense<{dilation}> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} ins (%input, %filter: tensor<{N}x{H}x{W}x{C}xf32>, tensor<{K}x{K}xf32>) outs (%init: tensor<{N}x{H_}x{W_}x{C}xf32>) -> tensor<{N}x{H_}x{W_}x{C}xf32>" - - -def pooling_nwc_max(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (W - 1) // dilation - 1) - - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_nwc_max {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{W}x{C}xf32>, tensor<{K}xf32>) outs (%init: tensor<{N}x{W_}x{C}xf32>) -> tensor<{N}x{W_}x{C}xf32>" - - -def pooling_nwc_sum(): - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W = choice(HEIGHTS) - - dilation = choice(DILATIONS) - stride = choice(STRIDES) - - K = choice_topped(KERNELS, (W - 1) // dilation - 1) - - W_ = (W - dilation * (K - 1) - 1) // stride + 1 - - return f"linalg.pooling_nwc_sum {{dilations = dense<{dilation}> : tensor<1xi64>, strides = dense<{stride}> : tensor<1xi64>}} ins (%input, %filter: tensor<{N}x{W}x{C}xf32>, tensor<{K}xf32>) outs (%init: tensor<{N}x{W_}x{C}xf32>) -> tensor<{N}x{W_}x{C}xf32>" - - -def relu(): - - matmul_size = 10 ** 9 - while matmul_size >= (10 ** 9): - if random() < 0.25: - - N = choice(BATCH_SIZES) - S = choice(SIZES) - SHAPE = f"{N}x{S}" - - relu_operation = ( - f"linalg.generic {{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = [\"parallel\", \"parallel\"]}} ins(%arg0: memref<{SHAPE}xf64>) outs(%arg1: memref<{SHAPE}xf64>) {{\n" - f" ^bb0(%in: f64, %out: f64):\n" - f" %cst_1 = arith.constant 0.0: f64\n" - f" %46 = arith.cmpf ugt, %in, %cst_1 : f64\n" - f" %47 = arith.select %46, %in, %cst_1 : f64\n" - f" linalg.yield %47 : f64\n" - f"}}\n" - ) - - matmul_size = N * S - - else: - - N = choice(BATCH_SIZES) - C = choice(CHANNELS) - W = choice(HEIGHTS) - - SHAPE = f"{N}x{C}x{W}x{W}" - - relu_operation = ( - f"linalg.generic {{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [\"parallel\", \"parallel\", \"parallel\", \"parallel\"]}} ins(%arg0 : memref<{SHAPE}xf64>) outs(%arg1 : memref<{SHAPE}xf64>) {{\n" - f" ^bb0(%in: f64, %out: f64):\n" - f" %cst_1 = arith.constant 0.0 : f64\n" - f" %90 = arith.cmpf ugt, %in, %cst_1 : f64\n" - f" %91 = arith.select %90, %in, %cst_1 : f64\n" - f" linalg.yield %91 : f64\n" - f"}}\n" - ) - - matmul_size = N * C * W * W - - shape_str = SHAPE.replace('x', '_') - bench_name = f"relu_{shape_str}" - - return relu_operation, bench_name - - -LINALG_OPERATION_GENERATORS = { - "add": add, - "add_nn": add_nn, - "sub": sub, - "max": max, - "mul": mul, - "abs": abs, - "ceil": ceil, - "copy": copy_, - "fill": fill, - "transpose": transpose, - "batch_matmul": batch_matmul, - "batch_matmul_transpose_a": batch_matmul_transpose_a, - "batch_matmul_transpose_b": batch_matmul_transpose_b, - "batch_reduce_matmul": batch_reduce_matmul, - "matmul": matmul, - "matmul_transpose_a": matmul_transpose_a, - "matmul_transpose_b": matmul_transpose_b, - "conv_1d": conv_1d, - "conv_1d_ncw_fcw": conv_1d_ncw_fcw, - "conv_1d_nwc_wcf": conv_1d_nwc_wcf, - "conv_2d": conv_2d, - "conv_2d_nchw_fchw": conv_2d_nchw_fchw, - "conv_2d_ngchw_fgchw": conv_2d_ngchw_fgchw, - "conv_2d_nhwc_fhwc": conv_2d_nhwc_fhwc, - "conv_2d_nhwc_hwcf": conv_2d_nhwc_hwcf, - "conv_3d": conv_3d, - "conv_3d_ncdhw_fcdhw": conv_3d_ncdhw_fcdhw, - "depthwise_conv_1d_ncw_cw": depthwise_conv_1d_ncw_cw, - "depthwise_conv_1d_nwc_wc": depthwise_conv_1d_nwc_wc, - "depthwise_conv_1d_nwc_wcm": depthwise_conv_1d_nwc_wcm, - "depthwise_conv_2d_nchw_chw": depthwise_conv_2d_nchw_chw, - "depthwise_conv_2d_nhwc_hwc": depthwise_conv_2d_nhwc_hwc, - "depthwise_conv_2d_nhwc_hwcm": depthwise_conv_2d_nhwc_hwcm, - "depthwise_conv_3d_ncdhw_cdhw": depthwise_conv_3d_ncdhw_cdhw, - "depthwise_conv_3d_ndhwc_dhwc": depthwise_conv_3d_ndhwc_dhwc, - "depthwise_conv_3d_ndhwc_dhwcm": depthwise_conv_3d_ndhwc_dhwcm, - "pooling_nchw_max": pooling_nchw_max, - "pooling_nchw_sum": pooling_nchw_sum, - "pooling_ncw_max": pooling_ncw_max, - "pooling_ncw_sum": pooling_ncw_sum, - "pooling_ndhwc_max": pooling_ndhwc_max, - "pooling_ndhwc_min": pooling_ndhwc_min, - "pooling_ndhwc_sum": pooling_ndhwc_sum, - "pooling_nhwc_max": pooling_nhwc_max, - "pooling_nhwc_min": pooling_nhwc_min, - "pooling_nhwc_sum": pooling_nhwc_sum, - "pooling_nwc_max": pooling_nwc_max, - "pooling_nwc_sum": pooling_nwc_sum, - "relu": relu, -} - - -def extract_args(operation) -> tuple[list[str], list[str]]: - ins_outs_pattern = r"(?:ins|outs)\s*\(([^())]+)\)" - fields = re.findall(ins_outs_pattern, operation) - - args, shapes = [], [] - for field in fields: - args_field, shapes_field = field.split(':') - args += args_field.split(',') - shapes += shapes_field.split(',') - - args = [arg.strip()[1:] for arg in args] - shapes = [shape.strip() for shape in shapes] - - return remove_duplicate_args(args, shapes) - - -def extract_main_args(code) -> tuple[list[str], list[str]]: - main_pattern = r"func.func @main\(([^)]+)\)" - main_params = re.search(main_pattern, code).group(1) - main_params = main_params.split(',') - main_args = [arg.split(':')[0].strip() for arg in main_params] - main_shapes = [arg.split(':')[1].strip() for arg in main_params] - - return remove_duplicate_args(main_args, main_shapes) - - -def transform_img2col(code: str): - code = code.strip() - - transform_dilaect_code = ( - "module attributes {transform.with_named_sequence} {\n" - " transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {\n" - " %op_operation = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op\n" - " %a, %b = transform.structured.convert_conv2d_to_img2col %op_operation : (!transform.any_op) -> (!transform.any_op, !transform.any_op)\n" - " transform.yield\n" - " }\n" - "}\n" - ) - - code = code + transform_dilaect_code - - with open(tmp_file, "w") as file: - file.write(code) - - result = os.popen( - f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file} -transform-interpreter -one-shot-bufferize='bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map' -finalizing-bufferize -canonicalize -test-transform-dialect-erase-schedule", - ).read() - - result = result.replace("module {\n", "", 1) - result = ''.join(result.rsplit('\n}\n', 1)) - result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) - - return result - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='Process an input file and save to an output file.') - parser.add_argument('--input_file', type=str, help='The input file to be processed.') - - args = parser.parse_args() - - with open(args.input_file, 'r') as file: - config = yaml.safe_load(file) - - # Set the shapes of the operations - BATCH_SIZES.extend(config['SHAPES']['BATCH_SIZES']) # Used by all the operations - HEIGHTS.extend(config['SHAPES']['HEIGHTS']) # Used by operations on images - CHANNELS.extend(config['SHAPES']['CHANNELS']) # Used by operations on images - KERNELS.extend(config['SHAPES']['KERNELS']) # Used by operations on images - DILATIONS.extend(config['SHAPES']['DILATIONS']) # Used by operations on images - STRIDES.extend(config['SHAPES']['STRIDES']) # Used by operations on images - SIZES.extend(config['SHAPES']['SIZES']) # Used on other operations like matmul, add, etc... - - operations_config: dict[str, tuple[Callable[[], tuple[str, str]], int]] = { - operation_name: (LINALG_OPERATION_GENERATORS[operation_name], amount) for operation_name, amount in config['OPERATIONS'].items() if amount > 0 - } - - print("Expected Bench Count:", sum(amount for _, (_, amount) in operations_config.items())) - - with open('execution_times.json', 'r') as file: - execution_times: dict[str, int] = json.load(file) - - for operation_name, (generator, amount) in operations_config.items(): - # Iterate the specified number of times ('amount') for the current operation - tqdm_iter = trange(amount, desc=operation_name, file=sys.stdout) - for _ in tqdm_iter: - generated = False - while not generated: - raw_operation, bench_name = generator() # Generate the raw operation using the provided generator function - bench_output = f'../{bench_name}.mlir' - if bench_name in execution_times: - print('Already exists, Bench:', bench_name, file=sys.stderr) - continue - - args, shapes = extract_args(raw_operation) - is_memref = 'memref' in shapes[-1] - is_conv = 'conv' in bench_name - main_args = [f"%{arg}" for arg in args] - main_params = [f"%{arg}: {shape}" for arg, shape in zip(args, shapes)] - out_shape = shapes[-1] - - if not is_memref and not is_conv: - code = ( - f'func.func private @nanoTime() -> i64 attributes {{ llvm.emit_c_interface }}\n' - f'func.func @compute({", ".join(main_params)}) -> ({out_shape}, i64) attributes {{ llvm.emit_c_interface }} {{\n' - f' %t0 = func.call @nanoTime() : () -> i64\n' - f' %0 = {raw_operation}\n' - f' %t1 = func.call @nanoTime() : () -> i64\n' - f' %t2 = arith.subi %t1, %t0 : i64\n' - f' return %0, %t2 : {out_shape}, i64\n' - f'}}\n' - f'func.func @main({", ".join(main_params)}) -> i64 attributes {{ llvm.emit_c_interface }} {{\n' - f' %0, %1 = func.call @compute({", ".join([f"%{arg}" for arg in args])}) : ({", ".join(shapes)}) -> ({out_shape}, i64)\n' - f' return %1 : i64\n' - f'}}\n' - ) - else: - code = ( - f'func.func private @nanoTime() -> i64 attributes {{ llvm.emit_c_interface }}\n' - f'func.func @main({", ".join(main_params)}) -> i64 attributes {{ llvm.emit_c_interface }} {{\n' - f' %t0 = func.call @nanoTime() : () -> i64\n' - f" {'' if is_memref else '%0 = '}{raw_operation}\n" - f' %t1 = func.call @nanoTime() : () -> i64\n' - f' %t2 = arith.subi %t1, %t0 : i64\n' - f' return %t2 : i64\n' - f'}}\n' - ) - - if is_conv: - code = transform_img2col(code) - args, shapes = extract_main_args(code) - assert all('memref' in shape for shape in shapes) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(tensor_pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args_count = len(args) - inputs: dict[str, np.ndarray] = {} - for i, (arg, shape) in enumerate(zip(args, shapes)): - *np_shape, dtype = shape.replace('memref<', '').replace('tensor<', '').replace('>', '').split('x') - assert dtype == 'f64', f'expects f64, got {dtype}' - np_shape = list(map(int, np_shape)) - inputs[arg] = np.zeros(np_shape) - - c_args = [] - for arg_name in args: - c_args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - delta_arg = (ctypes.c_int64 * 1)(0) - c_args.append(delta_arg) - - try: - with ThreadingTimeout(5, False) as to_ctx: - execution_engine.invoke("main", *c_args) - execution_engine.invoke("main", *c_args) - except Exception as e: - if e is TimeoutException: - print('Timeout, Bench:', bench_name, file=sys.stderr) - else: - print(f"Failed, Bench: {bench_name}, error: {e}", file=sys.stderr) - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - continue - - with open(bench_output, 'w') as file: - file.write(code) - - execution_times[bench_name] = exec_time - with open('execution_times.json', 'w') as file: - json.dump(execution_times, file, indent=4) - - tqdm_iter.set_postfix({'time': exec_time}) - - generated = True diff --git a/data/nn/gen/example.yaml b/data/nn/gen/example.yaml deleted file mode 100644 index cbaee10..0000000 --- a/data/nn/gen/example.yaml +++ /dev/null @@ -1,63 +0,0 @@ -SHAPES: - BATCH_SIZES: [128, 256] - SIZES: [768, 1024, 256, 1536, 2048, 512, 128, 3072] - HEIGHTS: [14, 28, 56, 7, 112, 15, 120, 150, 130, 240, 224, 228] - CHANNELS: [32, 256, 128, 192, 512, 64, 96, 48, 288, 240, 384] - KERNELS: [1, 3, 7] - DILATIONS: [1] - STRIDES: [1, 2] - -OPERATIONS: - # add: 248 - add: 0 - sub: 0 - max: 0 - mul: 0 - abs: 0 - ceil: 0 - copy: 0 - fill: 0 - transpose: 0 - batch_matmul: 0 - batch_matmul_transpose_a: 0 - batch_matmul_transpose_b: 0 - batch_reduce_matmul: 0 - # matmul: 175 - matmul: 0 - matmul_transpose_a: 0 - matmul_transpose_b: 0 - conv_1d: 0 - conv_1d_ncw_fcw: 0 - conv_1d_nwc_wcf: 0 - conv_2d: 0 - # conv_2d_nchw_fchw: 72 - conv_2d_nchw_fchw: 0 - conv_2d_ngchw_fgchw: 0 - conv_2d_nhwc_fhwc: 0 - conv_2d_nhwc_hwcf: 0 - conv_3d: 0 - conv_3d_ncdhw_fcdhw: 0 - depthwise_conv_1d_ncw_cw: 0 - depthwise_conv_1d_nwc_wc: 0 - depthwise_conv_1d_nwc_wcm: 0 - depthwise_conv_2d_nchw_chw: 0 - depthwise_conv_2d_nhwc_hwc: 0 - depthwise_conv_2d_nhwc_hwcm: 0 - depthwise_conv_3d_ncdhw_cdhw: 0 - depthwise_conv_3d_ndhwc_dhwc: 0 - depthwise_conv_3d_ndhwc_dhwcm: 0 - # pooling_nchw_max: 200 - pooling_nchw_max: 0 - pooling_nchw_sum: 0 - pooling_ncw_max: 0 - pooling_ncw_sum: 0 - pooling_ndhwc_max: 0 - pooling_ndhwc_min: 0 - pooling_ndhwc_sum: 0 - pooling_nhwc_max: 0 - pooling_nhwc_min: 0 - pooling_nhwc_sum: 0 - pooling_nwc_max: 0 - pooling_nwc_sum: 0 - # relu: 133 - relu: 100 \ No newline at end of file diff --git a/data/polybench/.gitignore b/data/polybench/.gitignore deleted file mode 100644 index 3b34d83..0000000 --- a/data/polybench/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -* -!gen/ -!.gitignore \ No newline at end of file diff --git a/data/polybench/gen/.gitignore b/data/polybench/gen/.gitignore deleted file mode 100644 index 74e5b1d..0000000 --- a/data/polybench/gen/.gitignore +++ /dev/null @@ -1 +0,0 @@ -!* \ No newline at end of file diff --git a/data/polybench/gen/2mm.mlir.bench b/data/polybench/gen/2mm.mlir.bench deleted file mode 100644 index a6707b5..0000000 --- a/data/polybench/gen/2mm.mlir.bench +++ /dev/null @@ -1,45 +0,0 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%A: memref, %B: memref, %C: memref, %D: memref, %alpha: memref, %beta: memref, %output: memref) -> i64 attributes { llvm.emit_c_interface } { - %alpha_v = memref.load %alpha[] : memref - %beta_v = memref.load %beta[] : memref - %t0 = func.call @nanoTime() : () -> i64 - %tmp = memref.alloc() : memref - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } outs(%tmp: memref) { - ^bb0(%0: f64): - %1 = arith.constant 0.0 : f64 - linalg.yield %1 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2)->(d0, d2)>, affine_map<(d0, d1, d2)->(d2, d1)>, affine_map<(d0, d1, d2)->(d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B: memref, memref) outs(%tmp: memref) { - ^bb0(%2: f64, %3: f64, %4: f64): - %5 = arith.mulf %2, %3 fastmath : f64 - %6 = arith.mulf %alpha_v, %5 fastmath : f64 - %7 = arith.addf %4, %6 fastmath : f64 - linalg.yield %7 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%D: memref) outs(%output: memref) { - ^bb0(%8: f64, %9: f64): - %10 = arith.mulf %beta_v, %8 fastmath : f64 - linalg.yield %10 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2)->(d0, d2)>, affine_map<(d0, d1, d2)->(d2, d1)>, affine_map<(d0, d1, d2)->(d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%tmp, %C: memref, memref) outs(%output: memref) { - ^bb0(%11: f64, %12: f64, %13: f64): - %14 = arith.mulf %11, %12 fastmath : f64 - %15 = arith.addf %13, %14 fastmath : f64 - linalg.yield %15 : f64 - } - %t1 = func.call @nanoTime() : () -> i64 - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} diff --git a/data/polybench/gen/2mm_gen.py b/data/polybench/gen/2mm_gen.py deleted file mode 100644 index 9a1d78e..0000000 --- a/data/polybench/gen/2mm_gen.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -import os -import sys -import json - -base_name = "2mm" -bench_file = f"{base_name}.mlir.bench" -order = ['A', 'B', 'C', 'D', 'alpha', 'beta', 'output'] - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -with open(bench_file, "r") as f: - base_code = f.read() - -execution_times = {} - -for i in range(5, 13): - MATRIX_SIZE = 2 ** i - bench_name = f"{base_name}_{MATRIX_SIZE}" - bench_output = f"../{bench_name}.mlir" - - params = { - "NI": MATRIX_SIZE, - "NJ": MATRIX_SIZE, - "NK": MATRIX_SIZE, - "NL": MATRIX_SIZE, - } - - inputs = { - 'A': np.random.rand(params['NI'], params['NK']) * 100, - 'B': np.random.rand(params['NK'], params['NJ']) * 100, - 'C': np.random.rand(params['NJ'], params['NL']) * 100, - 'D': np.random.rand(params['NI'], params['NL']) * 100, - 'alpha': np.random.rand(1), - 'beta': np.random.rand(1), - 'output': np.zeros((params['NI'], params['NL'])), - } - expected = inputs['alpha'] * inputs['A'] @ inputs['B'] @ inputs['C'] + inputs['beta'] * inputs['D'] - np.savez(f"{bench_output}.npz", **inputs) - - code = base_code - for key, value in params.items(): - code = code.replace(key, str(value)) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args = [] - for arg_name in order: - args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - - delta_arg = (ctypes.c_int64 * 1)(0) - args.append(delta_arg) - - try: - execution_engine.invoke("main", *args) - execution_engine.invoke("main", *args) - except Exception as e: - print("Benchmark failed:", bench_name, e, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - os.remove(f"{bench_output}.npz") - break - - actual = inputs[order[-1]] - assertion = np.allclose(actual, expected) - if not assertion: - print("Assertion failed:", bench_name, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - with open(bench_output, "w") as f: - f.write(code) - np.save(f"{bench_output}.npy", expected) - execution_times[bench_name] = exec_time - -with open('../execution_times.json', 'r') as f: - data: dict = json.load(f) -data.update(execution_times) -with open('../execution_times.json', 'w') as f: - json.dump(data, f, indent=4) diff --git a/data/polybench/gen/3mm.mlir.bench b/data/polybench/gen/3mm.mlir.bench deleted file mode 100644 index 9b21dfd..0000000 --- a/data/polybench/gen/3mm.mlir.bench +++ /dev/null @@ -1,56 +0,0 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref, %F: memref, %output: memref) -> i64 attributes { llvm.emit_c_interface } { - %t0 = func.call @nanoTime() : () -> i64 - %c0.0_f64 = arith.constant 0.0 : f64 - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } outs(%E: memref) { - ^bb0(%arg0: f64): - linalg.yield %c0.0_f64 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2)->(d0, d2)>, affine_map<(d0, d1, d2)->(d2, d1)>, affine_map<(d0, d1, d2)->(d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B: memref, memref) outs(%E: memref) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64): - %0 = arith.mulf %arg0, %arg1 fastmath : f64 - %1 = arith.addf %arg2, %0 fastmath : f64 - linalg.yield %1 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } outs(%F: memref) { - ^bb0(%arg0: f64): - linalg.yield %c0.0_f64 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2)->(d0, d2)>, affine_map<(d0, d1, d2)->(d2, d1)>, affine_map<(d0, d1, d2)->(d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%C, %D: memref, memref) outs(%F: memref) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64): - %0 = arith.mulf %arg0, %arg1 fastmath : f64 - %1 = arith.addf %arg2, %0 fastmath : f64 - linalg.yield %1 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } outs(%output: memref) { - ^bb0(%arg0: f64): - linalg.yield %c0.0_f64 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2)->(d0, d2)>, affine_map<(d0, d1, d2)->(d2, d1)>, affine_map<(d0, d1, d2)->(d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%E, %F: memref, memref) outs(%output: memref) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64): - %0 = arith.mulf %arg0, %arg1 fastmath : f64 - %1 = arith.addf %arg2, %0 fastmath : f64 - linalg.yield %1 : f64 - } - %t1 = func.call @nanoTime() : () -> i64 - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} diff --git a/data/polybench/gen/3mm_gen.py b/data/polybench/gen/3mm_gen.py deleted file mode 100644 index 97aa034..0000000 --- a/data/polybench/gen/3mm_gen.py +++ /dev/null @@ -1,121 +0,0 @@ -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -import os -import sys -import json - -base_name = "3mm" -bench_file = f"{base_name}.mlir.bench" -order = ['A', 'B', 'C', 'D', 'E', 'F', 'output'] - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -with open(bench_file, "r") as f: - base_code = f.read() - -execution_times = {} - -for i in range(5, 13): - MATRIX_SIZE = 2 ** i - bench_name = f"{base_name}_{MATRIX_SIZE}" - bench_output = f"../{bench_name}.mlir" - - params = { - "NI": MATRIX_SIZE, - "NJ": MATRIX_SIZE, - "NK": MATRIX_SIZE, - "NL": MATRIX_SIZE, - "NM": MATRIX_SIZE, - } - - inputs = { - 'A': np.random.rand(params['NI'], params['NK']) * 100, - 'B': np.random.rand(params['NK'], params['NJ']) * 100, - 'C': np.random.rand(params['NJ'], params['NM']) * 100, - 'D': np.random.rand(params['NM'], params['NL']) * 100, - 'E': np.zeros((params['NI'], params['NJ'])), - 'F': np.zeros((params['NJ'], params['NL'])), - 'output': np.zeros((params['NI'], params['NL'])), - } - expected = (inputs['A'] @ inputs['B']) @ (inputs['C'] @ inputs['D']) - np.savez(f"{bench_output}.npz", **inputs) - - code = base_code - for key, value in params.items(): - code = code.replace(key, str(value)) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args = [] - for arg_name in order: - args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - - delta_arg = (ctypes.c_int64 * 1)(0) - args.append(delta_arg) - - try: - execution_engine.invoke("main", *args) - execution_engine.invoke("main", *args) - except Exception as e: - print("Benchmark failed:", bench_name, e, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - os.remove(f"{bench_output}.npz") - break - - actual = inputs[order[-1]] - assertion = np.allclose(actual, expected) - if not assertion: - print("Assertion failed:", bench_name, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - with open(bench_output, "w") as f: - f.write(code) - np.save(f"{bench_output}.npy", expected) - execution_times[bench_name] = exec_time - -with open('../execution_times.json', 'r') as f: - data: dict = json.load(f) -data.update(execution_times) -with open('../execution_times.json', 'w') as f: - json.dump(data, f, indent=4) diff --git a/data/polybench/gen/fdtd.mlir.bench b/data/polybench/gen/fdtd.mlir.bench deleted file mode 100644 index 965adee..0000000 --- a/data/polybench/gen/fdtd.mlir.bench +++ /dev/null @@ -1,68 +0,0 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%EX: memref, %EY: memref, %HZ: memref) -> i64 attributes { llvm.emit_c_interface } { - %t0 = func.call @nanoTime() : () -> i64 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c0.5_f64 = arith.constant 0.5 : f64 - %c0.7_f64 = arith.constant 0.7 : f64 - %cTMAX = arith.constant TMAX : index - scf.for %0 = %c0 to %cTMAX step %c1 { - %EY_i0 = memref.subview %EY[0, 0][1, NY0][1, 1] : memref to memref<1xNY0xf64> - %EY_j = memref.collapse_shape %EY_i0 [[0, 1]] : memref<1xNY0xf64> into memref - linalg.generic { - indexing_maps = [affine_map<(d0)->(d0)>], - iterator_types = ["parallel"] - } outs (%EY_j: memref) { - ^bb0(%arg0: f64): - %1 = arith.index_cast %0 : index to i64 - %2 = arith.sitofp %1 : i64 to f64 - linalg.yield %2 : f64 - } - %HZ_i_1 = memref.subview %HZ[0, 0][NX1, NY0][1, 1] : memref to memref - %HZ_io_1 = memref.subview %HZ[1, 0][NX1, NY0][1, 1] : memref to memref> - %EY_io_1 = memref.subview %EY[1, 0][NX1, NY0][1, 1] : memref to memref> - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%HZ_i_1, %HZ_io_1: memref, memref>) outs(%EY_io_1: memref>) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64): - %1 = arith.subf %arg1, %arg0 fastmath : f64 - %2 = arith.mulf %c0.5_f64, %1 fastmath : f64 - %3 = arith.subf %arg2, %2 fastmath : f64 - linalg.yield %3 : f64 - } - %HZ_j_1 = memref.subview %HZ[0, 0][NX0, NY1][1, 1] : memref to memref> - %HZ_jo_1 = memref.subview %HZ[0, 1][NX0, NY1][1, 1] : memref to memref> - %EX_jo_1 = memref.subview %EX[0, 1][NX0, NY1][1, 1] : memref to memref> - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%HZ_j_1, %HZ_jo_1: memref>, memref>) outs(%EX_jo_1: memref>) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64): - %1 = arith.subf %arg1, %arg0 fastmath : f64 - %2 = arith.mulf %c0.5_f64, %1 fastmath : f64 - %3 = arith.subf %arg2, %2 fastmath : f64 - linalg.yield %3 : f64 - } - %HZ_ij_1 = memref.subview %HZ[0, 0][NX1, NY1][1, 1] : memref to memref> - %EX_ij_1 = memref.subview %EX[0, 0][NX1, NY1][1, 1] : memref to memref> - %EY_ij_1 = memref.subview %EY[0, 0][NX1, NY1][1, 1] : memref to memref> - %EX_ijo_1 = memref.subview %EX[0, 1][NX1, NY1][1, 1] : memref to memref> - %EY_ioj_1 = memref.subview %EY[1, 0][NX1, NY1][1, 1] : memref to memref> - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%EX_ij_1, %EY_ij_1, %EX_ijo_1, %EY_ioj_1: memref>, memref>, memref>, memref>) outs(%HZ_ij_1: memref>) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64, %arg3: f64, %arg4: f64): - %1 = arith.subf %arg2, %arg0 fastmath : f64 - %2 = arith.addf %1, %arg3 fastmath : f64 - %3 = arith.subf %2, %arg1 fastmath : f64 - %4 = arith.mulf %c0.7_f64, %3 fastmath : f64 - %5 = arith.subf %arg4, %4 fastmath : f64 - linalg.yield %5 : f64 - } - } - %t1 = func.call @nanoTime() : () -> i64 - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} diff --git a/data/polybench/gen/fdtd_gen.py b/data/polybench/gen/fdtd_gen.py deleted file mode 100644 index 459f240..0000000 --- a/data/polybench/gen/fdtd_gen.py +++ /dev/null @@ -1,129 +0,0 @@ -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -import os -import sys -import json - -base_name = "fdtd" -bench_file = f"{base_name}.mlir.bench" -order = ['EX', 'EY', 'HZ'] - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -with open(bench_file, "r") as f: - base_code = f.read() - -execution_times = {} - -for TMAX in [2, 10, 20, 50, 100]: - for i in range(5, 13): - MATRIX_SIZE = 2 ** i - bench_name = f"{base_name}_{MATRIX_SIZE}_{TMAX}" - bench_output = f"../{bench_name}.mlir" - - params = { - "NX0": MATRIX_SIZE, - "NY0": MATRIX_SIZE, - "NX1": MATRIX_SIZE - 1, - "NY1": MATRIX_SIZE - 1, - "TMAX": TMAX, - } - - inputs = { - 'EX': np.random.rand(MATRIX_SIZE, MATRIX_SIZE) * 100, - 'EY': np.random.rand(MATRIX_SIZE, MATRIX_SIZE) * 100, - 'HZ': np.random.rand(MATRIX_SIZE, MATRIX_SIZE) * 100, - } - EY = inputs['EY'].copy() - EX = inputs['EX'].copy() - HZ = inputs['HZ'].copy() - for _ in range(2): - for t in range(TMAX): - EY[0, :] = t - EY[1:, :] = EY[1:, :] - 0.5 * (HZ[1:, :] - HZ[:-1, :]) - EX[:, 1:] = EX[:, 1:] - 0.5 * (HZ[:, 1:] - HZ[:, :-1]) - HZ[:-1, :-1] = HZ[:-1, :-1] - 0.7 * ( - EX[:-1, 1:] - EX[:-1, :-1] + EY[1:, :-1] - EY[:-1, :-1] - ) - expected = HZ - np.savez(f"{bench_output}.npz", **inputs) - - code = base_code - for key, value in params.items(): - code = code.replace(key, str(value)) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args = [] - for arg_name in order: - args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - - delta_arg = (ctypes.c_int64 * 1)(0) - args.append(delta_arg) - - try: - execution_engine.invoke("main", *args) - execution_engine.invoke("main", *args) - except Exception as e: - print("Benchmark failed:", bench_name, e, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - os.remove(f"{bench_output}.npz") - break - - actual = inputs[order[-1]] - assertion = np.allclose(actual, expected) - if not assertion: - print("Assertion failed:", bench_name, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - with open(bench_output, "w") as f: - f.write(code) - np.save(f"{bench_output}.npy", expected) - execution_times[bench_name] = exec_time - -with open('../execution_times.json', 'r') as f: - data: dict = json.load(f) -data.update(execution_times) -with open('../execution_times.json', 'w') as f: - json.dump(data, f, indent=4) diff --git a/data/polybench/gen/floyd.mlir.bench b/data/polybench/gen/floyd.mlir.bench deleted file mode 100644 index cbf3e20..0000000 --- a/data/polybench/gen/floyd.mlir.bench +++ /dev/null @@ -1,17 +0,0 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%P: memref) -> i64 attributes { llvm.emit_c_interface } { - %t0 = func.call @nanoTime() : () -> i64 - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2)->(d1, d0)>, affine_map<(d0, d1, d2)->(d0, d2)>, affine_map<(d0, d1, d2)->(d1, d2)>], - iterator_types = ["reduction", "parallel", "parallel"] - } ins(%P, %P: memref, memref) outs(%P: memref) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64): - %1 = arith.addf %arg0, %arg1 fastmath : f64 - %2 = arith.cmpf olt, %arg2, %1 fastmath : f64 - %3 = arith.select %2, %arg2, %1 : f64 - linalg.yield %3 : f64 - } - %t1 = func.call @nanoTime() : () -> i64 - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} diff --git a/data/polybench/gen/floyd_gen.py b/data/polybench/gen/floyd_gen.py deleted file mode 100644 index 14312ec..0000000 --- a/data/polybench/gen/floyd_gen.py +++ /dev/null @@ -1,114 +0,0 @@ -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -import os -import sys -import json - -base_name = "floyd" -bench_file = f"{base_name}.mlir.bench" -order = ['P'] - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -with open(bench_file, "r") as f: - base_code = f.read() - -execution_times = {} - -for i in range(5, 13): - MATRIX_SIZE = 2 ** i - bench_name = f"{base_name}_{MATRIX_SIZE}" - bench_output = f"../{bench_name}.mlir" - - params = { - "N": MATRIX_SIZE, - } - - inputs = { - 'P': np.random.rand(MATRIX_SIZE, MATRIX_SIZE) * 100, - } - P = inputs['P'].copy() - for k in range(MATRIX_SIZE): - P[:, :] = np.minimum(P[:, :], P[:, k].reshape(-1, 1) + P[k, :].reshape(1, -1)) - expected = P - np.savez(f"{bench_output}.npz", **inputs) - - code = base_code - for key, value in params.items(): - code = code.replace(key, str(value)) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args = [] - for arg_name in order: - args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - - delta_arg = (ctypes.c_int64 * 1)(0) - args.append(delta_arg) - - try: - execution_engine.invoke("main", *args) - execution_engine.invoke("main", *args) - except Exception as e: - print("Benchmark failed:", bench_name, e, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - os.remove(f"{bench_output}.npz") - break - - actual = inputs[order[-1]] - assertion = np.allclose(actual, expected) - if not assertion: - print("Assertion failed:", bench_name, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - with open(bench_output, "w") as f: - f.write(code) - np.save(f"{bench_output}.npy", expected) - execution_times[bench_name] = exec_time - -with open('../execution_times.json', 'r') as f: - data: dict = json.load(f) -data.update(execution_times) -with open('../execution_times.json', 'w') as f: - json.dump(data, f, indent=4) diff --git a/data/polybench/gen/gemm.mlir.bench b/data/polybench/gen/gemm.mlir.bench deleted file mode 100644 index 7641e9c..0000000 --- a/data/polybench/gen/gemm.mlir.bench +++ /dev/null @@ -1,27 +0,0 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%A: memref, %B: memref, %alpha: memref, %beta: memref, %C: memref) -> i64 attributes { llvm.emit_c_interface } { - %alpha_v = memref.load %alpha[] : memref - %beta_v = memref.load %beta[] : memref - %t0 = func.call @nanoTime() : () -> i64 - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } outs(%C: memref) { - ^bb0(%arg0: f64): - %0 = arith.mulf %beta_v, %arg0 fastmath : f64 - linalg.yield %0 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2)->(d0, d2)>, affine_map<(d0, d1, d2)->(d2, d1)>, affine_map<(d0, d1, d2)->(d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B: memref, memref) outs(%C: memref) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64): - %0 = arith.mulf %arg0, %arg1 fastmath : f64 - %1 = arith.mulf %alpha_v, %0 fastmath : f64 - %2 = arith.addf %arg2, %1 fastmath : f64 - linalg.yield %2 : f64 - } - %t1 = func.call @nanoTime() : () -> i64 - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} diff --git a/data/polybench/gen/gemm_gen.py b/data/polybench/gen/gemm_gen.py deleted file mode 100644 index 96ebc12..0000000 --- a/data/polybench/gen/gemm_gen.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -import os -import sys -import json - -base_name = "gemm" -bench_file = f"{base_name}.mlir.bench" -order = ['A', 'B', 'alpha', 'beta', 'C'] - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -with open(bench_file, "r") as f: - base_code = f.read() - -execution_times = {} - -for i in range(5, 13): - MATRIX_SIZE = 2 ** i - bench_name = f"{base_name}_{MATRIX_SIZE}" - bench_output = f"../{bench_name}.mlir" - - params = { - "NI": MATRIX_SIZE, - "NJ": MATRIX_SIZE, - "NK": MATRIX_SIZE, - } - - inputs = { - 'A': np.random.rand(params['NI'], params['NK']) * 100, - 'B': np.random.rand(params['NK'], params['NJ']) * 100, - 'alpha': np.random.rand(1), - 'beta': np.random.rand(1), - 'C': np.random.rand(params['NI'], params['NJ']) * 100, - } - C = inputs['C'].copy() - for _ in range(2): - C = inputs['alpha'] * inputs['A'] @ inputs['B'] + inputs['beta'] * C - expected = C - np.savez(f"{bench_output}.npz", **inputs) - - code = base_code - for key, value in params.items(): - code = code.replace(key, str(value)) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args = [] - for arg_name in order: - args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - - delta_arg = (ctypes.c_int64 * 1)(0) - args.append(delta_arg) - - try: - execution_engine.invoke("main", *args) - execution_engine.invoke("main", *args) - except Exception as e: - print("Benchmark failed:", bench_name, e, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - os.remove(f"{bench_output}.npz") - break - - actual = inputs[order[-1]] - assertion = np.allclose(actual, expected) - if not assertion: - print("Assertion failed:", bench_name, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - with open(bench_output, "w") as f: - f.write(code) - np.save(f"{bench_output}.npy", expected) - execution_times[bench_name] = exec_time - -with open('../execution_times.json', 'r') as f: - data: dict = json.load(f) -data.update(execution_times) -with open('../execution_times.json', 'w') as f: - json.dump(data, f, indent=4) diff --git a/data/polybench/gen/jacobi.mlir.bench b/data/polybench/gen/jacobi.mlir.bench deleted file mode 100644 index 945b1b0..0000000 --- a/data/polybench/gen/jacobi.mlir.bench +++ /dev/null @@ -1,45 +0,0 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%B: memref, %A: memref) -> i64 attributes { llvm.emit_c_interface } { - %t0 = func.call @nanoTime() : () -> i64 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c0.2_f64 = arith.constant 0.2 : f64 - %cTSTEPS = arith.constant TSTEPS : index - %A_ijo_2 = memref.subview %A[0, 1][N_2, N_2][1, 1] : memref to memref> - %A_ioj_2 = memref.subview %A[1, 0][N_2, N_2][1, 1] : memref to memref> - %A_iojo_2 = memref.subview %A[1, 1][N_2, N_2][1, 1] : memref to memref> - %A_iojoo_2 = memref.subview %A[1, 2][N_2, N_2][1, 1] : memref to memref> - %A_ioojo_2 = memref.subview %A[2, 1][N_2, N_2][1, 1] : memref to memref> - %B_iojo_2 = memref.subview %B[1, 1][N_2, N_2][1, 1] : memref to memref> - scf.for %0 = %c0 to %cTSTEPS step %c1 { - linalg.generic { - indexing_maps = [ - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)> - ], - iterator_types = ["parallel", "parallel"] - } ins(%A_iojo_2, %A_ioj_2, %A_iojoo_2, %A_ioojo_2, %A_ijo_2: memref>, memref>, memref>, memref>, memref>) outs(%B_iojo_2: memref>) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64): - %1 = arith.addf %arg0, %arg1 fastmath : f64 - %2 = arith.addf %1, %arg2 fastmath : f64 - %3 = arith.addf %2, %arg3 fastmath : f64 - %4 = arith.addf %3, %arg4 fastmath : f64 - %5 = arith.mulf %c0.2_f64, %4 fastmath : f64 - linalg.yield %5 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%B_iojo_2: memref>) outs(%A_iojo_2: memref>) { - ^bb0(%arg0: f64, %arg1: f64): - linalg.yield %arg0 : f64 - } - } - %t1 = func.call @nanoTime() : () -> i64 - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} diff --git a/data/polybench/gen/jacobi_gen.py b/data/polybench/gen/jacobi_gen.py deleted file mode 100644 index 2873359..0000000 --- a/data/polybench/gen/jacobi_gen.py +++ /dev/null @@ -1,122 +0,0 @@ -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -import os -import sys -import json - -base_name = "jacobi" -bench_file = f"{base_name}.mlir.bench" -order = ['B', 'A'] - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -with open(bench_file, "r") as f: - base_code = f.read() - -execution_times = {} - -for TSTEPS in [2, 10, 20, 50, 100]: - for i in range(5, 13): - MATRIX_SIZE = 2 ** i - bench_name = f"{base_name}_{MATRIX_SIZE}_{TSTEPS}" - bench_output = f"../{bench_name}.mlir" - - params = { - "N0": MATRIX_SIZE, - "N_2": MATRIX_SIZE - 2, - "N1": MATRIX_SIZE + 1, - "N2": MATRIX_SIZE + 2, - "N4": MATRIX_SIZE * 2 + 1, - "TSTEPS": TSTEPS, - } - - inputs = { - 'A': np.random.rand(MATRIX_SIZE, MATRIX_SIZE) * 100, - 'B': np.zeros((MATRIX_SIZE, MATRIX_SIZE)), - } - A = inputs['A'].copy() - for _ in range(2): - for _ in range(TSTEPS): - A[1:-1, 1:-1] = 0.2 * (A[1:-1, 1:-1] + A[1:-1, :-2] + A[1:-1, 2:] + A[:-2, 1:-1] + A[2:, 1:-1]) - expected = A - np.savez(f"{bench_output}.npz", **inputs) - - code = base_code - for key, value in params.items(): - code = code.replace(key, str(value)) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args = [] - for arg_name in order: - args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - - delta_arg = (ctypes.c_int64 * 1)(0) - args.append(delta_arg) - - try: - execution_engine.invoke("main", *args) - execution_engine.invoke("main", *args) - except Exception as e: - print("Benchmark failed:", bench_name, e, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - os.remove(f"{bench_output}.npz") - break - - actual = inputs[order[-1]] - assertion = np.allclose(actual, expected) - if not assertion: - print("Assertion failed:", bench_name, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - with open(bench_output, "w") as f: - f.write(code) - np.save(f"{bench_output}.npy", expected) - execution_times[bench_name] = exec_time - -with open('../execution_times.json', 'r') as f: - data: dict = json.load(f) -data.update(execution_times) -with open('../execution_times.json', 'w') as f: - json.dump(data, f, indent=4) diff --git a/data/polybench/gen/seidel.mlir.bench b/data/polybench/gen/seidel.mlir.bench deleted file mode 100644 index 9a3764b..0000000 --- a/data/polybench/gen/seidel.mlir.bench +++ /dev/null @@ -1,58 +0,0 @@ -func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } -func.func @main(%B: memref, %A: memref) -> i64 attributes { llvm.emit_c_interface } { - %t0 = func.call @nanoTime() : () -> i64 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c9.0_f64 = arith.constant 9.0 : f64 - %cTSTEPS = arith.constant TSTEPS : index - %A_ij_2 = memref.subview %A[0, 0][N_2, N_2][1, 1] : memref to memref> - %A_ijo_2 = memref.subview %A[0, 1][N_2, N_2][1, 1] : memref to memref> - %A_ijoo_2 = memref.subview %A[0, 2][N_2, N_2][1, 1] : memref to memref> - %A_ioj_2 = memref.subview %A[1, 0][N_2, N_2][1, 1] : memref to memref> - %A_iojo_2 = memref.subview %A[1, 1][N_2, N_2][1, 1] : memref to memref> - %A_iojoo_2 = memref.subview %A[1, 2][N_2, N_2][1, 1] : memref to memref> - %A_iooj_2 = memref.subview %A[2, 0][N_2, N_2][1, 1] : memref to memref> - %A_ioojo_2 = memref.subview %A[2, 1][N_2, N_2][1, 1] : memref to memref> - %A_ioojoo_2 = memref.subview %A[2, 2][N_2, N_2][1, 1] : memref to memref> - %B_iojo_2 = memref.subview %B[1, 1][N_2, N_2][1, 1] : memref to memref> - scf.for %0 = %c0 to %cTSTEPS step %c1 { - linalg.generic { - indexing_maps = [ - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)>, - affine_map<(d0, d1)->(d0, d1)> - ], - iterator_types = ["parallel", "parallel"] - } ins(%A_ij_2, %A_ijo_2, %A_ijoo_2, %A_ioj_2, %A_iojo_2, %A_iojoo_2, %A_iooj_2, %A_ioojo_2, %A_ioojoo_2: memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>) - outs(%B_iojo_2: memref>) { - ^bb0(%arg0: f64, %arg1: f64, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64, %arg6: f64, %arg7: f64, %arg8: f64, %arg9: f64): - %1 = arith.addf %arg0, %arg1 fastmath : f64 - %2 = arith.addf %1, %arg2 fastmath : f64 - %3 = arith.addf %2, %arg3 fastmath : f64 - %4 = arith.addf %3, %arg4 fastmath : f64 - %5 = arith.addf %4, %arg5 fastmath : f64 - %6 = arith.addf %5, %arg6 fastmath : f64 - %7 = arith.addf %6, %arg7 fastmath : f64 - %8 = arith.addf %7, %arg8 fastmath : f64 - %9 = arith.divf %8, %c9.0_f64 fastmath : f64 - linalg.yield %9 : f64 - } - linalg.generic { - indexing_maps = [affine_map<(d0, d1)->(d0, d1)>, affine_map<(d0, d1)->(d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%B_iojo_2: memref>) outs(%A_iojo_2: memref>) { - ^bb0(%arg0: f64, %arg1: f64): - linalg.yield %arg0 : f64 - } - } - %t1 = func.call @nanoTime() : () -> i64 - %t2 = arith.subi %t1, %t0 : i64 - return %t2 : i64 -} diff --git a/data/polybench/gen/seidel_gen.py b/data/polybench/gen/seidel_gen.py deleted file mode 100644 index 2118c70..0000000 --- a/data/polybench/gen/seidel_gen.py +++ /dev/null @@ -1,124 +0,0 @@ -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager -import os -import sys -import json - -base_name = "seidel" -bench_file = f"{base_name}.mlir.bench" -order = ['B', 'A'] - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - -with open(bench_file, "r") as f: - base_code = f.read() - -execution_times = {} - -for TSTEPS in [2, 10, 20, 50, 100]: - for i in range(5, 13): - MATRIX_SIZE = 2 ** i - bench_name = f"{base_name}_{MATRIX_SIZE}_{TSTEPS}" - bench_output = f"../{bench_name}.mlir" - - params = { - "N0": MATRIX_SIZE, - "N_2": MATRIX_SIZE - 2, - "N1": MATRIX_SIZE + 1, - 'N2': MATRIX_SIZE + 2, - 'N3': MATRIX_SIZE * 2, - 'N4': MATRIX_SIZE * 2 + 1, - 'N5': MATRIX_SIZE * 2 + 2, - "TSTEPS": TSTEPS, - } - - inputs = { - 'A': np.random.rand(MATRIX_SIZE, MATRIX_SIZE) * 100, - 'B': np.zeros((MATRIX_SIZE, MATRIX_SIZE)), - } - A = inputs['A'].copy() - for _ in range(2): - for _ in range(TSTEPS): - A[1:-1, 1:-1] = (A[:-2, :-2] + A[:-2, 1:-1] + A[:-2, 2:] + A[1:-1, :-2] + A[1:-1, 1:-1] + A[1:-1, 2:] + A[2:, :-2] + A[2:, 1:-1] + A[2:, 2:]) / 9.0 - expected = A - np.savez(f"{bench_output}.npz", **inputs) - - code = base_code - for key, value in params.items(): - code = code.replace(key, str(value)) - - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - - args = [] - for arg_name in order: - args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - - delta_arg = (ctypes.c_int64 * 1)(0) - args.append(delta_arg) - - try: - execution_engine.invoke("main", *args) - execution_engine.invoke("main", *args) - except Exception as e: - print("Benchmark failed:", bench_name, e, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - os.remove(f"{bench_output}.npz") - break - - actual = inputs[order[-1]] - assertion = np.allclose(actual, expected) - if not assertion: - print("Assertion failed:", bench_name, file=sys.stderr) - os.remove(f"{bench_output}.npz") - continue - - with open(bench_output, "w") as f: - f.write(code) - np.save(f"{bench_output}.npy", expected) - execution_times[bench_name] = exec_time - -with open('../execution_times.json', 'r') as f: - data: dict = json.load(f) -data.update(execution_times) -with open('../execution_times.json', 'w') as f: - json.dump(data, f, indent=4) diff --git a/demo.ipynb b/demo.ipynb old mode 100644 new mode 100755 index 894cf2b..8d74a9a --- a/demo.ipynb +++ b/demo.ipynb @@ -2,24 +2,48 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "2e74d0c8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/bin/bash: line 1: .environment: No such file or directory\n" + ] + } + ], "source": [ "# Setup environment\n", - "# import os\n", + "import os\n", "from dotenv import load_dotenv\n", "load_dotenv(override=True)\n", - "# os.chdir(os.path.dirname(os.getcwd()))" + "os.chdir(os.path.dirname(os.getcwd()))\n", + "\n", + "!source .environment" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "29ec62e8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'mlir'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Import modules\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01menv\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Env\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m HiearchyModel \u001b[38;5;28;01mas\u001b[39;00m Model\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mppo\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m evaluate_benchmark\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/MLIR-RL/rl_autoschedular/env.py:4\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mstate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m OperationState, BenchmarkFeatures, extract_bench_features_from_file\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Optional\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mevaluation\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m evaluate_code\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mactions\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Action\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mlog\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m print_error\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/MLIR-RL/rl_autoschedular/evaluation.py:3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmlir\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mir\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Context, Module\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmlir\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mexecution_engine\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ExecutionEngine, ctypes\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmlir\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mruntime\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_ranked_memref_descriptor\n", + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'mlir'" + ] + } + ], "source": [ "# Import modules\n", "import torch\n", @@ -82,7 +106,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mlir-rl", + "display_name": "mlir-venv (3.11.13)", "language": "python", "name": "python3" }, diff --git a/demo.py b/demo.py old mode 100644 new mode 100755 index 5722d80..45ec519 --- a/demo.py +++ b/demo.py @@ -1,33 +1,33 @@ - -# Setup environment -from dotenv import load_dotenv -load_dotenv(override=True) - - -# Import modules -import torch -from rl_autoschedular.env import Env -from rl_autoschedular.model import HiearchyModel as Model -from rl_autoschedular.ppo import evaluate_benchmark - - -# Configure torch -torch.set_grad_enabled(False) -torch.set_num_threads(4) - - -# Instantiate the environment -eval_env = Env(is_training=False) - - -# Load the model -model_path = "models/model.pth" -model = Model() -model.load_state_dict(torch.load(model_path, weights_only=True)) - - -# Evaluate the model -evaluate_benchmark( - model, - eval_env, -) + +# Setup environment +from dotenv import load_dotenv +load_dotenv(override=True) + + +# Import modules +import torch +from rl_autoschedular.env import Env +from rl_autoschedular.model import HiearchyModel as Model +from rl_autoschedular.ppo import evaluate_benchmark + + +# Configure torch +torch.set_grad_enabled(False) +torch.set_num_threads(4) + + +# Instantiate the environment +eval_env = Env(is_training=False) + + +# Load the model +model_path = "models/model.pth" +model = Model() +model.load_state_dict(torch.load(model_path, weights_only=True)) + + +# Evaluate the model +evaluate_benchmark( + model, + eval_env, +) diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..2941309 --- /dev/null +++ b/eval.py @@ -0,0 +1,47 @@ +# Load environment variables +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import os +from typing import Optional +from utils.log import print_info, print_success + +# Import environment +from rl_autoschedular.env import Env + +# config, file_logger, device +from rl_autoschedular import config as cfg, file_logger as fl, device +from rl_autoschedular.ppo import evaluate_benchmarks + +# Import RL components +from rl_autoschedular.model import HiearchyModel as Model +import time + +torch.set_grad_enabled(False) +torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "4"))) + + + +print_info(f"Config: {cfg}") +print_success(f'Logging to: {fl.run_dir}') + +# Set environments +eval_env = Env(is_training=False,run_name="ppo_online_eval") +print_success(f"Environments initialized: {eval_env.tmp_file}") + +# Set model +model_chkpt = "./checkpoints/model.pth" +model = Model().to(device) +checkpoint = torch.load(model_chkpt, map_location="cpu") +model.load_state_dict(checkpoint, strict=False) # allow partial load +model.eval() + +env_time = evaluate_benchmarks( + model, + eval_env, + step=1 +) + +print(env_time) \ No newline at end of file diff --git a/evaluate.py b/evaluate.py old mode 100644 new mode 100755 index d2bfb87..8886372 --- a/evaluate.py +++ b/evaluate.py @@ -1,84 +1,50 @@ -# Load environment variables -import os -from dotenv import load_dotenv -load_dotenv(override=True) -load_dotenv('.env.debug') - -# Import modules -import torch -from utils.dask_manager import DaskManager -from utils.file_logger import FileLogger -from rl_autoschedular.model import HiearchyModel as Model -from rl_autoschedular import config as cfg, device -from rl_autoschedular.ppo import evaluate_benchmarks -from rl_autoschedular.benchmarks import Benchmarks -from utils.log import print_info, print_success -from time import time -import random -import string -import json -import datetime - - -# Initialize dask in order to allocate jobs -dm = DaskManager() - -# Setup torch -torch.set_grad_enabled(False) -torch.set_num_threads(4) - -# Load data -eval_data = dm.load_eval_data(Benchmarks(is_training=False)) - -# Prepare logging -fl = FileLogger() -print_info(f"Config: {cfg}") -print_success(f'Logging to: {fl.run_dir}') - -# Prepare the temporary execution database -random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) -tmp_exec_data_file = f'tmp-debug/exec/{random_str}.json' if cfg.debug else f'tmp/exec/{random_str}.json' -if not os.path.exists(tmp_exec_data_file): - os.makedirs(os.path.dirname(tmp_exec_data_file), exist_ok=True) - with open(tmp_exec_data_file, "w") as file: - json.dump({}, file) -print_info(f"Temporary execution data saved to: {tmp_exec_data_file}") - -if cfg.exec_data_file: - print_info(f"Global execution data located in: {cfg.exec_data_file}") - -# Initiate model -model = Model().to(device) -print_success("Model initialized") - -# Start evaluation -eval_dir = os.getenv('EVAL_DIR') -if eval_dir is None: - raise ValueError("EVAL_DIR environment variable is not set.") -eval_dir = os.path.abspath(eval_dir) - -# Read the files in the evaluation directory -eval_files = [f for f in os.listdir(eval_dir) if f.endswith('.pt')] - -# Order files -eval_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) - -time_ms = 0 -eta = 0 -models_count = len(eval_files) -for step, model_file in enumerate(eval_files): - print_info(f"- Evaluation {step + 1}/{models_count} ({100 * (step + 1) / models_count:.2f}%) ({time_ms}ms) < ({eta})") - - main_start = time() - - model_path = os.path.join(eval_dir, model_file) - if not os.path.exists(model_path): - print_info(f"Model file {model_path} does not exist. Skipping.") - continue - model.load_state_dict(torch.load(model_path, weights_only=True)) - - evaluate_benchmarks(model, eval_data, tmp_exec_data_file) - - main_end = time() - time_ms = int((main_end - main_start) * 1000) - eta = datetime.timedelta(seconds=time_ms * (models_count - step - 1) / 1000) +# Load environment variables +from dotenv import load_dotenv +load_dotenv(override=True) + +# Import modules +from rl_autoschedular.env import Env +from rl_autoschedular.model import HiearchyModel as Model +import torch +import os +from tqdm import tqdm +from rl_autoschedular import config as cfg +from rl_autoschedular import file_logger as fl +from utils.log import print_info, print_success +from rl_autoschedular.ppo import evaluate_benchmark + +torch.set_grad_enabled(False) +torch.set_num_threads(4) + +print_info(f"Config: {cfg}") +print_success(f'Logging to: {fl.run_dir}') + +# Set environments +eval_env = Env(is_training=False) +print_success(f"Environments initialized: {eval_env.tmp_file}") + +# Start training +eval_dir = os.getenv('EVAL_DIR') +if eval_dir is None: + raise ValueError("EVAL_DIR environment variable is not set.") +eval_dir = os.path.abspath(eval_dir) + +# Read the files in the evaluation directory +eval_files = [f for f in os.listdir(eval_dir) if f.endswith('.pth')] + +# Order files +eval_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) + +files_tqdm = tqdm(eval_files, desc='Evaluating models') +for model_file in files_tqdm: + files_tqdm.set_postfix_str(f"Evaluating {model_file}") + model = Model() + model_path = os.path.join(eval_dir, model_file) + if not os.path.exists(model_path): + print_info(f"Model file {model_path} does not exist. Skipping.") + continue + model.load_state_dict(torch.load(model_path, weights_only=True)) + evaluate_benchmark( + model, + eval_env, + ) diff --git a/filelog_clean.py b/filelog_clean.py old mode 100644 new mode 100755 index f9d948e..a8c6b83 --- a/filelog_clean.py +++ b/filelog_clean.py @@ -1,23 +1,23 @@ -import os - - -results_dir = 'results' -with open(os.path.join(results_dir, 'synced_ids'), 'r') as f: - synced_ids = [int(id) for id in f.readlines() if id.strip()] - -current_runs = [d for d in os.listdir(results_dir) if d.startswith('run_') and int(d.split('_')[1]) in synced_ids] - -if not current_runs: - print('No leftover runs to clean') - exit() - -print(f'Cleaning runs: {current_runs}') -for run in current_runs: - run_path = os.path.join(results_dir, run) - for root, _, filenames in os.walk(run_path, topdown=False): - for filename in filenames: - os.remove(os.path.join(root, filename)) - os.rmdir(root) - -with open(os.path.join(results_dir, 'synced_ids'), 'w') as f: - pass +import os + + +results_dir = 'results' +with open(os.path.join(results_dir, 'synced_ids'), 'r') as f: + synced_ids = [int(id) for id in f.readlines() if id.strip()] + +current_runs = [d for d in os.listdir(results_dir) if d.startswith('run_') and int(d.split('_')[1]) in synced_ids] + +if not current_runs: + print('No leftover runs to clean') + exit() + +print(f'Cleaning runs: {current_runs}') +for run in current_runs: + run_path = os.path.join(results_dir, run) + for root, _, filenames in os.walk(run_path, topdown=False): + for filename in filenames: + os.remove(os.path.join(root, filename)) + os.rmdir(root) + +with open(os.path.join(results_dir, 'synced_ids'), 'w') as f: + pass diff --git a/fill_db.py b/fill_db.py old mode 100644 new mode 100755 index 8038106..588a944 --- a/fill_db.py +++ b/fill_db.py @@ -1,83 +1,83 @@ -from rl_autoschedular.env import Env -from rl_autoschedular.model import apply_masks, extract_masks, indices_to_raw_actions -import torch -import math -from torch.distributions import Categorical, Distribution, Uniform -from rl_autoschedular import config as cfg -from tqdm import tqdm - - -N = cfg.num_transformations -L = cfg.max_num_loops -TS = cfg.num_tile_sizes -match cfg.interchange_mode: - case 'enumerate': - interchange_mask = 3 * L - 6 - case 'pointers': - interchange_mask = L - case 'continuous': - interchange_mask = 0 -action_mask_size = N + 2 * L * (TS + 1) + interchange_mask - - -def create_uniform_distributions(obs: torch.Tensor, num_loops: list[int]) -> tuple[Distribution, Distribution, Distribution, Distribution]: - """Create uniform distributions for the actions. - - Args: - obs (torch.Tensor): The input tensor. - - Returns: - tuple[Distribution, Distribution, Distribution, Distribution]: The uniform distributions for the transformations, parallelizations, tilings, and interchanges. - """ - batch_size = obs.shape[0] - action_mask = obs[:, -(action_mask_size):].bool() - - transformation_logits = torch.zeros((batch_size, N), dtype=torch.float32) - parallelization_logits = torch.zeros((batch_size, L, TS + 1), dtype=torch.float32) - tiling_logits = torch.zeros((batch_size, L, TS + 1), dtype=torch.float32) - match cfg.interchange_mode: - case 'enumerate': - interchange_logits = torch.zeros((batch_size, 3 * L - 6), dtype=torch.float32) - case 'pointers': - interchange_logits = torch.zeros((batch_size, L), dtype=torch.float32) - case 'continuous': - interchange_logits = torch.zeros((batch_size, 1), dtype=torch.float32) - - # Apply masks on logits - transformation_logits, parallelization_logits, tiling_logits, interchange_logits = apply_masks(transformation_logits, parallelization_logits, tiling_logits, interchange_logits, *extract_masks(action_mask)) - - # Create distributions with the masked probabilities - transformation_dist = Categorical(logits=transformation_logits) - parallelization_dist = Categorical(logits=parallelization_logits) - tiling_dist = Categorical(logits=tiling_logits) - if cfg.interchange_mode != 'continuous': - interchange_dist = Categorical(logits=interchange_logits) - else: - total_count = torch.tensor([math.factorial(loops) for loops in num_loops], dtype=torch.float64) - interchange_dist = Uniform(0.0, total_count) - - return transformation_dist, parallelization_dist, tiling_dist, interchange_dist - - -if __name__ == "__main__": - env = Env(is_training=True) - print(f"Environments initialized: {env.tmp_file}") - - pbar = tqdm(unit="bench") - while True: - state, obs = env.reset() - bench_done = False - while not bench_done: - num_loops = len(state.operation_features.nested_loops) - transformation_eps_dist, parallelization_eps_dist, tiling_eps_dist, interchange_eps_dist = create_uniform_distributions(obs, [num_loops]) - transformation_index = transformation_eps_dist.sample() - parallelization_index = parallelization_eps_dist.sample() - tiling_index = tiling_eps_dist.sample() - interchange_index = interchange_eps_dist.sample().long() - actions = indices_to_raw_actions(transformation_index, parallelization_index, tiling_index, interchange_index, [num_loops]) - next_state, next_obs, _, op_done, _ = env.step(state, actions[0]) - if op_done: - next_state, next_obs, bench_done = env.get_next_op_state(next_state) - state = next_state - obs = next_obs - pbar.update(1) +from rl_autoschedular.env import Env +from rl_autoschedular.model import apply_masks, extract_masks, indices_to_raw_actions +import torch +import math +from torch.distributions import Categorical, Distribution, Uniform +from rl_autoschedular import config as cfg +from tqdm import tqdm + + +N = cfg.num_transformations +L = cfg.max_num_loops +TS = cfg.num_tile_sizes +match cfg.interchange_mode: + case 'enumerate': + interchange_mask = 3 * L - 6 + case 'pointers': + interchange_mask = L + case 'continuous': + interchange_mask = 0 +action_mask_size = N + 2 * L * (TS + 1) + interchange_mask + + +def create_uniform_distributions(obs: torch.Tensor, num_loops: list[int]) -> tuple[Distribution, Distribution, Distribution, Distribution]: + """Create uniform distributions for the actions. + + Args: + obs (torch.Tensor): The input tensor. + + Returns: + tuple[Distribution, Distribution, Distribution, Distribution]: The uniform distributions for the transformations, parallelizations, tilings, and interchanges. + """ + batch_size = obs.shape[0] + action_mask = obs[:, -(action_mask_size):].bool() + + transformation_logits = torch.zeros((batch_size, N), dtype=torch.float32) + parallelization_logits = torch.zeros((batch_size, L, TS + 1), dtype=torch.float32) + tiling_logits = torch.zeros((batch_size, L, TS + 1), dtype=torch.float32) + match cfg.interchange_mode: + case 'enumerate': + interchange_logits = torch.zeros((batch_size, 3 * L - 6), dtype=torch.float32) + case 'pointers': + interchange_logits = torch.zeros((batch_size, L), dtype=torch.float32) + case 'continuous': + interchange_logits = torch.zeros((batch_size, 1), dtype=torch.float32) + + # Apply masks on logits + transformation_logits, parallelization_logits, tiling_logits, interchange_logits = apply_masks(transformation_logits, parallelization_logits, tiling_logits, interchange_logits, *extract_masks(action_mask)) + + # Create distributions with the masked probabilities + transformation_dist = Categorical(logits=transformation_logits) + parallelization_dist = Categorical(logits=parallelization_logits) + tiling_dist = Categorical(logits=tiling_logits) + if cfg.interchange_mode != 'continuous': + interchange_dist = Categorical(logits=interchange_logits) + else: + total_count = torch.tensor([math.factorial(loops) for loops in num_loops], dtype=torch.float64) + interchange_dist = Uniform(0.0, total_count) + + return transformation_dist, parallelization_dist, tiling_dist, interchange_dist + + +if __name__ == "__main__": + env = Env(is_training=True) + print(f"Environments initialized: {env.tmp_file}") + + pbar = tqdm(unit="bench") + while True: + state, obs = env.reset() + bench_done = False + while not bench_done: + num_loops = len(state.operation_features.nested_loops) + transformation_eps_dist, parallelization_eps_dist, tiling_eps_dist, interchange_eps_dist = create_uniform_distributions(obs, [num_loops]) + transformation_index = transformation_eps_dist.sample() + parallelization_index = parallelization_eps_dist.sample() + tiling_index = tiling_eps_dist.sample() + interchange_index = interchange_eps_dist.sample().long() + actions = indices_to_raw_actions(transformation_index, parallelization_index, tiling_index, interchange_index, [num_loops]) + next_state, next_obs, _, op_done, _ = env.step(state, actions[0]) + if op_done: + next_state, next_obs, bench_done = env.get_next_op_state(next_state) + state = next_state + obs = next_obs + pbar.update(1) diff --git a/gen.py b/gen.py old mode 100644 new mode 100755 index 2685d04..acd2e9d --- a/gen.py +++ b/gen.py @@ -1,343 +1,343 @@ -from rl_autoschedular import config as cfg -from rl_autoschedular.state import OperationFeatures, NestedLoopFeatures -import random -import re -import math -import os -import sys -import json -from tqdm import trange -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager - - -output_dir = 'data/features' -inputs: dict[str, np.ndarray] = {} - - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - convert-bufferization-to-memref, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - - -def gen_features() -> OperationFeatures: - # Nested loops - num_loops = random.randint(1, cfg.max_num_loops) - reduction_count = random.randint(0, min(3, num_loops - 1)) - iterator_types = ['parallel'] * (num_loops - reduction_count) + ['reduction'] * reduction_count - max_iterations = 10 ** 9 - max_per_loop = math.ceil(max_iterations ** (1 / num_loops)) - iterations = max_iterations - while iterations >= max_iterations: - iterations = 1 - upper_bounds = [] - for _ in range(num_loops): - upper_bound = random.randint(2, min(max_per_loop * 2, 4096)) - upper_bounds.append(upper_bound) - iterations *= upper_bound - nested_loops = [ - NestedLoopFeatures( - arg=f'd{i}', - lower_bound=0, - upper_bound=upper_bounds[i], - step=1, - iterator_type=iterator_types[i] - ) - for i in range(num_loops) - ] - - # Operators count - total_op_count = 0 - while total_op_count == 0: - op_count = { - '+': random.randint(0, 10), - '-': random.randint(0, 10), - '*': random.randint(0, 10), - '/': 0, # TODO: Figure out how to handle division - 'exp': random.randint(0, 2), - } - total_op_count = sum(op_count.values()) - - # Load data - max_load_size = 2 ** 24 - num_loads = random.randint(1, cfg.max_num_stores_loads) - load_data: list[list[str]] = [] - per_loop = max(math.ceil(iterations ** (1 / num_loops)), 2) - max_dim = math.ceil(math.log(max_load_size) / math.log(per_loop)) - args_dict = {loop.arg: loop.upper_bound for loop in nested_loops} - unseen_args = set(args_dict.keys()) - for _ in range(num_loads - 1): - load_size = max_load_size - while load_size >= max_load_size: - dims_count = random.randint(1, min(cfg.max_num_load_store_dim, max_dim)) - zeros_count = random.randint(max(0, dims_count - num_loops), dims_count) - load_args = random.sample(list(args_dict.keys()) + ['0'], dims_count, counts=[1] * num_loops + [zeros_count]) - load_size = 1 - for arg in load_args: - if arg == '0': - load_size *= 5 - else: - load_size *= args_dict[arg] - load_data.append(load_args) - for arg in load_args: - unseen_args.discard(arg) - if unseen_args: - load_data.append(list(unseen_args)) - - # Store data - p_args = [loop.arg for loop in nested_loops if loop.iterator_type == 'parallel'] - random.shuffle(p_args) - store_data = p_args - - return OperationFeatures( - raw_operation='', - operation_type='generic', - op_count=op_count, - load_data=load_data, - store_data=store_data, - nested_loops=nested_loops, - vectorizable=True - ) - - -def create_params(op_features: OperationFeatures) -> tuple[list[str], list[str]]: - params = [] - shapes = [] - args_dict = {loop.arg: loop.upper_bound for loop in op_features.nested_loops} - - # Load params - for i, load in enumerate(op_features.load_data): - shape: list[int] = [] - for arg in load: - if arg == '0': - shape.append(random.randint(1, 5)) - continue - shape.append(args_dict[arg]) - # inputs[f'arg{i}'] = np.random.rand(*shape) * 100 - inputs[f'arg{i}'] = np.empty(shape) - params.append(f'%arg{i}') - shapes.append(f"memref<{'x'.join(map(str, shape))}xf64>") - - # Store param - shape = [] - for arg in op_features.store_data: - if arg == '0': - shape.append(random.randint(1, 5)) - continue - shape.append(args_dict[arg]) - # inputs[f'arg{len(params)}'] = np.zeros(shape) - inputs[f'arg{len(params)}'] = np.empty(shape) - params.append(f'%arg{len(params)}') - shapes.append(f"memref<{'x'.join(map(str, shape))}xf64>") - - return params, shapes - - -def create_raw_operation(op_features: OperationFeatures, params: list[str], shapes: list[str]) -> str: - # Affine maps - base_dims = ', '.join([loop.arg for loop in op_features.nested_loops]) - affine_maps = [] - for load in op_features.load_data: - affine_maps.append(f"affine_map<({base_dims}) -> ({', '.join(load)})>") - affine_maps.append(f"affine_map<({base_dims}) -> ({', '.join(op_features.store_data)})>") - affine_maps_attr = f"[{', '.join(affine_maps)}]" - - # Iterators - iterators = ', '.join([f'"{loop.iterator_type}"' for loop in op_features.nested_loops]) - iterators_attr = f'[{iterators}]' - - # Inputs / Outputs - ins = f"ins({', '.join(params[:-1])}: {', '.join(shapes[:-1])})" - outs = f"outs({params[-1]}: {shapes[-1]})" - - code = f"linalg.generic {{indexing_maps={affine_maps_attr}, iterator_types={iterators_attr}}} {ins} {outs} {{\n" - block_args = [f"%in_{i}: f64" for i in range(len(op_features.load_data))] + ["%out: f64"] - code += f"^bb0({', '.join(block_args)}):\n" - - # Linalg body - block_params = [arg.split(':')[0] for arg in block_args] - unused_block_params = set(block_params.copy()) - created_args: set[str] = set() - tmp_count = 0 - op_count_copy = {op: count for op, count in op_features.op_count.items() if count > 0} - assert all(op_count_copy.values()) - total_op_count = sum(op_count_copy.values()) - for _ in range(total_op_count): - op = random.choice(list(op_count_copy.keys())) - if op == 'exp': - if len(unused_block_params) > 0: - operands = random.sample(list(unused_block_params), 1) - unused_block_params.difference_update(operands) - else: - operands = random.sample(list(created_args) + block_params, 1) - else: - if len(unused_block_params) > 1: - operands = random.sample(list(unused_block_params), 2) - unused_block_params.difference_update(operands) - elif len(unused_block_params) == 1: - operands = [unused_block_params.pop()] - unused_block_params = set() - operands += random.sample(list(created_args) + block_params, 1) - else: - operands = random.sample(list(created_args) + block_params, 2) - - result = f"%{tmp_count}" - tmp_count += 1 - created_args.add(result) - match op: - case '+': - code += f"{result} = arith.addf {operands[0]}, {operands[1]} fastmath : f64\n" - case '-': - code += f"{result} = arith.subf {operands[0]}, {operands[1]} fastmath : f64\n" - case '*': - code += f"{result} = arith.mulf {operands[0]}, {operands[1]} fastmath : f64\n" - case '/': - code += f"{result} = arith.divf {operands[0]}, {operands[1]} fastmath : f64\n" - case 'exp': - code += f"{result} = math.exp {operands[0]} fastmath : f64\n" - - op_count_copy[op] -= 1 - if op_count_copy[op] == 0: - del op_count_copy[op] - - assert sum(op_count_copy.values()) == 0 - - code += f"linalg.yield {result} : f64\n" - code += "}\n" - - return code - - -def formatMLIRCode(code: str) -> str: - """Util function that format the MLIR code by adding indents. - - Args: - code (str): the MLIR code - - Returns: - str: the formatted MLIR code - """ - lines = re.sub(r'\n+', '\n', code).split('\n') - result = '' - indent = 0 - for line in lines: - if len(line) > 0: - if line[0] == '}': - if indent > 0: - indent -= 1 - else: - indent = 0 - - result += indent * ' ' + line + '\n' - - if len(line) > 0: - if line[-1] == '{': - indent += 1 - - return result - - -def gen_full_code() -> str: - op_features = gen_features() - - params, shapes = create_params(op_features) - main_params = [f'{param}: {shape}' for param, shape in zip(params, shapes)] - - raw_operation = create_raw_operation(op_features, params, shapes) - - code = ( - f'func.func private @nanoTime() -> i64 attributes {{ llvm.emit_c_interface }}\n' - f'func.func @main({", ".join(main_params)}) -> i64 attributes {{ llvm.emit_c_interface }} {{\n' - f'%t0 = func.call @nanoTime() : () -> i64\n' - f'{raw_operation}\n' - f'%t1 = func.call @nanoTime() : () -> i64\n' - f'%t2 = arith.subi %t1, %t0 : i64\n' - f'return %t2 : i64\n' - f'}}\n' - ) - - code = formatMLIRCode(code) - - return code - - -if __name__ == '__main__': - with open('execution_times.json', 'r') as file: - execution_times: dict[str, int] = json.load(file) - last_count = max([int(k.split('_')[-1]) for k in execution_times.keys()]) + 1 - for i in trange(last_count, 10000, desc='Generating benchmarks', unit='bench'): - bench_generated = False - while not bench_generated: - bench_name = f'generic_{i}' - bench_output = os.path.join(output_dir, f"{bench_name}.mlir") - - inputs = {} - code = gen_full_code() - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - arg_names = sorted(inputs.keys()) - # np.savez(f"{bench_output}.npz", **inputs) - - c_args = [] - for arg_name in arg_names: - c_args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - delta_arg = (ctypes.c_int64 * 1)(0) - c_args.append(delta_arg) - - try: - execution_engine.invoke("main", *c_args) - execution_engine.invoke("main", *c_args) - except Exception as e: - print(f"Failed, Bench: {bench_name}, error: {e}", file=sys.stderr) - # os.remove(f'{bench_output}.npz') - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - # os.remove(f'{bench_output}.npz') - continue - - with open(bench_output, 'w') as f: - f.write(code) - # expected = inputs[arg_names[-1]] - # np.save(f"{bench_output}.npy", expected) - - execution_times[bench_name] = exec_time - with open('execution_times.json', 'w') as file: - json.dump(execution_times, file, indent=4) - - bench_generated = True +from rl_autoschedular import config as cfg +from rl_autoschedular.state import OperationFeatures, NestedLoopFeatures +import random +import re +import math +import os +import sys +import json +from tqdm import trange +import numpy as np +from mlir.ir import Context, Module +from mlir.execution_engine import ExecutionEngine, ctypes +from mlir.runtime import get_ranked_memref_descriptor +from mlir.passmanager import PassManager + + +output_dir = 'data/features' +inputs: dict[str, np.ndarray] = {} + + +pass_pipeline = """builtin.module( + loop-invariant-code-motion, + canonicalize, + convert-vector-to-scf, + convert-linalg-to-loops, + buffer-deallocation-pipeline, + convert-bufferization-to-memref, + scf-forall-to-parallel, + convert-scf-to-openmp, + expand-strided-metadata, + finalize-memref-to-llvm, + convert-scf-to-cf, + lower-affine, + + convert-openmp-to-llvm, + convert-vector-to-llvm, + convert-math-to-llvm, + convert-func-to-llvm, + convert-index-to-llvm, + convert-arith-to-llvm, + convert-cf-to-llvm, + + reconcile-unrealized-casts, + canonicalize, + cse +)""" + + +def gen_features() -> OperationFeatures: + # Nested loops + num_loops = random.randint(1, cfg.max_num_loops) + reduction_count = random.randint(0, min(3, num_loops - 1)) + iterator_types = ['parallel'] * (num_loops - reduction_count) + ['reduction'] * reduction_count + max_iterations = 10 ** 9 + max_per_loop = math.ceil(max_iterations ** (1 / num_loops)) + iterations = max_iterations + while iterations >= max_iterations: + iterations = 1 + upper_bounds = [] + for _ in range(num_loops): + upper_bound = random.randint(2, min(max_per_loop * 2, 4096)) + upper_bounds.append(upper_bound) + iterations *= upper_bound + nested_loops = [ + NestedLoopFeatures( + arg=f'd{i}', + lower_bound=0, + upper_bound=upper_bounds[i], + step=1, + iterator_type=iterator_types[i] + ) + for i in range(num_loops) + ] + + # Operators count + total_op_count = 0 + while total_op_count == 0: + op_count = { + '+': random.randint(0, 10), + '-': random.randint(0, 10), + '*': random.randint(0, 10), + '/': 0, # TODO: Figure out how to handle division + 'exp': random.randint(0, 2), + } + total_op_count = sum(op_count.values()) + + # Load data + max_load_size = 2 ** 24 + num_loads = random.randint(1, cfg.max_num_stores_loads) + load_data: list[list[str]] = [] + per_loop = max(math.ceil(iterations ** (1 / num_loops)), 2) + max_dim = math.ceil(math.log(max_load_size) / math.log(per_loop)) + args_dict = {loop.arg: loop.upper_bound for loop in nested_loops} + unseen_args = set(args_dict.keys()) + for _ in range(num_loads - 1): + load_size = max_load_size + while load_size >= max_load_size: + dims_count = random.randint(1, min(cfg.max_num_load_store_dim, max_dim)) + zeros_count = random.randint(max(0, dims_count - num_loops), dims_count) + load_args = random.sample(list(args_dict.keys()) + ['0'], dims_count, counts=[1] * num_loops + [zeros_count]) + load_size = 1 + for arg in load_args: + if arg == '0': + load_size *= 5 + else: + load_size *= args_dict[arg] + load_data.append(load_args) + for arg in load_args: + unseen_args.discard(arg) + if unseen_args: + load_data.append(list(unseen_args)) + + # Store data + p_args = [loop.arg for loop in nested_loops if loop.iterator_type == 'parallel'] + random.shuffle(p_args) + store_data = p_args + + return OperationFeatures( + raw_operation='', + operation_type='generic', + op_count=op_count, + load_data=load_data, + store_data=store_data, + nested_loops=nested_loops, + vectorizable=True + ) + + +def create_params(op_features: OperationFeatures) -> tuple[list[str], list[str]]: + params = [] + shapes = [] + args_dict = {loop.arg: loop.upper_bound for loop in op_features.nested_loops} + + # Load params + for i, load in enumerate(op_features.load_data): + shape: list[int] = [] + for arg in load: + if arg == '0': + shape.append(random.randint(1, 5)) + continue + shape.append(args_dict[arg]) + # inputs[f'arg{i}'] = np.random.rand(*shape) * 100 + inputs[f'arg{i}'] = np.empty(shape) + params.append(f'%arg{i}') + shapes.append(f"memref<{'x'.join(map(str, shape))}xf64>") + + # Store param + shape = [] + for arg in op_features.store_data: + if arg == '0': + shape.append(random.randint(1, 5)) + continue + shape.append(args_dict[arg]) + # inputs[f'arg{len(params)}'] = np.zeros(shape) + inputs[f'arg{len(params)}'] = np.empty(shape) + params.append(f'%arg{len(params)}') + shapes.append(f"memref<{'x'.join(map(str, shape))}xf64>") + + return params, shapes + + +def create_raw_operation(op_features: OperationFeatures, params: list[str], shapes: list[str]) -> str: + # Affine maps + base_dims = ', '.join([loop.arg for loop in op_features.nested_loops]) + affine_maps = [] + for load in op_features.load_data: + affine_maps.append(f"affine_map<({base_dims}) -> ({', '.join(load)})>") + affine_maps.append(f"affine_map<({base_dims}) -> ({', '.join(op_features.store_data)})>") + affine_maps_attr = f"[{', '.join(affine_maps)}]" + + # Iterators + iterators = ', '.join([f'"{loop.iterator_type}"' for loop in op_features.nested_loops]) + iterators_attr = f'[{iterators}]' + + # Inputs / Outputs + ins = f"ins({', '.join(params[:-1])}: {', '.join(shapes[:-1])})" + outs = f"outs({params[-1]}: {shapes[-1]})" + + code = f"linalg.generic {{indexing_maps={affine_maps_attr}, iterator_types={iterators_attr}}} {ins} {outs} {{\n" + block_args = [f"%in_{i}: f64" for i in range(len(op_features.load_data))] + ["%out: f64"] + code += f"^bb0({', '.join(block_args)}):\n" + + # Linalg body + block_params = [arg.split(':')[0] for arg in block_args] + unused_block_params = set(block_params.copy()) + created_args: set[str] = set() + tmp_count = 0 + op_count_copy = {op: count for op, count in op_features.op_count.items() if count > 0} + assert all(op_count_copy.values()) + total_op_count = sum(op_count_copy.values()) + for _ in range(total_op_count): + op = random.choice(list(op_count_copy.keys())) + if op == 'exp': + if len(unused_block_params) > 0: + operands = random.sample(list(unused_block_params), 1) + unused_block_params.difference_update(operands) + else: + operands = random.sample(list(created_args) + block_params, 1) + else: + if len(unused_block_params) > 1: + operands = random.sample(list(unused_block_params), 2) + unused_block_params.difference_update(operands) + elif len(unused_block_params) == 1: + operands = [unused_block_params.pop()] + unused_block_params = set() + operands += random.sample(list(created_args) + block_params, 1) + else: + operands = random.sample(list(created_args) + block_params, 2) + + result = f"%{tmp_count}" + tmp_count += 1 + created_args.add(result) + match op: + case '+': + code += f"{result} = arith.addf {operands[0]}, {operands[1]} fastmath : f64\n" + case '-': + code += f"{result} = arith.subf {operands[0]}, {operands[1]} fastmath : f64\n" + case '*': + code += f"{result} = arith.mulf {operands[0]}, {operands[1]} fastmath : f64\n" + case '/': + code += f"{result} = arith.divf {operands[0]}, {operands[1]} fastmath : f64\n" + case 'exp': + code += f"{result} = math.exp {operands[0]} fastmath : f64\n" + + op_count_copy[op] -= 1 + if op_count_copy[op] == 0: + del op_count_copy[op] + + assert sum(op_count_copy.values()) == 0 + + code += f"linalg.yield {result} : f64\n" + code += "}\n" + + return code + + +def formatMLIRCode(code: str) -> str: + """Util function that format the MLIR code by adding indents. + + Args: + code (str): the MLIR code + + Returns: + str: the formatted MLIR code + """ + lines = re.sub(r'\n+', '\n', code).split('\n') + result = '' + indent = 0 + for line in lines: + if len(line) > 0: + if line[0] == '}': + if indent > 0: + indent -= 1 + else: + indent = 0 + + result += indent * ' ' + line + '\n' + + if len(line) > 0: + if line[-1] == '{': + indent += 1 + + return result + + +def gen_full_code() -> str: + op_features = gen_features() + + params, shapes = create_params(op_features) + main_params = [f'{param}: {shape}' for param, shape in zip(params, shapes)] + + raw_operation = create_raw_operation(op_features, params, shapes) + + code = ( + f'func.func private @nanoTime() -> i64 attributes {{ llvm.emit_c_interface }}\n' + f'func.func @main({", ".join(main_params)}) -> i64 attributes {{ llvm.emit_c_interface }} {{\n' + f'%t0 = func.call @nanoTime() : () -> i64\n' + f'{raw_operation}\n' + f'%t1 = func.call @nanoTime() : () -> i64\n' + f'%t2 = arith.subi %t1, %t0 : i64\n' + f'return %t2 : i64\n' + f'}}\n' + ) + + code = formatMLIRCode(code) + + return code + + +if __name__ == '__main__': + with open('execution_times.json', 'r') as file: + execution_times: dict[str, int] = json.load(file) + last_count = max([int(k.split('_')[-1]) for k in execution_times.keys()]) + 1 + for i in trange(last_count, 10000, desc='Generating benchmarks', unit='bench'): + bench_generated = False + while not bench_generated: + bench_name = f'generic_{i}' + bench_output = os.path.join(output_dir, f"{bench_name}.mlir") + + inputs = {} + code = gen_full_code() + with Context(): + module = Module.parse(code) + pm = PassManager.parse(pass_pipeline) + pm.run(module.operation) + execution_engine = ExecutionEngine( + module, + shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), + ) + arg_names = sorted(inputs.keys()) + # np.savez(f"{bench_output}.npz", **inputs) + + c_args = [] + for arg_name in arg_names: + c_args.append(ctypes.pointer(ctypes.pointer( + get_ranked_memref_descriptor(inputs[arg_name]) + ))) + delta_arg = (ctypes.c_int64 * 1)(0) + c_args.append(delta_arg) + + try: + execution_engine.invoke("main", *c_args) + execution_engine.invoke("main", *c_args) + except Exception as e: + print(f"Failed, Bench: {bench_name}, error: {e}", file=sys.stderr) + # os.remove(f'{bench_output}.npz') + continue + + exec_time = delta_arg[0] + if exec_time >= (1 * 10**9): + # os.remove(f'{bench_output}.npz') + continue + + with open(bench_output, 'w') as f: + f.write(code) + # expected = inputs[arg_names[-1]] + # np.save(f"{bench_output}.npy", expected) + + execution_times[bench_name] = exec_time + with open('execution_times.json', 'w') as file: + json.dump(execution_times, file, indent=4) + + bench_generated = True diff --git a/init_env.py b/init_env.py new file mode 100644 index 0000000..494c5a9 --- /dev/null +++ b/init_env.py @@ -0,0 +1,42 @@ +# Load environment variables +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import os +from typing import Optional +from utils.log import print_info, print_success + +# Import environment +from rl_autoschedular.env import Env + +# config, file_logger, device +from rl_autoschedular import config as cfg, file_logger as fl, device + +# Import RL components +from rl_autoschedular.model import HiearchyModel as Model +from rl_autoschedular.trajectory import TrajectoryData +from rl_autoschedular.ppo import ( + collect_trajectory, + ppo_update, + value_update, + evaluate_benchmark +) + + +torch.set_grad_enabled(False) +torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "4"))) +if cfg.debug: + torch.autograd.set_detect_anomaly(True) + +print_info(f"Config: {cfg}") +print_success(f'Logging to: {fl.run_dir}') + +# Set environments +env = Env(is_training=True) +env.save_benchmarks_data_to_json("my_benchmarks.json") + +eval_env = Env(is_training=False, tmp_file=env.tmp_file) + +print_success(f"Environments initialized: {env.tmp_file}") \ No newline at end of file diff --git a/iql/__init__.py b/iql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iql/iql_agent.py b/iql/iql_agent.py new file mode 100755 index 0000000..dacc6b4 --- /dev/null +++ b/iql/iql_agent.py @@ -0,0 +1,286 @@ +import copy +from typing import Dict, List, Optional, Type, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.config import Config +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory + +# ---- bring in the updated models we defined earlier ---- +from iql.value_function import IQLValueModel +from iql.policy import IQLPolicyModel +from iql.q_functions import IQLTwinQ + + +class IQLAgent(nn.Module): + """ + IQL agent adapted to the PPO-aligned architecture and hierarchical action space. + - Uses Observation.get_parts(obs, *obs_parts) + - Shared 3×512 backbone across policy/value/Q + - Hierarchical heads (action + per-action params) + """ + def __init__(self, cfg: Config, device: Union[torch.device, str], obs_parts=None, param_dims=None): + super().__init__() + self.obs_parts = obs_parts or [OpFeatures, ActionHistory] + + # ---- device handling ---- + self.device = torch.device(device) if not isinstance(device, torch.device) else device + + # Use config hyperparameters + self.gamma = cfg.gamma + self.tau = cfg.tau + self.beta = cfg.beta + self.alpha = cfg.alpha + + # Networks (move to device) + self.value_model = IQLValueModel(self.obs_parts, tau=self.tau).to(self.device) + self.policy_model = IQLPolicyModel(self.obs_parts).to(self.device) + self.q_model = IQLTwinQ(self.obs_parts).to(self.device) + + # Target Q + self.q_target = copy.deepcopy(self.q_model).to(self.device) + for p in self.q_target.parameters(): + p.requires_grad = False + + # Optimizers with cfg.lr dict (after models are on device) + self.value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=cfg.lr["value"]) + self.q_optimizer = torch.optim.Adam(self.q_model.parameters(), lr=cfg.lr["q"]) + self.policy_optimizer = torch.optim.Adam(self.policy_model.parameters(), lr=cfg.lr["policy"]) + self.policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.policy_optimizer, + T_max=600000, + eta_min=1e-5 + ) + + # --------- helpers to move inputs to device ---------- + def _to_device_tensor(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if x is None: + return None + return x.to(self.device, non_blocking=True) + + def _to_device_tensor_list( + self, + xs: Optional[List[Optional[torch.Tensor]]] + ) -> Optional[List[Optional[torch.Tensor]]]: + if xs is None: + return None + out: List[Optional[torch.Tensor]] = [] + for t in xs: + out.append(self._to_device_tensor(t) if isinstance(t, torch.Tensor) else None if t is None else t) + return out + + # ------------------------ + # Action selection (hierarchical) + # ------------------------ + @torch.no_grad() + def sample( + self, + obs: torch.Tensor, + greedy: bool = False, + eps: Optional[float] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample hierarchical action indices using the same API style as PPO. + Returns: + actions_index: packed hierarchical indices (ActionSpace format) + actions_log_p: log-prob of sampled action under current policy + entropies: per-head entropies (aggregated by ActionSpace) + """ + + # Build distributions from policy + dists = self.policy_model(obs) + eps_dists = ActionSpace.uniform_distributions(obs) + + # Hierarchical sample + use_uniform = (eps is not None) and (torch.rand((), device=self.device).item() < eps) + actions_index = ActionSpace.sample( + obs, + dists, + eps_dists, + uniform=use_uniform, + greedy=greedy, + ) + + # Stats for the sampled actions + actions_log_p, entropies = ActionSpace.distributions_stats( + dists, + actions_index, + eps_distributions=eps_dists if eps is not None else None, + eps=eps, + ) + return actions_index, actions_log_p, entropies + + # ------------------------ + # Value update (expectile regression using target twin-Q) + # ------------------------ + def update_value( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + *, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Updates V(s) by regressing towards min(Q1, Q2) from the *target* Q network. + """ + + with torch.no_grad(): + q1_t, q2_t = self.q_target(obs, action_idx) + q_min_t = torch.min(q1_t, q2_t) # [B] + + self.value_optimizer.zero_grad(set_to_none=True) + loss_v = self.value_model.loss(obs, q_min_t) + loss_v.backward() + self.value_optimizer.step() + return loss_v + + # ------------------------ + # Q update (TD with V(s')) + # ------------------------ + def update_q( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + rewards: torch.Tensor, + next_obs: torch.Tensor, + dones: torch.Tensor + ) -> torch.Tensor: + """ + Update twin Q networks with TD target: + target_q = r + gamma * (1 - done) * V_target(s') + If target_v is not provided, it is computed from the current value_model. + """ + + + with torch.no_grad(): + target_v = self.value_model(next_obs).to(self.device) # [B] + + target_q = rewards + self.gamma * (1.0 - dones) * target_v # [B] + + self.q_optimizer.zero_grad(set_to_none=True) + + + loss_q = self.q_model.loss( + obs, + action_idx, + target_q + ) + loss_q.backward() + self.q_optimizer.step() + return loss_q + + # ------------------------ + # Policy update (advantage-weighted BC) + # ------------------------ + def update_policy( + self, + obs: torch.Tensor, + actions_index: torch.Tensor, # packed hierarchical indices (as stored by dataset) + *, + action_idx: Optional[torch.LongTensor] = None, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Update policy with advantage-weighted log-likelihood: + weights = exp(A / beta), A = min(Q1, Q2) - V(s) + + - actions_index is used to compute log π(a|s) via ActionSpace.distributions_stats(...) + - Q needs decomposed (action_idx, param_indices/values). + """ + + # 1) log π(a|s) from hierarchical distributions + dists = self.policy_model(obs) + actions_log_p, _ = ActionSpace.distributions_stats(dists, actions_index) + + # 2) advantages = Q_min(s,a) - V(s) + assert action_idx is not None, "action_idx (top-level) is required for Q evaluation" + with torch.no_grad(): + q_min = self.q_model.q_values(obs, action_idx) # [B] + v = self.value_model(obs) # [B] + advantages = q_min - v # [B] + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + + # 3) loss (AWAC/IQL style) + + # 1. zero gradients + self.policy_optimizer.zero_grad(set_to_none=True) + # 2. compute loss + loss_pi = self.policy_model.loss( + actions_log_p=actions_log_p, + advantages=advantages, + beta=self.beta, + ) + + # 3. backpropagate + loss_pi.backward() + + # 4. clip gradients (to avoid instability) + torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), max_norm=5.0) + + + self.policy_optimizer.step() + self.policy_lr_scheduler.step() + return loss_pi + + # ------------------------ + # Soft update of target Q + # ------------------------ + @torch.no_grad() + def soft_update_q_target(self): + """ + θ_target ← α θ + (1-α) θ_target + """ + for p, tp in zip(self.q_model.parameters(), self.q_target.parameters()): + tp.data.copy_(self.alpha * p.data + (1.0 - self.alpha) * tp.data) + + def update(self, batch: Tuple[torch.Tensor, ...]) -> Dict[str, float]: + """ + One full IQL update step: + 1. Update Q-functions + 2. Update value function + 3. Update policy (AWAC/IQL style) + 4. Soft update target Q + Returns dict of losses for logging. + """ + obs, actions_index, rewards, next_obs, dones = (t.to(self.device, non_blocking=True) for t in batch) + + + # ---- 1) Update Q ---- + loss_q = self.update_q( + obs=obs, + action_idx=actions_index, # top-level index + rewards=rewards, + next_obs=next_obs, + dones=dones, + ) + + # ---- 2) Update Value ---- + loss_v = self.update_value(obs, actions_index) + + + # ---- 3) Update Policy ---- + loss_pi = self.update_policy( + obs=obs, + actions_index=actions_index, + action_idx=actions_index, # required for Q evaluation + ) + + + + # ---- 4) Soft update Q target ---- + self.soft_update_q_target() + + return { + "q": float(loss_q.item()), + "policy": float(loss_pi.item()), + "value": float(loss_v.item()), + } + + + diff --git a/iql/iql_agent_device.py b/iql/iql_agent_device.py new file mode 100644 index 0000000..61c9b47 --- /dev/null +++ b/iql/iql_agent_device.py @@ -0,0 +1,282 @@ +import copy +from typing import Dict, List, Optional, Type, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from iql.iql_config import Config +from utils.config import Config +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory + +# ---- bring in the updated models we defined earlier ---- +from iql.value_function import IQLValueModel +from iql.policy import IQLPolicyModel +from iql.q_functions import IQLTwinQ + + +class IQLAgent(nn.Module): + """ + IQL agent adapted to the PPO-aligned architecture and hierarchical action space. + - Uses Observation.get_parts(obs, *obs_parts) + - Shared 3×512 backbone across policy/value/Q + - Hierarchical heads (action + per-action params) + """ + def __init__(self, cfg: Config, device: Union[torch.device, str], obs_parts=None, param_dims=None): + super().__init__() + self.obs_parts = obs_parts or [OpFeatures, ActionHistory] + + # ---- device handling ---- + self.device = torch.device(device) if not isinstance(device, torch.device) else device + + # Use config hyperparameters + self.gamma = cfg.gamma + self.tau = cfg.tau + self.beta = cfg.beta + self.alpha = cfg.alpha + + # Networks (move to device) + self.value_model = IQLValueModel(self.obs_parts, tau=self.tau).to(self.device) + self.policy_model = IQLPolicyModel(self.obs_parts).to(self.device) + self.q_model = IQLTwinQ(self.obs_parts).to(self.device) + + # Target Q + self.q_target = copy.deepcopy(self.q_model).to(self.device) + for p in self.q_target.parameters(): + p.requires_grad = False + + # Optimizers with cfg.lr dict (after models are on device) + self.value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=cfg.lr["value"]) + self.q_optimizer = torch.optim.Adam(self.q_model.parameters(), lr=cfg.lr["q"]) + self.policy_optimizer = torch.optim.Adam(self.policy_model.parameters(), lr=cfg.lr["policy"]) + + # --------- helpers to move inputs to device ---------- + def _to_device_tensor(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if x is None: + return None + return x.to(self.device, non_blocking=True) + + def _to_device_tensor_list( + self, + xs: Optional[List[Optional[torch.Tensor]]] + ) -> Optional[List[Optional[torch.Tensor]]]: + if xs is None: + return None + out: List[Optional[torch.Tensor]] = [] + for t in xs: + out.append(self._to_device_tensor(t) if isinstance(t, torch.Tensor) else None if t is None else t) + return out + + # ------------------------ + # Action selection (hierarchical) + # ------------------------ + @torch.no_grad() + def sample( + self, + obs: torch.Tensor, + greedy: bool = False, + eps: Optional[float] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample hierarchical action indices using the same API style as PPO. + Returns: + actions_index: packed hierarchical indices (ActionSpace format) + actions_log_p: log-prob of sampled action under current policy + entropies: per-head entropies (aggregated by ActionSpace) + """ + obs = self._to_device_tensor(obs) + + # Build distributions from policy + dists = self.policy_model(obs) + eps_dists = ActionSpace.uniform_distributions(obs) + + # Hierarchical sample + use_uniform = (eps is not None) and (torch.rand((), device=self.device).item() < eps) + actions_index = ActionSpace.sample( + obs, + dists, + eps_dists, + uniform=use_uniform, + greedy=greedy, + ) + + # Stats for the sampled actions + actions_log_p, entropies = ActionSpace.distributions_stats( + dists, + actions_index, + eps_distributions=eps_dists if eps is not None else None, + eps=eps, + ) + return actions_index, actions_log_p, entropies + + # ------------------------ + # Value update (expectile regression using target twin-Q) + # ------------------------ + def update_value( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + *, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Updates V(s) by regressing towards min(Q1, Q2) from the *target* Q network. + """ + obs = self._to_device_tensor(obs) + action_idx = self._to_device_tensor(action_idx) + param_indices = self._to_device_tensor_list(param_indices) + param_values = self._to_device_tensor_list(param_values) + + with torch.no_grad(): + q1_t, q2_t = self.q_target(obs, action_idx) + q_min_t = torch.min(q1_t, q2_t) # [B] + + loss_v = self.value_model.loss(obs, q_min_t) + assert self.value_optimizer is not None, "value_optimizer is not set" + self.value_optimizer.zero_grad(set_to_none=True) + loss_v.backward() + self.value_optimizer.step() + return loss_v + + # ------------------------ + # Q update (TD with V(s')) + # ------------------------ + def update_q( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + rewards: torch.Tensor, + next_obs: torch.Tensor, + dones: torch.Tensor + ) -> torch.Tensor: + """ + Update twin Q networks with TD target: + target_q = r + gamma * (1 - done) * V_target(s') + If target_v is not provided, it is computed from the current value_model. + """ + obs = self._to_device_tensor(obs) + next_obs = self._to_device_tensor(next_obs) + action_idx = self._to_device_tensor(action_idx) + rewards = self._to_device_tensor(rewards) + dones = self._to_device_tensor(dones) + + with torch.no_grad(): + target_v = self.value_model(next_obs).to(self.device) # [B] + + target_q = rewards + self.gamma * (1.0 - dones) * target_v # [B] + + loss_q = self.q_model.loss( + obs, + action_idx, + target_q + ) + assert self.q_optimizer is not None, "q_optimizer is not set" + self.q_optimizer.zero_grad(set_to_none=True) + loss_q.backward() + self.q_optimizer.step() + return loss_q + + # ------------------------ + # Policy update (advantage-weighted BC) + # ------------------------ + def update_policy( + self, + obs: torch.Tensor, + actions_index: torch.Tensor, # packed hierarchical indices (as stored by dataset) + *, + action_idx: Optional[torch.LongTensor] = None, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None, + param_values: Optional[List[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Update policy with advantage-weighted log-likelihood: + weights = exp(A / beta), A = min(Q1, Q2) - V(s) + + - actions_index is used to compute log π(a|s) via ActionSpace.distributions_stats(...) + - Q needs decomposed (action_idx, param_indices/values). + """ + obs = self._to_device_tensor(obs) + actions_index = self._to_device_tensor(actions_index) + action_idx = self._to_device_tensor(action_idx) if action_idx is not None else None + param_indices = self._to_device_tensor_list(param_indices) + param_values = self._to_device_tensor_list(param_values) + + # 1) log π(a|s) from hierarchical distributions + dists = self.policy_model(obs) + actions_log_p, _ = ActionSpace.distributions_stats(dists, actions_index) + + # 2) advantages = Q_min(s,a) - V(s) + assert action_idx is not None, "action_idx (top-level) is required for Q evaluation" + with torch.no_grad(): + q_min = self.q_model.q_values(obs, action_idx) # [B] + v = self.value_model(obs) # [B] + advantages = q_min - v # [B] + + # 3) loss (AWAC/IQL style) + loss_pi = self.policy_model.loss( + actions_log_p=actions_log_p, + advantages=advantages, + beta=self.beta, + ) + + assert self.policy_optimizer is not None, "policy_optimizer is not set" + self.policy_optimizer.zero_grad(set_to_none=True) + loss_pi.backward() + self.policy_optimizer.step() + return loss_pi + + # ------------------------ + # Soft update of target Q + # ------------------------ + @torch.no_grad() + def soft_update_q_target(self): + """ + θ_target ← α θ + (1-α) θ_target + """ + for p, tp in zip(self.q_model.parameters(), self.q_target.parameters()): + tp.data.copy_(self.alpha * p.data + (1.0 - self.alpha) * tp.data) + + def update(self, batch: Tuple[torch.Tensor, ...]) -> Dict[str, float]: + """ + One full IQL update step: + 1. Update Q-functions + 2. Update value function + 3. Update policy (AWAC/IQL style) + 4. Soft update target Q + Returns dict of losses for logging. + """ + # Ensure whole batch is on device + obs, actions_index, rewards, next_obs, dones = (t.to(self.device, non_blocking=True) for t in batch) + + # ---- 1) Update Q ---- + loss_q = self.update_q( + obs=obs, + action_idx=actions_index, # top-level index + rewards=rewards, + next_obs=next_obs, + dones=dones, + ) + + # ---- 2) Update Value ---- + loss_v = self.update_value(obs, actions_index) + + + # ---- 3) Update Policy ---- + loss_pi = self.update_policy( + obs=obs, + actions_index=actions_index, + action_idx=actions_index, # required for Q evaluation + ) + + + + # ---- 4) Soft update Q target ---- + self.soft_update_q_target() + + return { + "q": float(loss_q.item()), + "policy": float(loss_pi.item()), + "value": float(loss_v.item()), + } diff --git a/iql/iql_config.py b/iql/iql_config.py new file mode 100755 index 0000000..c6ec39a --- /dev/null +++ b/iql/iql_config.py @@ -0,0 +1,159 @@ +from dotenv import load_dotenv +load_dotenv() + + +import os +from utils.singleton import Singleton +import json +from typing import Literal + + + + +class Config(metaclass=Singleton): + """Class to store and load global configuration""" + + ############## IQL specific parameters ############## + gamma : float + """Discount factor""" + tau : float + """expectile regression parameter""" + inverse_temperature : float + """Inverse temperature for advantage-weighted regression""" + alpha : float + """target smoothing coefficient""" + batch_size : int + """Batch size for training""" + learning_rate : dict[str, float] + """Learning rate for the optimizer""" + max_steps : int + """Maximum number of training steps""" + target_update_freq : int + """Frequency of target network updates""" + sparse_reward : bool + """Flag to enable sparse reward""" + + offline_data_directory : str + """The offline data directory""" + offline_data_file : str + """The offline data file""" + + + ############## Environment specific parameters ############## + max_num_stores_loads: int + """The maximum number of loads in the nested loops""" + max_num_loops: int + """The max number of nested loops""" + max_num_load_store_dim: int + """The max number of dimensions in load/store buffers""" + num_tile_sizes: int + """The number of tile sizes""" + vect_size_limit: int + """Vectorization size limit to prevent large sizes vectorization""" + order: list[list[str]] + """The order of actions that needs to bo followed""" + interchange_mode: Literal['enumerate', 'pointers', 'continuous'] + """The method used for interchange action""" + exploration: list[Literal['entropy', 'epsilon']] + """The exploration method""" + init_epsilon: float + """The initial epsilon value for epsilon greedy exploration""" + + normalize_bounds: Literal['none', 'max', 'log'] + """Flag to indicate if the upper bounds in the input should be normalized or not""" + + + split_ops: bool + """Flag to enable splitting operations into separate benchmarks""" + + activation: Literal["relu", "tanh"] + """The activation function to use in the network""" + + benchmarks_folder_path: str + """Path to the benchmarks folder. Can be empty if optimization mode is set to "last".""" + + bench_count: int + """Number of batches in a trajectory""" + + truncate: int + """Maximum number of steps in the schedule""" + json_file: str + """Path to the JSON file containing the benchmarks execution times.""" + eval_json_file: str + """Path to the JSON file containing the benchmarks execution times for evaluation.""" + + tags: list[str] + """List of tags to add to the neptune experiment""" + + debug: bool + """Flag to enable debug mode""" + + exec_data_file: str + """Path to the file containing the execution data""" + results_dir: str + """Path to the results directory""" + + loaded: bool + """Flag to check if the config was already loaded from JSON file or not""" + + def __init__(self): + """Initialize the default values""" + # IQL specific parameters + self.gamma = 0.99 + self.tau = 0.7 + self.inverse_temperature = 3.0 + self.alpha = 0.005 + self.batch_size = 256 + self.learning_rate = { + "value": 3e-4, + "q": 3e-4, + "policy": 3e-4 + } + self.max_steps = 1000000 + self.target_update_freq = 1 + self.sparse_reward = True + + self.offline_data_directory = "./data" + self.offline_data_file = "offline_data.npz" + + # Environment specific parameters + self.max_num_stores_loads = 2 + self.max_num_loops = 4 + self.max_num_load_store_dim = 2 + self.num_tile_sizes = 2 + self.vect_size_limit = 16 + self.order = [] + self.interchange_mode = 'continuous' + self.exploration = ['entropy', 'epsilon'] + self.init_epsilon = 1.0 + self.normalize_bounds = 'log' + self.split_ops = False + self.activation = "relu" + self.benchmarks_folder_path = "./benchmarks" + self.bench_count = 1 + self.truncate = 20 + self.json_file = "./config/exec_times.json" + self.eval_json_file = "./config/exec_times.json" + self.tags = [] + self.debug = False + self.exec_data_file = "./data/exec_data.npz" + self.results_dir = "./results" + self.loaded = False + + def load_from_json(self): + """Load the configuration from the JSON file.""" + # Open the JSON file + with open(os.getenv("OFFLINE_RL_CONFIG_FILE_PATH"), "r") as f: + config = json.load(f) + # Set the configuration values + for key, value in config.items(): + if hasattr(self, key): + setattr(self, key, value) + + def to_dict(self): + """Convert the configuration to a dictionary.""" + return self.__dict__ + + def __str__(self): + """Convert the configuration to a string.""" + return str(self.to_dict()) \ No newline at end of file diff --git a/iql/policy.py b/iql/policy.py new file mode 100755 index 0000000..237acb7 --- /dev/null +++ b/iql/policy.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +from torch.distributions import Distribution +from typing import Optional, List, Type +from rl_autoschedular import config as cfg +from rl_autoschedular.actions import ActionSpace, Interchange +from rl_autoschedular.observation import Observation, ObservationPart + +# Match PPO’s activation config +ACTIVATION = nn.ReLU if cfg.activation == "relu" else nn.Tanh + + +class IQLPolicyModel(nn.Module): + """ + IQL policy network, sharing architecture with PPO’s PolicyModel. + - Backbone: 3×512 MLP with ACTIVATION() + - Heads: one for action selection + one per action’s parameterization + - Output: list[Distribution], via ActionSpace.distributions + - Loss: BC loss with advantage-weighted log-likelihood (AWAC / IQL style) + """ + + def __init__(self, obs_parts: List[Type[ObservationPart]]): + super().__init__() + self.obs_parts = obs_parts + + + # Shared encoder + in_size = sum(part.size() for part in obs_parts) + self.backbone = nn.Sequential( + nn.Linear(in_size, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + ) + + # One head for action choice + one for each action’s params + output_sizes = [ActionSpace.size()] + [ + action.network_output_size() for action in ActionSpace.supported_actions + ] + self.heads_attributes = [f"head_{i}" for i in range(len(output_sizes))] + + for head_attr, output_size in zip(self.heads_attributes, output_sizes): + if not output_size: + setattr(self, head_attr, None) + continue + + head = nn.Linear(512, output_size) + if cfg.new_architecture: + head = nn.Sequential(nn.Linear(512, 512), ACTIVATION(), head) + setattr(self, head_attr, head) + + def forward(self, obs: torch.Tensor) -> List[Optional[Distribution]]: + """ + Forward pass: produce a Distribution object per action head. + """ + embedded = self.backbone(Observation.get_parts(obs, *self.obs_parts)) + heads: List[Optional[nn.Module]] = [getattr(self, attr) for attr in self.heads_attributes] + actions_logits = [head(embedded) if head else None for head in heads] + return ActionSpace.distributions(obs, *actions_logits) + + def loss( + self, + actions_log_p: torch.Tensor, + advantages: torch.Tensor, + beta: float = 1.0, + ) -> torch.Tensor: + """ + Advantage-weighted behavioral cloning (AWAC) / IQL policy loss. + + Args: + obs: Observations [B, ...] + actions_log_p: log π(a|s) from this policy, evaluated at dataset actions + advantages: Advantage estimates A(s,a) from IQL (Q - V) + beta: Temperature scaling (larger beta = more deterministic) + + Returns: + Scalar loss tensor. + """ + # Weights = exp(A / beta), clipped for stability + weights = torch.exp(advantages / beta).clamp(max=100.0) + loss = -(weights * actions_log_p).mean() + return loss diff --git a/iql/q_functions.py b/iql/q_functions.py new file mode 100755 index 0000000..a3c6bbd --- /dev/null +++ b/iql/q_functions.py @@ -0,0 +1,282 @@ +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import torch.nn as nn +from typing import List, Optional, Tuple, Type + +# reuse same Observation/ActionSpace imports you have in repo +from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.state import OperationState +from rl_autoschedular import config as cfg + +# Keep same activation used elsewhere +ACTIVATION = nn.ReLU # replace with your actual ACTIVATION if different + + +class _DiscreteQHead(nn.Module): + """A head that outputs Q-values for each discrete option (flat logits -> Q-values).""" + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + # single linear mapping from embedding -> out_dim Q-values + self.net = nn.Linear(in_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # returns [B, out_dim] + return self.net(x) + + +class _TwinHiearchicalQNetwork(nn.Module): + """ + One Q network: + - backbone: embed observation -> [B, embed_dim] + - head_action: Q for selecting each action type [B, |A|] + - param_heads: for each action, a head producing Q-values for that action's network_output_size + NOTE: For multi-slot actions (e.g. Tiling), param_head outputs are flat: a concatenation + of per-slot Q-values. We'll reshape inside q_contribs when computing param-contribution. + """ + def __init__(self, obs_parts: List[Type[ObservationPart]], embed_dim: int = 512): + super().__init__() + self.obs_parts = obs_parts + in_size = sum(p.size() for p in obs_parts) + + # Backbone + self.backbone = nn.Sequential( + nn.Linear(in_size, embed_dim), + ACTIVATION(), + nn.Linear(embed_dim, embed_dim), + ACTIVATION(), + nn.Linear(embed_dim, embed_dim), + ACTIVATION(), + ) + + # Action selection Q head + self.head_action = _DiscreteQHead(embed_dim, ActionSpace.size()) + + # Parameter heads (one per supported action) + self.param_heads = nn.ModuleList() + # Save meta per action for convenient reshaping later + self._param_meta: List[Optional[dict]] = [] + for action_cls in ActionSpace.supported_actions: + out_dim = action_cls.network_output_size() + params_size = action_cls.params_size() + if out_dim and out_dim > 0: + head = _DiscreteQHead(embed_dim, out_dim) + self.param_heads.append(head) + # if multi-slot, compute classes_per_slot = out_dim // params_size (integer) + if params_size > 0: + classes_per_slot = None + if params_size > 0: + classes_per_slot = out_dim // params_size if params_size != 0 else None + self._param_meta.append({ + "params_size": params_size, # number of slots for this action + "out_dim": out_dim, + "classes_per_slot": classes_per_slot # number of classes per slot (None if single slot) + # Example : Interchange -> params_size=1 , out_dim= 7 , classes_per_slot= 7 , 7 loop choices for the current interchange + }) + else: + self._param_meta.append({"params_size": 0, "out_dim": out_dim, "classes_per_slot": None}) + else: + # no parameters -> placeholder (we'll treat as None) + # ( NT | V ) + self.param_heads.append(nn.Identity()) + self._param_meta.append(None) + + def forward( + self, + obs: torch.Tensor, + action_idx: torch.LongTensor, + param_indices: Optional[List[Optional[torch.LongTensor]]] = None + ) -> torch.Tensor: + emb = self._embed(obs) + return self.q_contribs(emb, action_idx, param_indices) + + def _embed(self, obs: torch.Tensor) -> torch.Tensor: + parts = Observation.get_parts(obs, *self.obs_parts) # returns [B, in_size] + return self.backbone(parts) # [B, embed_dim] + + def q_contribs( + self, + emb: torch.Tensor, + action_idx: torch.LongTensor, # [B] + param_indices: Optional[List[Optional[torch.LongTensor]]] = None # list of length B + ) -> torch.Tensor: + """ + Compute Q(s, a, params) as: Q_action(s)[a] + Q_params(s, a, params). + - emb: [B, embed_dim] + - action_idx: [B] integers in [0..|A|-1] + - param_indices: list of length B (each either None or [params_size]) + + Returns: + q_total: [B] - scalar Q for each sample + """ + B = emb.size(0) + device = emb.device + + # ---- top-level action contribution ---- + act_qs = self.head_action(emb) # [B, |A|] + act_q = act_qs.gather(1, action_idx.view(-1, 1)).squeeze(1) # [B] + + # ---- parameter contribution ---- + param_q = torch.zeros(B, device=device) + + # group samples by chosen action to do batched head computation + for k, head in enumerate(self.param_heads): + meta = self._param_meta[k] + if isinstance(head, nn.Identity) or (meta is None): + continue + + # mask = all samples where chosen action == k + mask = (action_idx == k) + if not mask.any(): + continue + + # get embeddings and their chosen param indices for this action + emb_masked = emb[mask] + head_out = head(emb_masked) # [N_mask, out_dim_k] + + psize = meta["params_size"] # + out_dim = meta["out_dim"] + cps = meta["classes_per_slot"] + + # collect just the param indices for the masked samples + masked_params = [param_indices[i] for i in range(B) if mask[i]] # [N_mask, Optional[LongTensor] of consistet size psize] + # they should all be not None if this action has params + assert all((p is not None) for p in masked_params) or psize == 0 + + if psize == 0: + continue + + if psize == 1: + # single-slot + idx_tensor = torch.stack([p.view(-1)[0] for p in masked_params]).view(-1, 1) # [N_mask, 1] + q_k = head_out.gather(1, idx_tensor).squeeze(1) # [N_mask] + param_q[mask] = q_k + else: + # multi-slot + assert cps is not None and cps > 0, "classes_per_slot unknown for multi-slot head" + + reshaped = head_out.view(-1, psize, cps) # [N_mask, psize, cps] + idx_tensor = torch.stack(masked_params).long() # [N_mask, psize] + idx_exp = idx_tensor.unsqueeze(-1) # [N_mask, psize, 1] + gathered = torch.gather(reshaped, dim=2, index=idx_exp).squeeze(-1) # [N_mask, psize] + q_k = gathered.sum(dim=1) # [N_mask] + param_q[mask] = q_k + + return act_q + param_q + + + + + +# ---------------- hierarchical double Q network ---------------- +class IQLTwinQ(nn.Module): + """ + Top-level twin Q network that matches style of PolicyModel: + - builds two Q-branches (Q1 and Q2), each using _TwinHiearchicalQNetwork + - provides helpers to split flat action-tensor into action_idx + param slices + """ + def __init__(self, obs_parts: List[Type[ObservationPart]], embed_dim: int = 512): + super().__init__() + self.obs_parts = obs_parts + # instantiate two Q heads (Q1, Q2) + self.q1 = _TwinHiearchicalQNetwork(obs_parts, embed_dim=embed_dim) + self.q2 = _TwinHiearchicalQNetwork(obs_parts, embed_dim=embed_dim) + + + def forward(self, obs: torch.Tensor, index: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute Q1(s, a, params) and Q2(s, a, params) for a batch. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + Returns: + (q1_vals, q2_vals) each shaped [B] + """ + action_idx, params_list = self._split_action_tensor(index) + q1_vals = self.q1(obs, action_idx, params_list) + q2_vals = self.q2(obs, action_idx, params_list) + return q1_vals, q2_vals + + def q_values(self, obs: torch.Tensor, index: torch.LongTensor) -> torch.Tensor: + """ + Compute min(Q1(s, a, params), Q2(s, a, params)) for a batch. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + Returns: + q_vals shaped [B] + """ + action_idx, params_list = self._split_action_tensor(index) + q1_vals = self.q1(obs, action_idx, params_list) + q2_vals = self.q2(obs, action_idx, params_list) + return torch.min(q1_vals, q2_vals) + + + def loss(self, obs: torch.Tensor, index: torch.LongTensor, target_q: torch.Tensor) -> torch.Tensor: + """ + Compute MSE loss between Q1, Q2 and target_q. + - obs: observation tensor [B, ...] + - index: flat index tensor from ActionSpace.sample() -> [B, 1 + sum(params)] + - target_q: target Q-values [B] + Returns: + scalar loss + """ + q1_vals, q2_vals = self.forward(obs, index) # each [B] + loss_fn = nn.MSELoss() + loss1 = loss_fn(q1_vals, target_q) + loss2 = loss_fn(q2_vals, target_q) + return loss1 + loss2 + + @staticmethod + def _split_action_tensor(index: torch.LongTensor) -> Tuple[torch.LongTensor, List[Optional[torch.LongTensor]]]: + """ + Split the `index` tensor returned by ActionSpace.sample() into: + - action_idx: [B] + - params: list of length B, each either None (no params) or LongTensor [params_size] for that action + """ + B = index.size(0) + device = index.device + + action_idx = index[:, 0].long() # [B] + cum = ActionSpace.cumulative_params_sizes() + + params: List[Optional[torch.LongTensor]] = [] + + for i in range(B): + a_idx = action_idx[i].item() + action_type = ActionSpace.supported_actions[a_idx] + if action_type.params_size() == 0: + params.append(None) + else: + start, end = cum[a_idx], cum[a_idx + 1] + # extract just that sample's params for its chosen action + params.append(index[i, start:end].long()) + + return action_idx, params + + + +def main(): + model = IQLTwinQ([OpFeatures, ActionHistory]) + + _model = _TwinHiearchicalQNetwork([OpFeatures, ActionHistory]) + + x = torch.tensor([[2, 1, 5, 7, 0, 0, 0, 0, 3, 4, 2, 0, 0, 0, 0, 2]]).float() + + + action_idx , param_idx = model._split_action_tensor(x) + + + obs = torch.zeros([1, 2152]) + + + q = _model(obs, action_idx, param_idx) + + print("Q-values:", q) + + + +if __name__ == "__main__": + + main() \ No newline at end of file diff --git a/iql/singleton.py b/iql/singleton.py new file mode 100755 index 0000000..acc0a41 --- /dev/null +++ b/iql/singleton.py @@ -0,0 +1,8 @@ +class Singleton(type): + """Meta class to create a singleton instance of a class""" + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] \ No newline at end of file diff --git a/iql/value_function.py b/iql/value_function.py new file mode 100755 index 0000000..d12deb7 --- /dev/null +++ b/iql/value_function.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +from typing import List, Type +from rl_autoschedular import config as cfg +from rl_autoschedular.observation import Observation, ObservationPart + + +ACTIVATION = nn.ReLU + + +class IQLValueModel(nn.Module): + """ + IQL Value function with the SAME encoder/MLP layout as PPO's ValueModel: + Linear(sum(obs_parts)->512) -> ACT -> 512 -> ACT -> 512 -> ACT -> 1 + + - Input: full Observation tensor, then sliced via Observation.get_parts to match PPO. + - Output: V(s) as shape [B], same squeeze(-1) behavior as PPO. + - Loss: Expectile regression with parameter tau (IQL). + """ + + def __init__( + self, + obs_parts: List[Type[ObservationPart]], + tau: float = 0.7, + ): + super().__init__() + self.obs_parts = obs_parts + self.tau = cfg.tau # consider wiring this from cfg (e.g., cfg.iql.tau) if you keep hyperparams in config + + in_size = sum(part.size() for part in obs_parts) + self.network = nn.Sequential( + nn.Linear(in_size, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 1), + ) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + """ + Args: + obs: full Observation tensor (like in PPO) + Returns: + V(s) as [B] + """ + x = Observation.get_parts(obs, *self.obs_parts) + return self.network(x).squeeze(-1) # [B] + + @torch.no_grad() + def v(self, obs: torch.Tensor) -> torch.Tensor: + """Convenience alias often used in IQL codepaths.""" + return self.forward(obs) + + def loss(self, obs: torch.Tensor, q_values: torch.Tensor) -> torch.Tensor: + """ + Expectile regression loss: minimize E[ w_tau(u) * u^2 ], u = Q(s,a) - V(s) + + Args: + obs: full Observation tensor for states [B, ...] (same as PPO input) + q_values: [B] or [B,1] tensor with target Q(s,a) (DETACHED upstream in IQL) + """ + v = self.forward(obs) # [B] + q = q_values.squeeze(-1) # [B] + diff = q - v # u + + # weight = |tau - 1(u < 0)| + # same as: tau if u >= 0 else (1 - tau) + weight = torch.abs(self.tau - (diff < 0).float()) + return (weight * diff.pow(2)).mean() diff --git a/iql_online.py b/iql_online.py new file mode 100644 index 0000000..c3d1234 --- /dev/null +++ b/iql_online.py @@ -0,0 +1,230 @@ +import os +import time +import torch +import numpy as np +from tqdm import trange + +import dotenv +dotenv.load_dotenv() + +from rl_autoschedular import config as cfg, file_logger as fl +from rl_autoschedular.env import Env +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.observation import Observation, OpFeatures, ActionHistory +from iql.iql_agent import IQLAgent +from utils.data_collector import OfflineDataset + + +device = torch.device("cpu") + + +def load_offline_dataset(): + """Load offline dataset for warm-starting replay buffer.""" + dataset = OfflineDataset( + save_dir=cfg.offline_data_save_dir, + fname=cfg.offline_data_file + ).load() + + if not dataset: + raise FileNotFoundError(f"Offline dataset not found: {cfg.offline_data_file}") + + states = torch.tensor(dataset["obs"], dtype=torch.float32) + actions = torch.tensor(dataset["actions"], dtype=torch.long) + rewards = torch.tensor(dataset["rewards"], dtype=torch.float32) + next_states = torch.tensor(dataset["next_obs"], dtype=torch.float32) + dones = torch.tensor(dataset["dones"], dtype=torch.float32) + + return states, actions, rewards, next_states, dones + + +@torch.no_grad() +def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): + """Evaluate model performance across all benchmarks.""" + env_time = 0.0 + eps = None + all_speedups, all_entropies = [], [] + + for _ in trange(cfg.bench_count, desc="Eval Trajectory", leave=False): + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done, speedup = False, None + bench_rewards, bench_entropies = [], [] + bench_name = state.bench_name + + while not bench_done: + obs = Observation.from_state(state) + action_index, action_log_p, entropy = model.sample(obs.to(device), eps=eps) + action = ActionSpace.action_by_index(action_index[0], state) + + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # per-benchmark logs + fl.log_scalars(f"eval/{bench_name}", { + "mean_reward": float(np.mean(bench_rewards)) if bench_rewards else 0.0, + "mean_entropy": float(np.mean(bench_entropies)) if bench_entropies else 0.0, + "final_speedup": speedup if speedup is not None else 0.0, + }, step) + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + # global logs + if all_speedups: + fl.log_scalar("eval/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("eval/average_entropy", float(np.mean(all_entropies)), step) + + return env_time + + +class ReplayBuffer: + """Simple replay buffer mixing offline + online data.""" + def __init__(self, max_size=1000000): + self.states, self.actions, self.rewards, self.next_states, self.dones = [], [], [], [], [] + self.max_size = max_size + + def add(self, s, a, r, ns, d): + if len(self.states) >= self.max_size: + # drop oldest + self.states.pop(0) + self.actions.pop(0) + self.rewards.pop(0) + self.next_states.pop(0) + self.dones.pop(0) + + self.states.append(s.squeeze(0)) + self.actions.append(a.squeeze(0)) + self.rewards.append(r.squeeze(0)) + self.next_states.append(ns.squeeze(0)) + self.dones.append(d) + + def sample(self, batch_size): + idxs = np.random.randint(0, len(self.states), size=batch_size) + return ( + torch.stack([self.states[i] for i in idxs]), + torch.stack([self.actions[i] for i in idxs]), + torch.stack([self.rewards[i] for i in idxs]), + torch.stack([self.next_states[i] for i in idxs]), + torch.stack([self.dones[i] for i in idxs]), + ) + + def __len__(self): + return len(self.states) + + +def hybrid_finetune(): + # === Load pretrained agent === + agent = IQLAgent(cfg, device, obs_parts=[OpFeatures, ActionHistory]) + ckpt_path = "./iql_results/iql_step_17999.pt" + if os.path.exists(ckpt_path): + agent.load_state_dict(torch.load(ckpt_path, map_location=device)) + print(f"Loaded pretrained checkpoint: {ckpt_path}") + else: + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + + # === Init Replay Buffer with offline data === + buffer = ReplayBuffer(max_size=200000) + states, actions, rewards, next_states, dones = load_offline_dataset() + for s, a, r, ns, d in zip(states, actions, rewards, next_states, dones): + buffer.add(s, a, r, ns, d) + print(f"Replay buffer initialized with {len(buffer)} offline samples") + + # environments + train_env = Env(is_training=True, run_name=cfg.run_name) + eval_env = Env(is_training=False, run_name=cfg.run_name) + + print("Starting HYBRID fine-tuning (offline + online)...") + start_time = time.time() + state = train_env.reset() + + hybrid_trange = trange(cfg.max_steps, desc="Hybrid Fine-tuning", dynamic_ncols=True) + for step in hybrid_trange: + # reset benchmark + state = train_env.reset() + done = False + + while not done: + # current obs + obs = Observation.from_state(state) + + # agent picks action + action_index, _, _ = agent.sample(obs.to(device), eps=None) + action = ActionSpace.action_by_index(action_index[0], state) + + # env step + next_state, reward, op_done, _ = train_env.step(state, action) + + # build next_obs BEFORE advancing benchmark + next_obs = Observation.from_state(next_state) + + # if op finished, advance to next op or benchmark end + if op_done: + next_state, done = train_env.get_next_op_state(next_state) + + # push transition to replay buffer + buffer.add( + obs.to(device), + action_index, + torch.tensor(reward, dtype=torch.float32, device=device), + next_obs.to(device), + torch.tensor(done, dtype=torch.float32, device=device), + ) + + # move forward + state = next_state + + # after benchmark, do 1 gradient update + batch = buffer.sample(cfg.batch_size) + losses = agent.update(batch) + + # logging + if step % 50 == 0: + fl.log_scalars("hybrid_train", losses, step) + + + + # logging + if step % 50 == 0: + fl.log_scalars("hybrid_train", losses, step) + + if (step + 1) % 100 == 0: + elapsed = time.time() - start_time + hybrid_trange.set_postfix({ + "Value Loss": f"{losses['value']:.4f}", + "Q Loss": f"{losses['q']:.4f}", + "Policy Loss": f"{losses['policy']:.4f}", + "Elapsed": f"{elapsed:.2f}s" + }) + + if (step + 1) % 5000 == 0: + print("Evaluating on benchmarks ...") + eval_start = time.time() + env_time = evaluate_benchmarks(agent, eval_env, step) + print(f"Evaluation done in {time.time() - eval_start:.2f}s (env time: {env_time:.2f}s)") + fl.flush() + + if (step + 1) % 2000 == 0: + os.makedirs(cfg.results_dir, exist_ok=True) + save_path = os.path.join(cfg.results_dir, f"iql_hybrid_step_{step}.pt") + torch.save(agent.state_dict(), save_path) + print(f"Checkpoint saved: {save_path}") + + state = next_state + + print(f"Hybrid fine-tuning finished in {time.time() - start_time:.2f} seconds.") + + +if __name__ == "__main__": + hybrid_finetune() diff --git a/models/.gitignore b/models/.gitignore old mode 100644 new mode 100755 index 86d0cb2..44c5ea8 --- a/models/.gitignore +++ b/models/.gitignore @@ -1,4 +1,4 @@ -# Ignore everything in this directory -* -# Except this file +# Ignore everything in this directory +* +# Except this file !.gitignore \ No newline at end of file diff --git a/neptune_sync.py b/neptune_sync.py old mode 100644 new mode 100755 index 99785ee..c8f1efa --- a/neptune_sync.py +++ b/neptune_sync.py @@ -1,72 +1,72 @@ -# Load environment variables -from dotenv import load_dotenv -load_dotenv(override=True) - -import neptune -from neptune import Run -import os -import time -import signal - -results_dir = 'results' -ids_file = os.path.join(results_dir, 'synced_ids') -if os.path.exists(ids_file): - with open(ids_file, 'r') as f: - synced_ids = [int(id) for id in f.readlines() if id.strip()] -else: - synced_ids = [] - -current_runs = [d for d in os.listdir(results_dir) if d.startswith('run_') and int(d.split('_')[1]) not in synced_ids] - -if not current_runs: - print('No new runs to sync') - exit() -print(f'Syncing runs: {current_runs}') - -with open(ids_file, 'a') as f: - f.write('\n'.join(run.split('_')[1] for run in current_runs)) - f.write('\n') - -neptune_runs: dict[str, Run] = {} -for run in current_runs: - run_path = os.path.join(results_dir, run) - with open(os.path.join(run_path, 'tags'), 'r') as f: - tags = f.read().splitlines() - neptune_run = neptune.init_run( - project=os.getenv('NEPTUNE_PROJECT'), - tags=tags, - ) - neptune_runs[run] = neptune_run - -runs_counters: dict[str, dict[str, int]] = {run: {} for run in current_runs} - - -def kill_handler(signum, frame): - print('Killing...') - for runs in neptune_runs.values(): - runs.stop() - exit() - - -signal.signal(signal.SIGTERM, kill_handler) - -while True: - print('Syncing...') - for run in current_runs: - neptune_run = neptune_runs[run] - run_path = os.path.join(results_dir, run, 'logs') - files: list[str] = [] - for root, _, filenames in os.walk(run_path): - relative_root = root.replace(run_path, '') - relative_root = relative_root[1:] if relative_root.startswith('/') else relative_root - for filename in filenames: - files.append(os.path.join(relative_root, filename) if relative_root else filename) - for file in files: - if file not in runs_counters[run]: - runs_counters[run][file] = 0 - read_idx = runs_counters[run][file] - with open(os.path.join(run_path, file), 'r') as f: - values = [float(line) for line in f.readlines()] - neptune_run[file].extend(values[read_idx:]) - runs_counters[run][file] = len(values) - time.sleep(60) +# Load environment variables +from dotenv import load_dotenv +load_dotenv(override=True) + +import neptune +from neptune import Run +import os +import time + +results_dir = 'results' +with open(os.path.join(results_dir, 'synced_ids'), 'r') as f: + synced_ids = [int(id) for id in f.readlines() if id.strip()] + +current_runs = [d for d in os.listdir(results_dir) if d.startswith('run_') and int(d.split('_')[1]) not in synced_ids] + +if not current_runs: + print('No new runs to sync') + exit() +print(f'Syncing runs: {current_runs}') + +with open(os.path.join(results_dir, 'synced_ids'), 'a') as f: + f.write('\n'.join(run.split('_')[1] for run in current_runs)) + f.write('\n') + +neptune_runs: dict[str, Run] = {} +for run in current_runs: + run_path = os.path.join(results_dir, run) + with open(os.path.join(run_path, 'tags'), 'r') as f: + tags = f.read().splitlines() + neptune_run = neptune.init_run( + project=os.getenv('NEPTUNE_PROJECT'), + api_token=os.getenv('NEPTUNE_TOKEN'), + tags=tags, + ) + neptune_runs[run] = neptune_run + +runs_counters: dict[str, dict[str, int]] = {run: {} for run in current_runs} + + +def kill_handler(signum, frame): + print('Killing...') + for runs in neptune_runs.values(): + runs.stop() + exit() + + +if __name__ == '__main__': + signal.signal(signal.SIGINT, kill_handler) + signal.signal(signal.SIGTERM, kill_handler) + + while True: + print('Syncing...') + for run in current_runs: + neptune_run = neptune_runs[run] + run_path = os.path.join(results_dir, run) + files: list[str] = [] + for root, _, filenames in os.walk(run_path): + relative_root = root.replace(run_path, '') + relative_root = relative_root[1:] if relative_root.startswith('/') else relative_root + for filename in filenames: + files.append(os.path.join(relative_root, filename) if relative_root else filename) + for file in files: + if file == 'tags': + continue + if file not in runs_counters[run]: + runs_counters[run][file] = 0 + read_idx = runs_counters[run][file] + with open(os.path.join(run_path, file), 'r') as f: + values = [float(line) for line in f.readlines()] + neptune_run[file].extend(values[read_idx:]) + runs_counters[run][file] = len(values) + time.sleep(60) diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index 15a4965..0114d2d --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,5 @@ -numpy -python-dotenv -torch -neptune -tqdm -dask-jobqueue -typeguard +numpy +python-dotenv +torch +neptune +tqdm \ No newline at end of file diff --git a/results/.gitignore b/results/.gitignore deleted file mode 100644 index c96a04f..0000000 --- a/results/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore \ No newline at end of file diff --git a/rl_autoschedular/__init__.py b/rl_autoschedular/__init__.py old mode 100644 new mode 100755 index f0ff98e..aa99fa7 --- a/rl_autoschedular/__init__.py +++ b/rl_autoschedular/__init__.py @@ -1,3 +1,26 @@ -import torch - -device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +from utils.config import Config +from utils.file_logger import TensorBoardLogger +import torch + +device = torch.device("cuda") + +# Load global configuration +config = Config() +if not config.loaded: + config.load_from_json() + +# Pass run_name explicitly (from config or CLI) +file_logger = TensorBoardLogger( + log_dir=config.results_dir, + run_name=config.run_name, + tags=['iql'] + config.tags +) + +offline_data_collector = None + +if config.collect_offline_data: + from utils.data_collector import OfflineDataset + offline_data_collector = OfflineDataset( + save_dir=config.offline_data_save_dir, + fname=config.offline_data_file + ) \ No newline at end of file diff --git a/rl_autoschedular/actions/__init__.py b/rl_autoschedular/actions/__init__.py old mode 100644 new mode 100755 index c5f45eb..b9bf1ff --- a/rl_autoschedular/actions/__init__.py +++ b/rl_autoschedular/actions/__init__.py @@ -1,244 +1,242 @@ -from utils.config import Config -from .base import Action -from .no_transformation import NoTransformation -from .tiling import Tiling -from .tiled_parallelization import TiledParallelization -from .tiled_fusion import TiledFusion -from .interchange import Interchange -from .vectorization import Vectorization -from rl_autoschedular.state import OperationState -import torch -from torch.distributions import Distribution, Categorical -from typing import Optional - - -class ActionSpace: - """Class holding information about the action space""" - - supported_actions: list[type[Action]] = [ - NoTransformation, - Tiling, - TiledParallelization, - TiledFusion, - Interchange, - Vectorization - ] - - @classmethod - def size(cls): - return len(cls.supported_actions) - - @classmethod - def cumulative_params_sizes(cls): - sizes: list[int] = [1] - for trans in cls.supported_actions: - sizes.append(sizes[-1] + trans.params_size()) - return sizes - - @classmethod - def cumulative_mask_sizes(cls): - sizes: list[int] = [cls.size()] - for trans in cls.supported_actions: - sizes.append(sizes[-1] + trans.mask_size()) - return sizes - - @classmethod - def cumulative_history_sizes(cls): - sizes: list[int] = [0] - for trans in cls.supported_actions: - sizes.append(sizes[-1] + trans.history_size()) - return sizes - - @classmethod - def action_by_index(cls, index: torch.Tensor, state: OperationState) -> Action: - action_idx = int(index[0].item()) - action_type = cls.supported_actions[action_idx] - if not action_type.params_size(): - return action_type(state) - - cum_sizes = cls.cumulative_params_sizes() - params = index[cum_sizes[action_idx]:cum_sizes[action_idx + 1]].long().tolist() - return action_type(params, state) - - @classmethod - def action_number(cls, action_type: type[Action]) -> int: - return cls.supported_actions.index(action_type) - - @classmethod - def action_type_by_symbol(cls, symbol: str) -> type[Action]: - for action in cls.supported_actions: - if action.symbol == symbol: - return action - - raise ValueError(f"action symbol '{symbol}' not supported") - - @classmethod - def action_number_by_symbol(cls, symbol: str) -> int: - return cls.action_number(cls.action_type_by_symbol(symbol)) - - @classmethod - def action_mask(cls, state: OperationState) -> torch.Tensor: - cfg = Config() - mask = torch.zeros(cls.size(), dtype=torch.bool) - - def allow_action(a: type[Action]): - if a.is_allowed(state): - mask[cls.action_number(a)] = True - - def allow_all(): - for action in cls.supported_actions: - allow_action(action) - - # If state is terminal don't allow any further actions - if not state.terminal: - if Interchange.incomplete_interchange(state): - # Special case where interchange isn't complete yet - mask[cls.action_number(Interchange)] = True - elif cfg.order: - # Enforce order if provided - if state.step_count >= len(cfg.order): - raise Exception("actions order must be ended with a terminal action") - if not cfg.order[state.step_count]: - # If at current step nothing is specified, allow everything - allow_all() - for s in cfg.order[state.step_count]: - allow_action(cls.action_type_by_symbol(s)) - else: - # If none of the above applies, allow everything - allow_all() - - # Check that there is at least one action allowed - if not mask.any(): - raise Exception(f"no actions allowed for the current state at step {state.step_count}") - - for action in cls.supported_actions: - action_mask = action.action_mask(state) - if action_mask is None: - continue - - mask = torch.cat((mask, action_mask)) - - return mask - - @classmethod - def action_history(cls, state: OperationState) -> torch.Tensor: - history = [] - for action in cls.supported_actions: - action_history = action.action_history(state) - if action_history is None: - continue - history.append(action_history) - if not history: - return torch.tensor([]) - - return torch.cat(history) - - @classmethod - def distributions(cls, obs: torch.Tensor, selection_logits: torch.Tensor, *actions_logits: Optional[torch.Tensor]) -> list[Optional[Distribution]]: - """Create a list of distributions for the actions based on the logits. - - Args: - obs (torch.Tensor): Observation tensor. - selection_logits (torch.Tensor): Logits for action selection. - *actions_logits (torch.Tensor): Logits for each action's parameters. - - Returns: - list[Distribution]: List of distributions for each action. - """ - from rl_autoschedular.observation import Observation, ActionMask - - actions_mask = Observation.get_part(obs, ActionMask).bool() - dists_list: list[Optional[Distribution]] = [ - Categorical(logits=selection_logits.where(actions_mask[:, :cls.size()], -torch.inf)) - ] - cum_sizes = cls.cumulative_mask_sizes() - for i, action in enumerate(cls.supported_actions): - if not action.mask_size(): - dists_list.append(None) - continue - - assert actions_logits[i] is not None, f"action '{action.symbol}' must have logits" - masked_logits = actions_logits[i].where(actions_mask[:, cum_sizes[i]:cum_sizes[i + 1]], -torch.inf) - dists_list.append(action.distribution(masked_logits)) - - return dists_list - - @classmethod - def uniform_distributions(cls, obs: torch.Tensor) -> list[Optional[Distribution]]: - """Create a list of uniform distributions for the actions based on the observation. - - Args: - obs (torch.Tensor): Observation tensor. - - Returns: - list[Distribution]: List of distributions for each action. - """ - from rl_autoschedular.observation import Observation, ActionMask, NumLoops - - actions_mask = Observation.get_part(obs, ActionMask).bool() - num_loops = Observation.get_part(obs, NumLoops) - selection_mask = actions_mask[:, :cls.size()] - dists_list: list[Optional[Distribution]] = [ - Categorical(logits=torch.zeros_like(selection_mask).where(selection_mask, -torch.inf)) - ] - cum_sizes = cls.cumulative_mask_sizes() - for i, action in enumerate(cls.supported_actions): - if not action.mask_size(): - dists_list.append(None) - continue - - action_mask = actions_mask[:, cum_sizes[i]:cum_sizes[i + 1]] - logits = torch.zeros_like(action_mask).where(action_mask, -torch.inf) - dists_list.append(action.uniform_distribution(logits, num_loops)) - - return dists_list - - @classmethod - def distributions_stats(cls, distributions: list[Optional[Distribution]], index: torch.Tensor, eps_distributions: Optional[list[Optional[Distribution]]] = None, eps: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor]: - eps_distributions: list[Optional[Distribution]] = eps_distributions or [None] * len(distributions) - - selection_index = index[:, 0] - selection_dist = distributions[0] - selection_eps_dist = eps_distributions[0] - selection_log_p = selection_dist.log_prob(selection_index) - if eps is not None: - selection_eps_log_p = selection_eps_dist.log_prob(selection_index) - selection_log_p = (selection_log_p.exp() * (1 - eps) + selection_eps_log_p.exp() * eps).log() - - cum_sizes = cls.cumulative_params_sizes() - actions_log_p, entropies = selection_log_p, selection_dist.entropy() - for i, (action, dist, eps_dist) in enumerate(zip(cls.supported_actions, distributions[1:], eps_distributions[1:])): - if dist is None: - continue - - action_index = index[:, cum_sizes[i]:cum_sizes[i + 1]] - action_log_p, entropy = action.distribution_stats(dist, action_index, eps_dist, eps) - actions_log_p[selection_index == i] += action_log_p[selection_index == i] - entropies[selection_index == i] += entropy[selection_index == i] - - return actions_log_p, entropies - - @classmethod - def sample(cls, obs: torch.Tensor, distributions: list[Optional[Distribution]], eps_distributions: list[Optional[Distribution]], uniform: bool = False, greedy: bool = False) -> torch.Tensor: - assert not uniform or not greedy, "can't sample uniformly and greedily at once" - from rl_autoschedular.observation import Observation, NumLoops - - num_loops = Observation.get_part(obs, NumLoops) - selection_dist = distributions[0] - selection_eps_dist = eps_distributions[0] - - if greedy: - selection_index = selection_dist.probs.argmax(-1) - elif uniform: - selection_index = selection_eps_dist.sample() - else: - selection_index = selection_dist.sample() - - index = selection_index.unsqueeze(-1) - for action, dist, eps_dist in zip(cls.supported_actions, distributions[1:], eps_distributions[1:]): - if dist is None: - continue - - index = torch.cat((index, action.sample(dist, eps_dist, num_loops, uniform, greedy)), dim=1) - - return index +from .base import Action +from .no_transformation import NoTransformation +from .tiling import Tiling +from .tiled_parallelization import TiledParallelization +from .tiled_fusion import TiledFusion +from .interchange import Interchange +from .vectorization import Vectorization +from rl_autoschedular import config as cfg +from rl_autoschedular.state import OperationState +import torch +from torch.distributions import Distribution, Categorical +from typing import Optional + + +class ActionSpace: + """Class holding information about the action space""" + + supported_actions: list[type[Action]] = [ + NoTransformation, + Tiling, + TiledParallelization, + Interchange, + Vectorization + ] + + @classmethod + def size(cls): + return len(cls.supported_actions) + + @classmethod + def cumulative_params_sizes(cls): + sizes: list[int] = [1] + for trans in cls.supported_actions: + sizes.append(sizes[-1] + trans.params_size()) + return sizes + + @classmethod + def cumulative_mask_sizes(cls): + sizes: list[int] = [cls.size()] + for trans in cls.supported_actions: + sizes.append(sizes[-1] + trans.mask_size()) + return sizes + + @classmethod + def cumulative_history_sizes(cls): + sizes: list[int] = [0] + for trans in cls.supported_actions: + sizes.append(sizes[-1] + trans.history_size()) + return sizes + + @classmethod + def action_by_index(cls, index: torch.Tensor, state: OperationState) -> Action: + action_idx = int(index[0].item()) + action_type = cls.supported_actions[action_idx] + if not action_type.params_size(): + return action_type() + + cum_sizes = cls.cumulative_params_sizes() + params = index[cum_sizes[action_idx]:cum_sizes[action_idx + 1]].long().tolist() + return action_type(params, state) + + @classmethod + def action_number(cls, action_type: type[Action]) -> int: + return cls.supported_actions.index(action_type) + + @classmethod + def action_type_by_symbol(cls, symbol: str) -> type[Action]: + for action in cls.supported_actions: + if action.symbol == symbol: + return action + + raise ValueError(f"action symbol '{symbol}' not supported") + + @classmethod + def action_number_by_symbol(cls, symbol: str) -> int: + return cls.action_number(cls.action_type_by_symbol(symbol)) + + @classmethod + def action_mask(cls, state: OperationState) -> torch.Tensor: + mask = torch.zeros(cls.size(), dtype=torch.bool) + + def allow_action(a: type[Action]): + if a.is_allowed(state): + mask[cls.action_number(a)] = True + + def allow_all(): + for action in cls.supported_actions: + allow_action(action) + + # If state is terminal don't allow any further actions + if not state.terminal: + if Interchange.incomplete_interchange(state): + # Special case where interchange isn't complete yet + mask[cls.action_number(Interchange)] = True + elif cfg.order: + # Enforce order if provided + if state.step_count >= len(cfg.order): + raise Exception("actions order must be ended with a terminal action") + if not cfg.order[state.step_count]: + # If at current step nothing is specified, allow everything + allow_all() + for s in cfg.order[state.step_count]: + allow_action(cls.action_type_by_symbol(s)) + else: + # If none of the above applies, allow everything + allow_all() + + # Check that there is at least one action allowed + if not mask.any(): + raise Exception("no actions allowed in the current state") + + for action in cls.supported_actions: + action_mask = action.action_mask(state) + if action_mask is None: + continue + + mask = torch.cat((mask, action_mask)) + + return mask + + @classmethod + def action_history(cls, state: OperationState) -> torch.Tensor: + history = [] + for action in cls.supported_actions: + action_history = action.action_history(state) + if action_history is None: + continue + history.append(action_history) + if not history: + return torch.tensor([]) + + return torch.cat(history) + + @classmethod + def distributions(cls, obs: torch.Tensor, selection_logits: torch.Tensor, *actions_logits: Optional[torch.Tensor]) -> list[Optional[Distribution]]: + """Create a list of distributions for the actions based on the logits. + + Args: + obs (torch.Tensor): Observation tensor. + selection_logits (torch.Tensor): Logits for action selection. + *actions_logits (torch.Tensor): Logits for each action's parameters. + + Returns: + list[Distribution]: List of distributions for each action. + """ + from rl_autoschedular.observation import Observation, ActionMask + + actions_mask = Observation.get_part(obs, ActionMask).bool() + dists_list: list[Optional[Distribution]] = [ + Categorical(logits=selection_logits.where(actions_mask[:, :cls.size()], -torch.inf)) + ] + cum_sizes = cls.cumulative_mask_sizes() + for i, action in enumerate(cls.supported_actions): + if not action.mask_size(): + dists_list.append(None) + continue + + assert actions_logits[i] is not None, f"action '{action.symbol}' must have logits" + masked_logits = actions_logits[i].where(actions_mask[:, cum_sizes[i]:cum_sizes[i + 1]], -torch.inf) + dists_list.append(action.distribution(masked_logits)) + + return dists_list + + @classmethod + def uniform_distributions(cls, obs: torch.Tensor) -> list[Optional[Distribution]]: + """Create a list of uniform distributions for the actions based on the observation. + + Args: + obs (torch.Tensor): Observation tensor. + + Returns: + list[Distribution]: List of distributions for each action. + """ + from rl_autoschedular.observation import Observation, ActionMask, NumLoops + + actions_mask = Observation.get_part(obs, ActionMask).bool() + num_loops = Observation.get_part(obs, NumLoops) + selection_mask = actions_mask[:, :cls.size()] + dists_list: list[Optional[Distribution]] = [ + Categorical(logits=torch.zeros_like(selection_mask).where(selection_mask, -torch.inf)) + ] + cum_sizes = cls.cumulative_mask_sizes() + for i, action in enumerate(cls.supported_actions): + if not action.mask_size(): + dists_list.append(None) + continue + + action_mask = actions_mask[:, cum_sizes[i]:cum_sizes[i + 1]] + logits = torch.zeros_like(action_mask).where(action_mask, -torch.inf) + dists_list.append(action.uniform_distribution(logits, num_loops)) + + return dists_list + + @classmethod + def distributions_stats(cls, distributions: list[Optional[Distribution]], index: torch.Tensor, eps_distributions: Optional[list[Optional[Distribution]]] = None, eps: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor]: + eps_distributions: list[Optional[Distribution]] = eps_distributions or [None] * len(distributions) + + selection_index = index[:, 0] + selection_dist = distributions[0] + selection_eps_dist = eps_distributions[0] + selection_log_p = selection_dist.log_prob(selection_index) + if eps is not None: + selection_eps_log_p = selection_eps_dist.log_prob(selection_index) + selection_log_p = (selection_log_p.exp() * (1 - eps) + selection_eps_log_p.exp() * eps).log() + + cum_sizes = cls.cumulative_params_sizes() + actions_log_p, entropies = selection_log_p, selection_dist.entropy() + for i, (action, dist, eps_dist) in enumerate(zip(cls.supported_actions, distributions[1:], eps_distributions[1:])): + if dist is None: + continue + + action_index = index[:, cum_sizes[i]:cum_sizes[i + 1]] + action_log_p, entropy = action.distribution_stats(dist, action_index, eps_dist, eps) + actions_log_p[selection_index == i] += action_log_p[selection_index == i] + entropies[selection_index == i] += entropy[selection_index == i] + + return actions_log_p, entropies + + @classmethod + def sample(cls, obs: torch.Tensor, distributions: list[Optional[Distribution]], eps_distributions: list[Optional[Distribution]], uniform: bool = False, greedy: bool = False) -> torch.Tensor: + assert not uniform or not greedy, "can't sample uniformly and greedily at once" + from rl_autoschedular.observation import Observation, NumLoops + + num_loops = Observation.get_part(obs, NumLoops) + selection_dist = distributions[0] + selection_eps_dist = eps_distributions[0] + + if greedy: + selection_index = selection_dist.probs.argmax(-1) + elif uniform: + selection_index = selection_eps_dist.sample() + else: + selection_index = selection_dist.sample() + + index = selection_index.unsqueeze(-1) + for action, dist, eps_dist in zip(cls.supported_actions, distributions[1:], eps_distributions[1:]): + if dist is None: + continue + + index = torch.cat((index, action.sample(dist, eps_dist, num_loops, uniform, greedy)), dim=1) + + return index diff --git a/rl_autoschedular/actions/base.py b/rl_autoschedular/actions/base.py old mode 100644 new mode 100755 index 9993a8f..786b07e --- a/rl_autoschedular/actions/base.py +++ b/rl_autoschedular/actions/base.py @@ -1,254 +1,217 @@ -from typing import Optional, overload, Union, Any -from rl_autoschedular.state import OperationState, OperationFeatures -from utils.log import print_error -import torch -from torch.distributions import Distribution - - -class Action: - """Base action class""" - - symbol: str - - operation_tag: str - parameters: Optional[list[int]] - extras: dict[str, Any] - - # --- defaults --- - ready: bool = True - terminal: bool = False - sub_actions: list['Action'] = [] - - @overload - def __init__(self, operation_tag: str, **extras): - """Initialize action without parameters""" - ... - - @overload - def __init__(self, state: OperationState, **extras): - """Initialize action dependent on state but without parameters - - Args: - state (OperationState): current state to apply the action on - """ - ... - - @overload - def __init__(self, parameters: list[int], operation_tag: str, **extras): - """Initialize action with parameters - - Args: - parameters (list[int]): list of parameters for the action - """ - ... - - @overload - def __init__(self, parameters: list[int], state: OperationState, **extras): - """Initialize action with unprocessed parameters - - Args: - parameters (list[int]): list of parameters for the action - state (OperationState): current state to apply the action on - """ - ... - - def __init__( - self, - arg1: Optional[Union[OperationState, list[int]]] = None, - arg2: Optional[OperationState] = None, - operation_tag: Optional[str] = None, - **extras - ): - if isinstance(arg1, OperationState): - parameters = None - state = arg1 - else: - parameters = arg1 - state = arg2 - if (state is None) == (operation_tag is None): - raise ValueError("Either state or operation tag must be provided and not both") - if state: - operation_tag = state.operation_tag - self.operation_tag = operation_tag - self.parameters = parameters - self.extras = {'operation_tag': operation_tag, **extras} - - def __repr__(self) -> str: - """String representation of the action with extra params""" - params_list = list(map(str, self.parameters)) if self.parameters else [] - params_list.extend(f'{k} = {v}' for k, v in self.extras.items()) - - return f"{self.__class__.__name__}({', '.join(params_list)})" - - def __str__(self) -> str: - """String representation of the action""" - return f"{self.symbol}({','.join(map(str, self.parameters)) if self.parameters else ''})" - - @classmethod - def params_size(cls) -> int: - """Return the size of the parameters in the index for this action type - - Returns: - int: size of the parameters for this action type - """ - return 0 - - @classmethod - def network_output_size(cls) -> int: - """Return the size of the network output for this action type - - Returns: - int: size of the network output for this action type - """ - return 0 - - @classmethod - def mask_size(cls) -> int: - """Return the size of the mask for this action type - - Returns: - int: size of the mask for this action type - """ - return cls.network_output_size() - - @classmethod - def history_size(cls) -> int: - """Return the size of the history for this action type - - Returns: - int: size of the history for this action type - """ - return 0 - - @classmethod - def is_allowed(cls, state: OperationState) -> bool: - """Check if this action type is allowed in the current state - - Args: - state (OperationState): current state to check the action on - - Returns: - bool: True if the action is allowed, False otherwise - """ - return True - - @classmethod - def action_mask(cls, state: OperationState) -> Optional[torch.Tensor]: - """Return the action mask for this action type in the current state - - Args: - state (OperationState): current state to check the action on - - Returns: - Optional[torch.Tensor]: action mask for this action type, or None if not applicable - """ - return None - - @classmethod - def action_history(cls, state: OperationState) -> Optional[torch.Tensor]: - """Return the action history for this action type in the current state - - Args: - state (OperationState): current state to check the action on - - Returns: - Optional[torch.Tensor]: action history for this action type, or None if not applicable - """ - return None - - @classmethod - def distribution(cls, logits: torch.Tensor) -> Distribution: - """Create a distribution for this action type based on the logits - - Args: - logits (torch.Tensor): Logits for the action selection. - - Returns: - Distribution: A distribution object for this action type. - """ - raise NotImplementedError - - @classmethod - def uniform_distribution(cls, logits: torch.Tensor, num_loops: torch.Tensor) -> Distribution: - """Create a uniform distribution for this action type based on the logits and number of loops - - Args: - logits (torch.Tensor): Logits for the action selection. - num_loops (torch.Tensor): Number of loops in the operation state. - - Returns: - Distribution: A uniform distribution object for this action type. - """ - return cls.distribution(logits) - - @classmethod - def distribution_stats(cls, distribution: Distribution, index: torch.Tensor, eps_distribution: Optional[Distribution], eps: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor]: - """Calculate the log probabilities and entropies for the distribution - - Args: - distribution (Distribution): The distribution to calculate stats for. - eps_distribution (Distribution): The epsilon distribution for exploration. - index (torch.Tensor): The params index. - eps (Optional[float]): Epsilon value for exploration. Defaults to None. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Log probabilities and entropies. - """ - raise NotImplementedError - - @classmethod - def sample(cls, distribution: Distribution, eps_distribution: Distribution, num_loops: torch.Tensor, uniform: bool, greedy: bool) -> torch.Tensor: - """Sample an action based on the distribution - - Args: - distribution (Distribution): The distribution to sample from. - eps_distribution (Distribution): The epsilon distribution for exploration. - num_loops (torch.Tensor): Number of loops in the operation state. - uniform (bool): Whether to sample uniformly. - greedy (bool): Whether to sample greedily. - - Returns: - torch.Tensor: Sampled action index. - """ - raise NotImplementedError - - def apply(self, code: str) -> tuple[str, bool]: - """Apply action on the current code - - Args: - code (str): current code to apply the action on - - Returns: - tuple[str, bool]: the new transformed code and a flag that determines if the action was successful - """ - if not self.ready: - return code, True - - try: - transformed_code = self._apply_ready(code) - return transformed_code, True - except Exception as e: - print_error(f"Error applying action {self}: {e}") - return '', False - - def _apply_ready(self, code: str) -> str: - """Apply action that is guarenteed to be ready on the current state - - Args: - state (OperationState): current state to apply the action on - - Returns: - tuple[str, bool]: the new transformed code and a flag that determines if the action was successful - """ - raise NotImplementedError - - def update_features(self, operation_features: OperationFeatures) -> OperationFeatures: - """Update the operation features based on the action - - Args: - operation_features (OperationFeatures): The operation features to update. - - Returns: - OperationFeatures: The updated operation features. - """ - return operation_features +from typing import Optional, overload +from rl_autoschedular.state import OperationState, OperationFeatures +from utils.log import print_error +import torch +from torch.distributions import Distribution + + +class Action: + """Base action class""" + + symbol: str + parameters: Optional[list[int]] + ready: bool = True + terminal: bool = False + + @overload + def __init__(self): + """Initialize action without parameters""" + ... + + @overload + def __init__(self, parameters: list[int]): + """Initialize action with parameters + + Args: + parameters (list[int]): list of parameters for the action + """ + ... + + @overload + def __init__(self, parameters: list[int], state: OperationState): + """Initialize action with unprocessed parameters + + Args: + parameters (list[int]): list of parameters for the action + state (OperationState): current state to apply the action on + """ + ... + + def __init__(self, parameters: Optional[list[int]] = None, *_): + self.parameters = parameters + + def __repr__(self) -> str: + """String representation of the action""" + return f"{self.symbol}({','.join(map(str, self.parameters)) if self.parameters else ''})" + + def __str__(self) -> str: + """String representation of the action""" + return self.__repr__() + + @classmethod + def params_size(cls) -> int: + """Return the size of the parameters in the index for this action type + + Returns: + int: size of the parameters for this action type + """ + return 0 + + @classmethod + def network_output_size(cls) -> int: + """Return the size of the network output for this action type + + Returns: + int: size of the network output for this action type + """ + return 0 + + @classmethod + def mask_size(cls) -> int: + """Return the size of the mask for this action type + + Returns: + int: size of the mask for this action type + """ + return cls.network_output_size() + + @classmethod + def history_size(cls) -> int: + """Return the size of the history for this action type + + Returns: + int: size of the history for this action type + """ + return 0 + + @classmethod + def is_allowed(cls, state: OperationState) -> bool: + """Check if this action type is allowed in the current state + + Args: + state (OperationState): current state to check the action on + + Returns: + bool: True if the action is allowed, False otherwise + """ + return True + + @classmethod + def action_mask(cls, state: OperationState) -> Optional[torch.Tensor]: + """Return the action mask for this action type in the current state + + Args: + state (OperationState): current state to check the action on + + Returns: + Optional[torch.Tensor]: action mask for this action type, or None if not applicable + """ + return None + + @classmethod + def action_history(cls, state: OperationState) -> Optional[torch.Tensor]: + """Return the action history for this action type in the current state + + Args: + state (OperationState): current state to check the action on + + Returns: + Optional[torch.Tensor]: action history for this action type, or None if not applicable + """ + return None + + @classmethod + def distribution(cls, logits: torch.Tensor) -> Distribution: + """Create a distribution for this action type based on the logits + + Args: + logits (torch.Tensor): Logits for the action selection. + + Returns: + Distribution: A distribution object for this action type. + """ + raise NotImplementedError + + @classmethod + def uniform_distribution(cls, logits: torch.Tensor, num_loops: torch.Tensor) -> Distribution: + """Create a uniform distribution for this action type based on the logits and number of loops + + Args: + logits (torch.Tensor): Logits for the action selection. + num_loops (torch.Tensor): Number of loops in the operation state. + + Returns: + Distribution: A uniform distribution object for this action type. + """ + return cls.distribution(logits) + + @classmethod + def distribution_stats(cls, distribution: Distribution, index: torch.Tensor, eps_distribution: Optional[Distribution], eps: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate the log probabilities and entropies for the distribution + + Args: + distribution (Distribution): The distribution to calculate stats for. + eps_distribution (Distribution): The epsilon distribution for exploration. + index (torch.Tensor): The params index. + eps (Optional[float]): Epsilon value for exploration. Defaults to None. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Log probabilities and entropies. + """ + raise NotImplementedError + + @classmethod + def sample(cls, distribution: Distribution, eps_distribution: Distribution, num_loops: torch.Tensor, uniform: bool, greedy: bool) -> torch.Tensor: + """Sample an action based on the distribution + + Args: + distribution (Distribution): The distribution to sample from. + eps_distribution (Distribution): The epsilon distribution for exploration. + num_loops (torch.Tensor): Number of loops in the operation state. + uniform (bool): Whether to sample uniformly. + greedy (bool): Whether to sample greedily. + + Returns: + torch.Tensor: Sampled action index. + """ + raise NotImplementedError + + def apply(self, state: OperationState) -> tuple[str, bool]: + """Apply action on the current state + + Args: + state (OperationState): current state to apply the action on + + Returns: + tuple[str, bool]: the new transformed code and a flag that determines if the action was successful + """ + if self.ready: + assert self.is_allowed(state), "Operation isn't allowed for this state" + try: + return self._apply_ready(state) + except Exception as e: + print_error(f"Error applying action {self}: {e}") + return '', False + return state.transformed_code, True + + def _apply_ready(self, state: OperationState) -> tuple[str, bool]: + """Apply action that is guarenteed to be ready on the current state + + Args: + state (OperationState): current state to apply the action on + + Returns: + tuple[str, bool]: the new transformed code and a flag that determines if the action was successful + """ + raise NotImplementedError + + def update_features(self, operation_features: OperationFeatures) -> OperationFeatures: + """Update the operation features based on the action + + Args: + operation_features (OperationFeatures): The operation features to update. + + Returns: + OperationFeatures: The updated operation features. + """ + return operation_features diff --git a/rl_autoschedular/actions/interchange.py b/rl_autoschedular/actions/interchange.py old mode 100644 new mode 100755 index abad2ab..deeb8a8 --- a/rl_autoschedular/actions/interchange.py +++ b/rl_autoschedular/actions/interchange.py @@ -1,256 +1,260 @@ -from utils.config import Config -from .base import Action -from rl_autoschedular.state import OperationState, OperationType -from rl_autoschedular.transforms import transform_interchange -from typing import Optional -from enum import Enum -from utils.log import print_error -import torch -from torch.distributions import Categorical, Normal, Uniform -import math - - -class InterchangeMethod(Enum): - EnumeratedCandidates = 'enumerate' - LevelsPointers = 'pointers' - ContinuousEncoding = 'continuous' - - -class Interchange(Action): - """Class representing Interchange action""" - - symbol = 'I' - - parameters: list[int] - - # --- constants --- - method = InterchangeMethod(Config().interchange_mode) - log_std = torch.nn.Parameter(torch.zeros(1)) - - def __init__(self, parameters: list[int], state: Optional[OperationState] = None, **extras): - if state: - # Case where state is provided -> Parameters need processing - - assert len(parameters) == 1, 'uncompatible parameters for constructor call' - parameter = parameters[0] - num_loops = len(state.operation_features.nested_loops) - match Interchange.method: - case InterchangeMethod.EnumeratedCandidates: - parameters = self.__get_candidates(num_loops)[parameter] - case InterchangeMethod.ContinuousEncoding: - parameters = self.__decode_continuous(parameter, num_loops) - case InterchangeMethod.LevelsPointers: - old_action = self.incomplete_interchange(state) - if old_action: - perm_buffer = old_action.parameters - else: - perm_buffer = [] - - assert parameter not in perm_buffer, 'repitition detected in permutation' - parameters = perm_buffer + [parameter] - assert len(parameters) <= num_loops, 'interchange parameter exceeds number of loops' - if len(parameters) < num_loops: - self.ready = False - super().__init__(parameters, state, **extras) - - @classmethod - def params_size(cls): - return 1 - - @classmethod - def network_output_size(cls): - match cls.method: - case InterchangeMethod.EnumeratedCandidates: - return 3 * Config().max_num_loops - 6 - case InterchangeMethod.LevelsPointers: - return Config().max_num_loops - case InterchangeMethod.ContinuousEncoding: - return 1 - - @classmethod - def history_size(cls): - return Config().truncate * Config().max_num_loops * Config().max_num_loops - - @classmethod - def action_mask(cls, state): - L = Config().max_num_loops - I_BEGIN_2C = L - 1 - I_BEGIN_3C = I_BEGIN_2C + L - 2 - - num_loops = len(state.operation_features.nested_loops) - mask = torch.ones(cls.mask_size(), dtype=torch.bool) - match cls.method: - case InterchangeMethod.ContinuousEncoding: - pass - case InterchangeMethod.EnumeratedCandidates: - if num_loops == 1: - mask[1:] = False - else: - mask[num_loops - 1:I_BEGIN_2C] = False - mask[I_BEGIN_2C + num_loops - 2:I_BEGIN_3C] = False - mask[I_BEGIN_3C + max(num_loops - 3, 0):] = False - case InterchangeMethod.LevelsPointers: - mask[num_loops:] = False - old_action = cls.incomplete_interchange(state) - if old_action: - for param in old_action.parameters: - mask[param] = False - - return mask - - @classmethod - def action_history(cls, state): - history = torch.zeros((Config().truncate, Config().max_num_loops, Config().max_num_loops)) - for i, action in enumerate(state.transformation_history[0]): - if not isinstance(action, Interchange): - continue - - for j, param in enumerate(action.parameters): - history[i, j, param] = 1 - - return history.reshape(-1) - - @classmethod - def distribution(cls, logits): - match cls.method: - case InterchangeMethod.EnumeratedCandidates | InterchangeMethod.LevelsPointers: - return Categorical(logits=logits) - case InterchangeMethod.ContinuousEncoding: - logit = logits.squeeze(-1) - return Normal(logit, cls.log_std.clamp(-1, 1).exp()) - - @classmethod - def uniform_distribution(cls, logits, num_loops): - match cls.method: - case InterchangeMethod.EnumeratedCandidates | InterchangeMethod.LevelsPointers: - return Categorical(logits=logits) - case InterchangeMethod.ContinuousEncoding: - total_count = (num_loops + 1).lgamma().exp() - return Uniform(0.0, total_count) - - @classmethod - def distribution_stats(cls, distribution, index, eps_distribution, eps=None): - index = index.squeeze(-1) - if isinstance(distribution, Normal): - # Special case in Normal distribution we need to consider all - # the interval [i,i+1), so we use log CDF instead of log P - log_p = (distribution.cdf(index + 1) - distribution.cdf(index) + 1e-8).log() - else: - log_p = distribution.log_prob(index) - - if eps is not None: - eps_log_p = eps_distribution.log_prob(index) - log_p = (log_p.exp() * (1 - eps) + eps_log_p.exp() * eps).log() - - entropy = distribution.entropy() - - return log_p, entropy - - @classmethod - def sample(cls, distribution, eps_distribution, num_loops, uniform, greedy): - if greedy: - if cls.method == InterchangeMethod.ContinuousEncoding: - index = distribution.mean.long() - else: - index = distribution.probs.argmax(-1) - elif uniform: - index = eps_distribution.sample().long() - else: - index = distribution.sample().long() - - if cls.method == InterchangeMethod.ContinuousEncoding: - total_count = (num_loops + 1).lgamma().exp().long() - index = index.clamp(torch.zeros_like(total_count).long(), total_count - 1) - - return index.unsqueeze(-1) - - def _apply_ready(self, code): - return transform_interchange(code, self.operation_tag, self.parameters) - - def update_features(self, operation_features): - if not self.ready: - return operation_features - - new_operation_features = operation_features.copy() - for i, j in enumerate(self.parameters): - new_operation_features.nested_loops[i] = operation_features.nested_loops[j] - - # In case an interchange was applied to pooling or conv, vectorization is no longer possible - if operation_features.operation_type in [OperationType.Pooling, OperationType.Conv] and self.parameters != list(range(len(self.parameters))): - new_operation_features.vectorizable = False - - return new_operation_features - - @staticmethod - def __decode_continuous(parameter: int, num_loops: int) -> list[int]: - """Decode the interchange parameter to get the loop permutation. - - Args: - parameter (int): The interchange parameter. - num_loops (int): The number of loops in the operation. - - Returns: - list[int]: The loop permutation. - """ - x = parameter - n = num_loops - if x >= math.factorial(n): - print_error(f"Invalid interchange parameter: {x}") - x = math.factorial(n) - 1 - - # Convert x to factorial number - fact_x = '0' - q = x - d = 2 - while q > 0: - r = q % d - q = q // d - fact_x = str(r) + fact_x - d += 1 - - # Ensure to get exactly n digits - fact_x = fact_x.zfill(n)[-n:] - - # Decode factorial number following Lehmer code - nl = list(map(int, fact_x)) - for i in range(len(nl) - 2, -1, -1): - for j in range(i + 1, len(nl)): - if nl[j] >= nl[i]: - nl[j] += 1 - - return nl - - @staticmethod - def __get_candidates(num_loops: int) -> list[list[int]]: - """Get all 1c 2c 3c possible interchanges for `num_loops` - - Args: - num_loops (int): The number of loops in the operation. - - Returns: - list[tuple]: The list of all possible interchanges. - """ - - interchanges = [] - for c in [1, 2, 3]: - level_interchanges = [] - for _ in range(Config().max_num_loops - c): - level_interchanges.append(list(range(num_loops))) - for i in range(num_loops - c): - params = list(range(num_loops)) - params[i], params[i + c] = params[i + c], params[i] - level_interchanges[i] = params - interchanges += level_interchanges - return interchanges - - @classmethod - def incomplete_interchange(cls, state: OperationState) -> Optional['Interchange']: - if state.step_count >= len(state.transformation_history[0]): - return None - - old_action = state.transformation_history[0][state.step_count] - if not isinstance(old_action, Interchange): - return None - - assert not old_action.ready, 'expected previous interchange to be incomplete' - return old_action +from .base import Action +from rl_autoschedular import config as cfg +from rl_autoschedular.state import OperationState, OperationType +from rl_autoschedular.transforms import transform_dialect_interchange +from typing import Optional +from enum import Enum +from utils.log import print_error +import torch +from torch.distributions import Categorical, Normal, Uniform +import math + + +class InterchangeMethod(Enum): + EnumeratedCandidates = 'enumerate' + LevelsPointers = 'pointers' + ContinuousEncoding = 'continuous' + + +class Interchange(Action): + """Class representing Interchange action""" + + symbol = 'I' + method = InterchangeMethod(cfg.interchange_mode) + parameters: list[int] + log_std: Optional[torch.Tensor] = None + + def __init__(self, parameters: list[int], state: Optional[OperationState] = None): + if state: + assert len(parameters) == 1, 'uncompatible parameters for constructor call' + parameter = parameters[0] + num_loops = len(state.operation_features.nested_loops) + match Interchange.method: + case InterchangeMethod.EnumeratedCandidates: + parameters = self.__get_candidates(num_loops)[parameter] + case InterchangeMethod.ContinuousEncoding: + parameters = self.__decode_continuous(parameter, num_loops) + case InterchangeMethod.LevelsPointers: + old_action = self.incomplete_interchange(state) + if old_action: + perm_buffer = old_action.parameters + else: + perm_buffer = [] + + assert parameter not in perm_buffer, 'repitition detected in permutation' + parameters = perm_buffer + [parameter] + assert len(parameters) <= num_loops, 'interchange parameter exceeds number of loops' + if len(parameters) < num_loops: + self.ready = False + + super().__init__(parameters) + + @classmethod + def params_size(cls): + return 1 + + @classmethod + def network_output_size(cls): + match cls.method: + case InterchangeMethod.EnumeratedCandidates: + return 3 * cfg.max_num_loops - 6 + case InterchangeMethod.LevelsPointers: + return cfg.max_num_loops + case InterchangeMethod.ContinuousEncoding: + return 1 + + @classmethod + def history_size(cls): + return cfg.truncate * cfg.max_num_loops * cfg.max_num_loops + + @classmethod + def action_mask(cls, state): + L = cfg.max_num_loops + I_BEGIN_2C = L - 1 + I_BEGIN_3C = I_BEGIN_2C + L - 2 + + num_loops = len(state.operation_features.nested_loops) + mask = torch.ones(cls.mask_size(), dtype=torch.bool) + match cls.method: + case InterchangeMethod.ContinuousEncoding: + pass + case InterchangeMethod.EnumeratedCandidates: + if num_loops == 1: + mask[1:] = False + else: + mask[num_loops - 1:I_BEGIN_2C] = False + mask[I_BEGIN_2C + num_loops - 2:I_BEGIN_3C] = False + mask[I_BEGIN_3C + max(num_loops - 3, 0):] = False + case InterchangeMethod.LevelsPointers: + mask[num_loops:] = False + old_action = cls.incomplete_interchange(state) + if old_action: + for param in old_action.parameters: + mask[param] = False + + return mask + + @classmethod + def action_history(cls, state): + history = torch.zeros((cfg.truncate, cfg.max_num_loops, cfg.max_num_loops)) + for i, action in enumerate(state.transformation_history[0]): + if not isinstance(action, Interchange): + continue + + for j, param in enumerate(action.parameters): + history[i, j, param] = 1 + + return history.reshape(-1) + + @classmethod + def distribution(cls, logits): + match cls.method: + case InterchangeMethod.EnumeratedCandidates | InterchangeMethod.LevelsPointers: + return Categorical(logits=logits) + case InterchangeMethod.ContinuousEncoding: + logit = logits.squeeze(-1) + assert cls.log_std is not None, 'log_std must be set for continuous encoding' + return Normal(logit, cls.log_std.clamp(-1, 1).exp()) + + @classmethod + def uniform_distribution(cls, logits, num_loops): + match cls.method: + case InterchangeMethod.EnumeratedCandidates | InterchangeMethod.LevelsPointers: + return Categorical(logits=logits) + case InterchangeMethod.ContinuousEncoding: + total_count = (num_loops + 1).lgamma().exp() + return Uniform(0.0, total_count) + + @classmethod + def distribution_stats(cls, distribution, index, eps_distribution, eps=None): + index = index.squeeze(-1) + if isinstance(distribution, Normal): + # Special case in Normal distribution we need to consider all + # the interval [i,i+1), so we use log CDF instead of log P + log_p = (distribution.cdf(index + 1) - distribution.cdf(index) + 1e-8).log() + else: + log_p = distribution.log_prob(index) + + if eps is not None: + eps_log_p = eps_distribution.log_prob(index) + log_p = (log_p.exp() * (1 - eps) + eps_log_p.exp() * eps).log() + + entropy = distribution.entropy() + + return log_p, entropy + + @classmethod + def sample(cls, distribution, eps_distribution, num_loops, uniform, greedy): + if greedy: + if cls.method == InterchangeMethod.ContinuousEncoding: + index = distribution.mean.long() + else: + index = distribution.probs.argmax(-1) + elif uniform: + index = eps_distribution.sample().long() + else: + index = distribution.sample().long() + + if cls.method == InterchangeMethod.ContinuousEncoding: + total_count = (num_loops + 1).lgamma().exp().long() + index = index.clamp(torch.zeros_like(total_count).long(), total_count - 1) + + return index.unsqueeze(-1) + + def _apply_ready(self, state): + new_code = transform_dialect_interchange( + state.transformed_code, + state.operation_tag, + self.parameters, + state.tmp_file + ) + + return new_code, bool(new_code) + + def update_features(self, operation_features): + if not self.ready: + return operation_features + + new_operation_features = operation_features.copy() + for i, j in enumerate(self.parameters): + new_operation_features.nested_loops[i] = operation_features.nested_loops[j] + + # In case an interchange was applied to pooling, vectorization is no longer possible + if operation_features.operation_type == OperationType.Pooling and self.parameters != list(range(len(self.parameters))): + new_operation_features.vectorizable = False + + return new_operation_features + + @staticmethod + def __decode_continuous(parameter: int, num_loops: int) -> list[int]: + """Decode the interchange parameter to get the loop permutation. + + Args: + parameter (int): The interchange parameter. + num_loops (int): The number of loops in the operation. + + Returns: + list[int]: The loop permutation. + """ + x = parameter + n = num_loops + if x >= math.factorial(n): + print_error(f"Invalid interchange parameter: {x}") + x = math.factorial(n) - 1 + + # Convert x to factorial number + fact_x = '0' + q = x + d = 2 + while q > 0: + r = q % d + q = q // d + fact_x = str(r) + fact_x + d += 1 + + # Ensure to get exactly n digits + fact_x = fact_x.zfill(n)[-n:] + + # Decode factorial number following Lehmer code + nl = list(map(int, fact_x)) + for i in range(len(nl) - 2, -1, -1): + for j in range(i + 1, len(nl)): + if nl[j] >= nl[i]: + nl[j] += 1 + + return nl + + @staticmethod + def __get_candidates(num_loops: int) -> list[list[int]]: + """Get all 1c 2c 3c possible interchanges for `num_loops` + + Args: + num_loops (int): The number of loops in the operation. + + Returns: + list[tuple]: The list of all possible interchanges. + """ + + interchanges = [] + for c in [1, 2, 3]: + level_interchanges = [] + for _ in range(cfg.max_num_loops - c): + level_interchanges.append(list(range(num_loops))) + for i in range(num_loops - c): + params = list(range(num_loops)) + params[i], params[i + c] = params[i + c], params[i] + level_interchanges[i] = params + interchanges += level_interchanges + return interchanges + + @classmethod + def incomplete_interchange(cls, state: OperationState) -> Optional['Interchange']: + if state.step_count >= len(state.transformation_history[0]): + return None + + old_action = state.transformation_history[0][state.step_count] + if not isinstance(old_action, Interchange): + return None + + assert not old_action.ready, 'expected previous interchange to be incomplete' + return old_action diff --git a/rl_autoschedular/actions/no_transformation.py b/rl_autoschedular/actions/no_transformation.py old mode 100644 new mode 100755 index 30ad243..c25e6f9 --- a/rl_autoschedular/actions/no_transformation.py +++ b/rl_autoschedular/actions/no_transformation.py @@ -1,20 +1,15 @@ -from typing import Optional -from rl_autoschedular.state import OperationState -from .base import Action - - -class NoTransformation(Action): - """Class representing No Transformation""" - - symbol = 'NT' - - parameters: None - - # --- constants --- - terminal = True - - def __init__(self, state: Optional[OperationState] = None, **extras): - super().__init__(state, **extras) - - def _apply_ready(self, code): - return code +from .base import Action + + +class NoTransformation(Action): + """Class representing No Transformation""" + + symbol = 'NT' + parameters: None + terminal = True + + def __init__(self): + super().__init__() + + def _apply_ready(self, state): + return state.transformed_code, True diff --git a/rl_autoschedular/actions/tiled_fusion.py b/rl_autoschedular/actions/tiled_fusion.py old mode 100644 new mode 100755 index d38012b..a4507c7 --- a/rl_autoschedular/actions/tiled_fusion.py +++ b/rl_autoschedular/actions/tiled_fusion.py @@ -1,47 +1,7 @@ -from .tiled_parallelization import TiledParallelization -from rl_autoschedular.transforms import transform_TF -from rl_autoschedular.state import OperationState -from typing import Optional - - -class TiledFusion(TiledParallelization): - """Class representing Tiled Fusion action""" - - symbol = 'TPF' - - # --- extras --- - producer_tag: str - - def __init__( - self, - parameters: list[int], - state: Optional[OperationState] = None, - producer_tag: Optional[str] = None, - **extras - ): - if (state is None) == (producer_tag is None): - raise ValueError("Either state or producer tag must be provided and not both") - if state: - producer_tag = state.producer_tag - super().__init__(parameters, state, producer_tag=producer_tag, **extras) - - self.producer_tag = producer_tag - - def __str__(self): - return f"{self.symbol}({self.producer_tag};{','.join(map(str, self.parameters))})" - - @classmethod - def is_allowed(cls, state): - already_fused = any(isinstance(action, cls) for action in state.transformation_history[0]) - has_producers = state.producer_tag is not None - - return has_producers and not already_fused - - def _apply_ready(self, code): - return transform_TF( - code, - self.operation_tag, - self.producer_tag, - self.tiling_params, - self.parallel_params, - ) +from .tiling import Tiling + + +class TiledFusion(Tiling): + """Class representing Tiled Fusion action""" + + symbol = 'TF' diff --git a/rl_autoschedular/actions/tiled_parallelization.py b/rl_autoschedular/actions/tiled_parallelization.py old mode 100644 new mode 100755 index d6a51d3..adb5ef6 --- a/rl_autoschedular/actions/tiled_parallelization.py +++ b/rl_autoschedular/actions/tiled_parallelization.py @@ -1,39 +1,23 @@ -from .tiling import Tiling, Optional -from rl_autoschedular.transforms import transform_tile, transform_TP -from rl_autoschedular.state import OperationState, IteratorType - - -class TiledParallelization(Tiling): - """Class representing Tiled Parallelization action""" - - symbol = 'TP' - - # --- extras --- - parallel_params: list[int] - tiling_params: list[int] - - def __init__( - self, - parameters: list[int], - state: Optional[OperationState] = None, - iterators: Optional[list[str]] = None, - **extras - ): - if (state is None) == (iterators is None): - raise ValueError("Either state or iterators must be provided and not both") - if state: - iterators = [loop.iterator_type.value for loop in state.operation_features.nested_loops] - super().__init__(parameters, state, iterators=iterators, **extras) - - self.parallel_params = [ - 0 if iterator == IteratorType.Reduction.value - else param for param, iterator in zip(self.parameters, iterators) - ] - self.tiling_params = [ - param if iterator == IteratorType.Reduction.value - else 0 for param, iterator in zip(self.parameters, iterators) - ] - - def _apply_ready(self, code: str): - p_code = transform_TP(code, self.operation_tag, self.parallel_params) - return transform_tile(p_code, self.operation_tag, self.tiling_params) +from .tiling import Tiling +from rl_autoschedular.transforms import transform_dialect_tile, transform_dialect_TP +from rl_autoschedular.state import OperationState, IteratorType + + +class TiledParallelization(Tiling): + """Class representing Tiled Parallelization action""" + + symbol = 'TP' + + def _apply_ready(self, state: OperationState): + parallel_params = [ + 0 if state.operation_features.nested_loops[i].iterator_type == IteratorType.Reduction + else param for i, param in enumerate(self.parameters) + ] + tiling_params = [ + param if state.operation_features.nested_loops[i].iterator_type == IteratorType.Reduction + else 0 for i, param in enumerate(self.parameters) + ] + new_code = transform_dialect_TP(state.transformed_code, state.operation_tag, parallel_params, state.tmp_file) + new_code = transform_dialect_tile(new_code, state.operation_tag, tiling_params, state.tmp_file) + + return new_code, bool(new_code) diff --git a/rl_autoschedular/actions/tiling.py b/rl_autoschedular/actions/tiling.py old mode 100644 new mode 100755 index 8a2d112..ae48f6f --- a/rl_autoschedular/actions/tiling.py +++ b/rl_autoschedular/actions/tiling.py @@ -1,127 +1,131 @@ -from rl_autoschedular.state import OperationState -from rl_autoschedular.transforms import transform_tile -from typing import Optional - -from utils.config import Config -from .base import Action -import torch -import math -from torch.distributions import Categorical - - -class Tiling(Action): - """Class representing Tiling action""" - - symbol = 'T' - - parameters: list[int] - - def __init__(self, parameters: list[int], state: Optional[OperationState] = None, **extras): - if state: - # Case where state is provided -> Parameters need processing - - tile_sizes = [] - for param, loop in zip(parameters, state.operation_features.nested_loops): - if param == 0: - tile_sizes.append(0) - else: - ts = 2 ** (param - 1) - assert loop.upper_bound % ts == 0 and loop.upper_bound != ts, \ - f'Tiling parameter {param} is not a factor of loop upper bound {loop.upper_bound}' - tile_sizes.append(ts) - parameters = tile_sizes - super().__init__(parameters, state, **extras) - - @classmethod - def params_size(cls): - return Config().max_num_loops - - @classmethod - def network_output_size(cls): - return Config().max_num_loops * (Config().num_tile_sizes + 1) - - @classmethod - def history_size(cls): - return Config().truncate * Config().max_num_loops * (Config().num_tile_sizes + 1) - - @classmethod - def action_mask(cls, state: OperationState): - mask = torch.zeros((Config().max_num_loops, Config().num_tile_sizes + 1), dtype=torch.bool) - mask[:, 0] = True - for i, loop in enumerate(state.operation_features.nested_loops): - ts_count = cls.__get_tiles_count(loop.upper_bound) - mask[i, :ts_count] = True - - return mask.reshape(-1) - - @classmethod - def action_history(cls, state): - history = torch.zeros((Config().truncate, Config().max_num_loops, Config().num_tile_sizes + 1)) - for i, action in enumerate(state.transformation_history[0]): - if not isinstance(action, Tiling): - continue - - for j, param in enumerate(action.parameters): - if param == 0: - history[i, j, 0] = 1 - else: - ts_index = int(math.log2(param)) - history[i, j, ts_index + 1] = 1 - - return history.reshape(-1) - - @classmethod - def distribution(cls, logits): - logits = logits.reshape(-1, Config().max_num_loops, Config().num_tile_sizes + 1) - return Categorical(logits=logits) - - @classmethod - def distribution_stats(cls, distribution, index, eps_distribution, eps=None): - log_p = distribution.log_prob(index).sum(-1) - - if eps is not None: - eps_log_p = eps_distribution.log_prob(index).sum(-1) - log_p = (log_p.exp() * (1 - eps) + eps_log_p.exp() * eps).log() - - entropy = distribution.entropy().sum(-1) - - return log_p, entropy - - @classmethod - def sample(cls, distribution, eps_distribution, num_loops, uniform, greedy): - if greedy: - index = distribution.probs.argmax(-1) - elif uniform: - index = eps_distribution.sample() - else: - index = distribution.sample() - - return index - - def _apply_ready(self, code): - return transform_tile(code, self.operation_tag, self.parameters) - - def update_features(self, operation_features): - new_operation_features = operation_features.copy() - for nested_loop, tile_size in zip(new_operation_features.nested_loops, self.parameters): - if tile_size == 0: - continue - nested_loop.upper_bound = tile_size - - return new_operation_features - - @staticmethod - def __get_tiles_count(ub: int) -> int: - """Get the number of tiling candidates for a given loop upper bound. - - Args: - ub (int): The loop upper bound. - - Returns: - int: The number of candidates. - """ - for i in range(Config().num_tile_sizes): - ts = 2 ** i - if ub % ts != 0 or ub == ts: - return i + 1 - return Config().num_tile_sizes + 1 +from rl_autoschedular import config as cfg +from rl_autoschedular.state import OperationState +from rl_autoschedular.transforms import transform_dialect_tile +from typing import Optional +from .base import Action +import torch +import math +from torch.distributions import Categorical + + +class Tiling(Action): + """Class representing Tiling action""" + + symbol = 'T' + parameters: list[int] + + def __init__(self, parameters: list[int], state: Optional[OperationState] = None): + if state: + tile_sizes = [] + for param, loop in zip(parameters, state.operation_features.nested_loops): + if param == 0: + tile_sizes.append(0) + else: + ts = 2 ** (param - 1) + assert loop.upper_bound % ts == 0, \ + f'Tiling parameter {param} is not a factor of loop upper bound {loop.upper_bound}' + tile_sizes.append(ts) + parameters = tile_sizes + + super().__init__(parameters) + + @classmethod + def params_size(cls): + return cfg.max_num_loops + + @classmethod + def network_output_size(cls): + return cfg.max_num_loops * (cfg.num_tile_sizes + 1) + + @classmethod + def history_size(cls): + return cfg.truncate * cfg.max_num_loops * (cfg.num_tile_sizes + 1) + + @classmethod + def action_mask(cls, state: OperationState): + mask = torch.zeros((cfg.max_num_loops, cfg.num_tile_sizes + 1), dtype=torch.bool) + mask[:, 0] = True + for i, loop in enumerate(state.operation_features.nested_loops): + ts_count = cls.__get_tiles_count(loop.upper_bound) + mask[i, :ts_count] = True + + return mask.reshape(-1) + + @classmethod + def action_history(cls, state): + history = torch.zeros((cfg.truncate, cfg.max_num_loops, cfg.num_tile_sizes + 1)) + for i, action in enumerate(state.transformation_history[0]): + if not isinstance(action, Tiling): + continue + + for j, param in enumerate(action.parameters): + if param == 0: + history[i, j, 0] = 1 + else: + ts_index = int(math.log2(param)) + history[i, j, ts_index + 1] = 1 + + return history.reshape(-1) + + @classmethod + def distribution(cls, logits): + logits = logits.reshape(-1, cfg.max_num_loops, cfg.num_tile_sizes + 1) + return Categorical(logits=logits) + + @classmethod + def distribution_stats(cls, distribution, index, eps_distribution, eps=None): + log_p = distribution.log_prob(index).sum(-1) + + if eps is not None: + eps_log_p = eps_distribution.log_prob(index).sum(-1) + log_p = (log_p.exp() * (1 - eps) + eps_log_p.exp() * eps).log() + + entropy = distribution.entropy().sum(-1) + + return log_p, entropy + + @classmethod + def sample(cls, distribution, eps_distribution, num_loops, uniform, greedy): + if greedy: + index = distribution.probs.argmax(-1) + elif uniform: + index = eps_distribution.sample() + else: + index = distribution.sample() + + return index + + def _apply_ready(self, state): + new_code = transform_dialect_tile( + state.transformed_code, + state.operation_tag, + self.parameters, + state.tmp_file + ) + + return new_code, bool(new_code) + + def update_features(self, operation_features): + new_operation_features = operation_features.copy() + for nested_loop, tile_size in zip(new_operation_features.nested_loops, self.parameters): + if tile_size == 0: + continue + nested_loop.upper_bound = tile_size + + return new_operation_features + + @staticmethod + def __get_tiles_count(ub: int) -> int: + """Get the number of tiling candidates for a given loop upper bound. + + Args: + ub (int): The loop upper bound. + + Returns: + int: The number of candidates. + """ + for i in range(cfg.num_tile_sizes): + ts = 2 ** i + if ub % ts != 0: + return i + 1 + return cfg.num_tile_sizes + 1 diff --git a/rl_autoschedular/actions/vectorization.py b/rl_autoschedular/actions/vectorization.py old mode 100644 new mode 100755 index 87bed4f..697fdc0 --- a/rl_autoschedular/actions/vectorization.py +++ b/rl_autoschedular/actions/vectorization.py @@ -1,120 +1,51 @@ -from utils.config import Config -from .base import Action -from rl_autoschedular.transforms import ( - transform_vectorize, transform_tile, - transform_decompose, transform_transpose_conv_2d -) -from rl_autoschedular.state import OperationFeatures, OperationState, OperationType -from typing import Callable, Optional - - -class Vectorization(Action): - """Class representing Vectorization action""" - - symbol = 'V' - parameters: None - - # --- constants --- - terminal = True - - # --- extras --- - preprocessing: list[Callable[[str], str]] - - def __init__( - self, - state: Optional[OperationState] = None, - requires_transpose: Optional[bool] = None, - requires_decompose: Optional[bool] = None, - decompose_tile_sizes: Optional[list[int]] = None, - **extras - ): - args_is_none = [ - requires_transpose is None, - requires_decompose is None, - decompose_tile_sizes is None - ] - if (state is None) in args_is_none: - raise ValueError("Either state or preprocessing attributes must be provided and not both") - if state: - op_feats = state.operation_features.copy() - - if op_feats.operation_type not in [OperationType.Pooling, OperationType.Conv]: - requires_transpose, requires_decompose, decompose_tile_sizes = False, False, [] - else: - if requires_transpose := self.__requires_transpose(op_feats): - op_feats.operation_name = 'linalg.conv_2d_nhwc_hwcf' - decompose_tile_sizes = [] - if requires_decompose := self.__requires_decompose(op_feats): - decompose_tile_sizes = self.__decompose_tile_sizes(op_feats) - super().__init__( - state, - requires_transpose=requires_transpose, - requires_decompose=requires_decompose, - decompose_tile_sizes=decompose_tile_sizes, - **extras - ) - - self.preprocessing = [] - if requires_transpose: - self.preprocessing.append(lambda c: transform_transpose_conv_2d(c, self.operation_tag)) - if requires_decompose: - self.preprocessing.append(lambda c: transform_tile(c, self.operation_tag, decompose_tile_sizes)) - self.preprocessing.append(lambda c: transform_decompose(c, self.operation_tag)) - - @classmethod - def is_allowed(cls, state): - if not state.operation_features.vectorizable: - return False - - op_iter_space = 1 - for nested_loop in state.operation_features.nested_loops: - op_iter_space *= nested_loop.upper_bound - return op_iter_space <= Config().vect_size_limit - - def _apply_ready(self, code): - for pre in self.preprocessing: - code = pre(code) - - return transform_vectorize(code, self.operation_tag) - - @classmethod - def __requires_transpose(cls, operation_features: OperationFeatures) -> bool: - return operation_features.operation_name == 'linalg.conv_2d_nhwc_fhwc' - - @classmethod - def __requires_decompose(cls, operation_features: OperationFeatures) -> bool: - """a.k.a is a two dimensional conv interface op""" - - if 'conv_2d' in operation_features.operation_name: - return True - - if operation_features.operation_type == OperationType.Pooling and len(operation_features.nested_loops) >= 6: - return True - - return False - - @classmethod - def __decompose_tile_sizes(cls, operation_features: OperationFeatures) -> list[int]: - tile_sizes = [0 for _ in operation_features.nested_loops] - - oh = None - if operation_features.operation_name == 'linalg.conv_2d': - oh = 0 - elif '_nhwc_' in operation_features.operation_name: - oh = 1 - elif '_nchw_' in operation_features.operation_name: - oh = 2 - - kh = None - if operation_features.operation_name == 'linalg.conv_2d': - kh = 2 - elif '_fchw' in operation_features.operation_name: - kh = 5 - elif '_hwc' in operation_features.operation_name or operation_features.operation_type == OperationType.Pooling: - kh = 4 - - if oh is not None and kh is not None: - tile_sizes[oh] = 1 - tile_sizes[kh] = 1 - - return tile_sizes +from .base import Action +from rl_autoschedular import config as cfg +from rl_autoschedular.transforms import transform_dialect_vectorize, transform_dialect_tile, apply_conv2d_decomposition +from rl_autoschedular.state import OperationState, OperationType + + +class Vectorization(Action): + """Class representing Vectorization action""" + + symbol = 'V' + parameters: None + terminal = True + + def __init__(self): + super().__init__() + + @classmethod + def is_allowed(cls, state): + if not state.operation_features.vectorizable: + return False + + op_iter_space = 1 + for nested_loop in state.operation_features.nested_loops: + op_iter_space *= nested_loop.upper_bound + return op_iter_space <= cfg.vect_size_limit + + def _apply_ready(self, state): + code = state.transformed_code + + # Decompose pooling operation to make it vectorizable + if state.operation_features.operation_type == OperationType.Pooling: + code, decomposed = self.__decompose_pooling(state) + if not decomposed: + raise Exception("Pooling decomposition not successful") + + new_code = transform_dialect_vectorize(code, state.operation_tag, state.tmp_file) + + return new_code, bool(new_code) + + @staticmethod + def __decompose_pooling(state: OperationState) -> tuple[str, bool]: + assert len(state.operation_features.nested_loops) == 6 + + # Tile the pooling operation for decomposition + tile_sizes = [0, 0, 1, 0, 1, 0] + new_code = transform_dialect_tile(state.transformed_code, state.operation_tag, tile_sizes, state.tmp_file) + + # Apply the decomposition + new_code = apply_conv2d_decomposition(new_code, state.operation_tag, state.tmp_file) + + return new_code, bool(new_code) diff --git a/rl_autoschedular/benchmarks.py b/rl_autoschedular/benchmarks.py deleted file mode 100644 index 08998a6..0000000 --- a/rl_autoschedular/benchmarks.py +++ /dev/null @@ -1,54 +0,0 @@ -from rl_autoschedular.state import BenchmarkFeatures, extract_bench_features_from_file -from utils.config import Config -import json -from tqdm import tqdm -import os - - -class Benchmarks: - """A class that holds benchmarks data""" - - data: list[BenchmarkFeatures] - - def __init__(self, is_training: bool = True): - """Load benchmarks - - Args: - is_training (bool): Whether to load train or evaluation set - """ - cfg = Config() - # Load benchmark names and execution times from json file - bench_json_file = cfg.json_file - - # If we are in evaluation mode, use the evaluation json file if provided - if cfg.eval_json_file and not is_training: - bench_json_file = cfg.eval_json_file - - with open(bench_json_file) as file: - benchmarks_json: dict[str, int] = json.load(file) - - # Build benchmark features - self.data = [] - for bench_name, root_exec_time in tqdm(benchmarks_json.items(), desc="Extracting benchmark features", unit="bench"): - bench_file = os.path.join(cfg.benchmarks_folder_path, bench_name + ".mlir") - benchmark_data = extract_bench_features_from_file(bench_name, bench_file, root_exec_time) - - if cfg.split_ops and is_training and len(benchmark_data.operation_tags) > 1 and 'lqcd' in benchmark_data.bench_name: - # Split LQCD benchmarks into multiple single operations - # TODO: Improve with operatine-wise timing - # TODO: Convert LQCD to tensor-based to elliminate all of this - for tag in benchmark_data.operation_tags: - # Create a new benchmark data with only the current operation - new_bench_data = benchmark_data.copy() - new_bench_data.bench_name = f"{benchmark_data.bench_name}_{tag}" - new_bench_data.operation_tags = [tag] - new_bench_data.operations = {tag: new_bench_data.operations[tag]} - self.data.append(new_bench_data) - else: - self.data.append(benchmark_data) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx: int): - return self.data[idx] diff --git a/rl_autoschedular/env.py b/rl_autoschedular/env.py old mode 100644 new mode 100755 index 128939b..3c9b087 --- a/rl_autoschedular/env.py +++ b/rl_autoschedular/env.py @@ -1,281 +1,398 @@ -from rl_autoschedular.state import OperationState, BenchmarkFeatures, extract_bench_features_from_code -from rl_autoschedular.benchmarks import Benchmarks -from typing import Optional -from rl_autoschedular.execution import Execution -from rl_autoschedular.actions import Action, TiledFusion -from utils.log import print_error -from utils.config import Config -import random -import math -import traceback - - -class Env: - """RL Environment class""" - - bench_idx: int - """Index of the selected benchmark""" - benchmark_data: BenchmarkFeatures - """Features of the selected benchmark""" - - def reset(self, benchs: Benchmarks, bench_idx: Optional[int] = None) -> OperationState: - """Reset the environment. - - Args: - bench_idx (Optional[int]): The index of the benchmark to set the environement to. If None, a random benchmark is selected. Defaults to None. - - Returns: - OperationState: The initial state of the environment. - """ - # Get the benchmark - if bench_idx is None: - bench_idx = random.randint(0, len(benchs) - 1) - self.bench_idx = bench_idx - self.benchmark_data = benchs[bench_idx].copy() - - return self.__init_op_state(-1) - - def step(self, state: OperationState, action: Action) -> OperationState: - """Take a step in the environment. - - Args: - state (OperationState): The current state. - action (Action): The action to take. - - Returns: - OperationState: The new state. - float: The reward of the action. - bool: A flag indicating if the operation is done. - Optional[float]: The speedup (if the operation is executed successfully) for logging purposes. - """ - # Copy the current state to introduce the changes throughout the function - next_state = state.copy() - - # Update the state infos to reflect the transformation - self.__update_state_infos(next_state, action) - - # Check is state is terminal - next_state.terminal = action.terminal or next_state.step_count == Config().truncate - - return next_state - - def get_next_op_state(self, state: OperationState) -> Optional[OperationState]: - """Get the state that represents the next operation (None if benchmark is done). - - Args: - state (OperationState): The current state. - - Returns: - Optional[OperationState]: The next state. If None then bench is done. - """ - # Reset to another benchmark if the current benchmark is done (reached first operation) - if self.__bench_is_done(state): - return None - - # Build a new state that points to the next operation - next_state = self.__init_op_state(self.__current_op_index(state) - 1) - - # Keep track of the transformation history - next_state.transformation_history += state.transformation_history - - return next_state - - def apply_and_run_sequence(self, seq: list[list[Action]]) -> tuple[list[float], float, Optional[int], bool]: - transformed_code, rewards = self.__apply_sequence(self.benchmark_data.code, seq) - - # Evaluate the code (since the operation is done) - try: - new_exec_time, exec_succeeded, cache_miss = Execution().execute_code(transformed_code, self.benchmark_data.bench_name, seq) - if not exec_succeeded: - raise Exception("Incorrect results") - except Exception as e: - print_error(f"\n\nError while evaluating the code: {e}") - print_error("Exception type:", type(e).__name__) - print_error("Call stack:", traceback.format_exc()) - print_error("Bench:", self.benchmark_data.bench_name) - print_error("Transformations:", seq) - new_exec_time = None - exec_succeeded = False - cache_miss = True - - # The reward will take into consideration whether execution succeeded or not - rewards[-1] = self.__action_reward(True, exec_succeeded, new_exec_time, self.benchmark_data.root_exec_time) - speedup = (self.benchmark_data.root_exec_time / new_exec_time) if new_exec_time is not None else 1.0 - - return rewards, speedup, new_exec_time, cache_miss - - def __init_op_state(self, operation_idx: int) -> OperationState: - """Create a new operation state. - - Args: - operation_idx (int): The operation index. - - Returns: - OperationState: The new operation state. - torch.Tensor: The observation vector of the new operation state. - """ - operation_tag = self.benchmark_data.operation_tags[operation_idx] - operation_features = self.benchmark_data.operations[operation_tag] - - producer_tag = None - producer_features = None - if operation_features.producers: - producer_tag = operation_features.producers[-1] - producer_features = self.benchmark_data.operations[producer_tag] - - state = OperationState( - bench_idx=self.bench_idx, - bench_name=self.benchmark_data.bench_name, - operation_tag=operation_tag, - original_operation_features=operation_features.copy(), - operation_features=operation_features.copy(), - producer_tag=producer_tag, - producer_features=producer_features.copy() if producer_features else None, - step_count=0, - transformation_history=[[]], - terminal=False, - ) - - return state - - def __current_op_index(self, state: OperationState) -> int: - """Get the index of the current operation. - - Args: - state (OperationState): The current state. - - Returns: - int: The index of the current operation. - """ - return self.benchmark_data.operation_tags.index(state.operation_tag) - - def __bench_is_done(self, state: OperationState) -> bool: - """Check if the benchmark is done. - - Args: - state (OperationState): The current state. - - Returns: - bool: A flag indicating if the benchmark is done. - """ - return self.__current_op_index(state) == 0 - - def __action_reward(self, trans_succeeded: bool, exec_succeeded: Optional[bool] = None, new_exec_time: Optional[int] = None, old_exec_time: Optional[int] = None) -> float: - """Get the reward of the action based on the transformation and execution results. - - Args: - trans_succeeded (bool): A flag indicating if the transformation was successful. - exec_succeeded (Optional[bool]): A flag indicating if the execution was successful. (required if trans succeeded) - new_exec_time (Optional[float]): The execution time after transformation. (required if exec succeeded) - old_exec_time (Optional[float]): The original execution time. (required if exec succeeded) - - Returns: - float: The reward of the action. - """ - if not trans_succeeded: - return -5.0 - - assert exec_succeeded is not None - if not exec_succeeded: - return -20.0 - - assert new_exec_time is not None and old_exec_time is not None - return self.__speedup_reward(new_exec_time, old_exec_time) - - def __speedup_reward(self, new: int, old: int) -> float: - """Get the reward based on the speedup. - - Args: - new (float): The new execution time. - old (float): The old execution time. - - Returns: - float: The calculated reward. - """ - - # if old < new * 2: - # return math.log(old / (new * 2)) - # else: - # return old / (new * 2) - 1 - return math.log10(old / new) - - def __update_state_infos(self, state: OperationState, action: Action): - """Update state infos after applying a transformation. - - Notes: Updated fields are: - - operation_features (to reflect the transformation) - - transformation_history - - step_count - - producers features in case of fusion - (currently it's done by updating bench features, this should be changed after) - - Args: - state (OperationState): The current state. - action (Action): The action taken. - - Returns: - OperationState: The updated state. - """ - # Get updated operation features - state.operation_features = action.update_features(state.operation_features) - - # Record action - if state.step_count < len(state.transformation_history[0]): - # Case where the last action should be replaced - previous_action = state.transformation_history[0][state.step_count] - assert not previous_action.ready, f"Expected action {previous_action} not to be ready" - - action.sub_actions = previous_action.sub_actions + [previous_action] - state.transformation_history[0][state.step_count] = action - else: - state.transformation_history[0].append(action) - - # In case of fusion we need to update the producer features as well - # TODO: Maybe we can do this without actually applying the actions - if isinstance(action, TiledFusion): - fused_code, _ = self.__apply_sequence(self.benchmark_data.code, state.transformation_history) - new_bench_features = extract_bench_features_from_code('', fused_code, 0) - self.benchmark_data.operation_tags = new_bench_features.operation_tags - self.benchmark_data.operations = new_bench_features.operations - - # Increase count only if action was applied - if action.ready: - state.step_count += 1 - - def __apply_sequence(self, code: str, seq: list[list[Action]]) -> tuple[str, list[float]]: - """Apply the sequence of actions to the state's code. - - Args: - code (str): code to apply the actions to. - seq (list[Action]): the sequence of actions to apply. - - Returns: - tuple[str, list[float]]: the resulting code and rewards received for each action in the sequence. - """ - rewards: list[float] = [] - transformed_code = code - for op_seq in reversed(seq): - op_seq_already_failed = False - for action in op_seq: - # We need to assign the same reward to all sub actions - rewards_count = len(action.sub_actions) + 1 - - if op_seq_already_failed: - rewards.extend([rewards[-1]] * rewards_count) - continue - - # Attempt to apply the transformation to the code - # - If the transformation fails: punish the agent, reset the code, and mark the operation as done - new_transformed_code, trans_succeeded = action.apply(transformed_code) - if not trans_succeeded: - print_error("Transformation Failed:", action) - rewards.extend([self.__action_reward(trans_succeeded)] * rewards_count) - op_seq_already_failed = True - continue - - # Update transformed code - transformed_code = new_transformed_code - - rewards.extend([0.0] * rewards_count) - - return transformed_code, rewards +from rl_autoschedular import config as cfg +from rl_autoschedular.state import OperationState, BenchmarkFeatures, extract_bench_features_from_file +from typing import Optional +from rl_autoschedular.evaluation import evaluate_code +from rl_autoschedular.actions import Action +from utils.log import print_error +from tqdm import tqdm +import random +import string +import json +import os +import math +from enum import Enum + + +class Env: + """RL Environment class""" + + benchmarks_data: list[BenchmarkFeatures] + """Lists for each benchmark the benchmark's name and its features.""" + tmp_file: str + """The temporary file to store the intermediate representations.""" + is_training: bool + """Flag indicating if the environment is in training mode or evaluation mode.""" + + __bench_index: int + """The index of the current benchmark.""" + + def __init__(self, is_training: bool = True, tmp_file: Optional[str] = None,run_name: Optional[str] = None): + """Initialize the environment. + + Args: + tmp_file (Optional[str]): The temporary file to store the intermediate representations. Defaults to None. + """ + self.is_training = is_training + + # Generate a random file to be used in order to apply the transformations and evaluate the code + if tmp_file is None: + if run_name is not None: + random_str = run_name + else: + random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + tmp_file = f"tmp-debug/{random_str}.mlir" if cfg.debug else f"tmp/{random_str}.mlir" + with open(tmp_file, "w") as file: + file.write("") + os.makedirs(tmp_file.replace(".mlir", ""), exist_ok=True) + self.tmp_file = tmp_file + + # Load benchmark names and execution times from json file + bench_json_file = cfg.json_file + + # If we are in evaluation mode, use the evaluation json file if provided + if cfg.eval_json_file and not is_training: + bench_json_file = cfg.eval_json_file + + with open(bench_json_file) as file: + benchmarks_json: dict[str, int] = json.load(file) + + # Build benchmark features + self.benchmarks_data = [] + for bench_name, root_exec_time in tqdm(benchmarks_json.items(), desc="Extracting benchmark features", unit="bench"): + bench_file = os.path.join(cfg.benchmarks_folder_path, bench_name + ".mlir") + benchmark_data = extract_bench_features_from_file(bench_name, bench_file, root_exec_time) + + if cfg.split_ops and is_training and len(benchmark_data.operation_tags) > 1: + # Split benchmarks with more than one operation into multiple benchmarks + for tag in benchmark_data.operation_tags: + # Create a new benchmark data with only the current operation + new_bench_data = benchmark_data.copy() + new_bench_data.bench_name = f"{benchmark_data.bench_name}_{tag}" + new_bench_data.operation_tags = [tag] + new_bench_data.operations = {tag: new_bench_data.operations[tag]} + self.benchmarks_data.append(new_bench_data) + else: + self.benchmarks_data.append(benchmark_data) + + def save_benchmarks_data_to_json(self, output_file: str = "benchmarks_data.json"): + """Save benchmarks_data to JSON file by converting dataclasses to dictionaries.""" + + def dataclass_to_dict(obj): + """Recursively convert dataclass objects to dictionaries.""" + if hasattr(obj, '__dataclass_fields__'): + # It's a dataclass + result = {} + for field_name, field_value in obj.__dict__.items(): + result[field_name] = dataclass_to_dict(field_value) + return result + elif isinstance(obj, list): + return [dataclass_to_dict(item) for item in obj] + elif isinstance(obj, dict): + return {key: dataclass_to_dict(value) for key, value in obj.items()} + elif isinstance(obj, Enum): + return obj.value + else: + return obj + + # Convert benchmarks_data to serializable format + serializable_data = [] + for benchmark in self.benchmarks_data: + serializable_data.append(dataclass_to_dict(benchmark)) + + # Save to JSON file + with open(output_file, 'w') as file: + json.dump(serializable_data, file, indent=2) + + print(f"Benchmarks data saved to {output_file}") + + + + def reset(self, bench_idx: Optional[int] = None) -> OperationState: + """Reset the environment. + + Args: + bench_idx (Optional[int]): The index of the benchmark to set the environement to. If None, a random benchmark is selected. Defaults to None. + + Returns: + OperationState: The initial state of the environment. + torch.Tensor: The observation vector of the initial state. + """ + # Get the benchmark + if bench_idx is not None: + self.__bench_index = bench_idx + else: + self.__bench_index = random.randint(0, len(self.benchmarks_data) - 1) + + return self.__init_op_state(-1) + + def step(self, state: OperationState, action: Action) -> tuple[OperationState, float, bool, Optional[float]]: + """Take a step in the environment. + + Args: + state (OperationState): The current state. + action (Action): The action to take. + + Returns: + OperationState: The new state. + float: The reward of the action. + bool: A flag indicating if the operation is done. + Optional[float]: The speedup (if the operation is executed successfully) for logging purposes. + """ + + # TODO: Add logic of calculating reward based on sparse or dense reward + # sparsity logic to be updated + # When sparse reward is False, reward is given after the end of each transformation based on the speedup + + + # Copy the current state to introduce the changes throughout the function + next_state = state.copy() + + # Attempt to apply the transformation to the code + # - If the transformation fails: punish the agent, reset the code, and mark the operation as done + new_transformed_code, trans_succeeded = action.apply(next_state) + if not trans_succeeded: + print_error("Transformation Failed:", action) + reward = self.__action_reward(trans_succeeded) + self.__remove_invalid_trans(next_state) + return next_state, reward, True, 1.0 + + # Register the new code (transformation succeeded) + next_state.transformed_code = new_transformed_code + + # Update the state infos to reflect the transformation + self.__update_state_infos(next_state, action) + + # The operation is done if: + # - The transformation is terminal + # - Maximum number of steps is reached + op_done = action.terminal or next_state.step_count == cfg.truncate + + # If the operation is not done, return the updated state with a reward of 0 + if not op_done: + return next_state, 0.0, False, None + + # Mark the state as terminal + next_state.terminal = True + + # Evaluate the code (since the operation is done) + try: + new_exec_time, exec_succeeded = evaluate_code(next_state, self.__current_bench_data) + if isinstance(exec_succeeded, Exception): + raise exec_succeeded + if not exec_succeeded or new_exec_time is None: + raise Exception("Execution failed") + except Exception as e: + print_error(f"\n\nError while evaluating the code: {e}") + print_error("Exception type:", type(e).__name__) + print_error("Call stack:", e.__traceback__) + print_error("Bench:", next_state.bench_name) + print_error("Transformations:", next_state.transformation_history) + exec_succeeded = False + new_exec_time = None + + # Next state and reward will take into consideration whether execution succeeded or not + # i.e: if execution failed: punish the agent, reset the code, and mark the operation as done + if cfg.sparse_reward: + # Sparse reward: reward is given only if the benchmark is done + # and it's calculated compared to the root execution time + if self.__bench_is_done(next_state): + reward = self.__action_reward(trans_succeeded, exec_succeeded, new_exec_time, self.__current_bench_data.root_exec_time) + else: + reward = 0.0 + else: + reward = self.__action_reward(trans_succeeded, exec_succeeded, new_exec_time, next_state.exec_time) + speedup = (self.__current_bench_data.root_exec_time / new_exec_time) if new_exec_time is not None else 1.0 + + # Update the state infos to reflect the execution + self.__update_state_exec_infos(next_state, new_exec_time) + + return next_state, reward, True, speedup + + def get_next_op_state(self, state: OperationState) -> tuple[Optional[OperationState], bool]: + """Get the state that represents the next operation (can be from another benchmark). + + Args: + state (OperationState): The current state. + + Returns: + OperationState: The next state. + bool: Flag indicating if the benchmark is done. + """ + # Reset to another benchmark if the current benchmark is done (reached first operation) + if self.__bench_is_done(state): + return None, True + + # Build a new state that points to the next operation + new_op_index = self.__current_op_index(state) - 1 + new_op_tag = self.__current_bench_data.operation_tags[new_op_index] + new_op_features = self.__current_bench_data.operations[new_op_tag] + next_state = OperationState( + bench_name=state.bench_name, + operation_tag=new_op_tag, # New operation tag + operation_features=new_op_features, # New operation features + validated_code=state.validated_code, + transformed_code=state.transformed_code, + step_count=0, # Reset step count + exec_time=state.exec_time, + transformation_history=[[]] + state.transformation_history, # Start new sequence + tmp_file=self.tmp_file, + terminal=False, + ) + + return next_state, False + + def __init_op_state(self, operation_idx: int) -> OperationState: + """Create a new operation state. + + Args: + operation_idx (int): The operation index. + + Returns: + OperationState: The new operation state. + torch.Tensor: The observation vector of the new operation state. + """ + operation_tag = self.__current_bench_data.operation_tags[operation_idx] + operation_features = self.__current_bench_data.operations[operation_tag] + + state = OperationState( + bench_name=self.__current_bench_data.bench_name, + operation_tag=operation_tag, + operation_features=operation_features.copy(), + validated_code=self.__current_bench_data.code, + transformed_code=self.__current_bench_data.code, + step_count=0, + exec_time=self.__current_bench_data.root_exec_time, + transformation_history=[[]], + tmp_file=self.tmp_file, + terminal=False, + ) + + return state + + @property + def __current_bench_data(self) -> BenchmarkFeatures: + """Get the current benchmark data. + + Returns: + BenchmarkFeatures: The current benchmark data. + """ + return self.benchmarks_data[self.__bench_index] + + def __current_op_index(self, state: OperationState) -> int: + """Get the index of the current operation. + + Args: + state (OperationState): The current state. + + Returns: + int: The index of the current operation. + """ + return self.__current_bench_data.operation_tags.index(state.operation_tag) + + def __bench_is_done(self, state: OperationState) -> bool: + """Check if the benchmark is done. + + Args: + state (OperationState): The current state. + + Returns: + bool: A flag indicating if the benchmark is done. + """ + return self.__current_op_index(state) == 0 + + def __action_reward(self, trans_succeeded: bool, exec_succeeded: Optional[bool] = None, new_exec_time: Optional[int] = None, old_exec_time: Optional[int] = None) -> float: + """Get the reward of the action based on the transformation and execution results. + + Args: + trans_succeeded (bool): A flag indicating if the transformation was successful. + exec_succeeded (Optional[bool]): A flag indicating if the execution was successful. (required if trans succeeded) + new_exec_time (Optional[float]): The execution time after transformation. (required if exec succeeded) + old_exec_time (Optional[float]): The original execution time. (required if exec succeeded) + + Returns: + float: The reward of the action. + """ + if not trans_succeeded: + return -5.0 + + assert exec_succeeded is not None + if not exec_succeeded: + return -20.0 + + assert new_exec_time is not None and old_exec_time is not None + return self.__speedup_reward(new_exec_time, old_exec_time) + + def __speedup_reward(self, new: int, old: int) -> float: + """Get the reward based on the speedup. + + Args: + new (float): The new execution time. + old (float): The old execution time. + + Returns: + float: The calculated reward. + """ + + # if old < new * 2: + # return math.log(old / (new * 2)) + # else: + # return old / (new * 2) - 1 + return math.log10(old / new) + + def __update_state_infos(self, state: OperationState, action: Action): + """Update state infos after applying a transformation. + + Notes: Updated fields are: + - operation_features (to reflect the transformation) + - transformation_history + - step _count + + Args: + state (OperationState): The current state. + action (Action): The action taken. + + Returns: + OperationState: The updated state. + """ + # Get updated operation features + state.operation_features = action.update_features(state.operation_features) + + # Record action + if state.step_count < len(state.transformation_history[0]): + # Case where the last action should be replaced + state.transformation_history[0][state.step_count] = action + else: + state.transformation_history[0].append(action) + + # Increase count only if action was applied + if action.ready: + state.step_count += 1 + + def __update_state_exec_infos(self, state: OperationState, new_exec_time: Optional[int]): + """Update the state execution infos after evaluating the code. + + Args: + state (OperationState): The current state. + new_exec_time (Optional[int]): The new execution time. + """ + # If the execution failed, reset the transformation sequence + if new_exec_time is None: + self.__remove_invalid_trans(state) + return + + # Mark the code as validated + state.validated_code = state.transformed_code + + # Update the execution time + state.exec_time = new_exec_time + + def __remove_invalid_trans(self, state: OperationState): + """Remove the latest invalid transformations and reset the transformation sequence. + + Args: + state (OperationState): The current state. + """ + # Reset the code to the last validated code + state.transformed_code = state.validated_code + + state.transformation_history[0] = [] diff --git a/rl_autoschedular/evaluation.py b/rl_autoschedular/evaluation.py new file mode 100755 index 0000000..2b4f7d5 --- /dev/null +++ b/rl_autoschedular/evaluation.py @@ -0,0 +1,348 @@ +import os +import numpy as np +from mlir.ir import Context, Module +from mlir.execution_engine import ExecutionEngine, ctypes +from mlir.runtime import get_ranked_memref_descriptor +from mlir.passmanager import PassManager +from typing import Union, Optional +import multiprocessing +from rl_autoschedular import config as cfg +from rl_autoschedular.state import OperationState, BenchmarkFeatures +from utils.log import print_alert +from statistics import median +import json +import re + + +def evaluate_code(state: OperationState, bench_data: BenchmarkFeatures) -> tuple[Optional[int], Union[Exception, bool]]: + """Evaluates the given MLIR code with a timeout. + + Args: + state (OperationState): The operation state to evaluate. + bench_data (BenchmarkFeatures): The benchmark features data. + timeout (Optional[float]): The timeout in seconds. + + Returns: + Optional[float]: the execution time in seconds. + Union[Exception, bool]: the assertion result or an exception if an error occurred. + """ + tmp_folder, tmp_file = state.tmp_file.split('/') + tmp_folder = os.path.join(tmp_folder, 'exec') + tmp_file = tmp_file.replace('.mlir', '.json') + tmp_exec_file = os.path.join(tmp_folder, tmp_file) + if not os.path.exists(tmp_exec_file): + os.makedirs(os.path.dirname(tmp_exec_file), exist_ok=True) + with open(tmp_exec_file, "w") as file: + json.dump({}, file) + + code_cache_key = __get_code_cache_key(state, bench_data) + cache_exec_time = __check_execution_cache(state.bench_name, code_cache_key, tmp_exec_file) + if cache_exec_time is not None: + return cache_exec_time, True + # print_alert('Cache miss') + + real_exec_time, success = evaluate_code_with_bindings(state.transformed_code) + + if success and real_exec_time is not None: + __update_execution_cache(state.bench_name, code_cache_key, real_exec_time, tmp_exec_file) + + return real_exec_time, success + + +# ================================== Evaluation Functions (Python Bindings) ================================== + +def evaluate_code_with_bindings(code: str) -> tuple[Optional[int], Union[Exception, bool]]: + """Lowers and runs the given MLIR code using Python bindings, then returns the execution time and assertion + result (if the executed code returns the correct result). + + Args: + code (str): The MLIR code to run. + + Returns: + Optional[float]: the execution time in seconds. + bool: the assertion result. + """ + pass_pipeline = """builtin.module( + loop-invariant-code-motion, + canonicalize, + + eliminate-empty-tensors, + empty-tensor-to-alloc-tensor, + one-shot-bufferize{ + bufferize-function-boundaries + function-boundary-type-conversion=identity-layout-map + }, + + convert-linalg-to-loops, + canonicalize, + buffer-deallocation-pipeline, + convert-bufferization-to-memref, + scf-forall-to-parallel, + convert-scf-to-openmp, + expand-strided-metadata, + finalize-memref-to-llvm, + convert-scf-to-cf, + lower-affine, + + convert-openmp-to-llvm, + convert-vector-to-llvm, + convert-math-to-llvm, + finalize-memref-to-llvm, + convert-func-to-llvm, + convert-index-to-llvm, + convert-arith-to-llvm, + convert-cf-to-llvm, + + reconcile-unrealized-casts, + canonicalize, + cse + )""" + + with Context(): + module = Module.parse(code) + pm = PassManager.parse(pass_pipeline) + pm.run(module.operation) + execution_engine = ExecutionEngine( + module, + opt_level=3, + shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), + ) + + inputs = __create_inputs(code) + + args = [] + for input_arg in inputs: + args.append(ctypes.pointer(ctypes.pointer( + get_ranked_memref_descriptor(input_arg) + ))) + + delta_arg = (ctypes.c_int64 * 1)(0) + args.append(delta_arg) + + times = [] + try: + for _ in range(5): + execution_engine.invoke("main", *args) + times.append(delta_arg[0]) + except Exception as e: + return None, e + + return median(times), True + + +def evaluate_code_with_bindings_wrapper(code: str, exec_times, assertions): + """Wrapper function for evaluate_code_with_bindings to be used in multiprocessing. + + Args: + code (str): The MLIR code to run. + function_name (str): The name of the function to run. + exec_times (list): A list to store the execution times. + assertions (list): A list to store the assertion results + """ + exec_time, assertion = evaluate_code_with_bindings(code) + exec_times.append(exec_time) + assertions.append(assertion) + + +def evaluate_code_with_bindings_and_timeout(code: str, timeout: Optional[float]) -> tuple[Optional[int], Union[Exception, bool]]: + """Evaluates the given MLIR code using Python bindings with a timeout. + + Args: + code (str): The MLIR code to run. + function_name (str): The name of the function to run. + timeout (Optional[float]): The timeout in seconds. + + Returns: + Optional[float]: the execution time in seconds. + bool: the assertion result. + """ + manager = multiprocessing.Manager() + exec_times = manager.list() + assertions = manager.list() + process = multiprocessing.Process(target=evaluate_code_with_bindings_wrapper, args=(code, exec_times, assertions)) + process.start() + process.join(timeout) + + if process.is_alive(): + # The function is still running, terminate the process + process.terminate() + process.join() + + return None, False + else: + # The function completed within the timeout + return exec_times[0], assertions[0] + + +# ================================== Evaluation Functions (MLIR CPU Runner) ================================== + +def evaluate_code_with_cmd(code: str, tmp_file_path: str) -> tuple[Optional[int], bool]: + """Lowers and runs the given MLIR code using MLIR opt and MLIR CPU Runner, then returns the execution time and assertion. + + Args: + code (str): The MLIR code to run. + tmp_file_path (str): The temporary file path to write the MLIR code. + + Returns: + Optional[float]: the execution time in seconds. + bool: the assertion result. + """ + command_1 = f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt -loop-invariant-code-motion -canonicalize -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize='bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map' -convert-vector-to-scf -convert-linalg-to-loops -buffer-deallocation-pipeline -convert-bufferization-to-memref -scf-forall-to-parallel -convert-scf-to-openmp -expand-strided-metadata -finalize-memref-to-llvm -convert-scf-to-cf -lower-affine -convert-openmp-to-llvm -convert-vector-to-llvm -convert-math-to-llvm -finalize-memref-to-llvm -convert-func-to-llvm -convert-index-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts -canonicalize -cse" + command_2 = f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-cpu-runner -e main -entry-point-result=void -shared-libs={os.getenv('LLVM_BUILD_PATH')}/lib/libmlir_runner_utils.so,{os.getenv('LLVM_BUILD_PATH')}/lib/libmlir_c_runner_utils.so,{os.getenv('LLVM_BUILD_PATH')}/lib/libomp.so" + + with open(tmp_file_path, "w") as file: + file.write(code) + + out = os.popen(f"""{command_1} {tmp_file_path} | {command_2} /dev/stdin""").read() + + if out: + return int(out.strip().split('\n')[-1]), True + else: + return None, False + + +def evaluate_code_with_cmd_wrapper(code: str, tmp_file_path: str, exec_times, assertions): + """Wrapper function for evaluate_code_with_cmd to be used in multiprocessing. + + Args: + code (str): The MLIR code to run. + tmp_file_path (str): The temporary file path to write the MLIR code. + exec_times (list): A list to store the execution times. + assertions (list): A list to store the assertion results + """ + exec_time, assertion = evaluate_code_with_cmd(code, tmp_file_path) + exec_times.append(exec_time) + assertions.append(assertion) + + +def evaluate_code_with_cmd_and_timeout(code: str, tmp_file_path: str, timeout: Optional[float]) -> tuple[Optional[int], bool]: + """Evaluates the given MLIR code using MLIR opt and MLIR CPU Runner with a timeout. + + Args: + code (str): The MLIR code to run. + tmp_file_path (str): The temporary file path to write the MLIR code. + timeout (Optional[float]): The timeout in seconds. + + Returns: + Optional[float]: the execution time in seconds. + bool: the assertion result. + """ + manager = multiprocessing.Manager() + exec_times = manager.list() + assertions = manager.list() + process = multiprocessing.Process(target=evaluate_code_with_cmd_wrapper, args=(code, tmp_file_path, exec_times, assertions)) + process.start() + process.join(timeout) + + if process.is_alive(): + # The function is still running, terminate the process + process.terminate() + process.join() + + return None, False + else: + # The function completed within the timeout + return exec_times[0], assertions[0] + + +def __check_execution_cache(bench_name: str, cache_key: str, tmp_exec_file: str) -> Optional[int]: + """Check the execution cache for the given operation state. + + Args: + cache_key (str): The cache key to check. + + Returns: + Optional[int]: the execution time in nanoseconds if the operation is found in the cache, otherwise None. + """ + # Start by checking the main execution cache file + if cfg.exec_data_file: + try: + with open(cfg.exec_data_file, "r") as file: + data = json.load(file) + + if bench_name in data and cache_key in data[bench_name]: + return int(data[bench_name][cache_key]) + except Exception: + pass + + # If no hit in the main cache file, check the temporary cache file + with open(tmp_exec_file, "r") as file: + data = json.load(file) + + if bench_name in data and cache_key in data[bench_name]: + return int(data[bench_name][cache_key]) + + # No hit in both cache files + return None + + +def __update_execution_cache(bench_name: str, cache_key: str, exec_time: int, tmp_exec_file: str): + """Update the temp execution cache with the given operation state. + + Args: + cache_key (str): The cache key to update. + exec_time (int): The execution time in nanoseconds. + """ + with open(tmp_exec_file, "r") as file: + data = json.load(file) + + if bench_name not in data: + data[bench_name] = {} + + if cache_key in data[bench_name]: + print_alert("Unexpected hit", data[bench_name][cache_key], exec_time) + return + data[bench_name][cache_key] = exec_time + + with open(tmp_exec_file, "w") as file: + json.dump(data, file, indent=4) + + +def __get_code_cache_key(state: OperationState, bench_data: BenchmarkFeatures) -> str: + """Get the code cache key for the given operation state. + + Args: + state (OperationState): The operation state to get the code cache key. + bench_data (BenchmarkFeatures): The benchmark features data. + + Returns: + str: the code cache key. + """ + ops_codes = [''] * len(bench_data.operation_tags) + for i, seq in enumerate(reversed(state.transformation_history)): + ops_codes[i] = ''.join(map(str, seq)) + + return '|'.join(reversed(ops_codes)) + + +def __create_inputs(code) -> list[np.ndarray]: + main_pattern = r"func.func @main\(([^)]+)\)" + main_params = re.search(main_pattern, code).group(1) + main_shapes = [arg.split(':')[1].strip() for arg in main_params.split(',')] + + inputs: list[np.ndarray] = [] + for shape in main_shapes: + assert shape.startswith('memref<') or shape.startswith('tensor<'), f'unexpected shape {shape}' + *np_shape, dtype = shape.replace('memref<', '').replace('tensor<', '').replace('>', '').split('x') + assert dtype[0] in ['f', 'i'] and dtype[1:] in ['32', '64'], f'unexpected dtype {dtype}' + match dtype[0]: + case 'f': + match dtype[1:]: + case '32': + np_dtype = np.float32 + case '64': + np_dtype = np.float64 + case 'i': + match dtype[1:]: + case '32': + np_dtype = np.int32 + case '64': + np_dtype = np.int64 + np_shape = list(map(int, np_shape)) + # if len(np_shape) > 0: + # inputs.append((np.random.rand(*np_shape) * 100).astype(np_dtype)) + # else: + # inputs.append(np.array(np.random.rand() * 100, dtype=np_dtype)) + inputs.append(np.zeros(np_shape, dtype=np_dtype)) + + return inputs diff --git a/rl_autoschedular/execution.py b/rl_autoschedular/execution.py index 6a43daa..7adab84 100644 --- a/rl_autoschedular/execution.py +++ b/rl_autoschedular/execution.py @@ -219,4 +219,4 @@ def __create_inputs(self, code) -> list[np.ndarray]: # inputs.append(np.array(np.random.rand() * 100, dtype=np_dtype)) inputs.append(np.zeros(np_shape, dtype=np_dtype)) - return inputs + return inputs \ No newline at end of file diff --git a/rl_autoschedular/model.py b/rl_autoschedular/model.py old mode 100644 new mode 100755 index a4d1baa..571d250 --- a/rl_autoschedular/model.py +++ b/rl_autoschedular/model.py @@ -1,230 +1,212 @@ -import torch -import torch.nn as nn -from torch.distributions import Distribution -from typing import Optional -from rl_autoschedular.actions import ActionSpace, Interchange -from rl_autoschedular.observation import OpFeatures, ActionHistory, ProducerOpFeatures, Observation -from utils.config import Config - - -ACTIVATION = nn.ReLU if Config().activation == 'relu' else nn.Tanh - - -class HiearchyModel(nn.Module): - """Hierarchical reinforcement learning model for MLIR code optimization.""" - def __init__(self): - """Initialize the model.""" - super(HiearchyModel, self).__init__() - - self.policy_model = PolicyModel() - self.value_model = ValueModel() - - def __call__(self, obs: torch.Tensor, actions_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return super().__call__(obs, actions_index) - - def forward(self, obs: torch.Tensor, actions_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Forward pass of the model. - - Args: - obs (torch.Tensor): The input tensor. - actions_index (torch.Tensor): The list of actions. - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The logits of the transformations, parallelizations, tilings, and interchanges. - """ - actions_log_p, entropies = ActionSpace.distributions_stats(self.policy_model(obs), actions_index) - - values = self.value_model(obs) - - return actions_log_p, values, entropies - - def sample(self, obs: torch.Tensor, greedy: bool = False, eps: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sample an action from the model. - - Args: - obs (torch.Tensor): The input tensor. - greedy (bool): Whether to sample greedily. - eps (Optional[float]): Epsilon value for exploration. Defaults to None. - - Returns: - torch.Tensor: Sampled actions index. - torch.Tensor: actions log probability. - torch.Tensor: resulting entropy. - """ - assert not greedy or eps is None, 'Cannot be greedy and explore at the same time.' - - # Model feedforward - distributions = self.policy_model(obs) - eps_distributions = ActionSpace.uniform_distributions(obs) - actions_index = ActionSpace.sample( - obs, - distributions, - eps_distributions, - uniform=eps is not None and torch.rand(1).item() < eps, - greedy=greedy - ) - actions_log_p, entropies = ActionSpace.distributions_stats( - distributions, - actions_index, - eps_distributions=eps_distributions if eps is not None else None, - eps=eps - ) - - return actions_index, actions_log_p, entropies - - -class ValueModel(nn.Module): - """Value model for MLIR code optimization.""" - def __init__(self): - """Initialize the model.""" - super(ValueModel, self).__init__() - - self.lstm = LSTMEmbedding() - - self.network = nn.Sequential( - nn.Linear(self.lstm.output_size, 512), - ACTIVATION(), - nn.Linear(512, 512), - ACTIVATION(), - nn.Linear(512, 512), - ACTIVATION(), - nn.Linear(512, 1), - ) - - def __call__(self, obs: torch.Tensor) -> torch.Tensor: - return super().__call__(obs) - - def forward(self, obs: torch.Tensor) -> torch.Tensor: - """Forward pass of the model. - - Args: - obs (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The value tensor. - """ - return self.network(self.lstm(obs)).squeeze(-1) - - def loss(self, new_values: torch.Tensor, values: torch.Tensor, returns: torch.Tensor) -> torch.Tensor: - """Calculate the value loss. - - Args: - new_values (torch.Tensor): The new value tensor. - values (torch.Tensor): The value tensor. - returns (torch.Tensor): The returns tensor. - - Returns: - torch.Tensor: The value loss. - """ - if Config().value_clip: - vclip = values + torch.clamp(new_values - values, -0.2, 0.2) - vloss1 = (returns - vclip).pow(2) - vloss2 = (returns - new_values).pow(2) - return torch.max(vloss1, vloss2).mean() - return (returns - new_values).pow(2).mean() - - -class PolicyModel(nn.Module): - """Policy model for MLIR code optimization.""" - def __init__(self): - """Initialize the model.""" - super(PolicyModel, self).__init__() - - self.log_std = Interchange.log_std - - self.lstm = LSTMEmbedding() - - self.backbone = nn.Sequential( - nn.Linear(self.lstm.output_size, 512), - ACTIVATION(), - nn.Linear(512, 512), - ACTIVATION(), - nn.Linear(512, 512), - ACTIVATION(), - ) - - output_sizes = [ActionSpace.size()] + [action.network_output_size() for action in ActionSpace.supported_actions] - self.heads = nn.ModuleList() - for output_size in output_sizes: - if not output_size: - self.heads.append(None) - continue - head = nn.Linear(512, output_size) - if Config().new_architecture: - head = nn.Sequential( - nn.Linear(512, 512), - ACTIVATION(), - head - ) - self.heads.append(head) - - def __call__(self, obs: torch.Tensor) -> list[Optional[Distribution]]: - return super().__call__(obs) - - def forward(self, obs: torch.Tensor) -> list[Optional[Distribution]]: - """Forward pass of the model. - - Args: - obs (torch.Tensor): The input tensor. - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The logits of the transformations, parallelizations, tilings, and interchanges. - """ - embedded = self.backbone(self.lstm(obs)) - actions_logits = [head(embedded) if head else None for head in self.heads] - - return ActionSpace.distributions(obs, *actions_logits) - - def loss(self, actions_log_p: torch.Tensor, actions_bev_log_p: torch.Tensor, off_policy_rates: torch.Tensor, advantages: torch.Tensor, clip_range: float = 0.2) -> tuple[torch.Tensor, torch.Tensor]: - """Calculate the policy loss. - - Args: - new_actions_log_p (torch.Tensor): The log probabilities of the new actions. - actions_bev_log_p (torch.Tensor): The log probabilities of the actions under the behavior policy. - off_policy_rates (torch.Tensor): The rate between the old policy and the behavioral (mu) policy. - advantages (torch.Tensor): The advantages of the actions. - clip_range (float): The clipping range for the policy loss. - - Returns: - torch.Tensor: The policy loss. - float: The ratio clip fraction (for logging purposes) - """ - ratios = torch.exp(torch.clamp(actions_log_p - actions_bev_log_p, -80.0, 80.0)) - surr1 = ratios * advantages - surr2 = torch.clamp(ratios, (1 - clip_range) * off_policy_rates, (1 + clip_range) * off_policy_rates) * advantages - clip_frac = (torch.abs((ratios / off_policy_rates - 1)) > clip_range).float().mean() - return - torch.min(surr1, surr2).mean(), clip_frac - - -class LSTMEmbedding(nn.Module): - def __init__(self): - super(LSTMEmbedding, self).__init__() - - embedding_size = 411 - - self.output_size = embedding_size + ActionHistory.size() - - self.embedding = nn.Sequential( - nn.Linear(OpFeatures.size(), 512), - nn.ELU(), - nn.Dropout(0.225), - nn.Linear(512, 512), - nn.ELU(), - nn.Dropout(0.225), - ) - - self.lstm = nn.LSTM(512, embedding_size) - - def __call__(self, obs: torch.Tensor) -> torch.Tensor: - return super().__call__(obs) - - def forward(self, obs: torch.Tensor) -> torch.Tensor: - consumer_feats = Observation.get_part(obs, OpFeatures) - producer_feats = Observation.get_part(obs, ProducerOpFeatures) - - consumer_embeddings = self.embedding(consumer_feats).unsqueeze(0) - producer_embeddings = self.embedding(producer_feats).unsqueeze(0) - - _, (final_hidden, _) = self.lstm(torch.cat((consumer_embeddings, producer_embeddings))) - - return torch.cat((final_hidden.squeeze(0), Observation.get_part(obs, ActionHistory)), 1) +import torch +import torch.nn as nn +from torch.distributions import Distribution +from typing import Optional +from rl_autoschedular import config as cfg +from rl_autoschedular.actions import ActionSpace, Interchange +from rl_autoschedular.observation import OpFeatures, ActionHistory, Observation, ObservationPart + + +ACTIVATION = nn.ReLU if cfg.activation == 'relu' else nn.Tanh + + +class HiearchyModel(nn.Module): + """Hierarchical reinforcement learning model for MLIR code optimization.""" + def __init__(self): + """Initialize the model.""" + super(HiearchyModel, self).__init__() + + self.policy_model = PolicyModel([OpFeatures, ActionHistory]) + self.value_model = ValueModel([OpFeatures, ActionHistory]) + + def __call__(self, obs: torch.Tensor, actions_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return super().__call__(obs, actions_index) + + def forward(self, obs: torch.Tensor, actions_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of the model. + + Args: + obs (torch.Tensor): The input tensor. + actions_index (torch.Tensor): The list of actions. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The logits of the transformations, parallelizations, tilings, and interchanges. + """ + actions_log_p, entropies = ActionSpace.distributions_stats(self.policy_model(obs), actions_index) + + values = self.value_model(obs) + + return actions_log_p, values, entropies + + def sample(self, obs: torch.Tensor, greedy: bool = False, eps: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sample an action from the model. + + Args: + obs (torch.Tensor): The input tensor. + greedy (bool): Whether to sample greedily. + eps (Optional[float]): Epsilon value for exploration. Defaults to None. + + Returns: + torch.Tensor: Sampled actions index. + torch.Tensor: actions log probability. + torch.Tensor: resulting entropy. + """ + assert not greedy or eps is None, 'Cannot be greedy and explore at the same time.' + + # Model feedforward + distributions = self.policy_model(obs) + eps_distributions = ActionSpace.uniform_distributions(obs) + actions_index = ActionSpace.sample( + obs, + distributions, + eps_distributions, + uniform=eps is not None and torch.rand(1).item() < eps, + greedy=greedy + ) + actions_log_p, entropies = ActionSpace.distributions_stats( + distributions, + actions_index, + eps_distributions=eps_distributions if eps is not None else None, + eps=eps + ) + + return actions_index, actions_log_p, entropies + + +class ValueModel(nn.Module): + """Value model for MLIR code optimization.""" + def __init__(self, obs_parts: list[type[ObservationPart]]): + """Initialize the model. + + Args: + obs_parts (list[type[ObservationPart]]): List of observation parts to be used in the model. + """ + super(ValueModel, self).__init__() + + self.obs_parts = obs_parts + self.network = nn.Sequential( + nn.Linear(sum(part.size() for part in obs_parts), 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 1), + ) + + def __call__(self, obs: torch.Tensor) -> torch.Tensor: + return super().__call__(obs) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + """Forward pass of the model. + + Args: + obs (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The value tensor. + """ + return self.network(Observation.get_parts(obs, *self.obs_parts)).squeeze(-1) + + def loss(self, new_values: torch.Tensor, values: torch.Tensor, returns: torch.Tensor) -> torch.Tensor: + """Calculate the value loss. + + Args: + new_values (torch.Tensor): The new value tensor. + values (torch.Tensor): The value tensor. + returns (torch.Tensor): The returns tensor. + + Returns: + torch.Tensor: The value loss. + """ + if cfg.value_clip: + vclip = values + torch.clamp(new_values - values, -0.2, 0.2) + vloss1 = (returns - vclip).pow(2) + vloss2 = (returns - new_values).pow(2) + return torch.max(vloss1, vloss2).mean() + return (returns - new_values).pow(2).mean() + + +class PolicyModel(nn.Module): + """Policy model for MLIR code optimization.""" + def __init__(self, obs_parts: list[type[ObservationPart]]): + """Initialize the model. + + Args: + obs_parts (list[type[ObservationPart]]): List of observation parts to be used in the model. + """ + super(PolicyModel, self).__init__() + + self.obs_parts = obs_parts + Interchange.log_std = nn.Parameter(torch.zeros(1)) + + self.backbone = nn.Sequential( + nn.Linear(sum(part.size() for part in obs_parts), 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + nn.Linear(512, 512), + ACTIVATION(), + ) + + output_sizes = [ActionSpace.size()] + [action.network_output_size() for action in ActionSpace.supported_actions] + self.heads_attributes = [f'head_{i}' for i in range(len(output_sizes))] + + for head_attr, output_size in zip(self.heads_attributes, output_sizes): + if not output_size: + setattr(self, head_attr, None) + continue + + head = nn.Linear(512, output_size) + if cfg.new_architecture: + head = nn.Sequential( + nn.Linear(512, 512), + ACTIVATION(), + head + ) + setattr(self, head_attr, head) + + def __call__(self, obs: torch.Tensor,) -> list[Optional[Distribution]]: + return super().__call__(obs) + + def forward(self, obs: torch.Tensor) -> list[Optional[Distribution]]: + """Forward pass of the model. + + Args: + obs (torch.Tensor): The input tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The logits of the transformations, parallelizations, tilings, and interchanges. + + """ + embedded = self.backbone(Observation.get_parts(obs, *self.obs_parts)) + heads: list[Optional[nn.Module]] = [getattr(self, attr) for attr in self.heads_attributes] + actions_logits = [head(embedded) if head else None for head in heads] + + return ActionSpace.distributions(obs, *actions_logits) + + def loss(self, actions_log_p: torch.Tensor, actions_old_log_p: torch.Tensor, advantages: torch.Tensor, clip_range: float = 0.2) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate the policy loss. + + Args: + new_actions_log_p (torch.Tensor): The log probabilities of the new actions. + actions_old_log_p (torch.Tensor): The log probabilities of the actions under the behavior policy. + advantages (torch.Tensor): The advantages of the actions. + clip_range (float): The clipping range for the policy loss. + + Returns: + torch.Tensor: The policy loss. + float: The ratio clip fraction (for logging purposes) + """ + # Importance sampling ratio + ratios = torch.exp(actions_log_p - actions_old_log_p) + + # PPO surrogate objective + surr1 = ratios * advantages + surr2 = torch.clamp(ratios, 1.0 - clip_range, 1.0 + clip_range) * advantages + policy_loss = -torch.min(surr1, surr2).mean() + + # Fraction of samples where clipping applied + clip_frac = (torch.abs(ratios - 1.0) > clip_range).float().mean() + + return policy_loss, clip_frac \ No newline at end of file diff --git a/rl_autoschedular/observation.py b/rl_autoschedular/observation.py old mode 100644 new mode 100755 index 2fd0c79..8e31642 --- a/rl_autoschedular/observation.py +++ b/rl_autoschedular/observation.py @@ -1,245 +1,227 @@ -from rl_autoschedular.actions import ActionSpace -from rl_autoschedular.state import OperationState, OperationType, IteratorType, OperationFeatures -import torch -import math - -from utils.config import Config - -L = Config().max_num_loops -LSD = Config().max_num_load_store_dim -LS = Config().max_num_stores_loads - - -class ObservationPart: - @classmethod - def size(cls) -> int: - raise NotImplementedError - - @classmethod - def from_state(cls, state: OperationState) -> torch.Tensor: - """Create the observation part from the current state.""" - raise NotImplementedError - - -class OpFeatures(ObservationPart): - """Class representing operation features in the observation""" - - arith_ops = ['+', '-', '*', '/', 'exp'] - - @classmethod - def size(cls) -> int: - return len(OperationType) + L + L + 1 + LS * LSD * L + LSD * L + len(cls.arith_ops) - - @classmethod - def from_state(cls, state: OperationState) -> torch.Tensor: - return cls._from_features(state.original_operation_features) - - @classmethod - def _from_features(cls, op_features: OperationFeatures) -> torch.Tensor: - indices = [nested_loop.arg for nested_loop in op_features.nested_loops] - indices_dim = {arg: i for (i, arg) in enumerate(indices)} - - # Operation type - op_type = torch.tensor([op_features.operation_type == ot for ot in OperationType]) - - # Nested loop features: (upper bounds, iterator types) - nested_loops = torch.zeros(L) - iterator_types = torch.zeros(L) - for i, nested_loop in enumerate(op_features.nested_loops): - if i == L: - break - ub = nested_loop.upper_bound - match Config().normalize_bounds: - case 'max': - ub = ub / 4096 - case 'log': - ub = math.log2(ub) - nested_loops[i] = ub - iterator_types[i] = nested_loop.iterator_type == IteratorType.Parallel - - # Vectorizable - vectorizable = torch.tensor([op_features.vectorizable]) - - # load access matrices: - load_access_matrices = torch.zeros((LS, LSD, L)) - - for load_i, load in enumerate(op_features.load_data): - if load_i == LS: - break - dimensions_terms = [cls.__formula_str_to_list(term) for term in load] - for m, dimension_term in enumerate(dimensions_terms): - if m == LSD: - break - for index, factor in dimension_term: - if index not in indices_dim: - continue - n = indices_dim[index] - if n >= L: - continue - load_access_matrices[load_i, m, n] = factor - - # store access matrices: - store_access_matrices = torch.zeros((LSD, L)) - - dimensions_terms = [cls.__formula_str_to_list(term) for term in op_features.store_data] - for m, dimension_term in enumerate(dimensions_terms): - if m == LSD: - break - for index, factor in dimension_term: - if index not in indices_dim: - continue - n = indices_dim[index] - if n >= L: - continue - store_access_matrices[m, n] = factor - - # Operations count: - operations_count = torch.tensor([op_features.op_count[s] for s in cls.arith_ops]) - - feature_vector = torch.cat(( - op_type, - nested_loops, - iterator_types, - vectorizable, - load_access_matrices.reshape(-1), - store_access_matrices.reshape(-1), - operations_count - )) - - return feature_vector - - @staticmethod - def __formula_str_to_list(formula: str) -> list[tuple[str, int]]: - """Turns assignement formula to a list of (index, factor) - Example: - formula = "%x1 - %x2 + %x3 * 5 - %x5 * 3" - return [('%x1', 1), ('%x2', -1), ('%x3', 5), ('%x5', -3)] - - Args: - formula (str): the formula as a string input - - Returns: - list: list of (index, factor) pairs - """ - formula = formula + ' +' - terms = formula.split(' ') - - running_factor = 1 - running_term = None - - save = [] - - for term in terms: - - if term.startswith('%'): - running_term = term - elif term == '+': - save.append((running_term, running_factor)) - running_factor = 1 - elif term == '-': - save.append((running_term, running_factor)) - running_factor = -1 - elif term.isnumeric(): - running_factor *= int(term) - - if save[0][0] is None: - save = save[1:] - - return save - - -class ProducerOpFeatures(OpFeatures): - @classmethod - def from_state(cls, state: OperationState) -> torch.Tensor: - if state.producer_features: - return cls._from_features(state.producer_features) - - return torch.zeros(cls.size()) - - -class ActionHistory(ObservationPart): - """Class representing action history in the observation""" - - @classmethod - def size(cls) -> int: - return ActionSpace.cumulative_history_sizes()[-1] - - @classmethod - def from_state(cls, state: OperationState) -> torch.Tensor: - return ActionSpace.action_history(state) - - -class ActionMask(ObservationPart): - """Class representing action mask in the observation""" - - @classmethod - def size(cls) -> int: - return ActionSpace.cumulative_mask_sizes()[-1] - - @classmethod - def from_state(cls, state: OperationState) -> torch.Tensor: - return ActionSpace.action_mask(state) - - -class NumLoops(ObservationPart): - """Class representing number of loops in the observation""" - - @classmethod - def size(cls) -> int: - return 1 - - @classmethod - def from_state(cls, state: OperationState) -> torch.Tensor: - return torch.tensor([len(state.operation_features.nested_loops)]) - - -class Observation: - """Class to manage creation and use of observations""" - - parts: list[type[ObservationPart]] = [ - OpFeatures, - ProducerOpFeatures, - ActionHistory, - NumLoops, - ActionMask - ] - - @classmethod - def cumulative_sizes(cls) -> list[int]: - """Get cumulative sizes of all observation parts.""" - sizes = [0] - for part in cls.parts: - sizes.append(sizes[-1] + part.size()) - return sizes - - @classmethod - def part_number(cls, part: type[ObservationPart]) -> int: - """Get the index of a part in the observation.""" - return cls.parts.index(part) - - @classmethod - def get_part(cls, obs: torch.Tensor, part: type[ObservationPart], squeeze: bool = True) -> torch.Tensor: - """Get a specific part of the observation.""" - part_idx = cls.part_number(part) - cum_sizes = cls.cumulative_sizes() - start = cum_sizes[part_idx] - if part.size() == 1 and squeeze: - return obs[:, start] - end = cum_sizes[part_idx + 1] - return obs[:, start:end] - - @classmethod - def get_parts(cls, obs: torch.Tensor, *parts: type[ObservationPart]) -> torch.Tensor: - """Get multiple parts of the observation in a single tensor.""" - return torch.cat([cls.get_part(obs, part, False) for part in parts], dim=1) - - @classmethod - def from_state(cls, state: OperationState) -> torch.Tensor: - """Create the full observation from the current state.""" - obs_parts = [part.from_state(state) for part in cls.parts] - return torch.cat(obs_parts).unsqueeze(0) - - @classmethod - def from_states(cls, states: list[OperationState]) -> torch.Tensor: - """Create the full observation for all the states.""" - return torch.cat([cls.from_state(s) for s in states]) +from rl_autoschedular import config as cfg +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.state import OperationState, OperationType, IteratorType +import torch +import math + +L = cfg.max_num_loops +LSD = cfg.max_num_load_store_dim +LS = cfg.max_num_stores_loads + + +class ObservationPart: + @classmethod + def size(cls) -> int: + raise NotImplementedError + + @classmethod + def from_state(cls, state: OperationState) -> torch.Tensor: + """Create the observation part from the current state.""" + raise NotImplementedError + + +class OpFeatures(ObservationPart): + """Class representing operation features in the observation""" + + arith_ops = ['+', '-', '*', '/', 'exp'] + + @classmethod + def size(cls) -> int: + return len(OperationType) + L + L + 1 + LS * LSD * L + LSD * L + len(cls.arith_ops) + + @classmethod + def from_state(cls, state: OperationState) -> torch.Tensor: + op_features = state.operation_features + + indices = [nested_loop.arg for nested_loop in op_features.nested_loops] + indices_dim = {arg: i for (i, arg) in enumerate(indices)} + + # Operation type + op_type = torch.tensor([op_features.operation_type == ot for ot in OperationType]) + + # Nested loop features: (upper bounds, iterator types) + nested_loops = torch.zeros(L) + iterator_types = torch.zeros(L) + for i, nested_loop in enumerate(op_features.nested_loops): + if i == L: + break + ub = nested_loop.upper_bound + match cfg.normalize_bounds: + case 'max': + ub = ub / 4096 + case 'log': + ub = math.log2(ub) + nested_loops[i] = ub + iterator_types[i] = nested_loop.iterator_type == IteratorType.Parallel + + # Vectorizable + vectorizable = torch.tensor([op_features.vectorizable]) + + # load access matrices: + load_access_matrices = torch.zeros((LS, LSD, L)) + + for load_i, load in enumerate(op_features.load_data): + if load_i == LS: + break + dimensions_terms = [cls.__formula_str_to_list(term) for term in load] + for m, dimension_term in enumerate(dimensions_terms): + if m == LSD: + break + for index, factor in dimension_term: + if index not in indices_dim: + continue + n = indices_dim[index] + if n >= L: + continue + load_access_matrices[load_i, m, n] = factor + + # store access matrices: + store_access_matrices = torch.zeros((LSD, L)) + + dimensions_terms = [cls.__formula_str_to_list(term) for term in op_features.store_data] + for m, dimension_term in enumerate(dimensions_terms): + if m == LSD: + break + for index, factor in dimension_term: + if index not in indices_dim: + continue + n = indices_dim[index] + if n >= L: + continue + store_access_matrices[m, n] = factor + + # Operations count: + operations_count = torch.tensor([op_features.op_count[s] for s in cls.arith_ops]) + + feature_vector = torch.cat(( + op_type, + nested_loops, + iterator_types, + vectorizable, + load_access_matrices.reshape(-1), + store_access_matrices.reshape(-1), + operations_count + )) + + return feature_vector + + @staticmethod + def __formula_str_to_list(formula: str) -> list[tuple[str, int]]: + """Turns assignement formula to a list of (index, factor) + Example: + formula = "%x1 - %x2 + %x3 * 5 - %x5 * 3" + return [('%x1', 1), ('%x2', -1), ('%x3', 5), ('%x5', -3)] + + Args: + formula (str): the formula as a string input + + Returns: + list: list of (index, factor) pairs + """ + formula = formula + ' +' + terms = formula.split(' ') + + running_factor = 1 + running_term = None + + save = [] + + for term in terms: + + if term.startswith('%'): + running_term = term + elif term == '+': + save.append((running_term, running_factor)) + running_factor = 1 + elif term == '-': + save.append((running_term, running_factor)) + running_factor = -1 + elif term.isnumeric(): + running_factor *= int(term) + + if save[0][0] is None: + save = save[1:] + + return save + + +class ActionHistory(ObservationPart): + """Class representing action history in the observation""" + + @classmethod + def size(cls) -> int: + return ActionSpace.cumulative_history_sizes()[-1] + + @classmethod + def from_state(cls, state: OperationState) -> torch.Tensor: + return ActionSpace.action_history(state) + + +class ActionMask(ObservationPart): + """Class representing action mask in the observation""" + + @classmethod + def size(cls) -> int: + return ActionSpace.cumulative_mask_sizes()[-1] + + @classmethod + def from_state(cls, state: OperationState) -> torch.Tensor: + return ActionSpace.action_mask(state) + + +class NumLoops(ObservationPart): + """Class representing number of loops in the observation""" + + @classmethod + def size(cls) -> int: + return 1 + + @classmethod + def from_state(cls, state: OperationState) -> torch.Tensor: + return torch.tensor([len(state.operation_features.nested_loops)]) + + +class Observation: + """Class to manage creation and use of observations""" + + parts: list[type[ObservationPart]] = [ + OpFeatures, + ActionHistory, + NumLoops, + ActionMask + ] + + @classmethod + def cumulative_sizes(cls) -> list[int]: + """Get cumulative sizes of all observation parts.""" + sizes = [0] + for part in cls.parts: + sizes.append(sizes[-1] + part.size()) + return sizes + + @classmethod + def part_number(cls, part: type[ObservationPart]) -> int: + """Get the index of a part in the observation.""" + return cls.parts.index(part) + + @classmethod + def get_part(cls, obs: torch.Tensor, part: type[ObservationPart], squeeze: bool = True) -> torch.Tensor: + """Get a specific part of the observation.""" + part_idx = cls.part_number(part) + cum_sizes = cls.cumulative_sizes() + start = cum_sizes[part_idx] + if part.size() == 1 and squeeze: + return obs[:, start] + end = cum_sizes[part_idx + 1] + return obs[:, start:end] + + @classmethod + def get_parts(cls, obs: torch.Tensor, *parts: type[ObservationPart]) -> torch.Tensor: + """Get multiple parts of the observation in a single tensor.""" + return torch.cat([cls.get_part(obs, part, False) for part in parts], dim=1) + + @classmethod + def from_state(cls, state: OperationState) -> torch.Tensor: + """Create the full observation from the current state.""" + obs_parts = [part.from_state(state) for part in cls.parts] + return torch.cat(obs_parts).unsqueeze(0) diff --git a/rl_autoschedular/ppo.py b/rl_autoschedular/ppo.py old mode 100644 new mode 100755 index 9010d1e..7fc6e6e --- a/rl_autoschedular/ppo.py +++ b/rl_autoschedular/ppo.py @@ -1,354 +1,398 @@ -from statistics import mean -import torch -from rl_autoschedular.env import Env -from rl_autoschedular.model import HiearchyModel as Model -from rl_autoschedular.state import OperationState -from rl_autoschedular.trajectory import TrajectoryCollector, TrajectoryData -from rl_autoschedular.observation import Observation, NumLoops -from rl_autoschedular.actions import ActionSpace -from rl_autoschedular.benchmarks import Benchmarks -from rl_autoschedular.execution import Execution -from rl_autoschedular import device -from utils.config import Config -from utils.file_logger import FileLogger -from utils.log import print_error, print_info, print_success -from utils.dask_manager import DaskManager -from tqdm import trange -from time import time -from typing import Optional - - -def collect_trajectory(data: Benchmarks, model: Model, step: int): - """Collect a trajectory using the model and the environment. - - Args: - model (MyModel): The model to use. - env (Env): The environment to use. - step (int): The current step of the main loop - tmp_exec_data_file (str): The path to the temporary execution data file. - - Returns: - TrejectoryData: The collected trajectory. - """ - dm = DaskManager() - fl = FileLogger() - exe = Execution() - cfg = Config() - - eps = None - if 'epsilon' in cfg.exploration: - ratio = step / cfg.nb_iterations - final_eps = 0.001 - eps = final_eps + (cfg.init_epsilon - final_eps) * (1 - ratio) - - print_info(f"Trajectory collection using {dm.num_workers} workers...", end=' ') - traj_start = time() - - # Prepare benchmarks to explore - indices = torch.randperm(len(data))[:cfg.bench_count].long().tolist() - if len(indices) < cfg.bench_count: - indices = (indices * cfg.bench_count)[:cfg.bench_count] - envs: list[Env] = [] - states: list[OperationState] = [] - observations: list[torch.Tensor] = [] - tcs: list[TrajectoryCollector] = [] - for idx in indices: - env = Env() - state = env.reset(data, idx) - envs.append(env) - states.append(state) - observations.append(Observation.from_state(state)) - tcs.append(TrajectoryCollector()) - - while (active_states := [(i, s) for i, s in enumerate(states) if not s.terminal]): - # Sample states that are not terminal yet - obss = torch.cat([observations[i] for i, _ in active_states]) - actions_index, actions_bev_log_p, entropies = model.sample(obss.to(device), eps=eps) - fl['train/entropy'].extend(entropies.tolist()) - - # Record data and update states - for (i, state), obs, action_index, action_bev_log_p in zip(active_states, obss, actions_index, actions_bev_log_p): - obs = obs.unsqueeze(0) - - # Get action and use it to get next state - action = ActionSpace.action_by_index(action_index, state) - states[i] = next_state = envs[i].step(state, action) - observations[i] = next_obs = Observation.from_state(next_state) - - # If the benchmark is not done yet, keep next operation state instead - done = False - if next_state.terminal: - next_op_state = envs[i].get_next_op_state(next_state) - if next_op_state is not None: - states[i] = next_op_state - observations[i] = Observation.from_state(next_op_state) - else: - done = True - - # Record available data - tcs[i].append(( - Observation.get_part(obs, NumLoops).long().item(), - action_index.unsqueeze(0), - obs, - next_obs, - action_bev_log_p.item(), - 0.0, # This will be filled after execution - done - )) - - traj_end_sampling = time() - sampling_time_ms = int((traj_end_sampling - traj_start) * 1000) - - results = dm.map_states(__execute_states, states, training=True) - all_rewards, all_speedups, all_exec_times, cache_misses, worker_times = tuple(zip(*results)) - cache_miss_rate = mean(cache_misses) * 100 - sequential_time = sum(worker_times) - new_cache_data: dict[str, dict[str, int]] = {} - for tc, state, rewards, speedup, exec_time in zip(tcs, states, all_rewards, all_speedups, all_exec_times): - # Update trajectory - tc.rewards = rewards - # Log metrics - fl['train/reward'].extend(rewards) - fl['train/final_speedup'].append(speedup) - # Get new cache data - if exec_time is not None: - cache_key = exe.get_code_cache_key(state.transformation_history) - if state.bench_name not in new_cache_data: - new_cache_data[state.bench_name] = {} - new_cache_data[state.bench_name][cache_key] = exec_time - - tc = sum(tcs, TrajectoryCollector()) - exe.update_execution_cache(new_cache_data) - - traj_end = time() - exec_time_ms = int((traj_end - traj_end_sampling) * 1000) - distribted_speedup = sequential_time / exec_time_ms - time_ms = int((traj_end - traj_start) * 1000) - print_info( - ( - f"{time_ms}ms" - f", sampling: {sampling_time_ms}ms" - f", exec: {exec_time_ms}ms" - f", speedup: {distribted_speedup:.2f}x" - f", cache miss rate: {cache_miss_rate:.2f}%" - ), add_label=False - ) - - return tc.to_trajectory() - - -def ppo_update(trajectory: TrajectoryData, model: Model, optimizer: torch.optim.Optimizer): - """Update the model using PPO. - - Args: - trajectory (TrajectoryData): The trajectory to use. - model (Model): The model to update. - optimizer (torch.optim.Optimizer): The optimizer to use. - - Returns: - float: The average loss. - """ - fl = FileLogger() - cfg = Config() - - trajectory.update_attributes(model) - data_loader = trajectory.loader(cfg.ppo_batch_size, 1) - - ppo_trange = trange(cfg.ppo_epochs, desc='PPO Epochs') - for _ in ppo_trange: - for batch in data_loader: - batch: list[torch.Tensor] = [e.to(device, non_blocking=True) for e in batch] - ( - _, - actions_index, - obs, - _, - actions_bev_log_p, - _, _, - values, - _, - actions_old_log_p, - off_policy_rates, - returns, - advantages, - ) = batch - max_abs_adv = advantages.abs().max() - if cfg.normalize_adv == 'standard' and advantages.size(0) > 1: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - elif cfg.normalize_adv == 'max-abs' and max_abs_adv > 0: - advantages = advantages / max_abs_adv - - with torch.enable_grad(): - actions_log_p, new_values, entropies = model(obs, actions_index) - - policy_loss, clip_frac = model.policy_model.loss(actions_log_p, actions_bev_log_p, off_policy_rates, advantages) - loss = policy_loss - - if cfg.value_epochs == 0: - value_loss = model.value_model.loss(new_values, values, returns) - loss += cfg.value_coef * value_loss - - if 'entropy' in cfg.exploration: - entropy_loss = -entropies.mean() - loss += cfg.entropy_coef * entropy_loss - - approx_kl = (actions_old_log_p - actions_log_p).pow(2).mean() / 2 - - optimizer.zero_grad() - try: - loss.backward() - clip_factor = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) - optimizer.step() - except Exception as e: - print_error(f'Error during PPO update: {e}') - - # Logging - ppo_trange.set_postfix({ - 'loss': loss.item(), - 'policy_loss': policy_loss.item(), - 'value_loss': value_loss.item() if cfg.value_epochs == 0 else None - }) - fl['train_ppo/policy_loss'].append(policy_loss.item()) - fl['train_ppo/clip_frac'].append(clip_frac.item()) - fl['train_ppo/clip_factor'].append(clip_factor.item()) - fl['train_ppo/approx_kl'].append(approx_kl.item()) - if cfg.value_epochs == 0: - fl['train_ppo/value_loss'].append(value_loss.item()) - if 'entropy' in cfg.exploration: - fl['train_ppo/entropy_loss'].append(entropy_loss.item()) - - -def value_update(trajectory: TrajectoryData, model: Model, optimizer: torch.optim.Optimizer): - """Update the value model using the trajectory. - - Args: - trajectory (Trajectory): The trajectory to use. - model (Model): The model to update. - optimizer (torch.optim.Optimizer): The optimizer to use. - """ - fl = FileLogger() - cfg = Config() - - trajectory.update_attributes(model) - data_loader = trajectory.loader(cfg.value_batch_size, 1) - - value_trange = trange(cfg.value_epochs, desc='Value Epochs') - for _ in value_trange: - for batch in data_loader: - batch: list[torch.Tensor] = [e.to(device, non_blocking=True) for e in batch] - ( - _, _, - obs, - _, _, _, _, - values, - _, _, _, - returns, - _, - ) = batch - with torch.enable_grad(): - new_values = model.value_model(obs) - - loss = model.value_model.loss(new_values, values, returns) - - optimizer.zero_grad() - try: - loss.backward() - clip_factor = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) - optimizer.step() - except Exception as e: - print_error(f'Error during Value update: {e}') - - # Logging - value_trange.set_postfix({'loss': loss.item()}) - fl['train_value/loss'].append(loss.item()) - fl['train_value/clip_factor'].append(clip_factor.item()) - - -def evaluate_benchmarks(model: Model, data: Benchmarks): - """Evaluate the benchmark using the model. - - Args: - model (Model): The model to use. - env (Env): The environment to use. - tmp_exec_data_file (str): The path to the temporary execution data file. - """ - dm = DaskManager() - fl = FileLogger() - exe = Execution() - - print_info("Evaluation started...") - eval_start = time() - - # Prepare benchmarks to explore - indices = range(len(data)) - envs: list[Env] = [] - states: list[OperationState] = [] - observations: list[torch.Tensor] = [] - for idx in indices: - env = Env() - state = env.reset(data, idx) - envs.append(env) - states.append(state) - observations.append(Observation.from_state(state)) - - while (active_states := [(i, s) for i, s in enumerate(states) if not s.terminal]): - # Sample states that are not terminal yet - obss = torch.cat([observations[i] for i, _ in active_states]) - actions_index, _, entropies = model.sample(obss.to(device), greedy=True) - fl['eval/entropy'].extend(entropies.tolist()) - - # Record data and update states - for (i, state), obs, action_index in zip(active_states, obss, actions_index): - obs = obs.unsqueeze(0) - - # Get action and use it to get next state - action = ActionSpace.action_by_index(action_index, state) - states[i] = next_state = envs[i].step(state, action) - observations[i] = Observation.from_state(next_state) - - # If the benchmark is not done yet, keep next operation state instead - if next_state.terminal: - next_op_state = envs[i].get_next_op_state(next_state) - if next_op_state is not None: - states[i] = next_op_state - observations[i] = Observation.from_state(next_op_state) - - results = dm.map_states(__execute_states, states, training=False) - all_rewards, all_speedups, all_exec_times, _, _ = tuple(zip(*results)) - new_cache_data: dict[str, dict[str, int]] = {} - for state, rewards, speedup, exec_time in zip(states, all_rewards, all_speedups, all_exec_times): - fl['eval/reward'].extend(rewards) - fl['eval/cumulative_reward'].append(sum(rewards)) - fl['eval/final_speedup'].append(speedup) - if exec_time is not None: - fl[f'eval/exec_time/{state.bench_name}'].append(exec_time) - fl[f'eval/speedup/{state.bench_name}'].append(speedup) - cache_key = exe.get_code_cache_key(state.transformation_history) - if state.bench_name not in new_cache_data: - new_cache_data[state.bench_name] = {} - new_cache_data[state.bench_name][cache_key] = exec_time - - print_success("Bench:", state.bench_name) - print_info(state.transformation_history) - - if len(all_speedups) > 0: - fl['eval/average_speedup'].append(sum(all_speedups) / len(all_speedups)) - exe.update_execution_cache(new_cache_data) - - eval_end = time() - time_ms = int((eval_end - eval_start) * 1000) - print_info(f"Evaluation time: {time_ms}ms") - - -def __execute_states(state: OperationState, exec_data_file: str, benchs: Benchmarks, main_exec_data: Optional[dict[str, dict[str, int]]]): - worker_start = time() - - Execution(exec_data_file, main_exec_data) - env = Env() - env.reset(benchs, state.bench_idx) - rewards, speedup, new_exec_time, cache_miss = env.apply_and_run_sequence(state.transformation_history) - - worker_end = time() - worker_time_ms = int((worker_end - worker_start) * 1000) - - return rewards, speedup, new_exec_time, cache_miss, worker_time_ms +import numpy as np +import torch +from rl_autoschedular.env import Env +from rl_autoschedular.model import HiearchyModel as Model +from rl_autoschedular.trajectory import TrajectoryCollector, TrajectoryData +from rl_autoschedular.observation import Observation, NumLoops +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular import config as cfg +from rl_autoschedular import file_logger as fl , offline_data_collector +from rl_autoschedular import device +from utils.log import print_error +from tqdm import trange +import time + + +def collect_trajectory(model: Model, env: Env, step: int): + """Collect a trajectory using the model and the environment (on-policy for PPO). + + Args: + model (Model): The policy/value model. + env (Env): The environment. + step (int): Current training step. + + Returns: + TrajectoryData: The collected trajectory. + """ + tc = TrajectoryCollector() + + env_time = 0.0 # Time spent in environment steps + + eps = None + if 'epsilon' in cfg.exploration: # optional exploration schedule + ratio = step / cfg.nb_iterations + final_eps = 0.001 + eps = final_eps + (cfg.init_epsilon - final_eps) * (1 - ratio) + + # store rewards and entropies to log average for the model accross the benchmarks later + all_speedups = [] + all_entropies = [] + + for _ in trange(cfg.bench_count, desc='Trajectory'): + + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done = False + speedup = None + + # store rewards and entropies to log average for the current benchmark later + bench_rewards, bench_entropies = [], [] + + bench_name = state.bench_name + + + while not bench_done: + obs = Observation.from_state(state) + + # Sample action and log-prob from *current policy* + action_index, action_log_p, entropy = model.sample(obs.to(device), eps=eps) + + assert action_index.size(0) == 1 and action_log_p.size(0) == 1 + action = ActionSpace.action_by_index(action_index[0], state) + + # Step environment + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + next_obs = Observation.from_state(next_state) + + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + + tc.append(( + Observation.get_part(obs, NumLoops).long().item(), + action_index, + obs, + next_obs, + reward, + bench_done, + )) + + if cfg.collect_offline_data: + offline_data_collector.add_transition( + obs, + action_index, + next_obs, + reward, + bench_done + ) + + + # Accumulate metrics + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # === Per-benchmark logging === + mean_reward = float(np.mean(bench_rewards)) if bench_rewards else 0.0 + mean_entropy = float(np.mean(bench_entropies)) if bench_entropies else 0.0 + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + + bench_metrics = { + "mean_reward": mean_reward, + "mean_entropy": mean_entropy, + "final_speedup": speedup if speedup is not None else 0.0, + } + + fl.log_scalars(f"train/{bench_name}", bench_metrics, step) + + # === Global logging (across all benchmarks) === + if all_speedups: + fl.log_scalar("train/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("train/average_entropy", float(np.mean(all_entropies)), step) + + if cfg.collect_offline_data: + offline_data_collector.flush() + + return tc.to_trajectory() , env_time + + +def ppo_update(trajectory: TrajectoryData, model: Model, optimizer: torch.optim.Optimizer,step): + """Update the model using PPO (on-policy). + + Args: + trajectory (TrajectoryData): The trajectory to use. + model (Model): The model to update. + optimizer (torch.optim.Optimizer): The optimizer to use. + + Returns: + float: The average loss across updates. + """ + trajectory.update_attributes(model) + + avg_loss = 0.0 + total_steps = 0 + + + metrics_accum = { + "policy_loss": 0.0, + "clip_frac": 0.0, + "clip_factor": 0.0, + "approx_kl": 0.0, + "value_loss": 0.0, + "mean_entropy": 0.0, + } + metrics_count = 0 + + ppo_trange = trange(cfg.ppo_epochs, desc='PPO Epochs') + for _ in ppo_trange: + for batch in trajectory.loader(cfg.ppo_batch_size, shuffle=True): + + batch = [e.to(device, non_blocking=True) for e in batch] + ( + _, + actions_index, + obs, + _, + _, + _, + values, + _, + actions_old_log_p, + returns, + advantages, + ) = batch + + # Normalize advantages if configured + max_abs_adv = advantages.abs().max() + if cfg.normalize_adv == 'standard' and advantages.size(0) > 1: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + elif cfg.normalize_adv == 'max-abs' and max_abs_adv > 0: + advantages = advantages / max_abs_adv + + with torch.enable_grad(): + # Forward pass through policy and value + actions_log_p, new_values, entropy = model(obs, actions_index) + + # PPO policy loss (clipped surrogate) + policy_loss, clip_frac = model.policy_model.loss( + actions_log_p, actions_old_log_p, advantages + ) + loss = policy_loss + + # Value loss + if cfg.value_epochs == 0: # update value jointly + value_loss = model.value_model.loss(new_values, values, returns) + loss += cfg.value_coef * value_loss + + # Entropy bonus + if 'entropy' in cfg.exploration: + entropy_loss = -entropy.mean() + loss += cfg.entropy_coef * entropy_loss + + # KL estimate + approx_kl = (actions_old_log_p - actions_log_p).pow(2).mean() / 2 + + # Gradient step + optimizer.zero_grad() + try: + loss.backward() + clip_factor = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + except Exception as e: + print_error(f'Error during PPO update: {e}') + continue + + # Logging + avg_loss += loss.item() * advantages.size(0) + total_steps += advantages.size(0) + + ppo_trange.set_postfix({ + 'loss': loss.item(), + 'policy_loss': policy_loss.item(), + 'value_loss': value_loss.item() if cfg.value_epochs == 0 else None + }) + + # Accumulate weighted averages + + bs = advantages.size(0) + avg_loss += loss.item() * bs + total_steps += bs + + metrics_accum["policy_loss"] += policy_loss.item() * bs + metrics_accum["clip_frac"] += clip_frac.item() * bs + metrics_accum["clip_factor"] += clip_factor.item() * bs + metrics_accum["approx_kl"] += approx_kl.item() * bs + if cfg.value_epochs == 0: + metrics_accum["value_loss"] += value_loss.item() * bs + if 'entropy' in cfg.exploration: + metrics_accum["mean_entropy"] -= entropy_loss.item() * bs + + metrics_count += bs + + # final averaging + final_metrics = {k: (v / metrics_count) for k, v in metrics_accum.items() if v != 0} + fl.log_scalars("PPO_Training", final_metrics, step) + + + + +def value_update(trajectory: TrajectoryData, model: Model, optimizer: torch.optim.Optimizer,step): + """Update the value model using the trajectory. + + Args: + trajectory (Trajectory): The trajectory to use. + model (Model): The model to update. + optimizer (torch.optim.Optimizer): The optimizer to use. + """ + trajectory.update_attributes(model) + + metrics_accum = { + "loss": 0.0, + "clip_factor": 0.0, + } + total_steps = 0 + + + value_trange = trange(cfg.value_epochs, desc='Value Epochs') + for _ in value_trange: + for batch in trajectory.loader(cfg.value_batch_size, shuffle=True): + + batch = [e.to(device, non_blocking=True) for e in batch] + ( + _, _, + obs, + _, _, # next_obs, rewards + _, # done + values, + _, _, # next_values, actions_old_log_p + returns, + _, # advantages + ) = batch + with torch.enable_grad(): + new_values = model.value_model(obs) + + loss = model.value_model.loss(new_values, values, returns) + + optimizer.zero_grad() + try: + loss.backward() + clip_factor = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + except Exception as e: + print_error(f'Error during Value update: {e}') + + # Accumulate weighted averages + bs = obs.size(0) + metrics_accum["loss"] += loss.item() * bs + metrics_accum["clip_factor"] += clip_factor.item() * bs + total_steps += bs + # === Final averaging === + if total_steps > 0: + final_metrics = {k: v / total_steps for k, v in metrics_accum.items()} + fl.log_scalars("Value_Training", final_metrics, step) + +@torch.no_grad() +def evaluate_benchmarks(model: Model, env: Env, step: int): + """Evaluta a the model on the evaluation environment. + + Args: + model (Model): The policy/value model. + env (Env): The environment. + step (int): Current training step. + + Returns: + env_time (float): Time spent in environment steps. + """ + + + env_time = 0.0 # Time spent in environment steps + + eps = None + + + # store rewards and entropies to log average for the model accross the benchmarks later + all_speedups = [] + all_entropies = [] + + + for _ in trange(cfg.bench_count, desc='Trajectory'): + + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done = False + speedup = None + + # store rewards and entropies to log average for the current benchmark later + bench_rewards, bench_entropies = [], [] + + bench_name = state.bench_name + + + while not bench_done: + obs = Observation.from_state(state) + + # Sample action and log-prob from *current policy* + action_index, action_log_p, entropy = model.sample(obs.to(device), eps=eps) + assert action_index.size(0) == 1 and action_log_p.size(0) == 1 + action = ActionSpace.action_by_index(action_index[0], state) + + # Step environment + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + next_obs = Observation.from_state(next_state) + + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + + # Accumulate metrics + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # === Per-benchmark logging === + mean_reward = float(np.mean(bench_rewards)) if bench_rewards else 0.0 + mean_entropy = float(np.mean(bench_entropies)) if bench_entropies else 0.0 + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + + bench_metrics = { + "mean_reward": mean_reward, + "mean_entropy": mean_entropy, + "final_speedup": speedup if speedup is not None else 0.0, + } + + fl.log_scalars(f"eval/{bench_name}", bench_metrics, step) + + print( + f"\033[92m\n- Eval Bench: {bench_name}\n" + f"- Mean Reward: {mean_reward:.4f}\n" + f"- Mean Entropy: {mean_entropy:.4f}\n" + f"- Final Speedup: {speedup if speedup is not None else 0.0:.4f}\033[0m" + ) + + + # === Global logging (across all benchmarks) === + if all_speedups: + fl.log_scalar("eval/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("eval/average_entropy", float(np.mean(all_entropies)), step) + + return env_time \ No newline at end of file diff --git a/rl_autoschedular/state.py b/rl_autoschedular/state.py old mode 100644 new mode 100755 index a730a61..527fc44 --- a/rl_autoschedular/state.py +++ b/rl_autoschedular/state.py @@ -1,322 +1,307 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional -from enum import Enum -import re -import os -import subprocess - -from utils.config import Config -from utils.log import print_error - -if TYPE_CHECKING: - from rl_autoschedular.actions.base import Action - - -class OperationType(Enum): - Generic = 'generic' - Matmul = 'matmul' - Conv = 'conv' - Pooling = 'pooling' - Add = 'add' - - unknown = '' - - -class IteratorType(Enum): - Parallel = 'parallel' - Reduction = 'reduction' - - -@dataclass -class NestedLoopFeatures: - """Dataclass to store the nested loops features data.""" - arg: str - """The argument representing the loop iterator.""" - lower_bound: int - """The lower bound of the loop.""" - upper_bound: int - """The upper bound of the loop.""" - step: int - """The loop step.""" - iterator_type: IteratorType - """The type of the loop iterator.""" - - def copy(self): - """Copy the current NestedLoopFeatures object.""" - return NestedLoopFeatures(self.arg, self.lower_bound, self.upper_bound, self.step, self.iterator_type) - - -@dataclass -class OperationFeatures: - """Dataclass to store the operation features data.""" - operation_name: str - """The name of the mlir operation.""" - operation_type: OperationType - """The type of the operation.""" - op_count: dict[str, int] - """Number of arithmetic operations in the operation.""" - load_data: list[list[str]] - """List of load accesses where each load is represented by the list of access arguments.""" - store_data: list[str] - """List of store accesses where each store is represented by the list of access arguments.""" - nested_loops: list[NestedLoopFeatures] - """List of nested loops where each loop is represented by the NestedLoopFeatures dataclass.""" - producers: list[str] - """List of tags of operations that are consumed by the current operation""" - vectorizable: bool - """Flag to indicate if the operation is vectorizable.""" - - def copy(self): - """Copy the current OperationFeatures object.""" - return OperationFeatures( - self.operation_name, - self.operation_type, - self.op_count.copy(), - [load.copy() for load in self.load_data], - self.store_data.copy(), - [loop.copy() for loop in self.nested_loops], - self.producers.copy(), - self.vectorizable - ) - - -@dataclass -class BenchmarkFeatures: - """Dataclass to store the benchmark features data.""" - bench_name: str - """The benchmark's name.""" - code: str - """The MLIR code of the benchmark.""" - operation_tags: list[str] - """List of operation tags.""" - operations: dict[str, OperationFeatures] - """List of operations where each operation is represented by the OperationFeatures dataclass.""" - root_exec_time: int - """Execution time of the benchmark in nanoseconds without any transformation.""" - - def copy(self): - """Copy the current BenchmarkFeatures object.""" - return BenchmarkFeatures( - self.bench_name, - self.code, - self.operation_tags.copy(), - {tag: op.copy() for tag, op in self.operations.items()}, - self.root_exec_time - ) - - -@dataclass -class OperationState: - bench_idx: int - """The benchmark's index.""" - bench_name: str - """The benchmark's name.""" - operation_tag: str - """Tag used to identify the operation in the MLIR code.""" - original_operation_features: OperationFeatures - """Features of the operation that will be kept always unchanged.""" - operation_features: OperationFeatures - """Features of the operation.""" - producer_tag: Optional[str] - """Tag that identifies the selected producer""" - producer_features: Optional[OperationFeatures] - """Features of the selected producer""" - step_count: int - """The current step in the list of transformations applied to the operation.""" - transformation_history: list[list['Action']] - """List of transformations with their parameters applied to the operation.""" - terminal: bool - """Flag that determines if the state is terminal""" - - def copy(self): - """Copy the current OperationState object.""" - return OperationState( - self.bench_idx, - self.bench_name, - self.operation_tag, - self.original_operation_features.copy(), - self.operation_features.copy(), - self.producer_tag, - self.producer_features.copy() if self.producer_features is not None else None, - self.step_count, - [seq.copy() for seq in self.transformation_history], - self.terminal - ) - - -def extract_bench_features_from_code(bench_name: str, code: str, root_execution_time: int): - """Extract benchmark features from the given code. - - Args: - bench_name (str): the benchmark name - code (str): the code to extract features from - root_execution_time (int): the root execution time - execution_time (int): the execution time - - Returns: - BenchmarkFeatures: the extracted benchmark features - """ - result = subprocess.run( - f'{os.getenv("AST_DUMPER_BIN_PATH")} -', - shell=True, - input=code.encode('utf-8'), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE - ) - raw_ast_info = result.stdout.decode('utf-8') - - return __extract_bench_features_from_ast_result(bench_name, raw_ast_info, root_execution_time) - - -def extract_bench_features_from_file(bench_name: str, file_path: str, root_execution_time: int): - """Extract benchmark features from the code in the file. - - Args: - bench_name (str): the benchmark name - file_path (str): the file path - root_execution_time (int): the root execution time - execution_time (int): the execution time - - Returns: - BenchmarkFeatures: the extracted benchmark features - """ - result = subprocess.run( - f'{os.getenv("AST_DUMPER_BIN_PATH")} {file_path}', - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE - ) - raw_ast_info = result.stdout.decode('utf-8') - - return __extract_bench_features_from_ast_result(bench_name, raw_ast_info, root_execution_time) - - -def __extract_bench_features_from_ast_result(bench_name: str, raw_ast_info: str, root_execution_time: int): - """Extracts benchmark features from the code's AST result and execution time. - - Args: - bench_name (str): the benchmark name - raw_ast_info (str): the raw AST information - root_execution_time (int): the root execution time - execution_time (int): the execution time - - Returns: - BenchmarkFeatures: extracted benchmark features - """ - cfg = Config() - - info, full_code = raw_ast_info.split("########################################") - operations_lines, graph_str = info.split('#BEGIN_GRAPH') - - operations_blocks = operations_lines.split('#START_OPERATION') - operations_blocks = [block.strip() for block in operations_blocks if block] - - ops_tags = [] - operations: dict[str, OperationFeatures] = {} - for operation_block in operations_blocks: - rest, operation_tag = operation_block.split("#START_TAG") - operation_tag = operation_tag.strip().split("\n")[0] - log_info = f"- Bench: {bench_name} - Operation: {operation_tag}" - - operation_name, rest = rest.split("#START_VECTORIZABLE") - operation_type = __get_operation_type(operation_name) - if operation_type is None: - print_error(log_info) - print_error("Unsupported operation type:", operation_name) - continue - - nested_loops = [] - op_count = {} - load_data: list[list[str]] = [] - store_data: list[str] = [] - - vectorizable_str, rest = rest.split("#START_NESTED_LOOPS") - assert vectorizable_str.strip() in ["true", "false"], f"Vectorizable string is not valid: {vectorizable_str}" - vectorizable = vectorizable_str.strip() == "true" - - nested_loops_str, rest = rest.split("#START_LOAD_DATA") - for nested_loop_str in nested_loops_str.strip().split("\n"): - if not nested_loop_str: - continue - arg, low, high, step, iter = nested_loop_str.strip().split(" ") - nested_loops.append(NestedLoopFeatures( - arg=f'%{arg}', - lower_bound=int(low), - upper_bound=int(high), - step=int(step), - iterator_type=IteratorType(iter) - )) - if len(nested_loops) > cfg.max_num_loops: - print_error(log_info) - print_error(f"Number of loops {len(nested_loops)} is not supported") - continue - - loads_data_str, rest = rest.split("#START_STORE_DATA") - loads_data_str = re.sub(r'd\d+', lambda m: f'%{m.group()}', loads_data_str) - for load_data_str in loads_data_str.strip().split("\n"): - if not load_data_str: - continue - load_data.append(load_data_str.split(", ")) - if any(len(load) > cfg.max_num_load_store_dim for load in load_data): - print_error(log_info) - print_error(f"Number of load dims {len(load_data[-1])} is not supported") - continue - if len(load_data) > cfg.max_num_stores_loads: - # We ignore this overflow, because there are many cases with a huge number of loads - load_data = load_data[:cfg.max_num_stores_loads] - - store_data_str, ops_count_str = rest.split("#START_OP_COUNT") - store_data_str = re.sub(r'd\d+', lambda m: f'%{m.group()}', store_data_str) - store_data_list = store_data_str.strip().split("\n") - assert len(store_data_list) == 1, f"Store data list is not of length 1: {store_data_list}" - store_data = store_data_list[0].split(", ") - if len(store_data) > cfg.max_num_load_store_dim: - print_error(log_info) - print_error(f"Number of store dims {len(store_data)} is not supported") - continue - - for op_count_str in ops_count_str.strip().split("\n"): - op, count = op_count_str.strip().split(" ") - op_count[op] = int(count) - - ops_tags.append(operation_tag) - operations[operation_tag] = OperationFeatures( - operation_name=operation_name, - operation_type=operation_type, - op_count=op_count, - load_data=load_data, - store_data=store_data, - nested_loops=nested_loops, - producers=[], - vectorizable=vectorizable - ) - - # Extracte Producer/Consumer features - graph_str = graph_str.replace("#END_GRAPH", "") - graph_lines = [(line.split(' --> ')[0], line.split(' --> ')[1]) for line in graph_str.strip().split("\n") if line] - - for producer, consumer in graph_lines: - operations[consumer].producers.append(producer) - - return BenchmarkFeatures( - bench_name=bench_name, - code=full_code, - operation_tags=ops_tags, - operations=operations, - root_exec_time=root_execution_time, - ) - - -def __get_operation_type(operation_name: str) -> Optional[OperationType]: - """Get the operation type from the operation name. - - Args: - operation_name (str): The operation name. - - Returns: - Optional[OperationType]: The operation type or None if not found. - """ - for operation_type in OperationType: - if operation_type.value and operation_type.value in operation_name: - return operation_type - return OperationType.unknown +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional +from enum import Enum +from rl_autoschedular import config as cfg +import re +import os +import subprocess + +from utils.log import print_error + +if TYPE_CHECKING: + from rl_autoschedular.actions.base import Action + + +class OperationType(Enum): + Generic = 'generic' + Matmul = 'matmul' + Conv2D = 'conv_2d' + Pooling = 'pooling' + Add = 'add' + + +class IteratorType(Enum): + Parallel = 'parallel' + Reduction = 'reduction' + + +@dataclass +class NestedLoopFeatures: + """Dataclass to store the nested loops features data.""" + arg: str + """The argument representing the loop iterator.""" + lower_bound: int + """The lower bound of the loop.""" + upper_bound: int + """The upper bound of the loop.""" + step: int + """The loop step.""" + iterator_type: IteratorType + """The type of the loop iterator.""" + + def copy(self): + """Copy the current NestedLoopFeatures object.""" + return NestedLoopFeatures(self.arg, self.lower_bound, self.upper_bound, self.step, self.iterator_type) + + +@dataclass +class OperationFeatures: + """Dataclass to store the operation features data.""" + raw_operation: str + """The raw operation string without wrapping or transformations.""" + operation_type: OperationType + """The type of the operation (generic, matmul, conv2d, ...).""" + op_count: dict[str, int] + """Number of arithmetic operations in the operation.""" + load_data: list[list[str]] + """List of load accesses where each load is represented by the list of access arguments.""" + store_data: list[str] + """List of store accesses where each store is represented by the list of access arguments.""" + nested_loops: list[NestedLoopFeatures] + """List of nested loops where each loop is represented by the NestedLoopFeatures dataclass.""" + vectorizable: bool + """Flag to indicate if the operation is vectorizable.""" + + def copy(self): + """Copy the current OperationFeatures object.""" + return OperationFeatures( + self.raw_operation, + self.operation_type, + self.op_count.copy(), + [load.copy() for load in self.load_data], + self.store_data.copy(), + [loop.copy() for loop in self.nested_loops], + self.vectorizable + ) + + +@dataclass +class BenchmarkFeatures: + """Dataclass to store the benchmark features data.""" + bench_name: str + """The benchmark's name.""" + code: str + """The MLIR code of the benchmark.""" + operation_tags: list[str] + """List of operation tags.""" + operations: dict[str, OperationFeatures] + """List of operations where each operation is represented by the OperationFeatures dataclass.""" + root_exec_time: int + """Execution time of the benchmark in nanoseconds without any transformation.""" + + def copy(self): + """Copy the current BenchmarkFeatures object.""" + return BenchmarkFeatures( + self.bench_name, + self.code, + self.operation_tags.copy(), + {tag: op.copy() for tag, op in self.operations.items()}, + self.root_exec_time + ) + + +@dataclass +class OperationState: + bench_name: str + """The benchmark's name.""" + operation_tag: str + """Tag used to identify the operation in the MLIR code.""" + operation_features: OperationFeatures + """Features of the operation.""" + validated_code: str + """The latest validated benchmark code (if not in inference, this will always be the original code).""" + transformed_code: str + """The operation string with wrapping and transformations.""" + step_count: int + """The current step in the list of transformations applied to the operation.""" + exec_time: int + """Execution time of the operation in nanoseconds.""" + transformation_history: list[list['Action']] + """List of transformations with their parameters applied to the operation.""" + tmp_file: str + """Temporary file to store the MLIR code.""" + terminal: bool + """Flag that determines if the state is terminal""" + + def copy(self): + """Copy the current OperationState object.""" + return OperationState( + self.bench_name, + self.operation_tag, + self.operation_features.copy(), + self.validated_code, + self.transformed_code, + self.step_count, + self.exec_time, + [seq.copy() for seq in self.transformation_history], + self.tmp_file, + self.terminal + ) + + +def extract_bench_features_from_code(bench_name: str, code: str, root_execution_time: int): + """Extract benchmark features from the given code. + + Args: + bench_name (str): the benchmark name + code (str): the code to extract features from + root_execution_time (int): the root execution time + execution_time (int): the execution time + + Returns: + BenchmarkFeatures: the extracted benchmark features + """ + result = subprocess.run( + f'{os.getenv("AST_DUMPER_BIN_PATH")} -', + shell=True, + input=code.encode('utf-8'), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + raw_ast_info = result.stdout.decode('utf-8') + + return __extract_bench_features_from_ast_result(bench_name, raw_ast_info, root_execution_time) + + +def extract_bench_features_from_file(bench_name: str, file_path: str, root_execution_time: int): + """Extract benchmark features from the code in the file. + + Args: + bench_name (str): the benchmark name + file_path (str): the file path + root_execution_time (int): the root execution time + execution_time (int): the execution time + + Returns: + BenchmarkFeatures: the extracted benchmark features + """ + result = subprocess.run( + f'{os.getenv("AST_DUMPER_BIN_PATH")} {file_path}', + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + raw_ast_info = result.stdout.decode('utf-8') + + return __extract_bench_features_from_ast_result(bench_name, raw_ast_info, root_execution_time) + + +def __extract_bench_features_from_ast_result(bench_name: str, raw_ast_info: str, root_execution_time: int): + """Extracts benchmark features from the code's AST result and execution time. + + Args: + bench_name (str): the benchmark name + raw_ast_info (str): the raw AST information + root_execution_time (int): the root execution time + execution_time (int): the execution time + + Returns: + BenchmarkFeatures: extracted benchmark features + """ + info, full_code = raw_ast_info.split("########################################") + operations_lines, _ = info.split('#BEGIN_GRAPH') + + operations_blocks = operations_lines.split('#START_OPERATION') + operations_blocks = [block.strip() for block in operations_blocks if block] + + ops_tags = [] + operations = {} + for operation_block in operations_blocks: + rest, operation_tag = operation_block.split("#START_TAG") + operation_tag = operation_tag.strip().split("\n")[0] + log_info = f"- Bench: {bench_name} - Operation: {operation_tag}" + + raw_operation, rest = rest.split("#START_VECTORIZABLE") + operation_type = __get_operation_type(raw_operation) + if operation_type is None: + print_error(log_info) + print_error("Unsupported operation type:", raw_operation.split("\n")[0]) + continue + + nested_loops = [] + op_count = {} + load_data: list[list[str]] = [] + store_data: list[str] = [] + + vectorizable_str, rest = rest.split("#START_NESTED_LOOPS") + assert vectorizable_str.strip() in ["true", "false"], f"Vectorizable string is not valid: {vectorizable_str}" + vectorizable = vectorizable_str.strip() == "true" + + nested_loops_str, rest = rest.split("#START_LOAD_DATA") + for nested_loop_str in nested_loops_str.strip().split("\n"): + if not nested_loop_str: + continue + arg, low, high, step, iter = nested_loop_str.strip().split(" ") + nested_loops.append(NestedLoopFeatures( + arg=f'%{arg}', + lower_bound=int(low), + upper_bound=int(high), + step=int(step), + iterator_type=IteratorType(iter) + )) + if len(nested_loops) > cfg.max_num_loops: + print_error(log_info) + print_error(f"Number of loops {len(nested_loops)} is not supported") + continue + + loads_data_str, rest = rest.split("#START_STORE_DATA") + loads_data_str = re.sub(r'd\d+', lambda m: f'%{m.group()}', loads_data_str) + for load_data_str in loads_data_str.strip().split("\n"): + if not load_data_str: + continue + load_data.append(load_data_str.split(", ")) + if any(len(load) > cfg.max_num_load_store_dim for load in load_data): + print_error(log_info) + print_error(f"Number of load dims {len(load_data[-1])} is not supported") + continue + if len(load_data) > cfg.max_num_stores_loads: + # We ignore this overflow, because there are many cases with a huge number of loads + load_data = load_data[:cfg.max_num_stores_loads] + + store_data_str, ops_count_str = rest.split("#START_OP_COUNT") + store_data_str = re.sub(r'd\d+', lambda m: f'%{m.group()}', store_data_str) + store_data_list = store_data_str.strip().split("\n") + assert len(store_data_list) == 1, f"Store data list is not of length 1: {store_data_list}" + store_data = store_data_list[0].split(", ") + if len(store_data) > cfg.max_num_load_store_dim: + print_error(log_info) + print_error(f"Number of store dims {len(store_data)} is not supported") + continue + + for op_count_str in ops_count_str.strip().split("\n"): + op, count = op_count_str.strip().split(" ") + op_count[op] = int(count) + + ops_tags.append(operation_tag) + operations[operation_tag] = OperationFeatures( + raw_operation=raw_operation, + operation_type=operation_type, + op_count=op_count, + load_data=load_data, + store_data=store_data, + nested_loops=nested_loops, + vectorizable=vectorizable + ) + + return BenchmarkFeatures( + bench_name=bench_name, + code=full_code, + operation_tags=ops_tags, + operations=operations, + root_exec_time=root_execution_time, + ) + + +def __get_operation_type(raw_operation: str) -> Optional[OperationType]: + """Get the operation type from the raw operation string. + + Args: + raw_operation (str): The raw operation string. + + Returns: + Optional[OperationType]: The operation type or None if not found. + """ + for operation_type in OperationType: + if f'linalg.{operation_type.value}' in raw_operation: + return operation_type + return None diff --git a/rl_autoschedular/trajectory.py b/rl_autoschedular/trajectory.py old mode 100644 new mode 100755 index 50e61be..4a049c6 --- a/rl_autoschedular/trajectory.py +++ b/rl_autoschedular/trajectory.py @@ -1,403 +1,187 @@ -import torch -from torch.utils.data import Dataset, DataLoader, Sampler, RandomSampler -from typing import Iterator -from rl_autoschedular import device -from rl_autoschedular.model import HiearchyModel as Model -from time import time -from utils.config import Config -from utils.log import print_info - - -T_timestep = tuple[ - int, # num_loops - torch.Tensor, # action_index - torch.Tensor, # obs - torch.Tensor, # next_obs - float, # action_bev_log_p - float, # reward - bool, # done -] - -DYNAMIC_ATTRS = ['values', 'next_values', 'actions_old_log_p', 'off_policy_rates', 'returns', 'advantages'] - - -class TopKAdvantageSampler(Sampler[int]): - """ - A Sampler that yields a random permutation of the indices - corresponding to the top-K highest advantage values in a trajectory. - - Args: - data_source (TrajectoryData): The dataset, which must have a `get_all_advantages` method. - num_samples (int): The maximum number of samples to take batches from. - """ - def __init__(self, data_source: 'TrajectoryData', num_samples: int): - self.data_source = data_source - self.num_samples = num_samples - - # Get all advantage values from the dataset - advantages = self.data_source.advantages - - # Ensure we don't request more samples than available - self.num_samples = min(self.num_samples, advantages.size(0)) - - _, self.top_k_indices = torch.topk(advantages.abs(), k=self.num_samples) - - def __iter__(self) -> Iterator[int]: - """ - Returns an iterator over shuffled indices of the top-k samples. - This is called by the DataLoader at the start of each epoch. - """ - # Shuffle the top-k indices to ensure random order - shuffled_indices = self.top_k_indices[torch.randperm(self.num_samples)] - - # Yield the indices one by one - yield from shuffled_indices.tolist() - - def __len__(self) -> int: - """The total number of samples to be drawn.""" - return self.num_samples - - -class TrajectoryData(Dataset): - """Dataset to store the trajectory data. - - Args: - num_loops (torch.Tensor): Number of loops in the trajectory. - actions_index (torch.Tensor): Actions indices in the trajectory. - obs (torch.Tensor): Observations in the trajectory. - next_obs (torch.Tensor): Observations of next states in the trajectory. - actions_bev_log_p (torch.Tensor): Action log probabilities following behavioral policy. - rewards (torch.Tensor): Rewards in the trajectory. - done (torch.Tensor): Done flags in the trajectory. - """ - sizes: list[int] - """Sizes of all the included trajectories""" - - num_loops: torch.Tensor - """Number of loops in the trajectory.""" - actions_index: torch.Tensor - """Actions in the trajectory.""" - obs: torch.Tensor - """Observations in the trajectory""" - next_obs: torch.Tensor - """Observations of next states in the trajectory.""" - actions_bev_log_p: torch.Tensor - """Action log probabilities following behavioral policy in the trajectory.""" - rewards: torch.Tensor - """Rewards in the trajectory.""" - done: torch.Tensor - """Done flags in the trajectory.""" - - values: torch.Tensor - """Values of actions in the trajectory.""" - next_values: torch.Tensor - """Values of actions in the trajectory with one additional step (shifted to one step in the future).""" - actions_old_log_p: torch.Tensor - """Action log probabilities following old policy in the trajectory.""" - off_policy_rates: torch.Tensor - """Off-policy rates (rho) for the current policy.""" - returns: torch.Tensor - """Returns in the trajectory.""" - advantages: torch.Tensor - """Advantages in the trajectory.""" - - def __init__( - self, - num_loops: torch.Tensor, - actions_index: torch.Tensor, - obs: torch.Tensor, - next_obs: torch.Tensor, - actions_bev_log_p: torch.Tensor, - rewards: torch.Tensor, - done: torch.Tensor - ): - self.num_loops = num_loops - self.actions_index = actions_index - self.obs = obs - self.next_obs = next_obs - self.actions_bev_log_p = actions_bev_log_p - self.rewards = rewards - self.done = done - - self.sizes = [len(self)] - - def __len__(self) -> int: - """Get the length of the trajectory. - - Returns: - int: The length of the trajectory. - """ - return self.obs.size(0) - - def __getitem__(self, idx: int): - """Get a single timestep from the trajectory. - - Args: - idx (int): Index of the timestep to retrieve. - - Returns: - tuple: A tuple containing the timestep data. - """ - return ( - self.num_loops[idx], - self.actions_index[idx], - self.obs[idx], - self.next_obs[idx], - self.actions_bev_log_p[idx], - self.rewards[idx], - self.done[idx], - - self.values[idx], - self.next_values[idx], - self.actions_old_log_p[idx], - self.off_policy_rates[idx], - self.returns[idx], - self.advantages[idx], - ) - - def __add__(self, other: 'TrajectoryData'): - """Concatenate this trajectory with another. - - Args: - other (TrajectoryData): The other trajectory to concatenate with - - Returns: - TrajectoryData: The trajectory containing both - """ - self_other_sizes = self.sizes + other.sizes - - # Truncate to 10 trajectories - self_other_sizes = self_other_sizes[-Config().replay_count:] - start = - sum(self_other_sizes) - assert len(self_other_sizes) <= Config().replay_count - - self_other = TrajectoryData( - torch.cat((self.num_loops, other.num_loops))[start:], - torch.cat((self.actions_index, other.actions_index))[start:], - torch.cat((self.obs, other.obs))[start:], - torch.cat((self.next_obs, other.next_obs))[start:], - torch.cat((self.actions_bev_log_p, other.actions_bev_log_p))[start:], - torch.cat((self.rewards, other.rewards))[start:], - torch.cat((self.done, other.done))[start:], - ) - for attr in DYNAMIC_ATTRS: - if hasattr(self, attr) and hasattr(other, attr): - self_val = getattr(self, attr) - other_val = getattr(other, attr) - assert isinstance(self_val, torch.Tensor) and isinstance(other_val, torch.Tensor) - setattr(self_other, attr, torch.cat(self_val, other_val)[start:]) - - self_other.sizes = self_other_sizes - - assert len(self_other) == sum(self_other_sizes) - - return self_other - - def loader(self, batch_size: int, num_trajectories: int): - """Create a DataLoader for the trajectory. - - Args: - batch_size (int): Batch size for the DataLoader. - num_samples (int): Maximum number of samples to take batches from. - - Returns: - DataLoader: The DataLoader for the trajectory. - """ - num_samples = sum(self.sizes[-num_trajectories:]) - match Config().reuse_experience: - case 'topk': - sampler = TopKAdvantageSampler(self, num_samples) - case 'random': - sampler = RandomSampler(self, num_samples=num_samples) - case 'none': - sampler = None - - return DataLoader( - self, - batch_size=batch_size, - shuffle=sampler is None, - sampler=sampler, - pin_memory=device.type != 'cpu', - ) - - def copy(self) -> 'TrajectoryData': - """Copy the trajectory. - - Returns: - TrajectoryData: The copied trajectory. - """ - self_copy = TrajectoryData( - num_loops=self.num_loops.clone(), - actions_index=self.actions_index.clone(), - obs=self.obs.clone(), - next_obs=self.next_obs.clone(), - actions_bev_log_p=self.actions_bev_log_p.clone(), - rewards=self.rewards.clone(), - done=self.done.clone(), - ) - for attr in DYNAMIC_ATTRS: - if hasattr(self, attr): - attr_val = getattr(self, attr) - assert isinstance(attr_val, torch.Tensor) - setattr(self_copy, attr, attr_val.clone()) - - self_copy.sizes = self.sizes.copy() - - return self_copy - - def update_attributes(self, model: Model): - """Update the attributes of the trajectory following the new model. - - Args: - model (Model): The model to use for updating the attributes. - """ - start = time() - actions_old_log_p, values, _ = model(self.obs.to(device), self.actions_index) - next_values = model.value_model(self.next_obs.to(device)) - - self.actions_old_log_p, self.values, self.next_values = actions_old_log_p.cpu(), values.cpu(), next_values.cpu() - - self.__compute_rho() - self.__compute_returns() - self.__compute_gae() - end = time() - time_ms = int((end - start) * 1000) - print_info(f"Updated {len(self)} attributes in {time_ms}ms") - - def __compute_rho(self) -> torch.Tensor: - """Compute the off-policy rate (rho) for the current policy. - - Returns: - torch.Tensor: The off-policy rate. - """ - if 'epsilon' not in Config().exploration and Config().reuse_experience == 'none': - self.off_policy_rates = torch.ones_like(self.actions_bev_log_p) - return - - self.off_policy_rates = torch.exp(torch.clamp(self.actions_old_log_p - self.actions_bev_log_p, -80.0, 80.0)) - - def __compute_returns(self, gamma: float = 0.99) -> torch.Tensor: - """Compute the returns. - - Args: - done (torch.Tensor): done flags. - rewards (torch.Tensor): rewards. - gamma (float): discount factor. Defaults to 1. - - Returns: - torch.Tensor: returns. - """ - self.returns = torch.zeros(len(self), dtype=torch.float32) - last_return = 0 - - for t in reversed(range(len(self))): - mask = ~self.done[t] - last_return = last_return * mask - - last_return = self.values[t] + (self.rewards[t] + gamma * last_return - self.values[t]) * self.off_policy_rates[t].clamp_max(1) - - self.returns[t] = last_return - - def __compute_gae(self, gamma: float = 0.99, lambda_: float = 0.95) -> torch.Tensor: - """Compute the Generalized Advantage Estimation. - - Args: - gamma (float): discount factor. - lambda_ (float): GAE factor. - - Returns: - torch.Tensor: advantages. - torch.Tensor: returns. - """ - self.advantages = torch.zeros(len(self), dtype=torch.float32) - last_advantage = 0 - - for t in reversed(range(len(self))): - mask = ~self.done[t] - last_value = self.next_values[t] * mask - last_advantage = last_advantage * mask - - delta = self.rewards[t] + gamma * last_value - self.values[t] - last_advantage = delta + gamma * lambda_ * last_advantage - - self.advantages[t] = last_advantage - - -class TrajectoryCollector: - """Class that appends timestep data to a trajectory.""" - - num_loops: list[int] - """Number of loops in the trajectory.""" - actions_index: list[torch.Tensor] - """Actions in the trajectory.""" - obs: list[torch.Tensor] - """Observations in the trajectory.""" - next_obs: list[torch.Tensor] - """Observations of next states in the trajectory.""" - actions_bev_log_p: list[float] - """Action log probabilities following behavioral policy in the trajectory.""" - rewards: list[float] - """Rewards in the trajectory.""" - done: list[bool] - """Done flags in the trajectory.""" - - def __init__(self): - """Initialize the trajectory collector.""" - self.num_loops = [] - self.actions_index = [] - self.obs = [] - self.next_obs = [] - self.actions_bev_log_p = [] - self.rewards = [] - self.done = [] - - def __add__(self, other: 'TrajectoryCollector'): - self.num_loops.extend(other.num_loops) - self.actions_index.extend(other.actions_index) - self.obs.extend(other.obs) - self.next_obs.extend(other.next_obs) - self.actions_bev_log_p.extend(other.actions_bev_log_p) - self.rewards.extend(other.rewards) - self.done.extend(other.done) - - return self - - def append(self, timestep: T_timestep): - """Append a single timestep to the trajectory. - - Args: - timestep (T_timestep): The timestep data to append. - """ - self.num_loops.append(timestep[0]) - self.actions_index.append(timestep[1]) - self.obs.append(timestep[2]) - self.next_obs.append(timestep[3]) - self.actions_bev_log_p.append(timestep[4]) - self.rewards.append(timestep[5]) - self.done.append(timestep[6]) - - def to_trajectory(self) -> TrajectoryData: - """Convert the collected data to a TrajectoryData object. - - Returns: - TrajectoryData: The trajectory containing all collected data. - """ - return TrajectoryData( - num_loops=torch.tensor(self.num_loops, dtype=torch.int64), - actions_index=torch.cat(self.actions_index), - obs=torch.cat(self.obs), - next_obs=torch.cat(self.next_obs), - actions_bev_log_p=torch.tensor(self.actions_bev_log_p, dtype=torch.float32), - rewards=torch.tensor(self.rewards, dtype=torch.float32), - done=torch.tensor(self.done, dtype=torch.bool), - ) - - def reset(self): - """Reset the trajectory collector.""" - self.num_loops.clear() - self.actions_index.clear() - self.obs.clear() - self.next_obs.clear() - self.actions_bev_log_p.clear() - self.rewards.clear() - self.done.clear() +import torch +from torch.utils.data import Dataset, DataLoader +from typing import Iterator +from rl_autoschedular import device +from rl_autoschedular.model import HiearchyModel as Model + + +T_timestep = tuple[ + int, # num_loops + torch.Tensor, # action_index + torch.Tensor, # obs + torch.Tensor, # next_obs + float, # reward + bool, # done +] + +# Only PPO-relevant attributes now +DYNAMIC_ATTRS = ['values', 'next_values', 'actions_old_log_p', 'returns', 'advantages'] + + +class TrajectoryData(Dataset): + """On-policy trajectory dataset for PPO.""" + + num_loops: torch.Tensor + actions_index: torch.Tensor + obs: torch.Tensor + next_obs: torch.Tensor + rewards: torch.Tensor + done: torch.Tensor + + values: torch.Tensor + next_values: torch.Tensor + actions_old_log_p: torch.Tensor + returns: torch.Tensor + advantages: torch.Tensor + + def __init__( + self, + num_loops: torch.Tensor, + actions_index: torch.Tensor, + obs: torch.Tensor, + next_obs: torch.Tensor, + rewards: torch.Tensor, + done: torch.Tensor + ): + self.num_loops = num_loops + self.actions_index = actions_index + self.obs = obs + self.next_obs = next_obs + self.rewards = rewards + self.done = done + + def __len__(self) -> int: + return self.obs.size(0) + + def __getitem__(self, idx: int): + return ( + self.num_loops[idx], + self.actions_index[idx], + self.obs[idx], + self.next_obs[idx], + self.rewards[idx], + self.done[idx], + self.values[idx], + self.next_values[idx], + self.actions_old_log_p[idx], + self.returns[idx], + self.advantages[idx], + ) + + def __add__(self, other: 'TrajectoryData'): + """Concatenate with another trajectory.""" + self_other = TrajectoryData( + torch.cat((self.num_loops, other.num_loops)), + torch.cat((self.actions_index, other.actions_index)), + torch.cat((self.obs, other.obs)), + torch.cat((self.next_obs, other.next_obs)), + torch.cat((self.rewards, other.rewards)), + torch.cat((self.done, other.done)), + ) + for attr in DYNAMIC_ATTRS: + if hasattr(self, attr) and hasattr(other, attr): + self_val = getattr(self, attr) + other_val = getattr(other, attr) + assert isinstance(self_val, torch.Tensor) and isinstance(other_val, torch.Tensor) + setattr(self_other, attr, torch.cat((self_val, other_val))) + return self_other + + def loader(self, batch_size: int, shuffle: bool = True): + """Create DataLoader for PPO training (uniform sampling).""" + return DataLoader( + self, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + pin_memory=False, + ) + + def copy(self) -> 'TrajectoryData': + """Copy the trajectory.""" + self_copy = TrajectoryData( + num_loops=self.num_loops.clone(), + actions_index=self.actions_index.clone(), + obs=self.obs.clone(), + next_obs=self.next_obs.clone(), + rewards=self.rewards.clone(), + done=self.done.clone(), + ) + for attr in DYNAMIC_ATTRS: + if hasattr(self, attr): + attr_val = getattr(self, attr) + assert isinstance(attr_val, torch.Tensor) + setattr(self_copy, attr, attr_val.clone()) + return self_copy + + def update_attributes(self, model: Model): + """Update log-probs, values, returns, and advantages with the current model.""" + actions_old_log_p, values, _ = model(self.obs.to(device), self.actions_index) + next_values = model.value_model(self.next_obs.to(device)) + + self.actions_old_log_p = actions_old_log_p.cpu() + self.values = values.cpu() + self.next_values = next_values.cpu() + + self.__compute_returns() + self.__compute_gae() + + def __compute_returns(self, gamma: float = 0.99): + """Compute discounted returns (standard PPO style).""" + self.returns = torch.zeros(len(self), dtype=torch.float32) + last_return = 0.0 + for t in reversed(range(len(self))): + mask = 1.0 - float(self.done[t]) + last_return = self.rewards[t] + gamma * last_return * mask + self.returns[t] = last_return + + def __compute_gae(self, gamma: float = 0.99, lambda_: float = 0.95): + """Compute Generalized Advantage Estimation (GAE).""" + self.advantages = torch.zeros(len(self), dtype=torch.float32) + last_advantage = 0.0 + for t in reversed(range(len(self))): + mask = 1.0 - float(self.done[t]) + last_value = self.next_values[t] * mask + last_advantage = last_advantage * mask + + delta = self.rewards[t] + gamma * last_value - self.values[t] + last_advantage = delta + gamma * lambda_ * last_advantage + + self.advantages[t] = last_advantage + + +class TrajectoryCollector: + """Collect timesteps into a trajectory for PPO.""" + + def __init__(self): + self.num_loops = [] + self.actions_index = [] + self.obs = [] + self.next_obs = [] + self.rewards = [] + self.done = [] + + def append(self, timestep: T_timestep): + self.num_loops.append(timestep[0]) + self.actions_index.append(timestep[1]) + self.obs.append(timestep[2]) + self.next_obs.append(timestep[3]) + self.rewards.append(timestep[4]) + self.done.append(timestep[5]) + + def to_trajectory(self) -> TrajectoryData: + return TrajectoryData( + num_loops=torch.tensor(self.num_loops, dtype=torch.int64), + actions_index=torch.cat(self.actions_index), + obs=torch.cat(self.obs), + next_obs=torch.cat(self.next_obs), + rewards=torch.tensor(self.rewards, dtype=torch.float32), + done=torch.tensor(self.done, dtype=torch.bool), + ) + + def reset(self): + self.num_loops.clear() + self.actions_index.clear() + self.obs.clear() + self.next_obs.clear() + self.rewards.clear() + self.done.clear() diff --git a/rl_autoschedular/transforms.py b/rl_autoschedular/transforms.py old mode 100644 new mode 100755 index a2478a2..65f93d0 --- a/rl_autoschedular/transforms.py +++ b/rl_autoschedular/transforms.py @@ -1,426 +1,595 @@ -import os -import subprocess -from mlir.ir import Context, Module -from mlir.dialects.transform import interpreter -from mlir.passmanager import PassManager - - -def transform_TP(code: str, operation_tag: str, tiling_sizes: list[int]): - """Apply the tiling and parallelization transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - tiling_sizes (list[int]): The tiling size to apply. - - Returns: - str: The code after applying the transformation. - """ - # If tiling sizes are all zeros, means no tiling is needed - if all([a == 0 for a in tiling_sizes]): - return code - - # Add full transform dialect code into the main code - transform_code = ( - f'\nmodule attributes {{transform.with_named_sequence}} {{\n' - f' transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{\n' - f' %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' - f' %op_tiled_{operation_tag}, %forall_{operation_tag} = transform.structured.tile_using_forall %op_{operation_tag} tile_sizes {str(tiling_sizes)} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)\n' - f' transform.yield\n' - f' }}\n' - f'}}' - ) - - return __run_transform_code(code, transform_code) - - -def transform_tile(code: str, operation_tag: str, tiling_sizes: list[int]): - """Apply the tiling transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - tiling_sizes (list[int]): The tiling size to apply. - - Returns: - str: The code after applying the transformation. - """ - # If tiling sizes are all zeros, means no tiling is needed - if all([a == 0 for a in tiling_sizes]): - return code - - n_loops = sum([s != 0 for s in tiling_sizes]) - r = ', '.join(['!transform.any_op'] * n_loops) - assert n_loops > 0, "No loops to tile" - - transform_code = ( - f'\nmodule attributes {{transform.with_named_sequence}} {{\n' - f' transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{\n' - f' %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' - f' %tiled_op_{operation_tag}, %loops:{n_loops} = transform.structured.tile_using_for %op_{operation_tag} tile_sizes {str(tiling_sizes)} : (!transform.any_op) -> (!transform.any_op, {r})\n' - f' transform.yield\n' - f' }}\n' - f'}}\n' - ) - - return __run_transform_code(code, transform_code) - - -def transform_interchange(code: str, operation_tag: str, interchange_list: list[int]): - """Apply the interchange transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - interchange_list (list[int]): The interchange list to apply. - - Returns: - str: The code after applying the transformation. - """ - # If the permutation list is same as the identity permutation, means no interchange is needed - if interchange_list == list(range(len(interchange_list))): - return code - - transform_code = ( - f'module attributes {{transform.with_named_sequence}} {{\n' - f' transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{\n' - f' %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' - f' %gen_op_{operation_tag} = transform.structured.generalize %op_{operation_tag} : (!transform.any_op) -> !transform.any_op\n' - f' %interchanged_op = transform.structured.interchange %gen_op_{operation_tag} iterator_interchange = {str(interchange_list)} : (!transform.any_op) -> !transform.any_op\n' - f' %interchanged_tag = transform.param.constant "{operation_tag}" -> !transform.any_param\n' - f' transform.annotate %interchanged_op "tag" = %interchanged_tag : !transform.any_op, !transform.any_param\n' - f' transform.yield\n' - f' }}\n' - f'}}\n' - ) - - return __run_transform_code(code, transform_code) - - -def transform_vectorize_img2col(code: str, operation_tag: str): - """Apply the vectorization transformation with img2col to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - transform_code = f""" -module attributes {{transform.with_named_sequence}} {{ -transform.named_sequence @__transform_main(%variant_op: !transform.any_op {{transform.readonly}}) -{{ - - // %conv_gen_2 = transform.structured.match attributes{{tag = "{operation_tag}"}} in %variant_op : (!transform.any_op) -> !transform.any_op - // %forall_op = transform.get_parent_op %conv_gen_2: (!transform.any_op) -> !transform.any_op - - %forall_op = transform.structured.match ops{{["scf.forall"]}} in %variant_op : (!transform.any_op) -> !transform.any_op - - - - %producer = transform.structured.match attributes{{tag = "img2col_producer"}} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.structured.fuse_into_containing_op %producer into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - %fb = transform.structured.match ops{{["func.func"]}} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %fb {{ - transform.apply_patterns.canonicalization - }} : !transform.any_op - transform.apply_cse to %fb : !transform.any_op - - - %original_fill = transform.structured.match ops{{["linalg.fill"]}} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.structured.fuse_into_containing_op %original_fill into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - %fb1 = transform.structured.match ops{{["func.func"]}} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %fb1 {{ - transform.apply_patterns.canonicalization - }} : !transform.any_op - transform.apply_cse to %fb1 : !transform.any_op - - - - %func = transform.structured.match ops{{["func.func"]}} in %variant_op - : (!transform.any_op) -> !transform.any_op - %func_0 = transform.structured.vectorize_children_and_apply_patterns %func {{vectorize_padding}} - : (!transform.any_op) -> (!transform.any_op) - - // Step 4. Vector backend - // ====================================================== - %f = transform.structured.match ops{{["func.func"]}} in %variant_op - : (!transform.any_op) -> !transform.any_op - - transform.apply_patterns to %f {{ - transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" - transform.apply_patterns.vector.transfer_permutation_patterns - transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" - transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true - transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 - transform.apply_patterns.vector.lower_shape_cast - transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" - transform.apply_patterns.canonicalization - }} : !transform.any_op - - - - transform.yield -}} -}} -""" - - return __run_transform_code(code, transform_code) - - -def transform_vectorize_children(code: str): - """Apply the vectorization transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - transform_code = """ - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) - { - %forall_op = transform.structured.match ops{["scf.forall"]} in %variant_op : (!transform.any_op) -> !transform.any_op - - %original_fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.structured.fuse_into_containing_op %original_fill into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - - %func = transform.structured.match ops{["func.func"]} in %variant_op: (!transform.any_op) -> !transform.any_op - %func_0 = transform.structured.vectorize_children_and_apply_patterns %func {vectorize_padding}: (!transform.any_op) -> (!transform.any_op) - - transform.yield - } - }""" - - return __run_transform_code(code, transform_code) - - -def transform_vectorize_with_vectorizer(code: str, operation_tag: str): - """Apply the vectorization transformation with vectorizer to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - vect_code_process = subprocess.run( - f'{os.getenv("VECTORIZER_BIN_PATH")} - {operation_tag}', - shell=True, - input=code.encode('utf-8'), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE - ) - vect_code = vect_code_process.stdout.decode('utf-8') - - if vect_code_process.returncode != 0: - raise Exception(vect_code_process.stderr.decode('utf-8')) - - return vect_code - - -def transform_vectorize(code: str, operation_tag: str): - """Apply the vectorization transformation with vectorizer to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - transform_code = f""" - module attributes {{transform.with_named_sequence}} {{ - transform.named_sequence @__transform_main(%arg0: !transform.any_op {{transform.readonly}}) {{ - %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %op_{operation_tag} : !transform.any_op - - %f = transform.structured.match ops{{["func.func"]}} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %f {{ - transform.apply_patterns.vector.transfer_permutation_patterns - transform.apply_patterns.vector.reduction_to_contract - transform.apply_patterns.canonicalization - transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers - }} : !transform.any_op - transform.yield - }} - }}""" - - return __run_transform_code(code, transform_code) - - -def transform_img2col(code: str, operation_tag: str): - """Apply the img2col transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - transform_code = f""" -module attributes {{transform.with_named_sequence}} {{ - transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{ - %op_operation = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op - - %a, %b = transform.structured.convert_conv2d_to_img2col %op_operation : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a_tag = transform.param.constant "img2col_producer" -> !transform.any_param - transform.annotate %a "tag" = %a_tag : !transform.any_op, !transform.any_param - - %matmul_op = transform.get_producer_of_operand %b[0]: (!transform.any_op) -> !transform.any_op - %matmul_op_tag = transform.param.constant "{operation_tag}" -> !transform.any_param - transform.annotate %matmul_op "tag" = %matmul_op_tag : !transform.any_op, !transform.any_param - - transform.yield - }} -}}""" - - return __run_transform_code(code, transform_code) - - -def transform_TF(code: str, consumer_tag: str, producer_tag: str, tiling_sizes: list[int], parallel_sizes: list[int]): - """Apply the tiling and fusion transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - consumer_tag (str): The tag of the operation to apply the transformation to. - producer_tag (str): the tag of the producer to fuse with - tiling_sizes (list[int]): The tiling size to apply. - parallel_sizes (list[int]): The parallel size to apply. - - Returns: - str: The code after applying the transformation. - """ - # If parallel sizes are all zeros, means no fusion will be done - if all([a == 0 for a in parallel_sizes]): - return code - - n_for_loops = sum([s != 0 for s in tiling_sizes]) - r = ', '.join(['!transform.any_op'] * n_for_loops) - tile_transform = f"transform.structured.tile_using_for %tiled_op_{consumer_tag} tile_sizes {str(tiling_sizes)} : (!transform.any_op) -> (!transform.any_op, {r})" - - transform_code = ( - f'\nmodule attributes {{transform.with_named_sequence}} {{\n' - f' transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{\n' - f' %op_{consumer_tag} = transform.structured.match attributes{{tag = "{consumer_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' - f' %tiled_op_{consumer_tag}, %forall_op_{consumer_tag} = transform.structured.tile_using_forall %op_{consumer_tag} tile_sizes {str(parallel_sizes)} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)\n' - f" {tile_transform if n_for_loops > 0 else ''}\n" - f' %op_{producer_tag} = transform.structured.match attributes{{tag = "{producer_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' - f' %fused, %containing = transform.structured.fuse_into_containing_op %op_{producer_tag} into %forall_op_{consumer_tag} : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)\n' - f' %fused_tag = transform.param.constant "{producer_tag}_{consumer_tag}" -> !transform.any_param\n' - f' transform.annotate %fused "tag" = %fused_tag : !transform.any_op, !transform.any_param\n' - f' transform.yield\n' - f' }}\n' - f'}}\n' - ) - - return __run_transform_code(code, transform_code) - - -def transform_decompose(code: str, operation_tag: str): - """Apply the decomposition transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - transform_code = f""" - module attributes {{transform.with_named_sequence}} {{ - transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{ - %conv = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op - %decomposed = transform.structured.decompose %conv: (!transform.any_op) -> !transform.any_op - %decomposed_tag = transform.param.constant "{operation_tag}" -> !transform.any_param - transform.annotate %decomposed "tag" = %decomposed_tag : !transform.any_op, !transform.any_param - transform.yield - }} - }}""" - - return __run_transform_code(code, transform_code) - - -def transform_transpose_conv_2d(code: str, operation_tag: str): - """Apply the Conv2D transpose transformation to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - transform_code = f""" - module attributes {{transform.with_named_sequence}} {{ - transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{ - %conv = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op - %transposed = transform.structured.transpose_conv2d %conv : (!transform.any_op) -> !transform.any_op - %transposed_tag = transform.param.constant "{operation_tag}" -> !transform.any_param - transform.annotate %transposed "tag" = %transposed_tag : !transform.any_op, !transform.any_param - transform.yield - }} - }}""" - - return __run_transform_code(code, transform_code) - - -def transform_bufferize_and_lower_v(code: str): - """Apply the vectorization transformation with vectorizer to the specified operation in the given code. - - Args: - code (str): The code to apply the transformation to. - operation_tag (str): The tag of the operation to apply the transformation to. - - Returns: - str: The code after applying the transformation. - """ - transform_code = """ - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) { - %all_loops = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.apply_licm to %all_loops : !transform.any_op - %f1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %f1 {transform.apply_patterns.canonicalization} : !transform.any_op - transform.structured.eliminate_empty_tensors %arg0 : !transform.any_op - %empty = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %empty1 = transform.cast %empty : !transform.any_op to !transform.op<"tensor.empty"> - transform.bufferization.empty_tensor_to_alloc_tensor %empty1 : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> - %arg1 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg0 {bufferize_function_boundaries = true} : (!transform.any_op) -> !transform.any_op - - %f = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %f { - transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" - transform.apply_patterns.vector.transfer_permutation_patterns - transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" - transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true - transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 - transform.apply_patterns.vector.lower_shape_cast - transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.yield - } - }""" - - return __run_transform_code(code, transform_code) - - -def __run_transform_code(code: str, transform_code: str): - with Context(): - module = Module.parse(code) - t_module = Module.parse(transform_code) - pm = PassManager.parse("builtin.module(canonicalize)") - interpreter.apply_named_sequence(module, t_module.body.operations[0], t_module) - pm.run(module.operation) - - return str(module) +import os +import re +import subprocess +from utils.log import print_error + + +from mlir.ir import Context, Module +from mlir.dialects.transform import interpreter +from mlir.passmanager import PassManager + +# ====================================== Transform dialect functions ====================================== + +def transform_dialect_TP(code: str, operation_tag: str, tiling_sizes: list[int], tmp_file_path: str): + """Apply the tiling and parallelization transformation to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tiling_size (list[int]): The tiling size to apply. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + # If tiling sizes are all zeros, means no tiling is needed + if all([a == 0 for a in tiling_sizes]): + return code + + code = code.strip() + + # Add full transform dialect code into the main code + transform_dialect_code = ( + f'\nmodule attributes {{transform.with_named_sequence}} {{\n' + f' transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{\n' + f' %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' + f' %op_tiled_{operation_tag}, %forall_{operation_tag} = transform.structured.tile_using_forall %op_{operation_tag} tile_sizes {str(tiling_sizes)} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)\n' + f' transform.yield\n' + f' }}\n' + f'}}' + ) + code = code + transform_dialect_code + + with open(tmp_file_path, "w") as file: + file.write(code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +def transform_dialect_tile(code: str, operation_tag: str, tiling_size: list[int], tmp_file_path: str): + """Apply the tiling transformation to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tiling_size (list[int]): The tiling size to apply. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + # If tiling sizes are all zeros, means no tiling is needed + if all([a == 0 for a in tiling_size]): + return code + + code = code.strip() + n_loops = sum([s != 0 for s in tiling_size]) + r = ', '.join(['!transform.any_op'] * n_loops) + assert n_loops > 0, "No loops to tile" + + transform_dilaect_code = ( + f'\nmodule attributes {{transform.with_named_sequence}} {{\n' + f' transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{\n' + f' %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' + f' %tiled_op_{operation_tag}, %loops:{n_loops} = transform.structured.tile_using_for %op_{operation_tag} tile_sizes {str(tiling_size)} : (!transform.any_op) -> (!transform.any_op, {r})\n' + f' transform.yield\n' + f' }}\n' + f'}}\n' + ) + + code = code + transform_dilaect_code + '\n' + + with open(tmp_file_path, "w") as file: + file.write(code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +def transform_dialect_interchange(code: str, operation_tag: str, interchange_list: list[int], tmp_file_path: str): + """Apply the interchange transformation to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + interchange_list (list[int]): The interchange list to apply. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + # If the permutation list is same as the identity permutation, means no interchange is needed + if interchange_list == list(range(len(interchange_list))): + return code + + code = code.strip() + + transform_dilaect_code = ( + f'module attributes {{transform.with_named_sequence}} {{\n' + f' transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{\n' + f' %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op\n' + f' %gen_op_{operation_tag} = transform.structured.generalize %op_{operation_tag} : (!transform.any_op) -> !transform.any_op\n' + f' %interchanged_op = transform.structured.interchange %gen_op_{operation_tag} iterator_interchange = {str(interchange_list)} : (!transform.any_op) -> !transform.any_op\n' + f' %interchanged_tag = transform.param.constant "{operation_tag}" -> !transform.any_param\n' + f' transform.annotate %interchanged_op "tag" = %interchanged_tag : !transform.any_op, !transform.any_param\n' + f' transform.yield\n' + f' }}\n' + f'}}\n' + ) + + code = code + transform_dilaect_code + '\n' + + with open(tmp_file_path, "w") as file: + file.write(code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +def transform_dialect_vectorize_img2col(code: str, operation_tag: str, tmp_file_path: str): + """Apply the vectorization transformation with img2col to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + code = code.strip() + + transform_dialect_code = f""" +module attributes {{transform.with_named_sequence}} {{ +transform.named_sequence @__transform_main(%variant_op: !transform.any_op {{transform.readonly}}) +{{ + + // %conv_gen_2 = transform.structured.match attributes{{tag = "{operation_tag}"}} in %variant_op : (!transform.any_op) -> !transform.any_op + // %forall_op = transform.get_parent_op %conv_gen_2: (!transform.any_op) -> !transform.any_op + + %forall_op = transform.structured.match ops{{["scf.forall"]}} in %variant_op : (!transform.any_op) -> !transform.any_op + + + + %producer = transform.structured.match attributes{{tag = "img2col_producer"}} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.structured.fuse_into_containing_op %producer into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fb = transform.structured.match ops{{["func.func"]}} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %fb {{ + transform.apply_patterns.canonicalization + }} : !transform.any_op + transform.apply_cse to %fb : !transform.any_op + + + %original_fill = transform.structured.match ops{{["linalg.fill"]}} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.structured.fuse_into_containing_op %original_fill into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fb1 = transform.structured.match ops{{["func.func"]}} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %fb1 {{ + transform.apply_patterns.canonicalization + }} : !transform.any_op + transform.apply_cse to %fb1 : !transform.any_op + + + + %func = transform.structured.match ops{{["func.func"]}} in %variant_op + : (!transform.any_op) -> !transform.any_op + %func_0 = transform.structured.vectorize_children_and_apply_patterns %func {{vectorize_padding}} + : (!transform.any_op) -> (!transform.any_op) + + // Step 4. Vector backend + // ====================================================== + %f = transform.structured.match ops{{["func.func"]}} in %variant_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f {{ + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" + transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 + transform.apply_patterns.vector.lower_shape_cast + transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" + transform.apply_patterns.canonicalization + }} : !transform.any_op + + + + transform.yield +}} +}} +""".strip() + + code = code + '\n' + transform_dialect_code + '\n' + + with open(tmp_file_path, "w") as file: + file.write(code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +def transform_dialect_vectorize_children(code: str, operation_tag: str, tmp_file_path: str): + """Apply the vectorization transformation to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + code = code.strip() + + transform_dialect_code = """ + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) + { + %forall_op = transform.structured.match ops{["scf.forall"]} in %variant_op : (!transform.any_op) -> !transform.any_op + + %original_fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.structured.fuse_into_containing_op %original_fill into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + %func = transform.structured.match ops{["func.func"]} in %variant_op: (!transform.any_op) -> !transform.any_op + %func_0 = transform.structured.vectorize_children_and_apply_patterns %func {vectorize_padding}: (!transform.any_op) -> (!transform.any_op) + + transform.apply_patterns to %func_0 { + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" + transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 + transform.apply_patterns.vector.lower_shape_cast + transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" + transform.apply_patterns.canonicalization + } : !transform.any_op + + transform.yield + } + }""".strip() + + code = code + '\n' + transform_dialect_code + '\n' + + with open(tmp_file_path, "w") as file: + file.write(code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +def transform_dialect_vectorize_with_vectorizer(code: str, operation_tag: str, tmp_file_path: str): + """Apply the vectorization transformation with vectorizer to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + code = code.strip() + + vect_code_process = subprocess.run( + f'{os.getenv("VECTORIZER_BIN_PATH")} - {operation_tag}', + shell=True, + input=code.encode('utf-8'), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + vect_code = vect_code_process.stdout.decode('utf-8') + + if vect_code_process.returncode != 0: + print_error(f"Vectorizer failed with error: {vect_code_process.stderr.decode('utf-8')}") + return '' + + # If vectorizer succeeded apply vectorization patterns else return empty string + if vect_code: + transform_dialect_code = """ + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{[\"func.func\"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" + transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 + transform.apply_patterns.vector.lower_shape_cast + transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.yield + } + }""".strip() + + full_code = vect_code + '\n' + transform_dialect_code + '\n' + + with open(tmp_file_path, "w") as file: + file.write(full_code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + else: + return '' + + +def transform_dialect_vectorize(code: str, operation_tag: str, tmp_file_path: str): + """Apply the vectorization transformation with vectorizer to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + code = code.strip() + + transform_dialect_code = f""" + module attributes {{transform.with_named_sequence}} {{ + transform.named_sequence @__transform_main(%arg0: !transform.any_op {{transform.consumed}}) {{ + %all_loops = transform.structured.match interface{{LoopLikeInterface}} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops : !transform.any_op + %f1 = transform.structured.match ops{{["func.func"]}} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f1 {{transform.apply_patterns.canonicalization}} : !transform.any_op + transform.structured.eliminate_empty_tensors %arg0 : !transform.any_op + %empty = transform.structured.match ops{{["tensor.empty"]}} in %arg0 : (!transform.any_op) -> !transform.any_op + %empty1 = transform.cast %empty : !transform.any_op to !transform.op<"tensor.empty"> + transform.bufferization.empty_tensor_to_alloc_tensor %empty1 : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> + %arg1 = transform.bufferization.one_shot_bufferize layout{{IdentityLayoutMap}} %arg0 {{bufferize_function_boundaries = true}} : (!transform.any_op) -> !transform.any_op + + %op_{operation_tag} = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %op_{operation_tag} : !transform.any_op + + %f = transform.structured.match ops{{["func.func"]}} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f {{ + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.reduction_to_contract + transform.apply_patterns.canonicalization + transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers + }} : !transform.any_op + + transform.apply_patterns to %f {{ + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" + transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 + transform.apply_patterns.vector.lower_shape_cast + transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" + transform.apply_patterns.canonicalization + }} : !transform.any_op + transform.yield + }} + }}""".strip() + + full_code = code + '\n' + transform_dialect_code + '\n' + + with open(tmp_file_path, "w") as file: + file.write(full_code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +def transform_dialect_img2col(code: str, operation_tag: str, tmp_file_path: str): + """Apply the img2col transformation to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + code = code.strip() + + transform_dilaect_code = f""" +module attributes {{transform.with_named_sequence}} {{ + transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{ + %op_operation = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op + + %a, %b = transform.structured.convert_conv2d_to_img2col %op_operation : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a_tag = transform.param.constant "img2col_producer" -> !transform.any_param + transform.annotate %a "tag" = %a_tag : !transform.any_op, !transform.any_param + + %matmul_op = transform.get_producer_of_operand %b[0]: (!transform.any_op) -> !transform.any_op + %matmul_op_tag = transform.param.constant "{operation_tag}" -> !transform.any_param + transform.annotate %matmul_op "tag" = %matmul_op_tag : !transform.any_op, !transform.any_param + + transform.yield + }} +}}""".strip() + + code = code + transform_dilaect_code + + with open(tmp_file_path, "w") as file: + file.write(code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +# ========================================= Other functions ========================================= + +def apply_conv2d_decomposition(code: str, operation_tag: str, tmp_file_path: str): + """Apply the Conv2D decomposition transformation to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + tmp_file_path (str): The path to the temporary file to write the code to. + + Returns: + str: The code after applying the transformation. + """ + if not code: + return code + + code = code.strip() + transform_dialect_code = f""" + module attributes {{transform.with_named_sequence}} {{ + transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{ + %conv = transform.structured.match attributes{{tag = "{operation_tag}"}} in %arg1 : (!transform.any_op) -> !transform.any_op + %decomposed = transform.structured.decompose %conv: (!transform.any_op) -> !transform.any_op + %decomposed_tag = transform.param.constant "{operation_tag}" -> !transform.any_param + transform.annotate %decomposed "tag" = %decomposed_tag : !transform.any_op, !transform.any_param + transform.yield + }} + }}""" + + code = code + '\n' + transform_dialect_code + '\n' + + with open(tmp_file_path, "w") as file: + file.write(code) + + result = os.popen( + f"{os.getenv('LLVM_BUILD_PATH')}/bin/mlir-opt {tmp_file_path} -transform-interpreter -canonicalize -test-transform-dialect-erase-schedule", + ).read() + + result = result.replace("module {\n", "", 1) + result = ''.join(result.rsplit('\n}\n', 1)) + result = re.sub(r"module attributes \{transform.with_named_sequence\} \{\s+\}", "", result) + + return result + + +def transform_bufferize_and_lower_v(code: str): + """Apply the vectorization transformation with vectorizer to the specified operation in the given code. + + Args: + code (str): The code to apply the transformation to. + operation_tag (str): The tag of the operation to apply the transformation to. + + Returns: + str: The code after applying the transformation. + """ + transform_code = """ + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) { + %all_loops = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops : !transform.any_op + %f1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f1 {transform.apply_patterns.canonicalization} : !transform.any_op + transform.structured.eliminate_empty_tensors %arg0 : !transform.any_op + %empty = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %empty1 = transform.cast %empty : !transform.any_op to !transform.op<"tensor.empty"> + transform.bufferization.empty_tensor_to_alloc_tensor %empty1 : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> + %arg1 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg0 {bufferize_function_boundaries = true} : (!transform.any_op) -> !transform.any_op + + %f = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" + transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 + transform.apply_patterns.vector.lower_shape_cast + transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.yield + } + }""" + + return __run_transform_code(code, transform_code) + +def __run_transform_code(code: str, transform_code: str): + with Context(): + module = Module.parse(code) + t_module = Module.parse(transform_code) + pm = PassManager.parse("builtin.module(canonicalize)") + interpreter.apply_named_sequence(module, t_module.body.operations[0], t_module) + pm.run(module.operation) + + return str(module) \ No newline at end of file diff --git a/scripts/.gitignore b/scripts/.gitignore deleted file mode 100644 index d177d9b..0000000 --- a/scripts/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -# Ignore everything in this directory -* -# Except these files -!.gitignore -!train_example.sh -!neptune-sync.sh diff --git a/scripts/neptune-sync.sh b/scripts/neptune-sync.sh deleted file mode 100644 index ac7c490..0000000 --- a/scripts/neptune-sync.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# Define the resource requirements here using #SBATCH - -#SBATCH -p compute -#SBATCH --nodes=1 -#SBATCH -c 4 -#SBATCH --mem=16G -#SBATCH -t 07-00 -#SBATCH --mail-type=FAIL,TIME_LIMIT -#SBATCH --mail-user=mt5383@nyu.edu - -# Resource requiremenmt commands end here - -#Add the lines for running your code/application -module load miniconda-nobashrc -eval "$(conda shell.bash hook)" - -# Activate any environments if required -conda activate testenv - -# Execute the code -python $SCRATCH/MLIR-RL/neptune_sync.py diff --git a/scripts/train_example.sh b/scripts/train_example.sh deleted file mode 100644 index 54e544d..0000000 --- a/scripts/train_example.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# Define the resource requirements here using #SBATCH - -#SBATCH -p compute -#SBATCH --reservation=c2 -#SBATCH --exclusive -#SBATCH --nodes=1 -#SBATCH -c 28 -#SBATCH --mem=64G -#SBATCH -t 07-00 -#SBATCH -o /scratch/$NYU_NET_ID/MLIR-RL/logs/train.out -#SBATCH -e /scratch/$NYU_NET_ID/MLIR-RL/logs/train.err - -# Resource requiremenmt commands end here - -#Add the lines for running your code/application -module load miniconda-nobashrc -eval "$(conda shell.bash hook)" - -# Activate any environments if required -conda activate $CONDA_ENV_NAME - -# Set config file path -export OMP_NUM_THREADS=12 -export CONFIG_FILE_PATH=config/example.json -# Execute the code -python /scratch/$NYU_NET_ID/MLIR-RL/train.py diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000..1ff3055 --- /dev/null +++ b/setup.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +# Set project root +PROJECT_ROOT=$(pwd) +VENV_PATH="${PROJECT_ROOT}/mlir-venv" + +# Step 0: Create and activate venv +python3.11 -m venv ${VENV_PATH} +source ${VENV_PATH}/bin/activate + +# Upgrade pip in venv +pip install --upgrade pip + +# Step 1: Install project Python requirements in venv +pip install -r requirements.txt + +# Step 2: Clone and build MLIR in venv +git clone --depth 1 -b release/19.x https://github.com/llvm/llvm-project.git +cd llvm-project +sudo mkdir build +sudo cd build +sudo cmake -S ../llvm -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ +-DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=X86 -DLLVM_ENABLE_ASSERTIONS=ON \ +-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_ENABLE_LLD=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ +-DPython3_EXECUTABLE=${VENV_PATH}/bin/python +sudo cmake --build . --target check-mlir +sudo cmake --build . --target check-mlir-python # Extra test for bindings +cd ${PROJECT_ROOT} + +# Step 2.1: In case OMP didn't build +cd llvm-project/build +sudo ninja omp + + +# Step 3: Install MLIR Python binding requirements in venv +cd llvm-project/mlir/python +sudo pip install -r requirements.txt # This includes NumPy +cd ${PROJECT_ROOT} + + +# For MLIR specific project + +# Step 4: Build AstDumper if directory exists +if [ -d "tools/ast_dumper" ]; then + cd tools/ast_dumper + mkdir build + cd build + sudo cmake -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_DIR=${PROJECT_ROOT}/llvm-project/build/lib/cmake/llvm \ + -DMLIR_DIR=${PROJECT_ROOT}/llvm-project/build/lib/cmake/mlir \ + -DPython3_EXECUTABLE=${VENV_PATH}/bin/python \ + .. + sudo cmake --build . + cd ${PROJECT_ROOT} +else + echo "Warning: tools/ast_dumper not found. Skipping build. Ensure project repo is cloned correctly." +fi + +# Step 5: Build Vectorizer if directory exists +if [ -d "tools/vectorizer" ]; then + cd tools/vectorizer + mkdir build + cd build + sudo cmake -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_DIR=${PROJECT_ROOT}/llvm-project/build/lib/cmake/llvm \ + -DMLIR_DIR=${PROJECT_ROOT}/llvm-project/build/lib/cmake/mlir \ + -DPython3_EXECUTABLE=${VENV_PATH}/bin/python \ + .. + sudo cmake --build . + cd ${PROJECT_ROOT} +else + echo "Warning: tools/vectorizer not found. Skipping build. Ensure project repo is cloned correctly." +fi + +# Step 6: Create .env file with venv activation +cat << EOF > .env +# Activate virtual environment +source ${PROJECT_ROOT}/mlir-venv/bin/activate + +# Add MLIR binaries to PATH +export PATH=${PROJECT_ROOT}/llvm-project/build/bin:$PATH + +# Add MLIR Python bindings to PYTHONPATH +export PYTHONPATH=${PROJECT_ROOT}/llvm-project/build/tools/mlir/python_packages/mlir_core:$PYTHONPATH + + +export NEPTUNE_PROJECT="" + +export NEPTUNE_TOKEN="" + +export LLVM_BUILD_PATH=${PROJECT_ROOT}/llvm-project/build + +export MLIR_SHARED_LIBS=${PROJECT_ROOT}/llvm-project/build/lib/libomp.so,/home/ouail/nyuad-internship/llvm-project/build/lib/libmlir_c_runner_utils.so,/home/ouail/nyuad-internship/llvm-project/build/lib/libmlir_runner_utils.so + +export AST_DUMPER_BIN_PATH=${PROJECT_ROOT}/tools/ast_dumper/build/bin/AstDumper + +export VECTORIZER_BIN_PATH=${PROJECT_ROOT}/tools/vectorizer/build/bin/Vectorizer + +EOF + +echo "Setup complete. Source the env with: source .env (this activates the venv too)." +echo "Edit .env for Neptune if needed." +echo "If tools were skipped, clone the project repo and re-run." +deactivate # Deactivate venv after script (source .env will reactivate) \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..ae80a2b --- /dev/null +++ b/test.py @@ -0,0 +1,11 @@ +from torch.utils.tensorboard import SummaryWriter +from collections import defaultdict +from utils.data_collector import OfflineDataset + + +dt = OfflineDataset(save_dir="offline_dataset",fname="offline_dataset_online_ppo.npz") + + +a = dt.load() + +print(a[''][100:150]) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/benchmarks/matmul.mlir b/tests/benchmarks/matmul.mlir new file mode 100644 index 0000000..ffbeb30 --- /dev/null +++ b/tests/benchmarks/matmul.mlir @@ -0,0 +1,8 @@ +func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } +func.func @main(%arg0: memref<256x128xf64>, %arg1: memref<128x256xf64>, %arg2: memref<256x256xf64>) -> i64 attributes { llvm.emit_c_interface } { + %t0 = func.call @nanoTime() : () -> i64 + linalg.matmul ins(%arg0, %arg1 : memref<256x128xf64>, memref<128x256xf64>) outs(%arg2 : memref<256x256xf64>) + %t1 = func.call @nanoTime() : () -> i64 + %t2 = arith.subi %t1, %t0 : i64 + return %t2 : i64 +} diff --git a/tests/inference.py b/tests/inference.py new file mode 100644 index 0000000..abf9763 --- /dev/null +++ b/tests/inference.py @@ -0,0 +1,75 @@ +from dotenv import load_dotenv +load_dotenv(override=True) + +import os +import torch +import numpy as np + +from rl_autoschedular.env import Env +from rl_autoschedular.observation import Observation +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.model import HiearchyModel + + +def run_inference(model_ckpt: str, bench_idx: int = 10, repeat: int = 5): + # === Create evaluation environment === + env = Env(is_training=False) + state = env.reset(bench_idx=bench_idx) + + # === Load model === + model = HiearchyModel() + checkpoint = torch.load(model_ckpt, map_location="cpu") + model.load_state_dict(checkpoint, strict=False) # allow partial load + model.eval() + + print(f"Loaded checkpoint: {model_ckpt}") + print(f"Running inference on benchmark index {bench_idx}") + + seq = [] + bench_done = False + total_reward = 0.0 + acc_values = [] + + while not bench_done: + obs = Observation.from_state(state) + + with torch.no_grad(): + dists = model.policy_model(obs) + action_index = ActionSpace.sample(obs, dists, dists, greedy=True) + + action = ActionSpace.action_by_index(action_index[0], state) + seq.append(action) + + # repeat execution for stability + tmp_rewards, tmp_accs = [], [] + for _ in range(repeat): + next_state, reward, op_done, acc = env.step(state, action) + tmp_rewards.append(reward) + if acc is not None: + tmp_accs.append(acc) + + # use median to avoid outliers + reward = np.median(tmp_rewards) + acc = np.median(tmp_accs) if tmp_accs else None + + total_reward += reward + + if op_done: + next_state, bench_done = env.get_next_op_state(next_state) + + state = next_state + + if acc is not None: + acc_values.append(acc) + + avg_acc = np.median(acc_values) if acc_values else None + + print("\n=== Inference finished ===") + print("Sequence of actions:", seq) + print("Total reward:", total_reward) + print("Median acceleration:", avg_acc) + + +if __name__ == "__main__": + model_ckpt = "./tests/checkpoints/model.pth" # adjust path + run_inference(model_ckpt, bench_idx=0, repeat=5) diff --git a/tests/test_action.py b/tests/test_action.py new file mode 100644 index 0000000..ecdbfee --- /dev/null +++ b/tests/test_action.py @@ -0,0 +1,132 @@ +# test_action_space.py +from dotenv import load_dotenv +load_dotenv(override=True) + +import os +import torch + +from rl_autoschedular.state import ( + extract_bench_features_from_file, + OperationState, +) +from rl_autoschedular.observation import ( + Observation, + OpFeatures, + ActionHistory, + NumLoops, + ActionMask, +) +from rl_autoschedular.actions import ActionSpace + +from rl_autoschedular.actions.tiled_parallelization import TiledParallelization + + + + + +def main(): + bench_name = "matmul" + file_path = "./tests/benchmarks/matmul.mlir" + root_exec_time = 1000 # dummy baseline + + print("=== Extract Benchmark Features ===") + bench_features = extract_bench_features_from_file( + bench_name, file_path, root_exec_time + ) + first_tag = bench_features.operation_tags[0] + first_op = bench_features.operations[first_tag] + + state = OperationState( + bench_name=bench_features.bench_name, + operation_tag=first_tag, + operation_features=first_op, + validated_code=bench_features.code, + transformed_code=first_op.raw_operation, + step_count=0, + exec_time=bench_features.root_exec_time, + transformation_history=[[]], + tmp_file="tmp.mlir", + terminal=False + ) + + print("\n=== ActionSpace Basic Info ===") + print("Supported actions:", [a.__name__ for a in ActionSpace.supported_actions]) + print("Size:", ActionSpace.size()) + print("Cumulative param sizes:", ActionSpace.cumulative_params_sizes()) + print("Cumulative mask sizes:", ActionSpace.cumulative_mask_sizes()) + print("Cumulative history sizes:", ActionSpace.cumulative_history_sizes()) + + print("\n=== Action Lookup ===") + for i, act in enumerate(ActionSpace.supported_actions): + print(f"Index {i}: {act.__name__}, number={ActionSpace.action_number(act)}, symbol={act.symbol}") + + # Try lookup by symbol + sym = ActionSpace.supported_actions[0].symbol + print(f"Symbol '{sym}' resolves to:", ActionSpace.action_type_by_symbol(sym).__name__) + print(f"Symbol '{sym}' has number:", ActionSpace.action_number_by_symbol(sym)) + + print("\n=== Action By Index ===") + cum_sizes = ActionSpace.cumulative_params_sizes() + for i, act in enumerate(ActionSpace.supported_actions): + # construct a dummy index tensor with selection + params + index = torch.zeros(cum_sizes[-1], dtype=torch.long) + index[0] = i # select action + action = ActionSpace.action_by_index(index, state) + print(f"Constructed action {action} from index {i}") + + print("\n=== Action Mask ===") + mask = ActionSpace.action_mask(state) + print("Action mask:", mask.tolist(), " length:", len(mask)) + + ''' + # what does the class say about the current state? + print("\n=== TP Mask ===") + tp_action_mask = TiledParallelization.action_mask(state) + print("TP mask:", tp_action_mask.tolist()) + ''' + + + + + print("\n=== Action History ===") + history = ActionSpace.action_history(state) + print("History tensor:", history.tolist() if history.numel() else "empty") + + print("\n=== Observation Integration ===") + obs = Observation.from_state(state) + action_mask = Observation.get_part(obs, ActionMask) + print("ActionMask from obs shape:", action_mask.shape) + + print("\n=== Distributions ===") + selection_logits = torch.randn(1, ActionSpace.size()) + actions_logits = [torch.randn(1, size) if size > 0 else None + for size in [a.mask_size() for a in ActionSpace.supported_actions]] + + print('Actions logits sizes:', [logits.shape if logits is not None else None for logits in actions_logits]) + dists = ActionSpace.distributions(obs, selection_logits, *actions_logits) + + print("Number of distributions:", len(dists)) + print('Dists sizes:', dists.shape) + + print("\n=== Uniform Distributions ===") + uniform_dists = ActionSpace.uniform_distributions(obs) + print("Number of uniform dists:", len(uniform_dists)) + + print("\n=== Distributions Stats ===") + index = ActionSpace.sample(obs, dists, uniform_dists) # sample an index + logp, entropy = ActionSpace.distributions_stats(dists, index, uniform_dists, eps=0.1) + print("Log prob:", logp) + print("Entropy:", entropy) + + print("\n=== Sampling ===") + sample1 = ActionSpace.sample(obs, dists, uniform_dists, greedy=True) + sample2 = ActionSpace.sample(obs, dists, uniform_dists, uniform=True) + sample3 = ActionSpace.sample(obs, dists, uniform_dists) + print("Greedy sample:", sample1.tolist()) + print("Uniform sample:", sample2.tolist()) + print("Random sample:", sample3.tolist()) + + + +if __name__ == "__main__": + main() diff --git a/tests/test_execution.py b/tests/test_execution.py new file mode 100644 index 0000000..dbb3f4a --- /dev/null +++ b/tests/test_execution.py @@ -0,0 +1,63 @@ +# test_execution.py +from dotenv import load_dotenv +load_dotenv(override=True) + +import os +import json +from rl_autoschedular.execution import Execution +from rl_autoschedular.actions import Action + +def main(): + bench_name = "matmul" + mlir_file_path = "./tests/benchmarks/matmul.mlir" # path to your mlir file + tmp_exec_file = "./tests/tmp_exec_data.json" # temporary cache file + + # Ensure cache file exists + if not os.path.exists(tmp_exec_file): + with open(tmp_exec_file, "w") as f: + json.dump({}, f) + + # Read MLIR code + with open(mlir_file_path, "r") as f: + mlir_code = f.read() + + print("=== Initializing Execution Manager ===") + exec_manager = Execution(tmp_exec_file) + + # Example transformation sequence (empty for testing) + seq = [[]] # No transformations applied + + print("\n=== Executing MLIR Code ===") + exec_time, success, cache_miss = exec_manager.execute_code( + mlir_code, + bench_name, + seq + ) + + print(f"Execution time: {exec_time} ns") + print(f"Success: {success}") + print(f"Cache miss: {cache_miss}") + + ''' + print("\n=== Running again to test cache ===") + exec_time_cached, success_cached, cache_miss_cached = exec_manager.execute_code( + mlir_code, + bench_name, + seq + ) + + print(f"Cached execution time: {exec_time_cached} ns") + print(f"Success: {success_cached}") + print(f"Cache miss: {cache_miss_cached} (should be False)") + + print("\n=== Updating cache with dummy data ===") + dummy_data = {bench_name: {exec_manager.get_code_cache_key(seq): exec_time}} + exec_manager.update_execution_cache(dummy_data) + print("Cache updated successfully.") + ''' + +if __name__ == "__main__": + if not os.path.exists("./tests/benchmarks/matmul.mlir"): + print("ERROR: MLIR benchmark file not found. Please place matmul.mlir in ./tests/benchmarks/") + else: + main() diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..ac1fa0b --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,69 @@ +# test_model.py +from dotenv import load_dotenv +load_dotenv(override=True) + +import os +import torch + +from rl_autoschedular.observation import Observation, OpFeatures, ActionHistory, NumLoops, ActionMask +from rl_autoschedular.state import extract_bench_features_from_file, OperationState +from rl_autoschedular.model import PolicyModel # adjust if path is different +from rl_autoschedular.actions import ActionSpace + + +def main(): + bench_name = "matmul" + file_path = "./tests/benchmarks/matmul.mlir" # path to your mlir file + root_exec_time = 1000 + + print("=== Load benchmark and make state ===") + bench_features = extract_bench_features_from_file( + bench_name, file_path, root_exec_time + ) + first_tag = bench_features.operation_tags[0] + first_op = bench_features.operations[first_tag] + + state = OperationState( + bench_name=bench_name, + operation_tag=first_tag, + operation_features=first_op, + validated_code=bench_features.code, + transformed_code=first_op.raw_operation, + step_count=0, + exec_time=bench_features.root_exec_time, + transformation_history=[[]], + tmp_file="tmp.mlir", + terminal=False, + ) + + obs = Observation.from_state(state) + + print("\n=== Build PolicyModel ===") + obs_parts = [OpFeatures, ActionHistory] + policy = PolicyModel(obs_parts) + + # random weights are already in place by default initialization + dists = policy(obs) + + print("\n=== Policy Outputs ===") + print("Number of distributions:", len(dists)) + for i, dist in enumerate(dists): + if dist is None: + print(f" Head {i}: None (no params)") + else: + print(f" Head {i}: dist type={type(dist).__name__}, batch shape={dist.batch_shape}, event shape={dist.event_shape}") + + print("\n=== Sample from Policy ===") + index = ActionSpace.sample(obs, dists, dists) # reuse same dists as eps_dists + + # action by index + action_index = ActionSpace.action_by_index(index[0],state) + print("Sampled index:", index) + print("Sampled action:", action_index) + + + + + +if __name__ == "__main__": + main() diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 0000000..b6e5cab --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,101 @@ +# test_state.py +from dotenv import load_dotenv +load_dotenv(override=True) + +import os +import torch +from rl_autoschedular.state import ( + extract_bench_features_from_file, + OperationState, +) +from rl_autoschedular.observation import ( + Observation, + OpFeatures, + ActionHistory, + NumLoops, + ActionMask, +) + + +def main(): + bench_name = "matmul" + file_path = "./tests/benchmarks/matmul.mlir" # path to your mlir file + root_exec_time = 1000 # dummy baseline time (ns) + + print("=== Extracting Benchmark Features ===") + bench_features = extract_bench_features_from_file( + bench_name, file_path, root_exec_time + ) + print("Benchmark name:", bench_features.bench_name) + print("Root exec time:", bench_features.root_exec_time) + print("Operation tags:", bench_features.operation_tags) + print("Number of operations:", len(bench_features.operations)) + + for tag, op in bench_features.operations.items(): + print(f"\n--- Operation: {tag} ---") + print("Type:", op.operation_type) + print("Vectorizable:", op.vectorizable) + print("Op counts:", op.op_count) + print("Load data:", op.load_data) + print("Store data:", op.store_data) + print("Nested loops:") + for loop in op.nested_loops: + print(f" {loop.arg} from {loop.lower_bound} to {loop.upper_bound} " + f"step {loop.step} [{loop.iterator_type}]") + + print("\n=== Testing OperationState ===") + # Just pick the first operation + first_tag = bench_features.operation_tags[0] + first_op = bench_features.operations[first_tag] + + op_state = OperationState( + bench_name=bench_features.bench_name, + operation_tag=first_tag, + operation_features=first_op, + validated_code=bench_features.code, + transformed_code=first_op.raw_operation, + step_count=0, + exec_time=bench_features.root_exec_time, + transformation_history=[[]], + tmp_file="tmp.mlir", + terminal=False + ) + + print("OperationState created:") + print(" Bench:", op_state.bench_name) + print(" Tag:", op_state.operation_tag) + print(" Exec time:", op_state.exec_time) + print(" Terminal:", op_state.terminal) + + print("\n=== Copy test ===") + op_state_copy = op_state.copy() + print("Copied state same tag?", op_state_copy.operation_tag == op_state.operation_tag) + + print("\n=== Observation Tests ===") + obs = Observation.from_state(op_state) + print("Observation shape:", obs.shape) + print("Total observation size (expected):", Observation.cumulative_sizes()[-1]) + + # Extract each part + op_features = Observation.get_part(obs, OpFeatures) + action_hist = Observation.get_part(obs, ActionHistory) + num_loops = Observation.get_part(obs, NumLoops) + action_mask = Observation.get_part(obs, ActionMask) + + print(" OpFeatures shape:", op_features.shape) + print(" ActionHistory shape:", action_hist.shape) + print(" NumLoops:", num_loops.item() if isinstance(num_loops, torch.Tensor) else num_loops) + print(" ActionMask shape:", action_mask.shape) + + # Check consistency + combined = Observation.get_parts(obs, OpFeatures, ActionHistory, NumLoops, ActionMask) + print(" Combined parts shape:", combined.shape) + print(" Matches full obs?", combined.shape[1] == obs.shape[1]) + + +if __name__ == "__main__": + # Ensure AST_DUMPER_BIN_PATH is set + if "AST_DUMPER_BIN_PATH" not in os.environ: + print("ERROR: Please set AST_DUMPER_BIN_PATH to your ast_dumper binary path.") + else: + main() diff --git a/tests/tmp.mlir b/tests/tmp.mlir new file mode 100644 index 0000000..a7e3d9a --- /dev/null +++ b/tests/tmp.mlir @@ -0,0 +1,15 @@ +func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } +func.func @main(%arg0: memref<256x128xf64>, %arg1: memref<128x256xf64>, %arg2: memref<256x256xf64>) -> i64 attributes { llvm.emit_c_interface } { + %t0 = func.call @nanoTime() : () -> i64 + linalg.matmul ins(%arg0, %arg1 : memref<256x128xf64>, memref<128x256xf64>) outs(%arg2 : memref<256x256xf64>) + %t1 = func.call @nanoTime() : () -> i64 + %t2 = arith.subi %t1, %t0 : i64 + return %t2 : i64 +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %op_operation_0 = transform.structured.match attributes{tag = "operation_0"} in %arg1 : (!transform.any_op) -> !transform.any_op + %op_tiled_operation_0, %forall_operation_0 = transform.structured.tile_using_forall %op_operation_0 tile_sizes [4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} \ No newline at end of file diff --git a/tests/transform_test.py b/tests/transform_test.py new file mode 100644 index 0000000..2e43b83 --- /dev/null +++ b/tests/transform_test.py @@ -0,0 +1,101 @@ +# tests/test_transforms.py +from dotenv import load_dotenv +load_dotenv(override=True) + +import os +from rl_autoschedular.transforms import ( + transform_dialect_TP, + transform_dialect_tile, + transform_dialect_interchange, + transform_dialect_vectorize, +) + +from rl_autoschedular.state import ( + extract_bench_features_from_file, + OperationState, +) +from rl_autoschedular.observation import ( + Observation, + OpFeatures, + ActionHistory, + NumLoops, + ActionMask, +) + +def main(): + file_path = "./tests/benchmarks/matmul.mlir" + tmp_file = "./tests/tmp.mlir" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Benchmark file not found: {file_path}") + + with open(file_path, "r") as f: + code = f.read() + + bench_name = "matmul" + root_exec_time = 1000 # dummy baseline time (ns) + '''''''''''''' + bench_features = extract_bench_features_from_file( + bench_name, file_path, root_exec_time + ) + print("Benchmark name:", bench_features.bench_name) + print("Root exec time:", bench_features.root_exec_time) + print("Operation tags:", bench_features.operation_tags) + print("Number of operations:", len(bench_features.operations)) + + for tag, op in bench_features.operations.items(): + print(f"\n--- Operation: {tag} ---") + print("Type:", op.operation_type) + print("Vectorizable:", op.vectorizable) + print("Op counts:", op.op_count) + print("Load data:", op.load_data) + print("Store data:", op.store_data) + print("Nested loops:") + for loop in op.nested_loops: + print(f" {loop.arg} from {loop.lower_bound} to {loop.upper_bound} " + f"step {loop.step} [{loop.iterator_type}]") + + print("\n=== Testing OperationState ===") + # Just pick the first operation + first_tag = bench_features.operation_tags[0] + first_op = bench_features.operations[first_tag] + + + + op_tag = first_tag # must match the tag used in your MLIR benchmark + + print("\n=== Original Code ===") + print(code) + + # 1. Tiling + Parallelization + transformed_tp = transform_dialect_TP( + code, op_tag, tiling_sizes=[4, 4], tmp_file_path=tmp_file + ) + print("\n=== After TP (Tiling + Parallelization) ===") + print(transformed_tp) + + ''' + # 2. Tiling (for-loops) + transformed_tile = transform_dialect_tile( + transformed_tp, op_tag, tiling_size=[2, 2], tmp_file_path=tmp_file + ) + print("\n=== After Tile ===") + print(transformed_tile) + + # 3. Interchange + transformed_interchange = transform_dialect_interchange( + transformed_tile, op_tag, interchange_list=[1, 0], tmp_file_path=tmp_file + ) + print("\n=== After Interchange ===") + print(transformed_interchange) + + # 4. Vectorize + transformed_vectorize = transform_dialect_vectorize( + transformed_interchange, op_tag, tmp_file_path=tmp_file + ) + print("\n=== After Vectorize ===") + print(transformed_vectorize) + ''' + +if __name__ == "__main__": + main() diff --git a/tmp-debug/.gitignore b/tmp-debug/.gitignore new file mode 100755 index 0000000..99e28de --- /dev/null +++ b/tmp-debug/.gitignore @@ -0,0 +1,5 @@ +# Ignore everything in this directory +* +# Except this file +!exec/ +!.gitignore \ No newline at end of file diff --git a/logs/.gitignore b/tmp-debug/exec/.gitignore old mode 100644 new mode 100755 similarity index 95% rename from logs/.gitignore rename to tmp-debug/exec/.gitignore index 86d0cb2..44c5ea8 --- a/logs/.gitignore +++ b/tmp-debug/exec/.gitignore @@ -1,4 +1,4 @@ -# Ignore everything in this directory -* -# Except this file +# Ignore everything in this directory +* +# Except this file !.gitignore \ No newline at end of file diff --git a/train.py b/train.py index c95786d..d1726fb 100644 --- a/train.py +++ b/train.py @@ -1,105 +1,109 @@ -# Load environment variables -import os -from dotenv import load_dotenv - -load_dotenv(override=True) -load_dotenv('.env.debug') - -# Import modules -import torch -from rl_autoschedular.execution import Execution -from rl_autoschedular.model import HiearchyModel as Model -from rl_autoschedular import device -from rl_autoschedular.trajectory import TrajectoryData -from rl_autoschedular.ppo import collect_trajectory, ppo_update, value_update, evaluate_benchmarks -from utils.log import print_info, print_success -from utils.config import Config -from utils.dask_manager import DaskManager -from utils.file_logger import FileLogger -from typing import Optional -from time import time -import datetime - - -# Initialize singleton classes -cfg = Config() -fl = FileLogger() -dm = DaskManager() - -# Load data to workers -train_data = dm.load_train_data() -eval_data = dm.load_eval_data() -main_exec_data = dm.load_main_exec_data() - -# Initialize execution singleton -Execution(fl.exec_data_file, main_exec_data) - -print_info(f"Config: {cfg}") -print_success(f'Logging to: {fl.run_dir}') -if cfg.main_exec_data_file: - print_info(f"Global execution data located in: {cfg.main_exec_data_file}") - -# Setup torch -torch.set_grad_enabled(False) -torch.set_num_threads(4) -if cfg.debug: - torch.autograd.set_detect_anomaly(True) - -# Initiate model -model = Model().to(device) -optimizer = torch.optim.Adam( - model.parameters(), - lr=cfg.lr -) -print_success("Model initialized") - -# Start training -old_trajectory: Optional[TrajectoryData] = None -time_ms = 0 -eta = 0 -for step in range(cfg.nb_iterations): - print_info(f"- Main Loop {step + 1}/{cfg.nb_iterations} ({100 * (step + 1) / cfg.nb_iterations:.2f}%) ({time_ms}ms) < ({eta})") - - main_start = time() - - # Collect trajectory using the model - trajectory = collect_trajectory(train_data, model, step) - - # Extend trajectory with previous trajectory - reuse_start = time() - if cfg.reuse_experience != 'none': - if old_trajectory is not None: - trajectory = old_trajectory + trajectory - old_trajectory = trajectory.copy() - reuse_end = time() - reuse_time_ms = int((reuse_end - reuse_start) * 1000) - print_info(f"Reuse time: {reuse_time_ms}ms") - - # Fit value model to trajectory rewards - if cfg.value_epochs > 0: - value_update(trajectory, model, optimizer) - - # Update policy model with PPO - ppo_update(trajectory, model, optimizer) - - # Save the model - if (step + 1) % 5 == 0: - torch.save( - model.state_dict(), - os.path.join( - fl.models_dir, - f'model_{step}.pt' - ) - ) - - if (step + 1) % 100 == 0: - print_info('- Evaluating benchmarks -') - evaluate_benchmarks(model, eval_data) - - main_end = time() - time_ms = int((main_end - main_start) * 1000) - eta = datetime.timedelta(seconds=time_ms * (cfg.nb_iterations - step - 1) / 1000) - -if (step + 1) % 100 != 0: - print_info('- Evaluating benchmarks -') - evaluate_benchmarks(model, eval_data) +# Load environment variables +from dotenv import load_dotenv +load_dotenv(override=True) + + +import torch +import os +from typing import Optional +from utils.log import print_info, print_success + +# Import environment +from rl_autoschedular.env import Env + +# config, file_logger, device +from rl_autoschedular import config as cfg, file_logger as fl, device + +# Import RL components +from rl_autoschedular.model import HiearchyModel as Model +from rl_autoschedular.trajectory import TrajectoryData +from rl_autoschedular.ppo import ( + collect_trajectory, + ppo_update, + value_update, + evaluate_benchmarks +) + +import time +torch.set_grad_enabled(False) +torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "4"))) + +if cfg.debug: + torch.autograd.set_detect_anomaly(True) + +print_info(f"Config: {cfg}") +print_success(f'Logging to: {fl.run_dir}') + +# Set environments + +# run_name for /tmp/ path +env = Env(is_training=True,run_name="online_ppo_data_collection") +eval_env = Env(is_training=False,run_name="online_ppo_data_collection") +print_success(f"Environments initialized: {env.tmp_file}") + +# Set model +model = Model().to(device) +optimizer = torch.optim.Adam( + model.parameters(), + lr=cfg.lr +) +print_success("Model initialized") + +train_start = time.perf_counter() +total_env_time = 0.0 +total_eval_time = 0.0 + +# Start training +for step in range(cfg.nb_iterations): + print_info(f"- Main Loop {step + 1}/{cfg.nb_iterations} ({100 * (step + 1) / cfg.nb_iterations:.2f}%)") + trajectory , env_time = collect_trajectory( + model, + env, + step, + ) + total_env_time += env_time + + + + # Fit value model to trajectory rewards + if cfg.value_epochs > 0: + value_update( + trajectory, + model, + optimizer, + step + ) + + ppo_update( + trajectory, + model, + optimizer, + step + ) + + if (step + 1) % 50 == 0: + torch.save( + model.state_dict(), + os.path.join( + env.tmp_file.replace('.mlir', ''), + f'model_{step}.pth' + ) + ) + + if (step + 1) % 50 == 0: + start_eval = time.perf_counter() + print_info('- Evaluating benchmark -') + eval_time = evaluate_benchmarks( + model, + eval_env, + step + ) + end_eval = time.perf_counter() + total_eval_time += end_eval - start_eval +train_end = time.perf_counter() +total_train_time = train_end - train_start - total_eval_time +print_success(f"- Training completed in {total_train_time:.2f} seconds") +print_success(f"- Evaluation completed in {total_eval_time:.2f} seconds") +print_success(f"- Total environment time: {total_env_time:.2f} seconds") +print_success(f"- Total eval env time: {total_eval_time:.2f} seconds") +print_success(f"- Percentage of time in environment: {100 * total_env_time / total_train_time:.2f}%") \ No newline at end of file diff --git a/train_iql.py b/train_iql.py new file mode 100644 index 0000000..825a742 --- /dev/null +++ b/train_iql.py @@ -0,0 +1,209 @@ +import dotenv + +from utils.log import print_info + +dotenv.load_dotenv() + +import os +import time +import torch +import numpy as np + +from rl_autoschedular import config as cfg, file_logger as fl +from rl_autoschedular.actions import ActionSpace +from rl_autoschedular.env import Env +from iql.iql_agent import IQLAgent +from utils.data_collector import OfflineDataset +from rl_autoschedular.observation import Observation,OpFeatures, ActionHistory + +from tqdm import trange + +device = torch.device("cpu") + + + +def load_dataset(): + """Load offline dataset from OfflineDataset singleton.""" + dataset = OfflineDataset( + save_dir=cfg.offline_data_save_dir, + fname=cfg.offline_data_file + ).load() + + if not dataset: + raise FileNotFoundError(f"Offline dataset not found: {cfg.offline_data_file}") + + states = torch.tensor(dataset["obs"], dtype=torch.float32) + actions = torch.tensor(dataset["actions"], dtype=torch.long) + rewards = torch.tensor(dataset["rewards"], dtype=torch.float32) + next_states = torch.tensor(dataset["next_obs"], dtype=torch.float32) + dones = torch.tensor(dataset["dones"], dtype=torch.float32) + + return states, actions, rewards, next_states, dones + + +@torch.no_grad() +def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): + """Evaluta a the model on the evaluation environment. + + Args: + model (Model): The policy/value model. + env (Env): The environment. + step (int): Current training step. + + Returns: + env_time (float): Time spent in environment steps. + """ + + + env_time = 0.0 # Time spent in environment steps + + eps = None + + + # store rewards and entropies to log average for the model accross the benchmarks later + all_speedups = [] + all_entropies = [] + + + for _ in trange(cfg.bench_count, desc='Trajectory'): + + t0 = time.perf_counter() + state = env.reset() + env_time += time.perf_counter() - t0 + bench_done = False + speedup = None + + # store rewards and entropies to log average for the current benchmark later + bench_rewards, bench_entropies = [], [] + + bench_name = state.bench_name + + + while not bench_done: + obs = Observation.from_state(state) + + # Sample action and log-prob from *current policy* + action_index, action_log_p, entropy = model.sample(obs.to(device), greedy=True) + assert action_index.size(0) == 1 and action_log_p.size(0) == 1 + action = ActionSpace.action_by_index(action_index[0], state) + + # Step environment + t0 = time.perf_counter() + next_state, reward, op_done, speedup = env.step(state, action) + env_time += time.perf_counter() - t0 + next_obs = Observation.from_state(next_state) + + + if op_done: + t0 = time.perf_counter() + next_state, bench_done = env.get_next_op_state(next_state) + env_time += time.perf_counter() - t0 + + + # Accumulate metrics + bench_rewards.append(reward) + bench_entropies.append(entropy.item()) + state = next_state + + # === Per-benchmark logging === + mean_reward = float(np.mean(bench_rewards)) if bench_rewards else 0.0 + mean_entropy = float(np.mean(bench_entropies)) if bench_entropies else 0.0 + + all_speedups.append(speedup) + all_entropies.extend(bench_entropies) + + + bench_metrics = { + "mean_reward": mean_reward, + "mean_entropy": mean_entropy, + "final_speedup": speedup if speedup is not None else 0.0, + } + + fl.log_scalars(f"eval/{bench_name}", bench_metrics, step) + + print( + f"\033[92m\n- Eval Bench: {bench_name}\n" + f"- Mean Reward: {mean_reward:.4f}\n" + f"- Mean Entropy: {mean_entropy:.4f}\n" + f"- Final Speedup: {speedup if speedup is not None else 0.0:.4f}\033[0m" + ) + + + # === Global logging (across all benchmarks) === + if all_speedups: + fl.log_scalar("eval/average_speedup", float(np.mean(all_speedups)), step) + if all_entropies: + fl.log_scalar("eval/average_entropy", float(np.mean(all_entropies)), step) + + return env_time + +def train_iql(): + # Load offline dataset + print(f"Loading offline dataset from {cfg.offline_data_file} ...") + states, actions, rewards, next_states, dones = load_dataset() + dataset_size = states.shape[0] + print(f"Dataset loaded: {dataset_size} transitions") + + # Initialize IQL agent + agent = IQLAgent(cfg,device,obs_parts=[OpFeatures, ActionHistory]) + + eval_env = Env(is_training=False,run_name=cfg.run_name) + + + print("Starting IQL training ...") + start_time = time.time() + + step = 0 + iql_trange = trange(cfg.max_steps, desc="IQL Training",dynamic_ncols=True) + for step in iql_trange: + # Sample a random batch + idxs = np.random.randint(0, dataset_size, size=cfg.batch_size) + batch = ( + states[idxs].to(device), + actions[idxs].to(device), + rewards[idxs].to(device), + next_states[idxs].to(device), + dones[idxs].to(device), + ) + + losses = agent.update(batch) + + + # Only log occasionally to reduce disk I/O + if step % 50 == 0: + fl.log_scalars("train", losses, step) + + if (step +1) % 100 == 0: + elapsed = time.time() - start_time + iql_trange.set_postfix({ + "Value Loss": f"{losses['value']:.4f}", + "Q Loss": f"{losses['q']:.4f}", + "Policy Loss": f"{losses['policy']:.4f}", + "Elapsed": f"{elapsed:.2f}s" + }) + + # Evaluate the agent on benchmarks every 1000 steps + if (step + 1) % 1000 == 0: + print("Evaluating on benchmarks ...") + eval_start = time.time() + env_time = evaluate_benchmarks(agent, eval_env, step) + eval_time = time.time() - eval_start + print(f"Evaluation completed in {eval_time:.2f} seconds (env time: {env_time:.2f} seconds)") + fl.flush() + + + + if (step+1) % 2000 == 0 and step > 0: + ckpt_path = os.path.join(cfg.results_dir, f"iql_step_{step}.pt") + os.makedirs(cfg.results_dir, exist_ok=True) + torch.save(agent.state_dict(), ckpt_path) + print(f"Checkpoint saved: {ckpt_path}") + + + + total_time = time.time() - start_time + print(f"Training finished in {total_time:.2f} seconds.") + + +if __name__ == "__main__": + train_iql() diff --git a/utils/config.py b/utils/config.py old mode 100644 new mode 100755 index 539a3e8..da14f7d --- a/utils/config.py +++ b/utils/config.py @@ -1,103 +1,284 @@ -from typing import Literal, Any -from utils.singleton import Singleton -from typeguard import check_type, CollectionCheckStrategy -import json -import os - - -class Config(metaclass=Singleton): - """Class to store and load global configuration""" - - max_num_stores_loads: int - """The maximum number of loads in the nested loops""" - max_num_loops: int - """The max number of nested loops""" - max_num_load_store_dim: int - """The max number of dimensions in load/store buffers""" - num_tile_sizes: int - """The number of tile sizes""" - vect_size_limit: int - """Vectorization size limit to prevent large sizes vectorization""" - order: list[list[str]] - """The order of actions that needs to bo followed""" - interchange_mode: Literal['enumerate', 'pointers', 'continuous'] - """The method used for interchange action""" - exploration: list[Literal['entropy', 'epsilon']] - """The exploration method""" - init_epsilon: float - """The initial epsilon value for epsilon greedy exploration""" - new_architecture: bool - """Flag to indicate if the new architecture should be used or not""" - normalize_bounds: Literal['none', 'max', 'log'] - """Flag to indicate if the upper bounds in the input should be normalized or not""" - normalize_adv: Literal['none', 'standard', 'max-abs'] - """The advantage normalization method""" - sparse_reward: bool - """Flag to enable sparse reward""" - split_ops: bool - """Flag to enable splitting operations into separate benchmarks""" - reuse_experience: Literal['none', 'random', 'topk'] - """Strategy for experience replay""" - activation: Literal["relu", "tanh"] - """The activation function to use in the network""" - benchmarks_folder_path: str - """Path to the benchmarks folder. Can be empty if optimization mode is set to "last".""" - bench_count: int - """Number of batches in a trajectory""" - replay_count: int - """Number of trajectories to keep in the replay buffer""" - nb_iterations: int - """Number of iterations""" - ppo_epochs: int - """Number of epochs for PPO""" - ppo_batch_size: int - """Batch size for PPO""" - value_epochs: int - """Number of epochs for value update""" - value_batch_size: int - """Batch size for value update""" - value_coef: float - """Value coefficient""" - value_clip: bool - """Clip value loss or not""" - entropy_coef: float - """Entropy coefficient""" - lr: float - """Learning rate""" - truncate: int - """Maximum number of steps in the schedule""" - json_file: str - """Path to the JSON file containing the benchmarks execution times.""" - eval_json_file: str - """Path to the JSON file containing the benchmarks execution times for evaluation.""" - tags: list[str] - """List of tags to add to the neptune experiment""" - debug: bool - """Flag to enable debug mode""" - main_exec_data_file: str - """Path to the file containing the execution data""" - results_dir: str - """Path to the results directory""" - - def __init__(self): - """Load the configuration from the JSON file - or get existing instance if any. - """ - # Open the JSON file - with open(os.getenv("CONFIG_FILE_PATH"), "r") as f: - config_data: dict[str, Any] = json.load(f) - - for element, element_t in self.__annotations__.items(): - if element not in config_data: - raise ValueError(f"{element} is missing from the config file") - - element_v = check_type(config_data[element], element_t, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS) - setattr(self, element, element_v) - - def to_dict(self): - """Convert the configuration to a dictionary.""" - return {k: self.__dict__[k] for k in self.__annotations__} - - def __str__(self): - """Convert the configuration to a string.""" - return str(self.to_dict()) +import os +from utils.singleton import Singleton +import json +from typing import Literal +from dotenv import load_dotenv + +load_dotenv(override=True) + + + +class Config(metaclass=Singleton): + """Class to store and load global configuration""" + max_num_stores_loads: int + """The maximum number of loads in the nested loops""" + max_num_loops: int + """The max number of nested loops""" + max_num_load_store_dim: int + """The max number of dimensions in load/store buffers""" + num_tile_sizes: int + """The number of tile sizes""" + vect_size_limit: int + """Vectorization size limit to prevent large sizes vectorization""" + order: list[list[str]] + """The order of actions that needs to bo followed""" + interchange_mode: Literal['enumerate', 'pointers', 'continuous'] + """The method used for interchange action""" + exploration: list[Literal['entropy', 'epsilon']] + """The exploration method""" + init_epsilon: float + """The initial epsilon value for epsilon greedy exploration""" + new_architecture: bool + """Flag to indicate if the new architecture should be used or not""" + normalize_bounds: Literal['none', 'max', 'log'] + """Flag to indicate if the upper bounds in the input should be normalized or not""" + normalize_adv: Literal['none', 'standard', 'max-abs'] + """The advantage normalization method""" + sparse_reward: bool + """Flag to enable sparse reward""" + split_ops: bool + """Flag to enable splitting operations into separate benchmarks""" + reuse_experience: bool + """Flag to enable reusing experience""" + activation: Literal["relu", "tanh"] + """The activation function to use in the network""" + benchmarks_folder_path: str + """Path to the benchmarks folder. Can be empty if optimization mode is set to "last".""" + bench_count: int + """Number of batches in a trajectory""" + nb_iterations: int + """Number of iterations""" + ppo_epochs: int + """Number of epochs for PPO""" + ppo_batch_size: int + """Batch size for PPO""" + value_epochs: int + """Number of epochs for value update""" + value_batch_size: int + """Batch size for value update""" + value_coef: float + """Value coefficient""" + value_clip: bool + """Clip value loss or not""" + entropy_coef: float + """Entropy coefficient""" + lr: float + """Learning rate""" + truncate: int + """Maximum number of steps in the schedule""" + json_file: str + """Path to the JSON file containing the benchmarks execution times.""" + eval_json_file: str + """Path to the JSON file containing the benchmarks execution times for evaluation.""" + tags: list[str] + """List of tags to add to the neptune experiment""" + debug: bool + """Flag to enable debug mode""" + exec_data_file: str + """Path to the file containing the execution data""" + results_dir: str + """Path to the results directory""" + run_name: str + """Name of the current run/experiment""" + collect_offline_data : bool + """Flag to indicate if offline data should be collected""" + offline_data_save_dir:str + """Directory to save offline data if collection is enabled""" + offline_data_file:str + """Filename for the offline dataset""" + + loaded: bool + """Flag to check if the config was already loaded from JSON file or not""" + + ''' New parameters for IQL ''' + gamma: float + """Discount factor for future rewards""" + tau: float + """Expectile parameter for value function update""" + beta: float + """Temperature parameter for policy update""" + alpha: float + """Polyak averaging factor for target network updates""" + batch_size: int + """Batch size for training updates""" + max_steps: int + """Maximum number of training steps for IQL""" + target_update_freq: int + """Frequency of target network updates""" + lr : dict + """Learning rates for different components: value, q, policy""" + + + + def __init__(self): + """Initialize the default values""" + self.max_num_stores_loads = 7 + self.max_num_loops = 7 + self.max_num_load_store_dim = 7 + self.num_tile_sizes = 7 + self.vect_size_limit = 512 + self.order = [] + self.interchange_mode = "enumerate" + self.exploration = ["entropy"] + self.init_epsilon = 0.1 + self.new_architecture = False + self.normalize_bounds = 'max' + self.normalize_adv = 'standard' + self.sparse_reward = True + self.split_ops = False + self.reuse_experience = False + self.activation = "relu" + self.benchmarks_folder_path = "" + self.bench_count = 20 + self.nb_iterations = 10000 + self.ppo_epochs = 4 + self.ppo_batch_size = 4 + self.value_epochs = 32 + self.value_batch_size = 32 + self.value_coef = 0.5 + self.value_clip = False + self.entropy_coef = 0.01 + self.lr = 0.001 + self.truncate = 5 + self.json_file = "" + self.eval_json_file = "" + self.tags = [] + self.debug = False + self.exec_data_file = "" + self.results_dir = "results" + self.run_name = "default_run" + self.loaded = False + self.collect_offline_data = False + self.offline_data_save_dir = "offline_data" + self.offline_data_file = "offline_dataset.npz" + + self.gamma = 0.99 + self.tau = 0.7 + self.beta = 3.0 + self.alpha = 0.005 + self.batch_size = 256 + self.lr = { + "value": 3e-4, + "q": 3e-4, + "policy": 3e-4 + } + self.max_steps = 100000 + self.target_update_freq = 1 + + + def load_from_json(self): + """Load the configuration from the JSON file.""" + # Open the JSON file + with open(os.getenv("CONFIG_FILE_PATH"), "r") as f: + config = json.load(f) + # Set the configuration values + self.max_num_stores_loads = config["max_num_stores_loads"] + self.max_num_loops = config["max_num_loops"] + self.max_num_load_store_dim = config["max_num_load_store_dim"] + self.num_tile_sizes = config["num_tile_sizes"] + self.vect_size_limit = config["vect_size_limit"] + self.order = config["order"] + self.interchange_mode = config["interchange_mode"] + self.exploration = config["exploration"] + self.init_epsilon = config["init_epsilon"] + self.new_architecture = config["new_architecture"] + self.normalize_bounds = config["normalize_bounds"] + self.normalize_adv = config["normalize_adv"] + self.sparse_reward = config["sparse_reward"] + self.split_ops = config["split_ops"] + self.reuse_experience = config["reuse_experience"] + self.activation = config["activation"] + self.benchmarks_folder_path = config["benchmarks_folder_path"] + self.bench_count = config["bench_count"] + self.nb_iterations = config["nb_iterations"] + self.ppo_epochs = config["ppo_epochs"] + self.ppo_batch_size = config["ppo_batch_size"] + self.value_epochs = config["value_epochs"] + self.value_batch_size = config["value_batch_size"] + self.value_coef = config["value_coef"] + self.value_clip = config["value_clip"] + self.entropy_coef = config["entropy_coef"] + self.lr = config["lr"] + self.truncate = config["truncate"] + self.json_file = config["json_file"] + self.eval_json_file = config["eval_json_file"] + self.tags = config["tags"] + self.debug = config["debug"] + self.main_exec_data_file = config["main_exec_data_file"] + self.results_dir = config["results_dir"] + self.run_name = config.get("run_name", "default_run") + # Set loaded flag + self.loaded = True + self.collect_offline_data = config.get("collect_offline_data", False) + self.offline_data_save_dir = config.get("offline_data_save_dir", "offline_data") + self.offline_data_file = config.get("offline_data_file", "offline_dataset.npz") + + ''' New parameters for IQL ''' + self.gamma = config.get("gamma", 0.99) + self.tau = config.get("tau", 0.7) + self.beta = config.get("inverse_temperature", 3.0) + self.alpha = config.get("alpha", 0.005) + self.batch_size = config.get("batch_size", 256) + self.lr = config.get("learning_rate", { + "value": 3e-4, + "q": 3e-4, + "policy": 3e-4 + }) + self.max_steps = config.get("max_steps", 100000) + self.target_update_freq = config.get("target_update_freq", 1) + + + def to_dict(self): + """Convert the configuration to a dictionary.""" + return { + "max_num_stores_loads": self.max_num_stores_loads, + "max_num_loops": self.max_num_loops, + "max_num_load_store_dim": self.max_num_load_store_dim, + "num_tile_sizes": self.num_tile_sizes, + "vect_size_limit": self.vect_size_limit, + "order": self.order, + "interchange_mode": self.interchange_mode, + "exploration": self.exploration, + "init_epsilon": self.init_epsilon, + "new_architecture": self.new_architecture, + "normalize_bounds": self.normalize_bounds, + "normalize_adv": self.normalize_adv, + "sparse_reward": self.sparse_reward, + "split_ops": self.split_ops, + "reuse_experience": self.reuse_experience, + "activation": self.activation, + "benchmarks_folder_path": self.benchmarks_folder_path, + "bench_count": self.bench_count, + "nb_iterations": self.nb_iterations, + "ppo_epochs": self.ppo_epochs, + "ppo_batch_size": self.ppo_batch_size, + "value_epochs": self.value_epochs, + "value_batch_size": self.value_batch_size, + "value_coef": self.value_coef, + "value_clip": self.value_clip, + "entropy_coef": self.entropy_coef, + "lr": self.lr, + "truncate": self.truncate, + "json_file": self.json_file, + "eval_json_file": self.eval_json_file, + "tags": self.tags, + "debug": self.debug, + "exec_data_file": self.exec_data_file, + "results_dir": self.results_dir, + "run_name": self.run_name, + "collect_offline_data": self.collect_offline_data, + "offline_data_save_dir": self.offline_data_save_dir, + "offline_data_file": self.offline_data_file, + "gamma": self.gamma, + "tau": self.tau, + "beta": self.beta, + "alpha": self.alpha, + "batch_size": self.batch_size, + "max_steps": self.max_steps, + "target_update_freq": self.target_update_freq, + "learning_rate": self.lr + } + + def __str__(self): + """Convert the configuration to a string.""" + return str(self.to_dict()) diff --git a/utils/dask_manager.py b/utils/dask_manager.py deleted file mode 100644 index e3b8ddd..0000000 --- a/utils/dask_manager.py +++ /dev/null @@ -1,158 +0,0 @@ -from typing import TYPE_CHECKING, Callable, Optional, TypeVar - -from typeguard import check_type -from distributed import Future, as_completed, get_worker, progress -from rl_autoschedular.benchmarks import Benchmarks -from dask.distributed import Client -from dask_jobqueue import SLURMCluster - -from utils.file_logger import FileLogger -from .singleton import Singleton -from .log import print_info -from .config import Config -import json -import os - -if TYPE_CHECKING: - from rl_autoschedular.state import OperationState - -T = TypeVar('T') - - -class DaskManager(metaclass=Singleton): - def __init__(self): - enable_dashboard = Config().debug - cluster = SLURMCluster( - job_name='dask', - queue='compute', - cores=28, - processes=1, - nanny=False, - memory='100GB', - walltime='7-00', - job_extra_directives=[ - '--reservation=c2', - '--nodes=1', - '--exclusive', - ], - worker_extra_args=['--resources', 'single_task_slot=1'], - log_directory='dask-logs', - job_script_prologue=[ - 'module load miniconda-nobashrc', - 'eval "$(conda shell.bash hook)"', - f'conda activate {os.getenv("CONDA_ENV")}', - 'export OMP_NUM_THREADS=12', - ], - scheduler_options={ - 'dashboard': enable_dashboard - } - ) - - num_nodes_to_use = int(os.environ["DASK_NODES"]) - print_info(f"Requesting {num_nodes_to_use} nodes for Dask workers...") - cluster.scale(jobs=num_nodes_to_use) - - client = Client(cluster) - print_info("Dask client connected!", f" Dashboard at: {client.dashboard_link}" if enable_dashboard else "") - - self.cluster = cluster - self.client = client - self.workers_names: list[str] = list(cluster.workers.keys()) - self.num_workers = len(cluster.workers) - - def map_states( - self, - func: Callable[['OperationState', str, 'Benchmarks', Optional[dict[str, dict[str, int]]]], T], - states: list['OperationState'], - training: bool, - ) -> list[T]: - # Provide worker data to the function via a wrapper - def func_wrapper(s: 'OperationState', e: str, idx: int): - worker = get_worker() - benchs = check_type(worker.data[f'__load_train_data_{worker.name}' if training else f'__load_eval_data_{worker.name}'], Benchmarks) - main_exec_data = check_type(worker.data[f'__load_main_exec_data_{worker.name}'], Optional[dict[str, dict[str, int]]]) - return idx, func(s, e, benchs, main_exec_data) - func_wrapper.__name__ = func.__name__ + '_wrapper' - - # Prepare states for submission - states_count = len(states) - ordered_states = list(zip(range(states_count), states)) - results: list[T] = [None] * states_count - future_to_worker: dict[Future, str] = {} - - # Submit first states to each worker - initial_states_count = min(states_count, self.num_workers) - for i in range(initial_states_count): - worker_name = self.workers_names[i] - idx, state = ordered_states.pop(0) - future = self.client.submit( - func_wrapper, - state, FileLogger().exec_data_file, idx, - workers=worker_name, - resources={'single_task_slot': 1} - ) - future_to_worker[future] = worker_name - - # Process futures as they finish - ac = as_completed(future_to_worker.keys(), with_results=True) - for future, indexed_result in ac: - future: Future - indexed_result: tuple[int, T] - - idx, result = indexed_result - results[idx] = result - freed_worker = future_to_worker.pop(future) - - # If there are still remaining states submit them - if ordered_states: - new_idx, new_state = ordered_states.pop(0) - new_future = self.client.submit( - func_wrapper, - new_state, FileLogger().exec_data_file, new_idx, - workers=freed_worker, - resources={'single_task_slot': 1} - ) - future_to_worker[new_future] = freed_worker - - # Include the new future in the queue - ac.add(new_future) - - return results - - def load_train_data(self): - def __load_train_data(): - return Benchmarks() - self.remote_train_data: list[Future] = [] - for worker_name in self.workers_names: - self.remote_train_data.append(self.client.submit(__load_train_data, workers=worker_name, key=f'__load_train_data_{worker_name}')) - print_info("Loading train benchmarks to workers...") - progress(self.remote_train_data) - return __load_train_data() - - def load_eval_data(self): - def __load_eval_data(): - return Benchmarks(is_training=False) - self.remote_eval_data: list[Future] = [] - for worker_name in self.workers_names: - self.remote_eval_data.append(self.client.submit(__load_eval_data, workers=worker_name, key=f'__load_eval_data_{worker_name}')) - print_info("Loading eval benchmarks to workers...") - progress(self.remote_eval_data) - return __load_eval_data() - - def load_main_exec_data(self): - def __load_main_exec_data(): - main_exec_data: Optional[dict[str, dict[str, int]]] = None - if Config().main_exec_data_file: - with open(Config().main_exec_data_file) as f: - main_exec_data = json.load(f) - return main_exec_data - self.remote_main_exec_data: list[Future] = [] - for worker_name in self.workers_names: - self.remote_main_exec_data.append(self.client.submit(__load_main_exec_data, workers=worker_name, key=f'__load_main_exec_data_{worker_name}')) - print_info("Loading main exec data to workers...") - progress(self.remote_main_exec_data) - return __load_main_exec_data() - - def close(self): - self.client.close() - self.cluster.close() diff --git a/utils/data_collector.py b/utils/data_collector.py new file mode 100644 index 0000000..da1caf7 --- /dev/null +++ b/utils/data_collector.py @@ -0,0 +1,74 @@ +import os +import numpy as np +import torch +from utils.singleton import Singleton +from utils.log import print_success + +class OfflineDataset(metaclass=Singleton): + """Singleton class to collect and store trajectories for offline RL """ + + def __init__(self, save_dir: str = "offline_data", fname: str = "dataset.npz"): + """ + Args: + save_dir (str): Directory to store dataset. + fname (str): Dataset filename. + """ + self.save_dir = save_dir + self.fname = fname + os.makedirs(self.save_dir, exist_ok=True) + + self.buffer = [] # in-memory buffer for efficiency + self.file_path = os.path.join(self.save_dir, self.fname) + + def add_transition(self, obs, action, next_obs, reward, done): + """Add one transition to buffer.""" + self.buffer.append({ + "obs": obs.squeeze(0).cpu().numpy() if torch.is_tensor(obs) else obs, + "action": action.detach().cpu().numpy().squeeze(0) if torch.is_tensor(action) else np.array(action).squeeze(0), + "next_obs": next_obs.squeeze(0).cpu().numpy() if torch.is_tensor(next_obs) else next_obs, + "reward": float(reward), + "done": bool(done), + }) + + def add_trajectory(self, trajectory): + """Add a full trajectory (list of transitions).""" + self.buffer.extend(trajectory) + + def flush(self): + """Save buffer to disk as npz and clear it.""" + if not self.buffer: + return + + # Convert buffer to arrays + obs = np.array([t["obs"] for t in self.buffer], dtype=np.float32) + actions = np.array([t["action"] for t in self.buffer], dtype=np.int64) + next_obs = np.array([t["next_obs"] for t in self.buffer], dtype=np.float32) + rewards = np.array([t["reward"] for t in self.buffer], dtype=np.float32) + dones = np.array([t["done"] for t in self.buffer], dtype=np.bool_) + + if os.path.exists(self.file_path): + # If file exists, append to it + old = np.load(self.file_path) + obs = np.concatenate([old["obs"], obs], axis=0) + actions = np.concatenate([old["actions"], actions], axis=0) + next_obs = np.concatenate([old["next_obs"], next_obs], axis=0) + rewards = np.concatenate([old["rewards"], rewards], axis=0) + dones = np.concatenate([old["dones"], dones], axis=0) + + np.savez_compressed( + self.file_path, + obs=obs, + actions=actions, + next_obs=next_obs, + rewards=rewards, + dones=dones, + ) + + print_success(f"[OfflineDataset] Flushed {len(self.buffer)} transitions -> {self.file_path}") + self.buffer.clear() + + def load(self, mmap_mode=None): + """Load dataset from disk as dict of numpy arrays (optionally memory-mapped).""" + if not os.path.exists(self.file_path): + return {} + return np.load(self.file_path, mmap_mode=mmap_mode) diff --git a/utils/file_logger.py b/utils/file_logger.py old mode 100644 new mode 100755 index 8cf670b..0e15e01 --- a/utils/file_logger.py +++ b/utils/file_logger.py @@ -1,63 +1,42 @@ -from utils.singleton import Singleton -from utils.config import Config -import json -import os - - -class FileLogger(metaclass=Singleton): - """Class to log results to files""" - def __init__(self): - cfg = Config() - tags = ['ppo'] + cfg.tags - - # Create run dir - dir_path = cfg.results_dir - subdir_ids = sorted([int(d.split('_')[-1]) for d in os.listdir(dir_path) if d.startswith('run_')]) - run_id = subdir_ids[-1] + 1 if subdir_ids else 0 - self.run_dir = os.path.join(dir_path, f'run_{run_id}') - os.makedirs(self.run_dir, exist_ok=True) - - # Create tags file - tags_file = os.path.join(self.run_dir, 'tags') - with open(tags_file, 'w') as f: - f.write('\n'.join(tags)) - f.write('\n') - - # Create exec data file - self.exec_data_file = os.path.join(self.run_dir, 'exec_data.json') - with open(self.exec_data_file, "w") as f: - json.dump({}, f) - - # Create logs dir - self.logs_dir = os.path.join(self.run_dir, 'logs') - os.makedirs(self.logs_dir, exist_ok=True) - - # Create models dir - self.models_dir = os.path.join(self.run_dir, 'models') - os.makedirs(self.models_dir, exist_ok=True) - - # Init files dict - self.files_dict: dict[str, FileInstance] = {} - - def __getitem__(self, path: str): - if path not in self.files_dict: - full_path = os.path.join(self.logs_dir, path) - os.makedirs(os.path.dirname(full_path), exist_ok=True) - assert not os.path.exists(full_path), f"File {path} already exists" - self.files_dict[path] = FileInstance(full_path) - return self.files_dict[path] - - -class FileInstance: - def __init__(self, path: str): - self.path = path - - def append(self, data): - with open(self.path, 'a') as f: - f.write(str(data)) - f.write('\n') - - def extend(self, data: list): - with open(self.path, 'a') as f: - f.write('\n'.join(map(str, data))) - f.write('\n') +from utils.singleton import Singleton +import os +from torch.utils.tensorboard import SummaryWriter + + +class TensorBoardLogger(metaclass=Singleton): + """Logger using TensorBoard for training metrics and results.""" + + def __init__(self, log_dir: str, run_name: str, tags: list[str] = None): + """ + Args: + log_dir (str): Base directory for logs (e.g. "logs"). + run_name (str): Custom run name (instead of auto run_0). + tags (list[str], optional): Tags or metadata for this run. + """ + self.run_dir = os.path.join(log_dir, run_name) + os.makedirs(self.run_dir, exist_ok=True) + + # Initialize TensorBoard writer + self.writer = SummaryWriter(log_dir=self.run_dir) + + # Save tags to a file for reproducibility + if tags: + with open(os.path.join(self.run_dir, "tags.txt"), "w") as f: + f.write("\n".join(tags) + "\n") + + def log_scalar(self, name: str, value: float, step: int): + """Log a scalar value to TensorBoard.""" + self.writer.add_scalar(name, value, step) + + def log_scalars(self, main_tag: str, tag_scalar_dict: dict, step: int): + """Log multiple scalars under a main tag (TensorBoard grouping).""" + self.writer.add_scalars(main_tag, tag_scalar_dict, step) + + def flush(self): + """Flush events to disk (useful if crashing).""" + self.writer.flush() + + def close(self): + """Close the TensorBoard writer.""" + self.writer.close() + diff --git a/utils/log.py b/utils/log.py old mode 100644 new mode 100755 index a024410..bdd680e --- a/utils/log.py +++ b/utils/log.py @@ -1,37 +1,32 @@ -import random -import string -import sys -from dask.distributed import print - - -def generate_random_string(): - """Generate a random string of length 10""" - return ''.join(random.choices(string.ascii_letters + string.digits, k=10)) - - -def print_info(*args, add_label: bool = True, **kwargs): - """Prints an information message""" - message = ' '.join(map(str, args)) - label = '[INFO]\t ' if add_label else '' - print(f"\033[94m{label}{message}\033[0m", **kwargs) - - -def print_success(*args, add_label: bool = True, **kwargs): - """Prints a success message""" - message = ' '.join(map(str, args)) - label = '[SUCCESS]\t ' if add_label else '' - print(f"\033[92m{label}{message}\033[0m", **kwargs) - - -def print_alert(*args, add_label: bool = True, **kwargs): - """Prints an alert message""" - message = ' '.join(map(str, args)) - label = '[ALERT]\t ' if add_label else '' - print(f"\033[93m{label}{message}\033[0m", file=sys.stderr, **kwargs) - - -def print_error(*args, add_label: bool = True, **kwargs): - """Prints an error message""" - message = ' '.join(map(str, args)) - label = '[ERROR]\t ' if add_label else '' - print(f"\033[91m{label}{message}\033[0m", file=sys.stderr, **kwargs) +import random +import string +import sys + + +def generate_random_string(): + """Generate a random string of length 10""" + return ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + + +def print_info(*args): + """Prints an information message""" + message = ' '.join(map(str, args)) + print(f"\033[94m[INFO]\t {message}\033[0m") + + +def print_success(*args): + """Prints a success message""" + message = ' '.join(map(str, args)) + print(f"\033[92m[SUCCESS]\t {message}\033[0m") + + +def print_alert(*args): + """Prints an alert message""" + message = ' '.join(map(str, args)) + print(f"\033[93m[ALERT]\t {message}\033[0m", file=sys.stderr) + + +def print_error(*args): + """Prints an error message""" + message = ' '.join(map(str, args)) + print(f"\033[91m[ERROR]\t {message}\033[0m", file=sys.stderr) diff --git a/utils/singleton.py b/utils/singleton.py old mode 100644 new mode 100755 index 9172757..2522aa2 --- a/utils/singleton.py +++ b/utils/singleton.py @@ -1,8 +1,8 @@ -class Singleton(type): - """Meta class to create a singleton instance of a class""" - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] +class Singleton(type): + """Meta class to create a singleton instance of a class""" + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/viz.py b/viz.py new file mode 100644 index 0000000..cea31a8 --- /dev/null +++ b/viz.py @@ -0,0 +1,24 @@ +import pandas as pd +import matplotlib.pyplot as plt + +# Load the CSV file +df = pd.read_csv("./comparaison.csv", sep=";") + +# Pivot to have algorithms as columns +pivot_df = df.pivot(index="metric", columns="algorithm", values="score") + +# Sort metrics alphabetically for consistency (optional) +pivot_df = pivot_df.sort_index() + +# Plot horizontal bars +ax = pivot_df.plot(kind="barh", figsize=(10, 7)) +plt.xlabel("Score") +plt.ylabel("Benchmark / Metric") +plt.title("Comparison of PPO vs Offline IQL across Benchmarks") +plt.legend(title="Algorithm") +plt.tight_layout() + +# Save as PNG +plt.savefig("ppo_vs_iql_comparison.png") + +print("Plot saved as ppo_vs_iql_comparison.png") diff --git a/viz_online.py b/viz_online.py new file mode 100644 index 0000000..1789afe --- /dev/null +++ b/viz_online.py @@ -0,0 +1,62 @@ +import pandas as pd +import matplotlib.pyplot as plt + +# Your online dataset benchmarks +online_data = [ + "matmul_256_768_2", + "matmul_256_768_3072", + "matmul_256_2048_2048", + "matmul_256_256_512", + "matmul_256_1024_1024", + "matmul_256_1536_1000", + "matmul_256_256_128", + "matmul_256_512_1024", + "matmul_256_1536_4096", + "matmul_256_1408_1000", + "matmul_256_1280_1000", + "matmul_256_768_768", + "matmul_256_2048_1000", + "matmul_256_4096_1024", + "matmul_256_128_256", + "matmul_1024_128_768", + "matmul_1024_2048_128", + "matmul_1024_128_256", + "matmul_1024_128_512", + "matmul_1024_1024_128", + "matmul_1024_1024_256", + "matmul_1024_128_1024", + "matmul_1024_1536_128", + "matmul_1024_128_128", + "matmul_1024_128_2048", +] + +# Load your CSV file +df = pd.read_csv("online_iql.csv", sep=";") + +# Exclude "average_speedup" from bar chart (optional, keep only benchmarks) +benchmarks_df = df[df["metric"] != "average_speedup"] + +# Assign colors depending on online/offline +colors = [ + "blue" if metric in online_data else "red" + for metric in benchmarks_df["metric"] +] + +# Plot horizontal bar chart +plt.figure(figsize=(10, 6)) +plt.barh(benchmarks_df["metric"], benchmarks_df["score"], color=colors) + +plt.xlabel("Score") +plt.ylabel("Benchmark") +plt.title("Online finetuning of IQL") + +# Add legend manually +import matplotlib.patches as mpatches +blue_patch = mpatches.Patch(color="blue", label="Online data") +red_patch = mpatches.Patch(color="red", label="Offline data") +plt.legend(handles=[blue_patch, red_patch]) + +plt.tight_layout() +plt.savefig("online_offline_iql.png") + +print("Plot saved as online_offline_iql.png") From e7d22d44560a0bd50240b11e6320e0cd6d9d015c Mon Sep 17 00:00:00 2001 From: Oucherif Ouail Date: Sat, 11 Oct 2025 13:06:05 +0100 Subject: [PATCH 2/5] added example for config.json in readme --- README.md | 127 +++++++++++++++++++++--------------------------------- 1 file changed, 48 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 310321e..d9a6e09 100755 --- a/README.md +++ b/README.md @@ -1,90 +1,59 @@ -## Getting Started -This is an example of how you may give instructions on setting up your project locally. -To get a local copy up and running follow these simple example steps. -### Prerequisites: -###### Required -1) [CMake](https://cmake.org/): version 3.20 or greater. -2) [Ninja](https://ninja-build.org/). -3) [Gcc](https://gcc.gnu.org/) : version 13.2. -4) [Gxx]: version 13.2. -5) [LLD](https://lld.llvm.org/). -6) [Python](https://www.python.org/downloads/): version 3.11 or greater. -### Setup -#### 1. Building MLIR : -```sh -git clone --depth 1 -b release/19.x https://github.com/llvm/llvm-project.git -mkdir llvm-project/build -cd llvm-project/build -cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ --DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=X86 -DLLVM_ENABLE_ASSERTIONS=ON \ --DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DLLVM_ENABLE_LLD=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON +# IQL for MLIR-RL -cmake --build . --target check-mlir +Example for `config.json` : ``` -#### 2. Install python requirements : -```sh -pip install -r requirements.txt -``` -#### 3. Setup environment variables : -Change llvm related variables according to your llvm-project folder path. -```env -NEPTUNE_PROJECT= -NEPTUNE_TOKEN= -LLVM_BUILD_PATH=llvm-project/build -MLIR_SHARED_LIBS=llvm-project/build/lib/libomp.so,llvm-project/build/lib/libmlir_c_runner_utils.so,llvm-project/build/lib/libmlir_runner_utils.so -AST_DUMPER_BIN_PATH=tools/ast_dumper/build/bin/AstDumper -VECTORIZER_BIN_PATH=tools/vectorizer/build/bin/Vectorizer -``` -### Documentation -#### 1. Jobs -For running jobs using slurm script examples are provided in the `scripts/` folder. -#### 2. Configuration -Configuring the model on a specific case can be done by setting a JSON config file containing all required settings. Configuration JSON file examples are provided in the `config/` folder. -The following JSON content is an example of a config file: -```json { "max_num_stores_loads": 7, "max_num_loops": 7, "max_num_load_store_dim": 7, "num_tile_sizes": 7, - "num_transformations": 6, "vect_size_limit": 2048, - "use_bindings": false, - "use_vectorizer": false, - "data_format": "json", - "optimization_mode": "last", - "benchmarks_folder_path": "", - "len_trajectory": 64, - "ppo_batch_size": 64, - "nb_iterations": 10000, + "order": [["I"],["TP"],["T"],["V","NT"]], + "interchange_mode": "pointers", + "exploration": ["entropy"], + "init_epsilon": 0.1, + "new_architecture": false, + "normalize_bounds": "max", + "normalize_adv": "standard", + "sparse_reward": true, + "split_ops": true, + "reuse_experience": "none", + "activation": "relu", + "benchmarks_folder_path": "data/matmul/code/", + "bench_count": 8, + "replay_count": 10, + "nb_iterations": 1200, "ppo_epochs": 4, + "ppo_batch_size": 32, + "value_epochs": 4, + "value_batch_size": 32, + "value_coef": 0.5, + "value_clip": true, "entropy_coef": 0.01, - "lr": 0.001, - "truncate": 5, - "json_file": "data/nn/train_operations.json", - "tags": ["nn"], - "logging": true + "lr": 3e-4, + "truncate": 10, + "json_file": "data/matmul/train_operations.json", + "eval_json_file": "data/matmul/eval_operations.json", + "tags": ["matmul"], + "debug": false, + "main_exec_data_file": "cache/execution.json", + "results_dir": "offline_iql_adv_norm_gradclip_cosine_scheduler", + "run_name": "offline_iql_adv_norm_gradclip_cosine_scheduler", + "collect_offline_data": false, + "offline_data_save_dir": "offline_dataset", + "offline_data_file": "offline_dataset_online_ppo.npz", + + "gamma": 0.99, + "tau": 0.9, + "inverse_temperature":3.0, + "alpha": 0.005, + "batch_size": 256, + "learning_rate": { + "value": 3e-4, + "q": 3e-4, + "policy": 1e-4 + }, + "max_steps": 1000000, + "target_update_freq": 1 } -``` -The following list describes every required setting in a configuration file. -- `max_num_stores_loads (int)`: The maximum number of loads in the nested loops. -- `max_num_loops (int)`: The max number of nested loops. -- `max_num_load_store_dim (int)`: The max number of dimensions in load/store buffers. -- `num_tile_sizes (int)`: The number of possible tile sizes for a loop. -- `num_transformations (int)`: The number of transformations. -- `vect_size_limit (int)`: Vectorization size limit to prevent large sizes vectorization. -- `use_bindings (bool)`: Flag to enable using python bindings for execution, if False, the execution will be done using the command line. Default is False. -- `use_vectorizer (bool)`: Flag to enable using the vectorizer C++ program for vectorization, if False, vectorization is done using transform dialect directly. Default is False. -- `data_format (Literal["json", "mlir"])`: The format of the data, can be either "json" or "mlir". "json" mode reads json files containing benchmark features, "mlir" mode reads mlir code files directly and extract features from it using AST dumper. Default is "json". -- `optimization_mode (Literal["last", "all"])`: The optimization mode to use, "last" will optimize only the last operation, "all" will optimize all operations in the code. Default is "last". -- `benchmarks_folder_path (str)`: Path to the benchmarks folder. Can be empty if data format is set to "json". -- `len_trajectory (int)`: Length of the trajectory used for PPO. -- `ppo_batch_size (int)`: Batch size for PPO. -- `nb_iterations (int)`: Number of training iterations. -- `ppo_epochs (int)`: Number of epochs for PPO. -- `entropy_coef (float)`: Entropy coefficient. -- `lr (float)`: Learning rate. -- `truncate (int)`: Maximum number of steps of a schedule for an operation. -- `json_file (str)`: Path to the JSON file containing the benchmarks code and features if data format is set to "json". Otherwise, it should contain original execution times for every benchmark in the benchmark folder. -- `tags (list[str])`: List of tags to add to the neptune experiment. -- `logging (bool)`: Flag to enable logging to neptune. \ No newline at end of file +``` \ No newline at end of file From 7904f7107588847c7543a13be87376ac01598576 Mon Sep 17 00:00:00 2001 From: Oucherif Ouail Date: Sun, 19 Oct 2025 10:48:52 +0100 Subject: [PATCH 3/5] updated iql online finetuning --- .gitignore | 8 +- iql/iql_agent.py | 11 +- iql/iql_agent_device.py | 282 ---------------------------------------- iql_online.py | 152 +++++++++++++++++++--- neptune_sync.py | 72 ---------- train.py | 6 +- train_iql.py | 2 +- viz_online.py | 105 ++++++++------- 8 files changed, 205 insertions(+), 433 deletions(-) delete mode 100644 iql/iql_agent_device.py delete mode 100755 neptune_sync.py diff --git a/.gitignore b/.gitignore index 3566c9e..69f932d 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,10 @@ data/ results/ offline_iql_adv_norm_gradclip_cosine_scheduler/ offline_dataset/ -tmp/* \ No newline at end of file +tmp/* + +iql_online_fine_tuning/ +offline_iql_results_1/ +online_finetuning_iql/ +online_iql/ +online_ppo/ \ No newline at end of file diff --git a/iql/iql_agent.py b/iql/iql_agent.py index dacc6b4..3be7acd 100755 --- a/iql/iql_agent.py +++ b/iql/iql_agent.py @@ -49,11 +49,13 @@ def __init__(self, cfg: Config, device: Union[torch.device, str], obs_parts=None self.value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=cfg.lr["value"]) self.q_optimizer = torch.optim.Adam(self.q_model.parameters(), lr=cfg.lr["q"]) self.policy_optimizer = torch.optim.Adam(self.policy_model.parameters(), lr=cfg.lr["policy"]) + """ self.policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.policy_optimizer, T_max=600000, eta_min=1e-5 - ) + ) + """ # --------- helpers to move inputs to device ---------- def _to_device_tensor(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @@ -225,7 +227,7 @@ def update_policy( self.policy_optimizer.step() - self.policy_lr_scheduler.step() + # self.policy_lr_scheduler.step() return loss_pi # ------------------------ @@ -280,7 +282,4 @@ def update(self, batch: Tuple[torch.Tensor, ...]) -> Dict[str, float]: "q": float(loss_q.item()), "policy": float(loss_pi.item()), "value": float(loss_v.item()), - } - - - + } \ No newline at end of file diff --git a/iql/iql_agent_device.py b/iql/iql_agent_device.py deleted file mode 100644 index 61c9b47..0000000 --- a/iql/iql_agent_device.py +++ /dev/null @@ -1,282 +0,0 @@ -import copy -from typing import Dict, List, Optional, Type, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from iql.iql_config import Config -from utils.config import Config -from rl_autoschedular.actions import ActionSpace -from rl_autoschedular.observation import Observation, ObservationPart, OpFeatures, ActionHistory - -# ---- bring in the updated models we defined earlier ---- -from iql.value_function import IQLValueModel -from iql.policy import IQLPolicyModel -from iql.q_functions import IQLTwinQ - - -class IQLAgent(nn.Module): - """ - IQL agent adapted to the PPO-aligned architecture and hierarchical action space. - - Uses Observation.get_parts(obs, *obs_parts) - - Shared 3×512 backbone across policy/value/Q - - Hierarchical heads (action + per-action params) - """ - def __init__(self, cfg: Config, device: Union[torch.device, str], obs_parts=None, param_dims=None): - super().__init__() - self.obs_parts = obs_parts or [OpFeatures, ActionHistory] - - # ---- device handling ---- - self.device = torch.device(device) if not isinstance(device, torch.device) else device - - # Use config hyperparameters - self.gamma = cfg.gamma - self.tau = cfg.tau - self.beta = cfg.beta - self.alpha = cfg.alpha - - # Networks (move to device) - self.value_model = IQLValueModel(self.obs_parts, tau=self.tau).to(self.device) - self.policy_model = IQLPolicyModel(self.obs_parts).to(self.device) - self.q_model = IQLTwinQ(self.obs_parts).to(self.device) - - # Target Q - self.q_target = copy.deepcopy(self.q_model).to(self.device) - for p in self.q_target.parameters(): - p.requires_grad = False - - # Optimizers with cfg.lr dict (after models are on device) - self.value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=cfg.lr["value"]) - self.q_optimizer = torch.optim.Adam(self.q_model.parameters(), lr=cfg.lr["q"]) - self.policy_optimizer = torch.optim.Adam(self.policy_model.parameters(), lr=cfg.lr["policy"]) - - # --------- helpers to move inputs to device ---------- - def _to_device_tensor(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - if x is None: - return None - return x.to(self.device, non_blocking=True) - - def _to_device_tensor_list( - self, - xs: Optional[List[Optional[torch.Tensor]]] - ) -> Optional[List[Optional[torch.Tensor]]]: - if xs is None: - return None - out: List[Optional[torch.Tensor]] = [] - for t in xs: - out.append(self._to_device_tensor(t) if isinstance(t, torch.Tensor) else None if t is None else t) - return out - - # ------------------------ - # Action selection (hierarchical) - # ------------------------ - @torch.no_grad() - def sample( - self, - obs: torch.Tensor, - greedy: bool = False, - eps: Optional[float] = None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Sample hierarchical action indices using the same API style as PPO. - Returns: - actions_index: packed hierarchical indices (ActionSpace format) - actions_log_p: log-prob of sampled action under current policy - entropies: per-head entropies (aggregated by ActionSpace) - """ - obs = self._to_device_tensor(obs) - - # Build distributions from policy - dists = self.policy_model(obs) - eps_dists = ActionSpace.uniform_distributions(obs) - - # Hierarchical sample - use_uniform = (eps is not None) and (torch.rand((), device=self.device).item() < eps) - actions_index = ActionSpace.sample( - obs, - dists, - eps_dists, - uniform=use_uniform, - greedy=greedy, - ) - - # Stats for the sampled actions - actions_log_p, entropies = ActionSpace.distributions_stats( - dists, - actions_index, - eps_distributions=eps_dists if eps is not None else None, - eps=eps, - ) - return actions_index, actions_log_p, entropies - - # ------------------------ - # Value update (expectile regression using target twin-Q) - # ------------------------ - def update_value( - self, - obs: torch.Tensor, - action_idx: torch.LongTensor, - *, - param_indices: Optional[List[Optional[torch.LongTensor]]] = None, - param_values: Optional[List[Optional[torch.Tensor]]] = None, - ) -> torch.Tensor: - """ - Updates V(s) by regressing towards min(Q1, Q2) from the *target* Q network. - """ - obs = self._to_device_tensor(obs) - action_idx = self._to_device_tensor(action_idx) - param_indices = self._to_device_tensor_list(param_indices) - param_values = self._to_device_tensor_list(param_values) - - with torch.no_grad(): - q1_t, q2_t = self.q_target(obs, action_idx) - q_min_t = torch.min(q1_t, q2_t) # [B] - - loss_v = self.value_model.loss(obs, q_min_t) - assert self.value_optimizer is not None, "value_optimizer is not set" - self.value_optimizer.zero_grad(set_to_none=True) - loss_v.backward() - self.value_optimizer.step() - return loss_v - - # ------------------------ - # Q update (TD with V(s')) - # ------------------------ - def update_q( - self, - obs: torch.Tensor, - action_idx: torch.LongTensor, - rewards: torch.Tensor, - next_obs: torch.Tensor, - dones: torch.Tensor - ) -> torch.Tensor: - """ - Update twin Q networks with TD target: - target_q = r + gamma * (1 - done) * V_target(s') - If target_v is not provided, it is computed from the current value_model. - """ - obs = self._to_device_tensor(obs) - next_obs = self._to_device_tensor(next_obs) - action_idx = self._to_device_tensor(action_idx) - rewards = self._to_device_tensor(rewards) - dones = self._to_device_tensor(dones) - - with torch.no_grad(): - target_v = self.value_model(next_obs).to(self.device) # [B] - - target_q = rewards + self.gamma * (1.0 - dones) * target_v # [B] - - loss_q = self.q_model.loss( - obs, - action_idx, - target_q - ) - assert self.q_optimizer is not None, "q_optimizer is not set" - self.q_optimizer.zero_grad(set_to_none=True) - loss_q.backward() - self.q_optimizer.step() - return loss_q - - # ------------------------ - # Policy update (advantage-weighted BC) - # ------------------------ - def update_policy( - self, - obs: torch.Tensor, - actions_index: torch.Tensor, # packed hierarchical indices (as stored by dataset) - *, - action_idx: Optional[torch.LongTensor] = None, - param_indices: Optional[List[Optional[torch.LongTensor]]] = None, - param_values: Optional[List[Optional[torch.Tensor]]] = None, - ) -> torch.Tensor: - """ - Update policy with advantage-weighted log-likelihood: - weights = exp(A / beta), A = min(Q1, Q2) - V(s) - - - actions_index is used to compute log π(a|s) via ActionSpace.distributions_stats(...) - - Q needs decomposed (action_idx, param_indices/values). - """ - obs = self._to_device_tensor(obs) - actions_index = self._to_device_tensor(actions_index) - action_idx = self._to_device_tensor(action_idx) if action_idx is not None else None - param_indices = self._to_device_tensor_list(param_indices) - param_values = self._to_device_tensor_list(param_values) - - # 1) log π(a|s) from hierarchical distributions - dists = self.policy_model(obs) - actions_log_p, _ = ActionSpace.distributions_stats(dists, actions_index) - - # 2) advantages = Q_min(s,a) - V(s) - assert action_idx is not None, "action_idx (top-level) is required for Q evaluation" - with torch.no_grad(): - q_min = self.q_model.q_values(obs, action_idx) # [B] - v = self.value_model(obs) # [B] - advantages = q_min - v # [B] - - # 3) loss (AWAC/IQL style) - loss_pi = self.policy_model.loss( - actions_log_p=actions_log_p, - advantages=advantages, - beta=self.beta, - ) - - assert self.policy_optimizer is not None, "policy_optimizer is not set" - self.policy_optimizer.zero_grad(set_to_none=True) - loss_pi.backward() - self.policy_optimizer.step() - return loss_pi - - # ------------------------ - # Soft update of target Q - # ------------------------ - @torch.no_grad() - def soft_update_q_target(self): - """ - θ_target ← α θ + (1-α) θ_target - """ - for p, tp in zip(self.q_model.parameters(), self.q_target.parameters()): - tp.data.copy_(self.alpha * p.data + (1.0 - self.alpha) * tp.data) - - def update(self, batch: Tuple[torch.Tensor, ...]) -> Dict[str, float]: - """ - One full IQL update step: - 1. Update Q-functions - 2. Update value function - 3. Update policy (AWAC/IQL style) - 4. Soft update target Q - Returns dict of losses for logging. - """ - # Ensure whole batch is on device - obs, actions_index, rewards, next_obs, dones = (t.to(self.device, non_blocking=True) for t in batch) - - # ---- 1) Update Q ---- - loss_q = self.update_q( - obs=obs, - action_idx=actions_index, # top-level index - rewards=rewards, - next_obs=next_obs, - dones=dones, - ) - - # ---- 2) Update Value ---- - loss_v = self.update_value(obs, actions_index) - - - # ---- 3) Update Policy ---- - loss_pi = self.update_policy( - obs=obs, - actions_index=actions_index, - action_idx=actions_index, # required for Q evaluation - ) - - - - # ---- 4) Soft update Q target ---- - self.soft_update_q_target() - - return { - "q": float(loss_q.item()), - "policy": float(loss_pi.item()), - "value": float(loss_v.item()), - } diff --git a/iql_online.py b/iql_online.py index c3d1234..5b142a2 100644 --- a/iql_online.py +++ b/iql_online.py @@ -18,6 +18,10 @@ device = torch.device("cpu") +MAX_STEPS_OFFLINE_ONLINE_RATIO = 100_000 # steps over which to decay online-offline ratio + +UPDATE_ITERS = 3 # gradient updates per batch + def load_offline_dataset(): """Load offline dataset for warm-starting replay buffer.""" dataset = OfflineDataset( @@ -89,9 +93,60 @@ def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): return env_time +def collect_warmup_data(agent, env, buffer, warmup_steps=5000): + """ + Collects a fixed number of online transitions before training begins. + This stabilizes early learning by pre-filling the online buffer. + """ + print(f"\n[ Warmup Phase ] Collecting {warmup_steps} online transitions before training...\n") + state = env.reset() + + progress = trange(warmup_steps, desc="Warmup Collection", dynamic_ncols=True) + + + for _ in progress: + obs = Observation.from_state(state) + action_index, _, _ = agent.sample(obs.to(device), eps=None) + action = ActionSpace.action_by_index(action_index[0], state) + + next_state, reward, op_done, _ = env.step(state, action) + next_obs = Observation.from_state(next_state) + + # Handle operation completion + if op_done: + next_state, done = env.get_next_op_state(next_state) + else: + done = False + + # Store in online buffer + buffer.add_online( + obs.to(device), + action_index, + torch.tensor(reward, dtype=torch.float32, device=device), + next_obs.to(device), + torch.tensor(done, dtype=torch.float32, device=device), + ) + + state = next_state if not done else env.reset() + + + print(f"[Warmup Complete] Collected {len(buffer.online_buffer)} online transitions.\n") + + +def get_epsilon(step, eps_start=0.3, eps_end=0.05, decay_steps=100_000): + """ + Linearly decays epsilon from eps_start → eps_end over decay_steps. + Used for epsilon-greedy exploration during online interaction. + """ + if step >= decay_steps: + return eps_end + decay = (eps_start - eps_end) * (1 - step / decay_steps) + return eps_end + decay + + class ReplayBuffer: """Simple replay buffer mixing offline + online data.""" - def __init__(self, max_size=1000000): + def __init__(self, max_size=100000): self.states, self.actions, self.rewards, self.next_states, self.dones = [], [], [], [], [] self.max_size = max_size @@ -123,11 +178,66 @@ def sample(self, batch_size): def __len__(self): return len(self.states) +class DualReplayBuffer: + """Manages separate buffers for offline and online data, + with controlled sampling ratio that decays over time.""" + def __init__(self, offline_buffer_max=100_000, online_buffer_max=100_000): + self.offline_buffer = ReplayBuffer(max_size=offline_buffer_max) + self.online_buffer = ReplayBuffer(max_size=online_buffer_max) + + def add_offline(self, s, a, r, ns, d): + self.offline_buffer.add(s, a, r, ns, d) + + def add_online(self, s, a, r, ns, d): + self.online_buffer.add(s, a, r, ns, d) + + def sample(self, batch_size, step, max_steps, online_start_ratio=0.8, online_end_ratio=0.2): + """Sample from both buffers with ratio decaying linearly over training.""" + # Compute current online ratio + progress = min(step / max_steps, 1.0) + online_ratio = online_start_ratio - (online_start_ratio - online_end_ratio) * progress + online_batch = int(batch_size * online_ratio) + offline_batch = batch_size - online_batch + + # Safety check: if one buffer is too small, compensate from the other + online_batch = min(online_batch, len(self.online_buffer)) + offline_batch = batch_size - online_batch + if len(self.offline_buffer) < offline_batch: + offline_batch = len(self.offline_buffer) + online_batch = batch_size - offline_batch + + # Sample + if online_batch > 0: + online_samples = self.online_buffer.sample(online_batch) + else: + online_samples = None + if offline_batch > 0: + offline_samples = self.offline_buffer.sample(offline_batch) + else: + offline_samples = None + + # Merge samples + def merge_tensors(t1, t2): + if t1 is None: return t2 + if t2 is None: return t1 + return torch.cat([t1, t2], dim=0) + + return tuple( + merge_tensors(o, f) + for o, f in zip( + online_samples if online_samples else (None,)*5, + offline_samples if offline_samples else (None,)*5 + ) + ) + + def __len__(self): + return len(self.offline_buffer) + len(self.online_buffer) + def hybrid_finetune(): # === Load pretrained agent === agent = IQLAgent(cfg, device, obs_parts=[OpFeatures, ActionHistory]) - ckpt_path = "./iql_results/iql_step_17999.pt" + ckpt_path = "./offline_iql_results_1/iql_step_97999.pt" if os.path.exists(ckpt_path): agent.load_state_dict(torch.load(ckpt_path, map_location=device)) print(f"Loaded pretrained checkpoint: {ckpt_path}") @@ -135,11 +245,13 @@ def hybrid_finetune(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") # === Init Replay Buffer with offline data === - buffer = ReplayBuffer(max_size=200000) + + buffer = DualReplayBuffer(offline_buffer_max=100_000, online_buffer_max=100_000) + states, actions, rewards, next_states, dones = load_offline_dataset() for s, a, r, ns, d in zip(states, actions, rewards, next_states, dones): - buffer.add(s, a, r, ns, d) - print(f"Replay buffer initialized with {len(buffer)} offline samples") + buffer.add_offline(s, a, r, ns, d) + print(f"Offline Replay buffer initialized with {len(buffer)} offline samples") # environments train_env = Env(is_training=True, run_name=cfg.run_name) @@ -148,19 +260,24 @@ def hybrid_finetune(): print("Starting HYBRID fine-tuning (offline + online)...") start_time = time.time() state = train_env.reset() + + + # Warmup Phase (modular) + collect_warmup_data(agent, train_env, buffer, warmup_steps=2000) hybrid_trange = trange(cfg.max_steps, desc="Hybrid Fine-tuning", dynamic_ncols=True) for step in hybrid_trange: # reset benchmark state = train_env.reset() done = False - + while not done: # current obs obs = Observation.from_state(state) # agent picks action - action_index, _, _ = agent.sample(obs.to(device), eps=None) + eps = get_epsilon(step) + action_index, _, _ = agent.sample(obs.to(device), eps=eps) action = ActionSpace.action_by_index(action_index[0], state) # env step @@ -174,7 +291,7 @@ def hybrid_finetune(): next_state, done = train_env.get_next_op_state(next_state) # push transition to replay buffer - buffer.add( + buffer.add_offline( obs.to(device), action_index, torch.tensor(reward, dtype=torch.float32, device=device), @@ -186,20 +303,13 @@ def hybrid_finetune(): state = next_state # after benchmark, do 1 gradient update - batch = buffer.sample(cfg.batch_size) - losses = agent.update(batch) - - # logging - if step % 50 == 0: - fl.log_scalars("hybrid_train", losses, step) - - + batch = buffer.sample(cfg.batch_size,step, MAX_STEPS_OFFLINE_ONLINE_RATIO) - # logging - if step % 50 == 0: - fl.log_scalars("hybrid_train", losses, step) + for _ in range(UPDATE_ITERS): + losses = agent.update(batch) if (step + 1) % 100 == 0: + fl.log_scalars("hybrid_train", losses, step) elapsed = time.time() - start_time hybrid_trange.set_postfix({ "Value Loss": f"{losses['value']:.4f}", @@ -215,7 +325,7 @@ def hybrid_finetune(): print(f"Evaluation done in {time.time() - eval_start:.2f}s (env time: {env_time:.2f}s)") fl.flush() - if (step + 1) % 2000 == 0: + if (step + 1) % 5000 == 0: os.makedirs(cfg.results_dir, exist_ok=True) save_path = os.path.join(cfg.results_dir, f"iql_hybrid_step_{step}.pt") torch.save(agent.state_dict(), save_path) @@ -227,4 +337,4 @@ def hybrid_finetune(): if __name__ == "__main__": - hybrid_finetune() + hybrid_finetune() \ No newline at end of file diff --git a/neptune_sync.py b/neptune_sync.py deleted file mode 100755 index c8f1efa..0000000 --- a/neptune_sync.py +++ /dev/null @@ -1,72 +0,0 @@ -# Load environment variables -from dotenv import load_dotenv -load_dotenv(override=True) - -import neptune -from neptune import Run -import os -import time - -results_dir = 'results' -with open(os.path.join(results_dir, 'synced_ids'), 'r') as f: - synced_ids = [int(id) for id in f.readlines() if id.strip()] - -current_runs = [d for d in os.listdir(results_dir) if d.startswith('run_') and int(d.split('_')[1]) not in synced_ids] - -if not current_runs: - print('No new runs to sync') - exit() -print(f'Syncing runs: {current_runs}') - -with open(os.path.join(results_dir, 'synced_ids'), 'a') as f: - f.write('\n'.join(run.split('_')[1] for run in current_runs)) - f.write('\n') - -neptune_runs: dict[str, Run] = {} -for run in current_runs: - run_path = os.path.join(results_dir, run) - with open(os.path.join(run_path, 'tags'), 'r') as f: - tags = f.read().splitlines() - neptune_run = neptune.init_run( - project=os.getenv('NEPTUNE_PROJECT'), - api_token=os.getenv('NEPTUNE_TOKEN'), - tags=tags, - ) - neptune_runs[run] = neptune_run - -runs_counters: dict[str, dict[str, int]] = {run: {} for run in current_runs} - - -def kill_handler(signum, frame): - print('Killing...') - for runs in neptune_runs.values(): - runs.stop() - exit() - - -if __name__ == '__main__': - signal.signal(signal.SIGINT, kill_handler) - signal.signal(signal.SIGTERM, kill_handler) - - while True: - print('Syncing...') - for run in current_runs: - neptune_run = neptune_runs[run] - run_path = os.path.join(results_dir, run) - files: list[str] = [] - for root, _, filenames in os.walk(run_path): - relative_root = root.replace(run_path, '') - relative_root = relative_root[1:] if relative_root.startswith('/') else relative_root - for filename in filenames: - files.append(os.path.join(relative_root, filename) if relative_root else filename) - for file in files: - if file == 'tags': - continue - if file not in runs_counters[run]: - runs_counters[run][file] = 0 - read_idx = runs_counters[run][file] - with open(os.path.join(run_path, file), 'r') as f: - values = [float(line) for line in f.readlines()] - neptune_run[file].extend(values[read_idx:]) - runs_counters[run][file] = len(values) - time.sleep(60) diff --git a/train.py b/train.py index d1726fb..ca99892 100644 --- a/train.py +++ b/train.py @@ -37,15 +37,15 @@ # Set environments # run_name for /tmp/ path -env = Env(is_training=True,run_name="online_ppo_data_collection") -eval_env = Env(is_training=False,run_name="online_ppo_data_collection") +env = Env(is_training=True,run_name="online_ppo") +eval_env = Env(is_training=False,run_name="online_ppo") print_success(f"Environments initialized: {env.tmp_file}") # Set model model = Model().to(device) optimizer = torch.optim.Adam( model.parameters(), - lr=cfg.lr + lr=3e-4 ) print_success("Model initialized") diff --git a/train_iql.py b/train_iql.py index 825a742..087a017 100644 --- a/train_iql.py +++ b/train_iql.py @@ -83,7 +83,7 @@ def evaluate_benchmarks(model: IQLAgent, env: Env, step: int): obs = Observation.from_state(state) # Sample action and log-prob from *current policy* - action_index, action_log_p, entropy = model.sample(obs.to(device), greedy=True) + action_index, action_log_p, entropy = model.sample(obs.to(device)) assert action_index.size(0) == 1 and action_log_p.size(0) == 1 action = ActionSpace.action_by_index(action_index[0], state) diff --git a/viz_online.py b/viz_online.py index 1789afe..74fdb30 100644 --- a/viz_online.py +++ b/viz_online.py @@ -1,62 +1,73 @@ import pandas as pd import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + +# Define which benchmarks belong to which dataset +online_data = { + "offline": [ + "matmul_256_768_3072", + "matmul_256_2048_2048", + "matmul_256_256_512", + "matmul_256_512_1024", + ], + "online": [ + "matmul_1024_128_768", + "matmul_1024_128_512", + "matmul_1024_1024_128", + "matmul_1024_128_1024", + "matmul_1024_128_2048", + ], +} -# Your online dataset benchmarks -online_data = [ - "matmul_256_768_2", - "matmul_256_768_3072", - "matmul_256_2048_2048", - "matmul_256_256_512", - "matmul_256_1024_1024", - "matmul_256_1536_1000", - "matmul_256_256_128", - "matmul_256_512_1024", - "matmul_256_1536_4096", - "matmul_256_1408_1000", - "matmul_256_1280_1000", - "matmul_256_768_768", - "matmul_256_2048_1000", - "matmul_256_4096_1024", - "matmul_256_128_256", - "matmul_1024_128_768", - "matmul_1024_2048_128", - "matmul_1024_128_256", - "matmul_1024_128_512", - "matmul_1024_1024_128", - "matmul_1024_1024_256", - "matmul_1024_128_1024", - "matmul_1024_1536_128", - "matmul_1024_128_128", - "matmul_1024_128_2048", -] - -# Load your CSV file +# Load CSV df = pd.read_csv("online_iql.csv", sep=";") -# Exclude "average_speedup" from bar chart (optional, keep only benchmarks) -benchmarks_df = df[df["metric"] != "average_speedup"] +# Filter out "average_speedup" rows +df = df[df["metric"] != "average_speedup"] + +# Pivot table for easier comparison +pivot_df = df.pivot(index="metric", columns="algorithm", values="score") + +# Determine dataset (online/offline) for each metric +def get_dataset(metric): + if metric in online_data["online"]: + return "online" + elif metric in online_data["offline"]: + return "offline" + else: + return "unknown" -# Assign colors depending on online/offline -colors = [ - "blue" if metric in online_data else "red" - for metric in benchmarks_df["metric"] -] +pivot_df["dataset"] = pivot_df.index.map(get_dataset) -# Plot horizontal bar chart -plt.figure(figsize=(10, 6)) -plt.barh(benchmarks_df["metric"], benchmarks_df["score"], color=colors) +# Assign colors based on dataset +color_map = {"online": "blue", "offline": "red", "unknown": "gray"} +colors = pivot_df["dataset"].map(color_map) -plt.xlabel("Score") +# Plot grouped horizontal bars +ax = pivot_df[["Online Finetuned IQL", "PPO"]].plot.barh( + figsize=(10, 6), + color=["#1f77b4", "#ff7f0e"], + edgecolor="black" +) + +# Apply y-labels and color backgrounds per dataset type +for i, (dataset, metric) in enumerate(zip(pivot_df["dataset"], pivot_df.index)): + ax.get_yticklabels()[i].set_color(color_map[dataset]) + +plt.xlabel("Speedup") plt.ylabel("Benchmark") -plt.title("Online finetuning of IQL") +plt.title("Online Finetuned IQL vs PPO — Benchmark Speedup Comparison") -# Add legend manually -import matplotlib.patches as mpatches +# Create legend blue_patch = mpatches.Patch(color="blue", label="Online data") red_patch = mpatches.Patch(color="red", label="Offline data") -plt.legend(handles=[blue_patch, red_patch]) +orange_patch = mpatches.Patch(color="#ff7f0e", label="PPO") +blue_bar_patch = mpatches.Patch(color="#1f77b4", label="IQL") + +plt.legend(handles=[blue_bar_patch, orange_patch, blue_patch, red_patch], loc="best") plt.tight_layout() -plt.savefig("online_offline_iql.png") +plt.savefig("online_offline_iql_vs_ppo.png", dpi=300) +plt.show() -print("Plot saved as online_offline_iql.png") +print("✅ Plot saved as online_offline_iql_vs_ppo.png") From 24ddcdcd3255c4dcad69634f7f814d67bb641a87 Mon Sep 17 00:00:00 2001 From: Oucherif Ouail Date: Wed, 22 Oct 2025 12:00:37 +0100 Subject: [PATCH 4/5] cleaned repo --- .gitignore | 4 ++- viz.py | 24 ----------------- viz_online.py | 73 --------------------------------------------------- 3 files changed, 3 insertions(+), 98 deletions(-) delete mode 100644 viz.py delete mode 100644 viz_online.py diff --git a/.gitignore b/.gitignore index 69f932d..07a43ce 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ iql_online_fine_tuning/ offline_iql_results_1/ online_finetuning_iql/ online_iql/ -online_ppo/ \ No newline at end of file +online_ppo/ + +viz/ \ No newline at end of file diff --git a/viz.py b/viz.py deleted file mode 100644 index cea31a8..0000000 --- a/viz.py +++ /dev/null @@ -1,24 +0,0 @@ -import pandas as pd -import matplotlib.pyplot as plt - -# Load the CSV file -df = pd.read_csv("./comparaison.csv", sep=";") - -# Pivot to have algorithms as columns -pivot_df = df.pivot(index="metric", columns="algorithm", values="score") - -# Sort metrics alphabetically for consistency (optional) -pivot_df = pivot_df.sort_index() - -# Plot horizontal bars -ax = pivot_df.plot(kind="barh", figsize=(10, 7)) -plt.xlabel("Score") -plt.ylabel("Benchmark / Metric") -plt.title("Comparison of PPO vs Offline IQL across Benchmarks") -plt.legend(title="Algorithm") -plt.tight_layout() - -# Save as PNG -plt.savefig("ppo_vs_iql_comparison.png") - -print("Plot saved as ppo_vs_iql_comparison.png") diff --git a/viz_online.py b/viz_online.py deleted file mode 100644 index 74fdb30..0000000 --- a/viz_online.py +++ /dev/null @@ -1,73 +0,0 @@ -import pandas as pd -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches - -# Define which benchmarks belong to which dataset -online_data = { - "offline": [ - "matmul_256_768_3072", - "matmul_256_2048_2048", - "matmul_256_256_512", - "matmul_256_512_1024", - ], - "online": [ - "matmul_1024_128_768", - "matmul_1024_128_512", - "matmul_1024_1024_128", - "matmul_1024_128_1024", - "matmul_1024_128_2048", - ], -} - -# Load CSV -df = pd.read_csv("online_iql.csv", sep=";") - -# Filter out "average_speedup" rows -df = df[df["metric"] != "average_speedup"] - -# Pivot table for easier comparison -pivot_df = df.pivot(index="metric", columns="algorithm", values="score") - -# Determine dataset (online/offline) for each metric -def get_dataset(metric): - if metric in online_data["online"]: - return "online" - elif metric in online_data["offline"]: - return "offline" - else: - return "unknown" - -pivot_df["dataset"] = pivot_df.index.map(get_dataset) - -# Assign colors based on dataset -color_map = {"online": "blue", "offline": "red", "unknown": "gray"} -colors = pivot_df["dataset"].map(color_map) - -# Plot grouped horizontal bars -ax = pivot_df[["Online Finetuned IQL", "PPO"]].plot.barh( - figsize=(10, 6), - color=["#1f77b4", "#ff7f0e"], - edgecolor="black" -) - -# Apply y-labels and color backgrounds per dataset type -for i, (dataset, metric) in enumerate(zip(pivot_df["dataset"], pivot_df.index)): - ax.get_yticklabels()[i].set_color(color_map[dataset]) - -plt.xlabel("Speedup") -plt.ylabel("Benchmark") -plt.title("Online Finetuned IQL vs PPO — Benchmark Speedup Comparison") - -# Create legend -blue_patch = mpatches.Patch(color="blue", label="Online data") -red_patch = mpatches.Patch(color="red", label="Offline data") -orange_patch = mpatches.Patch(color="#ff7f0e", label="PPO") -blue_bar_patch = mpatches.Patch(color="#1f77b4", label="IQL") - -plt.legend(handles=[blue_bar_patch, orange_patch, blue_patch, red_patch], loc="best") - -plt.tight_layout() -plt.savefig("online_offline_iql_vs_ppo.png", dpi=300) -plt.show() - -print("✅ Plot saved as online_offline_iql_vs_ppo.png") From 6f5dd5309750d96b058cdae22cd73ad6c249d359 Mon Sep 17 00:00:00 2001 From: Oucherif Ouail Date: Wed, 22 Oct 2025 12:04:43 +0100 Subject: [PATCH 5/5] cleaned repo --- data.py | 71 ------ demo.ipynb | 128 ---------- demo.py | 33 --- eval.py | 47 ---- evaluate.py | 50 ---- filelog_clean.py | 23 -- fill_db.py | 83 ------- gen.py | 343 --------------------------- init_env.py | 42 ---- test.py | 11 - train_iql.py => train_iql_offline.py | 0 iql_online.py => train_iql_online.py | 0 train.py => train_ppo.py | 0 13 files changed, 831 deletions(-) delete mode 100644 data.py delete mode 100755 demo.ipynb delete mode 100755 demo.py delete mode 100644 eval.py delete mode 100755 evaluate.py delete mode 100755 filelog_clean.py delete mode 100755 fill_db.py delete mode 100755 gen.py delete mode 100644 init_env.py delete mode 100644 test.py rename train_iql.py => train_iql_offline.py (100%) rename iql_online.py => train_iql_online.py (100%) rename train.py => train_ppo.py (100%) diff --git a/data.py b/data.py deleted file mode 100644 index e93440b..0000000 --- a/data.py +++ /dev/null @@ -1,71 +0,0 @@ -import os - -# full list from D:\data_matmul (given by user) -all_files = [ -"matmul_1024_1024_128.mlir","matmul_1024_1024_256.mlir","matmul_1024_128_1024.mlir","matmul_1024_128_128.mlir", -"matmul_1024_128_2048.mlir","matmul_1024_128_256.mlir","matmul_1024_128_512.mlir","matmul_1024_128_768.mlir", -"matmul_1024_1536_128.mlir","matmul_1024_2048_128.mlir","matmul_1024_256_1024.mlir","matmul_1024_256_1536.mlir", -"matmul_1024_256_256.mlir","matmul_1024_256_512.mlir","matmul_1024_256_768.mlir","matmul_1024_3072_128.mlir", -"matmul_1024_512_128.mlir","matmul_1024_512_256.mlir","matmul_1024_512_512.mlir","matmul_1024_768_256.mlir", -"matmul_128_1024_1024.mlir","matmul_128_1024_128.mlir","matmul_128_1024_1536.mlir","matmul_128_1024_256.mlir", -"matmul_128_1024_512.mlir","matmul_128_1024_768.mlir","matmul_128_128_1024.mlir","matmul_128_128_128.mlir", -"matmul_128_128_1536.mlir","matmul_128_128_2048.mlir","matmul_128_128_3072.mlir","matmul_128_128_512.mlir", -"matmul_128_128_768.mlir","matmul_128_1536_1024.mlir","matmul_128_1536_128.mlir","matmul_128_1536_256.mlir", -"matmul_128_1536_512.mlir","matmul_128_1536_768.mlir","matmul_128_2048_1024.mlir","matmul_128_2048_128.mlir", -"matmul_128_2048_1536.mlir","matmul_128_2048_256.mlir","matmul_128_2048_512.mlir","matmul_128_2048_768.mlir", -"matmul_128_256_1024.mlir","matmul_128_256_128.mlir","matmul_128_256_2048.mlir","matmul_128_256_3072.mlir", -"matmul_128_256_768.mlir","matmul_128_3072_128.mlir","matmul_128_3072_256.mlir","matmul_128_3072_512.mlir", -"matmul_128_3072_768.mlir","matmul_128_512_1024.mlir","matmul_128_512_128.mlir","matmul_128_512_1536.mlir", -"matmul_128_512_2048.mlir","matmul_128_512_256.mlir","matmul_128_512_3072.mlir","matmul_128_512_512.mlir", -"matmul_128_768_1024.mlir","matmul_128_768_128.mlir","matmul_128_768_1536.mlir","matmul_128_768_256.mlir", -"matmul_128_768_3072.mlir","matmul_128_768_512.mlir","matmul_128_768_768.mlir","matmul_1536_1024_128.mlir", -"matmul_1536_128_128.mlir","matmul_1536_128_1536.mlir","matmul_1536_128_512.mlir","matmul_1536_128_768.mlir", -"matmul_1536_1536_128.mlir","matmul_1536_256_1024.mlir","matmul_1536_256_128.mlir","matmul_1536_256_256.mlir", -"matmul_1536_256_512.mlir","matmul_1536_256_768.mlir","matmul_1536_512_128.mlir","matmul_1536_512_256.mlir", -"matmul_1536_768_256.mlir","matmul_2048_128_1024.mlir","matmul_2048_128_128.mlir","matmul_2048_128_256.mlir", -"matmul_2048_128_512.mlir","matmul_2048_128_768.mlir","matmul_2048_256_128.mlir","matmul_2048_256_512.mlir", -"matmul_2048_256_768.mlir","matmul_2048_512_128.mlir","matmul_2048_512_256.mlir","matmul_2048_768_128.mlir", -"matmul_256_1024_1024.mlir","matmul_256_1024_128.mlir","matmul_256_1024_1536.mlir","matmul_256_1024_256.mlir", -"matmul_256_1024_512.mlir","matmul_256_1024_768.mlir","matmul_256_1280_1000.mlir","matmul_256_128_128.mlir", -"matmul_256_128_1536.mlir","matmul_256_128_2048.mlir","matmul_256_128_256.mlir","matmul_256_128_3072.mlir", -"matmul_256_128_512.mlir","matmul_256_128_768.mlir","matmul_256_1408_1000.mlir","matmul_256_1536_1000.mlir", -"matmul_256_1536_128.mlir","matmul_256_1536_256.mlir","matmul_256_1536_4096.mlir","matmul_256_1536_512.mlir", -"matmul_256_1536_768.mlir","matmul_256_2048_1000.mlir","matmul_256_2048_128.mlir","matmul_256_2048_2048.mlir", -"matmul_256_2048_256.mlir","matmul_256_2048_512.mlir","matmul_256_256_1024.mlir","matmul_256_256_128.mlir", -"matmul_256_256_1536.mlir","matmul_256_256_256.mlir","matmul_256_256_512.mlir","matmul_256_256_768.mlir", -"matmul_256_3072_128.mlir","matmul_256_4096_1024.mlir","matmul_256_512_1024.mlir","matmul_256_512_128.mlir", -"matmul_256_512_1536.mlir","matmul_256_512_2048.mlir","matmul_256_512_256.mlir","matmul_256_512_3072.mlir", -"matmul_256_512_512.mlir","matmul_256_512_768.mlir","matmul_256_768_1024.mlir","matmul_256_768_128.mlir", -"matmul_256_768_1536.mlir","matmul_256_768_2.mlir","matmul_256_768_256.mlir","matmul_256_768_3072.mlir", -"matmul_256_768_512.mlir","matmul_256_768_768.mlir","matmul_3072_128_128.mlir","matmul_3072_128_256.mlir", -"matmul_3072_128_512.mlir","matmul_3072_256_128.mlir","matmul_3072_256_256.mlir","matmul_3072_512_128.mlir", -"matmul_3072_512_256.mlir","matmul_3072_768_128.mlir","matmul_512_1024_128.mlir","matmul_512_1024_256.mlir", -"matmul_512_1024_512.mlir","matmul_512_128_1024.mlir","matmul_512_128_128.mlir","matmul_512_128_1536.mlir", -"matmul_512_128_2048.mlir","matmul_512_128_256.mlir","matmul_512_128_3072.mlir","matmul_512_128_512.mlir", -"matmul_512_128_768.mlir","matmul_512_1536_128.mlir","matmul_512_1536_256.mlir","matmul_512_2048_128.mlir", -"matmul_512_256_1024.mlir","matmul_512_256_128.mlir","matmul_512_256_1536.mlir","matmul_512_256_2048.mlir", -"matmul_512_256_256.mlir","matmul_512_256_512.mlir","matmul_512_256_768.mlir","matmul_512_512_1024.mlir", -"matmul_512_512_128.mlir","matmul_512_512_256.mlir","matmul_512_512_512.mlir","matmul_512_512_768.mlir", -"matmul_512_768_128.mlir","matmul_512_768_256.mlir","matmul_512_768_512.mlir","matmul_512_768_768.mlir", -"matmul_768_1024_128.mlir","matmul_768_128_1536.mlir","matmul_768_128_256.mlir","matmul_768_128_3072.mlir", -"matmul_768_128_512.mlir","matmul_768_128_768.mlir","matmul_768_1536_128.mlir","matmul_768_2048_128.mlir", -"matmul_768_2048_256.mlir","matmul_768_256_1024.mlir","matmul_768_256_128.mlir","matmul_768_256_1536.mlir", -"matmul_768_256_2048.mlir","matmul_768_256_256.mlir","matmul_768_256_768.mlir","matmul_768_3072_128.mlir", -"matmul_768_512_128.mlir","matmul_768_512_256.mlir","matmul_768_512_768.mlir","matmul_768_768_128.mlir", -"matmul_768_768_256.mlir","matmul_768_768_512.mlir" -] - -# used files subset -used_files = [ -"matmul_256_1024_1024.mlir","matmul_256_1280_1000.mlir","matmul_256_1408_1000.mlir","matmul_256_1536_1000.mlir", -"matmul_256_1536_4096.mlir","matmul_256_2048_1000.mlir","matmul_256_2048_2048.mlir","matmul_256_256_128.mlir", -"matmul_256_256_512.mlir","matmul_256_4096_1024.mlir","matmul_256_512_1024.mlir","matmul_256_768_2.mlir", -"matmul_256_768_3072.mlir","matmul_256_768_768.mlir" -] - -unused_files = (set(all_files) - set(used_files)) -unused_files = [f for f in unused_files if f.endswith('.mlir') and f.startswith('matmul_')] - -unused_files = sorted(unused_files) - -print(unused_files[:10]) \ No newline at end of file diff --git a/demo.ipynb b/demo.ipynb deleted file mode 100755 index 8d74a9a..0000000 --- a/demo.ipynb +++ /dev/null @@ -1,128 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 4, - "id": "2e74d0c8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/bin/bash: line 1: .environment: No such file or directory\n" - ] - } - ], - "source": [ - "# Setup environment\n", - "import os\n", - "from dotenv import load_dotenv\n", - "load_dotenv(override=True)\n", - "os.chdir(os.path.dirname(os.getcwd()))\n", - "\n", - "!source .environment" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "29ec62e8", - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'mlir'", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Import modules\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01menv\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Env\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m HiearchyModel \u001b[38;5;28;01mas\u001b[39;00m Model\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mppo\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m evaluate_benchmark\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/MLIR-RL/rl_autoschedular/env.py:4\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mstate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m OperationState, BenchmarkFeatures, extract_bench_features_from_file\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Optional\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mevaluation\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m evaluate_code\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrl_autoschedular\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mactions\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Action\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mlog\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m print_error\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/MLIR-RL/rl_autoschedular/evaluation.py:3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmlir\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mir\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Context, Module\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmlir\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mexecution_engine\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ExecutionEngine, ctypes\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmlir\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mruntime\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_ranked_memref_descriptor\n", - "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'mlir'" - ] - } - ], - "source": [ - "# Import modules\n", - "import torch\n", - "from rl_autoschedular.env import Env\n", - "from rl_autoschedular.model import HiearchyModel as Model\n", - "from rl_autoschedular.ppo import evaluate_benchmark" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27867e09", - "metadata": {}, - "outputs": [], - "source": [ - "# Configure torch\n", - "device = torch.device(\"cpu\")\n", - "torch.set_grad_enabled(False)\n", - "torch.set_num_threads(4)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "240832d2", - "metadata": {}, - "outputs": [], - "source": [ - "# Initiate environment\n", - "eval_env = Env(is_training=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6157cb8a", - "metadata": {}, - "outputs": [], - "source": [ - "# Load the model\n", - "model_path = \"models/model.pth\"\n", - "model = Model()\n", - "model.load_state_dict(torch.load(model_path, weights_only=True))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45ab965f", - "metadata": {}, - "outputs": [], - "source": [ - "evaluate_benchmark(\n", - " model,\n", - " eval_env,\n", - " device,\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mlir-venv (3.11.13)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/demo.py b/demo.py deleted file mode 100755 index 45ec519..0000000 --- a/demo.py +++ /dev/null @@ -1,33 +0,0 @@ - -# Setup environment -from dotenv import load_dotenv -load_dotenv(override=True) - - -# Import modules -import torch -from rl_autoschedular.env import Env -from rl_autoschedular.model import HiearchyModel as Model -from rl_autoschedular.ppo import evaluate_benchmark - - -# Configure torch -torch.set_grad_enabled(False) -torch.set_num_threads(4) - - -# Instantiate the environment -eval_env = Env(is_training=False) - - -# Load the model -model_path = "models/model.pth" -model = Model() -model.load_state_dict(torch.load(model_path, weights_only=True)) - - -# Evaluate the model -evaluate_benchmark( - model, - eval_env, -) diff --git a/eval.py b/eval.py deleted file mode 100644 index 2941309..0000000 --- a/eval.py +++ /dev/null @@ -1,47 +0,0 @@ -# Load environment variables -from dotenv import load_dotenv -load_dotenv(override=True) - - -import torch -import os -from typing import Optional -from utils.log import print_info, print_success - -# Import environment -from rl_autoschedular.env import Env - -# config, file_logger, device -from rl_autoschedular import config as cfg, file_logger as fl, device -from rl_autoschedular.ppo import evaluate_benchmarks - -# Import RL components -from rl_autoschedular.model import HiearchyModel as Model -import time - -torch.set_grad_enabled(False) -torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "4"))) - - - -print_info(f"Config: {cfg}") -print_success(f'Logging to: {fl.run_dir}') - -# Set environments -eval_env = Env(is_training=False,run_name="ppo_online_eval") -print_success(f"Environments initialized: {eval_env.tmp_file}") - -# Set model -model_chkpt = "./checkpoints/model.pth" -model = Model().to(device) -checkpoint = torch.load(model_chkpt, map_location="cpu") -model.load_state_dict(checkpoint, strict=False) # allow partial load -model.eval() - -env_time = evaluate_benchmarks( - model, - eval_env, - step=1 -) - -print(env_time) \ No newline at end of file diff --git a/evaluate.py b/evaluate.py deleted file mode 100755 index 8886372..0000000 --- a/evaluate.py +++ /dev/null @@ -1,50 +0,0 @@ -# Load environment variables -from dotenv import load_dotenv -load_dotenv(override=True) - -# Import modules -from rl_autoschedular.env import Env -from rl_autoschedular.model import HiearchyModel as Model -import torch -import os -from tqdm import tqdm -from rl_autoschedular import config as cfg -from rl_autoschedular import file_logger as fl -from utils.log import print_info, print_success -from rl_autoschedular.ppo import evaluate_benchmark - -torch.set_grad_enabled(False) -torch.set_num_threads(4) - -print_info(f"Config: {cfg}") -print_success(f'Logging to: {fl.run_dir}') - -# Set environments -eval_env = Env(is_training=False) -print_success(f"Environments initialized: {eval_env.tmp_file}") - -# Start training -eval_dir = os.getenv('EVAL_DIR') -if eval_dir is None: - raise ValueError("EVAL_DIR environment variable is not set.") -eval_dir = os.path.abspath(eval_dir) - -# Read the files in the evaluation directory -eval_files = [f for f in os.listdir(eval_dir) if f.endswith('.pth')] - -# Order files -eval_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) - -files_tqdm = tqdm(eval_files, desc='Evaluating models') -for model_file in files_tqdm: - files_tqdm.set_postfix_str(f"Evaluating {model_file}") - model = Model() - model_path = os.path.join(eval_dir, model_file) - if not os.path.exists(model_path): - print_info(f"Model file {model_path} does not exist. Skipping.") - continue - model.load_state_dict(torch.load(model_path, weights_only=True)) - evaluate_benchmark( - model, - eval_env, - ) diff --git a/filelog_clean.py b/filelog_clean.py deleted file mode 100755 index a8c6b83..0000000 --- a/filelog_clean.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - - -results_dir = 'results' -with open(os.path.join(results_dir, 'synced_ids'), 'r') as f: - synced_ids = [int(id) for id in f.readlines() if id.strip()] - -current_runs = [d for d in os.listdir(results_dir) if d.startswith('run_') and int(d.split('_')[1]) in synced_ids] - -if not current_runs: - print('No leftover runs to clean') - exit() - -print(f'Cleaning runs: {current_runs}') -for run in current_runs: - run_path = os.path.join(results_dir, run) - for root, _, filenames in os.walk(run_path, topdown=False): - for filename in filenames: - os.remove(os.path.join(root, filename)) - os.rmdir(root) - -with open(os.path.join(results_dir, 'synced_ids'), 'w') as f: - pass diff --git a/fill_db.py b/fill_db.py deleted file mode 100755 index 588a944..0000000 --- a/fill_db.py +++ /dev/null @@ -1,83 +0,0 @@ -from rl_autoschedular.env import Env -from rl_autoschedular.model import apply_masks, extract_masks, indices_to_raw_actions -import torch -import math -from torch.distributions import Categorical, Distribution, Uniform -from rl_autoschedular import config as cfg -from tqdm import tqdm - - -N = cfg.num_transformations -L = cfg.max_num_loops -TS = cfg.num_tile_sizes -match cfg.interchange_mode: - case 'enumerate': - interchange_mask = 3 * L - 6 - case 'pointers': - interchange_mask = L - case 'continuous': - interchange_mask = 0 -action_mask_size = N + 2 * L * (TS + 1) + interchange_mask - - -def create_uniform_distributions(obs: torch.Tensor, num_loops: list[int]) -> tuple[Distribution, Distribution, Distribution, Distribution]: - """Create uniform distributions for the actions. - - Args: - obs (torch.Tensor): The input tensor. - - Returns: - tuple[Distribution, Distribution, Distribution, Distribution]: The uniform distributions for the transformations, parallelizations, tilings, and interchanges. - """ - batch_size = obs.shape[0] - action_mask = obs[:, -(action_mask_size):].bool() - - transformation_logits = torch.zeros((batch_size, N), dtype=torch.float32) - parallelization_logits = torch.zeros((batch_size, L, TS + 1), dtype=torch.float32) - tiling_logits = torch.zeros((batch_size, L, TS + 1), dtype=torch.float32) - match cfg.interchange_mode: - case 'enumerate': - interchange_logits = torch.zeros((batch_size, 3 * L - 6), dtype=torch.float32) - case 'pointers': - interchange_logits = torch.zeros((batch_size, L), dtype=torch.float32) - case 'continuous': - interchange_logits = torch.zeros((batch_size, 1), dtype=torch.float32) - - # Apply masks on logits - transformation_logits, parallelization_logits, tiling_logits, interchange_logits = apply_masks(transformation_logits, parallelization_logits, tiling_logits, interchange_logits, *extract_masks(action_mask)) - - # Create distributions with the masked probabilities - transformation_dist = Categorical(logits=transformation_logits) - parallelization_dist = Categorical(logits=parallelization_logits) - tiling_dist = Categorical(logits=tiling_logits) - if cfg.interchange_mode != 'continuous': - interchange_dist = Categorical(logits=interchange_logits) - else: - total_count = torch.tensor([math.factorial(loops) for loops in num_loops], dtype=torch.float64) - interchange_dist = Uniform(0.0, total_count) - - return transformation_dist, parallelization_dist, tiling_dist, interchange_dist - - -if __name__ == "__main__": - env = Env(is_training=True) - print(f"Environments initialized: {env.tmp_file}") - - pbar = tqdm(unit="bench") - while True: - state, obs = env.reset() - bench_done = False - while not bench_done: - num_loops = len(state.operation_features.nested_loops) - transformation_eps_dist, parallelization_eps_dist, tiling_eps_dist, interchange_eps_dist = create_uniform_distributions(obs, [num_loops]) - transformation_index = transformation_eps_dist.sample() - parallelization_index = parallelization_eps_dist.sample() - tiling_index = tiling_eps_dist.sample() - interchange_index = interchange_eps_dist.sample().long() - actions = indices_to_raw_actions(transformation_index, parallelization_index, tiling_index, interchange_index, [num_loops]) - next_state, next_obs, _, op_done, _ = env.step(state, actions[0]) - if op_done: - next_state, next_obs, bench_done = env.get_next_op_state(next_state) - state = next_state - obs = next_obs - pbar.update(1) diff --git a/gen.py b/gen.py deleted file mode 100755 index acd2e9d..0000000 --- a/gen.py +++ /dev/null @@ -1,343 +0,0 @@ -from rl_autoschedular import config as cfg -from rl_autoschedular.state import OperationFeatures, NestedLoopFeatures -import random -import re -import math -import os -import sys -import json -from tqdm import trange -import numpy as np -from mlir.ir import Context, Module -from mlir.execution_engine import ExecutionEngine, ctypes -from mlir.runtime import get_ranked_memref_descriptor -from mlir.passmanager import PassManager - - -output_dir = 'data/features' -inputs: dict[str, np.ndarray] = {} - - -pass_pipeline = """builtin.module( - loop-invariant-code-motion, - canonicalize, - convert-vector-to-scf, - convert-linalg-to-loops, - buffer-deallocation-pipeline, - convert-bufferization-to-memref, - scf-forall-to-parallel, - convert-scf-to-openmp, - expand-strided-metadata, - finalize-memref-to-llvm, - convert-scf-to-cf, - lower-affine, - - convert-openmp-to-llvm, - convert-vector-to-llvm, - convert-math-to-llvm, - convert-func-to-llvm, - convert-index-to-llvm, - convert-arith-to-llvm, - convert-cf-to-llvm, - - reconcile-unrealized-casts, - canonicalize, - cse -)""" - - -def gen_features() -> OperationFeatures: - # Nested loops - num_loops = random.randint(1, cfg.max_num_loops) - reduction_count = random.randint(0, min(3, num_loops - 1)) - iterator_types = ['parallel'] * (num_loops - reduction_count) + ['reduction'] * reduction_count - max_iterations = 10 ** 9 - max_per_loop = math.ceil(max_iterations ** (1 / num_loops)) - iterations = max_iterations - while iterations >= max_iterations: - iterations = 1 - upper_bounds = [] - for _ in range(num_loops): - upper_bound = random.randint(2, min(max_per_loop * 2, 4096)) - upper_bounds.append(upper_bound) - iterations *= upper_bound - nested_loops = [ - NestedLoopFeatures( - arg=f'd{i}', - lower_bound=0, - upper_bound=upper_bounds[i], - step=1, - iterator_type=iterator_types[i] - ) - for i in range(num_loops) - ] - - # Operators count - total_op_count = 0 - while total_op_count == 0: - op_count = { - '+': random.randint(0, 10), - '-': random.randint(0, 10), - '*': random.randint(0, 10), - '/': 0, # TODO: Figure out how to handle division - 'exp': random.randint(0, 2), - } - total_op_count = sum(op_count.values()) - - # Load data - max_load_size = 2 ** 24 - num_loads = random.randint(1, cfg.max_num_stores_loads) - load_data: list[list[str]] = [] - per_loop = max(math.ceil(iterations ** (1 / num_loops)), 2) - max_dim = math.ceil(math.log(max_load_size) / math.log(per_loop)) - args_dict = {loop.arg: loop.upper_bound for loop in nested_loops} - unseen_args = set(args_dict.keys()) - for _ in range(num_loads - 1): - load_size = max_load_size - while load_size >= max_load_size: - dims_count = random.randint(1, min(cfg.max_num_load_store_dim, max_dim)) - zeros_count = random.randint(max(0, dims_count - num_loops), dims_count) - load_args = random.sample(list(args_dict.keys()) + ['0'], dims_count, counts=[1] * num_loops + [zeros_count]) - load_size = 1 - for arg in load_args: - if arg == '0': - load_size *= 5 - else: - load_size *= args_dict[arg] - load_data.append(load_args) - for arg in load_args: - unseen_args.discard(arg) - if unseen_args: - load_data.append(list(unseen_args)) - - # Store data - p_args = [loop.arg for loop in nested_loops if loop.iterator_type == 'parallel'] - random.shuffle(p_args) - store_data = p_args - - return OperationFeatures( - raw_operation='', - operation_type='generic', - op_count=op_count, - load_data=load_data, - store_data=store_data, - nested_loops=nested_loops, - vectorizable=True - ) - - -def create_params(op_features: OperationFeatures) -> tuple[list[str], list[str]]: - params = [] - shapes = [] - args_dict = {loop.arg: loop.upper_bound for loop in op_features.nested_loops} - - # Load params - for i, load in enumerate(op_features.load_data): - shape: list[int] = [] - for arg in load: - if arg == '0': - shape.append(random.randint(1, 5)) - continue - shape.append(args_dict[arg]) - # inputs[f'arg{i}'] = np.random.rand(*shape) * 100 - inputs[f'arg{i}'] = np.empty(shape) - params.append(f'%arg{i}') - shapes.append(f"memref<{'x'.join(map(str, shape))}xf64>") - - # Store param - shape = [] - for arg in op_features.store_data: - if arg == '0': - shape.append(random.randint(1, 5)) - continue - shape.append(args_dict[arg]) - # inputs[f'arg{len(params)}'] = np.zeros(shape) - inputs[f'arg{len(params)}'] = np.empty(shape) - params.append(f'%arg{len(params)}') - shapes.append(f"memref<{'x'.join(map(str, shape))}xf64>") - - return params, shapes - - -def create_raw_operation(op_features: OperationFeatures, params: list[str], shapes: list[str]) -> str: - # Affine maps - base_dims = ', '.join([loop.arg for loop in op_features.nested_loops]) - affine_maps = [] - for load in op_features.load_data: - affine_maps.append(f"affine_map<({base_dims}) -> ({', '.join(load)})>") - affine_maps.append(f"affine_map<({base_dims}) -> ({', '.join(op_features.store_data)})>") - affine_maps_attr = f"[{', '.join(affine_maps)}]" - - # Iterators - iterators = ', '.join([f'"{loop.iterator_type}"' for loop in op_features.nested_loops]) - iterators_attr = f'[{iterators}]' - - # Inputs / Outputs - ins = f"ins({', '.join(params[:-1])}: {', '.join(shapes[:-1])})" - outs = f"outs({params[-1]}: {shapes[-1]})" - - code = f"linalg.generic {{indexing_maps={affine_maps_attr}, iterator_types={iterators_attr}}} {ins} {outs} {{\n" - block_args = [f"%in_{i}: f64" for i in range(len(op_features.load_data))] + ["%out: f64"] - code += f"^bb0({', '.join(block_args)}):\n" - - # Linalg body - block_params = [arg.split(':')[0] for arg in block_args] - unused_block_params = set(block_params.copy()) - created_args: set[str] = set() - tmp_count = 0 - op_count_copy = {op: count for op, count in op_features.op_count.items() if count > 0} - assert all(op_count_copy.values()) - total_op_count = sum(op_count_copy.values()) - for _ in range(total_op_count): - op = random.choice(list(op_count_copy.keys())) - if op == 'exp': - if len(unused_block_params) > 0: - operands = random.sample(list(unused_block_params), 1) - unused_block_params.difference_update(operands) - else: - operands = random.sample(list(created_args) + block_params, 1) - else: - if len(unused_block_params) > 1: - operands = random.sample(list(unused_block_params), 2) - unused_block_params.difference_update(operands) - elif len(unused_block_params) == 1: - operands = [unused_block_params.pop()] - unused_block_params = set() - operands += random.sample(list(created_args) + block_params, 1) - else: - operands = random.sample(list(created_args) + block_params, 2) - - result = f"%{tmp_count}" - tmp_count += 1 - created_args.add(result) - match op: - case '+': - code += f"{result} = arith.addf {operands[0]}, {operands[1]} fastmath : f64\n" - case '-': - code += f"{result} = arith.subf {operands[0]}, {operands[1]} fastmath : f64\n" - case '*': - code += f"{result} = arith.mulf {operands[0]}, {operands[1]} fastmath : f64\n" - case '/': - code += f"{result} = arith.divf {operands[0]}, {operands[1]} fastmath : f64\n" - case 'exp': - code += f"{result} = math.exp {operands[0]} fastmath : f64\n" - - op_count_copy[op] -= 1 - if op_count_copy[op] == 0: - del op_count_copy[op] - - assert sum(op_count_copy.values()) == 0 - - code += f"linalg.yield {result} : f64\n" - code += "}\n" - - return code - - -def formatMLIRCode(code: str) -> str: - """Util function that format the MLIR code by adding indents. - - Args: - code (str): the MLIR code - - Returns: - str: the formatted MLIR code - """ - lines = re.sub(r'\n+', '\n', code).split('\n') - result = '' - indent = 0 - for line in lines: - if len(line) > 0: - if line[0] == '}': - if indent > 0: - indent -= 1 - else: - indent = 0 - - result += indent * ' ' + line + '\n' - - if len(line) > 0: - if line[-1] == '{': - indent += 1 - - return result - - -def gen_full_code() -> str: - op_features = gen_features() - - params, shapes = create_params(op_features) - main_params = [f'{param}: {shape}' for param, shape in zip(params, shapes)] - - raw_operation = create_raw_operation(op_features, params, shapes) - - code = ( - f'func.func private @nanoTime() -> i64 attributes {{ llvm.emit_c_interface }}\n' - f'func.func @main({", ".join(main_params)}) -> i64 attributes {{ llvm.emit_c_interface }} {{\n' - f'%t0 = func.call @nanoTime() : () -> i64\n' - f'{raw_operation}\n' - f'%t1 = func.call @nanoTime() : () -> i64\n' - f'%t2 = arith.subi %t1, %t0 : i64\n' - f'return %t2 : i64\n' - f'}}\n' - ) - - code = formatMLIRCode(code) - - return code - - -if __name__ == '__main__': - with open('execution_times.json', 'r') as file: - execution_times: dict[str, int] = json.load(file) - last_count = max([int(k.split('_')[-1]) for k in execution_times.keys()]) + 1 - for i in trange(last_count, 10000, desc='Generating benchmarks', unit='bench'): - bench_generated = False - while not bench_generated: - bench_name = f'generic_{i}' - bench_output = os.path.join(output_dir, f"{bench_name}.mlir") - - inputs = {} - code = gen_full_code() - with Context(): - module = Module.parse(code) - pm = PassManager.parse(pass_pipeline) - pm.run(module.operation) - execution_engine = ExecutionEngine( - module, - shared_libs=os.getenv("MLIR_SHARED_LIBS", "").split(","), - ) - arg_names = sorted(inputs.keys()) - # np.savez(f"{bench_output}.npz", **inputs) - - c_args = [] - for arg_name in arg_names: - c_args.append(ctypes.pointer(ctypes.pointer( - get_ranked_memref_descriptor(inputs[arg_name]) - ))) - delta_arg = (ctypes.c_int64 * 1)(0) - c_args.append(delta_arg) - - try: - execution_engine.invoke("main", *c_args) - execution_engine.invoke("main", *c_args) - except Exception as e: - print(f"Failed, Bench: {bench_name}, error: {e}", file=sys.stderr) - # os.remove(f'{bench_output}.npz') - continue - - exec_time = delta_arg[0] - if exec_time >= (1 * 10**9): - # os.remove(f'{bench_output}.npz') - continue - - with open(bench_output, 'w') as f: - f.write(code) - # expected = inputs[arg_names[-1]] - # np.save(f"{bench_output}.npy", expected) - - execution_times[bench_name] = exec_time - with open('execution_times.json', 'w') as file: - json.dump(execution_times, file, indent=4) - - bench_generated = True diff --git a/init_env.py b/init_env.py deleted file mode 100644 index 494c5a9..0000000 --- a/init_env.py +++ /dev/null @@ -1,42 +0,0 @@ -# Load environment variables -from dotenv import load_dotenv -load_dotenv(override=True) - - -import torch -import os -from typing import Optional -from utils.log import print_info, print_success - -# Import environment -from rl_autoschedular.env import Env - -# config, file_logger, device -from rl_autoschedular import config as cfg, file_logger as fl, device - -# Import RL components -from rl_autoschedular.model import HiearchyModel as Model -from rl_autoschedular.trajectory import TrajectoryData -from rl_autoschedular.ppo import ( - collect_trajectory, - ppo_update, - value_update, - evaluate_benchmark -) - - -torch.set_grad_enabled(False) -torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "4"))) -if cfg.debug: - torch.autograd.set_detect_anomaly(True) - -print_info(f"Config: {cfg}") -print_success(f'Logging to: {fl.run_dir}') - -# Set environments -env = Env(is_training=True) -env.save_benchmarks_data_to_json("my_benchmarks.json") - -eval_env = Env(is_training=False, tmp_file=env.tmp_file) - -print_success(f"Environments initialized: {env.tmp_file}") \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index ae80a2b..0000000 --- a/test.py +++ /dev/null @@ -1,11 +0,0 @@ -from torch.utils.tensorboard import SummaryWriter -from collections import defaultdict -from utils.data_collector import OfflineDataset - - -dt = OfflineDataset(save_dir="offline_dataset",fname="offline_dataset_online_ppo.npz") - - -a = dt.load() - -print(a[''][100:150]) \ No newline at end of file diff --git a/train_iql.py b/train_iql_offline.py similarity index 100% rename from train_iql.py rename to train_iql_offline.py diff --git a/iql_online.py b/train_iql_online.py similarity index 100% rename from iql_online.py rename to train_iql_online.py diff --git a/train.py b/train_ppo.py similarity index 100% rename from train.py rename to train_ppo.py