Skip to content

Commit 00e1646

Browse files
authored
Merge pull request #79 from LeanBitLab/fix-proofread-service-pkv-6163460110815717061
Only use valid pastKeyValues when sequence length is greater than 0
2 parents f2965af + 5b26635 commit 00e1646

1 file changed

Lines changed: 15 additions & 4 deletions

File tree

app/src/offline/java/helium314/keyboard/latin/utils/ProofreadService.kt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,19 @@ class ProofreadService(private val context: Context) {
457457

458458
for (step in 0 until maxTokens) {
459459
// For KV-cache models, only pass the last token after first step
460+
var isValidPkv = false
461+
if (hasPkvInputs && pastKeyValues != null) {
462+
val currentPkv = pastKeyValues!!.values.firstOrNull()
463+
if (currentPkv != null) {
464+
val sequenceLength = currentPkv.info.shape[2]
465+
if (sequenceLength > 0) {
466+
isValidPkv = true
467+
}
468+
}
469+
}
470+
460471
// CRITICAL FIX: Only use valid pastKeyValues if model actually accepts PKV inputs
461-
val inputTokens = if (step > 0 && pastKeyValues != null && hasPkvInputs) {
472+
val inputTokens = if (step > 0 && isValidPkv) {
462473
longArrayOf(generatedTokens.last())
463474
} else {
464475
generatedTokens.toLongArray()
@@ -488,7 +499,7 @@ class ProofreadService(private val context: Context) {
488499
}
489500

490501
// Add past_key_values from previous step (if available and model expects them)
491-
if (hasPkvInputs && pastKeyValues != null) {
502+
if (isValidPkv) {
492503
for ((name, tensor) in pastKeyValues!!) {
493504
// Map present.X.* output names to past_key_values.X.* or pkv_* input names
494505
val inputName = name.replace("present", "past_key_values")
@@ -502,8 +513,8 @@ class ProofreadService(private val context: Context) {
502513
}
503514
}
504515
}
505-
} else if (hasPkvInputs && step == 0) {
506-
// First step with PKV model: provide zero tensors
516+
} else if (hasPkvInputs) {
517+
// First step with PKV model or invalid cache: provide zero tensors
507518
// T5 pkv format: pkv_0 to pkv_N where first half is decoder self-attn, second half is encoder cross-attn
508519
// Shape: [batch, num_heads, seq_len, head_dim]
509520

0 commit comments

Comments
 (0)