From a9408ffeb1fc4b7d221db67f2225255cd14a9d96 Mon Sep 17 00:00:00 2001
From: Moe Jette <jette1@llnl.gov>
Date: Wed, 10 Oct 2007 16:16:44 +0000
Subject: [PATCH] Restore support for the following srun RPCs: SRUN_PING,
 SRUN_USER_MSG, and SRUN_EXEC (including OpenMPI checkpoint).

---
 src/api/step_launch.c | 137 +++++++++++++++++++++++++++++++++++++++---
 src/api/step_launch.h |   5 ++
 src/srun/srun.c       |   3 +-
 src/srun/srun.h       |   2 -
 4 files changed, 133 insertions(+), 14 deletions(-)

diff --git a/src/api/step_launch.c b/src/api/step_launch.c
index 880cdbb3363..c91cf25509b 100644
--- a/src/api/step_launch.c
+++ b/src/api/step_launch.c
@@ -31,6 +31,7 @@
 #endif
 
 #include <errno.h>
+#include <fcntl.h>
 #include <pthread.h>
 #include <stdarg.h>
 #include <stdlib.h>
@@ -40,6 +41,7 @@
 #include <netinet/in.h>
 #include <sys/param.h>
 #include <sys/socket.h>
+#include <sys/stat.h>
 #include <sys/types.h>
 #include <sys/un.h>
 #include <netdb.h> /* for gethostbyname */
@@ -79,11 +81,13 @@ static void _print_launch_msg(launch_tasks_request_msg_t *msg,
 /**********************************************************************
  * Message handler declarations
  **********************************************************************/
+static pid_t  srun_ppid = (pid_t) 0;
 static uid_t  slurm_uid;
-static int _msg_thr_create(struct step_launch_state *sls, int num_nodes);
+static void _exec_prog(slurm_msg_t *msg);
+static int  _msg_thr_create(struct step_launch_state *sls, int num_nodes);
 static void _handle_msg(struct step_launch_state *sls, slurm_msg_t *msg);
 static bool _message_socket_readable(eio_obj_t *obj);
-static int _message_socket_accept(eio_obj_t *obj, List objs);
+static int  _message_socket_accept(eio_obj_t *obj, List objs);
 
 static struct io_operations message_socket_ops = {
 	readable:	&_message_socket_readable,
@@ -907,6 +911,7 @@ _handle_msg(struct step_launch_state *sls, slurm_msg_t *msg)
 {
 	uid_t req_uid = g_slurm_auth_get_uid(msg->auth_cred);
 	uid_t uid = getuid();
+	srun_user_msg_t *um;
 	int rc;
 	
 	if ((req_uid != slurm_uid) && (req_uid != 0) && (req_uid != uid)) {
@@ -926,21 +931,35 @@ _handle_msg(struct step_launch_state *sls, slurm_msg_t *msg)
 		_exit_handler(sls, msg);
 		slurm_free_task_exit_msg(msg->data);
 		break;
-	case SRUN_NODE_FAIL:
-		debug2("received srun node fail");
-		_node_fail_handler(sls, msg);
-		slurm_free_srun_node_fail_msg(msg->data);
+	case SRUN_PING:
+		debug3("slurmctld ping received");
+		slurm_send_rc_msg(msg, SLURM_SUCCESS);
+		slurm_free_srun_ping_msg(msg->data);
 		break;
-	case SRUN_TIMEOUT:
-		debug2("received job step timeout message");
-		_timeout_handler(sls, msg);
-		slurm_free_srun_timeout_msg(msg->data);
+	case SRUN_EXEC:
+		_exec_prog(msg);
+		slurm_free_srun_exec_msg(msg->data);
 		break;
 	case SRUN_JOB_COMPLETE:
 		debug2("received job step complete message");
 		_job_complete_handler(sls, msg);
 		slurm_free_srun_job_complete_msg(msg->data);
 		break;
+	case SRUN_TIMEOUT:
+		debug2("received job step timeout message");
+		_timeout_handler(sls, msg);
+		slurm_free_srun_timeout_msg(msg->data);
+		break;
+	case SRUN_USER_MSG:
+		um = msg->data;
+		info("%s", um->msg);
+		slurm_free_srun_user_msg(msg->data);
+		break;
+	case SRUN_NODE_FAIL:
+		debug2("received srun node fail");
+		_node_fail_handler(sls, msg);
+		slurm_free_srun_node_fail_msg(msg->data);
+		break;
 	case PMI_KVS_PUT_REQ:
 		debug2("PMI_KVS_PUT_REQ received");
 		rc = pmi_kvs_put((struct kvs_comm_set *) msg->data);
@@ -1055,3 +1074,101 @@ static void _print_launch_msg(launch_tasks_request_msg_t *msg,
 	debug3("uid:%ld gid:%ld cwd:%s %d", (long) msg->uid,
 		(long) msg->gid, msg->cwd, nodeid);
 }
+
+void record_ppid(void)
+{
+	srun_ppid = getppid();
+}
+
+/* This is used to initiate an OpenMPI checkpoint program, 
+ * but is written to be general purpose */
+static void
+_exec_prog(slurm_msg_t *msg)
+{
+	pid_t child;
+	int pfd[2], status, exit_code = 0, i;
+	ssize_t len;
+	char *argv[4], buf[256] = "";
+	time_t now = time(NULL);
+	bool checkpoint = false;
+	srun_exec_msg_t *exec_msg = msg->data;
+
+	if (exec_msg->argc > 2) {
+		verbose("Exec '%s %s' for %u.%u", 
+			exec_msg->argv[0], exec_msg->argv[1],
+			exec_msg->job_id, exec_msg->step_id);
+	} else {
+		verbose("Exec '%s' for %u.%u", 
+			exec_msg->argv[0], 
+			exec_msg->job_id, exec_msg->step_id);
+	}
+
+	if (strcmp(exec_msg->argv[0], "ompi-checkpoint") == 0) {
+		if (srun_ppid)
+			checkpoint = true;
+		else {
+			error("Can not create checkpoint, no srun_ppid set");
+			exit_code = EINVAL;
+			goto fini;
+		}
+	}
+	if (checkpoint) {
+		/* OpenMPI specific checkpoint support */
+		info("Checkpoint started at %s", ctime(&now));
+		for (i=0; (exec_msg->argv[i] && (i<2)); i++) {
+			argv[i] = exec_msg->argv[i];
+		}
+		snprintf(buf, sizeof(buf), "%ld", (long) srun_ppid);
+		argv[i] = buf;
+		argv[i+1] = NULL;
+	}
+
+	if (pipe(pfd) == -1) {
+		snprintf(buf, sizeof(buf), "pipe: %s", strerror(errno));
+		error("%s", buf);
+		exit_code = errno;
+		goto fini;
+	}
+
+	child = fork();
+	if (child == 0) {
+		int fd = open("/dev/null", O_RDONLY);
+		dup2(fd, 0);		/* stdin from /dev/null */
+		dup2(pfd[1], 1);	/* stdout to pipe */
+		dup2(pfd[1], 2);	/* stderr to pipe */
+		close(pfd[0]);
+		close(pfd[1]);
+		if (checkpoint)
+			execvp(exec_msg->argv[0], argv);
+		else
+			execvp(exec_msg->argv[0], exec_msg->argv);
+		error("execvp(%s): %m", exec_msg->argv[0]);
+	} else if (child < 0) {
+		snprintf(buf, sizeof(buf), "fork: %s", strerror(errno));
+		error("%s", buf);
+		exit_code = errno;
+		goto fini;
+	} else {
+		close(pfd[1]);
+		len = read(pfd[0], buf, sizeof(buf));
+		close(pfd[0]);
+		waitpid(child, &status, 0);
+		exit_code = WEXITSTATUS(status);
+	}
+
+fini:	if (checkpoint) {
+		now = time(NULL);
+		if (exit_code) {
+			info("Checkpoint completion code %d at %s", 
+				exit_code, ctime(&now));
+		} else {
+			info("Checkpoint completed successfully at %s",
+				ctime(&now));
+		}
+		if (buf[0])
+			info("Checkpoint location: %s", buf);
+		slurm_checkpoint_complete(exec_msg->job_id, exec_msg->step_id,
+			time(NULL), (uint32_t) exit_code, buf);
+	}
+}
+
diff --git a/src/api/step_launch.h b/src/api/step_launch.h
index 53ab565c9f7..1faca3f13c9 100644
--- a/src/api/step_launch.h
+++ b/src/api/step_launch.h
@@ -94,4 +94,9 @@ struct step_launch_state * step_launch_state_create(slurm_step_ctx_t *ctx);
  */
 void step_launch_state_destroy(struct step_launch_state *sls);
 
+/*
+ * Record the parent process ID of the program which spawned this.
+ * Needed to locate the mpirun program for OpenMPI checkpoint
+ */
+void record_ppid(void);
 #endif /* _STEP_LAUNCH_H */
diff --git a/src/srun/srun.c b/src/srun/srun.c
index 5bb9ebe0575..41701500eac 100644
--- a/src/srun/srun.c
+++ b/src/srun/srun.c
@@ -100,7 +100,6 @@
 #define	TYPE_SCRIPT	2
 
 mpi_plugin_client_info_t mpi_job_info[1];
-pid_t srun_ppid = 0;
 static struct termios termdefaults;
 int global_rc;
 srun_job_t *job = NULL;
@@ -185,7 +184,7 @@ int srun(int ac, char **av)
 		error ("srun initialization failed");
 		exit (1);
 	}
-	srun_ppid = getppid();
+	record_ppid();
 	
 	/* reinit log with new verbosity (if changed by command line)
 	 */
diff --git a/src/srun/srun.h b/src/srun/srun.h
index 90f9aaf6230..25c492b1830 100644
--- a/src/srun/srun.h
+++ b/src/srun/srun.h
@@ -37,8 +37,6 @@
 #include "src/api/step_io.h"
 #include "src/srun/srun_job.h"
 
-extern pid_t srun_ppid;		/* required for OpenMPI checkpoint */
-
 void srun_set_stdio_fds(srun_job_t *job, slurm_step_io_fds_t *cio_fds);
 
 #endif /* !_HAVE_SRUN_H */
-- 
GitLab