diff --git a/spackl/__init__.py b/spackl/__init__.py index 6fb335e..f359182 100644 --- a/spackl/__init__.py +++ b/spackl/__init__.py @@ -1,4 +1,29 @@ from spackl import db, file __all__ = [db, file] -__version__ = '0.1.0' +__version__ = '0.1.1' + + +DB_TYPE_MAP = { + 'bigquery': db.BigQuery, + 'postgres': db.Postgres, + 'redshift': db.Redshift, +} +FILE_TYPE_MAP = { + 'csv': file.CSV, +} + + +def get_db(name): + conf = db.Config() + for c in conf.dbs: + if c['name'] == name: + dbtype = c.get('type', 'postgres') + _db = DB_TYPE_MAP.get(dbtype, db.Postgres) + return _db(**c) + raise ValueError('DB with name "{}" not found.'.format(name)) + + +def get_file(filepath, filetype='csv', **kwargs): + _file = FILE_TYPE_MAP.get(filetype, file.CSV) + return _file(filepath, **kwargs) diff --git a/spackl/db/base.py b/spackl/db/base.py index db626bd..0adf38d 100644 --- a/spackl/db/base.py +++ b/spackl/db/base.py @@ -97,3 +97,21 @@ def execute(self, query_string, **kwargs): None """ raise NotImplementedError() + + @abstractmethod + def query_df(self, query_string, **kwargs): + """ + Runs a query against the source database and returns a pandas DataFrame + + If not connected, should call self.connect() first + + Args: + query_string : str - The query to run against the database + + Kwargs: + kwargs : Arbitrary parameters to pass to the query engine + + Returns: + pandas DataFrame + """ + raise NotImplementedError() diff --git a/spackl/db/bigquery.py b/spackl/db/bigquery.py index 4bd038f..3472d48 100644 --- a/spackl/db/bigquery.py +++ b/spackl/db/bigquery.py @@ -74,18 +74,22 @@ def _close(self): """ return - def _query(self, query_string): + def _query(self, query_string, **kwargs): self.connect() - query_job = self._conn.query(query_string) + query_job = self._conn.query(query_string, **kwargs) return query_job.result() - def query(self, query_string): + def query(self, query_string, **kwargs): from .result import QueryResult result = self._query(query_string) return QueryResult(result) - def execute(self, query_string): - self._query(query_string) + def execute(self, query_string, **kwargs): + self._query(query_string, **kwargs) + + def query_df(self, query_string, **kwargs): + result = self._query(query_string, **kwargs) + return result.to_dataframe() def list_tables(self, dataset_id): """ diff --git a/spackl/db/config.py b/spackl/db/config.py index 45a5281..1eca874 100644 --- a/spackl/db/config.py +++ b/spackl/db/config.py @@ -43,7 +43,6 @@ class Config(object): config_file : str - The path of the config file to load TODO(aaronbiller): allow config to create config file and write to it - TODO(aaronbiller): allow instantiation of BaseDb right from the config """ _config = None _dbs = list() diff --git a/spackl/db/postgres.py b/spackl/db/postgres.py index a338857..f62aa5c 100644 --- a/spackl/db/postgres.py +++ b/spackl/db/postgres.py @@ -66,3 +66,10 @@ def execute(self, query_string, **kwargs): with self._conn.begin(): self._query(query_string, **kwargs) self.close() + + def query_df(self, query_string, **kwargs): + from pandas import read_sql + self.connect() + df = read_sql(query_string, self._conn, **kwargs) + self.close() + return df diff --git a/spackl/db/result.py b/spackl/db/result.py index 64e34c8..e9b5c1f 100644 --- a/spackl/db/result.py +++ b/spackl/db/result.py @@ -1,9 +1,6 @@ """ Base classes to house the results of queries against a source databse """ -import six - -from collections import OrderedDict from google.cloud.bigquery.table import RowIterator from sqlalchemy.engine import ResultProxy @@ -12,6 +9,9 @@ class QueryResult(BaseResult): def __init__(self, query_iterator=None): + keys = list() + result = list() + if query_iterator is not None: if not isinstance(query_iterator, (RowIterator, ResultProxy)): raise TypeError( @@ -20,16 +20,11 @@ def __init__(self, query_iterator=None): 'RowIterator, returned by a call to google.cloud.bigquery.Client.query().result()' % type(query_iterator)) - result = [OrderedDict(row) for row in query_iterator] - else: - result = list() - - if result: - keys = list(six.iterkeys(result[0])) - if not all(sorted(keys) == sorted(list(six.iterkeys(x))) for x in result): - raise AttributeError('keys arg does not match all result keys') - else: - keys = list() + if isinstance(query_iterator, ResultProxy): + keys = list(query_iterator.keys()) + elif isinstance(query_iterator, RowIterator): + keys = [field.name for field in query_iterator.schema] + result = list(query_iterator) super(QueryResult, self).__init__(keys, result) diff --git a/spackl/file/base.py b/spackl/file/base.py index 6de812c..458663c 100644 --- a/spackl/file/base.py +++ b/spackl/file/base.py @@ -55,6 +55,21 @@ def query(self, **kwargs): kwargs : Arbitrary parameters to pass to the query method Returns: - IOResult containing the file data + FileResult containing the file data + """ + raise NotImplementedError() + + @abstractmethod + def query_df(self, **kwargs): + """ + Reads the open file object + + If not open, should call self.open() first + + Kwargs: + kwargs : Arbitrary parameters to pass to the query method + + Returns: + pandas DataFrame """ raise NotImplementedError() diff --git a/spackl/file/csv.py b/spackl/file/csv.py index 3561dff..34bc5c4 100644 --- a/spackl/file/csv.py +++ b/spackl/file/csv.py @@ -23,14 +23,10 @@ class CSV(BaseFile): that can be opened, read, and closed Kwargs: name : str - The canonical name to use for this instance - use_pandas : bool - Choose to use pandas to read the csv file, a better option if you plan - to ultimately convert the result to a DataFrame. - NOTE: This causes the query method to return a DataFrame instead of a FileResult csv_kwargs : Parameters to pass to the csv reader (fieldnames, delimiter, dialect, etc) """ - def __init__(self, file_path_or_obj, name=None, use_pandas=False, **csv_kwargs): + def __init__(self, file_path_or_obj, name=None, **csv_kwargs): self._name = name - self._use_pandas = use_pandas self._file = None self._csv_kwargs = dict(**csv_kwargs) @@ -60,7 +56,7 @@ def _set_file(self, file_path_or_obj): _log.warning('Provided path does not exist : %s', file_path_or_obj) file = None else: - if zipfile.is_zipfile(str(file)) and not self._use_pandas: + if zipfile.is_zipfile(str(file)): file = zipfile.ZipFile(str(file)) self._file = file @@ -90,15 +86,7 @@ def _get_dialect(self): self._data.seek(0) return dialect - def _load_using_pandas(self, **kwargs): - from pandas import read_csv - return read_csv(self._file, **kwargs) - - def query(self, use_pandas=False, pd_kwargs=dict(), **kwargs): - if use_pandas or self._use_pandas: - # Skip loading method and return a dataframe - return self._load_using_pandas(**pd_kwargs) - + def query(self, **kwargs): _kwargs = dict(**self._csv_kwargs) _kwargs.update(**kwargs) @@ -112,3 +100,8 @@ def query(self, use_pandas=False, pd_kwargs=dict(), **kwargs): self.close() return result + + def query_df(self, **kwargs): + from pandas import read_csv + fp = self._file.filename if isinstance(self._file, zipfile.ZipFile) else self._file + return read_csv(fp, **kwargs) diff --git a/spackl/result.py b/spackl/result.py index 4f9ea54..53f1705 100644 --- a/spackl/result.py +++ b/spackl/result.py @@ -21,7 +21,7 @@ def __init__(self, keys, row): self._row = row def __repr__(self): - return str(tuple([v for v in self.values()])) + return str(self.values()) def __str__(self): return self.__repr__() @@ -31,8 +31,14 @@ def __bool__(self): __nonzero__ = __bool__ + def __iter__(self): + for val in self._row: + yield val + def __getattr__(self, name): - return self.__getitem__(str(name)) + if name in self._keys: + return self.__getitem__(str(name)) + return super(ResultRow, self).__getattr__(name) def __getitem__(self, key): if isinstance(key, six.string_types): @@ -56,11 +62,15 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + @property + def columns(self): + return self._keys + def values(self): - return tuple(six.itervalues(self._row)) + return tuple(self._row) def keys(self): - return list(six.iterkeys(self._row)) + return self._keys def items(self): for key, value in six.iteritems(self._row): @@ -82,7 +92,7 @@ class ResultCol(object): key : str - The key (AKA column name) of the query result col : tuple - The column of data from the query result """ - __slots__ = ['_index', '_key', '_col'] + __slots__ = ['_key', '_col'] def __init__(self, key, col): self._key = key @@ -101,18 +111,8 @@ def __bool__(self): __nonzero__ = __bool__ def __iter__(self): - self._index = 0 - return self - - def __next__(self): - self._index += 1 - try: - col = self._col[self._index - 1] - except IndexError: - raise StopIteration - return col - - next = __next__ + for val in self._col: + yield val def __getattr__(self, name): return self.__getitem__(str(name)) @@ -145,39 +145,21 @@ def __eq__(self, other): def __ne__(self, other): return not self == other - def _rquery_format(self): - """ - Special method for models.SourcePair to call when formatting the rquery + def values(self): + return tuple(self._col) - Returns: - str - The properly formatted values of this column - """ - # Filter out None values - col = [v for v in self._col if v is not None] - _len = len(col) + def keys(self): + return [self._key] - # Calling str() on string type values to avoid unicode strings in python 2 - if _len == 0: - # Gracefully handle an empty column - # Fill the result with a nonsense value to prevent false positives - return str("('__xxx__EMPTYRESULT__xxx__')") - elif _len == 1: - # Handle a single-value column without creating an invalid sql syntax. - # When formatting the rquery in a SourcePair, using the query with a - # value like (1,) will raise a syntax error. - v = col[0] - if isinstance(v, six.string_types): - v = "'{}'".format(str(v)) - return str('({})'.format(v)) - else: - return str(tuple([str(v) if isinstance(v, six.string_types) else v for v in col])) + def items(self): + yield (self._key, self._col) class BaseResult(object): """ Base class for containing contents of a query result or a file """ - __slots__ = ['_index', '_keys', '_result'] + __slots__ = ['_keys', '_result'] def __init__(self, keys, result): if not isinstance(keys, list): @@ -186,13 +168,13 @@ def __init__(self, keys, result): if not isinstance(result, list): result = [result] - self._result = result + self._result = [ResultRow(keys, r) for r in result] def __repr__(self): return ''.format(br=self) def __str__(self): - return self.json() + return self.jsons() def __bool__(self): return bool(self._result) @@ -200,18 +182,8 @@ def __bool__(self): __nonzero__ = __bool__ def __iter__(self): - self._index = 0 - return self - - def __next__(self): - self._index += 1 - try: - row = self._result[self._index - 1] - except IndexError: - raise StopIteration - return ResultRow(self._keys, row) - - next = __next__ + for row in self._result: + yield row def __getattr__(self, name): return self.__getitem__(str(name)) @@ -225,8 +197,7 @@ def __getitem__(self, key): value = ResultCol(key, col) elif isinstance(key, six.integer_types): # Return the row corresponding with this index - row = self._result[key] - value = ResultRow(self._keys, row) + value = self._result[key] elif isinstance(key, slice): # Return the rows corresponding with this slice sliced = [self._result[ii] for ii in range(*key.indices(len(self._result)))] @@ -246,18 +217,15 @@ def __eq__(self, other): def __ne__(self, other): return not self == other - @classmethod - def _from_part(cls, keys, result): + @staticmethod + def _from_part(keys, result): """ Get a new BaseResult from an existing sliced or filtered result Returns: BaseResult """ - br = cls([], []) - br._keys = keys - br._result = result - return br + return BaseResult(keys, result) @property def empty(self): @@ -270,14 +238,8 @@ def empty(self): return bool(not self._keys and not self._result) @property - def result(self): - """ - Get the full results of the query as a list of dicts - - Returns: - list of dicts - [{column: value, ... }, ... ] - """ - return copy.deepcopy(self._result) + def columns(self): + return self._keys def dict(self): """ @@ -290,21 +252,21 @@ def dict(self): def json(self): """ - Get the full results of the query as a json string + Get the full results of the query as a list of dicts Returns: - string + list of dicts """ - return json.dumps(self._result, cls=DtDecEncoder) + return [{k: row[k] for k in self._keys} for row in self._result] - def list(self): + def jsons(self): """ - Get the full results of the query as a row-based list of tuples + Get the full results of the query as a json string Returns: - list of tuples - [(value, ... ), ... ] + string """ - return [tuple([v for v in six.itervalues(row)]) for row in self._result] + return json.dumps(self.json(), cls=DtDecEncoder) def df(self, *args, **kwargs): """ @@ -314,28 +276,9 @@ def df(self, *args, **kwargs): pandas.DataFrame """ from pandas import DataFrame - return DataFrame(self._result, *args, **kwargs) - - def first(self): - """ - Get the first row of the result - - Returns: - QueryResultRow - """ - return ResultRow(self._keys, self._result[0]) - - def values(self): - for item in self.list(): - yield item - - def keys(self): - for key in self._keys: - yield key - - def items(self): - for key, value in six.iteritems(self.dict()): - yield (key, value) + _kwargs = {'columns': self._keys} + _kwargs.update(**kwargs) + return DataFrame.from_records(self._result, *args, **_kwargs) def get(self, key, default=None): try: @@ -352,19 +295,19 @@ def pop(self, index=-1): index : int - The index of the row to remove and return Returns: - QueryResultRow + ResultRow """ return ResultRow(self._keys, self._result.pop(index)) def append(self, other): """ - Append a QueryResultRow with matching keys to the current result + Append a ResultRow with matching keys to the current result Args: - other : QueryResultRow - The row to append + other : ResultRow - The row to append """ if not isinstance(other, ResultRow): - raise NotImplementedError('Appending object must be a QueryResultRow') + raise NotImplementedError('Appending object must be a ResultRow') if self.empty: self._keys = other._keys elif self._keys != other._keys: @@ -373,33 +316,35 @@ def append(self, other): def extend(self, other): """ - Extend the current result with another QueryResult with matching keys + Extend the current result with another BaseResult with matching keys Args: - other : QueryResult - The row to append + other : BaseResult - The results to concat """ if not isinstance(other, BaseResult): - raise NotImplementedError('Extending object must be a QueryResult') + raise NotImplementedError('Extending object must be a BaseResult') if self.empty: self._keys = other._keys + elif other.empty: + other._keys = self._keys elif self._keys != other._keys: - raise ValueError('Keys in other QueryResult to not match, cannot extend') + raise ValueError('Keys in other BaseResult to not match, cannot extend') self._result.extend(other._result) def filter(self, predicate, inplace=False): """ - Filter the query result rows + Filter the result rows Args: predicate : callable - The function to apply to each result row, should return a boolean Kwargs: - inplace : boolean - Whether to alter the results in-place or return a new QueryResult object + inplace : boolean - Whether to alter the results in-place or return a new BaseResult object Returns: - QueryResult with filtered results + BaseResult with filtered results """ - filtered = [row for row in self._result if predicate(ResultRow(self._keys, row))] + filtered = [row for row in self._result if predicate(row)] if inplace: self._result = filtered return None