Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
0564285
input_json for context_attribution example
dennislwei Feb 13, 2026
5f49958
Mark sentence boundaries in all_but_last_message
dennislwei Feb 13, 2026
3b8b6e2
Add context_attribution.yaml configuration
dennislwei Feb 13, 2026
fbdbe65
Escape braces in context_attribution.yaml instruction
dennislwei Feb 14, 2026
0c08b34
Add all_but_last_message source to DecodeSentences
dennislwei Feb 24, 2026
4b1087a
Update context_attribution.yaml for all_but_last_message decoding
dennislwei Feb 27, 2026
22a82ab
Merge branch 'feat/issue119-context-attribution' of github.com:dennis…
dennislwei Feb 27, 2026
e1875e5
Add context_attribution_single config and rename files
dennislwei Feb 27, 2026
0b6d234
Add WrapInList transformation rule
dennislwei Feb 27, 2026
af35e6f
Add test_canned_input for context_attribution_all and _single
dennislwei Mar 3, 2026
8e0dce0
Add YamlJsonCombo entries for context_attribution_all and _single
dennislwei Mar 4, 2026
08c2939
Add test_canned_output for context_attribution_all and _single
dennislwei Mar 4, 2026
e12aec3
Replace "citation" with "attribution" in context_attribution_*.yaml
dennislwei Mar 9, 2026
22d3662
Merge branch 'feat/issue119-context-attribution' of github.com:dennis…
dennislwei Mar 9, 2026
dbbd733
Merge branch 'main' into feat/issue119-context-attribution
dennislwei Mar 13, 2026
6a75d09
Rename context_attribution_all to context-attribution and remove cont…
dennislwei Mar 13, 2026
91517eb
Add repo_id for context-attribution and regenerate expected result
dennislwei Mar 13, 2026
8d2c156
Add test_run_transformers expected output for context-attribution
dennislwei Mar 13, 2026
71f8c46
Add test_run_ollama expected output for context-attribution
dennislwei Mar 13, 2026
13e7a21
Exclude context-attribution from test_run_ollama
dennislwei Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions src/granite_common/intrinsics/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def sentence_delimiter(tag, sentence_num) -> str:


def mark_sentence_boundaries(
split_strings: list[list[str]], tag_prefix: str
) -> tuple[str, int]:
split_strings: list[list[str]], tag_prefix: str, index: int = 0
) -> tuple[list[str], int]:
"""
Modify one or more input strings by inserting a tag in the form
``<[prefix][number]>``
Expand All @@ -59,18 +59,19 @@ def mark_sentence_boundaries(
:param split_strings: Input string(s), pre-split into sentences
:param tag_prefix: String to place before the number part of each tagged
sentence boundary.
:param index: Starting index for sentence numbering (default: 0)

:returns: List of input strings with all sentence boundaries marked.
:returns: Tuple of (list of input strings with all sentence boundaries marked,
next available index)
"""
index = 0
result = []
for sentences in split_strings:
to_concat = []
for sentence in sentences:
to_concat.append(f"{sentence_delimiter(tag_prefix, index)}{sentence}")
index += 1
result.append(" ".join(to_concat))
return result
return result, index


def move_documents_to_message(
Expand Down Expand Up @@ -252,10 +253,11 @@ def __init__(
f"Received {self.sentence_boundaries}."
)
for k, v in self.sentence_boundaries.items():
if k not in ("last_message", "documents"):
if k not in ("last_message", "documents", "all_but_last_message"):
raise ValueError(
f"Unexpected location '{k}' in 'sentence_boundaries' field. "
f"Value should be 'last_message' or 'documents'."
f"Value should be 'last_message', 'documents', or "
f"'all_but_last_message'."
)
if not isinstance(v, str):
raise TypeError(
Expand All @@ -282,20 +284,27 @@ def _mark_sentence_boundaries(
:param chat_completion: Argument to :func:`_transform()`
:type chat_completion: ChatCompletion
:return: Copy of original chat completion with sentence boundaries marked in
the last message and in documents.
the last message, in documents, and/or in all but the last message.
:rtype: ChatCompletion
"""
# Initialize sentence index counter
index = 0

# Mark sentence boundaries in the last message.
if "last_message" in self.sentence_boundaries:
messages = chat_completion.messages.copy() # Do not modify input!
last_message_as_sentences = list(
self.sentence_splitter.tokenize(messages[-1].content)
)
rewritten_last_message_text = mark_sentence_boundaries(
[last_message_as_sentences], self.sentence_boundaries["last_message"]
)[0]
messages[-1].content = rewritten_last_message_text
rewritten_texts, _ = mark_sentence_boundaries(
[last_message_as_sentences],
self.sentence_boundaries["last_message"],
index,
)
messages[-1].content = rewritten_texts[0]
chat_completion = chat_completion.model_copy(update={"messages": messages})
# Reset index for subsequent cases (documents, all_but_last_message)
index = 0

# Mark sentence boundaries in documents if present
if (
Expand All @@ -309,13 +318,14 @@ def _mark_sentence_boundaries(
# The documents input to the model consists of the original documents
# with each sentence boundary marked with <c0>, <c1>, ... <ck-1>,
# where `k` is the number of sentences in ALL documents.
rewritten_doc_texts, index = mark_sentence_boundaries(
docs_as_sentences, self.sentence_boundaries["documents"], index
)
rewritten_docs = [
doc.model_copy(update={"text": text})
for doc, text in zip(
chat_completion.extra_body.documents,
mark_sentence_boundaries(
docs_as_sentences, self.sentence_boundaries["documents"]
),
rewritten_doc_texts,
strict=True,
)
]
Expand All @@ -326,6 +336,23 @@ def _mark_sentence_boundaries(
chat_completion = chat_completion.model_copy(
update={"extra_body": extra_body}
)

# Mark sentence boundaries in all messages except the last one
if "all_but_last_message" in self.sentence_boundaries:
messages = chat_completion.messages.copy() # Do not modify input!
# Process all messages except the last one
for i in range(len(messages) - 1):
message_as_sentences = list(
self.sentence_splitter.tokenize(messages[i].content)
)
rewritten_texts, index = mark_sentence_boundaries(
[message_as_sentences],
self.sentence_boundaries["all_but_last_message"],
index,
)
messages[i].content = rewritten_texts[0]
chat_completion = chat_completion.model_copy(update={"messages": messages})

return chat_completion

def _transform(
Expand Down
200 changes: 131 additions & 69 deletions src/granite_common/intrinsics/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,30 +449,38 @@ def __init__(
config: dict,
input_path_expr: list[str | int | None],
/,
source: str,
source: str | list[str],
output_names: dict,
):
"""
:param source: Name of the location to look for sentences; can be "last_message"
or "documents".
:param source: Name of the location to look for sentences, or a list of
locations to process in order. Each element can be "last_message",
"documents", or "all_but_last_message".
:param output_names: Names of new result fields to add
"""
super().__init__(config, input_path_expr)

allowed_sources = ("last_message", "documents")
if source not in allowed_sources:
raise ValueError(
f"'source' argument must be one of {allowed_sources}. "
f"Received '{source}'"
)
# Normalize source to always be a list
if isinstance(source, str):
source = [source]

allowed_sources = ("last_message", "documents", "all_but_last_message")
for s in source:
if s not in allowed_sources:
raise ValueError(
f"'source' argument must be one of {allowed_sources}. "
f"Received '{s}'"
)
self.source = source

if not isinstance(output_names, dict):
raise TypeError(
f"Expected mapping for output_names, but received {output_names}"
)
for k in output_names:
if source == "documents" and k == "document_id":
if "documents" in self.source and k == "document_id":
continue
if "all_but_last_message" in self.source and k == "message_index":
continue
if k not in ("begin", "end", "text"):
raise ValueError(f"Unexpected key '{k}' in output_names")
Expand All @@ -482,6 +490,7 @@ def __init__(
self.end_name = output_names.get("end")
self.text_name = output_names.get("text")
self.document_id_name = output_names.get("document_id")
self.message_index_name = output_names.get("message_index")

if config["docs_as_message"] and config["docs_as_message"] not in [
"json",
Expand All @@ -506,86 +515,127 @@ def _prepare(
f"'{self.rule_name()}' rule requires this object."
)

if self.source == "documents":
tag = self.config["sentence_boundaries"]["documents"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode document sentences, "
f"but 'sentence_boundaries' section of config file is missing "
f"the entry that tells how to tag document sentence boundaries."
)
begins = []
ends = []
texts = []
document_ids = []
message_indices = []
next_sentence_num = 0

for s in self.source:
if s == "documents":
tag = self.config["sentence_boundaries"]["documents"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode document "
f"sentences, but 'sentence_boundaries' section of config "
f"file is missing the entry that tells how to tag document "
f"sentence boundaries."
)

if not self.config["docs_as_message"]:
# Most common path: Documents from extra_body
documents = chat_completion.extra_body.documents
else:
# Model requires documents in a user message. Decode the message.
if self.config["docs_as_message"] == "json":
documents_json = json.loads(chat_completion.messages[0].content)
documents = [Document.model_validate(d) for d in documents_json]
elif self.config["docs_as_message"] == "roles":
documents = []
for message in chat_completion.messages:
if message.role.startswith("document "):
document = Document(
doc_id=message.role[len("document ") :],
text=message.content,
)
documents.append(document)
if not self.config["docs_as_message"]:
# Most common path: Documents from extra_body
documents = chat_completion.extra_body.documents
else:
# Model requires documents in a user message. Decode the message.
if self.config["docs_as_message"] == "json":
documents_json = json.loads(chat_completion.messages[0].content)
documents = [Document.model_validate(d) for d in documents_json]
elif self.config["docs_as_message"] == "roles":
documents = []
for message in chat_completion.messages:
if message.role.startswith("document "):
document = Document(
doc_id=message.role[len("document ") :],
text=message.content,
)
documents.append(document)
else:
raise ValueError(
f"Unsupported doc type {self.config['docs_as_message']}"
)

if documents is None:
documents = []

# De-split the sentences in each document in turn. Sentence numbers
# start at zero on the first document and continue in subsequent
# documents.
for d in documents:
local_results = _desplit_sentences(d.text, tag, next_sentence_num)
num_local_sentences = len(local_results["begins"])
begins.extend(local_results["begins"])
ends.extend(local_results["ends"])
texts.extend(local_results["texts"])
document_ids.extend([d.doc_id] * num_local_sentences)
message_indices.extend([None] * num_local_sentences)
next_sentence_num += num_local_sentences

elif s == "last_message":
tag = self.config["sentence_boundaries"]["last_message"]
if tag is None:
raise ValueError(
f"Unsupported doc type {self.config['docs_as_message']}"
f"'{self.rule_name()}' attempting to decode the last message, "
f"but 'sentence_boundaries' section of config file is missing "
f"the entry that tells how to tag message sentence boundaries."
)

if documents is None:
documents = []
# Use second-to-last turn if input processing added an instruction turn
message_ix = -2 if self.config["instruction"] else -1
target_text = chat_completion.messages[message_ix].content

# De-split the sentences in each document in turn. Sentence numbers
# start at zero on the first document and continue in subsequent documents.
begins = []
ends = []
texts = []
document_ids = []

next_sentence_num = 0
for d in documents:
local_results = _desplit_sentences(d.text, tag, next_sentence_num)
local_results = _desplit_sentences(target_text, tag, next_sentence_num)
num_local_sentences = len(local_results["begins"])
begins.extend(local_results["begins"])
ends.extend(local_results["ends"])
texts.extend(local_results["texts"])
document_ids.extend([d.doc_id] * num_local_sentences)
document_ids.extend([None] * num_local_sentences)
message_indices.extend([None] * num_local_sentences)
next_sentence_num += num_local_sentences

return {
"begins": begins,
"ends": ends,
"texts": texts,
"document_ids": document_ids,
}
if self.source == "last_message":
tag = self.config["sentence_boundaries"]["last_message"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode the last message, "
f"but 'sentence_boundaries' section of config file is missing "
f"the entry that tells how to tag message sentence boundaries."
)
elif s == "all_but_last_message":
tag = self.config["sentence_boundaries"]["all_but_last_message"]
if tag is None:
raise ValueError(
f"'{self.rule_name()}' attempting to decode conversation "
f"history sentences, but 'sentence_boundaries' section of "
f"config file is missing the entry that tells how to tag "
f"all_but_last_message sentence boundaries."
)

# Use second-to-last turn if the input processing added an instruction turn
message_ix = -2 if self.config["instruction"] else -1
target_text = chat_completion.messages[message_ix].content
# Use second-to-last as the boundary if an instruction turn was added
last_ix = -2 if self.config["instruction"] else -1
messages = chat_completion.messages[:last_ix]
for i, message in enumerate(messages):
local_results = _desplit_sentences(
message.content, tag, next_sentence_num
)
num_local_sentences = len(local_results["begins"])
begins.extend(local_results["begins"])
ends.extend(local_results["ends"])
texts.extend(local_results["texts"])
document_ids.extend([None] * num_local_sentences)
message_indices.extend([i] * num_local_sentences)
next_sentence_num += num_local_sentences

return _desplit_sentences(target_text, tag, 0)
else:
raise ValueError(f"Unexpected source string '{s}'")

raise ValueError(f"Unexpected source string '{self.source}'")
return {
"begins": begins,
"ends": ends,
"texts": texts,
"document_ids": document_ids,
"message_indices": message_indices,
}

def _transform(self, value: Any, path: tuple, prepare_output: dict) -> dict:
# Unpack global values we set aside during the prepare phase
begins = prepare_output["begins"]
ends = prepare_output["ends"]
texts = prepare_output["texts"]
document_ids = prepare_output.get("document_ids")
document_ids = prepare_output["document_ids"]
message_indices = prepare_output["message_indices"]

if not isinstance(value, int):
raise TypeError(
Expand All @@ -603,6 +653,8 @@ def _transform(self, value: Any, path: tuple, prepare_output: dict) -> dict:
result[self.text_name] = texts[sentence_num]
if self.document_id_name is not None:
result[self.document_id_name] = document_ids[sentence_num]
if self.message_index_name is not None:
result[self.message_index_name] = message_indices[sentence_num]
return result


Expand Down Expand Up @@ -779,6 +831,15 @@ def _transform(self, value, path, prepare_output):
return {self.field_name: value}


class WrapInList(InPlaceTransformation):
"""Wrap a value within a JSON structure in a single-element list."""

YAML_NAME = "wrap_in_list"

def _transform(self, value, path, prepare_output):
return [value]


class MergeSpans(InPlaceTransformation):
"""Merge adjacent spans into larger spans."""

Expand Down Expand Up @@ -940,6 +1001,7 @@ def _transform(self, value, path, prepare_output):
Nest,
Project,
TokenToFloat,
WrapInList,
]
NAME_TO_RULE = {cls.YAML_NAME: cls for cls in ALL_RULES}

Expand Down
Loading