diff --git a/src/api/slurm_pmi.c b/src/api/slurm_pmi.c index ccc003e73719fee8815902fede4ed0cd3864bd8d..acb4eb085aa1c8700da63d37049a565c528217ac 100644 --- a/src/api/slurm_pmi.c +++ b/src/api/slurm_pmi.c @@ -32,6 +32,7 @@ #include "src/common/slurm_protocol_defs.h" #include "src/common/xmalloc.h" +int pmi_fd = -1; uint16_t srun_port = 0; slurm_addr srun_addr; @@ -83,9 +84,9 @@ int slurm_send_kvs_comm_set(struct kvs_comm_set *kvs_set_ptr) int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, int pmi_rank, int pmi_size) { - int rc, pmi_fd; + int rc, srun_fd; slurm_msg_t msg_send, msg_rcv; - slurm_addr slurm_address; + slurm_addr slurm_addr; char hostname[64]; uint16_t port; kvs_get_msg_t data; @@ -93,14 +94,24 @@ int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, if (kvs_set_ptr == NULL) return EINVAL; - if ((rc = _get_addr()) != SLURM_SUCCESS) + if ((rc = _get_addr()) != SLURM_SUCCESS) { + error("_get_addr: %m"); return rc; - if ((pmi_fd = slurm_init_msg_engine_port(0)) < 0) - return SLURM_ERROR; - if (slurm_get_stream_addr(pmi_fd, &slurm_address) < 0) + } + if (pmi_fd < 0) { + if ((pmi_fd = slurm_init_msg_engine_port(0)) < 0) { + error("slurm_init_msg_engine_port: %m"); + return SLURM_ERROR; + } + fd_set_blocking(pmi_fd); + } + if (slurm_get_stream_addr(pmi_fd, &slurm_addr) < 0) { + error("slurm_get_stream_addr: %m"); return SLURM_ERROR; - fd_set_nonblocking(pmi_fd); - port = slurm_address.sin_port; + } + /* hostname is not set here, so slurm_get_addr fails + slurm_get_addr(&slurm_addr, &port, hostname, sizeof(hostname)); */ + port = ntohs(slurm_addr.sin_port); getnodename(hostname, sizeof(hostname)); data.task_id = pmi_rank; @@ -111,19 +122,44 @@ int slurm_get_kvs_comm_set(struct kvs_comm_set **kvs_set_ptr, msg_send.msg_type = PMI_KVS_GET_REQ; msg_send.data = &data; - /* Send the RPC to the local srun communcation manager */ - if (slurm_send_recv_node_msg(&msg_send, &msg_rcv, 0) < 0) + /* Send the RPC to the srun communcation manager */ + if (slurm_send_recv_node_msg(&msg_send, &msg_rcv, 0) < 0) { + error("slurm_send_recv_node_msg: %m"); return SLURM_ERROR; - if (msg_rcv.msg_type != RESPONSE_SLURM_RC) + } + if (msg_rcv.msg_type != RESPONSE_SLURM_RC) { + error("slurm_get_kvs_comm_set msg_type=%d", msg_rcv.msg_type); return SLURM_UNEXPECTED_MSG_ERROR; + } rc = ((return_code_msg_t *) msg_rcv.data)->return_code; slurm_free_return_code_msg((return_code_msg_t *) msg_rcv.data); - if (rc != SLURM_SUCCESS) + if (rc != SLURM_SUCCESS) { + error("slurm_get_kvs_comm_set error_code=%d", rc); return rc; + } /* get the message after all tasks reach the barrier */ -/* slurm_close_accepted_conn(pmi_fd); Consider leaving socket open */ - *kvs_set_ptr = NULL; +info("waiting for msg on port %u", port); + srun_fd = slurm_accept_msg_conn(pmi_fd, &srun_addr); + if (srun_fd < 0) { + error("slurm_accept_msg_conn: %m"); + return errno; + } +again: if (slurm_receive_msg(srun_fd, &msg_rcv, 0) < 0) { + if (errno == EINTR) + goto again; + error("slurm_receive_msg: %m"); + return errno; + } +info("got msg"); + if (msg_rcv.msg_type != PMI_KVS_GET_RESP) { + error("slurm_get_kvs_comm_set msg_type=%d", msg_rcv.msg_type); + return SLURM_UNEXPECTED_MSG_ERROR; + } + slurm_send_rc_msg(&msg_rcv, SLURM_SUCCESS); +info("sent reply"); + slurm_close_accepted_conn(srun_fd); + *kvs_set_ptr = msg_rcv.data; return SLURM_SUCCESS; } diff --git a/src/common/slurm_protocol_socket_implementation.c b/src/common/slurm_protocol_socket_implementation.c index 8442690f1a872c97a112ba6a895b432ad1f66dbc..07370e3e5801a66225ee59f64a3c54f32a06fe24 100644 --- a/src/common/slurm_protocol_socket_implementation.c +++ b/src/common/slurm_protocol_socket_implementation.c @@ -751,7 +751,7 @@ void _slurm_get_addr (slurm_addr *addr, uint16_t *port, char *host, (void *) &h_buf, sizeof(h_buf), &h_err ); if (he != NULL) { - *port = addr->sin_port; + *port = ntohs(addr->sin_port); strncpy(host, he->h_name, buflen); } else { error("Lookup failed: %s", host_strerror(h_err)); diff --git a/src/srun/allocate.c b/src/srun/allocate.c index acd8c804d526971296e3b63356709dcbc6f72f80..c999ae763014554d4619f0bd4730b258e99c138b 100644 --- a/src/srun/allocate.c +++ b/src/srun/allocate.c @@ -257,7 +257,7 @@ _accept_msg_connection(slurm_fd slurmctld_fd, } slurm_get_addr(&cli_addr, &port, host, sizeof(host)); - debug2("got message connection from %s:%d", host, ntohs(port)); + debug2("got message connection from %s:%d", host, port); msg = xmalloc(sizeof(*msg)); diff --git a/src/srun/msg.c b/src/srun/msg.c index 06caeeabdac3db7035aebcc363ff9f117678348c..af25de6878ca19a5bfb9eedece237cb19cd23378 100644 --- a/src/srun/msg.c +++ b/src/srun/msg.c @@ -1166,12 +1166,12 @@ extern slurm_fd slurmctld_msg_init(void) if (slurm_get_stream_addr(slurmctld_fd, &slurm_address) < 0) fatal("slurm_get_stream_addr error %m"); fd_set_nonblocking(slurmctld_fd); - /* hostname is not set, so slurm_get_addr fails + /* hostname is not set, so slurm_get_addr fails slurm_get_addr(&slurm_address, &port, hostname, sizeof(hostname)); */ - port = slurm_address.sin_port; + port = ntohs(slurm_address.sin_port); getnodename(hostname, sizeof(hostname)); slurmctld_comm_addr.hostname = xstrdup(hostname); - slurmctld_comm_addr.port = ntohs(port); + slurmctld_comm_addr.port = port; debug2("slurmctld messasges to host=%s,port=%u", slurmctld_comm_addr.hostname, slurmctld_comm_addr.port); diff --git a/src/srun/pmi.c b/src/srun/pmi.c index 7f504dab72b49454132cb9abf0a474402b6c15dc..4e3258e8e6668cd0cccaa1f0e0d26f57a52e8126 100644 --- a/src/srun/pmi.c +++ b/src/srun/pmi.c @@ -33,7 +33,10 @@ #include <slurm/slurm_errno.h> #include "src/api/slurm_pmi.h" +#include "src/common/macros.h" #include "src/common/slurm_protocol_defs.h" +#include "src/common/xsignal.h" +#include "src/common/xstring.h" #include "src/common/xmalloc.h" #define _DEBUG 1 @@ -51,12 +54,133 @@ struct barrier_resp *barrier_ptr = NULL; uint16_t barrier_resp_cnt = 0; uint16_t barrier_cnt = 0; -/* transmit the KVS keypairs to all tasks, waiting at a barrier */ +struct agent_arg { + struct barrier_resp *barrier_xmit_ptr; + int barrier_xmit_cnt; + struct kvs_comm **kvs_xmit_ptr; + int kvs_xmit_cnt; +}; + +static void *_agent(void *x); +static struct kvs_comm *_find_kvs_by_name(char *name); +struct kvs_comm **_kvs_comm_dup(void); +static void _kvs_xmit_tasks(void); +static void _merge_named_kvs(struct kvs_comm *kvs_orig, + struct kvs_comm *kvs_new); +static void _move_kvs(struct kvs_comm *kvs_new); +static void _print_kvs(void); + +/* Transmit the KVS keypairs to all tasks, waiting at a barrier + * This will take some time, so we work with a copy of the KVS keypairs. + * We also work with a private copy of the barrier data and clear the + * global data pointers so any new barrier requests get treated as + * completely independent of this one. */ static void _kvs_xmit_tasks(void) { + struct agent_arg args; + pthread_attr_t attr; + pthread_t agent_id; + #if _DEBUG info("All tasks at barrier, transmit KVS keypairs now"); #endif + /* copy the data */ + args.barrier_xmit_ptr = barrier_ptr; + args.barrier_xmit_cnt = barrier_cnt; + barrier_ptr = NULL; + barrier_resp_cnt = 0; + barrier_cnt = 0; + args.kvs_xmit_ptr = _kvs_comm_dup(); + args.kvs_xmit_cnt = kvs_comm_cnt; + + /* Spawn a pthread to transmit it */ + slurm_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); +// if (pthread_create(&agent_id, &attr, _agent, (void *) &args)) +// fatal("pthread_create"); +/* FIXME: signaling problem if pthread */ +_agent((void *) &args); +} + +static void *_agent(void *x) +{ + struct agent_arg *args = (struct agent_arg *) x; + struct kvs_comm_set kvs_set; + int i, j, rc; + slurm_msg_t msg_send, msg_rcv; + + /* send the message */ + kvs_set.kvs_comm_recs = args->kvs_xmit_cnt; + kvs_set.kvs_comm_ptr = args->kvs_xmit_ptr; + msg_send.msg_type = PMI_KVS_GET_RESP; + msg_send.data = (void *) &kvs_set; + for (i=0; i<args->barrier_xmit_cnt; i++) { + debug2("KVS_Barrier msg to %s:%u", + args->barrier_xmit_ptr[i].hostname, + args->barrier_xmit_ptr[i].port); + slurm_set_addr(&msg_send.address, + args->barrier_xmit_ptr[i].port, + args->barrier_xmit_ptr[i].hostname); + //if (slurm_send_recv_node_msg(&msg_send, &msg_rcv, 0) < 0) { + if (slurm_send_only_node_msg(&msg_send) < 0) { + error("KVS_Barrier reply fail to %s","TEST", + args->barrier_xmit_ptr[i].hostname); + continue; + } +continue; +/* FIXME: timing problem waiting for reply */ + if (msg_rcv.msg_type != RESPONSE_SLURM_RC) { + error("KVS_Barrier msg reply type %d bad", + msg_rcv.msg_type); + continue; + } + rc = ((return_code_msg_t *) msg_rcv.data)->return_code; + slurm_free_return_code_msg((return_code_msg_t *) msg_rcv.data); + if (rc != SLURM_SUCCESS) + error("KVS_Barrier rc=%d", rc); + } + + /* Release allocated memory */ + for (i=0; i<args->barrier_xmit_cnt; i++) + xfree(args->barrier_xmit_ptr[i].hostname); + xfree(args->barrier_xmit_ptr); + for (i=0; i<args->kvs_xmit_cnt; i++) { + for (j=0; j<args->kvs_xmit_ptr[i]->kvs_cnt; j++) { + xfree(args->kvs_xmit_ptr[i]->kvs_keys[j]); + xfree(args->kvs_xmit_ptr[i]->kvs_values[j]); + } + xfree(args->kvs_xmit_ptr[i]->kvs_keys); + xfree(args->kvs_xmit_ptr[i]->kvs_values); + xfree(args->kvs_xmit_ptr[i]->kvs_name); + xfree(args->kvs_xmit_ptr[i]); + } + xfree(args->kvs_xmit_ptr); + return NULL; +} + +/* duplicate the current KVS comm structure */ +struct kvs_comm **_kvs_comm_dup(void) +{ + int i, j; + struct kvs_comm **rc_kvs; + + rc_kvs = xmalloc(sizeof(struct kvs_comm *) * kvs_comm_cnt); + for (i=0; i<kvs_comm_cnt; i++) { + rc_kvs[i] = xmalloc(sizeof(struct kvs_comm)); + rc_kvs[i]->kvs_name = xstrdup(kvs_comm_ptr[i]->kvs_name); + rc_kvs[i]->kvs_cnt = kvs_comm_ptr[i]->kvs_cnt; + rc_kvs[i]->kvs_keys = + xmalloc(sizeof(char *) * rc_kvs[i]->kvs_cnt); + rc_kvs[i]->kvs_values = + xmalloc(sizeof(char *) * rc_kvs[i]->kvs_cnt); + for (j=0; j<rc_kvs[i]->kvs_cnt; j++) { + rc_kvs[i]->kvs_keys[j] = + xstrdup(kvs_comm_ptr[i]->kvs_keys[j]); + rc_kvs[i]->kvs_values[j] = + xstrdup(kvs_comm_ptr[i]->kvs_values[j]); + } + } + return rc_kvs; } /* return pointer to named kvs element or NULL if not found */ @@ -76,6 +200,7 @@ static void _merge_named_kvs(struct kvs_comm *kvs_orig, struct kvs_comm *kvs_new) { int i, j; + for (i=0; i<kvs_new->kvs_cnt; i++) { for (j=0; j<kvs_orig->kvs_cnt; j++) { if (strcmp(kvs_new->kvs_keys[i], kvs_orig->kvs_keys[j])) @@ -156,8 +281,9 @@ extern int pmi_kvs_get(kvs_get_msg_t *kvs_get_ptr) int rc = SLURM_SUCCESS; #if _DEBUG - info("pmi_kvs_get: rank:%u size:%u port:%u, host:%s", kvs_get_ptr->task_id, - kvs_get_ptr->size, kvs_get_ptr->port, kvs_get_ptr->hostname); + info("pmi_kvs_get: rank:%u size:%u port:%u, host:%s", + kvs_get_ptr->task_id, kvs_get_ptr->size, + kvs_get_ptr->port, kvs_get_ptr->hostname); #endif if (kvs_get_ptr->size == 0) { error("PMK_KVS_Barrier reached with size == 0"); @@ -167,7 +293,7 @@ extern int pmi_kvs_get(kvs_get_msg_t *kvs_get_ptr) pthread_mutex_lock(&kvs_mutex); if (barrier_cnt == 0) { barrier_cnt = kvs_get_ptr->size; - barrier_ptr = xmalloc(sizeof(struct barrier_resp) * barrier_cnt); + barrier_ptr = xmalloc(sizeof(struct barrier_resp)*barrier_cnt); } else if (barrier_cnt != kvs_get_ptr->size) { error("PMK_KVS_Barrier task count inconsistent (%u != %u)", barrier_cnt, kvs_get_ptr->size); diff --git a/testsuite/expect/test7.2 b/testsuite/expect/test7.2 index d55a27c9031b0e666ea3bdab1141353782b7de0c..92889e5f9b89630410a12ae02761185ff6a95485 100755 --- a/testsuite/expect/test7.2 +++ b/testsuite/expect/test7.2 @@ -67,7 +67,7 @@ exec $bin_chmod 700 $file_prog_get # Spawn a job to test BNR functionality # set timeout $max_job_delay -spawn $srun -N1 -n1 -O -t1 $file_prog_get +spawn $srun -l -N1-2 -n4 -O -t1 $file_prog_get expect { -re "FAILURE" { send_user "\nFAILURE: some error occured\n" diff --git a/testsuite/expect/test7.2.prog.c b/testsuite/expect/test7.2.prog.c index fe3cf68591989b8766d7143c5f768e87a9261bfc..2132a3ada903b9e873361a26d7f482d09622e8c8 100644 --- a/testsuite/expect/test7.2.prog.c +++ b/testsuite/expect/test7.2.prog.c @@ -150,8 +150,8 @@ main (int argc, char **argv) exit(1); } printf("PMI_Barrier completed\n"); - /* Task 0 only: Now lets get all keypairs and validate */ - if (pmi_rank == 0) { + /* Tasks 0 and 1 only: Now lets get all keypairs and validate */ + if (pmi_rank <= 1) { for (i=0; i<pmi_size; i++) { snprintf(key, key_len, "ATTR_1_%d", i); if ((rc = PMI_KVS_Get(kvs_name, key, val, val_len))