Source code for begin.utils.pretraining

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import copy
import math
import dgl
from dgl.utils import expand_as_pair
import dgl.function as fn

[docs]class PretrainingMethod(nn.Module): r""" Base framework for implementing pretraining methods. Arguments: encoder (torch.nn.Module): Pytorch model for pretraining. """
[docs] class PretrainIterator: r""" Base itearator for pretraining iterations. This class assumes full-batch training. """ def __init__(self, inputs, device): self.inputs = inputs self.count = 0 self.device = device def __iter__(self): return self def __next__(self): if self.count == 0: self.count = 1 return self.inputs.to(self.device) else: raise StopIteration
def __init__(self, encoder): super().__init__() self.encoder = encoder self.loss_fn = None
[docs] def iterator(self, inputs, device): """ Return iterator for the given input dataset. Args: inputs (object): the input graph dataset. device (str): target GPU device. Returns: An iterator for pretraining epoch. """ return self.PretrainIterator(inputs, device)
[docs] def inference(self, inputs): """ Return iterator for the given input dataset. Implementing this function is mandatory to operate the pretraining procedure. Args: inputs (object): the input sample drawn from iterator. Returns: a scalar which represents loss for pretraining. """ raise NotImplementedError
[docs] def update(self): """ This function is called when the best checkpoint needs to be updated. The default implementation stores the current `state_dict` of the model in `self.best_checkpoint`. """ self.best_checkpoint = copy.deepcopy(self.encoder.state_dict())
[docs] def processAfterTraining(self, original_model): """ This function is called once when the trainer concludes pretraining. The default implementation initializes the model using the saved best checkpoint (spec., `self.best_checkpoint`) before the main training begins. """ original_model.load_state_dict(self.best_checkpoint)
[docs]class DGI(PretrainingMethod): r""" An implementation of DGI for node-level and link-level problems. This code was implemented based on the official implementation by authors. For the details, see the `original paper <https://arxiv.org/pdf/1809.10341>`_. Arguments: encoder (torch.nn.Module): Pytorch model for pretraining. """
[docs] class Discriminator(nn.Module): def __init__(self, n_hidden): super().__init__() self.weight = nn.Parameter(torch.Tensor(n_hidden, n_hidden)) self.reset_parameters() def uniform(self, size, tensor): bound = 1.0 / math.sqrt(size) if tensor is not None: tensor.data.uniform_(-bound, bound) def reset_parameters(self): size = self.weight.size(0) self.uniform(size, self.weight) def forward(self, features, summary): features = torch.matmul(features, torch.matmul(self.weight, summary)) return features
def __init__(self, encoder, link_level=True): super().__init__(encoder) print("PRETRAINING_METHOD: DGI") self.discriminator = self.Discriminator(encoder.n_hidden) self.loss_fn = nn.BCEWithLogitsLoss() self.link_level = link_level
[docs] def inference(self, inputs): graph, features = inputs, inputs.ndata['feat'] if self.link_level: srcs, dsts = graph.edges() positive = self.encoder.forward_without_classifier(graph, features, srcs, dsts) perm = torch.randperm(graph.number_of_nodes()).to(features.device) negative = self.encoder.forward_without_classifier(graph, features[perm], srcs, dsts) else: positive = self.encoder.forward_without_classifier(graph, features) perm = torch.randperm(graph.number_of_nodes()).to(features.device) negative = self.encoder.forward_without_classifier(graph, features[perm]) summary = torch.sigmoid(positive.mean(dim=0)) positive = self.discriminator(positive, summary) negative = self.discriminator(negative, summary) l1 = self.loss_fn(positive, torch.ones_like(positive)) l2 = self.loss_fn(negative, torch.zeros_like(negative)) return l1 + l2
[docs]class DGISubgraphCL(PretrainingMethod): r""" An implementation of GraphCL (utilized with DGI, subgraph augmentation) for node-level and link-level problems. This code was implemented based on the official implementation by authors. For the details, see the `original paper <https://proceedings.neurips.cc/paper_files/paper/2020/file/3fe230348e9a12c13120749e3f9fa4cd-Paper.pdf>`_. Arguments: encoder (torch.nn.Module): Pytorch model for pretraining. """
[docs] class Discriminator(nn.Module): def __init__(self, n_hidden): super().__init__() self.weight = nn.Parameter(torch.Tensor(n_hidden, n_hidden)) self.reset_parameters() def uniform(self, size, tensor): bound = 1.0 / math.sqrt(size) if tensor is not None: tensor.data.uniform_(-bound, bound) def reset_parameters(self): size = self.weight.size(0) self.uniform(size, self.weight) def forward(self, features, summary): features = torch.matmul(features, torch.matmul(self.weight, summary)) return features
def __init__(self, encoder, link_level=True): super().__init__(encoder) print("PRETRAINING_METHOD: GraphCL (Node/Link)") self.discriminator = self.Discriminator(encoder.n_hidden) self.loss_fn = nn.BCEWithLogitsLoss() self.link_level = link_level def do_augmentation(self, graph): device = graph.ndata['feat'].device srcs, dsts = graph.edges() n_nodes = graph.num_nodes() with torch.no_grad(): with graph.local_scope(): root = torch.randint(n_nodes, (1,)) init_feat = n_nodes * torch.ones(n_nodes).to(device) init_feat[root.item()] = 0. graph.ndata['hop'] = init_feat graph.edata['w'] = torch.ones(graph.num_edges()).to(device) graph.edata['w'][srcs == dsts] = 0. cnt = (graph.ndata['hop'] < (n_nodes - 0.5)).sum().item() hop_cnt = 0 while cnt <= n_nodes * 0.8: graph.update_all(fn.u_add_e('hop', 'w', 'm'), fn.min('m', 'hop')) new_cnt = (graph.ndata['hop'] < (n_nodes - 0.5)).sum().item() hop_cnt += 1 if cnt == new_cnt: graph.ndata['hop'][(torch.arange(n_nodes).to(device)[graph.ndata['hop'] > (n_nodes - 0.5)])[torch.randperm(n_nodes - cnt)[0]]] = hop_cnt new_cnt += 1 cnt = new_cnt target_nodes = torch.argsort(graph.ndata['hop'].long() * n_nodes + torch.randperm(n_nodes).to(device), dim=0)[int(n_nodes * 0.8):] new_graph = copy.deepcopy(graph) return dgl.remove_nodes(new_graph, target_nodes)
[docs] def inference(self, inputs): graph, features = inputs, inputs.ndata['feat'] if self.link_level: srcs, dsts = graph.edges() positive = self.encoder.forward_without_classifier(graph, features, srcs, dsts) perm = torch.randperm(graph.number_of_nodes()).to(features.device) negative = self.encoder.forward_without_classifier(graph, features[perm], srcs, dsts) aug1_graph = self.do_augmentation(graph) aug1_srcs, aug1_dsts = aug1_graph.edges() aug1 = self.encoder.forward_without_classifier(aug1_graph, aug1_graph.ndata['feat'], aug1_srcs, aug1_dsts) aug2_graph = self.do_augmentation(graph) aug2_srcs, aug2_dsts = aug2_graph.edges() aug2 = self.encoder.forward_without_classifier(aug2_graph, aug2_graph.ndata['feat'], aug2_srcs, aug2_dsts) else: positive = self.encoder.forward_without_classifier(graph, features) perm = torch.randperm(graph.number_of_nodes()).to(features.device) negative = self.encoder.forward_without_classifier(graph, features[perm]) aug1_graph = self.do_augmentation(graph) aug1 = self.encoder.forward_without_classifier(aug1_graph, aug1_graph.ndata['feat']) aug2_graph = self.do_augmentation(graph) aug2 = self.encoder.forward_without_classifier(aug2_graph, aug2_graph.ndata['feat']) summary_aug1 = torch.sigmoid(aug1.mean(dim=0)) summary_aug2 = torch.sigmoid(aug2.mean(dim=0)) pos_logit1 = self.discriminator(positive, summary_aug1) neg_logit1 = self.discriminator(negative, summary_aug1) pos_logit2 = self.discriminator(positive, summary_aug2) neg_logit2 = self.discriminator(negative, summary_aug2) aug1_loss = self.loss_fn(pos_logit1, torch.ones_like(pos_logit1)) + self.loss_fn(neg_logit1, torch.zeros_like(neg_logit1)) aug2_loss = self.loss_fn(pos_logit2, torch.ones_like(pos_logit2)) + self.loss_fn(neg_logit2, torch.zeros_like(neg_logit2)) return aug1_loss + aug2_loss
[docs]class SubgraphCL(PretrainingMethod): r""" An implementation of GraphCL (subgraph augmentation) for graph-level problems. This code was implemented based on the official implementation by authors. For the details, see the `original paper <https://proceedings.neurips.cc/paper_files/paper/2020/file/3fe230348e9a12c13120749e3f9fa4cd-Paper.pdf>`_. Arguments: encoder (torch.nn.Module): Pytorch model for pretraining. """ def __init__(self, encoder): super().__init__(encoder) print("PRETRAINING_ALGO: GraphCL (Graph)") def do_augmentation(self, graph): device = graph.ndata['feat'].device srcs, dsts = graph.edges() n_nodes = graph.num_nodes() with torch.no_grad(): with graph.local_scope(): root = torch.randint(n_nodes, (1,)) init_feat = n_nodes * torch.ones(n_nodes).to(device) init_feat[root.item()] = 0. graph.ndata['hop'] = init_feat if (srcs == dsts).sum() == 0: graph = graph.add_self_loop() srcs, dsts = graph.edges() graph.edata['w'] = torch.ones(graph.num_edges()).to(device) graph.edata['w'][srcs == dsts] = 0. cnt = (graph.ndata['hop'] < (n_nodes - 0.5)).sum().item() hop_cnt = 0 while cnt <= n_nodes * 0.8: graph.update_all(fn.u_add_e('hop', 'w', 'm'), fn.min('m', 'hop')) new_cnt = (graph.ndata['hop'] < (n_nodes - 0.5)).sum().item() hop_cnt += 1 if cnt == new_cnt: graph.ndata['hop'][(torch.arange(n_nodes).to(device)[graph.ndata['hop'] > (n_nodes - 0.5)])[torch.randperm(n_nodes - cnt)[0]]] = hop_cnt new_cnt += 1 cnt = new_cnt target_nodes = torch.argsort(graph.ndata['hop'].long() * n_nodes + torch.randperm(n_nodes).to(device), dim=0)[int(n_nodes * 0.8):] new_graph = copy.deepcopy(graph) return dgl.remove_nodes(new_graph, target_nodes.to(torch.int32))
[docs] def inference(self, inputs): graphs = inputs aug_graphs = dgl.batch(list(map(self.do_augmentation, dgl.unbatch(graphs)))) _, original_outputs = self.encoder(graphs, graphs.ndata['feat'] if 'feat' in graphs.ndata else None, edge_attr = graphs.edata['feat'] if 'feat' in graphs.edata else None, edge_weight = graphs.edata['weight'] if 'weight' in graphs.edata else None, get_intermediate_outputs=True) _, aug_outputs = self.encoder(aug_graphs, aug_graphs.ndata['feat'] if 'feat' in aug_graphs.ndata else None, edge_attr = aug_graphs.edata['feat'] if 'feat' in aug_graphs.edata else None, edge_weight = aug_graphs.edata['weight'] if 'weight' in aug_graphs.edata else None, get_intermediate_outputs=True) # compute cl loss neg_score = torch.logsumexp(aug_outputs[-1] @ original_outputs[-1].t(), dim=-1).mean() pos_score = torch.sum(aug_outputs[-1] * original_outputs[-1], dim=-1).mean() loss = -pos_score + neg_score return loss
[docs]class LightGCL(PretrainingMethod): r""" An implementation of LightGCL for node-level and link-level problems. This code was implemented based on the official implementation by authors. Note that this method only supports bipartite graphs. For the details, see the `original paper <https://arxiv.org/pdf/2302.08191>`_. Arguments: encoder (torch.nn.Module): Pytorch model for pretraining. """
[docs] class PretrainIterator(PretrainingMethod.PretrainIterator): def __init__(self, inputs, batch_size, samples, device): super().__init__(inputs, device) self.samples = samples[torch.randperm(samples.shape[0])] self.batch_size = batch_size def __iter__(self): return self def __next__(self): if self.count * self.batch_size >= self.samples.shape[0]: raise StopIteration else: self.count += 1 return (self.inputs.to(self.device), self.samples[(self.count-1) * self.batch_size:self.count * self.batch_size].to(self.device))
def __init__(self, encoder, link_level=True, bipartite=False): super().__init__(encoder) print("PRETRAINING_ALGO: LightGCL (Bipartite Graph Only)") self.link_level = link_level self.bipartite = bipartite self.batch_size = 4096 if self.link_level: self.target_module = self.encoder.gcn else: self.target_module = self.encoder self.svd_u, self.svd_s, self.svd_v = None, None, None def exact_forward(self, graph, feats): final_h = 0. h = feats h = self.target_module.dropout(h) for i in range(self.target_module.n_layers): conv = self.target_module.convs[i](graph, h) h = conv h = self.target_module.norms[i](h) h = self.target_module.activation(h) final_h = final_h + h h = self.target_module.dropout(h) return final_h def approx_conv(self, conv, graph, feat): with graph.local_scope(): feat_src, feat_dst = expand_as_pair(feat, graph) degs = graph.out_degrees().to(feat_src).clamp(min=1) norm = torch.pow(degs, -0.5) shp = norm.shape + (1,) * (feat_src.dim() - 1) norm = torch.reshape(norm, shp) feat_src = feat_src * norm weight = conv.weight if conv._in_feats > conv._out_feats: if weight is not None: feat_src = torch.matmul(feat_src, weight) if self.bipartite: updated_feat = torch.cat((self.svd_u.to(feat_src.device) @ ((self.svd_v @ torch.diag(self.svd_s)).t().to(feat_src.device) @ feat_src[self.num_srcs:]), self.svd_v.to(feat_src.device) @ ((self.svd_u @ torch.diag(self.svd_s)).t().to(feat_src.device) @ feat_src[:self.num_srcs])), dim=0) else: updated_feat = self.svd_u.to(feat_src.device) @ ((self.svd_v @ torch.diag(self.svd_s)).t().to(feat_src.device) @ feat_src) rst = updated_feat + feat_src else: # aggregate first then mult W if self.bipartite: updated_feat = torch.cat((self.svd_u.to(feat_src.device) @ ((self.svd_v @ torch.diag(self.svd_s)).t().to(feat_src.device) @ feat_src[self.num_srcs:]), self.svd_v.to(feat_src.device) @ ((self.svd_u @ torch.diag(self.svd_s)).t().to(feat_src.device) @ feat_src[:self.num_srcs])), dim=0) else: updated_feat = self.svd_u.to(feat_src.device) @ ((self.svd_v @ torch.diag(self.svd_s)).t().to(feat_src.device) @ feat_src) rst = updated_feat + feat_src if weight is not None: rst = torch.matmul(rst, weight) degs = graph.in_degrees().to(feat_dst).clamp(min=1) norm = torch.pow(degs, -0.5) shp = norm.shape + (1,) * (feat_dst.dim() - 1) norm = torch.reshape(norm, shp) rst = rst * norm return rst def approx_forward(self, graph, feats): final_h = 0. h = feats h = self.target_module.dropout(h) for i in range(self.target_module.n_layers): conv = self.approx_conv(self.target_module.convs[i], graph, h) h = conv h = self.target_module.norms[i](h) h = self.target_module.activation(h) final_h = final_h + h h = self.target_module.dropout(h) return final_h
[docs] def iterator(self, inputs, device): graph, features = inputs, inputs.ndata['feat'] if self.svd_s is None: if self.link_level and self.bipartite: srcs, dsts = graph.edges() valid = srcs < dsts srcs, dsts = srcs[valid], dsts[valid] self.srcs = srcs self.dsts = dsts # self.srcs = srcs # self.dsts = dsts self.num_srcs = srcs.max() + 1 self.num_dsts = dsts.max() + 1 - self.num_srcs self.svd_u, self.svd_s, self.svd_v = torch.svd_lowrank(torch.sparse_coo_tensor([srcs.tolist(), (dsts - self.num_srcs).tolist()], torch.ones_like(srcs).float().tolist()), q=5) else: srcs, dsts = graph.edges() self.srcs = srcs self.dsts = dsts valid = srcs != dsts srcs, dsts = srcs[valid], dsts[valid] self.svd_u, self.svd_s, self.svd_v = torch.svd_lowrank(torch.sparse_coo_tensor([srcs.tolist(), dsts.tolist()], torch.ones_like(srcs).float().tolist(), (graph.num_nodes(), graph.num_nodes())), q=5) if self.bipartite and self.link_level: return self.PretrainIterator(inputs, self.batch_size, torch.stack((self.srcs, self.dsts), dim=-1), device) else: return self.PretrainIterator(inputs, self.batch_size, torch.arange(graph.num_nodes()).to(self.srcs.device), device)
[docs] def inference(self, inputs): graph, features, targets = inputs[0], inputs[0].ndata['feat'], inputs[1] exact_outs = self.exact_forward(graph, features) approx_outs = self.approx_forward(graph, features) if self.link_level and self.bipartite: target_srcs, target_dsts = targets[:, 0], targets[:, 1] neg_score = torch.logsumexp(approx_outs[target_srcs] @ exact_outs[:self.num_srcs].t(), dim=-1).mean() neg_score = neg_score + torch.logsumexp(approx_outs[target_dsts] @ exact_outs[self.num_srcs:].t(), dim=-1).mean() pos_score = torch.sum(approx_outs[target_srcs] * exact_outs[target_srcs], dim=-1).mean() pos_score = pos_score + torch.sum(approx_outs[target_dsts] * exact_outs[target_dsts], dim=-1).mean() loss = -pos_score + neg_score else: neg_score = torch.logsumexp(approx_outs[targets] @ exact_outs.t(), dim=-1).mean() pos_score = torch.sum(approx_outs[targets] * exact_outs[targets], dim=-1).mean() loss = -pos_score + neg_score return loss
[docs]class InfoGraph(PretrainingMethod): r""" An implementation of InfoGraph for graph-level problems. This code was implemented based on the official implementation by authors. For the details, see the `original paper <https://arxiv.org/pdf/1908.01000>`_. Arguments: encoder (torch.nn.Module): Pytorch model for pretraining. """
[docs] class FeedforwardNetwork(nn.Module): def __init__(self, in_dim, hid_dim, out_dim): super().__init__() self.block = nn.Sequential( nn.Linear(in_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, out_dim) ) self.jump_con = nn.Linear(in_dim, out_dim) def forward(self, feat): block_out = self.block(feat) jump_out = self.jump_con(feat) out = block_out + jump_out return out
def get_positive_expectation(self, p_samples, average=True): log_2 = math.log(2.0) Ep = log_2 - F.softplus(-p_samples) if average: return Ep.mean() else: return Ep def get_negative_expectation(self, q_samples, average=True): log_2 = math.log(2.0) Eq = F.softplus(-q_samples) + q_samples - log_2 if average: return Eq.mean() else: return Eq def local_global_loss_(self, l_enc, g_enc, graph_id): num_graphs = g_enc.shape[0] num_nodes = l_enc.shape[0] device = g_enc.device pos_mask = torch.zeros((num_nodes, num_graphs)).to(device) neg_mask = torch.ones((num_nodes, num_graphs)).to(device) for nodeidx, graphidx in enumerate(graph_id.tolist()): pos_mask[nodeidx][graphidx] = 1.0 neg_mask[nodeidx][graphidx] = 0.0 res = torch.mm(l_enc, g_enc.t()) E_pos = self.get_positive_expectation(res * pos_mask, average=False).sum() E_pos = E_pos / num_nodes E_neg = self.get_negative_expectation(res * neg_mask, average=False).sum() E_neg = E_neg / (num_nodes * (num_graphs - 1)) return E_neg - E_pos def __init__(self, encoder): super().__init__(encoder) print("PRETRAINING_ALGO: InfoGraph") self.local_d = self.FeedforwardNetwork(encoder.n_hidden * encoder.n_layers, encoder.n_hidden, encoder.n_hidden // (1 << encoder.n_mlp_layers))
[docs] def inference(self, inputs): graphs = inputs _, intermediate_outputs = self.encoder(graphs, graphs.ndata['feat'] if 'feat' in graphs.ndata else None, edge_attr = graphs.edata['feat'] if 'feat' in graphs.edata else None, edge_weight = graphs.edata['weight'] if 'weight' in graphs.edata else None, get_intermediate_outputs=True) global_h = intermediate_outputs[-1] local_h = self.local_d(torch.cat(intermediate_outputs[:-1], dim=-1)) graph_id = torch.cat([(torch.ones(_num, dtype=torch.long) * i) for i, _num in enumerate(graphs.batch_num_nodes().tolist())], dim=-1) loss = self.local_global_loss_(local_h, global_h, graph_id) return loss