HAT
HAT is a parameter-isolation-based continual learning method which learns real-valued masks, which play a role similar to attention modules, for weighting each layer output for each task. For the details, see the original paper.
Node-level Problems
- class NCTaskILHATTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
For HAT, we need to consider the techniques introduced in the original paper in this event function.
- Parameters
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
use_mask (bool) – whether model masks weights of the model.
- Returns
A dictionary containing the information from the results.
- inference(model, _curr_batch, training_states)[source]
The event function to execute inference step.
For task-IL, we need to additionally consider task information for the inference step. In addition, we need to consider (attentive) masks for HAT.
- Parameters
- Returns
A dictionary containing the inference results, such as prediction result and loss.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
In this event function, our implementation updates the masks for running HAT.
- Parameters
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- processBeforeTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes before training.
- Parameters
task_id (int) – the index of the current task
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- processEvalIteration(model, _curr_batch)[source]
The event function to handle every evaluation iteration.
We need to extend the base function since the output format is slightly different from the base trainer.
- processTrainIteration(model, optimizer, _curr_batch, training_states)[source]
The event function to handle every training iteration.
- Parameters
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
use_mask (bool) – whether model masks weights of the model.
- Returns
A dictionary containing the outcomes (stats) during the training iteration.