16
16
17
17
package com.google.firebase.vertexai.common
18
18
19
+ import com.google.firebase.FirebaseApp
19
20
import com.google.firebase.vertexai.BuildConfig
20
21
import com.google.firebase.vertexai.common.util.commonTest
21
22
import com.google.firebase.vertexai.common.util.createResponses
@@ -49,12 +50,18 @@ import kotlinx.coroutines.withTimeout
49
50
import kotlinx.serialization.ExperimentalSerializationApi
50
51
import kotlinx.serialization.encodeToString
51
52
import kotlinx.serialization.json.JsonObject
53
+ import org.junit.Before
52
54
import org.junit.Test
53
55
import org.junit.runner.RunWith
54
56
import org.junit.runners.Parameterized
57
+ import org.mockito.Mockito
55
58
56
59
private val TEST_CLIENT_ID = " genai-android/test"
57
60
61
+ private val TEST_APP_ID = " 1:android:12345"
62
+
63
+ private val TEST_VERSION = 1
64
+
58
65
internal class APIControllerTests {
59
66
private val testTimeout = 5 .seconds
60
67
@@ -87,6 +94,14 @@ internal class APIControllerTests {
87
94
88
95
@OptIn(ExperimentalSerializationApi ::class )
89
96
internal class RequestFormatTests {
97
+
98
+ private val mockFirebaseApp = Mockito .mock<FirebaseApp >()
99
+
100
+ @Before
101
+ fun setup () {
102
+ Mockito .`when `(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false )
103
+ }
104
+
90
105
@Test
91
106
fun `using default endpoint` () = doBlocking {
92
107
val channel = ByteChannel (autoFlush = true )
@@ -101,6 +116,9 @@ internal class RequestFormatTests {
101
116
RequestOptions (),
102
117
mockEngine,
103
118
" genai-android/${BuildConfig .VERSION_NAME } " ,
119
+ mockFirebaseApp,
120
+ TEST_VERSION ,
121
+ TEST_APP_ID ,
104
122
null ,
105
123
)
106
124
@@ -128,6 +146,9 @@ internal class RequestFormatTests {
128
146
RequestOptions (timeout = 5 .seconds, endpoint = " https://my.custom.endpoint" ),
129
147
mockEngine,
130
148
TEST_CLIENT_ID ,
149
+ mockFirebaseApp,
150
+ TEST_VERSION ,
151
+ TEST_APP_ID ,
131
152
null ,
132
153
)
133
154
@@ -155,6 +176,9 @@ internal class RequestFormatTests {
155
176
RequestOptions (),
156
177
mockEngine,
157
178
TEST_CLIENT_ID ,
179
+ mockFirebaseApp,
180
+ TEST_VERSION ,
181
+ TEST_APP_ID ,
158
182
null ,
159
183
)
160
184
@@ -163,6 +187,35 @@ internal class RequestFormatTests {
163
187
mockEngine.requestHistory.first().headers[" x-goog-api-client" ] shouldBe TEST_CLIENT_ID
164
188
}
165
189
190
+ @Test
191
+ fun `ml monitoring header is set correctly if data collection is enabled` () = doBlocking {
192
+ val response = JSON .encodeToString(CountTokensResponse .Internal (totalTokens = 10 ))
193
+ val mockEngine = MockEngine {
194
+ respond(response, HttpStatusCode .OK , headersOf(HttpHeaders .ContentType , " application/json" ))
195
+ }
196
+
197
+ Mockito .`when `(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(true )
198
+
199
+ val controller =
200
+ APIController (
201
+ " super_cool_test_key" ,
202
+ " gemini-pro-1.5" ,
203
+ RequestOptions (),
204
+ mockEngine,
205
+ TEST_CLIENT_ID ,
206
+ mockFirebaseApp,
207
+ TEST_VERSION ,
208
+ TEST_APP_ID ,
209
+ null ,
210
+ )
211
+
212
+ withTimeout(5 .seconds) { controller.countTokens(textCountTokenRequest(" cats" )) }
213
+
214
+ mockEngine.requestHistory.first().headers[" X-Firebase-AppId" ] shouldBe TEST_APP_ID
215
+ mockEngine.requestHistory.first().headers[" X-Firebase-AppVersion" ] shouldBe
216
+ TEST_VERSION .toString()
217
+ }
218
+
166
219
@Test
167
220
fun `ToolConfig serialization contains correct keys` () = doBlocking {
168
221
val channel = ByteChannel (autoFlush = true )
@@ -178,6 +231,9 @@ internal class RequestFormatTests {
178
231
RequestOptions (),
179
232
mockEngine,
180
233
TEST_CLIENT_ID ,
234
+ mockFirebaseApp,
235
+ TEST_VERSION ,
236
+ TEST_APP_ID ,
181
237
null ,
182
238
)
183
239
@@ -229,6 +285,9 @@ internal class RequestFormatTests {
229
285
RequestOptions (),
230
286
mockEngine,
231
287
TEST_CLIENT_ID ,
288
+ mockFirebaseApp,
289
+ TEST_VERSION ,
290
+ TEST_APP_ID ,
232
291
testHeaderProvider,
233
292
)
234
293
@@ -263,6 +322,9 @@ internal class RequestFormatTests {
263
322
RequestOptions (),
264
323
mockEngine,
265
324
TEST_CLIENT_ID ,
325
+ mockFirebaseApp,
326
+ TEST_VERSION ,
327
+ TEST_APP_ID ,
266
328
testHeaderProvider,
267
329
)
268
330
@@ -286,6 +348,9 @@ internal class RequestFormatTests {
286
348
RequestOptions (),
287
349
mockEngine,
288
350
TEST_CLIENT_ID ,
351
+ mockFirebaseApp,
352
+ TEST_VERSION ,
353
+ TEST_APP_ID ,
289
354
null ,
290
355
)
291
356
@@ -309,6 +374,12 @@ internal class RequestFormatTests {
309
374
310
375
@RunWith(Parameterized ::class )
311
376
internal class ModelNamingTests (private val modelName : String , private val actualName : String ) {
377
+ private val mockFirebaseApp = Mockito .mock<FirebaseApp >()
378
+
379
+ @Before
380
+ fun setup () {
381
+ Mockito .`when `(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false )
382
+ }
312
383
313
384
@Test
314
385
fun `request should include right model name` () = doBlocking {
@@ -324,6 +395,9 @@ internal class ModelNamingTests(private val modelName: String, private val actua
324
395
RequestOptions (),
325
396
mockEngine,
326
397
TEST_CLIENT_ID ,
398
+ mockFirebaseApp,
399
+ TEST_VERSION ,
400
+ TEST_APP_ID ,
327
401
null ,
328
402
)
329
403
0 commit comments