diff --git a/swift/Sources/CoreAIShared/Bundle/ModelBundle.swift b/swift/Sources/CoreAIShared/Bundle/ModelBundle.swift index 155a786..e25b9a8 100644 --- a/swift/Sources/CoreAIShared/Bundle/ModelBundle.swift +++ b/swift/Sources/CoreAIShared/Bundle/ModelBundle.swift @@ -82,9 +82,14 @@ public struct ModelBundle: Sendable { case kindMismatch(expected: BundleKind, got: BundleKind) case missingField(String) case missingAsset(key: String, path: URL) + case pointedAtModelAsset(URL) public var description: String { switch self { + case .pointedAtModelAsset(let url): + return "'\(url.lastPathComponent)' is a model asset, not a model bundle " + + "directory. A model bundle directory contains metadata, a tokenizer, " + + "and a model asset." case .missingMetadata(let url): return "metadata.json not found at \(url.path)" case .malformedMetadata(let url, let err): @@ -114,6 +119,16 @@ public struct ModelBundle: Sendable { } public init(at url: URL) throws { + // A model bundle is a *directory* (metadata.json + assets + tokenizer). + // If the caller points us directly at a `.aimodel`/`.aimodelc` asset, + // fail with actionable guidance. This must run before any filesystem + // read: a compiled `.aimodelc` is itself a directory holding its own + // unrelated metadata.json, which would otherwise parse as a bogus 0.1 + // bundle and surface a misleading "unsupported metadata_version" error. + let ext = url.pathExtension.lowercased() + if ext == "aimodel" || ext == "aimodelc" { + throw BundleError.pointedAtModelAsset(url) + } let metadataURL = url.appending(path: "metadata.json") guard FileManager.default.fileExists(atPath: metadataURL.path) else { throw BundleError.missingMetadata(metadataURL) diff --git a/swift/Sources/Tools/benchmark/BenchmarkMain.swift b/swift/Sources/Tools/benchmark/BenchmarkMain.swift index 47726f0..1e855e7 100644 --- a/swift/Sources/Tools/benchmark/BenchmarkMain.swift +++ b/swift/Sources/Tools/benchmark/BenchmarkMain.swift @@ -22,7 +22,7 @@ struct LLMBenchmark: AsyncParsableCommand { abstract: "LLM inference benchmark for CoreAI models" ) - @Option(name: .customLong("model"), help: "Path to model bundle directory") + @Option(name: .customLong("model"), help: "Path to a model bundle directory") var model: String @Option(name: [.customShort("p"), .customLong("prompt-tokens")], help: "Length of prompt") diff --git a/swift/Tests/CoreAISharedTests/ModelBundleTests.swift b/swift/Tests/CoreAISharedTests/ModelBundleTests.swift index 9a05013..793855d 100644 --- a/swift/Tests/CoreAISharedTests/ModelBundleTests.swift +++ b/swift/Tests/CoreAISharedTests/ModelBundleTests.swift @@ -78,4 +78,40 @@ struct ModelBundleTests { _ = try ModelBundle(at: dir) } } + + @Test("Pointing at a .aimodelc asset throws pointedAtModelAsset, not a parse error") + func pointedAtCompiledAssetThrows() throws { + // A compiled `.aimodelc` is a directory with its own unrelated + // metadata.json. Pointing the tool at it must fail fast with guidance, + // not parse that inner metadata as a bogus 0.1 bundle. + let bundleDir = FileManager.default.temporaryDirectory.appending( + path: "ModelBundleTests-\(UUID().uuidString)" + ) + let asset = bundleDir.appending(path: "model.aimodelc") + try FileManager.default.createDirectory(at: asset, withIntermediateDirectories: true) + try """ + { "producer": "coreai-build", "assetVersion": "2.0" } + """.write( + to: asset.appending(path: "metadata.json"), atomically: true, encoding: .utf8) + + let error = #expect(throws: ModelBundle.BundleError.self) { + _ = try ModelBundle(at: asset) + } + guard case .pointedAtModelAsset = error else { + Issue.record("expected pointedAtModelAsset, got \(String(describing: error))") + return + } + #expect(String(describing: error).contains("model.aimodelc")) + } + + @Test("Pointing at a .aimodel asset throws pointedAtModelAsset") + func pointedAtUncompiledAssetThrows() throws { + let error = #expect(throws: ModelBundle.BundleError.self) { + _ = try ModelBundle(from: "/some/where/model.aimodel") + } + guard case .pointedAtModelAsset = error else { + Issue.record("expected pointedAtModelAsset, got \(String(describing: error))") + return + } + } }