Skip to content

Add the built-in tool run_command_in_terminal for AI to execute commands in the connected PowerShell session #398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 237 additions & 2 deletions shell/AIShell.Abstraction/NamedPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ public enum MessageType : int
/// A message from AIShell to command-line shell to send code block.
/// </summary>
PostCode = 4,

/// <summary>
/// A message from AIShell to command-line shell to run a command.
/// </summary>
RunCommand = 5,

/// <summary>
/// A message from AIShell to command-line shell to ask for the result of a previous command run.
/// </summary>
AskCommandOutput = 6,

/// <summary>
/// A message from command-line shell to AIShell to post the result of a command.
/// </summary>
PostResult = 7,
}

/// <summary>
Expand Down Expand Up @@ -201,6 +216,95 @@ public PostCodeMessage(List<string> codeBlocks)
}
}

/// <summary>
/// Message for <see cref="MessageType.RunCommand"/>.
/// </summary>
public sealed class RunCommandMessage : PipeMessage
{
/// <summary>
/// Gets the command to run.
/// </summary>
public string Command { get; }

/// <summary>
/// Gets whether the command should be run in blocking mode.
/// </summary>
public bool Blocking { get; }

/// <summary>
/// Creates an instance of <see cref="RunCommandMessage"/>.
/// </summary>
public RunCommandMessage(string command, bool blocking)
: base(MessageType.RunCommand)
{
ArgumentException.ThrowIfNullOrEmpty(command);

Command = command;
Blocking = blocking;
}
}

/// <summary>
/// Message for <see cref="MessageType.AskCommandOutput"/>.
/// </summary>
public sealed class AskCommandOutputMessage : PipeMessage
{
/// <summary>
/// Gets the id of the command to retrieve the output for.
/// </summary>
public string CommandId { get; }

/// <summary>
/// Creates an instance of <see cref="AskCommandOutputMessage"/>.
/// </summary>
public AskCommandOutputMessage(string commandId)
: base(MessageType.AskCommandOutput)
{
ArgumentException.ThrowIfNullOrEmpty(commandId);
CommandId = commandId;
}
}

/// <summary>
/// Message for <see cref="MessageType.PostResult"/>.
/// </summary>
public sealed class PostResultMessage : PipeMessage
{
/// <summary>
/// Gets the result of the command for a blocking 'run_command' too call.
/// Or, for a non-blocking call, gets the id for retrieving the result later.
/// </summary>
public string Output { get; }

/// <summary>
/// Gets whether the command execution had any error.
/// i.e. a native command returned a non-zero exit code, or a powershell command threw any errors.
/// </summary>
public bool HadError { get; }

/// <summary>
/// Gets a value indicating whether the operation was canceled by the user.
/// </summary>
public bool UserCancelled { get; }

/// <summary>
/// Gets the internal exception message that is thrown when trying to run the command.
/// </summary>
public string Exception { get; }

/// <summary>
/// Creates an instance of <see cref="PostResultMessage"/>.
/// </summary>
public PostResultMessage(string output, bool hadError, bool userCancelled, string exception)
: base(MessageType.PostResult)
{
Output = output;
HadError = hadError;
UserCancelled = userCancelled;
Exception = exception;
}
}

/// <summary>
/// The base type for common pipe operations.
/// </summary>
Expand Down Expand Up @@ -301,7 +405,7 @@ protected async Task<PipeMessage> GetMessageAsync(CancellationToken cancellation
return null;
}

if (type > (int)MessageType.PostCode)
if (type > (int)MessageType.PostResult)
{
_pipeStream.Close();
throw new IOException($"Unknown message type received: {type}. Connection was dropped.");
Expand Down Expand Up @@ -344,9 +448,12 @@ private static PipeMessage DeserializePayload(int type, ReadOnlySpan<byte> bytes
{
(int)MessageType.PostQuery => JsonSerializer.Deserialize<PostQueryMessage>(bytes),
(int)MessageType.AskConnection => JsonSerializer.Deserialize<AskConnectionMessage>(bytes),
(int)MessageType.PostContext => JsonSerializer.Deserialize<PostContextMessage>(bytes),
(int)MessageType.AskContext => JsonSerializer.Deserialize<AskContextMessage>(bytes),
(int)MessageType.PostContext => JsonSerializer.Deserialize<PostContextMessage>(bytes),
(int)MessageType.PostCode => JsonSerializer.Deserialize<PostCodeMessage>(bytes),
(int)MessageType.RunCommand => JsonSerializer.Deserialize<RunCommandMessage>(bytes),
(int)MessageType.AskCommandOutput => JsonSerializer.Deserialize<AskCommandOutputMessage>(bytes),
(int)MessageType.PostResult => JsonSerializer.Deserialize<PostResultMessage>(bytes),
_ => throw new NotSupportedException("Unreachable code"),
};
}
Expand Down Expand Up @@ -465,6 +572,16 @@ public async Task StartProcessingAsync(int timeout, CancellationToken cancellati
InvokeOnPostCode((PostCodeMessage)message);
break;

case MessageType.RunCommand:
var result = InvokeOnRunCommand((RunCommandMessage)message);
SendMessage(result);
break;

case MessageType.AskCommandOutput:
var output = InvokeOnAskCommandOutput((AskCommandOutputMessage)message);
SendMessage(output);
break;

default:
// Log: unexpected messages ignored.
break;
Expand Down Expand Up @@ -537,6 +654,66 @@ private PostContextMessage InvokeOnAskContext(AskContextMessage message)
return null;
}

/// <summary>
/// Helper to invoke the <see cref="OnRunCommand"/> event.
/// </summary>
private PostResultMessage InvokeOnRunCommand(RunCommandMessage message)
{
if (OnRunCommand is null)
{
// Log: event handler not set.
return new PostResultMessage(
output: "Command execution is not supported.",
hadError: true,
userCancelled: false,
exception: null);
}

try
{
return OnRunCommand(message);
}
catch (Exception e)
{
// Log: exception when invoking 'OnRunCommand'
return new PostResultMessage(
output: "Failed to execute the command due to an internal error.",
hadError: true,
userCancelled: false,
exception: e.Message);
}
}

/// <summary>
/// Helper to invoke the <see cref="OnAskCommandOutput"/> event.
/// </summary>
private PostResultMessage InvokeOnAskCommandOutput(AskCommandOutputMessage message)
{
if (OnAskCommandOutput is null)
{
// Log: event handler not set.
return new PostResultMessage(
output: "Retrieving command output is not supported.",
hadError: true,
userCancelled: false,
exception: null);
}

try
{
return OnAskCommandOutput(message);
}
catch (Exception e)
{
// Log: exception when invoking 'OnAskCommandOutput'
return new PostResultMessage(
output: "Failed to retrieve the command output due to an internal error.",
hadError: true,
userCancelled: false,
exception: e.Message);
}
}

/// <summary>
/// Event for handling the <see cref="MessageType.PostCode"/> message.
/// </summary>
Expand All @@ -551,6 +728,16 @@ private PostContextMessage InvokeOnAskContext(AskContextMessage message)
/// Event for handling the <see cref="MessageType.AskContext"/> message.
/// </summary>
public event Func<AskContextMessage, PostContextMessage> OnAskContext;

/// <summary>
/// Event for handling the <see cref="MessageType.RunCommand"/> message.
/// </summary>
public event Func<RunCommandMessage, PostResultMessage> OnRunCommand;

/// <summary>
/// Event for handling the <see cref="MessageType.AskCommandOutput"/> message.
/// </summary>
public event Func<AskCommandOutputMessage, PostResultMessage> OnAskCommandOutput;
}

/// <summary>
Expand Down Expand Up @@ -771,4 +958,52 @@ public async Task<PostContextMessage> AskContext(AskContextMessage message, Canc

return postContext;
}

/// <summary>
/// Run a command in the connected PowerShell session.
/// </summary>
/// <param name="message">The <see cref="MessageType.RunCommand"/> message.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>A <see cref="MessageType.PostResult"/> message as the response.</returns>
/// <exception cref="IOException">Throws when the pipe is closed by the other side.</exception>
public async Task<PostResultMessage> RunCommand(RunCommandMessage message, CancellationToken cancellationToken)
{
// Send the request message to the shell.
SendMessage(message);

// Receiving response from the shell.
var response = await GetMessageAsync(cancellationToken);
if (response is not PostResultMessage postResult)
{
// Log: unexpected message. drop connection.
_client.Close();
throw new IOException($"Expecting '{MessageType.PostResult}' response, but received '{message.Type}' message.");
}

return postResult;
}

/// <summary>
/// Ask for the output of a previously run command in the connected PowerShell session.
/// </summary>
/// <param name="message">The <see cref="MessageType.AskCommandOutput"/> message.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>A <see cref="MessageType.PostResult"/> message as the response.</returns>
/// <exception cref="IOException">Throws when the pipe is closed by the other side.</exception>
public async Task<PostResultMessage> AskCommandOutput(AskCommandOutputMessage message, CancellationToken cancellationToken)
{
// Send the request message to the shell.
SendMessage(message);

// Receiving response from the shell.
var response = await GetMessageAsync(cancellationToken);
if (response is not PostResultMessage postResult)
{
// Log: unexpected message. drop connection.
_client.Close();
throw new IOException($"Expecting '{MessageType.PostResult}' response, but received '{message.Type}' message.");
}

return postResult;
}
}
4 changes: 2 additions & 2 deletions shell/AIShell.Integration/AIShell.psd1
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
PowerShellVersion = '7.4.6'
PowerShellHostName = 'ConsoleHost'
FunctionsToExport = @()
CmdletsToExport = @('Start-AIShell','Invoke-AIShell','Resolve-Error')
CmdletsToExport = @('Start-AIShell','Invoke-AIShell', 'Invoke-AICommand', 'Resolve-Error')
VariablesToExport = '*'
AliasesToExport = @('aish', 'askai', 'fixit')
AliasesToExport = @('aish', 'askai', 'fixit', 'airun')
HelpInfoURI = 'https://aka.ms/aishell-help'
PrivateData = @{ PSData = @{ Prerelease = 'preview5'; ProjectUri = 'https://github.com/PowerShell/AIShell' } }
}
Loading
Loading