diff --git a/bolt.go b/bolt.go new file mode 100644 index 0000000..611154f --- /dev/null +++ b/bolt.go @@ -0,0 +1,242 @@ +package shield + +import ( + "encoding/binary" + "fmt" + "log" + "os" + + "github.com/boltdb/bolt" +) + +type BoltStore struct { + bolt *bolt.DB + path string + sumKey string + classKey string + classesKey string + logger *log.Logger + prefix string +} + +func NewBoltStore(path string, logger *log.Logger, prefix string) *BoltStore { + bs := &BoltStore{ + path: path, + sumKey: "shield:sum", + classKey: "shield:class", + classesKey: "shield:classes", + logger: logger, + prefix: prefix, + } + + bs.init() + return bs +} + +type Bucket struct { + *bolt.Bucket +} + +func (b Bucket) Get(key string) int64 { + buff := b.Bucket.Get([]byte(key)) + if val, n := binary.Varint(buff); n > 0 { + return val + } + return 0 +} + +func (b Bucket) IncrementBy(key string, inc int64) error { + ret := b.Bucket.Get([]byte(key)) + + value := int64(0) + if val, n := binary.Varint(ret); n > 0 { + value = val + } + + value += int64(inc) + + buff := make([]byte, 8) + binary.PutVarint(buff, int64(value)) + + if err := b.Bucket.Put([]byte(key), buff); err != nil { + return err + } + + return nil +} + +func (b Bucket) Update(key string, value int64) error { + buff := make([]byte, 8) + binary.PutVarint(buff, value) + return b.Bucket.Put([]byte(key), buff) +} + +func (rs *BoltStore) init() (conn *bolt.DB, err error) { + if rs.bolt == nil { + db, err := bolt.Open(rs.path, 0600, nil) + if err != nil { + return nil, err + } + + rs.bolt = db + + tx, err := db.Begin(true) + if err != nil { + return nil, err + } + + defer tx.Rollback() + + if _, err := tx.CreateBucketIfNotExists([]byte(rs.sumKey)); err != nil { + return nil, err + } + + if _, err := tx.CreateBucketIfNotExists([]byte(rs.classKey)); err != nil { + return nil, err + } + + if _, err := tx.CreateBucketIfNotExists([]byte(rs.classesKey)); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + } + return rs.bolt, nil +} + +func (rs *BoltStore) Classes() (a []string, err error) { + err = rs.bolt.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(rs.classesKey)) + + return b.ForEach(func(k, v []byte) error { + a = append(a, string(v)) + return nil + }) + }) + + return +} + +func (rs *BoltStore) AddClass(class string) (err error) { + if class == "" { + return fmt.Errorf("invalid class: %s", class) + } + + err = rs.bolt.Update(func(tx *bolt.Tx) error { + b := Bucket{tx.Bucket([]byte(rs.classesKey))} + return b.Update(class, 0) + }) + + return +} + +func (rs *BoltStore) ClassWordCounts(class string, words []string) (mc map[string]int64, err error) { + key := fmt.Sprintf("%s:%s", rs.classKey, class) + + if err = rs.bolt.Update(func(tx *bolt.Tx) error { + b := Bucket{tx.Bucket([]byte(key))} + + mc = make(map[string]int64) + for _, v := range words { + mc[v] = b.Get(v) + } + + return nil + }); err != nil { + return + } + + return +} + +func (rs *BoltStore) IncrementClassWordCounts(m map[string]map[string]int64) (err error) { + type tuple struct { + word string + d int64 + } + + decrTuples := make(map[string][]*tuple, len(m)) + + if err = rs.bolt.Update(func(tx *bolt.Tx) error { + sb := Bucket{tx.Bucket([]byte(rs.sumKey))} + + for class, words := range m { + for word, d := range words { + if d > 0 { + key := fmt.Sprintf("%s:%s", rs.classKey, class) + + if bucket, err := tx.CreateBucketIfNotExists([]byte(key)); err == nil { + b := Bucket{bucket} + b.IncrementBy(word, d) + } + + sb.IncrementBy(class, d) + } else { + decrTuples[class] = append(decrTuples[class], &tuple{ + word: word, + d: d, + }) + } + } + } + + for class, paths := range decrTuples { + key := fmt.Sprintf("%s:%s", rs.classKey, class) + + b := Bucket{tx.Bucket([]byte(key))} + + for _, path := range paths { + if x := b.Get(path.word); x != 0 { + d := path.d + if (x + d) < 0 { + d = x * -1 + } + + b.IncrementBy(path.word, d) + sb.IncrementBy(class, d) + } + } + } + + return nil + }); err != nil { + return err + } + + return +} + +func (rs *BoltStore) TotalClassWordCounts() (m map[string]int64, err error) { + m = make(map[string]int64) + + err = rs.bolt.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(rs.sumKey)) + + cursor := b.Cursor() + for k, val := cursor.First(); k != nil; k, val = cursor.Next() { + value, _ := binary.Varint(val) + m[string(k)] = int64(value) + } + + return nil + }) + + return +} + +func (rs *BoltStore) Reset() (err error) { + + if rs.bolt != nil { + rs.bolt.Close() + + defer os.Remove(rs.path) + + rs.bolt = nil + } + + rs.init() + return +} diff --git a/bolt_test.go b/bolt_test.go new file mode 100644 index 0000000..c3bbbac --- /dev/null +++ b/bolt_test.go @@ -0,0 +1,25 @@ +package shield + +import ( + "io/ioutil" + "testing" +) + +func TempFileName() string { + f, _ := ioutil.TempFile("", "") + return f.Name() +} + +var ( + boltStore = NewBoltStore(TempFileName(), logger, "") +) + +func TestBoltLearn(t *testing.T) { + sh := newShield(boltStore) + testLearn(t, sh) +} + +func TestBoltDecrement(t *testing.T) { + sh := newShield(boltStore) + testDecrement(t, sh) +} diff --git a/redis_test.go b/redis_test.go new file mode 100644 index 0000000..0fff1d3 --- /dev/null +++ b/redis_test.go @@ -0,0 +1,17 @@ +package shield + +import "testing" + +var ( + redisStore = NewRedisStore("127.0.0.1:6379", "", logger, "redis") +) + +func TestRedisLearn(t *testing.T) { + sh := newShield(redisStore) + testLearn(t, sh) +} + +func TestRedisDecrement(t *testing.T) { + sh := newShield(redisStore) + testDecrement(t, sh) +} diff --git a/shield_test.go b/shield_test.go index 725c6ca..7a57777 100644 --- a/shield_test.go +++ b/shield_test.go @@ -5,10 +5,26 @@ import ( "io/ioutil" "log" "os" + "reflect" "strings" "testing" ) +var ( + logger = log.New(os.Stderr, "", log.LstdFlags) +) + +func newShield(store Store) Shield { + tokenizer := NewEnglishTokenizer() + + sh := New(tokenizer, store) + err := sh.Reset() + if err != nil { + panic(err) + } + return sh +} + func readDataSet(dataFile, labelFile string, t *testing.T) []string { d, err := ioutil.ReadFile("testdata/" + dataFile) if err != nil { @@ -34,21 +50,7 @@ func readDataSet(dataFile, labelFile string, t *testing.T) []string { return a } -func newShield() Shield { - logger := log.New(os.Stderr, "", log.LstdFlags) - store := NewRedisStore("127.0.0.1:6379", "", logger, "redis") - tokenizer := NewEnglishTokenizer() - - sh := New(tokenizer, store) - err := sh.Reset() - if err != nil { - panic(err) - } - return sh -} - -func TestLearn(t *testing.T) { - sh := newShield() +func testLearn(t *testing.T, sh Shield) { testData := readDataSet("testdata.txt", "testlabels.txt", t) trainData := readDataSet("traindata.txt", "trainlabels.txt", t) @@ -88,8 +90,7 @@ func TestLearn(t *testing.T) { } } -func TestDecrement(t *testing.T) { - sh := newShield() +func testDecrement(t *testing.T, sh Shield) { sh.Learn("a", "hello") sh.Learn("a", "sunshine") sh.Learn("a", "tree") @@ -109,8 +110,13 @@ func TestDecrement(t *testing.T) { if err != nil { t.Fatal(err) } - if r := fmt.Sprintf("%v", m); r != "map[hello:0 sunshine:1 tree:0 water:1]" { - t.Fatal(r) + if !reflect.DeepEqual(m, map[string]int64{ + "hello": 0, + "sunshine": 1, + "tree": 0, + "water": 1, + }) { + t.Fatal(fmt.Sprintf("%v", m)) } m2, err := s.store.ClassWordCounts("b", []string{ @@ -120,8 +126,12 @@ func TestDecrement(t *testing.T) { if err != nil { t.Fatal(err) } - if r := fmt.Sprintf("%v", m2); r != "map[hello:0 iamb!:0]" { - t.Fatal(r) + + if !reflect.DeepEqual(m2, map[string]int64{ + "iamb!": 0, + "hello": 0, + }) { + t.Fatal(fmt.Sprintf("%v", m2)) } wc, err := s.store.TotalClassWordCounts()