0725_format

This commit is contained in:
朱晗 2025-07-25 19:17:26 +08:00
parent 6b8bfcf9a4
commit 4ff813ae9f
4 changed files with 11 additions and 11 deletions

View File

@ -582,8 +582,7 @@ def update_weights_from_distributed(
for addr in addresses
]
)
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
return uvloop.run(_fn())

View File

@ -1,5 +1,5 @@
import os
import asyncio
import os
import uuid
import colorama

View File

@ -5,8 +5,8 @@ import sys
import torch
import torch.distributed as dist
import wandb
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import Subset
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
@ -125,10 +125,7 @@ def main(args):
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [
WeightUpdateMeta.from_disk(
config.saver)
]
weight_update_meta = [WeightUpdateMeta.from_disk(config.saver)]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]
@ -145,7 +142,7 @@ def main(args):
processor=processor,
enable_thinking=False,
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)

View File

@ -45,7 +45,9 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = example["problem"].replace("<image>", image_token).replace("different", "")
example["problem"] = (
example["problem"].replace("<image>", image_token).replace("different", "")
)
processed_images = []
for image in images:
processed_images.append(convert_image(image, 336 * 336))
@ -108,7 +110,9 @@ def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
messages = [
{
"role": "user",
"content": sample["problem"].replace("<image>", image_token).replace("different", ""),
"content": sample["problem"]
.replace("<image>", image_token)
.replace("different", ""),
}
]
messages.insert(0, system_prompt)