diff --git a/src/write_batch.rs b/src/write_batch.rs index 4f9630d..4c6bda4 100644 --- a/src/write_batch.rs +++ b/src/write_batch.rs @@ -141,7 +141,10 @@ impl WriteBatchPy { let inner = inner_mut!(self)?; let key = encode_key(key, self.raw_mode)?; let value = encode_value(value, &self.dumps, self.raw_mode)?; - match column_family { + match column_family + .as_ref() + .or(self.default_column_family.as_ref()) + { Some(cf) => inner.put_cf(&cf.cf, key, value), None => inner.put(key, value), } @@ -203,7 +206,10 @@ impl WriteBatchPy { ) -> PyResult<()> { let inner = inner_mut!(self)?; let key = encode_key(key, self.raw_mode)?; - match column_family { + match column_family + .as_ref() + .or(self.default_column_family.as_ref()) + { Some(cf) => inner.delete_cf(&cf.cf, key), None => inner.delete(key), } diff --git a/test/test_rdict.py b/test/test_rdict.py index 712e7c1..ca55ccc 100644 --- a/test/test_rdict.py +++ b/test/test_rdict.py @@ -590,6 +590,58 @@ def tearDownClass(cls): Rdict.destroy(cls.path, cls.opt) +class TestWriteBatch(unittest.TestCase): + test_dict = None + opt = None + path = "./temp_write_batch" + cf2 = str("square") + cf3 = str("qubic") + + @classmethod + def setUpClass(cls) -> None: + cls.opt = Options(raw_mode=False) + cls.opt.create_if_missing(True) + cls.opt.create_missing_column_families(True) + cf = {cls.cf2: Options(), cls.cf3: Options()} + cls.test_dict = Rdict(cls.path, column_families=cf, options=cls.opt) + + def test_write_batch_default_none(self): + assert self.test_dict is not None + wb = WriteBatch() + for i in range(100): + wb.put(i, i) + wb.put(i, i**2, self.test_dict.get_column_family_handle(self.cf2)) + self.test_dict.write(wb) + square_dict = self.test_dict.get_column_family(self.cf2) + for i in range(100): + # default column family is None, write to default column family + self.assertEqual(self.test_dict[i], i) + self.assertEqual(square_dict[i], i**2) + + def test_write_batch_default(self): + assert self.test_dict is not None + wb = WriteBatch() + wb.set_default_column_family(self.test_dict.get_column_family_handle(self.cf3)) + for i in range(100): + wb.put(i, i**3) + wb.put(i, i**2, self.test_dict.get_column_family_handle(self.cf2)) + self.test_dict.write(wb) + square_dict = self.test_dict.get_column_family(self.cf2) + quibic_dict = self.test_dict.get_column_family(self.cf3) + for i in range(100): + # default column family is set, write to default column family + self.assertEqual(square_dict[i], i**2) + self.assertEqual(quibic_dict[i], i**3) + + @classmethod + def tearDownClass(cls): + assert cls.test_dict is not None + cls.test_dict.close() + assert cls.opt is not None + gc.collect() + Rdict.destroy(cls.path, cls.opt) + + class TestWideColumnsRaw(unittest.TestCase): test_dict = None opt = None