diff --git a/src/mpid/ch4/netmod/ofi/Makefile.mk b/src/mpid/ch4/netmod/ofi/Makefile.mk index 490aa18e0f4..6b107f60921 100644 --- a/src/mpid/ch4/netmod/ofi/Makefile.mk +++ b/src/mpid/ch4/netmod/ofi/Makefile.mk @@ -16,6 +16,7 @@ mpi_core_sources += src/mpid/ch4/netmod/ofi/func_table.c \ src/mpid/ch4/netmod/ofi/ofi_win.c \ src/mpid/ch4/netmod/ofi/ofi_part.c \ src/mpid/ch4/netmod/ofi/ofi_events.c \ + src/mpid/ch4/netmod/ofi/ofi_huge.c \ src/mpid/ch4/netmod/ofi/ofi_progress.c \ src/mpid/ch4/netmod/ofi/ofi_am_events.c \ src/mpid/ch4/netmod/ofi/ofi_nic.c \ diff --git a/src/mpid/ch4/netmod/ofi/globals.c b/src/mpid/ch4/netmod/ofi/globals.c index 67ea56bf50a..40f534b9697 100644 --- a/src/mpid/ch4/netmod/ofi/globals.c +++ b/src/mpid/ch4/netmod/ofi/globals.c @@ -7,10 +7,10 @@ #include "ofi_impl.h" MPIDI_OFI_global_t MPIDI_OFI_global; -MPIDI_OFI_huge_recv_t *MPIDI_unexp_huge_recv_head = NULL; -MPIDI_OFI_huge_recv_t *MPIDI_unexp_huge_recv_tail = NULL; -MPIDI_OFI_huge_recv_list_t *MPIDI_posted_huge_recv_head = NULL; -MPIDI_OFI_huge_recv_list_t *MPIDI_posted_huge_recv_tail = NULL; +MPIDI_OFI_huge_recv_list_t *MPIDI_huge_ctrl_head = NULL; +MPIDI_OFI_huge_recv_list_t *MPIDI_huge_ctrl_tail = NULL; +MPIDI_OFI_huge_recv_list_t *MPIDI_huge_recv_head = NULL; +MPIDI_OFI_huge_recv_list_t *MPIDI_huge_recv_tail = NULL; unsigned long long PVAR_COUNTER_nic_sent_bytes_count[MPIDI_OFI_MAX_NICS] ATTRIBUTE((unused)); unsigned long long PVAR_COUNTER_nic_recvd_bytes_count[MPIDI_OFI_MAX_NICS] ATTRIBUTE((unused)); diff --git a/src/mpid/ch4/netmod/ofi/ofi_control.h b/src/mpid/ch4/netmod/ofi/ofi_control.h deleted file mode 100644 index 60d8539dbde..00000000000 --- a/src/mpid/ch4/netmod/ofi/ofi_control.h +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (C) by Argonne National Laboratory - * See COPYRIGHT in top-level directory - */ - -#ifndef OFI_CONTROL_H_INCLUDED -#define OFI_CONTROL_H_INCLUDED - -#include "ofi_am_impl.h" - -MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_control_send(MPIDI_OFI_send_control_t * control, - char *send_buf, - size_t msgsize, - int rank, MPIR_Comm * comm_ptr, - MPIR_Request * ackreq) -{ - int mpi_errno = MPI_SUCCESS; - MPIR_FUNC_ENTER; - - control->origin_rank = comm_ptr->rank; - control->send_buf = (uintptr_t) send_buf; - control->msgsize = msgsize; - control->comm_id = comm_ptr->context_id; - control->ackreq = ackreq; - - mpi_errno = MPIDI_OFI_do_inject(rank, comm_ptr, - MPIDI_OFI_INTERNAL_HANDLER_CONTROL, - (void *) control, sizeof(*control), 0, 0); - - MPIR_FUNC_EXIT; - return mpi_errno; -} - - -#endif /* OFI_CONTROL_H_INCLUDED */ diff --git a/src/mpid/ch4/netmod/ofi/ofi_events.c b/src/mpid/ch4/netmod/ofi/ofi_events.c index 00520fafa8f..f8c4c09e171 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_events.c +++ b/src/mpid/ch4/netmod/ofi/ofi_events.c @@ -14,10 +14,8 @@ static int peek_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq); static int peek_empty_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq); -static int recv_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq); static int send_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * sreq); static int ssend_ack_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * sreq); -static uintptr_t recv_rbase(MPIDI_OFI_huge_recv_t * recv); static int chunk_done_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req); static int inject_emu_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req); static int accept_probe_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq); @@ -32,77 +30,25 @@ static int am_read_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * static int peek_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq) { int mpi_errno = MPI_SUCCESS; - size_t count = 0; MPIR_FUNC_ENTER; - rreq->status.MPI_SOURCE = MPIDI_OFI_cqe_get_source(wc, false); - rreq->status.MPI_TAG = MPIDI_OFI_init_get_tag(wc->tag); - rreq->status.MPI_ERROR = MPI_SUCCESS; if (MPIDI_OFI_HUGE_SEND & wc->tag) { - MPIDI_OFI_huge_recv_t *list_ptr; - bool found_msg = false; - - /* If this is a huge message, find the control message on the unexpected list that matches - * with this and return the size in that. */ - LL_FOREACH(MPIDI_unexp_huge_recv_head, list_ptr) { - uint64_t context_id = MPIDI_OFI_CONTEXT_MASK & wc->tag; - uint64_t tag = MPIDI_OFI_TAG_MASK & wc->tag; - if (list_ptr->remote_info.comm_id == context_id && - list_ptr->remote_info.origin_rank == MPIDI_OFI_cqe_get_source(wc, false) && - list_ptr->remote_info.tag == tag) { - count = list_ptr->remote_info.msgsize; - found_msg = true; - } - } - if (!found_msg) { - MPIDI_OFI_huge_recv_t *recv_elem; - MPIDI_OFI_huge_recv_list_t *huge_list_ptr; - - /* Create an element in the posted list that only indicates a peek and will be - * deleted as soon as it's fulfilled without being matched. */ - recv_elem = (MPIDI_OFI_huge_recv_t *) MPL_calloc(sizeof(*recv_elem), 1, MPL_MEM_COMM); - MPIR_ERR_CHKANDJUMP(recv_elem == NULL, mpi_errno, MPI_ERR_OTHER, "**nomem"); - recv_elem->peek = true; - MPIR_Comm *comm_ptr = rreq->comm; - recv_elem->comm_ptr = comm_ptr; - MPIDIU_map_set(MPIDI_OFI_global.huge_recv_counters, rreq->handle, recv_elem, - MPL_MEM_BUFFER); - - huge_list_ptr = - (MPIDI_OFI_huge_recv_list_t *) MPL_calloc(sizeof(*huge_list_ptr), 1, MPL_MEM_COMM); - MPIR_ERR_CHKANDJUMP(huge_list_ptr == NULL, mpi_errno, MPI_ERR_OTHER, "**nomem"); - recv_elem->remote_info.comm_id = huge_list_ptr->comm_id = - MPIDI_OFI_CONTEXT_MASK & wc->tag; - recv_elem->remote_info.origin_rank = huge_list_ptr->rank = - MPIDI_OFI_cqe_get_source(wc, false); - recv_elem->remote_info.tag = huge_list_ptr->tag = MPIDI_OFI_TAG_MASK & wc->tag; - recv_elem->localreq = huge_list_ptr->rreq = rreq; - recv_elem->event_id = MPIDI_OFI_EVENT_GET_HUGE; - recv_elem->done_fn = MPIDI_OFI_recv_event; - recv_elem->wc = *wc; - if (MPIDI_OFI_COMM(comm_ptr).enable_striping) { - recv_elem->cur_offset = MPIDI_OFI_STRIPE_CHUNK_SIZE; - } else { - recv_elem->cur_offset = MPIDI_OFI_global.max_msg_size; - } - - LL_APPEND(MPIDI_posted_huge_recv_head, MPIDI_posted_huge_recv_tail, huge_list_ptr); - goto fn_exit; - } - } else { - /* Otherwise just get the size of the message we've already received. */ - count = wc->len; + mpi_errno = MPIDI_OFI_peek_huge_event(vni, wc, rreq); + goto fn_exit; } - MPIR_STATUS_SET_COUNT(rreq->status, count); + + rreq->status.MPI_SOURCE = MPIDI_OFI_cqe_get_source(wc, false); + rreq->status.MPI_TAG = MPIDI_OFI_init_get_tag(wc->tag); + rreq->status.MPI_ERROR = MPI_SUCCESS; + MPIR_STATUS_SET_COUNT(rreq->status, wc->len); /* util_id should be the last thing to change in rreq. Reason is * we use util_id to indicate peek_event has completed and all the * relevant values have been copied to rreq. */ MPL_atomic_release_store_int(&(MPIDI_OFI_REQUEST(rreq, util_id)), MPIDI_OFI_PEEK_FOUND); + fn_exit: MPIR_FUNC_EXIT; return mpi_errno; - fn_fail: - goto fn_exit; } static int peek_empty_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq) @@ -134,110 +80,6 @@ static int peek_empty_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request return MPI_SUCCESS; } -/* If we posted a huge receive, this event gets called to translate the - * completion queue entry into a get huge event */ -static int recv_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq) -{ - int mpi_errno = MPI_SUCCESS; - MPIDI_OFI_huge_recv_t *recv_elem = NULL; - MPIR_Comm *comm_ptr; - MPIR_FUNC_ENTER; - - bool ready_to_get = false; - /* Check that the sender didn't underflow the message by sending less than - * the huge message threshold. When striping is enabled underflow occurs if - * the sender sends < MPIDI_OFI_STRIPE_CHUNK_SIZE through the huge message protocol - * or < MPIDI_OFI_global.stripe_threshold through normal send */ - if (((wc->len < MPIDI_OFI_STRIPE_CHUNK_SIZE || - (wc->len > MPIDI_OFI_STRIPE_CHUNK_SIZE && wc->len < MPIDI_OFI_global.stripe_threshold)) && - MPIDI_OFI_COMM(rreq->comm).enable_striping) || - (wc->len < MPIDI_OFI_global.max_msg_size && !MPIDI_OFI_COMM(rreq->comm).enable_striping)) { - return MPIDI_OFI_recv_event(vni, wc, rreq, MPIDI_OFI_REQUEST(rreq, event_id)); - } - - comm_ptr = rreq->comm; - MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_recvd_bytes_count[MPIDI_OFI_REQUEST(rreq, nic_num)], - wc->len); - /* Check to see if the tracker is already in the unexpected list. - * Otherwise, allocate one. */ - { - MPIDI_OFI_huge_recv_t *list_ptr; - - MPL_DBG_MSG_FMT(MPIR_DBG_PT2PT, VERBOSE, - (MPL_DBG_FDEST, "SEARCHING HUGE UNEXPECTED LIST: (%d, %d, %llu)", - comm_ptr->context_id, MPIDI_OFI_cqe_get_source(wc, false), - (MPIDI_OFI_TAG_MASK & wc->tag))); - - LL_FOREACH(MPIDI_unexp_huge_recv_head, list_ptr) { - if (list_ptr->remote_info.comm_id == comm_ptr->context_id && - list_ptr->remote_info.origin_rank == MPIDI_OFI_cqe_get_source(wc, false) && - list_ptr->remote_info.tag == (MPIDI_OFI_TAG_MASK & wc->tag)) { - MPL_DBG_MSG_FMT(MPIR_DBG_PT2PT, VERBOSE, - (MPL_DBG_FDEST, "MATCHED HUGE UNEXPECTED LIST: (%d, %d, %llu, %d)", - comm_ptr->context_id, MPIDI_OFI_cqe_get_source(wc, false), - (MPIDI_OFI_TAG_MASK & wc->tag), rreq->handle)); - - LL_DELETE(MPIDI_unexp_huge_recv_head, MPIDI_unexp_huge_recv_tail, list_ptr); - - recv_elem = list_ptr; - MPIDIU_map_set(MPIDI_OFI_global.huge_recv_counters, rreq->handle, recv_elem, - MPL_MEM_COMM); - break; - } - } - } - - if (recv_elem) { - ready_to_get = true; - } else { - MPIDI_OFI_huge_recv_list_t *list_ptr; - - MPL_DBG_MSG_FMT(MPIR_DBG_PT2PT, VERBOSE, - (MPL_DBG_FDEST, "CREATING HUGE POSTED ENTRY: (%d, %d, %llu)", - comm_ptr->context_id, MPIDI_OFI_cqe_get_source(wc, false), - (MPIDI_OFI_TAG_MASK & wc->tag))); - - recv_elem = (MPIDI_OFI_huge_recv_t *) MPL_calloc(sizeof(*recv_elem), 1, MPL_MEM_BUFFER); - MPIR_ERR_CHKANDJUMP(recv_elem == NULL, mpi_errno, MPI_ERR_OTHER, "**nomem"); - MPIDIU_map_set(MPIDI_OFI_global.huge_recv_counters, rreq->handle, recv_elem, - MPL_MEM_BUFFER); - - list_ptr = (MPIDI_OFI_huge_recv_list_t *) MPL_calloc(sizeof(*list_ptr), 1, MPL_MEM_BUFFER); - if (!list_ptr) - MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); - - list_ptr->comm_id = comm_ptr->context_id; - list_ptr->rank = MPIDI_OFI_cqe_get_source(wc, false); - list_ptr->tag = (MPIDI_OFI_TAG_MASK & wc->tag); - list_ptr->rreq = rreq; - - LL_APPEND(MPIDI_posted_huge_recv_head, MPIDI_posted_huge_recv_tail, list_ptr); - } - - /* Plug the information for the huge event into the receive request and go - * to the MPIDI_OFI_get_huge_event function. */ - recv_elem->event_id = MPIDI_OFI_EVENT_GET_HUGE; - recv_elem->peek = false; - recv_elem->comm_ptr = comm_ptr; - recv_elem->localreq = rreq; - recv_elem->done_fn = MPIDI_OFI_recv_event; - recv_elem->wc = *wc; - if (MPIDI_OFI_COMM(comm_ptr).enable_striping) { - recv_elem->cur_offset = MPIDI_OFI_STRIPE_CHUNK_SIZE; - } else { - recv_elem->cur_offset = MPIDI_OFI_global.max_msg_size; - } - if (ready_to_get) { - MPIDI_OFI_get_huge_event(vni, NULL, (MPIR_Request *) recv_elem); - } - - fn_exit: - MPIR_FUNC_EXIT; - return mpi_errno; - fn_fail: - goto fn_exit; -} - static int send_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * sreq) { int mpi_errno = MPI_SUCCESS; @@ -248,21 +90,11 @@ static int send_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request if (c == 0) { MPIR_Comm *comm; - void *ptr; struct fid_mr **huge_send_mrs; comm = sreq->comm; num_nics = MPIDI_OFI_COMM(comm).enable_striping ? MPIDI_OFI_global.num_nics : 1; - /* Look for the memory region using the sreq handle */ - ptr = MPIDIU_map_lookup(MPIDI_OFI_global.huge_send_counters, sreq->handle); - MPIR_Assert(ptr != MPIDIU_MAP_NOT_FOUND); - - huge_send_mrs = (struct fid_mr **) ptr; - - /* Send a cleanup message to the receivier and clean up local - * resources. */ - /* Clean up the local counter */ - MPIDIU_map_erase(MPIDI_OFI_global.huge_send_counters, sreq->handle); + huge_send_mrs = MPIDI_OFI_REQUEST(sreq, huge.send_mrs); /* Clean up the memory region */ if (!MPIDI_OFI_ENABLE_MR_PROV_KEY) { @@ -306,129 +138,6 @@ static int ssend_ack_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request return mpi_errno; } -static uintptr_t recv_rbase(MPIDI_OFI_huge_recv_t * recv_elem) -{ - if (!MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { - return 0; - } else { - return recv_elem->remote_info.send_buf; - } -} - -/* Note: MPIDI_OFI_get_huge_event is invoked from three places -- - * 1. In recv_huge_event, when recv buffer is matched and first chunk received, and - * when control message (with remote info) has also been received. - * 2. In MPIDI_OFI_get_huge, as a callback when control message is received, and - * when first chunk has been matched and received. - * - * recv_huge_event will fill the local request information, and the control message - * callback will fill the remote (sender) information. Lastly -- - * - * 3. As the event function when RDMA read (issued here) completes. - */ -int MPIDI_OFI_get_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req) -{ - int mpi_errno = MPI_SUCCESS; - MPIDI_OFI_huge_recv_t *recv_elem = (MPIDI_OFI_huge_recv_t *) req; - uint64_t remote_key; - size_t bytesLeft, bytesToGet; - MPIR_FUNC_ENTER; - - void *recv_buf = MPIDI_OFI_REQUEST(recv_elem->localreq, util.iov.iov_base); - - if (MPIDI_OFI_COMM(recv_elem->comm_ptr).enable_striping) { - /* Subtract one stripe_chunk_size because we send the first chunk via a regular message - * instead of the memory region */ - recv_elem->stripe_size = (recv_elem->remote_info.msgsize - MPIDI_OFI_STRIPE_CHUNK_SIZE) - / MPIDI_OFI_global.num_nics; /* striping */ - - if (recv_elem->stripe_size > MPIDI_OFI_global.max_msg_size) { - recv_elem->stripe_size = MPIDI_OFI_global.max_msg_size; - } - if (recv_elem->chunks_outstanding) - recv_elem->chunks_outstanding--; - bytesLeft = recv_elem->remote_info.msgsize - recv_elem->cur_offset; - bytesToGet = (bytesLeft <= recv_elem->stripe_size) ? bytesLeft : recv_elem->stripe_size; - } else { - /* Subtract one max_msg_size because we send the first chunk via a regular message - * instead of the memory region */ - bytesLeft = recv_elem->remote_info.msgsize - recv_elem->cur_offset; - bytesToGet = (bytesLeft <= MPIDI_OFI_global.max_msg_size) ? - bytesLeft : MPIDI_OFI_global.max_msg_size; - } - if (bytesToGet == 0ULL && recv_elem->chunks_outstanding == 0) { - MPIDI_OFI_send_control_t ctrl; - /* recv_elem->localreq may be freed during done_fn. - * Need to backup the handle here for later use with MPIDIU_map_erase. */ - uint64_t key_to_erase = recv_elem->localreq->handle; - recv_elem->wc.len = recv_elem->cur_offset; - recv_elem->done_fn(vni, &recv_elem->wc, recv_elem->localreq, recv_elem->event_id); - ctrl.type = MPIDI_OFI_CTRL_HUGEACK; - mpi_errno = - MPIDI_OFI_do_control_send(&ctrl, NULL, 0, recv_elem->remote_info.origin_rank, - recv_elem->comm_ptr, recv_elem->remote_info.ackreq); - MPIR_ERR_CHECK(mpi_errno); - - MPIDIU_map_erase(MPIDI_OFI_global.huge_recv_counters, key_to_erase); - MPL_free(recv_elem); - - goto fn_exit; - } - - int nic = 0; - int vni_src = recv_elem->remote_info.vni_src; - int vni_dst = recv_elem->remote_info.vni_dst; - if (MPIDI_OFI_COMM(recv_elem->comm_ptr).enable_striping) { /* if striping enabled */ - MPIDI_OFI_cntr_incr(recv_elem->comm_ptr, vni_src, nic); - if (recv_elem->cur_offset >= MPIDI_OFI_STRIPE_CHUNK_SIZE && bytesLeft > 0) { - for (nic = 0; nic < MPIDI_OFI_global.num_nics; nic++) { - int ctx_idx = MPIDI_OFI_get_ctx_index(recv_elem->comm_ptr, vni_dst, nic); - remote_key = recv_elem->remote_info.rma_keys[nic]; - - bytesLeft = recv_elem->remote_info.msgsize - recv_elem->cur_offset; - if (bytesLeft <= 0) { - break; - } - bytesToGet = - (bytesLeft <= recv_elem->stripe_size) ? bytesLeft : recv_elem->stripe_size; - - MPIDI_OFI_CALL_RETRY(fi_read(MPIDI_OFI_global.ctx[ctx_idx].tx, (void *) ((char *) recv_buf + recv_elem->cur_offset), /* local buffer */ - bytesToGet, /* bytes */ - NULL, /* descriptor */ - MPIDI_OFI_comm_to_phys(recv_elem->comm_ptr, recv_elem->remote_info.origin_rank, nic, vni_dst, vni_src), recv_rbase(recv_elem) + recv_elem->cur_offset, /* remote maddr */ - remote_key, /* Key */ - (void *) &recv_elem->context), nic, /* Context */ - rdma_readfrom, FALSE); - MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_recvd_bytes_count[nic], bytesToGet); - MPIR_T_PVAR_COUNTER_INC(MULTINIC, striped_nic_recvd_bytes_count[nic], bytesToGet); - recv_elem->cur_offset += bytesToGet; - recv_elem->chunks_outstanding++; - } - } - } else { - int ctx_idx = MPIDI_OFI_get_ctx_index(recv_elem->comm_ptr, vni_src, nic); - remote_key = recv_elem->remote_info.rma_keys[nic]; - MPIDI_OFI_cntr_incr(recv_elem->comm_ptr, vni_src, nic); - MPIDI_OFI_CALL_RETRY(fi_read(MPIDI_OFI_global.ctx[ctx_idx].tx, /* endpoint */ - (void *) ((char *) recv_buf + recv_elem->cur_offset), /* local buffer */ - bytesToGet, /* bytes */ - NULL, /* descriptor */ - MPIDI_OFI_comm_to_phys(recv_elem->comm_ptr, recv_elem->remote_info.origin_rank, nic, vni_src, vni_dst), /* Destination */ - recv_rbase(recv_elem) + recv_elem->cur_offset, /* remote maddr */ - remote_key, /* Key */ - (void *) &recv_elem->context), vni_src, rdma_readfrom, /* Context */ - FALSE); - MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_recvd_bytes_count[nic], bytesToGet); - recv_elem->cur_offset += bytesToGet; - } - - fn_exit: - MPIR_FUNC_EXIT; - return mpi_errno; - fn_fail: - goto fn_exit; -} - static int chunk_done_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req) { int c; @@ -769,7 +478,11 @@ int MPIDI_OFI_dispatch_function(int vni, struct fi_cq_tagged_entry *wc, MPIR_Req break; case MPIDI_OFI_EVENT_RECV_HUGE: - mpi_errno = recv_huge_event(vni, wc, req); + if (wc->tag & MPIDI_OFI_HUGE_SEND) { + mpi_errno = MPIDI_OFI_recv_huge_event(vni, wc, req); + } else { + mpi_errno = MPIDI_OFI_recv_event(vni, wc, req, MPIDI_OFI_EVENT_RECV_HUGE); + } break; case MPIDI_OFI_EVENT_RECV_PACK: @@ -862,10 +575,10 @@ int MPIDI_OFI_handle_cq_error(int vni, int nic, ssize_t ret) break; case MPIR_REQUEST_KIND__RECV: + req->status.MPI_ERROR = MPI_ERR_TRUNCATE; mpi_errno = MPIDI_OFI_dispatch_function(vni, (struct fi_cq_tagged_entry *) &e, req); - req->status.MPI_ERROR = MPI_ERR_TRUNCATE; break; default: diff --git a/src/mpid/ch4/netmod/ofi/ofi_events.h b/src/mpid/ch4/netmod/ofi/ofi_events.h index 79f31916591..f88a3060cdb 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_events.h +++ b/src/mpid/ch4/netmod/ofi/ofi_events.h @@ -9,11 +9,9 @@ #include "ofi_impl.h" #include "ofi_am_impl.h" #include "ofi_am_events.h" -#include "ofi_control.h" #include "utlist.h" int MPIDI_OFI_rma_done_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * in_req); -int MPIDI_OFI_get_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req); int MPIDI_OFI_dispatch_function(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req); MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_cqe_get_source(struct fi_cq_tagged_entry *wc, bool has_err) @@ -55,8 +53,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_recv_event(int vni, struct fi_cq_tagged_e size_t count; MPIR_FUNC_ENTER; + if (wc->tag & MPIDI_OFI_HUGE_SEND) { + mpi_errno = MPIDI_OFI_recv_huge_event(vni, wc, rreq); + goto fn_exit; + } rreq->status.MPI_SOURCE = MPIDI_OFI_cqe_get_source(wc, true); - rreq->status.MPI_ERROR = MPIDI_OFI_idata_get_error_bits(wc->data); + if (!rreq->status.MPI_ERROR) { + rreq->status.MPI_ERROR = MPIDI_OFI_idata_get_error_bits(wc->data); + } rreq->status.MPI_TAG = MPIDI_OFI_init_get_tag(wc->tag); count = wc->len; MPIR_STATUS_SET_COUNT(rreq->status, count); @@ -131,7 +135,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_recv_event(int vni, struct fi_cq_tagged_e MPIDIU_request_complete(rreq); - /* Polling loop will check for truncation */ fn_exit: MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/netmod/ofi/ofi_huge.c b/src/mpid/ch4/netmod/ofi/ofi_huge.c new file mode 100644 index 00000000000..ef227f73681 --- /dev/null +++ b/src/mpid/ch4/netmod/ofi/ofi_huge.c @@ -0,0 +1,394 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#include +#include "ofi_impl.h" +#include "ofi_events.h" + +static int get_huge(MPIR_Request * rreq); +static int get_huge_complete(MPIR_Request * rreq); + +static int get_huge(MPIR_Request * rreq) +{ + int mpi_errno = MPI_SUCCESS; + MPIDI_OFI_huge_remote_info_t *info = MPIDI_OFI_REQUEST(rreq, huge.remote_info); + + MPI_Aint cur_offset; + if (MPIDI_OFI_COMM(rreq->comm).enable_striping) { + cur_offset = MPIDI_OFI_STRIPE_CHUNK_SIZE; + } else { + cur_offset = MPIDI_OFI_global.max_msg_size; + } + + MPI_Aint data_sz = MPIDI_OFI_REQUEST(rreq, util.iov.iov_len); + + if (data_sz < info->msgsize) { + rreq->status.MPI_ERROR = MPI_ERR_TRUNCATE; + info->msgsize = data_sz; + } + + if (data_sz < cur_offset) { + /* huge message sent to small recv buffer */ + mpi_errno = get_huge_complete(rreq); + MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + } + + MPIDI_OFI_huge_recv_t *recv_elem = NULL; + recv_elem = (MPIDI_OFI_huge_recv_t *) MPL_calloc(sizeof(*recv_elem), 1, MPL_MEM_BUFFER); + MPIR_ERR_CHKANDJUMP(recv_elem == NULL, mpi_errno, MPI_ERR_OTHER, "**nomem"); + recv_elem->event_id = MPIDI_OFI_EVENT_GET_HUGE; + recv_elem->localreq = rreq; + recv_elem->cur_offset = cur_offset; + + MPIDI_OFI_get_huge_event(info->vni_dst, NULL, (MPIR_Request *) recv_elem); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +static int get_huge_complete(MPIR_Request * rreq) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + MPIDI_OFI_huge_remote_info_t *info = MPIDI_OFI_REQUEST(rreq, huge.remote_info); + + /* note: it's receiver ack sender */ + int vni_remote = info->vni_src; + int vni_local = info->vni_dst; + + struct fi_cq_tagged_entry wc; + wc.len = info->msgsize; + wc.data = info->origin_rank; + wc.tag = info->tag; + MPIDI_OFI_recv_event(vni_local, &wc, rreq, MPIDI_OFI_EVENT_GET_HUGE); + + MPIDI_OFI_send_control_t ctrl; + ctrl.type = MPIDI_OFI_CTRL_HUGEACK; + ctrl.u.huge_ack.ackreq = info->ackreq; + mpi_errno = MPIDI_NM_am_send_hdr(info->origin_rank, rreq->comm, + MPIDI_OFI_INTERNAL_HANDLER_CONTROL, + &ctrl, sizeof(ctrl), vni_local, vni_remote); + MPIR_ERR_CHECK(mpi_errno); + + MPL_free(info); + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} + +/* this function called by recv event of a huge message */ +int MPIDI_OFI_recv_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_Comm *comm_ptr; + MPIR_FUNC_ENTER; + + bool ready_to_get = false; + if (MPIDI_OFI_REQUEST(rreq, event_id) != MPIDI_OFI_EVENT_RECV_HUGE) { + /* huge send recved by a small buffer */ + } else if (MPIDI_OFI_COMM(rreq->comm).enable_striping) { + MPIR_Assert(wc->len == MPIDI_OFI_STRIPE_CHUNK_SIZE); + } else { + MPIR_Assert(wc->len == MPIDI_OFI_global.max_msg_size); + } + + comm_ptr = rreq->comm; + MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_recvd_bytes_count[MPIDI_OFI_REQUEST(rreq, nic_num)], + wc->len); + if (MPIDI_OFI_REQUEST(rreq, huge.remote_info)) { + /* this is mrecv, we already got remote info */ + ready_to_get = true; + } else { + /* Check for remote control info */ + MPIDI_OFI_huge_recv_list_t *list_ptr; + MPIR_Context_id_t comm_id = comm_ptr->recvcontext_id; + int rank = MPIDI_OFI_cqe_get_source(wc, false); + int tag = (MPIDI_OFI_TAG_MASK & wc->tag); + + LL_FOREACH(MPIDI_huge_ctrl_head, list_ptr) { + if (list_ptr->comm_id == comm_id && list_ptr->rank == rank && list_ptr->tag == tag) { + MPIDI_OFI_REQUEST(rreq, huge.remote_info) = list_ptr->u.info; + LL_DELETE(MPIDI_huge_ctrl_head, MPIDI_huge_ctrl_tail, list_ptr); + MPL_free(list_ptr); + ready_to_get = true; + break; + } + } + } + + if (!ready_to_get) { + MPIDI_OFI_huge_recv_list_t *list_ptr; + + list_ptr = (MPIDI_OFI_huge_recv_list_t *) MPL_calloc(sizeof(*list_ptr), 1, MPL_MEM_BUFFER); + if (!list_ptr) + MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); + + list_ptr->comm_id = comm_ptr->recvcontext_id; + list_ptr->rank = MPIDI_OFI_cqe_get_source(wc, false); + list_ptr->tag = (MPIDI_OFI_TAG_MASK & wc->tag); + list_ptr->u.rreq = rreq; + + LL_APPEND(MPIDI_huge_recv_head, MPIDI_huge_recv_tail, list_ptr); + /* control handler will finish the recv */ + } else { + /* proceed to get the huge message */ + mpi_errno = get_huge(rreq); + MPIR_ERR_CHECK(mpi_errno); + } + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} + +/* This function is called when we receive a huge control message */ +int MPIDI_OFI_recv_huge_control(MPIR_Context_id_t comm_id, int rank, int tag, + MPIDI_OFI_huge_remote_info_t * info_ptr) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + MPIDI_OFI_huge_recv_list_t *list_ptr; + MPIR_Request *rreq = NULL; + MPIDI_OFI_huge_remote_info_t *info; + + /* need persist the info. It will eventually get freed at recv completion */ + info = MPL_malloc(sizeof(MPIDI_OFI_huge_remote_info_t), MPL_MEM_OTHER); + MPIR_Assert(info); + memcpy(info, info_ptr, sizeof(*info)); + + /* If there has been a posted receive, search through the list of unmatched + * receives to find the one that goes with the incoming message. */ + LL_FOREACH(MPIDI_huge_recv_head, list_ptr) { + if (list_ptr->comm_id == comm_id && list_ptr->rank == rank && list_ptr->tag == tag) { + rreq = list_ptr->u.rreq; + LL_DELETE(MPIDI_huge_recv_head, MPIDI_huge_recv_tail, list_ptr); + MPL_free(list_ptr); + break; + } + } + + if (!rreq) { + list_ptr = (MPIDI_OFI_huge_recv_list_t *) MPL_calloc(sizeof(MPIDI_OFI_huge_recv_list_t), + 1, MPL_MEM_OTHER); + if (!list_ptr) { + MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); + } + list_ptr->comm_id = comm_id; + list_ptr->rank = rank; + list_ptr->tag = tag; + list_ptr->u.info = info; + + LL_APPEND(MPIDI_huge_ctrl_head, MPIDI_huge_ctrl_tail, list_ptr); + /* let MPIDI_OFI_recv_huge_event finish the recv */ + } else if (MPIDI_OFI_REQUEST(rreq, kind) == MPIDI_OFI_req_kind__mprobe) { + /* attach info and finish the mprobe */ + MPIDI_OFI_REQUEST(rreq, huge.remote_info) = info; + MPIR_STATUS_SET_COUNT(rreq->status, info->msgsize); + MPL_atomic_release_store_int(&(MPIDI_OFI_REQUEST(rreq, util_id)), MPIDI_OFI_PEEK_FOUND); + } else { + /* attach info and finish recv */ + MPIDI_OFI_REQUEST(rreq, huge.remote_info) = info; + mpi_errno = get_huge(rreq); + MPIR_ERR_CHECK(mpi_errno); + } + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} + +int MPIDI_OFI_peek_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + MPI_Aint count = 0; + MPIDI_OFI_huge_recv_list_t *list_ptr; + bool found_msg = false; + + /* If this is a huge message, find the control message on the unexpected list that matches + * with this and return the size in that. */ + LL_FOREACH(MPIDI_huge_ctrl_head, list_ptr) { + /* FIXME: fix the type of comm_id */ + MPIR_Context_id_t comm_id = rreq->comm->recvcontext_id; + int rank = MPIDI_OFI_cqe_get_source(wc, false); + int tag = (int) (MPIDI_OFI_TAG_MASK & wc->tag); + if (list_ptr->comm_id == comm_id && list_ptr->rank == rank && list_ptr->tag == tag) { + count = list_ptr->u.info->msgsize; + found_msg = true; + break; + } + } + if (found_msg) { + if (MPIDI_OFI_REQUEST(rreq, kind) == MPIDI_OFI_req_kind__mprobe) { + MPIDI_OFI_REQUEST(rreq, huge.remote_info) = list_ptr->u.info; + LL_DELETE(MPIDI_huge_ctrl_head, MPIDI_huge_ctrl_tail, list_ptr); + MPL_free(list_ptr); + } + rreq->status.MPI_SOURCE = MPIDI_OFI_cqe_get_source(wc, false); + rreq->status.MPI_TAG = MPIDI_OFI_init_get_tag(wc->tag); + rreq->status.MPI_ERROR = MPI_SUCCESS; + MPIR_STATUS_SET_COUNT(rreq->status, count); + /* util_id should be the last thing to change in rreq. Reason is + * we use util_id to indicate peek_event has completed and all the + * relevant values have been copied to rreq. */ + MPL_atomic_release_store_int(&(MPIDI_OFI_REQUEST(rreq, util_id)), MPIDI_OFI_PEEK_FOUND); + } else if (MPIDI_OFI_REQUEST(rreq, kind) == MPIDI_OFI_req_kind__probe) { + /* return not found for this probe. User can probe again. */ + MPL_atomic_release_store_int(&(MPIDI_OFI_REQUEST(rreq, util_id)), MPIDI_OFI_PEEK_NOT_FOUND); + } else if (MPIDI_OFI_REQUEST(rreq, kind) == MPIDI_OFI_req_kind__mprobe) { + /* fill the status with wc info. Count is still missing */ + rreq->status.MPI_SOURCE = MPIDI_OFI_cqe_get_source(wc, false); + rreq->status.MPI_TAG = MPIDI_OFI_init_get_tag(wc->tag); + rreq->status.MPI_ERROR = MPI_SUCCESS; + + /* post the rreq to list and let control handler handle it */ + MPIDI_OFI_huge_recv_list_t *huge_list_ptr; + + huge_list_ptr = + (MPIDI_OFI_huge_recv_list_t *) MPL_calloc(sizeof(*huge_list_ptr), 1, MPL_MEM_COMM); + MPIR_ERR_CHKANDJUMP(huge_list_ptr == NULL, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + huge_list_ptr->comm_id = rreq->comm->recvcontext_id; + huge_list_ptr->rank = MPIDI_OFI_cqe_get_source(wc, false); + huge_list_ptr->tag = MPIDI_OFI_TAG_MASK & wc->tag; + huge_list_ptr->u.rreq = rreq; + + LL_APPEND(MPIDI_huge_recv_head, MPIDI_huge_recv_tail, huge_list_ptr); + } + + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} + +static uintptr_t recv_rbase(MPIDI_OFI_huge_remote_info_t * remote_info) +{ + if (!MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { + return 0; + } else { + return (uintptr_t) remote_info->send_buf; + } +} + +/* Note: MPIDI_OFI_get_huge_event is invoked from three places -- + * 1. In MPIDI_OFI_recv_huge_event, when recv buffer is matched and first chunk received, and + * when control message (with remote info) has also been received. + * 2. In MPIDI_OFI_recv_huge_control, as a callback when control message is received, and + * when first chunk has been matched and received. + * + * MPIDI_OFI_recv_huge_event will fill the local request information, and + * MPIDI_OFI_recv_huge_control will fill the remote (sender) information. Lastly -- + * + * 3. As the event function when RDMA read (issued here) completes. + */ +int MPIDI_OFI_get_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req) +{ + int mpi_errno = MPI_SUCCESS; + MPIDI_OFI_huge_recv_t *recv_elem = (MPIDI_OFI_huge_recv_t *) req; + MPIDI_OFI_huge_remote_info_t *info = MPIDI_OFI_REQUEST(recv_elem->localreq, huge.remote_info); + MPIR_Comm *comm = recv_elem->localreq->comm; + uint64_t remote_key; + size_t bytesLeft, bytesToGet; + MPIR_FUNC_ENTER; + + void *recv_buf = MPIDI_OFI_REQUEST(recv_elem->localreq, util.iov.iov_base); + + if (MPIDI_OFI_COMM(comm).enable_striping) { + /* Subtract one stripe_chunk_size because we send the first chunk via a regular message + * instead of the memory region */ + recv_elem->stripe_size = (info->msgsize - MPIDI_OFI_STRIPE_CHUNK_SIZE) + / MPIDI_OFI_global.num_nics; /* striping */ + + if (recv_elem->stripe_size > MPIDI_OFI_global.max_msg_size) { + recv_elem->stripe_size = MPIDI_OFI_global.max_msg_size; + } + if (recv_elem->chunks_outstanding) + recv_elem->chunks_outstanding--; + bytesLeft = info->msgsize - recv_elem->cur_offset; + bytesToGet = (bytesLeft <= recv_elem->stripe_size) ? bytesLeft : recv_elem->stripe_size; + } else { + /* Subtract one max_msg_size because we send the first chunk via a regular message + * instead of the memory region */ + bytesLeft = info->msgsize - recv_elem->cur_offset; + bytesToGet = (bytesLeft <= MPIDI_OFI_global.max_msg_size) ? + bytesLeft : MPIDI_OFI_global.max_msg_size; + } + if (bytesToGet == 0ULL && recv_elem->chunks_outstanding == 0) { + mpi_errno = get_huge_complete(recv_elem->localreq); + MPIR_ERR_CHECK(mpi_errno); + MPL_free(recv_elem); + goto fn_exit; + } + + int vni_src = info->vni_src; + int vni_dst = info->vni_dst; + if (MPIDI_OFI_COMM(comm).enable_striping) { /* if striping enabled */ + if (recv_elem->cur_offset >= MPIDI_OFI_STRIPE_CHUNK_SIZE && bytesLeft > 0) { + for (int nic = 0; nic < MPIDI_OFI_global.num_nics; nic++) { + int ctx_idx = MPIDI_OFI_get_ctx_index(comm, vni_dst, nic); + remote_key = info->rma_keys[nic]; + + bytesLeft = info->msgsize - recv_elem->cur_offset; + if (bytesLeft <= 0) { + break; + } + bytesToGet = + (bytesLeft <= recv_elem->stripe_size) ? bytesLeft : recv_elem->stripe_size; + + /* FIXME: Can we issue concurrent fi_read with the same context? */ + MPIDI_OFI_cntr_incr(recv_elem->comm_ptr, vni_src, nic); + MPIDI_OFI_CALL_RETRY(fi_read(MPIDI_OFI_global.ctx[ctx_idx].tx, (void *) ((char *) recv_buf + recv_elem->cur_offset), /* local buffer */ + bytesToGet, /* bytes */ + NULL, /* descriptor */ + MPIDI_OFI_comm_to_phys(comm, info->origin_rank, nic, vni_dst, vni_src), recv_rbase(info) + recv_elem->cur_offset, /* remote maddr */ + remote_key, /* Key */ + (void *) &recv_elem->context), nic, /* Context */ + rdma_readfrom, FALSE); + MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_recvd_bytes_count[nic], bytesToGet); + MPIR_T_PVAR_COUNTER_INC(MULTINIC, striped_nic_recvd_bytes_count[nic], bytesToGet); + recv_elem->cur_offset += bytesToGet; + recv_elem->chunks_outstanding++; + } + } + } else { + int nic = 0; + int ctx_idx = MPIDI_OFI_get_ctx_index(comm, vni_src, nic); + remote_key = info->rma_keys[nic]; + MPIDI_OFI_cntr_incr(comm, vni_src, nic); + MPIDI_OFI_CALL_RETRY(fi_read(MPIDI_OFI_global.ctx[ctx_idx].tx, /* endpoint */ + (void *) ((char *) recv_buf + recv_elem->cur_offset), /* local buffer */ + bytesToGet, /* bytes */ + NULL, /* descriptor */ + MPIDI_OFI_comm_to_phys(comm, info->origin_rank, nic, vni_src, vni_dst), /* Destination */ + recv_rbase(info) + recv_elem->cur_offset, /* remote maddr */ + remote_key, /* Key */ + (void *) &recv_elem->context), vni_src, rdma_readfrom, /* Context */ + FALSE); + MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_recvd_bytes_count[nic], bytesToGet); + recv_elem->cur_offset += bytesToGet; + } + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index 2a52421ff03..24d47384a4b 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -308,6 +308,11 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_OFI_cntr_set(int ctx_idx, int val) #define MPIDI_OFI_COLL_MR_KEY 1 #define MPIDI_OFI_INVALID_MR_KEY 0xFFFFFFFFFFFFFFFFULL int MPIDI_OFI_retry_progress(void); +int MPIDI_OFI_recv_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq); +int MPIDI_OFI_recv_huge_control(MPIR_Context_id_t comm_id, int rank, int tag, + MPIDI_OFI_huge_remote_info_t * info); +int MPIDI_OFI_peek_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * rreq); +int MPIDI_OFI_get_huge_event(int vni, struct fi_cq_tagged_entry *wc, MPIR_Request * req); int MPIDI_OFI_control_handler(void *am_hdr, void *data, MPI_Aint data_sz, uint32_t attr, MPIR_Request ** req); int MPIDI_OFI_am_rdma_read_ack_handler(void *am_hdr, void *data, diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index eb941a3183a..d9e20b0a288 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -558,10 +558,6 @@ int MPIDI_OFI_init_local(int *tag_bits) MPIDIU_map_create(&MPIDI_OFI_global.win_map, MPL_MEM_RMA); MPIDIU_map_create(&MPIDI_OFI_global.req_map, MPL_MEM_OTHER); - /* Create huge protocol maps */ - MPIDIU_map_create(&MPIDI_OFI_global.huge_send_counters, MPL_MEM_COMM); - MPIDIU_map_create(&MPIDI_OFI_global.huge_recv_counters, MPL_MEM_COMM); - /* Initialize RMA keys allocator */ MPIDI_OFI_mr_key_allocator_init(); @@ -904,9 +900,6 @@ int MPIDI_OFI_mpi_finalize_hook(void) MPIDIU_map_destroy(MPIDI_OFI_global.win_map); MPIDIU_map_destroy(MPIDI_OFI_global.req_map); - MPIDIU_map_destroy(MPIDI_OFI_global.huge_send_counters); - MPIDIU_map_destroy(MPIDI_OFI_global.huge_recv_counters); - if (MPIDI_OFI_ENABLE_AM) { for (int vni = 0; vni < MPIDI_OFI_global.num_vnis; vni++) { while (MPIDI_OFI_global.per_vni[vni].am_unordered_msgs) { diff --git a/src/mpid/ch4/netmod/ofi/ofi_pre.h b/src/mpid/ch4/netmod/ofi/ofi_pre.h index 593de7ef64e..e56f6c65e15 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_pre.h +++ b/src/mpid/ch4/netmod/ofi/ofi_pre.h @@ -176,6 +176,11 @@ typedef struct { MPI_Aint data_sz; /* save data_sz to avoid double checking */ } MPIDI_OFI_am_request_t; +enum MPIDI_OFI_req_kind { + MPIDI_OFI_req_kind__any, + MPIDI_OFI_req_kind__probe, + MPIDI_OFI_req_kind__mprobe, +}; typedef struct { struct fi_context context[MPIDI_OFI_CONTEXT_STRUCTS]; /* fixed field, do not move */ @@ -184,6 +189,11 @@ typedef struct { MPI_Datatype datatype; int nic_num; /* Store the nic number so we can use it to cancel a request later * if needed. */ + enum MPIDI_OFI_req_kind kind; + union { + struct fid_mr **send_mrs; + void *remote_info; + } huge; union { struct { void *buf; diff --git a/src/mpid/ch4/netmod/ofi/ofi_probe.h b/src/mpid/ch4/netmod/ofi/ofi_probe.h index 14181444b55..cae8b5fc38f 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_probe.h +++ b/src/mpid/ch4/netmod/ofi/ofi_probe.h @@ -14,8 +14,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_iprobe(int source, int context_offset, MPIDI_av_entry_t * addr, int vni_src, int vni_dst, int *flag, - MPI_Status * status, - MPIR_Request ** message, uint64_t peek_flags) + MPI_Status * status, MPIR_Request ** message) { int mpi_errno = MPI_SUCCESS; fi_addr_t remote_proc; @@ -41,6 +40,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_iprobe(int source, } else { rreq = &r; } + if (message) { + MPIDI_OFI_REQUEST(rreq, kind) = MPIDI_OFI_req_kind__mprobe; + } else { + MPIDI_OFI_REQUEST(rreq, kind) = MPIDI_OFI_req_kind__probe; + } + MPIDI_OFI_REQUEST(rreq, huge.remote_info) = NULL; rreq->comm = comm; MPIR_Comm_add_ref(comm); @@ -58,8 +63,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_iprobe(int source, msg.context = (void *) &(MPIDI_OFI_REQUEST(rreq, context)); msg.data = 0; - MPIDI_OFI_CALL_RETURN(fi_trecvmsg(MPIDI_OFI_global.ctx[ctx_idx].rx, &msg, - peek_flags | FI_PEEK | FI_COMPLETION), ofi_err); + uint64_t recv_flags = FI_PEEK | FI_COMPLETION; + if (message) { + recv_flags |= FI_CLAIM; + } + MPIDI_OFI_CALL_RETURN(fi_trecvmsg(MPIDI_OFI_global.ctx[ctx_idx].rx, &msg, recv_flags), ofi_err); if (ofi_err == -FI_ENOMSG) { *flag = 0; if (message) @@ -138,7 +146,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_improbe(int source, MPIDI_OFI_THREAD_CS_ENTER_VCI_OPTIONAL(vni_dst); /* Set flags for mprobe peek, when ready */ mpi_errno = MPIDI_OFI_do_iprobe(source, tag, comm, context_offset, addr, vni_src, vni_dst, - flag, status, message, FI_CLAIM | FI_COMPLETION); + flag, status, message); MPIDI_OFI_THREAD_CS_EXIT_VCI_OPTIONAL(vni_dst); if (mpi_errno != MPI_SUCCESS) @@ -166,7 +174,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iprobe(int source, } else { MPIDI_OFI_THREAD_CS_ENTER_VCI_OPTIONAL(vni_dst); mpi_errno = MPIDI_OFI_do_iprobe(source, tag, comm, context_offset, addr, - vni_src, vni_dst, flag, status, NULL, 0ULL); + vni_src, vni_dst, flag, status, NULL); MPIDI_OFI_THREAD_CS_EXIT_VCI_OPTIONAL(vni_dst); } diff --git a/src/mpid/ch4/netmod/ofi/ofi_recv.h b/src/mpid/ch4/netmod/ofi/ofi_recv.h index 40049aa2dc5..3a4d389af96 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_recv.h +++ b/src/mpid/ch4/netmod/ofi/ofi_recv.h @@ -156,6 +156,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf, } *request = rreq; + MPIDI_OFI_REQUEST(rreq, kind) = MPIDI_OFI_req_kind__any; + if (!flags) { + MPIDI_OFI_REQUEST(rreq, huge.remote_info) = NULL; /* for huge recv remote info */ + } /* Calculate the correct NICs. */ sender_nic = MPIDI_OFI_multx_sender_nic_index(comm, comm->recvcontext_id, MPIR_Process.rank, @@ -226,6 +230,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf, } /* Read ordering unnecessary for context_id, so use relaxed load */ MPL_atomic_relaxed_store_int(&MPIDI_OFI_REQUEST(rreq, util_id), context_id); + MPIDI_OFI_REQUEST(rreq, util.iov.iov_base) = recv_buf; + MPIDI_OFI_REQUEST(rreq, util.iov.iov_len) = data_sz; if (unlikely(data_sz >= MPIDI_OFI_global.max_msg_size) && !MPIDI_OFI_COMM(comm).enable_striping) { MPIDI_OFI_REQUEST(rreq, event_id) = MPIDI_OFI_EVENT_RECV_HUGE; @@ -240,8 +246,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf, } else if (MPIDI_OFI_REQUEST(rreq, event_id) != MPIDI_OFI_EVENT_RECV_PACK) MPIDI_OFI_REQUEST(rreq, event_id) = MPIDI_OFI_EVENT_RECV; - MPIDI_OFI_REQUEST(rreq, util.iov.iov_base) = recv_buf; - MPIDI_OFI_REQUEST(rreq, util.iov.iov_len) = data_sz; if (!flags) { MPIDI_OFI_CALL_RETRY(fi_trecv(MPIDI_OFI_global.ctx[ctx_idx].rx, recv_buf, diff --git a/src/mpid/ch4/netmod/ofi/ofi_send.h b/src/mpid/ch4/netmod/ofi/ofi_send.h index 124cd8cbe3c..406c11502a3 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_send.h +++ b/src/mpid/ch4/netmod/ofi/ofi_send.h @@ -278,14 +278,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_normal(const void *buf, MPI_Aint cou tsenddata, FALSE /* eagain */); MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_sent_bytes_count[sender_nic], data_sz); } else if (unlikely(1)) { - MPIDI_OFI_send_control_t ctrl; - int i, num_nics = MPIDI_OFI_global.num_nics; + int num_nics = MPIDI_OFI_global.num_nics; uint64_t rma_keys[MPIDI_OFI_MAX_NICS]; struct fid_mr **huge_send_mrs; uint64_t msg_size = MPIDI_OFI_STRIPE_CHUNK_SIZE; - MPIDI_OFI_REQUEST(sreq, event_id) = MPIDI_OFI_EVENT_SEND_HUGE; - MPIR_cc_inc(sreq->cc_ptr); if (!MPIDI_OFI_COMM(comm).enable_striping) { num_nics = 1; @@ -295,18 +292,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_normal(const void *buf, MPI_Aint cou (struct fid_mr **) MPL_malloc((num_nics * sizeof(struct fid_mr *)), MPL_MEM_BUFFER); if (!MPIDI_OFI_ENABLE_MR_PROV_KEY) { /* Set up a memory region for the lmt data transfer */ - for (i = 0; i < num_nics; i++) { - ctrl.rma_keys[i] = + for (int i = 0; i < num_nics; i++) { + rma_keys[i] = MPIDI_OFI_mr_key_alloc(MPIDI_OFI_LOCAL_MR_KEY, MPIDI_OFI_INVALID_MR_KEY); - rma_keys[i] = ctrl.rma_keys[i]; } } else { /* zero them to avoid warnings */ - for (i = 0; i < num_nics; i++) { + for (int i = 0; i < num_nics; i++) { rma_keys[i] = 0; } } - for (i = 0; i < num_nics; i++) { + for (int i = 0; i < num_nics; i++) { MPIDI_OFI_CALL(fi_mr_reg(MPIDI_OFI_global.ctx[MPIDI_OFI_get_ctx_index(comm, vni_local, i)].domain, /* In: Domain Object */ send_buf, /* In: Lower memory address */ data_sz, /* In: Length */ @@ -317,13 +313,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_normal(const void *buf, MPI_Aint cou &huge_send_mrs[i], /* Out: memregion object */ NULL), mr_reg); /* In: context */ } - /* Create map to the memory region */ - MPIDIU_map_set(MPIDI_OFI_global.huge_send_counters, sreq->handle, huge_send_mrs, - MPL_MEM_BUFFER); + MPIDI_OFI_REQUEST(sreq, huge.send_mrs) = huge_send_mrs; if (MPIDI_OFI_ENABLE_MR_PROV_KEY) { /* MR_BASIC */ - for (i = 0; i < num_nics; i++) { - ctrl.rma_keys[i] = fi_mr_key(huge_send_mrs[i]); + for (int i = 0; i < num_nics; i++) { + rma_keys[i] = fi_mr_key(huge_send_mrs[i]); } } @@ -335,6 +329,29 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_normal(const void *buf, MPI_Aint cou MPIR_Comm_add_ref(comm); /* Store ordering unnecessary for dst_rank, so use relaxed store */ MPL_atomic_relaxed_store_int(&MPIDI_OFI_REQUEST(sreq, util_id), dst_rank); + + /* send ctrl message first */ + MPIDI_OFI_send_control_t ctrl; + ctrl.type = MPIDI_OFI_CTRL_HUGE; + for (int i = 0; i < num_nics; i++) { + ctrl.u.huge.info.rma_keys[i] = rma_keys[i]; + } + ctrl.u.huge.info.comm_id = comm->context_id; + ctrl.u.huge.info.tag = tag; + ctrl.u.huge.info.origin_rank = comm->rank; + ctrl.u.huge.info.vni_src = vni_src; + ctrl.u.huge.info.vni_dst = vni_dst; + ctrl.u.huge.info.send_buf = send_buf; + ctrl.u.huge.info.msgsize = data_sz; + ctrl.u.huge.info.ackreq = sreq; + + mpi_errno = MPIDI_NM_am_send_hdr(dst_rank, comm, MPIDI_OFI_INTERNAL_HANDLER_CONTROL, + &ctrl, sizeof(ctrl), vni_src, vni_dst); + MPIR_ERR_CHECK(mpi_errno); + + /* send main native message next */ + MPIDI_OFI_REQUEST(sreq, event_id) = MPIDI_OFI_EVENT_SEND_HUGE; + match_bits |= MPIDI_OFI_HUGE_SEND; /* Add the bit for a huge message */ MPIDI_OFI_CALL_RETRY(fi_tsenddata(MPIDI_OFI_global.ctx[ctx_idx].tx, send_buf, msg_size, NULL /* desc */ , @@ -346,15 +363,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_normal(const void *buf, MPI_Aint cou vni_local, tsenddata, FALSE /* eagain */); MPIR_T_PVAR_COUNTER_INC(MULTINIC, nic_sent_bytes_count[sender_nic], msg_size); MPIR_T_PVAR_COUNTER_INC(MULTINIC, striped_nic_sent_bytes_count[sender_nic], msg_size); - ctrl.type = MPIDI_OFI_CTRL_HUGE; - ctrl.seqno = 0; - ctrl.tag = tag; - ctrl.vni_src = vni_src; - ctrl.vni_dst = vni_dst; - - /* Send information about the memory region here to get the lmt going. */ - mpi_errno = MPIDI_OFI_do_control_send(&ctrl, send_buf, data_sz, dst_rank, comm, sreq); - MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpid/ch4/netmod/ofi/ofi_types.h b/src/mpid/ch4/netmod/ofi/ofi_types.h index d37206c08bc..3d50f4f44cd 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_types.h +++ b/src/mpid/ch4/netmod/ofi/ofi_types.h @@ -361,10 +361,6 @@ typedef struct { * OFI provider at MPI initialization.*/ MPIDI_OFI_atomic_valid_t win_op_table[MPIR_DATATYPE_N_PREDEFINED][MPIDIG_ACCU_NUM_OP]; - /* huge protocol globals */ - void *huge_send_counters; - void *huge_recv_counters; - /* Active Message Globals */ MPL_atomic_int_t am_inflight_inject_emus; MPL_atomic_int_t am_inflight_rma_send_mrs; @@ -386,17 +382,27 @@ typedef struct { } MPIDI_OFI_global_t; typedef struct { - int16_t type; - int16_t seqno; + MPIR_Context_id_t comm_id; int origin_rank; + int tag; MPIR_Request *ackreq; - uintptr_t send_buf; + void *send_buf; size_t msgsize; - int comm_id; uint64_t rma_keys[MPIDI_OFI_MAX_NICS]; - int tag; int vni_src; int vni_dst; +} MPIDI_OFI_huge_remote_info_t; + +typedef struct { + int16_t type; + union { + struct { + MPIDI_OFI_huge_remote_info_t info; + } huge; + struct { + MPIR_Request *ackreq; + } huge_ack; + } u; } MPIDI_OFI_send_control_t; typedef struct MPIDI_OFI_win_acc_hint { @@ -493,18 +499,11 @@ typedef struct MPIDI_OFI_huge_recv { char pad[MPIDI_REQUEST_HDR_SIZE]; struct fi_context context[MPIDI_OFI_CONTEXT_STRUCTS]; /* fixed field, do not move */ int event_id; /* fixed field, do not move */ - int (*done_fn) (int vni, struct fi_cq_tagged_entry * wc, MPIR_Request * req, int event_id); - MPIDI_OFI_send_control_t remote_info; - bool peek; /* Flag to indicate whether this struct has been created to track an uncompleted peek - * operation. */ size_t cur_offset; size_t stripe_size; int chunks_outstanding; MPIR_Comm *comm_ptr; MPIR_Request *localreq; - struct fi_cq_tagged_entry wc; - struct MPIDI_OFI_huge_recv *next; /* Points to the next entry in the unexpected list - * (when in the unexpected list) */ } MPIDI_OFI_huge_recv_t; /* The list of posted huge receives that haven't been matched yet. These need @@ -512,19 +511,22 @@ typedef struct MPIDI_OFI_huge_recv { * data from the remote memory region and we need a way of matching up the * control messages with the "real" requests. */ typedef struct MPIDI_OFI_huge_recv_list { - int comm_id; + MPIR_Context_id_t comm_id; int rank; int tag; - MPIR_Request *rreq; + union { + MPIDI_OFI_huge_remote_info_t *info; /* ctrl list */ + MPIR_Request *rreq; /* recv list */ + } u; struct MPIDI_OFI_huge_recv_list *next; } MPIDI_OFI_huge_recv_list_t; /* Externs */ extern MPIDI_OFI_global_t MPIDI_OFI_global; -extern MPIDI_OFI_huge_recv_t *MPIDI_unexp_huge_recv_head; -extern MPIDI_OFI_huge_recv_t *MPIDI_unexp_huge_recv_tail; -extern MPIDI_OFI_huge_recv_list_t *MPIDI_posted_huge_recv_head; -extern MPIDI_OFI_huge_recv_list_t *MPIDI_posted_huge_recv_tail; +extern MPIDI_OFI_huge_recv_list_t *MPIDI_huge_ctrl_head; +extern MPIDI_OFI_huge_recv_list_t *MPIDI_huge_ctrl_tail; +extern MPIDI_OFI_huge_recv_list_t *MPIDI_huge_recv_head; +extern MPIDI_OFI_huge_recv_list_t *MPIDI_huge_recv_tail; extern MPIDI_OFI_capabilities_t MPIDI_OFI_caps_list[MPIDI_OFI_NUM_SETS]; diff --git a/src/mpid/ch4/netmod/ofi/util.c b/src/mpid/ch4/netmod/ofi/util.c index 955e88e7860..870e1a2362e 100644 --- a/src/mpid/ch4/netmod/ofi/util.c +++ b/src/mpid/ch4/netmod/ofi/util.c @@ -150,87 +150,6 @@ void MPIDI_OFI_mr_key_allocator_destroy(void) MPL_free(mr_key_allocator.bitmask); } -/* Translate the control message to get a huge message into a request to - * actually perform the data transfer. */ -static int MPIDI_OFI_get_huge(int vni, MPIDI_OFI_send_control_t * info) -{ - MPIDI_OFI_huge_recv_t *recv_elem = NULL; - int mpi_errno = MPI_SUCCESS; - MPIR_FUNC_ENTER; - - bool ready_to_get = false; - - /* If there has been a posted receive, search through the list of unmatched - * receives to find the one that goes with the incoming message. */ - { - MPIDI_OFI_huge_recv_list_t *list_ptr; - - MPL_DBG_MSG_FMT(MPIR_DBG_PT2PT, VERBOSE, - (MPL_DBG_FDEST, "SEARCHING POSTED LIST: (%d, %d, %d)", info->comm_id, - info->origin_rank, info->tag)); - - LL_FOREACH(MPIDI_posted_huge_recv_head, list_ptr) { - if (list_ptr->comm_id == info->comm_id && - list_ptr->rank == info->origin_rank && list_ptr->tag == info->tag) { - MPL_DBG_MSG_FMT(MPIR_DBG_PT2PT, VERBOSE, - (MPL_DBG_FDEST, "MATCHED POSTED LIST: (%d, %d, %d, %d)", - info->comm_id, info->origin_rank, info->tag, - list_ptr->rreq->handle)); - - LL_DELETE(MPIDI_posted_huge_recv_head, MPIDI_posted_huge_recv_tail, list_ptr); - - recv_elem = (MPIDI_OFI_huge_recv_t *) - MPIDIU_map_lookup(MPIDI_OFI_global.huge_recv_counters, list_ptr->rreq->handle); - - /* If this is a "peek" element for an MPI_Probe, it shouldn't be matched. Grab the - * important information and remove the element from the list. */ - if (recv_elem->peek) { - MPIR_STATUS_SET_COUNT(recv_elem->localreq->status, info->msgsize); - MPL_atomic_release_store_int(&(MPIDI_OFI_REQUEST(recv_elem->localreq, util_id)), - MPIDI_OFI_PEEK_FOUND); - MPIDIU_map_erase(MPIDI_OFI_global.huge_recv_counters, - recv_elem->localreq->handle); - MPL_free(recv_elem); - recv_elem = NULL; - } - - MPL_free(list_ptr); - break; - } - } - } - - if (recv_elem) { - ready_to_get = true; - } else { - /* Put the struct describing the transfer on an unexpected list to be retrieved later */ - MPL_DBG_MSG_FMT(MPIR_DBG_PT2PT, VERBOSE, - (MPL_DBG_FDEST, "CREATING UNEXPECTED HUGE RECV: (%d, %d, %d)", - info->comm_id, info->origin_rank, info->tag)); - - /* If this is unexpected, create a new tracker and put it in the unexpected list. */ - recv_elem = (MPIDI_OFI_huge_recv_t *) MPL_calloc(sizeof(*recv_elem), 1, MPL_MEM_COMM); - if (!recv_elem) - MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); - - LL_APPEND(MPIDI_unexp_huge_recv_head, MPIDI_unexp_huge_recv_tail, recv_elem); - } - - recv_elem->event_id = MPIDI_OFI_EVENT_GET_HUGE; - recv_elem->remote_info = *info; - recv_elem->next = NULL; - if (ready_to_get) { - MPIDI_OFI_get_huge_event(vni, NULL, (MPIR_Request *) recv_elem); - } - - MPIR_FUNC_EXIT; - - fn_exit: - return mpi_errno; - fn_fail: - goto fn_exit; -} - int MPIDI_OFI_control_handler(void *am_hdr, void *data, MPI_Aint data_sz, uint32_t attr, MPIR_Request ** req) { @@ -241,18 +160,20 @@ int MPIDI_OFI_control_handler(void *am_hdr, void *data, MPI_Aint data_sz, *req = NULL; } + int local_vci = MPIDIG_AM_ATTR_DST_VCI(attr); + MPIR_AssertDeclValue(int remote_vci, MPIDIG_AM_ATTR_SRC_VCI(attr)); switch (ctrlsend->type) { - case MPIDI_OFI_CTRL_HUGEACK:{ - /* FIXME: need vni from the callback parameters */ - mpi_errno = MPIDI_OFI_dispatch_function(0, NULL, ctrlsend->ackreq); - goto fn_exit; - } + case MPIDI_OFI_CTRL_HUGEACK: + mpi_errno = MPIDI_OFI_dispatch_function(local_vci, NULL, ctrlsend->u.huge_ack.ackreq); break; - case MPIDI_OFI_CTRL_HUGE:{ - mpi_errno = MPIDI_OFI_get_huge(0, ctrlsend); - goto fn_exit; - } + case MPIDI_OFI_CTRL_HUGE: + MPIR_Assert(local_vci == ctrlsend->u.huge.info.vni_dst); + MPIR_Assert(remote_vci == ctrlsend->u.huge.info.vni_src); + mpi_errno = MPIDI_OFI_recv_huge_control(ctrlsend->u.huge.info.comm_id, + ctrlsend->u.huge.info.origin_rank, + ctrlsend->u.huge.info.tag, + &(ctrlsend->u.huge.info)); break; default: diff --git a/test/mpi/errors/pt2pt/testlist b/test/mpi/errors/pt2pt/testlist index 4e1e629b5a1..e21ed1fd68b 100644 --- a/test/mpi/errors/pt2pt/testlist +++ b/test/mpi/errors/pt2pt/testlist @@ -1,5 +1,6 @@ proberank 1 truncmsg1 2 +truncmsg1 2 env=MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE=16384 truncmsg2 2 truncmsg_mrecv 2 mpiversion=3.0 # multiple completion ests diff --git a/test/mpi/errors/pt2pt/truncmsg1.c b/test/mpi/errors/pt2pt/truncmsg1.c index 7ebe7402808..2328a27bb83 100644 --- a/test/mpi/errors/pt2pt/truncmsg1.c +++ b/test/mpi/errors/pt2pt/truncmsg1.c @@ -83,6 +83,14 @@ int main(int argc, char *argv[]) err = MPI_Recv(buf, LongLen - 1, MPI_INT, source, 0, MPI_COMM_WORLD, &status); errs += checkTruncError(err, "long"); } + /* Test when the receive buffer is much shorter */ + if (rank == source) { + err = MPI_Send(buf, LongLen, MPI_INT, dest, 0, MPI_COMM_WORLD); + errs += checkOk(err, "long"); + } else if (rank == dest) { + err = MPI_Recv(buf, ShortLen, MPI_INT, source, 0, MPI_COMM_WORLD, &status); + errs += checkTruncError(err, "long-receive-short"); + } } free(buf); diff --git a/test/mpi/maint/jenkins/xfail.conf b/test/mpi/maint/jenkins/xfail.conf index dc2c31b6148..09ad90d5a8c 100644 --- a/test/mpi/maint/jenkins/xfail.conf +++ b/test/mpi/maint/jenkins/xfail.conf @@ -35,7 +35,6 @@ * * * ch4:ofi * /^idup_nb/ xfail=ticket3794 test/mpi/threads/comm/testlist * * * ch4:ucx * /^idup_comm_gen/ xfail=ticket3794 test/mpi/threads/comm/testlist * * * ch4:ucx * /^idup_nb/ xfail=ticket3794 test/mpi/threads/comm/testlist -* * * ch4:ofi * /^mt_.*_huge.* env=MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE=16384/ xfail=ticket5359 test/mpi/threads/pt2pt/testlist ################################################################################ # misc special build * * nofast * * /^large_acc_flush_local/ xfail=issue4663 test/mpi/rma/testlist diff --git a/test/mpi/pt2pt/probe_unexp.c b/test/mpi/pt2pt/probe_unexp.c index 0dda3336b61..3e867f15d4f 100644 --- a/test/mpi/pt2pt/probe_unexp.c +++ b/test/mpi/pt2pt/probe_unexp.c @@ -17,8 +17,10 @@ char buf[1 << MAX_BUF_SIZE_LG]; * been called. This program may hang if MPI_Probe() does not return when the * message finally arrives (see req #375). */ + int main(int argc, char **argv) { + MPI_Comm comm; int p_size; int p_rank; int msg_size_lg; @@ -27,115 +29,121 @@ int main(int argc, char **argv) MTest_Init(&argc, &argv); - MPI_Comm_size(MPI_COMM_WORLD, &p_size); - MPI_Comm_rank(MPI_COMM_WORLD, &p_rank); - /* To improve reporting of problems about operations, we - * change the error handler to errors return */ - MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN); - - - for (msg_size_lg = 0; msg_size_lg <= MAX_BUF_SIZE_LG; msg_size_lg++) { - const int msg_size = 1 << msg_size_lg; - int msg_cnt; + while (MTestGetIntracommGeneral(&comm, 2, 1)) { + if (comm == MPI_COMM_NULL) { + continue; + } - MTestPrintfMsg(2, "testing messages of size %d\n", msg_size); - for (msg_cnt = 0; msg_cnt < NUM_MSGS_PER_BUF_SIZE; msg_cnt++) { - MPI_Status status; - const int tag = msg_size_lg * NUM_MSGS_PER_BUF_SIZE + msg_cnt; + MPI_Comm_size(comm, &p_size); + MPI_Comm_rank(comm, &p_rank); + /* To improve reporting of problems about operations, we + * change the error handler to errors return */ + MPI_Comm_set_errhandler(comm, MPI_ERRORS_RETURN); + + + for (msg_size_lg = 0; msg_size_lg <= MAX_BUF_SIZE_LG; msg_size_lg++) { + const int msg_size = 1 << msg_size_lg; + int msg_cnt; + + MTestPrintfMsg(2, "testing messages of size %d\n", msg_size); + for (msg_cnt = 0; msg_cnt < NUM_MSGS_PER_BUF_SIZE; msg_cnt++) { + MPI_Status status; + const int tag = msg_size_lg * NUM_MSGS_PER_BUF_SIZE + msg_cnt; + + MTestPrintfMsg(2, "Message count %d\n", msg_cnt); + if (p_rank == 0) { + int p; + + for (p = 1; p < p_size; p++) { + /* Wait for synchronization message */ + mpi_errno = MPI_Recv(NULL, 0, MPI_BYTE, MPI_ANY_SOURCE, tag, comm, &status); + if (mpi_errno != MPI_SUCCESS && errs++ < 10) { + MTestPrintError(mpi_errno); + } + + if (status.MPI_TAG != tag && errs++ < 10) { + printf + ("ERROR: unexpected message tag from MPI_Recv(): lp=0, rp=%d, expected=%d, actual=%d, count=%d\n", + status.MPI_SOURCE, status.MPI_TAG, tag, msg_cnt); + } +# if defined(VERBOSE) + { + printf("sending message: p=%d s=%d c=%d\n", + status.MPI_SOURCE, msg_size, msg_cnt); + } +# endif - MTestPrintfMsg(2, "Message count %d\n", msg_cnt); - if (p_rank == 0) { - int p; + /* Send unexpected message which hopefully MPI_Probe() is + * already waiting for at the remote process */ + mpi_errno = MPI_Send(buf, msg_size, MPI_BYTE, + status.MPI_SOURCE, status.MPI_TAG, comm); + if (mpi_errno != MPI_SUCCESS && errs++ < 10) { + MTestPrintError(mpi_errno); + } + } + } else { + int incoming_msg_size; - for (p = 1; p < p_size; p++) { - /* Wait for synchronization message */ - mpi_errno = MPI_Recv(NULL, 0, MPI_BYTE, MPI_ANY_SOURCE, - tag, MPI_COMM_WORLD, &status); + /* Send synchronization message */ + mpi_errno = MPI_Send(NULL, 0, MPI_BYTE, 0, tag, comm); if (mpi_errno != MPI_SUCCESS && errs++ < 10) { MTestPrintError(mpi_errno); } + /* Perform probe, hopefully before the main process can + * send its reply */ + mpi_errno = MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, comm, &status); + if (mpi_errno != MPI_SUCCESS && errs++ < 10) { + MTestPrintError(mpi_errno); + } + mpi_errno = MPI_Get_count(&status, MPI_BYTE, &incoming_msg_size); + if (mpi_errno != MPI_SUCCESS && errs++ < 10) { + MTestPrintError(mpi_errno); + } + if (status.MPI_SOURCE != 0 && errs++ < 10) { + printf + ("ERROR: unexpected message source from MPI_Probe(): p=%d, expected=0, actual=%d, count=%d\n", + p_rank, status.MPI_SOURCE, msg_cnt); + } if (status.MPI_TAG != tag && errs++ < 10) { printf - ("ERROR: unexpected message tag from MPI_Recv(): lp=0, rp=%d, expected=%d, actual=%d, count=%d\n", - status.MPI_SOURCE, status.MPI_TAG, tag, msg_cnt); + ("ERROR: unexpected message tag from MPI_Probe(): p=%d, expected=%d, actual=%d, count=%d\n", + p_rank, tag, status.MPI_TAG, msg_cnt); } -# if defined(VERBOSE) - { - printf("sending message: p=%d s=%d c=%d\n", - status.MPI_SOURCE, msg_size, msg_cnt); + if (incoming_msg_size != msg_size && errs++ < 10) { + printf + ("ERROR: unexpected message size from MPI_Probe(): p=%d, expected=%d, actual=%d, count=%d\n", + p_rank, msg_size, incoming_msg_size, msg_cnt); } -# endif - /* Send unexpected message which hopefully MPI_Probe() is - * already waiting for at the remote process */ - mpi_errno = MPI_Send(buf, msg_size, MPI_BYTE, - status.MPI_SOURCE, status.MPI_TAG, MPI_COMM_WORLD); + /* Receive the probed message from the main process */ + mpi_errno = MPI_Recv(buf, msg_size, MPI_BYTE, 0, tag, comm, &status); if (mpi_errno != MPI_SUCCESS && errs++ < 10) { MTestPrintError(mpi_errno); } - } - } else { - int incoming_msg_size; - - /* Send synchronization message */ - mpi_errno = MPI_Send(NULL, 0, MPI_BYTE, 0, tag, MPI_COMM_WORLD); - if (mpi_errno != MPI_SUCCESS && errs++ < 10) { - MTestPrintError(mpi_errno); - } - - /* Perform probe, hopefully before the main process can - * send its reply */ - mpi_errno = MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - if (mpi_errno != MPI_SUCCESS && errs++ < 10) { - MTestPrintError(mpi_errno); - } - mpi_errno = MPI_Get_count(&status, MPI_BYTE, &incoming_msg_size); - if (mpi_errno != MPI_SUCCESS && errs++ < 10) { - MTestPrintError(mpi_errno); - } - if (status.MPI_SOURCE != 0 && errs++ < 10) { - printf - ("ERROR: unexpected message source from MPI_Probe(): p=%d, expected=0, actual=%d, count=%d\n", - p_rank, status.MPI_SOURCE, msg_cnt); - } - if (status.MPI_TAG != tag && errs++ < 10) { - printf - ("ERROR: unexpected message tag from MPI_Probe(): p=%d, expected=%d, actual=%d, count=%d\n", - p_rank, tag, status.MPI_TAG, msg_cnt); - } - if (incoming_msg_size != msg_size && errs++ < 10) { - printf - ("ERROR: unexpected message size from MPI_Probe(): p=%d, expected=%d, actual=%d, count=%d\n", - p_rank, msg_size, incoming_msg_size, msg_cnt); - } - - /* Receive the probed message from the main process */ - mpi_errno = MPI_Recv(buf, msg_size, MPI_BYTE, 0, tag, MPI_COMM_WORLD, &status); - if (mpi_errno != MPI_SUCCESS && errs++ < 10) { - MTestPrintError(mpi_errno); - } - mpi_errno = MPI_Get_count(&status, MPI_BYTE, &incoming_msg_size); - if (mpi_errno != MPI_SUCCESS && errs++ < 10) { - MTestPrintError(mpi_errno); - } - if (status.MPI_SOURCE != 0 && errs++ < 10) { - printf - ("ERROR: unexpected message source from MPI_Recv(): p=%d, expected=0, actual=%d, count=%d\n", - p_rank, status.MPI_SOURCE, msg_cnt); - } - if (status.MPI_TAG != tag && errs++ < 10) { - printf - ("ERROR: unexpected message tag from MPI_Recv(): p=%d, expected=%d, actual=%d, count=%d\n", - p_rank, tag, status.MPI_TAG, msg_cnt); - } - if (incoming_msg_size != msg_size && errs++ < 10) { - printf - ("ERROR: unexpected message size from MPI_Recv(): p=%d, expected=%d, actual=%d, count=%d\n", - p_rank, msg_size, incoming_msg_size, msg_cnt); + mpi_errno = MPI_Get_count(&status, MPI_BYTE, &incoming_msg_size); + if (mpi_errno != MPI_SUCCESS && errs++ < 10) { + MTestPrintError(mpi_errno); + } + if (status.MPI_SOURCE != 0 && errs++ < 10) { + printf + ("ERROR: unexpected message source from MPI_Recv(): p=%d, expected=0, actual=%d, count=%d\n", + p_rank, status.MPI_SOURCE, msg_cnt); + } + if (status.MPI_TAG != tag && errs++ < 10) { + printf + ("ERROR: unexpected message tag from MPI_Recv(): p=%d, expected=%d, actual=%d, count=%d\n", + p_rank, tag, status.MPI_TAG, msg_cnt); + } + if (incoming_msg_size != msg_size && errs++ < 10) { + printf + ("ERROR: unexpected message size from MPI_Recv(): p=%d, expected=%d, actual=%d, count=%d\n", + p_rank, msg_size, incoming_msg_size, msg_cnt); + } } } } + MTestFreeComm(&comm); } MTest_Finalize(errs); diff --git a/test/mpi/pt2pt/testlist.in b/test/mpi/pt2pt/testlist.in index 66656e5c02f..9f713342e95 100644 --- a/test/mpi/pt2pt/testlist.in +++ b/test/mpi/pt2pt/testlist.in @@ -52,10 +52,14 @@ waitany_null 1 # perhaps disable in the release tarball large_message 3 mem=6.5 mprobe 2 +mprobe 2 env=MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE=16384 +mprobe 2 env=MPIR_CVAR_CH4_OFI_AM_LONG_FORCE_PIPELINE=true big_count_status 1 many_isend 3 manylmt 2 huge_underflow 2 +huge_underflow 2 env=MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE=16384 +huge_underflow 2 env=MPIR_CVAR_CH4_OFI_AM_LONG_FORCE_PIPELINE=true huge_anysrc 2 huge_dupcomm 2 huge_ssend 2 diff --git a/test/mpi/threads/pt2pt/Makefile.am b/test/mpi/threads/pt2pt/Makefile.am index 536a1b20150..aed8b7e4e2c 100644 --- a/test/mpi/threads/pt2pt/Makefile.am +++ b/test/mpi/threads/pt2pt/Makefile.am @@ -8,7 +8,7 @@ include $(top_srcdir)/Makefile_threads.mtest EXTRA_DIST = testlist noinst_PROGRAMS = threads threaded_sr alltoall sendselfth greq_wait greq_test \ - multisend multisend2 multisend3 multisend4 ibsend \ + multisend multisend2 multisend3 multisend4 ibsend ssend \ mt_sendrecv mt_bsendrecv mt_ssendrecv \ mt_isendirecv mt_ibsendirecv mt_issendirecv \ mt_sendrecv_huge mt_bsendrecv_huge mt_ssendrecv_huge \ diff --git a/test/mpi/threads/pt2pt/ssend.c b/test/mpi/threads/pt2pt/ssend.c new file mode 100644 index 00000000000..c4b165f5191 --- /dev/null +++ b/test/mpi/threads/pt2pt/ssend.c @@ -0,0 +1,137 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#include +#include +#include "mpitest.h" +#include "mpithreadtest.h" + +#define MAX_COUNT 1024 * 1600 + +#define NUM_THREADS 4 +#define NUM_MSG_PER_THREAD 10 +#define NUM_CHECK 4 + +#define TOTAL NUM_THREADS * NUM_MSG_PER_THREAD + +int buf[TOTAL][MAX_COUNT]; +MPI_Request reqs[TOTAL]; +int counts[NUM_THREADS]; + +MPI_Comm comm = MPI_COMM_WORLD; +int tag = 1; + +static MTEST_THREAD_RETURN_TYPE do_ssend(void *arg) +{ + int id = (long) arg; + int base = id * NUM_MSG_PER_THREAD; + + for (int i = 0; i < NUM_MSG_PER_THREAD; i++) { + buf[base + i][0] = id; + int count = 1; + if (i % 2 == 0) { + count = MAX_COUNT; + } + MPI_Issend(buf[base + i], count, MPI_INT, 1, tag, comm, &reqs[base + i]); + } + return NULL; +} + +static MTEST_THREAD_RETURN_TYPE do_recv(void *arg) +{ + int id = (long) arg; + int base = id * NUM_MSG_PER_THREAD; + + for (int i = 0; i < NUM_CHECK; i++) { + MPI_Irecv(buf[base + i], MAX_COUNT, MPI_INT, 0, tag, comm, &reqs[base + i]); + } + MPI_Waitall(NUM_CHECK, reqs + base, MPI_STATUSES_IGNORE); + return NULL; +} + +int main(int argc, char *argv[]) +{ + int errs = 0; + + int provided; + MTest_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + if (provided != MPI_THREAD_MULTIPLE) { + printf("MPI_THREAD_MULTIPLE not supported by the MPI implementation\n"); + MPI_Abort(MPI_COMM_WORLD, -1); + } + + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + if (size != 2) { + printf("This test require 2 processes\n"); + MPI_Abort(MPI_COMM_WORLD, -1); + } + + + if (rank == 0) { + for (int i = 0; i < NUM_THREADS; i++) { + /* Issend NUM_MSG_PER_THREAD * NUM_THREADS messages */ + MTest_Start_thread(do_ssend, (void *) (long) i); + } + MTest_Join_threads(); + + for (int i = 0; i < NUM_CHECK * NUM_THREADS; i++) { + int id, indx; + MPI_Waitany(TOTAL, reqs, &indx, MPI_STATUS_IGNORE); + id = indx / NUM_MSG_PER_THREAD; + printf(" - %d - %d send complete\n", id, indx); + counts[id]++; + } + + MPI_Send(counts, NUM_THREADS, MPI_INT, 1, tag + 1, comm); + + MPI_Barrier(comm); + + MPI_Waitall(TOTAL, reqs, MPI_STATUSES_IGNORE); + } else { +#if 0 + for (int i = 0; i < NUM_THREADS; i++) { + /* Receive NUM_CHECK * NUM_THREADS messages */ + MTest_Start_thread(do_recv, (void *) (long) i); + } + MTest_Join_threads(); + + for (int j = 0; j < NUM_THREADS; j++) { + for (int i = 0; i < NUM_CHECK; i++) { + int id = buf[j * NUM_MSG_PER_THREAD + i][0]; + counts[id]++; + } + } +#else + for (int i = 0; i < NUM_CHECK * NUM_THREADS; i++) { + MPI_Irecv(buf[i], MAX_COUNT, MPI_INT, 0, tag, comm, &reqs[i]); + } + MPI_Waitall(NUM_CHECK * NUM_THREADS, reqs, MPI_STATUSES_IGNORE); + for (int i = 0; i < NUM_CHECK * NUM_THREADS; i++) { + int id = buf[i][0]; + counts[id]++; + } +#endif + int recv_counts[TOTAL]; + MPI_Recv(recv_counts, NUM_THREADS, MPI_INT, 0, tag + 1, comm, MPI_STATUS_IGNORE); + for (int i = 0; i < NUM_THREADS; i++) { + if (counts[i] != recv_counts[i]) { + errs++; + } + printf("From thread %d, received %d messages, sender reported %d ssend completed\n", i, + counts[i], recv_counts[i]); + } + + MPI_Barrier(comm); + + for (int i = NUM_CHECK * NUM_THREADS; i < TOTAL; i++) { + MPI_Recv(&buf[i], MAX_COUNT, MPI_INT, 0, tag, comm, MPI_STATUS_IGNORE); + } + } + + MTest_Finalize(errs); + return 0; +}