Dune-Fufem 2.11-git
Loading...
Searching...
No Matches
numpybackend.hh
Go to the documentation of this file.
1// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2// vi: set et ts=4 sw=2 sts=2:
3
4// SPDX-FileCopyrightText: Copyright © DUNE-FUFEM Project contributors, see file AUTHORS.md
5// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception OR LGPL-3.0-or-later
6
7#ifndef DUNE_FUFEM_BACKENDS_NUMPYBACKEND_HH
8#define DUNE_FUFEM_BACKENDS_NUMPYBACKEND_HH
9
10#include <cstddef>
11#include <type_traits>
12#include <utility>
13#include <algorithm>
14#include <array>
15#include <tuple>
16
18
19#include <dune/assembler/backends/flatvector.hh>
20#include <dune/assembler/backends/flatmatrix.hh>
21#include <dune/assembler/backends/localpattern.hh>
22#include <dune/assembler/utility.hh>
23
24#define DUNE_FUFEM_DISABLE_HEADER_DEPRECATION
26#undef DUNE_FUFEM_DISABLE_HEADER_DEPRECATION
27
28#include <dune/python/pybind11/numpy.h>
29
30
31
32namespace Dune::Fufem {
33
40template<class T>
42{
43
44public:
45
46 using Vector = pybind11::array_t<T>;
47
49 vector_(&vector)
50 {}
51
52 template<class SizeInfo>
53 void resize(const SizeInfo& sizeInfo)
54 {
55 *vector_ = pybind11::array_t<T>(sizeInfo.size());
56 }
57
58 template<class MultiIndex>
59 decltype(auto) operator[](const MultiIndex& index) const
60 {
61 auto vector_u = vector_->unchecked();
62 return vector_u[index];
63 }
64
65 template<class MultiIndex>
66 decltype(auto) operator[](const MultiIndex& index)
67 {
68 auto vector_mu = vector_->mutable_unchecked();
69 return vector_mu[index];
70 }
71
72 template<typename Value>
73 void operator= (const Value& value)
74 {
75 auto size = vector_->shape(0);
76 auto vector_mu = vector_->mutable_unchecked();
77 for(auto&& i : Dune::range(size))
78 vector_mu(i) = value;
79 }
80
82 {
83 vector() = other.vector();
84 }
85
86 const Vector& vector() const
87 {
88 return *vector_;
89 }
90
92 {
93 return *vector_;
94 }
95
96 // Extended interface of a Dune::Assembler vector-backend
97
98 using value_type = T;
99 using LocalVector = Dune::Assembler::FlatVector<value_type>;
100
101 template <class LocalView>
102 void scatter (const LocalView& localView, const LocalVector& localVector)
103 {
104 for (std::size_t i = 0; i < localView.size(); ++i)
105 (*this)[localView.index(i)] += localVector[i];
106 }
107
108 void setZero ()
109 {
110 (*this) = 0;
111 }
112
113private:
114 Vector* vector_;
115};
116
117
118
130template<class T=double>
132{
134 using field_type = T;
135
136 pybind11::array_t<T> entries;
137 pybind11::array_t<size_type> colIndices;
138 pybind11::array_t<size_type> rowPtrs;
139
141
142 auto asTuple()
143 {
145 }
146
147 auto asTuple() const
148 {
150 }
151
152};
153
154
155
156namespace Impl {
157
158 // This implements the pattern builder for NumPyCSRMatrix<T>
159 template<class T>
160 class NumPyCSRMatrixPatternBuilder
161 {
162 public:
163
164 using Matrix = NumPyCSRMatrix<T>;
165 using LocalPattern = Dune::Assembler::LocalPatternMatrix;
166 using size_type = typename Matrix::size_type;
167
168 NumPyCSRMatrixPatternBuilder(Matrix& matrix) :
169 matrix_(matrix)
170 {}
171
172 template<class RowSizeInfo, class ColSizeInfo>
173 void resize(const RowSizeInfo& rowSizeInfo, const ColSizeInfo& colSizeInfo)
174 {
175 indices_.resize(rowSizeInfo.size(), colSizeInfo.size());
176 }
177
178 void setupMatrix()
179 {
180 size_type rows = indices_.rows();
181 size_type cols = indices_.cols();
182 matrix_.shape = {rows, cols};
183
184 // Compute number of nonzeros
185 std::size_t nnz = 0;
186 for(auto row : Dune::range(rows))
187 nnz += indices_.rowsize(row);
188
189 // Allocate data
190 matrix_.entries = pybind11::array_t<T>(nnz);
191 matrix_.colIndices = pybind11::array_t<size_type>(nnz);
192 matrix_.rowPtrs = pybind11::array_t<size_type>(rows+1);
193
194 auto entries_mu = matrix_.entries.mutable_unchecked();
195 auto colIndices_mu = matrix_.colIndices.mutable_unchecked();
196 auto rowPtrs_mu = matrix_.rowPtrs.mutable_unchecked();
197
198 // Write column indices
199 size_type next = 0;
200 for(auto row : Dune::range(rows))
201 {
202 rowPtrs_mu(row) = next;
203 std::visit([&](const auto& colIndicesOfRow) {
204 for(auto&& colIndex : colIndicesOfRow)
205 {
206 colIndices_mu(next) = colIndex;
207 ++next;
208 }
209 }, indices_.columnIndices(row));
210 }
211 rowPtrs_mu(rows) = next;
212
213 // Zero initialize entries
214 for(auto i : Dune::range(nnz))
215 entries_mu(i) = 0;
216 }
217
218 template <class RowLocalView, class ColLocalView>
219 void scatter (const RowLocalView& rowLocalView, const ColLocalView& colLocalView,
220 const LocalPattern& localPattern)
221 {
222 auto rowIndices = Dune::Assembler::Impl::localIndices(rowLocalView.tree());
223 auto colIndices = Dune::Assembler::Impl::localIndices(colLocalView.tree());
224 localPattern.visitEntries(rowIndices, colIndices, [&](const auto& i, const auto&j) {
225 add(rowLocalView.index(i), colLocalView.index(j));
226 });
227 }
228
229 template <class F>
230 void forEachEntry (F&& f) const
231 {
232 for (auto row : Dune::range(indices_.rows()))
233 {
234 std::visit([&](const auto& columnIndices) {
235 for (auto col : columnIndices)
236 f(std::size_t(row), std::size_t(col));
237 }, indices_.columnIndices(row));
238 }
239 }
240
241 // Old dune-fufem interface
242
243 template <class RowIndex, class ColIndex>
244 void add (const RowIndex& rowIndex, const ColIndex& colIndex)
245 {
246 indices_.add(rowIndex[0], colIndex[0]);
247 }
248
249 template<class RowIndex, class ColIndex>
250 void insertEntry(const RowIndex& rowIndex, const ColIndex& colIndex)
251 {
252 this->add(rowIndex, colIndex);
253 }
254
255 private:
256 Dune::MatrixIndexSet indices_;
257 Matrix& matrix_;
258 };
259
260}
261
262
263
264
271template<class T=double>
273{
274public:
275
276 using value_type = T;
277 using LocalMatrix = Dune::Assembler::FlatMatrix<value_type>;
278 using PatternBuilder = Impl::NumPyCSRMatrixPatternBuilder<T>;
279 using LocalPattern = typename PatternBuilder::LocalPattern;
280
282
286
288 {
289 return {*matrix_};
290 }
291
292 template<class RowIndex, class ColIndex>
293 auto nnzIndex(const RowIndex& row, const ColIndex& col) const
294 {
295 auto colIndices_u = matrix_->colIndices.unchecked();
296 auto rowPtrs_u = matrix_->rowPtrs.unchecked();
297 auto ptrRange = Dune::range(rowPtrs_u(row), rowPtrs_u(row+1));
298 auto colIndicesOfRow = Dune::transformedRangeView(ptrRange, [&](auto i) -> decltype(auto) { return colIndices_u(i); });
299 auto k = std::lower_bound(colIndicesOfRow.begin(), colIndicesOfRow.end(), col) - colIndicesOfRow.begin();
300 return k+rowPtrs_u(row);
301 }
302
303 template<class RowIndex, class ColIndex>
304 const value_type& operator()(const RowIndex& row, const ColIndex& col) const
305 {
306 auto entries_u = matrix_->entries.unchecked();
307 return entries_u(nnzIndex(row, col));
308 }
309
310 template<class RowIndex, class ColIndex>
311 value_type& operator()(const RowIndex& row, const ColIndex& col)
312 {
313 auto entries_mu = matrix_->entries.mutable_unchecked();
314 return entries_mu(nnzIndex(row, col));
315 }
316
317 template <class RowLocalView, class ColLocalView>
318 void scatter (const RowLocalView& rowLocalView, const ColLocalView& colLocalView,
319 const LocalMatrix& localMatrix, const LocalPattern& localPattern)
320 {
321 auto rowIndices = Dune::Assembler::Impl::localIndices(rowLocalView.tree());
322 auto colIndices = Dune::Assembler::Impl::localIndices(colLocalView.tree());
323 localPattern.visitEntries(rowIndices, colIndices, [&](const auto& i, const auto&j) {
324 (*this)(rowLocalView.index(i), colLocalView.index(j)) += localMatrix[i][j];
325 });
326 }
327
328 void setZero ()
329 {
330 assign(0);
331 }
332
333 // Old dune-fufem interface
334
335 using Entry = T;
336
337 const Matrix& matrix() const
338 {
339 return *matrix_;
340 }
341
343 {
344 return *matrix_;
345 }
346
347 template<class Value>
348 void assign(const Value& value)
349 {
350 auto nnz = matrix_->entries.shape(0);
351 auto entries_mu = matrix_->entries.mutable_unchecked();
352 for(auto&& i : Dune::range(nnz))
353 entries_mu(i) = value;
354 }
355
356protected:
357
359};
360
361
362
371template<class T>
372class
373[[deprecated("This class is deprecated and will be removed after 2.11.")]]
375 : public Impl::NumPyCSRMatrixPatternBuilder<T>
376{};
377
378} // namespace Dune::Fufem
379
380
381
382#endif // DUNE_FUFEM_BACKENDS_NUMPYBACKEND_HH
Col col
Dune::BCRSMatrix< FieldMatrix< T, n, m >, TA > Matrix
BCRSMatrix< FieldMatrix< T, n, m >, A >::size_type size_type
auto rows(Matrix const &matrix)
auto cols(Matrix const &matrix)
int size() const
auto transformedRangeView(R &&range, F &&f)
static constexpr IntegralRange< std::decay_t< T > > range(T &&from, U &&to) noexcept
size_type rowIndex() const
std::ptrdiff_t index() const
size_t() const
void add(const GlobalIndex &global)
STL namespace.
Definition dunefunctionsboundaryfunctionalassembler.hh:29
Helper class for building matrix pattern.
Definition matrixbuilder.hh:51
Implementation of the VectorBackend concept for numpy vectors.
Definition numpybackend.hh:42
void resize(const SizeInfo &sizeInfo)
Definition numpybackend.hh:53
void setZero()
Definition numpybackend.hh:108
Dune::Assembler::FlatVector< value_type > LocalVector
Definition numpybackend.hh:99
T value_type
Definition numpybackend.hh:98
const Vector & vector() const
Definition numpybackend.hh:86
NumPyVectorBackend(Vector &vector)
Definition numpybackend.hh:48
void operator=(const Value &value)
Definition numpybackend.hh:73
void scatter(const LocalView &localView, const LocalVector &localVector)
Definition numpybackend.hh:102
pybind11::array_t< T > Vector
Definition numpybackend.hh:46
Vector & vector()
Definition numpybackend.hh:91
Struct providing raw storage for a flat CSR matrix using numpy arrays.
Definition numpybackend.hh:132
pybind11::array_t< size_type > rowPtrs
Definition numpybackend.hh:138
T field_type
Definition numpybackend.hh:134
std::array< size_type, 2 > shape
Definition numpybackend.hh:140
auto asTuple()
Definition numpybackend.hh:142
pybind11::array_t< T > entries
Definition numpybackend.hh:136
std::size_t size_type
Definition numpybackend.hh:133
auto asTuple() const
Definition numpybackend.hh:147
pybind11::array_t< size_type > colIndices
Definition numpybackend.hh:137
Implementation of the MatrixBackend concept NumPyCSRMatrix.
Definition numpybackend.hh:273
NumPyCSRMatrixBackend(Matrix &matrix)
Definition numpybackend.hh:283
const value_type & operator()(const RowIndex &row, const ColIndex &col) const
Definition numpybackend.hh:304
Matrix & matrix()
Definition numpybackend.hh:342
T value_type
Definition numpybackend.hh:276
value_type & operator()(const RowIndex &row, const ColIndex &col)
Definition numpybackend.hh:311
Impl::NumPyCSRMatrixPatternBuilder< T > PatternBuilder
Definition numpybackend.hh:278
Dune::Assembler::FlatMatrix< value_type > LocalMatrix
Definition numpybackend.hh:277
void assign(const Value &value)
Definition numpybackend.hh:348
void setZero()
Definition numpybackend.hh:328
typename PatternBuilder::LocalPattern LocalPattern
Definition numpybackend.hh:279
T Entry
Definition numpybackend.hh:335
auto nnzIndex(const RowIndex &row, const ColIndex &col) const
Definition numpybackend.hh:293
Matrix * matrix_
Definition numpybackend.hh:358
PatternBuilder patternBuilder()
Definition numpybackend.hh:287
void scatter(const RowLocalView &rowLocalView, const ColLocalView &colLocalView, const LocalMatrix &localMatrix, const LocalPattern &localPattern)
Definition numpybackend.hh:318
const Matrix & matrix() const
Definition numpybackend.hh:337
T lower_bound(T... args)
T next(T... args)
T visit(T... args)