diff --git a/tgrag/experiments/gnn_experiments/weak_supervision_experiment.py b/tgrag/experiments/gnn_experiments/weak_supervision_experiment.py index cd7a3b68..deba68a1 100644 --- a/tgrag/experiments/gnn_experiments/weak_supervision_experiment.py +++ b/tgrag/experiments/gnn_experiments/weak_supervision_experiment.py @@ -6,6 +6,7 @@ import pandas as pd import torch from torch_geometric.loader import NeighborLoader +from torch_geometric.utils import degree from tqdm import tqdm from tgrag.dataset.temporal_dataset import TemporalDataset @@ -16,6 +17,10 @@ from tgrag.utils.logger import setup_logging from tgrag.utils.matching import reverse_domain from tgrag.utils.path import get_root_dir, get_scratch +from tgrag.utils.plot import ( + plot_neighbor_degree_distribution, + plot_neighbor_distribution, +) from tgrag.utils.seed import seed_everything parser = argparse.ArgumentParser( @@ -34,6 +39,7 @@ def run_weak_supervision_forward( model_arguments: ModelArguments, dataset: TemporalDataset, weight_directory: Path, + target: str, ) -> None: root = get_root_dir() phishing_dict: Dict[str, str] = { @@ -42,6 +48,13 @@ def run_weak_supervision_forward( 'PhishTank': 'data/phishing_data/cc_dec_2024_phishtank_domains.csv', } data = dataset[0] + + src, dst = data.edge_index + logging.info(f'Src, dst degrees loaded.') + + out_degree = degree(src, num_nodes=data.num_nodes, dtype=torch.long) + in_degree = degree(dst, num_nodes=data.num_nodes, dtype=torch.long) + device = f'cuda:{model_arguments.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) logging.info(f'Device found: {device}') @@ -72,8 +85,8 @@ def run_weak_supervision_forward( phishing_loader = NeighborLoader( data, input_nodes=phishing_indices, - num_neighbors=[30, 30, 30], - batch_size=1024, + num_neighbors=model_arguments.num_neighbors, + batch_size=model_arguments.batch_size, shuffle=False, ) logging.info( @@ -82,17 +95,54 @@ def run_weak_supervision_forward( num_nodes = data.num_nodes all_preds = torch.zeros(num_nodes, 1) + neighbor_preds = [] + neighbor_nodes = set() with torch.no_grad(): for batch in tqdm(phishing_loader, desc=f'{dataset_name} batch'): batch = batch.to(device) preds = model(batch.x, batch.edge_index) seed_nodes = batch.n_id[: batch.batch_size] + + pred_neighbors = preds[batch.batch_size :] + neighbor_preds.append(pred_neighbors.cpu()) + neighbor_nodes.update(batch.n_id[batch.batch_size :].tolist()) + all_preds[seed_nodes] = preds[: batch.batch_size].cpu() + neighbor_preds = torch.cat(neighbor_preds, dim=0) + neighbor_nodes = torch.tensor(list(neighbor_nodes), dtype=torch.long) + + neighbor_in_degree = in_degree[neighbor_nodes] + logging.info(f'Size of in-degree tensor: {neighbor_in_degree.size()}') + logging.info(f'Sample of in-degree: {neighbor_in_degree[:10]}') + neighbor_out_degree = out_degree[neighbor_nodes] + logging.info(f'Size of out-degree tensor: {neighbor_out_degree.size()}') + logging.info(f'Sample of out-degree: {neighbor_out_degree[:10]}') + + plot_neighbor_distribution( + neighbor_preds=neighbor_preds, + dataset_name=dataset_name, + model_name=model_arguments.model, + target=target, + ) + plot_neighbor_degree_distribution( + neighbor_degree=neighbor_in_degree, + dataset_name=dataset_name, + model_name=model_arguments.model, + target=target, + degree='In-degree', + ) + plot_neighbor_degree_distribution( + neighbor_degree=neighbor_out_degree, + dataset_name=dataset_name, + model_name=model_arguments.model, + target=target, + degree='Out-degree', + ) + logging.info(f'Saving distribution of {dataset_name}') preds = all_preds[phishing_indices] logging.info(f'Number of predictions: {preds.size()}') - logging.info(f'Predictions: {preds}') for threshold in [0.1, 0.3, 0.5]: upper = dataset_name == 'IP2Location' accuracy = get_accuracy(preds, threshold=threshold, upper=upper) @@ -145,7 +195,7 @@ def main() -> None: encoding=encoding_dict, seed=meta_args.global_seed, processed_dir=f'{scratch}/{meta_args.processed_location}', - ) # Map to .to_cpu() + ) logging.info('In-Memory Dataset loaded.') weight_directory = ( root / cast(str, meta_args.weights_directory) / f'{meta_args.target_col}' @@ -157,6 +207,7 @@ def main() -> None: experiment_arg.model_args, dataset, weight_directory, + target=meta_args.target_col, ) diff --git a/tgrag/utils/plot.py b/tgrag/utils/plot.py index bc14d762..10c51884 100644 --- a/tgrag/utils/plot.py +++ b/tgrag/utils/plot.py @@ -1015,3 +1015,59 @@ def plot_pred_target_distributions_histogram( plt.tight_layout() plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1) plt.close() + + +def plot_neighbor_distribution( + neighbor_preds: Tensor, dataset_name: str, model_name: str, target: str +) -> None: + root = get_root_dir() + save_dir = root / 'results' / 'plots' / model_name / 'distribution' / target + save_dir.mkdir(parents=True, exist_ok=True) + save_path = save_dir / f'{dataset_name}_neighbor_pred_distribution.png' + plt.figure(figsize=(6, 4)) + plt.hist(neighbor_preds.numpy(), bins=20, range=(0, 1), edgecolor='black') + plt.title(f'Predicted Label Distribution (Neighbors) — {dataset_name}') + plt.xlabel('Predicted label (0, 1)') + plt.ylabel('Frequency') + plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(save_path) + plt.close() + + +def plot_neighbor_degree_distribution( + neighbor_degree: Tensor, + dataset_name: str, + model_name: str, + target: str, + degree: str, +) -> None: + root = get_root_dir() + save_dir = root / 'results' / 'plots' / model_name / 'distribution' / target + save_dir.mkdir(parents=True, exist_ok=True) + save_path = save_dir / f'{dataset_name}_neighbor_{degree}_degree_distribution.png' + plt.figure(figsize=(6, 4)) + + deg = neighbor_degree + deg = deg[deg > 0] + + unique_deg, counts = torch.unique(deg, return_counts=True) + + sorted_idx = torch.argsort(unique_deg) + unique_deg = unique_deg[sorted_idx] + counts = counts[sorted_idx] + + plt.bar( + unique_deg.numpy(), + counts.numpy(), + width=0.8, + edgecolor='black', + align='center', + ) + plt.title(f'{degree} Distribution (Neighbors) — {dataset_name}') + plt.xlabel(f'{degree}') + plt.ylabel('Frequency') + plt.grid(alpha=0.3) + plt.tight_layout() + plt.savefig(save_path) + plt.close()