mirror of https://github.com/inclusionAI/AReaL
0725_11
This commit is contained in:
commit
fb1796d941
|
@ -4,6 +4,7 @@ from dataclasses import MISSING
|
|||
from io import BytesIO
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
|
@ -21,3 +22,24 @@ def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
|
|||
byte_images.append(byte_image)
|
||||
|
||||
return byte_images
|
||||
|
||||
|
||||
def pad_images_batch_to_max_size(images):
|
||||
max_width = max(image.size[0] for image in images)
|
||||
max_height = max(image.size[1] for image in images)
|
||||
|
||||
padded_images = []
|
||||
|
||||
for image in images:
|
||||
|
||||
width, height = image.size
|
||||
|
||||
padding_left = (max_width - width) // 2
|
||||
padding_top = (max_height - height) // 2
|
||||
|
||||
padded_image = Image.new("RGB", (max_width, max_height), (0, 0, 0))
|
||||
padded_image.paste(image, (padding_left, padding_top))
|
||||
|
||||
padded_images.append(padded_image)
|
||||
|
||||
return padded_images
|
||||
|
|
|
@ -10,7 +10,7 @@ from transformers import AutoProcessor, PreTrainedTokenizerFast
|
|||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.io_struct import VLMRequest
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.image import image2base64
|
||||
from arealite.utils.image import image2base64, pad_images_batch_to_max_size
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
|
||||
|
||||
|
@ -28,10 +28,11 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
self.processor = processor
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
# self.processor.tokenizer.add_generation_prompt=True
|
||||
|
||||
padded_images = pad_images_batch_to_max_size(data["images"])
|
||||
|
||||
processed_input = self.processor(
|
||||
images=data["images"],
|
||||
images=padded_images,
|
||||
text=data["messages"],
|
||||
padding=False,
|
||||
return_tensors="pt",
|
||||
|
@ -41,7 +42,7 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
|
||||
n_samples = self.gconfig.n_samples
|
||||
|
||||
byte_images = image2base64(data["images"])
|
||||
byte_images = image2base64(padded_images)
|
||||
|
||||
req = VLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
|
@ -50,13 +51,13 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
|
||||
version = engine.get_version()
|
||||
prompt_strs = []
|
||||
completions_strs = []
|
||||
rewards = []
|
||||
seqlens = []
|
||||
|
||||
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
|
@ -118,4 +119,3 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
f.write(info + "\n")
|
||||
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import wandb
|
|||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torch.utils.data import Subset
|
||||
|
||||
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
|
@ -18,6 +17,7 @@ from arealite.utils.device import log_gpu_stats
|
|||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from PIL.Image import Image as ImageObject
|
|||
|
||||
def convert_image(
|
||||
image: Union[Dict[str, Any], ImageObject, str],
|
||||
min_pixels: Optional[int],
|
||||
max_pixels: Optional[int],
|
||||
) -> ImageObject:
|
||||
if max_pixels is not None and (image.width * image.height) > max_pixels:
|
||||
|
@ -19,13 +18,6 @@ def convert_image(
|
|||
)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if min_pixels is not None and (image.width * image.height) < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(
|
||||
image.height * resize_factor
|
||||
)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
with BytesIO() as output:
|
||||
|
@ -53,10 +45,10 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
|||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
example["problem"] = example["problem"].replace("<image>", image_token)
|
||||
example["problem"] = example["problem"].replace("<image>", image_token).replace("different", "")
|
||||
processed_images = []
|
||||
for image in images:
|
||||
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
|
||||
processed_images.append(convert_image(image, 336 * 336))
|
||||
example["images"] = processed_images
|
||||
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
||||
|
||||
|
@ -100,7 +92,7 @@ def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
|
|||
|
||||
def process(sample):
|
||||
processed_images = [
|
||||
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
|
||||
convert_image(image, 336 * 336) for image in sample["images"]
|
||||
]
|
||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
@ -116,7 +108,7 @@ def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
|
|||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sample["problem"].replace("<image>", image_token),
|
||||
"content": sample["problem"].replace("<image>", image_token).replace("different", ""),
|
||||
}
|
||||
]
|
||||
messages.insert(0, system_prompt)
|
||||
|
|
Loading…
Reference in New Issue