pcm-participant/ai/service/service.go

337 lines
8.5 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"gitlink.org.cn/JointCloud/pcm-participant-ai/image"
"sync"
"gitlink.org.cn/JointCloud/pcm-participant-ai/algorithm"
"gitlink.org.cn/JointCloud/pcm-participant-ai/dataset"
"gitlink.org.cn/JointCloud/pcm-participant-ai/model"
"gitlink.org.cn/JointCloud/pcm-participant-ai/platform"
"gitlink.org.cn/JointCloud/pcm-participant-ai/resource"
"gitlink.org.cn/JointCloud/pcm-participant-ai/task"
"gitlink.org.cn/JointCloud/pcm-participant-octopus"
openI "gitlink.org.cn/JointCloud/pcm-participant-openi"
)
type Service struct {
resourceMap sync.Map
datasetMap sync.Map
taskMap sync.Map
algorithmMap sync.Map
imageMap sync.Map
modelMap sync.Map
}
func NewService(platforms ...platform.IPlatform) (*Service, error) {
s := &Service{}
for _, pf := range platforms {
if err := s.registerPlatform(pf); err != nil {
return nil, fmt.Errorf("failed to register platform %d: %w", pf.Id(), err)
}
}
return s, nil
}
// registerPlatform 注册单个平台及其所有组件
func (s *Service) registerPlatform(pf platform.IPlatform) error {
// 内部注册函数
register := func(res resource.IResource, img image.IImage, task task.Task, ds dataset.IDataset, alg algorithm.IAlgorithm, mdl model.IModel, platformName string) error {
if res == nil || img == nil || task == nil || ds == nil || alg == nil || mdl == nil {
return fmt.Errorf("%s platform missing required components", platformName)
}
s.resourceMap.Store(pf.Id(), NewResource(res))
s.imageMap.Store(pf.Id(), NewImage(img))
s.taskMap.Store(pf.Id(), NewTask(task))
s.datasetMap.Store(pf.Id(), NewDataset(ds))
s.algorithmMap.Store(pf.Id(), NewAlgorithm(alg))
s.modelMap.Store(pf.Id(), NewModel(mdl))
return nil
}
switch pt := pf.(type) {
case *openI.OpenI:
return register(pt.Res, pt.Img, pt.Task, pt.Ds, pt.Alg, pt.Mdl, "OpenI")
case *octopus.Octopus:
return register(pt.Res, pt.Img, pt.Task, pt.Ds, pt.Alg, pt.Mdl, "Octopus")
default:
return fmt.Errorf("unsupported platform type: %T", pf)
}
}
// Resource operations
func (s *Service) GetResourceSpecs(ctx context.Context, id int64, rtype string) (interface{}, error) {
pid := platform.Id(id)
val, err := s.loadResource(pid)
if err != nil {
return nil, err
}
return val.GetResourcespecs(ctx, rtype)
}
func (s *Service) TrainResources(ctx context.Context, id int64) ([]*resource.Spec, error) {
pid := platform.Id(id)
val, err := s.loadResource(pid)
if err != nil {
return nil, err
}
resources, err := val.Train(ctx, nil)
if err != nil {
return nil, fmt.Errorf("resource training failed: %w", err)
}
return resources.Specs()
}
func (s *Service) AllTrainResources(ctx context.Context) ([]*resource.Spec, error) {
var (
all []*resource.Spec
mu sync.Mutex
wg sync.WaitGroup
errList []error
)
s.resourceMap.Range(func(key, value interface{}) bool {
wg.Add(1)
go func(id platform.Id) {
defer wg.Done()
resources, err := s.TrainResources(ctx, int64(id))
mu.Lock()
defer mu.Unlock()
if err != nil {
errList = append(errList, fmt.Errorf("platform %d: %w", id, err))
return
}
if len(resources) > 0 {
all = append(all, resources...)
}
}(key.(platform.Id))
return true
})
wg.Wait()
return all, errors.Join(errList...)
}
// Algorithm operations
func (s *Service) TrainAlgorithms(ctx context.Context, id int64) ([]*algorithm.Algorithm, error) {
pid := platform.Id(id)
val, err := s.loadAlgorithm(pid)
if err != nil {
return nil, err
}
alg, err := val.Train(ctx, nil)
if err != nil {
return nil, fmt.Errorf("algorithm training failed: %w", err)
}
return alg.Algorithms()
}
func (s *Service) AllTrainAlgorithms(ctx context.Context) ([]*algorithm.Algorithm, error) {
var (
all []*algorithm.Algorithm
mu sync.Mutex
wg sync.WaitGroup
errs []error
)
s.algorithmMap.Range(func(key, value interface{}) bool {
wg.Add(1)
go func(id platform.Id) {
defer wg.Done()
algorithms, err := s.TrainAlgorithms(ctx, int64(id))
mu.Lock()
defer mu.Unlock()
if err != nil {
errs = append(errs, fmt.Errorf("platform %d: %w", id, err))
return
}
if len(algorithms) > 0 {
all = append(all, algorithms...)
}
}(key.(platform.Id))
return true
})
wg.Wait()
return all, errors.Join(errs...)
}
func (s *Service) CreateAlgorithm(ctx context.Context, id int64, param *algorithm.CreateParam) (interface{}, error) {
pid := platform.Id(id)
val, err := s.loadAlgorithm(pid)
if err != nil {
return nil, err
}
return val.Create(ctx, param)
}
func (s *Service) CreateTrainTask(ctx context.Context, param *CreateTrainTaskParam) (interface{}, error) {
if param == nil || param.Id == nil {
return nil, fmt.Errorf("invalid task parameters")
}
trainParams, err := s.generateParamsForTrainTask(ctx, *param.Id, param)
if err != nil {
return nil, fmt.Errorf("failed to generate train params: %w", err)
}
val, err := s.loadTask(*param.Id)
if err != nil {
return nil, fmt.Errorf("failed to load task: %w", err)
}
return val.createTrainingTask(ctx, trainParams)
}
func (s *Service) generateParamsForTrainTask(ctx context.Context, id platform.Id, cp *CreateTrainTaskParam) (*task.TrainParams, error) {
res, err := s.loadResource(id)
if err != nil {
return nil, err
}
dat, err := s.loadDataset(id)
if err != nil {
return nil, err
}
img, err := s.loadImage(id)
if err != nil {
return nil, err
}
param := &task.TrainParams{}
if resourceParam, err := res.TrainParam(ctx, cp.Resource); err != nil {
return nil, fmt.Errorf("resource param error: %w", err)
} else {
param.Resource = resourceParam
}
if datasetParam, err := dat.TrainParam(ctx, cp.Dataset); err != nil {
return nil, fmt.Errorf("dataset param error: %w", err)
} else {
param.Dataset = datasetParam
}
if imgParam, err := img.TrainParam(ctx, cp.Image); err != nil {
return nil, fmt.Errorf("image param error: %w", err)
} else {
param.Image = imgParam
}
return param, nil
}
func (s *Service) TaskResultSync(ctx context.Context, id int64, param *task.ResultSyncParam) error {
pid := platform.Id(id)
val, err := s.loadTask(pid)
if err != nil {
return err
}
return val.resultSync(ctx, param)
}
func (s *Service) TaskLog(ctx context.Context, id int64, taskId string) (interface{}, error) {
pid := platform.Id(id)
val, err := s.loadTask(pid)
if err != nil {
return nil, err
}
return val.getLog(ctx, taskId)
}
func (s *Service) GetTrainingTask(ctx context.Context, id int64, taskId string) (interface{}, error) {
pid := platform.Id(id)
val, err := s.loadTask(pid)
if err != nil {
return nil, err
}
return val.getTrainingTask(ctx, taskId)
}
func (s *Service) GetInferenceTask(ctx context.Context, id int64, taskId string) (interface{}, error) {
pid := platform.Id(id)
val, err := s.loadTask(pid)
if err != nil {
return nil, err
}
return val.getInferenceTask(ctx, taskId)
}
// Dataset operations
func (s *Service) CreateDataset(ctx context.Context, id int64, param *dataset.CreateParam) (interface{}, error) {
pid := platform.Id(id)
val, err := s.loadDataset(pid)
if err != nil {
return nil, err
}
return val.Create(ctx, param)
}
// Model operations
func (s *Service) CreateModel(ctx context.Context, id int64, param *model.CreateParam) (interface{}, error) {
pid := platform.Id(id)
val, err := s.loadModel(pid)
if err != nil {
return nil, err
}
return val.Create(ctx, param)
}
// Helper methods for loading
func (s *Service) loadResource(id platform.Id) (*Resource, error) {
return loadFromSyncMap[*Resource](&s.resourceMap, id, "resource")
}
func (s *Service) loadDataset(id platform.Id) (*Dataset, error) {
return loadFromSyncMap[*Dataset](&s.datasetMap, id, "dataset")
}
func (s *Service) loadTask(id platform.Id) (*Task, error) {
return loadFromSyncMap[*Task](&s.taskMap, id, "task")
}
func (s *Service) loadAlgorithm(id platform.Id) (*Algorithm, error) {
return loadFromSyncMap[*Algorithm](&s.algorithmMap, id, "algorithm")
}
func (s *Service) loadImage(id platform.Id) (*Image, error) {
return loadFromSyncMap[*Image](&s.imageMap, id, "image")
}
func (s *Service) loadModel(id platform.Id) (*Model, error) {
return loadFromSyncMap[*Model](&s.modelMap, id, "model")
}
func loadFromSyncMap[T any](m *sync.Map, id platform.Id, resourceType string) (T, error) {
var zero T
val, ok := m.Load(id)
if !ok {
return zero, fmt.Errorf("%s for platform ID %d not found", resourceType, id)
}
result, ok := val.(T)
if !ok {
return zero, fmt.Errorf("invalid %s type for platform ID %d", resourceType, id)
}
return result, nil
}