mirror of https://github.com/inclusionAI/AReaL
0725_format
This commit is contained in:
parent
6b8bfcf9a4
commit
4ff813ae9f
|
@ -582,8 +582,7 @@ def update_weights_from_distributed(
|
|||
for addr in addresses
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
||||
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
|
||||
|
||||
return uvloop.run(_fn())
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import colorama
|
||||
|
|
|
@ -5,8 +5,8 @@ import sys
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torch.utils.data import Subset
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||
|
@ -125,10 +125,7 @@ def main(args):
|
|||
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
|
||||
# due to `engine.get_param_specs()`.
|
||||
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
|
||||
weight_update_meta = [
|
||||
WeightUpdateMeta.from_disk(
|
||||
config.saver)
|
||||
]
|
||||
weight_update_meta = [WeightUpdateMeta.from_disk(config.saver)]
|
||||
dist.broadcast_object_list(weight_update_meta, src=0)
|
||||
weight_update_meta = weight_update_meta[0]
|
||||
|
||||
|
@ -145,7 +142,7 @@ def main(args):
|
|||
processor=processor,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
|
|
|
@ -45,7 +45,9 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
|
|||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
else:
|
||||
image_token = processor.image_token if processor is not None else "<image>"
|
||||
example["problem"] = example["problem"].replace("<image>", image_token).replace("different", "")
|
||||
example["problem"] = (
|
||||
example["problem"].replace("<image>", image_token).replace("different", "")
|
||||
)
|
||||
processed_images = []
|
||||
for image in images:
|
||||
processed_images.append(convert_image(image, 336 * 336))
|
||||
|
@ -108,7 +110,9 @@ def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
|
|||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sample["problem"].replace("<image>", image_token).replace("different", ""),
|
||||
"content": sample["problem"]
|
||||
.replace("<image>", image_token)
|
||||
.replace("different", ""),
|
||||
}
|
||||
]
|
||||
messages.insert(0, system_prompt)
|
||||
|
|
Loading…
Reference in New Issue