Skip to content

Commit 582f341

Browse files
znullCopilot
andcommitted
Decouple close-responsibility from stdin/stdout dynamic type
Rather than inferring who is responsible for closing a pipeline stage's stdin/stdout from the dynamic type, communicate that explicitly in StartOptions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 9b53532 commit 582f341

6 files changed

Lines changed: 201 additions & 65 deletions

File tree

pipe/close_responsibility_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package pipe
2+
3+
import (
4+
"context"
5+
"io"
6+
"os/exec"
7+
"strings"
8+
"sync/atomic"
9+
"testing"
10+
)
11+
12+
// readCloseSpy records whether Close was called.
13+
type readCloseSpy struct {
14+
io.Reader
15+
closed atomic.Bool
16+
}
17+
18+
func (r *readCloseSpy) Close() error {
19+
r.closed.Store(true)
20+
return nil
21+
}
22+
23+
// writeCloseSpy records whether Close was called.
24+
type writeCloseSpy struct {
25+
io.Writer
26+
closed atomic.Bool
27+
}
28+
29+
func (w *writeCloseSpy) Close() error {
30+
w.closed.Store(true)
31+
return nil
32+
}
33+
34+
// TestGoStageHonorsLeaveOpenFlags verifies that a Function stage closes
35+
// stdin/stdout iff the corresponding StartOptions.Leave*Open flag is unset.
36+
func TestGoStageHonorsLeaveOpenFlags(t *testing.T) {
37+
cases := []struct {
38+
name string
39+
leaveIn, leaveOut bool
40+
}{
41+
{"own both", false, false},
42+
{"leave stdin open", true, false},
43+
{"leave stdout open", false, true},
44+
{"leave both open", true, true},
45+
}
46+
for _, tc := range cases {
47+
t.Run(tc.name, func(t *testing.T) {
48+
in := &readCloseSpy{Reader: strings.NewReader("hi")}
49+
out := &writeCloseSpy{Writer: io.Discard}
50+
51+
s := Function("f", func(_ context.Context, _ Env, stdin io.Reader, stdout io.Writer) error {
52+
_, err := io.Copy(stdout, stdin)
53+
return err
54+
})
55+
56+
if err := s.Start(context.Background(), Env{}, in, out, StartOptions{
57+
LeaveStdinOpen: tc.leaveIn,
58+
LeaveStdoutOpen: tc.leaveOut,
59+
}); err != nil {
60+
t.Fatalf("Start: %v", err)
61+
}
62+
if err := s.Wait(); err != nil {
63+
t.Fatalf("Wait: %v", err)
64+
}
65+
66+
if got, want := in.closed.Load(), !tc.leaveIn; got != want {
67+
t.Errorf("stdin closed = %v, want %v (LeaveStdinOpen=%v)", got, want, tc.leaveIn)
68+
}
69+
if got, want := out.closed.Load(), !tc.leaveOut; got != want {
70+
t.Errorf("stdout closed = %v, want %v (LeaveStdoutOpen=%v)", got, want, tc.leaveOut)
71+
}
72+
})
73+
}
74+
}
75+
76+
// TestCommandStageHonorsLeaveStdinOpen verifies that a command stage closes a
77+
// non-file stdin (a "late" closer) iff LeaveStdinOpen is unset. An empty
78+
// reader is used so exec.Cmd's input-copy goroutine sees EOF promptly.
79+
func TestCommandStageHonorsLeaveStdinOpen(t *testing.T) {
80+
for _, leave := range []bool{false, true} {
81+
name := "owns stdin"
82+
if leave {
83+
name = "leaves stdin open"
84+
}
85+
t.Run(name, func(t *testing.T) {
86+
in := &readCloseSpy{Reader: strings.NewReader("")}
87+
88+
cmd := exec.Command("true")
89+
s := CommandStage("true", cmd).(*commandStage)
90+
91+
if err := s.Start(context.Background(), Env{}, in, nil, StartOptions{
92+
LeaveStdinOpen: leave,
93+
}); err != nil {
94+
t.Fatalf("Start: %v", err)
95+
}
96+
if err := s.Wait(); err != nil {
97+
t.Fatalf("Wait: %v", err)
98+
}
99+
100+
if got, want := in.closed.Load(), !leave; got != want {
101+
t.Errorf("stdin closed = %v, want %v (LeaveStdinOpen=%v)", got, want, leave)
102+
}
103+
})
104+
}
105+
}
106+
107+
// TestCommandStageHonorsLeaveStdoutOpen verifies the stdout counterpart: a
108+
// non-file stdout (routed through the pooled-copy path) is closed iff
109+
// LeaveStdoutOpen is unset.
110+
func TestCommandStageHonorsLeaveStdoutOpen(t *testing.T) {
111+
for _, leave := range []bool{false, true} {
112+
name := "owns stdout"
113+
if leave {
114+
name = "leaves stdout open"
115+
}
116+
t.Run(name, func(t *testing.T) {
117+
out := &writeCloseSpy{Writer: io.Discard}
118+
119+
cmd := exec.Command("true")
120+
s := CommandStage("true", cmd).(*commandStage)
121+
122+
if err := s.Start(context.Background(), Env{}, nil, out, StartOptions{
123+
LeaveStdoutOpen: leave,
124+
}); err != nil {
125+
t.Fatalf("Start: %v", err)
126+
}
127+
if err := s.Wait(); err != nil {
128+
t.Fatalf("Wait: %v", err)
129+
}
130+
131+
if got, want := out.closed.Load(), !leave; got != want {
132+
t.Errorf("stdout closed = %v, want %v (LeaveStdoutOpen=%v)", got, want, leave)
133+
}
134+
})
135+
}
136+
}

pipe/command.go

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (s *commandStage) Preferences() StagePreferences {
7777
}
7878

7979
func (s *commandStage) Start(
80-
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, _ StartOptions,
80+
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, opts StartOptions,
8181
) error {
8282
if s.cmd.Dir == "" {
8383
s.cmd.Dir = env.Dir
@@ -92,62 +92,44 @@ func (s *commandStage) Start(
9292
// See the type comment for `Stage` and the long comment in
9393
// `Pipeline.WithStdin()` for the explanation of this unwrapping
9494
// and closing behavior.
95-
9695
if stdin != nil {
97-
switch stdin := stdin.(type) {
98-
case readerNopCloser:
99-
// In this case, we shouldn't close it. But unwrap it for
100-
// efficiency's sake:
101-
s.cmd.Stdin = UnwrapReader(stdin)
102-
case *os.File:
103-
// In this case, we can close stdin as soon as the command
104-
// has started:
105-
s.cmd.Stdin = stdin
106-
earlyClosers = append(earlyClosers, stdin)
96+
// For a non-wrapped value this is a no-op.
97+
reader := UnwrapReader(stdin)
98+
s.cmd.Stdin = reader
99+
100+
switch {
101+
case opts.LeaveStdinOpen:
102+
// leave it open.
107103
default:
108-
// In this case, we need to close `stdin`, but we should
109-
// only do so after the command has finished:
110-
s.cmd.Stdin = stdin
111-
s.lateClosers = append(s.lateClosers, stdin)
104+
if _, ok := reader.(*os.File); ok {
105+
// We can close our copy as soon as the command has started
106+
earlyClosers = append(earlyClosers, stdin)
107+
} else {
108+
// We need to close `stdin`, but only after the command has finished
109+
s.lateClosers = append(s.lateClosers, stdin)
110+
}
112111
}
113112
}
114113

115114
if stdout != nil {
116-
// See the long comment in `Pipeline.Start()` for the
117-
// explanation of this special case.
118-
switch stdout := stdout.(type) {
119-
case writerNopCloser:
120-
// We shouldn't close the wrapped writer. Unwrap it; if
121-
// it's an `*os.File`, exec.Cmd can pass the fd directly
122-
// to the child. Otherwise route the copy through our own
123-
// pipe so we can use a pooled buffer.
124-
writer := UnwrapWriter(stdout)
125-
if f, ok := writer.(*os.File); ok {
126-
s.cmd.Stdout = f
127-
} else {
128-
ec, err := s.setupPooledStdout(writer)
129-
if err != nil {
130-
return err
131-
}
132-
earlyClosers = append(earlyClosers, ec)
115+
writer := UnwrapWriter(stdout)
116+
if f, ok := writer.(*os.File); ok {
117+
s.cmd.Stdout = f
118+
if !opts.LeaveStdoutOpen {
119+
earlyClosers = append(earlyClosers, stdout)
133120
}
134-
case *os.File:
135-
// In this case, we can close stdout as soon as the command
136-
// has started:
137-
s.cmd.Stdout = stdout
138-
earlyClosers = append(earlyClosers, stdout)
139-
default:
140-
// In this case, we need to close `stdout`, but we should
141-
// only do so after the command has finished. We also
142-
// route the copy through our own pipe so we can use a
121+
} else {
122+
// Route the copy through our own pipe so we can use a
143123
// pooled buffer rather than letting exec.Cmd allocate a
144124
// fresh 32KB buffer for its internal io.Copy.
145-
ec, err := s.setupPooledStdout(stdout)
125+
ec, err := s.setupPooledStdout(writer)
146126
if err != nil {
147127
return err
148128
}
149129
earlyClosers = append(earlyClosers, ec)
150-
s.lateClosers = append(s.lateClosers, stdout)
130+
if !opts.LeaveStdoutOpen {
131+
s.lateClosers = append(s.lateClosers, stdout)
132+
}
151133
}
152134
}
153135

pipe/function.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ func (s *goStage) Start(
7474
s.err = opts.PanicHandler(p)
7575
}
7676
}
77-
if stdout != nil {
77+
if stdout != nil && !opts.LeaveStdoutOpen {
7878
if err := stdout.Close(); err != nil && s.err == nil {
7979
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
8080
}
8181
}
82-
if stdin != nil {
82+
if stdin != nil && !opts.LeaveStdinOpen {
8383
if err := stdin.Close(); err != nil && s.err == nil {
8484
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
8585
}

pipe/nop_closer.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ func (readerNopCloser) Close() error {
2525
return nil
2626
}
2727

28-
// writerNopCloser is a WriteCloser that wraps a provided `io.Writer`, but
29-
// whose `Close()` method does nothing. It should be unwrapped (via
30-
// [UnwrapWriter]) before use where fast-path interfaces such as
31-
// `io.ReaderFrom` are relevant.
28+
// writerNopCloser is the stdout counterpart of [readerNopCloser]
3229
type writerNopCloser struct {
3330
io.Writer
3431
}

pipe/pipeline.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ type Pipeline struct {
6060
stages []Stage
6161
cancel func()
6262

63+
leaveStdoutOpen bool // only matters when stdout is non-nil
64+
6365
// Atomically written and read value, nonzero if the pipeline has
6466
// been started. This is only used for lifecycle sanity checks but
6567
// does not guarantee that clients are using the class correctly.
@@ -152,6 +154,7 @@ func WithStdin(stdin io.Reader) Option {
152154
func WithStdout(stdout io.Writer) Option {
153155
return func(p *Pipeline) {
154156
p.stdout = writerNopCloser{stdout}
157+
p.leaveStdoutOpen = true
155158
}
156159
}
157160

@@ -160,6 +163,7 @@ func WithStdout(stdout io.Writer) Option {
160163
func WithStdoutCloser(stdout io.WriteCloser) Option {
161164
return func(p *Pipeline) {
162165
p.stdout = stdout
166+
p.leaveStdoutOpen = false
163167
}
164168
}
165169

@@ -258,6 +262,19 @@ type stageStarter struct {
258262
stdout io.WriteCloser
259263
}
260264

265+
// startOptions builds the StartOptions for the stage at index i. It sets
266+
// LeaveStdinOpen/LeaveStdoutOpen for the first and last stages, as appropriate.
267+
func (p *Pipeline) startOptions(i int) StartOptions {
268+
opts := StartOptions{PanicHandler: p.panicHandler}
269+
if i == 0 && p.stdin != nil {
270+
opts.LeaveStdinOpen = true
271+
}
272+
if i == len(p.stages)-1 && p.stdout != nil {
273+
opts.LeaveStdoutOpen = p.leaveStdoutOpen
274+
}
275+
return opts
276+
}
277+
261278
// Start starts the commands in the pipeline. If `Start()` exits
262279
// without an error, `Wait()` must also be called, to allow all
263280
// resources to be freed.
@@ -320,7 +337,7 @@ func (p *Pipeline) Start(ctx context.Context) error {
320337
// Close the pipe that the previous stage was writing to.
321338
// That should cause it to exit even if it's not minding
322339
// its context.
323-
if stageStarters[i].stdin != nil {
340+
if stageStarters[i].stdin != nil && !p.startOptions(i).LeaveStdinOpen {
324341
_ = stageStarters[i].stdin.Close()
325342
}
326343

@@ -361,7 +378,7 @@ func (p *Pipeline) Start(ctx context.Context) error {
361378
} else {
362379
nextSS.stdin, ss.stdout = io.Pipe()
363380
}
364-
if err := s.Start(ctx, p.env, ss.stdin, ss.stdout, StartOptions{PanicHandler: p.panicHandler}); err != nil {
381+
if err := s.Start(ctx, p.env, ss.stdin, ss.stdout, p.startOptions(i)); err != nil {
365382
nextSS.stdin.Close()
366383
ss.stdout.Close()
367384
return abort(i, err)
@@ -376,7 +393,7 @@ func (p *Pipeline) Start(ctx context.Context) error {
376393
s := p.stages[i]
377394
ss := &stageStarters[i]
378395

379-
if err := s.Start(ctx, p.env, ss.stdin, ss.stdout, StartOptions{PanicHandler: p.panicHandler}); err != nil {
396+
if err := s.Start(ctx, p.env, ss.stdin, ss.stdout, p.startOptions(i)); err != nil {
380397
return abort(i, err)
381398
}
382399
}
@@ -387,6 +404,7 @@ func (p *Pipeline) Start(ctx context.Context) error {
387404
func (p *Pipeline) Output(ctx context.Context) ([]byte, error) {
388405
var buf bytes.Buffer
389406
p.stdout = writerNopCloser{&buf}
407+
p.leaveStdoutOpen = true
390408
err := p.Run(ctx)
391409
return buf.Bytes(), err
392410
}

pipe/stage.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ import (
1010
//
1111
// Who closes stdin and stdout?
1212
//
13-
// A `Stage` as a whole needs to be responsible for closing its end of
13+
// A `Stage` as a whole is responsible for closing its end of
1414
// stdin and stdout (assuming that `Start()` returns successfully).
1515
// Its doing so tells the previous/next stage that it is done
1616
// reading/writing data, which can affect their behavior. Therefore,
1717
// it should close each one as soon as it is done with it. If the
1818
// caller wants to suppress the closing of stdin/stdout, it can always
19-
// wrap the corresponding argument in a "nopCloser".
19+
// indicate otherwise via `StartOptions`.
2020
//
2121
// How this should be done depends on whether stdin/stdout are of type
2222
// `*os.File`.
@@ -63,18 +63,17 @@ import (
6363
// }()
6464
//
6565
// From the point of view of the pipeline as a whole, if stdin is
66-
// provided by the user (`WithStdin()`), then we don't want to close
67-
// it at all, whether it's an `*os.File` or not. For this reason,
68-
// stdin has to be wrapped using a `readerNopCloser` before being
69-
// passed into the first stage. For efficiency reasons, the first
70-
// stage should ideally unwrap its stdin argument (using
71-
// [UnwrapReader]) before actually using it. If the wrapped value is
72-
// an `*os.File` and the stage is a command stage, then unwrapping is
73-
// also important to get the right semantics.
66+
// provided by the user (`WithStdin()`), then we don't want the first
67+
// stage to close it at all, whether it's an `*os.File` or not. The
68+
// pipeline communicates this by setting `StartOptions.LeaveStdinOpen`
69+
// when it starts that stage. stdin is still wrapped in a
70+
// `readerNopCloser` before being passed in, but only so that a bare
71+
// `io.Reader` satisfies `io.ReadCloser`, and so that a command stage
72+
// can recover the underlying object via [UnwrapReader].
7473
//
7574
// For stdout, it depends on whether the user supplied it using
76-
// `WithStdout()` or `WithStdoutCloser()`. If the former, then the
77-
// considerations are the same as for stdin.
75+
// `WithStdout()` or `WithStdoutCloser()`. [UnwrapWriter] plays the same
76+
// role for stdout that [UnwrapReader] plays for stdin.
7877
//
7978
// [1] It's theoretically possible for a command to pass the open file
8079
// descriptor to another, longer-lived process, in which case the
@@ -118,6 +117,10 @@ type StartOptions struct {
118117
// handler), converting it into an error. Stage types that don't run
119118
// user code in a library-spawned goroutine ignore it.
120119
PanicHandler StagePanicHandler
120+
121+
// LeaveStd{in,out}Open tell the stage that it must NOT close stdin/stdout
122+
LeaveStdinOpen bool
123+
LeaveStdoutOpen bool
121124
}
122125

123126
// StagePanicHandler is a function that handles panics in a pipeline's stages.

0 commit comments

Comments
 (0)