10 #ifndef EIGEN_BLASUTIL_H
11 #define EIGEN_BLASUTIL_H
21 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs=false,
bool ConjugateRhs=false>
24 template<
typename Scalar,
typename Index,
typename DataMapper,
int nr,
int StorageOrder,
bool Conjugate = false,
bool PanelMode=false>
27 template<
typename Scalar,
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
int StorageOrder,
bool Conjugate = false,
bool PanelMode = false>
32 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
33 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
34 int ResStorageOrder,
int ResInnerStride>
35 struct general_matrix_matrix_product;
37 template<
typename Index,
38 typename LhsScalar,
typename LhsMapper,
int LhsStorageOrder,
bool ConjugateLhs,
39 typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version=Specialized>
40 struct general_matrix_vector_product;
43 template<
bool Conjugate>
struct conj_if;
45 template<>
struct conj_if<true> {
47 inline T operator()(
const T& x)
const {
return numext::conj(x); }
49 inline T pconj(
const T& x)
const {
return internal::pconj(x); }
52 template<>
struct conj_if<false> {
54 inline const T& operator()(
const T& x)
const {
return x; }
56 inline const T& pconj(
const T& x)
const {
return x; }
60 template<
typename LhsScalar,
typename RhsScalar,
bool ConjLhs,
bool ConjRhs>
63 typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar;
65 EIGEN_STRONG_INLINE Scalar pmadd(
const LhsScalar& x,
const RhsScalar& y,
const Scalar& c)
const
66 {
return padd(c, pmul(x,y)); }
68 EIGEN_STRONG_INLINE Scalar pmul(
const LhsScalar& x,
const RhsScalar& y)
const
69 {
return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); }
72 template<
typename Scalar>
struct conj_helper<Scalar,Scalar,false,false>
74 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const {
return internal::pmadd(x,y,c); }
75 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const {
return internal::pmul(x,y); }
78 template<
typename RealScalar>
struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
80 typedef std::complex<RealScalar> Scalar;
81 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const
82 {
return c + pmul(x,y); }
84 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const
85 {
return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
88 template<
typename RealScalar>
struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
90 typedef std::complex<RealScalar> Scalar;
91 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const
92 {
return c + pmul(x,y); }
94 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const
95 {
return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
98 template<
typename RealScalar>
struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
100 typedef std::complex<RealScalar> Scalar;
101 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const
102 {
return c + pmul(x,y); }
104 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const
105 {
return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
108 template<
typename RealScalar,
bool Conj>
struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
110 typedef std::complex<RealScalar> Scalar;
111 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const RealScalar& y,
const Scalar& c)
const
112 {
return padd(c, pmul(x,y)); }
113 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const RealScalar& y)
const
114 {
return conj_if<Conj>()(x)*y; }
117 template<
typename RealScalar,
bool Conj>
struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
119 typedef std::complex<RealScalar> Scalar;
120 EIGEN_STRONG_INLINE Scalar pmadd(
const RealScalar& x,
const Scalar& y,
const Scalar& c)
const
121 {
return padd(c, pmul(x,y)); }
122 EIGEN_STRONG_INLINE Scalar pmul(
const RealScalar& x,
const Scalar& y)
const
123 {
return x*conj_if<Conj>()(y); }
126 template<
typename From,
typename To>
struct get_factor {
127 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE To run(
const From& x) {
return To(x); }
130 template<
typename Scalar>
struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
132 static EIGEN_STRONG_INLINE
typename NumTraits<Scalar>::Real run(
const Scalar& x) {
return numext::real(x); }
136 template<
typename Scalar,
typename Index>
137 class BlasVectorMapper {
139 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
141 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(
Index i)
const {
144 template <
typename Packet,
int AlignmentType>
145 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(
Index i)
const {
146 return ploadt<Packet, AlignmentType>(m_data + i);
149 template <
typename Packet>
150 EIGEN_DEVICE_FUNC
bool aligned(
Index i)
const {
151 return (UIntPtr(m_data+i)%
sizeof(Packet))==0;
158 template<
typename Scalar,
typename Index,
int AlignmentType,
int Incr=1>
159 class BlasLinearMapper;
161 template<
typename Scalar,
typename Index,
int AlignmentType>
164 typedef typename packet_traits<Scalar>::type Packet;
165 typedef typename packet_traits<Scalar>::half HalfPacket;
167 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,
Index incr=1)
170 EIGEN_ONLY_USED_FOR_DEBUG(incr);
171 eigen_assert(incr==1);
174 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void prefetch(
int i)
const {
175 internal::prefetch(&
operator()(i));
178 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(
Index i)
const {
182 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(
Index i)
const {
183 return ploadt<Packet, AlignmentType>(m_data + i);
186 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(
Index i)
const {
187 return ploadt<HalfPacket, AlignmentType>(m_data + i);
190 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void storePacket(
Index i,
const Packet &p)
const {
191 pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
199 template<
typename Scalar,
typename Index,
int StorageOrder,
int AlignmentType = Unaligned,
int Incr = 1>
200 class blas_data_mapper;
202 template<
typename Scalar,
typename Index,
int StorageOrder,
int AlignmentType>
206 typedef typename packet_traits<Scalar>::type Packet;
207 typedef typename packet_traits<Scalar>::half HalfPacket;
209 typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
210 typedef BlasVectorMapper<Scalar, Index> VectorMapper;
212 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data,
Index stride,
Index incr=1)
213 : m_data(data), m_stride(stride)
215 EIGEN_ONLY_USED_FOR_DEBUG(incr);
216 eigen_assert(incr==1);
219 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
221 return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&
operator()(i, j), m_stride);
224 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(
Index i,
Index j)
const {
225 return LinearMapper(&
operator()(i, j));
228 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(
Index i,
Index j)
const {
229 return VectorMapper(&
operator()(i, j));
234 EIGEN_ALWAYS_INLINE Scalar& operator()(
Index i,
Index j)
const {
235 return m_data[StorageOrder==
RowMajor ? j + i*m_stride : i + j*m_stride];
238 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(
Index i,
Index j)
const {
239 return ploadt<Packet, AlignmentType>(&
operator()(i, j));
242 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(
Index i,
Index j)
const {
243 return ploadt<HalfPacket, AlignmentType>(&
operator()(i, j));
246 template<
typename SubPacket>
247 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void scatterPacket(
Index i,
Index j,
const SubPacket &p)
const {
248 pscatter<Scalar, SubPacket>(&
operator()(i, j), p, m_stride);
251 template<
typename SubPacket>
252 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(
Index i,
Index j)
const {
253 return pgather<Scalar, SubPacket>(&
operator()(i, j), m_stride);
256 EIGEN_DEVICE_FUNC
const Index stride()
const {
return m_stride; }
257 EIGEN_DEVICE_FUNC
const Scalar* data()
const {
return m_data; }
259 EIGEN_DEVICE_FUNC
Index firstAligned(
Index size)
const {
260 if (UIntPtr(m_data)%
sizeof(Scalar)) {
263 return internal::first_default_aligned(m_data, size);
267 Scalar* EIGEN_RESTRICT m_data;
268 const Index m_stride;
274 template<
typename Scalar,
typename Index,
int AlignmentType,
int Incr>
275 class BlasLinearMapper
278 typedef typename packet_traits<Scalar>::type Packet;
279 typedef typename packet_traits<Scalar>::half HalfPacket;
281 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,
Index incr) : m_data(data), m_incr(incr) {}
283 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void prefetch(
int i)
const {
284 internal::prefetch(&
operator()(i));
287 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(
Index i)
const {
288 return m_data[i*m_incr.value()];
291 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(
Index i)
const {
292 return pgather<Scalar,Packet>(m_data + i*m_incr.value(), m_incr.value());
295 template<
typename PacketType>
296 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void storePacket(
Index i,
const PacketType &p)
const {
297 pscatter<Scalar, PacketType>(m_data + i*m_incr.value(), p, m_incr.value());
302 const internal::variable_if_dynamic<Index,Incr> m_incr;
305 template<
typename Scalar,
typename Index,
int StorageOrder,
int AlignmentType,
int Incr>
306 class blas_data_mapper
309 typedef typename packet_traits<Scalar>::type Packet;
310 typedef typename packet_traits<Scalar>::half HalfPacket;
312 typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
314 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data,
Index stride,
Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
316 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper
318 return blas_data_mapper(&
operator()(i, j), m_stride, m_incr.value());
321 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(
Index i,
Index j)
const {
322 return LinearMapper(&
operator()(i, j), m_incr.value());
326 EIGEN_ALWAYS_INLINE Scalar& operator()(
Index i,
Index j)
const {
327 return m_data[StorageOrder==
RowMajor ? j*m_incr.value() + i*m_stride : i*m_incr.value() + j*m_stride];
330 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(
Index i,
Index j)
const {
331 return pgather<Scalar,Packet>(&
operator()(i, j),m_incr.value());
334 template <
typename PacketT,
int AlignmentT>
335 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(
Index i,
Index j)
const {
336 return pgather<Scalar,PacketT>(&
operator()(i, j),m_incr.value());
339 template<
typename SubPacket>
340 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void scatterPacket(
Index i,
Index j,
const SubPacket &p)
const {
341 pscatter<Scalar, SubPacket>(&
operator()(i, j), p, m_stride);
344 template<
typename SubPacket>
345 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(
Index i,
Index j)
const {
346 return pgather<Scalar, SubPacket>(&
operator()(i, j), m_stride);
350 Scalar* EIGEN_RESTRICT m_data;
351 const Index m_stride;
352 const internal::variable_if_dynamic<Index,Incr> m_incr;
356 template<
typename Scalar,
typename Index,
int StorageOrder>
357 class const_blas_data_mapper :
public blas_data_mapper<const Scalar, Index, StorageOrder> {
359 EIGEN_ALWAYS_INLINE const_blas_data_mapper(
const Scalar *data,
Index stride) : blas_data_mapper<const Scalar,
Index, StorageOrder>(data, stride) {}
361 EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(
Index i,
Index j)
const {
362 return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->
operator()(i, j)), this->m_stride);
370 template<
typename XprType>
struct blas_traits
372 typedef typename traits<XprType>::Scalar Scalar;
373 typedef const XprType& ExtractType;
374 typedef XprType _ExtractType;
376 IsComplex = NumTraits<Scalar>::IsComplex,
377 IsTransposed =
false,
378 NeedToConjugate =
false,
380 && (
bool(XprType::IsVectorAtCompileTime)
381 || int(inner_stride_at_compile_time<XprType>::ret) == 1)
384 typedef typename conditional<bool(HasUsableDirectAccess),
386 typename _ExtractType::PlainObject
387 >::type DirectLinearAccessType;
388 static inline ExtractType extract(
const XprType& x) {
return x; }
389 static inline const Scalar extractScalarFactor(
const XprType&) {
return Scalar(1); }
393 template<
typename Scalar,
typename NestedXpr>
394 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
395 : blas_traits<NestedXpr>
397 typedef blas_traits<NestedXpr> Base;
398 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
399 typedef typename Base::ExtractType ExtractType;
402 IsComplex = NumTraits<Scalar>::IsComplex,
403 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
405 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.nestedExpression()); }
406 static inline Scalar extractScalarFactor(
const XprType& x) {
return conj(Base::extractScalarFactor(x.nestedExpression())); }
410 template<
typename Scalar,
typename NestedXpr,
typename Plain>
411 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
412 : blas_traits<NestedXpr>
414 typedef blas_traits<NestedXpr> Base;
415 typedef CwiseBinaryOp<scalar_product_op<Scalar>,
const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
416 typedef typename Base::ExtractType ExtractType;
417 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.rhs()); }
418 static inline Scalar extractScalarFactor(
const XprType& x)
419 {
return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
421 template<
typename Scalar,
typename NestedXpr,
typename Plain>
422 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
423 : blas_traits<NestedXpr>
425 typedef blas_traits<NestedXpr> Base;
426 typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr,
const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
427 typedef typename Base::ExtractType ExtractType;
428 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.lhs()); }
429 static inline Scalar extractScalarFactor(
const XprType& x)
430 {
return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
432 template<
typename Scalar,
typename Plain1,
typename Plain2>
433 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>,
434 const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > >
435 : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> >
439 template<
typename Scalar,
typename NestedXpr>
440 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
441 : blas_traits<NestedXpr>
443 typedef blas_traits<NestedXpr> Base;
444 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
445 typedef typename Base::ExtractType ExtractType;
446 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.nestedExpression()); }
447 static inline Scalar extractScalarFactor(
const XprType& x)
448 {
return - Base::extractScalarFactor(x.nestedExpression()); }
452 template<
typename NestedXpr>
453 struct blas_traits<Transpose<NestedXpr> >
454 : blas_traits<NestedXpr>
456 typedef typename NestedXpr::Scalar Scalar;
457 typedef blas_traits<NestedXpr> Base;
458 typedef Transpose<NestedXpr> XprType;
459 typedef Transpose<const typename Base::_ExtractType> ExtractType;
460 typedef Transpose<const typename Base::_ExtractType> _ExtractType;
461 typedef typename conditional<bool(Base::HasUsableDirectAccess),
463 typename ExtractType::PlainObject
464 >::type DirectLinearAccessType;
466 IsTransposed = Base::IsTransposed ? 0 : 1
468 static inline ExtractType extract(
const XprType& x) {
return ExtractType(Base::extract(x.nestedExpression())); }
469 static inline Scalar extractScalarFactor(
const XprType& x) {
return Base::extractScalarFactor(x.nestedExpression()); }
473 struct blas_traits<const T>
477 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
478 struct extract_data_selector {
479 static const typename T::Scalar* run(
const T& m)
481 return blas_traits<T>::extract(m).data();
486 struct extract_data_selector<T,false> {
487 static typename T::Scalar* run(
const T&) {
return 0; }
490 template<
typename T>
const typename T::Scalar* extract_data(
const T& m)
492 return extract_data_selector<T>::run(m);
499 #endif // EIGEN_BLASUTIL_H