diff --git a/src/common/slurm_protocol_defs.c b/src/common/slurm_protocol_defs.c index 86bca47a219ce9bb5c1f6f9277c04d227d174b5c..f1a5b1673900f00b30971001aa34d6b081ac38ac 100644 --- a/src/common/slurm_protocol_defs.c +++ b/src/common/slurm_protocol_defs.c @@ -326,19 +326,24 @@ void slurm_free_launch_tasks_request_msg ( launch_tasks_request_msg_t * msg ) int i ; if ( msg ) { - if ( msg -> credentials ) - xfree ( msg -> credentials ); + if ( msg -> credential ) + xfree ( msg -> credential ); if ( msg -> env ) for ( i = 0 ; i < msg -> envc ; i++ ) { if ( msg -> env[i] ) xfree ( msg -> env[i] ); } - xfree ( msg -> env ) ; + xfree ( msg -> env ) ; if ( msg -> cwd ) xfree ( msg -> cwd ); - if ( msg -> cmd_line ) - xfree ( msg -> cmd_line ); + if ( msg -> argv ) + for ( i = 0 ; i < msg -> argc ; i++ ) + { + if ( msg -> argv[i] ) + xfree ( msg -> argv[i] ); + } + xfree ( msg -> argv ) ; if ( msg -> global_task_ids ) xfree ( msg -> global_task_ids ); xfree ( msg ) ; @@ -349,8 +354,8 @@ void slurm_free_reattach_tasks_streams_msg ( reattach_tasks_streams_msg_t * msg { if ( msg ) { - if ( msg -> credentials ) - xfree ( msg -> credentials ); + if ( msg -> credential ) + xfree ( msg -> credential ); if ( msg -> global_task_ids ) xfree ( msg -> global_task_ids ); xfree ( msg ) ; diff --git a/src/common/slurm_protocol_defs.h b/src/common/slurm_protocol_defs.h index 0dc5663892c2ecdd9ef7bdb5a638343d007df2f8..8dede346260ff433e641275ca68bce67b0d1da3e 100644 --- a/src/common/slurm_protocol_defs.h +++ b/src/common/slurm_protocol_defs.h @@ -168,13 +168,13 @@ typedef struct slurm_protocol_header uint32_t body_length ; } header_t ; -typedef struct slurm_stream_io_header +typedef struct slurm_io_stream_header { uint16_t version ; /*version/magic number*/ char key[16] ; uint32_t task_id ; uint16_t type ; -} slurm_stream_io_header ; +} slurm_io_stream_header_t ; /* Job credential */ typedef struct slurm_job_credential @@ -256,12 +256,13 @@ typedef struct launch_tasks_request_msg uint32_t job_id ; uint32_t job_step_id ; uint32_t uid ; - slurm_job_credential_t* credentials; + slurm_job_credential_t* credential; uint32_t tasks_to_launch ; uint16_t envc ; char ** env ; char * cwd ; - char * cmd_line ; + uint16_t argc; + char ** argv; slurm_addr response_addr ; slurm_addr streams; uint32_t * global_task_ids; @@ -278,7 +279,7 @@ typedef struct reattach_tasks_streams_msg uint32_t job_id ; uint32_t job_step_id ; uint32_t uid ; - slurm_job_credential_t* credentials; + slurm_job_credential_t* credential; uint32_t tasks_to_reattach ; slurm_addr streams; uint32_t * global_task_ids; diff --git a/src/common/slurm_protocol_pack.c b/src/common/slurm_protocol_pack.c index 212fe58829e2cd71731c2b78bcb9191f4af55297..ada4e1bc41f9bd870916e04e50dac225bdbce43f 100644 --- a/src/common/slurm_protocol_pack.c +++ b/src/common/slurm_protocol_pack.c @@ -66,7 +66,7 @@ void unpack_header ( header_t * header , char ** buffer , uint32_t * length ) unpack32 ( & header -> body_length , ( void ** ) buffer , length ) ; } -void pack_stream_io_header ( slurm_stream_io_header * msg , void ** buffer , uint32_t * length ) +void pack_io_stream_header ( slurm_io_stream_header_t * msg , void ** buffer , uint32_t * length ) { assert ( msg != NULL ); @@ -76,7 +76,7 @@ void pack_stream_io_header ( slurm_stream_io_header * msg , void ** buffer , uin pack16( msg->type, buffer, length ) ; } -void unpack_stream_io_header ( slurm_stream_io_header * msg , void ** buffer , uint32_t * length ) +void unpack_io_stream_header ( slurm_io_stream_header_t * msg , void ** buffer , uint32_t * length ) { uint16_t uint16_tmp; @@ -1096,7 +1096,7 @@ void pack_reattach_tasks_streams_msg ( reattach_tasks_streams_msg_t * msg , void pack32 ( msg -> job_id , buffer , length ) ; pack32 ( msg -> job_step_id , buffer , length ) ; pack32 ( msg -> uid , buffer , length ) ; - pack_job_credential ( msg -> credentials , buffer , length ) ; + pack_job_credential ( msg -> credential , buffer , length ) ; pack32 ( msg -> tasks_to_reattach , buffer , length ) ; slurm_pack_slurm_addr ( & msg -> streams , buffer , length ) ; pack32_array ( msg -> global_task_ids , ( uint16_t ) msg -> tasks_to_reattach , buffer , length ) ; @@ -1117,7 +1117,7 @@ int unpack_reattach_tasks_streams_msg ( reattach_tasks_streams_msg_t ** msg_ptr unpack32 ( & msg -> job_id , buffer , length ) ; unpack32 ( & msg -> job_step_id , buffer , length ) ; unpack32 ( & msg -> uid , buffer , length ) ; - unpack_job_credential( & msg -> credentials , buffer , length ) ; + unpack_job_credential( & msg -> credential , buffer , length ) ; unpack32 ( & msg -> tasks_to_reattach , buffer , length ) ; slurm_unpack_slurm_addr_no_alloc ( & msg -> streams , buffer , length ) ; unpack32_array ( & msg -> global_task_ids , & uint16_tmp , buffer , length ) ; @@ -1154,11 +1154,11 @@ void pack_launch_tasks_request_msg ( launch_tasks_request_msg_t * msg , void ** pack32 ( msg -> job_id , buffer , length ) ; pack32 ( msg -> job_step_id , buffer , length ) ; pack32 ( msg -> uid , buffer , length ) ; - pack_job_credential ( msg -> credentials , buffer , length ) ; + pack_job_credential ( msg -> credential , buffer , length ) ; pack32 ( msg -> tasks_to_launch , buffer , length ) ; packstring_array ( msg -> env , msg -> envc , buffer , length ) ; packstr ( msg -> cwd , buffer , length ) ; - packstr ( msg -> cmd_line , buffer , length ) ; + packstring_array ( msg -> argv , msg -> argc , buffer , length ) ; slurm_pack_slurm_addr ( & msg -> response_addr , buffer , length ) ; slurm_pack_slurm_addr ( & msg -> streams , buffer , length ) ; pack32_array ( msg -> global_task_ids , ( uint16_t ) msg -> tasks_to_launch , buffer , length ) ; @@ -1179,11 +1179,11 @@ int unpack_launch_tasks_request_msg ( launch_tasks_request_msg_t ** msg_ptr , vo unpack32 ( & msg -> job_id , buffer , length ) ; unpack32 ( & msg -> job_step_id , buffer , length ) ; unpack32 ( & msg -> uid , buffer , length ) ; - unpack_job_credential( & msg -> credentials , buffer , length ) ; + unpack_job_credential( & msg -> credential , buffer , length ) ; unpack32 ( & msg -> tasks_to_launch , buffer , length ) ; unpackstring_array ( & msg -> env , & msg -> envc , buffer , length ) ; unpackstr_xmalloc ( & msg -> cwd , & uint16_tmp , buffer , length ) ; - unpackstr_xmalloc ( & msg -> cmd_line , & uint16_tmp , buffer , length ) ; + unpackstring_array ( & msg -> argv , & msg->argc , buffer , length ) ; slurm_unpack_slurm_addr_no_alloc ( & msg -> response_addr , buffer , length ) ; slurm_unpack_slurm_addr_no_alloc ( & msg -> streams , buffer , length ) ; unpack32_array ( & msg -> global_task_ids , & uint16_tmp , buffer , length ) ; diff --git a/src/common/slurm_protocol_pack.h b/src/common/slurm_protocol_pack.h index 0ef94d6540951ef3df8f0bc16ba01d4a0606a049..026ea97b5c1960f3f083fe96219fb600fbdd1f22 100644 --- a/src/common/slurm_protocol_pack.h +++ b/src/common/slurm_protocol_pack.h @@ -21,8 +21,8 @@ void pack_header ( header_t * header , char ** buffer , uint32_t * length ); void unpack_header ( header_t * header , char ** buffer , uint32_t * length ); /* Pack / Unpack methods for slurm io pipe streams header */ -void pack_stream_io_header ( slurm_stream_io_header * msg , void ** buffer , uint32_t * length ) ; -void unpack_stream_io_header ( slurm_stream_io_header * msg , void ** buffer , uint32_t * length ) ; +void pack_io_stream_header ( slurm_io_stream_header_t * msg , void ** buffer , uint32_t * length ) ; +void unpack_io_stream_header ( slurm_io_stream_header_t * msg , void ** buffer , uint32_t * length ) ; /* generic case statement Pack / Unpack methods for slurm protocol bodies */ int pack_msg ( slurm_msg_t const * msg , char ** buffer , uint32_t * buf_len ); diff --git a/src/common/slurm_protocol_util.c b/src/common/slurm_protocol_util.c index c2b27ff3d029063503ff686b93360e17805a0a4e..6a3e3a0cd80d219f90271cf4628b13ae789aaf6d 100644 --- a/src/common/slurm_protocol_util.c +++ b/src/common/slurm_protocol_util.c @@ -1,3 +1,6 @@ +#include <stdlib.h> +#include <assert.h> + #include <src/common/slurm_protocol_defs.h> #include <src/common/slurm_protocol_common.h> #include <src/common/slurm_protocol_util.h> @@ -22,4 +25,24 @@ void init_header ( header_t * header , slurm_msg_type_t msg_type , uint16_t flag header -> msg_type = msg_type ; } +/* checks to see that the specified header was sent from a node running the same version of the protocol as the current node */ +uint32_t check_io_stream_header_version( slurm_io_stream_header_t * header) +{ + if ( header -> version != SLURM_PROTOCOL_VERSION ) + { + info ( "Invalid Protocol Version %d ", header -> version ) ; + return SLURM_PROTOCOL_VERSION_ERROR ; + } + return SLURM_PROTOCOL_SUCCESS ; +} +/* simple function to create a header, always insuring that an accurate version string is inserted */ +void init_io_stream_header ( slurm_io_stream_header_t * header , char * key , uint32_t task_id , uint16_t type ) +{ + + assert ( key != NULL ); + header -> version = SLURM_PROTOCOL_VERSION ; + memcpy ( header -> key , key , SLURM_SSL_SIGNATURE_LENGTH ) ; + header -> task_id = task_id ; + header -> type = type ; +} diff --git a/src/common/slurm_protocol_util.h b/src/common/slurm_protocol_util.h index 2024cbc64e4f02327597975f5df366a1dbc643c4..c05024cfc9611ccca4c9e6eead885db06a458b77 100644 --- a/src/common/slurm_protocol_util.h +++ b/src/common/slurm_protocol_util.h @@ -16,7 +16,13 @@ #include <src/common/slurm_protocol_defs.h> #include <src/common/slurm_protocol_common.h> +#define SLURM_SSL_SIGNATURE_LENGTH 16 +#define SLURM_IO_STREAM_INOUT 0 +#define SLURM_IO_STREAM_SIGERR 1 uint32_t check_header_version( header_t * header) ; void init_header ( header_t * header , slurm_msg_type_t msg_type , uint16_t flags ) ; + +uint32_t check_io_stream_header_version( slurm_io_stream_header_t * header) ; +void init_io_stream_header ( slurm_io_stream_header_t * header , char * key , uint32_t task_id , uint16_t type ) ; #endif diff --git a/src/slurmd/task_mgr.c b/src/slurmd/task_mgr.c index c72206f0a01b1f1de7f05dc7a5bc70258c5d08ef..4c211f0662f1f744312fad948cb1b86003fa32df 100644 --- a/src/slurmd/task_mgr.c +++ b/src/slurmd/task_mgr.c @@ -115,12 +115,12 @@ int fan_out_task_launch ( launch_tasks_request_msg_t * launch_msg ) { curr_task = alloc_task ( shmem_ptr , curr_job_step ); task_start[i] = & curr_task -> task_start ; + curr_task -> task_id = launch_msg -> global_task_ids[i] ; /* fill in task_start struct */ task_start[i] -> launch_msg = launch_msg ; task_start[i] -> local_task_id = i ; - task_start[i] -> inout_dest = launch_msg -> streams ; - task_start[i] -> err_dest = launch_msg -> streams ; + task_start[i] -> io_streams_dest = launch_msg -> streams ; if ( pthread_create ( & task_start[i]->pthread_id , NULL , task_exec_thread , ( void * ) task_start[i] ) ) goto kill_threads; @@ -149,6 +149,7 @@ int forward_io ( task_start_t * task_arg ) { pthread_attr_t pthread_attr ; int local_errno; + slurm_io_stream_header_t io_header ; #define STDIN_OUT_SOCK 0 #define SIG_STDERR_SOCK 1 @@ -156,19 +157,41 @@ int forward_io ( task_start_t * task_arg ) posix_signal_pipe_ignore ( ) ; /* open stdout & stderr sockets */ - if ( ( task_arg->sockets[STDIN_OUT_SOCK] = slurm_open_stream ( & ( task_arg -> inout_dest ) ) ) == SLURM_PROTOCOL_ERROR ) + if ( ( task_arg->sockets[STDIN_OUT_SOCK] = slurm_open_stream ( & ( task_arg -> io_streams_dest ) ) ) == SLURM_PROTOCOL_ERROR ) { local_errno = errno ; info ( "error opening socket to srun to pipe stdout errno %i" , local_errno ) ; // pthread_exit ( 0 ) ; } + else + { + char buffer[sizeof(slurm_io_stream_header_t)] ; + char * buf_ptr = buffer ; + int buf_size = sizeof(slurm_io_stream_header_t) ; + int size = sizeof(slurm_io_stream_header_t) ; + + init_io_stream_header ( & io_header , task_arg -> launch_msg -> credential -> signature , task_arg -> launch_msg -> global_task_ids[task_arg -> local_task_id ] , SLURM_IO_STREAM_INOUT ) ; + pack_io_stream_header ( & io_header , & buf_ptr , & size ) ; + slurm_write_stream ( task_arg->sockets[STDIN_OUT_SOCK] , buffer , buf_size - size ) ; + } - if ( ( task_arg->sockets[SIG_STDERR_SOCK] = slurm_open_stream ( &( task_arg -> err_dest ) ) ) == SLURM_PROTOCOL_ERROR ) + if ( ( task_arg->sockets[SIG_STDERR_SOCK] = slurm_open_stream ( &( task_arg -> io_streams_dest ) ) ) == SLURM_PROTOCOL_ERROR ) { local_errno = errno ; info ( "error opening socket to srun to pipe stdout errno %i" , local_errno ) ; // pthread_exit ( 0 ) ; } + else + { + char buffer[sizeof(slurm_io_stream_header_t)] ; + char * buf_ptr = buffer ; + int buf_size = sizeof(slurm_io_stream_header_t) ; + int size = sizeof(slurm_io_stream_header_t) ; + + init_io_stream_header ( & io_header , task_arg -> launch_msg -> credential -> signature , task_arg -> launch_msg -> global_task_ids[task_arg -> local_task_id ] , SLURM_IO_STREAM_SIGERR ) ; + pack_io_stream_header ( & io_header , & buf_ptr , & size ) ; + slurm_write_stream ( task_arg->sockets[SIG_STDERR_SOCK] , buffer , buf_size - size ) ; + } /* spawn io pipe threads */ pthread_attr_init( & pthread_attr ) ; @@ -316,7 +339,7 @@ void * stdout_io_pipe_thread ( void * arg ) if ( difftime ( curr_time , last_reconnect_try ) > RECONNECT_RETRY_TIME ) { slurm_close_stream ( io_arg->sockets[STDIN_OUT_SOCK] ) ; - if ( ( io_arg->sockets[STDIN_OUT_SOCK] = slurm_open_stream ( & ( io_arg -> inout_dest ) ) ) == SLURM_PROTOCOL_ERROR ) + if ( ( io_arg->sockets[STDIN_OUT_SOCK] = slurm_open_stream ( & ( io_arg -> io_streams_dest ) ) ) == SLURM_PROTOCOL_ERROR ) { local_errno = errno ; info ( "error reconnecting socket to srun to pipe stdout errno %i" , local_errno ) ; @@ -404,7 +427,7 @@ void * stderr_io_pipe_thread ( void * arg ) if ( difftime ( curr_time , last_reconnect_try ) > RECONNECT_RETRY_TIME ) { slurm_close_stream ( io_arg->sockets[SIG_STDERR_SOCK] ) ; - if ( ( io_arg->sockets[SIG_STDERR_SOCK] = slurm_open_stream ( &( io_arg -> err_dest ) ) ) == SLURM_PROTOCOL_ERROR ) + if ( ( io_arg->sockets[SIG_STDERR_SOCK] = slurm_open_stream ( &( io_arg -> io_streams_dest ) ) ) == SLURM_PROTOCOL_ERROR ) { local_errno = errno ; info ( "error reconnecting socket to srun to pipe stderr errno %i" , local_errno ) ; @@ -486,17 +509,18 @@ void * task_exec_thread ( void * arg ) } /* setuid and gid*/ - if ( ( rc = setuid ( launch_msg->uid ) ) == SLURM_ERROR ) + if ( ( rc = setgid ( pwd -> pw_gid ) ) == SLURM_ERROR ) { - info ( "set user id failed " ) ; + info ( "set group id failed " ) ; _exit ( SLURM_FAILURE ) ; } - - if ( ( rc = setgid ( pwd -> pw_gid ) ) == SLURM_ERROR ) + + if ( ( rc = setuid ( launch_msg->uid ) ) == SLURM_ERROR ) { - info ( "set group id failed " ) ; + info ( "set user id failed " ) ; _exit ( SLURM_FAILURE ) ; } + /* initgroups */ /*if ( ( rc = initgroups ( pwd ->pw_name , pwd -> pw_gid ) ) == SLURM_ERROR ) { @@ -508,10 +532,9 @@ void * task_exec_thread ( void * arg ) /* run bash and cmdline */ debug( "cwd %s", launch_msg->cwd ) ; chdir ( launch_msg->cwd ) ; - debug( "cmdline %s", launch_msg->cmd_line ) ; - execl ("/bin/bash", "bash", "-c", launch_msg->cmd_line, 0); + //execl ("/bin/bash", "bash", "-c", launch_msg->cmd_line, 0); - //execle ( "/bin/sh", launch_msg->cmd_line , launch_msg->env ); + execve ( launch_msg->argv[0], launch_msg->argv , launch_msg->env ); close ( STDIN_FILENO ); close ( STDOUT_FILENO ); close ( STDERR_FILENO ); @@ -626,8 +649,7 @@ int reattach_tasks_streams ( reattach_tasks_streams_msg_t * req_msg ) task_t * task = find_task ( job_step_ptr , req_msg->global_task_ids[i] ) ; if ( task != NULL ) { - task -> task_start . inout_dest = req_msg -> streams ; - task -> task_start . err_dest = req_msg -> streams ; + task -> task_start . io_streams_dest = req_msg -> streams ; } else { diff --git a/src/slurmd/task_mgr.h b/src/slurmd/task_mgr.h index 26d764bba1c0bf0805c8fd213ac3cb4713baccf8..c3906b5b478c5fa2a1bf822590557e2b55aae7e5 100644 --- a/src/slurmd/task_mgr.h +++ b/src/slurmd/task_mgr.h @@ -42,7 +42,6 @@ typedef struct task_start int sockets[2]; int local_task_id; char addr_update; - slurm_addr inout_dest; - slurm_addr err_dest; + slurm_addr io_streams_dest; } task_start_t ; #endif diff --git a/testsuite/slurm_unit/slurmd/task_launch-test.c b/testsuite/slurm_unit/slurmd/task_launch-test.c index f0b53fd01d5215fe78213d43c4786aa43d0e2d40..67a79035c62116ee91c129359153bd053ff6725d 100644 --- a/testsuite/slurm_unit/slurmd/task_launch-test.c +++ b/testsuite/slurm_unit/slurmd/task_launch-test.c @@ -4,33 +4,40 @@ int main ( int argc , char* argv[] ) { slurm_msg_t request_msg ; slurm_msg_t response_msg ; - launch_tasks_request_msg_t launch_tasks_msg ; - slurm_addr io_pipe_addrs[2] ; - slurm_addr slurmd_addr ; - int gids[1] ; slurm_job_credential_t credential ; + launch_tasks_request_msg_t launch_tasks_msg ; + slurm_addr io_pipe_addrs ; + slurm_addr slurmd_addr ; + int gids[1] ; + char arg0[] = "./testme" ; + char arg1[] = "" ; + char * args[] = { arg0 , arg1 } ; - credential . node_list = "TESTING" ; gids[1] = 9999 ; + + credential . node_list = "TESTING" ; slurm_set_addr_char ( & slurmd_addr , 7002 , "localhost" ) ; + request_msg . msg_type = REQUEST_LAUNCH_TASKS ; request_msg . data = & launch_tasks_msg ; request_msg . address = slurmd_addr ; + + slurm_set_addr_char ( & io_pipe_addrs , 7071 , "localhost" ) ; - slurm_set_addr_char ( io_pipe_addrs , 7071 , "localhost" ) ; - slurm_set_addr_char ( io_pipe_addrs + 1 , 7072 , "localhost" ) ; + //kill_tasks_msg_t kill_tasks_msg ; launch_tasks_msg . job_id = 1000 ; - launch_tasks_msg . job_step_id = 2000 ; + launch_tasks_msg . job_step_id = 2000 ; launch_tasks_msg . uid = 8207 ; - launch_tasks_msg . credentials = & credential ; + launch_tasks_msg . credential = & credential ; launch_tasks_msg . tasks_to_launch = 1 ; launch_tasks_msg . envc = 0 ; launch_tasks_msg . env = NULL ; launch_tasks_msg . cwd = "." ; - launch_tasks_msg . cmd_line = "./testme" ; - launch_tasks_msg . streams = io_pipe_addrs ; + launch_tasks_msg . argc = 2 ; + launch_tasks_msg . argv = args ; + launch_tasks_msg . streams = io_pipe_addrs ; launch_tasks_msg . global_task_ids = gids ; - + slurm_send_only_node_msg ( & request_msg ) ; switch ( response_msg . msg_type ) diff --git a/testsuite/slurm_unit/slurmd/task_mgr-test.c b/testsuite/slurm_unit/slurmd/task_mgr-test.c index e658a5526dc1c0206a27368c574f860f69ca154f..2f722b502dbb956543ba0c71a6f6e98da77a9c72 100644 --- a/testsuite/slurm_unit/slurmd/task_mgr-test.c +++ b/testsuite/slurm_unit/slurmd/task_mgr-test.c @@ -6,22 +6,26 @@ int main ( int argc , char ** argv ) { launch_tasks_request_msg_t launch_tasks_msg ; - slurm_addr io_pipe_addrs[2] ; + slurm_addr io_pipe_addrs ; int gids[1] ; + char arg0[] = "./testme" ; + char arg1[] = "" ; + char * args[] = { arg0 , arg1 } ; + gids[1] = 9999 ; - slurm_set_addr_char ( io_pipe_addrs , 7071 , "localhost" ) ; - slurm_set_addr_char ( io_pipe_addrs + 1 , 7072 , "localhost" ) ; + slurm_set_addr_char ( & io_pipe_addrs , 7071 , "localhost" ) ; //kill_tasks_msg_t kill_tasks_msg ; launch_tasks_msg . job_id = 1000 ; launch_tasks_msg . job_step_id = 2000 ; launch_tasks_msg . uid = 8207 ; - launch_tasks_msg . credentials = NULL ; + launch_tasks_msg . credential = NULL ; launch_tasks_msg . tasks_to_launch = 1 ; launch_tasks_msg . envc = 0 ; launch_tasks_msg . env = NULL ; launch_tasks_msg . cwd = "." ; - launch_tasks_msg . cmd_line = "./testme" ; + launch_tasks_msg . argc = 2 ; + launch_tasks_msg . argv = args ; launch_tasks_msg . streams = io_pipe_addrs ; launch_tasks_msg . global_task_ids = gids ;