Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions src/crypto/cipher/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ func BenchmarkAESGCM(b *testing.B) {
}
}

func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte) {
func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte, keySize int) {
b.SetBytes(int64(len(buf)))

var key [16]byte
key := make([]byte, keySize)
var iv [16]byte
aes, _ := aes.NewCipher(key[:])
aes, _ := aes.NewCipher(key)
stream := mode(aes, iv[:])

b.ResetTimer()
Expand All @@ -87,15 +87,20 @@ const almost1K = 1024 - 5
const almost8K = 8*1024 - 5

func BenchmarkAESCTR(b *testing.B) {
b.Run("50", func(b *testing.B) {
benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50))
})
b.Run("1K", func(b *testing.B) {
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K))
})
b.Run("8K", func(b *testing.B) {
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K))
})
for _, keyBits := range []int{128, 192, 256} {
keySize := keyBits / 8
b.Run(strconv.Itoa(keyBits), func(b *testing.B) {
b.Run("50", func(b *testing.B) {
benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50), keySize)
})
b.Run("1K", func(b *testing.B) {
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K), keySize)
})
b.Run("8K", func(b *testing.B) {
benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K), keySize)
})
})
}
}

func BenchmarkAESCBCEncrypt1K(b *testing.B) {
Expand Down
64 changes: 64 additions & 0 deletions src/crypto/cipher/ctr_aes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"crypto/internal/boring"
"crypto/internal/cryptotest"
fipsaes "crypto/internal/fips140/aes"
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
Expand Down Expand Up @@ -117,6 +118,60 @@ func makeTestingCiphers(aesBlock cipher.Block, iv []byte) (genericCtr, multibloc
return cipher.NewCTR(wrap(aesBlock), iv), cipher.NewCTR(aesBlock, iv)
}

// TestCTR_AES_blocks8FastPathMatchesGeneric ensures the overlow aware branch
// produces identical keystreams to the generic counter walker across
// representative IVs, including near-overflow cases.
func TestCTR_AES_blocks8FastPathMatchesGeneric(t *testing.T) {
key := make([]byte, aes.BlockSize)
block, err := aes.NewCipher(key)
if err != nil {
t.Fatal(err)
}
if _, ok := block.(*fipsaes.Block); !ok {
t.Skip("requires crypto/internal/fips140/aes")
}

keystream := make([]byte, 8*aes.BlockSize)

testCases := []struct {
name string
hi uint64
lo uint64
}{
{"Zero", 0, 0},
{"NearOverflowMinus7", 1, ^uint64(0) - 7},
{"NearOverflowMinus6", 2, ^uint64(0) - 6},
{"Overflow", 0, ^uint64(0)},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var iv [aes.BlockSize]byte
binary.BigEndian.PutUint64(iv[0:8], tc.hi)
binary.BigEndian.PutUint64(iv[8:], tc.lo)

generic, multiblock := makeTestingCiphers(block, iv[:])

genericOut := make([]byte, len(keystream))
multiblockOut := make([]byte, len(keystream))

generic.XORKeyStream(genericOut, keystream)
multiblock.XORKeyStream(multiblockOut, keystream)

if !bytes.Equal(multiblockOut, genericOut) {
t.Fatalf("mismatch for iv %#x:%#x\n"+
"asm keystream: %x\n"+
"gen keystream: %x\n"+
"asm counters: %x\n"+
"gen counters: %x",
tc.hi, tc.lo, multiblockOut, genericOut,
extractCounters(block, multiblockOut),
extractCounters(block, genericOut))
}
})
}
}

func randBytes(t *testing.T, r *rand.Rand, count int) []byte {
t.Helper()
buf := make([]byte, count)
Expand Down Expand Up @@ -297,3 +352,12 @@ func TestCTR_AES_multiblock_XORKeyStreamAt(t *testing.T) {
})
}
}

func extractCounters(block cipher.Block, keystream []byte) []byte {
blockSize := block.BlockSize()
res := make([]byte, len(keystream))
for i := 0; i < len(keystream); i += blockSize {
block.Decrypt(res[i:i+blockSize], keystream[i:i+blockSize])
}
return res
}
84 changes: 72 additions & 12 deletions src/crypto/internal/fips140/aes/_asm/ctr/ctr_amd64_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,79 @@ func ctrBlocks(numBlocks int) {
bswap := XMM()
MOVOU(bswapMask(), bswap)

blocks := make([]VecVirtual, 0, numBlocks)
blocks := make([]VecVirtual, numBlocks)

// For the 8-block case we optimize counter generation. We build the first
// counter as usual, then check whether the remaining seven increments will
// overflow. When they do not (the common case) we keep the work entirely in
// XMM registers to avoid expensive general-purpose -> XMM moves. Otherwise
// we fall back to the traditional scalar path.
if numBlocks == 8 {
for i := range blocks {
blocks[i] = XMM()
}

// Lay out counter block plaintext.
for i := 0; i < numBlocks; i++ {
x := XMM()
blocks = append(blocks, x)

MOVQ(ivlo, x)
PINSRQ(Imm(1), ivhi, x)
PSHUFB(bswap, x)
if i < numBlocks-1 {
ADDQ(Imm(1), ivlo)
ADCQ(Imm(0), ivhi)
base := XMM()
tmp := GP64()
addVec := XMM()

MOVQ(ivlo, blocks[0])
PINSRQ(Imm(1), ivhi, blocks[0])
MOVAPS(blocks[0], base)
PSHUFB(bswap, blocks[0])

// Check whether any of these eight counters will overflow.
MOVQ(ivlo, tmp)
ADDQ(Imm(uint64(numBlocks-1)), tmp)
slowLabel := fmt.Sprintf("ctr%d_slow", numBlocks)
doneLabel := fmt.Sprintf("ctr%d_done", numBlocks)
JC(LabelRef(slowLabel))

// Fast branch: create an XMM increment vector containing the value 1.
// Adding it to the base counter yields each subsequent counter.
XORQ(tmp, tmp)
INCQ(tmp)
PXOR(addVec, addVec)
PINSRQ(Imm(0), tmp, addVec)

for i := 1; i < numBlocks; i++ {
PADDQ(addVec, base)
MOVAPS(base, blocks[i])
}
JMP(LabelRef(doneLabel))

Label(slowLabel)
ADDQ(Imm(1), ivlo)
ADCQ(Imm(0), ivhi)
for i := 1; i < numBlocks; i++ {
MOVQ(ivlo, blocks[i])
PINSRQ(Imm(1), ivhi, blocks[i])
if i < numBlocks-1 {
ADDQ(Imm(1), ivlo)
ADCQ(Imm(0), ivhi)
}
}

Label(doneLabel)

// Convert little-endian counters to big-endian after the branch since
// both paths share the same shuffle sequence.
for i := 1; i < numBlocks; i++ {
PSHUFB(bswap, blocks[i])
}
} else {
// Lay out counter block plaintext.
for i := 0; i < numBlocks; i++ {
x := XMM()
blocks[i] = x

MOVQ(ivlo, x)
PINSRQ(Imm(1), ivhi, x)
PSHUFB(bswap, x)
if i < numBlocks-1 {
ADDQ(Imm(1), ivlo)
ADCQ(Imm(0), ivhi)
}
}
}

Expand Down
39 changes: 33 additions & 6 deletions src/crypto/internal/fips140/aes/ctr_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -286,41 +286,68 @@ TEXT ·ctrBlocks8Asm(SB), $0-48
MOVOU bswapMask<>+0(SB), X0
MOVQ SI, X1
PINSRQ $0x01, DI, X1
MOVAPS X1, X8
PSHUFB X0, X1
MOVQ SI, R8
ADDQ $0x07, R8
JC ctr8_slow
XORQ R8, R8
INCQ R8
PXOR X9, X9
PINSRQ $0x00, R8, X9
PADDQ X9, X8
MOVAPS X8, X2
PADDQ X9, X8
MOVAPS X8, X3
PADDQ X9, X8
MOVAPS X8, X4
PADDQ X9, X8
MOVAPS X8, X5
PADDQ X9, X8
MOVAPS X8, X6
PADDQ X9, X8
MOVAPS X8, X7
PADDQ X9, X8
MOVAPS X8, X8
JMP ctr8_done

ctr8_slow:
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X2
PINSRQ $0x01, DI, X2
PSHUFB X0, X2
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X3
PINSRQ $0x01, DI, X3
PSHUFB X0, X3
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X4
PINSRQ $0x01, DI, X4
PSHUFB X0, X4
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X5
PINSRQ $0x01, DI, X5
PSHUFB X0, X5
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X6
PINSRQ $0x01, DI, X6
PSHUFB X0, X6
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X7
PINSRQ $0x01, DI, X7
PSHUFB X0, X7
ADDQ $0x01, SI
ADCQ $0x00, DI
MOVQ SI, X8
PINSRQ $0x01, DI, X8

ctr8_done:
PSHUFB X0, X2
PSHUFB X0, X3
PSHUFB X0, X4
PSHUFB X0, X5
PSHUFB X0, X6
PSHUFB X0, X7
PSHUFB X0, X8
MOVUPS (CX), X0
PXOR X0, X1
Expand Down