From 119b95972e8bd6249ec1ca158469edea6efba7d3 Mon Sep 17 00:00:00 2001 From: Arnav Kapoor Date: Mon, 2 Feb 2026 16:55:10 +0530 Subject: [PATCH 1/7] Refactor progress bar backend to allow user choice and decouple from rich - Create progress utility module with configurable backends (rich, tqdm, none) - Update config default to 'simple' and fix metadata - Make rich an optional dependency - Update all hard-coded rich imports to use the utility - Handle progress bar selection in tabular_model callbacks --- pyproject.toml | 2 +- src/pytorch_tabular/categorical_encoders.py | 5 +- src/pytorch_tabular/config/config.py | 4 +- src/pytorch_tabular/feature_extractor.py | 4 +- src/pytorch_tabular/tabular_model.py | 6 +- src/pytorch_tabular/tabular_model_sweep.py | 6 +- src/pytorch_tabular/tabular_model_tuner.py | 4 +- src/pytorch_tabular/utils/progress.py | 74 +++++++++++++++++++++ 8 files changed, 91 insertions(+), 14 deletions(-) create mode 100644 src/pytorch_tabular/utils/progress.py diff --git a/pyproject.toml b/pyproject.toml index c8debec0..94df992f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ dependencies = [ "omegaconf>=2.3.0", "torchmetrics>=0.10.0,<1.9.0", "einops>=0.6.0,<0.9.0", - "rich>=11.0.0", "scikit-base", ] @@ -73,6 +72,7 @@ extra = [ "kaleido>=0.2.0,<0.3.0", "captum>=0.5.0,<0.8.0", "pytorch-tabnet<4.2", + "rich>=11.0.0", ] notebooks = [ diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index b3d7a1ee..e7f8d2e7 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -12,7 +12,7 @@ import pickle import numpy as np -from rich.progress import track +from pytorch_tabular.utils.progress import get_progress_tracker from sklearn.base import BaseEstimator, TransformerMixin from pytorch_tabular.utils import get_logger @@ -234,10 +234,9 @@ def transform(self, X: DataFrame, y=None) -> DataFrame: assert all(c in X.columns for c in self.cols) X_encoded = X.copy(deep=True) - for col, mapping in track( + for col, mapping in get_progress_tracker("rich")( self._mapping.items(), description="Encoding the data...", - total=len(self._mapping.values()), ): for dim in range(mapping[self.NAN_CATEGORY].shape[0]): X_encoded.loc[:, f"{col}_embed_dim_{dim}"] = ( diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index ad7bda85..6ddc775f 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -356,7 +356,7 @@ class TrainerConfig: track_grad_norm (int): Track and Log Gradient Norms in the logger. -1 by default means no tracking. 1 for the L1 norm, 2 for L2 norm, etc. - progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`. + progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `simple`. precision (str): Precision of the model. Defaults to `32`. See https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision @@ -541,7 +541,7 @@ class TrainerConfig: ) progress_bar: str = field( default="simple", - metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."}, + metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `simple`."}, ) precision: str = field( default="32", diff --git a/src/pytorch_tabular/feature_extractor.py b/src/pytorch_tabular/feature_extractor.py index 424a03e4..2ff64bda 100644 --- a/src/pytorch_tabular/feature_extractor.py +++ b/src/pytorch_tabular/feature_extractor.py @@ -4,7 +4,7 @@ from collections import defaultdict import pandas as pd -from rich.progress import track +from pytorch_tabular.utils.progress import get_progress_tracker from sklearn.base import BaseEstimator, TransformerMixin from pytorch_tabular.models import NODEModel, TabNetModel @@ -65,7 +65,7 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: self.tabular_model.model.eval() inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded) logits_predictions = defaultdict(list) - for batch in track(inference_dataloader, description="Generating Features..."): + for batch in get_progress_tracker("rich")(inference_dataloader, description="Generating Features..."): for k, v in batch.items(): if isinstance(v, list) and (len(v) == 0): # Skipping empty list diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 5c402d1a..fd95d03c 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -25,7 +25,7 @@ from omegaconf.dictconfig import DictConfig from pandas import DataFrame from pytorch_lightning import seed_everything -from pytorch_lightning.callbacks import RichProgressBar +from pytorch_lightning.callbacks import RichProgressBar, TQDMProgressBar from pytorch_lightning.callbacks.gradient_accumulation_scheduler import ( GradientAccumulationScheduler, ) @@ -321,6 +321,10 @@ def _prepare_callbacks(self, callbacks=None) -> List: self.config.enable_checkpointing = False if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get("enable_progress_bar", True): callbacks.append(RichProgressBar()) + elif self.config.progress_bar == "simple" and self.config.trainer_kwargs.get("enable_progress_bar", True): + callbacks.append(TQDMProgressBar()) + elif self.config.progress_bar == "none": + self.config.trainer_kwargs["enable_progress_bar"] = False if self.verbose: logger.debug(f"Callbacks used: {callbacks}") return callbacks diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index fc97140e..15cd43df 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from rich.progress import Progress, track +from pytorch_tabular.utils.progress import get_progress_context, get_progress_tracker from skbase.utils.dependencies import _check_soft_dependencies from pytorch_tabular import TabularModel, models @@ -321,8 +321,8 @@ def _init_tabular_model(m): best_model = None is_lower_better = rank_metric[1] == "lower_is_better" best_score = 1e9 if is_lower_better else -1e9 - it = track(model_list, description="Sweeping Models") if progress_bar else model_list - ctx = Progress() if progress_bar else nullcontext() + it = get_progress_tracker("rich" if progress_bar else "none")(model_list, description="Sweeping Models") + ctx = get_progress_context("rich" if progress_bar else "none") with ctx as progress: if progress_bar: task_p = progress.add_task("Sweeping Models", total=len(model_list)) diff --git a/src/pytorch_tabular/tabular_model_tuner.py b/src/pytorch_tabular/tabular_model_tuner.py index d199d1fb..d4142486 100644 --- a/src/pytorch_tabular/tabular_model_tuner.py +++ b/src/pytorch_tabular/tabular_model_tuner.py @@ -13,7 +13,7 @@ import pandas as pd from omegaconf.dictconfig import DictConfig from pandas import DataFrame -from rich.progress import Progress +from pytorch_tabular.utils.progress import get_progress_context from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler from pytorch_tabular.config import ( @@ -255,7 +255,7 @@ def tune( verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False) - with Progress() as progress: + with get_progress_context("rich" if progress_bar else "none") as progress: model_config_iterator = range(len(self.model_config)) if progress_bar: model_config_iterator = progress.track( diff --git a/src/pytorch_tabular/utils/progress.py b/src/pytorch_tabular/utils/progress.py new file mode 100644 index 00000000..9428537e --- /dev/null +++ b/src/pytorch_tabular/utils/progress.py @@ -0,0 +1,74 @@ +"""Progress bar utilities for PyTorch Tabular.""" + +from contextlib import nullcontext +from functools import partial +from typing import Any, Callable, Iterator, Optional + + +class DummyProgress: + """A dummy progress class that mimics rich.Progress but does nothing.""" + + def add_task(self, *args, **kwargs): + return None + + def update(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def track(self, iterable, *args, **kwargs): + return iterable + + +def get_progress_tracker(backend: str = "rich", description: Optional[str] = None) -> Callable[[Iterator], Iterator]: + """Get a progress tracker function based on the backend. + + Args: + backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'. + description: Description for the progress bar. + + Returns: + A function that takes an iterable and returns an iterator with progress tracking. + """ + if backend == "rich": + try: + from rich.progress import track + return partial(track, description=description) if description else track + except ImportError: + # Fallback to none if rich is not available + return lambda it: it + elif backend == "tqdm": + try: + from tqdm.auto import tqdm + return partial(tqdm, desc=description) if description else tqdm + except ImportError: + return lambda it: it + else: # none + return lambda it: it + + +def get_progress_context(backend: str = "rich"): + """Get a progress context manager based on the backend. + + Args: + backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'. + + Returns: + A context manager for progress tracking that has a track method. + """ + if backend == "rich": + try: + from rich.progress import Progress + return Progress() + except ImportError: + return DummyProgress() + elif backend == "tqdm": + # tqdm doesn't have a context manager like rich's Progress + # For now, return DummyProgress + return DummyProgress() + else: + return DummyProgress() \ No newline at end of file From 3b31e4633686bd97a42d36ee88ff4ad43474b804 Mon Sep 17 00:00:00 2001 From: Arnav Kapoor Date: Wed, 4 Feb 2026 01:17:25 +0530 Subject: [PATCH 2/7] Fix hard-coded rich usage in progress tracking - Change internal utilities to use 'none' backend to avoid rich dependency - Update sweep and tuner to use 'simple' instead of 'rich' when progress enabled - Refactor predict method to use progress utility instead of manual backend checks - Ensure all progress usage is configurable and doesn't force rich dependency --- src/pytorch_tabular/categorical_encoders.py | 2 +- src/pytorch_tabular/feature_extractor.py | 2 +- src/pytorch_tabular/tabular_model.py | 17 ++++------------- src/pytorch_tabular/tabular_model_sweep.py | 4 ++-- src/pytorch_tabular/tabular_model_tuner.py | 2 +- 5 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index e7f8d2e7..00a41735 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -234,7 +234,7 @@ def transform(self, X: DataFrame, y=None) -> DataFrame: assert all(c in X.columns for c in self.cols) X_encoded = X.copy(deep=True) - for col, mapping in get_progress_tracker("rich")( + for col, mapping in get_progress_tracker("none")( self._mapping.items(), description="Encoding the data...", ): diff --git a/src/pytorch_tabular/feature_extractor.py b/src/pytorch_tabular/feature_extractor.py index 2ff64bda..af51bdba 100644 --- a/src/pytorch_tabular/feature_extractor.py +++ b/src/pytorch_tabular/feature_extractor.py @@ -65,7 +65,7 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: self.tabular_model.model.eval() inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded) logits_predictions = defaultdict(list) - for batch in get_progress_tracker("rich")(inference_dataloader, description="Generating Features..."): + for batch in get_progress_tracker("none")(inference_dataloader, description="Generating Features..."): for k, v in batch.items(): if isinstance(v, list) and (len(v) == 0): # Skipping empty list diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index fd95d03c..bb927992 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -1234,13 +1234,13 @@ def _generate_predictions( quantiles, n_samples, ret_logits, - progress_bar, + progress_tracker, is_probabilistic, ): point_predictions = [] quantile_predictions = [] logits_predictions = defaultdict(list) - for batch in progress_bar(inference_dataloader): + for batch in progress_tracker(inference_dataloader): for k, v in batch.items(): if isinstance(v, list) and (len(v) == 0): continue # Skipping empty list @@ -1377,23 +1377,14 @@ def _predict( inference_dataloader = self.datamodule.prepare_inference_dataloader(test) is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic - if progress_bar == "rich": - from rich.progress import track - - progress_bar = partial(track, description="Generating Predictions...") - elif progress_bar == "tqdm": - from tqdm.auto import tqdm - - progress_bar = partial(tqdm, description="Generating Predictions...") - else: - progress_bar = lambda it: it # E731 + progress_tracker = get_progress_tracker(progress_bar or "none", description="Generating Predictions...") point_predictions, quantile_predictions, logits_predictions = self._generate_predictions( model, inference_dataloader, quantiles, n_samples, ret_logits, - progress_bar, + progress_tracker, is_probabilistic, ) pred_df = self._format_predicitons( diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index 15cd43df..5b654675 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -321,8 +321,8 @@ def _init_tabular_model(m): best_model = None is_lower_better = rank_metric[1] == "lower_is_better" best_score = 1e9 if is_lower_better else -1e9 - it = get_progress_tracker("rich" if progress_bar else "none")(model_list, description="Sweeping Models") - ctx = get_progress_context("rich" if progress_bar else "none") + it = get_progress_tracker("simple" if progress_bar else "none")(model_list, description="Sweeping Models") + ctx = get_progress_context("simple" if progress_bar else "none") with ctx as progress: if progress_bar: task_p = progress.add_task("Sweeping Models", total=len(model_list)) diff --git a/src/pytorch_tabular/tabular_model_tuner.py b/src/pytorch_tabular/tabular_model_tuner.py index d4142486..35984739 100644 --- a/src/pytorch_tabular/tabular_model_tuner.py +++ b/src/pytorch_tabular/tabular_model_tuner.py @@ -255,7 +255,7 @@ def tune( verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False) - with get_progress_context("rich" if progress_bar else "none") as progress: + with get_progress_context("simple" if progress_bar else "none") as progress: model_config_iterator = range(len(self.model_config)) if progress_bar: model_config_iterator = progress.track( From 2d55d22f2f0a43ba229951ed9f30738a0b6d067f Mon Sep 17 00:00:00 2001 From: Arnav Kapoor Date: Wed, 4 Feb 2026 01:19:29 +0530 Subject: [PATCH 3/7] Change progress utility defaults to 'none' to avoid rich dependency - Update get_progress_tracker and get_progress_context defaults from 'rich' to 'none' - Ensures no hard-coded rich usage in utility functions - Rich is now truly optional with graceful fallback --- src/pytorch_tabular/utils/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_tabular/utils/progress.py b/src/pytorch_tabular/utils/progress.py index 9428537e..07fd209f 100644 --- a/src/pytorch_tabular/utils/progress.py +++ b/src/pytorch_tabular/utils/progress.py @@ -24,7 +24,7 @@ def track(self, iterable, *args, **kwargs): return iterable -def get_progress_tracker(backend: str = "rich", description: Optional[str] = None) -> Callable[[Iterator], Iterator]: +def get_progress_tracker(backend: str = "none", description: Optional[str] = None) -> Callable[[Iterator], Iterator]: """Get a progress tracker function based on the backend. Args: @@ -51,7 +51,7 @@ def get_progress_tracker(backend: str = "rich", description: Optional[str] = Non return lambda it: it -def get_progress_context(backend: str = "rich"): +def get_progress_context(backend: str = "none"): """Get a progress context manager based on the backend. Args: From b3c213b4ccff1a28d2e52813ce4566b5ee1fd458 Mon Sep 17 00:00:00 2001 From: Arnav Kapoor Date: Wed, 4 Feb 2026 01:20:06 +0530 Subject: [PATCH 4/7] Fix lambda functions to accept keyword arguments - Update fallback lambda functions to accept **kwargs to handle description and other parameters - Prevents TypeError when progress tracker is called with keyword arguments --- src/pytorch_tabular/utils/progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_tabular/utils/progress.py b/src/pytorch_tabular/utils/progress.py index 07fd209f..14dc2e70 100644 --- a/src/pytorch_tabular/utils/progress.py +++ b/src/pytorch_tabular/utils/progress.py @@ -40,15 +40,15 @@ def get_progress_tracker(backend: str = "none", description: Optional[str] = Non return partial(track, description=description) if description else track except ImportError: # Fallback to none if rich is not available - return lambda it: it + return lambda it, **kwargs: it elif backend == "tqdm": try: from tqdm.auto import tqdm return partial(tqdm, desc=description) if description else tqdm except ImportError: - return lambda it: it + return lambda it, **kwargs: it else: # none - return lambda it: it + return lambda it, **kwargs: it def get_progress_context(backend: str = "none"): From 927262f36f3e64a691b1f4fc6469d038cbef5984 Mon Sep 17 00:00:00 2001 From: Arnav Kapoor Date: Wed, 4 Feb 2026 01:23:15 +0530 Subject: [PATCH 5/7] Add get_progress_bar_callback utility to abstract PyTorch Lightning callback selection - Create get_progress_bar_callback() function for choosing progress bar callbacks - Refactor _prepare_callbacks to use the utility instead of manual if/elif logic - Further decouples progress bar handling from direct Lightning imports --- src/pytorch_tabular/tabular_model.py | 12 ++++++----- src/pytorch_tabular/utils/progress.py | 31 +++++++++++++++------------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index bb927992..0124cb4a 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -25,7 +25,7 @@ from omegaconf.dictconfig import DictConfig from pandas import DataFrame from pytorch_lightning import seed_everything -from pytorch_lightning.callbacks import RichProgressBar, TQDMProgressBar +from pytorch_tabular.utils.progress import get_progress_bar_callback from pytorch_lightning.callbacks.gradient_accumulation_scheduler import ( GradientAccumulationScheduler, ) @@ -319,10 +319,12 @@ def _prepare_callbacks(self, callbacks=None) -> List: self.config.enable_checkpointing = True else: self.config.enable_checkpointing = False - if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get("enable_progress_bar", True): - callbacks.append(RichProgressBar()) - elif self.config.progress_bar == "simple" and self.config.trainer_kwargs.get("enable_progress_bar", True): - callbacks.append(TQDMProgressBar()) + progress_callback = get_progress_bar_callback( + self.config.progress_bar, + self.config.trainer_kwargs.get("enable_progress_bar", True) + ) + if progress_callback is not None: + callbacks.append(progress_callback) elif self.config.progress_bar == "none": self.config.trainer_kwargs["enable_progress_bar"] = False if self.verbose: diff --git a/src/pytorch_tabular/utils/progress.py b/src/pytorch_tabular/utils/progress.py index 14dc2e70..23b5bb31 100644 --- a/src/pytorch_tabular/utils/progress.py +++ b/src/pytorch_tabular/utils/progress.py @@ -51,24 +51,27 @@ def get_progress_tracker(backend: str = "none", description: Optional[str] = Non return lambda it, **kwargs: it -def get_progress_context(backend: str = "none"): - """Get a progress context manager based on the backend. +def get_progress_bar_callback(backend: str = "simple", enable_progress_bar: bool = True): + """Get the appropriate PyTorch Lightning progress bar callback based on backend. Args: - backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'. + backend: The progress bar backend. Can be 'rich', 'simple', or 'none'. + enable_progress_bar: Whether progress bar is enabled in trainer kwargs. Returns: - A context manager for progress tracking that has a track method. + A PyTorch Lightning callback or None. """ - if backend == "rich": + if backend == "rich" and enable_progress_bar: try: - from rich.progress import Progress - return Progress() + from pytorch_lightning.callbacks import RichProgressBar + return RichProgressBar() except ImportError: - return DummyProgress() - elif backend == "tqdm": - # tqdm doesn't have a context manager like rich's Progress - # For now, return DummyProgress - return DummyProgress() - else: - return DummyProgress() \ No newline at end of file + return None + elif backend == "simple" and enable_progress_bar: + try: + from pytorch_lightning.callbacks import TQDMProgressBar + return TQDMProgressBar() + except ImportError: + return None + else: # none + return None \ No newline at end of file From baaa564bab85b58d96afec3743d9c068b42bc172 Mon Sep 17 00:00:00 2001 From: Arnav Kapoor Date: Sat, 14 Feb 2026 12:38:56 +0530 Subject: [PATCH 6/7] bug fix --- src/pytorch_tabular/tabular_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 0124cb4a..6938e924 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -1498,7 +1498,7 @@ def add_noise(module, input, output): ret_logits, include_input_features=False, device=device, - progress_bar=progress_bar or "None", + progress_bar=progress_bar or "none", ) pred_idx = pred_df.index if self.config.task == "classification": From d19e14bd8fc5028e97daed46f47558d96bf7cad4 Mon Sep 17 00:00:00 2001 From: Arnav Kapoor Date: Mon, 16 Feb 2026 21:43:54 +0530 Subject: [PATCH 7/7] Added back the get_progress_context function to progress.py --- src/pytorch_tabular/utils/progress.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/pytorch_tabular/utils/progress.py b/src/pytorch_tabular/utils/progress.py index 23b5bb31..f098deb4 100644 --- a/src/pytorch_tabular/utils/progress.py +++ b/src/pytorch_tabular/utils/progress.py @@ -51,6 +51,29 @@ def get_progress_tracker(backend: str = "none", description: Optional[str] = Non return lambda it, **kwargs: it +def get_progress_context(backend: str = "none"): + """Get a progress context manager based on the backend. + + Args: + backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'. + + Returns: + A context manager for progress tracking that has a track method. + """ + if backend == "rich": + try: + from rich.progress import Progress + return Progress() + except ImportError: + return DummyProgress() + elif backend == "tqdm": + # tqdm doesn't have a context manager like rich's Progress + # For now, return DummyProgress + return DummyProgress() + else: + return DummyProgress() + + def get_progress_bar_callback(backend: str = "simple", enable_progress_bar: bool = True): """Get the appropriate PyTorch Lightning progress bar callback based on backend.