diff --git a/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml b/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml index 8531fd081..36a0e7eae 100644 --- a/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml +++ b/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml @@ -48,6 +48,8 @@ trainerConfig: ("paper", "to", "author"): [15, 15], ("author", "to", "paper"): [20, 20] } + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" command: python -m examples.link_prediction.heterogeneous_training inferencerConfig: inferencerArgs: @@ -63,6 +65,8 @@ inferencerConfig: ("paper", "to", "author"): [15, 15], ("author", "to", "paper"): [20, 20] } + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" inferenceBatchSize: 512 command: python -m examples.link_prediction.heterogeneous_inference sharedConfig: diff --git a/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml b/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml index 606f13c29..6297faeb2 100644 --- a/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml +++ b/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml @@ -17,12 +17,16 @@ trainerConfig: # Example argument to trainer log_every_n_batch: "50" # Frequency in which we log batch information num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" command: python -m examples.link_prediction.homogeneous_training inferencerConfig: inferencerArgs: # Example argument to inferencer log_every_n_batch: "50" # Frequency in which we log batch information num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" inferenceBatchSize: 512 command: python -m examples.link_prediction.homogeneous_inference sharedConfig: diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml index 7c23186c7..324b83192 100644 --- a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml @@ -48,6 +48,8 @@ trainerConfig: ("paper", "to", "author"): [15, 15], ("author", "to", "paper"): [20, 20] } + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" command: python -m examples.link_prediction.graph_store.heterogeneous_training graphStoreStorageConfig: command: python -m examples.link_prediction.graph_store.storage_main @@ -87,6 +89,8 @@ inferencerConfig: ("paper", "to", "author"): [15, 15], ("author", "to", "paper"): [20, 20] } + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" inferenceBatchSize: 512 command: python -m examples.link_prediction.graph_store.heterogeneous_inference graphStoreStorageConfig: diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml index 2283a2f91..4d1000484 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml @@ -20,6 +20,8 @@ trainerConfig: # Example argument to trainer log_every_n_batch: "50" # Frequency in which we log batch information num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" command: python -m examples.link_prediction.graph_store.homogeneous_training graphStoreStorageConfig: command: python -m examples.link_prediction.graph_store.storage_main @@ -47,6 +49,8 @@ inferencerConfig: # Example argument to inferencer log_every_n_batch: "50" # Frequency in which we log batch information num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696" + tensorboard_experiment_name: "gigl-oss-examples" inferenceBatchSize: 512 command: python -m examples.link_prediction.graph_store.homogeneous_inference graphStoreStorageConfig: diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index b85e6b638..a8ee180cd 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -86,7 +86,7 @@ import sys import time from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch import torch.distributed @@ -114,6 +114,7 @@ from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -185,6 +186,11 @@ class InferenceProcessArgs: sampling_worker_shared_channel_size: str log_every_n_batch: int + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + @torch.no_grad() def _inference_process( @@ -213,6 +219,13 @@ def _inference_process( ) # Set the device for the current process. Without this, NCCL will fail when multiple GPUs are available. rank = args.machine_rank * args.local_world_size + local_rank + is_chief_process = args.machine_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) world_size = args.machine_world_size * args.local_world_size # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster. # If this is not done, the dataloader will not be able to sample from the graph store and will crash. @@ -343,6 +356,11 @@ def _inference_process( f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds." f"and model inference took {cumulative_inference_time:.2f} seconds." ) + batches_per_sec = args.log_every_n_batch / max(time.time() - t, 1e-9) + tensorboard_writer.log( + {"Inference/throughput_batches_per_sec": batches_per_sec}, + step=batch_idx, + ) t = time.time() cumulative_data_loading_time = 0 cumulative_inference_time = 0 @@ -370,6 +388,7 @@ def _inference_process( torch.distributed.barrier() data_loader.shutdown() + tensorboard_writer.close() shutdown_compute_process() gc.collect() @@ -383,12 +402,20 @@ def _inference_process( def _run_example_inference( job_name: str, task_config_uri: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ) -> None: """ Runs an example inference pipeline using GiGL Orchestration. Args: job_name (str): Name of current job task_config_uri (str): Path to frozen GBMLConfigPbWrapper + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name. When set together with ``tensorboard_experiment_name``, + the chief rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ # All machines run this logic to connect together, and return a distributed context with: # - the (GCP) internal IP address of the rank 0 machine, which will be used for building RPC connections. @@ -546,6 +573,9 @@ def _run_example_inference( sampling_workers_per_inference_process=sampling_workers_per_inference_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, log_every_n_batch=log_every_n_batch, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) logger.info( f"Rank {cluster_info.compute_node_rank} started inference process for node type {inference_node_type} with {num_inference_processes_per_machine} processes\nargs: {inference_args}" @@ -602,16 +632,38 @@ def _run_example_inference( help="Inference job name", ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference args, unused_args = parser.parse_known_args() logger.info(f"Args: {args}, Unused arguments: {unused_args}") flush() - # We only need `job_name` and `task_config_uri` for running inference _run_example_inference( job_name=args.job_name, task_config_uri=args.task_config_uri, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) except Exception as e: sys.stderr.write(f"Error: {e}\n") diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 2f04ea9f7..ace53be88 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -115,6 +115,7 @@ from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -423,6 +424,11 @@ class TrainingProcessArgs: log_every_n_batch: int should_skip_training: bool + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + def _training_process( local_rank: int, @@ -460,11 +466,20 @@ def _training_process( torch.cuda.set_device(device) print(f"---Rank {rank} training process set device {device}") + is_chief_process = args.cluster_info.compute_node_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) + loss_fn = RetrievalLoss( loss=torch.nn.CrossEntropyLoss(reduction="mean"), temperature=0.07, remove_accidental_hits=True, ) + batch_idx = 0 if not args.should_skip_training: train_main_loader, train_random_negative_loader = _setup_dataloaders( @@ -525,7 +540,6 @@ def _training_process( # Entering the training loop training_start_time = time.time() - batch_idx = 0 avg_train_loss = 0.0 last_n_batch_avg_loss: list[float] = [] last_n_batch_time: list[float] = [] @@ -567,6 +581,7 @@ def _training_process( if ( batch_idx % args.log_every_n_batch == 0 or batch_idx < 10 ): # Log the first 10 batches to ensure the model is initialized correctly + mean_train_loss = statistics.mean(last_n_batch_avg_loss) print( f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" ) @@ -577,15 +592,16 @@ def _training_process( ) last_n_batch_time.clear() print( - f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}" ) + tensorboard_writer.log({"Loss/train": mean_train_loss}, step=batch_idx) last_n_batch_avg_loss.clear() flush() if batch_idx % args.val_every_n_batch == 0: print(f"rank={rank}, batch={batch_idx}, validating...") model.eval() - _run_validation_loops( + global_avg_val_loss = _run_validation_loops( model=model, main_loader=val_main_loader_iter, random_negative_loader=val_random_negative_loader_iter, @@ -596,6 +612,9 @@ def _training_process( log_every_n_batch=args.log_every_n_batch, num_batches=num_val_batches_per_process, ) + tensorboard_writer.log( + {"Loss/val": global_avg_val_loss}, step=batch_idx + ) model.train() else: print(f"rank={rank} ended training early - no break condition was met") @@ -674,6 +693,7 @@ def _training_process( device=device, log_every_n_batch=args.log_every_n_batch, ) + tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx) # Memory cleanup and waiting for all processes to finish if torch.cuda.is_available(): @@ -702,6 +722,8 @@ def _training_process( ) flush() + tensorboard_writer.close() + # Graph store mode cleanup: shutdown the compute process connection to the storage cluster. shutdown_compute_process() gc.collect() @@ -814,11 +836,22 @@ def _run_validation_loops( def _run_example_training( task_config_uri: str, + job_name: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ): """ Runs an example training + testing loop using GiGL Orchestration in graph store mode. Args: task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + job_name (str): Unique launch identifier used as the TensorBoard + ``ExperimentRun`` name on the chief rank. + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name. When set together with ``tensorboard_experiment_name``, + the chief rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ program_start_time = time.time() mp.set_start_method("spawn") @@ -966,6 +999,9 @@ def _run_example_training( val_every_n_batch=val_every_n_batch, log_every_n_batch=log_every_n_batch, should_skip_training=should_skip_training, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) torch.multiprocessing.spawn( @@ -985,11 +1021,40 @@ def _run_example_training( parser = argparse.ArgumentParser( description="Arguments for distributed model training on VertexAI (graph store mode)" ) + parser.add_argument( + "--job_name", + type=str, + help="Training job name; used as the TensorBoard ExperimentRun name", + ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) args, unused_args = parser.parse_known_args() print(f"Unused arguments: {unused_args}") _run_example_training( task_config_uri=args.task_config_uri, + job_name=args.job_name, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 34bc2672e..f102eff8c 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -87,7 +87,7 @@ import sys import time from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch import torch.multiprocessing as mp @@ -111,6 +111,7 @@ from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -175,6 +176,11 @@ class InferenceProcessArgs: inference_node_type: NodeType gbml_config_pb_wrapper: GbmlConfigPbWrapper + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + @torch.no_grad() def _inference_process( @@ -212,6 +218,13 @@ def _inference_process( local_rank, ) rank = torch.distributed.get_rank() + is_chief_process = args.cluster_info.compute_node_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) world_size = torch.distributed.get_world_size() logger.info( f"Local rank {local_rank} in machine {args.cluster_info.compute_node_rank} has rank {rank}/{world_size} and using device {device} for inference" @@ -329,6 +342,11 @@ def _inference_process( f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds " f"and model inference took {cumulative_inference_time:.2f} seconds." ) + batches_per_sec = args.log_every_n_batch / max(time.time() - t, 1e-9) + tensorboard_writer.log( + {"Inference/throughput_batches_per_sec": batches_per_sec}, + step=batch_idx, + ) # We don't see logs for graph store mode for whatever reason. # TOOD(#442): Revert this once the GCP issues are resolved. sys.stdout.flush() @@ -361,6 +379,7 @@ def _inference_process( torch.distributed.barrier() data_loader.shutdown() + tensorboard_writer.close() gc.collect() logger.info( @@ -394,12 +413,20 @@ def _inference_process( def _run_example_inference( job_name: str, task_config_uri: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ) -> None: """ Runs an example inference pipeline using GiGL Orchestration. Args: job_name (str): Name of current job task_config_uri (str): Path to frozen GBMLConfigPbWrapper + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name. When set together with ``tensorboard_experiment_name``, + the chief rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ program_start_time = time.time() @@ -533,6 +560,9 @@ def _run_example_inference( log_every_n_batch=log_every_n_batch, inference_node_type=graph_metadata.homogeneous_node_type, gbml_config_pb_wrapper=gbml_config_pb_wrapper, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) mp.spawn( fn=_inference_process, @@ -560,12 +590,34 @@ def _run_example_inference( help="Inference job name", ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference args, unused_args = parser.parse_known_args() logger.info(f"Unused arguments: {unused_args}") - # We only need `job_name` and `task_config_uri` for running inference _run_example_inference( job_name=args.job_name, task_config_uri=args.task_config_uri, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 04340f99a..6c81191ea 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -159,6 +159,7 @@ from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -415,6 +416,11 @@ class TrainingProcessArgs: log_every_n_batch: int should_skip_training: bool + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + def _training_process( local_rank: int, @@ -451,11 +457,20 @@ def _training_process( torch.cuda.set_device(device) logger.info(f"---Rank {rank} training process set device {device}") + is_chief_process = args.cluster_info.compute_node_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) + loss_fn = RetrievalLoss( loss=torch.nn.CrossEntropyLoss(reduction="mean"), temperature=0.07, remove_accidental_hits=True, ) + batch_idx = 0 if not args.should_skip_training: train_main_loader, train_random_negative_loader = _setup_dataloaders( @@ -517,7 +532,6 @@ def _training_process( # Entering the training loop training_start_time = time.time() - batch_idx = 0 avg_train_loss = 0.0 last_n_batch_avg_loss: list[float] = [] last_n_batch_time: list[float] = [] @@ -555,6 +569,7 @@ def _training_process( batch_start = time.time() batch_idx += 1 if batch_idx % args.log_every_n_batch == 0: + mean_train_loss = statistics.mean(last_n_batch_avg_loss) logger.info( f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" ) @@ -565,15 +580,16 @@ def _training_process( ) last_n_batch_time.clear() logger.info( - f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}" ) + tensorboard_writer.log({"Loss/train": mean_train_loss}, step=batch_idx) last_n_batch_avg_loss.clear() flush() if batch_idx % args.val_every_n_batch == 0: logger.info(f"rank={rank}, batch={batch_idx}, validating...") model.eval() - _run_validation_loops( + global_avg_val_loss = _run_validation_loops( model=model, main_loader=val_main_loader_iter, random_negative_loader=val_random_negative_loader_iter, @@ -582,6 +598,9 @@ def _training_process( log_every_n_batch=args.log_every_n_batch, num_batches=num_val_batches_per_process, ) + tensorboard_writer.log( + {"Loss/val": global_avg_val_loss}, step=batch_idx + ) model.train() logger.info(f"---Rank {rank} finished training") @@ -657,6 +676,7 @@ def _training_process( device=device, log_every_n_batch=args.log_every_n_batch, ) + tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx) # Memory cleanup and waiting for all processes to finish if torch.cuda.is_available(): @@ -685,6 +705,8 @@ def _training_process( ) flush() + tensorboard_writer.close() + # Graph store mode cleanup: shutdown the compute process connection to the storage cluster. shutdown_compute_process() gc.collect() @@ -805,11 +827,22 @@ def _run_validation_loops( def _run_example_training( task_config_uri: str, + job_name: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ): """ Runs an example training + testing loop using GiGL Orchestration in graph store mode. Args: task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + job_name (str): Unique launch identifier used as the TensorBoard + ``ExperimentRun`` name on the chief rank. + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name. When set together with ``tensorboard_experiment_name``, + the chief rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ program_start_time = time.time() mp.set_start_method("spawn") @@ -943,6 +976,9 @@ def _run_example_training( val_every_n_batch=val_every_n_batch, log_every_n_batch=log_every_n_batch, should_skip_training=should_skip_training, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) torch.multiprocessing.spawn( @@ -962,11 +998,40 @@ def _run_example_training( parser = argparse.ArgumentParser( description="Arguments for distributed model training on VertexAI (graph store mode)" ) + parser.add_argument( + "--job_name", + type=str, + help="Training job name; used as the TensorBoard ExperimentRun name", + ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) args, unused_args = parser.parse_known_args() logger.info(f"Unused arguments: {unused_args}") _run_example_training( task_config_uri=args.task_config_uri, + job_name=args.job_name, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) diff --git a/examples/link_prediction/heterogeneous_inference.py b/examples/link_prediction/heterogeneous_inference.py index 9aeda018f..90d58ed34 100644 --- a/examples/link_prediction/heterogeneous_inference.py +++ b/examples/link_prediction/heterogeneous_inference.py @@ -45,6 +45,7 @@ from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -110,6 +111,11 @@ class InferenceProcessArgs: sampling_worker_shared_channel_size: str log_every_n_batch: int + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + @torch.no_grad() def _inference_process( @@ -133,6 +139,13 @@ def _inference_process( local_process_rank=local_rank, ) # The device is automatically inferred based off the local process rank and the available devices rank = args.machine_rank * args.local_world_size + local_rank + is_chief_process = args.machine_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) world_size = args.machine_world_size * args.local_world_size if torch.cuda.is_available(): torch.cuda.set_device( @@ -257,6 +270,11 @@ def _inference_process( f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds." f"and model inference took {cumulative_inference_time:.2f} seconds." ) + batches_per_sec = args.log_every_n_batch / max(time.time() - t, 1e-9) + tensorboard_writer.log( + {"Inference/throughput_batches_per_sec": batches_per_sec}, + step=batch_idx, + ) t = time.time() cumulative_data_loading_time = 0 cumulative_inference_time = 0 @@ -283,6 +301,7 @@ def _inference_process( torch.distributed.barrier() data_loader.shutdown() + tensorboard_writer.close() gc.collect() logger.info( @@ -293,12 +312,20 @@ def _inference_process( def _run_example_inference( job_name: str, task_config_uri: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ) -> None: """ Runs an example inference pipeline using GiGL Orchestration. Args: job_name (str): Name of current job task_config_uri (str): Path to frozen GBMLConfigPbWrapper + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name. When set together with ``tensorboard_experiment_name``, + the chief rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ # All machines run this logic to connect together, and return a distributed context with: # - the (GCP) internal IP address of the rank 0 machine, which will be used for building RPC connections. @@ -459,6 +486,9 @@ def _run_example_inference( sampling_workers_per_inference_process=sampling_workers_per_inference_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, log_every_n_batch=log_every_n_batch, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) mp.spawn( @@ -507,13 +537,35 @@ def _run_example_inference( help="Inference job name", ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference args, unused_args = parser.parse_known_args() logger.info(f"Unused arguments: {unused_args}") - # We only need `job_name` and `task_config_uri` for running inference _run_example_inference( job_name=args.job_name, task_config_uri=args.task_config_uri, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index f0d58ca5e..352eec867 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -65,6 +65,7 @@ from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -368,6 +369,11 @@ class TrainingProcessArgs: log_every_n_batch: int should_skip_training: bool + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + def _training_process( local_rank: int, @@ -400,11 +406,21 @@ def _training_process( if torch.cuda.is_available(): torch.cuda.set_device(device) logger.info(f"---Rank {rank} training process set device {device}") + + is_chief_process = args.machine_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) + loss_fn = RetrievalLoss( loss=torch.nn.CrossEntropyLoss(reduction="mean"), temperature=0.07, remove_accidental_hits=True, ) + batch_idx = 0 if not args.should_skip_training: train_main_loader, train_random_negative_loader = _setup_dataloaders( @@ -469,7 +485,6 @@ def _training_process( # Entering the training loop training_start_time = time.time() - batch_idx = 0 avg_train_loss = 0.0 last_n_batch_avg_loss: list[float] = [] last_n_batch_time: list[float] = [] @@ -509,6 +524,7 @@ def _training_process( batch_start = time.time() batch_idx += 1 if batch_idx % args.log_every_n_batch == 0: + mean_train_loss = statistics.mean(last_n_batch_avg_loss) logger.info( f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" ) @@ -521,14 +537,15 @@ def _training_process( last_n_batch_time.clear() # log the global average training loss logger.info( - f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}" ) + tensorboard_writer.log({"Loss/train": mean_train_loss}, step=batch_idx) last_n_batch_avg_loss.clear() if batch_idx % args.val_every_n_batch == 0: logger.info(f"rank={rank}, batch={batch_idx}, validating...") model.eval() - _run_validation_loops( + global_avg_val_loss = _run_validation_loops( model=model, main_loader=val_main_loader_iter, random_negative_loader=val_random_negative_loader_iter, @@ -538,6 +555,9 @@ def _training_process( log_every_n_batch=args.log_every_n_batch, num_batches=num_val_batches_per_process, ) + tensorboard_writer.log( + {"Loss/val": global_avg_val_loss}, step=batch_idx + ) model.train() logger.info(f"---Rank {rank} finished training") @@ -619,6 +639,7 @@ def _training_process( device=device, log_every_n_batch=args.log_every_n_batch, ) + tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx) # Memory cleanup and waiting for all processes to finish if torch.cuda.is_available(): @@ -649,6 +670,8 @@ def _training_process( f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" ) + tensorboard_writer.close() + torch.distributed.destroy_process_group() @@ -747,11 +770,22 @@ def _run_validation_loops( def _run_example_training( task_config_uri: str, + job_name: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ): """ Runs an example training + testing loop using GiGL Orchestration. Args: task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + job_name (str): Unique launch identifier used as the TensorBoard + ``ExperimentRun`` name on the chief rank. + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name. When set together with ``tensorboard_experiment_name``, + the chief rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ start_time = time.time() mp.set_start_method("spawn") @@ -923,6 +957,9 @@ def _run_example_training( val_every_n_batch=val_every_n_batch, log_every_n_batch=log_every_n_batch, should_skip_training=should_skip_training, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) torch.multiprocessing.spawn( @@ -938,13 +975,41 @@ def _run_example_training( parser = argparse.ArgumentParser( description="Arguments for distributed model training on VertexAI" ) + parser.add_argument( + "--job_name", + type=str, + help="Training job name; used as the TensorBoard ExperimentRun name", + ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed trainer args, unused_args = parser.parse_known_args() logger.info(f"Unused arguments: {unused_args}") - # We only need `task_config_uri` for running trainer _run_example_training( task_config_uri=args.task_config_uri, + job_name=args.job_name, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) diff --git a/examples/link_prediction/homogeneous_inference.py b/examples/link_prediction/homogeneous_inference.py index 38d8ba3c1..d2416fb40 100644 --- a/examples/link_prediction/homogeneous_inference.py +++ b/examples/link_prediction/homogeneous_inference.py @@ -23,7 +23,7 @@ import gc import time from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch import torch.multiprocessing as mp @@ -44,6 +44,7 @@ from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -113,6 +114,11 @@ class InferenceProcessArgs: sampling_worker_shared_channel_size: str log_every_n_batch: int + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + @torch.no_grad() def _inference_process( @@ -142,6 +148,13 @@ def _inference_process( ) torch.cuda.set_device(device) rank = args.machine_rank * args.local_world_size + local_rank + is_chief_process = args.machine_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) world_size = args.machine_world_size * args.local_world_size logger.info( f"Local rank {local_rank} in machine {args.machine_rank} has rank {rank}/{world_size} and using device {device} for inference" @@ -250,6 +263,11 @@ def _inference_process( f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds " f"and model inference took {cumulative_inference_time:.2f} seconds." ) + batches_per_sec = args.log_every_n_batch / max(time.time() - t, 1e-9) + tensorboard_writer.log( + {"Inference/throughput_batches_per_sec": batches_per_sec}, + step=batch_idx, + ) t = time.time() cumulative_data_loading_time = 0 cumulative_inference_time = 0 @@ -274,6 +292,7 @@ def _inference_process( torch.distributed.barrier() data_loader.shutdown() + tensorboard_writer.close() gc.collect() logger.info( @@ -284,12 +303,20 @@ def _inference_process( def _run_example_inference( job_name: str, task_config_uri: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ) -> None: """ Runs an example inference pipeline using GiGL Orchestration. Args: job_name (str): Name of current job task_config_uri (str): Path to frozen GBMLConfigPbWrapper + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name. When set together with ``tensorboard_experiment_name``, + the chief rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ program_start_time = time.time() @@ -424,6 +451,9 @@ def _run_example_inference( sampling_workers_per_inference_process=sampling_workers_per_inference_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, log_every_n_batch=log_every_n_batch, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) mp.spawn( @@ -466,13 +496,35 @@ def _run_example_inference( help="Inference job name", ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference args, unused_args = parser.parse_known_args() logger.info(f"Unused arguments: {unused_args}") - # We only need `job_name` and `task_config_uri` for running inference _run_example_inference( job_name=args.job_name, task_config_uri=args.task_config_uri, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index b95a77489..426f4e92d 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -61,6 +61,7 @@ from gigl.types.graph import to_homogeneous from gigl.utils.iterator import InfiniteIterator from gigl.utils.sampling import parse_fanout +from gigl.utils.tensorboard_writer import TensorBoardWriter logger = Logger() @@ -326,6 +327,11 @@ class TrainingProcessArgs: log_every_n_batch: int should_skip_training: bool + # TensorBoard + job_name: str + tensorboard_resource_name: Optional[str] + tensorboard_experiment_name: Optional[str] + def _training_process( local_rank: int, @@ -360,11 +366,20 @@ def _training_process( logger.info(f"---Rank {rank} training process group initialized") + is_chief_process = args.machine_rank == 0 and local_rank == 0 + tensorboard_writer = TensorBoardWriter.create( + resource_name=args.tensorboard_resource_name, + experiment_name=args.tensorboard_experiment_name, + experiment_run_name=args.job_name, + enabled=is_chief_process, + ) + loss_fn = RetrievalLoss( loss=torch.nn.CrossEntropyLoss(reduction="mean"), temperature=0.07, remove_accidental_hits=True, ) + batch_idx = 0 if not args.should_skip_training: train_main_loader, train_random_negative_loader = _setup_dataloaders( @@ -429,7 +444,6 @@ def _training_process( # Entering the training loop training_start_time = time.time() - batch_idx = 0 avg_train_loss = 0.0 last_n_batch_avg_loss: list[float] = [] last_n_batch_time: list[float] = [] @@ -468,6 +482,7 @@ def _training_process( batch_start = time.time() batch_idx += 1 if batch_idx % args.log_every_n_batch == 0: + mean_train_loss = statistics.mean(last_n_batch_avg_loss) logger.info( f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" ) @@ -480,14 +495,15 @@ def _training_process( last_n_batch_time.clear() # log the global average training loss logger.info( - f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}" ) + tensorboard_writer.log({"Loss/train": mean_train_loss}, step=batch_idx) last_n_batch_avg_loss.clear() if batch_idx % args.val_every_n_batch == 0: logger.info(f"rank={rank}, batch={batch_idx}, validating...") model.eval() - _run_validation_loops( + global_avg_val_loss = _run_validation_loops( model=model, main_loader=val_main_loader_iter, random_negative_loader=val_random_negative_loader_iter, @@ -496,6 +512,9 @@ def _training_process( log_every_n_batch=args.log_every_n_batch, num_batches=num_val_batches_per_process, ) + tensorboard_writer.log( + {"Loss/val": global_avg_val_loss}, step=batch_idx + ) model.train() logger.info(f"---Rank {rank} finished training") @@ -573,6 +592,7 @@ def _training_process( device=device, log_every_n_batch=args.log_every_n_batch, ) + tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx) # Memory cleanup and waiting for all processes to finish if torch.cuda.is_available(): @@ -603,6 +623,8 @@ def _training_process( f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" ) + tensorboard_writer.close() + torch.distributed.destroy_process_group() @@ -697,11 +719,23 @@ def _run_validation_loops( def _run_example_training( task_config_uri: str, + job_name: str, + tensorboard_resource_name: Optional[str], + tensorboard_experiment_name: Optional[str], ): """ Runs an example training + testing loop using GiGL Orchestration. Args: task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + job_name (str): Unique launch identifier used as the TensorBoard + ``ExperimentRun`` name on the chief rank. + tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard + resource name (``projects/.../locations/.../tensorboards/``). + When set together with ``tensorboard_experiment_name``, the chief + rank streams scalar metrics to Vertex AI Experiments. + tensorboard_experiment_name (Optional[str]): Vertex AI + TensorboardExperiment name. Required when ``tensorboard_resource_name`` + is set. """ start_time = time.time() mp.set_start_method("spawn") @@ -849,6 +883,9 @@ def _run_example_training( val_every_n_batch=val_every_n_batch, log_every_n_batch=log_every_n_batch, should_skip_training=should_skip_training, + job_name=job_name, + tensorboard_resource_name=tensorboard_resource_name, + tensorboard_experiment_name=tensorboard_experiment_name, ) torch.multiprocessing.spawn( @@ -864,13 +901,41 @@ def _run_example_training( parser = argparse.ArgumentParser( description="Arguments for distributed model training on VertexAI" ) + parser.add_argument( + "--job_name", + type=str, + help="Training job name; used as the TensorBoard ExperimentRun name", + ) parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + parser.add_argument( + "--tensorboard_resource_name", + type=str, + help=( + "Optional Vertex AI Tensorboard resource name. When set together " + "with --tensorboard_experiment_name, the chief rank streams " + "scalar metrics to Vertex AI Experiments." + ), + required=False, + default=None, + ) + parser.add_argument( + "--tensorboard_experiment_name", + type=str, + help=( + "Optional Vertex AI TensorboardExperiment name. Required when " + "--tensorboard_resource_name is set." + ), + required=False, + default=None, + ) # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed trainer args, unused_args = parser.parse_known_args() logger.info(f"Unused arguments: {unused_args}") - # We only need `task_config_uri` for running trainer _run_example_training( task_config_uri=args.task_config_uri, + job_name=args.job_name, + tensorboard_resource_name=args.tensorboard_resource_name, + tensorboard_experiment_name=args.tensorboard_experiment_name, ) diff --git a/gigl/utils/tensorboard_writer.py b/gigl/utils/tensorboard_writer.py new file mode 100644 index 000000000..200ce3e7a --- /dev/null +++ b/gigl/utils/tensorboard_writer.py @@ -0,0 +1,208 @@ +"""TensorBoard writer for GiGL training and inference entrypoints. + +Writes scalars to Vertex AI's TensorboardService via the synchronous +``aiplatform.log_time_series_metrics`` API. The writer attaches to a Vertex +AI ``Experiment`` + ``ExperimentRun`` whose backing ``Tensorboard`` resource +the caller supplies explicitly. + +Vertex AI TensorBoard data model: + Tensorboard -> TensorboardExperiment -> TensorboardRun -> TensorboardTimeSeries + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview + +Configuration is plumbed through the trainer/inferencer's argparse interface +(typically populated from ``GbmlConfig.trainerConfig.trainerArgs`` or +``inferencerConfig.inferencerArgs``), not through env vars or proto fields on +``GiglResourceConfig``. Construct the writer with +:meth:`TensorBoardWriter.create` and let chief / non-chief ranks share the +same call sites: + + >>> is_chief_process = args.machine_rank == 0 and local_rank == 0 + >>> with TensorBoardWriter.create( + ... resource_name=args.tensorboard_resource_name, + ... experiment_name=args.tensorboard_experiment_name, + ... experiment_run_name=args.job_name, + ... enabled=is_chief_process, + ... ) as tb: + ... tb.log({"Loss/train": loss}, step=batch_idx) +""" + +import re +from typing import Final, Optional + +from google.cloud import aiplatform + +from gigl.common.logger import Logger + +logger = Logger() + +# Vertex AI Tensorboard resource name format. +_TENSORBOARD_RESOURCE_NAME_PATTERN: Final[re.Pattern[str]] = re.compile( + r"^projects/(?P[^/]+)" + r"/locations/(?P[^/]+)" + r"/tensorboards/(?P[^/]+)$" +) + + +class TensorBoardWriter: + """Writes scalar metrics to a Vertex AI ``ExperimentRun``. + + No-ops when disabled, so callers never see ``Optional[TensorBoardWriter]`` + plumbing across chief / non-chief ranks. + + Each :meth:`log` call issues a synchronous ``WriteTensorboardRunData`` RPC + via ``aiplatform.log_time_series_metrics``. On first sight of any new + metric key the SDK also issues a ``CreateTensorboardTimeSeries`` RPC. + Failures propagate to the caller rather than being absorbed in a + background uploader thread. + """ + + def __init__(self, *, active: bool) -> None: + """Initialize the writer. + + Callers should use :meth:`create` rather than constructing directly. + + Args: + active: When ``False``, the writer is a no-op (no SDK calls). + When ``True``, :meth:`create` has already called + ``aiplatform.init`` and ``aiplatform.start_run`` on this + process. + """ + self._active = active + self._closed = False + + @classmethod + def create( + cls, + *, + resource_name: Optional[str], + experiment_name: Optional[str], + experiment_run_name: str, + enabled: bool, + ) -> "TensorBoardWriter": + """Construct a writer from explicit configuration. + + When ``enabled`` is ``False`` (non-chief ranks), returns a no-op + writer without touching the aiplatform SDK regardless of the other + arguments. + + When ``enabled`` is ``True``, all three of ``resource_name``, + ``experiment_name``, and ``experiment_run_name`` must be non-empty. + Missing any of them raises ``RuntimeError`` so config gaps surface + immediately. ``resource_name`` must additionally match + ``projects/.../locations/.../tensorboards/...``. + + Side effects when ``enabled`` is ``True`` and all args are valid: + + - Calls ``aiplatform.init(project=..., location=..., experiment=..., + experiment_tensorboard=...)`` with project + location parsed from + ``resource_name``. + - Calls ``aiplatform.start_run(experiment_run_name, resume=False)``. + Callers are expected to pass a launch-unique run name (typically + the trainer's ``job_name``). + - Logs the human-readable TensorBoard UI URL so engineers can find + the cross-job experiment page from trainer stdout. + + Args: + resource_name: Fully-qualified Vertex AI ``Tensorboard`` resource + name (``projects/.../locations/.../tensorboards/``). + experiment_name: Vertex AI ``TensorboardExperiment`` ID under + ``resource_name``. Multiple jobs that share this value + surface as comparable runs on a single TensorBoard page. + experiment_run_name: Vertex AI ``TensorboardRun`` ID under + ``experiment_name``. Must be unique per launch (use + ``args.job_name``). + enabled: Whether this caller is responsible for writing events + (typically ``is_chief_process``). + + Returns: + A ``TensorBoardWriter`` — real if ``enabled``, no-op otherwise. + + Raises: + RuntimeError: ``enabled`` is True and any required argument is + missing. + ValueError: ``resource_name`` doesn't match the Vertex AI + Tensorboard resource-name format. + """ + if not enabled: + return cls(active=False) + + missing = [ + name + for name, value in ( + ("resource_name", resource_name), + ("experiment_name", experiment_name), + ("experiment_run_name", experiment_run_name), + ) + if not value + ] + if missing: + raise RuntimeError( + "TensorBoardWriter.create(enabled=True) requires " + f"{', '.join(missing)} to be set. The trainer/inferencer " + "entrypoint plumbs these through argparse from " + "GbmlConfig.trainerArgs / inferencerArgs." + ) + + assert resource_name is not None # narrowed by the missing check above + assert experiment_name is not None + assert experiment_run_name is not None + match = _TENSORBOARD_RESOURCE_NAME_PATTERN.match(resource_name) + if not match: + raise ValueError( + f"resource_name {resource_name!r} does not match " + "projects/.../locations/.../tensorboards/...; pass the " + "Tensorboard resource name from GCP, not the display name." + ) + + aiplatform.init( + project=match["project"], + location=match["location"], + experiment=experiment_name, + experiment_tensorboard=resource_name, + ) + aiplatform.start_run(experiment_run_name, resume=False) + experiment_url = ( + f"https://{match['location']}.tensorboard.googleusercontent.com/experiment/" + f"projects+{match['project']}" + f"+locations+{match['location']}" + f"+tensorboards+{match['tensorboard_id']}" + f"+experiments+{experiment_name}" + ) + logger.info( + f"View TensorBoard (cross-job comparison, experiment={experiment_name!r}): " + f"{experiment_url}" + ) + return cls(active=True) + + def log(self, metrics: dict[str, float], step: int) -> None: + """Write each metric scalar at ``step`` via Vertex AI Experiments. + + No-ops when the writer is inactive or already closed. All entries + in ``metrics`` are written under the hood in a single + ``WriteTensorboardRunData`` RPC. + + Args: + metrics: Mapping of TensorBoard tag to scalar value. All entries + are written at the same step. + step: TensorBoard step for the data points. + """ + if not self._active or self._closed: + return + aiplatform.log_time_series_metrics(metrics, step=step) + + def close(self) -> None: + """End the backing ``ExperimentRun``. + + Idempotent; safe to call multiple times and on no-op writers. + """ + if self._closed: + return + if self._active: + aiplatform.end_run() + self._closed = True + + def __enter__(self) -> "TensorBoardWriter": + return self + + def __exit__(self, *_exc: object) -> None: + self.close() diff --git a/tests/unit/utils/tensorboard_writer_test.py b/tests/unit/utils/tensorboard_writer_test.py new file mode 100644 index 000000000..7a65513e8 --- /dev/null +++ b/tests/unit/utils/tensorboard_writer_test.py @@ -0,0 +1,205 @@ +"""Unit tests for gigl.utils.tensorboard_writer.""" + +from unittest.mock import patch + +from absl.testing import absltest + +from gigl.utils import tensorboard_writer as tensorboard_writer_module +from gigl.utils.tensorboard_writer import TensorBoardWriter +from tests.test_assets.test_case import TestCase + +_TB_RESOURCE = "projects/my-project/locations/us-central1/tensorboards/42" +_EXPERIMENT = "my-experiment" +_RUN = "my-job-name-20260507-120000" + + +class TestTensorBoardWriter(TestCase): + """Tests for the TensorBoardWriter class.""" + + def test_create_returns_noop_when_disabled(self) -> None: + """Disabled (non-chief) writers must not touch aiplatform at all.""" + with ( + patch("google.cloud.aiplatform.init") as mock_init, + patch("google.cloud.aiplatform.start_run") as mock_start_run, + patch("google.cloud.aiplatform.log_time_series_metrics") as mock_log, + patch("google.cloud.aiplatform.end_run") as mock_end, + ): + writer = TensorBoardWriter.create( + resource_name=None, + experiment_name=None, + experiment_run_name=_RUN, + enabled=False, + ) + writer.log({"Loss/train": 1.0}, step=0) + writer.close() + + mock_init.assert_not_called() + mock_start_run.assert_not_called() + mock_log.assert_not_called() + mock_end.assert_not_called() + + def test_create_initializes_aiplatform_and_starts_run(self) -> None: + with ( + patch("google.cloud.aiplatform.init") as mock_init, + patch("google.cloud.aiplatform.start_run") as mock_start_run, + ): + TensorBoardWriter.create( + resource_name=_TB_RESOURCE, + experiment_name=_EXPERIMENT, + experiment_run_name=_RUN, + enabled=True, + ) + + mock_init.assert_called_once_with( + project="my-project", + location="us-central1", + experiment=_EXPERIMENT, + experiment_tensorboard=_TB_RESOURCE, + ) + mock_start_run.assert_called_once_with(_RUN, resume=False) + + def test_create_raises_when_enabled_and_resource_name_missing(self) -> None: + with ( + patch("google.cloud.aiplatform.init") as mock_init, + patch("google.cloud.aiplatform.start_run") as mock_start_run, + ): + with self.assertRaises(RuntimeError) as ctx: + TensorBoardWriter.create( + resource_name=None, + experiment_name=_EXPERIMENT, + experiment_run_name=_RUN, + enabled=True, + ) + + self.assertIn("resource_name", str(ctx.exception)) + mock_init.assert_not_called() + mock_start_run.assert_not_called() + + def test_create_raises_when_enabled_and_experiment_name_missing(self) -> None: + with ( + patch("google.cloud.aiplatform.init") as mock_init, + patch("google.cloud.aiplatform.start_run") as mock_start_run, + ): + with self.assertRaises(RuntimeError) as ctx: + TensorBoardWriter.create( + resource_name=_TB_RESOURCE, + experiment_name=None, + experiment_run_name=_RUN, + enabled=True, + ) + + self.assertIn("experiment_name", str(ctx.exception)) + mock_init.assert_not_called() + mock_start_run.assert_not_called() + + def test_create_raises_when_enabled_and_run_name_missing(self) -> None: + with ( + patch("google.cloud.aiplatform.init") as mock_init, + patch("google.cloud.aiplatform.start_run") as mock_start_run, + ): + with self.assertRaises(RuntimeError) as ctx: + TensorBoardWriter.create( + resource_name=_TB_RESOURCE, + experiment_name=_EXPERIMENT, + experiment_run_name="", + enabled=True, + ) + + self.assertIn("experiment_run_name", str(ctx.exception)) + mock_init.assert_not_called() + mock_start_run.assert_not_called() + + def test_create_raises_on_invalid_resource_name(self) -> None: + with ( + patch("google.cloud.aiplatform.init") as mock_init, + patch("google.cloud.aiplatform.start_run") as mock_start_run, + ): + with self.assertRaises(ValueError) as ctx: + TensorBoardWriter.create( + resource_name="not-a-valid-resource-name", + experiment_name=_EXPERIMENT, + experiment_run_name=_RUN, + enabled=True, + ) + + self.assertIn("resource_name", str(ctx.exception)) + mock_init.assert_not_called() + mock_start_run.assert_not_called() + + def test_create_logs_named_experiment_url_on_start(self) -> None: + """The named-experiment URL is logged so engineers can find the TB + page from trainer stdout. + """ + with ( + patch("google.cloud.aiplatform.init"), + patch("google.cloud.aiplatform.start_run"), + patch.object(tensorboard_writer_module.logger, "info") as mock_info, + ): + TensorBoardWriter.create( + resource_name=_TB_RESOURCE, + experiment_name=_EXPERIMENT, + experiment_run_name=_RUN, + enabled=True, + ) + + url_logs = [ + call.args[0] + for call in mock_info.call_args_list + if "View TensorBoard" in call.args[0] + ] + self.assertEqual(len(url_logs), 1) + self.assertIn(_EXPERIMENT, url_logs[0]) + self.assertIn("tensorboards+42", url_logs[0]) + self.assertIn("us-central1", url_logs[0]) + + def test_log_forwards_to_log_time_series_metrics(self) -> None: + with patch("google.cloud.aiplatform.log_time_series_metrics") as mock_log: + writer = TensorBoardWriter(active=True) + writer.log({"Loss/train": 1.5, "Loss/val": 2.0}, step=10) + + mock_log.assert_called_once_with({"Loss/train": 1.5, "Loss/val": 2.0}, step=10) + + def test_log_is_noop_when_inactive(self) -> None: + with patch("google.cloud.aiplatform.log_time_series_metrics") as mock_log: + writer = TensorBoardWriter(active=False) + writer.log({"Loss/train": 1.0}, step=0) + + mock_log.assert_not_called() + + def test_log_is_noop_after_close(self) -> None: + with ( + patch("google.cloud.aiplatform.end_run"), + patch("google.cloud.aiplatform.log_time_series_metrics") as mock_log, + ): + writer = TensorBoardWriter(active=True) + writer.close() + writer.log({"Loss/train": 1.0}, step=0) + + mock_log.assert_not_called() + + def test_context_manager_ends_run(self) -> None: + with patch("google.cloud.aiplatform.end_run") as mock_end: + with TensorBoardWriter(active=True): + pass + + mock_end.assert_called_once_with() + + def test_close_is_idempotent(self) -> None: + with patch("google.cloud.aiplatform.end_run") as mock_end: + writer = TensorBoardWriter(active=True) + writer.close() + writer.close() + + mock_end.assert_called_once_with() + + def test_close_on_inactive_writer_does_not_raise(self) -> None: + with patch("google.cloud.aiplatform.end_run") as mock_end: + writer = TensorBoardWriter(active=False) + writer.close() + writer.close() # Idempotent on no-op writer. + + mock_end.assert_not_called() + + +if __name__ == "__main__": + absltest.main()