Skip to content

Commit e0afb5b

Browse files
authored
fix(go/core): flow name is not passed to the context (#3718)
1 parent 02d48d8 commit e0afb5b

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

go/core/flow.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ type flowContext struct {
4848
// DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out.
4949
func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] {
5050
return (*Flow[In, Out, struct{}])(DefineAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) {
51-
fc := &flowContext{}
51+
fc := &flowContext{
52+
flowName: name,
53+
}
5254
ctx = flowContextKey.NewContext(ctx, fc)
5355
return fn(ctx, input)
5456
}))
@@ -65,7 +67,9 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo
6567
// Otherwise, it should ignore the callback and just return a result.
6668
func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] {
6769
return (*Flow[In, Out, Stream])(DefineStreamingAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) {
68-
fc := &flowContext{}
70+
fc := &flowContext{
71+
flowName: name,
72+
}
6973
ctx = flowContextKey.NewContext(ctx, fc)
7074
return fn(ctx, input, cb)
7175
}))

go/core/flow_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,26 @@ func TestRunFlow(t *testing.T) {
6666
t.Errorf("got %d, want %d", got, want)
6767
}
6868
}
69+
70+
func TestFlowNameFromContext(t *testing.T) {
71+
r := registry.New()
72+
flows := []*Flow[struct{}, string, struct{}]{
73+
DefineFlow(r, "DefineFlow", func(ctx context.Context, _ struct{}) (string, error) {
74+
return FlowNameFromContext(ctx), nil
75+
}),
76+
DefineStreamingFlow(r, "DefineStreamingFlow", func(ctx context.Context, _ struct{}, s StreamCallback[struct{}]) (string, error) {
77+
return FlowNameFromContext(ctx), nil
78+
}),
79+
}
80+
for _, flow := range flows {
81+
t.Run(flow.Name(), func(t *testing.T) {
82+
got, err := flow.Run(context.Background(), struct{}{})
83+
if err != nil {
84+
t.Fatal(err)
85+
}
86+
if want := flow.Name(); got != want {
87+
t.Errorf("got '%s', want '%s'", got, want)
88+
}
89+
})
90+
}
91+
}

0 commit comments

Comments
 (0)