graphtoolbox.models.gnn¶
Classes
|
Additive Graph Model (GAM-like GNN). |
|
|
|
Simple 2-layer Graph Convolutional Network encoder. |
|
Variational Graph Encoder producing mean and log-variance embeddings for Variational Graph Autoencoders (VGAE) or graph-based latent models. |
|
- class graphtoolbox.models.gnn.ConvAdapter(conv_class, in_dim, hidden_channels, heads=1, base_kwargs=None)[source][source]¶
Bases:
Module- forward(x, edge_index, x0=None, **tensors)[source][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class graphtoolbox.models.gnn.myGNN(in_channels: int, num_layers: int, hidden_channels: int, out_channels: int, conv_class=<class 'torch_geometric.nn.conv.gat_conv.GATConv'>, conv_kwargs=None, heads=1, **kwargs)[source][source]¶
Bases:
Module- forward(x, edge_index, edge_weight=None, edge_attr=None, return_attention=False, **kwargs)[source][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class graphtoolbox.models.gnn.AdditiveGraphModel(feature_group_dims: dict, num_layers: int, hidden_channels: int, out_channels: int, conv_class=None, conv_kwargs=None, **kwargs)[source][source]¶
Bases:
ModuleAdditive Graph Model (GAM-like GNN).
This model builds per-feature-group encoders and then applies one shared GNN backbone over a block-diagonal expansion of the original graph (one disjoint copy per feature group). The final per-node prediction is the average of group-wise outputs plus a learnable bias.
Mathematical form¶
\(h_g = \mathrm{Encoder}_g(x_g)\), \(y_g = \mathrm{GNN}(h_g, \mathcal{G})\), \(\hat{y} = b + \frac{1}{|G|} \sum_{g \in G} y_g\).
Key points¶
Single shared backbone over disjoint copies improves stability.
Avoids repeated backbone calls (prevents exploding activations).
Provides additive interpretability via
group_outputs.
- param feature_group_dims:
Mapping group name -> input feature dimension.
- type feature_group_dims:
dict[str, int]
- param num_layers:
Number of GNN layers in the shared backbone.
- type num_layers:
int
- param hidden_channels:
Hidden dimension used by encoders and backbone.
- type hidden_channels:
int
- param out_channels:
Output dimension per node.
- type out_channels:
int
- param conv_class:
PyG convolution class (default:
GCNConv).- type conv_class:
type, optional
- param conv_kwargs:
Extra keyword arguments forwarded to the convolution constructor.
- type conv_kwargs:
dict, optional
- group_names¶
Ordered list of feature group names.
- Type:
list[str]
- num_groups¶
Number of feature groups.
- Type:
int
- encoders¶
Per-group MLP encoders producing hidden representations.
- Type:
nn.ModuleDict
- bias¶
Learnable scalar bias added to predictions.
- Type:
torch.nn.Parameter
- group_index_map¶
Slices locating each group inside the concatenated feature tensor.
- Type:
dict[str, slice]
- forward(x, edge_index, edge_weight=None, mask=None, return_attention: bool = False, return_group_outputs: bool = False, **kwargs)[source][source]¶
Forward pass.
Fast path (
return_attention=False) constructs a block-diagonal graph and performs one shared GNN pass. Slow path (return_attention=True) iterates per group to collect attention maps.- Parameters:
x (torch.Tensor) – Node features concatenated by group, shape [N, F_total].
edge_index (torch.Tensor) – Graph connectivity [2, E].
edge_weight (torch.Tensor | None) – Optional edge weights [E].
mask (Any | None) – Unused placeholder (trainer compatibility).
return_attention (bool) – If True, returns per-group attention statistics.
return_group_outputs (bool) – If True, also return per-group node outputs.
- Returns:
y_hat: Tensor [N, out_channels] (always)
group_outputs (dict[str, Tensor]) if
return_group_outputsis Trueattention_per_group (dict[str, dict]) if
return_attentionis True
- Return type:
torch.Tensor | tuple
- class graphtoolbox.models.gnn.GCNEncoder(in_channels, out_channels)[source][source]¶
Bases:
ModuleSimple 2-layer Graph Convolutional Network encoder.
This encoder maps node features into a compact latent space using two GCNConv layers and ReLU activation.
- Parameters:
in_channels (int) – Number of input node features.
out_channels (int) – Dimension of the latent embedding space.
Examples
>>> enc = GCNEncoder(in_channels=32, out_channels=16) >>> z = enc(x, edge_index) >>> z.shape torch.Size([N, 16])
- forward(x, edge_index)[source][source]¶
Forward pass through the GCN encoder.
- Parameters:
x (torch.Tensor) – Node feature matrix of shape
[N, F].edge_index (torch.LongTensor) – Graph connectivity in COO format.
- Returns:
Latent node representations of shape
[N, out_channels].- Return type:
torch.Tensor
- class graphtoolbox.models.gnn.VariationalGNNEncoder(in_channels, out_channels, conv='gcn')[source][source]¶
Bases:
ModuleVariational Graph Encoder producing mean and log-variance embeddings for Variational Graph Autoencoders (VGAE) or graph-based latent models.
Supports both GCN and GraphSAGE convolutions.
- Parameters:
in_channels (int) – Number of input node features.
out_channels (int) – Latent embedding dimension.
conv ({'gcn', 'sage'}, default='gcn') – Type of graph convolution to use.
- conv1¶
First convolution layer (shared for both mu/logstd branches).
- Type:
torch_geometric.nn.MessagePassing
- conv_mu¶
Convolution layer producing mean embeddings.
- Type:
torch_geometric.nn.MessagePassing
- conv_logstd¶
Convolution layer producing log-variance embeddings.
- Type:
torch_geometric.nn.MessagePassing
Examples
>>> enc = VariationalGNNEncoder(in_channels=32, out_channels=16, conv='sage') >>> mu, logstd = enc(x, edge_index) >>> mu.shape, logstd.shape (torch.Size([N, 16]), torch.Size([N, 16]))
- forward(x, edge_index)[source][source]¶
Compute latent mean and log-variance representations.
- Parameters:
x (torch.Tensor) – Node feature matrix of shape
[N, F].edge_index (torch.LongTensor) – Graph connectivity in COO format.
- Returns:
Mean and log-variance tensors, each of shape
[N, out_channels].- Return type:
tuple of torch.Tensor