Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package ai.reveng.toolkit.ghidra.binarysimilarity.ui.aidecompiler;

import ai.reveng.invoker.ApiException;
import ai.reveng.model.AiDecompilationTaskStatus;
import ai.reveng.model.GetAiDecompilationTask;
import ai.reveng.model.DecompilationData;
import ai.reveng.model.WorkflowProgress;
import ai.reveng.toolkit.ghidra.core.services.api.GhidraRevengService;
import ai.reveng.toolkit.ghidra.core.services.api.TypedApiInterface;
import ai.reveng.toolkit.ghidra.core.services.api.types.AIDecompilationStatus;
import ai.reveng.toolkit.ghidra.core.services.logging.ReaiLoggingService;
import ai.reveng.toolkit.ghidra.plugins.ReaiPluginPackage;
import docking.ActionContext;
Expand Down Expand Up @@ -40,7 +41,7 @@ public class AIDecompilationdWindow extends ComponentProviderAdapter {
private JComponent component;
private Function function;
private TaskMonitorComponent taskMonitorComponent;
private final Map<Function, GetAiDecompilationTask> cache = new java.util.HashMap<>();
private final Map<Function, AIDecompilationStatus> cache = new java.util.HashMap<>();


public AIDecompilationdWindow(PluginTool tool, String owner) {
Expand Down Expand Up @@ -172,6 +173,11 @@ private JComponent buildComponent() {
return component;
}

@Override
public JComponent getComponent() {
return component;
}

/**
* Apply the predicted function name to the current function
*/
Expand All @@ -181,11 +187,11 @@ private void applyPredictedName() {
}

var cachedStatus = cache.get(function);
if (cachedStatus == null || cachedStatus.getPredictedFunctionName() == null) {
if (cachedStatus == null || cachedStatus.predictedFunctionName() == null) {
return;
}

String predictedName = cachedStatus.getPredictedFunctionName();
String predictedName = cachedStatus.predictedFunctionName();
var program = function.getProgram();

int txId = program.startTransaction("Rename function to predicted name");
Expand All @@ -199,34 +205,70 @@ private void applyPredictedName() {
}
}

@Override
public JComponent getComponent() {
return component;
}



public void setDisplayedValuesBasedOnStatus(Function function, GetAiDecompilationTask status) {
public void setDisplayedValuesBasedOnStatus(Function function, AIDecompilationStatus status) {
this.function = function;
if (status.getStatus() == AiDecompilationTaskStatus.SUCCESS) {
setCode(status.getDecompilation());
descriptionArea.setText("<html>%s</html>".formatted(status.getSummary()));

// Show predicted name if available
String predictedName = status.getPredictedFunctionName();
if (predictedName != null && !predictedName.isEmpty()) {
predictedNameLabel.setText("Predicted name: " + predictedName);
predictedNamePanel.setVisible(true);
} else {
switch (status.status()) {
case COMPLETED -> {
setCode(withInlineComments(status.decompilation(), status.inlineComments()));
descriptionArea.setText("<html>%s</html>".formatted(status.summary() == null ? "" : status.summary()));

String predictedName = status.predictedFunctionName();
if (predictedName != null && !predictedName.isEmpty()) {
predictedNameLabel.setText("Predicted name: " + predictedName);
predictedNamePanel.setVisible(true);
} else {
predictedNamePanel.setVisible(false);
}
}
case FAILED -> {
setCode("");
descriptionArea.setText("Decompilation failed");
predictedNamePanel.setVisible(false);
}
} else if (status.getStatus() == AiDecompilationTaskStatus.ERROR) {
setCode("");
descriptionArea.setText("Decompilation failed");
predictedNamePanel.setVisible(false);
case UNINITIALISED, PENDING, RUNNING -> {
setCode("");
descriptionArea.setText("Decompiling %s ...".formatted(function.getName()));
predictedNamePanel.setVisible(false);
}
default -> {
// Unknown status — leave existing UI state untouched.
}
}
}

/**
* Splice inline comments into the decompilation as `// comment` lines above each
* targeted line (1-indexed). Preserves the leading indentation of the target line.
*/
private static String withInlineComments(String decompilation, java.util.List<AIDecompilationStatus.InlineCommentEntry> comments) {
if (decompilation == null || decompilation.isEmpty() || comments == null || comments.isEmpty()) {
return decompilation == null ? "" : decompilation;
}
var byLine = new java.util.HashMap<Long, String>();
for (var entry : comments) {
byLine.put(entry.line(), entry.comment());
}
String[] lines = decompilation.split("\n", -1);
var out = new StringBuilder();
for (int i = 0; i < lines.length; i++) {
long lineNumber = i + 1L;
String comment = byLine.get(lineNumber);
if (comment != null) {
int indentEnd = 0;
while (indentEnd < lines[i].length() && Character.isWhitespace(lines[i].charAt(indentEnd))) {
indentEnd++;
}
String indent = lines[i].substring(0, indentEnd);
out.append(indent).append("// ").append(comment).append('\n');
}
out.append(lines[i]);
if (i < lines.length - 1) {
out.append('\n');
}
}
return out.toString();
}

private void setCode(String code) {
String text = code;
textArea.setText(text);
Expand All @@ -250,6 +292,11 @@ public void refresh(GhidraRevengService.FunctionWithID function) {
// Only start decompilation if the window is visible and the status of the analysis is complete.
if (this.isVisible()) {
taskMonitorComponent.setVisible(true);
// Replace the initial "no function selected" placeholder before the first poll lands.
this.function = function.function();
setCode("");
descriptionArea.setText("Decompiling %s ...".formatted(function.function().getName()));
predictedNamePanel.setVisible(false);
// Start a new background task to decompile the function
var task = new AIDecompTask(tool, function);
var builder = TaskBuilder.withTask(task);
Expand Down Expand Up @@ -280,28 +327,30 @@ public void locationChanged(ProgramLocation loc) {
}


void newStatusForFunction(Function function, GetAiDecompilationTask status) {
void newStatusForFunction(Function function, AIDecompilationStatus status) {
cache.put(function, status);
if (function == this.function) {
SwingUtilities.invokeLater(() ->
setDisplayedValuesBasedOnStatus(function, status)
);
}
if (status.getStatus() == AiDecompilationTaskStatus.SUCCESS) {
if (status.status() == DecompilationData.StatusEnum.COMPLETED) {
var logger = tool.getService(ReaiLoggingService.class);
logger.info("AI Decompilation finished for function %s: %s".formatted(function.getName(), status.getDecompilation()));
logger.info("AI Decompilation finished for function %s: %s".formatted(function.getName(), status.decompilation()));
if (!hasPendingDecompilations()) {
taskMonitorComponent.setVisible(false);
}
} else if (status.getStatus() == AiDecompilationTaskStatus.ERROR) {
} else if (status.status() == DecompilationData.StatusEnum.FAILED) {
if (!hasPendingDecompilations()) {
taskMonitorComponent.setVisible(false);
}
}
}

private boolean hasPendingDecompilations() {
return cache.values().stream().anyMatch(s -> s.getStatus() == AiDecompilationTaskStatus.PENDING);
return cache.values().stream().anyMatch(s ->
s.status() == DecompilationData.StatusEnum.PENDING
|| s.status() == DecompilationData.StatusEnum.RUNNING);
}
class AIDecompTask extends Task {

Expand All @@ -318,7 +367,7 @@ public AIDecompTask(PluginTool tool, GhidraRevengService.FunctionWithID function
public void run(TaskMonitor monitor) throws CancelledException {
var fID = functionWithID.functionID();
// Check if there is an existing process already, because the trigger API will fail with 400 if there is
if (service.getApi().pollAIDecompileStatus(fID).getStatus() == AiDecompilationTaskStatus.UNINITIALISED) {
if (service.getApi().pollAIDecompileStatus(fID).status() == DecompilationData.StatusEnum.UNINITIALISED) {
// Trigger the decompilation
service.getApi().triggerAIDecompilationForFunctionID(fID);
}
Expand All @@ -330,18 +379,22 @@ public void run(TaskMonitor monitor) throws CancelledException {
private void waitForDecomp(TypedApiInterface.FunctionID id, TaskMonitor monitor) throws CancelledException {
var logger = tool.getService(ReaiLoggingService.class);
var api = service.getApi();
GetAiDecompilationTask lastDecompStatus = null;
AIDecompilationStatus lastDecompStatus = null;
boolean inlineCommentsTriggered = false;
while (true) {
var newStatus = api.pollAIDecompileStatus(id);
if (lastDecompStatus == null || !Objects.equals(newStatus.getStatus(), lastDecompStatus.getStatus())) {
if (lastDecompStatus == null
|| !Objects.equals(newStatus.status(), lastDecompStatus.status())
|| !Objects.equals(newStatus.inlineCommentsStatus(), lastDecompStatus.inlineCommentsStatus())
|| newStatus.inlineComments().size() != lastDecompStatus.inlineComments().size()) {
lastDecompStatus = newStatus;

newStatusForFunction(functionWithID.function(), newStatus);
}
monitor.setMessage("Waiting for AI Decompilation for %s ... Current status: %s".formatted(functionWithID.function().getName(), lastDecompStatus.getStatus()));
monitor.setMessage("Waiting for AI Decompilation for %s ... Current status: %s".formatted(functionWithID.function().getName(), lastDecompStatus.status()));
monitor.checkCancelled();
switch (newStatus.getStatus()) {
switch (newStatus.status()) {
case PENDING:
case RUNNING:
case UNINITIALISED:
try {
// Wait a second before polling again. We don't want to spam the API with requests too often
Expand All @@ -350,14 +403,40 @@ private void waitForDecomp(TypedApiInterface.FunctionID id, TaskMonitor monitor)
throw new RuntimeException(e);
}
break;
case SUCCESS:
case COMPLETED:
monitor.setProgress(monitor.getMaximum());
return;
case ERROR:
logger.error("Decompilation failed: %s".formatted(newStatus.getDecompilation()));
// Decompilation is done; now wait for inline comments to land before stopping.
var commentsStatus = newStatus.inlineCommentsStatus();
if (commentsStatus == WorkflowProgress.StatusEnum.COMPLETED) {
return;
}
if (commentsStatus == WorkflowProgress.StatusEnum.FAILED) {
logger.error("Inline comments generation failed for function %s".formatted(functionWithID.function().getName()));
return;
}
if (!inlineCommentsTriggered) {
// Trigger on first entry to COMPLETED regardless of reported status: if comments
// haven't been requested (or the status endpoint was unreachable), POSTing kicks
// them off; if they're already running the server treats it as a regenerate.
inlineCommentsTriggered = true;
try {
api.triggerAIDecompilationInlineComments(id);
} catch (RuntimeException e) {
logger.error("Failed to trigger inline comments: %s".formatted(e.getMessage()));
return;
}
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
break;
case FAILED:
logger.error("Decompilation failed: %s".formatted(newStatus.decompilation()));
return;
default:
throw new RuntimeException("Unknown status: %s".formatted(newStatus.getStatus()));
throw new RuntimeException("Unknown status: %s".formatted(newStatus.status()));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,42 +645,37 @@ public String decompileFunctionViaAI(FunctionWithID functionWithID, TaskMonitor
// Check if there is an existing process already, because the trigger API will fail with 400 if there is
var fID = functionWithID.functionID;
var function = functionWithID.function;
if (api.pollAIDecompileStatus(fID).getStatus() == AiDecompilationTaskStatus.UNINITIALISED){
if (api.pollAIDecompileStatus(fID).status() == DecompilationData.StatusEnum.UNINITIALISED){
// Trigger the decompilation
api.triggerAIDecompilationForFunctionID(fID);
}

String lastStatus;

while (true) {
if (monitor.isCancelled()) {
return "Decompilation cancelled";
}
var status = api.pollAIDecompileStatus(fID);
window.setDisplayedValuesBasedOnStatus(function, status);

switch (status.getStatus()) {
switch (status.status()) {
case PENDING:
case RUNNING:
case UNINITIALISED:
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
// monitor.incrementProgress(100);
break;
case SUCCESS:
case COMPLETED:
monitor.setProgress(monitor.getMaximum());
window.setDisplayedValuesBasedOnStatus(function, status);
return status.getDecompilation();
case ERROR:
return "Decompilation failed: %s".formatted(status.getStatus());
return status.decompilation();
case FAILED:
return "Decompilation failed: %s".formatted(status.status());
default:
throw new RuntimeException("Unknown status: %s".formatted(status.getStatus()));
throw new RuntimeException("Unknown status: %s".formatted(status.status()));
}



}
}

Expand Down Expand Up @@ -925,7 +920,7 @@ public BoxPlot getNameScoreForMatch(GhidraFunctionMatch functionMatch) {

public void openFunctionInPortal(TypedApiInterface.FunctionID functionID) {
var details = api.getFunctionDetails(functionID);
openPortal("analyses", String.format("%s?fn=%s", details.analysisId().id(), functionID.value()));
openPortal("analyses", String.format("%s?view=functions&fn=%s", details.analysisId().id(), functionID.value()));
}

public void openCollectionInPortal(Collection collection) {
Expand Down
Loading
Loading