Source code for nupic.torch.modules.sparse_weights

# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2019, Numenta, Inc.  Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program.  If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------
import abc
import math
import warnings

import numpy as np
import torch
import torch.nn as nn


[docs]def rezero_weights(m): """Function used to update the weights after each epoch. Call using :meth:`torch.nn.Module.apply` after each epoch if required For example: ``m.apply(rezero_weights)`` :param m: HasRezeroWeights module """ if isinstance(m, HasRezeroWeights): m.rezero_weights()
[docs]def normalize_sparse_weights(m): """Initialize the weights using kaiming_uniform initialization normalized to the number of non-zeros in the layer instead of the whole input size. Similar to torch.nn.Linear.reset_parameters() but applying weight sparsity to the input size """ if isinstance(m, SparseWeightsBase): _, input_size = m.module.weight.shape fan = int(input_size * (1.0 - m.sparsity)) gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5)) std = gain / np.math.sqrt(fan) bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation nn.init.uniform_(m.module.weight, -bound, bound) if m.module.bias is not None: bound = 1 / math.sqrt(fan) nn.init.uniform_(m.module.bias, -bound, bound)
[docs]class HasRezeroWeights(metaclass=abc.ABCMeta):
[docs] @abc.abstractmethod def rezero_weights(self): """Set the previously selected weights to zero.""" raise NotImplementedError
[docs]class SparseWeightsBase(nn.Module, HasRezeroWeights): """ Base class for the all Sparse Weights modules. :param module: The module to sparsify the weights :param weight_sparsity: Pct of weights that are NON-ZERO in the layer. Also equal to 1-sparsity **Please note this is the first positional parameter for backwards compatibility** :param sparsity: Pct of weights that are ZERO in the layer Accepts either sparsity or weight_sparsity, but not both at a time """ def __init__(self, module, weight_sparsity=None, sparsity=None): super(SparseWeightsBase, self).__init__() assert weight_sparsity is not None or sparsity is not None assert not(weight_sparsity and sparsity), ( "Both `weight_sparsity` and `sparsity` were passed as arguments, " "but only one of those two can be defined at a time." ) if weight_sparsity is not None and sparsity is None: sparsity = 1.0 - weight_sparsity warnings.warn( "Parameter `weight_sparsity` is deprecated. Use `sparsity` instead.", DeprecationWarning, ) self.module = module self.sparsity = sparsity
[docs] def extra_repr(self): return "sparsity={}".format(self.sparsity)
[docs] def forward(self, x): return self.module(x)
@property def weight_sparsity(self): warnings.warn( "Parameter `weight_sparsity` is deprecated. Use `sparsity` instead.", DeprecationWarning, ) return 1.0 - self.sparsity @property def weight(self): return self.module.weight @property def bias(self): return self.module.bias
[docs]class SparseWeights(SparseWeightsBase): """Enforce weight sparsity on linear module during training. Sample usage: model = nn.Linear(784, 10) model = SparseWeights(model, sparsity=0.4) :param module: The module to sparsify the weights :param weight_sparsity: Pct of weights that are NON-ZERO in the layer. Also equal to 1-sparsity **Please note this is the first positional parameter for backwards compatibility** :param sparsity: Pct of weights that are ZERO in the layer Accepts either sparsity or weight_sparsity, but not both at a time :param allow_extremes: Allow values sparsity=0 and sparsity=1. These values are often a sign that there is a bug in the configuration, because they lead to Identity and Zero layers, respectively, but they can make sense in scenarios where the mask is dynamic. """ def __init__(self, module, weight_sparsity=None, sparsity=None, allow_extremes=False): assert len(module.weight.shape) == 2, "Should resemble a nn.Linear" super(SparseWeights, self).__init__( module, weight_sparsity=weight_sparsity, sparsity=sparsity ) if allow_extremes: assert 0 <= self.sparsity <= 1 else: assert 0 < self.sparsity < 1 # For each unit, decide which weights are going to be zero in_features = self.module.in_features out_features = self.module.out_features num_nz = int(round((1 - self.sparsity) * in_features)) zero_mask = torch.ones(out_features, in_features, dtype=torch.bool) for out_feature in range(out_features): in_indices = np.random.choice(in_features, num_nz, replace=False) zero_mask[out_feature, in_indices] = False # Use float16 because pytorch distributed nccl doesn't support bools self.register_buffer("zero_mask", zero_mask.half()) self.rezero_weights()
[docs] def rezero_weights(self): self.module.weight.data.masked_fill_(self.zero_mask.bool(), 0)
[docs]class SparseWeights2d(SparseWeightsBase): """Enforce weight sparsity on CNN modules Sample usage: model = nn.Conv2d(in_channels, out_channels, kernel_size, ...) model = SparseWeights2d(model, sparsity=0.4) :param module: The module to sparsify the weights :param weight_sparsity: Pct of weights that are NON-ZERO in the layer. Also equal to 1-sparsity **Please note this is the first positional parameter for backwards compatibility** :param sparsity: Pct of weights that are ZERO in the layer Accepts either sparsity or weight_sparsity, but not both at a time :param allow_extremes: Allow values sparsity=0 and sparsity=1. These values are often a sign that there is a bug in the configuration, because they lead to Identity and Zero layers, respectively, but they can make sense in scenarios where the mask is dynamic. """ def __init__(self, module, weight_sparsity=None, sparsity=None, allow_extremes=False): assert len(module.weight.shape) == 4, "Should resemble a nn.Conv2d" super(SparseWeights2d, self).__init__( module, weight_sparsity=weight_sparsity, sparsity=sparsity ) if allow_extremes: assert 0 <= self.sparsity <= 1 else: assert 0 < self.sparsity < 1 # For each unit, decide which weights are going to be zero in_channels = self.module.in_channels out_channels = self.module.out_channels kernel_size = self.module.kernel_size input_size = in_channels * kernel_size[0] * kernel_size[1] num_nz = int(round((1 - self.sparsity) * input_size)) zero_mask = torch.ones(out_channels, input_size, dtype=torch.bool) for out_channel in range(out_channels): in_indices = np.random.choice(input_size, num_nz, replace=False) zero_mask[out_channel, in_indices] = False zero_mask = zero_mask.view(out_channels, in_channels, *kernel_size) # Use float16 because pytorch distributed nccl doesn't support bools self.register_buffer("zero_mask", zero_mask.half()) self.rezero_weights()
[docs] def rezero_weights(self): self.module.weight.data.masked_fill_(self.zero_mask.bool(), 0)