Skip to content

fix(gralora): use _cast_input_dtype in forward instead of raw .to()#3209

Open
Chessing234 wants to merge 1 commit into
huggingface:mainfrom
Chessing234:fix/gralora-cast-input-dtype
Open

fix(gralora): use _cast_input_dtype in forward instead of raw .to()#3209
Chessing234 wants to merge 1 commit into
huggingface:mainfrom
Chessing234:fix/gralora-cast-input-dtype

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Bug

Linear.forward in src/peft/tuners/gralora/layer.py casts the input tensor with x.to(gralora_dtype) in two places. This bypasses the disable_lora_input_dtype_casting context manager provided by BaseTunerLayer, which lets callers suppress dtype casting during specific operations (e.g. AMP, mixed-precision inference).

Root cause

BaseTunerLayer._cast_input_dtype checks self._disable_lora_input_dtype_casting and is a no-op when the context manager is active. A bare .to(dtype) call has no such guard.

Fix

Replace both raw casts with the inherited helper:

# before (line 367)
dropout(x.to(gralora_dtype)).view(...)
# after
dropout(self._cast_input_dtype(x, gralora_dtype)).view(...)

# before (line 382)
gralora_B_general(gralora_A_general(dropout(x.to(gralora_dtype))))
# after
gralora_B_general(gralora_A_general(dropout(self._cast_input_dtype(x, gralora_dtype))))

This is the same pattern already applied to vera (#3172), poly (#3177), fourierft (#3170), and pvera (#3208). gralora was missed in those fixes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant