@@ -457,8 +457,19 @@ class ProofreadService(private val context: Context) {
457457
458458 for (step in 0 until maxTokens) {
459459 // For KV-cache models, only pass the last token after first step
460+ var isValidPkv = false
461+ if (hasPkvInputs && pastKeyValues != null ) {
462+ val currentPkv = pastKeyValues!! .values.firstOrNull()
463+ if (currentPkv != null ) {
464+ val sequenceLength = currentPkv.info.shape[2 ]
465+ if (sequenceLength > 0 ) {
466+ isValidPkv = true
467+ }
468+ }
469+ }
470+
460471 // CRITICAL FIX: Only use valid pastKeyValues if model actually accepts PKV inputs
461- val inputTokens = if (step > 0 && pastKeyValues != null && hasPkvInputs ) {
472+ val inputTokens = if (step > 0 && isValidPkv ) {
462473 longArrayOf(generatedTokens.last())
463474 } else {
464475 generatedTokens.toLongArray()
@@ -488,7 +499,7 @@ class ProofreadService(private val context: Context) {
488499 }
489500
490501 // Add past_key_values from previous step (if available and model expects them)
491- if (hasPkvInputs && pastKeyValues != null ) {
502+ if (isValidPkv ) {
492503 for ((name, tensor) in pastKeyValues!! ) {
493504 // Map present.X.* output names to past_key_values.X.* or pkv_* input names
494505 val inputName = name.replace(" present" , " past_key_values" )
@@ -502,8 +513,8 @@ class ProofreadService(private val context: Context) {
502513 }
503514 }
504515 }
505- } else if (hasPkvInputs && step == 0 ) {
506- // First step with PKV model: provide zero tensors
516+ } else if (hasPkvInputs) {
517+ // First step with PKV model or invalid cache : provide zero tensors
507518 // T5 pkv format: pkv_0 to pkv_N where first half is decoder self-attn, second half is encoder cross-attn
508519 // Shape: [batch, num_heads, seq_len, head_dim]
509520
0 commit comments