Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Reddit as a third party datasource to EvaDB #1452

Open
wants to merge 3 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions docs/source/reference/databases/reddit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Reddit
==========

The connection to Reddit is based on the `praw <https://praw.readthedocs.io/>`_ library.

Dependency
----------

* praw


Parameters
----------

Required:

* ``subreddit`` is the name of the subreddit from which the data is fetched.
* ``clientId`` is the unique identifier issued to the client when creating credentials on Reddit. Refer to the [First Steps](https://github.com/reddit-archive/reddit/wiki/OAuth2-Quick-Start-Example#first-steps) guide for more details on how to get this and the next two parameters.
* ``clientSecret`` is the secret key obtained when credentials are created that is used for authentication and authorization.
* ``userAgent`` is a string of your choosing that explains your use of the the Reddit API. More details are available in the guide linked above.

Optional:


Create Connection
-----------------

.. code-block:: text

CREATE DATABASE reddit_data WITH ENGINE = 'reddit', PARAMETERS = {
"subreddit": "AskReddit",
"client_id": "abcd",
"clientSecret": "abcd1234",
"userAgent": "Eva DB Staging Build"
};

Supported Tables
----------------

* ``submissions``: Lists top submissions in the given subreddit. Check `databases/reddit/table_column_info.py` for all the available columns in the table.

.. code-block:: sql

SELECT * FROM hackernews_data.search_results LIMIT 3;

.. note::

Looking for another table from Hackernews? Please raise a `Feature Request <https://github.com/georgia-tech-db/evadb/issues/new/choose>`_.
2 changes: 2 additions & 0 deletions evadb/third_party/databases/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _get_database_handler(engine: str, **kwargs):
return mod.HackernewsSearchHandler(engine, **kwargs)
elif engine == "slack":
return mod.SlackHandler(engine, **kwargs)
elif engine == "reddit":
return mod.RedditHandler(engine, **kwargs)
else:
raise NotImplementedError(f"Engine {engine} is not supported")

Expand Down
15 changes: 15 additions & 0 deletions evadb/third_party/databases/reddit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""third party/applications/reddit"""
168 changes: 168 additions & 0 deletions evadb/third_party/databases/reddit/reddit_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
from praw import Reddit
from prawcore import ResponseException

from .table_column_info import SUBMISSION_COLUMNS
from ..types import DBHandler, DBHandlerResponse, DBHandlerStatus


class RedditHandler(DBHandler):
def __init__(self, name: str, **kwargs):
super().__init__(name)
self.clientId = kwargs.get("client_id")
self.clientSecret = kwargs.get("clientSecret")
self.userAgent = kwargs.get("userAgent")
self.subreddit = kwargs.get("subreddit")

def connect(self):
try:
self.client = Reddit(
client_id=self.clientId,
client_secret=self.clientSecret,
user_agent=self.userAgent
)
return DBHandlerStatus(status=True)
except Exception as e:
return DBHandlerStatus(status=False, error=str(e))

@property
def supported_table(self):
def _submission_generator():
for submission in self.client.subreddit(self.subreddit).hot(): #TODO: REMOVE LIMIT
yield {
property_name: getattr(submission, property_name)
for property_name, _ in SUBMISSION_COLUMNS
}

mapping = {
"submissions": {
"columns": SUBMISSION_COLUMNS,
"generator": _submission_generator(),
},
}
return mapping

def disconnect(self):
"""
No action required to disconnect from Reddit datasource
TODO: Add support for destroying session token if used in other flows
"""
return
#raise NotImplementedError()

def check_connection(self) -> DBHandlerStatus:
try:
self.client.user.me()
except ResponseException as e:
return DBHandlerStatus(status=False, error=f"Received ResponseException: {e.response}")
return DBHandlerStatus(status=True)

def get_tables(self) -> DBHandlerResponse:
connection_status = self.check_connection()
if not connection_status.status:
return DBHandlerResponse(data=None, error=str(connection_status))

try:
tables_df = pd.DataFrame(
list(self.supported_table.keys()), columns=["table_name"]
)
return DBHandlerResponse(data=tables_df)
except Exception as e:
return DBHandlerResponse(data=None, error=str(e))

def get_columns(self, table_name: str) -> DBHandlerResponse:
columns = self.supported_table[table_name]["columns"]
columns_df = pd.DataFrame(columns, columns=["name", "dtype"])
return DBHandlerResponse(data=columns_df)

def select(self, table_name: str) -> DBHandlerResponse:
"""
Returns a generator that yields the data from the given table.
Args:
table_name (str): name of the table whose data is to be retrieved.
Returns:
DBHandlerResponse
"""
if not self.client:
return DBHandlerResponse(data=None, error="Not connected to the database.")
try:
if table_name not in self.supported_table:
return DBHandlerResponse(
data=None,
error="{} is not supported or does not exist.".format(table_name),
)
# TODO: Projection column trimming optimization opportunity
return DBHandlerResponse(
data=None,
data_generator=self.supported_table[table_name]["generator"],
)
except Exception as e:
return DBHandlerResponse(data=None, error=str(e))

# def post_message(self, message) -> DBHandlerResponse:
# try:
# response = self.client.chat_postMessage(channel=self.channel, text=message)
# return DBHandlerResponse(data=response["message"]["text"])
# except SlackApiError as e:
# assert e.response["ok"] is False
# assert e.response["error"]
# return DBHandlerResponse(data=None, error=e.response["error"])
#
# def _convert_json_response_to_DataFrame(self, json_response):
# messages = json_response["messages"]
# columns = ["text", "ts", "user"]
# data_df = pd.DataFrame(columns=columns)
# for message in messages:
# if message["text"] and message["ts"] and message["user"]:
# data_df.loc[len(data_df.index)] = [
# message["text"],
# message["ts"],
# message["user"],
# ]
# return data_df
#
# def get_messages(self) -> DBHandlerResponse:
# try:
# channels = self.client.conversations_list(
# types="public_channel,private_channel"
# )["channels"]
# channel_ids = {c["name"]: c["id"] for c in channels}
# response = self.client.conversations_history(
# channel=channel_ids[self.channel_name]
# )
# data_df = self._convert_json_response_to_DataFrame(response)
# return data_df
#
# except SlackApiError as e:
# assert e.response["ok"] is False
# assert e.response["error"]
# return DBHandlerResponse(data=None, error=e.response["error"])
#
# def del_message(self, timestamp) -> DBHandlerResponse:
# try:
# self.client.chat_delete(channel=self.channel, ts=timestamp)
# except SlackApiError as e:
# assert e.response["ok"] is False
# assert e.response["error"]
# return DBHandlerResponse(data=None, error=e.response["error"])

# def execute_native_query(self, query_string: str) -> DBHandlerResponse:
# """
# TODO: integrate code for executing query on Reddit
# """
# raise NotImplementedError()
27 changes: 27 additions & 0 deletions evadb/third_party/databases/reddit/table_column_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Union

SUBMISSION_COLUMNS = [
["author", str],
["author_flair_text", Union[str, None]],
["clicked", bool],
["created_utc", str],
["distinguished", bool],
["edited", bool],
["id", str],
["is_original_content", bool],
["is_self", bool],
["link_flair_text", Union[str, None]],
["locked", bool],
["name", str],
["num_comments", int],
["over_18", bool],
["permalink", str],
["saved", bool],
["score", float],
["selftext", str],
["spoiler", bool],
["stickied", bool],
["title", str],
["upvote_ratio", float],
["url", str]
]
Empty file added run_reddit_command.py
Empty file.
1 change: 1 addition & 0 deletions script/formatting/spelling.txt
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ PlanOprType
Popen
PostgresHandler
PostgresNativeStorageEngineTest
praw
PredicateExecutor
PredicatePlan
PredictEmployee
Expand Down
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def read(path, encoding="utf-8"):
"replicate"
]

reddit_libs = [
"praw"
]

### NEEDED FOR DEVELOPER TESTING ONLY

dev_libs = [
Expand Down Expand Up @@ -183,8 +187,9 @@ def read(path, encoding="utf-8"):
"xgboost": xgboost_libs,
"forecasting": forecasting_libs,
"hackernews": hackernews_libs,
"reddit": reddit_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs + reddit_libs
}

setup(
Expand Down
60 changes: 60 additions & 0 deletions test/integration_tests/long/test_reddit_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

from test.markers import reddit_skip_marker
from test.util import get_evadb_for_testing

import pytest

from evadb.server.command_handler import execute_query_fetch_all
from evadb.third_party.databases.reddit.table_column_info import SUBMISSION_COLUMNS


@pytest.mark.notparallel
class RedditDataSourceTest(unittest.TestCase):
def setUp(self):
self.evadb = get_evadb_for_testing()
# reset the catalog manager before running each test
self.evadb.catalog().reset()

def tearDown(self):
execute_query_fetch_all(self.evadb, "DROP DATABASE IF EXISTS reddit_data;")

@reddit_skip_marker
def test_should_run_select_query_on_reddit(self):
# Create database.
params = {
"subreddit": "cricket",
"client_id": 'clientid..',
"client_secret": 'clientsecret..',
"user_agent": 'test script for dev eva'
}
query = f"""CREATE DATABASE reddit_data
WITH ENGINE = "reddit",
PARAMETERS = {params};"""
execute_query_fetch_all(self.evadb, query)

query = "SELECT * FROM reddit_data.submissions LIMIT 10;"
batch = execute_query_fetch_all(self.evadb, query)
self.assertEqual(len(batch), 10)
expected_column = list(
["submissions.{}".format(col) for col, _ in SUBMISSION_COLUMNS]
)
self.assertEqual(batch.columns, expected_column)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions test/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,7 @@
stable_diffusion_skip_marker = pytest.mark.skipif(
is_replicate_available() is False, reason="requires replicate"
)

reddit_skip_marker = pytest.mark.skip(
reason="requires Reddit secret key"
)