From 9876c6f8a1a2b0ac6fc230d2ccdd401ceb781bca Mon Sep 17 00:00:00 2001
From: "Christopher J. Morrone" <morrone2@llnl.gov>
Date: Tue, 29 Nov 2005 01:09:32 +0000
Subject: [PATCH] First step towards fixing srun signalling when using --attach
 --join.

Fixes fallout from the srun message handler fork() by passing the
slurm_step_layout_t data from the message process to the main srun
process. (The MPIR debugger proctable stuff should probably be handled
in the same place, rather than having four seperate messages of its own,
PIPE_MPIR_*).

Also rewrote par_thr to use a switch statement rather than a long
string of else-if statements.
---
 src/common/global_srun.c |   7 +-
 src/srun/msg.c           | 157 +++++++++++++++++++++++++++++++++------
 src/srun/srun_job.h      |   3 +-
 3 files changed, 140 insertions(+), 27 deletions(-)

diff --git a/src/common/global_srun.c b/src/common/global_srun.c
index 5d2b8421ec7..17bf9ab1e4d 100644
--- a/src/common/global_srun.c
+++ b/src/common/global_srun.c
@@ -135,9 +135,10 @@ job_active_tasks_on_host(srun_job_t *job, int hostid)
 
 	slurm_mutex_lock(&job->task_mutex);
 	for (i = 0; i < job->step_layout->tasks[hostid]; i++) {
-		uint32_t tid = job->step_layout->tids[hostid][i];
-		debug("Task %d state: %d", tid, job->task_state[tid]);
-		if (job->task_state[tid] == SRUN_TASK_RUNNING) 
+		uint32_t *tids = job->step_layout->tids[hostid];
+		xassert(tids != NULL);
+		debug("Task %d state: %d", tids[i], job->task_state[tids[i]]);
+		if (job->task_state[tids[i]] == SRUN_TASK_RUNNING) 
 			retval++;
 	}
 	slurm_mutex_unlock(&job->task_mutex);
diff --git a/src/srun/msg.c b/src/srun/msg.c
index 9e623146dc2..caa968b39bf 100644
--- a/src/srun/msg.c
+++ b/src/srun/msg.c
@@ -104,6 +104,24 @@ static void     _node_fail_handler(char *nodelist, srun_job_t *job);
 #define _poll_wr_isset(pfd) ((pfd).revents & POLLOUT)
 #define _poll_err(pfd)      ((pfd).revents & POLLERR)
 
+#define safe_read(fd, ptr, size) do {					\
+		if (read(fd, ptr, size) != size) {			\
+			debug("%s:%d: %s: read (%d bytes) failed: %m",	\
+			      __FILE__, __LINE__, __CURRENT_FUNC__,	\
+			      (int)size);				\
+			goto rwfail;					\
+		}							\
+	} while (0)
+
+#define safe_write(fd, ptr, size) do {					\
+		if (write(fd, ptr, size) != size) {			\
+			debug("%s:%d: %s: write (%d bytes) failed: %m",	\
+			      __FILE__, __LINE__, __CURRENT_FUNC__,	\
+			      (int)size);				\
+			goto rwfail;					\
+		}							\
+	} while (0)
+
 /*
  * Install entry in the MPI_proctable for host with node id `nodeid'
  *  and the number of tasks `ntasks' with pid array `pid'
@@ -162,6 +180,71 @@ _build_proctable(srun_job_t *job, char *host, int nodeid, int ntasks, uint32_t *
 	}
 }
 
+static void _update_step_layout(int fd, slurm_step_layout_t *layout, int nodeid)
+{
+	int msg_type = PIPE_UPDATE_STEP_LAYOUT;
+	int dummy = 0xdeadbeef;
+	int len = 0;
+	
+	safe_write(fd, &msg_type, sizeof(int)); /* read by par_thr() */
+	safe_write(fd, &dummy, sizeof(int));    /* read by par_thr() */
+
+	/* the rest are read by _handle_update_step_layout() */
+	safe_write(fd, &nodeid, sizeof(int));
+	safe_write(fd, &layout->num_hosts, sizeof(uint32_t));
+	safe_write(fd, &layout->num_tasks, sizeof(uint32_t));
+
+	len = strlen(layout->host[nodeid]) + 1;
+	safe_write(fd, &len, sizeof(int));
+	safe_write(fd, layout->host[nodeid], len);
+
+	safe_write(fd, &layout->tasks[nodeid], sizeof(uint32_t));
+	safe_write(fd, layout->tids[nodeid],
+		   layout->tasks[nodeid]*sizeof(uint32_t));
+
+	return;
+
+rwfail:
+	error("write to srun main process failed");
+	return;
+}
+
+static void _handle_update_step_layout(int fd, slurm_step_layout_t *layout)
+{
+	int nodeid;
+	int len = 0;
+
+	safe_read(fd, &nodeid, sizeof(int));
+	safe_read(fd, &layout->num_hosts, sizeof(uint32_t));
+	safe_read(fd, &layout->num_tasks, sizeof(uint32_t));
+	xassert(nodeid >= 0 && nodeid <= layout->num_tasks);
+
+	/* If this is the first call to this function, then we probably need
+	   to intialize some of the arrays */
+	if (layout->host == NULL)
+		layout->host = xmalloc(layout->num_hosts * sizeof(char *));
+	if (layout->tasks == NULL)
+		layout->tasks = xmalloc(layout->num_hosts * sizeof(uint32_t *));
+	if (layout->tids == NULL)
+		layout->tids = xmalloc(layout->num_hosts * sizeof(uint32_t *));
+
+	safe_read(fd, &len, sizeof(int));
+	/*xassert(layout->host[nodeid] == NULL);*/
+        layout->host[nodeid] = xmalloc(len);
+	safe_read(fd, layout->host[nodeid], len);
+
+	safe_read(fd, &layout->tasks[nodeid], sizeof(uint32_t));
+	xassert(layout->tids[nodeid] == NULL);
+	layout->tids[nodeid] = xmalloc(layout->tasks[nodeid]*sizeof(uint32_t));
+	safe_read(fd, layout->tids[nodeid],
+		  layout->tasks[nodeid]*sizeof(uint32_t));
+	return;
+
+rwfail:
+	error("read from srun message-handler process failed");
+	return;
+}
+
 static void _dump_proctable(srun_job_t *job)
 {
 	int node_inx, task_inx, taskid;
@@ -427,7 +510,6 @@ _reattach_handler(srun_job_t *job, slurm_msg_t *msg)
 {
 	int i;
 	reattach_tasks_response_msg_t *resp = msg->data;
-	pipe_enum_t pipe_enum = PIPE_HOST_STATE;
 	
 	if ((resp->srun_node_id < 0) || (resp->srun_node_id >= job->nhosts)) {
 		error ("Invalid reattach response received");
@@ -439,12 +521,13 @@ _reattach_handler(srun_job_t *job, slurm_msg_t *msg)
 	slurm_mutex_unlock(&job->task_mutex);
 
 	if(message_thread) {
-		write(job->forked_msg->
-		      par_msg->msg_pipe[1],&pipe_enum,sizeof(int));
+		pipe_enum_t pipe_enum = PIPE_HOST_STATE;
+		write(job->forked_msg->par_msg->msg_pipe[1],
+		      &pipe_enum, sizeof(int));
 		write(job->forked_msg->par_msg->msg_pipe[1],
-		      &resp->srun_node_id,sizeof(int));
+		      &resp->srun_node_id, sizeof(int));
 		write(job->forked_msg->par_msg->msg_pipe[1],
-		      &job->host_state[resp->srun_node_id],sizeof(int));
+		      &job->host_state[resp->srun_node_id], sizeof(int));
 	}
 
 	if (resp->return_code != 0) {
@@ -467,14 +550,18 @@ _reattach_handler(srun_job_t *job, slurm_msg_t *msg)
 	job->step_layout->tids[resp->srun_node_id]  = 
 		xmalloc( resp->ntasks * sizeof(uint32_t) );
 
-	job->step_layout->tasks[resp->srun_node_id] = resp->ntasks;      
+	job->step_layout->tasks[resp->srun_node_id] = resp->ntasks;
 
 	for (i = 0; i < resp->ntasks; i++) {
 		job->step_layout->tids[resp->srun_node_id][i] = resp->gtids[i];
 		job->hostid[resp->gtids[i]]      = resp->srun_node_id;
 	}
+	_update_step_layout(job->forked_msg->par_msg->msg_pipe[1],
+			    job->step_layout, resp->srun_node_id);
 
 	/* Build process table for any parallel debugger
+	 * FIXME - does remote_arg* need to be updated
+         *         in the main srun process?
          */
 	if ((remote_argc == 0) && (resp->executable_name)) {
 		remote_argc = 1;
@@ -906,6 +993,11 @@ msg_thr(void *arg)
 	return (void *)1;
 }
 
+
+/*
+ *  This function runs in a pthread of the parent srun process and
+ *  handles messages from the srun message-handler process.
+ */
 void *
 par_thr(void *arg)
 {
@@ -930,10 +1022,12 @@ par_thr(void *arg)
 			continue;
 		} 
 
-		if(type == PIPE_JOB_STATE) {
+		switch(type) {
+		case PIPE_JOB_STATE:
 			debug("PIPE_JOB_STATE, c = %d", c);
 			update_job_state(job, c);
-		} else if(type == PIPE_TASK_STATE) {
+			break;
+		case PIPE_TASK_STATE:
 			debug("PIPE_TASK_STATE, c = %d", c);
 			if(tid == -1) {
 				tid = c;
@@ -949,7 +1043,8 @@ par_thr(void *arg)
 				update_job_state(job, SRUN_JOB_TERMINATED);
 			}
 			tid = -1;
-		} else if(type == PIPE_TASK_EXITCODE) {
+			break;
+		case PIPE_TASK_EXITCODE:
 			debug("PIPE_TASK_EXITCODE");
 			if(tid == -1) {
 				debug("  setting tid");
@@ -961,7 +1056,8 @@ par_thr(void *arg)
 			job->tstatus[tid] = c;
 			slurm_mutex_unlock(&job->task_mutex);
 			tid = -1;
-		} else if(type == PIPE_HOST_STATE) {
+			break;
+		case PIPE_HOST_STATE:
 			if(tid == -1) {
 				tid = c;
 				continue;
@@ -970,20 +1066,24 @@ par_thr(void *arg)
 			job->host_state[tid] = c;
 			slurm_mutex_unlock(&job->task_mutex);
 			tid = -1;
-		} else if(type == PIPE_SIGNALED) {
+			break;
+		case PIPE_SIGNALED:
 			slurm_mutex_lock(&job->state_mutex);
 			job->signaled = c;
 			slurm_mutex_unlock(&job->state_mutex);
-		} else if(type == PIPE_MPIR_PROCTABLE_SIZE) {
+			break;
+		case PIPE_MPIR_PROCTABLE_SIZE:
 			if(MPIR_proctable_size == 0) {
 				MPIR_proctable_size = c;
 				MPIR_proctable = 
 					xmalloc(sizeof(MPIR_PROCDESC) * c);
-			}		
-		} else if(type == PIPE_MPIR_TOTALVIEW_JOBID) {
+			}
+			break;
+		case PIPE_MPIR_TOTALVIEW_JOBID:
 			totalview_jobid = NULL;
 			xstrfmtcat(totalview_jobid, "%lu", c);
-		} else if(type == PIPE_MPIR_PROCDESC) {
+			break;
+		case PIPE_MPIR_PROCDESC:
 			if(tid == -1) {
 				tid = c;
 				continue;
@@ -992,20 +1092,31 @@ par_thr(void *arg)
 				nodeid = c;
 				continue;
 			}
-			MPIR_PROCDESC *tv   = &MPIR_proctable[tid];
-			tv->host_name       = job->step_layout->host[nodeid];
-			tv->executable_name = remote_argv[0];
-			tv->pid             = c;
-			tid = -1;
-			nodeid = -1;
-		} else if(type == PIPE_MPIR_DEBUG_STATE) {
+			{
+				MPIR_PROCDESC *tv = &MPIR_proctable[tid];
+				tv->host_name = job->step_layout->host[nodeid];
+				debug("tv->host_name = %s", tv->host_name);
+				tv->executable_name = remote_argv[0];
+				tv->pid = c;
+				tid = -1;
+				nodeid = -1;
+			}
+			break;
+		case PIPE_MPIR_DEBUG_STATE:
 			MPIR_debug_state = c;
 			MPIR_Breakpoint();
 			if (opt.debugger_test)
 				_dump_proctable(job);
+			break;
+		case PIPE_UPDATE_STEP_LAYOUT:
+			_handle_update_step_layout(par_msg->msg_pipe[0],
+						   job->step_layout);
+			break;
+		default:
+			error("Unrecognized message from message thread %d",
+			      type);
 		}
 		type = PIPE_NONE;
-		
 	}
 	close(par_msg->msg_pipe[0]); // close excess fildes    
 	close(msg_par->msg_pipe[1]); // close excess fildes
diff --git a/src/srun/srun_job.h b/src/srun/srun_job.h
index 09dd683faea..151ff988e51 100644
--- a/src/srun/srun_job.h
+++ b/src/srun/srun_job.h
@@ -56,7 +56,8 @@ typedef enum {
 	PIPE_MPIR_PROCTABLE_SIZE,
 	PIPE_MPIR_TOTALVIEW_JOBID,
 	PIPE_MPIR_PROCDESC,
-	PIPE_MPIR_DEBUG_STATE
+	PIPE_MPIR_DEBUG_STATE,
+	PIPE_UPDATE_STEP_LAYOUT
 } pipe_enum_t;
 
 typedef enum {
-- 
GitLab