Skip to content

Commit a4727f0

Browse files
davidmotsonDavid Motsonashvili
authored andcommitted
Add ML monitoring info to GenAI requests (#6804)
Co-authored-by: David Motsonashvili <[email protected]>
1 parent 307ef13 commit a4727f0

File tree

9 files changed

+151
-1
lines changed

9 files changed

+151
-1
lines changed

firebase-vertexai/firebase-vertexai.gradle.kts

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ dependencies {
115115
testImplementation(libs.kotlin.coroutines.test)
116116
testImplementation(libs.robolectric)
117117
testImplementation(libs.truth)
118+
testImplementation(libs.mockito.core)
118119

119120
androidTestImplementation(libs.androidx.espresso.core)
120121
androidTestImplementation(libs.androidx.test.junit)

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ internal constructor(
7171
return GenerativeModel(
7272
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
7373
firebaseApp.options.apiKey,
74+
firebaseApp,
7475
generationConfig,
7576
safetySettings,
7677
tools,
@@ -105,6 +106,7 @@ internal constructor(
105106
return ImagenModel(
106107
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
107108
firebaseApp.options.apiKey,
109+
firebaseApp,
108110
generationConfig,
109111
safetySettings,
110112
requestOptions,

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.firebase.vertexai
1818

1919
import android.graphics.Bitmap
20+
import com.google.firebase.FirebaseApp
2021
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
2122
import com.google.firebase.auth.internal.InternalAuthProvider
2223
import com.google.firebase.vertexai.common.APIController
@@ -59,6 +60,7 @@ internal constructor(
5960
internal constructor(
6061
modelName: String,
6162
apiKey: String,
63+
firebaseApp: FirebaseApp,
6264
generationConfig: GenerationConfig? = null,
6365
safetySettings: List<SafetySetting>? = null,
6466
tools: List<Tool>? = null,
@@ -79,6 +81,7 @@ internal constructor(
7981
modelName,
8082
requestOptions,
8183
"gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}",
84+
firebaseApp,
8285
AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider),
8386
),
8487
)

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/ImagenModel.kt

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.firebase.vertexai
1818

19+
import com.google.firebase.FirebaseApp
1920
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
2021
import com.google.firebase.auth.internal.InternalAuthProvider
2122
import com.google.firebase.vertexai.common.APIController
@@ -46,6 +47,7 @@ internal constructor(
4647
internal constructor(
4748
modelName: String,
4849
apiKey: String,
50+
firebaseApp: FirebaseApp,
4951
generationConfig: ImagenGenerationConfig? = null,
5052
safetySettings: ImagenSafetySettings? = null,
5153
requestOptions: RequestOptions = RequestOptions(),
@@ -60,6 +62,7 @@ internal constructor(
6062
modelName,
6163
requestOptions,
6264
"gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}",
65+
firebaseApp,
6366
AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider),
6467
),
6568
)

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt

+30-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package com.google.firebase.vertexai.common
1818

1919
import android.util.Log
2020
import com.google.firebase.Firebase
21+
import com.google.firebase.FirebaseApp
2122
import com.google.firebase.options
2223
import com.google.firebase.vertexai.common.util.decodeToFlow
2324
import com.google.firebase.vertexai.common.util.fullModelName
@@ -91,6 +92,9 @@ internal constructor(
9192
private val requestOptions: RequestOptions,
9293
httpEngine: HttpClientEngine,
9394
private val apiClient: String,
95+
private val firebaseApp: FirebaseApp,
96+
private val appVersion: Int = 0,
97+
private val googleAppId: String,
9498
private val headerProvider: HeaderProvider?,
9599
) {
96100

@@ -99,8 +103,19 @@ internal constructor(
99103
model: String,
100104
requestOptions: RequestOptions,
101105
apiClient: String,
106+
firebaseApp: FirebaseApp,
102107
headerProvider: HeaderProvider? = null,
103-
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)
108+
) : this(
109+
key,
110+
model,
111+
requestOptions,
112+
OkHttp.create(),
113+
apiClient,
114+
firebaseApp,
115+
getVersionNumber(firebaseApp),
116+
firebaseApp.options.applicationId,
117+
headerProvider
118+
)
104119

105120
private val model = fullModelName(model)
106121

@@ -175,6 +190,10 @@ internal constructor(
175190
contentType(ContentType.Application.Json)
176191
header("x-goog-api-key", key)
177192
header("x-goog-api-client", apiClient)
193+
if (firebaseApp.isDataCollectionDefaultEnabled) {
194+
header("X-Firebase-AppId", googleAppId)
195+
header("X-Firebase-AppVersion", appVersion)
196+
}
178197
}
179198

180199
private suspend fun HttpRequestBuilder.applyHeaderProvider() {
@@ -240,6 +259,16 @@ internal constructor(
240259

241260
companion object {
242261
private val TAG = APIController::class.java.simpleName
262+
263+
private fun getVersionNumber(app: FirebaseApp): Int {
264+
try {
265+
val context = app.applicationContext
266+
return context.packageManager.getPackageInfo(context.packageName, 0).versionCode
267+
} catch (e: Exception) {
268+
Log.d(TAG, "Error while getting app version: ${e.message}")
269+
return 0
270+
}
271+
}
243272
}
244273
}
245274

firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt

+18
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.firebase.vertexai
1818

19+
import com.google.firebase.FirebaseApp
1920
import com.google.firebase.vertexai.common.APIController
2021
import com.google.firebase.vertexai.common.JSON
2122
import com.google.firebase.vertexai.common.util.doBlocking
@@ -42,10 +43,21 @@ import kotlin.time.Duration.Companion.seconds
4243
import kotlinx.coroutines.withTimeout
4344
import kotlinx.serialization.ExperimentalSerializationApi
4445
import kotlinx.serialization.encodeToString
46+
import org.junit.Before
4547
import org.junit.Test
48+
import org.mockito.Mockito
4649

4750
internal class GenerativeModelTesting {
4851
private val TEST_CLIENT_ID = "test"
52+
private val TEST_APP_ID = "1:android:12345"
53+
private val TEST_VERSION = 1
54+
55+
private var mockFirebaseApp: FirebaseApp = Mockito.mock<FirebaseApp>()
56+
57+
@Before
58+
fun setup() {
59+
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
60+
}
4961

5062
@Test
5163
fun `system calling in request`() = doBlocking {
@@ -64,6 +76,9 @@ internal class GenerativeModelTesting {
6476
RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"),
6577
mockEngine,
6678
TEST_CLIENT_ID,
79+
mockFirebaseApp,
80+
TEST_VERSION,
81+
TEST_APP_ID,
6782
null,
6883
)
6984

@@ -109,6 +124,9 @@ internal class GenerativeModelTesting {
109124
RequestOptions(),
110125
mockEngine,
111126
TEST_CLIENT_ID,
127+
mockFirebaseApp,
128+
TEST_VERSION,
129+
TEST_APP_ID,
112130
null,
113131
)
114132

firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt

+74
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.firebase.vertexai.common
1818

19+
import com.google.firebase.FirebaseApp
1920
import com.google.firebase.vertexai.BuildConfig
2021
import com.google.firebase.vertexai.common.util.commonTest
2122
import com.google.firebase.vertexai.common.util.createResponses
@@ -49,12 +50,18 @@ import kotlinx.coroutines.withTimeout
4950
import kotlinx.serialization.ExperimentalSerializationApi
5051
import kotlinx.serialization.encodeToString
5152
import kotlinx.serialization.json.JsonObject
53+
import org.junit.Before
5254
import org.junit.Test
5355
import org.junit.runner.RunWith
5456
import org.junit.runners.Parameterized
57+
import org.mockito.Mockito
5558

5659
private val TEST_CLIENT_ID = "genai-android/test"
5760

61+
private val TEST_APP_ID = "1:android:12345"
62+
63+
private val TEST_VERSION = 1
64+
5865
internal class APIControllerTests {
5966
private val testTimeout = 5.seconds
6067

@@ -87,6 +94,14 @@ internal class APIControllerTests {
8794

8895
@OptIn(ExperimentalSerializationApi::class)
8996
internal class RequestFormatTests {
97+
98+
private val mockFirebaseApp = Mockito.mock<FirebaseApp>()
99+
100+
@Before
101+
fun setup() {
102+
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
103+
}
104+
90105
@Test
91106
fun `using default endpoint`() = doBlocking {
92107
val channel = ByteChannel(autoFlush = true)
@@ -101,6 +116,9 @@ internal class RequestFormatTests {
101116
RequestOptions(),
102117
mockEngine,
103118
"genai-android/${BuildConfig.VERSION_NAME}",
119+
mockFirebaseApp,
120+
TEST_VERSION,
121+
TEST_APP_ID,
104122
null,
105123
)
106124

@@ -128,6 +146,9 @@ internal class RequestFormatTests {
128146
RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"),
129147
mockEngine,
130148
TEST_CLIENT_ID,
149+
mockFirebaseApp,
150+
TEST_VERSION,
151+
TEST_APP_ID,
131152
null,
132153
)
133154

@@ -155,6 +176,9 @@ internal class RequestFormatTests {
155176
RequestOptions(),
156177
mockEngine,
157178
TEST_CLIENT_ID,
179+
mockFirebaseApp,
180+
TEST_VERSION,
181+
TEST_APP_ID,
158182
null,
159183
)
160184

@@ -163,6 +187,35 @@ internal class RequestFormatTests {
163187
mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID
164188
}
165189

190+
@Test
191+
fun `ml monitoring header is set correctly if data collection is enabled`() = doBlocking {
192+
val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10))
193+
val mockEngine = MockEngine {
194+
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
195+
}
196+
197+
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(true)
198+
199+
val controller =
200+
APIController(
201+
"super_cool_test_key",
202+
"gemini-pro-1.5",
203+
RequestOptions(),
204+
mockEngine,
205+
TEST_CLIENT_ID,
206+
mockFirebaseApp,
207+
TEST_VERSION,
208+
TEST_APP_ID,
209+
null,
210+
)
211+
212+
withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
213+
214+
mockEngine.requestHistory.first().headers["X-Firebase-AppId"] shouldBe TEST_APP_ID
215+
mockEngine.requestHistory.first().headers["X-Firebase-AppVersion"] shouldBe
216+
TEST_VERSION.toString()
217+
}
218+
166219
@Test
167220
fun `ToolConfig serialization contains correct keys`() = doBlocking {
168221
val channel = ByteChannel(autoFlush = true)
@@ -178,6 +231,9 @@ internal class RequestFormatTests {
178231
RequestOptions(),
179232
mockEngine,
180233
TEST_CLIENT_ID,
234+
mockFirebaseApp,
235+
TEST_VERSION,
236+
TEST_APP_ID,
181237
null,
182238
)
183239

@@ -229,6 +285,9 @@ internal class RequestFormatTests {
229285
RequestOptions(),
230286
mockEngine,
231287
TEST_CLIENT_ID,
288+
mockFirebaseApp,
289+
TEST_VERSION,
290+
TEST_APP_ID,
232291
testHeaderProvider,
233292
)
234293

@@ -263,6 +322,9 @@ internal class RequestFormatTests {
263322
RequestOptions(),
264323
mockEngine,
265324
TEST_CLIENT_ID,
325+
mockFirebaseApp,
326+
TEST_VERSION,
327+
TEST_APP_ID,
266328
testHeaderProvider,
267329
)
268330

@@ -286,6 +348,9 @@ internal class RequestFormatTests {
286348
RequestOptions(),
287349
mockEngine,
288350
TEST_CLIENT_ID,
351+
mockFirebaseApp,
352+
TEST_VERSION,
353+
TEST_APP_ID,
289354
null,
290355
)
291356

@@ -309,6 +374,12 @@ internal class RequestFormatTests {
309374

310375
@RunWith(Parameterized::class)
311376
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {
377+
private val mockFirebaseApp = Mockito.mock<FirebaseApp>()
378+
379+
@Before
380+
fun setup() {
381+
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
382+
}
312383

313384
@Test
314385
fun `request should include right model name`() = doBlocking {
@@ -324,6 +395,9 @@ internal class ModelNamingTests(private val modelName: String, private val actua
324395
RequestOptions(),
325396
mockEngine,
326397
TEST_CLIENT_ID,
398+
mockFirebaseApp,
399+
TEST_VERSION,
400+
TEST_APP_ID,
327401
null,
328402
)
329403

firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt

+10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package com.google.firebase.vertexai.common.util
2020

21+
import com.google.firebase.FirebaseApp
2122
import com.google.firebase.vertexai.common.APIController
2223
import com.google.firebase.vertexai.common.JSON
2324
import com.google.firebase.vertexai.type.Candidate
@@ -33,8 +34,11 @@ import io.ktor.http.headersOf
3334
import io.ktor.utils.io.ByteChannel
3435
import kotlinx.serialization.ExperimentalSerializationApi
3536
import kotlinx.serialization.encodeToString
37+
import org.mockito.Mockito
3638

3739
private val TEST_CLIENT_ID = "genai-android/test"
40+
private val TEST_APP_ID = "1:android:12345"
41+
private val TEST_VERSION = 1
3842

3943
internal fun prepareStreamingResponse(
4044
response: List<GenerateContentResponse.Internal>
@@ -90,6 +94,9 @@ internal fun commonTest(
9094
requestOptions: RequestOptions = RequestOptions(),
9195
block: CommonTest,
9296
) = doBlocking {
97+
val mockFirebaseApp = Mockito.mock<FirebaseApp>()
98+
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
99+
93100
val channel = ByteChannel(autoFlush = true)
94101
val apiController =
95102
APIController(
@@ -100,6 +107,9 @@ internal fun commonTest(
100107
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
101108
},
102109
TEST_CLIENT_ID,
110+
mockFirebaseApp,
111+
TEST_VERSION,
112+
TEST_APP_ID,
103113
null,
104114
)
105115
CommonTestScope(channel, apiController).block()

0 commit comments

Comments
 (0)