在Android上使用Gemma 2B进行设备端LLM处理 AI聊天应用
随着大型语言模型(LLMs)的不断进步,将它们集成到移动应用中变得越来越可行且有益。设备端LLM处理具有多种优势,如降低延迟、增强隐私保护和离线功能。
通过直接在设备上运行LLMs,应用程序可以提供实时响应,无需依赖持续的互联网连接或将敏感数据暴露给外部服务器。
本博客探讨了在Android上进行设备端LLM处理的概念,展示了如何使用Kotlin实现这一功能。我们将逐步介绍利用LLMs进行实时文本生成和处理的Android应用程序的关键组件,提供一种高效且安全的方式来直接在设备上处理语言模型。
入门指南
我们使用的是Gemma 2B ,Gemma 是一系列轻量级、开源的模型,基于谷歌创建 Gemini 模型时所用的研究和技术的成果。您可以从提供的链接下载该模型,解压后即可使用。
要开始使用,请创建一个新的 Android 项目,我们将使用 Compose。我们将借助谷歌的 Mediapipe 与模型进行交互。MediaPipe Solutions 提供了一系列库和工具,旨在帮助您快速将人工智能(AI)和机器学习(ML)功能集成到您的应用程序中。
将模型复制到设备
- 在您的计算机上下载文件夹中解压下载的模型,并连接您的移动设备。
- 在终端中使用adb运行以下命令:
1 2 3
| adb shell rm -r /data/local/tmp/llm/ adb shell mkdir -p /data/local/tmp/llm/ adb push gemma2b.bin /data/local/tmp/llm/gemma2b.bin
|
这些命令将把gemma2b.bin模型文件复制到临时目录。现在模型已位于正确位置,我们可以开始编码部分了。
开始编码
在 AndroidManifest 文件中添加以下内容以支持原生库:
1 2 3 4 5 6 7 8 9
| <uses-native-library android:name="libOpenCL.so" android:required="false" /> <uses-native-library android:name="libOpenCL-car.so" android:required="false" /> <uses-native-library android:name="libOpenCL-pixel.so" android:required="false" />
|
在您的 Android 应用的 build.gradle 文件中添加以下依赖项:
1 2 3
| dependencies { implementation 'com.google.mediapipe:tasks-genai:0.10.14' }
|
LLMTask 类
我们实现的核心组件是 LLMTask 类。该类负责 LLM 推理的初始化和执行,确保模型在设备上高效运行。让我们分解这个类的关键元素:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| class LLMTask(context: Context) { private val _partialResults = MutableSharedFlow<Pair<String, Boolean>>( extraBufferCapacity = 1, onBufferOverflow = BufferOverflow.DROP_OLDEST ) val partialResults: SharedFlow<Pair<String, Boolean>> = _partialResults.asSharedFlow() private var llmInference: LlmInference
init { val options = LlmInference.LlmInferenceOptions.builder() .setModelPath(MODEL_PATH) .setMaxTokens(2048) .setTopK(50) .setTemperature(0.7f) .setRandomSeed(1) .setResultListener { partialResult, done -> _partialResults.tryEmit(partialResult to done) } .build()
llmInference = LlmInference.createFromOptions( context, options ) }
fun generateResponse(prompt: String) { llmInference.generateResponseAsync(prompt) }
companion object { private const val MODEL_PATH = "/data/local/tmp/llm/gemma2b.bin" private var instance: LLMTask? = null fun getInstance(context: Context): LLMTask { return if (instance != null) { instance!! } else { LLMTask(context).also { instance = it } } } } }
|
关键组件
- MutableSharedFlow 和 SharedFlow: 这些用于管理从LLM推理中得到的中间结果流。MutableSharedFlow允许我们发送新的结果,而SharedFlow则将这些结果暴露给应用程序的其他部分。
- LlmInference 初始化: LlmInference实例通过选项初始化,包括模型路径、最大令牌数和一个处理中间结果的结果监听器。
我们可以使用以下配置选项来初始化LlmInference,
- modelPath: 模型在项目目录中的存储路径。
- maxTokens: 模型处理的最大令牌数(输入令牌 + 输出令牌)。默认值为512。
- topK: 模型在生成每一步考虑的令牌数量。限制预测为最可能的k个令牌。默认值为40。
- temperature: 生成过程中引入的随机性量。较高的温度会导致生成文本更具创造性,而较低的温度则产生更可预测的生成结果。默认值为0.8。
- randomSeed: 文本生成过程中使用的随机种子。默认值为0。
- loraPath: 设备上LoRA模型的绝对路径。注意:这仅兼容GPU模型。
- resultListener: 设置结果监听器以异步接收结果。仅在使用异步生成方法时适用。
- errorListener: 设置一个可选的错误监听器。
使用LLMState管理状态
1 2 3 4 5 6 7 8 9
| sealed class LLMState { data object LLMModelLoading : LLMState() data object LLMModelLoaded : LLMState() data object LLMResponseLoading : LLMState() data object LLMResponseLoaded : LLMState()
val isLLMModelLoading get() = this is LLMModelLoading val isLLMResponseLoading get() = this is LLMResponseLoading }
|
这个密封类有助于管理和响应不同的状态,例如模型加载中、模型已加载以及响应生成时。
ChatState 类
ChatState 类负责维护聊天状态,包括用户消息和LLM响应。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| class ChatState( messages: List<ChatDataModel> = emptyList() ) { private val _chatMessages: MutableList<ChatDataModel> = messages.toMutableStateList() val chatMessages: List<ChatDataModel> get() = _chatMessages.map { model -> val isUser = model.isUser val prefixToRemove = if (isUser) USER_PREFIX else MODEL_PREFIX model.copy( chatMessage = model.chatMessage .replace( START_TURN + prefixToRemove + "\n", "" ) .replace( END_TURN, "" ) ) }.reversed()
val fullPrompt get() = _chatMessages.takeLast(5).joinToString("\n") { it.chatMessage }
fun createLLMLoadingMessage(): String { val chatMessage = ChatDataModel( chatMessage = "", isUser = false ) _chatMessages.add(chatMessage) return chatMessage.id }
fun appendFirstLLMResponse( id: String, message: String, ) { appendLLMResponse( id, "$START_TURN$MODEL_PREFIX\n$message", false ) }
fun appendLLMResponse( id: String, message: String, done: Boolean ) { val index = _chatMessages.indexOfFirst { it.id == id } if (index != -1) { val newText = if (done) { _chatMessages[index].chatMessage + message + END_TURN } else { _chatMessages[index].chatMessage + message } _chatMessages[index] = _chatMessages[index].copy(chatMessage = newText) } }
fun appendUserMessage( message: String, ) { val chatMessage = ChatDataModel( chatMessage = "$START_TURN$USER_PREFIX\n$message$END_TURN", isUser = true ) _chatMessages.add(chatMessage) }
fun addErrorLLMResponse(e: Exception) { _chatMessages.add( ChatDataModel( chatMessage = e.localizedMessage ?: "Error generating message", isUser = false ) ) }
companion object { private const val MODEL_PREFIX = "model" private const val USER_PREFIX = "user" private const val START_TURN = "<start_of_turn>" private const val END_TURN = "<end_of_turn>" } }
|
关键方法:
- createLLMLoadingMessage:向聊天状态中添加一条新的加载消息并返回其ID。
- appendFirstLLMResponse 和 appendLLMResponse:这些方法处理将部分和完整的LLM响应追加到聊天消息中。
- appendUserMessage:将用户消息添加到聊天状态中。
- addErrorLLMResponse:如果在LLM处理过程中出现问题,则添加一条错误消息。
- fullPrompt:拼接最后5条消息,为LLM提供更好的上下文。
ChatViewModel 类
ChatViewModel 类管理 UI 与 LLM 处理逻辑之间的交互。它使用 Kotlin 协程来管理异步任务,并相应地更新 UI 状态。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| @HiltViewModel class ChatViewModel @Inject constructor(@ApplicationContext private val context: Context) : ViewModel() { private val _llmState = MutableStateFlow<LLMState>(LLMState.LLMModelLoading) val llmState = _llmState.asStateFlow() private val _chatState: MutableStateFlow<ChatState> = MutableStateFlow(ChatState()) val chatState: StateFlow<ChatState> = _chatState.asStateFlow()
fun initLLMModel() { viewModelScope.launch(Dispatchers.IO) { _llmState.emit(LLMState.LLMModelLoading) LLMTask.getInstance(context) }.invokeOnCompletion { _llmState.value = LLMState.LLMModelLoaded } }
fun sendMessage(message: String) { viewModelScope.launch(Dispatchers.IO) { _chatState.value.appendUserMessage(message) try { _llmState.emit(LLMState.LLMResponseLoading) var currentLLMResponseId: String? = _chatState.value.createLLMLoadingMessage() LLMTask.getInstance(context).generateResponse(_chatState.value.fullPrompt) LLMTask.getInstance(context).partialResults .collectIndexed { index, (partialResult, done) -> currentLLMResponseId?.let { id -> if (index == 0) { _chatState.value.appendFirstLLMResponse(id, partialResult) } else { _chatState.value.appendLLMResponse(id, partialResult, done) } if (done) { _llmState.emit(LLMState.LLMResponseLoaded) currentLLMResponseId = null } } } } catch (e: Exception) { _chatState.value.addErrorLLMResponse(e) } } } }
|
关键功能:
- initLLMModel:初始化 LLM 模型并相应地更新状态。
- sendMessage:处理用户消息,生成 LLM 响应,并使用部分和最终结果更新聊天状态。
在聊天界面中的整合

优势
隐私保护: 通过在设备本地处理数据,设备端大型语言模型减少了将敏感信息通过互联网传输的需求,增强了用户隐私保护。
离线功能: 设备端大型语言模型无需互联网连接即可运行,使用户即使在离线环境下也能访问语言处理功能。
低延迟: 本地处理数据减少了将数据发送到远程服务器进行处理时的延迟,从而实现更快的响应时间。
降低数据成本: 用户可以避免因将数据发送到远程服务器进行处理而产生的数据费用。
定制化: 设备端大型语言模型可以根据特定用例或设备进行定制和优化,提供更大的灵活性和性能优化。
成本效益: 长期来看,设备端处理更具成本效益,因为它减少了昂贵服务器基础设施和数据传输成本的需求。
缺点
模型规模与复杂性: 设备端处理要求将大型语言模型(LLM)存储在本地设备上,由于现代LLM的规模和复杂性,这可能颇具挑战。较大的模型需要更多的存储空间和计算资源,这可能对低端设备造成压力。
资源密集型: 在设备上运行LLM可能会消耗大量资源,特别是对于复杂模型或长序列。这可能导致电池消耗增加和性能下降,尤其是在较旧或性能较弱的设备上。
模型更新: 跟上LLM模型的最新进展和改进可能颇具挑战。更新模型需要更新应用程序,这对用户来说可能并非总是可行或实际。
🚀 喜欢我在最新Medium文章中的见解吗?如果你觉得有帮助,请考虑给予掌声(👏) 并与你的网络分享。
别忘了点击‘关注’按钮。 📚 让我们保持联系,一起探索更多! 🚀📌