Skip to content

Commit 6498e68

Browse files
committed
misc fixes from code review
1 parent ccea5cb commit 6498e68

File tree

16 files changed

+212
-87
lines changed

16 files changed

+212
-87
lines changed

.env

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ WAVE_ENV=development
66
WAVE_TLS=false
77
WAVE_DB_HOSTNAME=127.0.0.1
88
WAVE_DB_USERNAME=postgres
9-
WAVE_DB_PASSWORD=pastgres
9+
WAVE_DB_PASSWORD=postgres
10+
WAVE_DB_ADMIN_USERNAME=postgres
11+
WAVE_DB_ADMIN_PASSWORD=postgres
1012
WAVE_DB_NAME=wave_dev
1113
WAVE_DB_TLS=disable

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ You'll need `npm`, `go`, and `docker-compose` available.
88

99
### Installing dependencies
1010

11-
Install [reflex](https://github.com/cespare/reflex), [forego](https://github.com/ddollar/forego), and [go-bindata](https://github.com/jteeuwen/go-bindata), run `go get -t` and `npm install`.
11+
Install [reflex](https://github.com/cespare/reflex), [forego](https://github.com/ddollar/forego), and [go-bindata](https://github.com/jteeuwen/go-bindata), run `make embedded-assets`, `go get -t`, and `npm install`.
1212

1313
```
1414
make deps

controllers/controllers.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ func init() {
1414
for _, client := range VisualClients {
1515
err := client.WriteJSON(event)
1616
if err != nil {
17-
log.Warn(err)
17+
log.WithFields(log.Fields{
18+
"at": "controllers.init",
19+
"error": err.Error(),
20+
}).Warn("error writing visual update to client")
1821
}
1922
}
2023
}

controllers/router.go

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package controllers
33
import (
44
"github.com/gin-gonic/contrib/static"
55
"github.com/gin-gonic/gin"
6+
"github.com/hkparker/Wave/engines/visualizer"
67
"github.com/hkparker/Wave/helpers"
78
"github.com/hkparker/Wave/middleware"
89
)
@@ -51,6 +52,8 @@ func NewAPI() *gin.Engine {
5152
}
5253

5354
func NewCollector() *gin.Engine {
55+
visualizer.Load()
56+
5457
collector := gin.Default()
5558
collector.GET("/frames", pollCollector)
5659
return collector

controllers/users_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ func TestAdminCanAssignUserPassword(t *testing.T) {
336336
assert.Equal(200, resp.StatusCode)
337337
user.Reload()
338338
assert.Equal(true, user.ValidAuthentication("1234"))
339-
//assert.Equal(0, len(user.Sessions))
339+
user.DestroyAllSessions()
340+
assert.Equal(0, len(user.Sessions))
340341
}
341342

342343
func TestUserCannotAssignUserPassword(t *testing.T) {

controllers/visualizer.go

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ func streamVisualization(c *gin.Context) {
2525
for _, event := range visualizer.CatchupEvents() {
2626
err := conn.WriteJSON(event)
2727
if err != nil {
28+
log.WithFields(log.Fields{
29+
"at": "controllers.streamVisualization",
30+
"error": err.Error(),
31+
}).Error("error writing catch-up event")
2832
}
2933
}
3034
VisualClients[id] = conn

engines/ids/ids.go

+26-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package ids
22

33
import (
4+
"encoding/json"
45
"fmt"
56
log "github.com/Sirupsen/logrus"
67
"github.com/hkparker/Wave/helpers"
78
"github.com/hkparker/Wave/models"
89
"github.com/robertkrimen/otto"
10+
"sync"
911
)
1012

1113
var VMs = make(map[string][]*otto.Otto, 0)
@@ -15,11 +17,19 @@ var Alerts = make(chan models.Alert, 0)
1517
func init() {
1618
go processAlerts()
1719
go prepareVMs()
18-
buildVMs()
1920
}
2021

2122
var alerting_function = func(call otto.FunctionCall) otto.Value {
22-
Alerts <- models.Alert{} //call.Argument(0).String()
23+
new_alert := models.Alert{}
24+
err := json.Unmarshal([]byte(call.Argument(0).String()), new_alert)
25+
if err != nil {
26+
log.WithFields(log.Fields{
27+
"at": "ids.alerting_function",
28+
"error": err.Error(),
29+
}).Error("bad alert from rule")
30+
} else {
31+
Alerts <- new_alert
32+
}
2333
return otto.Value{}
2434
}
2535

@@ -59,7 +69,11 @@ func buildVMs() (vm_set []*otto.Otto) {
5969
vm := otto.New()
6070
_, err := vm.Run(string(rule_data))
6171
if err != nil {
62-
log.Error(err)
72+
log.WithFields(log.Fields{
73+
"at": "ids.buildVMs",
74+
"file": rule_file,
75+
"error": err.Error(),
76+
}).Error("error loading rule data into VM")
6377
}
6478
vm.Set("alert", alerting_function)
6579
vm_set = append(vm_set, vm)
@@ -79,10 +93,15 @@ func Insert(frame string, parsed models.Wireless80211Frame) {
7993
vm_set = <-NewVMs
8094
VMs[parsed.Interface] = vm_set
8195
}
96+
var evals sync.WaitGroup
8297
for _, vm := range vm_set {
83-
_, err := vm.Run(fmt.Sprintf("evaluate(%s)", frame))
84-
if err != nil {
85-
log.Error(err)
86-
}
98+
evals.Add(1)
99+
go func(vm *otto.Otto) {
100+
_, err := vm.Run(fmt.Sprintf("evaluate(%s)", frame))
101+
if err != nil {
102+
log.WithFields(log.Fields{}).Error(err)
103+
}
104+
}(vm)
87105
}
106+
evals.Wait()
88107
}

engines/ids/rules/example-rule.js

+7
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,11 @@ var manifest = {
77
function evaluate(frame) {
88
// We can pretty print each from by logging JSON.stringify
99
//console.log(JSON.stringify(frame, null, 2))
10+
11+
// If this frame results in an IDS alert, we send an
12+
alert(JSON.stringify({
13+
"Title": "Example rule got a frame",
14+
"Rule": manifest.Name,
15+
"Severity": "low"
16+
}))
1017
}

engines/visualizer/visualizer.go

+19-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,26 @@ import (
77

88
var Devices = make(map[string]models.Device)
99
var DevicesMux sync.Mutex
10+
var Networks = make(map[string][]models.Network)
11+
var NetworksMux sync.Mutex
1012

11-
func init() {
12-
// Load devices from DB
13+
func Load() {
14+
DevicesMux.Lock()
15+
defer DevicesMux.Unlock()
16+
NetworksMux.Lock()
17+
defer NetworksMux.Unlock()
18+
19+
all_devices := make([]models.Device, 0)
20+
models.Orm.Find(&all_devices)
21+
for _, device := range all_devices {
22+
Devices[device.MAC] = device
23+
}
24+
25+
all_networks := make([]models.Network, 0)
26+
models.Orm.Find(&all_networks)
27+
for _, network := range all_networks {
28+
Networks[network.SSID] = append(Networks[network.SSID], network)
29+
}
1330
}
1431

1532
func Insert(frame models.Wireless80211Frame) {

helpers/database.go

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
var DBHostname string
99
var DBUsername string
1010
var DBPassword string
11+
var DBAdminUsername string
12+
var DBAdminPassword string
1113
var DBName string
1214
var DBTLS string
1315

@@ -18,11 +20,13 @@ func setDatabase() {
1820
}
1921

2022
DBUsername = os.Getenv("WAVE_DB_USERNAME")
23+
DBAdminUsername = os.Getenv("WAVE_DB_ADMIN_USERNAME")
2124
if DBUsername == "" {
2225
log.Fatal("WAVE_DB_USERNAME envar must be provided")
2326
}
2427

2528
DBPassword = os.Getenv("WAVE_DB_PASSWORD")
29+
DBAdminPassword = os.Getenv("WAVE_DB_ADMIN_PASSWORD")
2630
if DBPassword == "" {
2731
log.Fatal("WAVE_DB_PASSWORD envar must be provided")
2832
}

main.go

+1-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@ import (
1111
func main() {
1212
// Setup environment
1313
helpers.Setup()
14-
models.Connect(
15-
helpers.DBUsername,
16-
helpers.DBPassword,
17-
helpers.DBName,
18-
helpers.DBTLS,
19-
)
14+
models.Connect()
2015

2116
// Start Collector server
2217
go helpers.RunTLS(

models/collector.go

+24-6
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ func Collectors() (collectors []Collector, err error) {
4444
func CreateCollector(name string) (collector Collector, err error) {
4545
cert_data, key_data, err := newCollectorKeys()
4646
if err != nil {
47-
log.Errorf("failed to create collector: %s", err)
47+
log.WithFields(log.Fields{
48+
"at": "models.CreateCollector",
49+
"error": err.Error(),
50+
}).Error("failed to create collector")
4851
return
4952
}
5053
collector = Collector{
@@ -60,7 +63,10 @@ func newCollectorKeys() (cert_data []byte, key_data []byte, err error) {
6063
api_cert := APITLSCertificate()
6164
ca, err := x509.ParseCertificate(api_cert.Certificate[0])
6265
if err != nil {
63-
log.Errorf("error parsing API TLS certificate for new collector: %s", err)
66+
log.WithFields(log.Fields{
67+
"at": "models.newColletorKeys",
68+
"error": err.Error(),
69+
}).Error("error parsing API TLS certificate for new collector")
6470
return
6571
}
6672
ca_key := api_cert.PrivateKey
@@ -78,21 +84,30 @@ func newCollectorKeys() (cert_data []byte, key_data []byte, err error) {
7884
}
7985
collector_priv, err := rsa.GenerateKey(rand.Reader, 2048)
8086
if err != nil {
81-
log.Errorf("error generating private key for new collector: %s", err)
87+
log.WithFields(log.Fields{
88+
"at": "models.newColletorKeys",
89+
"error": err.Error(),
90+
}).Error("error generating private key for new collector")
8291
return
8392
}
8493
collector_pub := &collector_priv.PublicKey
8594
collector_cert_data, err := x509.CreateCertificate(rand.Reader, collector_cert, ca, collector_pub, ca_key)
8695
if err != nil {
87-
log.Errorf("error creating collector certificate: %s", err)
96+
log.WithFields(log.Fields{
97+
"at": "models.newColletorKeys",
98+
"error": err.Error(),
99+
}).Error("error creating collector certificate")
88100
return
89101
}
90102

91103
// Create PEM encoding of certificate
92104
var cert_buffer bytes.Buffer
93105
err = pem.Encode(&cert_buffer, &pem.Block{Type: "CERTIFICATE", Bytes: collector_cert_data})
94106
if err != nil {
95-
log.Errorf("could not PEM encode collector certificate data: %s", err)
107+
log.WithFields(log.Fields{
108+
"at": "models.newColletorKeys",
109+
"error": err.Error(),
110+
}).Error("could not PEM encode collector certificate data")
96111
return
97112
}
98113
cert_data = cert_buffer.Bytes()
@@ -101,7 +116,10 @@ func newCollectorKeys() (cert_data []byte, key_data []byte, err error) {
101116
var key_buffer bytes.Buffer
102117
err = pem.Encode(&key_buffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(collector_priv)})
103118
if err != nil {
104-
log.Errorf("could not PEM encode collector key data: %s", err)
119+
log.WithFields(log.Fields{
120+
"at": "models.newColletorKeys",
121+
"error": err.Error(),
122+
}).Error("could not PEM encode collector key data")
105123
return
106124
}
107125
key_data = key_buffer.Bytes()

models/database.go

+37-31
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,57 @@ package models
33
import (
44
"fmt"
55
log "github.com/Sirupsen/logrus"
6+
"github.com/hkparker/Wave/helpers"
67
"github.com/jinzhu/gorm"
78
_ "github.com/jinzhu/gorm/dialects/postgres"
89
_ "github.com/jinzhu/gorm/dialects/sqlite"
910
)
1011

1112
var Orm *gorm.DB
1213

13-
func Connect(db_username, db_password, db_name, db_ssl string) {
14-
db_check_args := fmt.Sprintf(
15-
"user=%s password=%s sslmode=%s",
16-
db_username,
17-
db_password,
18-
db_ssl,
19-
)
20-
var err error
21-
check, err := gorm.Open("postgres", db_check_args)
22-
if err != nil {
23-
log.WithFields(log.Fields{
24-
"at": "models.Connect",
25-
"user": db_username,
26-
"ssl": db_ssl,
27-
"error": err.Error(),
28-
}).Fatal("unable to connect to database server")
29-
}
30-
if check.Exec(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name)).RowsAffected != 1 {
31-
check.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
32-
log.WithFields(log.Fields{
33-
"at": "models.Connect",
34-
"db_name": db_name,
35-
}).Info("created missing database")
14+
func Connect() {
15+
if helpers.DBAdminUsername != "" {
16+
db_check_args := fmt.Sprintf(
17+
"user=%s password=%s sslmode=%s",
18+
helpers.DBAdminUsername,
19+
helpers.DBAdminPassword,
20+
helpers.DBTLS,
21+
)
22+
var err error
23+
check, err := gorm.Open("postgres", db_check_args)
24+
if err != nil {
25+
log.WithFields(log.Fields{
26+
"at": "models.Connect",
27+
"user": helpers.DBAdminUsername,
28+
"ssl": helpers.DBTLS,
29+
"error": err.Error(),
30+
}).Fatal("unable to connect to database server")
31+
}
32+
if check.Exec(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", helpers.DBName)).RowsAffected != 1 {
33+
check.Exec(fmt.Sprintf("CREATE DATABASE %s", helpers.DBName))
34+
log.WithFields(log.Fields{
35+
"at": "models.Connect",
36+
"db_name": helpers.DBName,
37+
}).Info("created missing database")
38+
}
3639
}
40+
3741
db_args := fmt.Sprintf(
3842
"user=%s password=%s sslmode=%s dbname=%s",
39-
db_username,
40-
db_password,
41-
db_ssl,
42-
db_name,
43+
helpers.DBUsername,
44+
helpers.DBPassword,
45+
helpers.DBTLS,
46+
helpers.DBName,
4347
)
48+
var err error
4449
Orm, err = gorm.Open("postgres", db_args)
4550
if err != nil {
4651
log.WithFields(log.Fields{
47-
"at": "models.Connect",
48-
"user": db_username,
49-
"ssl": db_ssl,
50-
"error": err.Error(),
52+
"at": "models.Connect",
53+
"user": helpers.DBUsername,
54+
"ssl": helpers.DBTLS,
55+
"dbname": helpers.DBName,
56+
"error": err.Error(),
5157
}).Fatal("unable to connect to database server")
5258
}
5359
Setup()

0 commit comments

Comments
 (0)