"""Classes and functions for working with implicit sums of tensors."""
# Copyright 2025 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.
from __future__ import annotations
import warnings
from copy import deepcopy
from textwrap import indent
from typing import TYPE_CHECKING, Literal
import pyttb as ttb
from pyttb.pyttb_utils import np_to_python
if TYPE_CHECKING:
import numpy as np
[docs]
class sumtensor:
"""Class for implicit sum of other tensors."""
[docs]
def __init__(
self,
tensors: list[ttb.tensor | ttb.sptensor | ttb.ktensor | ttb.ttensor]
| None = None,
copy: bool = True,
):
"""Create a :class:`pyttb.sumtensor` from a collection of tensors.
Each provided tensor is explicitly retained. All provided tensors
must have the same shape but can be combinations of types.
Parameters
----------
tensors:
Tensor source data.
copy:
Whether to make a copy of provided data or just reference it.
Examples
--------
Create an empty :class:`pyttb.tensor`:
>>> T1 = ttb.tenones((3, 4, 5))
>>> T2 = ttb.sptensor(shape=(3, 4, 5))
>>> S = ttb.sumtensor([T1, T2])
"""
if tensors is None:
tensors = []
assert isinstance(tensors, list), (
"Collection of tensors must be provided as a list "
f"but received: {type(tensors)}"
)
assert all(tensors[0].shape == tensor_i.shape for tensor_i in tensors[1:]), (
"All tensors must be the same shape"
)
if copy:
tensors = deepcopy(tensors)
self.parts = tensors
@property
def order(self) -> Literal["F"]:
"""Return the data layout of the underlying storage."""
return "F"
def _matches_order(self, array: np.ndarray) -> bool:
"""Check if provided array matches tensor memory layout."""
if array.flags["C_CONTIGUOUS"] and self.order == "C":
return True
if array.flags["F_CONTIGUOUS"] and self.order == "F":
return True
return False
[docs]
def copy(self) -> sumtensor:
"""Make a deep copy of a :class:`pyttb.sumtensor`.
Returns
-------
Copy of original sumtensor.
Examples
--------
>>> T1 = ttb.tensor(np.ones((3, 2)))
>>> S1 = ttb.sumtensor([T1, T1])
>>> S2 = S1
>>> S3 = S2.copy()
>>> S1.parts[0][0, 0] = 3
>>> S1.parts[0][0, 0] == S2.parts[0][0, 0]
True
>>> S1.parts[0][0, 0] == S3.parts[0][0, 0]
False
"""
return ttb.sumtensor(self.parts, copy=True)
[docs]
def __deepcopy__(self, memo):
"""Return deepcopy of this sumtensor."""
return self.copy()
@property
def shape(self) -> tuple[int, ...]:
"""Shape of a :class:`pyttb.sumtensor`."""
if len(self.parts) == 0:
return ()
return self.parts[0].shape
[docs]
def __repr__(self):
"""Return string representation of the sumtensor.
Returns
-------
String displaying shape and constituent parts.
Examples
--------
>>> T1 = ttb.tenones((2, 2))
>>> T2 = ttb.sptensor(shape=(2, 2))
>>> ttb.sumtensor([T1, T2]) # doctest: +NORMALIZE_WHITESPACE
sumtensor of shape (2, 2) with 2 parts:
Part 0:
tensor of shape (2, 2) with order F
data[:, :] =
[[1. 1.]
[1. 1.]]
Part 1:
empty sparse tensor of shape (2, 2) with order F
"""
if len(self.parts) == 0:
return "Empty sumtensor"
s = (
f"sumtensor of shape {np_to_python(self.shape)} "
f"with {len(self.parts)} parts:"
)
for i, part in enumerate(self.parts):
s += f"\nPart {i}: \n"
s += indent(str(part), prefix="\t")
return s
__str__ = __repr__
@property
def ndims(self) -> int:
"""
Number of dimensions of the sumtensor.
Examples
--------
>>> T1 = ttb.tenones((2, 2))
>>> S = ttb.sumtensor([T1, T1])
>>> S.ndims
2
"""
return self.parts[0].ndims
[docs]
def __pos__(self):
"""
Unary plus (+) for tensors.
Returns
-------
Copy of sumtensor.
Examples
--------
>>> T = ttb.tensor(np.array([[1, 2], [3, 4]]))
>>> S = ttb.sumtensor([T, T])
>>> S2 = +S
"""
return self.copy()
[docs]
def __neg__(self):
"""
Unary minus (-) for tensors.
Returns
-------
Copy of negated sumtensor.
Examples
--------
>>> T = ttb.tensor(np.array([[1, 2], [3, 4]]))
>>> S = ttb.sumtensor([T, T])
>>> S2 = -S
>>> S2.parts[0].isequal(-1 * S.parts[0])
True
"""
return ttb.sumtensor([-part for part in self.parts], copy=False)
[docs]
def __add__(self, other):
"""
Binary addition (+) for sumtensors.
Parameters
----------
other: :class:`pyttb.tensor`, :class:`pyttb.sptensor`
:class:`pyttb.ktensor`, :class:`pyttb.ttensor`, or list
containing those classes
Returns
-------
:class:`pyttb.sumtensor`
Examples
--------
>>> T = ttb.tenones((2, 2))
>>> S = ttb.sumtensor([T])
>>> len(S.parts)
1
>>> S2 = S + T
>>> len(S2.parts)
2
>>> S3 = S2 + [T, T]
>>> len(S3.parts)
4
"""
updated_parts = self.parts.copy()
if isinstance(other, (ttb.tensor, ttb.sptensor, ttb.ktensor, ttb.ttensor)):
updated_parts.append(other)
elif isinstance(other, list) and all(
isinstance(part, (ttb.tensor, ttb.sptensor, ttb.ktensor, ttb.ttensor))
for part in other
):
updated_parts.extend(other)
else:
raise TypeError(
"Sumtensor only supports collections of tensor, sptensor, ktensor, "
f"and ttensor but received: {type(other)}"
)
return ttb.sumtensor(updated_parts, copy=False)
[docs]
def __radd__(self, other):
"""
Right Binary addition (+) for sumtensors.
Parameters
----------
other: :class:`pyttb.tensor`, :class:`pyttb.sptensor`
:class:`pyttb.ktensor`, :class:`pyttb.ttensor`, or list
containing those classes
Returns
-------
:class:`pyttb.sumtensor`
Examples
--------
>>> T = ttb.tenones((2, 2))
>>> S = ttb.sumtensor([T])
>>> len(S.parts)
1
>>> S2 = T + S
>>> len(S2.parts)
2
>>> S3 = [T, T] + S2
>>> len(S3.parts)
4
"""
return self.__add__(other)
[docs]
def to_tensor(self) -> ttb.tensor:
"""Return sumtensor converted to dense tensor.
Same as :meth:`pyttb.sumtensor.full`.
"""
return self.full()
[docs]
def full(self) -> ttb.tensor:
"""
Convert a :class:`pyttb.sumtensor` to a :class:`pyttb.tensor`.
Returns
-------
Re-assembled dense tensor.
Examples
--------
>>> T = ttb.tenones((2, 2))
>>> S = ttb.sumtensor([T, T])
>>> print(S.full()) # doctest: +NORMALIZE_WHITESPACE
tensor of shape (2, 2) with order F
data[:, :] =
[[2. 2.]
[2. 2.]]
<BLANKLINE>
"""
result = self.parts[0].full()
for part in self.parts[1:]:
result += part
return result
[docs]
def double(self, immutable: bool = False) -> np.ndarray:
"""
Convert :class:`pyttb.tensor` to an :class:`numpy.ndarray` of doubles.
Parameters
----------
immutable: Whether or not the returned data cam be mutated. May enable
additional optimizations.
Examples
--------
>>> T = ttb.tenones((2, 2))
>>> S = ttb.sumtensor([T, T])
>>> S.double()
array([[2., 2.],
[2., 2.]])
"""
return self.full().double(immutable)
[docs]
def innerprod(
self, other: ttb.tensor | ttb.sptensor | ttb.ktensor | ttb.ttensor
) -> float:
"""Efficient inner product between a sumtensor and other `pyttb` tensors.
Parameters
----------
other:
Tensor to take an innerproduct with.
Examples
--------
>>> T1 = ttb.tensor(np.array([[1.0, 0.0], [0.0, 4.0]]))
>>> T2 = T1.to_sptensor()
>>> S = ttb.sumtensor([T1, T2])
>>> T1.innerprod(T1)
17.0
>>> T1.innerprod(T2)
17.0
>>> S.innerprod(T1)
34.0
"""
result = self.parts[0].innerprod(other)
for part in self.parts[1:]:
result += part.innerprod(other)
return result
[docs]
def mttkrp(
self, U: ttb.ktensor | list[np.ndarray], n: int | np.integer
) -> np.ndarray:
"""Matricized tensor times Khatri-Rao product.
The matrices used in the
Khatri-Rao product are passed as a :class:`pyttb.ktensor` (where the
factor matrices are used) or as a list of :class:`numpy.ndarray` objects.
Parameters
----------
U:
Matrices to create the Khatri-Rao product.
n:
Mode used to matricize tensor.
Returns
-------
Array containing matrix product.
Examples
--------
>>> T1 = ttb.tenones((2, 2, 2))
>>> T2 = T1.to_sptensor()
>>> S = ttb.sumtensor([T1, T2])
>>> U = [np.ones((2, 2))] * 3
>>> T1.mttkrp(U, 2)
array([[4., 4.],
[4., 4.]])
>>> S.mttkrp(U, 2)
array([[8., 8.],
[8., 8.]])
"""
result = self.parts[0].mttkrp(U, n)
for part in self.parts[1:]:
result += part.mttkrp(U, n)
return result
[docs]
def ttv(
self,
vector: np.ndarray | list[np.ndarray],
dims: int | np.ndarray | None = None,
exclude_dims: int | np.ndarray | None = None,
) -> float | sumtensor:
"""
Tensor times vector.
Computes the n-mode product of `parts` with the vector `vector`; i.e.,
`self x_n vector`. The integer `n` specifies the dimension (or mode)
along which the vector should be multiplied. If `vector.shape = (I,)`,
then the sumtensor must have `self.shape[n] = I`. The result will be the
same order and shape as `self` except that the size of dimension `n`
will be `J`. The resulting parts of the sum tensor have one less dimension,
as dimension `n` is removed in the multiplication.
Multiplication with more than one vector is provided using a list of
vectors and corresponding dimensions in the tensor to use.
The dimensions of the tensor with which to multiply can be provided as
`dims`, or the dimensions to exclude from `[0, ..., self.ndims]` can be
specified using `exclude_dims`.
Parameters
----------
vector:
Vector or vectors to multiple by.
dims:
Dimensions to multiply against.
exclude_dims:
Use all dimensions but these.
Returns
-------
Sumtensor containing individual products or a single sum if every
product is a single value.
Examples
--------
>>> T = ttb.tensor(np.array([[1, 2], [3, 4]]))
>>> S = ttb.sumtensor([T, T])
>>> T.ttv(np.ones(2), 0)
tensor of shape (2,) with order F
data[:] =
[4. 6.]
>>> S.ttv(np.ones(2), 0) # doctest: +NORMALIZE_WHITESPACE
sumtensor of shape (2,) with 2 parts:
Part 0:
tensor of shape (2,) with order F
data[:] =
[4. 6.]
Part 1:
tensor of shape (2,) with order F
data[:] =
[4. 6.]
>>> T.ttv([np.ones(2), np.ones(2)])
10.0
>>> S.ttv([np.ones(2), np.ones(2)])
20.0
"""
new_parts = []
scalar_sum = 0.0
for part in self.parts:
result = part.ttv(vector, dims, exclude_dims)
if isinstance(result, float):
scalar_sum += result
else:
new_parts.append(result)
if len(new_parts) == 0:
return scalar_sum
assert scalar_sum == 0.0
return ttb.sumtensor(new_parts, copy=False)
[docs]
def norm(self) -> float:
"""Compatibility Interface. Just returns 0."""
warnings.warn(
"Sumtensor doesn't actually support norm. Returning 0 for compatibility."
)
return 0.0
if __name__ == "__main__":
import doctest # pragma: no cover
doctest.testmod() # pragma: no cover