This commit is contained in:
lichangye.lcy 2025-07-25 19:12:25 +08:00
parent fb1796d941
commit a4ad671d3b
2 changed files with 3 additions and 5 deletions

View File

@ -68,12 +68,10 @@ def main(args):
)
train_size = len(train_dataset)
subset_size = int(0.3 * train_size)
subset_size = int(1.0 * train_size)
# 随机选择 30% 数据的索引
random_indices = torch.randperm(train_size).tolist()[:subset_size]
# 创建一个新的子集数据集
subset_train_dataset = Subset(train_dataset, random_indices)
valid_dataset = get_custom_dataset(

View File

@ -1,9 +1,9 @@
experiment_name: clevr_count_70k-grpo
trial_name: trial0
trial_name: trial1
seed: 1
total_train_epochs: 3
total_train_epochs: 1
tokenizer_path: ${actor.path}
async_training: true