1#ifndef DUNE_FEM_THREADPARTITIONER_HH 
    2#define DUNE_FEM_THREADPARTITIONER_HH 
   11#include <dune/fem/gridpart/common/capabilities.hh> 
   14#include <dune/alugrid/3d/alugrid.hh> 
   20template < 
class Gr
idPartImp >
 
   21class ThreadPartitioner
 
   25  typedef GridPartImp GridPartType;
 
   26  typedef typename GridPartType :: GridType GridType;
 
   27  typedef typename GridType :: Traits :: LocalIdSet LocalIdSetType;
 
   28  typedef typename LocalIdSetType :: IdType IdType;
 
   30  typedef typename GridPartType :: IndexSetType  IndexSetType;
 
   31  typedef typename IndexSetType :: IndexType IndexType;
 
   34  typedef ALU3DSPACE LoadBalancer LoadBalancerType;
 
   35  typedef typename LoadBalancerType :: DataBase DataBaseType;
 
   38  typedef ALU3DSPACE MpAccessLocal MPAccessInterfaceType;
 
   41  typedef ALU3DSPACE MpAccessSerial  MPAccessImplType;
 
   43  mutable MPAccessImplType  mpAccess_;
 
   47  const GridPartType& gridPart_;
 
   48  const IndexSetType& indexSet_;
 
   50  typedef typename GridPartType :: template Codim<0> :: EntityType         EntityType;
 
   53  const double ldbOver_ ;
 
   54  const double ldbUnder_;
 
   55  const double cutOffFactor_;
 
   60  std::vector< int > index_;
 
   63  std::vector< int > partition_;
 
   66  enum Method { recursive = 0, 
 
   75  ThreadPartitioner( 
const GridPartType& gridPart,
 
   77                     const double cutOffFactor = 1.0 )
 
   81    , indexSet_( gridPart_.indexSet() )
 
   84    , cutOffFactor_( 
std::
max( double(1.0 - cutOffFactor), 0.0 ) )
 
   86    , graphSize_( pSize_ )
 
   87    , index_( indexSet_.
size( 0 ), -1 )
 
   90    calculateGraph( gridPart_ );
 
   97  int getIndex( 
const size_t idx )
 
   99    assert( idx < index_.size() );
 
  100    if( index_[ idx ] < 0 ) index_[ idx ] = indexCounter_ ++ ;
 
  101    return index_[ idx ] ;
 
  105  int getIndex( 
const size_t idx )
 const 
  107    assert( idx < index_.size() );
 
  108    return index_[ idx ] ;
 
  112  int getIndex( 
const EntityType& entity )
 
  114    return getIndex( indexSet_.index( entity ) );
 
  118  int getIndex( 
const EntityType& entity )
 const 
  120    return getIndex( indexSet_.index( entity ) );
 
  123  void calculateGraph( 
const GridPartType& gridPart )
 
  126    typedef typename GridPartType :: template 
Codim< 0 > :: IteratorType Iterator;
 
  127    const Iterator end = gridPart.template end<0> ();
 
  128    const int cutOff = cutOffFactor_ * (indexSet_.size( 0 ) / pSize_) ;
 
  130    for(Iterator it = gridPart.template begin<0> (); it != end; ++it )
 
  132      const EntityType& entity = *it;
 
  134      ldbUpdateVertex ( entity, cutOff,
 
  135                        gridPart.ibegin( entity ),
 
  136                        gridPart.iend( entity ),
 
  141  template <
class IntersectionIteratorType>
 
  142  void ldbUpdateVertex ( 
const EntityType & entity,
 
  144                         const IntersectionIteratorType& ibegin,
 
  145                         const IntersectionIteratorType& iend,
 
  148    const int index = getIndex( entity );
 
  149    int weight = (index >= cutOff) ? 1 : 8; 
 
  152      if( Fem::GridPartCapabilities::hasGrid< GridPartType >::v )
 
  155        const int mxl = gridPart_.grid().maxLevel();
 
  156        if( mxl > entity.level() && ! entity.isLeaf() )
 
  158          typedef typename EntityType :: HierarchicIterator HierIt;
 
  159          const HierIt endit = entity.hend( mxl );
 
  160          for(HierIt it = entity.hbegin( mxl ); it != endit; ++it)
 
  165      db.vertexUpdate( 
typename LoadBalancerType::GraphVertex( index, weight ) );
 
  170    updateFaces( entity, ibegin, iend, weight, db );
 
  173  template <
class IntersectionIteratorType>
 
  174  void updateFaces(
const EntityType& en,
 
  175                   IntersectionIteratorType nit,
 
  176                   const IntersectionIteratorType endit,
 
  180    for( ; nit != endit; ++nit )
 
  182      typedef typename IntersectionIteratorType :: Intersection IntersectionType;
 
  183      const IntersectionType& intersection = *nit;
 
  184      if(intersection.neighbor())
 
  186        EntityType nb = intersection.outside();
 
  189          const int eid = getIndex( en );
 
  190          const int nid = getIndex( nb );
 
  197            typedef typename LoadBalancerType :: GraphEdge GraphEdge;
 
  198            db.edgeUpdate ( GraphEdge ( eid, nid, weight, -1, -1 ) );
 
  212  bool serialPartition( 
const Method method = recursive )
 
  218      if( graphSize_ <= pSize_ )
 
  220        partition_.resize( graphSize_ );
 
  221        for( 
int i=0; i<graphSize_; ++ i )
 
  227        if( method == recursive )
 
  228          partition_ = db_.repartition( mpAccess_, DataBaseType :: METIS_PartGraphRecursive, pSize_ );
 
  229        else if( method == kway )
 
  230          partition_ = db_.repartition( mpAccess_, DataBaseType :: METIS_PartGraphKway, pSize_ );
 
  231        else if( method == sfc )
 
  233          partition_ = db_.repartition( mpAccess_, DataBaseType :: ALUGRID_SpaceFillingCurveSerial, pSize_ );
 
  236          DUNE_THROW(InvalidStateException,
"ThreadPartitioner::serialPartition: wrong method");
 
  237        assert( 
int(partition_.size()) >= graphSize_ );
 
  249      return partition_.size() > 0;
 
  253      partition_.resize( indexSet_.size( 0 ) );
 
  254      for( 
size_t i =0; i<partition_.size(); ++i )
 
  260  std::set < int, std::less < int > > scan()
 const 
  265  int getRank( 
const EntityType& entity )
 const 
  268    assert( (
int) partition_.size() > getIndex( entity ) );
 
  269    return partition_[ getIndex( entity ) ];
 
  272  bool validEntity( 
const EntityType& entity, 
const int rank )
 const 
  274    return getRank( entity ) == rank;
 
  282#warning "DUNE-ALUGrid Partitioner not available" 
#define DUNE_THROW(E,...)
Definition: exceptions.hh:314
 
constexpr auto max
Function object that returns the greater of the given values.
Definition: hybridutilities.hh:485
 
Helpers for dealing with MPI.
 
Dune namespace.
Definition: alignedallocator.hh:13
 
constexpr std::integral_constant< std::size_t, sizeof...(II)> size(std::integer_sequence< T, II... >)
Return the size of the sequence.
Definition: integersequence.hh:75