Graph Neural Networks#
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
#ensure that the PyTorch and the PyG are the same version
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
import torch_geometric
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import networkx as nx
from tqdm.notebook import tqdm
2.0.0.dev20230212
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
Introduction to Graph Data#
Graphs can be seen as a system or language for modeling systems that are complex and linked together. A graph is a data type that is modelled as a set of objects which can be represented as a node or vertex and their relationships which is called edges. A graph data can also be seen as a network data where there are points connected together.
A node (vertex) of a graph is point in a graph while an edge is a component that joins edges together in a graph. Graphs can be used to represent data from a lot of domains like biology, physics, social science, chemistry and others.
Graph Classification#
Graphs can be classified into different categories directed/undirected graphs, weighted/binary graphs and homogenous/heterogenous graphs.
- Directed/Undirected graphs: Directed graphs are the ones that all the edges have directtions while in undirected graphs, the edges are does not have directions 
- Weighted/Binary graphs: Weighted graphs is a type of graph that each of the edges are assigned with a value while binary graphs are the ones that the edges does not have an assigned value. 
- Homogenous/Heterogenous graphs: Homogenous graphs are the ones that all the nodes and/or edges are of the same type (e.g. friendship graph) while heterogenous graphs are graphs where the nodes and/or edges are of different types (e.g. knowledge graph). 
Traditional graph analysis methods requires using searching algorithms, clustering spanning tree algorithms and so on. A major downside to using these methods for analysis of graph data is that you require a prior knowledge of the graph to be able to apply these algorithms.
Based on the structure of the graph data, traditional machine learning system will not be able to properly interprete the graph data and thus the advent of the Graph Neural Network (GNN). Graph neural network is a domain of deep learning that is mostly concerned with deep learning operations on graph datasets.
Exploring graph data with Pytorch Geometric#
Now we will explore graph and graph neural networks using the PyTorch Geometric (PyG) package (already loaded)
Let’s create a function that will help us visualize a graph data.
def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
PyG provides access to a couple of graph datasets. For this notebook, we are going to use Zachary’s karate club network dataset.
Zachary’s karate club is a great example of social relationships within a small group. This set of data indicated the interactions of club members outside of the club. This dataset also documented the conflict between the instructor, Mr.Hi, and the club president, John because of course price. In the end, half of the members formed a new club around Mr.Hi, and the other half either stayed at the old karate club or gave up karate.
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
Now that we have imported a graph dataset, let’s look at some of the properties of a graph dataset. We will look at some of the properties at the level of the dataset and then select a graph in the dataset to explore it’s properties.
print('Dataset properties')
print('==============================================================')
print(f'Dataset: {dataset}') #This prints the name of the dataset
print(f'Number of graphs in the dataset: {len(dataset)}')
print(f'Number of features: {dataset.num_features}') #Number of features each node in the dataset has
print(f'Number of classes: {dataset.num_classes}') #Number of classes that a node can be classified into
# Since we have one graph in the dataset, we will select the graph and explore it's properties
data = dataset[0]
print('Graph properties')
print('==============================================================')
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}') #Number of nodes in the graph
print(f'Number of edges: {data.num_edges}') #Number of edges in the graph
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') # Average number of nodes in the graph
print(f'Contains isolated nodes: {data.has_isolated_nodes()}') #Does the graph contains nodes that are not connected
print(f'Contains self-loops: {data.has_self_loops()}') #Does the graph contains nodes that are linked to themselves
print(f'Is undirected: {data.is_undirected()}') #Is the graph an undirected graph
Dataset properties
==============================================================
Dataset: KarateClub()
Number of graphs in the dataset: 1
Number of features: 34
Number of classes: 4
Graph properties
==============================================================
Number of nodes: 34
Number of edges: 156
Average node degree: 4.59
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True
Now let’s visualize the graph using the function that we created earlier. But first, we will convert the graph to networkx graph
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)
 
Implementing a Graph Neural Network#
For this notebook, we will use a simple GNN which is the Graph Convolution Network (GCN) layer.
- GCN is a specific type of GNN that uses convolutional operations to propagate information between nodes in a graph. 
- GCNs leverage a localized aggregation of neighboring node features to update the representations of the nodes. 
- GCNs are based on the convolutional operation commonly used in image processing, adapted to the graph domain. 
- The layers in a GCN typically apply a graph convolution operation followed by non-linear activation functions. 
- GCNs have been successful in tasks such as node classification, where nodes are labeled based on their features and graph structure. 
Our GNN is defined by stacking three graph convolution layers, which corresponds to aggregating 3-hop neighborhood information around each node (all nodes up to 3 “hops” away). In addition, the GCNConv layers reduce the node feature dimensionality to 2, i.e., 34→4→4→2. Each GCNConv layer is enhanced by a tanh non-linearity. We then apply a linear layer which acts as a classifier to map the nodes to 1 out of the 4 possible classes.
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.
        
        # Apply a final (linear) classifier.
        out = self.classifier(h)
        return out, h
model = GCN()
print(model)
GCN(
  (conv1): GCNConv(34, 4)
  (conv2): GCNConv(4, 4)
  (conv3): GCNConv(4, 2)
  (classifier): Linear(in_features=2, out_features=4, bias=True)
)
Training the model#
To train the network, we will use the CrossEntropyLoss for the loss function and Adam as the gradient optimizer
model = GCN()
criterion = torch.nn.CrossEntropyLoss()  #Initialize the CrossEntropyLoss function.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Initialize the Adam optimizer.
def train(data):
    optimizer.zero_grad()  # Clear gradients.
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h
for epoch in range(401):
    loss, h = train(data)
    print(f'Epoch: {epoch}, Loss: {loss}')
   
Epoch: 0, Loss: 1.414404273033142
Epoch: 1, Loss: 1.4079036712646484
Epoch: 2, Loss: 1.4017865657806396
Epoch: 3, Loss: 1.3959277868270874
Epoch: 4, Loss: 1.3902204036712646
Epoch: 5, Loss: 1.3845514059066772
Epoch: 6, Loss: 1.3788156509399414
Epoch: 7, Loss: 1.372924566268921
Epoch: 8, Loss: 1.3668041229248047
Epoch: 9, Loss: 1.3603920936584473
Epoch: 10, Loss: 1.3536322116851807
Epoch: 11, Loss: 1.346463680267334
Epoch: 12, Loss: 1.3388220071792603
Epoch: 13, Loss: 1.3306517601013184
Epoch: 14, Loss: 1.321907877922058
Epoch: 15, Loss: 1.3125460147857666
Epoch: 16, Loss: 1.3025131225585938
Epoch: 17, Loss: 1.2917399406433105
Epoch: 18, Loss: 1.2801496982574463
Epoch: 19, Loss: 1.267681360244751
Epoch: 20, Loss: 1.2542929649353027
Epoch: 21, Loss: 1.2399559020996094
Epoch: 22, Loss: 1.224645733833313
Epoch: 23, Loss: 1.2083479166030884
Epoch: 24, Loss: 1.1910830736160278
Epoch: 25, Loss: 1.172907829284668
Epoch: 26, Loss: 1.1538856029510498
Epoch: 27, Loss: 1.1340937614440918
Epoch: 28, Loss: 1.113678216934204
Epoch: 29, Loss: 1.0928146839141846
Epoch: 30, Loss: 1.071663737297058
Epoch: 31, Loss: 1.0504322052001953
Epoch: 32, Loss: 1.0293234586715698
Epoch: 33, Loss: 1.0085136890411377
Epoch: 34, Loss: 0.9882092475891113
Epoch: 35, Loss: 0.9685617685317993
Epoch: 36, Loss: 0.9497098326683044
Epoch: 37, Loss: 0.9317660927772522
Epoch: 38, Loss: 0.9147973656654358
Epoch: 39, Loss: 0.8988679051399231
Epoch: 40, Loss: 0.8839972615242004
Epoch: 41, Loss: 0.8701907992362976
Epoch: 42, Loss: 0.8574356436729431
Epoch: 43, Loss: 0.8456954956054688
Epoch: 44, Loss: 0.8349326252937317
Epoch: 45, Loss: 0.8250938057899475
Epoch: 46, Loss: 0.8161190748214722
Epoch: 47, Loss: 0.8079493045806885
Epoch: 48, Loss: 0.8005207180976868
Epoch: 49, Loss: 0.7937697768211365
Epoch: 50, Loss: 0.7876371145248413
Epoch: 51, Loss: 0.78206467628479
Epoch: 52, Loss: 0.7769970893859863
Epoch: 53, Loss: 0.7723835706710815
Epoch: 54, Loss: 0.7681776881217957
Epoch: 55, Loss: 0.7643368244171143
Epoch: 56, Loss: 0.7608219981193542
Epoch: 57, Loss: 0.7575987577438354
Epoch: 58, Loss: 0.7546370625495911
Epoch: 59, Loss: 0.7519098520278931
Epoch: 60, Loss: 0.7493933439254761
Epoch: 61, Loss: 0.7470666170120239
Epoch: 62, Loss: 0.7449116706848145
Epoch: 63, Loss: 0.7429123520851135
Epoch: 64, Loss: 0.7410544157028198
Epoch: 65, Loss: 0.7393249869346619
Epoch: 66, Loss: 0.7377125024795532
Epoch: 67, Loss: 0.7362067699432373
Epoch: 68, Loss: 0.7347980737686157
Epoch: 69, Loss: 0.7334779500961304
Epoch: 70, Loss: 0.7322388291358948
Epoch: 71, Loss: 0.7310730814933777
Epoch: 72, Loss: 0.7299746870994568
Epoch: 73, Loss: 0.7289379239082336
Epoch: 74, Loss: 0.727957546710968
Epoch: 75, Loss: 0.7270289659500122
Epoch: 76, Loss: 0.7261483669281006
Epoch: 77, Loss: 0.7253118753433228
Epoch: 78, Loss: 0.7245160341262817
Epoch: 79, Loss: 0.7237582206726074
Epoch: 80, Loss: 0.7230353355407715
Epoch: 81, Loss: 0.7223449945449829
Epoch: 82, Loss: 0.7216848134994507
Epoch: 83, Loss: 0.7210525274276733
Epoch: 84, Loss: 0.7204459309577942
Epoch: 85, Loss: 0.7198634743690491
Epoch: 86, Loss: 0.7193030118942261
Epoch: 87, Loss: 0.7187634110450745
Epoch: 88, Loss: 0.7182430028915405
Epoch: 89, Loss: 0.7177403569221497
Epoch: 90, Loss: 0.7172543406486511
Epoch: 91, Loss: 0.716783881187439
Epoch: 92, Loss: 0.7163279056549072
Epoch: 93, Loss: 0.7158852815628052
Epoch: 94, Loss: 0.7154552340507507
Epoch: 95, Loss: 0.7150367498397827
Epoch: 96, Loss: 0.7146288156509399
Epoch: 97, Loss: 0.7142307758331299
Epoch: 98, Loss: 0.713841438293457
Epoch: 99, Loss: 0.7134604454040527
Epoch: 100, Loss: 0.7130865454673767
Epoch: 101, Loss: 0.7127190828323364
Epoch: 102, Loss: 0.7123571038246155
Epoch: 103, Loss: 0.7120000720024109
Epoch: 104, Loss: 0.7116469144821167
Epoch: 105, Loss: 0.7112966179847717
Epoch: 106, Loss: 0.7109485864639282
Epoch: 107, Loss: 0.7106014490127563
Epoch: 108, Loss: 0.7102541327476501
Epoch: 109, Loss: 0.7099055051803589
Epoch: 110, Loss: 0.7095539569854736
Epoch: 111, Loss: 0.7091981768608093
Epoch: 112, Loss: 0.7088361978530884
Epoch: 113, Loss: 0.7084661722183228
Epoch: 114, Loss: 0.708085834980011
Epoch: 115, Loss: 0.7076926827430725
Epoch: 116, Loss: 0.7072839736938477
Epoch: 117, Loss: 0.7068567276000977
Epoch: 118, Loss: 0.7064076662063599
Epoch: 119, Loss: 0.7059327960014343
Epoch: 120, Loss: 0.7054286599159241
Epoch: 121, Loss: 0.7048908472061157
Epoch: 122, Loss: 0.7043149471282959
Epoch: 123, Loss: 0.7036959528923035
Epoch: 124, Loss: 0.7030287981033325
Epoch: 125, Loss: 0.7023081183433533
Epoch: 126, Loss: 0.7015291452407837
Epoch: 127, Loss: 0.700687050819397
Epoch: 128, Loss: 0.6997777819633484
Epoch: 129, Loss: 0.6987966895103455
Epoch: 130, Loss: 0.6977395415306091
Epoch: 131, Loss: 0.696602463722229
Epoch: 132, Loss: 0.695381760597229
Epoch: 133, Loss: 0.6940730810165405
Epoch: 134, Loss: 0.6926735043525696
Epoch: 135, Loss: 0.6911810040473938
Epoch: 136, Loss: 0.6895937323570251
Epoch: 137, Loss: 0.6879096031188965
Epoch: 138, Loss: 0.6861276030540466
Epoch: 139, Loss: 0.6842458248138428
Epoch: 140, Loss: 0.6822646856307983
Epoch: 141, Loss: 0.6801841855049133
Epoch: 142, Loss: 0.678005039691925
Epoch: 143, Loss: 0.6757267713546753
Epoch: 144, Loss: 0.6733517646789551
Epoch: 145, Loss: 0.670880913734436
Epoch: 146, Loss: 0.6683160066604614
Epoch: 147, Loss: 0.6656591892242432
Epoch: 148, Loss: 0.662912905216217
Epoch: 149, Loss: 0.6600789427757263
Epoch: 150, Loss: 0.6571605801582336
Epoch: 151, Loss: 0.6541601419448853
Epoch: 152, Loss: 0.6510804891586304
Epoch: 153, Loss: 0.647924542427063
Epoch: 154, Loss: 0.6446951627731323
Epoch: 155, Loss: 0.6413953304290771
Epoch: 156, Loss: 0.6380279660224915
Epoch: 157, Loss: 0.6346004009246826
Epoch: 158, Loss: 0.6311676502227783
Epoch: 159, Loss: 0.6280736923217773
Epoch: 160, Loss: 0.6247416138648987
Epoch: 161, Loss: 0.6204184293746948
Epoch: 162, Loss: 0.6175937056541443
Epoch: 163, Loss: 0.6130721569061279
Epoch: 164, Loss: 0.6099883913993835
Epoch: 165, Loss: 0.6056156754493713
Epoch: 166, Loss: 0.6024593710899353
Epoch: 167, Loss: 0.5980790853500366
Epoch: 168, Loss: 0.5947374105453491
Epoch: 169, Loss: 0.5904760360717773
Epoch: 170, Loss: 0.5870139598846436
Epoch: 171, Loss: 0.582822859287262
Epoch: 172, Loss: 0.5791879892349243
Epoch: 173, Loss: 0.5751645565032959
Epoch: 174, Loss: 0.5713585019111633
Epoch: 175, Loss: 0.567484438419342
Epoch: 176, Loss: 0.5635411739349365
Epoch: 177, Loss: 0.5597836375236511
Epoch: 178, Loss: 0.5557853579521179
Epoch: 179, Loss: 0.5520825386047363
Epoch: 180, Loss: 0.5481284856796265
Epoch: 181, Loss: 0.5443800091743469
Epoch: 182, Loss: 0.5405675172805786
Epoch: 183, Loss: 0.5367638468742371
Epoch: 184, Loss: 0.5330841541290283
Epoch: 185, Loss: 0.5293054580688477
Epoch: 186, Loss: 0.5256626605987549
Epoch: 187, Loss: 0.5220222473144531
Epoch: 188, Loss: 0.5183926820755005
Epoch: 189, Loss: 0.5148745775222778
Epoch: 190, Loss: 0.5113506317138672
Epoch: 191, Loss: 0.5078821182250977
Epoch: 192, Loss: 0.5045010447502136
Epoch: 193, Loss: 0.5011410117149353
Epoch: 194, Loss: 0.49784329533576965
Epoch: 195, Loss: 0.4946274161338806
Epoch: 196, Loss: 0.49145352840423584
Epoch: 197, Loss: 0.4883362650871277
Epoch: 198, Loss: 0.4852979779243469
Epoch: 199, Loss: 0.4823189973831177
Epoch: 200, Loss: 0.47939279675483704
Epoch: 201, Loss: 0.4765353798866272
Epoch: 202, Loss: 0.47374674677848816
Epoch: 203, Loss: 0.47101518511772156
Epoch: 204, Loss: 0.46834152936935425
Epoch: 205, Loss: 0.46573275327682495
Epoch: 206, Loss: 0.46318715810775757
Epoch: 207, Loss: 0.4606989622116089
Epoch: 208, Loss: 0.4582676887512207
Epoch: 209, Loss: 0.45589637756347656
Epoch: 210, Loss: 0.45358556509017944
Epoch: 211, Loss: 0.45133209228515625
Epoch: 212, Loss: 0.44913429021835327
Epoch: 213, Loss: 0.4469928443431854
Epoch: 214, Loss: 0.4449085295200348
Epoch: 215, Loss: 0.44288063049316406
Epoch: 216, Loss: 0.44090723991394043
Epoch: 217, Loss: 0.4389871656894684
Epoch: 218, Loss: 0.43712010979652405
Epoch: 219, Loss: 0.43530601263046265
Epoch: 220, Loss: 0.43354326486587524
Epoch: 221, Loss: 0.4318305552005768
Epoch: 222, Loss: 0.43016621470451355
Epoch: 223, Loss: 0.42854925990104675
Epoch: 224, Loss: 0.42697861790657043
Epoch: 225, Loss: 0.4254530668258667
Epoch: 226, Loss: 0.42397117614746094
Epoch: 227, Loss: 0.4225310683250427
Epoch: 228, Loss: 0.42113178968429565
Epoch: 229, Loss: 0.41977158188819885
Epoch: 230, Loss: 0.4184495806694031
Epoch: 231, Loss: 0.41716426610946655
Epoch: 232, Loss: 0.41591429710388184
Epoch: 233, Loss: 0.4146982431411743
Epoch: 234, Loss: 0.41351503133773804
Epoch: 235, Loss: 0.4123631417751312
Epoch: 236, Loss: 0.41124171018600464
Epoch: 237, Loss: 0.4101494550704956
Epoch: 238, Loss: 0.4090851843357086
Epoch: 239, Loss: 0.40804779529571533
Epoch: 240, Loss: 0.40703630447387695
Epoch: 241, Loss: 0.4060494899749756
Epoch: 242, Loss: 0.4050863981246948
Epoch: 243, Loss: 0.4041459560394287
Epoch: 244, Loss: 0.40322721004486084
Epoch: 245, Loss: 0.40232914686203003
Epoch: 246, Loss: 0.4014507234096527
Epoch: 247, Loss: 0.4005908966064453
Epoch: 248, Loss: 0.39974868297576904
Epoch: 249, Loss: 0.39892318844795227
Epoch: 250, Loss: 0.3981132507324219
Epoch: 251, Loss: 0.3973178267478943
Epoch: 252, Loss: 0.3965359330177307
Epoch: 253, Loss: 0.39576637744903564
Epoch: 254, Loss: 0.3950079679489136
Epoch: 255, Loss: 0.39425957202911377
Epoch: 256, Loss: 0.39351969957351685
Epoch: 257, Loss: 0.3927871584892273
Epoch: 258, Loss: 0.39206013083457947
Epoch: 259, Loss: 0.3913373351097107
Epoch: 260, Loss: 0.39061659574508667
Epoch: 261, Loss: 0.38989609479904175
Epoch: 262, Loss: 0.38917356729507446
Epoch: 263, Loss: 0.388446569442749
Epoch: 264, Loss: 0.3877125084400177
Epoch: 265, Loss: 0.3869687616825104
Epoch: 266, Loss: 0.3862120509147644
Epoch: 267, Loss: 0.385439395904541
Epoch: 268, Loss: 0.38464784622192383
Epoch: 269, Loss: 0.38383424282073975
Epoch: 270, Loss: 0.38299548625946045
Epoch: 271, Loss: 0.38212850689888
Epoch: 272, Loss: 0.3812297582626343
Epoch: 273, Loss: 0.3802952170372009
Epoch: 274, Loss: 0.37932053208351135
Epoch: 275, Loss: 0.3783015012741089
Epoch: 276, Loss: 0.3772348165512085
Epoch: 277, Loss: 0.3761177659034729
Epoch: 278, Loss: 0.3749470114707947
Epoch: 279, Loss: 0.37371814250946045
Epoch: 280, Loss: 0.3724268674850464
Epoch: 281, Loss: 0.3710700273513794
Epoch: 282, Loss: 0.36964452266693115
Epoch: 283, Loss: 0.36814582347869873
Epoch: 284, Loss: 0.3665703237056732
Epoch: 285, Loss: 0.36491674184799194
Epoch: 286, Loss: 0.3631839156150818
Epoch: 287, Loss: 0.36136889457702637
Epoch: 288, Loss: 0.35947149991989136
Epoch: 289, Loss: 0.35749104619026184
Epoch: 290, Loss: 0.3554253578186035
Epoch: 291, Loss: 0.35327666997909546
Epoch: 292, Loss: 0.3510452210903168
Epoch: 293, Loss: 0.3487322926521301
Epoch: 294, Loss: 0.34633979201316833
Epoch: 295, Loss: 0.3438686728477478
Epoch: 296, Loss: 0.34132301807403564
Epoch: 297, Loss: 0.33870524168014526
Epoch: 298, Loss: 0.3360193967819214
Epoch: 299, Loss: 0.33326831459999084
Epoch: 300, Loss: 0.3304571807384491
Epoch: 301, Loss: 0.3275907337665558
Epoch: 302, Loss: 0.32467275857925415
Epoch: 303, Loss: 0.3217098116874695
Epoch: 304, Loss: 0.3187068700790405
Epoch: 305, Loss: 0.3156689405441284
Epoch: 306, Loss: 0.3126027584075928
Epoch: 307, Loss: 0.3095141649246216
Epoch: 308, Loss: 0.30640846490859985
Epoch: 309, Loss: 0.3032921552658081
Epoch: 310, Loss: 0.3001706600189209
Epoch: 311, Loss: 0.2970506548881531
Epoch: 312, Loss: 0.2939375638961792
Epoch: 313, Loss: 0.2908394932746887
Epoch: 314, Loss: 0.28776970505714417
Epoch: 315, Loss: 0.28478091955184937
Epoch: 316, Loss: 0.28216126561164856
Epoch: 317, Loss: 0.2806890606880188
Epoch: 318, Loss: 0.281312495470047
Epoch: 319, Loss: 0.27288466691970825
Epoch: 320, Loss: 0.27456799149513245
Epoch: 321, Loss: 0.27207839488983154
Epoch: 322, Loss: 0.2669985592365265
Epoch: 323, Loss: 0.26659494638442993
Epoch: 324, Loss: 0.25929903984069824
Epoch: 325, Loss: 0.2601289749145508
Epoch: 326, Loss: 0.254558801651001
Epoch: 327, Loss: 0.2547817528247833
Epoch: 328, Loss: 0.249167338013649
Epoch: 329, Loss: 0.24909116327762604
Epoch: 330, Loss: 0.2447076141834259
Epoch: 331, Loss: 0.24408449232578278
Epoch: 332, Loss: 0.2399887889623642
Epoch: 333, Loss: 0.23915280401706696
Epoch: 334, Loss: 0.2357933521270752
Epoch: 335, Loss: 0.2346203625202179
Epoch: 336, Loss: 0.231528177857399
Epoch: 337, Loss: 0.2302950769662857
Epoch: 338, Loss: 0.22760164737701416
Epoch: 339, Loss: 0.22624777257442474
Epoch: 340, Loss: 0.22371965646743774
Epoch: 341, Loss: 0.22239263355731964
Epoch: 342, Loss: 0.22007593512535095
Epoch: 343, Loss: 0.2187439650297165
Epoch: 344, Loss: 0.21652215719223022
Epoch: 345, Loss: 0.21523401141166687
Epoch: 346, Loss: 0.21314474940299988
Epoch: 347, Loss: 0.21188393235206604
Epoch: 348, Loss: 0.20987601578235626
Epoch: 349, Loss: 0.20864510536193848
Epoch: 350, Loss: 0.20675095915794373
Epoch: 351, Loss: 0.2055433839559555
Epoch: 352, Loss: 0.20373988151550293
Epoch: 353, Loss: 0.20253685116767883
Epoch: 354, Loss: 0.20084065198898315
Epoch: 355, Loss: 0.1996387541294098
Epoch: 356, Loss: 0.1980428397655487
Epoch: 357, Loss: 0.19682282209396362
Epoch: 358, Loss: 0.1953372359275818
Epoch: 359, Loss: 0.1940995752811432
Epoch: 360, Loss: 0.1927228569984436
Epoch: 361, Loss: 0.19146019220352173
Epoch: 362, Loss: 0.19018493592739105
Epoch: 363, Loss: 0.18890531361103058
Epoch: 364, Loss: 0.1877131164073944
Epoch: 365, Loss: 0.1864350140094757
Epoch: 366, Loss: 0.1853000968694687
Epoch: 367, Loss: 0.18404754996299744
Epoch: 368, Loss: 0.18294095993041992
Epoch: 369, Loss: 0.18173813819885254
Epoch: 370, Loss: 0.18063828349113464
Epoch: 371, Loss: 0.17949633300304413
Epoch: 372, Loss: 0.17839308083057404
Epoch: 373, Loss: 0.17731067538261414
Epoch: 374, Loss: 0.17621085047721863
Epoch: 375, Loss: 0.1751726269721985
Epoch: 376, Loss: 0.17409324645996094
Epoch: 377, Loss: 0.1730785220861435
Epoch: 378, Loss: 0.17203523218631744
Epoch: 379, Loss: 0.1710285097360611
Epoch: 380, Loss: 0.170027494430542
Epoch: 381, Loss: 0.1690281182527542
Epoch: 382, Loss: 0.16806179285049438
Epoch: 383, Loss: 0.16708001494407654
Epoch: 384, Loss: 0.1661340892314911
Epoch: 385, Loss: 0.16518047451972961
Epoch: 386, Loss: 0.16424526274204254
Epoch: 387, Loss: 0.16332247853279114
Epoch: 388, Loss: 0.16239924728870392
Epoch: 389, Loss: 0.16149979829788208
Epoch: 390, Loss: 0.16059643030166626
Epoch: 391, Loss: 0.15971118211746216
Epoch: 392, Loss: 0.15883225202560425
Epoch: 393, Loss: 0.1579587608575821
Epoch: 394, Loss: 0.15710139274597168
Epoch: 395, Loss: 0.15624427795410156
Epoch: 396, Loss: 0.15540161728858948
Epoch: 397, Loss: 0.154564768075943
Epoch: 398, Loss: 0.15373432636260986
Epoch: 399, Loss: 0.15291598439216614
Epoch: 400, Loss: 0.1521005630493164
There is much more analysis that can be done with GNNs on the Karate Club dataseet. See here for examples. Our goal in this lecture was to use it an way to introduce the implementation and use of GNNs within the Pytorch framework.
We will now turn to examples using simulated open data from the CMS experiment at the LHC. The algorithms’ inputs are features of the reconstructed charged particles in a jet and the secondary vertices associated with them. Describing the jet shower as a combination of particle-to-particle and particle-to-vertex interactions, the model is trained to learn a jet representation on which the classification problem isoptimized.
We show below an example using a community model called “DeepSets” and then compare to an Interaction Model GNN.
Deep Sets#
We will start by looking at Deep Sets networks using PyTorch. The architecture is based on the following paper: DeepSets
!pip install wget
import wget
!pip install -U PyYAML
!pip install uproot
!pip install awkward
!pip install mplhep
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
Requirement already satisfied: wget in /usr/local/lib/python3.11/site-packages (3.2)
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
Requirement already satisfied: PyYAML in /usr/local/lib/python3.11/site-packages (6.0.1)
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
Requirement already satisfied: uproot in /usr/local/lib/python3.11/site-packages (5.0.11)
Requirement already satisfied: awkward>=2.0.0 in /usr/local/lib/python3.11/site-packages (from uproot) (2.4.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/site-packages (from uproot) (1.23.5)
Requirement already satisfied: packaging in /usr/local/lib/python3.11/site-packages (from uproot) (23.0)
Requirement already satisfied: awkward-cpp==23 in /usr/local/lib/python3.11/site-packages (from awkward>=2.0.0->uproot) (23)
Requirement already satisfied: importlib-metadata>=4.13.0 in /usr/local/lib/python3.11/site-packages (from awkward>=2.0.0->uproot) (6.8.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.11/site-packages (from importlib-metadata>=4.13.0->awkward>=2.0.0->uproot) (3.16.2)
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
Requirement already satisfied: awkward in /usr/local/lib/python3.11/site-packages (2.4.1)
Requirement already satisfied: awkward-cpp==23 in /usr/local/lib/python3.11/site-packages (from awkward) (23)
Requirement already satisfied: importlib-metadata>=4.13.0 in /usr/local/lib/python3.11/site-packages (from awkward) (6.8.0)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.11/site-packages (from awkward) (1.23.5)
Requirement already satisfied: packaging in /usr/local/lib/python3.11/site-packages (from awkward) (23.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.11/site-packages (from importlib-metadata>=4.13.0->awkward) (3.16.2)
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
Requirement already satisfied: mplhep in /usr/local/lib/python3.11/site-packages (0.3.28)
Requirement already satisfied: matplotlib>=3.4 in /usr/local/lib/python3.11/site-packages (from mplhep) (3.7.0)
Requirement already satisfied: mplhep-data in /usr/local/lib/python3.11/site-packages (from mplhep) (0.0.3)
Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.11/site-packages (from mplhep) (1.23.5)
Requirement already satisfied: packaging in /usr/local/lib/python3.11/site-packages (from mplhep) (23.0)
Requirement already satisfied: uhi>=0.2.0 in /usr/local/lib/python3.11/site-packages (from mplhep) (0.3.3)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/site-packages (from matplotlib>=3.4->mplhep) (1.0.7)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/site-packages (from matplotlib>=3.4->mplhep) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/site-packages (from matplotlib>=3.4->mplhep) (4.38.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.11/site-packages (from matplotlib>=3.4->mplhep) (1.4.4)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.11/site-packages (from matplotlib>=3.4->mplhep) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/site-packages (from matplotlib>=3.4->mplhep) (3.0.9)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/site-packages (from matplotlib>=3.4->mplhep) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib>=3.4->mplhep) (1.16.0)
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
WARNING: Skipping /usr/local/lib/python3.11/site-packages/six-1.16.0-py3.11.egg-info due to invalid metadata entry 'name'
import yaml
# WGET for colab
if not os.path.exists("definitions_lorentz.yml"):
    url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml"
    definitionsFile = wget.download(url)
with open("definitions_lorentz.yml") as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    definitions = yaml.load(file, Loader=yaml.FullLoader)
features = definitions["features"]
spectators = definitions["spectators"]
labels = definitions["labels"]
nfeatures = definitions["nfeatures"]
nspectators = definitions["nspectators"]
nlabels = definitions["nlabels"]
ntracks = definitions["ntracks"]
Data Loader#
Here we have to define the dataset loader.
# If in colab
if not os.path.exists("GraphDataset.py"):
    urlDSD = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/DeepSetsDataset.py"
    DSD = wget.download(urlDSD)
if not os.path.exists("utils.py"):
    urlUtils = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/utils.py"
    utils = wget.download(urlUtils)
from DeepSetsDataset import DeepSetsDataset
# For colab
import os.path
if not os.path.exists("ntuple_merged_90.root"):
    urlFILE = "http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root"
    dataFILE = wget.download(urlFILE)
train_files = ["ntuple_merged_90.root"]
train_generator = DeepSetsDataset(
    features,
    labels,
    spectators,
    start_event=0,
    stop_event=10000,
    npad=ntracks,
    file_names=train_files,
)
train_generator.process()
test_generator = DeepSetsDataset(
    features,
    labels,
    spectators,
    start_event=10001,
    stop_event=14001,
    npad=ntracks,
    file_names=train_files,
)
test_generator.process()
Deep Sets Network#
Deep Sets models are designed to be explicitly permutation invariant. At their core they are composed of two networks, \(\phi\) and \(\rho\), such that the total network \(f\) is given by
where \(\mathbf{x}_i\) are the features for the \(i\)-th element in the input sequence \(\mathcal{X}\).
We will define a DeepSets model that will take as input up to 60 of the tracks (with 48 features) with zero-padding.
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import (
    Sequential as Seq,
    Linear as Lin,
    ReLU,
    BatchNorm1d,
    AvgPool1d,
    Sigmoid,
    Conv1d,
)
from torch_scatter import scatter_mean
# ntracks = 60
inputs = 6
hidden1 = 64
hidden2 = 32
hidden3 = 16
classify1 = 50
outputs = 2
class DeepSets(torch.nn.Module):
    def __init__(self):
        super(DeepSets, self).__init__()
        self.phi = Seq(
            Conv1d(inputs, hidden1, 1),
            BatchNorm1d(hidden1),
            ReLU(),
            Conv1d(hidden1, hidden2, 1),
            BatchNorm1d(hidden2),
            ReLU(),
            Conv1d(hidden2, hidden3, 1),
            BatchNorm1d(hidden3),
            ReLU(),
        )
        self.rho = Seq(
            Lin(hidden3, classify1),
            BatchNorm1d(classify1),
            ReLU(),
            Lin(classify1, outputs),
            Sigmoid(),
        )
    def forward(self, x):
        out = self.phi(x)
        out = scatter_mean(out, torch.LongTensor(np.zeros(ntracks)), dim=-1)
        return self.rho(torch.squeeze(out))
model = DeepSets()
print(model)
print("----------")
print({l: model.state_dict()[l].shape for l in model.state_dict()})
model = DeepSets().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
DeepSets(
  (phi): Sequential(
    (0): Conv1d(6, 64, kernel_size=(1,), stride=(1,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
    (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
    (7): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (rho): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=50, out_features=2, bias=True)
    (4): Sigmoid()
  )
)
----------
{'phi.0.weight': torch.Size([64, 6, 1]), 'phi.0.bias': torch.Size([64]), 'phi.1.weight': torch.Size([64]), 'phi.1.bias': torch.Size([64]), 'phi.1.running_mean': torch.Size([64]), 'phi.1.running_var': torch.Size([64]), 'phi.1.num_batches_tracked': torch.Size([]), 'phi.3.weight': torch.Size([32, 64, 1]), 'phi.3.bias': torch.Size([32]), 'phi.4.weight': torch.Size([32]), 'phi.4.bias': torch.Size([32]), 'phi.4.running_mean': torch.Size([32]), 'phi.4.running_var': torch.Size([32]), 'phi.4.num_batches_tracked': torch.Size([]), 'phi.6.weight': torch.Size([16, 32, 1]), 'phi.6.bias': torch.Size([16]), 'phi.7.weight': torch.Size([16]), 'phi.7.bias': torch.Size([16]), 'phi.7.running_mean': torch.Size([16]), 'phi.7.running_var': torch.Size([16]), 'phi.7.num_batches_tracked': torch.Size([]), 'rho.0.weight': torch.Size([50, 16]), 'rho.0.bias': torch.Size([50]), 'rho.1.weight': torch.Size([50]), 'rho.1.bias': torch.Size([50]), 'rho.1.running_mean': torch.Size([50]), 'rho.1.running_var': torch.Size([50]), 'rho.1.num_batches_tracked': torch.Size([]), 'rho.3.weight': torch.Size([2, 50]), 'rho.3.bias': torch.Size([2])}
Define the training loop#
@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
    model.eval()
    xentropy = nn.CrossEntropyLoss(reduction="mean")
    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        x = data[0].to(device)
        y = data[1].to(device)
        y = torch.argmax(y, dim=1)
        batch_output = model(x)
        batch_loss_item = xentropy(batch_output, y).item()
        sum_loss += batch_loss_item
        t.set_description("loss = %.5f" % (batch_loss_item))
        t.refresh()  # to show immediately the update
    return sum_loss / (i + 1)
def train(model, optimizer, loader, total, batch_size, leave=False):
    model.train()
    xentropy = nn.CrossEntropyLoss(reduction="mean")
    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        x = data[0].to(device)
        y = data[1].to(device)
        y = torch.argmax(y, dim=1)
        optimizer.zero_grad()
        batch_output = model(x)
        batch_loss = xentropy(batch_output, y)
        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description("loss = %.5f" % batch_loss_item)
        t.refresh()  # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()
    return sum_loss / (i + 1)
Define training, validation, testing data generators#
from torch.utils.data import ConcatDataset
train_generator_data = ConcatDataset(train_generator.datas)
test_generator_data = ConcatDataset(test_generator.datas)
from torch.utils.data import random_split, DataLoader
torch.manual_seed(0)
valid_frac = 0.20
train_length = len(train_generator_data)
valid_num = int(valid_frac * train_length)
batch_size = 32
train_dataset, valid_dataset = random_split(
    train_generator_data, [train_length - valid_num, valid_num]
)
def collate(items):
    l = sum(items, [])
    return Batch.from_data_list(l)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# train_loader.collate_fn = collate
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
# valid_loader.collate_fn = collate
test_loader = DataLoader(test_generator_data, batch_size=batch_size, shuffle=False)
# test_loader.collate_fn = collate
train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_generator_data)
print(train_length)
print(train_samples)
print(valid_samples)
print(test_samples)
9387
7510
1877
3751
Train#
import os.path as osp
n_epochs = 30
stale_epochs = 0
best_valid_loss = 99999
patience = 5
t = tqdm(range(0, n_epochs))
for epoch in t:
    loss = train(
        model,
        optimizer,
        train_loader,
        train_samples,
        batch_size,
        leave=bool(epoch == n_epochs - 1),
    )
    valid_loss = test(
        model,
        valid_loader,
        valid_samples,
        batch_size,
        leave=bool(epoch == n_epochs - 1),
    )
    print("Epoch: {:02d}, Training Loss:   {:.4f}".format(epoch, loss))
    print("           Validation Loss: {:.4f}".format(valid_loss))
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        modpath = osp.join("deepsets_best.pth")
        print("New best model saved to:", modpath)
        torch.save(model.state_dict(), modpath)
        stale_epochs = 0
    else:
        print("Stale epoch")
        stale_epochs += 1
    if stale_epochs >= patience:
        print("Early stopping after %i stale epochs" % patience)
        break
Epoch: 00, Training Loss:   0.4487
           Validation Loss: 0.4485
New best model saved to: deepsets_best.pth
Epoch: 01, Training Loss:   0.4403
           Validation Loss: 0.4482
New best model saved to: deepsets_best.pth
Epoch: 02, Training Loss:   0.4411
           Validation Loss: 0.4481
New best model saved to: deepsets_best.pth
Epoch: 03, Training Loss:   0.4401
           Validation Loss: 0.4489
Stale epoch
Epoch: 04, Training Loss:   0.4398
           Validation Loss: 0.4489
Stale epoch
Epoch: 05, Training Loss:   0.4398
           Validation Loss: 0.4473
New best model saved to: deepsets_best.pth
Epoch: 06, Training Loss:   0.4397
           Validation Loss: 0.4446
New best model saved to: deepsets_best.pth
Epoch: 07, Training Loss:   0.4391
           Validation Loss: 0.4441
New best model saved to: deepsets_best.pth
Epoch: 08, Training Loss:   0.4397
           Validation Loss: 0.4489
Stale epoch
Epoch: 09, Training Loss:   0.4389
           Validation Loss: 0.4489
Stale epoch
Epoch: 10, Training Loss:   0.4387
           Validation Loss: 0.4460
Stale epoch
Epoch: 11, Training Loss:   0.4393
           Validation Loss: 0.4458
Stale epoch
Epoch: 12, Training Loss:   0.4386
           Validation Loss: 0.4442
Stale epoch
Early stopping after 5 stale epochs
Evaluate on Test Data#
# In case you need to load the model from a pth file
# Trained on 4 vectors (as above in notebook)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepsets_best_4vec.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("deepsets_best_4vec.pth"))
# Trained on all possible inputs (a different configuration)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepsets_best_AllTraining.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("deepsets_best_AllTraining.pth"))
model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
track_pt = []
for i, data in t:
    x = data[0].to(device)
    y = data[1].to(device)
    track_pt.append(x[:, 0, 0].numpy())
    batch_output = model(x)
    y_predict.append(batch_output.detach().cpu().numpy())
    y_test.append(y.cpu().numpy())
track_pt = np.concatenate(track_pt)
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
_, bins, _ = plt.hist(track_pt[y_test[:, 1] == 1], bins=50, label="sig", histtype="step")
_, bins, _ = plt.hist(track_pt[y_test[:, 1] == 0], bins=bins, label="bkg", histtype="step")
plt.legend()
plt.semilogy()
[]
 
from sklearn.metrics import roc_curve, auc
import mplhep as hep
plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_deepset, tpr_deepset, threshold_deepset = roc_curve(y_test[:, 1], y_predict[:, 1])
with open("deepset_roc.npy", "wb") as f:
    np.save(f, fpr_deepset)
    np.save(f, tpr_deepset)
    np.save(f, threshold_deepset)
# plot ROC curves
plt.figure()
plt.plot(
    tpr_deepset,
    fpr_deepset,
    lw=2.5,
    label="DeepSet, AUC = {:.1f}%".format(auc(fpr_deepset, tpr_deepset) * 100),
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
 
AUC and ROC Curves#
To better understand this curve, let define the confusion matrix:
True Positive Rate (TPR) is a synonym for recall/sensitivity and is therefore defined as follows:
TPR tells us what proportion of the positive class got correctly classified. A simple example would be determining what proportion of the actual sick people were correctly detected by the model.
False Negative Rate (FNR) is defined as follows:
FNR tells us what proportion of the positive class got incorrectly classified by the classifier. A higher TPR and a lower FNR are desirable since we want to classify the positive class correctly.
True Negative Rate (TNR) is a synonym for specificity and is defined as follows:
Specificity tells us what proportion of the negative class got correctly classified. Taking the same example as in Sensitivity, Specificity would mean determining the proportion of healthy people who were correctly identified by the model.
False Positive Rate (FPR) is defined as follows:
FPR tells us what proportion of the negative class got incorrectly classified by the classifier. A higher TNR and a lower FPR are desirable since we want to classify the negative class correctly.
A ROC curve (receiver operating characteristic curve) is a graph showing the performance of a classification model at all classification thresholds. This curve plots two parameters:
- True Positive Rate 
- False Positive Rate 
An ROC curve plots TPR vs. FPR at different classification thresholds. Lowering the classification threshold classifies more items as positive, thus increasing both False Positives and True Positives. The following figure shows a typical ROC curve.
AUC stands for “Area under the ROC Curve.” That is, AUC measures the entire two-dimensional area underneath the entire ROC curve (think integral calculus) from (0,0) to (1,1). AUC provides an aggregate measure of performance across all possible classification thresholds.
#
Interaction Network#
Now we will look at graph neural networks using the PyTorch Geometric library: https://pytorch-geometric.readthedocs.io/. See [] for more details.
import yaml
import os.path
# WGET for colab
if not os.path.exists("definitions.yml"):
    url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions.yml"
    definitionsFile = wget.download(url)
with open("definitions.yml") as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    definitions = yaml.load(file, Loader=yaml.FullLoader)
# You can test with using only 4-vectors by using:
# if not os.path.exists("definitions_lorentz.yml"):
#    url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml"
#    definitionsFile = wget.download(url)
# with open('definitions_lorentz.yml') as file:
#    # The FullLoader parameter handles the conversion from YAML
#    # scalar values to Python the dictionary format
#    definitions = yaml.load(file, Loader=yaml.FullLoader)
features = definitions["features"]
spectators = definitions["spectators"]
labels = definitions["labels"]
nfeatures = definitions["nfeatures"]
nspectators = definitions["nspectators"]
nlabels = definitions["nlabels"]
ntracks = definitions["ntracks"]
Graph Datasets#
Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets
Formally, a graph is represented by a triplet \(\mathcal G = (\mathbf{u}, V, E)\), consisting of a graph-level, or global, feature vector \(\mathbf{u}\), a set of \(N^v\) nodes \(V\), and a set of \(N^e\) edges \(E\). The nodes are given by \(V = \{\mathbf{v}_i\}_{i=1:N^v}\), where \(\mathbf{v}_i\) represents the \(i\)th node’s attributes. The edges connect pairs of nodes and are given by \(E = \{\left(\mathbf{e}_k, r_k, s_k\right)\}_{k=1:N^e}\), where \(\mathbf{e}_k\) represents the \(k\)th edge’s attributes, and \(r_k\) and \(s_k\) are the indices of the “receiver” and “sender” nodes, respectively, connected by the \(k\)th edge (from the sender node to the receiver node). The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.
 
# If in colab
if not os.path.exists("GraphDataset.py"):
    urlDSD = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/GraphDataset.py"
    DSD = wget.download(urlDSD)
if not os.path.exists("utils.py"):
    urlUtils = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/utils.py"
    utils = wget.download(urlUtils)
from GraphDataset import GraphDataset
# For Colab
if not os.path.exists("ntuple_merged_90.root"):
    urlFILE = "http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root"
    dataFILE = wget.download(urlFILE)
file_names = ["ntuple_merged_90.root"]
graph_dataset = GraphDataset(
    "gdata_train",
    features,
    labels,
    spectators,
    start_event=0,
    stop_event=8000,
    n_events_merge=1,
    file_names=file_names,
)
test_dataset = GraphDataset(
    "gdata_test",
    features,
    labels,
    spectators,
    start_event=8001,
    stop_event=10001,
    n_events_merge=1,
    file_names=file_names,
)
Graph Neural Network#
Here, we recapitulate the “graph network” (GN) formalism described in this paper, which generalizes various GNNs and other similar methods.
GNs are graph-to-graph mappings, whose output graphs have the same node and edge structure as the input. Formally, a GN block contains three “update” functions, \(\phi\), and three “aggregation” functions, \(\rho\). The stages of processing in a single GN block are:
where \(E'_i = \left\{\left(\mathbf{e}'_k, r_k, s_k \right)\right\}_{r_k=i,\; k=1:N^e}\) contains the updated edge features for edges whose receiver node is the \(i^\text{th}\) node, \(E' = \bigcup_i E_i' = \left\{\left(\mathbf{e}'_k, r_k, s_k \right)\right\}_{k=1:N^e}\) is the set of updated edges, and \(V'=\left\{\mathbf{v}'_i\right\}_{i=1:N^v}\) is the set of updated nodes.
 
We will define an interaction network model similar to this paper, but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization [] layers to improve the stability of the training.
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import EdgeConv, global_mean_pool
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer
inputs = 48
hidden = 128
outputs = 2
class EdgeBlock(torch.nn.Module):
    def __init__(self):
        super(EdgeBlock, self).__init__()
        self.edge_mlp = Seq(
            Lin(inputs * 2, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden)
        )
    def forward(self, src, dest, edge_attr, u, batch):
        out = torch.cat([src, dest], 1)
        return self.edge_mlp(out)
class NodeBlock(torch.nn.Module):
    def __init__(self):
        super(NodeBlock, self).__init__()
        self.node_mlp_1 = Seq(
            Lin(inputs + hidden, hidden),
            BatchNorm1d(hidden),
            ReLU(),
            Lin(hidden, hidden),
        )
        self.node_mlp_2 = Seq(
            Lin(inputs + hidden, hidden),
            BatchNorm1d(hidden),
            ReLU(),
            Lin(hidden, hidden),
        )
    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        return self.node_mlp_2(out)
class GlobalBlock(torch.nn.Module):
    def __init__(self):
        super(GlobalBlock, self).__init__()
        self.global_mlp = Seq(
            Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs)
        )
    def forward(self, x, edge_index, edge_attr, u, batch):
        out = scatter_mean(x, batch, dim=0)
        return self.global_mlp(out)
class InteractionNetwork(torch.nn.Module):
    def __init__(self):
        super(InteractionNetwork, self).__init__()
        self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())
        self.bn = BatchNorm1d(inputs)
    def forward(self, x, edge_index, batch):
        x = self.bn(x)
        x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)
        return u
model = InteractionNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
Define the training loop#
@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
    model.eval()
    xentropy = nn.CrossEntropyLoss(reduction="mean")
    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = torch.argmax(data.y, dim=1)
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss_item = xentropy(batch_output, y).item()
        sum_loss += batch_loss_item
        t.set_description("loss = %.5f" % (batch_loss_item))
        t.refresh()  # to show immediately the update
    return sum_loss / (i + 1)
def train(model, optimizer, loader, total, batch_size, leave=False):
    model.train()
    xentropy = nn.CrossEntropyLoss(reduction="mean")
    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = torch.argmax(data.y, dim=1)
        optimizer.zero_grad()
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss = xentropy(batch_output, y)
        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description("loss = %.5f" % batch_loss_item)
        t.refresh()  # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()
    return sum_loss / (i + 1)
Define training, validation, testing data generators#
from torch_geometric.data import Data, DataListLoader, Batch
from torch.utils.data import random_split
def collate(items):
    l = sum(items, [])
    return Batch.from_data_list(l)
torch.manual_seed(0)
valid_frac = 0.20
full_length = len(graph_dataset)
valid_num = int(valid_frac * full_length)
batch_size = 32
train_dataset, valid_dataset = random_split(
    graph_dataset, [full_length - valid_num, valid_num]
)
train_loader = DataListLoader(
    train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True
)
train_loader.collate_fn = collate
valid_loader = DataListLoader(
    valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False
)
valid_loader.collate_fn = collate
test_loader = DataListLoader(
    test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False
)
test_loader.collate_fn = collate
train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_dataset)
print(full_length)
print(train_samples)
print(valid_samples)
print(test_samples)
7501
6001
1500
1886
Train#
import os.path as osp
n_epochs = 10
stale_epochs = 0
best_valid_loss = 99999
patience = 5
t = tqdm(range(0, n_epochs))
for epoch in t:
    loss = train(
        model,
        optimizer,
        train_loader,
        train_samples,
        batch_size,
        leave=bool(epoch == n_epochs - 1),
    )
    valid_loss = test(
        model,
        valid_loader,
        valid_samples,
        batch_size,
        leave=bool(epoch == n_epochs - 1),
    )
    print("Epoch: {:02d}, Training Loss:   {:.4f}".format(epoch, loss))
    print("           Validation Loss: {:.4f}".format(valid_loss))
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        modpath = osp.join("interactionnetwork_best.pth")
        print("New best model saved to:", modpath)
        torch.save(model.state_dict(), modpath)
        stale_epochs = 0
    else:
        print("Stale epoch")
        stale_epochs += 1
    if stale_epochs >= patience:
        print("Early stopping after %i stale epochs" % patience)
        break
Epoch: 00, Training Loss:   0.2306
           Validation Loss: 0.1614
New best model saved to: interactionnetwork_best.pth
Epoch: 01, Training Loss:   0.1688
           Validation Loss: 0.1328
New best model saved to: interactionnetwork_best.pth
Epoch: 02, Training Loss:   0.1638
           Validation Loss: 0.1307
New best model saved to: interactionnetwork_best.pth
Epoch: 03, Training Loss:   0.1482
           Validation Loss: 0.1558
Stale epoch
Epoch: 04, Training Loss:   0.1479
           Validation Loss: 0.1069
New best model saved to: interactionnetwork_best.pth
Epoch: 05, Training Loss:   0.1415
           Validation Loss: 0.1417
Stale epoch
Epoch: 06, Training Loss:   0.1480
           Validation Loss: 0.1338
Stale epoch
Epoch: 07, Training Loss:   0.1387
           Validation Loss: 0.1198
Stale epoch
Epoch: 08, Training Loss:   0.1394
           Validation Loss: 0.1224
Stale epoch
Epoch: 09, Training Loss:   0.1290
           Validation Loss: 0.1045
New best model saved to: interactionnetwork_best.pth
Evaluate on Test Data#
# In case you need to load the model from a pth file
# Trained on all possible inputs (as above in notebook)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/interactionnetwork_best_Aug1_AllTraining.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("interactionnetwork_best_Aug1_AllTraining.pth"))
# Trained on 4 vector input (different setup)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/interactionnetwork_best_Aug1_4vec.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("interactionnetwork_best_Aug1_4vec.pth"))
model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
for i, data in t:
    data = data.to(device)
    batch_output = model(data.x, data.edge_index, data.batch)
    y_predict.append(batch_output.detach().cpu().numpy())
    y_test.append(data.y.cpu().numpy())
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import mplhep as hep
plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_gnn, tpr_gnn, threshold_gnn = roc_curve(y_test[:, 1], y_predict[:, 1])
with open("gnn_roc.npy", "wb") as f:
    np.save(f, fpr_gnn)
    np.save(f, tpr_gnn)
    np.save(f, threshold_gnn)
# For colab:
#if not os.path.exists("deepset_roc.py"):
if not os.path.exists("/Users/msn/repos/jmduarte/iaifi-summer-school/book/deepset_roc.npy"):
    urlROC = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepset_roc.npy"
    rocFile = wget.download(urlROC)
#with open("deepset_roc.npy", "rb") as f:
with open("/Users/msn/repos/jmduarte/iaifi-summer-school/book/deepset_roc.npy", "rb") as f:
    fpr_deepset = np.load(f)
    tpr_deepset = np.load(f)
    threshold_deepset = np.load(f)
# plot ROC curves
plt.figure()
plt.plot(
    tpr_deepset,
    fpr_deepset,
    lw=2.5,
    label="DeepSet, AUC = {:.1f}%".format(auc(fpr_deepset, tpr_deepset) * 100),
)
plt.plot(
    tpr_gnn,
    fpr_gnn,
    lw=2.5,
    label="GNN, AUC = {:.1f}%".format(auc(fpr_gnn, tpr_gnn) * 100),
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.semilogy()
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
 
Acknowledgments#
- Initial version: Mark Neubauer 
© Copyright 2024
 
    
  
  
 
  


