Skip to content

Update ANDMask Problem #13

@khan-yin

Description

@khan-yin

Hello, author. I want to add ANDMask for benchmark, Well I met a problem when I run for the LSA64 dataset. Could you please check out if the ANDMask code is right and how to solve in LSA64, while the other datasets I test did not see this error.

Code

class ANDMask(ERM):
    """
    Learning Explanations that are Hard to Vary [https://arxiv.org/abs/2009.00329]
    AND-Mask implementation from [https://github.com/gibipara92/learning-explanations-hard-to-vary]
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(ANDMask, self).__init__(model, dataset, optimizer, hparams)

        # Hyper parameters
        self.tau = self.hparams['tau']

    def mask_grads(self, tau, gradients, params):

        for param, grads in zip(params, gradients):
            grads = torch.stack(grads, dim=0)
            grad_signs = torch.sign(grads)
            mask = torch.mean(grad_signs, dim=0).abs() >= self.tau
            mask = mask.to(torch.float32)
            avg_grad = torch.mean(grads, dim=0)

            mask_t = (mask.sum() / mask.numel())
            param.grad = mask * avg_grad
            param.grad *= (1. / (1e-10 + mask_t))

    def update(self):
        X, Y = self.dataset.get_next_batch()

        out, out_features = self.predict(X)
        n_domains = self.dataset.get_nb_training_domains()
        out, labels = self.dataset.split_tensor_by_domains(out, Y, n_domains)

        # Compute loss for each environment 
        env_losses = torch.zeros(out.shape[0]).to(self.device)
        for i in range(out.shape[0]):
            for t_idx in range(out.shape[2]):     # Number of time steps
                env_losses[i] += F.cross_entropy(out[i, :, t_idx, :], labels[i,:,t_idx]) 

        # Compute gradients for each env
        param_gradients = [[] for _ in self.model.parameters()]
        for env_loss in env_losses:
            env_grads = autograd.grad(env_loss, self.model.parameters(), retain_graph=True)
            for grads, env_grad in zip(param_gradients, env_grads):
                grads.append(env_grad)
            
        # Back propagate
        self.optimizer.zero_grad()
        self.mask_grads(self.tau, param_gradients, self.model.parameters())
        self.optimizer.step()

Error for LSA64

image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions