337 lines
8.5 KiB
Go
337 lines
8.5 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/image"
|
|
"sync"
|
|
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/algorithm"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/dataset"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/model"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/platform"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/resource"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/task"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-octopus"
|
|
openI "gitlink.org.cn/JointCloud/pcm-participant-openi"
|
|
)
|
|
|
|
type Service struct {
|
|
resourceMap sync.Map
|
|
datasetMap sync.Map
|
|
taskMap sync.Map
|
|
algorithmMap sync.Map
|
|
imageMap sync.Map
|
|
modelMap sync.Map
|
|
}
|
|
|
|
func NewService(platforms ...platform.IPlatform) (*Service, error) {
|
|
s := &Service{}
|
|
for _, pf := range platforms {
|
|
if err := s.registerPlatform(pf); err != nil {
|
|
return nil, fmt.Errorf("failed to register platform %d: %w", pf.Id(), err)
|
|
}
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// registerPlatform 注册单个平台及其所有组件
|
|
func (s *Service) registerPlatform(pf platform.IPlatform) error {
|
|
// 内部注册函数
|
|
register := func(res resource.IResource, img image.IImage, task task.Task, ds dataset.IDataset, alg algorithm.IAlgorithm, mdl model.IModel, platformName string) error {
|
|
if res == nil || img == nil || task == nil || ds == nil || alg == nil || mdl == nil {
|
|
return fmt.Errorf("%s platform missing required components", platformName)
|
|
}
|
|
|
|
s.resourceMap.Store(pf.Id(), NewResource(res))
|
|
s.imageMap.Store(pf.Id(), NewImage(img))
|
|
s.taskMap.Store(pf.Id(), NewTask(task))
|
|
s.datasetMap.Store(pf.Id(), NewDataset(ds))
|
|
s.algorithmMap.Store(pf.Id(), NewAlgorithm(alg))
|
|
s.modelMap.Store(pf.Id(), NewModel(mdl))
|
|
return nil
|
|
}
|
|
|
|
switch pt := pf.(type) {
|
|
case *openI.OpenI:
|
|
return register(pt.Res, pt.Img, pt.Task, pt.Ds, pt.Alg, pt.Mdl, "OpenI")
|
|
case *octopus.Octopus:
|
|
return register(pt.Res, pt.Img, pt.Task, pt.Ds, pt.Alg, pt.Mdl, "Octopus")
|
|
default:
|
|
return fmt.Errorf("unsupported platform type: %T", pf)
|
|
}
|
|
}
|
|
|
|
// Resource operations
|
|
func (s *Service) GetResourceSpecs(ctx context.Context, id int64, rtype string) (interface{}, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadResource(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.GetResourcespecs(ctx, rtype)
|
|
}
|
|
|
|
func (s *Service) TrainResources(ctx context.Context, id int64) ([]*resource.Spec, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadResource(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resources, err := val.Train(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resource training failed: %w", err)
|
|
}
|
|
|
|
return resources.Specs()
|
|
}
|
|
|
|
func (s *Service) AllTrainResources(ctx context.Context) ([]*resource.Spec, error) {
|
|
var (
|
|
all []*resource.Spec
|
|
mu sync.Mutex
|
|
wg sync.WaitGroup
|
|
errList []error
|
|
)
|
|
|
|
s.resourceMap.Range(func(key, value interface{}) bool {
|
|
wg.Add(1)
|
|
go func(id platform.Id) {
|
|
defer wg.Done()
|
|
|
|
resources, err := s.TrainResources(ctx, int64(id))
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if err != nil {
|
|
errList = append(errList, fmt.Errorf("platform %d: %w", id, err))
|
|
return
|
|
}
|
|
|
|
if len(resources) > 0 {
|
|
all = append(all, resources...)
|
|
}
|
|
}(key.(platform.Id))
|
|
|
|
return true
|
|
})
|
|
|
|
wg.Wait()
|
|
return all, errors.Join(errList...)
|
|
}
|
|
|
|
// Algorithm operations
|
|
func (s *Service) TrainAlgorithms(ctx context.Context, id int64) ([]*algorithm.Algorithm, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadAlgorithm(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
alg, err := val.Train(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("algorithm training failed: %w", err)
|
|
}
|
|
|
|
return alg.Algorithms()
|
|
}
|
|
|
|
func (s *Service) AllTrainAlgorithms(ctx context.Context) ([]*algorithm.Algorithm, error) {
|
|
var (
|
|
all []*algorithm.Algorithm
|
|
mu sync.Mutex
|
|
wg sync.WaitGroup
|
|
errs []error
|
|
)
|
|
|
|
s.algorithmMap.Range(func(key, value interface{}) bool {
|
|
wg.Add(1)
|
|
go func(id platform.Id) {
|
|
defer wg.Done()
|
|
|
|
algorithms, err := s.TrainAlgorithms(ctx, int64(id))
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if err != nil {
|
|
errs = append(errs, fmt.Errorf("platform %d: %w", id, err))
|
|
return
|
|
}
|
|
|
|
if len(algorithms) > 0 {
|
|
all = append(all, algorithms...)
|
|
}
|
|
}(key.(platform.Id))
|
|
|
|
return true
|
|
})
|
|
|
|
wg.Wait()
|
|
return all, errors.Join(errs...)
|
|
}
|
|
|
|
func (s *Service) CreateAlgorithm(ctx context.Context, id int64, param *algorithm.CreateParam) (interface{}, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadAlgorithm(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.Create(ctx, param)
|
|
}
|
|
|
|
func (s *Service) CreateTrainTask(ctx context.Context, param *CreateTrainTaskParam) (interface{}, error) {
|
|
if param == nil || param.Id == nil {
|
|
return nil, fmt.Errorf("invalid task parameters")
|
|
}
|
|
|
|
trainParams, err := s.generateParamsForTrainTask(ctx, *param.Id, param)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate train params: %w", err)
|
|
}
|
|
|
|
val, err := s.loadTask(*param.Id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load task: %w", err)
|
|
}
|
|
|
|
return val.createTrainingTask(ctx, trainParams)
|
|
}
|
|
|
|
func (s *Service) generateParamsForTrainTask(ctx context.Context, id platform.Id, cp *CreateTrainTaskParam) (*task.TrainParams, error) {
|
|
res, err := s.loadResource(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dat, err := s.loadDataset(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
img, err := s.loadImage(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
param := &task.TrainParams{}
|
|
|
|
if resourceParam, err := res.TrainParam(ctx, cp.Resource); err != nil {
|
|
return nil, fmt.Errorf("resource param error: %w", err)
|
|
} else {
|
|
param.Resource = resourceParam
|
|
}
|
|
|
|
if datasetParam, err := dat.TrainParam(ctx, cp.Dataset); err != nil {
|
|
return nil, fmt.Errorf("dataset param error: %w", err)
|
|
} else {
|
|
param.Dataset = datasetParam
|
|
}
|
|
|
|
if imgParam, err := img.TrainParam(ctx, cp.Image); err != nil {
|
|
return nil, fmt.Errorf("image param error: %w", err)
|
|
} else {
|
|
param.Image = imgParam
|
|
}
|
|
|
|
return param, nil
|
|
}
|
|
|
|
func (s *Service) TaskResultSync(ctx context.Context, id int64, param *task.ResultSyncParam) error {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadTask(pid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return val.resultSync(ctx, param)
|
|
}
|
|
|
|
func (s *Service) TaskLog(ctx context.Context, id int64, taskId string) (interface{}, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadTask(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.getLog(ctx, taskId)
|
|
}
|
|
|
|
func (s *Service) GetTrainingTask(ctx context.Context, id int64, taskId string) (interface{}, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadTask(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.getTrainingTask(ctx, taskId)
|
|
}
|
|
|
|
func (s *Service) GetInferenceTask(ctx context.Context, id int64, taskId string) (interface{}, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadTask(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.getInferenceTask(ctx, taskId)
|
|
}
|
|
|
|
// Dataset operations
|
|
func (s *Service) CreateDataset(ctx context.Context, id int64, param *dataset.CreateParam) (interface{}, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadDataset(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.Create(ctx, param)
|
|
}
|
|
|
|
// Model operations
|
|
func (s *Service) CreateModel(ctx context.Context, id int64, param *model.CreateParam) (interface{}, error) {
|
|
pid := platform.Id(id)
|
|
val, err := s.loadModel(pid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.Create(ctx, param)
|
|
}
|
|
|
|
// Helper methods for loading
|
|
func (s *Service) loadResource(id platform.Id) (*Resource, error) {
|
|
return loadFromSyncMap[*Resource](&s.resourceMap, id, "resource")
|
|
}
|
|
|
|
func (s *Service) loadDataset(id platform.Id) (*Dataset, error) {
|
|
return loadFromSyncMap[*Dataset](&s.datasetMap, id, "dataset")
|
|
}
|
|
|
|
func (s *Service) loadTask(id platform.Id) (*Task, error) {
|
|
return loadFromSyncMap[*Task](&s.taskMap, id, "task")
|
|
}
|
|
|
|
func (s *Service) loadAlgorithm(id platform.Id) (*Algorithm, error) {
|
|
return loadFromSyncMap[*Algorithm](&s.algorithmMap, id, "algorithm")
|
|
}
|
|
|
|
func (s *Service) loadImage(id platform.Id) (*Image, error) {
|
|
return loadFromSyncMap[*Image](&s.imageMap, id, "image")
|
|
}
|
|
|
|
func (s *Service) loadModel(id platform.Id) (*Model, error) {
|
|
return loadFromSyncMap[*Model](&s.modelMap, id, "model")
|
|
}
|
|
|
|
func loadFromSyncMap[T any](m *sync.Map, id platform.Id, resourceType string) (T, error) {
|
|
var zero T
|
|
val, ok := m.Load(id)
|
|
if !ok {
|
|
return zero, fmt.Errorf("%s for platform ID %d not found", resourceType, id)
|
|
}
|
|
|
|
result, ok := val.(T)
|
|
if !ok {
|
|
return zero, fmt.Errorf("invalid %s type for platform ID %d", resourceType, id)
|
|
}
|
|
|
|
return result, nil
|
|
}
|