diff --git a/AUTHORS.txt b/AUTHORS.txt index e9ab3dc..31e930d 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -137,6 +137,7 @@ Michael MacDonald Michael MacDonald Michiel De Backker <38858977+backkem@users.noreply.github.com> Mike Coleman +Mikhail Mindgamesnl mission-liao mohammadne diff --git a/peerconnection.go b/peerconnection.go index 867dafc..379be7b 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1865,7 +1865,7 @@ func (pc *PeerConnection) RemoveTrack(sender *RTPSender) (err error) { return } -func (pc *PeerConnection) newTransceiverFromTrack(direction RTPTransceiverDirection, track TrackLocal) (t *RTPTransceiver, err error) { +func (pc *PeerConnection) newTransceiverFromTrack(direction RTPTransceiverDirection, track TrackLocal, init ...RTPTransceiverInit) (t *RTPTransceiver, err error) { var ( r *RTPReceiver s *RTPSender @@ -1885,6 +1885,13 @@ func (pc *PeerConnection) newTransceiverFromTrack(direction RTPTransceiverDirect if err != nil { return } + + // Allow RTPTransceiverInit to override SSRC + if s != nil && len(s.trackEncodings) == 1 && + len(init) == 1 && len(init[0].SendEncodings) == 1 && init[0].SendEncodings[0].SSRC != 0 { + s.trackEncodings[0].ssrc = init[0].SendEncodings[0].SSRC + } + return newRTPTransceiver(r, s, direction, track.Kind(), pc.api), nil } @@ -1910,7 +1917,7 @@ func (pc *PeerConnection) AddTransceiverFromKind(kind RTPCodecType, init ...RTPT if err != nil { return nil, err } - t, err = pc.newTransceiverFromTrack(direction, track) + t, err = pc.newTransceiverFromTrack(direction, track, init...) if err != nil { return nil, err } @@ -1942,7 +1949,7 @@ func (pc *PeerConnection) AddTransceiverFromTrack(track TrackLocal, init ...RTPT direction = init[0].Direction } - t, err = pc.newTransceiverFromTrack(direction, track) + t, err = pc.newTransceiverFromTrack(direction, track, init...) if err == nil { pc.mu.Lock() pc.addRTPTransceiver(t) diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 2571db5..bb81d38 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -152,9 +152,11 @@ func TestPeerConnection_Media_Sample(t *testing.T) { }() go func() { + parameters := sender.GetParameters() + for { time.Sleep(time.Millisecond * 100) - if routineErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.trackEncodings[0].ssrc), MediaSSRC: uint32(sender.trackEncodings[0].ssrc)}}); routineErr != nil { + if routineErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(parameters.Encodings[0].SSRC), MediaSSRC: uint32(parameters.Encodings[0].SSRC)}}); routineErr != nil { awaitRTCPSenderSend <- routineErr } diff --git a/rtpsender.go b/rtpsender.go index 7b154a5..cbff6f8 100644 --- a/rtpsender.go +++ b/rtpsender.go @@ -111,6 +111,9 @@ func (r *RTPSender) Transport() *DTLSTransport { } func (r *RTPSender) getParameters() RTPSendParameters { + r.mu.RLock() + defer r.mu.RUnlock() + var encodings []RTPEncodingParameters for _, trackEncoding := range r.trackEncodings { var rid string @@ -196,19 +199,10 @@ func (r *RTPSender) AddEncoding(track TrackLocal) error { } func (r *RTPSender) addEncoding(track TrackLocal) { - ssrc := SSRC(randutil.NewMathRandomGenerator().Uint32()) trackEncoding := &trackEncoding{ - track: track, - srtpStream: &srtpWriterFuture{ssrc: ssrc}, - ssrc: ssrc, - } - trackEncoding.srtpStream.rtpSender = r - trackEncoding.rtcpInterceptor = r.api.interceptor.BindRTCPReader( - interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { - n, err = trackEncoding.srtpStream.Read(in) - return n, a, err - }), - ) + track: track, + ssrc: SSRC(randutil.NewMathRandomGenerator().Uint32()), + } r.trackEncodings = append(r.trackEncodings, trackEncoding) } @@ -295,8 +289,13 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { return errRTPSenderTrackRemoved } - for idx, trackEncoding := range r.trackEncodings { + for idx := range r.trackEncodings { + trackEncoding := r.trackEncodings[idx] + srtpStream := &srtpWriterFuture{ssrc: parameters.Encodings[idx].SSRC, rtpSender: r} writeStream := &interceptorToTrackLocalWriter{} + + trackEncoding.srtpStream = srtpStream + trackEncoding.ssrc = parameters.Encodings[idx].SSRC trackEncoding.context = &baseTrackLocalContext{ id: r.id, params: r.api.mediaEngine.getRTPParametersByKind(trackEncoding.track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}), @@ -318,13 +317,21 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { codec.RTPCodecCapability, parameters.HeaderExtensions, ) - srtpStream := trackEncoding.srtpStream + + trackEncoding.rtcpInterceptor = r.api.interceptor.BindRTCPReader( + interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = trackEncoding.srtpStream.Read(in) + return n, a, err + }), + ) + rtpInterceptor := r.api.interceptor.BindLocalStream( &trackEncoding.streamInfo, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { return srtpStream.WriteRTP(header, payload) }), ) + writeStream.interceptor.Store(rtpInterceptor) } @@ -355,7 +362,9 @@ func (r *RTPSender) Stop() error { errs := []error{} for _, trackEncoding := range r.trackEncodings { r.api.interceptor.UnbindLocalStream(&trackEncoding.streamInfo) - errs = append(errs, trackEncoding.srtpStream.Close()) + if trackEncoding.srtpStream != nil { + errs = append(errs, trackEncoding.srtpStream.Close()) + } } return util.FlattenErrs(errs) diff --git a/rtptransceiverinit_go_test.go b/rtptransceiverinit_go_test.go new file mode 100644 index 0000000..b326874 --- /dev/null +++ b/rtptransceiverinit_go_test.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +package webrtc + +import ( + "context" + "testing" + "time" + + "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" +) + +func Test_RTPTransceiverInit_SSRC(t *testing.T) { + lim := test.TimeOut(time.Second * 30) //nolint + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "a", "b") + assert.NoError(t, err) + + t.Run("SSRC of 0 is ignored", func(t *testing.T) { + offerer, answerer, err := newPair() + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + answerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { + assert.NotEqual(t, 0, track.SSRC()) + cancel() + }) + + _, err = offerer.AddTransceiverFromTrack(track, RTPTransceiverInit{ + Direction: RTPTransceiverDirectionSendonly, + SendEncodings: []RTPEncodingParameters{ + { + RTPCodingParameters: RTPCodingParameters{ + SSRC: 0, + }, + }, + }, + }) + assert.NoError(t, err) + assert.NoError(t, signalPair(offerer, answerer)) + sendVideoUntilDone(ctx.Done(), t, []*TrackLocalStaticSample{track}) + closePairNow(t, offerer, answerer) + }) + + t.Run("SSRC of 5000", func(t *testing.T) { + offerer, answerer, err := newPair() + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + answerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { + assert.NotEqual(t, 5000, track.SSRC()) + cancel() + }) + + _, err = offerer.AddTransceiverFromTrack(track, RTPTransceiverInit{ + Direction: RTPTransceiverDirectionSendonly, + SendEncodings: []RTPEncodingParameters{ + { + RTPCodingParameters: RTPCodingParameters{ + SSRC: 5000, + }, + }, + }, + }) + assert.NoError(t, err) + assert.NoError(t, signalPair(offerer, answerer)) + sendVideoUntilDone(ctx.Done(), t, []*TrackLocalStaticSample{track}) + closePairNow(t, offerer, answerer) + }) +}