diff --git a/src/common/slurm_auth.c b/src/common/slurm_auth.c index 36d7d67d862a5eb4ba274ff204a92056e5c80dbf..ff0b3a83f93b78fa350089f1cee643ac0069f009 100644 --- a/src/common/slurm_auth.c +++ b/src/common/slurm_auth.c @@ -93,8 +93,9 @@ static const char *syms[] = { * A global authentication context. "Global" in the sense that there's * only one, with static bindings. We don't export it. */ -static slurm_auth_ops_t ops; -static plugin_context_t *g_context = NULL; +static slurm_auth_ops_t *ops = NULL; +static plugin_context_t **g_context = NULL; +static int g_context_num = -1; static pthread_mutex_t context_lock = PTHREAD_MUTEX_INITIALIZER; static const char *slurm_auth_generic_errstr(int slurm_errno) @@ -131,12 +132,12 @@ extern int slurm_auth_init(char *auth_type) char *type = NULL; char *plugin_type = "auth"; - if (init_run && g_context) + if (init_run && (g_context_num > 0)) return retval; slurm_mutex_lock(&context_lock); - if (g_context) + if (g_context_num > 0) goto done; if (auth_type) @@ -144,14 +145,20 @@ extern int slurm_auth_init(char *auth_type) type = slurm_get_auth_type(); - g_context = plugin_context_create( - plugin_type, type, (void **)&ops, syms, sizeof(syms)); + g_context_num = 0; - if (!g_context) { + xrealloc(ops, sizeof(slurm_auth_ops_t) * (g_context_num + 1)); + xrealloc(g_context, sizeof(plugin_context_t) * (g_context_num + 1)); + + g_context[g_context_num] = plugin_context_create( + plugin_type, type, (void **)ops, syms, sizeof(syms)); + + if (!g_context[g_context_num]) { error("cannot create %s context for %s", plugin_type, type); retval = SLURM_ERROR; goto done; } + g_context_num++; init_run = true; done: @@ -163,14 +170,30 @@ done: /* Release all global memory associated with the plugin */ extern int slurm_auth_fini(void) { - int rc; + int i, rc = SLURM_SUCCESS, rc2; + slurm_mutex_lock(&context_lock); if (!g_context) - return SLURM_SUCCESS; + goto done; init_run = false; - rc = plugin_context_destroy(g_context); - g_context = NULL; + + for (i = 0; i < g_context_num; i++) { + rc2 = plugin_context_destroy(g_context[i]); + if (rc2) { + debug("%s: %s: %s", + __func__, g_context[i]->type, + slurm_strerror(rc2)); + rc = SLURM_ERROR; + } + } + + xfree(ops); + xfree(g_context); + g_context_num = -1; + +done: + slurm_mutex_unlock(&context_lock); return rc; } @@ -186,7 +209,7 @@ void *g_slurm_auth_create(char *auth_info) if (slurm_auth_init(NULL) < 0) return NULL; - return (*(ops.create))(auth_info); + return (*(ops[0].create))(auth_info); } int g_slurm_auth_destroy(void *cred) @@ -194,7 +217,7 @@ int g_slurm_auth_destroy(void *cred) if (slurm_auth_init(NULL) < 0) return SLURM_ERROR; - return (*(ops.destroy))(cred); + return (*(ops[0].destroy))(cred); } int g_slurm_auth_verify(void *cred, char *auth_info) @@ -202,7 +225,7 @@ int g_slurm_auth_verify(void *cred, char *auth_info) if (slurm_auth_init(NULL) < 0) return SLURM_ERROR; - return (*(ops.verify))(cred, auth_info); + return (*(ops[0].verify))(cred, auth_info); } uid_t g_slurm_auth_get_uid(void *cred, char *auth_info) @@ -210,7 +233,7 @@ uid_t g_slurm_auth_get_uid(void *cred, char *auth_info) if (slurm_auth_init(NULL) < 0) return SLURM_AUTH_NOBODY; - return (*(ops.get_uid))(cred, auth_info); + return (*(ops[0].get_uid))(cred, auth_info); } gid_t g_slurm_auth_get_gid(void *cred, char *auth_info) @@ -218,7 +241,7 @@ gid_t g_slurm_auth_get_gid(void *cred, char *auth_info) if (slurm_auth_init(NULL) < 0) return SLURM_AUTH_NOBODY; - return (*(ops.get_gid))(cred, auth_info); + return (*(ops[0].get_gid))(cred, auth_info); } char *g_slurm_auth_get_host(void *cred, char *auth_info) @@ -226,7 +249,7 @@ char *g_slurm_auth_get_host(void *cred, char *auth_info) if (slurm_auth_init(NULL) < 0) return NULL; - return (*(ops.get_host))(cred, auth_info); + return (*(ops[0].get_host))(cred, auth_info); } int g_slurm_auth_pack(void *cred, Buf buf, uint16_t protocol_version) @@ -235,10 +258,10 @@ int g_slurm_auth_pack(void *cred, Buf buf, uint16_t protocol_version) return SLURM_ERROR; if (protocol_version >= SLURM_19_05_PROTOCOL_VERSION) { - pack32(*ops.plugin_id, buf); - return (*(ops.pack))(cred, buf, protocol_version); + pack32(*ops[0].plugin_id, buf); + return (*(ops[0].pack))(cred, buf, protocol_version); } else if (protocol_version >= SLURM_MIN_PROTOCOL_VERSION) { - packstr(ops.plugin_type, buf); + packstr(ops[0].plugin_type, buf); /* * This next field was packed with plugin_version within each * individual auth plugin, but upon unpack was never checked @@ -246,7 +269,7 @@ int g_slurm_auth_pack(void *cred, Buf buf, uint16_t protocol_version) * symbol, just pack a zero here instead. */ pack32(0, buf); - return (*(ops.pack))(cred, buf, protocol_version); + return (*(ops[0].pack))(cred, buf, protocol_version); } else { error("%s: protocol_version %hu not supported", __func__, protocol_version); @@ -263,24 +286,24 @@ void *g_slurm_auth_unpack(Buf buf, uint16_t protocol_version) if (protocol_version >= SLURM_19_05_PROTOCOL_VERSION) { safe_unpack32(&plugin_id, buf); - if (plugin_id != *(ops.plugin_id)) { + if (plugin_id != *(ops[0].plugin_id)) { error("%s: remote plugin_id %u != %u", - __func__, plugin_id, *(ops.plugin_id)); + __func__, plugin_id, *(ops[0].plugin_id)); return NULL; } - return (*(ops.unpack))(buf, protocol_version); + return (*(ops[0].unpack))(buf, protocol_version); } else if (protocol_version >= SLURM_MIN_PROTOCOL_VERSION) { char *plugin_type; uint32_t uint32_tmp, version; safe_unpackmem_ptr(&plugin_type, &uint32_tmp, buf); - if (xstrcmp(plugin_type, ops.plugin_type)) { + if (xstrcmp(plugin_type, ops[0].plugin_type)) { error("%s: remote plugin_type `%s` != `%s`", - __func__, plugin_type, ops.plugin_type); + __func__, plugin_type, ops[0].plugin_type); return NULL; } safe_unpack32(&version, buf); - return (*(ops.unpack))(buf, protocol_version); + return (*(ops[0].unpack))(buf, protocol_version); } else { error("%s: protocol_version %hu not supported", __func__, protocol_version); @@ -296,7 +319,7 @@ int g_slurm_auth_print(void *cred, FILE *fp) if (slurm_auth_init(NULL) < 0) return SLURM_ERROR; - return (*(ops.print))(cred, fp); + return (*(ops[0].print))(cred, fp); } int g_slurm_auth_errno(void *cred) @@ -304,7 +327,7 @@ int g_slurm_auth_errno(void *cred) if (slurm_auth_init(NULL) < 0) return SLURM_ERROR; - return (*(ops.sa_errno))(cred); + return (*(ops[0].sa_errno))(cred); } const char *g_slurm_auth_errstr(int slurm_errno) @@ -318,5 +341,5 @@ const char *g_slurm_auth_errstr(int slurm_errno) if ((generic = slurm_auth_generic_errstr(slurm_errno))) return generic; - return (*(ops.sa_errstr))(slurm_errno); + return (*(ops[0].sa_errstr))(slurm_errno); }