From a7ad973ff646e149e7cb004462901a8631532e18 Mon Sep 17 00:00:00 2001 From: tzwang Date: Mon, 14 Jul 2025 17:58:06 +0800 Subject: [PATCH] update service --- ai/service/service.go | 433 +++++++++++++++++++------------------ ai/service/service_test.go | 5 +- client/api/ai.go | 12 - 3 files changed, 230 insertions(+), 220 deletions(-) diff --git a/ai/service/service.go b/ai/service/service.go index e35b88b..9f4c823 100644 --- a/ai/service/service.go +++ b/ai/service/service.go @@ -2,8 +2,12 @@ package service import ( "context" + "errors" + "fmt" + "gitlink.org.cn/JointCloud/pcm-participant-ai/image" + "sync" + "gitlink.org.cn/JointCloud/pcm-participant-ai/algorithm" - "gitlink.org.cn/JointCloud/pcm-participant-ai/common" "gitlink.org.cn/JointCloud/pcm-participant-ai/dataset" "gitlink.org.cn/JointCloud/pcm-participant-ai/model" "gitlink.org.cn/JointCloud/pcm-participant-ai/platform" @@ -11,303 +15,322 @@ import ( "gitlink.org.cn/JointCloud/pcm-participant-ai/task" "gitlink.org.cn/JointCloud/pcm-participant-octopus" openI "gitlink.org.cn/JointCloud/pcm-participant-openi" - "sync" ) type Service struct { - rmap map[platform.Id]*Resource - dmap map[platform.Id]*Dataset - tmap map[platform.Id]*Task - amap map[platform.Id]*Algorithm - imap map[platform.Id]*Image - mmap map[platform.Id]*Model - rlock sync.Mutex - dlock sync.Mutex - tlock sync.Mutex + resourceMap sync.Map + datasetMap sync.Map + taskMap sync.Map + algorithmMap sync.Map + imageMap sync.Map + modelMap sync.Map } func NewService(platforms ...platform.IPlatform) (*Service, error) { - rmap := make(map[platform.Id]*Resource) - amap := make(map[platform.Id]*Algorithm) - imap := make(map[platform.Id]*Image) - mmap := make(map[platform.Id]*Model) + s := &Service{} for _, pf := range platforms { - switch pf.Type() { - case platform.OpenI: - openI, ok := pf.(*openI.OpenI) - if !ok { - - } - if openI.Res == nil { - - } - if openI.Img == nil { - - } - - resource := NewResource(openI.Res) - rmap[pf.Id()] = resource - - alg := NewAlgorithm(openI.Alg) - amap[pf.Id()] = alg - - img := NewImage(openI.Img) - imap[pf.Id()] = img - - mdl := NewModel(openI.Mdl) - mmap[pf.Id()] = mdl - - case platform.Octopus: - oct, ok := pf.(*octopus.Octopus) - if !ok { - - } - alg := NewAlgorithm(oct.Alg) - amap[pf.Id()] = alg + if err := s.registerPlatform(pf); err != nil { + return nil, fmt.Errorf("failed to register platform %d: %w", pf.Id(), err) } } - return &Service{rmap: rmap, amap: amap, imap: imap, mmap: mmap}, nil + return s, nil } -// resource -func (s *Service) GetResourceSpecs(ctx context.Context, pfId int64, rtype string) (interface{}, error) { - var pid = platform.Id(pfId) - res, found := s.rmap[pid] - if !found { +// registerPlatform 注册单个平台及其所有组件 +func (s *Service) registerPlatform(pf platform.IPlatform) error { + // 内部注册函数 + register := func(res resource.IResource, img image.IImage, task task.Task, ds dataset.IDataset, alg algorithm.IAlgorithm, mdl model.IModel, platformName string) error { + if res == nil || img == nil || task == nil || ds == nil || alg == nil || mdl == nil { + return fmt.Errorf("%s platform missing required components", platformName) + } + s.resourceMap.Store(pf.Id(), NewResource(res)) + s.imageMap.Store(pf.Id(), NewImage(img)) + s.taskMap.Store(pf.Id(), NewTask(task)) + s.datasetMap.Store(pf.Id(), NewDataset(ds)) + s.algorithmMap.Store(pf.Id(), NewAlgorithm(alg)) + s.modelMap.Store(pf.Id(), NewModel(mdl)) + return nil } - specs, err := res.GetResourcespecs(ctx, rtype) - if err != nil { - return nil, err + + switch pt := pf.(type) { + case *openI.OpenI: + return register(pt.Res, pt.Img, pt.Task, pt.Ds, pt.Alg, pt.Mdl, "OpenI") + case *octopus.Octopus: + return register(pt.Res, pt.Img, pt.Task, pt.Ds, pt.Alg, pt.Mdl, "Octopus") + default: + return fmt.Errorf("unsupported platform type: %T", pf) } - return specs, nil } -func (s *Service) TrainResources(ctx context.Context, pfId int64) ([]*resource.Spec, error) { - var pid = platform.Id(pfId) - res, found := s.rmap[pid] - if !found { +// Resource operations +func (s *Service) GetResourceSpecs(ctx context.Context, id int64, rtype string) (interface{}, error) { + pid := platform.Id(id) + val, err := s.loadResource(pid) + if err != nil { + return nil, err + } + return val.GetResourcespecs(ctx, rtype) +} - } - resources, err := res.Train(ctx, nil) +func (s *Service) TrainResources(ctx context.Context, id int64) ([]*resource.Spec, error) { + pid := platform.Id(id) + val, err := s.loadResource(pid) if err != nil { return nil, err } - specs, err := resources.Specs() + + resources, err := val.Train(ctx, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("resource training failed: %w", err) } - return specs, nil + + return resources.Specs() } func (s *Service) AllTrainResources(ctx context.Context) ([]*resource.Spec, error) { - all := make([]*resource.Spec, 0) - for id, _ := range s.rmap { - resources, err := s.TrainResources(ctx, int64(id)) - if err != nil { - return nil, err - } - if len(resources) == 0 { - continue - } - all = common.ConcatMultipleSlices([][]*resource.Spec{all, resources}) - } - return all, nil + var ( + all []*resource.Spec + mu sync.Mutex + wg sync.WaitGroup + errList []error + ) + s.resourceMap.Range(func(key, value interface{}) bool { + wg.Add(1) + go func(id platform.Id) { + defer wg.Done() + + resources, err := s.TrainResources(ctx, int64(id)) + + mu.Lock() + defer mu.Unlock() + + if err != nil { + errList = append(errList, fmt.Errorf("platform %d: %w", id, err)) + return + } + + if len(resources) > 0 { + all = append(all, resources...) + } + }(key.(platform.Id)) + + return true + }) + + wg.Wait() + return all, errors.Join(errList...) } -// algorithm -func (s *Service) TrainAlgorithms(ctx context.Context, pfId int64) ([]*algorithm.Algorithm, error) { - var pid = platform.Id(pfId) - a, found := s.amap[pid] - if !found { +// Algorithm operations +func (s *Service) TrainAlgorithms(ctx context.Context, id int64) ([]*algorithm.Algorithm, error) { + pid := platform.Id(id) + val, err := s.loadAlgorithm(pid) + if err != nil { + return nil, err + } - } - alg, err := a.Train(ctx, nil) + alg, err := val.Train(ctx, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("algorithm training failed: %w", err) } - algorithms, err := alg.Algorithms() - if err != nil { - return nil, err - } - return algorithms, nil + + return alg.Algorithms() } func (s *Service) AllTrainAlgorithms(ctx context.Context) ([]*algorithm.Algorithm, error) { - all := make([]*algorithm.Algorithm, 0) - for id, _ := range s.amap { - algorithms, err := s.TrainAlgorithms(ctx, int64(id)) - if err != nil { - return nil, err - } - if len(algorithms) == 0 { - continue - } - all = common.ConcatMultipleSlices([][]*algorithm.Algorithm{all, algorithms}) - } - return all, nil + var ( + all []*algorithm.Algorithm + mu sync.Mutex + wg sync.WaitGroup + errs []error + ) + + s.algorithmMap.Range(func(key, value interface{}) bool { + wg.Add(1) + go func(id platform.Id) { + defer wg.Done() + + algorithms, err := s.TrainAlgorithms(ctx, int64(id)) + mu.Lock() + defer mu.Unlock() + + if err != nil { + errs = append(errs, fmt.Errorf("platform %d: %w", id, err)) + return + } + + if len(algorithms) > 0 { + all = append(all, algorithms...) + } + }(key.(platform.Id)) + + return true + }) + + wg.Wait() + return all, errors.Join(errs...) } -func (s *Service) CreateAlgorithm(ctx context.Context, pfId int64, param *algorithm.CreateParam) (interface{}, error) { - var pid = platform.Id(pfId) - alg, found := s.amap[pid] - if !found { - - } - resp, err := alg.Create(ctx, param) +func (s *Service) CreateAlgorithm(ctx context.Context, id int64, param *algorithm.CreateParam) (interface{}, error) { + pid := platform.Id(id) + val, err := s.loadAlgorithm(pid) if err != nil { return nil, err } - - return resp, nil + return val.Create(ctx, param) } -// task func (s *Service) CreateTrainTask(ctx context.Context, param *CreateTrainTaskParam) (interface{}, error) { - trainParams, err := s.generateParamsForTrainTask(ctx, param.Id, param) - if err != nil { - return nil, err + if param == nil || param.Id == nil { + return nil, fmt.Errorf("invalid task parameters") } - t, found := s.tmap[*param.Id] - if !found { - } - resp, err := t.createTrainingTask(ctx, trainParams) + trainParams, err := s.generateParamsForTrainTask(ctx, *param.Id, param) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to generate train params: %w", err) } - return resp, nil + + val, err := s.loadTask(*param.Id) + if err != nil { + return nil, fmt.Errorf("failed to load task: %w", err) + } + + return val.createTrainingTask(ctx, trainParams) } -func (s *Service) generateParamsForTrainTask(ctx context.Context, pid *platform.Id, cp *CreateTrainTaskParam) (*task.TrainParams, error) { - res, found := s.rmap[*pid] - if !found { - +func (s *Service) generateParamsForTrainTask(ctx context.Context, id platform.Id, cp *CreateTrainTaskParam) (*task.TrainParams, error) { + res, err := s.loadResource(id) + if err != nil { + return nil, err } - dat, found := s.dmap[*pid] - if !found { + dat, err := s.loadDataset(id) + if err != nil { + return nil, err } - img, found := s.imap[*pid] - if !found { + img, err := s.loadImage(id) + if err != nil { + return nil, err } param := &task.TrainParams{} - resourceParam, err := res.TrainParam(ctx, cp.Resource) - if err != nil { - return nil, err - } - datasetParam, err := dat.TrainParam(ctx, cp.Dataset) - if err != nil { - return nil, err - } - imgParam, err := img.TrainParam(ctx, cp.Image) - if err != nil { - return nil, err + if resourceParam, err := res.TrainParam(ctx, cp.Resource); err != nil { + return nil, fmt.Errorf("resource param error: %w", err) + } else { + param.Resource = resourceParam } - param.Resource = resourceParam - param.Dataset = datasetParam - param.Image = imgParam + if datasetParam, err := dat.TrainParam(ctx, cp.Dataset); err != nil { + return nil, fmt.Errorf("dataset param error: %w", err) + } else { + param.Dataset = datasetParam + } + + if imgParam, err := img.TrainParam(ctx, cp.Image); err != nil { + return nil, fmt.Errorf("image param error: %w", err) + } else { + param.Image = imgParam + } return param, nil } -func (s *Service) TaskResultSync(ctx context.Context, pfId int64, param *task.ResultSyncParam) error { - var pid = platform.Id(pfId) - tsk, found := s.tmap[pid] - if !found { - - } - err := tsk.resultSync(ctx, param) +func (s *Service) TaskResultSync(ctx context.Context, id int64, param *task.ResultSyncParam) error { + pid := platform.Id(id) + val, err := s.loadTask(pid) if err != nil { return err } - - return nil + return val.resultSync(ctx, param) } -func (s *Service) TaskLog(ctx context.Context, pfId int64, id string) (interface{}, error) { - var pid = platform.Id(pfId) - tsk, found := s.tmap[pid] - if !found { - - } - resp, err := tsk.getLog(ctx, id) +func (s *Service) TaskLog(ctx context.Context, id int64, taskId string) (interface{}, error) { + pid := platform.Id(id) + val, err := s.loadTask(pid) if err != nil { return nil, err } - - return resp, nil + return val.getLog(ctx, taskId) } -func (s *Service) GetTrainingTask(ctx context.Context, pfId int64, id string) (interface{}, error) { - var pid = platform.Id(pfId) - tsk, found := s.tmap[pid] - if !found { - - } - resp, err := tsk.getTrainingTask(ctx, id) +func (s *Service) GetTrainingTask(ctx context.Context, id int64, taskId string) (interface{}, error) { + pid := platform.Id(id) + val, err := s.loadTask(pid) if err != nil { return nil, err } - - return resp, nil + return val.getTrainingTask(ctx, taskId) } -func (s *Service) GetInferenceTask(ctx context.Context, pfId int64, id string) (interface{}, error) { - var pid = platform.Id(pfId) - tsk, found := s.tmap[pid] - if !found { - - } - resp, err := tsk.getInferenceTask(ctx, id) +func (s *Service) GetInferenceTask(ctx context.Context, id int64, taskId string) (interface{}, error) { + pid := platform.Id(id) + val, err := s.loadTask(pid) if err != nil { return nil, err } - - return resp, nil + return val.getInferenceTask(ctx, taskId) } -// dataset -func (s *Service) CreateDataset(ctx context.Context, pfId int64, param *dataset.CreateParam) (interface{}, error) { - var pid = platform.Id(pfId) - ds, found := s.dmap[pid] - if !found { - - } - resp, err := ds.Create(ctx, param) +// Dataset operations +func (s *Service) CreateDataset(ctx context.Context, id int64, param *dataset.CreateParam) (interface{}, error) { + pid := platform.Id(id) + val, err := s.loadDataset(pid) if err != nil { return nil, err } - - return resp, nil + return val.Create(ctx, param) } -// model -func (s *Service) CreateModel(ctx context.Context, pfId int64, param *model.CreateParam) (interface{}, error) { - var pid = platform.Id(pfId) - mdl, found := s.mmap[pid] - if !found { - - } - resp, err := mdl.Create(ctx, param) +// Model operations +func (s *Service) CreateModel(ctx context.Context, id int64, param *model.CreateParam) (interface{}, error) { + pid := platform.Id(id) + val, err := s.loadModel(pid) if err != nil { return nil, err } - - return resp, nil + return val.Create(ctx, param) } -func (s *Service) TestFuncRes(ctx context.Context, pfId int64) { - var pid = platform.Id(pfId) - res, found := s.rmap[pid] - if !found { +// Helper methods for loading +func (s *Service) loadResource(id platform.Id) (*Resource, error) { + return loadFromSyncMap[*Resource](&s.resourceMap, id, "resource") +} +func (s *Service) loadDataset(id platform.Id) (*Dataset, error) { + return loadFromSyncMap[*Dataset](&s.datasetMap, id, "dataset") +} + +func (s *Service) loadTask(id platform.Id) (*Task, error) { + return loadFromSyncMap[*Task](&s.taskMap, id, "task") +} + +func (s *Service) loadAlgorithm(id platform.Id) (*Algorithm, error) { + return loadFromSyncMap[*Algorithm](&s.algorithmMap, id, "algorithm") +} + +func (s *Service) loadImage(id platform.Id) (*Image, error) { + return loadFromSyncMap[*Image](&s.imageMap, id, "image") +} + +func (s *Service) loadModel(id platform.Id) (*Model, error) { + return loadFromSyncMap[*Model](&s.modelMap, id, "model") +} + +func loadFromSyncMap[T any](m *sync.Map, id platform.Id, resourceType string) (T, error) { + var zero T + val, ok := m.Load(id) + if !ok { + return zero, fmt.Errorf("%s for platform ID %d not found", resourceType, id) } - res.TrainParam(context.Background(), nil) + result, ok := val.(T) + if !ok { + return zero, fmt.Errorf("invalid %s type for platform ID %d", resourceType, id) + } + + return result, nil } diff --git a/ai/service/service_test.go b/ai/service/service_test.go index 2e3f685..b3968e1 100644 --- a/ai/service/service_test.go +++ b/ai/service/service_test.go @@ -7,7 +7,6 @@ import ( "github.com/smartystreets/goconvey/convey" aiconf "gitlink.org.cn/JointCloud/pcm-participant-ai/config" "gitlink.org.cn/JointCloud/pcm-participant-ai/platform" - "gitlink.org.cn/JointCloud/pcm-participant-octopus" openI "gitlink.org.cn/JointCloud/pcm-participant-openi" "gitlink.org.cn/JointCloud/pcm-participant-openi/common" "testing" @@ -17,11 +16,11 @@ import ( func TestService(t *testing.T) { convey.Convey("Test Service", t, func() { o, _ := openI.New(aiconf.Cfg[aiconf.OpenI].Username, aiconf.Cfg[aiconf.OpenI].Password, aiconf.Cfg[aiconf.OpenI].APIKey, platform.Id(123), aiconf.Cfg[aiconf.OpenI].DataRepo) - oct, _ := octopus.New(aiconf.Cfg[aiconf.Octopus].URL, aiconf.Cfg[aiconf.Octopus].Username, aiconf.Cfg[aiconf.Octopus].Password, platform.Id(123)) + //oct, _ := octopus.New(aiconf.Cfg[aiconf.Octopus].URL, aiconf.Cfg[aiconf.Octopus].Username, aiconf.Cfg[aiconf.Octopus].Password, platform.Id(456)) common.InitClient() - svc, err := NewService(o, oct) + svc, err := NewService(o) if err != nil { } diff --git a/client/api/ai.go b/client/api/ai.go index d683072..4282340 100644 --- a/client/api/ai.go +++ b/client/api/ai.go @@ -209,15 +209,3 @@ func (a *aiApi) CreateTrainTaskHandler(c *gin.Context) { } model.Response(c, http.StatusOK, "success", resp) } - -// TestFuncResHandler 处理测试资源相关功能的请求 -func (a *aiApi) TestFuncResHandler(c *gin.Context) { - pfIdStr := c.Query("pfId") - pfId, err := strconv.ParseInt(pfIdStr, 10, 64) - if err != nil { - model.Response(c, http.StatusBadRequest, "invalid pfId", nil) - return - } - a.service.TestFuncRes(c, pfId) - model.Response(c, http.StatusOK, "success", nil) -}