diff --git a/fastchat/model/rwkv_model.py b/fastchat/model/rwkv_model.py index bdbc14584..0a0c65bd7 100644 --- a/fastchat/model/rwkv_model.py +++ b/fastchat/model/rwkv_model.py @@ -17,9 +17,13 @@ def __init__(self, model_path): "Experimental support. Please use ChatRWKV if you want to chat with RWKV" ) self.config = SimpleNamespace(is_encoder_decoder=False) - self.model = RWKV(model=model_path, strategy="cuda fp16") - # two GPUs - # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") + n = torch.cuda.device_count() + strategy = ( + " -> ".join(f"cuda:{i} fp16" for i in range(n)) + if n > 1 + else "cuda fp16" + ) + self.model = RWKV(model=model_path, strategy=strategy) self.tokenizer = None self.model_path = model_path