From 9fed7b1749503f4d2813a6b6e618279bee2e5770 Mon Sep 17 00:00:00 2001 From: eileenaaa Date: Fri, 15 Aug 2025 11:00:39 +0800 Subject: [PATCH 1/2] cancel goroutine when has error --- worker/task.go | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/worker/task.go b/worker/task.go index 92c1d02350f..ab873d806d7 100644 --- a/worker/task.go +++ b/worker/task.go @@ -798,12 +798,19 @@ func (qs *queryState) handleUidPostings( errCh := make(chan error, numGo) outputs := make([]*pb.Result, numGo) + cctx, ccancel := context.WithCancel(ctx) + defer ccancel() calculate := func(start, end int) error { x.AssertTrue(start%width == 0) out := &pb.Result{} outputs[start/width] = out for i := start; i < end; i++ { + select { + case <-cctx.Done(): + return cctx.Err() + default: + } if i%100 == 0 { select { case <-ctx.Done(): @@ -951,7 +958,13 @@ func (qs *queryState) handleUidPostings( end = srcFn.n } go func(start, end int) { - errCh <- calculate(start, end) + if err := calculate(start, end); err != nil { + errCh <- err + ccancel() + return + } else { + errCh <- nil + } }(start, end) } for range numGo { @@ -1597,11 +1610,18 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error attribute.Int("num_go", numGo), attribute.Int("width", width))) + cctx, ccancel := context.WithCancel(ctx) + defer ccancel() filtered := make([]*pb.List, numGo) filter := func(idx, start, end int) error { filtered[idx] = &pb.List{} out := filtered[idx] for _, uid := range uids.Uids[start:end] { + select { + case <-cctx.Done(): + return cctx.Err() + default: + } pl, err := qs.cache.Get(x.DataKey(attr, uid)) if err != nil { return err @@ -1631,7 +1651,13 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error end = len(uids.Uids) } go func(idx, start, end int) { - errCh <- filter(idx, start, end) + if err := filter(idx, start, end); err != nil { + errCh <- err + ccancel() + return + } else { + errCh <- nil + } }(i, start, end) } for range numGo { From 2549c89c461c2184fd546d6ea8406335c1dcdb02 Mon Sep 17 00:00:00 2001 From: eileenaaa Date: Fri, 15 Aug 2025 16:42:04 +0800 Subject: [PATCH 2/2] use errgroup to cancel goroutine --- worker/task.go | 53 ++++++++++++++++---------------------------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/worker/task.go b/worker/task.go index ab873d806d7..5e575997372 100644 --- a/worker/task.go +++ b/worker/task.go @@ -795,11 +795,9 @@ func (qs *queryState) handleUidPostings( needFiltering := needsStringFiltering(srcFn, q.Langs, q.Attr) isList := schema.State().IsList(q.Attr) - errCh := make(chan error, numGo) outputs := make([]*pb.Result, numGo) - cctx, ccancel := context.WithCancel(ctx) - defer ccancel() + eg, egCtx := errgroup.WithContext(ctx) calculate := func(start, end int) error { x.AssertTrue(start%width == 0) out := &pb.Result{} @@ -807,8 +805,8 @@ func (qs *queryState) handleUidPostings( for i := start; i < end; i++ { select { - case <-cctx.Done(): - return cctx.Err() + case <-egCtx.Done(): + return egCtx.Err() default: } if i%100 == 0 { @@ -957,20 +955,12 @@ func (qs *queryState) handleUidPostings( if end > srcFn.n { end = srcFn.n } - go func(start, end int) { - if err := calculate(start, end); err != nil { - errCh <- err - ccancel() - return - } else { - errCh <- nil - } - }(start, end) + eg.Go(func() error { + return calculate(start, end) + }) } - for range numGo { - if err := <-errCh; err != nil { - return err - } + if err := eg.Wait(); err != nil { + return err } // All goroutines are done. Now attach their results. out := args.out @@ -1610,16 +1600,15 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error attribute.Int("num_go", numGo), attribute.Int("width", width))) - cctx, ccancel := context.WithCancel(ctx) - defer ccancel() + eg, egCtx := errgroup.WithContext(ctx) filtered := make([]*pb.List, numGo) filter := func(idx, start, end int) error { filtered[idx] = &pb.List{} out := filtered[idx] for _, uid := range uids.Uids[start:end] { select { - case <-cctx.Done(): - return cctx.Err() + case <-egCtx.Done(): + return egCtx.Err() default: } pl, err := qs.cache.Get(x.DataKey(attr, uid)) @@ -1643,27 +1632,19 @@ func (qs *queryState) filterGeoFunction(ctx context.Context, arg funcArgs) error return nil } - errCh := make(chan error, numGo) for i := range numGo { start := i * width end := start + width if end > len(uids.Uids) { end = len(uids.Uids) } - go func(idx, start, end int) { - if err := filter(idx, start, end); err != nil { - errCh <- err - ccancel() - return - } else { - errCh <- nil - } - }(i, start, end) + idx := i + eg.Go(func() error { + return filter(idx, start, end) + }) } - for range numGo { - if err := <-errCh; err != nil { - return err - } + if err := eg.Wait(); err != nil { + return err } final := &pb.List{} for _, out := range filtered {