Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ trainerConfig:
("paper", "to", "author"): [15, 15],
("author", "to", "paper"): [20, 20]
}
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
command: python -m examples.link_prediction.heterogeneous_training
inferencerConfig:
inferencerArgs:
Expand All @@ -63,6 +65,8 @@ inferencerConfig:
("paper", "to", "author"): [15, 15],
("author", "to", "paper"): [20, 20]
}
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
inferenceBatchSize: 512
command: python -m examples.link_prediction.heterogeneous_inference
sharedConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ trainerConfig:
# Example argument to trainer
log_every_n_batch: "50" # Frequency in which we log batch information
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
command: python -m examples.link_prediction.homogeneous_training
inferencerConfig:
inferencerArgs:
# Example argument to inferencer
log_every_n_batch: "50" # Frequency in which we log batch information
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
inferenceBatchSize: 512
command: python -m examples.link_prediction.homogeneous_inference
sharedConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ trainerConfig:
("paper", "to", "author"): [15, 15],
("author", "to", "paper"): [20, 20]
}
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
command: python -m examples.link_prediction.graph_store.heterogeneous_training
graphStoreStorageConfig:
command: python -m examples.link_prediction.graph_store.storage_main
Expand Down Expand Up @@ -87,6 +89,8 @@ inferencerConfig:
("paper", "to", "author"): [15, 15],
("author", "to", "paper"): [20, 20]
}
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
inferenceBatchSize: 512
command: python -m examples.link_prediction.graph_store.heterogeneous_inference
graphStoreStorageConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ trainerConfig:
# Example argument to trainer
log_every_n_batch: "50" # Frequency in which we log batch information
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
command: python -m examples.link_prediction.graph_store.homogeneous_training
graphStoreStorageConfig:
command: python -m examples.link_prediction.graph_store.storage_main
Expand Down Expand Up @@ -47,6 +49,8 @@ inferencerConfig:
# Example argument to inferencer
log_every_n_batch: "50" # Frequency in which we log batch information
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
tensorboard_resource_name: "projects/87123883529/locations/us-central1/tensorboards/2426122984222621696"
tensorboard_experiment_name: "gigl-oss-examples"
inferenceBatchSize: 512
command: python -m examples.link_prediction.graph_store.homogeneous_inference
graphStoreStorageConfig:
Expand Down
56 changes: 54 additions & 2 deletions examples/link_prediction/graph_store/heterogeneous_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
import sys
import time
from dataclasses import dataclass
from typing import Union
from typing import Optional, Union

import torch
import torch.distributed
Expand Down Expand Up @@ -114,6 +114,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri
from gigl.src.inference.lib.assets import InferenceAssets
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -185,6 +186,11 @@ class InferenceProcessArgs:
sampling_worker_shared_channel_size: str
log_every_n_batch: int

# TensorBoard
job_name: str
tensorboard_resource_name: Optional[str]
tensorboard_experiment_name: Optional[str]


@torch.no_grad()
def _inference_process(
Expand Down Expand Up @@ -213,6 +219,13 @@ def _inference_process(
) # Set the device for the current process. Without this, NCCL will fail when multiple GPUs are available.

rank = args.machine_rank * args.local_world_size + local_rank
is_chief_process = args.machine_rank == 0 and local_rank == 0
tensorboard_writer = TensorBoardWriter.create(
resource_name=args.tensorboard_resource_name,
experiment_name=args.tensorboard_experiment_name,
experiment_run_name=args.job_name,
enabled=is_chief_process,
)
world_size = args.machine_world_size * args.local_world_size
# Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster.
# If this is not done, the dataloader will not be able to sample from the graph store and will crash.
Expand Down Expand Up @@ -343,6 +356,11 @@ def _inference_process(
f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds."
f"and model inference took {cumulative_inference_time:.2f} seconds."
)
batches_per_sec = args.log_every_n_batch / max(time.time() - t, 1e-9)
tensorboard_writer.log(
{"Inference/throughput_batches_per_sec": batches_per_sec},
step=batch_idx,
)
t = time.time()
cumulative_data_loading_time = 0
cumulative_inference_time = 0
Expand Down Expand Up @@ -370,6 +388,7 @@ def _inference_process(
torch.distributed.barrier()

data_loader.shutdown()
tensorboard_writer.close()
shutdown_compute_process()
gc.collect()

Expand All @@ -383,12 +402,20 @@ def _inference_process(
def _run_example_inference(
job_name: str,
task_config_uri: str,
tensorboard_resource_name: Optional[str],
tensorboard_experiment_name: Optional[str],
) -> None:
"""
Runs an example inference pipeline using GiGL Orchestration.
Args:
job_name (str): Name of current job
task_config_uri (str): Path to frozen GBMLConfigPbWrapper
tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard
resource name. When set together with ``tensorboard_experiment_name``,
the chief rank streams scalar metrics to Vertex AI Experiments.
tensorboard_experiment_name (Optional[str]): Vertex AI
TensorboardExperiment name. Required when ``tensorboard_resource_name``
is set.
"""
# All machines run this logic to connect together, and return a distributed context with:
# - the (GCP) internal IP address of the rank 0 machine, which will be used for building RPC connections.
Expand Down Expand Up @@ -546,6 +573,9 @@ def _run_example_inference(
sampling_workers_per_inference_process=sampling_workers_per_inference_process,
sampling_worker_shared_channel_size=sampling_worker_shared_channel_size,
log_every_n_batch=log_every_n_batch,
job_name=job_name,
tensorboard_resource_name=tensorboard_resource_name,
tensorboard_experiment_name=tensorboard_experiment_name,
)
logger.info(
f"Rank {cluster_info.compute_node_rank} started inference process for node type {inference_node_type} with {num_inference_processes_per_machine} processes\nargs: {inference_args}"
Expand Down Expand Up @@ -602,16 +632,38 @@ def _run_example_inference(
help="Inference job name",
)
parser.add_argument("--task_config_uri", type=str, help="Gbml config uri")
parser.add_argument(
"--tensorboard_resource_name",
type=str,
help=(
"Optional Vertex AI Tensorboard resource name. When set together "
"with --tensorboard_experiment_name, the chief rank streams "
"scalar metrics to Vertex AI Experiments."
),
required=False,
default=None,
)
parser.add_argument(
"--tensorboard_experiment_name",
type=str,
help=(
"Optional Vertex AI TensorboardExperiment name. Required when "
"--tensorboard_resource_name is set."
),
required=False,
default=None,
)

# We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference
args, unused_args = parser.parse_known_args()
logger.info(f"Args: {args}, Unused arguments: {unused_args}")
flush()

# We only need `job_name` and `task_config_uri` for running inference
_run_example_inference(
job_name=args.job_name,
task_config_uri=args.task_config_uri,
tensorboard_resource_name=args.tensorboard_resource_name,
tensorboard_experiment_name=args.tensorboard_experiment_name,
)
except Exception as e:
sys.stderr.write(f"Error: {e}\n")
Expand Down
71 changes: 68 additions & 3 deletions examples/link_prediction/graph_store/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -423,6 +424,11 @@ class TrainingProcessArgs:
log_every_n_batch: int
should_skip_training: bool

# TensorBoard
job_name: str
tensorboard_resource_name: Optional[str]
tensorboard_experiment_name: Optional[str]


def _training_process(
local_rank: int,
Expand Down Expand Up @@ -460,11 +466,20 @@ def _training_process(
torch.cuda.set_device(device)
print(f"---Rank {rank} training process set device {device}")

is_chief_process = args.cluster_info.compute_node_rank == 0 and local_rank == 0
tensorboard_writer = TensorBoardWriter.create(
resource_name=args.tensorboard_resource_name,
experiment_name=args.tensorboard_experiment_name,
experiment_run_name=args.job_name,
enabled=is_chief_process,
)

loss_fn = RetrievalLoss(
loss=torch.nn.CrossEntropyLoss(reduction="mean"),
temperature=0.07,
remove_accidental_hits=True,
)
batch_idx = 0

if not args.should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
Expand Down Expand Up @@ -525,7 +540,6 @@ def _training_process(

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
last_n_batch_time: list[float] = []
Expand Down Expand Up @@ -567,6 +581,7 @@ def _training_process(
if (
batch_idx % args.log_every_n_batch == 0 or batch_idx < 10
): # Log the first 10 batches to ensure the model is initialized correctly
mean_train_loss = statistics.mean(last_n_batch_avg_loss)
print(
f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}"
)
Expand All @@ -577,15 +592,16 @@ def _training_process(
)
last_n_batch_time.clear()
print(
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}"
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}"
)
tensorboard_writer.log({"Loss/train": mean_train_loss}, step=batch_idx)
last_n_batch_avg_loss.clear()
flush()

if batch_idx % args.val_every_n_batch == 0:
print(f"rank={rank}, batch={batch_idx}, validating...")
model.eval()
_run_validation_loops(
global_avg_val_loss = _run_validation_loops(
model=model,
main_loader=val_main_loader_iter,
random_negative_loader=val_random_negative_loader_iter,
Expand All @@ -596,6 +612,9 @@ def _training_process(
log_every_n_batch=args.log_every_n_batch,
num_batches=num_val_batches_per_process,
)
tensorboard_writer.log(
{"Loss/val": global_avg_val_loss}, step=batch_idx
)
model.train()
else:
print(f"rank={rank} ended training early - no break condition was met")
Expand Down Expand Up @@ -674,6 +693,7 @@ def _training_process(
device=device,
log_every_n_batch=args.log_every_n_batch,
)
tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx)

# Memory cleanup and waiting for all processes to finish
if torch.cuda.is_available():
Expand Down Expand Up @@ -702,6 +722,8 @@ def _training_process(
)
flush()

tensorboard_writer.close()

# Graph store mode cleanup: shutdown the compute process connection to the storage cluster.
shutdown_compute_process()
gc.collect()
Expand Down Expand Up @@ -814,11 +836,22 @@ def _run_validation_loops(

def _run_example_training(
task_config_uri: str,
job_name: str,
tensorboard_resource_name: Optional[str],
tensorboard_experiment_name: Optional[str],
):
"""
Runs an example training + testing loop using GiGL Orchestration in graph store mode.
Args:
task_config_uri (str): Path to YAML-serialized GbmlConfig proto.
job_name (str): Unique launch identifier used as the TensorBoard
``ExperimentRun`` name on the chief rank.
tensorboard_resource_name (Optional[str]): Vertex AI Tensorboard
resource name. When set together with ``tensorboard_experiment_name``,
the chief rank streams scalar metrics to Vertex AI Experiments.
tensorboard_experiment_name (Optional[str]): Vertex AI
TensorboardExperiment name. Required when ``tensorboard_resource_name``
is set.
"""
program_start_time = time.time()
mp.set_start_method("spawn")
Expand Down Expand Up @@ -966,6 +999,9 @@ def _run_example_training(
val_every_n_batch=val_every_n_batch,
log_every_n_batch=log_every_n_batch,
should_skip_training=should_skip_training,
job_name=job_name,
tensorboard_resource_name=tensorboard_resource_name,
tensorboard_experiment_name=tensorboard_experiment_name,
)

torch.multiprocessing.spawn(
Expand All @@ -985,11 +1021,40 @@ def _run_example_training(
parser = argparse.ArgumentParser(
description="Arguments for distributed model training on VertexAI (graph store mode)"
)
parser.add_argument(
"--job_name",
type=str,
help="Training job name; used as the TensorBoard ExperimentRun name",
)
parser.add_argument("--task_config_uri", type=str, help="Gbml config uri")
parser.add_argument(
"--tensorboard_resource_name",
type=str,
help=(
"Optional Vertex AI Tensorboard resource name. When set together "
"with --tensorboard_experiment_name, the chief rank streams "
"scalar metrics to Vertex AI Experiments."
),
required=False,
default=None,
)
parser.add_argument(
"--tensorboard_experiment_name",
type=str,
help=(
"Optional Vertex AI TensorboardExperiment name. Required when "
"--tensorboard_resource_name is set."
),
required=False,
default=None,
)

args, unused_args = parser.parse_known_args()
print(f"Unused arguments: {unused_args}")

_run_example_training(
task_config_uri=args.task_config_uri,
job_name=args.job_name,
tensorboard_resource_name=args.tensorboard_resource_name,
tensorboard_experiment_name=args.tensorboard_experiment_name,
)
Loading