168 lines
3.4 KiB
Go
168 lines
3.4 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/algorithm"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/errs"
|
|
)
|
|
|
|
type Algorithm struct {
|
|
algo algorithm.IAlgorithm
|
|
}
|
|
|
|
func NewAlgorithm(algo algorithm.IAlgorithm) *Algorithm {
|
|
return &Algorithm{algo: algo}
|
|
}
|
|
|
|
func (a *Algorithm) All(ctx context.Context, filter *algorithm.Filter) (algorithm.Algorithms, error) {
|
|
alg, err := a.algo.All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(alg) == 0 {
|
|
return nil, errs.Error_Not_Found_Algorithm
|
|
}
|
|
|
|
if filter != nil {
|
|
filtered := filter.Apply(alg)
|
|
if len(filtered) == 0 {
|
|
return nil, errs.Error_Not_Found_Algorithm
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
return alg, nil
|
|
}
|
|
|
|
func (a *Algorithm) Train(ctx context.Context, filter *algorithm.Filter) (algorithm.Algorithms, error) {
|
|
alg, err := a.algo.Train(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(alg) == 0 {
|
|
return nil, errs.Error_Not_Found_Algorithm
|
|
}
|
|
|
|
if filter != nil {
|
|
filtered := filter.Apply(alg)
|
|
if len(filtered) == 0 {
|
|
return nil, errs.Error_Not_Found_Algorithm
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
return alg, nil
|
|
}
|
|
|
|
func (a *Algorithm) Infer(ctx context.Context, filter *algorithm.Filter) (algorithm.Algorithms, error) {
|
|
alg, err := a.algo.Infer(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(alg) == 0 {
|
|
return nil, errs.Error_Not_Found_Algorithm
|
|
}
|
|
|
|
if filter != nil {
|
|
filtered := filter.Apply(alg)
|
|
if len(filtered) == 0 {
|
|
return nil, errs.Error_Not_Found_Algorithm
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
return alg, nil
|
|
}
|
|
|
|
func (a *Algorithm) Create(ctx context.Context, param *algorithm.CreateParam) (*algorithm.CreateResp, error) {
|
|
alg, err := a.algo.Create(ctx, param)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return alg, nil
|
|
}
|
|
|
|
func (a *Algorithm) TrainParam(ctx context.Context, param *AlgorithmParam) (algorithm.TrainParameter, error) {
|
|
if param == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
filter := algorithm.NewFilter().
|
|
SetId(param.Id).
|
|
SetName(param.Name)
|
|
|
|
filtered, err := a.Train(ctx, filter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(filtered) != 1 {
|
|
return nil, errs.Error_Not_Found_Algorithm
|
|
}
|
|
|
|
var (
|
|
chosen = filtered[0]
|
|
ftSpecNotFound = true
|
|
trainParam algorithm.TrainParameter
|
|
)
|
|
|
|
if chosen.Features().SpecBondWithId {
|
|
if withId := chosen.Features().WithId; withId != nil {
|
|
ftSpecNotFound = false
|
|
if chosen.Features().SpecHasVersionControl {
|
|
if withIdVsn := chosen.Features().WithIdVsn; withIdVsn != nil {
|
|
version, err := withIdVsn()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
trainParam = version
|
|
} else {
|
|
id, err := withId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
trainParam = id
|
|
}
|
|
} else {
|
|
id, err := withId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
trainParam = id
|
|
}
|
|
}
|
|
}
|
|
|
|
if ftSpecNotFound {
|
|
if chosen.Features().SpecBasedOnFileSystem {
|
|
if withFilePath := chosen.Features().WithFilePath; withFilePath != nil {
|
|
ftSpecNotFound = false
|
|
filePath, err := withFilePath()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
trainParam = filePath
|
|
}
|
|
} else if chosen.Features().SpecBasedOnRepository {
|
|
if withCodeRepo := chosen.Features().WithCodeRepo; withCodeRepo != nil {
|
|
ftSpecNotFound = false
|
|
codeRepo, err := withCodeRepo()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
trainParam = codeRepo
|
|
}
|
|
}
|
|
}
|
|
|
|
if ftSpecNotFound {
|
|
spec, err := chosen.Spec()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
trainParam = spec
|
|
}
|
|
|
|
return trainParam, nil
|
|
}
|