Source code for nupic.torch.modules.prunable_sparse_weights
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2020, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------
from .sparse_weights import SparseWeights, SparseWeights2d
[docs]class PrunableSparseWeightBase(object):
"""
Enable easy setting and getting of the off-mask that defines which
weights are zero.
"""
@property
def off_mask(self):
"""
Gets the value of `zero_mask` in bool format. Thus one may call
```
self.weight[~self.off_mask] # returns weights that are currently on
```
"""
return self.zero_mask.bool()
@off_mask.setter
def off_mask(self, mask):
"""
Sets the values of `zero_mask`, updating self.sparsity to reflect the
sparsity of the new mask.
"""
self.sparsity = mask.sum().item() / mask.numel()
self.zero_mask[:] = mask
[docs]class PrunableSparseWeights(SparseWeights, PrunableSparseWeightBase):
"""
Enforce weight sparsity on linear module. The off-weights may be
changed dynamically through the `off_mask` property.
"""
def __init__(self, module, weight_sparsity=None, sparsity=None):
super().__init__(
module, weight_sparsity=weight_sparsity, sparsity=sparsity,
allow_extremes=True
)
[docs]class PrunableSparseWeights2d(SparseWeights2d, PrunableSparseWeightBase):
"""
Enforce weight sparsity on CNN modules. The off-weights may be
changed dynamically through the `off_mask` property.
"""
def __init__(self, module, weight_sparsity=None, sparsity=None):
super().__init__(
module, weight_sparsity=weight_sparsity, sparsity=sparsity,
allow_extremes=True
)