Skip to content

Commit d1f7c7e

Browse files
authored
[Vertex AI] Fix unsupported model name check introduced in #14610 (#14629)
1 parent b20d812 commit d1f7c7e

File tree

5 files changed

+38
-37
lines changed

5 files changed

+38
-37
lines changed

FirebaseVertexAI/Sources/GenerativeModel.swift

+5-11
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ public final class GenerativeModel: Sendable {
5353
/// Initializes a new remote model with the given parameters.
5454
///
5555
/// - Parameters:
56-
/// - name: The name of the model to use, for example `"gemini-1.0-pro"`.
56+
/// - modelResourceName: The resource name of the model to use, for example
57+
/// `"projects/{project-id}/locations/{location-id}/publishers/google/models/{model-name}"`.
5758
/// - firebaseInfo: Firebase data used by the SDK, including project ID and API key.
5859
/// - apiConfig: Configuration for the backend API used by this model.
5960
/// - generationConfig: The content generation parameters your model should use.
@@ -64,7 +65,7 @@ public final class GenerativeModel: Sendable {
6465
/// only text content is supported.
6566
/// - requestOptions: Configuration parameters for sending requests to the backend.
6667
/// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
67-
init(name: String,
68+
init(modelResourceName: String,
6869
firebaseInfo: FirebaseInfo,
6970
apiConfig: APIConfig,
7071
generationConfig: GenerationConfig? = nil,
@@ -74,14 +75,7 @@ public final class GenerativeModel: Sendable {
7475
systemInstruction: ModelContent? = nil,
7576
requestOptions: RequestOptions,
7677
urlSession: URLSession = .shared) {
77-
if !name.starts(with: GenerativeModel.geminiModelNamePrefix) {
78-
VertexLog.warning(code: .unsupportedGeminiModel, """
79-
Unsupported Gemini model "\(name)"; see \
80-
https://firebase.google.com/docs/vertex-ai/models for a list supported Gemini model names.
81-
""")
82-
}
83-
84-
modelResourceName = name
78+
self.modelResourceName = modelResourceName
8579
self.apiConfig = apiConfig
8680
generativeAIService = GenerativeAIService(
8781
firebaseInfo: firebaseInfo,
@@ -108,7 +102,7 @@ public final class GenerativeModel: Sendable {
108102
`\(VertexLog.enableArgumentKey)` as a launch argument in Xcode.
109103
""")
110104
}
111-
VertexLog.debug(code: .generativeModelInitialized, "Model \(name) initialized.")
105+
VertexLog.debug(code: .generativeModelInitialized, "Model \(modelResourceName) initialized.")
112106
}
113107

114108
/// Generates content from String and/or image inputs, given to the model as a prompt, that are

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift

+2-9
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,14 @@ public final class ImagenModel {
4747
/// Configuration parameters for sending requests to the backend.
4848
let requestOptions: RequestOptions
4949

50-
init(name: String,
50+
init(modelResourceName: String,
5151
firebaseInfo: FirebaseInfo,
5252
apiConfig: APIConfig,
5353
generationConfig: ImagenGenerationConfig?,
5454
safetySettings: ImagenSafetySettings?,
5555
requestOptions: RequestOptions,
5656
urlSession: URLSession = .shared) {
57-
if !name.starts(with: ImagenModel.imagenModelNamePrefix) {
58-
VertexLog.warning(code: .unsupportedImagenModel, """
59-
Unsupported Imagen model "\(name)"; see \
60-
https://firebase.google.com/docs/vertex-ai/models for a list supported Imagen model names.
61-
""")
62-
}
63-
64-
modelResourceName = name
57+
self.modelResourceName = modelResourceName
6558
self.apiConfig = apiConfig
6659
generativeAIService = GenerativeAIService(
6760
firebaseInfo: firebaseInfo,

FirebaseVertexAI/Sources/VertexAI.swift

+16-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,15 @@ public class VertexAI {
7070
systemInstruction: ModelContent? = nil,
7171
requestOptions: RequestOptions = RequestOptions())
7272
-> GenerativeModel {
73+
if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix) {
74+
VertexLog.warning(code: .unsupportedGeminiModel, """
75+
Unsupported Gemini model "\(modelName)"; see \
76+
https://firebase.google.com/docs/vertex-ai/models for a list supported Gemini model names.
77+
""")
78+
}
79+
7380
return GenerativeModel(
74-
name: modelResourceName(modelName: modelName),
81+
modelResourceName: modelResourceName(modelName: modelName),
7582
firebaseInfo: firebaseInfo,
7683
apiConfig: apiConfig,
7784
generationConfig: generationConfig,
@@ -102,8 +109,15 @@ public class VertexAI {
102109
public func imagenModel(modelName: String, generationConfig: ImagenGenerationConfig? = nil,
103110
safetySettings: ImagenSafetySettings? = nil,
104111
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
112+
if !modelName.starts(with: ImagenModel.imagenModelNamePrefix) {
113+
VertexLog.warning(code: .unsupportedImagenModel, """
114+
Unsupported Imagen model "\(modelName)"; see \
115+
https://firebase.google.com/docs/vertex-ai/models for a list supported Imagen model names.
116+
""")
117+
}
118+
105119
return ImagenModel(
106-
name: modelResourceName(modelName: modelName),
120+
modelResourceName: modelResourceName(modelName: modelName),
107121
firebaseInfo: firebaseInfo,
108122
apiConfig: apiConfig,
109123
generationConfig: generationConfig,

FirebaseVertexAI/Tests/Unit/ChatTests.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ final class ChatTests: XCTestCase {
5959
options: FirebaseOptions(googleAppID: "ignore",
6060
gcmSenderID: "ignore"))
6161
let model = GenerativeModel(
62-
name: "my-model",
62+
modelResourceName: "my-model",
6363
firebaseInfo: FirebaseInfo(
6464
projectID: "my-project-id",
6565
apiKey: "API_KEY",

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

+14-14
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ final class GenerativeModelTests: XCTestCase {
7070
configuration.protocolClasses = [MockURLProtocol.self]
7171
urlSession = try XCTUnwrap(URLSession(configuration: configuration))
7272
model = GenerativeModel(
73-
name: testModelResourceName,
73+
modelResourceName: testModelResourceName,
7474
firebaseInfo: testFirebaseInfo(),
7575
apiConfig: apiConfig,
7676
tools: nil,
@@ -276,7 +276,7 @@ final class GenerativeModelTests: XCTestCase {
276276
)
277277
let model = GenerativeModel(
278278
// Model name is prefixed with "models/".
279-
name: "models/test-model",
279+
modelResourceName: "models/test-model",
280280
firebaseInfo: testFirebaseInfo(),
281281
apiConfig: apiConfig,
282282
tools: nil,
@@ -399,7 +399,7 @@ final class GenerativeModelTests: XCTestCase {
399399
func testGenerateContent_appCheck_validToken() async throws {
400400
let appCheckToken = "test-valid-token"
401401
model = GenerativeModel(
402-
name: testModelResourceName,
402+
modelResourceName: testModelResourceName,
403403
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
404404
apiConfig: apiConfig,
405405
tools: nil,
@@ -420,7 +420,7 @@ final class GenerativeModelTests: XCTestCase {
420420
func testGenerateContent_dataCollectionOff() async throws {
421421
let appCheckToken = "test-valid-token"
422422
model = GenerativeModel(
423-
name: testModelResourceName,
423+
modelResourceName: testModelResourceName,
424424
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken),
425425
privateAppID: true),
426426
apiConfig: apiConfig,
@@ -442,7 +442,7 @@ final class GenerativeModelTests: XCTestCase {
442442

443443
func testGenerateContent_appCheck_tokenRefreshError() async throws {
444444
model = GenerativeModel(
445-
name: testModelResourceName,
445+
modelResourceName: testModelResourceName,
446446
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
447447
apiConfig: apiConfig,
448448
tools: nil,
@@ -463,7 +463,7 @@ final class GenerativeModelTests: XCTestCase {
463463
func testGenerateContent_auth_validAuthToken() async throws {
464464
let authToken = "test-valid-token"
465465
model = GenerativeModel(
466-
name: testModelResourceName,
466+
modelResourceName: testModelResourceName,
467467
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)),
468468
apiConfig: apiConfig,
469469
tools: nil,
@@ -483,7 +483,7 @@ final class GenerativeModelTests: XCTestCase {
483483

484484
func testGenerateContent_auth_nilAuthToken() async throws {
485485
model = GenerativeModel(
486-
name: testModelResourceName,
486+
modelResourceName: testModelResourceName,
487487
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)),
488488
apiConfig: apiConfig,
489489
tools: nil,
@@ -503,7 +503,7 @@ final class GenerativeModelTests: XCTestCase {
503503

504504
func testGenerateContent_auth_authTokenRefreshError() async throws {
505505
model = GenerativeModel(
506-
name: "my-model",
506+
modelResourceName: "my-model",
507507
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())),
508508
apiConfig: apiConfig,
509509
tools: nil,
@@ -900,7 +900,7 @@ final class GenerativeModelTests: XCTestCase {
900900
)
901901
let requestOptions = RequestOptions(timeout: expectedTimeout)
902902
model = GenerativeModel(
903-
name: testModelResourceName,
903+
modelResourceName: testModelResourceName,
904904
firebaseInfo: testFirebaseInfo(),
905905
apiConfig: apiConfig,
906906
tools: nil,
@@ -1204,7 +1204,7 @@ final class GenerativeModelTests: XCTestCase {
12041204
func testGenerateContentStream_appCheck_validToken() async throws {
12051205
let appCheckToken = "test-valid-token"
12061206
model = GenerativeModel(
1207-
name: testModelResourceName,
1207+
modelResourceName: testModelResourceName,
12081208
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
12091209
apiConfig: apiConfig,
12101210
tools: nil,
@@ -1225,7 +1225,7 @@ final class GenerativeModelTests: XCTestCase {
12251225

12261226
func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
12271227
model = GenerativeModel(
1228-
name: testModelResourceName,
1228+
modelResourceName: testModelResourceName,
12291229
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
12301230
apiConfig: apiConfig,
12311231
tools: nil,
@@ -1375,7 +1375,7 @@ final class GenerativeModelTests: XCTestCase {
13751375
)
13761376
let requestOptions = RequestOptions(timeout: expectedTimeout)
13771377
model = GenerativeModel(
1378-
name: testModelResourceName,
1378+
modelResourceName: testModelResourceName,
13791379
firebaseInfo: testFirebaseInfo(),
13801380
apiConfig: apiConfig,
13811381
tools: nil,
@@ -1451,7 +1451,7 @@ final class GenerativeModelTests: XCTestCase {
14511451
parts: "You are a calculator. Use the provided tools."
14521452
)
14531453
model = GenerativeModel(
1454-
name: testModelResourceName,
1454+
modelResourceName: testModelResourceName,
14551455
firebaseInfo: testFirebaseInfo(),
14561456
apiConfig: apiConfig,
14571457
generationConfig: generationConfig,
@@ -1511,7 +1511,7 @@ final class GenerativeModelTests: XCTestCase {
15111511
)
15121512
let requestOptions = RequestOptions(timeout: expectedTimeout)
15131513
model = GenerativeModel(
1514-
name: testModelResourceName,
1514+
modelResourceName: testModelResourceName,
15151515
firebaseInfo: testFirebaseInfo(),
15161516
apiConfig: apiConfig,
15171517
tools: nil,

0 commit comments

Comments
 (0)