@@ -2,7 +2,6 @@ package client
2
2
3
3
import (
4
4
"bufio"
5
- "errors"
6
5
"fmt"
7
6
"net"
8
7
"sync"
@@ -11,15 +10,10 @@ import (
11
10
"google.golang.org/protobuf/proto"
12
11
13
12
"github.com/mycontroller-org/esphome_api/pkg/api"
13
+ "github.com/mycontroller-org/esphome_api/pkg/connection"
14
14
types "github.com/mycontroller-org/esphome_api/pkg/types"
15
15
)
16
16
17
- // Error types
18
- var (
19
- ErrPassword = errors .New ("esphome_api: invalid password" )
20
- ErrCommunicationTimeout = errors .New ("esphome_api: communication timeout" )
21
- )
22
-
23
17
// Client struct.
24
18
type Client struct {
25
19
ID string
@@ -31,11 +25,12 @@ type Client struct {
31
25
lastMessageAt time.Time
32
26
handlerFunc func (proto.Message )
33
27
CommunicationTimeout time.Duration
28
+ apiConn connection.ApiConnection
34
29
}
35
30
36
- // Init func
37
- func Init (clientID , addr string , timeout time.Duration , handlerFunc func (proto.Message )) (* Client , error ) {
38
- conn , err := net .DialTimeout ("tcp" , addr , timeout )
31
+ // GetClient returns esphome api client
32
+ func GetClient (clientID , address , encryptionKey string , timeout time.Duration , handlerFunc func (proto.Message )) (* Client , error ) {
33
+ conn , err := net .DialTimeout ("tcp" , address , timeout )
39
34
if err != nil {
40
35
return nil , err
41
36
}
@@ -45,6 +40,11 @@ func Init(clientID, addr string, timeout time.Duration, handlerFunc func(proto.M
45
40
handlerFunc = func (msg proto.Message ) {}
46
41
}
47
42
43
+ apiConn , err := connection .GetConnection (conn , timeout , encryptionKey )
44
+ if err != nil {
45
+ return nil , err
46
+ }
47
+
48
48
c := & Client {
49
49
ID : clientID ,
50
50
conn : conn ,
@@ -53,7 +53,15 @@ func Init(clientID, addr string, timeout time.Duration, handlerFunc func(proto.M
53
53
stopChan : make (chan bool ),
54
54
handlerFunc : handlerFunc ,
55
55
CommunicationTimeout : timeout ,
56
+ apiConn : apiConn ,
56
57
}
58
+
59
+ // call handshake, used in encrypted connection
60
+ err = apiConn .Handshake ()
61
+ if err != nil {
62
+ return nil , err
63
+ }
64
+
57
65
go c .messageReader ()
58
66
return c , nil
59
67
}
@@ -103,7 +111,7 @@ func (c *Client) Login(password string) error {
103
111
}
104
112
connectResponse := message .(* api.ConnectResponse )
105
113
if connectResponse .InvalidPassword {
106
- return ErrPassword
114
+ return types . ErrPassword
107
115
}
108
116
109
117
return nil
@@ -181,7 +189,7 @@ func (c *Client) messageReader() {
181
189
182
190
func (c * Client ) getMessage () error {
183
191
var message proto.Message
184
- message , err := api . ReadMessage (c .reader )
192
+ message , err := c . apiConn . Read (c .reader )
185
193
if err == nil {
186
194
c .lastMessageAt = time .Now ()
187
195
// check waiting map
@@ -244,20 +252,7 @@ func (c *Client) handleInternal(message proto.Message) bool {
244
252
}
245
253
246
254
func (c * Client ) Send (message proto.Message ) error {
247
- packed , err := api .Marshal (message )
248
- if err != nil {
249
- return err
250
- }
251
- if err = c .conn .SetWriteDeadline (time .Now ().Add (c .CommunicationTimeout )); err != nil {
252
- return err
253
- }
254
- if _ , err = c .conn .Write (packed ); err != nil {
255
- return err
256
- }
257
- if err = c .conn .SetWriteDeadline (time.Time {}); err != nil {
258
- return err
259
- }
260
- return nil
255
+ return c .apiConn .Write (message )
261
256
}
262
257
263
258
func (c * Client ) SendAndWaitForResponse (message proto.Message , messageType uint64 ) (proto.Message , error ) {
@@ -276,7 +271,7 @@ func (c *Client) waitForMessage(messageType uint64) (proto.Message, error) {
276
271
case message := <- in :
277
272
return message , nil
278
273
case <- time .After (c .CommunicationTimeout ):
279
- return nil , ErrCommunicationTimeout
274
+ return nil , types . ErrCommunicationTimeout
280
275
}
281
276
}
282
277
0 commit comments