pcm-participant/ai/service/algorithm.go

168 lines
3.4 KiB
Go

package service
import (
"context"
"gitlink.org.cn/JointCloud/pcm-participant-ai/algorithm"
"gitlink.org.cn/JointCloud/pcm-participant-ai/errs"
)
type Algorithm struct {
algo algorithm.IAlgorithm
}
func NewAlgorithm(algo algorithm.IAlgorithm) *Algorithm {
return &Algorithm{algo: algo}
}
func (a *Algorithm) All(ctx context.Context, filter *algorithm.Filter) (algorithm.Algorithms, error) {
alg, err := a.algo.All(ctx)
if err != nil {
return nil, err
}
if len(alg) == 0 {
return nil, errs.Error_Not_Found_Algorithm
}
if filter != nil {
filtered := filter.Apply(alg)
if len(filtered) == 0 {
return nil, errs.Error_Not_Found_Algorithm
}
return filtered, nil
}
return alg, nil
}
func (a *Algorithm) Train(ctx context.Context, filter *algorithm.Filter) (algorithm.Algorithms, error) {
alg, err := a.algo.Train(ctx)
if err != nil {
return nil, err
}
if len(alg) == 0 {
return nil, errs.Error_Not_Found_Algorithm
}
if filter != nil {
filtered := filter.Apply(alg)
if len(filtered) == 0 {
return nil, errs.Error_Not_Found_Algorithm
}
return filtered, nil
}
return alg, nil
}
func (a *Algorithm) Infer(ctx context.Context, filter *algorithm.Filter) (algorithm.Algorithms, error) {
alg, err := a.algo.Infer(ctx)
if err != nil {
return nil, err
}
if len(alg) == 0 {
return nil, errs.Error_Not_Found_Algorithm
}
if filter != nil {
filtered := filter.Apply(alg)
if len(filtered) == 0 {
return nil, errs.Error_Not_Found_Algorithm
}
return filtered, nil
}
return alg, nil
}
func (a *Algorithm) Create(ctx context.Context, param *algorithm.CreateParam) (*algorithm.CreateResp, error) {
alg, err := a.algo.Create(ctx, param)
if err != nil {
return nil, err
}
return alg, nil
}
func (a *Algorithm) TrainParam(ctx context.Context, param *AlgorithmParam) (algorithm.TrainParameter, error) {
if param == nil {
return nil, nil
}
filter := algorithm.NewFilter().
SetId(param.Id).
SetName(param.Name)
filtered, err := a.Train(ctx, filter)
if err != nil {
return nil, err
}
if len(filtered) != 1 {
return nil, errs.Error_Not_Found_Algorithm
}
var (
chosen = filtered[0]
ftSpecNotFound = true
trainParam algorithm.TrainParameter
)
if chosen.Features().SpecBondWithId {
if withId := chosen.Features().WithId; withId != nil {
ftSpecNotFound = false
if chosen.Features().SpecHasVersionControl {
if withIdVsn := chosen.Features().WithIdVsn; withIdVsn != nil {
version, err := withIdVsn()
if err != nil {
return nil, err
}
trainParam = version
} else {
id, err := withId()
if err != nil {
return nil, err
}
trainParam = id
}
} else {
id, err := withId()
if err != nil {
return nil, err
}
trainParam = id
}
}
}
if ftSpecNotFound {
if chosen.Features().SpecBasedOnFileSystem {
if withFilePath := chosen.Features().WithFilePath; withFilePath != nil {
ftSpecNotFound = false
filePath, err := withFilePath()
if err != nil {
return nil, err
}
trainParam = filePath
}
} else if chosen.Features().SpecBasedOnRepository {
if withCodeRepo := chosen.Features().WithCodeRepo; withCodeRepo != nil {
ftSpecNotFound = false
codeRepo, err := withCodeRepo()
if err != nil {
return nil, err
}
trainParam = codeRepo
}
}
}
if ftSpecNotFound {
spec, err := chosen.Spec()
if err != nil {
return nil, err
}
trainParam = spec
}
return trainParam, nil
}