diff --git a/worker/task.go b/worker/task.go index 92c1d02350f..5e575997372 100644 --- a/worker/task.go +++ b/worker/task.go @@ -795,15 +795,20 @@ func (qs *queryState) handleUidPostings( needFiltering := needsStringFiltering(srcFn, q.Langs, q.Attr) isList := schema.State().IsList(q.Attr) - errCh := make(chan error, numGo) outputs := make([]*pb.Result, numGo) + eg, egCtx := errgroup.WithContext(ctx) calculate := func(start, end int) error { x.AssertTrue(start%width == 0) out := &pb.Result{} outputs[start/width] = out for i := start; i < end; i++ { + select { + case <-egCtx.Done(): + return egCtx.Err() + default: + } if i%100 == 0 { select { case <-ctx.Done(): @@ -950,14 +955,12 @@ func (qs *queryState) handleUidPostings( if end > srcFn.n { end = srcFn.n } - go func(start, end int) { - errCh <- calculate(start, end) - }(start, end) + eg.Go(func() error { + return calculate(start, end) + }) } - for range numGo { - if err := <-errCh; err != nil { - return err - } + if err := eg.Wait(); err != nil { + return err } // All goroutines are done. Now attach their results. out := args.out @@ -1597,11 +1600,17 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error attribute.Int("num_go", numGo), attribute.Int("width", width))) + eg, egCtx := errgroup.WithContext(ctx) filtered := make([]*pb.List, numGo) filter := func(idx, start, end int) error { filtered[idx] = &pb.List{} out := filtered[idx] for _, uid := range uids.Uids[start:end] { + select { + case <-egCtx.Done(): + return egCtx.Err() + default: + } pl, err := qs.cache.Get(x.DataKey(attr, uid)) if err != nil { return err @@ -1623,21 +1632,19 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error return nil } - errCh := make(chan error, numGo) for i := range numGo { start := i * width end := start + width if end > len(uids.Uids) { end = len(uids.Uids) } - go func(idx, start, end int) { - errCh <- filter(idx, start, end) - }(i, start, end) + idx := i + eg.Go(func() error { + return filter(idx, start, end) + }) } - for range numGo { - if err := <-errCh; err != nil { - return err - } + if err := eg.Wait(); err != nil { + return err } final := &pb.List{} for _, out := range filtered {