diff --git a/internal/storeLink/storeLink.go b/internal/storeLink/storeLink.go index 70ce291b1..27a821941 100644 --- a/internal/storeLink/storeLink.go +++ b/internal/storeLink/storeLink.go @@ -73,6 +73,11 @@ const ( ) var ( + Datasets = []string{"mnist", "cifar10"} + AlgorithmsForDatasets = map[string][]string{ + "mnist": {"fcn"}, + "cifar10": {"cnn"}, + } OctImgStatus = map[int32]string{ 1: "未上传", 3: "制作完成", @@ -153,6 +158,10 @@ func GetModelNamesByType(t string) ([]string, error) { } func GetDatasetsNames(ctx context.Context, collectorMap map[string]collector.AiCollector) ([]string, error) { + return Datasets, nil +} + +func GetDatasetsNamesSync(ctx context.Context, collectorMap map[string]collector.AiCollector) ([]string, error) { var wg sync.WaitGroup var errCh = make(chan interface{}, len(collectorMap)) var errs []interface{} @@ -225,6 +234,14 @@ func GetDatasetsNames(ctx context.Context, collectorMap map[string]collector.AiC } func GetAlgorithms(ctx context.Context, collectorMap map[string]collector.AiCollector, resourceType string, taskType string, dataset string) ([]string, error) { + algorithm := AlgorithmsForDatasets[dataset] + if len(algorithm) != 0 { + return algorithm, nil + } + return nil, errors.New("not found") +} + +func GetAlgorithmsSync(ctx context.Context, collectorMap map[string]collector.AiCollector, resourceType string, taskType string, dataset string) ([]string, error) { var names []string var wg sync.WaitGroup var errCh = make(chan interface{}, len(collectorMap))