PullRequest: 27 support bf16 training

Merge branch fw/bf16 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/27

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support bf16 training
This commit is contained in:
博惟 2025-03-12 13:04:30 +08:00
parent 9a07ab9460
commit fb23009e99
15 changed files with 23 additions and 36 deletions

View File

@ -68,7 +68,9 @@ async def async_invoke_function(
break
except Exception as e:
logger.error(f"Async invocation failed on attempt {retries + 1}:{str(e)}, URL: {url}, Headers: {session.headers}")
logger.error(
f"Async invocation failed on attempt {retries + 1}:{str(e)}, URL: {url}, Headers: {session.headers}"
)
retries += 1
if retries > max_retries:

View File

@ -25,7 +25,7 @@ def handle(event, context):
answers = event.get("answers", "")
solutions = event.get("solutions", "")
#print(f"math payload:{event}\n")
# print(f"math payload:{event}\n")
# answers and solutions are json lists, and call process_results then collect result into a list
if isinstance(answers, str):
answers = json.loads(answers)

View File

@ -58,7 +58,7 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
"query_ids": [query_ids[i] for i in indices],
}
#print(batch_args)
# print(batch_args)
batch_args_list.append(batch_args)
results_batch = batch_function_call(batch_args_list, "python_math", timeout)

View File

@ -214,10 +214,8 @@ class ModelTrainEvalConfig:
:type path: str
:param gradient_checkpointing: Whether to use gradient checkpointing to save memory.
:type gradient_checkpointing: bool
:param enable_fp16: Whether to use fp16 precision.
:type enable_fp16: bool
:param enable_bf16: Whether to use bf16 precision. Mutually exclusive with fp16.
:type enable_bf16: bool
:param bf16: Whether to use bf16 precision. Otherwise use fp16.
:type bf16: bool
:param offload: Whether to offload model parameters to CPU. Only valid for the DeepSpeed backend.
:type offload: bool
:param parallel: Configuration for parallelism.
@ -236,8 +234,7 @@ class ModelTrainEvalConfig:
)
path: str = ""
gradient_checkpointing: bool = True
enable_fp16: bool = True
enable_bf16: bool = False
bf16: bool = False
offload: bool = False
zero_stage: int = dataclasses.field(
metadata={"choices": [0, 1, 2, 3]},

View File

@ -58,11 +58,6 @@ def check_valid_backend(role: str, model: ModelTrainEvalConfig):
def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig):
if model.enable_bf16 and model.enable_fp16:
raise ValueError(
f"For model `{role}`, enable_bf16 and" " enable_fp16 cannot be both True."
)
if not os.path.exists(model.path):
raise FileNotFoundError(
f"The model path `{model.path}` for `{role}` does not exist locally. "

View File

@ -606,6 +606,7 @@ class CommonExperimentConfig(Experiment):
"vllm",
args=dict(
model_path=model_cfg.path,
dtype="bfloat16" if model_cfg.bf16 else "float16",
**vllm_dict_args,
),
),
@ -633,7 +634,7 @@ class CommonExperimentConfig(Experiment):
is_critic=model_cfg.type.is_critic,
init_from_scratch=model_cfg.init_from_scratch,
init_critic_from_actor=model_cfg.init_critic_from_actor,
dtype="bf16" if model_cfg.enable_bf16 else "fp16",
dtype="bf16" if model_cfg.bf16 else "fp16",
)
hf_config = transformers.AutoConfig.from_pretrained(
model_cfg.path,

View File

@ -282,8 +282,6 @@ class PPOCODEConfig(CommonExperimentConfig):
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
)
# interfaces
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",

View File

@ -277,13 +277,12 @@ class PPOMATHConfig(CommonExperimentConfig):
raise RuntimeError(
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if domain and (not (domain.startswith("http://") and ":" in domain)):
raise RuntimeError(
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
)
# interfaces
actor_interface = ModelInterfaceAbstraction(

View File

@ -82,8 +82,7 @@ def make_train_backend_config(
),
offload_optimizer_state=model_cfg.optimizer.offload,
offload_param=model_cfg.offload,
enable_bf16=model_cfg.enable_bf16,
enable_fp16=model_cfg.enable_fp16,
bf16=model_cfg.bf16,
),
)
elif model_cfg.backend == "megatron":
@ -100,8 +99,7 @@ def make_train_backend_config(
return ModelBackendAbstraction(
"megatron",
args=dict(
enable_bf16=model_cfg.enable_bf16,
enable_fp16=model_cfg.enable_fp16,
bf16=model_cfg.bf16,
zero_stage=model_cfg.zero_stage,
optimizer=model_cfg.optimizer,
**megatron_args,

View File

@ -883,8 +883,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
@dataclasses.dataclass
class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
enable_fp16: bool = True
enable_bf16: bool = False
bf16: bool = False
zero_stage: int = dataclasses.field(
metadata={"choices": [0, 1, 2, 3]},
default=2,
@ -946,8 +945,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
lr = self.optimizer.lr
opt_cfg = MegatronOptimizerConfig(
optimizer=self.optimizer.type,
fp16=self.enable_fp16,
bf16=self.enable_bf16,
bf16=self.bf16,
lr=lr,
min_lr=self.optimizer.min_lr_ratio * lr,
weight_decay=wd,

View File

@ -166,6 +166,7 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM):
@dataclasses.dataclass
class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
model_path: str = ""
dtype: str = "bfloat16"
def _initialize(
self, model: model_api.Model, spec: model_api.FinetuneSpec
@ -187,7 +188,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
trust_remote_code=True,
max_model_len=self.max_model_len,
seed=seeding.get_seed(),
dtype=torch.float16,
dtype=getattr(torch, self.dtype),
kv_cache_dtype=self.kv_cache_type,
device=constants.current_device(),
# Parallelism.

View File

@ -285,8 +285,7 @@ class HFModelRegistry:
hf_config = self.config_to_hf_converter(model.config)
hf_config.architectures = [self.hf_cls_name]
hf_config.name_or_path = str(save_dir)
# HACK: because currently our interface codes are all written in float16
hf_config.torch_dtype = "float16"
hf_config.torch_dtype = str(model.dtype).strip("torch.")
param_size = sum(
[value.numel() * value.element_size() for value in hf_sd.values()]

View File

@ -446,7 +446,7 @@ class PPOActorInterface(model_api.ModelInterface):
res = SequenceSample(
keys=["packed_ref_logprobs"],
ids=input_.ids,
dtypes=dict(packed_ref_logprobs=torch.float16),
dtypes=dict(packed_ref_logprobs=model.module.dtype),
trailing_shapes=dict(packed_ref_logprobs=()),
data=dict(packed_ref_logprobs=logprobs),
seqlens=dict(
@ -949,7 +949,7 @@ class PPOCriticInterface(model_api.ModelInterface):
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
)
if self.disable_value:
scores = torch.zeros_like(input_.data["packed_input_ids"]).to(torch.float16)
scores = input_.data["packed_input_ids"].new_zeros(dtype=module.dtype)
else:
scores = module.forward(input_=input_flattend, mb_spec=mb_spec)
@ -964,7 +964,7 @@ class PPOCriticInterface(model_api.ModelInterface):
res = SequenceSample(
keys=["values"],
ids=input_.ids,
dtypes=dict(values=torch.float16),
dtypes=dict(values=module.dtype),
trailing_shapes=dict(values=()),
data=dict(values=scores),
seqlens=dict(values=input_.seqlens["packed_input_ids"]),

View File

@ -909,7 +909,7 @@ def make_real_model(
elif dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp32":
dtype == torch.float32
dtype = torch.float32
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

View File

@ -77,8 +77,7 @@ class ProfileLayers:
warmup_steps_proportion=0.0,
min_lr_ratio=0.0,
zero_stage=1,
enable_fp16=True,
enable_bf16=False,
bf16=False,
),
)