import torch
import torch.nn as nn
import torch.nn.functional as F
from constrainedmf.nmf.metrics import Beta_divergence
def _mu_update(param, pos, gamma, l1_reg, l2_reg):
"""
Perform multiplicative update of param (W, or H)
Parameters
----------
param: tensor
Weights or components
pos: tensor
positive denominator (mu)
gamma: float
Beta - 1 from Beta divergence loss
l1_reg: float
L1 regularization
l2_reg: float
L2 regularization
Returns
-------
"""
if isinstance(param, nn.ParameterList):
# Handle no gradients in fixed components
grad = torch.cat(
[x.grad if x.requires_grad else torch.zeros_like(x) for x in param]
)
elif param.grad is None:
return
else:
grad = param.grad
# prevent negative terms and zero division
multiplier = F.relu(pos - grad, inplace=True)
if (pos == 0).sum() > 0:
pos.add_(1e-7)
if l1_reg > 0:
pos.add_(l1_reg)
if l2_reg > 0:
if isinstance(param, nn.ParameterList):
reg_param = torch.cat([x for x in param])
else:
reg_param = param
if pos.shape != reg_param.shape:
pos = pos + l2_reg * reg_param
else:
pos.add_(l2_reg * reg_param)
multiplier.div_(pos)
if gamma != 1:
multiplier.pow_(gamma)
if isinstance(param, nn.ParameterList):
for i, sub_param in enumerate(param):
sub_param.mul_(multiplier[i, :])
else:
param.mul_(multiplier)
[docs]class NMFBase(nn.Module):
def __init__(
self,
W_shape,
H_shape,
n_components,
*,
initial_components=None,
fix_components=(),
initial_weights=None,
fix_weights=(),
device=None,
**kwargs
):
"""
Base class for setting up NMF
Parameters
----------
W_shape: tuple of int
Shape of the weights matrix
H_shape: tuple of int
Shape of the components matrix
n_components: int
Number of components in the factorization
initial_components: tuple of torch.Tensor
Initial components for the factorization. Shape (1, n_features)
fix_components: tuple of bool
Corresponding directive to fix each component in the factorization.
The components are ordered, and the default behavior is to allow a component to vary.
I.e. (True, False, True) for a 4 component factorization will result in the first and third
component being fixed, while the second and fourth vary.
initial_weights: tuple of torch.Tensor
Initial weights for the factorization. Shape (1, m_examples)
fix_weights: tuple of bool
Corresponding directive to fix each weight in the factorization.
device: str, torch.device, None
Device for matrix factorization to proceed on. Defaults to cpu.
kwargs: dict
Keyword arguments for torch.nn.Module
"""
super().__init__()
self.fix_neg = nn.Threshold(0.0, 1e-8)
self.rank = n_components
if device is None:
self.device = torch.device("cpu")
else:
self.device = torch.device(device)
if initial_weights is not None:
w_list = [nn.Parameter(weight) for weight in initial_weights] + [
nn.Parameter(torch.rand(1, *W_shape[1:]))
for _ in range(W_shape[0] - len(initial_weights))
]
else:
w_list = [
nn.Parameter(torch.rand(1, *W_shape[1:])) for _ in range(W_shape[0])
]
if fix_weights:
for i in range(len(fix_weights)):
w_list[i].requires_grad = not fix_weights[i]
self.W_list = nn.ParameterList(w_list).to(device)
if initial_components is not None:
h_list = [nn.Parameter(component) for component in initial_components] + [
nn.Parameter(torch.rand(1, *H_shape[1:]))
for _ in range(H_shape[0] - len(initial_components))
]
else:
h_list = [
nn.Parameter(torch.rand(1, *H_shape[1:])) for _ in range(H_shape[0])
]
if fix_components:
for i in range(len(fix_components)):
h_list[i].requires_grad = not fix_components[i]
self.H_list = nn.ParameterList(h_list).to(device)
@property
def H(self):
return torch.cat([x for x in self.H_list])
@property
def W(self):
return torch.cat([x for x in self.W_list])
def loss(self, X, beta=2):
with torch.no_grad():
WH = self.reconstruct(self.H, self.W)
return Beta_divergence(self.fix_neg(WH), X, beta)
def forward(self, H=None, W=None):
if H is None:
H = self.H
if W is None:
W = self.W
return self.reconstruct(H, W)
[docs] def reconstruct(self, H, W):
"""
Method for reconstructing the approximate input matrix from the components and weights
Parameters
----------
H: torch.Tensor
Components matrix
W: torch.Tensor
Weights matrix
Returns
-------
"""
raise NotImplementedError
[docs] def get_W_positive(self, WH, beta, H_sum) -> (torch.Tensor, None or torch.Tensor):
"""
Get the positive denominator an/or H sum for multiplicative W update
Parameters
----------
WH: torch.Tensor
Reconstruction of input matrix (in the simple case this is the matrix produce W @ H
beta: int, float
Value for beta divergence
H_sum: torch.Tensor, None
Sum over components matrix to use in denominator of update. If unknown or not required use None.
Returns
-------
"""
raise NotImplementedError
[docs] def get_H_positive(self, WH, beta, W_sum) -> (torch.Tensor, None or torch.Tensor):
"""
Get the positive denominator and/or W sum for multiplicative H update
Parameters
----------
WH: torch.Tensor
Reconstruction of input matrix (in the simple case this is the matrix produce W @ H
beta: int, float
Value for beta divergence
W_sum: torch.Tensor, None
Sum over weights matrix to use in denominator of update. If unknown or not required use None.
Returns
-------
"""
raise NotImplementedError
def fit(
self,
X,
update_W=True,
update_H=True,
beta=1,
tol=1e-5,
max_iter=200,
alpha=0,
l1_ratio=0,
):
"""
Fit the wights (W) and components (H) to the dataset X.
Parameters
----------
X: torch.Tensor
Tensor of the dataset to fit, shape (m_examples, n_features)
update_W: bool
Override on updating weights matrix
update_H: bool
Override on updating components matrix
beta: float
Value for beta divergence
tol: float
Change in loss tolerance for exiting optimization loop
max_iter: int
Maximum number of iterations to consider for optimization loop
alpha: float
Amount of regularization for the mu update
l1_ratio: float
Ratio of L1 to L2 regularization
Returns
-------
"""
X = X.type(torch.float).to(self.device)
X = self.fix_neg(X)
if beta < 1:
gamma = 1 / (2 - beta)
elif beta > 2:
gamma = 1 / (beta - 1)
else:
gamma = 1
l1_reg = alpha * l1_ratio
l2_reg = alpha * (1 - l1_ratio)
loss_scale = torch.prod(torch.tensor(X.shape)).float()
losses = []
H_sum, W_sum = None, None
if max_iter < 1:
raise ValueError("Maximum number of iterations must be at least 1.")
for n_iter in range(max_iter):
# W update
if update_W and any([x.requires_grad for x in self.W_list]):
self.zero_grad()
WH = self.reconstruct(self.H.detach(), self.W)
loss = Beta_divergence(self.fix_neg(WH), X, beta)
loss.backward()
with torch.no_grad():
positive_comps, H_sum = self.get_W_positive(WH, beta, H_sum)
_mu_update(self.W_list, positive_comps, gamma, l1_reg, l2_reg)
W_sum = None
# H update
if update_H and any([x.requires_grad for x in self.H_list]):
self.zero_grad()
WH = self.reconstruct(self.H, self.W.detach())
loss = Beta_divergence(self.fix_neg(WH), X, beta)
loss.backward()
with torch.no_grad():
positive_comps, W_sum = self.get_H_positive(WH, beta, W_sum)
_mu_update(self.H_list, positive_comps, gamma, l1_reg, l2_reg)
H_sum = None
loss = loss.div_(loss_scale).item()
if not n_iter:
loss_init = loss
elif (previous_loss - loss) / loss_init < tol: # noqa: F821
break
previous_loss = loss # noqa:F841
losses.append(loss)
return losses
def fit_transform(self, *args, **kwargs):
self.fit(*args, **kwargs)
return self.W
[docs]class NMF(NMFBase):
def __init__(
self,
X_shape,
n_components,
*,
initial_components=None,
fix_components=(),
initial_weights=None,
fix_weights=(),
device=None,
**kwargs
):
"""
Standard NMF with ability for constraints constructed from input matrix shape.
W is (m_examples, n_components)
H is (n_components, n_example_features)
W @ H give reconstruction of X.
Parameters
----------
X_shape: tuple
Tuple of ints describing shape of input matrix
n_components: int
Number of desired components for the matrix factorization
initial_components: tuple of torch.Tensor
Initial components for the factorization. Shape (1, n_features)
fix_components: tuple of bool
Corresponding directive to fix each component in the factorization.
The components are ordered, and the default behavior is to allow a component to vary.
I.e. (True, False, True) for a 4 component factorization will result in the first and third
component being fixed, while the second and fourth vary.
initial_weights: tuple of torch.Tensor
Initial weights for the factorization. Shape (1, m_examples)
fix_weights: tuple of bool
Corresponding directive to fix each weight in the factorization.
device: str, torch.device, None
Device for matrix factorization to proceed on. Defaults to cpu.
kwargs: dict
kwargs for torch.nn.Module
"""
self.m_examples, self.n_features = X_shape
super().__init__(
(self.m_examples, n_components),
(n_components, self.n_features),
n_components,
initial_components=initial_components,
initial_weights=initial_weights,
fix_weights=fix_weights,
fix_components=fix_components,
device=device,
**kwargs
)
[docs] def reconstruct(self, H, W):
"""
Reconstructs the approximate input matrix from matrix product of weights and components
Parameters
----------
H: torch.Tensor
Components matrix
W: torch.Tensor
Weights matrix
Returns
-------
torch.Tensor
"""
return W @ H
def get_W_positive(self, WH, beta, H_sum):
H = self.H
if beta == 1:
if H_sum is None:
H_sum = H.sum(1)
denominator = H_sum[None, :]
else:
if beta != 2:
WH = WH.pow(beta - 1)
WHHt = WH @ H.t()
denominator = WHHt
return denominator, H_sum
def get_H_positive(self, WH, beta, W_sum):
W = self.W
if beta == 1:
if W_sum is None:
W_sum = W.sum(0) # shape(n_components, )
denominator = W_sum[:, None]
else:
if beta != 2:
WH = WH.pow(beta - 1)
WtWH = W.t() @ WH
denominator = WtWH
return denominator, W_sum
def sort(self):
raise NotImplementedError
class NMFD(NMFBase):
"""
Deconvolutional NMF
W is (m_examples, n_components, kernel_width)
H is (n_components, n_example_features)
"""
def __init__(self, X_shape, n_components, T=1, **kwargs):
self.m_examples, self.n_features = X_shape
self.pad_size = T - 1
super().__init__(
(self.m_examples, n_components, T),
(n_components, self.n_features - T + 1),
n_components,
**kwargs
)
def reconstruct(self, H, W):
return F.conv1d(H[None, :], W.flip(2), padding=self.pad_size)[0]
def get_W_positive(self, WH, beta, H_sum):
H = self.H
if beta == 1:
if H_sum is None:
H_sum = H.sum(1)
denominator = H_sum[None, :, None]
else:
if beta != 2:
WH = WH.pow(beta - 1)
WHHt = F.conv1d(WH[:, None], H[:, None])
denominator = WHHt
return denominator, H_sum
def get_H_positive(self, WH, beta, W_sum):
W = self.W
if beta == 1:
if W_sum is None:
W_sum = W.sum((0, 2))
denominator = W_sum[:, None]
else:
if beta != 2:
WH = WH.pow(beta - 1)
WtWH = F.conv1d(WH[None, :], W.transpose(0, 1))[0]
denominator = WtWH
return denominator, W_sum
def sort(self):
raise NotImplementedError