7#ifndef DUNE_FUFEM_BACKENDS_NUMPYBACKEND_HH
8#define DUNE_FUFEM_BACKENDS_NUMPYBACKEND_HH
19#include <dune/assembler/backends/flatvector.hh>
20#include <dune/assembler/backends/flatmatrix.hh>
21#include <dune/assembler/backends/localpattern.hh>
22#include <dune/assembler/utility.hh>
24#define DUNE_FUFEM_DISABLE_HEADER_DEPRECATION
26#undef DUNE_FUFEM_DISABLE_HEADER_DEPRECATION
28#include <dune/python/pybind11/numpy.h>
52 template<
class SizeInfo>
53 void resize(
const SizeInfo& sizeInfo)
55 *vector_ = pybind11::array_t<T>(sizeInfo.size());
58 template<
class MultiIndex>
61 auto vector_u = vector_->unchecked();
62 return vector_u[
index];
65 template<
class MultiIndex>
68 auto vector_mu = vector_->mutable_unchecked();
69 return vector_mu[
index];
72 template<
typename Value>
75 auto size = vector_->shape(0);
76 auto vector_mu = vector_->mutable_unchecked();
101 template <
class LocalView>
105 (*
this)[localView.index(i)] += localVector[i];
130template<
class T=
double>
160 class NumPyCSRMatrixPatternBuilder
164 using Matrix = NumPyCSRMatrix<T>;
165 using LocalPattern = Dune::Assembler::LocalPatternMatrix;
168 NumPyCSRMatrixPatternBuilder(
Matrix& matrix) :
172 template<
class RowSizeInfo,
class ColSizeInfo>
173 void resize(
const RowSizeInfo& rowSizeInfo,
const ColSizeInfo& colSizeInfo)
175 indices_.resize(rowSizeInfo.size(), colSizeInfo.size());
187 nnz += indices_.rowsize(row);
190 matrix_.entries = pybind11::array_t<T>(nnz);
191 matrix_.colIndices = pybind11::array_t<size_type>(nnz);
192 matrix_.rowPtrs = pybind11::array_t<size_type>(rows+1);
194 auto entries_mu = matrix_.entries.mutable_unchecked();
195 auto colIndices_mu = matrix_.colIndices.mutable_unchecked();
196 auto rowPtrs_mu = matrix_.rowPtrs.mutable_unchecked();
202 rowPtrs_mu(row) =
next;
204 for(
auto&& colIndex : colIndicesOfRow)
206 colIndices_mu(next) = colIndex;
209 }, indices_.columnIndices(row));
211 rowPtrs_mu(rows) =
next;
218 template <
class RowLocalView,
class ColLocalView>
219 void scatter (
const RowLocalView& rowLocalView,
const ColLocalView& colLocalView,
220 const LocalPattern& localPattern)
222 auto rowIndices = Dune::Assembler::Impl::localIndices(rowLocalView.tree());
223 auto colIndices = Dune::Assembler::Impl::localIndices(colLocalView.tree());
224 localPattern.visitEntries(rowIndices, colIndices, [&](
const auto& i,
const auto&j) {
225 add(rowLocalView.index(i), colLocalView.index(j));
230 void forEachEntry (F&& f)
const
235 for (
auto col : columnIndices)
237 }, indices_.columnIndices(row));
243 template <
class RowIndex,
class ColIndex>
244 void add (
const RowIndex&
rowIndex,
const ColIndex& colIndex)
246 indices_.add(
rowIndex[0], colIndex[0]);
249 template<
class RowIndex,
class ColIndex>
250 void insertEntry(
const RowIndex&
rowIndex,
const ColIndex& colIndex)
271template<
class T=
double>
292 template<
class RowIndex,
class ColIndex>
293 auto nnzIndex(
const RowIndex& row,
const ColIndex& col)
const
297 auto ptrRange =
Dune::range(rowPtrs_u(row), rowPtrs_u(row+1));
299 auto k =
std::lower_bound(colIndicesOfRow.begin(), colIndicesOfRow.end(),
col) - colIndicesOfRow.begin();
300 return k+rowPtrs_u(row);
303 template<
class RowIndex,
class ColIndex>
310 template<
class RowIndex,
class ColIndex>
317 template <
class RowLocalView,
class ColLocalView>
318 void scatter (
const RowLocalView& rowLocalView,
const ColLocalView& colLocalView,
321 auto rowIndices = Dune::Assembler::Impl::localIndices(rowLocalView.tree());
322 auto colIndices = Dune::Assembler::Impl::localIndices(colLocalView.tree());
323 localPattern.visitEntries(rowIndices, colIndices, [&](
const auto& i,
const auto&j) {
324 (*this)(rowLocalView.index(i), colLocalView.index(j)) += localMatrix[i][j];
347 template<
class Value>
353 entries_mu(i) = value;
373[[deprecated(
"This class is deprecated and will be removed after 2.11.")]]
375 :
public Impl::NumPyCSRMatrixPatternBuilder<T>
Dune::BCRSMatrix< FieldMatrix< T, n, m >, TA > Matrix
BCRSMatrix< FieldMatrix< T, n, m >, A >::size_type size_type
auto rows(Matrix const &matrix)
auto cols(Matrix const &matrix)
auto transformedRangeView(R &&range, F &&f)
static constexpr IntegralRange< std::decay_t< T > > range(T &&from, U &&to) noexcept
size_type rowIndex() const
std::ptrdiff_t index() const
void add(const GlobalIndex &global)
Definition dunefunctionsboundaryfunctionalassembler.hh:29
Helper class for building matrix pattern.
Definition matrixbuilder.hh:51
Implementation of the VectorBackend concept for numpy vectors.
Definition numpybackend.hh:42
void resize(const SizeInfo &sizeInfo)
Definition numpybackend.hh:53
void setZero()
Definition numpybackend.hh:108
Dune::Assembler::FlatVector< value_type > LocalVector
Definition numpybackend.hh:99
T value_type
Definition numpybackend.hh:98
const Vector & vector() const
Definition numpybackend.hh:86
NumPyVectorBackend(Vector &vector)
Definition numpybackend.hh:48
void operator=(const Value &value)
Definition numpybackend.hh:73
void scatter(const LocalView &localView, const LocalVector &localVector)
Definition numpybackend.hh:102
pybind11::array_t< T > Vector
Definition numpybackend.hh:46
Vector & vector()
Definition numpybackend.hh:91
Struct providing raw storage for a flat CSR matrix using numpy arrays.
Definition numpybackend.hh:132
pybind11::array_t< size_type > rowPtrs
Definition numpybackend.hh:138
T field_type
Definition numpybackend.hh:134
std::array< size_type, 2 > shape
Definition numpybackend.hh:140
auto asTuple()
Definition numpybackend.hh:142
pybind11::array_t< T > entries
Definition numpybackend.hh:136
std::size_t size_type
Definition numpybackend.hh:133
auto asTuple() const
Definition numpybackend.hh:147
pybind11::array_t< size_type > colIndices
Definition numpybackend.hh:137
Implementation of the MatrixBackend concept NumPyCSRMatrix.
Definition numpybackend.hh:273
NumPyCSRMatrixBackend(Matrix &matrix)
Definition numpybackend.hh:283
const value_type & operator()(const RowIndex &row, const ColIndex &col) const
Definition numpybackend.hh:304
Matrix & matrix()
Definition numpybackend.hh:342
T value_type
Definition numpybackend.hh:276
value_type & operator()(const RowIndex &row, const ColIndex &col)
Definition numpybackend.hh:311
Impl::NumPyCSRMatrixPatternBuilder< T > PatternBuilder
Definition numpybackend.hh:278
Dune::Assembler::FlatMatrix< value_type > LocalMatrix
Definition numpybackend.hh:277
void assign(const Value &value)
Definition numpybackend.hh:348
void setZero()
Definition numpybackend.hh:328
typename PatternBuilder::LocalPattern LocalPattern
Definition numpybackend.hh:279
T Entry
Definition numpybackend.hh:335
auto nnzIndex(const RowIndex &row, const ColIndex &col) const
Definition numpybackend.hh:293
Matrix * matrix_
Definition numpybackend.hh:358
PatternBuilder patternBuilder()
Definition numpybackend.hh:287
void scatter(const RowLocalView &rowLocalView, const ColLocalView &colLocalView, const LocalMatrix &localMatrix, const LocalPattern &localPattern)
Definition numpybackend.hh:318
const Matrix & matrix() const
Definition numpybackend.hh:337