Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions peer/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,10 @@ type Peer struct {
userAgent string
services wire.ServiceFlag
versionKnown bool
handshakeDone bool
advertisedProtoVer uint32 // protocol version advertised by remote
protocolVersion uint32 // negotiated protocol version
sendHeadersPreferred bool // peer sent a sendheaders message
versionSent bool
verAckReceived bool

knownInventory *lru.Set[wire.InvVect]
Expand Down Expand Up @@ -705,6 +705,18 @@ func (p *Peer) VersionKnown() bool {
return versionKnown
}

// HandshakeDone returns whether initial version messages were sent and
// received.
//
// This function is safe for concurrent access.
func (p *Peer) HandshakeDone() bool {
p.flagsMtx.Lock()
handshakeDone := p.handshakeDone
p.flagsMtx.Unlock()

return handshakeDone
}

// VerAckReceived returns whether or not a verack message was received by the
// peer.
//
Expand Down Expand Up @@ -1652,7 +1664,7 @@ out:

case iv := <-p.outputInvChan:
// No handshake? They'll find out soon enough.
if p.VersionKnown() {
if p.HandshakeDone() {
invSendQueue = append(invSendQueue, iv)
}

Expand Down Expand Up @@ -2080,9 +2092,6 @@ func (p *Peer) writeLocalVersionMsg() error {
return err
}

p.flagsMtx.Lock()
p.versionSent = true
p.flagsMtx.Unlock()
return nil
}

Expand All @@ -2094,7 +2103,15 @@ func (p *Peer) negotiateInboundProtocol() error {
return err
}

return p.writeLocalVersionMsg()
if err := p.writeLocalVersionMsg(); err != nil {
return err
}

p.flagsMtx.Lock()
p.handshakeDone = true
p.flagsMtx.Unlock()

return nil
}

// negotiateOutboundProtocol sends our version message then waits to receive a
Expand All @@ -2105,7 +2122,15 @@ func (p *Peer) negotiateOutboundProtocol() error {
return err
}

return p.readRemoteVersionMsg()
if err := p.readRemoteVersionMsg(); err != nil {
return err
}

p.flagsMtx.Lock()
p.handshakeDone = true
p.flagsMtx.Unlock()

return nil
}

// start begins processing input and output messages.
Expand Down