Skip to content

Commit

Permalink
Podcasts and Threads (#493)
Browse files Browse the repository at this point in the history
**Threads**
* Experimental support for Thread Index 
* Define a Thread by Time Range and Description
* Thread descriptions are indexed and matched during search processing
to find time ranges
* Time range used to filter matches
* E.g. For multiple episodes of a podcast in  single index

**Relevance Improvements**
**Bug fixes**
**Unit tests**
  • Loading branch information
umeshma authored Dec 13, 2024
1 parent f790c63 commit 45f03fc
Show file tree
Hide file tree
Showing 16 changed files with 485 additions and 64 deletions.
4 changes: 4 additions & 0 deletions ts/examples/chat/src/memory/chatMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,7 @@ export async function runChatMemory(): Promise<void> {
def.options.skipEntities = argBool("Skip entity matching", false);
def.options.skipActions = argBool("Skip action matching", false);
def.options.skipTopics = argBool("Skip topics matching", false);
def.options.threads = argBool("Use most likely thread", false);
return def;
}
commands.search.metadata = searchDef();
Expand Down Expand Up @@ -1333,6 +1334,9 @@ export async function runChatMemory(): Promise<void> {
if (namedArgs.fallback) {
searchOptions.fallbackSearch = { maxMatches: 10 };
}
if (namedArgs.threads) {
searchOptions.threadSearch = { maxMatches: 1, minScore: 0.8 };
}
if (!namedArgs.eval) {
// just translate user query into structured query without eval
const translationContext = await context.searcher.buildContext(
Expand Down
22 changes: 17 additions & 5 deletions ts/examples/chat/src/memory/chatMemoryPrinter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,17 +297,29 @@ export class ChatMemoryPrinter extends ChatPrinter {
}
}

public writeSearchQuestion(
result:
| conversation.SearchTermsActionResponse
| conversation.SearchTermsActionResponseV2
| undefined,
debug: boolean = false,
) {
if (result) {
const question = getSearchQuestion(result);
if (question) {
this.writeInColor(chalk.cyanBright, `Question: ${question}`);
this.writeLine();
}
}
}

public writeSearchTermsResult(
result:
| conversation.SearchTermsActionResponse
| conversation.SearchTermsActionResponseV2,
debug: boolean = false,
) {
const question = getSearchQuestion(result);
if (question) {
this.writeInColor(chalk.cyanBright, `Question: ${question}`);
this.writeLine();
}
this.writeSearchQuestion(result);
if (result.response && result.response.answer) {
this.writeResultStats(result.response);
if (result.response.answer.answer) {
Expand Down
4 changes: 2 additions & 2 deletions ts/examples/chat/src/memory/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ export function argClean(defaultValue = false): ArgDef {
};
}

export function argPause(): ArgDef {
export function argPause(defaultValue = 0): ArgDef {
return {
type: "number",
defaultValue: 0,
defaultValue,
description: "Pause for given milliseconds after each iteration",
};
}
Expand Down
6 changes: 5 additions & 1 deletion ts/examples/chat/src/memory/emailMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import {
import chalk from "chalk";
import { convertMsgFiles } from "./importer.js";
import fs from "fs";
import { Result, success } from "typechat";
import { error, Result, success } from "typechat";

export async function createEmailMemory(
models: Models,
Expand Down Expand Up @@ -341,7 +341,11 @@ export function createEmailCommands(
options,
previousUserInputs,
);
if (!searchResults) {
return error("No search results");
}
context.printer.writeLine();
context.printer.writeSearchQuestion(searchResults);
context.printer.writeResultStats(searchResults?.response);
context.printer.writeLine();
return success(searchResults);
Expand Down
84 changes: 80 additions & 4 deletions ts/examples/chat/src/memory/podcastMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
} from "./common.js";
import path from "path";
import {
asyncArray,
createWorkQueueFolder,
ensureDir,
getFileName,
Expand Down Expand Up @@ -74,7 +75,10 @@ export function createPodcastCommands(
): void {
commands.importPodcast = importPodcast;
commands.podcastConvert = podcastConvert;
commands.prodcastIndex = podcastIndex;
commands.podcastIndex = podcastIndex;
commands.podcastAddThread = podcastAddThread;
commands.podcastListThreads = podcastListThreads;

//-----------
// COMMANDS
//---------
Expand All @@ -83,16 +87,19 @@ export function createPodcastCommands(
description: "Import a podcast transcript.",
args: {
sourcePath: argSourceFileOrFolder(),
name: arg("Thread name"),
description: arg("Thread description"),
},
options: {
startAt: arg("Start date and time"),
length: argNum("Length of the podcast in minutes", 60),
clean: argClean(),
maxTurns: argNum("Max turns"),
pauseMs: argPause(),
pauseMs: argPause(1000),
},
};
}
commands.importPodcast.metadata = importPodcastDef();
async function importPodcast(args: string[]): Promise<void> {
const namedArgs = parseNamedArguments(args, importPodcastDef());
let sourcePath: string = namedArgs.sourcePath;
Expand All @@ -102,6 +109,7 @@ export function createPodcastCommands(
}

await podcastConvert(namedArgs);
await podcastAddThread(namedArgs);
const turnsFilePath = getTurnsFolderPath(sourcePath);
namedArgs.sourcePath = turnsFilePath;
await podcastIndex(namedArgs);
Expand Down Expand Up @@ -139,11 +147,11 @@ export function createPodcastCommands(
options: {
clean: argClean(),
maxTurns: argNum("Max turns"),
pauseMs: argPause(),
pauseMs: argPause(1000),
},
};
}
commands.importPodcast.metadata = podcastIndexDef();
commands.podcastIndex.metadata = podcastIndexDef();
async function podcastIndex(args: string[] | NamedArgs) {
const namedArgs = parseNamedArguments(args, podcastIndexDef());
let sourcePath: string = namedArgs.sourcePath;
Expand All @@ -159,6 +167,66 @@ export function createPodcastCommands(
context.printer.writeError(`${sourcePath} is not a directory`);
}
}

function podcastAddThreadDef(): CommandMetadata {
return {
description: "Add a sub-thread to the podcast index",
args: {
sourcePath: argSourceFileOrFolder(),
name: arg("Thread name"),
description: arg("Thread description"),
},
options: {
startAt: arg("Start date and time"),
length: argNum("Length of the podcast in minutes", 60),
},
};
}
commands.podcastAddThread.metadata = podcastAddThreadDef();
async function podcastAddThread(args: string[] | NamedArgs): Promise<void> {
const namedArgs = parseNamedArguments(args, podcastConvertDef());
const sourcePath = namedArgs.sourcePath;
const timeRange = conversation.parseTranscriptDuration(
namedArgs.startAt,
namedArgs.length,
);
if (!timeRange) {
context.printer.writeError("Time range required");
return;
}
const turns =
await conversation.loadTurnsFromTranscriptFile(sourcePath);
const metadata: conversation.TranscriptMetadata = {
sourcePath,
name: namedArgs.name,
description: namedArgs.description,
startAt: namedArgs.startAt,
lengthMinutes: namedArgs.length,
};
const overview = conversation.createTranscriptOverview(metadata, turns);
const threadDef: conversation.ThreadTimeRange = {
type: "temporal",
description: overview,
timeRange,
};
const threads =
await context.podcastMemory.conversation.getThreadIndex();
await threads.add(threadDef);
writeThread(threadDef);
}
commands.podcastListThreads.metadata = "List all registered threads";
async function podcastListThreads(args: string[]) {
const threads =
await context.podcastMemory.conversation.getThreadIndex();
const allThreads: conversation.ConversationThread[] =
await asyncArray.toArray(threads.entries());
for (let i = 0; i < allThreads.length; ++i) {
const t = allThreads[i];
context.printer.writeLine(`[${i}]`);
writeThread(t);
}
}

return;

//---
Expand Down Expand Up @@ -249,4 +317,12 @@ export function createPodcastCommands(
`${context.podcastMemory.conversationName}_stats.json`,
);
}

function writeThread(t: conversation.ConversationThread) {
context.printer.writeLine(t.description);
const range = conversation.toDateRange(t.timeRange);
context.printer.writeLine(range.startDate.toISOString());
context.printer.writeLine(range.stopDate!.toISOString());
context.printer.writeLine();
}
}
21 changes: 21 additions & 0 deletions ts/packages/knowledgeProcessor/src/conversation/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import {
StorageProvider,
} from "../storageProvider.js";
import { RecentItems, createRecentItemsWindow } from "../temporal.js";
import { createThreadIndexOnStorage, ThreadIndex } from "./threads.js";

export interface ConversationSettings {
indexSettings: TextIndexSettings;
Expand Down Expand Up @@ -147,6 +148,7 @@ export interface Conversation<
* Returns the index of
*/
getActionIndex(): Promise<ActionIndex<TActionId, MessageId>>;
getThreadIndex(): Promise<ThreadIndex<string>>;
/**
*
* @param removeMessages If you want the original messages also removed. Set to false if you just want to rebuild the indexes
Expand Down Expand Up @@ -279,6 +281,8 @@ export async function createConversation(
let entityIndex: EntityIndex | undefined;
const actionPath = path.join(rootPath, "actions");
let actionIndex: ActionIndex | undefined;
const threadsPath = path.join(rootPath, "threads");
let threadIndex: ThreadIndex | undefined;

const thisConversation: Conversation<string, string, string> = {
settings,
Expand All @@ -288,6 +292,7 @@ export async function createConversation(
getEntityIndex,
getTopicsIndex,
getActionIndex,
getThreadIndex,
clear,
addMessage,
addKnowledgeForMessage,
Expand Down Expand Up @@ -344,6 +349,22 @@ export async function createConversation(
return actionIndex;
}

async function getThreadIndex(): Promise<ThreadIndex> {
if (!threadIndex) {
// Using file provider until stable
const provider = createFileSystemStorageProvider(
rootPath,
folderSettings,
fSys,
);
threadIndex = await createThreadIndexOnStorage(
threadsPath,
provider,
);
}
return threadIndex;
}

async function getTopicsIndex(level?: number): Promise<TopicIndex> {
const name = topicsName(level);
let topicIndex = topics.get(name);
Expand Down
23 changes: 10 additions & 13 deletions ts/packages/knowledgeProcessor/src/conversation/entities.ts
Original file line number Diff line number Diff line change
Expand Up @@ -380,36 +380,39 @@ export async function createEntityIndexOnStorage<TSourceId = string>(

terms = terms.filter((t) => !noiseTerms.has(t));
if (terms && terms.length > 0) {
const hitCounter = createHitTable<EntityId>();
const entityIdHitTable = createHitTable<EntityId>();
const scoreBoost = terms.length;
await Promise.all([
nameIndex.getNearestHitsMultiple(
terms,
hitCounter,
entityIdHitTable,
options.nameSearchOptions?.maxMatches ?? options.maxMatches,
options.nameSearchOptions?.minScore ?? options.minScore,
scoreBoost,
nameAliases,
),
typeIndex.getNearestHitsMultiple(
terms,
hitCounter,
entityIdHitTable,
options.maxMatches,
options.minScore,
scoreBoost,
),
facetIndex.getNearestHitsMultiple(
terms,
hitCounter,
entityIdHitTable,
options.facetSearchOptions?.maxMatches,
options.facetSearchOptions?.minScore ?? options.minScore,
),
]);
let entityHits = hitCounter.getTopK(determineTopK(options)).sort();
entityIdHitTable.roundScores(2);
let entityIdHits = entityIdHitTable
.getTopK(determineTopK(options))
.sort();

results.entityIds = [
...intersectMultiple(
entityHits,
entityIdHits,
itemsFromTemporalSequence(results.temporalSequence),
),
];
Expand Down Expand Up @@ -472,13 +475,7 @@ export async function createEntityIndexOnStorage<TSourceId = string>(
}

function determineTopK(options: EntitySearchOptions): number {
const topK =
options.topK ??
Math.max(
options.maxMatches,
options.nameSearchOptions?.maxMatches ?? 0,
//options.facetSearchOptions?.maxMatches ?? 0,
);
const topK = options.topK;
return topK === undefined || topK < 3 ? 3 : topK;
}
}
Expand Down
1 change: 1 addition & 0 deletions ts/packages/knowledgeProcessor/src/conversation/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export * from "./topics.js";
export * from "./topicSchema.js";
export * from "./transcript.js";
export * from "./actions.js";
export * from "./threads.js";

export * from "./searchResponse.js";
export * from "./searchProcessor.js";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
import { createTypeScriptJsonValidator } from "typechat/ts";
import { SearchAction } from "./knowledgeSearchSchema.js";
import { dateTime, loadSchema } from "typeagent";
import { DateTime, DateTimeRange } from "./dateTimeSchema.js";
import { DateTime, DateTimeRange, DateVal, TimeVal } from "./dateTimeSchema.js";
import { SearchTermsAction } from "./knowledgeTermSearchSchema.js";
import { SearchTermsActionV2 } from "./knowledgeTermSearchSchema2.js";

Expand Down Expand Up @@ -196,3 +196,20 @@ export function dateTimeToDate(dateTime: DateTime): Date {
}
return dt;
}

export function dateToDateTime(dt: Date): DateTime {
const date: DateVal = {
day: dt.getDate(),
month: dt.getMonth() + 1,
year: dt.getFullYear(),
};
const time: TimeVal = {
hour: dt.getHours(),
minute: dt.getMinutes(),
seconds: dt.getSeconds(),
};
return {
date,
time,
};
}
Loading

0 comments on commit 45f03fc

Please sign in to comment.