Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/regression_selection_synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


import torch
import torch.nn as nn
import torch.optim as optim

from examples.regression_cqr_synthetic import prepare_dataset
from torchcp.regression.utils import build_regression_model
from torchcp.selection.score import RES
from torchcp.selection.selector import ConformalSelector
from torchcp.selection.testing_correction import BH_procedure


# get dataloader
train_loader, cal_loader, test_loader = prepare_dataset(train_ratio=0.4, cal_ratio=0.2, batch_size=128)
# build regression model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = build_regression_model("NonLinearNet")(next(iter(train_loader))[0].shape[1], 1, 64, 0.5).to(device)

# train model
epochs = 100
criterion = nn.MSELoss()
lr = 0.01
optimizer = optim.Adam(model.parameters(), lr=lr)

for tmp_x, tmp_y in train_loader:
outputs = model(tmp_x.to(device))
loss = criterion(outputs, tmp_y.reshape(-1, 1).to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Conformal Selection
thresholds = torch.ones(len(test_loader.dataset)) * 5

selector = ConformalSelector(score_function=RES(), testing_correction=BH_procedure(), model=model)
selector.calibrate(cal_loader)
print(selector.select(test_loader, thresholds))
1 change: 0 additions & 1 deletion torchcp/classification/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
from .cd import CDLoss
from .conftr import ConfTrLoss
from .confts import ConfTSLoss
from .uncertainty_aware import UncertaintyAwareLoss
from .scpo import SCPOLoss
1 change: 0 additions & 1 deletion torchcp/classification/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@
from .confts_trainer import ConfTSTrainer
from .model_zoo import TemperatureScalingModel
from .ts_trainer import TSTrainer
from .ua_trainer import UncertaintyAwareTrainer
from .ordinal_trainer import OrdinalTrainer
from .scpo_trainer import SCPOTrainer
2 changes: 1 addition & 1 deletion torchcp/regression/predictor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from .ensemble import EnsemblePredictor
from .split import SplitPredictor
from .agaci import AgACIPredictor
from .cpd import ConformalPredictiveDistribution
from .cpd import ConformalPredictiveDistribution
2 changes: 1 addition & 1 deletion torchcp/regression/score/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from .cqrm import CQRM
from .cqrr import CQRR
from .r2ccp import R2CCP
from .sign import Sign
from .sign import Sign
Empty file added torchcp/selection/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions torchcp/selection/score/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from .clip import CLIP
from .res import RES
21 changes: 21 additions & 0 deletions torchcp/selection/score/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


import torch
from torchcp.regression.score.base import BaseScore


class CLIP(BaseScore):
"""
CLIP score (Jin et al., 2023), only apply to binary classification.
paper: https://arxiv.org/pdf/2210.01408
"""
def __call__(self, predicts, y_truth, M=100):
if len(predicts.shape) == 2:
predicts = predicts.squeeze().view(-1)
return M * torch.max(predicts, 0) - predicts
20 changes: 20 additions & 0 deletions torchcp/selection/score/res.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from torchcp.regression.score.base import BaseScore


class RES(BaseScore):
"""
RES score (Jin et al., 2023)
paper: https://arxiv.org/pdf/2210.01408
"""
def __call__(self, predicts, y_truth):
if len(predicts.shape) == 2:
predicts = predicts.squeeze().view(-1)
return y_truth - predicts
8 changes: 8 additions & 0 deletions torchcp/selection/selector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from .conformal_selector import ConformalSelector
100 changes: 100 additions & 0 deletions torchcp/selection/selector/conformal_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


import torch

from torchcp.regression.predictor.split import SplitPredictor
from torchcp.selection.utils.metrics import Metrics


class ConformalSelector(SplitPredictor):
"""
Conformal Selection:
a screening procedure that aims to select candidates whose unobserved outcomes exceed user-specified value.

Args:
score_function (torchcp.regression.scores): A class that implements the score function.
model (torch.nn.Module): A PyTorch model capable of outputting quantile values.
The model should be an initialization model that has not been trained.
alpha (float, optional): The significance level. Default is 0.1.
device (torch.device, optional): The device on which the model is located. Default is None.

Reference:
Paper: Selection by Prediction with Conformal p-values (Jin et al., 2023)
Link: https://arxiv.org/pdf/2210.01408
Github: https://github.com/ying531/conformal-selection
"""

def __init__(self, score_function, testing_correction, model, alpha=0.1, device=None):
super().__init__(score_function, model, alpha, device)
self.testing_correction = testing_correction
self._metric = Metrics()


def calibrate(self, cal_dataloader):
self._model.eval()
predicts_list, y_truth_list = [], []
with torch.no_grad():
for tmp_x, tmp_labels in cal_dataloader:
tmp_x, tmp_labels = tmp_x.to(self._device), tmp_labels.to(self._device)
tmp_predicts = self._model(tmp_x).detach()
predicts_list.append(tmp_predicts)
y_truth_list.append(tmp_labels)

predicts = torch.cat(predicts_list).float().to(self._device)
y_truth = torch.cat(y_truth_list).to(self._device)
self.cal_scores = self.score_function(predicts, y_truth)


def select(self, data_loader, thresholds):
"""
Evaluate the performance of conformal selection on a test dataset by calculating false discovery proportion
(FDP) and power of the selection set.

Args:
data_loader (DataLoader): The DataLoader providing the test data batches.
thresholds (torch.Tensor): A tensor of user-defined thresholds.

Returns:
dict: A dictionary containing:
- "False discovery proportion": The FDP of the selection set.
- "Power": The power of the selection set.

Example::

>>> eval_results = selector.evaluate(test_loader, thresholds)
>>> print(eval_results)
"""
self._model.eval()
y_truth_list = []
predicts_list = []
with torch.no_grad():
for examples in data_loader:
tmp_x, tmp_labels = examples[0].to(self._device), examples[1].to(self._device)
tmp_predicts = self._model(tmp_x).detach()
predicts_list.append(tmp_predicts)
y_truth_list.append(tmp_labels)
predicts = torch.cat(predicts_list).float().to(self._device)
y_truth = torch.cat(y_truth_list).to(self._device)
scores = self.score_function(predicts, thresholds)

n_cal, n_test = self.cal_scores.shape[0], scores.shape[0]

# Compute p-values with tie-breaking
u = torch.rand(n_test)
count_less = (self.cal_scores.view(1, n_cal) < scores.view(n_test, 1)).sum(dim=1)
count_tie = (self.cal_scores.view(1, n_cal) == scores.view(n_test, 1)).sum(dim=1) + 1
p_values = (count_less + count_tie * u) / (n_cal + 1)

indices = self.testing_correction(p_values, self.alpha)

# Evaluation
res_dict = {"false_discovery_proportion": self._metric("false_discovery_proportion")(y_truth, thresholds,
indices),
"power": self._metric("power")(y_truth, thresholds, indices)}
return res_dict
10 changes: 10 additions & 0 deletions torchcp/selection/testing_correction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from .base import Base
from .bh_procedure import BH_procedure
27 changes: 27 additions & 0 deletions torchcp/selection/testing_correction/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from abc import ABCMeta, abstractmethod
import torch
from tqdm import tqdm

from torchcp.utils.common import get_device


class Base(object):
"""
Abstract base class for all multiple testing correction algorithms.
"""
__metaclass__ = ABCMeta

def __init__(self) -> None:
pass

@abstractmethod
def __call__(self, p_values, alpha):
raise NotImplementedError
49 changes: 49 additions & 0 deletions torchcp/selection/testing_correction/bh_procedure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


import torch

from torchcp.regression.score.base import BaseScore


class BH_procedure(BaseScore):
"""
Benjamini-Hochberg (BH) procedure:
finds a p-value threshold from a list of p-values to determine which null hypotheses to reject, given a target
FDR level 'alpha'.

References:
Paper: Controlling the False Discovery Rate: A Practical and Powerful Approach to Multiple Testing
(Benjamini and Hochberg, 1995)
Link: https://www.jstor.org/stable/2346101
"""
def __init__(self):
super().__init__()

def __call__(self, p_values, alpha):
"""
Apply the Benjamini-Hochberg procedure.

Args:
p_values (torch.Tensor): A 1D tensor of p-values.
alpha (float): The desired False Discovery Rate (FDR) level (e.g., 0.1).

Returns:
torch.Tensor: A 1D tensor of indices corresponding to the p-values (hypotheses) that are rejected.
"""
p_values_sorted, _ = torch.sort(p_values)
n_test = p_values_sorted.shape[0]

k_range = torch.arange(1, n_test + 1, device=p_values_sorted.device)
thresholds = k_range * alpha / n_test
mask = p_values_sorted <= thresholds
k_star = torch.max(torch.where(mask, k_range, torch.zeros_like(k_range))) if mask.any() else 0
threshold = (k_star * alpha / n_test) if k_star > 0 else 0
indices = torch.nonzero(p_values <= threshold, as_tuple=False).squeeze()

return indices
Empty file.
64 changes: 64 additions & 0 deletions torchcp/selection/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from typing import Any

import torch

from torchcp.utils.registry import Registry

METRICS_REGISTRY_REGRESSION = Registry("METRICS")


@METRICS_REGISTRY_REGRESSION.register()
def false_discovery_proportion(y_truth, thresholds, indices):
"""
Conpute the false discovery proportion (the proportion of false discovery among all selected points) of the
selection set.

Args:
y_truth (torch.Tensor): A tensor of ground truth values.
thresholds (torch.Tensor): Tensor of user-defined thresholds.
indices (torch.Tensor): A tensor containing the indices of selected points.

Returns:
torch.Tensor: The false discovery proportion of the selection set.
"""
if indices.dim() == 0:
indices = indices.unsqueeze(0)

false_positives = torch.sum(y_truth[indices] <= thresholds[indices])
fdp = false_positives / indices.shape[-1] if indices.shape[-1] > 0 else torch.tensor(0.)
return fdp.item()


@METRICS_REGISTRY_REGRESSION.register()
def power(y_truth, thresholds, indices):
"""
Conpute the power (the proportion of desirable points that are correctly selected) of the selection set.

Args:
y_truth (torch.Tensor): A tensor of ground truth values.
thresholds (torch.Tensor): Tensor of user-defined thresholds.
indices (torch.Tensor): A tensor containing the indices of selected points.

Returns:
torch.Tensor: The power of the selection set.
"""
if indices.dim() == 0:
indices = indices.unsqueeze(0)

true_positives = torch.sum(y_truth[indices] > thresholds[indices])
power = true_positives / torch.sum(y_truth > thresholds)
return power.item()


class Metrics:
def __call__(self, metric) -> Any:
if metric not in METRICS_REGISTRY_REGRESSION.registered_names():
raise NameError(f"The metric: {metric} is not defined in TorchCP.")
return METRICS_REGISTRY_REGRESSION.get(metric)
Empty file added torchcp/utils/metrics.py
Empty file.