Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(race-conditon): fix data race condition when accessing stderr on pipe #209

Merged
merged 11 commits into from
Sep 2, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
with:
go-version: ${{ matrix.go-version }}
- uses: actions/checkout@v3
- run: go test ./...
- run: go test -race ./...

gocritic:
runs-on: ubuntu-latest
Expand Down
38 changes: 32 additions & 6 deletions script.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Pipe struct {
stdout, stderr io.Writer
httpClient *http.Client

// because pipe stages are concurrent, protect 'err'
// because pipe stages are concurrent, protect 'err' and 'stderr'
bitfield marked this conversation as resolved.
Show resolved Hide resolved
mu *sync.Mutex
err error
}
Expand Down Expand Up @@ -385,8 +385,9 @@ func (p *Pipe) Exec(cmdLine string) *Pipe {
cmd.Stdin = r
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
pipeStderr := p.getStderr()
if pipeStderr != nil {
cmd.Stderr = pipeStderr
}
err = cmd.Start()
if err != nil {
Expand Down Expand Up @@ -425,8 +426,9 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
pipeStderr := p.getStderr()
if pipeStderr != nil {
cmd.Stderr = pipeStderr
}
err = cmd.Start()
if err != nil {
Expand Down Expand Up @@ -762,6 +764,30 @@ func (p *Pipe) SetError(err error) {
p.err = err
}

// setStderr sets the stderr writer on the pipe. This field
// is protected by a mutex since stderr is accessed inside a
// goroutine from [Pipe.Exec].
func (p *Pipe) setStderr(stderr io.Writer) {
if p.mu == nil { // uninitialised pipe
return
}
p.mu.Lock()
defer p.mu.Unlock()
p.stderr = stderr
}

// getStderr obtains the stderr writer on the pipe. This field
// is protected by a mutex since stderr is accessed inside a
// goroutine from [Pipe.Exec].
func (p *Pipe) getStderr() io.Writer {
if p.mu == nil { // uninitialised pipe
return nil
}
p.mu.Lock()
bitfield marked this conversation as resolved.
Show resolved Hide resolved
defer p.mu.Unlock()
return p.stderr
}

// SHA256Sum returns the hex-encoded SHA-256 hash of the entire contents of the
// pipe, or an error.
func (p *Pipe) SHA256Sum() (string, error) {
Expand Down Expand Up @@ -887,7 +913,7 @@ func (p *Pipe) WithReader(r io.Reader) *Pipe {
// [Pipe.Exec] or [Pipe.ExecForEach] to the writer w, instead of going to the
// pipe as it normally would.
func (p *Pipe) WithStderr(w io.Writer) *Pipe {
p.stderr = w
p.setStderr(w)
return p
}

Expand Down
13 changes: 13 additions & 0 deletions script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,19 @@ func TestReadReturnsErrorGivenReadErrorOnPipe(t *testing.T) {
}
}

// TestWithStdErrAfterExec is a regression test that was added to test against
// a race condition for Pipe.stderr.
func TestWithStdErrAfterExec(t *testing.T) {
bitfield marked this conversation as resolved.
Show resolved Hide resolved
t.Parallel()
stdOut := new(bytes.Buffer)
stdErr := new(bytes.Buffer)

_, err := script.Exec("echo").WithStdout(stdOut).WithStderr(stdErr).Stdout()
if err != nil {
t.Fatal(err)
}
}

func ExampleArgs() {
script.Args().Stdout()
// prints command-line arguments
Expand Down