Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions Runtime/LLMBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace LLMUnity
public class LLMBuilder : AssetPostprocessor
{
static List<StringPair> movedPairs = new List<StringPair>();
public static string BuildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild");
public static string BuildTempDir = Path.Combine(Directory.GetParent(Application.dataPath).FullName, "LLMUnityBuild");
static string movedCache = Path.Combine(BuildTempDir, "moved.json");

[InitializeOnLoadMethod]
Expand All @@ -36,6 +36,28 @@ public static string PluginLibraryDir(string platform, bool relative = false)
return Path.Combine(PluginDir(platform, relative), LLMUnitySetup.libraryName);
}

public static void Retry(System.Action action, int retries = 10, int delayMs = 100)
{
for (int i = 0; i < retries; i++)
{
try
{
action();
return;
}
catch (IOException)
{
if (i == retries - 1) throw;
System.Threading.Thread.Sleep(delayMs);
}
catch (System.UnauthorizedAccessException)
{
if (i == retries - 1) throw;
System.Threading.Thread.Sleep(delayMs);
}
}
}

/// <summary>
/// Performs an action for a file or a directory recursively
/// </summary>
Expand All @@ -46,7 +68,7 @@ public static void HandleActionFileRecursive(string source, string target, Actio
{
if (File.Exists(source))
{
actionCallback(source, target);
Retry(() => actionCallback(source, target));
}
else if (Directory.Exists(source))
{
Expand Down Expand Up @@ -106,8 +128,8 @@ public static bool DeletePath(string path)
LLMUnitySetup.LogError($"Safeguard: {path} will not be deleted because it may not be safe");
return false;
}
if (File.Exists(path)) File.Delete(path);
else if (Directory.Exists(path)) Directory.Delete(path, true);
if (File.Exists(path)) Retry(() => File.Delete(path));
else if (Directory.Exists(path)) Retry(() => Directory.Delete(path, true));
return true;
}

Expand Down Expand Up @@ -266,8 +288,21 @@ public static void Reset()
bool refresh = false;
foreach (var pair in movedPairs)
{
if (pair.source == "") refresh |= DeletePath(pair.target);
else refresh |= MoveAction(pair.target, pair.source, false);
if (pair.source == "")
{
refresh |= DeletePath(pair.target);
}
else
{
if (File.Exists(pair.source) || Directory.Exists(pair.source))
{
refresh |= DeletePath(pair.target);
}
else
{
refresh |= MoveAction(pair.target, pair.source, false);
}
}
}
if (refresh) AssetDatabase.Refresh();
DeletePath(movedCache);
Expand Down
12 changes: 10 additions & 2 deletions Runtime/LLMCaller.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,18 @@ protected virtual Ret ConvertContent<Res, Ret>(string response, ContentCallback<
}
response = $"{{\"data\": [{responseArray}]}}";
}
return getContent(JsonUtility.FromJson<Res>(response));
try
{
return getContent(JsonUtility.FromJson<Res>(response));
}
catch (Exception e)
{
LLMUnitySetup.LogError($"Error converting response: {e.Message}\nResponse: {response}");
return default;
}
}

protected virtual void CancelRequestsLocal() {}
protected virtual void CancelRequestsLocal() { }

protected virtual void CancelRequestsRemote()
{
Expand Down
37 changes: 29 additions & 8 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ public class LLMCharacter : LLMCaller
protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
protected string chatTemplate;
protected ChatTemplate template = null;
protected Task grammarTask;

/// \endcond

/// <summary>
Expand All @@ -157,7 +159,7 @@ public override void Awake()
int slotFromServer = llm.Register(this);
if (slot == -1) slot = slotFromServer;
}
InitGrammar();
grammarTask = InitGrammar();
InitHistory();
}

Expand Down Expand Up @@ -273,19 +275,33 @@ protected virtual async Task<bool> InitNKeep()
return true;
}

protected virtual void InitGrammar()
protected virtual async Task InitGrammar()
{
grammarString = "";
grammarJSONString = "";
if (!String.IsNullOrEmpty(grammar))
{
grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
if (!String.IsNullOrEmpty(grammarJSON))
LLMUnitySetup.LogWarning("Both GBNF and JSON grammars are set, only the GBNF will be used");
await LLMUnitySetup.AndroidExtractAsset(grammar, true);
string path = LLMUnitySetup.GetAssetPath(grammar);
if (File.Exists(path))
{
grammarString = File.ReadAllText(path);
if (!String.IsNullOrEmpty(grammarJSON))
LLMUnitySetup.LogWarning("Both GBNF and JSON grammars are set, only the GBNF will be used");
}
else
{
LLMUnitySetup.LogError($"Grammar file {path} not found!");
}
}
else if (!String.IsNullOrEmpty(grammarJSON))
{
grammarJSONString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammarJSON));
await LLMUnitySetup.AndroidExtractAsset(grammarJSON, true);
string path = LLMUnitySetup.GetAssetPath(grammarJSON);
if (File.Exists(path))
grammarJSONString = File.ReadAllText(path);
else
LLMUnitySetup.LogError($"Grammar file {path} not found!");
}
}

Expand Down Expand Up @@ -327,10 +343,10 @@ public virtual async Task SetGrammarFile(string path, bool gnbf)
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
#endif
await LLMUnitySetup.AndroidExtractAsset(path, true);
if (gnbf) grammar = path;
else grammarJSON = path;
InitGrammar();
grammarTask = InitGrammar();
await grammarTask;
}

/// <summary>
Expand Down Expand Up @@ -524,6 +540,7 @@ public virtual async Task<string> Chat(string query, Callback<string> callback =
await LoadTemplate();
if (!CheckTemplate()) return null;
if (!await InitNKeep()) return null;
if (grammarTask != null) await grammarTask;

ChatRequest request = await PromptWithQuery(query);
string result = await CompletionRequest(request, callback);
Expand Down Expand Up @@ -562,6 +579,7 @@ public virtual async Task<string> Complete(string prompt, Callback<string> callb
// call the callback function while the answer is received
// call the completionCallback function when the answer is fully received
await LoadTemplate();
if (grammarTask != null) await grammarTask;

ChatRequest request = GenerateRequest(prompt);
string result = await CompletionRequest(request, callback);
Expand Down Expand Up @@ -595,6 +613,7 @@ public virtual async Task Warmup(string query, EmptyCallback completionCallback
await LoadTemplate();
if (!CheckTemplate()) return;
if (!await InitNKeep()) return;
if (grammarTask != null) await grammarTask;

ChatRequest request;
if (String.IsNullOrEmpty(query))
Expand All @@ -608,6 +627,8 @@ public virtual async Task Warmup(string query, EmptyCallback completionCallback
}

request.n_predict = 0;
request.grammar = null;
request.json_schema = null;
await CompletionRequest(request);
completionCallback?.Invoke();
}
Expand Down
2 changes: 1 addition & 1 deletion Runtime/LLMManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ public static void SaveToDisk()
List<ModelEntry> modelEntriesBuild = new List<ModelEntry>();
foreach (ModelEntry modelEntry in modelEntries)
{
if (!modelEntry.includeInBuild) continue;
if (!modelEntry.includeInBuild && string.IsNullOrEmpty(modelEntry.url)) continue;
modelEntriesBuild.Add(modelEntry.OnlyRequiredFields());
}
string json = JsonUtility.ToJson(new LLMManagerStore
Expand Down