Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion spackl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,29 @@
from spackl import db, file

__all__ = [db, file]
__version__ = '0.1.0'
__version__ = '0.1.1'


DB_TYPE_MAP = {
'bigquery': db.BigQuery,
'postgres': db.Postgres,
'redshift': db.Redshift,
}
FILE_TYPE_MAP = {
'csv': file.CSV,
}


def get_db(name):
conf = db.Config()
for c in conf.dbs:
if c['name'] == name:
dbtype = c.get('type', 'postgres')
_db = DB_TYPE_MAP.get(dbtype, db.Postgres)
return _db(**c)
raise ValueError('DB with name "{}" not found.'.format(name))


def get_file(filepath, filetype='csv', **kwargs):
_file = FILE_TYPE_MAP.get(filetype, file.CSV)
return _file(filepath, **kwargs)
18 changes: 18 additions & 0 deletions spackl/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,21 @@ def execute(self, query_string, **kwargs):
None
"""
raise NotImplementedError()

@abstractmethod
def query_df(self, query_string, **kwargs):
"""
Runs a query against the source database and returns a pandas DataFrame

If not connected, should call self.connect() first

Args:
query_string : str - The query to run against the database

Kwargs:
kwargs : Arbitrary parameters to pass to the query engine

Returns:
pandas DataFrame
"""
raise NotImplementedError()
14 changes: 9 additions & 5 deletions spackl/db/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,22 @@ def _close(self):
"""
return

def _query(self, query_string):
def _query(self, query_string, **kwargs):
self.connect()
query_job = self._conn.query(query_string)
query_job = self._conn.query(query_string, **kwargs)
return query_job.result()

def query(self, query_string):
def query(self, query_string, **kwargs):
from .result import QueryResult
result = self._query(query_string)
return QueryResult(result)

def execute(self, query_string):
self._query(query_string)
def execute(self, query_string, **kwargs):
self._query(query_string, **kwargs)

def query_df(self, query_string, **kwargs):
result = self._query(query_string, **kwargs)
return result.to_dataframe()

def list_tables(self, dataset_id):
"""
Expand Down
1 change: 0 additions & 1 deletion spackl/db/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class Config(object):
config_file : str - The path of the config file to load

TODO(aaronbiller): allow config to create config file and write to it
TODO(aaronbiller): allow instantiation of BaseDb right from the config
"""
_config = None
_dbs = list()
Expand Down
7 changes: 7 additions & 0 deletions spackl/db/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ def execute(self, query_string, **kwargs):
with self._conn.begin():
self._query(query_string, **kwargs)
self.close()

def query_df(self, query_string, **kwargs):
from pandas import read_sql
self.connect()
df = read_sql(query_string, self._conn, **kwargs)
self.close()
return df
21 changes: 8 additions & 13 deletions spackl/db/result.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""
Base classes to house the results of queries against a source databse
"""
import six

from collections import OrderedDict
from google.cloud.bigquery.table import RowIterator
from sqlalchemy.engine import ResultProxy

Expand All @@ -12,6 +9,9 @@

class QueryResult(BaseResult):
def __init__(self, query_iterator=None):
keys = list()
result = list()

if query_iterator is not None:
if not isinstance(query_iterator, (RowIterator, ResultProxy)):
raise TypeError(
Expand All @@ -20,16 +20,11 @@ def __init__(self, query_iterator=None):
'RowIterator, returned by a call to google.cloud.bigquery.Client.query().result()'
% type(query_iterator))

result = [OrderedDict(row) for row in query_iterator]
else:
result = list()

if result:
keys = list(six.iterkeys(result[0]))
if not all(sorted(keys) == sorted(list(six.iterkeys(x))) for x in result):
raise AttributeError('keys arg does not match all result keys')
else:
keys = list()
if isinstance(query_iterator, ResultProxy):
keys = list(query_iterator.keys())
elif isinstance(query_iterator, RowIterator):
keys = [field.name for field in query_iterator.schema]
result = list(query_iterator)

super(QueryResult, self).__init__(keys, result)

Expand Down
17 changes: 16 additions & 1 deletion spackl/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ def query(self, **kwargs):
kwargs : Arbitrary parameters to pass to the query method

Returns:
IOResult containing the file data
FileResult containing the file data
"""
raise NotImplementedError()

@abstractmethod
def query_df(self, **kwargs):
"""
Reads the open file object

If not open, should call self.open() first

Kwargs:
kwargs : Arbitrary parameters to pass to the query method

Returns:
pandas DataFrame
"""
raise NotImplementedError()
23 changes: 8 additions & 15 deletions spackl/file/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,10 @@ class CSV(BaseFile):
that can be opened, read, and closed
Kwargs:
name : str - The canonical name to use for this instance
use_pandas : bool - Choose to use pandas to read the csv file, a better option if you plan
to ultimately convert the result to a DataFrame.
NOTE: This causes the query method to return a DataFrame instead of a FileResult
csv_kwargs : Parameters to pass to the csv reader (fieldnames, delimiter, dialect, etc)
"""
def __init__(self, file_path_or_obj, name=None, use_pandas=False, **csv_kwargs):
def __init__(self, file_path_or_obj, name=None, **csv_kwargs):
self._name = name
self._use_pandas = use_pandas

self._file = None
self._csv_kwargs = dict(**csv_kwargs)
Expand Down Expand Up @@ -60,7 +56,7 @@ def _set_file(self, file_path_or_obj):
_log.warning('Provided path does not exist : %s', file_path_or_obj)
file = None
else:
if zipfile.is_zipfile(str(file)) and not self._use_pandas:
if zipfile.is_zipfile(str(file)):
file = zipfile.ZipFile(str(file))
self._file = file

Expand Down Expand Up @@ -90,15 +86,7 @@ def _get_dialect(self):
self._data.seek(0)
return dialect

def _load_using_pandas(self, **kwargs):
from pandas import read_csv
return read_csv(self._file, **kwargs)

def query(self, use_pandas=False, pd_kwargs=dict(), **kwargs):
if use_pandas or self._use_pandas:
# Skip loading method and return a dataframe
return self._load_using_pandas(**pd_kwargs)

def query(self, **kwargs):
_kwargs = dict(**self._csv_kwargs)
_kwargs.update(**kwargs)

Expand All @@ -112,3 +100,8 @@ def query(self, use_pandas=False, pd_kwargs=dict(), **kwargs):
self.close()

return result

def query_df(self, **kwargs):
from pandas import read_csv
fp = self._file.filename if isinstance(self._file, zipfile.ZipFile) else self._file
return read_csv(fp, **kwargs)
Loading