fix: multiple model families

This commit is contained in:
wanghuaijie.whj 2025-03-18 15:42:03 +08:00
parent 4ac9595295
commit 8c592f8ca1
1 changed files with 10 additions and 6 deletions

View File

@ -122,6 +122,7 @@ class ReaLModel(nn.Module):
config: model_api.ReaLModelConfig,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
hf_model_family: Optional[str] = None,
):
super().__init__()
if dtype is None:
@ -173,6 +174,14 @@ class ReaLModel(nn.Module):
)
self.contiguous_param = None
self.hf_model_family = hf_model_family
def save_to_hf(self, tokenizer, save_dir):
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
def load_from_hf(self, load_dir):
return getattr(self, f"from_{self.hf_model_family}")(load_dir)
@property
def pre_process(self):
# A workaround to make Megatron-LM backend happy.
@ -918,12 +927,7 @@ def make_real_model(
model_path=model_path,
is_critic=is_critic or init_critic_from_actor,
)
m = ReaLModel(mconfig, dtype=dtype, device=device)
# Since we load from `hf_model_family`, we should save to `hf_model_family`.
# The following line creates a convinent function to save the model.
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{hf_model_family}"))
setattr(ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{hf_model_family}"))
m = ReaLModel(mconfig, dtype=dtype, device=device, hf_model_family=hf_model_family)
if not init_from_scratch:
m._instantiation_hooks.append(