diff --git a/Haruspex.java b/Haruspex.java index 2fc03dc..25a9da6 100644 --- a/Haruspex.java +++ b/Haruspex.java @@ -41,6 +41,10 @@ import java.util.ArrayList; import java.io.FileWriter; import java.io.PrintWriter; +import java.io.File; +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; import ghidra.app.script.GhidraScript; import ghidra.program.model.symbol.*; @@ -48,92 +52,181 @@ import ghidra.app.decompiler.DecompInterface; import ghidra.app.decompiler.DecompileOptions; import ghidra.app.decompiler.DecompileResults; +import ghidra.util.exception.CancelledException; -public class Haruspex extends GhidraScript -{ - List functions; - DecompileOptions options; - DecompInterface decomp; - String outputPath = "/tmp/haruspex.out"; - static int TIMEOUT = 60; - - @Override - public void run() throws Exception - { - printf("\nHaruspex.java - Extract Ghidra decompiler's pseudo-code\n"); - printf("Copyright (c) 2022 Marco Ivaldi \n\n"); - - // ask for output directory path - try { - outputPath = askString("Output directory path", "Enter the path of the output directory:"); - } catch (Exception e) { - printf("Output directory not supplied, using default \"%s\".\n", outputPath); - } - - // get all functions - functions = new ArrayList<>(); - getAllFunctions(); - - // extract pseudo-code of all functions (using default options) - decomp = new DecompInterface(); - options = new DecompileOptions(); - decomp.setOptions(options); - decomp.toggleCCode(true); - decomp.toggleSyntaxTree(true); - decomp.setSimplificationStyle("decompile"); - if (!decomp.openProgram(currentProgram)) { - printf("Could not initialize the decompiler, exiting.\n\n"); - return; - } - printf("Extracting pseudo-code from %d functions...\n\n", functions.size()); - functions.forEach(f -> extractPseudoCode(f)); - } - - // collect all Function objects into a global ArrayList - public void getAllFunctions() - { - SymbolTable st = currentProgram.getSymbolTable(); - SymbolIterator si = st.getSymbolIterator(); - - while (si.hasNext()) { - Symbol s = si.next(); - if ( (s.getSymbolType() == SymbolType.FUNCTION) && (!s.isExternal()) ) { - Function fun = getFunctionAt(s.getAddress()); - if (!fun.isThunk()) { - functions.add(fun); - } - } - } - } - - // extract the pseudo-code of a function - // @param func target function - public void extractPseudoCode(Function func) - { - DecompileResults res = decomp.decompileFunction(func, TIMEOUT, monitor); - if(res.getDecompiledFunction() != null){ - saveToFile(outputPath, func.getName() + "@" + func.getEntryPoint() + ".c", res.getDecompiledFunction().getC()); - } - else{ - printf("Can't decompile %s\n\n", func.getName()); - } - } - - // save results to file - // @param path name of the output directory - // @param name name of the output file - // @param output content to save to file - public void saveToFile(String path, String name, String output) - { - try { - FileWriter fw = new FileWriter(path + "/" + name); - PrintWriter pw = new PrintWriter(fw); - pw.write(output); - pw.close(); - - } catch (Exception e) { - printf("Cannot write to output file \"%s\".\n\n", path + "/" + name); - return; - } - } -} +public class Haruspex extends GhidraScript { + List < Function > functions; + DecompileOptions options; + DecompInterface decomp; + String outputPath; + String functionsFilePath; + static int TIMEOUT = 60; + static int BATCHSIZE = 25; // Specify your batch size + + @Override + public void run() throws Exception { + printf("\nHaruspex.java - Extract Ghidra decompiler's pseudo-code\n"); + printf("Copyright (c) 2022 Marco Ivaldi \n\n"); + + String[] inputPaths = new String[2]; + try { + inputPaths[0] = askString("Output directory path", "Enter the path of the output directory:"); + + // Allow empty path + try { + inputPaths[1] = askString("Function Names File Path", "Enter the path of the function names file:"); + } catch (CancelledException e) { + inputPaths[1] = null; + } + + } catch (CancelledException e) { + // Handle cancellation + printf("Script execution canceled by the user.\n"); + return; + } + + outputPath = inputPaths[0]; + functionsFilePath = inputPaths[1]; + + List < String > functionNames; + functions = new ArrayList < > (); + + // Check if function file path is provided + if (functionsFilePath != null && !functionsFilePath.isEmpty()) { + try { + functionNames = readFunctionNames(functionsFilePath); + } catch (Exception e) { + printf("Error reading function names file: %s\n", e.getMessage()); + return; + } + getFunctionsByName(functionNames); + } else { + // If no function file path is provided, extract all functions + getAllFunctions(); + } + + // Create the output directory if it doesn't exist + File outputDir = new File(outputPath); + if (!outputDir.exists()) { + try { + outputDir.mkdirs(); + } catch (SecurityException e) { + printf("Error creating output directory: %s\n", e.getMessage()); + return; + } + } + + printf("Reading function names from: \"%s\".\n", functionsFilePath); + printf("Output directory: \"%s\".\n", outputPath); + + // extract pseudo-code of all functions (using default options) + extractPseudoCode(); + } + + // Reads function names from a file and returns them as a list + // @param filePath The path to the file containing function names + private List < String > readFunctionNames(String filePath) { + List < String > names = new ArrayList < > (); + try (BufferedReader reader = new BufferedReader(new FileReader(new File(filePath)))) { + String line; + while ((line = reader.readLine()) != null) { + names.add(line.trim()); + } + } catch (IOException e) { + printf("Error reading file: %s\n", e.getMessage()); + } + return names; + } + + // Collect Function objects into a global ArrayList based on names + private void getFunctionsByName(List < String > functionNames) { + SymbolTable st = currentProgram.getSymbolTable(); + for (String functionName: functionNames) { + + // Get all symbols with function name + SymbolIterator si = st.getSymbols(functionName); + + // Iterate over symbols + while (si.hasNext()) { + Symbol s = si.next(); + // Check if the symbol is a function + if (s.getSymbolType() == SymbolType.FUNCTION) { + Function fun = getFunctionAt(s.getAddress()); + if (fun != null && !fun.isThunk()) { + functions.add(fun); + } + } + } + } + } + + // Collects all Function objects into a global ArrayList + public void getAllFunctions() { + SymbolTable st = currentProgram.getSymbolTable(); + SymbolIterator si = st.getSymbolIterator(); + + while (si.hasNext()) { + Symbol s = si.next(); + if ((s.getSymbolType() == SymbolType.FUNCTION) && (!s.isExternal())) { + Function fun = getFunctionAt(s.getAddress()); + if (!fun.isThunk()) { + functions.add(fun); + } + } + } + } + + // Initalizes the decompiler and extracts pseudo-code from a batch of functions + private void extractPseudoCode() { + decomp = new DecompInterface(); + options = new DecompileOptions(); + decomp.setOptions(options); + decomp.toggleCCode(true); + decomp.toggleSyntaxTree(true); + decomp.setSimplificationStyle("decompile"); + + if (!decomp.openProgram(currentProgram)) { + printf("Could not initialize the decompiler, exiting.\n\n"); + return; + } + + printf("Extracting pseudo-code from %d functions...\n\n", functions.size()); + + // batch processing for better extraction performance + for (int i = 0; i < functions.size(); i += BATCHSIZE) { + List < Function > batch = functions.subList(i, Math.min(i + BATCHSIZE, functions.size())); + batchExtractPseudoCode(batch); + } + } + + // extract the pseudo-code of a function + // @param func target function batch + public void batchExtractPseudoCode(List < Function > batch) { + for (Function func: batch) { + DecompileResults res = decomp.decompileFunction(func, TIMEOUT, monitor); + if (res.getDecompiledFunction() != null) { + saveToFile(outputPath, func.getName() + "@" + func.getEntryPoint() + ".c", res.getDecompiledFunction().getC()); + } else { + printf("Can't decompile %s\n\n", func.getName()); + } + } + } + + // save results to file + // @param path name of the output directory + // @param name name of the output file + // @param output content to save to file + public void saveToFile(String path, String name, String output) { + String fullPath = path + File.separator + name; + try { + FileWriter fw = new FileWriter(fullPath); + PrintWriter pw = new PrintWriter(fw); + pw.write(output); + pw.close(); + + } catch (IOException e) { + printf("Error writing to output file \"%s\". %s\n\n", fullPath, e.getMessage()); + return; + } + } +} \ No newline at end of file diff --git a/Rhabdomancer.java b/Rhabdomancer.java index 2face90..f224cc0 100644 --- a/Rhabdomancer.java +++ b/Rhabdomancer.java @@ -26,6 +26,7 @@ * - Copy the script into your ghidra_scripts directory * - Open the Script Manager in Ghidra and run the script * - You can also run it via the Tools > Rhabdomancer menu or the shortcut "Y" + * - Choose the output directory to save the results * - Open Window > Comments and navigate [BAD] candidate points in tier 0-2 * * Inspired by The Ghidra Book (No Starch, 2020). Tested with Ghidra v10.3. @@ -44,6 +45,10 @@ import java.util.Map; import java.util.LinkedHashMap; import java.util.Iterator; +import java.io.File; +import java.io.FileWriter; +import java.io.BufferedWriter; +import java.io.IOException; import ghidra.app.script.GhidraScript; import ghidra.program.model.symbol.*; @@ -52,6 +57,8 @@ public class Rhabdomancer extends GhidraScript { + private BufferedWriter outputFile; + @Override public void run() throws Exception { @@ -199,6 +206,13 @@ public void run() throws Exception bad.put("[BAD 1]", tier1); bad.put("[BAD 2]", tier2); + // Create a new file to store function names containing insecure API calls + String outputDirectory = askString("Choose Output Directory", "Select the directory to save the results"); + File outputFilePath = new File(outputDirectory); + try { + outputFile = new BufferedWriter(new FileWriter(outputFilePath)); + printf("Writing results to: %s\n", outputFilePath); + // enumerate candidate points at each tier Iterator>> i = bad.entrySet().iterator(); while (i.hasNext()) { @@ -209,6 +223,11 @@ public void run() throws Exception funcs.forEach(f -> listCalls(f, entry.getKey() + " " + f.getName())); } } + catch (IOException e) { + e.printStackTrace(); + } + outputFile.close(); + } // collect Function objects associated with the specified name // @param name function name @@ -263,8 +282,15 @@ public void listCalls(Function dstFunc, String tag) if (getBookmarks(callAddr).length == 0) { createBookmark(callAddr, "Insecure function - " + tag, dstName + " is called"); } + + try { + outputFile.write(srcName + "\n"); + } catch (IOException e) { + e.printStackTrace(); + } + } } } } -} +} \ No newline at end of file