@@ -11,6 +11,7 @@ import (
1111 "strconv"
1212 "strings"
1313 "testing"
14+ "time"
1415
1516 "github.com/getkin/kin-openapi/openapi3"
1617 "github.com/stretchr/testify/require"
@@ -134,7 +135,7 @@ func TestListModelsWithDefaultProvider(t *testing.T) {
134135 }
135136}
136137
137- func TestAbortRun (t * testing.T ) {
138+ func TestCancelRun (t * testing.T ) {
138139 tool := ToolDef {Instructions : "What is the capital of the united states?" }
139140
140141 run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
@@ -146,7 +147,7 @@ func TestAbortRun(t *testing.T) {
146147 <- run .Events ()
147148
148149 if err := run .Close (); err != nil {
149- t .Errorf ("Error aborting run: %v" , err )
150+ t .Errorf ("Error canceling run: %v" , err )
150151 }
151152
152153 if run .State () != Error {
@@ -158,6 +159,77 @@ func TestAbortRun(t *testing.T) {
158159 }
159160}
160161
162+ func TestAbortChatCompletionRun (t * testing.T ) {
163+ tool := ToolDef {Instructions : "What is the capital of the united states?" }
164+
165+ run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
166+ if err != nil {
167+ t .Errorf ("Error executing tool: %v" , err )
168+ }
169+
170+ // Abort the run after the first event from the LLM
171+ for e := range run .Events () {
172+ if e .Call != nil && e .Call .Type == EventTypeCallProgress && len (e .Call .Output ) > 0 && e .Call .Output [0 ].Content != "Waiting for model response..." {
173+ break
174+ }
175+ }
176+
177+ if err := g .AbortRun (context .Background (), run ); err != nil {
178+ t .Errorf ("Error aborting run: %v" , err )
179+ }
180+
181+ // Wait for run to stop
182+ for range run .Events () {
183+ continue
184+ }
185+
186+ if run .State () != Finished {
187+ t .Errorf ("Unexpected run state: %s" , run .State ())
188+ }
189+
190+ if out , err := run .Text (); err != nil {
191+ t .Errorf ("Error reading output: %v" , err )
192+ } else if strings .TrimSpace (out ) != "ABORTED BY USER" && ! strings .HasSuffix (out , "\n ABORTED BY USER" ) {
193+ t .Errorf ("Unexpected output: %s" , out )
194+ }
195+ }
196+
197+ func TestAbortCommandRun (t * testing.T ) {
198+ tool := ToolDef {Instructions : "#!/usr/bin/env bash\n echo Hello, world!\n sleep 5\n echo Hello, again!\n sleep 5" }
199+
200+ run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
201+ if err != nil {
202+ t .Errorf ("Error executing tool: %v" , err )
203+ }
204+
205+ // Abort the run after the first event.
206+ for e := range run .Events () {
207+ if e .Call != nil && e .Call .Type == EventTypeChat {
208+ time .Sleep (2 * time .Second )
209+ break
210+ }
211+ }
212+
213+ if err := g .AbortRun (context .Background (), run ); err != nil {
214+ t .Errorf ("Error aborting run: %v" , err )
215+ }
216+
217+ // Wait for run to stop
218+ for range run .Events () {
219+ continue
220+ }
221+
222+ if run .State () != Finished {
223+ t .Errorf ("Unexpected run state: %s" , run .State ())
224+ }
225+
226+ if out , err := run .Text (); err != nil {
227+ t .Errorf ("Error reading output: %v" , err )
228+ } else if ! strings .Contains (out , "Hello, world!" ) || strings .Contains (out , "Hello, again!" ) || ! strings .HasSuffix (out , "\n ABORTED BY USER" ) {
229+ t .Errorf ("Unexpected output: %s" , out )
230+ }
231+ }
232+
161233func TestSimpleEvaluate (t * testing.T ) {
162234 tool := ToolDef {Instructions : "What is the capital of the united states?" }
163235
@@ -844,6 +916,69 @@ func TestToolChat(t *testing.T) {
844916 }
845917}
846918
919+ func TestAbortChat (t * testing.T ) {
920+ tool := ToolDef {
921+ Chat : true ,
922+ Instructions : "You are a chat bot. Don't finish the conversation until I say 'bye'." ,
923+ Tools : []string {"sys.chat.finish" },
924+ }
925+
926+ run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
927+ if err != nil {
928+ t .Fatalf ("Error executing tool: %v" , err )
929+ }
930+ inputs := []string {
931+ "Tell me a joke." ,
932+ "What was my first message?" ,
933+ }
934+
935+ // Just wait for the chat to start up.
936+ for range run .Events () {
937+ continue
938+ }
939+
940+ for i , input := range inputs {
941+ run , err = run .NextChat (context .Background (), input )
942+ if err != nil {
943+ t .Fatalf ("Error sending next input %q: %v" , input , err )
944+ }
945+
946+ // Abort the run after the first event from the LLM
947+ for e := range run .Events () {
948+ if e .Call != nil && e .Call .Type == EventTypeCallProgress && len (e .Call .Output ) > 0 && e .Call .Output [0 ].Content != "Waiting for model response..." {
949+ break
950+ }
951+ }
952+
953+ if i == 0 {
954+ if err := g .AbortRun (context .Background (), run ); err != nil {
955+ t .Fatalf ("Error aborting run: %v" , err )
956+ }
957+ }
958+
959+ // Wait for the run to complete
960+ for range run .Events () {
961+ continue
962+ }
963+
964+ out , err := run .Text ()
965+ if err != nil {
966+ t .Errorf ("Error reading output: %s" , run .ErrorOutput ())
967+ t .Fatalf ("Error reading output: %v" , err )
968+ }
969+
970+ if i == 0 {
971+ if strings .TrimSpace (out ) != "ABORTED BY USER" && ! strings .HasSuffix (out , "\n ABORTED BY USER" ) {
972+ t .Fatalf ("Unexpected output: %s" , out )
973+ }
974+ } else {
975+ if ! strings .Contains (out , "Tell me a joke" ) {
976+ t .Errorf ("Unexpected output: %s" , out )
977+ }
978+ }
979+ }
980+ }
981+
847982func TestFileChat (t * testing.T ) {
848983 wd , err := os .Getwd ()
849984 if err != nil {
0 commit comments