diff --git a/src/common/slurm_protocol_defs.c b/src/common/slurm_protocol_defs.c index 0d7eba21016496d90016f62443aa83dc6442149e..12a17772ae0a4592bbfa8411f4e8ab48825e82e6 100644 --- a/src/common/slurm_protocol_defs.c +++ b/src/common/slurm_protocol_defs.c @@ -221,6 +221,8 @@ void slurm_free_launch_tasks_response_msg(launch_tasks_response_msg_t * if (msg) { if (msg->node_name) xfree(msg->node_name); + if (msg->local_pids) + xfree(msg->local_pids); xfree(msg); } } diff --git a/src/common/slurm_protocol_defs.h b/src/common/slurm_protocol_defs.h index 50d7a87c4997aebde1bfb4db8ac1062934097268..ec6af6291fb94def238dd0b46587151922f59f13 100644 --- a/src/common/slurm_protocol_defs.h +++ b/src/common/slurm_protocol_defs.h @@ -260,7 +260,8 @@ typedef struct launch_tasks_response_msg { uint32_t return_code; char *node_name; uint32_t srun_node_id; - uint32_t local_pid; + uint32_t count_of_pids; + uint32_t *local_pids; } launch_tasks_response_msg_t; typedef struct task_ext_msg { diff --git a/src/common/slurm_protocol_pack.c b/src/common/slurm_protocol_pack.c index 85bf34af37c0c2fbc30ce97e16a358a9759da8c4..043e17c8a093c39279ca4b25eaefcb2b3bb841d9 100644 --- a/src/common/slurm_protocol_pack.c +++ b/src/common/slurm_protocol_pack.c @@ -798,9 +798,13 @@ _unpack_resource_allocation_response_msg(resource_allocation_response_msg_t safe_unpack32_array((uint32_t **) & (tmp_ptr->cpus_per_node), &uint32_tmp, buffer); + if (tmp_ptr->num_cpu_groups != uint32_tmp) + goto unpack_error; safe_unpack32_array((uint32_t **) & (tmp_ptr->cpu_count_reps), &uint32_tmp, buffer); + if (tmp_ptr->num_cpu_groups != uint32_tmp) + goto unpack_error; } else { tmp_ptr->cpus_per_node = NULL; tmp_ptr->cpu_count_reps = NULL; @@ -875,9 +879,13 @@ static int safe_unpack32_array((uint32_t **) & (tmp_ptr->cpus_per_node), &uint32_tmp, buffer); + if (tmp_ptr->num_cpu_groups != uint32_tmp) + goto unpack_error; safe_unpack32_array((uint32_t **) & (tmp_ptr->cpu_count_reps), &uint32_tmp, buffer); + if (tmp_ptr->num_cpu_groups != uint32_tmp) + goto unpack_error; } safe_unpack32(&tmp_ptr->job_step_id, buffer); @@ -1936,6 +1944,8 @@ _unpack_task_exit_msg(task_exit_msg_t ** msg_ptr, Buf buffer) safe_unpack32(&msg->return_code, buffer); safe_unpack32(&msg->num_tasks, buffer); safe_unpack32_array(&msg->task_id_list, &uint32_tmp, buffer); + if (msg->num_tasks != uint32_tmp) + goto unpack_error; return SLURM_SUCCESS; unpack_error: @@ -1951,7 +1961,9 @@ _pack_launch_tasks_response_msg(launch_tasks_response_msg_t * msg, Buf buffer) pack32(msg->return_code, buffer); packstr(msg->node_name, buffer); pack32(msg->srun_node_id, buffer); - pack32(msg->local_pid, buffer); + pack32(msg->count_of_pids, buffer); + pack32_array(msg->local_pids, + msg->count_of_pids, buffer); } static int @@ -1959,6 +1971,7 @@ _unpack_launch_tasks_response_msg(launch_tasks_response_msg_t ** msg_ptr, Buf buffer) { uint16_t uint16_tmp; + uint32_t uint32_tmp; launch_tasks_response_msg_t *msg; msg = xmalloc(sizeof(launch_tasks_response_msg_t)); @@ -1967,7 +1980,10 @@ _unpack_launch_tasks_response_msg(launch_tasks_response_msg_t ** safe_unpack32(&msg->return_code, buffer); safe_unpackstr_xmalloc(&msg->node_name, &uint16_tmp, buffer); safe_unpack32(&msg->srun_node_id, buffer); - safe_unpack32(&msg->local_pid, buffer); + safe_unpack32(&msg->count_of_pids, buffer); + safe_unpack32_array(&msg->local_pids, &uint32_tmp, buffer); + if (msg->count_of_pids != uint32_tmp) + goto unpack_error; return SLURM_SUCCESS; unpack_error: @@ -2035,6 +2051,8 @@ _unpack_launch_tasks_request_msg(launch_tasks_request_msg_t ** safe_unpackstr_xmalloc(&msg->efname, &uint16_tmp, buffer); safe_unpackstr_xmalloc(&msg->ifname, &uint16_tmp, buffer); safe_unpack32_array(&msg->global_task_ids, &uint32_tmp, buffer); + if (msg->tasks_to_launch != uint32_tmp) + goto unpack_error; #ifdef HAVE_LIBELAN3 qsw_alloc_jobinfo(&msg->qsw_job); diff --git a/src/slurmd/req.c b/src/slurmd/req.c index c1dda4844463fae247b1bb594d7d36d3abe95afa..82033d59d75d558f7526321b7ccc5780ded78081 100644 --- a/src/slurmd/req.c +++ b/src/slurmd/req.c @@ -217,6 +217,8 @@ _rpc_launch_tasks(slurm_msg_t *msg, slurm_addr *cli) resp.node_name = conf->hostname; resp.srun_node_id = req->srun_node_id; resp.return_code = rc; + resp.count_of_pids = 0; + resp.local_pids = NULL; /* array type of uint32_t */ slurm_send_only_node_msg(&resp_msg); } diff --git a/src/srun/msg.c b/src/srun/msg.c index 72e31ed1481e396fda85f468a4149ccc7ef1f9ac..58482b15ca2833e005c070e7d049ca1705587990 100644 --- a/src/srun/msg.c +++ b/src/srun/msg.c @@ -86,6 +86,26 @@ static char * _taskid2hostname(int task_id, job_t * job); #define _poll_err(pfd) ((pfd).revents & POLLERR) +#ifdef HAVE_TOTALVIEW +static void +_build_tv_list(launch_tasks_response_msg_t *msg) +{ + MPIR_PROCDESC * tv_tasks; + int i; + + if (!opt.totalview) + return; + + for (i=0; i<msg->count_of_pids; i++) { + tv_tasks = &MPIR_proctable[MPIR_proctable_size++]; + tv_tasks->host_name = msg->node_name; + tv_tasks->executable_name = opt.progname; + tv_tasks->pid = msg->local_pid[i]; + } + msg->node_name = NULL; /* nothing to free */ +} +#endif + static void _launch_handler(job_t *job, slurm_msg_t *resp) { @@ -106,15 +126,7 @@ _launch_handler(job_t *job, slurm_msg_t *resp) job->host_state[msg->srun_node_id] = SRUN_HOST_REPLIED; #ifdef HAVE_TOTALVIEW - if (opt.totalview) { - MPIR_PROCDESC * tv_tasks; - tv_tasks = - &MPIR_proctable[MPIR_proctable_size++]; - tv_tasks->host_name = msg->node_name; - msg->node_name = NULL; /* nothing to free */ - tv_tasks->executable_name = opt.progname; - tv_tasks->pid = msg->local_pid; - } + _build_tv_list(msg); #endif } else error("launch resp from %s has bad task_id %d",