graphtoolbox.training.trainer¶
Classes
|
Implements early stopping to terminate training when validation loss stops improving. |
|
Train, validate, and evaluate a graph neural network model on temporal graph datasets. |
- class graphtoolbox.training.trainer.EarlyStopping(patience=10, min_delta=0.0)[source][source]¶
Bases:
objectImplements early stopping to terminate training when validation loss stops improving.
- Parameters:
patience (int, default=10) – Number of epochs with no improvement after which training will be stopped.
min_delta (float, default=0.0) – Minimum change in validation loss to qualify as an improvement.
- counter¶
Number of consecutive epochs without improvement.
- Type:
int
- best_loss¶
Lowest recorded validation loss.
- Type:
float
- early_stop¶
Whether the early stopping condition has been met.
- Type:
bool
Examples
>>> stopper = EarlyStopping(patience=5, min_delta=0.01) >>> for epoch in range(100): ... val_loss = compute_validation_loss() ... stopper(val_loss) ... if stopper.early_stop: ... print("Stopped early at epoch", epoch)
- class graphtoolbox.training.trainer.Trainer(model, dataset_train, dataset_val, dataset_test, batch_size, model_kwargs: Dict | None = None, reconcile: bool = True, **kwargs)[source][source]¶
Bases:
objectTrain, validate, and evaluate a graph neural network model on temporal graph datasets.
This class handles the full training loop, including: - batched training with PyTorch Geometric loaders, - validation and early stopping, - checkpointing and loss tracking, - optional attention collection and saving, - optional per-group outputs for additive models (sum to final prediction), - inference and hierarchical reconciliation (MinT).
- Parameters:
model (torch.nn.Module) – Graph neural network model to train.
dataset_train (GraphDataset) – Training dataset.
dataset_val (GraphDataset) – Validation dataset.
dataset_test (GraphDataset) – Test dataset.
batch_size (int) – Number of graph samples per batch.
model_kwargs (dict, optional) – Dictionary of model hyperparameters (loaded from config if None).
reconcile (bool, default=True) – Whether to apply MinT reconciliation to predictions.
**kwargs –
- Optional keyword-only arguments:
edge_index (torch.Tensor[2, E])
edge_weight (torch.Tensor[E]) or None
return_attention (bool): collect attention during validation/test.
return_group_outputs (bool): ask additive models to return group-wise contributions.
lam_reg (float): graph smoothness regularizer weight (pairwise node prediction L2).
- is_trained¶
Whether the model has been trained.
- Type:
bool
- train_loader, val_loader, test_loader
Dataloaders for training, validation, and test.
- Type:
PyGDataLoader
- saving_directory¶
Path to saved model checkpoints.
- Type:
str
- S, G, P
Matrices for hierarchical MinT reconciliation.
- Type:
torch.Tensor
- nodes¶
Node names (order used by the model).
- Type:
list[str]
- num_nodes¶
Number of base nodes.
- Type:
int
Notes
When the model class name is ‘AdditiveGraphModel’ and return_group_outputs is not set, Trainer will auto-enable group output collection.
- train(**kwargs) Tuple[source][source]¶
Train the model and optionally evaluate during training.
Supports early stopping, checkpoint saving, attention visualization and (for additive models) returning per-group outputs.
- Parameters:
num_epochs (int, optional) – Number of epochs to train (default: from model_kwargs).
optimizer (torch.optim.Optimizer, optional) – Optimizer (default: Adam with model_kwargs[‘lr’]).
patience (int, default=20) – Early stopping patience.
min_delta (float, default=0.0) – Minimum delta to count as improvement.
force_training (bool, default=False) – If True, retrains even if checkpoint exists.
saving_directory (str, optional) – Folder to store model weights.
plot_loss (bool, optional) – If True, plots training/validation curves.
dynamic_graph (bool, optional) – Enable dynamic adjacency matrix updates per epoch.
save (bool, optional) – If True, saves attention maps.
- Returns:
- If return_attention is False and return_group_outputs is False:
(preds, targets, edge_index, edge_weight)
- If return_attention is True:
(preds, targets, edge_index, attention_mats)
- If return_group_outputs is True and return_attention is False:
(preds, targets, edge_index, edge_weight, group_outputs)
- If both are True:
(preds, targets, edge_index, attention_mats, group_outputs)
Where: - preds : torch.Tensor[num_nodes, T] in original units (reconciled if enabled) - targets : torch.Tensor[num_nodes, T] in original units - attention_mats : dict[str, list[torch.Tensor]] if collected - group_outputs : dict[str, torch.Tensor[num_nodes, T]] if collected
- Return type:
Tuple
Notes
Applies graph smoothness regularization weighted by lam_reg.
Uses early stopping with best-checkpoint saving.