"""Prepare Function and Gradient Handles for GCP OPT."""
# Copyright 2025 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Callable
import numpy as np
from pyttb.decompositions.cp.gcp import handles
from pyttb.decompositions.cp.gcp.handles import Objectives
from pyttb.tensors.sparse import sptensor
if TYPE_CHECKING:
from pyttb.tensors.dense import tensor
function_type = Callable[[np.ndarray, np.ndarray], np.ndarray]
fg_return = tuple[function_type, function_type, float]
[docs]
def setup( # noqa: PLR0912,PLR0915
objective: Objectives,
data: tensor | sptensor | None = None,
additional_parameter: float | None = None,
) -> fg_return:
"""Collect the function and gradient handles for GCP.
Parameters
----------
objective:
Objective function to gather handles for.
data:
Tensor to check for consistency with desired objective function.
additional_parameter:
Additional constant argument provided to objective function if necessary.
Returns
-------
Function handle, gradient handle, and lower bound.
"""
if objective == Objectives.GAUSSIAN:
function_handle = handles.gaussian
gradient_handle = handles.gaussian_grad
lower_bound = -np.inf
elif objective == Objectives.BERNOULLI_ODDS:
if data is not None and not valid_binary(data):
raise ValueError(f"{objective.name} requires a binary tensor")
function_handle = handles.bernoulli_odds
gradient_handle = handles.bernoulli_odds_grad
lower_bound = 0.0
elif objective == Objectives.BERNOULLI_LOGIT:
if data is not None and not valid_binary(data):
raise ValueError(f"{objective.name} requires a binary tensor")
function_handle = handles.bernoulli_logit
gradient_handle = handles.bernoulli_logit_grad
lower_bound = -np.inf
elif objective == Objectives.POISSON:
if data is not None and not valid_natural(data):
raise ValueError(f"{objective.name} requires a count tensor")
function_handle = handles.poisson
gradient_handle = handles.poisson_grad
lower_bound = 0.0
elif objective == Objectives.POISSON_LOG:
if data is not None and not valid_natural(data):
raise ValueError(f"{objective.name} requires a count tensor")
function_handle = handles.poisson_log
gradient_handle = handles.poisson_log_grad
lower_bound = -np.inf
elif objective == Objectives.RAYLEIGH:
if data is not None and not valid_nonneg(data):
raise ValueError(f"{objective.name} requires a non-negative tensor")
function_handle = handles.rayleigh
gradient_handle = handles.rayleigh_grad
lower_bound = 0.0
elif objective == Objectives.GAMMA:
if data is not None and not valid_nonneg(data):
raise ValueError(f"{objective.name} requires a non-negative tensor")
function_handle = handles.gamma
gradient_handle = handles.gamma_grad
lower_bound = 0.0
elif objective == Objectives.HUBER:
if additional_parameter is None:
raise ValueError(
f"{objective.name} requires additional parameter for `threshold`"
)
function_handle = partial(handles.huber, threshold=additional_parameter)
gradient_handle = partial(handles.huber_grad, threshold=additional_parameter)
lower_bound = -np.inf
elif objective == Objectives.NEGATIVE_BINOMIAL:
if data is not None and not valid_nonneg(data):
raise ValueError(f"{objective.name} requires a non-negative tensor")
if additional_parameter is None:
raise ValueError(
f"{objective.name} requires additional parameter for `num_trials`"
)
function_handle = partial(
handles.negative_binomial, num_trials=additional_parameter
)
gradient_handle = partial(
handles.negative_binomial_grad, num_trials=additional_parameter
)
lower_bound = 0
elif objective == Objectives.BETA:
if data is not None and not valid_nonneg(data):
raise ValueError(f"{objective.name} requires a non-negative tensor")
if additional_parameter is None:
raise ValueError(f"{objective.name} requires additional parameter for `b`")
function_handle = partial(handles.beta, b=additional_parameter)
gradient_handle = partial(handles.beta_grad, b=additional_parameter)
lower_bound = 0
elif objective == Objectives.ZT_POISSON:
if data is not None and not valid_natural(data):
raise ValueError(f"{objective.name} requires a count tensor")
function_handle = handles.ztp
gradient_handle = handles.ztp_grad
lower_bound = 0.0
else:
raise ValueError(f" Unknown objective: {objective}")
return function_handle, gradient_handle, lower_bound
[docs]
def valid_nonneg(data: tensor | sptensor) -> bool:
"""Check if provided data is valid non-negative tensor."""
if isinstance(data, sptensor):
return bool(np.all(data.vals > 0))
return bool(np.all(data.data > 0))
[docs]
def valid_binary(data: tensor | sptensor) -> bool:
"""Check if provided data is valid binary tensor."""
if isinstance(data, sptensor):
return bool(np.all(data.vals == 1))
return bool(np.all(np.isin(np.unique(data.data), [0, 1])))
[docs]
def valid_natural(data: tensor | sptensor) -> bool:
"""Check if provided data is valid natural number tensor."""
if isinstance(data, sptensor):
vals = data.vals
else:
vals = data.data
return bool(np.all(vals % 1 == 0))