diff --git a/src/api/pmi.c b/src/api/pmi.c index 74e60ffe95d997bef3de147dda87a32677303418..1e2591a9c4a0461db384726eb8c482a345ae5cf0 100644 --- a/src/api/pmi.c +++ b/src/api/pmi.c @@ -553,7 +553,8 @@ int PMI_Barrier( void ) int i, j, k, rc = PMI_SUCCESS; /* Issue the RPC */ - if (slurm_get_kvs_comm_set(&kvs_set_ptr, pmi_rank) != SLURM_SUCCESS) + if (slurm_get_kvs_comm_set(&kvs_set_ptr, pmi_rank, pmi_size) + != SLURM_SUCCESS) return PMI_FAIL; if (kvs_set_ptr == NULL) return PMI_SUCCESS; diff --git a/src/api/slurm_pmi.c b/src/api/slurm_pmi.c index 86ce5c318e456739b5fd44a5557cacbd660a2dea..ccc003e73719fee8815902fede4ed0cd3864bd8d 100644 --- a/src/api/slurm_pmi.c +++ b/src/api/slurm_pmi.c @@ -80,7 +80,8 @@ int slurm_send_kvs_comm_set(struct kvs_comm_set *kvs_set_ptr) } /* Wait for barrier and get full PMI Keyval space data */ -int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, int pmi_rank) +int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, + int pmi_rank, int pmi_size) { int rc, pmi_fd; slurm_msg_t msg_send, msg_rcv; @@ -103,6 +104,7 @@ int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, int pmi_rank) getnodename(hostname, sizeof(hostname)); data.task_id = pmi_rank; + data.size = pmi_size; data.port = port; data.hostname = hostname; msg_send.address = srun_addr; diff --git a/src/api/slurm_pmi.h b/src/api/slurm_pmi.h index 4450d1175a3429134b41c853ff5c053cd2204c2c..ed7113d7c24862bbf1f67853121ec8b4ed84f594 100644 --- a/src/api/slurm_pmi.h +++ b/src/api/slurm_pmi.h @@ -62,7 +62,8 @@ struct kvs_comm_set { int slurm_send_kvs_comm_set(struct kvs_comm_set *kvs_set_ptr); /* Wait for barrier and get full PMI Keyval space data */ -int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, int pmi_rank); +int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, + int pmi_rank, int pmi_size); /* Free kvs_comm_set returned by slurm_get_kvs_comm_set() */ void slurm_free_kvs_comm_set(struct kvs_comm_set *kvs_set_ptr); diff --git a/src/common/slurm_protocol_defs.h b/src/common/slurm_protocol_defs.h index 46579f9ccf8b62a4c4d941cb818b0398737d1efd..64e3458bd06e4a0a9142a7dd0dbe4a9fba07497f 100644 --- a/src/common/slurm_protocol_defs.h +++ b/src/common/slurm_protocol_defs.h @@ -455,6 +455,7 @@ typedef struct jobacct_msg { typedef struct kvs_get_msg { uint16_t task_id; /* job step's task id */ + uint16_t size; /* count of tasks in job */ uint16_t port; /* port to be sent the kvs data */ char * hostname; /* hostname to be sent the kvs data */ } kvs_get_msg_t; diff --git a/src/common/slurm_protocol_pack.c b/src/common/slurm_protocol_pack.c index b9db1bc82fe8a416a764b62acbc697930cd662b8..042e56ecac444ebe7bb2d569e6c22adcd5604ec4 100644 --- a/src/common/slurm_protocol_pack.c +++ b/src/common/slurm_protocol_pack.c @@ -3370,6 +3370,7 @@ unpack_error: static void _pack_kvs_get(kvs_get_msg_t *msg_ptr, Buf buffer) { pack16(msg_ptr->task_id, buffer); + pack16(msg_ptr->size, buffer); pack16(msg_ptr->port, buffer); packstr(msg_ptr->hostname, buffer); } @@ -3382,6 +3383,7 @@ static int _unpack_kvs_get(kvs_get_msg_t **msg_ptr, Buf buffer) msg = xmalloc(sizeof(struct kvs_get_msg)); *msg_ptr = msg; safe_unpack16(&msg->task_id, buffer); + safe_unpack16(&msg->size, buffer); safe_unpack16(&msg->port, buffer); safe_unpackstr_xmalloc(&msg->hostname, &uint16_tmp, buffer); return SLURM_SUCCESS; diff --git a/src/srun/pmi.c b/src/srun/pmi.c index c0bb1880fda62c6431a9452fe376bc55f5a5655c..7f504dab72b49454132cb9abf0a474402b6c15dc 100644 --- a/src/srun/pmi.c +++ b/src/srun/pmi.c @@ -43,6 +43,22 @@ pthread_mutex_t kvs_mutex = PTHREAD_MUTEX_INITIALIZER; int kvs_comm_cnt = 0; struct kvs_comm **kvs_comm_ptr = NULL; +struct barrier_resp { + uint16_t port; + char *hostname; +}; +struct barrier_resp *barrier_ptr = NULL; +uint16_t barrier_resp_cnt = 0; +uint16_t barrier_cnt = 0; + +/* transmit the KVS keypairs to all tasks, waiting at a barrier */ +static void _kvs_xmit_tasks(void) +{ +#if _DEBUG + info("All tasks at barrier, transmit KVS keypairs now"); +#endif +} + /* return pointer to named kvs element or NULL if not found */ static struct kvs_comm *_find_kvs_by_name(char *name) { @@ -137,7 +153,44 @@ extern int pmi_kvs_put(struct kvs_comm_set *kvs_set_ptr) extern int pmi_kvs_get(kvs_get_msg_t *kvs_get_ptr) { -debug("pmi_kvs_get: rank:%u port:%u, host:%s", kvs_get_ptr->task_id, kvs_get_ptr->port, kvs_get_ptr->hostname); - return SLURM_SUCCESS; + int rc = SLURM_SUCCESS; + +#if _DEBUG + info("pmi_kvs_get: rank:%u size:%u port:%u, host:%s", kvs_get_ptr->task_id, + kvs_get_ptr->size, kvs_get_ptr->port, kvs_get_ptr->hostname); +#endif + if (kvs_get_ptr->size == 0) { + error("PMK_KVS_Barrier reached with size == 0"); + return SLURM_ERROR; + } + + pthread_mutex_lock(&kvs_mutex); + if (barrier_cnt == 0) { + barrier_cnt = kvs_get_ptr->size; + barrier_ptr = xmalloc(sizeof(struct barrier_resp) * barrier_cnt); + } else if (barrier_cnt != kvs_get_ptr->size) { + error("PMK_KVS_Barrier task count inconsistent (%u != %u)", + barrier_cnt, kvs_get_ptr->size); + rc = SLURM_ERROR; + goto fini; + } + if (kvs_get_ptr->task_id >= barrier_cnt) { + error("PMK_KVS_Barrier task count(%u) >= size(%u)", + kvs_get_ptr->task_id, barrier_cnt); + rc = SLURM_ERROR; + goto fini; + } + if (barrier_ptr[kvs_get_ptr->task_id].port == 0) + barrier_resp_cnt++; + else + error("PMK_KVS_Barrier duplicate request from task %u", + kvs_get_ptr->task_id); + barrier_ptr[kvs_get_ptr->task_id].port = kvs_get_ptr->port; + barrier_ptr[kvs_get_ptr->task_id].hostname = kvs_get_ptr->hostname; + kvs_get_ptr->hostname = NULL; /* just moved the pointer */ + if (barrier_resp_cnt == barrier_cnt) + _kvs_xmit_tasks(); +fini: pthread_mutex_unlock(&kvs_mutex); + return rc; }