Skip to content
106 changes: 82 additions & 24 deletions pipe/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ import (
// commandStage is a pipeline `Stage` based on running an external
// command and piping the data through its stdin and stdout.
type commandStage struct {
name string
stdin io.Closer
cmd *exec.Cmd
name string
cmd *exec.Cmd

// lateClosers is a list of things that have to be closed once the
// command has finished.
lateClosers []io.Closer

done chan struct{}
wg errgroup.Group
stderr bytes.Buffer
Expand All @@ -30,6 +34,10 @@ type commandStage struct {
ctxErr atomic.Value
}

var (
_ Stage = (*commandStage)(nil)
)

// Command returns a pipeline `Stage` based on the specified external
// `command`, run with the given command-line `args`. Its stdin and
// stdout are handled as usual, and its stderr is collected and
Expand Down Expand Up @@ -59,33 +67,80 @@ func (s *commandStage) Name() string {
return s.name
}

func (s *commandStage) Preferences() StagePreferences {
prefs := StagePreferences{
StdinPreference: IOPreferenceFile,
StdoutPreference: IOPreferenceFile,
}
if s.cmd.Stdin != nil {
prefs.StdinPreference = IOPreferenceNil
}
if s.cmd.Stdout != nil {
prefs.StdoutPreference = IOPreferenceNil
}

return prefs
}

func (s *commandStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser,
) (io.ReadCloser, error) {
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
if s.cmd.Dir == "" {
s.cmd.Dir = env.Dir
}

s.setupEnv(ctx, env)

// Things that have to be closed as soon as the command has
// started:
var earlyClosers []io.Closer

// See the type command for `Stage` and the long comment in
// `Pipeline.WithStdin()` for the explanation of this unwrapping
// and closing behavior.

if stdin != nil {
// See the long comment in `Pipeline.Start()` for the
// explanation of this special case.
switch stdin := stdin.(type) {
case nopCloser:
case readerNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdin = stdin.Reader
case nopCloserWriterTo:
case readerWriterToNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdin = stdin.Reader
case *os.File:
// In this case, we can close stdin as soon as the command
// has started:
s.cmd.Stdin = stdin
earlyClosers = append(earlyClosers, stdin)
default:
// In this case, we need to close `stdin`, but we should
// only do so after the command has finished:
s.cmd.Stdin = stdin
s.lateClosers = append(s.lateClosers, stdin)
}
// Also keep a copy so that we can close it when the command exits:
s.stdin = stdin
}

stdout, err := s.cmd.StdoutPipe()
if err != nil {
return nil, err
if stdout != nil {
// See the long comment in `Pipeline.Start()` for the
// explanation of this special case.
switch stdout := stdout.(type) {
case writerNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdout = stdout.Writer
case *os.File:
// In this case, we can close stdout as soon as the command
// has started:
s.cmd.Stdout = stdout
earlyClosers = append(earlyClosers, stdout)
default:
// In this case, we need to close `stdout`, but we should
// only do so after the command has finished:
s.cmd.Stdout = stdout
s.lateClosers = append(s.lateClosers, stdout)
}
}

// If the caller hasn't arranged otherwise, read the command's
Expand All @@ -97,7 +152,7 @@ func (s *commandStage) Start(
// can be sure.
p, err := s.cmd.StderrPipe()
if err != nil {
return nil, err
return err
}
s.wg.Go(func() error {
_, err := io.Copy(&s.stderr, p)
Expand All @@ -114,7 +169,11 @@ func (s *commandStage) Start(
s.runInOwnProcessGroup()

if err := s.cmd.Start(); err != nil {
return nil, err
return err
}

for _, closer := range earlyClosers {
_ = closer.Close()
}

// Arrange for the process to be killed (gently) if the context
Expand All @@ -128,7 +187,7 @@ func (s *commandStage) Start(
}
}()

return stdout, nil
return nil
}

// setupEnv sets or modifies the environment that will be passed to
Expand Down Expand Up @@ -217,19 +276,18 @@ func (s *commandStage) Wait() error {

// Make sure that any stderr is copied before `s.cmd.Wait()`
// closes the read end of the pipe:
wErr := s.wg.Wait()
wgErr := s.wg.Wait()

err := s.cmd.Wait()
err = s.filterCmdError(err)

if err == nil && wErr != nil {
err = wErr
if err == nil && wgErr != nil {
err = wgErr
}

if s.stdin != nil {
cErr := s.stdin.Close()
if cErr != nil && err == nil {
return cErr
for _, closer := range s.lateClosers {
if closeErr := closer.Close(); closeErr != nil && err == nil {
err = closeErr
}
}

Expand Down
3 changes: 2 additions & 1 deletion pipe/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func TestCopyEnvWithOverride(t *testing.T) {
ex := ex
t.Run(ex.label, func(t *testing.T) {
assert.ElementsMatch(t, ex.expectedResult,
copyEnvWithOverrides(ex.env, ex.overrides))
copyEnvWithOverrides(ex.env, ex.overrides),
)
})
}
}
4 changes: 4 additions & 0 deletions pipe/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package pipe

// This file exports a functions to be used only for testing.
var UnwrapNopCloser = unwrapNopCloser
41 changes: 34 additions & 7 deletions pipe/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// StageFunc is a function that can be used to power a `goStage`. It
// should read its input from `stdin` and write its output to
// `stdout`. `stdin` and `stdout` will be closed automatically (if
// necessary) once the function returns.
// non-nil) once the function returns.
//
// Neither `stdin` nor `stdout` are necessarily buffered. If the
// `StageFunc` requires buffering, it needs to arrange that itself.
Expand Down Expand Up @@ -38,26 +38,53 @@ type goStage struct {
err error
}

var (
_ Stage = (*goStage)(nil)
)

func (s *goStage) Name() string {
return s.name
}

func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
r, w := io.Pipe()
func (s *goStage) Preferences() StagePreferences {
return StagePreferences{
StdinPreference: IOPreferenceUndefined,
StdoutPreference: IOPreferenceUndefined,
}
}

func (s *goStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
var r io.Reader = stdin
if stdin, ok := stdin.(readerNopCloser); ok {
r = stdin.Reader
}

var w io.Writer = stdout
if stdout, ok := stdout.(writerNopCloser); ok {
w = stdout.Writer
}

go func() {
s.err = s.f(ctx, env, stdin, w)
if err := w.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
s.err = s.f(ctx, env, r, w)

if stdout != nil {
if err := stdout.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
}
}

if stdin != nil {
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
}

close(s.done)
}()

return r, nil
return nil
}

func (s *goStage) Wait() error {
Expand Down
62 changes: 0 additions & 62 deletions pipe/iocopier.go

This file was deleted.

30 changes: 20 additions & 10 deletions pipe/memorylimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import (

const memoryPollInterval = time.Second

// ErrMemoryLimitExceeded is the error that will be used to kill a process, if
// necessary, from MemoryLimit.
// ErrMemoryLimitExceeded is the error that will be used to kill a
// process, if necessary, from MemoryLimit.
var ErrMemoryLimitExceeded = errors.New("memory limit exceeded")

// LimitableStage is the superset of Stage that must be implemented by stages
// passed to MemoryLimit and MemoryObserver.
// LimitableStage is the superset of `Stage` that must be implemented
// by stages passed to MemoryLimit and MemoryObserver.
type LimitableStage interface {
Stage

Expand Down Expand Up @@ -175,12 +175,24 @@ func (m *memoryWatchStage) Name() string {
return m.stage.Name() + m.nameSuffix
}

func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
io, err := m.stage.Start(ctx, env, stdin)
if err != nil {
return nil, err
func (m *memoryWatchStage) Preferences() StagePreferences {
return m.stage.Preferences()
}

func (m *memoryWatchStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
if err := m.stage.Start(ctx, env, stdin, stdout); err != nil {
return err
}

m.monitor(ctx)

return nil
}

// monitor starts up a goroutine that monitors the memory of `m`.
func (m *memoryWatchStage) monitor(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
Expand All @@ -189,8 +201,6 @@ func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadClos
m.watch(ctx, m.stage)
m.wg.Done()
}()

return io, nil
}

func (m *memoryWatchStage) Wait() error {
Expand Down
Loading