1 #ifndef STAN_MATH_REV_SCAL_FUN_FMA_HPP
2 #define STAN_MATH_REV_SCAL_FUN_FMA_HPP
7 #include <boost/math/special_functions/fpclassify.hpp>
16 class fma_vvv_vari :
public op_vvv_vari {
18 fma_vvv_vari(vari* avi, vari* bvi, vari* cvi) :
19 op_vvv_vari(::
fma(avi->val_, bvi->val_, cvi->val_),
26 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
27 bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
28 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
30 avi_->adj_ += adj_ * bvi_->val_;
31 bvi_->adj_ += adj_ * avi_->val_;
37 class fma_vvd_vari :
public op_vvd_vari {
39 fma_vvd_vari(vari* avi, vari* bvi,
double c) :
40 op_vvd_vari(::
fma(avi->val_, bvi->val_, c),
47 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
48 bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
50 avi_->adj_ += adj_ * bvi_->val_;
51 bvi_->adj_ += adj_ * avi_->val_;
56 class fma_vdv_vari :
public op_vdv_vari {
58 fma_vdv_vari(vari* avi,
double b, vari* cvi) :
59 op_vdv_vari(::
fma(avi->val_ , b, cvi->val_),
66 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
67 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
69 avi_->adj_ += adj_ * bd_;
75 class fma_vdd_vari :
public op_vdd_vari {
77 fma_vdd_vari(vari* avi,
double b,
double c) :
78 op_vdd_vari(::
fma(avi->val_ , b, c),
85 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
87 avi_->adj_ += adj_ * bd_;
91 class fma_ddv_vari :
public op_ddv_vari {
93 fma_ddv_vari(
double a,
double b, vari* cvi) :
94 op_ddv_vari(::
fma(a, b, cvi->val_),
101 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
157 return var(
new fma_vvd_vari(a.
vi_, b.
vi_, c));
182 return var(
new fma_vdv_vari(a.
vi_, b, c.
vi_));
205 return var(
new fma_vdd_vari(a.
vi_, b, c));
228 return var(
new fma_vdd_vari(b.
vi_, a, c));
251 return var(
new fma_ddv_vari(a, b, c.
vi_));
276 return var(
new fma_vdv_vari(b.
vi_, a, c.
vi_));
Independent (input) and dependent (output) variables for gradients.
bool isnan(const stan::math::var &v)
Checks if the given number is NaN.
fvar< typename stan::return_type< T1, T2, T3 >::type > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
vari * vi_
Pointer to the implementation of this variable.