mirror of https://github.com/inclusionAI/AReaL
fix: multiple model families
This commit is contained in:
parent
4ac9595295
commit
8c592f8ca1
|
@ -122,6 +122,7 @@ class ReaLModel(nn.Module):
|
|||
config: model_api.ReaLModelConfig,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
hf_model_family: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dtype is None:
|
||||
|
@ -173,6 +174,14 @@ class ReaLModel(nn.Module):
|
|||
)
|
||||
self.contiguous_param = None
|
||||
|
||||
self.hf_model_family = hf_model_family
|
||||
|
||||
def save_to_hf(self, tokenizer, save_dir):
|
||||
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
|
||||
|
||||
def load_from_hf(self, load_dir):
|
||||
return getattr(self, f"from_{self.hf_model_family}")(load_dir)
|
||||
|
||||
@property
|
||||
def pre_process(self):
|
||||
# A workaround to make Megatron-LM backend happy.
|
||||
|
@ -918,12 +927,7 @@ def make_real_model(
|
|||
model_path=model_path,
|
||||
is_critic=is_critic or init_critic_from_actor,
|
||||
)
|
||||
m = ReaLModel(mconfig, dtype=dtype, device=device)
|
||||
|
||||
# Since we load from `hf_model_family`, we should save to `hf_model_family`.
|
||||
# The following line creates a convinent function to save the model.
|
||||
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{hf_model_family}"))
|
||||
setattr(ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{hf_model_family}"))
|
||||
m = ReaLModel(mconfig, dtype=dtype, device=device, hf_model_family=hf_model_family)
|
||||
|
||||
if not init_from_scratch:
|
||||
m._instantiation_hooks.append(
|
||||
|
|
Loading…
Reference in New Issue