1 #ifndef STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP
2 #define STAN_MATH_REV_MAT_FUN_QUAD_FORM_HPP
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
19 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
20 class quad_form_vari_alloc :
public chainable_alloc {
22 inline void compute(
const Eigen::Matrix<double, RA, CA>& A,
23 const Eigen::Matrix<double, RB, CB>& B) {
24 Eigen::Matrix<double, CB, CB> Cd(B.transpose()*A*B);
25 for (
int j = 0; j <
C_.cols(); j++) {
26 for (
int i = 0; i <
C_.rows(); i++) {
28 C_(i, j) = var(
new vari(0.5*(Cd(i, j) + Cd(j, i)),
false));
30 C_(i, j) = var(
new vari(Cd(i, j),
false));
37 quad_form_vari_alloc(
const Eigen::Matrix<TA, RA, CA>& A,
38 const Eigen::Matrix<TB, RB, CB>& B,
39 bool symmetric =
false)
45 Eigen::Matrix<TA, RA, CA>
A_;
46 Eigen::Matrix<TB, RB, CB>
B_;
47 Eigen::Matrix<var, CB, CB>
C_;
51 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
52 class quad_form_vari :
public vari {
54 inline void chainA(Eigen::Matrix<double, RA, CA>& A,
55 const Eigen::Matrix<double, RB, CB>& Bd,
56 const Eigen::Matrix<double, CB, CB>& adjC) {}
57 inline void chainB(Eigen::Matrix<double, RB, CB>& B,
58 const Eigen::Matrix<double, RA, CA>& Ad,
59 const Eigen::Matrix<double, RB, CB>& Bd,
60 const Eigen::Matrix<double, CB, CB>& adjC) {}
62 inline void chainA(Eigen::Matrix<var, RA, CA>& A,
63 const Eigen::Matrix<double, RB, CB>& Bd,
64 const Eigen::Matrix<double, CB, CB>& adjC) {
65 Eigen::Matrix<double, RA, CA> adjA(Bd*adjC*Bd.transpose());
66 for (
int j = 0; j < A.cols(); j++) {
67 for (
int i = 0; i < A.rows(); i++) {
68 A(i, j).vi_->adj_ += adjA(i, j);
72 inline void chainB(Eigen::Matrix<var, RB, CB>& B,
73 const Eigen::Matrix<double, RA, CA>& Ad,
74 const Eigen::Matrix<double, RB, CB>& Bd,
75 const Eigen::Matrix<double, CB, CB>& adjC) {
76 Eigen::Matrix<double, RA, CA> adjB(Ad * Bd * adjC.transpose()
77 + Ad.transpose()*Bd*adjC);
78 for (
int j = 0; j < B.cols(); j++)
79 for (
int i = 0; i < B.rows(); i++)
80 B(i, j).vi_->adj_ += adjB(i, j);
83 inline void chainAB(Eigen::Matrix<TA, RA, CA>& A,
84 Eigen::Matrix<TB, RB, CB>& B,
85 const Eigen::Matrix<double, RA, CA>& Ad,
86 const Eigen::Matrix<double, RB, CB>& Bd,
87 const Eigen::Matrix<double, CB, CB>& adjC) {
89 chainB(B, Ad, Bd, adjC);
93 quad_form_vari(
const Eigen::Matrix<TA, RA, CA>& A,
94 const Eigen::Matrix<TB, RB, CB>& B,
95 bool symmetric =
false)
98 =
new quad_form_vari_alloc<TA, RA, CA, TB, RB, CB>(A, B, symmetric);
101 virtual void chain() {
103 Eigen::Matrix<double, CB, CB> adjC(
_impl->C_.rows(),
106 for (
int j = 0; j <
_impl->C_.cols(); j++)
107 for (
int i = 0; i <
_impl->C_.rows(); i++)
108 adjC(i, j) =
_impl->C_(i, j).vi_->adj_;
115 quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *
_impl;
119 template <
typename TA,
int RA,
int CA,
typename TB,
int RB,
int CB>
121 boost::enable_if_c< boost::is_same<TA, var>::value ||
122 boost::is_same<TB, var>::value,
123 Eigen::Matrix<var, CB, CB> >::type
125 const Eigen::Matrix<TB, RB, CB>& B) {
131 quad_form_vari<TA, RA, CA, TB, RB, CB> *baseVari
132 =
new quad_form_vari<TA, RA, CA, TB, RB, CB>(A, B);
134 return baseVari->_impl->C_;
136 template <
typename TA,
int RA,
int CA,
typename TB,
int RB>
138 boost::enable_if_c< boost::is_same<TA, var>::value ||
139 boost::is_same<TB, var>::value,
142 const Eigen::Matrix<TB, RB, 1>& B) {
148 quad_form_vari<TA, RA, CA, TB, RB, 1> *baseVari
149 =
new quad_form_vari<TA, RA, CA, TB, RB, 1>(A, B);
151 return baseVari->_impl->C_(0, 0);
T value_of(const fvar< T > &v)
Return the value of the specified variable.
bool check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Return true if the matrices can be multiplied.
size_t cols(const Eigen::Matrix< T, R, C > &m)
bool check_square(const char *function, const char *name, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y)
Return true if the specified matrix is square.
Eigen::Matrix< T, CB, CB > quad_form(const Eigen::Matrix< T, RA, CA > &A, const Eigen::Matrix< T, RB, CB > &B)
Compute B^T A B.