mirror of https://github.com/inclusionAI/AReaL
PullRequest: 27 support bf16 training
Merge branch fw/bf16 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/27 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support bf16 training
This commit is contained in:
parent
9a07ab9460
commit
fb23009e99
|
@ -68,7 +68,9 @@ async def async_invoke_function(
|
|||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Async invocation failed on attempt {retries + 1}:{str(e)}, URL: {url}, Headers: {session.headers}")
|
||||
logger.error(
|
||||
f"Async invocation failed on attempt {retries + 1}:{str(e)}, URL: {url}, Headers: {session.headers}"
|
||||
)
|
||||
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
|
|
|
@ -25,7 +25,7 @@ def handle(event, context):
|
|||
answers = event.get("answers", "")
|
||||
solutions = event.get("solutions", "")
|
||||
|
||||
#print(f"math payload:{event}\n")
|
||||
# print(f"math payload:{event}\n")
|
||||
# answers and solutions are json lists, and call process_results then collect result into a list
|
||||
if isinstance(answers, str):
|
||||
answers = json.loads(answers)
|
||||
|
|
|
@ -58,7 +58,7 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) ->
|
|||
"query_ids": [query_ids[i] for i in indices],
|
||||
}
|
||||
|
||||
#print(batch_args)
|
||||
# print(batch_args)
|
||||
batch_args_list.append(batch_args)
|
||||
|
||||
results_batch = batch_function_call(batch_args_list, "python_math", timeout)
|
||||
|
|
|
@ -214,10 +214,8 @@ class ModelTrainEvalConfig:
|
|||
:type path: str
|
||||
:param gradient_checkpointing: Whether to use gradient checkpointing to save memory.
|
||||
:type gradient_checkpointing: bool
|
||||
:param enable_fp16: Whether to use fp16 precision.
|
||||
:type enable_fp16: bool
|
||||
:param enable_bf16: Whether to use bf16 precision. Mutually exclusive with fp16.
|
||||
:type enable_bf16: bool
|
||||
:param bf16: Whether to use bf16 precision. Otherwise use fp16.
|
||||
:type bf16: bool
|
||||
:param offload: Whether to offload model parameters to CPU. Only valid for the DeepSpeed backend.
|
||||
:type offload: bool
|
||||
:param parallel: Configuration for parallelism.
|
||||
|
@ -236,8 +234,7 @@ class ModelTrainEvalConfig:
|
|||
)
|
||||
path: str = ""
|
||||
gradient_checkpointing: bool = True
|
||||
enable_fp16: bool = True
|
||||
enable_bf16: bool = False
|
||||
bf16: bool = False
|
||||
offload: bool = False
|
||||
zero_stage: int = dataclasses.field(
|
||||
metadata={"choices": [0, 1, 2, 3]},
|
||||
|
|
|
@ -58,11 +58,6 @@ def check_valid_backend(role: str, model: ModelTrainEvalConfig):
|
|||
|
||||
|
||||
def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig):
|
||||
if model.enable_bf16 and model.enable_fp16:
|
||||
raise ValueError(
|
||||
f"For model `{role}`, enable_bf16 and" " enable_fp16 cannot be both True."
|
||||
)
|
||||
|
||||
if not os.path.exists(model.path):
|
||||
raise FileNotFoundError(
|
||||
f"The model path `{model.path}` for `{role}` does not exist locally. "
|
||||
|
|
|
@ -606,6 +606,7 @@ class CommonExperimentConfig(Experiment):
|
|||
"vllm",
|
||||
args=dict(
|
||||
model_path=model_cfg.path,
|
||||
dtype="bfloat16" if model_cfg.bf16 else "float16",
|
||||
**vllm_dict_args,
|
||||
),
|
||||
),
|
||||
|
@ -633,7 +634,7 @@ class CommonExperimentConfig(Experiment):
|
|||
is_critic=model_cfg.type.is_critic,
|
||||
init_from_scratch=model_cfg.init_from_scratch,
|
||||
init_critic_from_actor=model_cfg.init_critic_from_actor,
|
||||
dtype="bf16" if model_cfg.enable_bf16 else "fp16",
|
||||
dtype="bf16" if model_cfg.bf16 else "fp16",
|
||||
)
|
||||
hf_config = transformers.AutoConfig.from_pretrained(
|
||||
model_cfg.path,
|
||||
|
|
|
@ -282,8 +282,6 @@ class PPOCODEConfig(CommonExperimentConfig):
|
|||
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
|
||||
)
|
||||
|
||||
|
||||
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
|
|
|
@ -277,13 +277,12 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
raise RuntimeError(
|
||||
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
|
||||
)
|
||||
|
||||
|
||||
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
||||
if domain and (not (domain.startswith("http://") and ":" in domain)):
|
||||
raise RuntimeError(
|
||||
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
|
||||
)
|
||||
|
||||
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
|
|
|
@ -82,8 +82,7 @@ def make_train_backend_config(
|
|||
),
|
||||
offload_optimizer_state=model_cfg.optimizer.offload,
|
||||
offload_param=model_cfg.offload,
|
||||
enable_bf16=model_cfg.enable_bf16,
|
||||
enable_fp16=model_cfg.enable_fp16,
|
||||
bf16=model_cfg.bf16,
|
||||
),
|
||||
)
|
||||
elif model_cfg.backend == "megatron":
|
||||
|
@ -100,8 +99,7 @@ def make_train_backend_config(
|
|||
return ModelBackendAbstraction(
|
||||
"megatron",
|
||||
args=dict(
|
||||
enable_bf16=model_cfg.enable_bf16,
|
||||
enable_fp16=model_cfg.enable_fp16,
|
||||
bf16=model_cfg.bf16,
|
||||
zero_stage=model_cfg.zero_stage,
|
||||
optimizer=model_cfg.optimizer,
|
||||
**megatron_args,
|
||||
|
|
|
@ -883,8 +883,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
||||
enable_fp16: bool = True
|
||||
enable_bf16: bool = False
|
||||
bf16: bool = False
|
||||
zero_stage: int = dataclasses.field(
|
||||
metadata={"choices": [0, 1, 2, 3]},
|
||||
default=2,
|
||||
|
@ -946,8 +945,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
lr = self.optimizer.lr
|
||||
opt_cfg = MegatronOptimizerConfig(
|
||||
optimizer=self.optimizer.type,
|
||||
fp16=self.enable_fp16,
|
||||
bf16=self.enable_bf16,
|
||||
bf16=self.bf16,
|
||||
lr=lr,
|
||||
min_lr=self.optimizer.min_lr_ratio * lr,
|
||||
weight_decay=wd,
|
||||
|
|
|
@ -166,6 +166,7 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM):
|
|||
@dataclasses.dataclass
|
||||
class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
|
||||
model_path: str = ""
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
def _initialize(
|
||||
self, model: model_api.Model, spec: model_api.FinetuneSpec
|
||||
|
@ -187,7 +188,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
|
|||
trust_remote_code=True,
|
||||
max_model_len=self.max_model_len,
|
||||
seed=seeding.get_seed(),
|
||||
dtype=torch.float16,
|
||||
dtype=getattr(torch, self.dtype),
|
||||
kv_cache_dtype=self.kv_cache_type,
|
||||
device=constants.current_device(),
|
||||
# Parallelism.
|
||||
|
|
|
@ -285,8 +285,7 @@ class HFModelRegistry:
|
|||
hf_config = self.config_to_hf_converter(model.config)
|
||||
hf_config.architectures = [self.hf_cls_name]
|
||||
hf_config.name_or_path = str(save_dir)
|
||||
# HACK: because currently our interface codes are all written in float16
|
||||
hf_config.torch_dtype = "float16"
|
||||
hf_config.torch_dtype = str(model.dtype).strip("torch.")
|
||||
|
||||
param_size = sum(
|
||||
[value.numel() * value.element_size() for value in hf_sd.values()]
|
||||
|
|
|
@ -446,7 +446,7 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
res = SequenceSample(
|
||||
keys=["packed_ref_logprobs"],
|
||||
ids=input_.ids,
|
||||
dtypes=dict(packed_ref_logprobs=torch.float16),
|
||||
dtypes=dict(packed_ref_logprobs=model.module.dtype),
|
||||
trailing_shapes=dict(packed_ref_logprobs=()),
|
||||
data=dict(packed_ref_logprobs=logprobs),
|
||||
seqlens=dict(
|
||||
|
@ -949,7 +949,7 @@ class PPOCriticInterface(model_api.ModelInterface):
|
|||
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
|
||||
)
|
||||
if self.disable_value:
|
||||
scores = torch.zeros_like(input_.data["packed_input_ids"]).to(torch.float16)
|
||||
scores = input_.data["packed_input_ids"].new_zeros(dtype=module.dtype)
|
||||
else:
|
||||
scores = module.forward(input_=input_flattend, mb_spec=mb_spec)
|
||||
|
||||
|
@ -964,7 +964,7 @@ class PPOCriticInterface(model_api.ModelInterface):
|
|||
res = SequenceSample(
|
||||
keys=["values"],
|
||||
ids=input_.ids,
|
||||
dtypes=dict(values=torch.float16),
|
||||
dtypes=dict(values=module.dtype),
|
||||
trailing_shapes=dict(values=()),
|
||||
data=dict(values=scores),
|
||||
seqlens=dict(values=input_.seqlens["packed_input_ids"]),
|
||||
|
|
|
@ -909,7 +909,7 @@ def make_real_model(
|
|||
elif dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp32":
|
||||
dtype == torch.float32
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
|
|
|
@ -77,8 +77,7 @@ class ProfileLayers:
|
|||
warmup_steps_proportion=0.0,
|
||||
min_lr_ratio=0.0,
|
||||
zero_stage=1,
|
||||
enable_fp16=True,
|
||||
enable_bf16=False,
|
||||
bf16=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue