Skip to content

fix: fix with stderr data race #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions script/constructor.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package script

// TODO: generate constructrs from bitfield/script's source code
// TODO: generate constructors from bitfield/script's source code

import (
"context"
"net/http"
"sync"

"github.com/bitfield/script"
)

func newPipeFrom(pipe *script.Pipe) *Pipe {
return &Pipe{Pipe: pipe}
return &Pipe{
Pipe: pipe,
mu: new(sync.Mutex),
}
}

func NewPipe() *Pipe {
Expand Down
44 changes: 36 additions & 8 deletions script/contextual.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os/exec"
"path/filepath"
"strings"
"sync"
"text/template"

"github.com/bitfield/script"
Expand All @@ -25,15 +26,30 @@ var NewReadAutoCloser = script.NewReadAutoCloser
type Pipe struct {
*script.Pipe

stderr io.Writer // captured from WithStderr

// wd is the working directory for current pipe.
wd string

mu *sync.Mutex // protects the following fields

stderr io.Writer // captured from WithStderr

// env is the environment variables for current pipe.
// Non-empty value will be set to the exec.Command instance.
env []string
}

func (p *Pipe) environments() []string {
p.mu.Lock()
defer p.mu.Unlock()
return p.env
}

func (p *Pipe) stdErr() io.Writer {
p.mu.Lock()
defer p.mu.Unlock()
return p.stderr
}

func (p *Pipe) At(dir string) *Pipe {
p.wd = dir
return p
Expand All @@ -46,23 +62,35 @@ func (p *Pipe) WithCurrentEnv() *Pipe {

// WithEnv sets the environment variables for the current pipe.
func (p *Pipe) WithEnv(env []string) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.env = env
return p
}

// AppendEnv appends the environment variables for the current pipe.
func (p *Pipe) AppendEnv(env ...string) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.env = append(p.env, env...)
return p
}

// WithEnvKV sets the environment variable key-value pair for the current pipe.
func (p *Pipe) WithEnvKV(key, value string) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.env = append(p.env, key+"="+value)
return p
}

func (p *Pipe) WithStderr(w io.Writer) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.stderr = w
p.Pipe = p.Pipe.WithStderr(w)
return p
Expand Down Expand Up @@ -132,8 +160,8 @@ func (p *Pipe) execContext(
if p.wd != "" {
cmd.Dir = p.wd
}
if len(p.env) > 0 {
cmd.Env = p.env
if envs := p.environments(); len(envs) > 0 {
cmd.Env = envs
}

return cmd
Expand All @@ -150,8 +178,8 @@ func (p *Pipe) ExecContext(ctx context.Context, cmdLine string) *Pipe {
cmd.Stdin = r
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
if stderr := p.stdErr(); stderr != nil {
cmd.Stderr = stderr
}

if err := cmd.Start(); err != nil {
Expand Down Expand Up @@ -189,8 +217,8 @@ func (p *Pipe) ExecForEachContext(ctx context.Context, cmdLine string) *Pipe {
cmd := p.execContext(ctx, args[0], args[1:])
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
if stderr := p.stdErr(); stderr != nil {
cmd.Stderr = stderr
}
err = cmd.Start()
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions script/contextual_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package script

import (
"context"
"testing"
)

func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) {
t.Parallel()
ctx := context.Background()
err := ExecContext(ctx, "echo").WithStderr(nil).Wait()
if err != nil {
t.Fatal(err)
}
}
4 changes: 2 additions & 2 deletions script/go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module github.com/b4fun/script-contextual/script

go 1.22.3
go 1.22.7

require (
github.com/bitfield/script v0.22.1
github.com/bitfield/script v0.23.0
mvdan.cc/sh/v3 v3.7.0
)

Expand Down
4 changes: 2 additions & 2 deletions script/go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/bitfield/script v0.22.1 h1:DphxoC5ssYciwd0ZS+N0Xae46geAD/0mVWh6a2NUxM4=
github.com/bitfield/script v0.22.1/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/bitfield/script v0.23.0 h1:N0R5yLEl6wJIS9PR/A6xXwjMsplMubyxdi05N5l0X28=
github.com/bitfield/script v0.23.0/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
Expand Down
4 changes: 2 additions & 2 deletions tests/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module tests

go 1.22.3
go 1.22.7

replace github.com/b4fun/script-contextual/script => ../script

Expand All @@ -11,7 +11,7 @@ require (
)

require (
github.com/bitfield/script v0.22.1 // indirect
github.com/bitfield/script v0.23.0 // indirect
github.com/itchyny/gojq v0.12.13 // indirect
github.com/itchyny/timefmt-go v0.1.5 // indirect
golang.org/x/sys v0.10.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions tests/go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/bitfield/script v0.22.1 h1:DphxoC5ssYciwd0ZS+N0Xae46geAD/0mVWh6a2NUxM4=
github.com/bitfield/script v0.22.1/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/bitfield/script v0.23.0 h1:N0R5yLEl6wJIS9PR/A6xXwjMsplMubyxdi05N5l0X28=
github.com/bitfield/script v0.23.0/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
Expand Down
175 changes: 175 additions & 0 deletions tests/script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,40 @@ func TestWithStdout_SetsSpecifiedWriterAsStdout(t *testing.T) {
}
}

func TestWithEnv_UnsetsAllEnvVarsGivenEmptySlice(t *testing.T) {
t.Parallel()
p := script.NewPipe().WithEnv([]string{"ENV1=test1"}).Exec("sh -c 'echo ENV1=$ENV1'")
want := "ENV1=test1\n"
got, err := p.String()
if err != nil {
t.Fatal(err)
}
if got != want {
t.Fatalf("want %q, got %q", want, got)
}
got, err = p.Echo("").WithEnv([]string{}).Exec("sh -c 'echo ENV1=$ENV1'").String()
if err != nil {
t.Fatal(err)
}
want = "ENV1=\n"
if got != want {
t.Errorf("want %q, got %q", want, got)
}
}

func TestWithEnv_SetsGivenVariablesForSubsequentExec(t *testing.T) {
t.Parallel()
env := []string{"ENV1=test1", "ENV2=test2"}
got, err := script.NewPipe().WithEnv(env).Exec("sh -c 'echo ENV1=$ENV1 ENV2=$ENV2'").String()
if err != nil {
t.Fatal(err)
}
want := "ENV1=test1 ENV2=test2\n"
if got != want {
t.Errorf("want %q, got %q", want, got)
}
}

func TestErrorReturnsErrorSetByPreviousPipeStage(t *testing.T) {
t.Parallel()
p := script.File("testdata/nonexistent.txt")
Expand Down Expand Up @@ -1850,6 +1884,135 @@ func TestReadReturnsErrorGivenReadErrorOnPipe(t *testing.T) {
}
}

func TestWait_ReturnsErrorPresentOnPipe(t *testing.T) {
t.Parallel()
p := script.Echo("a\nb\nc\n").ExecForEach("{{invalid template syntax}}")
if p.Wait() == nil {
t.Error("want error, got nil")
}
}

func TestWait_DoesNotReturnErrorForValidExecution(t *testing.T) {
t.Parallel()
p := script.Echo("a\nb\nc\n").ExecForEach("echo \"{{.}}\"")
if err := p.Wait(); err != nil {
t.Fatal(err)
}
}

var base64Cases = []struct {
name string
decoded string
encoded string
}{
{
name: "empty string",
decoded: "",
encoded: "",
},
{
name: "single line string",
decoded: "hello world",
encoded: "aGVsbG8gd29ybGQ=",
},
{
name: "multi line string",
decoded: "hello\nthere\nworld\n",
encoded: "aGVsbG8KdGhlcmUKd29ybGQK",
},
}

func TestEncodeBase64_CorrectlyEncodes(t *testing.T) {
t.Parallel()
for _, tc := range base64Cases {
t.Run(tc.name, func(t *testing.T) {
got, err := script.Echo(tc.decoded).EncodeBase64().String()
if err != nil {
t.Fatal(err)
}
if got != tc.encoded {
t.Logf("input %q incorrectly encoded:", tc.decoded)
t.Error(cmp.Diff(tc.encoded, got))
}
})
}
}

func TestDecodeBase64_CorrectlyDecodes(t *testing.T) {
t.Parallel()
for _, tc := range base64Cases {
t.Run(tc.name, func(t *testing.T) {
got, err := script.Echo(tc.encoded).DecodeBase64().String()
if err != nil {
t.Fatal(err)
}
if got != tc.decoded {
t.Logf("input %q incorrectly decoded:", tc.encoded)
t.Error(cmp.Diff(tc.decoded, got))
}
})
}
}

func TestEncodeBase64_FollowedByDecodeRecoversOriginal(t *testing.T) {
t.Parallel()
for _, tc := range base64Cases {
t.Run(tc.name, func(t *testing.T) {
decoded, err := script.Echo(tc.decoded).EncodeBase64().DecodeBase64().String()
if err != nil {
t.Fatal(err)
}
if decoded != tc.decoded {
t.Error("encode-decode round trip failed:", cmp.Diff(tc.decoded, decoded))
}
encoded, err := script.Echo(tc.encoded).DecodeBase64().EncodeBase64().String()
if err != nil {
t.Fatal(err)
}
if encoded != tc.encoded {
t.Error("decode-encode round trip failed:", cmp.Diff(tc.encoded, encoded))
}
})
}
}

func TestDecodeBase64_CorrectlyDecodesInputToBytes(t *testing.T) {
t.Parallel()
input := "CAAAEA=="
got, err := script.Echo(input).DecodeBase64().Bytes()
if err != nil {
t.Fatal(err)
}
want := []byte{8, 0, 0, 16}
if !bytes.Equal(want, got) {
t.Logf("input %#v incorrectly decoded:", input)
t.Error(cmp.Diff(want, got))
}
}

func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) {
t.Parallel()
input := []byte{8, 0, 0, 16}
reader := bytes.NewReader(input)
want := "CAAAEA=="
got, err := script.NewPipe().WithReader(reader).EncodeBase64().String()
if err != nil {
t.Fatal(err)
}
if got != want {
t.Logf("input %#v incorrectly encoded:", input)
t.Error(cmp.Diff(want, got))
}
}

func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) {
t.Parallel()
err := script.Exec("echo").WithStderr(nil).Wait()
if err != nil {
t.Fatal(err)
}
}

func ExampleArgs() {
script.Args().Stdout()
// prints command-line arguments
Expand Down Expand Up @@ -1969,6 +2132,12 @@ func ExamplePipe_CountLines() {
// 3
}

func ExamplePipe_DecodeBase64() {
script.Echo("SGVsbG8sIHdvcmxkIQ==").DecodeBase64().Stdout()
// Output:
// Hello, world!
}

func ExamplePipe_Do() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
Expand Down Expand Up @@ -2004,6 +2173,12 @@ func ExamplePipe_Echo() {
// Hello, world!
}

func ExamplePipe_EncodeBase64() {
script.Echo("Hello, world!").EncodeBase64().Stdout()
// Output:
// SGVsbG8sIHdvcmxkIQ==
}

func ExamplePipe_ExitStatus() {
p := script.Exec("echo")
fmt.Println(p.ExitStatus())
Expand Down
Loading