1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_QUAD_FORM_HPP 2 #define STAN_MATH_REV_MAT_FUN_TRACE_QUAD_FORM_HPP 4 #include <boost/utility/enable_if.hpp> 5 #include <boost/type_traits.hpp> 19 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
20 class trace_quad_form_vari_alloc :
public chainable_alloc {
22 trace_quad_form_vari_alloc(
const Eigen::Matrix<TA, RA, CA>& A,
23 const Eigen::Matrix<TB, RB, CB>& B)
32 Eigen::Matrix<TA, RA, CA>
A_;
33 Eigen::Matrix<TB, RB, CB>
B_;
36 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
37 class trace_quad_form_vari :
public vari {
39 static inline void chainA(Eigen::Matrix<double, RA, CA>& A,
40 const Eigen::Matrix<double, RB, CB>& Bd,
42 static inline void chainB(Eigen::Matrix<double, RB, CB>& B,
43 const Eigen::Matrix<double, RA, CA>& Ad,
44 const Eigen::Matrix<double, RB, CB>& Bd,
47 static inline void chainA(Eigen::Matrix<var, RA, CA>& A,
48 const Eigen::Matrix<double, RB, CB>& Bd,
50 Eigen::Matrix<double, RA, CA> adjA(adjC*Bd*Bd.transpose());
51 for (
int j = 0; j < A.cols(); j++)
52 for (
int i = 0; i < A.rows(); i++)
53 A(i, j).vi_->adj_ += adjA(i, j);
55 static inline void chainB(Eigen::Matrix<var, RB, CB>& B,
56 const Eigen::Matrix<double, RA, CA>& Ad,
57 const Eigen::Matrix<double, RB, CB>& Bd,
59 Eigen::Matrix<double, RA, CA> adjB(adjC*(Ad + Ad.transpose())*Bd);
60 for (
int j = 0; j < B.cols(); j++)
61 for (
int i = 0; i < B.rows(); i++)
62 B(i, j).vi_->adj_ += adjB(i, j);
65 inline void chainAB(Eigen::Matrix<TA, RA, CA>& A,
66 Eigen::Matrix<TB, RB, CB>& B,
67 const Eigen::Matrix<double, RA, CA>& Ad,
68 const Eigen::Matrix<double, RB, CB>& Bd,
71 chainB(B, Ad, Bd, adjC);
77 (trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *impl)
78 : vari(impl->compute()),
impl_(impl) { }
80 virtual void chain() {
86 trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *
impl_;
90 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
92 boost::enable_if_c< boost::is_same<TA, var>::value ||
93 boost::is_same<TB, var>::value,
96 const Eigen::Matrix<TB, RB, CB>& B) {
102 trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *baseVari
103 =
new trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB>(A, B);
105 return var(
new trace_quad_form_vari<TA, RA, CA, TB, RB, CB>(baseVari));
fvar< T > trace_quad_form(const Eigen::Matrix< fvar< T >, RA, CA > &A, const Eigen::Matrix< fvar< T >, RB, CB > &B)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Independent (input) and dependent (output) variables for gradients.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void check_square(const char *function, const char *name, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y)
Check if the specified matrix is square.