diff --git a/util/encoding/encoding_test.go b/util/encoding/encoding_test.go index f1a0cccdc..a7d230a64 100644 --- a/util/encoding/encoding_test.go +++ b/util/encoding/encoding_test.go @@ -2,6 +2,7 @@ package encoding import ( "bytes" + "io" "testing" "github.com/stretchr/testify/require" @@ -60,3 +61,40 @@ func TestScalarHexString(t *testing.T) { ErrFatal(err) require.True(t, sc.Equal(s2)) } + +// Tests for error cases +type MockFailingReader struct { + data []byte +} +type MockEmptyReader struct { + data []byte +} + +func (m *MockFailingReader) Read(p []byte) (n int, err error) { + return copy(p, m.data), io.EOF +} +func (m *MockEmptyReader) Read(p []byte) (n int, err error) { + return 0, nil +} + +func TestReadHexPointErrorInvalidHexEnc(t *testing.T) { + // Test case: invalid hex encoding + reader := bytes.NewReader([]byte("invalidhex")) + _, err := ReadHexPoint(s, reader) + require.Error(t, err, "Expected error when reading invalid hex encoding, but got nil") +} + +func TestReadHexPointErrorReaderFails(t *testing.T) { + // Test case: reader fails + mockReader1 := &MockFailingReader{data: []byte("abc")} + _, err := ReadHexPoint(s, mockReader1) + require.Error(t, err, "Expected error when reader fails, but got nil") +} + +func TestReadHexPointErrorNotEnoughBytes(t *testing.T) { + // Test case: not enough bytes from stream + mockReader2 := &MockEmptyReader{data: []byte("abc")} + _, err := ReadHexPoint(s, mockReader2) + require.Error(t, err, "Expected error when not enough bytes from stream, but got nil") + require.EqualError(t, err, "didn't get enough bytes from stream", "Expected error message: didn't get enough bytes from stream, but got %s", err.Error()) +} diff --git a/util/random/rand_test.go b/util/random/rand_test.go index 42b50ac03..36f62f76e 100644 --- a/util/random/rand_test.go +++ b/util/random/rand_test.go @@ -3,14 +3,18 @@ package random import ( "bytes" "crypto/rand" + "fmt" + "math/big" + "strconv" "strings" "testing" ) const size = 32 +const readerStream = "some io.Reader stream to be used for testing" func TestMixedEntropy(t *testing.T) { - r := strings.NewReader("some io.Reader stream to be used for testing") + r := strings.NewReader(readerStream) cipher := New(r, rand.Reader) src := make([]byte, size) @@ -57,13 +61,12 @@ func TestCryptoOnly(t *testing.T) { } func TestUserOnly(t *testing.T) { - seed := "some io.Reader stream to be used for testing" - cipher1 := New(strings.NewReader(seed)) + cipher1 := New(strings.NewReader(readerStream)) src := make([]byte, size) copy(src, []byte("hello")) dst1 := make([]byte, size) cipher1.XORKeyStream(dst1, src) - cipher2 := New(strings.NewReader(seed)) + cipher2 := New(strings.NewReader(readerStream)) dst2 := make([]byte, size) cipher2.XORKeyStream(dst2, src) if !bytes.Equal(dst1, dst2) { @@ -84,3 +87,63 @@ func TestIncorrectSize(t *testing.T) { dst := make([]byte, size+1) cipher.XORKeyStream(dst, src) } + +func TestBits(t *testing.T) { + testCases := []struct { + bitlen uint // input bit length + exact bool // whether the exact bit length should be enforced + }{ + {bitlen: 128, exact: false}, + {bitlen: 256, exact: true}, + {bitlen: 512, exact: false}, + {bitlen: 1024, exact: true}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("bitlen: %d exact: %s", tc.bitlen, strconv.FormatBool(tc.exact)), func(t *testing.T) { + r := strings.NewReader(readerStream) + cipher := New(r, rand.Reader) + + bigIntBytes := Bits(tc.bitlen, tc.exact, cipher) + bigInt := new(big.Int).SetBytes(bigIntBytes) + + // Check if the bit length matches the expected length + expectedBitLen := tc.bitlen + if tc.exact && uint(bigInt.BitLen()) != expectedBitLen { + t.Errorf("Generated BigInt with exact bits doesn't match the expected bit length: got %d, expected %d", bigInt.BitLen(), expectedBitLen) + } else if !tc.exact && uint(bigInt.BitLen()) > expectedBitLen { + t.Errorf("Generated BigInt with more bits than maximum bit length: got %d, expected at most %d", bigInt.BitLen(), expectedBitLen) + } + }) + } +} + +func TestInt(t *testing.T) { + testCases := []struct { + modulusBitLen uint // Bit length of the modulus + }{ + {128}, + {256}, + {512}, + {1024}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("modulusBitlen: %d", tc.modulusBitLen), func(t *testing.T) { + modulus, err := rand.Prime(rand.Reader, int(tc.modulusBitLen)) + if err != nil { + t.Fatalf("Failed to generate random prime: %v", err) + } + + r := strings.NewReader(readerStream) + cipher := New(r, rand.Reader) + + randomInt := Int(modulus, cipher) + + // Check if the generated BigInt is less than the modulus + if randomInt.Cmp(modulus) >= 0 { + t.Errorf("Generated BigInt %v is not less than the modulus %v", randomInt, modulus) + } + }) + } +}