diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 384f417..654eedb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: with: go-version: ${{ matrix.go-version }} - uses: actions/checkout@v3 - - run: go test ./... + - run: go test -race ./... gocritic: runs-on: ubuntu-latest diff --git a/script.go b/script.go index 953ef68..cff9351 100644 --- a/script.go +++ b/script.go @@ -28,13 +28,14 @@ import ( // Pipe represents a pipe object with an associated [ReadAutoCloser]. type Pipe struct { // Reader is the underlying reader. - Reader ReadAutoCloser - stdout, stderr io.Writer - httpClient *http.Client + Reader ReadAutoCloser + stdout io.Writer + httpClient *http.Client - // because pipe stages are concurrent, protect 'err' - mu *sync.Mutex - err error + // because pipe stages are concurrent, protect 'err' and 'stderr' + mu *sync.Mutex + err error + stderr io.Writer } // Args creates a pipe containing the program's command-line arguments from @@ -414,8 +415,9 @@ func (p *Pipe) Exec(cmdLine string) *Pipe { cmd.Stdin = r cmd.Stdout = w cmd.Stderr = w - if p.stderr != nil { - cmd.Stderr = p.stderr + pipeStderr := p.stdErr() + if pipeStderr != nil { + cmd.Stderr = pipeStderr } err = cmd.Start() if err != nil { @@ -454,8 +456,9 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe { cmd := exec.Command(args[0], args[1:]...) cmd.Stdout = w cmd.Stderr = w - if p.stderr != nil { - cmd.Stderr = p.stderr + pipeStderr := p.stdErr() + if pipeStderr != nil { + cmd.Stderr = pipeStderr } err = cmd.Start() if err != nil { @@ -839,6 +842,18 @@ func (p *Pipe) Slice() ([]string, error) { return result, p.Error() } +// stdErr returns the pipe's configured standard error writer for commands run +// via [Pipe.Exec] and [Pipe.ExecForEach]. The default is nil, which means that +// error output will go to the pipe. +func (p *Pipe) stdErr() io.Writer { + if p.mu == nil { // uninitialised pipe + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + return p.stderr +} + // Stdout copies the pipe's contents to its configured standard output (using // [Pipe.WithStdout]), or to [os.Stdout] otherwise, and returns the number of // bytes successfully written, together with any error. @@ -913,10 +928,11 @@ func (p *Pipe) WithReader(r io.Reader) *Pipe { return p } -// WithStderr redirects the standard error output for commands run via -// [Pipe.Exec] or [Pipe.ExecForEach] to the writer w, instead of going to the -// pipe as it normally would. +// WithStderr sets the standard error output for [Pipe.Exec] or +// [Pipe.ExecForEach] commands to w, instead of the pipe. func (p *Pipe) WithStderr(w io.Writer) *Pipe { + p.mu.Lock() + defer p.mu.Unlock() p.stderr = w return p } diff --git a/script_test.go b/script_test.go index 296869a..e006e7b 100644 --- a/script_test.go +++ b/script_test.go @@ -1971,6 +1971,14 @@ func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) { } } +func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) { + t.Parallel() + err := script.Exec("echo").WithStderr(nil).Wait() + if err != nil { + t.Fatal(err) + } +} + func ExampleArgs() { script.Args().Stdout() // prints command-line arguments