Source code for nupic.torch.compatibility
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2020, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------
from collections import OrderedDict
import torch
[docs]def upgrade_to_masked_sparseweights(state_dict):
"""
Returns a new state dict with any "zero_weights" tensors converted to
"zero_mask" tensors. (The "zero_weights" was a list of indices of zeroes in
the weight tensor.)
"""
upgraded = []
for name, tensor in state_dict.items():
if "zero_weights" in name:
weight_name = name.replace("zero_weights", "module.weight")
zero_mask = torch.zeros(state_dict[weight_name].shape,
device=tensor.device)
if tensor.shape[0] == 2:
# Assume this is the standard previous format of SparseWeights
# and SparseWeights2d
zero_mask.view(zero_mask.shape[0], -1)[tuple(tensor)] = 1
else:
# Assume the tensor is a valid index list for the weight shape
zero_mask[tuple(tensor)] = 1
upgraded.append((name.replace("zero_weights", "zero_mask"),
zero_mask))
else:
upgraded.append((name, tensor))
return OrderedDict(upgraded)