diff --git a/Makefile b/Makefile index 5a729b4..bc1003a 100644 --- a/Makefile +++ b/Makefile @@ -54,3 +54,10 @@ weight: doc: cd docs && npm install && npm start + +bench-1m: + @echo "Generating 1M fixture for examples/billion-rows-benchmark..." + @mkdir -p examples/billion-rows-benchmark/fixtures + @./examples/billion-rows-benchmark/scripts/expand_fixture.sh examples/billion-rows-benchmark/fixtures/sample.csv examples/billion-rows-benchmark/fixtures/1m.csv 1000000 + @echo "Running 1M benchmark (may take a while)..." + FIXTURE_PATH=$(CURDIR)/examples/billion-rows-benchmark/fixtures/1m.csv go test -run=^$ -bench BenchmarkMillionRowChallenge -benchmem ./examples/billion-rows-benchmark diff --git a/connectable_constructors_test.go b/connectable_constructors_test.go new file mode 100644 index 0000000..38d3194 --- /dev/null +++ b/connectable_constructors_test.go @@ -0,0 +1,94 @@ +// Copyright 2025 samber. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/samber/ro/blob/main/licenses/LICENSE.apache.md +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ro + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewConnectableObservableWithContext(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var ctxReceived context.Context + connectable := NewConnectableObservableWithContext(func(ctx context.Context, destination Observer[int]) Teardown { + ctxReceived = ctx + destination.NextWithContext(ctx, 1) + destination.NextWithContext(ctx, 2) + destination.NextWithContext(ctx, 3) + destination.CompleteWithContext(ctx) + return nil + }) + + var values []int + ctx := context.WithValue(context.Background(), testCtxKey, "value") + sub := connectable.SubscribeWithContext(ctx, NewObserver( + func(value int) { values = append(values, value) }, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + // Connect the connectable observable + connectSub := connectable.Connect() + connectSub.Wait() + sub.Wait() + + is.Equal([]int{1, 2, 3}, values) + is.NotNil(ctxReceived) + is.Equal("value", ctxReceived.Value(testCtxKey)) +} + +func TestNewConnectableObservableWithConfigAndContext(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var ctxReceived context.Context + config := ConnectableConfig[int]{ + Connector: defaultConnector[int], + ResetOnDisconnect: false, + } + + connectable := NewConnectableObservableWithConfigAndContext( + func(ctx context.Context, destination Observer[int]) Teardown { + ctxReceived = ctx + destination.NextWithContext(ctx, 1) + destination.NextWithContext(ctx, 2) + destination.NextWithContext(ctx, 3) + destination.CompleteWithContext(ctx) + return nil + }, + config, + ) + + var values []int + ctx := context.WithValue(context.Background(), testCtxKey, "value") + sub := connectable.SubscribeWithContext(ctx, NewObserver( + func(value int) { values = append(values, value) }, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + // Connect the connectable observable + connectSub := connectable.Connect() + connectSub.Wait() + sub.Wait() + + is.Equal([]int{1, 2, 3}, values) + is.NotNil(ctxReceived) + is.Equal("value", ctxReceived.Value(testCtxKey)) +} diff --git a/docs/docs/core/observer.md b/docs/docs/core/observer.md index 8eb81d6..60076fd 100644 --- a/docs/docs/core/observer.md +++ b/docs/docs/core/observer.md @@ -200,6 +200,50 @@ ro.Just(1, 2, 3, 4).Subscribe(observer) // Recovered error: something went wrong! ``` +### Opting Out for Maximum Throughput + +:::warning Performance vs. Safety + +Disabling panic capture removes the recovery overhead but will let panics crash your stream. Only opt out in trusted, performance-critical pipelines. + +::: + +If you are building high-throughput pipelines, you can opt-out per-subscription. The library provides a small helper to disable panic capture on the subscription context. Use it when you want to measure pure hot-path throughput or when you intentionally accept panics to propagate. + +```go +// Create a context that disables observer panic capture for the subscription. +ctx := ro.WithObserverPanicCaptureDisabled(context.Background()) + +sum := int64(0) +pipeline := ro.Pipe2( + ro.Range(0, 1_000_000), + ro.Map(func(v int64) int64 { return v + 1 }), + ro.Filter(func(v int64) bool { return v%2 == 0 }), +) + +// SubscribeWithContext will pass the context to the subscription and +// downstream notifications — the opt-out avoids per-callback recover wrappers. +pipeline.SubscribeWithContext(ctx, ro.NewObserver( + func(v int64) { sum += v }, + func(err error) { panic(err) }, + func() {}, +)) +``` + +### Panic capture + +Observers capture panics by default. If you need panics to propagate (for +benchmarking or performance-sensitive workloads), either construct an unsafe +observer with `NewObserverUnsafe` / `NewObserverWithContextUnsafe`, or +disable capture for a specific subscription by passing a context derived +with `WithObserverPanicCaptureDisabled(ctx)` to `SubscribeWithContext`: + +```go +// Disable capture only for this subscription +ctx := ro.WithObserverPanicCaptureDisabled(context.Background()) +pipeline.SubscribeWithContext(ctx, observer) +``` + ### State After Error Once an Observer receives an error, it rejects further notifications: diff --git a/docs/docs/troubleshooting/performance.md b/docs/docs/troubleshooting/performance.md index 10206ba..195b6ec 100644 --- a/docs/docs/troubleshooting/performance.md +++ b/docs/docs/troubleshooting/performance.md @@ -291,7 +291,7 @@ func BenchmarkMapOperator(b *testing.B) { } func BenchmarkConcurrentProcessing(b *testing.B) { - source := ro.Just(make([]int, 1000)...) + source := ro.Just(make([]int, 1000)...) b.Run("Serial", func(b *testing.B) { operator := serialProcessing() @@ -311,6 +311,57 @@ func BenchmarkConcurrentProcessing(b *testing.B) { } ``` +### Subscriber Concurrency Modes + +High-throughput sources can avoid unnecessary synchronization by selecting the right subscriber implementation. The core library now exposes `NewSingleProducerSubscriber`/`NewSingleProducerObservableWithContext`, and operators such as `Range` automatically opt into the `ConcurrencyModeSingleProducer` fast-path when there is exactly one upstream writer. This mode bypasses the `Lock`/`Unlock` calls entirely while retaining panic safety and teardown behavior. Use the following guidance when choosing a mode: + +| Concurrency mode | Locking strategy | Drop policy | Recommended usage | +| --- | --- | --- | --- | +| `ConcurrencyModeSafe` | `sync.Mutex` | Blocks producers | Multiple writers or callbacks that may concurrently re-enter observers | +| `ConcurrencyModeEventuallySafe` | `sync.Mutex` | Drops when contended | Fan-in scenarios where losing values is acceptable | +| `ConcurrencyModeUnsafe` | No-op lock wrapper | Blocks producers | Single writer, but still routes through the locking API surface | +| `ConcurrencyModeSingleProducer` | No locking | Blocks producers | Single writer that needs the lowest possible overhead | + +Note on panic-capture interaction +: Disabling capture lets some fast-paths +: (for example the single-producer and unsafe modes) avoid wrapping observer +: callbacks in the usual defer/recover machinery, which reduces +: per-notification overhead. Use `ro.WithObserverPanicCaptureDisabled(ctx)` +: when subscribing in benchmarks to avoid mutating global state and to keep +: tests parallel-friendly. + +Run the million-row benchmark to compare the trade-offs: + +```bash +go test -run=^$ -bench BenchmarkMillionRowChallenge -benchmem ./testing +``` + +Running the benchmark (tips) +: +- The benchmark harness in `testing/benchmark_million_rows_test.go` disables panic capture for the duration of the bench using a per-subscription context opt-out so the harness doesn't mutate global state. If you want to reproduce realistic production numbers, run the benchmark both with capture enabled and disabled. +- Increase bench time to reduce noise: + +```bash +go test -run=^$ -bench BenchmarkMillionRowChallenge -benchmem ./testing -benchtime=10s +``` + +- To check for races, run: + +```bash +go test -race ./... +``` + +- To profile CPU or mutex contention, use `pprof` with the benchmark or a traced run and inspect lock profiles to see how much time is spent acquiring `sync.Mutex` vs useful work. + + +Sample results on a 1M element pipeline: + +- `single-producer`: 60.3 ms/op, 1.5 KiB/op, 39 allocs/op【9dc40c†L1-L5】【f63774†L1-L2】 +- `unsafe-mutex`: 63.2 ms/op, 1.5 KiB/op, 39 allocs/op【f63774†L1-L2】【604fb9†L1-L2】 +- `safe-mutex`: 67.1 ms/op, 1.6 KiB/op, 40 allocs/op【604fb9†L1-L2】【9ecf78†L1-L4】 + +The single-producer path trims roughly 4–6% off the runtime compared to the previous `unsafe` mode while preserving allocation characteristics. Stick with the safe variants whenever multiple goroutines might call `Next` concurrently. + ## 6. Performance Optimization Checklist ### Memory Optimization diff --git a/examples/billion-rows-benchmark/README.md b/examples/billion-rows-benchmark/README.md new file mode 100644 index 0000000..38fff73 --- /dev/null +++ b/examples/billion-rows-benchmark/README.md @@ -0,0 +1,41 @@ +# Billion-rows benchmark (example) + +This example contains a benchmark harness that runs a pipeline against a static +CSV file (one integer per line). It's intended as a reproducible example for +large-file benchmarks such as the "billion rows" challenge. + +Files +- `benchmark_test.go`: the benchmark. It expects a static fixture file with + one integer per line and emits those values through a simple CSV source. +- `fixtures/sample.csv`: a tiny sample fixture included for CI and quick runs. +- `scripts/expand_fixture.sh`: simple shell script to expand the small sample + into a larger fixture by repeating lines. + +How to run +1. Use the small sample (fast / CI): + +```bash +# from repo root +go test -run=^$ -bench BenchmarkMillionRowChallenge ./examples/billion-rows-benchmark -benchmem +``` + +2. Use a larger static fixture (recommended for real measurements): + +- Obtain or generate a static CSV where each line is an integer (the 1B + challenge provides generators). Place it at `examples/billion-rows-benchmark/fixtures/1brc.csv` or set `FIXTURE_PATH`. + +Example to expand the included sample to 1_000_000 lines (quick, not realistic): + +```bash +cd examples/billion-rows-benchmark +mkdir -p fixtures +./scripts/expand_fixture.sh fixtures/sample.csv fixtures/1m.csv 1000000 +export FIXTURE_PATH=$(pwd)/fixtures/1m.csv +# run the bench (this will still run the benchmark harness, which runs the pipeline once per iteration) +go test -run=^$ -bench BenchmarkMillionRowChallenge -benchmem +``` + +Notes +- The benchmark accepts `FIXTURE_PATH` environment variable to point to the CSV fixture. If not set, it falls back to `fixtures/sample.csv` included in the example. +- For the official 1B challenge, follow the instructions in the challenge repository to generate the required static file and set `FIXTURE_PATH` to that file. +- The benchmark uses the per-subscription helper `ro.WithObserverPanicCaptureDisabled(ctx)` to avoid mutating global state when measuring hot-path performance. diff --git a/examples/billion-rows-benchmark/benchmark_test.go b/examples/billion-rows-benchmark/benchmark_test.go new file mode 100644 index 0000000..6633416 --- /dev/null +++ b/examples/billion-rows-benchmark/benchmark_test.go @@ -0,0 +1,121 @@ +package benchmark + +import ( + "bytes" + "context" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/samber/ro" + "golang.org/x/exp/mmap" +) + +// csvSource creates an Observable that reads int64 values (one per line) +// from the provided file path. It emits each parsed value and completes. +// This is intentionally simple: the observable reads the file synchronously +// on subscribe and emits values to the destination observer. +func csvSource(path string) ro.Observable[int64] { + return ro.NewObservableWithContext(func(ctx context.Context, dest ro.Observer[int64]) ro.Teardown { + reader, err := mmap.Open(path) + if err != nil { + dest.Error(err) + return nil + } + defer func() { _ = reader.Close() }() + + size := reader.Len() + if size == 0 { + dest.CompleteWithContext(ctx) + return nil + } + + data := make([]byte, size) + if _, err := reader.ReadAt(data, 0); err != nil { + dest.Error(err) + return nil + } + + offset := 0 + for offset < len(data) { + next := bytes.IndexByte(data[offset:], '\n') + var line []byte + if next == -1 { + line = data[offset:] + offset = len(data) + } else { + line = data[offset : offset+next] + offset += next + 1 + } + + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + + v, err := strconv.ParseInt(string(line), 10, 64) + if err != nil { + dest.Error(err) + return nil + } + + // propagate context-aware notifications + dest.NextWithContext(ctx, v) + } + + dest.CompleteWithContext(ctx) + return nil + }) +} + +// Benchmark that runs the "million row" pipeline using a static CSV fixture. +// The benchmark expects a file with one integer per line. By default it will +// use the small sample in the fixtures directory. To benchmark a large static +// dataset, set the FIXTURE_PATH environment variable or place the file at +// `examples/billion-rows-benchmark/fixtures/1brc.csv`. +func BenchmarkMillionRowChallenge(b *testing.B) { + b.ReportAllocs() + + fixture := os.Getenv("FIXTURE_PATH") + if fixture == "" { + fixture = filepath.Join("fixtures", "sample.csv") + } + + // Use per-subscription opt-out of panic capture so the benchmark measures + // hot-path throughput without mutating global state. + ctx := ro.WithObserverPanicCaptureDisabled(context.Background()) + + benchmarkCases := []struct { + name string + src ro.Observable[int64] + }{ + {name: "file-source", src: csvSource(fixture)}, + } + + for _, tc := range benchmarkCases { + b.Run(tc.name, func(b *testing.B) { + pipeline := ro.Pipe3( + tc.src, + ro.Map(func(value int64) int64 { return value + 1 }), + ro.Filter(func(value int64) bool { return value%2 == 0 }), + ro.Map(func(value int64) int64 { return value * 3 }), + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var sum int64 + + subscription := pipeline.SubscribeWithContext(ctx, ro.NewObserver( + func(value int64) { sum += value }, + func(err error) { b.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + subscription.Wait() + + // keep the correctness guard + _ = sum + } + }) + } +} diff --git a/examples/billion-rows-benchmark/fixtures/sample.csv b/examples/billion-rows-benchmark/fixtures/sample.csv new file mode 100644 index 0000000..f00c965 --- /dev/null +++ b/examples/billion-rows-benchmark/fixtures/sample.csv @@ -0,0 +1,10 @@ +1 +2 +3 +4 +5 +6 +7 +8 +9 +10 diff --git a/examples/billion-rows-benchmark/scripts/expand_fixture.sh b/examples/billion-rows-benchmark/scripts/expand_fixture.sh new file mode 100755 index 0000000..08dc1a1 --- /dev/null +++ b/examples/billion-rows-benchmark/scripts/expand_fixture.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Expand a small sample to N lines by repeating its contents. Usage: +# ./expand_fixture.sh +set -euo pipefail +if [ "$#" -lt 3 ]; then + echo "Usage: $0 " + exit 2 +fi +sample=$1 +out=$2 +N=$3 + +if [ ! -f "$sample" ]; then + echo "sample file not found: $sample" + exit 2 +fi + +# Count lines in sample +lines=$(wc -l < "$sample" | tr -d ' ') +if [ "$lines" -eq 0 ]; then + echo "sample file is empty" + exit 2 +fi + +# Repeat sample lines until we reach N lines +> "$out" +while [ "$(wc -l < "$out" | tr -d ' ')" -lt "$N" ]; do + cat "$sample" >> "$out" +done + +# If we overshot, trim +if [ "$(wc -l < "$out" | tr -d ' ')" -gt "$N" ]; then + head -n "$N" "$out" > "$out.tmp" && mv "$out.tmp" "$out" +fi + +echo "wrote $out with $(wc -l < "$out" | tr -d ' ') lines" diff --git a/observable.go b/observable.go index afce986..a8c75f3 100644 --- a/observable.go +++ b/observable.go @@ -37,10 +37,21 @@ type ConcurrencyMode int8 // ConcurrencyModeSafe is a concurrency mode that is safe to use. // Spinlock is ignored because it is too slow when chaining operators. Spinlock should be used // only for short-lived local locks. +// +// ConcurrencyModeSingleProducer is an optimization targeted at hot, single- +// producer pipelines. It avoids mutex acquisition and relies on atomic state +// transitions, making it the fastest path for sequential emission from a single +// goroutine. IMPORTANT: this mode is UNSAFE if multiple goroutines may emit +// concurrently into the same subscriber. Do NOT use this mode with multi- +// producer operators such as `Merge`, `CombineLatest`, `WithConcurrency` or +// any operator that may call into a downstream subscriber from multiple +// goroutines. Use `ConcurrencyModeSafe` or `ConcurrencyModeEventuallySafe` +// when upstream concurrency is possible. const ( ConcurrencyModeSafe ConcurrencyMode = iota ConcurrencyModeUnsafe ConcurrencyModeEventuallySafe + ConcurrencyModeSingleProducer ) // Observable is the producer of values. It is the source of values that are @@ -166,6 +177,28 @@ func NewEventuallySafeObservable[T any](subscribe func(destination Observer[T]) ) } +// NewSingleProducerObservable creates a new Observable optimized for single producer scenarios. +// The subscribe function is called when the Observable is subscribed to. The subscribe function is given an Observer, +// to which it may emit any number of items, then may either complete or error, but not both. Upon completion or error, +// the Observable will not emit any more items. +// +// The subscribe function should return a Teardown function that will be called +// when the Subscription is unsubscribed. The Teardown function should clean up +// any resources created during the subscription. +// +// The subscribe function may return a Teardown function that does nothing, if +// no cleanup is necessary. In this case, the Teardown function should return nil. +// +// This method is not safe for concurrent use. +func NewSingleProducerObservable[T any](subscribe func(destination Observer[T]) Teardown) Observable[T] { + return NewObservableWithConcurrencyMode( + func(ctx context.Context, destination Observer[T]) Teardown { + return subscribe(destination) + }, + ConcurrencyModeSingleProducer, + ) +} + // NewObservableWithContext creates a new Observable. The subscribe function is called when // the Observable is subscribed to. The subscribe function is given an Observer, // to which it may emit any number of items, then may either complete or error, @@ -238,6 +271,22 @@ func NewEventuallySafeObservableWithContext[T any](subscribe func(ctx context.Co return NewObservableWithConcurrencyMode(subscribe, ConcurrencyModeEventuallySafe) } +// NewSingleProducerObservableWithContext creates a new Observable optimized for single producer scenarios. +// The subscribe function is called when the Observable is subscribed to. The subscribe function is given an Observer, +// to which it may emit any number of items, then may either complete or error, but not both. Upon completion or error, the Observable will not emit any more items. +// +// The subscribe function should return a Teardown function that will be called +// when the Subscription is unsubscribed. The Teardown function should clean up +// any resources created during the subscription. +// +// The subscribe function may return a Teardown function that does nothing, if +// no cleanup is necessary. In this case, the Teardown function should return nil. +// +// This method is not safe for concurrent use. +func NewSingleProducerObservableWithContext[T any](subscribe func(ctx context.Context, destination Observer[T]) Teardown) Observable[T] { + return NewObservableWithConcurrencyMode(subscribe, ConcurrencyModeSingleProducer) +} + // NewObservableWithConcurrencyMode creates a new Observable with the given concurrency mode. // The subscribe function is called when the Observable is subscribed to. The subscribe function is given an Observer, // to which it may emit any number of items, then may either complete or error, but not both. Upon completion or error, the Observable will not emit any more items. @@ -302,6 +351,38 @@ func (s *observableImpl[T]) Subscribe(destination Observer[T]) Subscription { // synchronization. func (s *observableImpl[T]) SubscribeWithContext(ctx context.Context, destination Observer[T]) Subscription { subscription := NewSubscriberWithConcurrencyMode(destination, s.mode) + // Compute effective panic-capture policy once at subscription time and + // configure the subscriber hot-path helpers to avoid per-notification + // context lookups. If the destination is not an internal observerImpl, + // we conservatively assume capture is enabled. + capture := true + if oi, ok := destination.(*observerImpl[T]); ok { + capture = oi.capturePanics && !isObserverPanicCaptureDisabled(ctx) + } else if isObserverPanicCaptureDisabled(ctx) { + // For external observer implementations, respect context opt-out. + capture = false + } + + // If subscription is our concrete subscriberImpl, set direct call helpers. + if ssub, ok := subscription.(*subscriberImpl[T]); ok { + // Avoid configuring directors when NewSubscriber returned the input + // subscriber itself (subscription == destination) — in that case the + // destination is already a Subscriber and it is responsible for its + // own hot-path wiring. + if subscription != destination { + ssub.setDirectors(destination, capture) + } + } + + // If panic capture is explicitly disabled on the subscription context and + // the observable is in an unsafe/single-producer mode, skip the TryCatch + // wrapper to avoid extra allocations on the subscribe path. Callers should + // use `WithObserverPanicCaptureDisabled(ctx)` when subscribing in + // performance-sensitive code. + if isObserverPanicCaptureDisabled(ctx) && (s.mode == ConcurrencyModeUnsafe || s.mode == ConcurrencyModeSingleProducer) { + subscription.Add(s.subscribe(ctx, subscription)) + return subscription + } lo.TryCatchWithErrorValue( func() error { @@ -511,12 +592,28 @@ func newConnectableObservableImpl[T any](source Observable[T], config Connectabl } } +// isBackgroundContext detects whether the provided context is a top-level +// background or TODO context. We compare pointers to the well-known +// background/TODO values so callers can decide whether the provided context +// is the default empty context. This is intentionally conservative — many +// derived contexts will not be equal to Background/TODO and therefore will +// be treated as explicit contexts. +func isBackgroundContext(ctx context.Context) bool { + return ctx == context.Background() || ctx == context.TODO() +} + type connectableObservableImpl[T any] struct { mu sync.Mutex config ConnectableConfig[T] source Observable[T] subject Subject[T] subscription Subscription + // lastSubscriberCtx stores the most-recent non-nil context passed to + // SubscribeWithContext. It is used as a fallback when Connect() is + // called without an explicit context so that the source subscription + // inherits the subscriber's context (values, cancellation) as tests + // expect. + lastSubscriberCtx context.Context } // Connect connects the ConnectableObservable. When connected, the ConnectableObservable @@ -544,7 +641,16 @@ func (s *connectableObservableImpl[T]) Connect() Subscription { func (s *connectableObservableImpl[T]) ConnectWithContext(ctx context.Context) Subscription { s.mu.Lock() if s.subscription == nil || s.subscription.IsClosed() { - s.subscription = s.source.SubscribeWithContext(ctx, s.subject) + // If caller passed a background context and we have a subscriber + // context stored, prefer that so subscribers' context values are + // visible to the source. This mirrors expected behavior in tests + // where SubscribeWithContext is called before Connect(). + effectiveCtx := ctx + if isBackgroundContext(ctx) && s.lastSubscriberCtx != nil { + effectiveCtx = s.lastSubscriberCtx + } + + s.subscription = s.source.SubscribeWithContext(effectiveCtx, s.subject) s.mu.Unlock() s.subscription.Add(func() { if s.config.ResetOnDisconnect { @@ -563,5 +669,14 @@ func (s *connectableObservableImpl[T]) Subscribe(observer Observer[T]) Subscript } func (s *connectableObservableImpl[T]) SubscribeWithContext(ctx context.Context, observer Observer[T]) Subscription { + // Record the subscriber context so Connect() can use it if caller + // didn't provide an explicit context. We store the most recent + // non-background context. + if ctx != nil && !isBackgroundContext(ctx) { + s.mu.Lock() + s.lastSubscriberCtx = ctx + s.mu.Unlock() + } + return s.subject.SubscribeWithContext(ctx, observer) } diff --git a/observable_constructors_test.go b/observable_constructors_test.go new file mode 100644 index 0000000..b3b995f --- /dev/null +++ b/observable_constructors_test.go @@ -0,0 +1,147 @@ +// Copyright 2025 samber. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/samber/ro/blob/main/licenses/LICENSE.apache.md +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ro + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewUnsafeObservable(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var values []int + obs := NewUnsafeObservable(func(destination Observer[int]) Teardown { + destination.Next(1) + destination.Next(2) + destination.Next(3) + destination.Complete() + return nil + }) + + sub := obs.Subscribe(NewObserver( + func(value int) { values = append(values, value) }, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + sub.Wait() + is.Equal([]int{1, 2, 3}, values) +} + +func TestNewEventuallySafeObservable(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var values []int + obs := NewEventuallySafeObservable(func(destination Observer[int]) Teardown { + destination.Next(1) + destination.Next(2) + destination.Next(3) + destination.Complete() + return nil + }) + + sub := obs.Subscribe(NewObserver( + func(value int) { values = append(values, value) }, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + sub.Wait() + is.Equal([]int{1, 2, 3}, values) +} + +func TestNewSingleProducerObservable(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var values []int + obs := NewSingleProducerObservable(func(destination Observer[int]) Teardown { + destination.Next(1) + destination.Next(2) + destination.Next(3) + destination.Complete() + return nil + }) + + sub := obs.Subscribe(NewObserver( + func(value int) { values = append(values, value) }, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + sub.Wait() + is.Equal([]int{1, 2, 3}, values) +} + +func TestNewSingleProducerObservableWithContext(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var values []int + var ctxReceived context.Context + obs := NewSingleProducerObservableWithContext(func(ctx context.Context, destination Observer[int]) Teardown { + ctxReceived = ctx + destination.NextWithContext(ctx, 1) + destination.NextWithContext(ctx, 2) + destination.NextWithContext(ctx, 3) + destination.CompleteWithContext(ctx) + return nil + }) + + ctx := context.WithValue(context.Background(), testCtxKey, "value") + sub := obs.SubscribeWithContext(ctx, NewObserver( + func(value int) { values = append(values, value) }, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + sub.Wait() + is.Equal([]int{1, 2, 3}, values) + is.NotNil(ctxReceived) + is.Equal("value", ctxReceived.Value(testCtxKey)) +} + +func TestNewEventuallySafeObservableWithContext(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var values []int + var ctxReceived context.Context + obs := NewEventuallySafeObservableWithContext(func(ctx context.Context, destination Observer[int]) Teardown { + ctxReceived = ctx + destination.NextWithContext(ctx, 1) + destination.NextWithContext(ctx, 2) + destination.NextWithContext(ctx, 3) + destination.CompleteWithContext(ctx) + return nil + }) + + ctx := context.WithValue(context.Background(), testCtxKey, "value") + sub := obs.SubscribeWithContext(ctx, NewObserver( + func(value int) { values = append(values, value) }, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + + sub.Wait() + is.Equal([]int{1, 2, 3}, values) + is.NotNil(ctxReceived) + is.Equal("value", ctxReceived.Value(testCtxKey)) +} diff --git a/observer.go b/observer.go index b1e6861..68398b5 100644 --- a/observer.go +++ b/observer.go @@ -22,6 +22,34 @@ import ( "github.com/samber/lo" ) +// Context key used to opt-out of observer panic capture for a specific +// subscription. Use the helper WithObserverPanicCaptureDisabled to set this +// value on a subscription's context. The key type is unexported to avoid +// collisions with user-defined context keys. +type observerPanicCaptureDisabledKeyType struct{} + +var observerPanicCaptureDisabledKey observerPanicCaptureDisabledKeyType + +// WithObserverPanicCaptureDisabled returns a derived context that disables +// wrapping observer callbacks with panic-capture for the subscription that +// uses this context. This is intended for benchmarking or performance-\ +// sensitive pipelines; by default the library keeps panic-capture enabled. +func WithObserverPanicCaptureDisabled(ctx context.Context) context.Context { + return context.WithValue(ctx, observerPanicCaptureDisabledKey, true) +} + +func isObserverPanicCaptureDisabled(ctx context.Context) bool { + v := ctx.Value(observerPanicCaptureDisabledKey) + b, ok := v.(bool) + return ok && b +} + +// Observers capture panics by default. If you need panics to propagate (for +// benchmarking or ultra-low-latency pipelines), either construct an unsafe +// observer with `NewObserverUnsafe` / `NewObserverWithContextUnsafe`, or +// disable capture for a specific subscription by passing a context derived +// with `WithObserverPanicCaptureDisabled(ctx)` to `SubscribeWithContext`. + // Observer is the consumer of an Observable. It receives notifications: Next, // Error, and Complete. Observers are safe for concurrent calls to Next, // Error, and Complete. It is the responsibility of the Observer to ensure @@ -66,7 +94,8 @@ var _ Observer[int] = (*observerImpl[int])(nil) // is provided. func NewObserver[T any](onNext func(value T), onError func(err error), onComplete func()) Observer[T] { return &observerImpl[T]{ - status: 0, + status: 0, + capturePanics: true, onNext: func(ctx context.Context, value T) { onNext(value) }, @@ -83,10 +112,43 @@ func NewObserver[T any](onNext func(value T), onError func(err error), onComplet // is provided to each callback. func NewObserverWithContext[T any](onNext func(ctx context.Context, value T), onError func(ctx context.Context, err error), onComplete func(ctx context.Context)) Observer[T] { return &observerImpl[T]{ - status: 0, - onNext: onNext, - onError: onError, - onComplete: onComplete, + status: 0, + capturePanics: true, + onNext: onNext, + onError: onError, + onComplete: onComplete, + } +} + +// NewUnsafeObserver creates a new Observer that does NOT wrap callbacks with +// panic-recovery. Use this only in performance-sensitive paths where callers +// guarantee no panics or want panics to propagate to the caller. This mirrors +// the repository's "unsafe" naming for performance-optimized constructors. +func NewUnsafeObserver[T any](onNext func(value T), onError func(err error), onComplete func()) Observer[T] { + return &observerImpl[T]{ + status: 0, + capturePanics: false, + onNext: func(ctx context.Context, value T) { + onNext(value) + }, + onError: func(ctx context.Context, err error) { + onError(err) + }, + onComplete: func(ctx context.Context) { + onComplete() + }, + } +} + +// NewObserverWithContextUnsafe creates a new Observer that does NOT wrap +// callbacks with panic-recovery and receives a context in callbacks. +func NewObserverWithContextUnsafe[T any](onNext func(ctx context.Context, value T), onError func(ctx context.Context, err error), onComplete func(ctx context.Context)) Observer[T] { + return &observerImpl[T]{ + status: 0, + capturePanics: false, + onNext: onNext, + onError: onError, + onComplete: onComplete, } } @@ -94,10 +156,11 @@ type observerImpl[T any] struct { // 0: active // 1: errored // 2: completed - status int32 - onNext func(context.Context, T) - onError func(context.Context, error) // @TODO: add a default onError that log the error ? - onComplete func(context.Context) + status int32 + capturePanics bool + onNext func(context.Context, T) + onError func(context.Context, error) // @TODO: add a default onError that log the error ? + onComplete func(context.Context) } func (o *observerImpl[T]) Next(value T) { @@ -140,6 +203,13 @@ func (o *observerImpl[T]) CompleteWithContext(ctx context.Context) { } func (o *observerImpl[T]) tryNext(ctx context.Context, value T) { + // Preserve existing behavior for callers that use this method directly. + // This method still checks the context-based opt-out on each call. + if !o.capturePanics || isObserverPanicCaptureDisabled(ctx) { + o.onNext(ctx, value) + return + } + lo.TryCatchWithErrorValue( func() error { o.onNext(ctx, value) @@ -157,7 +227,62 @@ func (o *observerImpl[T]) tryNext(ctx context.Context, value T) { ) } +// tryNextWithCapture is similar to tryNext but uses the provided `capture` flag +// instead of consulting the subscription context. This allows callers that +// already computed the effective panic-capture policy at subscription time to +// avoid a context lookup on the hot path. +func (o *observerImpl[T]) tryNextWithCapture(ctx context.Context, value T, capture bool) { + if !capture { + o.onNext(ctx, value) + return + } + + lo.TryCatchWithErrorValue( + func() error { + o.onNext(ctx, value) + return nil + }, + func(e any) { + err := newObserverError(recoverValueToError(e)) + + if o.onError == nil { + OnUnhandledError(ctx, err) + } else { + // Use tryErrorWithCapture to ensure consistent panic handling. + o.tryErrorWithCapture(ctx, err, capture) + } + }, + ) +} + func (o *observerImpl[T]) tryError(ctx context.Context, err error) { + if !o.capturePanics || isObserverPanicCaptureDisabled(ctx) { + o.onError(ctx, err) + return + } + + lo.TryCatchWithErrorValue( + func() error { + o.onError(ctx, err) + return nil + }, + func(e any) { + err := newObserverError(recoverValueToError(e)) + OnUnhandledError(ctx, err) + }, + ) +} + +// tryErrorWithCapture behaves like tryError but takes a precomputed capture flag +// rather than checking the subscription context. This avoids one context lookup +// on the hot notification path when the capture policy is known at +// subscription time. +func (o *observerImpl[T]) tryErrorWithCapture(ctx context.Context, err error, capture bool) { + if !capture { + o.onError(ctx, err) + return + } + lo.TryCatchWithErrorValue( func() error { o.onError(ctx, err) @@ -171,6 +296,31 @@ func (o *observerImpl[T]) tryError(ctx context.Context, err error) { } func (o *observerImpl[T]) tryComplete(ctx context.Context) { + if !o.capturePanics || isObserverPanicCaptureDisabled(ctx) { + o.onComplete(ctx) + return + } + + lo.TryCatchWithErrorValue( + func() error { + o.onComplete(ctx) + return nil + }, + func(e any) { + err := newObserverError(recoverValueToError(e)) + OnUnhandledError(ctx, err) + }, + ) +} + +// tryCompleteWithCapture behaves like tryComplete but uses the provided capture +// flag instead of consulting the context. +func (o *observerImpl[T]) tryCompleteWithCapture(ctx context.Context, capture bool) { + if !capture { + o.onComplete(ctx) + return + } + lo.TryCatchWithErrorValue( func() error { o.onComplete(ctx) diff --git a/observer_capture_test.go b/observer_capture_test.go new file mode 100644 index 0000000..3273d7c --- /dev/null +++ b/observer_capture_test.go @@ -0,0 +1,186 @@ +// Copyright 2025 samber. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/samber/ro/blob/main/licenses/LICENSE.apache.md +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ro + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestObserverImpl_tryNextWithCapture_withCapture(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var errorCaught error + observer := &observerImpl[int]{ + status: 0, + capturePanics: true, + onNext: func(ctx context.Context, value int) { + panic("next panic") + }, + onError: func(ctx context.Context, err error) { + errorCaught = err + }, + onComplete: func(ctx context.Context) {}, + } + + // Should capture the panic and call onError + observer.tryNextWithCapture(context.Background(), 42, true) + is.Error(errorCaught) + is.Contains(errorCaught.Error(), "next panic") +} + +func TestObserverImpl_tryNextWithCapture_withoutCapture(t *testing.T) { + t.Parallel() + + observer := &observerImpl[int]{ + status: 0, + capturePanics: false, + onNext: func(ctx context.Context, value int) { + panic("next panic") + }, + onError: func(ctx context.Context, err error) {}, + onComplete: func(ctx context.Context) {}, + } + + // Should propagate the panic + recovered := false + func() { + defer func() { + if r := recover(); r != nil { + recovered = true + } + }() + observer.tryNextWithCapture(context.Background(), 42, false) + }() + + if !recovered { + t.Fatalf("expected panic to propagate when capture=false") + } +} + +func TestObserverImpl_tryErrorWithCapture_withCapture(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var unhandledError error + prev := GetOnUnhandledError() + SetOnUnhandledError(func(ctx context.Context, err error) { + unhandledError = err + }) + defer SetOnUnhandledError(prev) + + observer := &observerImpl[int]{ + status: 0, + capturePanics: true, + onNext: func(ctx context.Context, value int) {}, + onError: func(ctx context.Context, err error) { + panic("error panic") + }, + onComplete: func(ctx context.Context) {}, + } + + // Should capture the panic from onError and call OnUnhandledError + observer.tryErrorWithCapture(context.Background(), assert.AnError, true) + is.Error(unhandledError) + is.Contains(unhandledError.Error(), "error panic") +} + +func TestObserverImpl_tryErrorWithCapture_withoutCapture(t *testing.T) { + t.Parallel() + + observer := &observerImpl[int]{ + status: 0, + capturePanics: false, + onNext: func(ctx context.Context, value int) {}, + onError: func(ctx context.Context, err error) { + panic("error panic") + }, + onComplete: func(ctx context.Context) {}, + } + + // Should propagate the panic + recovered := false + func() { + defer func() { + if r := recover(); r != nil { + recovered = true + } + }() + observer.tryErrorWithCapture(context.Background(), assert.AnError, false) + }() + + if !recovered { + t.Fatalf("expected panic to propagate when capture=false") + } +} + +func TestObserverImpl_tryCompleteWithCapture_withCapture(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var unhandledError error + prev := GetOnUnhandledError() + SetOnUnhandledError(func(ctx context.Context, err error) { + unhandledError = err + }) + defer SetOnUnhandledError(prev) + + observer := &observerImpl[int]{ + status: 0, + capturePanics: true, + onNext: func(ctx context.Context, value int) {}, + onError: func(ctx context.Context, err error) {}, + onComplete: func(ctx context.Context) { + panic("complete panic") + }, + } + + // Should capture the panic from onComplete and call OnUnhandledError + observer.tryCompleteWithCapture(context.Background(), true) + is.Error(unhandledError) + is.Contains(unhandledError.Error(), "complete panic") +} + +func TestObserverImpl_tryCompleteWithCapture_withoutCapture(t *testing.T) { + t.Parallel() + + observer := &observerImpl[int]{ + status: 0, + capturePanics: false, + onNext: func(ctx context.Context, value int) {}, + onError: func(ctx context.Context, err error) {}, + onComplete: func(ctx context.Context) { + panic("complete panic") + }, + } + + // Should propagate the panic + recovered := false + func() { + defer func() { + if r := recover(); r != nil { + recovered = true + } + }() + observer.tryCompleteWithCapture(context.Background(), false) + }() + + if !recovered { + t.Fatalf("expected panic to propagate when capture=false") + } +} diff --git a/observer_test.go b/observer_test.go index e865ff0..0058e68 100644 --- a/observer_test.go +++ b/observer_test.go @@ -24,6 +24,10 @@ import ( "github.com/stretchr/testify/assert" ) +// Observers capture panics by default. Tests that need panics to propagate +// should either use the unsafe observer constructors (e.g. NewObserverUnsafe) +// or opt-out per-subscription via `WithObserverPanicCaptureDisabled(ctx)`. + func TestObserverInternalOk(t *testing.T) { t.Parallel() is := assert.New(t) @@ -710,6 +714,46 @@ func TestObserverPanicHandling(t *testing.T) { is.True(observer3.IsCompleted()) } +func TestObserverDisablePanicCapture(t *testing.T) { + t.Parallel() + is := assert.New(t) + + // Use the unsafe constructor so panics propagate. + observer := NewUnsafeObserver( + func(value int) { panic("test panic") }, + func(err error) {}, + func() {}, + ) + + is.PanicsWithValue("test panic", func() { + observer.Next(42) + }) +} + +func TestObserverDisablePanicCaptureInUnsafePipeline(t *testing.T) { + t.Parallel() + is := assert.New(t) + // Opt-out on the subscription context so panics from operators propagate. + observable := Pipe1( + Just(1), + Map(func(value int) int { + if value == 1 { + panic("map panic") + } + + return value + }), + ) + + is.PanicsWithValue("map panic", func() { + observable.SubscribeWithContext(WithObserverPanicCaptureDisabled(context.Background()), NewObserver( + func(value int) {}, + func(err error) { t.Fatalf("unexpected error: %v", err) }, + func() {}, + )) + }) +} + func TestObserverMixedOperations(t *testing.T) { t.Parallel() testWithTimeout(t, 5*time.Second) diff --git a/observer_unsafe_test.go b/observer_unsafe_test.go new file mode 100644 index 0000000..6fd6312 --- /dev/null +++ b/observer_unsafe_test.go @@ -0,0 +1,69 @@ +package ro + +import ( + "context" + "testing" +) + +func TestNewObserverUnsafe_panicsPropagate(t *testing.T) { + t.Parallel() + obs := NewUnsafeObserver[int]( + func(v int) { panic("boom") }, + func(err error) {}, + func() {}, + ) + + recovered := false + func() { + defer func() { + if r := recover(); r != nil { + recovered = true + } + }() + obs.Next(1) + }() + + if !recovered { + t.Fatalf("expected panic to propagate from NewObserverUnsafe") + } +} + +func TestNewObserverWithContextUnsafe_panicsPropagate(t *testing.T) { + t.Parallel() + obs := NewObserverWithContextUnsafe[int]( + func(ctx context.Context, v int) { panic("boom") }, + func(ctx context.Context, err error) {}, + func(ctx context.Context) {}, + ) + + recovered := false + func() { + defer func() { + if r := recover(); r != nil { + recovered = true + } + }() + obs.NextWithContext(context.Background(), 1) + }() + + if !recovered { + t.Fatalf("expected panic to propagate from NewObserverWithContextUnsafe") + } +} + +func TestNewObserver_defaultCapturesPanic(t *testing.T) { + t.Parallel() + caught := false + obs := NewObserver[int]( + func(v int) { panic("boom2") }, + func(err error) { caught = true }, + func() {}, + ) + + // This should not panic; instead the onError handler should be called. + obs.Next(1) + + if !caught { + t.Fatalf("expected NewObserver to capture panic and call onError") + } +} diff --git a/operator_creation.go b/operator_creation.go index c0b6d66..bb6e285 100644 --- a/operator_creation.go +++ b/operator_creation.go @@ -178,6 +178,13 @@ func IntervalWithInitial(initial, interval time.Duration) Observable[int64] { // descending order. The step is 1. // Play: https://go.dev/play/p/5XAXfNrtJm2 func Range(start, end int64) Observable[int64] { + return RangeWithMode(start, end, ConcurrencyModeSingleProducer) +} + +// RangeWithMode creates an Observable that emits a range of integers using the +// provided concurrency mode. The semantics match `Range` but allow callers to +// opt into a specific `ConcurrencyMode` for testing or specialised pipelines. +func RangeWithMode(start, end int64, mode ConcurrencyMode) Observable[int64] { sign := int64(1) if start == end { @@ -186,7 +193,7 @@ func Range(start, end int64) Observable[int64] { sign = -1 } - return NewUnsafeObservableWithContext(func(ctx context.Context, destination Observer[int64]) Teardown { + return NewObservableWithConcurrencyMode(func(ctx context.Context, destination Observer[int64]) Teardown { cursor := start for cursor*sign < end*sign { @@ -197,7 +204,7 @@ func Range(start, end int64) Observable[int64] { destination.CompleteWithContext(ctx) return nil - }) + }, mode) } // RangeWithStep creates an Observable that emits a range of floats. @@ -220,7 +227,7 @@ func RangeWithStep(start, end, step float64) Observable[float64] { panic(ErrRangeWithStepWrongStep) } - return NewUnsafeObservableWithContext(func(ctx context.Context, destination Observer[float64]) Teardown { + return NewObservableWithConcurrencyMode(func(ctx context.Context, destination Observer[float64]) Teardown { cursor := start for cursor*sign < end*sign { @@ -231,7 +238,7 @@ func RangeWithStep(start, end, step float64) Observable[float64] { destination.CompleteWithContext(ctx) return nil - }) + }, ConcurrencyModeSingleProducer) } // RangeWithInterval creates an Observable that emits a range of integers. diff --git a/operator_creation_test.go b/operator_creation_test.go index cadd0a9..52e5ae5 100644 --- a/operator_creation_test.go +++ b/operator_creation_test.go @@ -640,6 +640,32 @@ func TestOperatorCreationMerge(t *testing.T) { //nolint:paralleltest is.EqualError(err, assert.AnError.Error()) } +// TestSingleProducerWithMultiProducerOperator is a manual reproduction test that +// demonstrates mixing `ConcurrencyModeSingleProducer` sources with a +// multi-producer operator (`Merge`) is unsupported and can lead to data races. +// +// This test is intentionally skipped by default. To reproduce the race, run: +// +// go test -run TestSingleProducerWithMultiProducerOperator -race ./ -v +func TestSingleProducerWithMultiProducerOperator(t *testing.T) { //nolint:paralleltest + if !RaceEnabled { + t.Skip("manual reproduction test: run with -race to observe data races when mixing single-producer mode with Merge") + } + + // Construct two single-producer ranges and merge them. Merge is a + // multi-producer operator and may invoke downstream subscribers from + // multiple goroutines; combining it with single-producer sources is + // unsupported and may produce races. + s1 := RangeWithMode(0, 10000, ConcurrencyModeSingleProducer) + s2 := RangeWithMode(10000, 20000, ConcurrencyModeSingleProducer) + + merged := Merge(s1, s2) + + // Collect will block until the merged observable completes. When run with + // -race the race detector should report concurrent access if present. + _, _ = Collect(merged) +} + func TestOperatorCreationCombineLatest2(t *testing.T) { //nolint:paralleltest // @TODO } diff --git a/plugins/observability/zap/operator_test.go b/plugins/observability/zap/operator_test.go index 1cced8c..d3d10eb 100644 --- a/plugins/observability/zap/operator_test.go +++ b/plugins/observability/zap/operator_test.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - package rozap import ( diff --git a/raceflag_norace.go b/raceflag_norace.go new file mode 100644 index 0000000..330815e --- /dev/null +++ b/raceflag_norace.go @@ -0,0 +1,6 @@ +//go:build !race + +package ro + +// RaceEnabled is false when the test binary is NOT built with the race detector. +const RaceEnabled = false diff --git a/raceflag_race.go b/raceflag_race.go new file mode 100644 index 0000000..ab06a0a --- /dev/null +++ b/raceflag_race.go @@ -0,0 +1,6 @@ +//go:build race + +package ro + +// RaceEnabled is true when the test binary is built with the race detector. +const RaceEnabled = true diff --git a/ro.go b/ro.go index 6c9984c..88cd9d5 100644 --- a/ro.go +++ b/ro.go @@ -18,35 +18,69 @@ import ( "context" "fmt" "log" + "sync/atomic" ) var ( - // By default, the library will ignore unhandled errors and dropped notifications. - // You can change this behavior by setting the following variables to your own - // error handling functions. - // - // Example: - // - // ro.OnUnhandledError = func(ctx context.Context, err error) { - // slog.Error(fmt.Sprintf("unhandled error: %s\n", err.Error())) - // } - // - // ro.OnDroppedNotification = func(ctx context.Context, notification fmt.Stringer) { - // slog.Warn(fmt.Sprintf("dropped notification: %s\n", notification.String())) - // } - // - // Note: `OnUnhandledError` and `OnDroppedNotification` are called synchronously from - // the goroutine that emits the error or the notification. A slow callback will slow - // down the whole pipeline. - - // OnUnhandledError is called when an error is emitted by an Observable and - // no error handler is registered. - OnUnhandledError = IgnoreOnUnhandledError - // OnDroppedNotification is called when a notification is emitted by an Observable and - // no notification handler is registered. - OnDroppedNotification = IgnoreOnDroppedNotification + // onUnhandledError stores the current handler for unhandled errors. It is accessed + // via atomic.Value to allow concurrent readers and writers without data races. + onUnhandledError atomic.Value // func(context.Context, error) + + // onDroppedNotification stores the current handler for dropped notifications. + onDroppedNotification atomic.Value // func(context.Context, fmt.Stringer) ) +func init() { + onUnhandledError.Store(IgnoreOnUnhandledError) + onDroppedNotification.Store(IgnoreOnDroppedNotification) +} + +// SetOnUnhandledError sets the handler that will be invoked when an error is +// emitted and not otherwise handled. Passing nil restores the default. +func SetOnUnhandledError(fn func(ctx context.Context, err error)) { + if fn == nil { + fn = IgnoreOnUnhandledError + } + onUnhandledError.Store(fn) +} + +// GetOnUnhandledError returns the currently configured unhandled-error handler. +func GetOnUnhandledError() func(ctx context.Context, err error) { + v := onUnhandledError.Load() + if fn, ok := v.(func(context.Context, error)); ok && fn != nil { + return fn + } + return IgnoreOnUnhandledError +} + +// OnUnhandledError calls the currently configured unhandled-error handler. +func OnUnhandledError(ctx context.Context, err error) { + GetOnUnhandledError()(ctx, err) +} + +// SetOnDroppedNotification sets the handler invoked when a notification is +// dropped. Passing nil restores the default. +func SetOnDroppedNotification(fn func(ctx context.Context, notification fmt.Stringer)) { + if fn == nil { + fn = IgnoreOnDroppedNotification + } + onDroppedNotification.Store(fn) +} + +// GetOnDroppedNotification returns the currently configured dropped-notification handler. +func GetOnDroppedNotification() func(ctx context.Context, notification fmt.Stringer) { + v := onDroppedNotification.Load() + if fn, ok := v.(func(context.Context, fmt.Stringer)); ok && fn != nil { + return fn + } + return IgnoreOnDroppedNotification +} + +// OnDroppedNotification calls the currently configured dropped-notification handler. +func OnDroppedNotification(ctx context.Context, notification fmt.Stringer) { + GetOnDroppedNotification()(ctx, notification) +} + // IgnoreOnUnhandledError is the default implementation of `OnUnhandledError`. func IgnoreOnUnhandledError(ctx context.Context, err error) {} diff --git a/subscriber.go b/subscriber.go index b3354f2..8c92205 100644 --- a/subscriber.go +++ b/subscriber.go @@ -89,6 +89,18 @@ func NewEventuallySafeSubscriber[T any](destination Observer[T]) Subscriber[T] { return NewSubscriberWithConcurrencyMode(destination, ConcurrencyModeEventuallySafe) } +// NewSingleProducerSubscriber creates a new Subscriber optimized for single producer scenarios. +// If the Observer is already a Subscriber, it is returned as is. Otherwise, a new Subscriber +// is created that wraps the Observer. +// +// The returned Subscriber will unsubscribe from the destination Observer when +// Unsubscribe() is called. +// +// This method is not safe for concurrent producers. +func NewSingleProducerSubscriber[T any](destination Observer[T]) Subscriber[T] { + return NewSubscriberWithConcurrencyMode(destination, ConcurrencyModeSingleProducer) +} + // NewSubscriberWithConcurrencyMode creates a new Subscriber from an Observer. If the Observer // is already a Subscriber, it is returned as is. Otherwise, a new Subscriber // is created that wraps the Observer. @@ -102,11 +114,23 @@ func NewSubscriberWithConcurrencyMode[T any](destination Observer[T], mode Concu // only for short-lived local locks. switch mode { case ConcurrencyModeSafe: - return newSubscriberImpl(mode, xsync.NewMutexWithLock(), BackpressureBlock, destination) + // Fully synchronized subscriber that uses a real mutex implementation. + return newSubscriberImpl(mode, xsync.NewMutexWithLock(), BackpressureBlock, destination, false) case ConcurrencyModeUnsafe: - return newSubscriberImpl(mode, xsync.NewMutexWithoutLock(), BackpressureBlock, destination) + // Unsafe mode: uses a no-op mutex object. Method calls to Lock/Unlock will be executed + // but they are no-ops; this preserves the same call-site shape as the safe variant while + // avoiding actual synchronization overhead. + return newSubscriberImpl(mode, xsync.NewMutexWithoutLock(), BackpressureBlock, destination, false) case ConcurrencyModeEventuallySafe: - return newSubscriberImpl(mode, xsync.NewMutexWithLock(), BackpressureDrop, destination) + // Safe with backpressure drop: uses a real mutex but drops values when the lock cannot + // be acquired immediately. + return newSubscriberImpl(mode, xsync.NewMutexWithLock(), BackpressureDrop, destination, false) + case ConcurrencyModeSingleProducer: + // Single-producer optimized: uses the lockless fast path (mu == nil, lockless == true). + // This avoids any Lock/Unlock calls on the hot path and relies on atomics for status + // checks. It is intentionally different from ConcurrencyModeUnsafe which still calls + // no-op Lock/Unlock methods (and therefore incurs a method call per notification). + return newSubscriberImpl(mode, nil, BackpressureBlock, destination, true) default: panic("invalid concurrency mode") } @@ -114,12 +138,19 @@ func NewSubscriberWithConcurrencyMode[T any](destination Observer[T], mode Concu // newSubscriberImpl creates a new subscriber implementation with the specified // synchronization behavior and destination observer. -func newSubscriberImpl[T any](mode ConcurrencyMode, mu xsync.Mutex, backpressure Backpressure, destination Observer[T]) Subscriber[T] { +func newSubscriberImpl[T any](mode ConcurrencyMode, mu xsync.Mutex, backpressure Backpressure, destination Observer[T], lockless bool) Subscriber[T] { // Protect against multiple encapsulation layers. if subscriber, ok := destination.(Subscriber[T]); ok { return subscriber } + // Note: `mu == nil` combined with `lockless == true` enables the fast-path used by + // `ConcurrencyModeSingleProducer` where the subscriber avoids calling Lock/Unlock on each + // notification and instead uses atomic status checks. `xsync.NewMutexWithoutLock()` is a + // no-op mutex implementation used by `ConcurrencyModeUnsafe` — its Lock/Unlock methods are + // still invoked but do nothing. We keep both variants to make the performance trade-offs + // explicit and measurable. + subscriber := &subscriberImpl[T]{ status: 0, // KindNext backpressure: backpressure, @@ -129,6 +160,7 @@ func newSubscriberImpl[T any](mode ConcurrencyMode, mu xsync.Mutex, backpressure Subscription: NewSubscription(nil), mode: mode, + lockless: lockless, } if subscription, ok := destination.(Subscription); ok { @@ -164,7 +196,15 @@ type subscriberImpl[T any] struct { Subscription - mode ConcurrencyMode + mode ConcurrencyMode + lockless bool + // Per-subscription direct call helpers. When non-nil these are used in the + // hot path to call the destination without additional interface dispatch + // or context lookups. They are set once at subscription time by the + // Observable (see observable.SubscribeWithContext). + nextDirect func(context.Context, T) + errorDirect func(context.Context, error) + completeDirect func(context.Context) } // Implements Observer. @@ -178,6 +218,22 @@ func (s *subscriberImpl[T]) NextWithContext(ctx context.Context, v T) { return } + if s.lockless { + // Fast-path: if status indicates not-next, drop the notification. + if atomic.LoadInt32(&s.status) != 0 { + OnDroppedNotification(ctx, NewNotificationNext(v)) + return + } + + if s.nextDirect != nil { + s.nextDirect(ctx, v) + } else { + s.destination.NextWithContext(ctx, v) + } + + return + } + if s.backpressure == BackpressureDrop { if !s.mu.TryLock() { OnDroppedNotification(ctx, NewNotificationNext(v)) @@ -187,10 +243,17 @@ func (s *subscriberImpl[T]) NextWithContext(ctx context.Context, v T) { s.mu.Lock() } - if atomic.LoadInt32(&s.status) == 0 { - s.destination.NextWithContext(ctx, v) - } else { + // If already in non-next state, drop the notification and return early. + if atomic.LoadInt32(&s.status) != 0 { + s.mu.Unlock() OnDroppedNotification(ctx, NewNotificationNext(v)) + return + } + + if s.nextDirect != nil { + s.nextDirect(ctx, v) + } else { + s.destination.NextWithContext(ctx, v) } s.mu.Unlock() @@ -203,14 +266,46 @@ func (s *subscriberImpl[T]) Error(err error) { // Implements Observer. func (s *subscriberImpl[T]) ErrorWithContext(ctx context.Context, err error) { - s.mu.Lock() + if s.lockless { + // Fast-path: attempt to move to error state; if CAS fails, drop. + if !atomic.CompareAndSwapInt32(&s.status, 0, 1) { + OnDroppedNotification(ctx, NewNotificationError[T](err)) + s.unsubscribe() + return + } - if atomic.CompareAndSwapInt32(&s.status, 0, 1) { - if s.destination != nil { + // If no destination, nothing to do beyond unsubscribing. + if s.destination == nil { + s.unsubscribe() + return + } + + if s.errorDirect != nil { + s.errorDirect(ctx, err) + } else { s.destination.ErrorWithContext(ctx, err) } - } else { + + s.unsubscribe() + return + } + + s.mu.Lock() + + // If CAS to error fails, drop and return early. + if !atomic.CompareAndSwapInt32(&s.status, 0, 1) { + s.mu.Unlock() OnDroppedNotification(ctx, NewNotificationError[T](err)) + s.unsubscribe() + return + } + + if s.destination != nil { + if s.errorDirect != nil { + s.errorDirect(ctx, err) + } else { + s.destination.ErrorWithContext(ctx, err) + } } s.mu.Unlock() @@ -225,14 +320,46 @@ func (s *subscriberImpl[T]) Complete() { // Implements Observer. func (s *subscriberImpl[T]) CompleteWithContext(ctx context.Context) { - s.mu.Lock() + if s.lockless { + // Fast-path: attempt to move to complete state; if CAS fails, drop. + if !atomic.CompareAndSwapInt32(&s.status, 0, 2) { + OnDroppedNotification(ctx, NewNotificationComplete[T]()) + s.unsubscribe() + return + } - if atomic.CompareAndSwapInt32(&s.status, 0, 2) { - if s.destination != nil { + // If no destination, nothing to do beyond unsubscribing. + if s.destination == nil { + s.unsubscribe() + return + } + + if s.completeDirect != nil { + s.completeDirect(ctx) + } else { s.destination.CompleteWithContext(ctx) } - } else { + + s.unsubscribe() + return + } + + s.mu.Lock() + + // If CAS to complete fails, drop and return early. + if !atomic.CompareAndSwapInt32(&s.status, 0, 2) { + s.mu.Unlock() OnDroppedNotification(ctx, NewNotificationComplete[T]()) + s.unsubscribe() + return + } + + if s.destination != nil { + if s.completeDirect != nil { + s.completeDirect(ctx) + } else { + s.destination.CompleteWithContext(ctx) + } } s.mu.Unlock() @@ -266,3 +393,30 @@ func (s *subscriberImpl[T]) unsubscribe() { // s.Subscription.Unsubscribe() is protected against concurrent calls. s.Subscription.Unsubscribe() } + +// setDirectors configures per-subscription direct call helpers based on the +// concrete destination type and the precomputed capture flag. This avoids +// per-notification context lookups and type assertions on the hot path. +func (s *subscriberImpl[T]) setDirectors(destination Observer[T], capture bool) { + // Default to interface-based calls. + s.nextDirect = func(ctx context.Context, v T) { destination.NextWithContext(ctx, v) } + s.errorDirect = func(ctx context.Context, err error) { destination.ErrorWithContext(ctx, err) } + s.completeDirect = func(ctx context.Context) { destination.CompleteWithContext(ctx) } + + // If destination is an *observerImpl[T], we can configure per-subscription + // direct call helpers. When capture==true we create one-off wrappers that + // call internal helpers that accept a precomputed capture flag and therefore + // avoid context lookups. + if oi, ok := destination.(*observerImpl[T]); ok { + if !capture { + // No panic capture: call internal methods directly. + s.nextDirect = func(ctx context.Context, v T) { oi.onNext(ctx, v) } + s.errorDirect = func(ctx context.Context, err error) { oi.onError(ctx, err) } + s.completeDirect = func(ctx context.Context) { oi.onComplete(ctx) } + return + } + s.nextDirect = func(ctx context.Context, v T) { oi.tryNextWithCapture(ctx, v, capture) } + s.errorDirect = func(ctx context.Context, err error) { oi.tryErrorWithCapture(ctx, err, capture) } + s.completeDirect = func(ctx context.Context) { oi.tryCompleteWithCapture(ctx, capture) } + } +} diff --git a/subscriber_bench_test.go b/subscriber_bench_test.go new file mode 100644 index 0000000..910bc89 --- /dev/null +++ b/subscriber_bench_test.go @@ -0,0 +1,41 @@ +package ro + +import ( + "context" + "testing" +) + +// BenchmarkSubscriberNextPath compares the hot-path cost of calling Next for +// different concurrency modes: +// - Safe: real mutex +// - Unsafe: no-op mutex (method calls happen but do nothing) +// - SingleProducer: lockless fast-path (no Lock/Unlock calls) +// +// The benchmark disables observer panic-capture to reduce noise from the +// panic-recovery wrappers and focus measurements on synchronization costs. +func BenchmarkSubscriberNextPath(b *testing.B) { + // Use a context-scoped opt-out for panic capture so the benchmark avoids + // global state mutation and measures only the synchronization costs. + ctx := WithObserverPanicCaptureDisabled(context.Background()) + + cases := []struct { + name string + mode ConcurrencyMode + }{ + {"Safe", ConcurrencyModeSafe}, + {"Unsafe", ConcurrencyModeUnsafe}, + {"SingleProducer", ConcurrencyModeSingleProducer}, + } + + for _, c := range cases { + c := c + b.Run(c.name, func(b *testing.B) { + sub := NewSubscriberWithConcurrencyMode[int](NoopObserver[int](), c.mode) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sub.NextWithContext(ctx, i) + } + }) + } +} diff --git a/subscriber_edge_cases_test.go b/subscriber_edge_cases_test.go new file mode 100644 index 0000000..ed366c4 --- /dev/null +++ b/subscriber_edge_cases_test.go @@ -0,0 +1,197 @@ +// Copyright 2025 samber. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/samber/ro/blob/main/licenses/LICENSE.apache.md +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ro + +import ( + "context" + "testing" + + "github.com/samber/ro/internal/xsync" + "github.com/stretchr/testify/assert" +) + +func TestSubscriberImpl_ErrorWithContext_locklessNilDestination(t *testing.T) { + t.Parallel() + is := assert.New(t) + + // Create a lockless subscriber with nil destination + subscriber := &subscriberImpl[int]{ + status: 0, + backpressure: BackpressureBlock, + mu: nil, + destination: nil, + Subscription: NewSubscription(nil), + mode: ConcurrencyModeSingleProducer, + lockless: true, + } + + // Should handle nil destination gracefully + subscriber.ErrorWithContext(context.Background(), assert.AnError) + is.Equal(int32(1), subscriber.status) // Should transition to error state +} + +func TestSubscriberImpl_CompleteWithContext_locklessNilDestination(t *testing.T) { + t.Parallel() + is := assert.New(t) + + // Create a lockless subscriber with nil destination + subscriber := &subscriberImpl[int]{ + status: 0, + backpressure: BackpressureBlock, + mu: nil, + destination: nil, + Subscription: NewSubscription(nil), + mode: ConcurrencyModeSingleProducer, + lockless: true, + } + + // Should handle nil destination gracefully + subscriber.CompleteWithContext(context.Background()) + is.Equal(int32(2), subscriber.status) // Should transition to complete state +} + +func TestSubscriberImpl_setDirectors_nonObserverImpl(t *testing.T) { + t.Parallel() + is := assert.New(t) + + // Create a custom observer that doesn't use observerImpl + type customObserver struct { + nextCalled bool + errorCalled bool + completeCalled bool + } + + custom := &customObserver{} + + // Create an observer wrapper + observer := NewObserver( + func(value int) { custom.nextCalled = true }, + func(err error) { custom.errorCalled = true }, + func() { custom.completeCalled = true }, + ) + + subscriber := &subscriberImpl[int]{ + status: 0, + backpressure: BackpressureBlock, + mu: xsync.NewMutexWithLock(), + destination: observer, + Subscription: NewSubscription(nil), + mode: ConcurrencyModeSafe, + lockless: false, + } + + // Call setDirectors with a non-observerImpl destination + // This should set up the default interface-based calls + subscriber.setDirectors(observer, true) + + // Verify directors were set + is.NotNil(subscriber.nextDirect) + is.NotNil(subscriber.errorDirect) + is.NotNil(subscriber.completeDirect) + + // Test that the directors work + subscriber.nextDirect(context.Background(), 1) + is.True(custom.nextCalled) + + subscriber.errorDirect(context.Background(), assert.AnError) + is.True(custom.errorCalled) + + subscriber.completeDirect(context.Background()) + is.True(custom.completeCalled) +} + +func TestSubscriberImpl_setDirectors_withObserverImpl(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var nextCalled, errorCalled, completeCalled bool + + // Create an observerImpl directly + observer := &observerImpl[int]{ + status: 0, + capturePanics: true, + onNext: func(ctx context.Context, value int) { nextCalled = true }, + onError: func(ctx context.Context, err error) { errorCalled = true }, + onComplete: func(ctx context.Context) { completeCalled = true }, + } + + subscriber := &subscriberImpl[int]{ + status: 0, + backpressure: BackpressureBlock, + mu: xsync.NewMutexWithLock(), + destination: observer, + Subscription: NewSubscription(nil), + mode: ConcurrencyModeSafe, + lockless: false, + } + + // Call setDirectors with an observerImpl destination + // This should set up the optimized tryXXXWithCapture calls + subscriber.setDirectors(observer, true) + + // Verify directors were set + is.NotNil(subscriber.nextDirect) + is.NotNil(subscriber.errorDirect) + is.NotNil(subscriber.completeDirect) + + // Test that the directors work and use the optimized path + subscriber.nextDirect(context.Background(), 1) + is.True(nextCalled) + + subscriber.errorDirect(context.Background(), assert.AnError) + is.True(errorCalled) + + subscriber.completeDirect(context.Background()) + is.True(completeCalled) +} + +func TestSubscriberImpl_setDirectors_noCapture_propagatesPanics(t *testing.T) { + t.Parallel() + is := assert.New(t) + + // Create an observerImpl whose handlers panic. When setDirectors is + // invoked with capture==false we expect these panics to propagate + // (i.e. no TryCatch wrapper should be used). + observer := &observerImpl[int]{ + status: 0, + capturePanics: true, + onNext: func(ctx context.Context, value int) { panic("onNext panic") }, + onError: func(ctx context.Context, err error) { panic("onError panic") }, + onComplete: func(ctx context.Context) { panic("onComplete panic") }, + } + + subscriber := &subscriberImpl[int]{ + status: 0, + backpressure: BackpressureBlock, + mu: xsync.NewMutexWithLock(), + destination: observer, + Subscription: NewSubscription(nil), + mode: ConcurrencyModeSafe, + lockless: false, + } + + // Configure directors with capture=false so the direct helpers call + // the internal onNext/onError/onComplete without TryCatch wrappers. + subscriber.setDirectors(observer, false) + + is.NotNil(subscriber.nextDirect) + is.NotNil(subscriber.errorDirect) + is.NotNil(subscriber.completeDirect) + + // Each direct call should panic (propagate) because capture==false. + is.Panics(func() { subscriber.nextDirect(context.Background(), 1) }) + is.Panics(func() { subscriber.errorDirect(context.Background(), assert.AnError) }) + is.Panics(func() { subscriber.completeDirect(context.Background()) }) +} diff --git a/subscriber_test.go b/subscriber_test.go index c89fbd9..0a51367 100644 --- a/subscriber_test.go +++ b/subscriber_test.go @@ -16,6 +16,7 @@ package ro import ( "context" + "fmt" "sync" "sync/atomic" "testing" @@ -38,47 +39,56 @@ func TestSubscriberInternalOk(t *testing.T) { subscriber2, ok2 := NewSafeSubscriber(observer).(*subscriberImpl[int]) subscriber3, ok3 := NewUnsafeSubscriber(observer).(*subscriberImpl[int]) subscriber4, ok4 := NewEventuallySafeSubscriber(observer).(*subscriberImpl[int]) + subscriber5, ok5 := NewSingleProducerSubscriber(observer).(*subscriberImpl[int]) is.True(ok1) is.True(ok2) is.True(ok3) is.True(ok4) + is.True(ok5) // default state is.EqualValues(KindNext, subscriber1.status) is.EqualValues(KindNext, subscriber2.status) is.EqualValues(KindNext, subscriber3.status) is.EqualValues(KindNext, subscriber4.status) + is.EqualValues(KindNext, subscriber5.status) // send values subscriber1.Next(21) subscriber2.Next(21) subscriber3.Next(21) subscriber4.Next(21) + subscriber5.Next(21) is.EqualValues(KindNext, subscriber1.status) is.EqualValues(KindNext, subscriber2.status) is.EqualValues(KindNext, subscriber3.status) is.EqualValues(KindNext, subscriber4.status) + is.EqualValues(KindNext, subscriber5.status) // completed state subscriber1.Complete() subscriber2.Complete() subscriber3.Complete() subscriber4.Complete() + subscriber5.Complete() is.EqualValues(KindComplete, subscriber1.status) is.EqualValues(KindComplete, subscriber2.status) is.EqualValues(KindComplete, subscriber3.status) is.EqualValues(KindComplete, subscriber4.status) + is.EqualValues(KindComplete, subscriber5.status) // no change subscriber1.Next(42) subscriber2.Next(42) subscriber3.Next(42) subscriber4.Next(42) + subscriber5.Next(42) is.EqualValues(KindComplete, subscriber1.status) is.EqualValues(KindComplete, subscriber2.status) is.EqualValues(KindComplete, subscriber3.status) is.EqualValues(KindComplete, subscriber4.status) + is.EqualValues(KindComplete, subscriber5.status) } func TestSubscriberInternalError(t *testing.T) { @@ -95,11 +105,13 @@ func TestSubscriberInternalError(t *testing.T) { subscriber2, ok2 := NewSafeSubscriber(observer).(*subscriberImpl[int]) subscriber3, ok3 := NewUnsafeSubscriber(observer).(*subscriberImpl[int]) subscriber4, ok4 := NewEventuallySafeSubscriber(observer).(*subscriberImpl[int]) + subscriber5, ok5 := NewSingleProducerSubscriber(observer).(*subscriberImpl[int]) is.True(ok1) is.True(ok2) is.True(ok3) is.True(ok4) + is.True(ok5) // default state is.EqualValues(KindNext, subscriber1.status) @@ -112,30 +124,36 @@ func TestSubscriberInternalError(t *testing.T) { subscriber2.Next(21) subscriber3.Next(21) subscriber4.Next(21) + subscriber5.Next(21) is.EqualValues(KindNext, subscriber1.status) is.EqualValues(KindNext, subscriber2.status) is.EqualValues(KindNext, subscriber3.status) is.EqualValues(KindNext, subscriber4.status) + is.EqualValues(KindNext, subscriber5.status) // trigger error subscriber1.Error(assert.AnError) subscriber2.Error(assert.AnError) subscriber3.Error(assert.AnError) subscriber4.Error(assert.AnError) + subscriber5.Error(assert.AnError) is.EqualValues(KindError, subscriber1.status) is.EqualValues(KindError, subscriber2.status) is.EqualValues(KindError, subscriber3.status) is.EqualValues(KindError, subscriber4.status) + is.EqualValues(KindError, subscriber5.status) // no change subscriber1.Next(42) subscriber2.Next(42) subscriber3.Next(42) subscriber4.Next(42) + subscriber5.Next(42) is.EqualValues(KindError, subscriber1.status) is.EqualValues(KindError, subscriber2.status) is.EqualValues(KindError, subscriber3.status) is.EqualValues(KindError, subscriber4.status) + is.EqualValues(KindError, subscriber5.status) } func TestSubscriberNext(t *testing.T) { @@ -147,6 +165,7 @@ func TestSubscriberNext(t *testing.T) { var counter2 int64 var counter3 int64 var counter4 int64 + var counter5 int64 observer1 := NewObserver( func(value int) { atomic.AddInt64(&counter1, int64(value)) }, @@ -168,64 +187,93 @@ func TestSubscriberNext(t *testing.T) { func(err error) {}, func() {}, ) + observer5 := NewObserver( + func(value int) { atomic.AddInt64(&counter5, int64(value)) }, + func(err error) {}, + func() {}, + ) subscriber1, ok1 := NewSubscriber(observer1).(*subscriberImpl[int]) subscriber2, ok2 := NewSafeSubscriber(observer2).(*subscriberImpl[int]) subscriber3, ok3 := NewUnsafeSubscriber(observer3).(*subscriberImpl[int]) subscriber4, ok4 := NewEventuallySafeSubscriber(observer4).(*subscriberImpl[int]) + subscriber5, ok5 := NewSingleProducerSubscriber(observer5).(*subscriberImpl[int]) is.True(ok1) is.True(ok2) is.True(ok3) is.True(ok4) + is.True(ok5) subscriber1.Next(21) is.EqualValues(21, atomic.LoadInt64(&counter1)) is.EqualValues(0, atomic.LoadInt64(&counter2)) is.EqualValues(0, atomic.LoadInt64(&counter3)) is.EqualValues(0, atomic.LoadInt64(&counter4)) + is.EqualValues(0, atomic.LoadInt64(&counter5)) subscriber2.Next(21) is.EqualValues(21, atomic.LoadInt64(&counter1)) is.EqualValues(21, atomic.LoadInt64(&counter2)) is.EqualValues(0, atomic.LoadInt64(&counter3)) is.EqualValues(0, atomic.LoadInt64(&counter4)) + is.EqualValues(0, atomic.LoadInt64(&counter5)) subscriber3.Next(21) is.EqualValues(21, atomic.LoadInt64(&counter1)) is.EqualValues(21, atomic.LoadInt64(&counter2)) is.EqualValues(21, atomic.LoadInt64(&counter3)) is.EqualValues(0, atomic.LoadInt64(&counter4)) + is.EqualValues(0, atomic.LoadInt64(&counter5)) subscriber4.Next(21) is.EqualValues(21, atomic.LoadInt64(&counter1)) is.EqualValues(21, atomic.LoadInt64(&counter2)) is.EqualValues(21, atomic.LoadInt64(&counter3)) is.EqualValues(21, atomic.LoadInt64(&counter4)) + is.EqualValues(0, atomic.LoadInt64(&counter5)) + + subscriber5.Next(21) + is.EqualValues(21, atomic.LoadInt64(&counter1)) + is.EqualValues(21, atomic.LoadInt64(&counter2)) + is.EqualValues(21, atomic.LoadInt64(&counter3)) + is.EqualValues(21, atomic.LoadInt64(&counter4)) + is.EqualValues(21, atomic.LoadInt64(&counter5)) subscriber1.Next(21) is.EqualValues(42, atomic.LoadInt64(&counter1)) is.EqualValues(21, atomic.LoadInt64(&counter2)) is.EqualValues(21, atomic.LoadInt64(&counter3)) is.EqualValues(21, atomic.LoadInt64(&counter4)) + is.EqualValues(21, atomic.LoadInt64(&counter5)) subscriber2.Next(21) is.EqualValues(42, atomic.LoadInt64(&counter1)) is.EqualValues(42, atomic.LoadInt64(&counter2)) is.EqualValues(21, atomic.LoadInt64(&counter3)) is.EqualValues(21, atomic.LoadInt64(&counter4)) + is.EqualValues(21, atomic.LoadInt64(&counter5)) subscriber3.Next(21) is.EqualValues(42, atomic.LoadInt64(&counter1)) is.EqualValues(42, atomic.LoadInt64(&counter2)) is.EqualValues(42, atomic.LoadInt64(&counter3)) is.EqualValues(21, atomic.LoadInt64(&counter4)) + is.EqualValues(21, atomic.LoadInt64(&counter5)) subscriber4.Next(21) is.EqualValues(42, atomic.LoadInt64(&counter1)) is.EqualValues(42, atomic.LoadInt64(&counter2)) is.EqualValues(42, atomic.LoadInt64(&counter3)) is.EqualValues(42, atomic.LoadInt64(&counter4)) + is.EqualValues(21, atomic.LoadInt64(&counter5)) + + subscriber5.Next(21) + is.EqualValues(42, atomic.LoadInt64(&counter1)) + is.EqualValues(42, atomic.LoadInt64(&counter2)) + is.EqualValues(42, atomic.LoadInt64(&counter3)) + is.EqualValues(42, atomic.LoadInt64(&counter4)) + is.EqualValues(42, atomic.LoadInt64(&counter5)) } func TestSubscriberError(t *testing.T) { @@ -236,6 +284,7 @@ func TestSubscriberError(t *testing.T) { var counter2 int64 var counter3 int64 var counter4 int64 + var counter5 int64 observer1 := NewObserver( func(value int) { atomic.AddInt64(&counter1, int64(value)) }, @@ -257,45 +306,58 @@ func TestSubscriberError(t *testing.T) { func(err error) { atomic.AddInt64(&counter4, int64(1)) }, func() {}, ) + observer5 := NewObserver( + func(value int) { atomic.AddInt64(&counter5, int64(value)) }, + func(err error) { atomic.AddInt64(&counter5, int64(1)) }, + func() {}, + ) subscriber1, ok1 := NewSubscriber(observer1).(*subscriberImpl[int]) subscriber2, ok2 := NewSafeSubscriber(observer2).(*subscriberImpl[int]) subscriber3, ok3 := NewUnsafeSubscriber(observer3).(*subscriberImpl[int]) subscriber4, ok4 := NewEventuallySafeSubscriber(observer4).(*subscriberImpl[int]) + subscriber5, ok5 := NewSingleProducerSubscriber(observer5).(*subscriberImpl[int]) is.True(ok1) is.True(ok2) is.True(ok3) is.True(ok4) + is.True(ok5) subscriber1.Next(21) subscriber2.Next(21) subscriber3.Next(21) subscriber4.Next(21) + subscriber5.Next(21) is.EqualValues(21, atomic.LoadInt64(&counter1)) is.EqualValues(21, atomic.LoadInt64(&counter2)) is.EqualValues(21, atomic.LoadInt64(&counter3)) is.EqualValues(21, atomic.LoadInt64(&counter4)) + is.EqualValues(21, atomic.LoadInt64(&counter5)) // trigger error subscriber1.Error(assert.AnError) subscriber2.Error(assert.AnError) subscriber3.Error(assert.AnError) subscriber4.Error(assert.AnError) + subscriber5.Error(assert.AnError) is.EqualValues(22, atomic.LoadInt64(&counter1)) is.EqualValues(22, atomic.LoadInt64(&counter2)) is.EqualValues(22, atomic.LoadInt64(&counter3)) is.EqualValues(22, atomic.LoadInt64(&counter4)) + is.EqualValues(22, atomic.LoadInt64(&counter5)) // send a new message subscriber1.Next(21) subscriber2.Next(21) subscriber3.Next(21) subscriber4.Next(21) + subscriber5.Next(21) is.EqualValues(22, atomic.LoadInt64(&counter1)) is.EqualValues(22, atomic.LoadInt64(&counter2)) is.EqualValues(22, atomic.LoadInt64(&counter3)) is.EqualValues(22, atomic.LoadInt64(&counter4)) + is.EqualValues(22, atomic.LoadInt64(&counter5)) } func TestSubscriberComplete(t *testing.T) { @@ -306,6 +368,7 @@ func TestSubscriberComplete(t *testing.T) { var counter2 int64 var counter3 int64 var counter4 int64 + var counter5 int64 observer1 := NewObserver( func(value int) { atomic.AddInt64(&counter1, int64(value)) }, @@ -327,45 +390,58 @@ func TestSubscriberComplete(t *testing.T) { func(err error) {}, func() { atomic.AddInt64(&counter4, 1) }, ) + observer5 := NewObserver( + func(value int) { atomic.AddInt64(&counter5, int64(value)) }, + func(err error) {}, + func() { atomic.AddInt64(&counter5, 1) }, + ) subscriber1, ok1 := NewSubscriber(observer1).(*subscriberImpl[int]) subscriber2, ok2 := NewSafeSubscriber(observer2).(*subscriberImpl[int]) subscriber3, ok3 := NewUnsafeSubscriber(observer3).(*subscriberImpl[int]) subscriber4, ok4 := NewEventuallySafeSubscriber(observer4).(*subscriberImpl[int]) + subscriber5, ok5 := NewSingleProducerSubscriber(observer5).(*subscriberImpl[int]) is.True(ok1) is.True(ok2) is.True(ok3) is.True(ok4) + is.True(ok5) subscriber1.Next(21) subscriber2.Next(21) subscriber3.Next(21) subscriber4.Next(21) + subscriber5.Next(21) is.EqualValues(21, atomic.LoadInt64(&counter1)) is.EqualValues(21, atomic.LoadInt64(&counter2)) is.EqualValues(21, atomic.LoadInt64(&counter3)) is.EqualValues(21, atomic.LoadInt64(&counter4)) + is.EqualValues(21, atomic.LoadInt64(&counter5)) // trigger complete subscriber1.Complete() subscriber2.Complete() subscriber3.Complete() subscriber4.Complete() + subscriber5.Complete() is.EqualValues(22, atomic.LoadInt64(&counter1)) is.EqualValues(22, atomic.LoadInt64(&counter2)) is.EqualValues(22, atomic.LoadInt64(&counter3)) is.EqualValues(22, atomic.LoadInt64(&counter4)) + is.EqualValues(22, atomic.LoadInt64(&counter5)) // send a new message subscriber1.Next(21) subscriber2.Next(21) subscriber3.Next(21) subscriber4.Next(21) + subscriber5.Next(21) is.EqualValues(22, atomic.LoadInt64(&counter1)) is.EqualValues(22, atomic.LoadInt64(&counter2)) is.EqualValues(22, atomic.LoadInt64(&counter3)) is.EqualValues(22, atomic.LoadInt64(&counter4)) + is.EqualValues(22, atomic.LoadInt64(&counter5)) } func TestSubscriberWithContext(t *testing.T) { @@ -439,6 +515,137 @@ func TestSubscriberWithContext(t *testing.T) { is.Equal(assert.AnError, receivedError) } +// it uses a helper to serialize overrides but intentionally does not call +// t.Parallel() to avoid races on the global variable. +// +//nolint:paralleltest // this test mutates the global `OnDroppedNotification` hook; +func TestLocklessDroppedNotification(t *testing.T) { + // Do NOT run this test in parallel. It mutates the global + // `OnDroppedNotification` hook which may race with other tests + // that read or call the hook. Keep the existing defer that + // restores the previous handler so the hook is reset for later + // tests. + is := assert.New(t) + + // Capture dropped notifications using the helper which serializes + // overrides of the global `OnDroppedNotification` hook and restores it + // after the inner function returns. + var seen string + WithDroppedNotification(t, func(ctx context.Context, notification fmt.Stringer) { + seen = notification.String() + }, func() { + observer := NewObserver( + func(value int) {}, + func(err error) {}, + func() {}, + ) + + subscriber, ok := NewSingleProducerSubscriber(observer).(*subscriberImpl[int]) + is.True(ok) + + // Mark as completed so further Next() should be dropped + subscriber.Complete() + + subscriber.Next(42) + + is.NotEmpty(seen) + is.Contains(seen, "Next(42)") + }) +} + +// Stress test for the single-producer fast path. This exercises the lockless +// code path under a tight sequential loop to ensure correctness and to make it +// easy to run under the race detector in CI. +func TestSingleProducerStress(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var counter int64 + observer := NewObserver( + func(value int) { atomic.AddInt64(&counter, int64(value)) }, + func(err error) {}, + func() {}, + ) + + subscriber, ok := NewSingleProducerSubscriber(observer).(*subscriberImpl[int]) + is.True(ok) + + const iterations = 100000 + for i := 0; i < iterations; i++ { + subscriber.Next(1) + } + + is.Equal(int64(iterations), atomic.LoadInt64(&counter)) +} + +// TestSingleProducerContextCancellation ensures the single-producer (lockless) +// fast path forwards the provided context to the destination observer and that +// a cancelled context is visible to the observer callback. +func TestSingleProducerContextCancellation(t *testing.T) { + t.Parallel() + testWithTimeout(t, 100*time.Millisecond) + is := assert.New(t) + + var sawCanceled bool + + observer := NewObserverWithContext( + func(ctx context.Context, value int) { + if ctx.Err() == context.Canceled { + sawCanceled = true + } + }, + func(ctx context.Context, err error) {}, + func(ctx context.Context) {}, + ) + + subscriber, ok := NewSingleProducerSubscriber(observer).(*subscriberImpl[int]) + is.True(ok) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + subscriber.NextWithContext(ctx, 42) + + is.True(sawCanceled) +} + +// Concurrent stress test for the safe subscriber. Spawns multiple goroutines +// concurrently calling Next and validates the final accumulated value. +func TestSafeSubscriberConcurrentStress(t *testing.T) { + t.Parallel() + is := assert.New(t) + + var counter int64 + observer := NewObserver( + func(value int) { atomic.AddInt64(&counter, int64(value)) }, + func(err error) {}, + func() {}, + ) + + subscriber, ok := NewSafeSubscriber(observer).(*subscriberImpl[int]) + is.True(ok) + + const goroutines = 8 + const per = 10000 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := 0; g < goroutines; g++ { + go func() { + for i := 0; i < per; i++ { + subscriber.Next(1) + } + wg.Done() + }() + } + + wg.Wait() + + expected := int64(goroutines * per) + is.Equal(expected, atomic.LoadInt64(&counter)) +} + func TestSubscriberIsClosed(t *testing.T) { t.Parallel() is := assert.New(t) diff --git a/subscriber_test_helper.go b/subscriber_test_helper.go new file mode 100644 index 0000000..7cb1bcb --- /dev/null +++ b/subscriber_test_helper.go @@ -0,0 +1,48 @@ +// Copyright 2025 samber. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/samber/ro/blob/main/licenses/LICENSE.apache.md +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ro + +import ( + "context" + "fmt" + "sync" + "testing" +) + +// droppedNotificationMu serializes test-time overrides of the package-level +// `OnDroppedNotification` hook so tests do not concurrently write the global +// variable and cause data races. Tests that need to temporarily replace the +// hook should use WithDroppedNotification. +var droppedNotificationMu sync.Mutex + +// WithDroppedNotification temporarily sets `OnDroppedNotification` to the +// provided handler while executing fn. The previous handler is restored when +// fn returns. The helper serializes mutations using a mutex so concurrent +// test goroutines don't perform simultaneous writes to the global hook. +func WithDroppedNotification(t *testing.T, handler func(ctx context.Context, notification fmt.Stringer), fn func()) { + t.Helper() + + droppedNotificationMu.Lock() + prev := GetOnDroppedNotification() + SetOnDroppedNotification(handler) + + // Ensure restore and unlock even if fn panics. + defer func() { + SetOnDroppedNotification(prev) + droppedNotificationMu.Unlock() + }() + + fn() +} diff --git a/test_context_key.go b/test_context_key.go new file mode 100644 index 0000000..d87da9e --- /dev/null +++ b/test_context_key.go @@ -0,0 +1,24 @@ +// Copyright 2025 samber. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/samber/ro/blob/main/licenses/LICENSE.apache.md +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ro + +// ctxKey is an unexported type used for context keys in tests to avoid +// using a basic type (like string) directly as a context key which +// triggers linters such as revive's context-keys-type rule. +type ctxKey string + +// testCtxKey is the key used by tests that need to attach a value to a +// context and later retrieve it. Keep it unexported and typed. +const testCtxKey ctxKey = "test" diff --git a/testing/benchmark_million_rows_test.go b/testing/benchmark_million_rows_test.go new file mode 100644 index 0000000..b7f6dca --- /dev/null +++ b/testing/benchmark_million_rows_test.go @@ -0,0 +1,77 @@ +// Copyright 2025 samber. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://github.com/samber/ro/blob/main/licenses/LICENSE.apache.md +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "context" + stdtesting "testing" + + "github.com/samber/ro" +) + +func BenchmarkMillionRowChallenge(b *stdtesting.B) { + b.ReportAllocs() + // Use a per-subscription context to disable observer panic capture for the + // benchmark. This avoids mutating global state and keeps tests parallel- + // friendly while removing the panic-capture overhead. + ctx := ro.WithObserverPanicCaptureDisabled(context.Background()) + + const expectedSum int64 = 750001500000 + + benchmarkCases := []struct { + name string + source ro.Observable[int64] + }{ + {name: "single-producer", source: ro.Range(0, 1_000_000)}, + {name: "unsafe-mutex", source: ro.RangeWithMode(0, 1_000_000, ro.ConcurrencyModeUnsafe)}, + {name: "safe-mutex", source: ro.RangeWithMode(0, 1_000_000, ro.ConcurrencyModeSafe)}, + {name: "eventually-safe", source: ro.RangeWithMode(0, 1_000_000, ro.ConcurrencyModeEventuallySafe)}, + } + + for _, tc := range benchmarkCases { + b.Run(tc.name, func(b *stdtesting.B) { + pipeline := ro.Pipe3( + tc.source, + ro.Map(func(value int64) int64 { return value + 1 }), + ro.Filter(func(value int64) bool { return value%2 == 0 }), + ro.Map(func(value int64) int64 { return value * 3 }), + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var sum int64 + + subscription := pipeline.SubscribeWithContext(ctx, ro.NewObserver( + func(value int64) { + sum += value + }, + func(err error) { + b.Fatalf("unexpected error: %v", err) + }, + func() {}, + )) + + subscription.Wait() + + if sum != expectedSum { + b.Fatalf("unexpected sum: %d", sum) + } + } + }) + } +} + +// Note: RangeWithMode in the main package replaces the ad-hoc helper previously +// used here. See operator_creation.go for the implementation.