mirror of https://github.com/inclusionAI/AReaL
reformatted
This commit is contained in:
parent
eff8f09149
commit
6bde86a934
|
@ -690,7 +690,7 @@ class ClusterSpecConfig:
|
|||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
path: str =field(
|
||||
path: str = field(
|
||||
default=MISSING,
|
||||
metadata={
|
||||
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."
|
||||
|
|
|
@ -47,14 +47,17 @@ class LLMResponse:
|
|||
@property
|
||||
def output_len(self) -> int:
|
||||
return len(self.output_tokens)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMRequest(LLMRequest):
|
||||
image_data: Optional[List[ImageObject|str]] = field(default_factory=list)
|
||||
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResponse(LLMResponse):
|
||||
input_images: List[ImageObject|str] = field(default_factory=list)
|
||||
input_images: List[ImageObject | str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuneSpec:
|
||||
|
@ -142,6 +145,7 @@ class AllocationMode:
|
|||
raise ValueError(
|
||||
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
|
||||
pattern = re.compile(
|
||||
|
|
|
@ -4,36 +4,39 @@ import transformers
|
|||
|
||||
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
|
||||
|
||||
|
||||
def get_custom_dataset(
|
||||
path: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
training_type: str= "sft",
|
||||
training_type: str = "sft",
|
||||
split: Optional[str] = None,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||
processor: Optional[transformers.AutoProcessor] = None,
|
||||
):
|
||||
):
|
||||
|
||||
|
||||
if "gsm8k" in path and training_type == "sft":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
||||
|
||||
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
|
||||
elif "gsm8k" in path and training_type == "rl":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
||||
|
||||
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "sft":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_sft_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "rl":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_rl_dataset,
|
||||
)
|
||||
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
|
||||
|
||||
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
|
||||
f"Supported datasets are: {VALID_DATASETS}. "
|
||||
)
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class BaseHFEngine(TrainEngine):
|
|||
self.own_global_group = False
|
||||
self._parallelism_group: dist.ProcessGroup
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
|
@ -96,8 +96,10 @@ class BaseHFEngine(TrainEngine):
|
|||
dtype = getattr(torch, self.config.dtype)
|
||||
|
||||
if self.is_vision_model:
|
||||
dtype = torch.bfloat16
|
||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(self.config.path)
|
||||
dtype = torch.bfloat16
|
||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
|
||||
self.config.path
|
||||
)
|
||||
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
|
@ -132,7 +134,7 @@ class BaseHFEngine(TrainEngine):
|
|||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
|
@ -231,14 +233,12 @@ class BaseHFEngine(TrainEngine):
|
|||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
|
||||
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if self.is_vision_model:
|
||||
assert "pixel_values" in input_ and "image_grid_thw" in input_, (
|
||||
"For vision-language models, pixel_values and image_grid_thw must be present in input_"
|
||||
)
|
||||
assert (
|
||||
"pixel_values" in input_ and "image_grid_thw" in input_
|
||||
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
|
||||
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
|
@ -303,7 +303,6 @@ class BaseHFEngine(TrainEngine):
|
|||
|
||||
loss *= loss_scale
|
||||
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.optimizer_config.gradient_clipping,
|
||||
|
|
|
@ -49,7 +49,6 @@ class FSDPEngine(BaseHFEngine):
|
|||
|
||||
self.create_process_group()
|
||||
self.create_device_model()
|
||||
|
||||
|
||||
# Wrap with FSDP2
|
||||
# Simple auto wrap policy
|
||||
|
@ -101,7 +100,10 @@ class FSDPEngine(BaseHFEngine):
|
|||
self.load_optimizer_state(meta.path)
|
||||
|
||||
def _save_model_to_hf(
|
||||
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast], processor: Optional[AutoProcessor]
|
||||
self,
|
||||
path: str,
|
||||
tokenizer: Optional[PreTrainedTokenizerFast],
|
||||
processor: Optional[AutoProcessor],
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
|
@ -147,7 +149,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer,self.processor)
|
||||
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
|
||||
# dist.barrier() are called when _save_model_to_hf finished
|
||||
if dist.get_rank() == 0:
|
||||
update_name = names.update_weights_from_disk(
|
||||
|
@ -240,7 +242,6 @@ class FSDPEngine(BaseHFEngine):
|
|||
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
|
||||
# NOTE: grad norm clip function is different
|
||||
|
||||
|
|
|
@ -257,8 +257,6 @@ class FSDPPPOActor(FSDPEngine):
|
|||
return self.actor.ppo_update(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
def grpo_loss_fn(
|
||||
logits: torch.Tensor,
|
||||
input_data: Dict,
|
||||
|
|
|
@ -41,6 +41,7 @@ class FSDPLMEngine(FSDPEngine):
|
|||
def evaluate_lm(self, data):
|
||||
return self.lm_engine.evaluate_lm(data)
|
||||
|
||||
|
||||
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
|
@ -49,7 +50,7 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
|
|||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
||||
logprobs = torch.where(loss_mask, logprobs, 0)
|
||||
|
||||
|
||||
loss = -logprobs.sum() / loss_mask.count_nonzero()
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
|
|
|
@ -228,7 +228,9 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
return server
|
||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||
|
||||
async def agenerate(self, req: LLMRequest|VLMRequest) -> LLMResponse|VLMResponse:
|
||||
async def agenerate(
|
||||
self, req: LLMRequest | VLMRequest
|
||||
) -> LLMResponse | VLMResponse:
|
||||
"""Async version of generate using aiohttp."""
|
||||
# Prepare request payload
|
||||
gconfig = req.gconfig
|
||||
|
@ -318,28 +320,28 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
sample_params["max_new_tokens"] -= len(output_tokens)
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
|
||||
if isinstance(req, VLMRequest):
|
||||
response = VLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
input_images=req.image_data,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
input_tokens=req.input_ids,
|
||||
input_images=req.image_data,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
else:
|
||||
response=LLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
response = LLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
return response
|
||||
|
||||
def update_weights(self, meta):
|
||||
|
@ -526,7 +528,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
):
|
||||
try:
|
||||
data = next(self.data_generator)
|
||||
|
||||
|
||||
except StopIteration:
|
||||
self.data_generator = iter(dataloader)
|
||||
data = next(self.data_generator)
|
||||
|
|
|
@ -65,8 +65,9 @@ def pad_sequences_to_tensors(
|
|||
return TensorDict()
|
||||
skip_keys = {"pixel_values", "image_grid_thw"}
|
||||
max_length = max(
|
||||
len(seq) for item in sequence_list
|
||||
for key, seq in item.items()
|
||||
len(seq)
|
||||
for item in sequence_list
|
||||
for key, seq in item.items()
|
||||
if key not in skip_keys
|
||||
)
|
||||
result = {}
|
||||
|
@ -79,14 +80,18 @@ def pad_sequences_to_tensors(
|
|||
x = item[key]
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.tensor(x)
|
||||
padded_x=torch.nn.functional.pad(
|
||||
x, (0, max_length - len(item[key])), value=pad_value
|
||||
)
|
||||
padded_x = torch.nn.functional.pad(
|
||||
x, (0, max_length - len(item[key])), value=pad_value
|
||||
)
|
||||
padded.append(padded_x)
|
||||
result[key] = torch.stack(padded)
|
||||
attention_mask = [
|
||||
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||
+ [0] * (max_length - len(next(iter(item[key] for key in item.keys() if key not in skip_keys))))
|
||||
+ [0]
|
||||
* (
|
||||
max_length
|
||||
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
|
||||
)
|
||||
for item in sequence_list
|
||||
]
|
||||
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
||||
|
@ -139,7 +144,7 @@ def concat_padded_tensors(
|
|||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
current_length = tensor.shape[1]
|
||||
if key == "pixel_values" or key== "image_grid_thw":
|
||||
if key == "pixel_values" or key == "image_grid_thw":
|
||||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
if current_length < max_length:
|
||||
|
@ -150,7 +155,7 @@ def concat_padded_tensors(
|
|||
padding = torch.zeros(
|
||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# Pad feature tensors with pad_value
|
||||
padding = torch.full(
|
||||
|
@ -323,7 +328,7 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
to_split = {}
|
||||
not_to_split = {}
|
||||
for key, value in data.items():
|
||||
if key=="image_grid_thw" or key=="pixel_values":
|
||||
if key == "image_grid_thw" or key == "pixel_values":
|
||||
continue
|
||||
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
||||
not_to_split[key] = value
|
||||
|
@ -368,7 +373,7 @@ def split_padded_tensor_dict_into_mb_list(
|
|||
|
||||
for group_index in group_indices:
|
||||
group_pixel_values = [pixel_values[i] for i in group_index]
|
||||
group_image_grid_thw = [image_grid_thw[i].squeeze()for i in group_index]
|
||||
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
|
||||
|
||||
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
|
||||
pixel_values_split.append(torch.stack(group_pixel_values))
|
||||
|
|
|
@ -46,7 +46,7 @@ def fsdp2_clip_grad_norm_(
|
|||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
|
||||
|
||||
|
||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||
return total_norm
|
||||
|
||||
|
|
|
@ -7,20 +7,17 @@ from typing import List
|
|||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
|
||||
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
|
||||
|
||||
if isinstance(images, ImageObject):
|
||||
images = [images]
|
||||
|
||||
|
||||
byte_images = []
|
||||
for image in images:
|
||||
with BytesIO() as buffer:
|
||||
image.save(buffer, format="PNG")
|
||||
buffer.seek(0)
|
||||
byte_image = base64.b64encode(buffer.read()).decode('utf-8')
|
||||
buffer.seek(0)
|
||||
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
byte_images.append(byte_image)
|
||||
|
||||
|
||||
return byte_images
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ VALID_VISION_MODELS = [
|
|||
"qwen2_5_vl",
|
||||
]
|
||||
|
||||
|
||||
# Copied from trl
|
||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||
for module in model.modules():
|
||||
|
|
|
@ -34,7 +34,7 @@ class VL_RLVRWorkflow(RLVRWorkflow):
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids=processed_input["input_ids"].tolist()[0]
|
||||
input_ids = processed_input["input_ids"].tolist()[0]
|
||||
|
||||
n_samples = self.gconfig.n_samples
|
||||
|
||||
|
@ -62,13 +62,19 @@ class VL_RLVRWorkflow(RLVRWorkflow):
|
|||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||
pixel_values=processed_input["pixel_values"].clone().detach().unsqueeze(0),
|
||||
image_grid_thw=processed_input["image_grid_thw"].clone().detach().unsqueeze(0),
|
||||
pixel_values=processed_input["pixel_values"]
|
||||
.clone()
|
||||
.detach()
|
||||
.unsqueeze(0),
|
||||
image_grid_thw=processed_input["image_grid_thw"]
|
||||
.clone()
|
||||
.detach()
|
||||
.unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||
|
|
|
@ -24,7 +24,7 @@ from realhf.base import stats_tracker
|
|||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
return match[-1]
|
||||
return match[-1]
|
||||
|
||||
return ""
|
||||
|
||||
|
@ -56,32 +56,35 @@ def extract_solution(solution_str, method="strict") -> str | None:
|
|||
break
|
||||
return final_answer
|
||||
|
||||
def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
|
||||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
is_thinking = "thinking" in completions.lower()
|
||||
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans =answer
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
|
||||
|
||||
if sol.strip() == ans.strip():
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
if is_thinking:
|
||||
return 1
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
|
||||
return 0.05
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(args):
|
||||
os.environ["WANDB_API_KEY"]=""
|
||||
os.environ["WANDB_API_KEY"] = ""
|
||||
wandb.init(project="clevr_70k")
|
||||
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
|
@ -90,22 +93,22 @@ def main(args):
|
|||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||
train_dataset=get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="rl",
|
||||
processor=processor
|
||||
)
|
||||
valid_dataset=get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="rl",
|
||||
processor=processor
|
||||
)
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="rl",
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="rl",
|
||||
processor=processor,
|
||||
)
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
train_dataset,
|
||||
|
@ -142,7 +145,7 @@ def main(args):
|
|||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref =FSDPPPOActor(config=config.ref)
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# Create rollout workflow
|
||||
|
@ -200,7 +203,7 @@ def main(args):
|
|||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
|
||||
|
||||
log_gpu_stats("ref logp")
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
|
@ -212,8 +215,8 @@ def main(args):
|
|||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
wandb.log({"actor_reward": stats[0]['grpo_actor/final_reward/avg']})
|
||||
|
||||
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
|
@ -252,7 +255,14 @@ def main(args):
|
|||
cnt += 1
|
||||
batch = eval_rollout.wait(cnt, timeout=None)
|
||||
rewards = batch["rewards"].float().to(actor.device)
|
||||
wandb.log({"eval_reward": rewards.mean().item(), "epoch": epoch, "step": step, "global_step": global_step})
|
||||
wandb.log(
|
||||
{
|
||||
"eval_reward": rewards.mean().item(),
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
}
|
||||
)
|
||||
with stats_tracker.scope("grpo-eval"):
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
|
@ -281,6 +291,6 @@ def main(args):
|
|||
actor.destroy()
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
||||
|
||||
|
|
|
@ -22,25 +22,25 @@ def main_sft():
|
|||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||
train_dataset=get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset=get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="sft",
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
train_dataset,
|
||||
|
|
|
@ -9,20 +9,31 @@ from datasets.distributed import split_dataset_by_node
|
|||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
def input_text(text:str):
|
||||
def input_text(text: str):
|
||||
return {"type": "input_text", "text": text}
|
||||
|
||||
|
||||
def input_image(base64_image: str):
|
||||
return {"type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}"}
|
||||
def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[Dict[str, Any]]:
|
||||
|
||||
return {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/jpeg;base64,{base64_image}",
|
||||
}
|
||||
|
||||
|
||||
def build_raw_message(
|
||||
sample: Dict[str, Any], base64_images: list[str]
|
||||
) -> list[Dict[str, Any]]:
|
||||
|
||||
raw_message = []
|
||||
problem_parts = [part.strip() for part in sample["problem"].split("<image>") if part.strip()]
|
||||
problem_parts = [
|
||||
part.strip() for part in sample["problem"].split("<image>") if part.strip()
|
||||
]
|
||||
insert_list = []
|
||||
for i, part in enumerate(problem_parts):
|
||||
if i > 0 or sample["problem"].startswith("<image>"):
|
||||
if i > 0 or sample["problem"].startswith("<image>"):
|
||||
insert_list.append("image")
|
||||
part = part.strip()
|
||||
if part:
|
||||
part = part.strip()
|
||||
if part:
|
||||
insert_list.append("text")
|
||||
image_index = 0
|
||||
text_index = 0
|
||||
|
@ -40,17 +51,25 @@ def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[
|
|||
|
||||
def encode_image(image_file):
|
||||
return base64.b64encode(image_file).decode("utf-8")
|
||||
|
||||
|
||||
def convert_image(
|
||||
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int]
|
||||
image: Union[Dict[str, Any], ImageObject, str],
|
||||
min_pixels: Optional[int],
|
||||
max_pixels: Optional[int],
|
||||
) -> ImageObject:
|
||||
if max_pixels is not None and (image.width * image.height) > max_pixels:
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
width, height = int(image.width * resize_factor), int(
|
||||
image.height * resize_factor
|
||||
)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if min_pixels is not None and (image.width * image.height) < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
width, height = int(image.width * resize_factor), int(
|
||||
image.height * resize_factor
|
||||
)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.mode != "RGB":
|
||||
|
@ -59,32 +78,34 @@ def convert_image(
|
|||
image.save(output, format="JPEG")
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
||||
'''
|
||||
"""
|
||||
"clevr_count_70k": {
|
||||
"image_key": "images",
|
||||
"question_key": "problem",
|
||||
"answer_key": "answer"
|
||||
},
|
||||
'''
|
||||
"""
|
||||
dataset = load_dataset(path=path, split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
tokenizer = processor.tokenizer
|
||||
|
||||
tokenizer = processor.tokenizer
|
||||
|
||||
def process_example(example, idx):
|
||||
# Add query_id column
|
||||
images = example["images"]
|
||||
if 'qwen' in processor.image_processor.image_processor_type.lower():
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
example["problem"] = example["problem"].replace("<image>", image_token)
|
||||
processed_images = []
|
||||
for image in images:
|
||||
processed_images.append(convert_image(image,113*113,336*336))
|
||||
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
|
||||
example["images"] = processed_images
|
||||
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
||||
|
||||
|
||||
return example
|
||||
|
||||
dataset = dataset.map(
|
||||
|
@ -93,8 +114,8 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
|||
)
|
||||
|
||||
def _process(example):
|
||||
text=example["seq"]
|
||||
processed_input=processor(
|
||||
text = example["seq"]
|
||||
processed_input = processor(
|
||||
text=[text],
|
||||
images=example["images"],
|
||||
padding=False,
|
||||
|
@ -103,38 +124,52 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
|||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
example["input_ids"] =processed_input["input_ids"].squeeze(0)
|
||||
example["input_ids"] = processed_input["input_ids"].squeeze(0)
|
||||
example["pixel_values"] = processed_input["pixel_values"]
|
||||
example["image_grid_thw"] = processed_input["image_grid_thw"]
|
||||
answer_token = tokenizer.encode(example["answer"])
|
||||
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token))+[1]*len(answer_token)
|
||||
example["loss_mask"]=loss_mask
|
||||
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
|
||||
answer_token
|
||||
)
|
||||
example["loss_mask"] = loss_mask
|
||||
return example
|
||||
|
||||
dataset = dataset.map(lambda x: _process(x),remove_columns=["images","seq","problem","answer"])
|
||||
dataset = dataset.map(
|
||||
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size):
|
||||
|
||||
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
|
||||
dataset = load_dataset(path=path, split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
processed_images = [convert_image(image, 113*113, 336*336) for image in sample["images"]]
|
||||
if 'qwen' in processor.image_processor.image_processor_type.lower():
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>"
|
||||
processed_images = [
|
||||
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
|
||||
]
|
||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
system_prompt = {
|
||||
"role": "system",
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
messages =[{"role": "user", "content": sample["problem"].replace("<image>", image_token)}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sample["problem"].replace("<image>", image_token),
|
||||
}
|
||||
]
|
||||
messages.insert(0, system_prompt)
|
||||
messages=processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
messages = processor.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
return {"messages": messages, "images": processed_images}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["problem"])
|
||||
return dataset
|
||||
return dataset
|
||||
|
|
|
@ -5,6 +5,7 @@ from datasets.distributed import split_dataset_by_node
|
|||
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
||||
dataset = load_dataset(path=path, name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
seq_token = tokenizer.encode(
|
||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
|
@ -16,12 +17,14 @@ def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
|
|||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
||||
def get_gsm8k_rl_dataset(path,split, rank, world_size):
|
||||
|
||||
def get_gsm8k_rl_dataset(path, split, rank, world_size):
|
||||
dataset = load_dataset(path=path, name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
return dataset
|
||||
|
|
|
@ -68,6 +68,7 @@ def load_hf_tokenizer(
|
|||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
return tokenizer
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def load_hf_processor_and_tokenizer(
|
||||
model_name_or_path: str,
|
||||
|
|
Loading…
Reference in New Issue