diff --git a/src/privileged_extensions.c b/src/privileged_extensions.c index c778697..ba18ebd 100644 --- a/src/privileged_extensions.c +++ b/src/privileged_extensions.c @@ -67,94 +67,99 @@ static void run_custom_script(const char *filename, const char *extname, running_custom_script = false; } -void handle_create_extension( - void (*process_utility_hook)(PROCESS_UTILITY_PARAMS), - PROCESS_UTILITY_PARAMS, const char *privileged_extensions, - const char *superuser, - const char *privileged_extensions_custom_scripts_path, - const extension_parameter_overrides *epos, const size_t total_epos) { - CreateExtensionStmt *stmt = (CreateExtensionStmt *)pstmt->utilityStmt; - char *filename = (char *)palloc(MAXPGPATH); - - // Run global before-create script. - { - DefElem *d_schema = NULL; - DefElem *d_new_version = NULL; - DefElem *d_cascade = NULL; - char *extschema = NULL; - char *extversion = NULL; - bool extcascade = false; - ListCell *option_cell = NULL; - bool already_switched_to_superuser = false; - - foreach (option_cell, stmt->options) { - DefElem *defel = (DefElem *)lfirst(option_cell); - - if (strcmp(defel->defname, "schema") == 0) { - d_schema = defel; - extschema = defGetString(d_schema); - } else if (strcmp(defel->defname, "new_version") == 0) { - d_new_version = defel; - extversion = defGetString(d_new_version); - } else if (strcmp(defel->defname, "cascade") == 0) { - d_cascade = defel; - extcascade = defGetBoolean(d_cascade); - } +void run_global_before_create_script(char *extname, List *options, const char *privileged_extensions_custom_scripts_path){ + DefElem *d_schema = NULL, *d_new_version = NULL, *d_cascade = NULL; + char *extschema = NULL, *extversion = NULL; + bool extcascade = false; + char filename[MAXPGPATH]; + + ListCell *option_cell = NULL; + + foreach (option_cell, options) { + DefElem *defel = (DefElem *)lfirst(option_cell); + + if (strcmp(defel->defname, "schema") == 0) { + d_schema = defel; + extschema = defGetString(d_schema); + } else if (strcmp(defel->defname, "new_version") == 0) { + d_new_version = defel; + extversion = defGetString(d_new_version); + } else if (strcmp(defel->defname, "cascade") == 0) { + d_cascade = defel; + extcascade = defGetBoolean(d_cascade); } + } - switch_to_superuser(superuser, - &already_switched_to_superuser); - - snprintf(filename, MAXPGPATH, "%s/before-create.sql", - privileged_extensions_custom_scripts_path); - run_custom_script(filename, stmt->extname, extschema, extversion, - extcascade); + snprintf(filename, MAXPGPATH, "%s/before-create.sql", + privileged_extensions_custom_scripts_path); + run_custom_script(filename, extname, extschema, extversion, + extcascade); +} - if (!already_switched_to_superuser) { - switch_to_original_role(); +void run_ext_before_create_script(char *extname, List *options, const char *privileged_extensions_custom_scripts_path){ + DefElem *d_schema = NULL; + DefElem *d_new_version = NULL; + DefElem *d_cascade = NULL; + char *extschema = NULL; + char *extversion = NULL; + bool extcascade = false; + ListCell *option_cell = NULL; + char filename[MAXPGPATH]; + + foreach (option_cell, options) { + DefElem *defel = (DefElem *)lfirst(option_cell); + + if (strcmp(defel->defname, "schema") == 0) { + d_schema = defel; + extschema = defGetString(d_schema); + } else if (strcmp(defel->defname, "new_version") == 0) { + d_new_version = defel; + extversion = defGetString(d_new_version); + } else if (strcmp(defel->defname, "cascade") == 0) { + d_cascade = defel; + extcascade = defGetBoolean(d_cascade); } } - // Run per-extension before-create script. - { - DefElem *d_schema = NULL; - DefElem *d_new_version = NULL; - DefElem *d_cascade = NULL; - char *extschema = NULL; - char *extversion = NULL; - bool extcascade = false; - ListCell *option_cell = NULL; - bool already_switched_to_superuser = false; - - foreach (option_cell, stmt->options) { - DefElem *defel = (DefElem *)lfirst(option_cell); - - if (strcmp(defel->defname, "schema") == 0) { - d_schema = defel; - extschema = defGetString(d_schema); - } else if (strcmp(defel->defname, "new_version") == 0) { - d_new_version = defel; - extversion = defGetString(d_new_version); - } else if (strcmp(defel->defname, "cascade") == 0) { - d_cascade = defel; - extcascade = defGetBoolean(d_cascade); - } - } - - switch_to_superuser(superuser, - &already_switched_to_superuser); - snprintf(filename, MAXPGPATH, "%s/%s/before-create.sql", - privileged_extensions_custom_scripts_path, stmt->extname); - run_custom_script(filename, stmt->extname, extschema, extversion, - extcascade); + snprintf(filename, MAXPGPATH, "%s/%s/before-create.sql", + privileged_extensions_custom_scripts_path, extname); + run_custom_script(filename, extname, extschema, extversion, + extcascade); +} - if (!already_switched_to_superuser) { - switch_to_original_role(); +void run_ext_after_create_script(char *extname, List *options, const char *privileged_extensions_custom_scripts_path){ + DefElem *d_schema = NULL; + DefElem *d_new_version = NULL; + DefElem *d_cascade = NULL; + char *extschema = NULL; + char *extversion = NULL; + bool extcascade = false; + ListCell *option_cell = NULL; + char filename[MAXPGPATH]; + + foreach (option_cell, options) { + DefElem *defel = (DefElem *)lfirst(option_cell); + + if (strcmp(defel->defname, "schema") == 0) { + d_schema = defel; + extschema = defGetString(d_schema); + } else if (strcmp(defel->defname, "new_version") == 0) { + d_new_version = defel; + extversion = defGetString(d_new_version); + } else if (strcmp(defel->defname, "cascade") == 0) { + d_cascade = defel; + extcascade = defGetBoolean(d_cascade); } } - // Apply overrides. + snprintf(filename, MAXPGPATH, "%s/%s/after-create.sql", + privileged_extensions_custom_scripts_path, extname); + run_custom_script(filename, extname, extschema, extversion, + extcascade); +} + +void override_create_ext_statement(CreateExtensionStmt *stmt, const size_t total_epos, const extension_parameter_overrides *epos){ for (size_t i = 0; i < total_epos; i++) { if (strcmp(epos[i].name, stmt->extname) == 0) { const extension_parameter_overrides *epo = &epos[i]; @@ -189,63 +194,6 @@ void handle_create_extension( } } } - - // Run `CREATE EXTENSION`. - if (is_string_in_comma_delimited_string(stmt->extname, - privileged_extensions)) { - bool already_switched_to_superuser = false; - switch_to_superuser(superuser, - &already_switched_to_superuser); - - run_process_utility_hook(process_utility_hook); - - if (!already_switched_to_superuser) { - switch_to_original_role(); - } - } else { - run_process_utility_hook(process_utility_hook); - } - - // Run per-extension after-create script. - { - DefElem *d_schema = NULL; - DefElem *d_new_version = NULL; - DefElem *d_cascade = NULL; - char *extschema = NULL; - char *extversion = NULL; - bool extcascade = false; - ListCell *option_cell = NULL; - bool already_switched_to_superuser = false; - - foreach (option_cell, stmt->options) { - DefElem *defel = (DefElem *)lfirst(option_cell); - - if (strcmp(defel->defname, "schema") == 0) { - d_schema = defel; - extschema = defGetString(d_schema); - } else if (strcmp(defel->defname, "new_version") == 0) { - d_new_version = defel; - extversion = defGetString(d_new_version); - } else if (strcmp(defel->defname, "cascade") == 0) { - d_cascade = defel; - extcascade = defGetBoolean(d_cascade); - } - } - - switch_to_superuser(superuser, - &already_switched_to_superuser); - - snprintf(filename, MAXPGPATH, "%s/%s/after-create.sql", - privileged_extensions_custom_scripts_path, stmt->extname); - run_custom_script(filename, stmt->extname, extschema, extversion, - extcascade); - - if (!already_switched_to_superuser) { - switch_to_original_role(); - } - } - - pfree(filename); } bool all_extensions_are_privileged(List *objects, const char *privileged_extensions){ diff --git a/src/privileged_extensions.h b/src/privileged_extensions.h index 83bb0d5..681422d 100644 --- a/src/privileged_extensions.h +++ b/src/privileged_extensions.h @@ -4,15 +4,16 @@ #include "extensions_parameter_overrides.h" #include "utils.h" -extern void handle_create_extension( - void (*process_utility_hook)(PROCESS_UTILITY_PARAMS), - PROCESS_UTILITY_PARAMS, const char *privileged_extensions, - const char *superuser, - const char *privileged_extensions_custom_scripts_path, - const extension_parameter_overrides *epos, const size_t total_epos); - bool all_extensions_are_privileged(List *objects, const char *privileged_extensions); bool is_extension_privileged(const char *extname, const char *privileged_extensions); +void run_global_before_create_script(char *extname, List *options, const char *privileged_extensions_custom_scripts_path); + +void run_ext_before_create_script(char *extname, List *options, const char *privileged_extensions_custom_scripts_path); + +void run_ext_after_create_script(char *extname, List *options, const char *privileged_extensions_custom_scripts_path); + +void override_create_ext_statement(CreateExtensionStmt *stmt, const size_t total_epos, const extension_parameter_overrides *epos); + #endif diff --git a/src/supautils.c b/src/supautils.c index b04cc31..216668c 100644 --- a/src/supautils.c +++ b/src/supautils.c @@ -448,13 +448,29 @@ static void supautils_hook(PROCESS_UTILITY_PARAMS) { constrain_extension(stmt->extname, cexts, total_cexts); - handle_create_extension(prev_hook, - PROCESS_UTILITY_ARGS, - privileged_extensions, - supautils_superuser, - privileged_extensions_custom_scripts_path, - epos, total_epos); - return; + if (is_extension_privileged(stmt->extname, privileged_extensions)) { + bool already_switched_to_superuser = false; + + switch_to_superuser(supautils_superuser, &already_switched_to_superuser); + + run_global_before_create_script(stmt->extname, stmt->options, privileged_extensions_custom_scripts_path); + + run_ext_before_create_script(stmt->extname, stmt->options, privileged_extensions_custom_scripts_path); + + override_create_ext_statement(stmt, total_epos, epos); + + run_process_utility_hook(prev_hook); + + run_ext_after_create_script(stmt->extname, stmt->options, privileged_extensions_custom_scripts_path); + + if (!already_switched_to_superuser) { + switch_to_original_role(); + } + + return; + } + + break; } /*