Skip to content

Commit 83ba9bc

Browse files
authored
Merge pull request #199 from martinthomson/quic_record_layer
QUIC record layer changes
2 parents a14404e + e78e097 commit 83ba9bc

11 files changed

+205
-129
lines changed

client-state-machine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
161161
var offeredPSK PreSharedKey
162162
var earlyHash crypto.Hash
163163
var earlySecret []byte
164-
var clientEarlyTrafficKeys keySet
164+
var clientEarlyTrafficKeys KeySet
165165
var clientHello *HandshakeMessage
166166
if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok {
167167
offeredPSK = key

common_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ func assertNotByteEquals(t *testing.T, a, b []byte) {
7474
func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
7575
t.Helper()
7676
assertEquals(t, a.Suite, b.Suite)
77-
// Can't compare aeadFactory values
77+
// Can't compare AEADFactory values
7878
assertEquals(t, a.Hash, b.Hash)
79-
assertEquals(t, a.KeyLen, b.KeyLen)
80-
assertEquals(t, a.IvLen, b.IvLen)
79+
assertEquals(t, len(a.KeyLengths), len(b.KeyLengths))
80+
for k, v := range a.KeyLengths {
81+
assertEquals(t, v, b.KeyLengths[k])
82+
}
8183
}
8284

8385
func assertDeepEquals(t *testing.T, a, b interface{}) {

conn.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ type Config struct {
129129
NonBlocking bool
130130
UseDTLS bool
131131

132+
RecordLayer RecordLayerFactory
133+
132134
// The same config object can be shared among different connections, so it
133135
// needs its own mutex
134136
mutex sync.RWMutex
@@ -270,28 +272,33 @@ type Conn struct {
270272
handshakeComplete bool
271273

272274
readBuffer []byte
273-
in, out *RecordLayer
275+
in, out RecordLayer
274276
hsCtx *HandshakeContext
275277
}
276278

277279
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
278280
c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
279281
if !config.UseDTLS {
280-
c.in = NewRecordLayerTLS(c.conn, directionRead)
281-
c.out = NewRecordLayerTLS(c.conn, directionWrite)
282+
if config.RecordLayer == nil {
283+
c.in = NewRecordLayerTLS(c.conn, DirectionRead)
284+
c.out = NewRecordLayerTLS(c.conn, DirectionWrite)
285+
} else {
286+
c.in = config.RecordLayer.NewLayer(c.conn, DirectionRead)
287+
c.out = config.RecordLayer.NewLayer(c.conn, DirectionWrite)
288+
}
282289
c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
283290
c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
284291
} else {
285-
c.in = NewRecordLayerDTLS(c.conn, directionRead)
286-
c.out = NewRecordLayerDTLS(c.conn, directionWrite)
292+
c.in = NewRecordLayerDTLS(c.conn, DirectionRead)
293+
c.out = NewRecordLayerDTLS(c.conn, DirectionWrite)
287294
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
288295
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
289296
c.hsCtx.timeoutMS = initialTimeout
290297
c.hsCtx.timers = newTimerSet()
291298
c.hsCtx.waitingNextFlight = true
292299
}
293-
c.in.label = c.label()
294-
c.out.label = c.label()
300+
c.in.SetLabel(c.label())
301+
c.out.SetLabel(c.label())
295302
c.hsCtx.hIn.nonblocking = c.config.NonBlocking
296303
return c
297304
}
@@ -598,15 +605,15 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
598605
logf(logTypeHandshake, "%s Rekey with data still in handshake buffers", label)
599606
return AlertDecodeError
600607
}
601-
err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
608+
err := c.in.Rekey(action.epoch, action.KeySet.Cipher, &action.KeySet)
602609
if err != nil {
603610
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
604611
return AlertInternalError
605612
}
606613

607614
case RekeyOut:
608615
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet)
609-
err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
616+
err := c.out.Rekey(action.epoch, action.KeySet.Cipher, &action.KeySet)
610617
if err != nil {
611618
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
612619
return AlertInternalError
@@ -906,7 +913,7 @@ func (c *Conn) Writable() bool {
906913
}
907914

908915
// If we're a client in 0-RTT, then we're writable.
909-
if c.isClient && c.out.cipher.epoch == EpochEarlyData {
916+
if c.isClient && c.out.Epoch() == EpochEarlyData {
910917
return true
911918
}
912919

conn_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,14 @@ func init() {
322322
}
323323
}
324324

325-
func assertKeySetEquals(t *testing.T, k1, k2 keySet) {
325+
func assertKeySetEquals(t *testing.T, k1, k2 KeySet) {
326326
t.Helper()
327327
// Assume cipher is the same
328-
assertByteEquals(t, k1.iv, k2.iv)
329-
assertByteEquals(t, k1.key, k2.key)
328+
assertTrue(t, len(k1.Keys) > 0, "assert that there are some keys")
329+
assertEquals(t, len(k1.Keys), len(k2.Keys))
330+
for k, v := range k1.Keys {
331+
assertByteEquals(t, v, k2.Keys[k])
332+
}
330333
}
331334

332335
func computeExporter(t *testing.T, c *Conn, label string, context []byte, length int) []byte {

crypto.go

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@ import (
2727

2828
var prng = rand.Reader
2929

30-
type aeadFactory func(key []byte) (cipher.AEAD, error)
30+
type AEADFactory func(key []byte) (cipher.AEAD, error)
3131

3232
type CipherSuiteParams struct {
33-
Suite CipherSuite
34-
Cipher aeadFactory // Cipher factory
35-
Hash crypto.Hash // Hash function
36-
KeyLen int // Key length in octets
37-
IvLen int // IV length in octets
33+
Suite CipherSuite
34+
Cipher AEADFactory // Cipher factory
35+
Hash crypto.Hash // Hash function
36+
KeyLengths map[string]int // This maps keys (the label used for HKDF-Expand-Label) to the length of the key needed.
3837
}
3938

4039
type signatureAlgorithm uint8
@@ -91,18 +90,16 @@ var (
9190

9291
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
9392
TLS_AES_128_GCM_SHA256: {
94-
Suite: TLS_AES_128_GCM_SHA256,
95-
Cipher: newAESGCM,
96-
Hash: crypto.SHA256,
97-
KeyLen: 16,
98-
IvLen: 12,
93+
Suite: TLS_AES_128_GCM_SHA256,
94+
Cipher: newAESGCM,
95+
Hash: crypto.SHA256,
96+
KeyLengths: map[string]int{labelForKey: 16, labelForIV: 12},
9997
},
10098
TLS_AES_256_GCM_SHA384: {
101-
Suite: TLS_AES_256_GCM_SHA384,
102-
Cipher: newAESGCM,
103-
Hash: crypto.SHA384,
104-
KeyLen: 32,
105-
IvLen: 12,
99+
Suite: TLS_AES_256_GCM_SHA384,
100+
Cipher: newAESGCM,
101+
Hash: crypto.SHA384,
102+
KeyLengths: map[string]int{labelForKey: 32, labelForIV: 12},
106103
},
107104
}
108105

@@ -604,19 +601,18 @@ func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte)
604601
return mac.Sum(nil)
605602
}
606603

607-
type keySet struct {
608-
cipher aeadFactory
609-
key []byte
610-
iv []byte
604+
type KeySet struct {
605+
Cipher AEADFactory
606+
Keys map[string][]byte
611607
}
612608

613-
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
609+
func makeTrafficKeys(params CipherSuiteParams, secret []byte) KeySet {
614610
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
615-
return keySet{
616-
cipher: params.Cipher,
617-
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
618-
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
611+
ks := KeySet{Cipher: params.Cipher, Keys: make(map[string][]byte, len(params.KeyLengths))}
612+
for label, length := range params.KeyLengths {
613+
ks.Keys[label] = HkdfExpandLabel(params.Hash, secret, label, []byte{}, length)
619614
}
615+
return ks
620616
}
621617

622618
func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) {

handshake-layer.go

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*H
120120
type HandshakeLayer struct {
121121
ctx *HandshakeContext // The handshake we are attached to
122122
nonblocking bool // Should we operate in nonblocking mode
123-
conn *RecordLayer // Used for reading/writing records
123+
conn RecordLayer // Used for reading/writing records
124124
frame *frameReader // The buffered frame reader
125125
datagram bool // Is this DTLS?
126126
msgSeq uint32 // The DTLS message sequence number
@@ -153,7 +153,7 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
153153
return int(val), nil
154154
}
155155

156-
func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
156+
func NewHandshakeLayerTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer {
157157
h := HandshakeLayer{}
158158
h.ctx = c
159159
h.conn = r
@@ -163,7 +163,7 @@ func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
163163
return &h
164164
}
165165

166-
func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
166+
func NewHandshakeLayerDTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer {
167167
h := HandshakeLayer{}
168168
h.ctx = c
169169
h.conn = r
@@ -174,8 +174,15 @@ func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer
174174
}
175175

176176
func (h *HandshakeLayer) readRecord() error {
177-
logf(logTypeVerbose, "Trying to read record")
178-
pt, err := h.conn.readRecordAnyEpoch()
177+
var pt *TLSPlaintext
178+
var err error
179+
180+
if h.datagram {
181+
logf(logTypeVerbose, "Trying to read record")
182+
pt, err = h.conn.(*DefaultRecordLayer).ReadRecordAnyEpoch()
183+
} else {
184+
pt, err = h.conn.ReadRecord()
185+
}
179186
if err != nil {
180187
return err
181188
}
@@ -204,7 +211,7 @@ func (h *HandshakeLayer) readRecord() error {
204211
}
205212

206213
assert(h.ctx.hIn.conn != nil)
207-
if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
214+
if pt.epoch != h.ctx.hIn.conn.Epoch() {
208215
// This is out of order but we're dropping it.
209216
// TODO([email protected]): If server, need to retransmit Finished.
210217
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
@@ -394,9 +401,13 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
394401
}
395402

396403
func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
397-
hm.cipher = h.conn.cipher
398-
h.queued = append(h.queued, hm)
399-
return nil
404+
if h.datagram {
405+
hm.cipher = h.conn.(*DefaultRecordLayer).cipher
406+
h.queued = append(h.queued, hm)
407+
return nil
408+
}
409+
_, err := h.WriteMessages([]*HandshakeMessage{hm})
410+
return err
400411
}
401412

402413
func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
@@ -456,22 +467,30 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int
456467
buf = body
457468
}
458469

470+
var err error
459471
if h.datagram {
460472
// Remember that we sent this.
461473
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
462474
hm.seq,
463475
start,
464476
len(body),
465-
h.conn.cipher.combineSeq(true),
477+
h.conn.(*DefaultRecordLayer).cipher.combineSeq(true),
466478
false,
467479
})
480+
err = h.conn.(*DefaultRecordLayer).writeRecordWithPadding(
481+
&TLSPlaintext{
482+
contentType: RecordTypeHandshake,
483+
fragment: buf,
484+
},
485+
hm.cipher, 0)
486+
} else {
487+
err = h.conn.WriteRecord(
488+
&TLSPlaintext{
489+
contentType: RecordTypeHandshake,
490+
fragment: buf,
491+
})
468492
}
469-
return true, start + bodylen, h.conn.writeRecordWithPadding(
470-
&TLSPlaintext{
471-
contentType: RecordTypeHandshake,
472-
fragment: buf,
473-
},
474-
hm.cipher, 0)
493+
return true, start + bodylen, err
475494
}
476495

477496
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {

handshake-layer_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func TestMessageFromBody(t *testing.T) {
154154
chValid := unhex(chValidHex)
155155

156156
b := bytes.NewBuffer(nil)
157-
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionRead))
157+
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionRead))
158158

159159
// Test successful conversion
160160
hm, err := h.HandshakeMessageFromBody(&chValidIn)
@@ -172,7 +172,7 @@ func TestMessageFromBody(t *testing.T) {
172172
func newHandshakeLayerFromBytes(d []byte) *HandshakeLayer {
173173
hc := &HandshakeContext{}
174174
b := bytes.NewBuffer(d)
175-
hc.hIn = NewHandshakeLayerTLS(hc, NewRecordLayerTLS(b, directionRead))
175+
hc.hIn = NewHandshakeLayerTLS(hc, NewRecordLayerTLS(b, DirectionRead))
176176
return hc.hIn
177177
}
178178

@@ -224,7 +224,7 @@ func TestReadHandshakeMessage(t *testing.T) {
224224
}
225225

226226
func testWriteHandshakeMessage(h *HandshakeLayer, hm *HandshakeMessage) error {
227-
hm.cipher = h.conn.cipher
227+
hm.cipher = h.conn.(*DefaultRecordLayer).cipher
228228
_, err := h.WriteMessage(hm)
229229
return err
230230
}
@@ -235,26 +235,26 @@ func TestWriteHandshakeMessage(t *testing.T) {
235235

236236
// Test successful write of single message
237237
b := bytes.NewBuffer(nil)
238-
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite))
238+
h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionWrite))
239239
err := testWriteHandshakeMessage(h, shortMessageIn)
240240
assertNotError(t, err, "Failed to write valid short message")
241241
assertByteEquals(t, b.Bytes(), short)
242242

243243
// Test successful write of single long message
244244
b = bytes.NewBuffer(nil)
245-
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite))
245+
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionWrite))
246246
err = testWriteHandshakeMessage(h, longMessageIn)
247247
assertNotError(t, err, "Failed to write valid long message")
248248
assertByteEquals(t, b.Bytes(), long)
249249

250250
// Test write failure on message too large
251251
b = bytes.NewBuffer(nil)
252-
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite))
252+
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, DirectionWrite))
253253
err = testWriteHandshakeMessage(h, tooLongMessageIn)
254254
assertError(t, err, "Wrote a message exceeding the length bound")
255255

256256
// Test write failure on underlying write failure
257-
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(ErrorReadWriter{}, directionWrite))
257+
h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(ErrorReadWriter{}, DirectionWrite))
258258
err = testWriteHandshakeMessage(h, longMessageIn)
259259
assertError(t, err, "Write succeeded despite error in full fragment send")
260260
err = testWriteHandshakeMessage(h, shortMessageIn)
@@ -265,7 +265,7 @@ type testReassembleFixture struct {
265265
t *testing.T
266266
c HandshakeContext
267267
h *HandshakeLayer
268-
r *RecordLayer
268+
r *DefaultRecordLayer
269269
rd *pipeConn
270270
wr *pipeConn
271271
m0 *HandshakeMessage
@@ -298,7 +298,7 @@ func newTestReassembleFixture(t *testing.T) *testReassembleFixture {
298298
f.m1 = newHsFragment(m1, 1, 0, 2048)
299299
f.rd, f.wr = pipe()
300300

301-
f.r = NewRecordLayerDTLS(f.rd, directionRead)
301+
f.r = NewRecordLayerDTLS(f.rd, DirectionRead)
302302
f.h = NewHandshakeLayerDTLS(&f.c, f.r)
303303
f.c.hIn = f.h
304304
f.c.timers = newTimerSet()

0 commit comments

Comments
 (0)