Skip to content

Commit

Permalink
Merge pull request #143 from rstudio/spacy
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelizimm authored Apr 18, 2023
2 parents ebaddb5 + e41d4a8 commit 152dd6f
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ You can use vetiver with:
- [torch](https://pytorch.org/)
- [statsmodels](https://www.statsmodels.org/stable/index.html)
- [xgboost](https://xgboost.readthedocs.io/en/stable/)
- [spacy](https://spacy.io/)
- or utilize [custom handlers](https://rstudio.github.io/vetiver-python/stable/advancedusage/custom_handler.html) to support your own models!

## Installation
Expand Down
1 change: 1 addition & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ quartodoc:
- TorchHandler
- StatsmodelsHandler
- XGBoostHandler
- SpacyHandler

metadata-files:
- _sidebar.yml
Expand Down
1 change: 1 addition & 0 deletions docs/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ You can use vetiver with:
- [torch](https://pytorch.org/)
- [statsmodels](https://www.statsmodels.org/stable/index.html)
- [xgboost](https://xgboost.readthedocs.io/en/stable/)
- [spacy](https://spacy.io/)
- or utilize [custom handlers](https://rstudio.github.io/vetiver-python/stable/advancedusage/custom_handler.html) to support your own models!

## Installation
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ all_models =
vetiver[torch]
vetiver[statsmodels]
vetiver[xgboost]
vetiver[spacy]

dev =
pytest
Expand All @@ -67,6 +68,9 @@ torch =
xgboost =
xgboost

spacy =
spacy

typecheck =
pyright
pandas-stubs
1 change: 1 addition & 0 deletions vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .handlers.torch import TorchHandler # noqa
from .handlers.statsmodels import StatsmodelsHandler # noqa
from .handlers.xgboost import XGBoostHandler # noqa
from .handlers.spacy import SpacyHandler # noqa
from .helpers import api_data_to_frame # noqa
from .rsconnect import deploy_rsconnect # noqa
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
Expand Down
83 changes: 83 additions & 0 deletions vetiver/handlers/spacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from .base import BaseHandler
from ..prototype import vetiver_create_prototype
from ..helpers import api_data_to_frame

import pandas as pd

spacy_exists = True
try:
import spacy
except ImportError:
spacy_exists = False


class SpacyHandler(BaseHandler):
"""Handler class for creating VetiverModels with spacy.
Parameters
----------
model :
a trained and fit spacy model
"""

model_class = staticmethod(lambda: spacy.Language)

if spacy_exists:
pip_name = "spacy"

def construct_prototype(self):
"""Create data prototype for a spacy model, which is one column of string data
Returns
-------
prototype :
Input data prototype for spacy model
"""
if self.prototype_data is not None and not isinstance(
self.prototype_data, (pd.Series, pd.DataFrame, dict)
): # wrong type
raise TypeError(
"Spacy prototype must be a dict, pandas Series, or pandas DataFrame"
)
elif (
isinstance(self.prototype_data, pd.DataFrame)
and len(self.prototype_data.columns) != 1
): # is dataframe, more than one column
raise ValueError("Spacy prototype data must be a 1-column pandas DataFrame")
elif (
isinstance(self.prototype_data, dict) and len(self.prototype_data) != 1
): # is dict, more than one key
raise ValueError("Spacy prototype data must be a dictionary with 1 key")

prototype = vetiver_create_prototype(self.prototype_data)

return prototype

def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
at each API call.
Parameters
----------
input_data:
Test data
Returns
-------
prediction
Prediction from model
"""
if not spacy_exists:
raise ImportError("Cannot import `spacy`")

response_body = []

input_data = api_data_to_frame(input_data)

for doc in self.model.pipe(input_data.iloc[:, 0]):
response_body.append(doc.to_json())

return pd.Series(response_body)
6 changes: 6 additions & 0 deletions vetiver/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def _(pred_data):
return pd.DataFrame([dict(s) for s in pred_data])


@api_data_to_frame.register(pd.DataFrame)
def _pd_frame(pred_data):

return pred_data


@api_data_to_frame.register(dict)
def _dict(pred_data):
return api_data_to_frame([pred_data])
Expand Down
170 changes: 170 additions & 0 deletions vetiver/tests/test_spacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import pytest

spacy = pytest.importorskip("spacy", reason="spacy library not installed")

import numpy as np # noqa
import pandas as pd # noqa
from fastapi.testclient import TestClient # noqa
from numpy import nan # noqa
import vetiver # noqa


@spacy.language.Language.component("animals")
def animal_component_function(doc):
matches = matcher(doc) # noqa
spans = [
spacy.tokens.Span(doc, start, end, label="ANIMAL")
for match_id, start, end in matches
]
doc.ents = spans
return doc


nlp = spacy.blank("en")
animals = list(nlp.pipe(["dog", "cat", "turtle"]))
matcher = spacy.matcher.PhraseMatcher(nlp.vocab)
matcher.add("ANIMAL", animals)
nlp.add_pipe("animals")


@pytest.fixture
def spacy_model():
return nlp


@pytest.fixture()
def vetiver_client_with_prototype(spacy_model): # With check_prototype=True
df = pd.DataFrame({"new_column": ["one", "two", "three"]})
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=df)
app = vetiver.VetiverAPI(v, check_prototype=True)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


@pytest.fixture(scope="function")
def vetiver_client_with_prototype_series(spacy_model): # With check_prototype=True
df = pd.Series({"new_column": ["one", "two", "three"]})
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=df)
app = vetiver.VetiverAPI(v, check_prototype=True)
app.app.root_path = "/predict"
client = TestClient(app.app)
return client


@pytest.fixture
def vetiver_client_no_prototype(spacy_model): # With check_prototype=False
v = vetiver.VetiverModel(spacy_model, "animals")
app = vetiver.VetiverAPI(v, check_prototype=False)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


@pytest.mark.parametrize("data", ["a", 1, [1, 2, 3]])
def test_bad_prototype_data(data, spacy_model):
with pytest.raises(TypeError):
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)


@pytest.mark.parametrize(
"data",
[
{"col": ["1", "2"], "col2": [1, 2]},
pd.DataFrame({"col": ["1", "2"], "col2": [1, 2]}),
],
)
def test_bad_prototype_shape(data, spacy_model):
with pytest.raises(ValueError):
vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)


@pytest.mark.parametrize("data", [{"col": "1"}, pd.DataFrame({"col": ["1"]})])
def test_good_prototype_shape(data, spacy_model):
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)

assert v.prototype.construct().dict() == {"col": "1"}


def test_vetiver_predict_with_prototype(vetiver_client_with_prototype):
df = pd.DataFrame({"new_column": ["turtles", "i have a dog"]})

response = vetiver.predict(endpoint=vetiver_client_with_prototype, data=df)

assert isinstance(response, pd.DataFrame), response
assert response.to_dict() == {
"0": {
"text": "turtles",
"ents": [],
"sents": [{"start": 0, "end": 7}],
"tokens": [{"id": 0, "start": 0, "end": 7}],
},
"1": {
"text": "i have a dog",
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
"sents": nan,
"tokens": [
{"id": 0, "start": 0, "end": 1},
{"id": 1, "start": 2, "end": 6},
{"id": 2, "start": 7, "end": 8},
{"id": 3, "start": 9, "end": 12},
],
},
}


def test_vetiver_predict_no_prototype(vetiver_client_no_prototype):
df = pd.DataFrame({"uhhh": ["turtles", "i have a dog"]})

response = vetiver.predict(endpoint=vetiver_client_no_prototype, data=df)

assert isinstance(response, pd.DataFrame), response
assert response.to_dict() == {
"0": {
"text": "turtles",
"ents": [],
"sents": [{"start": 0, "end": 7}],
"tokens": [{"id": 0, "start": 0, "end": 7}],
},
"1": {
"text": "i have a dog",
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
"sents": nan,
"tokens": [
{"id": 0, "start": 0, "end": 1},
{"id": 1, "start": 2, "end": 6},
{"id": 2, "start": 7, "end": 8},
{"id": 3, "start": 9, "end": 12},
],
},
}


def test_serialize_no_prototype(spacy_model):
import pins

board = pins.board_temp(allow_pickle_read=True)
v = vetiver.VetiverModel(spacy_model, "animals")
vetiver.vetiver_pin_write(board=board, model=v)
v2 = vetiver.VetiverModel.from_pin(board, "animals")
assert isinstance(
v2.model,
spacy.lang.en.English,
)


def test_serialize_prototype(spacy_model):
import pins

board = pins.board_temp(allow_pickle_read=True)
v = vetiver.VetiverModel(
spacy_model, "animals", prototype_data=pd.DataFrame({"text": ["text"]})
)
vetiver.vetiver_pin_write(board=board, model=v)
v2 = vetiver.VetiverModel.from_pin(board, "animals")
assert isinstance(
v2.model,
spacy.lang.en.English,
)

0 comments on commit 152dd6f

Please sign in to comment.