step_launch.c 18.47 KiB
/*****************************************************************************\
* launch.c - launch a parallel job step
*
* $Id: spawn.c 7973 2006-05-08 23:52:35Z morrone $
*****************************************************************************
* Copyright (C) 2006 The Regents of the University of California.
* Produced at Lawrence Livermore National Laboratory (cf, DISCLAIMER).
* Written by Christopher J. Morrone <morrone2@llnl.gov>
* UCRL-CODE-217948.
*
* This file is part of SLURM, a resource management program.
* For details, see <http://www.llnl.gov/linux/slurm/>.
*
* SLURM is free software; you can redistribute it and/or modify it under
* the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 2 of the License, or (at your option)
* any later version.
*
* SLURM is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along
* with SLURM; if not, write to the Free Software Foundation, Inc.,
* 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
\*****************************************************************************/
#ifdef HAVE_CONFIG_H
# include "config.h"
#endif
#include <errno.h>
#include <pthread.h>
#include <stdarg.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <netinet/in.h>
#include <sys/param.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <slurm/slurm.h>
#include "src/common/hostlist.h"
#include "src/common/slurm_protocol_api.h"
#include "src/common/slurm_protocol_defs.h"
#include "src/common/xmalloc.h"
#include "src/common/xstring.h"
#include "src/common/eio.h"
#include "src/common/net.h"
#include "src/common/fd.h"
#include "src/common/slurm_auth.h"
#include "src/common/forward.h"
#include "src/common/plugstack.h"
#include "src/common/slurm_cred.h"
#include "src/api/step_ctx.h"
#include "src/api/step_pmi.h"
/**********************************************************************
* General declarations for step launch code
**********************************************************************/
static int _launch_tasks(slurm_step_ctx ctx,
launch_tasks_request_msg_t *launch_msg);
static client_io_t *_setup_step_client_io(slurm_step_ctx ctx,
slurm_step_io_fds_t fds,
bool labelio);
/* static int _get_step_addresses(const slurm_step_ctx ctx, */
/* slurm_addr **address, int *num_addresses); */
/**********************************************************************
* Message handler declarations
**********************************************************************/
static uid_t slurm_uid;
static int _msg_thr_create(struct step_launch_state *sls, int num_nodes);
static void _handle_msg(struct step_launch_state *sls, slurm_msg_t *msg);
static bool _message_socket_readable(eio_obj_t *obj);
static int _message_socket_accept(eio_obj_t *obj, List objs);
static struct io_operations message_socket_ops = {
readable: &_message_socket_readable,
handle_read: &_message_socket_accept
};
/**********************************************************************
* API functions
**********************************************************************/
/*
* slurm_job_step_launch_t_init - initialize a user-allocated
* slurm_job_step_launch_t structure with default values.
* default values. This function will NOT allocate any new memory.
* IN ptr - pointer to a structure allocated by the use. The structure will
* be intialized.
*/
void slurm_job_step_launch_t_init (slurm_job_step_launch_t *ptr)
{
static slurm_step_io_fds_t fds = SLURM_STEP_IO_FDS_INITIALIZER;
ptr->argc = 0;
ptr->argv = NULL;
ptr->envc = 0;
ptr->env = NULL;
ptr->cwd = NULL;
ptr->buffered_stdio = true;
ptr->labelio = false;
ptr->remote_output_filename = NULL;
ptr->remote_error_filename = NULL;
ptr->remote_input_filename = NULL;
memcpy(&ptr->local_fds, &fds, sizeof(fds));
ptr->gid = getgid();
ptr->multi_prog = false;
ptr->slurmd_debug = 0;
ptr->parallel_debug = false;
ptr->task_start_callback = NULL;
ptr->task_finish_callback = NULL;
}
/*
* slurm_step_launch - launch a parallel job step
* IN ctx - job step context generated by slurm_step_ctx_create
* RET SLURM_SUCCESS or SLURM_ERROR (with errno set)
*/
int slurm_step_launch (slurm_step_ctx ctx,
const slurm_job_step_launch_t *params)
{
launch_tasks_request_msg_t launch;
int i;
char **env = NULL;
debug("Entering slurm_step_launch");
if (ctx == NULL || ctx->magic != STEP_CTX_MAGIC) {
error("Not a valid slurm_step_ctx!");
slurm_seterrno(EINVAL);
return SLURM_ERROR;
}
/* Initialize launch state structure */
ctx->launch_state = xmalloc(sizeof(struct step_launch_state));
if (ctx->launch_state == NULL) {
error("Failed to allocate memory for step launch state: %m");
return SLURM_ERROR;
}
pthread_mutex_init(&ctx->launch_state->lock, NULL);
pthread_cond_init(&ctx->launch_state->cond, NULL);
ctx->launch_state->tasks_requested = ctx->step_req->num_tasks;
ctx->launch_state->tasks_start_success = 0;
ctx->launch_state->tasks_start_failure = 0;
ctx->launch_state->tasks_exited = 0;
ctx->launch_state->task_start_callback = params->task_start_callback;
ctx->launch_state->task_finish_callback = params->task_finish_callback;
/* Create message receiving sockets and handler thread */
_msg_thr_create(ctx->launch_state, ctx->step_req->node_count);
/* Start tasks on compute nodes */
launch.job_id = ctx->alloc_resp->job_id;
launch.uid = ctx->step_req->user_id;
launch.gid = params->gid;
launch.argc = params->argc;
launch.argv = params->argv;
launch.cred = ctx->step_resp->cred;
launch.job_step_id = ctx->step_resp->job_step_id;
env = env_array_create_for_step(ctx->step_resp,
"localhost",
15500,
"127.0.0.1");
env_array_merge(&env, (const char **)params->env);
launch.envc = envcount(env);
launch.env = env;
launch.cwd = params->cwd;
launch.nnodes = ctx->step_req->node_count;
launch.nprocs = ctx->step_req->num_tasks;
launch.slurmd_debug = params->slurmd_debug;
launch.switch_job = ctx->step_resp->switch_job;
launch.task_prolog = NULL; /* FIXME - opt.task_prolog */
launch.task_epilog = NULL; /* FIXME - opt.task_epilog */
launch.cpu_bind_type = 0; /* FIXME opt.cpu_bind_type; */
launch.cpu_bind = NULL; /* FIXME opt.cpu_bind; */
launch.mem_bind_type = 0; /* FIXME opt.mem_bind_type; */
launch.mem_bind = NULL; /* FIXME opt.mem_bind; */
launch.multi_prog = params->multi_prog ? 1 : 0;
launch.options = job_options_create();
spank_set_remote_options (launch.options);
launch.ofname = params->remote_output_filename;
launch.efname = params->remote_error_filename;
launch.ifname = params->remote_input_filename;
launch.buffered_stdio = params->buffered_stdio ? 1 : 0;
if (params->parallel_debug)
launch.task_flags |= TASK_PARALLEL_DEBUG;
/* Node specific message contents */
/* if (slurm_mpi_single_task_per_node ()) { */
/* for (i = 0; i < job->num_hosts; i++) */
/* job->tasks[i] = 1; */
/* } */
launch.tasks_to_launch = ctx->step_resp->step_layout->tasks;
launch.cpus_allocated = ctx->step_resp->step_layout->tasks;
launch.global_task_ids = ctx->step_resp->step_layout->tids;
ctx->launch_state->client_io = _setup_step_client_io(
ctx, params->local_fds, params->labelio);
if (ctx->launch_state->client_io == NULL)
return SLURM_ERROR;
if (client_io_handler_start(ctx->launch_state->client_io)
!= SLURM_SUCCESS)
return SLURM_ERROR;
launch.num_io_port = ctx->launch_state->client_io->num_listen;
launch.io_port = xmalloc(sizeof(uint16_t) * launch.num_io_port);
for (i = 0; i < launch.num_io_port; i++) {
launch.io_port[i] =
ntohs(ctx->launch_state->client_io->listenport[i]);
}
launch.num_resp_port = ctx->launch_state->num_resp_port;
launch.resp_port = xmalloc(sizeof(uint16_t) * launch.num_resp_port);
for (i = 0; i < launch.num_resp_port; i++) {
launch.resp_port[i] = ntohs(ctx->launch_state->resp_port[i]);
}
_launch_tasks(ctx, &launch);
env_array_free(env);
return SLURM_SUCCESS;
}
/*
* Block until all tasks have started.
*/
int slurm_step_launch_wait_start(slurm_step_ctx ctx)
{
struct step_launch_state *sls = ctx->launch_state;
/* First wait for all tasks to complete */
pthread_mutex_lock(&sls->lock);
while ((sls->tasks_start_success + sls->tasks_start_failure)
< sls->tasks_requested) {
pthread_cond_wait(&sls->cond, &sls->lock);
}
pthread_mutex_unlock(&sls->lock);
return 1;
}
/*
* Block until all tasks have finished (or failed to start altogether).
*/
void slurm_step_launch_wait_finish(slurm_step_ctx ctx)
{
struct step_launch_state *sls = ctx->launch_state;
/* First wait for all tasks to complete */
pthread_mutex_lock(&sls->lock);
while (((sls->tasks_start_success + sls->tasks_start_failure)
< sls->tasks_requested)
|| (sls->tasks_exited < sls->tasks_start_success)) {
pthread_cond_wait(&sls->cond, &sls->lock);
}
/* Then shutdown the message handler thread */
eio_signal_shutdown(sls->msg_handle);
pthread_join(sls->msg_thread, NULL);
eio_handle_destroy(sls->msg_handle);
/* Then wait for the IO thread to finish */
client_io_handler_finish(sls->client_io);
client_io_handler_destroy(sls->client_io);
pthread_mutex_unlock(&sls->lock);
/* FIXME - put these in an sls-specific desctructor */
pthread_mutex_destroy(&sls->lock);
pthread_cond_destroy(&sls->cond);
xfree(sls->resp_port);
}
/**********************************************************************
* Message handler functions
**********************************************************************/
static void *_msg_thr_internal(void *arg)
{
struct step_launch_state *sls = (struct step_launch_state *)arg;
eio_handle_mainloop(sls->msg_handle);
return NULL;
}
static inline int
_estimate_nports(int nclients, int cli_per_port)
{
div_t d;
d = div(nclients, cli_per_port);
return d.rem > 0 ? d.quot + 1 : d.quot;
}
static int _msg_thr_create(struct step_launch_state *sls, int num_nodes)
{
int sock = -1;
int port = -1;
eio_obj_t *obj;
int i;
debug("Entering _msg_thr_create()");
slurm_uid = (uid_t) slurm_get_slurm_user_id();
sls->msg_handle = eio_handle_create();
sls->num_resp_port = _estimate_nports(num_nodes, 48);
sls->resp_port = xmalloc(sizeof(uint16_t) * sls->num_resp_port);
for (i = 0; i < sls->num_resp_port; i++) {
if (net_stream_listen(&sock, &port) < 0) {
error("unable to intialize step launch listening socket: %m");
return SLURM_ERROR;
}
sls->resp_port[i] = port;
obj = eio_obj_create(sock, &message_socket_ops, (void *)sls);
eio_new_initial_obj(sls->msg_handle, obj);
}
if (pthread_create(&sls->msg_thread, NULL,
_msg_thr_internal, (void *)sls) != 0) {
error("pthread_create of message thread: %m");
return SLURM_ERROR;
}
return SLURM_SUCCESS;
}
static bool _message_socket_readable(eio_obj_t *obj)
{
debug3("Called _message_socket_readable");
if (obj->shutdown == true) {
if (obj->fd != -1) {
debug2(" false, shutdown");
close(obj->fd);
obj->fd = -1;
/*_wait_for_connections();*/
} else {
debug2(" false");
}
return false;
}
return true;
}
static int _message_socket_accept(eio_obj_t *obj, List objs)
{
struct step_launch_state *sls = (struct step_launch_state *)obj->arg;
int fd;
unsigned char *uc;
short port;
struct sockaddr_un addr;
slurm_msg_t *msg = NULL;
int len = sizeof(addr);
int timeout = 0; /* slurm default value */
List ret_list = NULL;
debug3("Called _msg_socket_accept");
while ((fd = accept(obj->fd, (struct sockaddr *)&addr,
(socklen_t *)&len)) < 0) {
if (errno == EINTR)
continue;
if (errno == EAGAIN
|| errno == ECONNABORTED
|| errno == EWOULDBLOCK) {
return SLURM_SUCCESS;
}
error("Error on msg accept socket: %m");
obj->shutdown = true;
return SLURM_SUCCESS;
}
fd_set_close_on_exec(fd);
fd_set_blocking(fd);
/* Should not call slurm_get_addr() because the IP may not be
in /etc/hosts. */
uc = (unsigned char *)&((struct sockaddr_in *)&addr)->sin_addr.s_addr;
port = ((struct sockaddr_in *)&addr)->sin_port;
debug2("got message connection from %u.%u.%u.%u:%d",
uc[0], uc[1], uc[2], uc[3], ntohs(port));
fflush(stdout);
msg = xmalloc(sizeof(slurm_msg_t));
forward_init(&msg->forward, NULL);
msg->ret_list = NULL;
msg->conn_fd = fd;
msg->forward_struct_init = 0;
/* multiple jobs (easily induced via no_alloc) and highly
* parallel jobs using PMI sometimes result in slow message
* responses and timeouts. Raise the default timeout for srun. */
timeout = slurm_get_msg_timeout() * 8;
again:
ret_list = slurm_receive_msg(fd, msg, timeout);
if(!ret_list || errno != SLURM_SUCCESS) {
if (errno == EINTR) {
list_destroy(ret_list);
goto again;
}
error("slurm_receive_msg[%u.%u.%u.%u]: %m",
uc[0],uc[1],uc[2],uc[3]);
goto cleanup;
}
if(list_count(ret_list)>0) {
error("_message_socket_accept connection: "
"got %d from receive, expecting 0",
list_count(ret_list));
}
msg->ret_list = ret_list;
_handle_msg(sls, msg); /* handle_msg frees msg */
cleanup:
if ((msg->conn_fd >= 0) && slurm_close_accepted_conn(msg->conn_fd) < 0)
error ("close(%d): %m", msg->conn_fd);
slurm_free_msg(msg);
return SLURM_SUCCESS;
}
static void
_launch_handler(struct step_launch_state *sls, slurm_msg_t *resp)
{
launch_tasks_response_msg_t *msg = resp->data;
pthread_mutex_lock(&sls->lock);
if (msg->return_code == SLURM_SUCCESS)
sls->tasks_start_success += msg->count_of_pids;
else
sls->tasks_start_failure += msg->count_of_pids;
if (sls->task_start_callback != NULL)
(sls->task_start_callback)(msg);
pthread_cond_signal(&sls->cond);
pthread_mutex_unlock(&sls->lock);
}
static void
_exit_handler(struct step_launch_state *sls, slurm_msg_t *exit_msg)
{
task_exit_msg_t *msg = (task_exit_msg_t *) exit_msg->data;
pthread_mutex_lock(&sls->lock);
sls->tasks_exited += msg->num_tasks;
if (sls->task_finish_callback != NULL)
(sls->task_finish_callback)(msg);
pthread_cond_signal(&sls->cond);
pthread_mutex_unlock(&sls->lock);
}
static void
_node_fail_handler(struct step_launch_state *sls, slurm_msg_t *fail_msg)
{
/*srun_node_fail_msg_t *nf = fail_msg->data;*/
pthread_mutex_lock(&sls->lock);
/* does nothing yet */
pthread_cond_signal(&sls->cond);
pthread_mutex_unlock(&sls->lock);
slurm_send_rc_msg(fail_msg, SLURM_SUCCESS);
}
static void
_handle_msg(struct step_launch_state *sls, slurm_msg_t *msg)
{
uid_t req_uid = g_slurm_auth_get_uid(msg->auth_cred);
uid_t uid = getuid();
int rc;
if ((req_uid != slurm_uid) && (req_uid != 0) && (req_uid != uid)) {
error ("Security violation, slurm message from uid %u",
(unsigned int) req_uid);
return;
}
switch (msg->msg_type) {
case RESPONSE_LAUNCH_TASKS:
debug2("received task launch\n");
_launch_handler(sls, msg);
slurm_free_launch_tasks_response_msg(msg->data);
break;
case MESSAGE_TASK_EXIT:
debug2("received task exit\n");
_exit_handler(sls, msg);
slurm_free_task_exit_msg(msg->data);
break;
case SRUN_NODE_FAIL:
debug2("received srun node fail\n");
_node_fail_handler(sls, msg);
slurm_free_srun_node_fail_msg(msg->data);
break;
case PMI_KVS_PUT_REQ:
debug2("PMI_KVS_PUT_REQ received\n");
rc = pmi_kvs_put((struct kvs_comm_set *) msg->data);
slurm_send_rc_msg(msg, rc);
break;
case PMI_KVS_GET_REQ:
debug2("PMI_KVS_GET_REQ received\n");
rc = pmi_kvs_get((kvs_get_msg_t *) msg->data);
slurm_send_rc_msg(msg, rc);
slurm_free_get_kvs_msg((kvs_get_msg_t *) msg->data);
break;
default:
error("received spurious message type: %d\n",
msg->msg_type);
break;
}
return;
}
/**********************************************************************
* Task launch functions
**********************************************************************/
static int _launch_tasks(slurm_step_ctx ctx,
launch_tasks_request_msg_t *launch_msg)
{
slurm_msg_t msg;
Buf buffer = NULL;
hostlist_t hostlist = NULL;
hostlist_iterator_t itr = NULL;
int zero = 0;
List ret_list = NULL;
ListIterator ret_itr;
ListIterator ret_data_itr;
ret_types_t *ret;
ret_data_info_t *ret_data;
int timeout;
debug("Entering _launch_tasks");
msg.msg_type = REQUEST_LAUNCH_TASKS;
msg.data = launch_msg;
buffer = slurm_pack_msg_no_header(&msg);
hostlist = hostlist_create(ctx->step_resp->step_layout->node_list);
itr = hostlist_iterator_create(hostlist);
msg.srun_node_id = 0;
msg.ret_list = NULL;
msg.orig_addr.sin_addr.s_addr = 0;
msg.buffer = buffer;
memcpy(&msg.address, &ctx->step_resp->step_layout->node_addr[0],
sizeof(slurm_addr));
timeout = slurm_get_msg_timeout();
forward_set_launch(&msg.forward,
ctx->step_resp->step_layout->node_cnt,
&zero,
ctx->step_resp->step_layout->node_cnt,
ctx->step_resp->step_layout->node_addr,
itr,
timeout);
hostlist_iterator_destroy(itr);
hostlist_destroy(hostlist);
ret_list =
slurm_send_recv_rc_packed_msg(&msg, timeout);
if (ret_list == NULL) {
error("slurm_send_recv_rc_packed_msg failed miserably: %m");
return SLURM_ERROR;
}
ret_itr = list_iterator_create(ret_list);
while ((ret = list_next(ret_itr)) != NULL) {
debug("launch returned msg_rc=%d err=%d type=%d",
ret->msg_rc, ret->err, ret->type);
if (ret->msg_rc != SLURM_SUCCESS) {
ret_data_itr =
list_iterator_create(ret->ret_data_list);
while ((ret_data = list_next(ret_data_itr)) != NULL) {
errno = ret->err;
error("Task launch failed on node %s(%d): %m",
ret_data->node_name, ret_data->nodeid);
}
list_iterator_destroy(ret_data_itr);
} else {
#if 0 /* only for debugging */
ret_data_itr =
list_iterator_create(ret->ret_data_list);
while ((ret_data = list_next(ret_data_itr)) != NULL) {
errno = ret->err;
info("Launch success on node %s(%d)",
ret_data->node_name, ret_data->nodeid);
}
list_iterator_destroy(ret_data_itr);
#endif
}
}
list_iterator_destroy(ret_itr);
list_destroy(ret_list);
return SLURM_SUCCESS;
}
static client_io_t *_setup_step_client_io(slurm_step_ctx ctx,
slurm_step_io_fds_t fds,
bool labelio)
{
int siglen;
char *sig;
client_io_t *client_io;
if (slurm_cred_get_signature(ctx->step_resp->cred, &sig, &siglen)
< 0) {
debug("_setup_step_client_io slurm_cred_get_signature failed");
return NULL;
}
client_io = client_io_handler_create(fds,
ctx->step_req->num_tasks,
ctx->step_req->node_count,
sig,
labelio);
/* no need to free sig, it is just a pointer into the credential */
return client_io;
}