[mpich2-commits] r3924 - mpich2/trunk/src/mpi/coll

balaji at mcs.anl.gov balaji at mcs.anl.gov
Tue Mar 3 13:11:53 CST 2009


Author: balaji
Date: 2009-03-03 13:11:53 -0600 (Tue, 03 Mar 2009)
New Revision: 3924

Modified:
   mpich2/trunk/src/mpi/coll/gatherv.c
   mpich2/trunk/src/mpi/coll/scatterv.c
Log:
Change Gatherv/Scatterv to use non-blocking communication. Reviewed by thakur.

Modified: mpich2/trunk/src/mpi/coll/gatherv.c
===================================================================
--- mpich2/trunk/src/mpi/coll/gatherv.c	2009-03-03 17:44:40 UTC (rev 3923)
+++ mpich2/trunk/src/mpi/coll/gatherv.c	2009-03-03 19:11:53 UTC (rev 3924)
@@ -57,99 +57,69 @@
     int        mpi_errno = MPI_SUCCESS;
     MPI_Comm comm;
     MPI_Aint       extent;
-    int            i;
+    int            i, reqs;
+    MPI_Request *reqarray;
+    MPI_Status *starray;
     
     comm = comm_ptr->handle;
     rank = comm_ptr->rank;
     
     /* check if multiple threads are calling this collective function */
     MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
-    
+
     /* If rank == root, then I recv lots, otherwise I send */
-    if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (rank == root)) {
-        /* intracomm root */
-        comm_size = comm_ptr->local_size;
-        MPID_Datatype_get_extent_macro(recvtype, extent);
+    if (((comm_ptr->comm_kind == MPID_INTRACOMM) && (root == rank)) ||
+        ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT))) {
+        if (comm_ptr->comm_kind == MPID_INTRACOMM)
+            comm_size = comm_ptr->local_size;
+        else
+            comm_size = comm_ptr->remote_size;
 
+        MPID_Datatype_get_extent_macro(recvtype, extent);
 	/* each node can make sure it is not going to overflow aint */
         MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
 					 displs[rank] * extent);
 
-        for ( i=0; i<root; i++ ) {
+        reqarray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Request));
+        starray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Status));
+
+        reqs = 0;
+        for (i = 0; i < comm_size; i++) {
             if (recvcnts[i]) {
-                mpi_errno = MPIC_Recv(((char *)recvbuf+displs[i]*extent), 
-                                      recvcnts[i], recvtype, i,
-                                      MPIR_GATHERV_TAG, comm,
-                                      MPI_STATUS_IGNORE);
+                if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (i == rank) &&
+                    (sendbuf != MPI_IN_PLACE)) {
+                    mpi_errno = MPIR_Localcopy(sendbuf, sendcnt, sendtype,
+                                               ((char *)recvbuf+displs[rank]*extent), 
+                                               recvcnts[rank], recvtype);
+                }
+                else {
+                    mpi_errno = MPIC_Irecv(((char *)recvbuf+displs[i]*extent), 
+                                           recvcnts[i], recvtype, i,
+                                           MPIR_GATHERV_TAG, comm,
+                                           &reqarray[reqs++]);
+                }
 		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
+                if (mpi_errno) {
 		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
 		    return mpi_errno;
 		}
 		/* --END ERROR HANDLING-- */
             }
         }
-        if (sendbuf != MPI_IN_PLACE) {
-            if (recvcnts[rank]) {
-		MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
-						 displs[rank]*extent);
-                mpi_errno = MPIR_Localcopy(sendbuf, sendcnt, sendtype,
-                                           ((char *)recvbuf+displs[rank]*extent), 
-                                           recvcnts[rank], recvtype);
-		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
-		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
-		    return mpi_errno;
-		}
-		/* --END ERROR HANDLING-- */
+        /* ... then wait for *all* of them to finish: */
+        mpi_errno = NMPI_Waitall(reqs, reqarray, starray);
+        /* --BEGIN ERROR HANDLING-- */
+        if (mpi_errno == MPI_ERR_IN_STATUS) {
+            for (i = 0; i < reqs; i++) {
+                if (starray[i].MPI_ERROR != MPI_SUCCESS)
+                    mpi_errno = starray[i].MPI_ERROR;
             }
         }
-        MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
-					 displs[rank] * extent);
 
-        for ( i=root+1; i<comm_size; i++ ) {
-            if (recvcnts[i]) {
-                mpi_errno = MPIC_Recv(((char *)recvbuf+displs[i]*extent), 
-                                      recvcnts[i], recvtype, i,
-                                      MPIR_GATHERV_TAG, comm,
-                                      MPI_STATUS_IGNORE);
-		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
-		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
-		    return mpi_errno;
-		}
-		/* --END ERROR HANDLING-- */
-            }
-        }
+        MPIU_Free(reqarray);
+        MPIU_Free(starray);
     }
-    
-    else if ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT)) {
-        /* intercommunicator root */
-        remote_comm_size = comm_ptr->remote_size;
-        MPID_Datatype_get_extent_macro(recvtype, extent);
 
-	MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
-					 displs[rank] * extent);
-        for (i=0; i<remote_comm_size; i++) {
-            if (recvcnts[i]) {
-                mpi_errno = MPIC_Recv(((char *)recvbuf+displs[i]*extent), 
-                                      recvcnts[i], recvtype, i,
-                                      MPIR_GATHERV_TAG, comm,
-                                      MPI_STATUS_IGNORE);
-		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
-		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
-		    return mpi_errno;
-		}
-		/* --END ERROR HANDLING-- */
-            }                
-        }
-    }
-
     else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */
         if (sendcnt)
             mpi_errno = MPIC_Send(sendbuf, sendcnt, sendtype, root, 

Modified: mpich2/trunk/src/mpi/coll/scatterv.c
===================================================================
--- mpich2/trunk/src/mpi/coll/scatterv.c	2009-03-03 17:44:40 UTC (rev 3923)
+++ mpich2/trunk/src/mpi/coll/scatterv.c	2009-03-03 19:11:53 UTC (rev 3924)
@@ -55,10 +55,12 @@
 	MPID_Comm *comm_ptr )
 {
     static const char FCNAME[] = "MPIR_Scatterv";
-    int rank, mpi_errno = MPI_SUCCESS;
+    int rank, comm_size, mpi_errno = MPI_SUCCESS;
     MPI_Comm comm;
     MPI_Aint extent;
-    int      i;
+    int      i, reqs;
+    MPI_Request *reqarray;
+    MPI_Status *starray;
 
     comm = comm_ptr->handle;
     rank = comm_ptr->rank;
@@ -67,11 +69,13 @@
     MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
 
     /* If I'm the root, then scatter */
-    if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (rank == root)) {
-        /* intracomm root */
-        int comm_size;
-        
-        comm_size = comm_ptr->local_size;
+    if (((comm_ptr->comm_kind == MPID_INTRACOMM) && (root == rank)) ||
+        ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT))) {
+        if (comm_ptr->comm_kind == MPID_INTRACOMM)
+            comm_size = comm_ptr->local_size;
+        else
+            comm_size = comm_ptr->remote_size;
+
         MPID_Datatype_get_extent_macro(sendtype, extent);
         /* We need a check to ensure extent will fit in a
          * pointer. That needs extent * (max count) but we can't get
@@ -81,74 +85,43 @@
          * this? */
         MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT sendbuf + extent);
 
-        /* We could use Isend here, but since the receivers need to execute
-           a simple Recv, it may not make much difference in performance, 
-           and using the blocking version is simpler */
-        for ( i=0; i<root; i++ ) {
+        reqarray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Request));
+        starray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Status));
+
+        reqs = 0;
+        for (i = 0; i < comm_size; i++) {
             if (sendcnts[i]) {
-                mpi_errno = MPIC_Send(((char *)sendbuf+displs[i]*extent), 
-                                      sendcnts[i], sendtype, i,
-                                      MPIR_SCATTERV_TAG, comm);
+                if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (i == rank) &&
+                    (sendbuf != MPI_IN_PLACE)) {
+                    mpi_errno = MPIR_Localcopy(((char *)sendbuf+displs[rank]*extent), 
+                                               sendcnts[rank], sendtype, 
+                                               recvbuf, recvcnt, recvtype);
+                }
+                else {
+                    mpi_errno = MPIC_Isend(((char *)sendbuf+displs[i]*extent), 
+                                           sendcnts[i], sendtype, i,
+                                           MPIR_SCATTERV_TAG, comm, &reqarray[reqs++]);
+                }
 		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
+                if (mpi_errno) {
 		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
 		    return mpi_errno;
 		}
 		/* --END ERROR HANDLING-- */
             }
         }
-        if (recvbuf != MPI_IN_PLACE) {
-            if (sendcnts[rank]) {
-                mpi_errno = MPIR_Localcopy(((char *)sendbuf+displs[rank]*extent), 
-                                           sendcnts[rank], sendtype, 
-                                           recvbuf, recvcnt, recvtype);
-		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
-		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
-		    return mpi_errno;
-		}
-		/* --END ERROR HANDLING-- */
+        /* ... then wait for *all* of them to finish: */
+        mpi_errno = NMPI_Waitall(reqs, reqarray, starray);
+        /* --BEGIN ERROR HANDLING-- */
+        if (mpi_errno == MPI_ERR_IN_STATUS) {
+            for (i = 0; i < reqs; i++) {
+                if (starray[i].MPI_ERROR != MPI_SUCCESS)
+                    mpi_errno = starray[i].MPI_ERROR;
             }
-        }        
-        for ( i=root+1; i<comm_size; i++ ) {
-            if (sendcnts[i]) {
-                mpi_errno = MPIC_Send(((char *)sendbuf+displs[i]*extent), 
-                                      sendcnts[i], sendtype, i, 
-                                      MPIR_SCATTERV_TAG, comm);
-		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
-		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
-		    return mpi_errno;
-		}
-		/* --END ERROR HANDLING-- */
-            }
         }
-    }
 
-    else if ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT)) {
-        /* intercommunicator root */
-        int remote_comm_size;
-
-        remote_comm_size = comm_ptr->remote_size;
-        MPID_Datatype_get_extent_macro(sendtype, extent);
-        
-        for (i=0; i<remote_comm_size; i++) {
-            if (sendcnts[i]) {
-                mpi_errno = MPIC_Send(((char *)sendbuf+displs[i]*extent), 
-                                      sendcnts[i], sendtype, i,
-                                      MPIR_SCATTERV_TAG, comm);
-		/* --BEGIN ERROR HANDLING-- */
-                if (mpi_errno)
-		{
-		    mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
-		    return mpi_errno;
-		}
-		/* --END ERROR HANDLING-- */
-            }          
-        }
+        MPIU_Free(reqarray);
+        MPIU_Free(starray);
     }
 
     else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */



More information about the mpich2-commits mailing list