[mpich2-commits] r3924 - mpich2/trunk/src/mpi/coll
balaji at mcs.anl.gov
balaji at mcs.anl.gov
Tue Mar 3 13:11:53 CST 2009
Author: balaji
Date: 2009-03-03 13:11:53 -0600 (Tue, 03 Mar 2009)
New Revision: 3924
Modified:
mpich2/trunk/src/mpi/coll/gatherv.c
mpich2/trunk/src/mpi/coll/scatterv.c
Log:
Change Gatherv/Scatterv to use non-blocking communication. Reviewed by thakur.
Modified: mpich2/trunk/src/mpi/coll/gatherv.c
===================================================================
--- mpich2/trunk/src/mpi/coll/gatherv.c 2009-03-03 17:44:40 UTC (rev 3923)
+++ mpich2/trunk/src/mpi/coll/gatherv.c 2009-03-03 19:11:53 UTC (rev 3924)
@@ -57,99 +57,69 @@
int mpi_errno = MPI_SUCCESS;
MPI_Comm comm;
MPI_Aint extent;
- int i;
+ int i, reqs;
+ MPI_Request *reqarray;
+ MPI_Status *starray;
comm = comm_ptr->handle;
rank = comm_ptr->rank;
/* check if multiple threads are calling this collective function */
MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
-
+
/* If rank == root, then I recv lots, otherwise I send */
- if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (rank == root)) {
- /* intracomm root */
- comm_size = comm_ptr->local_size;
- MPID_Datatype_get_extent_macro(recvtype, extent);
+ if (((comm_ptr->comm_kind == MPID_INTRACOMM) && (root == rank)) ||
+ ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT))) {
+ if (comm_ptr->comm_kind == MPID_INTRACOMM)
+ comm_size = comm_ptr->local_size;
+ else
+ comm_size = comm_ptr->remote_size;
+ MPID_Datatype_get_extent_macro(recvtype, extent);
/* each node can make sure it is not going to overflow aint */
MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
displs[rank] * extent);
- for ( i=0; i<root; i++ ) {
+ reqarray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Request));
+ starray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Status));
+
+ reqs = 0;
+ for (i = 0; i < comm_size; i++) {
if (recvcnts[i]) {
- mpi_errno = MPIC_Recv(((char *)recvbuf+displs[i]*extent),
- recvcnts[i], recvtype, i,
- MPIR_GATHERV_TAG, comm,
- MPI_STATUS_IGNORE);
+ if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (i == rank) &&
+ (sendbuf != MPI_IN_PLACE)) {
+ mpi_errno = MPIR_Localcopy(sendbuf, sendcnt, sendtype,
+ ((char *)recvbuf+displs[rank]*extent),
+ recvcnts[rank], recvtype);
+ }
+ else {
+ mpi_errno = MPIC_Irecv(((char *)recvbuf+displs[i]*extent),
+ recvcnts[i], recvtype, i,
+ MPIR_GATHERV_TAG, comm,
+ &reqarray[reqs++]);
+ }
/* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
+ if (mpi_errno) {
mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
return mpi_errno;
}
/* --END ERROR HANDLING-- */
}
}
- if (sendbuf != MPI_IN_PLACE) {
- if (recvcnts[rank]) {
- MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
- displs[rank]*extent);
- mpi_errno = MPIR_Localcopy(sendbuf, sendcnt, sendtype,
- ((char *)recvbuf+displs[rank]*extent),
- recvcnts[rank], recvtype);
- /* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
- mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
- return mpi_errno;
- }
- /* --END ERROR HANDLING-- */
+ /* ... then wait for *all* of them to finish: */
+ mpi_errno = NMPI_Waitall(reqs, reqarray, starray);
+ /* --BEGIN ERROR HANDLING-- */
+ if (mpi_errno == MPI_ERR_IN_STATUS) {
+ for (i = 0; i < reqs; i++) {
+ if (starray[i].MPI_ERROR != MPI_SUCCESS)
+ mpi_errno = starray[i].MPI_ERROR;
}
}
- MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
- displs[rank] * extent);
- for ( i=root+1; i<comm_size; i++ ) {
- if (recvcnts[i]) {
- mpi_errno = MPIC_Recv(((char *)recvbuf+displs[i]*extent),
- recvcnts[i], recvtype, i,
- MPIR_GATHERV_TAG, comm,
- MPI_STATUS_IGNORE);
- /* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
- mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
- return mpi_errno;
- }
- /* --END ERROR HANDLING-- */
- }
- }
+ MPIU_Free(reqarray);
+ MPIU_Free(starray);
}
-
- else if ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT)) {
- /* intercommunicator root */
- remote_comm_size = comm_ptr->remote_size;
- MPID_Datatype_get_extent_macro(recvtype, extent);
- MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
- displs[rank] * extent);
- for (i=0; i<remote_comm_size; i++) {
- if (recvcnts[i]) {
- mpi_errno = MPIC_Recv(((char *)recvbuf+displs[i]*extent),
- recvcnts[i], recvtype, i,
- MPIR_GATHERV_TAG, comm,
- MPI_STATUS_IGNORE);
- /* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
- mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
- return mpi_errno;
- }
- /* --END ERROR HANDLING-- */
- }
- }
- }
-
else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */
if (sendcnt)
mpi_errno = MPIC_Send(sendbuf, sendcnt, sendtype, root,
Modified: mpich2/trunk/src/mpi/coll/scatterv.c
===================================================================
--- mpich2/trunk/src/mpi/coll/scatterv.c 2009-03-03 17:44:40 UTC (rev 3923)
+++ mpich2/trunk/src/mpi/coll/scatterv.c 2009-03-03 19:11:53 UTC (rev 3924)
@@ -55,10 +55,12 @@
MPID_Comm *comm_ptr )
{
static const char FCNAME[] = "MPIR_Scatterv";
- int rank, mpi_errno = MPI_SUCCESS;
+ int rank, comm_size, mpi_errno = MPI_SUCCESS;
MPI_Comm comm;
MPI_Aint extent;
- int i;
+ int i, reqs;
+ MPI_Request *reqarray;
+ MPI_Status *starray;
comm = comm_ptr->handle;
rank = comm_ptr->rank;
@@ -67,11 +69,13 @@
MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
/* If I'm the root, then scatter */
- if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (rank == root)) {
- /* intracomm root */
- int comm_size;
-
- comm_size = comm_ptr->local_size;
+ if (((comm_ptr->comm_kind == MPID_INTRACOMM) && (root == rank)) ||
+ ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT))) {
+ if (comm_ptr->comm_kind == MPID_INTRACOMM)
+ comm_size = comm_ptr->local_size;
+ else
+ comm_size = comm_ptr->remote_size;
+
MPID_Datatype_get_extent_macro(sendtype, extent);
/* We need a check to ensure extent will fit in a
* pointer. That needs extent * (max count) but we can't get
@@ -81,74 +85,43 @@
* this? */
MPID_Ensure_Aint_fits_in_pointer(MPI_VOID_PTR_CAST_TO_MPI_AINT sendbuf + extent);
- /* We could use Isend here, but since the receivers need to execute
- a simple Recv, it may not make much difference in performance,
- and using the blocking version is simpler */
- for ( i=0; i<root; i++ ) {
+ reqarray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Request));
+ starray = (MPI_Request *) MPIU_Malloc(comm_size * sizeof(MPI_Status));
+
+ reqs = 0;
+ for (i = 0; i < comm_size; i++) {
if (sendcnts[i]) {
- mpi_errno = MPIC_Send(((char *)sendbuf+displs[i]*extent),
- sendcnts[i], sendtype, i,
- MPIR_SCATTERV_TAG, comm);
+ if ((comm_ptr->comm_kind == MPID_INTRACOMM) && (i == rank) &&
+ (sendbuf != MPI_IN_PLACE)) {
+ mpi_errno = MPIR_Localcopy(((char *)sendbuf+displs[rank]*extent),
+ sendcnts[rank], sendtype,
+ recvbuf, recvcnt, recvtype);
+ }
+ else {
+ mpi_errno = MPIC_Isend(((char *)sendbuf+displs[i]*extent),
+ sendcnts[i], sendtype, i,
+ MPIR_SCATTERV_TAG, comm, &reqarray[reqs++]);
+ }
/* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
+ if (mpi_errno) {
mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
return mpi_errno;
}
/* --END ERROR HANDLING-- */
}
}
- if (recvbuf != MPI_IN_PLACE) {
- if (sendcnts[rank]) {
- mpi_errno = MPIR_Localcopy(((char *)sendbuf+displs[rank]*extent),
- sendcnts[rank], sendtype,
- recvbuf, recvcnt, recvtype);
- /* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
- mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
- return mpi_errno;
- }
- /* --END ERROR HANDLING-- */
+ /* ... then wait for *all* of them to finish: */
+ mpi_errno = NMPI_Waitall(reqs, reqarray, starray);
+ /* --BEGIN ERROR HANDLING-- */
+ if (mpi_errno == MPI_ERR_IN_STATUS) {
+ for (i = 0; i < reqs; i++) {
+ if (starray[i].MPI_ERROR != MPI_SUCCESS)
+ mpi_errno = starray[i].MPI_ERROR;
}
- }
- for ( i=root+1; i<comm_size; i++ ) {
- if (sendcnts[i]) {
- mpi_errno = MPIC_Send(((char *)sendbuf+displs[i]*extent),
- sendcnts[i], sendtype, i,
- MPIR_SCATTERV_TAG, comm);
- /* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
- mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
- return mpi_errno;
- }
- /* --END ERROR HANDLING-- */
- }
}
- }
- else if ((comm_ptr->comm_kind == MPID_INTERCOMM) && (root == MPI_ROOT)) {
- /* intercommunicator root */
- int remote_comm_size;
-
- remote_comm_size = comm_ptr->remote_size;
- MPID_Datatype_get_extent_macro(sendtype, extent);
-
- for (i=0; i<remote_comm_size; i++) {
- if (sendcnts[i]) {
- mpi_errno = MPIC_Send(((char *)sendbuf+displs[i]*extent),
- sendcnts[i], sendtype, i,
- MPIR_SCATTERV_TAG, comm);
- /* --BEGIN ERROR HANDLING-- */
- if (mpi_errno)
- {
- mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0);
- return mpi_errno;
- }
- /* --END ERROR HANDLING-- */
- }
- }
+ MPIU_Free(reqarray);
+ MPIU_Free(starray);
}
else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */
More information about the mpich2-commits
mailing list