Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ sealed interface LlmProvider {
*
* @property nameRes A string resource ID representing the display name of the provider.
* @property iconRes A drawable resource ID representing the icon of the provider if present.
* @property model The identifier of the model that this provider serves (e.g.
* "moz-summarization"). Used for telemetry and logging.
*/
data class Info(val nameRes: Int, val iconRes: Int? = null)
data class Info(val nameRes: Int, val iconRes: Int? = null, val model: String? = null)

/**
* Metadata about this provider, including its display name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ import mozilla.components.lib.llm.mlpa.service.ChatService.Request.ModelID
internal class MlpaLlm(
val chatService: ChatService,
val authorizationToken: AuthorizationToken,
val model: ModelID,
) : Llm {
override suspend fun prompt(prompt: Prompt): Flow<String> = chatService.completion(
authorizationToken,
request = prompt.asRequest,
request = prompt.toRequest(model),
)
}

internal val Prompt.asRequest
get() = Request(
model = ModelID.mozSummarization,
messages = buildList {
systemPrompt?.let { add(Message.system(it)) }
add(Message.user(userPrompt))
},
stream = true,
)
internal fun Prompt.toRequest(model: ModelID) = Request(
model = model,
messages = buildList {
systemPrompt?.let { add(Message.system(it)) }
add(Message.user(userPrompt))
},
stream = true,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import mozilla.components.concept.llm.ErrorCode
import mozilla.components.concept.llm.Llm
import mozilla.components.concept.llm.LlmProvider
import mozilla.components.lib.llm.mlpa.service.ChatService
import mozilla.components.lib.llm.mlpa.service.ChatService.Request.ModelID
import mozilla.components.lib.llm.mlpa.service.ChatServiceError
import mozilla.components.lib.llm.mlpa.service.MlpaService

Expand All @@ -37,7 +38,13 @@ class MlpaLlmProvider(
val storage: MlpaTokenStorage,
val mlpaService: MlpaService,
) : CloudLlmProvider {
override val info = LlmProvider.Info(nameRes = R.string.mlpa_llm_provider_name, iconRes = R.drawable.firefox_icon)
private val model = ModelID.mozSummarization

override val info = LlmProvider.Info(
nameRes = R.string.mlpa_llm_provider_name,
iconRes = R.drawable.firefox_icon,
model = model.value,
)
private val _state = MutableStateFlow<State>(State.Available)

/**
Expand All @@ -55,7 +62,7 @@ class MlpaLlmProvider(
*/
override suspend fun prepare() {
tokenProvider.fetchToken()
.onSuccess { _state.value = State.Ready(MlpaLlm(chatService, it)) }
.onSuccess { _state.value = State.Ready(MlpaLlm(chatService, it, model)) }
.onFailure {
_state.value = State.Unavailable(
it as? Llm.Exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MlpaLlmTest {
successChatService.completion(token, request)
},
authorizationToken = AuthorizationToken.Integrity("my-test-token"),
model = ChatService.Request.ModelID.mozSummarization,
)

val actual = llm.prompt(Prompt("This is my prompt")).toList()
Expand All @@ -49,6 +50,7 @@ class MlpaLlmTest {
val llm = MlpaLlm(
chatService = failureChatService,
authorizationToken = AuthorizationToken.Integrity("my-test-token"),
model = ChatService.Request.ModelID.mozSummarization,
)

llm.prompt(Prompt("This is my prompt"))
Expand All @@ -67,6 +69,7 @@ class MlpaLlmTest {
successChatService.completion(token, request)
},
authorizationToken = AuthorizationToken.Integrity("my-test-token"),
model = ChatService.Request.ModelID.mozSummarization,
)

llm.prompt(Prompt("user prompt", "system prompt")).toList()
Expand All @@ -88,6 +91,7 @@ class MlpaLlmTest {
successChatService.completion(token, request)
},
authorizationToken = AuthorizationToken.Integrity("my-test-token"),
model = ChatService.Request.ModelID.mozSummarization,
)

llm.prompt(Prompt("user prompt", null)).toList()
Expand Down
19 changes: 17 additions & 2 deletions mobile/android/fenix/app/metrics.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15398,10 +15398,17 @@ ai_summarize:
type: string
error_type:
description: |
If an error occurred, the resulting error code. Examples include:
If an error occurred, a human-readable name for the failure
(typically the exception class name, e.g. "RateLimited",
"RequestTooLarge").
type: string
error_code:
description: |
If an error occurred, the integer error code associated with the
failure. Examples include:
1005 (meaning content was too large to summarize)
1007 (meaning LLM service has rate-limited the user)
type: string
type: quantity
summarize_duration_ms:
description: |
Time in milliseconds between request and completion.
Expand All @@ -15413,6 +15420,7 @@ ai_summarize:
type: string
bugs:
- https://bugzilla.mozilla.org/show_bug.cgi?id=2025523
- https://bugzilla.mozilla.org/show_bug.cgi?id=2036474
data_reviews:
- https://phabricator.services.mozilla.com/D289882
data_sensitivity:
Expand Down Expand Up @@ -15443,8 +15451,15 @@ ai_summarize:
description: |
Records when the user initiates summarization.
(e.g. taps menu button, toolbar icon or shakes the device).
extra_keys:
trigger:
description: |
How the user initiated summarization. Possible values:
"shake", "menu".
type: string
bugs:
- https://bugzilla.mozilla.org/show_bug.cgi?id=2025523
- https://bugzilla.mozilla.org/show_bug.cgi?id=2036474
data_reviews:
- https://phabricator.services.mozilla.com/D289882
data_sensitivity:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ class SummarizationTelemetryMiddleware(
when (action) {
ViewAppeared -> handleViewAppeared(stateBefore)
is SummarizationRequested -> {
sessionTelemetry = sessionTelemetry.copy(model = action.info.nameRes.toString())
sessionTelemetry = sessionTelemetry.copy(model = action.info.model)
}
is ContentExtracted -> handleExtractedContent(action.content)
is SummarizationCompleted -> recordSummarizationCompleted()
is SummarizationFailed -> recordSummarizationCompleted(success = false, action.throwable.errorType)
is SummarizationFailed -> recordSummarizationCompleted(success = false, action.throwable)
ViewDismissed -> {
AiSummarize.closed.record(
AiSummarize.ClosedExtra(
Expand Down Expand Up @@ -119,8 +119,6 @@ class SummarizationTelemetryMiddleware(
}

private fun handleViewAppeared(stateBefore: SummarizationState) {
AiSummarize.requested.record()
timerId = AiSummarize.duration.start()
if (stateBefore is SummarizationState.Inert) {
val trigger = if (stateBefore.initializedWithShake) {
SummarizationTrigger.SHAKE
Expand All @@ -129,6 +127,10 @@ class SummarizationTelemetryMiddleware(
}
sessionTelemetry = sessionTelemetry.copy(trigger = trigger)
}
AiSummarize.requested.record(
AiSummarize.RequestedExtra(trigger = sessionTelemetry.trigger?.toString()),
)
timerId = AiSummarize.duration.start()
}

private fun handleExtractedContent(content: Content) {
Expand All @@ -151,7 +153,7 @@ class SummarizationTelemetryMiddleware(
)
}

private fun recordSummarizationCompleted(success: Boolean = true, errorType: String? = null) {
private fun recordSummarizationCompleted(success: Boolean = true, error: Throwable? = null) {
timerId?.let {
AiSummarize.duration.stopAndAccumulate(it)
timerId = null
Expand All @@ -161,7 +163,8 @@ class SummarizationTelemetryMiddleware(
AiSummarize.CompletedExtra(
connectionType = connectionType.toString(),
contentType = sessionTelemetry.contentMetrics?.contentType,
errorType = errorType,
errorType = error?.errorType,
errorCode = error?.errorCode,
language = sessionTelemetry.contentMetrics?.language,
lengthChars = sessionTelemetry.contentMetrics?.charCount,
lengthWords = sessionTelemetry.contentMetrics?.wordCount,
Expand All @@ -173,4 +176,6 @@ class SummarizationTelemetryMiddleware(
}
}

private val Throwable.errorType get() = (this as? Llm.Exception)?.errorCode?.value?.toString()
private val Throwable.errorType get() = this::class.simpleName

private val Throwable.errorCode get() = (this as? Llm.Exception)?.errorCode?.value
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ class SummarizationTelemetryMiddlewareTest {
invokeMiddleware(ViewAppeared)
invokeMiddleware(createContentExtractedAction())

val extras = AiSummarize.started.testGetValue()!!.first().extra!!
assertEquals("SHAKE", extras["trigger"])
val startedExtras = AiSummarize.started.testGetValue()!!.first().extra!!
assertEquals("SHAKE", startedExtras["trigger"])

val requestedExtras = AiSummarize.requested.testGetValue()!!.first().extra!!
assertEquals("SHAKE", requestedExtras["trigger"])
}

@Test
Expand All @@ -77,8 +80,11 @@ class SummarizationTelemetryMiddlewareTest {
invokeMiddleware(ViewAppeared)
invokeMiddleware(createContentExtractedAction())

val extras = AiSummarize.started.testGetValue()!!.first().extra!!
assertEquals("MENU", extras["trigger"])
val startedExtras = AiSummarize.started.testGetValue()!!.first().extra!!
assertEquals("MENU", startedExtras["trigger"])

val requestedExtras = AiSummarize.requested.testGetValue()!!.first().extra!!
assertEquals("MENU", requestedExtras["trigger"])
}

@Test
Expand All @@ -88,7 +94,7 @@ class SummarizationTelemetryMiddlewareTest {
every { store.state } returns SummarizationState.Inert(initializedWithShake = false)
invokeMiddleware(ViewAppeared)
invokeMiddleware(
SummarizationRequested(LlmProvider.Info(nameRes = 42)),
SummarizationRequested(LlmProvider.Info(nameRes = 42, model = TEST_MODEL)),
)
invokeMiddleware(
createContentExtractedAction(
Expand All @@ -106,7 +112,7 @@ class SummarizationTelemetryMiddlewareTest {

val extras = snapshot.first().extra!!
assertEquals("MENU", extras["trigger"])
assertEquals("42", extras["model"])
assertEquals(TEST_MODEL, extras["model"])
assertEquals("120", extras["length_words"])
assertEquals("15", extras["length_chars"])
assertEquals("[recipe]", extras["content_type"])
Expand All @@ -125,13 +131,14 @@ class SummarizationTelemetryMiddlewareTest {
val extras = snapshot.first().extra!!
assertEquals("true", extras["success"])
assertEquals("WIFI", extras["connection_type"])
assertEquals("42", extras["model"])
assertEquals(TEST_MODEL, extras["model"])
assertNull(extras["error_type"])
assertNull(extras["error_code"])
assertNotNull(extras["summarize_duration_ms"])
}

@Test
fun `WHEN SummarizationFailed is received THEN summarization_completed is recorded with success false`() {
fun `WHEN SummarizationFailed with Llm Exception THEN error_type is exception name and error_code is the int code`() {
assertNull(AiSummarize.completed.testGetValue())

setupFullSession()
Expand All @@ -143,7 +150,21 @@ class SummarizationTelemetryMiddlewareTest {

val extras = snapshot.first().extra!!
assertEquals("false", extras["success"])
assertEquals(exception.errorCode.value.toString(), extras["error_type"])
assertEquals("Exception", extras["error_type"])
assertEquals("1001", extras["error_code"])
}

@Test
fun `WHEN SummarizationFailed with non-Llm throwable THEN error_type is class name and error_code is absent`() {
assertNull(AiSummarize.completed.testGetValue())

setupFullSession()
invokeMiddleware(SummarizationFailed(IllegalStateException("oops")))

val extras = AiSummarize.completed.testGetValue()!!.first().extra!!
assertEquals("false", extras["success"])
assertEquals("IllegalStateException", extras["error_type"])
assertNull(extras["error_code"])
}

@Test
Expand All @@ -160,12 +181,12 @@ class SummarizationTelemetryMiddlewareTest {
every { store.state } returns SummarizationState.Inert(initializedWithShake = false)
invokeMiddleware(ViewAppeared)
invokeMiddleware(
SummarizationRequested(LlmProvider.Info(nameRes = 99)),
SummarizationRequested(LlmProvider.Info(nameRes = 99, model = "another-model")),
)
invokeMiddleware(ViewDismissed)

val extras = AiSummarize.closed.testGetValue()!!.first().extra!!
assertEquals("99", extras["model"])
assertEquals("another-model", extras["model"])
}

@Test
Expand Down Expand Up @@ -234,7 +255,7 @@ class SummarizationTelemetryMiddlewareTest {
every { store.state } returns SummarizationState.Inert(initializedWithShake = false)
invokeMiddleware(ViewAppeared)
invokeMiddleware(
SummarizationRequested(LlmProvider.Info(nameRes = 42)),
SummarizationRequested(LlmProvider.Info(nameRes = 42, model = TEST_MODEL)),
)
invokeMiddleware(createContentExtractedAction())
}
Expand All @@ -251,4 +272,8 @@ class SummarizationTelemetryMiddlewareTest {
action = action,
)
}

private companion object {
const val TEST_MODEL = "moz-summarization"
}
}