This commit is contained in:
lichangye.lcy 2025-07-25 18:52:29 +08:00
commit fb1796d941
4 changed files with 34 additions and 20 deletions

View File

@ -4,6 +4,7 @@ from dataclasses import MISSING
from io import BytesIO
from typing import List
from PIL import Image
from PIL.Image import Image as ImageObject
@ -21,3 +22,24 @@ def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
byte_images.append(byte_image)
return byte_images
def pad_images_batch_to_max_size(images):
max_width = max(image.size[0] for image in images)
max_height = max(image.size[1] for image in images)
padded_images = []
for image in images:
width, height = image.size
padding_left = (max_width - width) // 2
padding_top = (max_height - height) // 2
padded_image = Image.new("RGB", (max_width, max_height), (0, 0, 0))
padded_image.paste(image, (padding_left, padding_top))
padded_images.append(padded_image)
return padded_images

View File

@ -10,7 +10,7 @@ from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.io_struct import VLMRequest
from arealite.utils.data import concat_padded_tensors
from arealite.utils.image import image2base64
from arealite.utils.image import image2base64, pad_images_batch_to_max_size
from arealite.workflow.rlvr import RLVRWorkflow
@ -28,10 +28,11 @@ class VisionRLVRWorkflow(RLVRWorkflow):
self.processor = processor
async def arun_episode(self, engine, data):
# self.processor.tokenizer.add_generation_prompt=True
padded_images = pad_images_batch_to_max_size(data["images"])
processed_input = self.processor(
images=data["images"],
images=padded_images,
text=data["messages"],
padding=False,
return_tensors="pt",
@ -41,7 +42,7 @@ class VisionRLVRWorkflow(RLVRWorkflow):
n_samples = self.gconfig.n_samples
byte_images = image2base64(data["images"])
byte_images = image2base64(padded_images)
req = VLMRequest(
rid=uuid.uuid4().hex,
@ -50,13 +51,13 @@ class VisionRLVRWorkflow(RLVRWorkflow):
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
version = engine.get_version()
prompt_strs = []
completions_strs = []
rewards = []
seqlens = []
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
@ -118,4 +119,3 @@ class VisionRLVRWorkflow(RLVRWorkflow):
f.write(info + "\n")
return concat_padded_tensors(results)

View File

@ -8,7 +8,6 @@ import wandb
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import Subset
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
@ -18,6 +17,7 @@ from arealite.utils.device import log_gpu_stats
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker

View File

@ -9,7 +9,6 @@ from PIL.Image import Image as ImageObject
def convert_image(
image: Union[Dict[str, Any], ImageObject, str],
min_pixels: Optional[int],
max_pixels: Optional[int],
) -> ImageObject:
if max_pixels is not None and (image.width * image.height) > max_pixels:
@ -19,13 +18,6 @@ def convert_image(
)
image = image.resize((width, height))
if min_pixels is not None and (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
with BytesIO() as output:
@ -53,10 +45,10 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = example["problem"].replace("<image>", image_token)
example["problem"] = example["problem"].replace("<image>", image_token).replace("different", "")
processed_images = []
for image in images:
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
processed_images.append(convert_image(image, 336 * 336))
example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
@ -100,7 +92,7 @@ def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
def process(sample):
processed_images = [
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
convert_image(image, 336 * 336) for image in sample["images"]
]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
@ -116,7 +108,7 @@ def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
messages = [
{
"role": "user",
"content": sample["problem"].replace("<image>", image_token),
"content": sample["problem"].replace("<image>", image_token).replace("different", ""),
}
]
messages.insert(0, system_prompt)