[mpich2-commits] r6743 - in mpich2/trunk/src: include mpi/coll

thakur at mcs.anl.gov thakur at mcs.anl.gov
Tue May 25 14:41:49 CDT 2010


Author: thakur
Date: 2010-05-25 14:41:49 -0500 (Tue, 25 May 2010)
New Revision: 6743

Modified:
   mpich2/trunk/src/include/mpiimpl.h
   mpich2/trunk/src/mpi/coll/red_scat_block.c
Log:
updated Reduce_scatter_block to use optimized Reduce_scatter algorithms

Modified: mpich2/trunk/src/include/mpiimpl.h
===================================================================
--- mpich2/trunk/src/include/mpiimpl.h	2010-05-25 19:39:49 UTC (rev 6742)
+++ mpich2/trunk/src/include/mpiimpl.h	2010-05-25 19:41:49 UTC (rev 6743)
@@ -1730,6 +1730,8 @@
                       MPID_Comm *);
     int (*Reduce_scatter) (void*, void*, int *, MPI_Datatype, MPI_Op, 
                            MPID_Comm *);
+    int (*Reduce_scatter_block) (void*, void*, int, MPI_Datatype, MPI_Op, 
+                           MPID_Comm *);
     int (*Scan) (void*, void*, int, MPI_Datatype, MPI_Op, MPID_Comm * );
     int (*Exscan) (void*, void*, int, MPI_Datatype, MPI_Op, MPID_Comm * );
     
@@ -3198,6 +3200,7 @@
 #define MPIR_ALLTOALLW_TAG            25
 #define MPIR_TOPO_A_TAG               26
 #define MPIR_TOPO_B_TAG               27
+#define MPIR_REDUCE_SCATTER_BLOCK_TAG 28
 
 /* These functions are used in the implementation of collective
    operations. They are wrappers around MPID send/recv functions. They do

Modified: mpich2/trunk/src/mpi/coll/red_scat_block.c
===================================================================
--- mpich2/trunk/src/mpi/coll/red_scat_block.c	2010-05-25 19:39:49 UTC (rev 6742)
+++ mpich2/trunk/src/mpi/coll/red_scat_block.c	2010-05-25 19:41:49 UTC (rev 6743)
@@ -1,10 +1,16 @@
-/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
+/* -*- Mode: C; c-basic-offset:4 ; -*- */
 /*
  *
- *  (C) 2009 by Argonne National Laboratory.
+ *  (C) 2010 by Argonne National Laboratory.
  *      See COPYRIGHT in top-level directory.
  */
 
+
+/* This implementation of MPI_Reduce_scatter_block was obtained by taking
+   the implementation of MPI_Reduce_scatter from red_scat.c and replacing 
+   recvcnts[i] with recvcount everywhere. */
+
+
 #include "mpiimpl.h"
 
 /* -- Begin Profiling Symbol Block for routine MPI_Reduce_scatter_block */
@@ -22,26 +28,1087 @@
 #ifndef MPICH_MPI_FROM_PMPI
 #undef MPI_Reduce_scatter_block
 #define MPI_Reduce_scatter_block PMPI_Reduce_scatter_block
-/* any utility functions should go here, usually prefixed with PMPI_LOCAL to
- * correctly handle weak symbols and the profiling interface */
+
+
+/* Implements the "mirror permutation" of "bits" bits of an integer "x".
+
+   positions 76543210, bits==3 yields 76543012.
+
+   This function could/should be moved to a common utility location for use in
+   other collectives as well. */
+ATTRIBUTE((const)) /* tells the compiler that this func only depends on its args
+                      and may be optimized much more aggressively, similar to "pure" */
+static inline int mirror_permutation(unsigned int x, int bits)
+{
+    /* a mask for the high order bits that should be copied as-is */
+    int high_mask = ~((0x1 << bits) - 1);
+    int retval = x & high_mask;
+    int i;
+
+    for (i = 0; i < bits; ++i) {
+        unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
+        retval |= bitval << ((bits - i) - 1);
+    }
+
+    return retval;
+}
+
+/* FIXME should we be checking the op_errno here? */
+#ifdef HAVE_CXX_BINDING
+/* NOTE: assumes 'uop' is the operator function pointer and
+   that 'is_cxx_uop' is is a boolean indicating the obvious */
+#define call_uop(in_, inout_, count_, datatype_)                                     \
+do {                                                                                 \
+    if (is_cxx_uop) {                                                                \
+        (*MPIR_Process.cxx_call_op_fn)((in_), (inout_), (count_), (datatype_), uop); \
+    }                                                                                \
+    else {                                                                           \
+        (*uop)((in_), (inout_), &(count_), &(datatype_));                            \
+    }                                                                                \
+} while (0)
+
+#else
+#define call_uop(in_, inout_, count_, datatype_)      \
+    (*uop)((in_), (inout_), &(count_), &(datatype_))
 #endif
 
+/* Implements the reduce-scatter butterfly algorithm described in J. L. Traff's
+ * "An Improved Algorithm for (Non-commutative) Reduce-Scatter with an Application"
+ * from EuroPVM/MPI 2005.  This function currently only implements support for
+ * the power-of-2 case. */
 #undef FUNCNAME
+#define FUNCNAME MPIR_Reduce_scatter_block_noncomm
+#undef FCNAME
+#define FCNAME MPIU_QUOTE(FUNCNAME)
+static int MPIR_Reduce_scatter_block_noncomm (
+    void *sendbuf,
+    void *recvbuf,
+    int recvcount,
+    MPI_Datatype datatype,
+    MPI_Op op,
+    MPID_Comm *comm_ptr )
+{
+    int mpi_errno = MPI_SUCCESS;
+    int comm_size = comm_ptr->local_size;
+    int rank = comm_ptr->rank;
+    int pof2;
+    int log2_comm_size;
+    int i, k;
+    int recv_offset, send_offset;
+    int block_size, total_count, size;
+    MPI_Aint extent, true_extent, true_lb;
+    int is_commutative;
+    int buf0_was_inout;
+    void *tmp_buf0;
+    void *tmp_buf1;
+    void *result_ptr;
+    MPI_Comm comm = comm_ptr->handle;
+    MPI_User_function *uop;
+    MPID_Op *op_ptr;
+#ifdef HAVE_CXX_BINDING
+    int is_cxx_uop = 0;
+#endif
+    MPIU_CHKLMEM_DECL(3);
+
+    MPID_Datatype_get_extent_macro(datatype, extent);
+    /* assumes nesting is handled by the caller right now, may not be true in the future */
+    mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent);
+
+    if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) {
+        is_commutative = 1;
+        /* get the function by indexing into the op table */
+        uop = MPIR_Op_table[op%16 - 1];
+    }
+    else {
+        MPID_Op_get_ptr(op, op_ptr);
+        if (op_ptr->kind == MPID_OP_USER_NONCOMMUTE)
+            is_commutative = 0;
+        else
+            is_commutative = 1;
+
+#ifdef HAVE_CXX_BINDING
+        if (op_ptr->language == MPID_LANG_CXX) {
+            uop = (MPI_User_function *) op_ptr->function.c_function;
+            is_cxx_uop = 1;
+        }
+        else
+#endif
+        if ((op_ptr->language == MPID_LANG_C))
+            uop = (MPI_User_function *) op_ptr->function.c_function;
+        else
+            uop = (MPI_User_function *) op_ptr->function.f77_function;
+    }
+
+    pof2 = 1;
+    log2_comm_size = 0;
+    while (pof2 < comm_size) {
+        pof2 <<= 1;
+        ++log2_comm_size;
+    }
+
+    /* begin error checking */
+    MPIU_Assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
+    /* end error checking */
+
+    /* size of a block (count of datatype per block, NOT bytes per block) */
+    block_size = recvcount;
+    total_count = block_size * comm_size;
+
+    MPIU_CHKLMEM_MALLOC(tmp_buf0, void *, true_extent * total_count, mpi_errno, "tmp_buf0");
+    MPIU_CHKLMEM_MALLOC(tmp_buf1, void *, true_extent * total_count, mpi_errno, "tmp_buf1");
+    /* adjust for potential negative lower bound in datatype */
+    tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
+    tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
+
+    /* Copy our send data to tmp_buf0.  We do this one block at a time and
+       permute the blocks as we go according to the mirror permutation. */
+    for (i = 0; i < comm_size; ++i) {
+        mpi_errno = MPIR_Localcopy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
+                                   (char *)tmp_buf0 + (mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
+        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+    }
+    buf0_was_inout = 1;
+
+    send_offset = 0;
+    recv_offset = 0;
+    size = total_count;
+    for (k = 0; k < log2_comm_size; ++k) {
+        /* use a double-buffering scheme to avoid local copies */
+        char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
+        char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
+        int peer = rank ^ (0x1 << k);
+        size /= 2;
+
+        if (rank > peer) {
+            /* we have the higher rank: send top half, recv bottom half */
+            recv_offset += size;
+        }
+        else {
+            /* we have the lower rank: recv top half, send bottom half */
+            send_offset += size;
+        }
+
+        mpi_errno = MPIC_Sendrecv(outgoing_data + send_offset*true_extent,
+                                  size, datatype, peer, MPIR_REDUCE_SCATTER_BLOCK_TAG,
+                                  incoming_data + recv_offset*true_extent,
+                                  size, datatype, peer, MPIR_REDUCE_SCATTER_BLOCK_TAG,
+                                  comm, MPI_STATUS_IGNORE);
+        /* always perform the reduction at recv_offset, the data at send_offset
+           is now our peer's responsibility */
+        if (rank > peer) {
+            /* higher ranked value so need to call op(received_data, my_data) */
+            call_uop(incoming_data + recv_offset*true_extent,
+                     outgoing_data + recv_offset*true_extent,
+                     size, datatype);
+            buf0_was_inout = buf0_was_inout;
+        }
+        else {
+            /* lower ranked value so need to call op(my_data, received_data) */
+            call_uop(outgoing_data + recv_offset*true_extent,
+                     incoming_data + recv_offset*true_extent,
+                     size, datatype);
+            buf0_was_inout = !buf0_was_inout;
+        }
+
+        /* the next round of send/recv needs to happen within the block (of size
+           "size") that we just received and reduced */
+        send_offset = recv_offset;
+    }
+
+    MPIU_Assert(size == recvcount);
+
+    /* copy the reduced data to the recvbuf */
+    result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
+    mpi_errno = MPIR_Localcopy(result_ptr, size, datatype,
+                               recvbuf, size, datatype);
+fn_exit:
+    MPIU_CHKLMEM_FREEALL();
+    return mpi_errno;
+fn_fail:
+    goto fn_exit;
+}
+
+/* This is the default implementation of reduce_scatter. The algorithm is:
+
+   Algorithm: MPI_Reduce_scatter
+
+   If the operation is commutative, for short and medium-size
+   messages, we use a recursive-halving
+   algorithm in which the first p/2 processes send the second n/2 data
+   to their counterparts in the other half and receive the first n/2
+   data from them. This procedure continues recursively, halving the
+   data communicated at each step, for a total of lgp steps. If the
+   number of processes is not a power-of-two, we convert it to the
+   nearest lower power-of-two by having the first few even-numbered
+   processes send their data to the neighboring odd-numbered process
+   at (rank+1). Those odd-numbered processes compute the result for
+   their left neighbor as well in the recursive halving algorithm, and
+   then at  the end send the result back to the processes that didn't
+   participate.
+   Therefore, if p is a power-of-two,
+   Cost = lgp.alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma
+   If p is not a power-of-two,
+   Cost = (floor(lgp)+2).alpha + n.(1+(p-1+n)/p).beta + n.(1+(p-1)/p).gamma
+   The above cost in the non power-of-two case is approximate because
+   there is some imbalance in the amount of work each process does
+   because some processes do the work of their neighbors as well.
+
+   For commutative operations and very long messages we use
+   we use a pairwise exchange algorithm similar to
+   the one used in MPI_Alltoall. At step i, each process sends n/p
+   amount of data to (rank+i) and receives n/p amount of data from
+   (rank-i).
+   Cost = (p-1).alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma
+
+
+   If the operation is not commutative, we do the following:
+
+   We use a recursive doubling algorithm, which
+   takes lgp steps. At step 1, processes exchange (n-n/p) amount of
+   data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
+   amount of data, and so forth.
+
+   Cost = lgp.alpha + n.(lgp-(p-1)/p).beta + n.(lgp-(p-1)/p).gamma
+
+   Possible improvements:
+
+   End Algorithm: MPI_Reduce_scatter
+*/
+
+#undef FUNCNAME
+#define FUNCNAME MPIR_Reduce_scatter_block
+#undef FCNAME
+#define FCNAME MPIU_QUOTE(FUNCNAME)
+/* begin:nested */
+/* not declared static because a machine-specific function may call this one in some cases */
+int MPIR_Reduce_scatter_block ( 
+    void *sendbuf, 
+    void *recvbuf, 
+    int recvcount, 
+    MPI_Datatype datatype, 
+    MPI_Op op, 
+    MPID_Comm *comm_ptr )
+{
+    int   rank, comm_size, i;
+    MPI_Aint extent, true_extent, true_lb; 
+    int  *disps;
+    void *tmp_recvbuf, *tmp_results;
+    int   mpi_errno = MPI_SUCCESS;
+    int type_size, dis[2], blklens[2], total_count, nbytes, src, dst;
+    int mask, dst_tree_root, my_tree_root, j, k;
+    int *newcnts, *newdisps, rem, newdst, send_idx, recv_idx,
+        last_idx, send_cnt, recv_cnt;
+    int pof2, old_i, newrank, received;
+    MPI_Datatype sendtype, recvtype;
+    int nprocs_completed, tmp_mask, tree_root, is_commutative;
+    MPI_User_function *uop;
+    MPID_Op *op_ptr;
+    MPI_Comm comm;
+    MPIU_THREADPRIV_DECL;
+#ifdef HAVE_CXX_BINDING
+    int is_cxx_uop = 0;
+#endif
+    MPIU_CHKLMEM_DECL(5);
+
+    comm = comm_ptr->handle;
+    comm_size = comm_ptr->local_size;
+    rank = comm_ptr->rank;
+
+    /* set op_errno to 0. stored in perthread structure */
+    MPIU_THREADPRIV_GET;
+    MPIU_THREADPRIV_FIELD(op_errno) = 0;
+
+    MPIR_Nest_incr();
+
+    if (recvcount == 0) {
+        goto fn_exit;
+    }
+
+    MPID_Datatype_get_extent_macro(datatype, extent);
+    mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb,
+                                          &true_extent);  
+    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+    
+    if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) {
+        is_commutative = 1;
+        /* get the function by indexing into the op table */
+        uop = MPIR_Op_table[op%16 - 1];
+    }
+    else {
+        MPID_Op_get_ptr(op, op_ptr);
+        if (op_ptr->kind == MPID_OP_USER_NONCOMMUTE)
+            is_commutative = 0;
+        else
+            is_commutative = 1;
+
+#ifdef HAVE_CXX_BINDING            
+	if (op_ptr->language == MPID_LANG_CXX) {
+	    uop = (MPI_User_function *) op_ptr->function.c_function;
+	    is_cxx_uop = 1;
+	}
+	else
+#endif
+        if ((op_ptr->language == MPID_LANG_C))
+            uop = (MPI_User_function *) op_ptr->function.c_function;
+        else
+            uop = (MPI_User_function *) op_ptr->function.f77_function;
+    }
+
+    MPIU_CHKLMEM_MALLOC(disps, int *, comm_size * sizeof(int), mpi_errno, "disps");
+
+    total_count = comm_size*recvcount;
+    for (i=0; i<comm_size; i++) {
+        disps[i] = i*recvcount;
+    }
+    
+    MPID_Datatype_get_size_macro(datatype, type_size);
+    nbytes = total_count * type_size;
+    
+    /* check if multiple threads are calling this collective function */
+    MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
+
+    /* total_count*extent eventually gets malloced. it isn't added to
+     * a user-passed in buffer */
+    MPID_Ensure_Aint_fits_in_pointer(total_count * MPIR_MAX(true_extent, extent));
+
+    if ((is_commutative) && (nbytes < MPIR_REDSCAT_COMMUTATIVE_LONG_MSG)) {
+        /* commutative and short. use recursive halving algorithm */
+
+        /* allocate temp. buffer to receive incoming data */
+        MPIU_CHKLMEM_MALLOC(tmp_recvbuf, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_recvbuf");
+        /* adjust for potential negative lower bound in datatype */
+        tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
+            
+        /* need to allocate another temporary buffer to accumulate
+           results because recvbuf may not be big enough */
+        MPIU_CHKLMEM_MALLOC(tmp_results, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_results");
+        /* adjust for potential negative lower bound in datatype */
+        tmp_results = (void *)((char*)tmp_results - true_lb);
+        
+        /* copy sendbuf into tmp_results */
+        if (sendbuf != MPI_IN_PLACE)
+            mpi_errno = MPIR_Localcopy(sendbuf, total_count, datatype,
+                                       tmp_results, total_count, datatype);
+        else
+            mpi_errno = MPIR_Localcopy(recvbuf, total_count, datatype,
+                                       tmp_results, total_count, datatype);
+        
+        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+
+        pof2 = 1;
+        while (pof2 <= comm_size) pof2 <<= 1;
+        pof2 >>=1;
+
+        rem = comm_size - pof2;
+
+        /* In the non-power-of-two case, all even-numbered
+           processes of rank < 2*rem send their data to
+           (rank+1). These even-numbered processes no longer
+           participate in the algorithm until the very end. The
+           remaining processes form a nice power-of-two. */
+
+        if (rank < 2*rem) {
+            if (rank % 2 == 0) { /* even */
+                mpi_errno = MPIC_Send(tmp_results, total_count, 
+                                      datatype, rank+1,
+                                      MPIR_REDUCE_SCATTER_BLOCK_TAG, comm);
+		/* --BEGIN ERROR HANDLING-- */
+                if (mpi_errno)
+		{
+		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+		    return mpi_errno;
+		}
+		/* --END ERROR HANDLING-- */
+                
+                /* temporarily set the rank to -1 so that this
+                   process does not pariticipate in recursive
+                   doubling */
+                newrank = -1; 
+            }
+            else { /* odd */
+                mpi_errno = MPIC_Recv(tmp_recvbuf, total_count, 
+                                      datatype, rank-1,
+                                      MPIR_REDUCE_SCATTER_BLOCK_TAG, comm,
+                                      MPI_STATUS_IGNORE);
+		/* --BEGIN ERROR HANDLING-- */
+                if (mpi_errno)
+		{
+		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+		    return mpi_errno;
+		}
+		/* --END ERROR HANDLING-- */
+                
+                /* do the reduction on received data. since the
+                   ordering is right, it doesn't matter whether
+                   the operation is commutative or not. */
+#ifdef HAVE_CXX_BINDING
+                if (is_cxx_uop) {
+                    (*MPIR_Process.cxx_call_op_fn)( tmp_recvbuf, tmp_results, 
+                                                    total_count,
+                                                    datatype,
+                                                    uop ); 
+                }
+                else 
+#endif
+                    (*uop)(tmp_recvbuf, tmp_results, &total_count, &datatype);
+                
+                /* change the rank */
+                newrank = rank / 2;
+            }
+        }
+        else  /* rank >= 2*rem */
+            newrank = rank - rem;
+
+        if (newrank != -1) {
+            /* recalculate the recvcnts and disps arrays because the
+               even-numbered processes who no longer participate will
+               have their result calculated by the process to their
+               right (rank+1). */
+
+            MPIU_CHKLMEM_MALLOC(newcnts, int *, pof2*sizeof(int), mpi_errno, "newcnts");
+            MPIU_CHKLMEM_MALLOC(newdisps, int *, pof2*sizeof(int), mpi_errno, "newdisps");
+            
+            for (i=0; i<pof2; i++) {
+                /* what does i map to in the old ranking? */
+                old_i = (i < rem) ? i*2 + 1 : i + rem;
+                if (old_i < 2*rem) {
+                    /* This process has to also do its left neighbor's
+                       work */
+                    newcnts[i] = 2 * recvcount;
+                }
+                else
+                    newcnts[i] = recvcount;
+            }
+            
+            newdisps[0] = 0;
+            for (i=1; i<pof2; i++)
+                newdisps[i] = newdisps[i-1] + newcnts[i-1];
+
+            mask = pof2 >> 1;
+            send_idx = recv_idx = 0;
+            last_idx = pof2;
+            while (mask > 0) {
+                newdst = newrank ^ mask;
+                /* find real rank of dest */
+                dst = (newdst < rem) ? newdst*2 + 1 : newdst + rem;
+                
+                send_cnt = recv_cnt = 0;
+                if (newrank < newdst) {
+                    send_idx = recv_idx + mask;
+                    for (i=send_idx; i<last_idx; i++)
+                        send_cnt += newcnts[i];
+                    for (i=recv_idx; i<send_idx; i++)
+                        recv_cnt += newcnts[i];
+                }
+                else {
+                    recv_idx = send_idx + mask;
+                    for (i=send_idx; i<recv_idx; i++)
+                        send_cnt += newcnts[i];
+                    for (i=recv_idx; i<last_idx; i++)
+                        recv_cnt += newcnts[i];
+                }
+                
+/*                    printf("Rank %d, send_idx %d, recv_idx %d, send_cnt %d, recv_cnt %d, last_idx %d\n", newrank, send_idx, recv_idx,
+                      send_cnt, recv_cnt, last_idx);
+*/
+                /* Send data from tmp_results. Recv into tmp_recvbuf */ 
+                if ((send_cnt != 0) && (recv_cnt != 0)) 
+                    mpi_errno = MPIC_Sendrecv((char *) tmp_results +
+                                          newdisps[send_idx]*extent,
+                                          send_cnt, datatype,  
+                                          dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, 
+                                          (char *) tmp_recvbuf +
+                                          newdisps[recv_idx]*extent,
+                                          recv_cnt, datatype, dst,
+                                          MPIR_REDUCE_SCATTER_BLOCK_TAG, comm,
+                                          MPI_STATUS_IGNORE); 
+                else if ((send_cnt == 0) && (recv_cnt != 0))
+                    mpi_errno = MPIC_Recv((char *) tmp_recvbuf +
+                                          newdisps[recv_idx]*extent,
+                                          recv_cnt, datatype, dst,
+                                          MPIR_REDUCE_SCATTER_BLOCK_TAG, comm,
+                                          MPI_STATUS_IGNORE);
+                else if ((recv_cnt == 0) && (send_cnt != 0))
+                    mpi_errno = MPIC_Send((char *) tmp_results +
+                                          newdisps[send_idx]*extent,
+                                          send_cnt, datatype,  
+                                          dst, MPIR_REDUCE_SCATTER_BLOCK_TAG,
+                                          comm);  
+
+                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+                
+                /* tmp_recvbuf contains data received in this step.
+                   tmp_results contains data accumulated so far */
+                
+                if (recv_cnt) {
+#ifdef HAVE_CXX_BINDING
+                    if (is_cxx_uop) {
+                        (*MPIR_Process.cxx_call_op_fn)((char *) tmp_recvbuf +
+                                                   newdisps[recv_idx]*extent,
+                                                   (char *) tmp_results + 
+                                                   newdisps[recv_idx]*extent, 
+                                                   recv_cnt, datatype, uop);
+                    }
+                    else 
+#endif
+                        (*uop)((char *) tmp_recvbuf + newdisps[recv_idx]*extent,
+                             (char *) tmp_results + newdisps[recv_idx]*extent, 
+                               &recv_cnt, &datatype);
+                }
+
+                /* update send_idx for next iteration */
+                send_idx = recv_idx;
+                last_idx = recv_idx + mask;
+                mask >>= 1;
+            }
+
+            /* copy this process's result from tmp_results to recvbuf */
+            mpi_errno = MPIR_Localcopy((char *)tmp_results +
+                                       disps[rank]*extent, 
+                                       recvcount, datatype, recvbuf,
+                                       recvcount, datatype);
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+        }
+
+        /* In the non-power-of-two case, all odd-numbered
+           processes of rank < 2*rem send to (rank-1) the result they
+           calculated for that process */
+        if (rank < 2*rem) {
+            if (rank % 2) { /* odd */
+                mpi_errno = MPIC_Send((char *) tmp_results +
+                                      disps[rank-1]*extent, recvcount,
+                                      datatype, rank-1,
+                                      MPIR_REDUCE_SCATTER_BLOCK_TAG, comm);
+            }
+            else  {   /* even */
+                mpi_errno = MPIC_Recv(recvbuf, recvcount,
+                                      datatype, rank+1,
+                                      MPIR_REDUCE_SCATTER_BLOCK_TAG, comm,
+                                      MPI_STATUS_IGNORE); 
+            }
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+        }
+    }
+    
+    if (is_commutative && (nbytes >= MPIR_REDSCAT_COMMUTATIVE_LONG_MSG)) {
+
+        /* commutative and long message, or noncommutative and long message.
+           use (p-1) pairwise exchanges */ 
+        
+        if (sendbuf != MPI_IN_PLACE) {
+            /* copy local data into recvbuf */
+            mpi_errno = MPIR_Localcopy(((char *)sendbuf+disps[rank]*extent),
+                                       recvcount, datatype, recvbuf,
+                                       recvcount, datatype);
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+        }
+        
+        /* allocate temporary buffer to store incoming data */
+        MPIU_CHKLMEM_MALLOC(tmp_recvbuf, void *, recvcount*(MPIR_MAX(true_extent,extent))+1, mpi_errno, "tmp_recvbuf");
+        /* adjust for potential negative lower bound in datatype */
+        tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
+        
+        for (i=1; i<comm_size; i++) {
+            src = (rank - i + comm_size) % comm_size;
+            dst = (rank + i) % comm_size;
+            
+            /* send the data that dst needs. recv data that this process
+               needs from src into tmp_recvbuf */
+            if (sendbuf != MPI_IN_PLACE) 
+                mpi_errno = MPIC_Sendrecv(((char *)sendbuf+disps[dst]*extent), 
+                                          recvcount, datatype, dst,
+                                          MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf,
+                                          recvcount, datatype, src,
+                                          MPIR_REDUCE_SCATTER_BLOCK_TAG, comm,
+                                          MPI_STATUS_IGNORE);
+            else
+                mpi_errno = MPIC_Sendrecv(((char *)recvbuf+disps[dst]*extent), 
+                                          recvcount, datatype, dst,
+                                          MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf,
+                                          recvcount, datatype, src,
+                                          MPIR_REDUCE_SCATTER_BLOCK_TAG, comm,
+                                          MPI_STATUS_IGNORE);
+            
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+            
+            if (is_commutative || (src < rank)) {
+                if (sendbuf != MPI_IN_PLACE) {
+#ifdef HAVE_CXX_BINDING
+                    if (is_cxx_uop) {
+                        (*MPIR_Process.cxx_call_op_fn)(tmp_recvbuf, 
+                                                       recvbuf, 
+                                                       recvcount, 
+                                                       datatype, uop );
+                    }
+                    else 
+#endif
+                        (*uop)(tmp_recvbuf, recvbuf, &recvcount, 
+                               &datatype); 
+                }
+                else {
+#ifdef HAVE_CXX_BINDING
+                    if (is_cxx_uop) {
+                        (*MPIR_Process.cxx_call_op_fn)( tmp_recvbuf, 
+                                                        ((char *)recvbuf+disps[rank]*extent), 
+                                                        recvcount, datatype, uop ); 
+                    }
+                    else 
+#endif
+                        (*uop)(tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
+                               &recvcount, &datatype); 
+                    /* we can't store the result at the beginning of
+                       recvbuf right here because there is useful data
+                       there that other process/processes need. at the
+                       end, we will copy back the result to the
+                       beginning of recvbuf. */
+                }
+            }
+            else {
+                if (sendbuf != MPI_IN_PLACE) {
+#ifdef HAVE_CXX_BINDING
+                    if (is_cxx_uop) {
+                        (*MPIR_Process.cxx_call_op_fn)( recvbuf, 
+                                                        tmp_recvbuf, 
+                                                        recvcount, 
+                                                        datatype, uop );
+                    }
+                    else 
+#endif
+                        (*uop)(recvbuf, tmp_recvbuf, &recvcount, &datatype); 
+                    /* copy result back into recvbuf */
+                    mpi_errno = MPIR_Localcopy(tmp_recvbuf, recvcount, 
+                                               datatype, recvbuf,
+                                               recvcount, datatype); 
+                }
+                else {
+#ifdef HAVE_CXX_BINDING
+                    if (is_cxx_uop) {
+                        (*MPIR_Process.cxx_call_op_fn)( 
+                            ((char *)recvbuf+disps[rank]*extent),
+                            tmp_recvbuf, recvcount, datatype, uop );   
+                        
+                    }
+                    else 
+#endif
+                        (*uop)(((char *)recvbuf+disps[rank]*extent),
+                               tmp_recvbuf, &recvcount, &datatype);   
+                    /* copy result back into recvbuf */
+                    mpi_errno = MPIR_Localcopy(tmp_recvbuf, recvcount, 
+                                               datatype, 
+                                               ((char *)recvbuf +
+                                                disps[rank]*extent), 
+                                               recvcount, datatype); 
+                }
+                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+            }
+        }
+        
+        /* if MPI_IN_PLACE, move output data to the beginning of
+           recvbuf. already done for rank 0. */
+        if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
+            mpi_errno = MPIR_Localcopy(((char *)recvbuf +
+                                        disps[rank]*extent),  
+                                       recvcount, datatype, 
+                                       recvbuf, 
+                                       recvcount, datatype); 
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+        }
+    }
+    
+    if (!is_commutative) {
+
+        /* power of two check */
+        if (!(comm_size & (comm_size - 1))) {
+            /* noncommutative, pof2 size */
+            mpi_errno = MPIR_Reduce_scatter_block_noncomm(sendbuf, recvbuf, recvcount, datatype, op, comm_ptr);
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+        }
+        else {
+            /* noncommutative and non-pof2, use recursive doubling. */
+
+            /* need to allocate temporary buffer to receive incoming data*/
+            MPIU_CHKLMEM_MALLOC(tmp_recvbuf, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_recvbuf");
+            /* adjust for potential negative lower bound in datatype */
+            tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
+
+            /* need to allocate another temporary buffer to accumulate
+               results */
+            MPIU_CHKLMEM_MALLOC(tmp_results, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_results");
+            /* adjust for potential negative lower bound in datatype */
+            tmp_results = (void *)((char*)tmp_results - true_lb);
+
+            /* copy sendbuf into tmp_results */
+            if (sendbuf != MPI_IN_PLACE)
+                mpi_errno = MPIR_Localcopy(sendbuf, total_count, datatype,
+                                           tmp_results, total_count, datatype);
+            else
+                mpi_errno = MPIR_Localcopy(recvbuf, total_count, datatype,
+                                           tmp_results, total_count, datatype);
+
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+
+            mask = 0x1;
+            i = 0;
+            while (mask < comm_size) {
+                dst = rank ^ mask;
+
+                dst_tree_root = dst >> i;
+                dst_tree_root <<= i;
+
+                my_tree_root = rank >> i;
+                my_tree_root <<= i;
+
+                /* At step 1, processes exchange (n-n/p) amount of
+                   data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
+                   amount of data, and so forth. We use derived datatypes for this.
+
+                   At each step, a process does not need to send data
+                   indexed from my_tree_root to
+                   my_tree_root+mask-1. Similarly, a process won't receive
+                   data indexed from dst_tree_root to dst_tree_root+mask-1. */
+
+                /* calculate sendtype */
+                blklens[0] = blklens[1] = 0;
+                for (j=0; j<my_tree_root; j++)
+                    blklens[0] += recvcount;
+                for (j=my_tree_root+mask; j<comm_size; j++)
+                    blklens[1] += recvcount;
+
+                dis[0] = 0;
+                dis[1] = blklens[0];
+                for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
+                    dis[1] += recvcount;
+
+                NMPI_Type_indexed(2, blklens, dis, datatype, &sendtype);
+                NMPI_Type_commit(&sendtype);
+
+                /* calculate recvtype */
+                blklens[0] = blklens[1] = 0;
+                for (j=0; j<dst_tree_root && j<comm_size; j++)
+                    blklens[0] += recvcount;
+                for (j=dst_tree_root+mask; j<comm_size; j++)
+                    blklens[1] += recvcount;
+
+                dis[0] = 0;
+                dis[1] = blklens[0];
+                for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
+                    dis[1] += recvcount;
+
+                NMPI_Type_indexed(2, blklens, dis, datatype, &recvtype);
+                NMPI_Type_commit(&recvtype);
+
+                received = 0;
+                if (dst < comm_size) {
+                    /* tmp_results contains data to be sent in each step. Data is
+                       received in tmp_recvbuf and then accumulated into
+                       tmp_results. accumulation is done later below.   */ 
+
+                    mpi_errno = MPIC_Sendrecv(tmp_results, 1, sendtype, dst,
+                                              MPIR_REDUCE_SCATTER_BLOCK_TAG, 
+                                              tmp_recvbuf, 1, recvtype, dst,
+                                              MPIR_REDUCE_SCATTER_BLOCK_TAG, comm,
+                                              MPI_STATUS_IGNORE); 
+                    received = 1;
+                    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+                }
+
+                /* if some processes in this process's subtree in this step
+                   did not have any destination process to communicate with
+                   because of non-power-of-two, we need to send them the
+                   result. We use a logarithmic recursive-halfing algorithm
+                   for this. */
+
+                if (dst_tree_root + mask > comm_size) {
+                    nprocs_completed = comm_size - my_tree_root - mask;
+                    /* nprocs_completed is the number of processes in this
+                       subtree that have all the data. Send data to others
+                       in a tree fashion. First find root of current tree
+                       that is being divided into two. k is the number of
+                       least-significant bits in this process's rank that
+                       must be zeroed out to find the rank of the root */ 
+                    j = mask;
+                    k = 0;
+                    while (j) {
+                        j >>= 1;
+                        k++;
+                    }
+                    k--;
+
+                    tmp_mask = mask >> 1;
+                    while (tmp_mask) {
+                        dst = rank ^ tmp_mask;
+
+                        tree_root = rank >> k;
+                        tree_root <<= k;
+
+                        /* send only if this proc has data and destination
+                           doesn't have data. at any step, multiple processes
+                           can send if they have the data */
+                        if ((dst > rank) && 
+                            (rank < tree_root + nprocs_completed)
+                            && (dst >= tree_root + nprocs_completed)) {
+                            /* send the current result */
+                            mpi_errno = MPIC_Send(tmp_recvbuf, 1, recvtype,
+                                                  dst, MPIR_REDUCE_SCATTER_BLOCK_TAG,
+                                                  comm);  
+                            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+                        }
+                        /* recv only if this proc. doesn't have data and sender
+                           has data */
+                        else if ((dst < rank) && 
+                                 (dst < tree_root + nprocs_completed) &&
+                                 (rank >= tree_root + nprocs_completed)) {
+                            mpi_errno = MPIC_Recv(tmp_recvbuf, 1, recvtype, dst,
+                                                  MPIR_REDUCE_SCATTER_BLOCK_TAG,
+                                                  comm, MPI_STATUS_IGNORE); 
+                            received = 1;
+                            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+                        }
+                        tmp_mask >>= 1;
+                        k--;
+                    }
+                }
+
+                /* The following reduction is done here instead of after 
+                   the MPIC_Sendrecv or MPIC_Recv above. This is
+                   because to do it above, in the noncommutative 
+                   case, we would need an extra temp buffer so as not to
+                   overwrite temp_recvbuf, because temp_recvbuf may have
+                   to be communicated to other processes in the
+                   non-power-of-two case. To avoid that extra allocation,
+                   we do the reduce here. */
+                if (received) {
+                    if (is_commutative || (dst_tree_root < my_tree_root)) {
+#ifdef HAVE_CXX_BINDING
+                        if (is_cxx_uop) {
+                            (*MPIR_Process.cxx_call_op_fn)( tmp_recvbuf, 
+                                                            tmp_results, blklens[0],
+                                                            datatype, uop); 
+                            (*MPIR_Process.cxx_call_op_fn)( 
+                                ((char *)tmp_recvbuf + dis[1]*extent),
+                                ((char *)tmp_results + dis[1]*extent),
+                                blklens[1], datatype, uop ); 
+                        }
+                        else
+#endif
+                        {
+                            (*uop)(tmp_recvbuf, tmp_results, &blklens[0],
+                                   &datatype); 
+                            (*uop)(((char *)tmp_recvbuf + dis[1]*extent),
+                                   ((char *)tmp_results + dis[1]*extent),
+                                   &blklens[1], &datatype); 
+                        }
+                    }
+                    else {
+#ifdef HAVE_CXX_BINDING
+                        if (is_cxx_uop) {
+                            (*MPIR_Process.cxx_call_op_fn)( tmp_results, 
+                                                            tmp_recvbuf, blklens[0],
+                                                            datatype, uop ); 
+                            (*MPIR_Process.cxx_call_op_fn)( 
+                                ((char *)tmp_results + dis[1]*extent),
+                                ((char *)tmp_recvbuf + dis[1]*extent),
+                                blklens[1], datatype, uop ); 
+                        }
+                        else 
+#endif
+                        {
+                            (*uop)(tmp_results, tmp_recvbuf, &blklens[0],
+                                   &datatype); 
+                            (*uop)(((char *)tmp_results + dis[1]*extent),
+                                   ((char *)tmp_recvbuf + dis[1]*extent),
+                                   &blklens[1], &datatype); 
+                        }
+                        /* copy result back into tmp_results */
+                        mpi_errno = MPIR_Localcopy(tmp_recvbuf, 1, recvtype, 
+                                                   tmp_results, 1, recvtype);
+                        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+                    }
+                }
+
+                NMPI_Type_free(&sendtype);
+                NMPI_Type_free(&recvtype);
+
+                mask <<= 1;
+                i++;
+            }
+
+            /* now copy final results from tmp_results to recvbuf */
+            mpi_errno = MPIR_Localcopy(((char *)tmp_results+disps[rank]*extent),
+                                       recvcount, datatype, recvbuf,
+                                       recvcount, datatype); 
+            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+        }
+    }
+
+fn_exit:
+    MPIU_CHKLMEM_FREEALL();
+
+    MPIR_Nest_decr();
+    /* check if multiple threads are calling this collective function */
+    MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr );
+
+    if (MPIU_THREADPRIV_FIELD(op_errno)) 
+	mpi_errno = MPIU_THREADPRIV_FIELD(op_errno);
+
+    return (mpi_errno);
+fn_fail:
+    goto fn_exit;
+}
+/* end:nested */
+
+#undef FUNCNAME
+#define FUNCNAME MPIR_Reduce_scatter_block
+#undef FCNAME
+#define FCNAME MPIU_QUOTE(FUNCNAME)
+/* begin:nested */
+/* not declared static because a machine-specific function may call this one in some cases */
+int MPIR_Reduce_scatter_block_inter ( 
+    void *sendbuf, 
+    void *recvbuf, 
+    int recvcount, 
+    MPI_Datatype datatype, 
+    MPI_Op op, 
+    MPID_Comm *comm_ptr )
+{
+/* Intercommunicator Reduce_scatter_block.
+   We first do an intercommunicator reduce to rank 0 on left group,
+   then an intercommunicator reduce to rank 0 on right group, followed
+   by local intracommunicator scattervs in each group.
+*/
+    
+    int rank, mpi_errno, root, local_size, total_count, i;
+    MPI_Aint true_extent, true_lb = 0, extent;
+    void *tmp_buf=NULL;
+    MPID_Comm *newcomm_ptr = NULL;
+
+    rank = comm_ptr->rank;
+    local_size = comm_ptr->local_size;
+
+    total_count = local_size * recvcount;
+
+    if (rank == 0) {
+        /* In each group, rank 0 allocates a temp. buffer for the 
+           reduce */
+
+        mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb,
+                                              &true_extent);  
+	/* --BEGIN ERROR HANDLING-- */
+        if (mpi_errno)
+	{
+	    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+	    return mpi_errno;
+	}
+	/* --END ERROR HANDLING-- */
+        MPID_Datatype_get_extent_macro(datatype, extent);
+
+        tmp_buf = MPIU_Malloc(total_count*(MPIR_MAX(extent,true_extent)));
+	/* --BEGIN ERROR HANDLING-- */
+        if (!tmp_buf) {
+            mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 );
+            return mpi_errno;
+        }
+	/* --END ERROR HANDLING-- */
+        /* adjust for potential negative lower bound in datatype */
+        tmp_buf = (void *)((char*)tmp_buf - true_lb);
+    }
+
+    /* first do a reduce from right group to rank 0 in left group,
+       then from left group to rank 0 in right group*/
+    if (comm_ptr->is_low_group) {
+        /* reduce from right group to rank 0*/
+        root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
+        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
+                                root, comm_ptr);  
+	/* --BEGIN ERROR HANDLING-- */
+        if (mpi_errno)
+	{
+	    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+	    return mpi_errno;
+	}
+	/* --END ERROR HANDLING-- */
+
+        /* reduce to rank 0 of right group */
+        root = 0;
+        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
+                                root, comm_ptr);  
+	/* --BEGIN ERROR HANDLING-- */
+        if (mpi_errno)
+	{
+	    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+	    return mpi_errno;
+	}
+	/* --END ERROR HANDLING-- */
+    }
+    else {
+        /* reduce to rank 0 of left group */
+        root = 0;
+        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
+                                root, comm_ptr);  
+	/* --BEGIN ERROR HANDLING-- */
+        if (mpi_errno)
+	{
+	    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+	    return mpi_errno;
+	}
+	/* --END ERROR HANDLING-- */
+
+        /* reduce from right group to rank 0 */
+        root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
+        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
+                                root, comm_ptr);  
+	/* --BEGIN ERROR HANDLING-- */
+        if (mpi_errno)
+	{
+	    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+	    return mpi_errno;
+	}
+	/* --END ERROR HANDLING-- */
+    }
+
+    /* Get the local intracommunicator */
+    if (!comm_ptr->local_comm)
+	MPIR_Setup_intercomm_localcomm( comm_ptr );
+
+    newcomm_ptr = comm_ptr->local_comm;
+
+    mpi_errno = MPIR_Scatter(tmp_buf, recvcount, datatype, recvbuf,
+                              recvcount, datatype, 0, newcomm_ptr);
+    /* --BEGIN ERROR HANDLING-- */
+    if (mpi_errno)
+    {
+	mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
+	return mpi_errno;
+    }
+    /* --END ERROR HANDLING-- */
+    
+    if (rank == 0) {
+        MPIU_Free((char*)tmp_buf+true_lb);
+    }
+
+    return mpi_errno;
+
+}
+/* end:nested */
+#endif
+
+#undef FUNCNAME
 #define FUNCNAME MPI_Reduce_scatter_block
 #undef FCNAME
 #define FCNAME MPIU_QUOTE(FUNCNAME)
 /*@
+
 MPI_Reduce_scatter_block - Combines values and scatters the results
 
 Input Parameters:
-+ sendbuf - starting address of send buffer (choice)
++ sendbuf - starting address of send buffer (choice) 
 . recvcount - element count per block (non-negative integer)
-. datatype - datatype of elements of send and receive buffers (handle)
-. op - operation (handle)
-- comm - communicator (handle)
+. datatype - data type of elements of input buffer (handle) 
+. op - operation (handle) 
+- comm - communicator (handle) 
 
 Output Parameter:
-. recvbuf - starting address of receive buffer (choice)
+. recvbuf - starting address of receive buffer (choice) 
 
 .N ThreadSafe
 
@@ -58,20 +1125,16 @@
 .N MPI_ERR_OP
 .N MPI_ERR_BUFFER_ALIAS
 @*/
-int MPI_Reduce_scatter_block(void *sendbuf, void *recvbuf, int recvcount,
-                             MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
+int MPI_Reduce_scatter_block(void *sendbuf, void *recvbuf, int recvcount, 
+		       MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
 {
     int mpi_errno = MPI_SUCCESS;
     MPID_Comm *comm_ptr = NULL;
-    MPID_Comm *scatter_comm_ptr = NULL;
-    void *tmp_buf;
-    MPI_Aint extent, true_extent, true_lb;
     MPIU_THREADPRIV_DECL;
-    MPIU_CHKLMEM_DECL(1);
     MPID_MPI_STATE_DECL(MPID_STATE_MPI_REDUCE_SCATTER_BLOCK);
 
     MPIR_ERRTEST_INITIALIZED_ORDIE();
-
+    
     MPIU_THREAD_CS_ENTER(ALLFUNC,);
     MPID_MPI_COLL_FUNC_ENTER(MPID_STATE_MPI_REDUCE_SCATTER_BLOCK);
 
@@ -80,55 +1143,55 @@
     {
         MPID_BEGIN_ERROR_CHECKS;
         {
-            MPIR_ERRTEST_COMM(comm, mpi_errno);
+	    MPIR_ERRTEST_COMM(comm, mpi_errno);
             if (mpi_errno != MPI_SUCCESS) goto fn_fail;
-        }
+	}
         MPID_END_ERROR_CHECKS;
     }
 #   endif /* HAVE_ERROR_CHECKING */
 
     /* Convert MPI object handles to object pointers */
-    MPID_Comm_get_ptr(comm, comm_ptr);
+    MPID_Comm_get_ptr( comm, comm_ptr );
 
     /* Validate parameters and objects (post conversion) */
 #   ifdef HAVE_ERROR_CHECKING
     {
         MPID_BEGIN_ERROR_CHECKS;
         {
-            MPID_Datatype *datatype_ptr = NULL;
+	    MPID_Datatype *datatype_ptr = NULL;
             MPID_Op *op_ptr = NULL;
-            int size;
-
-            MPID_Comm_valid_ptr(comm_ptr, mpi_errno);
+	    
+            MPID_Comm_valid_ptr( comm_ptr, mpi_errno );
             if (mpi_errno != MPI_SUCCESS) goto fn_fail;
 
-            size = comm_ptr->local_size;
+            MPIR_ERRTEST_COUNT(recvcount,mpi_errno);
 
-            MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno);
+	    MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno);
             if (HANDLE_GET_KIND(datatype) != HANDLE_KIND_BUILTIN) {
                 MPID_Datatype_get_ptr(datatype, datatype_ptr);
-                MPID_Datatype_valid_ptr(datatype_ptr, mpi_errno);
-                MPID_Datatype_committed_ptr(datatype_ptr, mpi_errno);
+                MPID_Datatype_valid_ptr( datatype_ptr, mpi_errno );
+                MPID_Datatype_committed_ptr( datatype_ptr, mpi_errno );
             }
 
             MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, recvcount, mpi_errno);
-            if (comm_ptr->comm_kind == MPID_INTERCOMM)
+	    if (comm_ptr->comm_kind == MPID_INTERCOMM) 
                 MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, recvcount, mpi_errno);
 
             MPIR_ERRTEST_USERBUFFER(recvbuf,recvcount,datatype,mpi_errno);
-            MPIR_ERRTEST_USERBUFFER(sendbuf,recvcount*size,datatype,mpi_errno);
+            MPIR_ERRTEST_USERBUFFER(sendbuf,recvcount,datatype,mpi_errno); 
 
-            MPIR_ERRTEST_OP(op, mpi_errno);
+	    MPIR_ERRTEST_OP(op, mpi_errno);
 
             if (mpi_errno != MPI_SUCCESS) goto fn_fail;
             if (HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) {
                 MPID_Op_get_ptr(op, op_ptr);
-                MPID_Op_valid_ptr(op_ptr, mpi_errno);
+                MPID_Op_valid_ptr( op_ptr, mpi_errno );
             }
             if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) {
-                mpi_errno = (*MPIR_Op_check_dtype_table[op%16 - 1])(datatype);
+                mpi_errno = 
+                    ( * MPIR_Op_check_dtype_table[op%16 - 1] )(datatype); 
             }
-            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+            if (mpi_errno != MPI_SUCCESS) goto fn_fail;
         }
         MPID_END_ERROR_CHECKS;
     }
@@ -136,41 +1199,36 @@
 
     /* ... body of routine ...  */
 
-    /* Use a naive implementation for now (reduce followed by scatter).
-     *
-     * FIXME We should adapt one or more of the existing MPI_Reduce_scatter
-     * algorithms to work here as well. */
-    MPIU_THREADPRIV_GET;
-    MPIR_Nest_incr();
+    if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Reduce_scatter_block != NULL)
+    {
+	mpi_errno = comm_ptr->coll_fns->Reduce_scatter_block(sendbuf, recvbuf,
+                                                       recvcount, datatype, 
+                                                       op, comm_ptr);
+    }
+    else
+    {
+	MPIU_THREADPRIV_GET;
 
-    MPID_Datatype_get_extent_macro(datatype, extent);
-    mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent);
-    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
+	MPIR_Nest_incr();
+        if (comm_ptr->comm_kind == MPID_INTRACOMM) 
+            /* intracommunicator */
+            mpi_errno = MPIR_Reduce_scatter_block(sendbuf, recvbuf,
+                                            recvcount, datatype, 
+                                            op, comm_ptr);
+        else {
+            /* intercommunicator */
+            mpi_errno = MPIR_Reduce_scatter_block_inter(sendbuf, recvbuf,
+                                                  recvcount, datatype, 
+                                                  op, comm_ptr); 
+        }
+	MPIR_Nest_decr();
+    }
 
-    MPIU_CHKLMEM_MALLOC(tmp_buf, void *, true_extent * recvcount * comm_ptr->local_size, mpi_errno, "tmp_buf");
-    tmp_buf = (void *)((char*)tmp_buf - true_lb);
+    if (mpi_errno != MPI_SUCCESS) goto fn_fail;
 
-    mpi_errno = NMPI_Reduce((sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf), tmp_buf,
-                            (recvcount * comm_ptr->local_size), datatype, op,
-                            0/*root*/, comm);
-    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
-
-    scatter_comm_ptr = comm_ptr;
-    if (comm_ptr->comm_kind == MPID_INTERCOMM) {
-        /* Get the local intracommunicator */
-        if (!comm_ptr->local_comm)
-            MPIR_Setup_intercomm_localcomm(comm_ptr);
-        scatter_comm_ptr = comm_ptr->local_comm;
-    }
-    mpi_errno = NMPI_Scatter(tmp_buf, recvcount, datatype,
-                             recvbuf, recvcount, datatype,
-                             0/*root*/, scatter_comm_ptr->handle);
-    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
-
     /* ... end of body of routine ... */
+    
   fn_exit:
-    MPIU_CHKLMEM_FREEALL();
-    MPIR_Nest_decr();
     MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_REDUCE_SCATTER_BLOCK);
     MPIU_THREAD_CS_EXIT(ALLFUNC,);
     return mpi_errno;
@@ -179,14 +1237,12 @@
     /* --BEGIN ERROR HANDLING-- */
 #   ifdef HAVE_ERROR_CHECKING
     {
-        mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__,
-                                         MPI_ERR_OTHER, "**mpi_reduce_scatter_block",
-                                         "**mpi_reduce_scatter_block %p %p %d %D %O %C",
-                                         sendbuf, recvbuf, recvcount, datatype, op, comm);
+	mpi_errno = MPIR_Err_create_code(
+	    mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**mpi_reduce_scatter_block",
+	    "**mpi_reduce_scatter_block %p %p %d %D %O %C", sendbuf, recvbuf, recvcount, datatype, op, comm);
     }
 #   endif
     mpi_errno = MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno );
     goto fn_exit;
     /* --END ERROR HANDLING-- */
 }
-



More information about the mpich2-commits mailing list