Skip to content

Commit

Permalink
Merge pull request #9342 from [BEAM-7866][BEAM-5148] Cherry-picks mon…
Browse files Browse the repository at this point in the history
…godb fixes to 2.15.0 release branch
  • Loading branch information
yifanzou authored Aug 14, 2019
2 parents 45de258 + cc9e966 commit 7931ec0
Show file tree
Hide file tree
Showing 3 changed files with 436 additions and 112 deletions.
225 changes: 183 additions & 42 deletions sdks/python/apache_beam/io/mongodbio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,35 @@
"""

from __future__ import absolute_import
from __future__ import division

import logging

from bson import objectid
from pymongo import MongoClient
from pymongo import ReplaceOne
import struct

import apache_beam as beam
from apache_beam.io import iobase
from apache_beam.io.range_trackers import OffsetRangeTracker
from apache_beam.io.range_trackers import OrderedPositionRangeTracker
from apache_beam.transforms import DoFn
from apache_beam.transforms import PTransform
from apache_beam.transforms import Reshuffle
from apache_beam.utils.annotations import experimental

try:
# Mongodb has its own bundled bson, which is not compatible with bson pakcage.
# (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
# it fails because bson package is installed, MongoDB IO will not work but at
# least rest of the SDK will work.
from bson import objectid

# pymongo also internally depends on bson.
from pymongo import ASCENDING
from pymongo import DESCENDING
from pymongo import MongoClient
from pymongo import ReplaceOne
except ImportError:
objectid = None
logging.warning("Could not find a compatible bson package.")

__all__ = ['ReadFromMongoDB', 'WriteToMongoDB']


Expand Down Expand Up @@ -139,50 +153,49 @@ def __init__(self,
self.filter = filter
self.projection = projection
self.spec = extra_client_params
self.doc_count = self._get_document_count()
self.avg_doc_size = self._get_avg_document_size()
self.client = None

def estimate_size(self):
return self.avg_doc_size * self.doc_count
with MongoClient(self.uri, **self.spec) as client:
return client[self.db].command('collstats', self.coll).get('size')

def split(self, desired_bundle_size, start_position=None, stop_position=None):
# use document cursor index as the start and stop positions
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = self.doc_count
start_position, stop_position = self._replace_none_positions(
start_position, stop_position)

# get an estimate on how many documents should be included in a split batch
desired_bundle_count = desired_bundle_size // self.avg_doc_size
desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
split_keys = self._get_split_keys(desired_bundle_size_in_mb, start_position,
stop_position)

bundle_start = start_position
while bundle_start < stop_position:
bundle_end = min(stop_position, bundle_start + desired_bundle_count)
yield iobase.SourceBundle(weight=bundle_end - bundle_start,
for split_key_id in split_keys:
if bundle_start >= stop_position:
break
bundle_end = min(stop_position, split_key_id)
yield iobase.SourceBundle(weight=desired_bundle_size_in_mb,
source=self,
start_position=bundle_start,
stop_position=bundle_end)
bundle_start = bundle_end
# add range of last split_key to stop_position
if bundle_start < stop_position:
yield iobase.SourceBundle(weight=desired_bundle_size_in_mb,
source=self,
start_position=bundle_start,
stop_position=stop_position)

def get_range_tracker(self, start_position, stop_position):
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = self.doc_count
return OffsetRangeTracker(start_position, stop_position)
start_position, stop_position = self._replace_none_positions(
start_position, stop_position)
return _ObjectIdRangeTracker(start_position, stop_position)

def read(self, range_tracker):
with MongoClient(self.uri, **self.spec) as client:
# docs is a MongoDB Cursor
docs = client[self.db][self.coll].find(
filter=self.filter, projection=self.projection
)[range_tracker.start_position():range_tracker.stop_position()]
for index in range(range_tracker.start_position(),
range_tracker.stop_position()):
if not range_tracker.try_claim(index):
all_filters = self._merge_id_filter(range_tracker)
docs_cursor = client[self.db][self.coll].find(filter=all_filters)
for doc in docs_cursor:
if not range_tracker.try_claim(doc['_id']):
return
yield docs[index - range_tracker.start_position()]
yield doc

def display_data(self):
res = super(_BoundedMongoSource, self).display_data()
Expand All @@ -194,18 +207,146 @@ def display_data(self):
res['mongo_client_spec'] = self.spec
return res

def _get_avg_document_size(self):
def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
# calls mongodb splitVector command to get document ids at split position
# for desired bundle size, if desired chunk size smaller than 1mb, use
# mongodb default split size of 1mb.
if desired_chunk_size_in_mb < 1:
desired_chunk_size_in_mb = 1
if start_pos >= end_pos:
# single document not splittable
return []
with MongoClient(self.uri, **self.spec) as client:
size = client[self.db].command('collstats', self.coll).get('avgObjSize')
if size is None or size <= 0:
raise ValueError(
'Collection %s not found or average doc size is '
'incorrect', self.coll)
return size

def _get_document_count(self):
name_space = '%s.%s' % (self.db, self.coll)
return (client[self.db].command(
'splitVector',
name_space,
keyPattern={'_id': 1}, # Ascending index
min={'_id': start_pos},
max={'_id': end_pos},
maxChunkSize=desired_chunk_size_in_mb)['splitKeys'])

def _merge_id_filter(self, range_tracker):
# Merge the default filter with refined _id field range of range_tracker.
# see more at https://docs.mongodb.com/manual/reference/operator/query/and/
all_filters = {
'$and': [
self.filter.copy(),
# add additional range filter to query. $gte specifies start
# position(inclusive) and $lt specifies the end position(exclusive),
# see more at
# https://docs.mongodb.com/manual/reference/operator/query/gte/ and
# https://docs.mongodb.com/manual/reference/operator/query/lt/
{
'_id': {
'$gte': range_tracker.start_position(),
'$lt': range_tracker.stop_position()
}
},
]
}

return all_filters

def _get_head_document_id(self, sort_order):
with MongoClient(self.uri, **self.spec) as client:
return max(client[self.db][self.coll].count_documents(self.filter), 0)
cursor = client[self.db][self.coll].find(filter={}, projection=[]).sort([
('_id', sort_order)
]).limit(1)
try:
return cursor[0]['_id']
except IndexError:
raise ValueError('Empty Mongodb collection')

def _replace_none_positions(self, start_position, stop_position):
if start_position is None:
start_position = self._get_head_document_id(ASCENDING)
if stop_position is None:
last_doc_id = self._get_head_document_id(DESCENDING)
# increment last doc id binary value by 1 to make sure the last document
# is not excluded
stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
return start_position, stop_position


class _ObjectIdHelper(object):
"""A Utility class to manipulate bson object ids."""

@classmethod
def id_to_int(cls, id):
"""
Args:
id: ObjectId required for each MongoDB document _id field.
Returns: Converted integer value of ObjectId's 12 bytes binary value.
"""
# converts object id binary to integer
# id object is bytes type with size of 12
ints = struct.unpack('>III', id.binary)
return (ints[0] << 64) + (ints[1] << 32) + ints[2]

@classmethod
def int_to_id(cls, number):
"""
Args:
number(int): The integer value to be used to convert to ObjectId.
Returns: The ObjectId that has the 12 bytes binary converted from the
integer value.
"""
# converts integer value to object id. Int value should be less than
# (2 ^ 96) so it can be convert to 12 bytes required by object id.
if number < 0 or number >= (1 << 96):
raise ValueError('number value must be within [0, %s)' % (1 << 96))
ints = [(number & 0xffffffff0000000000000000) >> 64,
(number & 0x00000000ffffffff00000000) >> 32,
number & 0x0000000000000000ffffffff]

bytes = struct.pack('>III', *ints)
return objectid.ObjectId(bytes)

@classmethod
def increment_id(cls, object_id, inc):
"""
Args:
object_id: The ObjectId to change.
inc(int): The incremental int value to be added to ObjectId.
Returns:
"""
# increment object_id binary value by inc value and return new object id.
id_number = _ObjectIdHelper.id_to_int(object_id)
new_number = id_number + inc
if new_number < 0 or new_number >= (1 << 96):
raise ValueError('invalid incremental, inc value must be within ['
'%s, %s)' % (0 - id_number, 1 << 96 - id_number))
return _ObjectIdHelper.int_to_id(new_number)


class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
"""RangeTracker for tracking mongodb _id of bson ObjectId type."""

def position_to_fraction(self, pos, start, end):
pos_number = _ObjectIdHelper.id_to_int(pos)
start_number = _ObjectIdHelper.id_to_int(start)
end_number = _ObjectIdHelper.id_to_int(end)
return (pos_number - start_number) / (end_number - start_number)

def fraction_to_position(self, fraction, start, end):
start_number = _ObjectIdHelper.id_to_int(start)
end_number = _ObjectIdHelper.id_to_int(end)
total = end_number - start_number
pos = int(total * fraction + start_number)
# make sure split position is larger than start position and smaller than
# end position.
if pos <= start_number:
return _ObjectIdHelper.increment_id(start, 1)
if pos >= end_number:
return _ObjectIdHelper.increment_id(end, -1)
return _ObjectIdHelper.int_to_id(pos)


@experimental()
Expand Down
18 changes: 10 additions & 8 deletions sdks/python/apache_beam/io/mongodbio_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,33 @@ def run(argv=None):
default=default_coll,
help='mongo uri string for connection')
parser.add_argument('--num_documents',
default=1000,
default=100000,
help='The expected number of documents to be generated '
'for write or read',
type=int)
parser.add_argument('--batch_size',
default=100,
default=10000,
help=('batch size for writing to mongodb'))
known_args, pipeline_args = parser.parse_known_args(argv)

# Test Write to MongoDB
with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
start_time = time.time()
logging.info('Writing %d documents to mongodb' % known_args.num_documents)
docs = [{
'number': x,
'number_mod_2': x % 2,
'number_mod_3': x % 3
} for x in range(known_args.num_documents)]

start_time = time.time()
_ = p | 'Create documents' >> beam.Create(docs) \
| 'WriteToMongoDB' >> beam.io.WriteToMongoDB(known_args.mongo_uri,
known_args.mongo_db,
known_args.mongo_coll,
known_args.batch_size)
logging.info('Writing %d documents to mongodb finished in %.3f seconds' %
(known_args.num_documents, time.time() - start_time))
elapsed = time.time() - start_time
logging.info('Writing %d documents to mongodb finished in %.3f seconds' %
(known_args.num_documents, elapsed))

# Test Read from MongoDB
with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
Expand All @@ -80,11 +81,12 @@ def run(argv=None):
known_args.mongo_coll,
projection=['number']) \
| 'Map' >> beam.Map(lambda doc: doc['number'])

assert_that(
r, equal_to([number for number in range(known_args.num_documents)]))
logging.info('Read %d documents from mongodb finished in %.3f seconds' %
(known_args.num_documents, time.time() - start_time))

elapsed = time.time() - start_time
logging.info('Read %d documents from mongodb finished in %.3f seconds' %
(known_args.num_documents, elapsed))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 7931ec0

Please sign in to comment.