Skip to content

Commit 3359004

Browse files
committed
support encrypted connection
Signed-off-by: Jeeva Kandasamy <[email protected]>
1 parent 611c57d commit 3359004

File tree

10 files changed

+486
-36
lines changed

10 files changed

+486
-36
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/esphome_config
2+
/.esphome/
3+
/secrets.yaml

examples/commom.go

+13-6
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@ import (
1111
)
1212

1313
const (
14-
EnvHostAddress = "ESPHOME_ADDRESS"
15-
EnvPassword = "ESPHOME_PASSWORD"
14+
EnvHostAddress = "ESPHOME_ADDRESS"
15+
EnvPassword = "ESPHOME_PASSWORD"
16+
EnvEncryptionKey = "ESPHOME_ENCRYPTION_KEY"
1617
)
1718

1819
var (
19-
HostAddressFlag = flag.String("address", "", "esphome node hostname or IP with port. example: my_esphome.local:6053")
20-
PasswordFlag = flag.String("password", "", "esphome node API password")
21-
TimeoutFlag = flag.Duration("timeout", 10*time.Second, "communication timeout")
20+
HostAddressFlag = flag.String("address", "", "esphome node hostname or IP with port. example: my_esphome.local:6053")
21+
PasswordFlag = flag.String("password", "", "esphome node API password")
22+
EncryptionKeyFlag = flag.String("encryption-key", "", "esphome node API encryption key")
23+
TimeoutFlag = flag.Duration("timeout", 10*time.Second, "communication timeout")
2224
)
2325

2426
func GetClient(handlerFunc func(msg proto.Message)) (*esphome.Client, error) {
@@ -38,11 +40,16 @@ func GetClient(handlerFunc func(msg proto.Message)) (*esphome.Client, error) {
3840
*PasswordFlag = os.Getenv(EnvPassword)
3941
}
4042

43+
// update encryption key
44+
if *EncryptionKeyFlag == "" {
45+
*EncryptionKeyFlag = os.Getenv(EnvEncryptionKey)
46+
}
47+
4148
if handlerFunc == nil {
4249
handlerFunc = handlerFuncImpl
4350
}
4451

45-
client, err := esphome.Init("mycontroller.org", *HostAddressFlag, *TimeoutFlag, handlerFunc)
52+
client, err := esphome.GetClient("mycontroller.org", *HostAddressFlag, *EncryptionKeyFlag, *TimeoutFlag, handlerFunc)
4653
if err != nil {
4754
return nil, err
4855
}

go.mod

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,12 @@ module github.com/mycontroller-org/esphome_api
22

33
go 1.18
44

5-
require google.golang.org/protobuf v1.28.0
5+
require (
6+
github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4
7+
google.golang.org/protobuf v1.28.0
8+
)
9+
10+
require (
11+
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect
12+
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
13+
)

go.sum

+17
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1+
github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4 h1:6pcIWmKkQZdpPjs/pD9OLt0NwftBozNE0Nm5zMCG2C4=
2+
github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
13
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
24
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
35
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
6+
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
7+
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
8+
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
9+
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
10+
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
11+
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w=
12+
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
13+
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
14+
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
15+
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
16+
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
17+
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
18+
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
419
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
520
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
621
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
722
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
823
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
24+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
25+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

pkg/api/custom_marshal.go renamed to pkg/api/api_frame_helper_default.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"google.golang.org/protobuf/proto"
1111
)
1212

13-
func Marshal(message proto.Message) ([]byte, error) {
13+
func Marshal_(message proto.Message) ([]byte, error) {
1414
messageBytes, err := proto.Marshal(message)
1515
if err != nil {
1616
return nil, err
@@ -36,7 +36,7 @@ func Marshal(message proto.Message) ([]byte, error) {
3636
return bytesPack[:index], nil
3737
}
3838

39-
func ReadMessage(reader *bufio.Reader) (proto.Message, error) {
39+
func Unmarshal_(reader *bufio.Reader) (proto.Message, error) {
4040
firstByte, err := reader.ReadByte()
4141
if err != nil {
4242
return nil, err

pkg/client/client.go

+22-27
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package client
22

33
import (
44
"bufio"
5-
"errors"
65
"fmt"
76
"net"
87
"sync"
@@ -11,15 +10,10 @@ import (
1110
"google.golang.org/protobuf/proto"
1211

1312
"github.com/mycontroller-org/esphome_api/pkg/api"
13+
"github.com/mycontroller-org/esphome_api/pkg/connection"
1414
types "github.com/mycontroller-org/esphome_api/pkg/types"
1515
)
1616

17-
// Error types
18-
var (
19-
ErrPassword = errors.New("esphome_api: invalid password")
20-
ErrCommunicationTimeout = errors.New("esphome_api: communication timeout")
21-
)
22-
2317
// Client struct.
2418
type Client struct {
2519
ID string
@@ -31,11 +25,12 @@ type Client struct {
3125
lastMessageAt time.Time
3226
handlerFunc func(proto.Message)
3327
CommunicationTimeout time.Duration
28+
apiConn connection.ApiConnection
3429
}
3530

36-
// Init func
37-
func Init(clientID, addr string, timeout time.Duration, handlerFunc func(proto.Message)) (*Client, error) {
38-
conn, err := net.DialTimeout("tcp", addr, timeout)
31+
// GetClient returns esphome api client
32+
func GetClient(clientID, address, encryptionKey string, timeout time.Duration, handlerFunc func(proto.Message)) (*Client, error) {
33+
conn, err := net.DialTimeout("tcp", address, timeout)
3934
if err != nil {
4035
return nil, err
4136
}
@@ -45,6 +40,11 @@ func Init(clientID, addr string, timeout time.Duration, handlerFunc func(proto.M
4540
handlerFunc = func(msg proto.Message) {}
4641
}
4742

43+
apiConn, err := connection.GetConnection(conn, timeout, encryptionKey)
44+
if err != nil {
45+
return nil, err
46+
}
47+
4848
c := &Client{
4949
ID: clientID,
5050
conn: conn,
@@ -53,7 +53,15 @@ func Init(clientID, addr string, timeout time.Duration, handlerFunc func(proto.M
5353
stopChan: make(chan bool),
5454
handlerFunc: handlerFunc,
5555
CommunicationTimeout: timeout,
56+
apiConn: apiConn,
5657
}
58+
59+
// call handshake, used in encrypted connection
60+
err = apiConn.Handshake()
61+
if err != nil {
62+
return nil, err
63+
}
64+
5765
go c.messageReader()
5866
return c, nil
5967
}
@@ -103,7 +111,7 @@ func (c *Client) Login(password string) error {
103111
}
104112
connectResponse := message.(*api.ConnectResponse)
105113
if connectResponse.InvalidPassword {
106-
return ErrPassword
114+
return types.ErrPassword
107115
}
108116

109117
return nil
@@ -181,7 +189,7 @@ func (c *Client) messageReader() {
181189

182190
func (c *Client) getMessage() error {
183191
var message proto.Message
184-
message, err := api.ReadMessage(c.reader)
192+
message, err := c.apiConn.Read(c.reader)
185193
if err == nil {
186194
c.lastMessageAt = time.Now()
187195
// check waiting map
@@ -244,20 +252,7 @@ func (c *Client) handleInternal(message proto.Message) bool {
244252
}
245253

246254
func (c *Client) Send(message proto.Message) error {
247-
packed, err := api.Marshal(message)
248-
if err != nil {
249-
return err
250-
}
251-
if err = c.conn.SetWriteDeadline(time.Now().Add(c.CommunicationTimeout)); err != nil {
252-
return err
253-
}
254-
if _, err = c.conn.Write(packed); err != nil {
255-
return err
256-
}
257-
if err = c.conn.SetWriteDeadline(time.Time{}); err != nil {
258-
return err
259-
}
260-
return nil
255+
return c.apiConn.Write(message)
261256
}
262257

263258
func (c *Client) SendAndWaitForResponse(message proto.Message, messageType uint64) (proto.Message, error) {
@@ -276,7 +271,7 @@ func (c *Client) waitForMessage(messageType uint64) (proto.Message, error) {
276271
case message := <-in:
277272
return message, nil
278273
case <-time.After(c.CommunicationTimeout):
279-
return nil, ErrCommunicationTimeout
274+
return nil, types.ErrCommunicationTimeout
280275
}
281276
}
282277

0 commit comments

Comments
 (0)