0711_6
|
@ -13,10 +13,10 @@ def get_custom_dataset(
|
|||
):
|
||||
|
||||
if "gsm8k" in path and training_type == "sft":
|
||||
from AReaL.examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
||||
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "sft":
|
||||
from AReaL.examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_sft_dataset
|
||||
from examples.arealite.dataset.clevr_count_70k import get_clevr_count_70k_sft_dataset
|
||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -24,4 +24,3 @@ def get_custom_dataset(
|
|||
f"Supported datasets are: {VALID_DATASETS}. "
|
||||
)
|
||||
|
||||
|
Before Width: | Height: | Size: 31 KiB After Width: | Height: | Size: 31 KiB |
Before Width: | Height: | Size: 136 KiB After Width: | Height: | Size: 136 KiB |
Before Width: | Height: | Size: 240 KiB After Width: | Height: | Size: 240 KiB |
Before Width: | Height: | Size: 232 KiB After Width: | Height: | Size: 232 KiB |
Before Width: | Height: | Size: 125 KiB After Width: | Height: | Size: 125 KiB |
Before Width: | Height: | Size: 224 KiB After Width: | Height: | Size: 224 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
Before Width: | Height: | Size: 57 KiB After Width: | Height: | Size: 57 KiB |