Stan Math Library  2.6.3
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Friends Macros
multiply.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_MULTIPLY_HPP
2 #define STAN_MATH_REV_MAT_FUN_MULTIPLY_HPP
3 
4 #include <stan/math/rev/core.hpp>
12 #include <boost/utility/enable_if.hpp>
13 #include <boost/type_traits.hpp>
14 #include <boost/math/tools/promotion.hpp>
15 
16 namespace stan {
17  namespace math {
18 
25  template <typename T1, typename T2>
26  inline typename
27  boost::enable_if_c<
28  (boost::is_scalar<T1>::value || boost::is_same<T1, var>::value)
29  && (boost::is_scalar<T2>::value || boost::is_same<T2, var>::value),
30  typename boost::math::tools::promote_args<T1, T2>::type>::type
31  multiply(const T1& v, const T2& c) {
32  return v * c;
33  }
34 
41  template<typename T1, typename T2, int R2, int C2>
42  inline Eigen::Matrix<var, R2, C2>
43  multiply(const T1& c, const Eigen::Matrix<T2, R2, C2>& m) {
44  // FIXME: pull out to eliminate overpromotion of one side
45  // move to matrix.hpp w. promotion?
46  return to_var(m) * to_var(c);
47  }
48 
55  template<typename T1, int R1, int C1, typename T2>
56  inline Eigen::Matrix<var, R1, C1>
57  multiply(const Eigen::Matrix<T1, R1, C1>& m, const T2& c) {
58  return to_var(m) * to_var(c);
59  }
60 
71  template<typename T1, int R1, int C1, typename T2, int R2, int C2>
72  inline typename
73  boost::enable_if_c< boost::is_same<T1, var>::value ||
74  boost::is_same<T2, var>::value,
75  Eigen::Matrix<var, R1, C2> >::type
76  multiply(const Eigen::Matrix<T1, R1, C1>& m1,
77  const Eigen::Matrix<T2, R2, C2>& m2) {
79  "m1", m1,
80  "m2", m2);
81  Eigen::Matrix<var, R1, C2> result(m1.rows(), m2.cols());
82  for (int i = 0; i < m1.rows(); i++) {
83  typename Eigen::Matrix<T1, R1, C1>::ConstRowXpr crow(m1.row(i));
84  for (int j = 0; j < m2.cols(); j++) {
85  typename Eigen::Matrix<T2, R2, C2>::ConstColXpr ccol(m2.col(j));
86  if (j == 0) {
87  if (i == 0) {
88  result(i, j) = var(new dot_product_vari<T1, T2>(crow, ccol));
89  } else {
90  dot_product_vari<T1, T2> *v2
91  = static_cast<dot_product_vari<T1, T2>*>(result(0, j).vi_);
92  result(i, j)
93  = var(new dot_product_vari<T1, T2>(crow, ccol, NULL, v2));
94  }
95  } else {
96  if (i == 0) {
97  dot_product_vari<T1, T2> *v1
98  = static_cast<dot_product_vari<T1, T2>*>(result(i, 0).vi_);
99  result(i, j)
100  = var(new dot_product_vari<T1, T2>(crow, ccol, v1, NULL));
101  } else /* if (i != 0 && j != 0) */ {
102  dot_product_vari<T1, T2> *v1
103  = static_cast<dot_product_vari<T1, T2>*>(result(i, 0).vi_);
104  dot_product_vari<T1, T2> *v2
105  = static_cast<dot_product_vari<T1, T2>*>(result(0, j).vi_);
106  result(i, j)
107  = var(new dot_product_vari<T1, T2>(crow, ccol, v1, v2));
108  }
109  }
110  }
111  }
112  return result;
113  }
114 
124  template <typename T1, int C1, typename T2, int R2>
125  inline typename
126  boost::enable_if_c< boost::is_same<T1, var>::value ||
127  boost::is_same<T2, var>::value, var >::type
128  multiply(const Eigen::Matrix<T1, 1, C1>& rv,
129  const Eigen::Matrix<T2, R2, 1>& v) {
130  if (rv.size() != v.size())
131  throw std::domain_error("row vector and vector must be same length "
132  "in multiply");
133  return dot_product(rv, v);
134  }
135 
136  }
137 }
138 #endif
Eigen::Matrix< fvar< T >, R1, C1 > multiply(const Eigen::Matrix< fvar< T >, R1, C1 > &m, const fvar< T > &c)
Definition: multiply.hpp:20
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:32
var to_var(const double &x)
Converts argument to an automatic differentiation variable.
Definition: to_var.hpp:21
bool check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Return true if the matrices can be multiplied.
fvar< T > dot_product(const Eigen::Matrix< fvar< T >, R1, C1 > &v1, const Eigen::Matrix< fvar< T >, R2, C2 > &v2)
Definition: dot_product.hpp:20
void domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.

     [ Stan Home Page ] © 2011–2015, Stan Development Team.