Skip to content

Commit

Permalink
Contextual Generate model
Browse files Browse the repository at this point in the history
Signed-off-by: Sean Smith <[email protected]>
  • Loading branch information
sean-smith committed Mar 3, 2025
1 parent cb16cf6 commit 26236cd
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Contextual LLM Integration for LlamaIndex

This package provides a Contextual LLM integration for LlamaIndex.

## Installation

```bash
pip install llama-index-llms-contextual
```

## Usage

```python
from llama_index.llms.contextual import Contextual

llm = Contextual(model="contextual-clm", api_key="your_api_key")

response = llm.complete("Explain the importance of Grounded Language Models.")
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.llms.contextual.base import Contextual

__all__ = ["Contextual"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
from typing import Any, Optional

from llama_index.llms.openai_like import OpenAILike
from pydantic import Field
from typing import List
from llama_index.core.llms.callbacks import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.core.base.llms.types import (
CompletionResponse,
CompletionResponseGen,
ChatResponse,
ChatResponseGen,
ChatResponseAsyncGen,
CompletionResponseAsyncGen,
LLMMetadata,
MessageRole,
ChatMessage,
)


from contextual import ContextualAI

class Contextual(OpenAILike):
"""
Generate a response using Contextual's Grounded Language Model (GLM), an LLM engineered specifically to prioritize faithfulness to in-context retrievals over parametric knowledge to reduce hallucinations in Retrieval-Augmented Generation.
The total request cannot exceed 32,000 tokens. Email [email protected] with any feedback or questions.
Examples:
`pip install llama-index-llms-contextual`
```python
from llama_index.llms.contextual import Contextual
# Set up the Contextual class with the required model and API key
llm = Contextual(model="contextual-clm", api_key="your_api_key")
# Call the complete method with a query
response = llm.complete("Explain the importance of low latency LLMs")
print(response)
```
"""

model: str = Field(description="The model to use. Currently only supports `v1`.", default="v1")
api_key: str = Field(description="The API key to use.", default=None)
base_url: str = Field(description="The base URL to use.", default="https://api.contextual.ai/v1/generate")
avoid_commentary: bool = Field(description="Flag to indicate whether the model should avoid providing additional commentary in responses. Commentary is conversational in nature and does not contain verifiable claims; therefore, commentary is not strictly grounded in available context. However, commentary may provide useful context which improves the helpfulness of responses.", default=False)
client: Any = Field(default=None, description="Contextual AI Client")

def __init__(self,
model: str,
api_key: str,
base_url: str = None,
avoid_commentary: bool = False,
**openai_llm_kwargs: Any
) -> None:

super().__init__(
model=model,
api_key=api_key,
api_base=base_url,
is_chat_model=openai_llm_kwargs.pop("is_chat_model", True),
**openai_llm_kwargs,
)

try:
self.client = ContextualAI(api_key=api_key, base_url=base_url)
except Exception as e:
raise ValueError(f"Error initializing ContextualAI client: {e}")

@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "contextual-clm"

# Synchronous Methods
@llm_completion_callback()
def complete(self, prompt: str, knowledge: Optional[List[str]] = None, **kwargs) -> CompletionResponse:
"""
Generate completion for the given prompt.
Args:
prompt (str): The input prompt to generate completion for.
**kwargs: Additional keyword arguments for the API request.
Returns:
str: The generated text completion.
"""
return self._generate(
knowledge=knowledge,
messages=[ChatMessage(role=MessageRole.USER, content=prompt)],
model=self.model,
system_prompt=self.system_prompt,
**kwargs,
)

def _generate(self, knowledge, messages, system_prompt, **kwargs) -> CompletionResponse:
"""
Generate completion for the given prompt.
"""
raw_message = self.client.generate.create(
messages=[{"role": msg.role, "content": msg.blocks[0].text} for msg in messages],
knowledge=knowledge or [],
model=self.model,
system_prompt=system_prompt,
avoid_commentary=self.avoid_commentary,
temperature=kwargs.get("temperature", 0.0),
max_new_tokens=kwargs.get("max_tokens", 1024),
top_p=kwargs.get("top_p", 1),
)
return CompletionResponse(text=raw_message.response)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Contextual GLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install llama-index-llms-contextual"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CompletionResponse(text=\"I apologize, but I am unable to provide information about Grounded Language Models. I am an AI assistant created by Contextual AI. I don't have relevant documentation about that topic, but feel free to ask me something else!\", additional_kwargs={}, raw=None, logprobs=None, delta=None)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from llama_index.llms.contextual import Contextual\n",
"from dotenv import load_dotenv\n",
"import os\n",
"\n",
"# Set up the Contextual class with the required model and API key\n",
"# Store the API key in a .env file as CONTEXTUAL_API_KEY\n",
"load_dotenv()\n",
"llm = Contextual(model=\"v1\", api_key=os.getenv(\"CONTEXTUAL_API_KEY\"))\n",
"\n",
"# Call the complete method with a query\n",
"llm.complete(\"Explain the importance of Grounded Language Models.\", temperature=0.5, max_tokens=1024, avoid_commentary=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The sky is blue.'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"llm.complete(\"what color is the sky?\", knowledge=[\"The sky is blue\"], avoid_commentary=False, temperature=0.9, max_tokens=1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "llama-index-VCjo73HL-py3.10",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]

[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"

[tool.llamahub]
contains_example = false
import_path = "llama_index.llms.contextual"

[tool.llamahub.class_authors]
Contextual = "sean-smith"

[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Sean Smith <[email protected]>"]
description = "llama-index contextual integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-contextual"
readme = "README.md"
version = "0.0.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
llama-index-llms-openai-like = "^0.3.3"
contextual-client = "^0.4.0"

[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"

[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"

[[tool.poetry.packages]]
include = "llama_index/"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.contextual import Contextual

def test_llm_class():
names_of_base_classes = [b.__name__ for b in Contextual.__mro__]
assert BaseLLM.__name__ in names_of_base_classes

0 comments on commit 26236cd

Please sign in to comment.