graphtoolbox.optim.optimizer¶
Classes
|
Hyperparameter optimizer for graph neural networks using Optuna. |
- class graphtoolbox.optim.optimizer.Optimizer(model, dataset_train: GraphDataset, dataset_val: GraphDataset, out_channels: int = 48, optim_kwargs: Dict = None, **kwargs)[source][source]¶
Bases:
objectHyperparameter optimizer for graph neural networks using Optuna.
The Optimizer class automates hyperparameter tuning of GNN models on training and validation datasets. It supports structured logging, pruning of poor trials, and optional dashboard visualization via optuna-dashboard.
- Parameters:
model (torch.nn.Module) – GNN model class to be optimized (not an instance).
dataset_train (GraphDataset) – Training dataset.
dataset_val (GraphDataset) – Validation dataset.
optim_kwargs (dict, optional) – Search space definition for hyperparameters. Example:
{"hidden_channels": (32, 128), "num_layers": (2, 5), "lr": (1e-4, 1e-2)}. Loaded from the configuration folder if not provided.num_epochs (int, optional) – Number of epochs to train each trial. Default is 200.
conv_class (torch_geometric.nn.MessagePassing, optional) – Convolution class to use in the model (default:
GATv2Conv).
- study¶
Optuna study object containing all trials and results.
- Type:
optuna.Study
- storage¶
In-memory storage backend for optimization results.
- Type:
optuna.storages.InMemoryStorage
- is_optimized¶
Whether the optimization process has been executed.
- Type:
bool
- logger¶
Logger instance for progress and diagnostic output.
- Type:
logging.Logger
Examples
>>> opt = Optimizer(model=myGNN, dataset_train=train_set, dataset_val=val_set) >>> opt.optimize(n_trials=30) >>> opt.run_on_server() # visualize results
- optimize(**kwargs)[source][source]¶
Run the Optuna optimization loop.
- Parameters:
study_name (str, optional) – Name for the Optuna study. Default is derived from the model class.
n_trials (int, default=100) – Number of hyperparameter trials to perform.
direction ({'minimize', 'maximize'}, default='minimize') – Optimization direction for the objective function.
timeout (int, optional) – Maximum runtime in seconds.
Effects (Side)
------------
./results_optim_<ConvClass>/. (- Saves best parameters and statistics to)
./logs/optimization_<date>.log. (- Logs progress to)
Notes
Uses in-memory storage by default but can be extended for database-backed studies if persistence is needed.