Source code for ugnn.conformal

import numpy as np
import torch

from typing import Literal
from torch import Tensor
from torch_geometric.data import Data


[docs] def get_prediction_sets( output: Tensor, data: Data, calib_mask: Tensor | np.ndarray, test_mask: Tensor | np.ndarray, score_function: Literal["APS", "RAPS", "SAPS", "THR"] = "APS", alpha=0.1, kreg=1, ): """ Computes prediction sets for a given model's output using conformal prediction. This function uses data points from the calibration set to compute a non-conformity score, which measures how "strange" a data-label pair is. It then calculates the quantile of the non-conformity scores to form prediction sets for the test set. Args: output : Tensor The model's output, typically softmax probabilities or logits, for all nodes or samples. data : Data The dataset object containing features, labels, and other graph-related information. calib_mask : Tensor A boolean mask indicating which samples belong to the calibration set. test_mask : Tensor A boolean mask indicating which samples belong to the test set. score_function : Literal["APS", "RAPS", "SAPS", "THR"], optional The scoring function to use for conformal prediction. Options include: - "APS": Adaptive Prediction Sets. - "RAPS": Regularized Adaptive Prediction Sets. - "SAPS": Smoothed Adaptive Prediction Sets. - "THR": Threshold-based Prediction Sets. Default is "APS". alpha : float, optional The miscoverage level for conformal prediction. Determines the confidence level (e.g., alpha=0.1 corresponds to 90% confidence). Default is 0.1. kreg : int, optional Regularization parameter used in certain scoring functions like "RAPS" and "SAPS". Default is 1. Returns: prediction_sets : Tensor A tensor containing the prediction sets for the test samples. Each set contains the indices of the predicted classes for each test sample. Notes: - For "RAPS" and "SAPS" scoring functions, the calibration set is further split into a validation set (20%) and a calibration set (80%) to tune the hyperparameter `kreg`. - The function assumes that the calibration and test masks are disjoint. References: - "Conformal Prediction for Reliable Machine Learning: Theory and Applications" (2023) https://arxiv.org/pdf/2310.06430 """ if isinstance(calib_mask, np.ndarray): calib_mask = torch.tensor(calib_mask, dtype=bool) if isinstance(test_mask, np.ndarray): test_mask = torch.tensor(test_mask, dtype=bool) # Some scores require the choice of a hyperparameter # Following https://arxiv.org/pdf/2310.06430, split the calibration set to a # calib_validation and calib_calibration set to choose the hyperparameter # The above authors do this with a 20:80 split initial_mask_size = torch.sum(calib_mask).item() if score_function in ["RAPS", "SAPS"]: # Split the calibration set into a calibration and validation set n_calib = torch.sum(calib_mask).item() n_valid = int(0.2 * n_calib) # Get indices of the calibration set calib_indices = np.where(calib_mask)[0] # Shuffle the indices np.random.shuffle(calib_indices) # Split into calibration and validation indices calib_valid_indices = calib_indices[:n_valid] calib_indices = calib_indices[n_valid:] # Create new masks calib_valid_mask = torch.zeros_like(calib_mask, dtype=bool) calib_valid_mask[calib_valid_indices] = True calib_mask = torch.zeros_like(calib_mask, dtype=bool) calib_mask[calib_indices] = True # Compute softmax probabilities n_calib = calib_mask.sum() n_calib_valid = calib_valid_mask.sum() smx = torch.nn.Softmax(dim=1) calib_heuristic = smx(output[calib_mask]) calib_valid_heuristic = smx(output[calib_valid_mask]) test_heuristic = smx(output[test_mask]).detach().numpy() assert ( torch.sum(calib_mask).item() < initial_mask_size ), "Calibration mask not reduced" else: # Compute softmax probabilities n_calib = calib_mask.sum() smx = torch.nn.Softmax(dim=1) calib_heuristic = smx(output[calib_mask]) # .detach().numpy() test_heuristic = smx(output[test_mask]).detach().numpy() if score_function == "APS": calib_scores = ( APS_scores(probs=calib_heuristic, label=data.y[calib_mask]).detach().numpy() ) elif score_function == "THR": calib_scores = ( THR_scores(probs=calib_heuristic, label=data.y[calib_mask]).detach().numpy() ) elif score_function == "RAPS": pen_to_try = np.array([0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5]) best_param = 0 best_size = np.unique(data.y[calib_valid_mask]).shape[0] for pen in pen_to_try: calib_valid_scores = ( RAPS_scores( probs=calib_valid_heuristic, label=data.y[calib_valid_mask], penalty=pen, kreg=kreg, ) .detach() .numpy() ) # Evaluate qhat = np.quantile( calib_valid_scores, np.ceil((n_calib_valid + 1) * (1 - alpha)) / n_calib_valid, method="higher", ) test_pi = test_heuristic.argsort(1)[:, ::-1] test_srt = np.take_along_axis(test_heuristic, test_pi, axis=1).cumsum( axis=1 ) pred_sets = np.take_along_axis( test_srt <= qhat, test_pi.argsort(axis=1), axis=1 ) # Average size avg_size = np.mean(np.sum(pred_sets, axis=1)) # print(f"Penalty: {pen}, Avg size: {avg_size}") if avg_size < best_size: best_param = pen best_size = avg_size # print(f"\nBest penalty: {best_param}") calib_scores = ( RAPS_scores( probs=calib_heuristic, label=data.y[calib_mask], penalty=best_param, kreg=kreg, ) .detach() .numpy() ) elif score_function == "SAPS": wt_to_try = np.array( [0.02, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 2, 5] ) best_param = wt_to_try[0] best_size = np.unique(data.y[calib_valid_mask]).shape[0] for wt in wt_to_try: calib_valid_scores = ( SAPS_scores( probs=calib_valid_heuristic, label=data.y[calib_valid_mask], weight=wt, ) .detach() .numpy() ) # Evaluate qhat = np.quantile( calib_valid_scores, np.ceil((n_calib_valid + 1) * (1 - alpha)) / n_calib_valid, method="higher", ) test_pi = test_heuristic.argsort(1)[:, ::-1] test_srt = np.take_along_axis(test_heuristic, test_pi, axis=1).cumsum( axis=1 ) pred_sets = np.take_along_axis( test_srt <= qhat, test_pi.argsort(axis=1), axis=1 ) # Average size avg_size = np.mean(np.sum(pred_sets, axis=1)) if avg_size < best_size: best_param = wt best_size = avg_size # print(f"\nBest weight: {best_param}") calib_scores = ( SAPS_scores( probs=calib_heuristic, label=data.y[calib_mask], weight=best_param, ) .detach() .numpy() ) else: raise ValueError(f"Unknown method: {score_function}") # Get the score quantile qhat_quantile = np.ceil((n_calib + 1) * (1 - alpha)) / n_calib if qhat_quantile > 1: raise ValueError( "Specified quantile is larger than 1. Either increase the number of calibration data points or increase alpha." ) qhat = np.quantile(calib_scores, qhat_quantile, method="higher") test_pi = test_heuristic.argsort(1)[:, ::-1] test_srt = np.take_along_axis(test_heuristic, test_pi, axis=1).cumsum(axis=1) pred_sets = np.take_along_axis(test_srt <= qhat, test_pi.argsort(axis=1), axis=1) return pred_sets
def _sort_sum(probs): # ordered: the ordered probabilities in descending order # indices: the rank of ordered probabilities in descending order # cumsum: the accumulation of sorted probabilities ordered, indices = torch.sort(probs, dim=-1, descending=True) cumsum = torch.cumsum(ordered, dim=-1) return indices, ordered, cumsum
[docs] def APS_scores(probs, label): indices, ordered, cumsum = _sort_sum(probs) U = torch.rand(indices.shape[0], device=probs.device) idx = torch.where(indices == label.view(-1, 1)) # scores_first_rank = U * cumsum[idx] scores_first_rank = cumsum[idx] idx_minus_one = (idx[0], idx[1] - 1) # scores_usual = U * ordered[idx] + cumsum[idx_minus_one] scores_usual = ordered[idx] + cumsum[idx_minus_one] return torch.where(idx[1] == 0, scores_first_rank, scores_usual)
[docs] def RAPS_scores(probs, label, penalty, kreg): indices, ordered, cumsum = _sort_sum(probs) U = torch.rand(indices.shape[0], device=probs.device) idx = torch.where(indices == label.view(-1, 1)) reg = torch.maximum(penalty * (idx[1] + 1 - kreg), torch.tensor(0).to(probs.device)) # scores_first_rank = U * ordered[idx] + reg scores_first_rank = ordered[idx] + reg idx_minus_one = (idx[0], idx[1] - 1) # scores_usual = U * ordered[idx] + cumsum[idx_minus_one] + reg scores_usual = ordered[idx] + cumsum[idx_minus_one] + reg return torch.where(idx[1] == 0, scores_first_rank, scores_usual)
[docs] def THR_scores(probs, label): return 1 - probs[torch.arange(probs.shape[0], device=probs.device), label]
[docs] def SAPS_scores(probs, label, weight): if weight <= 0: raise ValueError("The parameter 'weight' must be a positive value.") indices, ordered, cumsum = _sort_sum(probs) # U = torch.rand(indices.shape[0], device=probs.device) idx = torch.where(indices == label.view(-1, 1)) # scores_first_rank = U * cumsum[idx] scores_first_rank = cumsum[idx] # scores_usual = weight * (idx[1] - U) + ordered[:, 0] scores_usual = weight * (idx[1]) + ordered[:, 0] return torch.where(idx[1] == 0, scores_first_rank, scores_usual)