Pretraining

Our framework BeGin supports pretraining and provides pre-implemented self-supervised learning (SSL) methods.

class DGI(*args: Any, **kwargs: Any)[source]

An implementation of DGI for node-level and link-level problems. This code was implemented based on the official implementation by authors. For the details, see the original paper.

Parameters

encoder (torch.nn.Module) – Pytorch model for pretraining.

class Discriminator(*args: Any, **kwargs: Any)[source]
inference(inputs)[source]

Return iterator for the given input dataset. Implementing this function is mandatory to operate the pretraining procedure.

Parameters

inputs (object) – the input sample drawn from iterator.

Returns

a scalar which represents loss for pretraining.

class DGISubgraphCL(*args: Any, **kwargs: Any)[source]

An implementation of GraphCL (utilized with DGI, subgraph augmentation) for node-level and link-level problems. This code was implemented based on the official implementation by authors. For the details, see the original paper.

Parameters

encoder (torch.nn.Module) – Pytorch model for pretraining.

class Discriminator(*args: Any, **kwargs: Any)[source]
inference(inputs)[source]

Return iterator for the given input dataset. Implementing this function is mandatory to operate the pretraining procedure.

Parameters

inputs (object) – the input sample drawn from iterator.

Returns

a scalar which represents loss for pretraining.

class InfoGraph(*args: Any, **kwargs: Any)[source]

An implementation of InfoGraph for graph-level problems. This code was implemented based on the official implementation by authors. For the details, see the original paper.

Parameters

encoder (torch.nn.Module) – Pytorch model for pretraining.

class FeedforwardNetwork(*args: Any, **kwargs: Any)[source]
inference(inputs)[source]

Return iterator for the given input dataset. Implementing this function is mandatory to operate the pretraining procedure.

Parameters

inputs (object) – the input sample drawn from iterator.

Returns

a scalar which represents loss for pretraining.

class LightGCL(*args: Any, **kwargs: Any)[source]

An implementation of LightGCL for node-level and link-level problems. This code was implemented based on the official implementation by authors. Note that this method only supports bipartite graphs. For the details, see the original paper.

Parameters

encoder (torch.nn.Module) – Pytorch model for pretraining.

class PretrainIterator(inputs, batch_size, samples, device)[source]
inference(inputs)[source]

Return iterator for the given input dataset. Implementing this function is mandatory to operate the pretraining procedure.

Parameters

inputs (object) – the input sample drawn from iterator.

Returns

a scalar which represents loss for pretraining.

iterator(inputs, device)[source]

Return iterator for the given input dataset.

Parameters
  • inputs (object) – the input graph dataset.

  • device (str) – target GPU device.

Returns

An iterator for pretraining epoch.

class PretrainingMethod(*args: Any, **kwargs: Any)[source]

Base framework for implementing pretraining methods.

Parameters

encoder (torch.nn.Module) – Pytorch model for pretraining.

class PretrainIterator(inputs, device)[source]

Base itearator for pretraining iterations. This class assumes full-batch training.

inference(inputs)[source]

Return iterator for the given input dataset. Implementing this function is mandatory to operate the pretraining procedure.

Parameters

inputs (object) – the input sample drawn from iterator.

Returns

a scalar which represents loss for pretraining.

iterator(inputs, device)[source]

Return iterator for the given input dataset.

Parameters
  • inputs (object) – the input graph dataset.

  • device (str) – target GPU device.

Returns

An iterator for pretraining epoch.

processAfterTraining(original_model)[source]

This function is called once when the trainer concludes pretraining. The default implementation initializes the model using the saved best checkpoint (spec., self.best_checkpoint) before the main training begins.

update()[source]

This function is called when the best checkpoint needs to be updated. The default implementation stores the current state_dict of the model in self.best_checkpoint.

class SubgraphCL(*args: Any, **kwargs: Any)[source]

An implementation of GraphCL (subgraph augmentation) for graph-level problems. This code was implemented based on the official implementation by authors. For the details, see the original paper.

Parameters

encoder (torch.nn.Module) – Pytorch model for pretraining.

inference(inputs)[source]

Return iterator for the given input dataset. Implementing this function is mandatory to operate the pretraining procedure.

Parameters

inputs (object) – the input sample drawn from iterator.

Returns

a scalar which represents loss for pretraining.