mirror of https://github.com/Wox-launcher/Wox
325 lines
11 KiB
Go
325 lines
11 KiB
Go
package plugin
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"path"
|
|
"sync"
|
|
"time"
|
|
"wox/ai"
|
|
"wox/common"
|
|
"wox/i18n"
|
|
"wox/setting"
|
|
"wox/util"
|
|
|
|
"github.com/samber/lo"
|
|
)
|
|
|
|
type LogLevel = string
|
|
|
|
const (
|
|
LogLevelInfo LogLevel = "Info"
|
|
LogLevelError LogLevel = "Error"
|
|
LogLevelDebug LogLevel = "Debug"
|
|
LogLevelWarning LogLevel = "Warning"
|
|
)
|
|
|
|
type API interface {
|
|
ChangeQuery(ctx context.Context, query common.PlainQuery)
|
|
HideApp(ctx context.Context)
|
|
ShowApp(ctx context.Context)
|
|
Notify(ctx context.Context, description string)
|
|
Log(ctx context.Context, level LogLevel, msg string)
|
|
GetTranslation(ctx context.Context, key string) string
|
|
GetSetting(ctx context.Context, key string) string
|
|
SaveSetting(ctx context.Context, key string, value string, isPlatformSpecific bool)
|
|
OnSettingChanged(ctx context.Context, callback func(key string, value string))
|
|
OnGetDynamicSetting(ctx context.Context, callback func(key string) string)
|
|
OnDeepLink(ctx context.Context, callback func(arguments map[string]string))
|
|
OnUnload(ctx context.Context, callback func())
|
|
OnMRURestore(ctx context.Context, callback func(mruData MRUData) (*QueryResult, error))
|
|
RegisterQueryCommands(ctx context.Context, commands []MetadataCommand)
|
|
AIChatStream(ctx context.Context, model common.Model, conversations []common.Conversation, options common.ChatOptions, callback common.ChatStreamFunc) error
|
|
}
|
|
|
|
type APIImpl struct {
|
|
pluginInstance *Instance
|
|
logger *util.Log
|
|
toolCallStartTimeMap *util.HashMap[string, int64] // store the start time of tool calls
|
|
}
|
|
|
|
func (a *APIImpl) ChangeQuery(ctx context.Context, query common.PlainQuery) {
|
|
GetPluginManager().GetUI().ChangeQuery(ctx, query)
|
|
}
|
|
|
|
func (a *APIImpl) HideApp(ctx context.Context) {
|
|
GetPluginManager().GetUI().HideApp(ctx)
|
|
}
|
|
|
|
func (a *APIImpl) ShowApp(ctx context.Context) {
|
|
GetPluginManager().GetUI().ShowApp(ctx, common.ShowContext{
|
|
SelectAll: true,
|
|
})
|
|
}
|
|
|
|
func (a *APIImpl) Notify(ctx context.Context, message string) {
|
|
GetPluginManager().GetUI().Notify(ctx, common.NotifyMsg{
|
|
PluginId: a.pluginInstance.Metadata.Id,
|
|
Text: a.GetTranslation(ctx, message),
|
|
DisplaySeconds: 5,
|
|
})
|
|
}
|
|
|
|
func (a *APIImpl) Log(ctx context.Context, level LogLevel, msg string) {
|
|
logCtx := util.NewComponentContext(ctx, a.pluginInstance.Metadata.Name)
|
|
if level == LogLevelError {
|
|
a.logger.Error(logCtx, msg)
|
|
logger.Error(logCtx, msg)
|
|
return
|
|
}
|
|
|
|
if level == LogLevelInfo {
|
|
a.logger.Info(logCtx, msg)
|
|
logger.Info(logCtx, msg)
|
|
return
|
|
}
|
|
|
|
if level == LogLevelDebug {
|
|
a.logger.Debug(logCtx, msg)
|
|
logger.Debug(logCtx, msg)
|
|
return
|
|
}
|
|
|
|
if level == LogLevelWarning {
|
|
a.logger.Warn(logCtx, msg)
|
|
logger.Warn(logCtx, msg)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *APIImpl) GetTranslation(ctx context.Context, key string) string {
|
|
if a.pluginInstance.IsSystemPlugin {
|
|
return i18n.GetI18nManager().TranslateWox(ctx, key)
|
|
} else {
|
|
return i18n.GetI18nManager().TranslatePlugin(ctx, key, a.pluginInstance.PluginDirectory)
|
|
}
|
|
}
|
|
|
|
func (a *APIImpl) GetSetting(ctx context.Context, key string) string {
|
|
// try to get platform specific setting first
|
|
platformSpecificKey := key + "@" + util.GetCurrentPlatform()
|
|
v, exist := a.pluginInstance.Setting.Get(platformSpecificKey)
|
|
if exist {
|
|
return v
|
|
}
|
|
|
|
v, exist = a.pluginInstance.Setting.Get(key)
|
|
if exist {
|
|
return v
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (a *APIImpl) SaveSetting(ctx context.Context, key string, value string, isPlatformSpecific bool) {
|
|
finalKey := key
|
|
if isPlatformSpecific {
|
|
finalKey = key + "@" + util.GetCurrentPlatform()
|
|
} else {
|
|
// if not platform specific, remove platform specific setting, otherwise it will be loaded first
|
|
a.pluginInstance.Setting.Delete(key + "@" + util.GetCurrentPlatform())
|
|
}
|
|
|
|
existValue, exist := a.pluginInstance.Setting.Get(finalKey)
|
|
a.pluginInstance.Setting.Set(finalKey, value)
|
|
if !exist || (existValue != value) {
|
|
for _, callback := range a.pluginInstance.SettingChangeCallbacks {
|
|
callback(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *APIImpl) OnSettingChanged(ctx context.Context, callback func(key string, value string)) {
|
|
a.pluginInstance.SettingChangeCallbacks = append(a.pluginInstance.SettingChangeCallbacks, callback)
|
|
}
|
|
|
|
func (a *APIImpl) OnGetDynamicSetting(ctx context.Context, callback func(key string) string) {
|
|
a.pluginInstance.DynamicSettingCallbacks = append(a.pluginInstance.DynamicSettingCallbacks, callback)
|
|
}
|
|
|
|
func (a *APIImpl) OnDeepLink(ctx context.Context, callback func(arguments map[string]string)) {
|
|
if !a.pluginInstance.Metadata.IsSupportFeature(MetadataFeatureDeepLink) {
|
|
a.Log(ctx, LogLevelError, "plugin has no access to deep link feature")
|
|
return
|
|
}
|
|
|
|
a.pluginInstance.DeepLinkCallbacks = append(a.pluginInstance.DeepLinkCallbacks, callback)
|
|
}
|
|
|
|
func (a *APIImpl) OnUnload(ctx context.Context, callback func()) {
|
|
a.pluginInstance.UnloadCallbacks = append(a.pluginInstance.UnloadCallbacks, callback)
|
|
}
|
|
|
|
func (a *APIImpl) RegisterQueryCommands(ctx context.Context, commands []MetadataCommand) {
|
|
a.pluginInstance.Setting.QueryCommands.Set(lo.Map(commands, func(command MetadataCommand, _ int) setting.PluginQueryCommand {
|
|
return setting.PluginQueryCommand{
|
|
Command: command.Command,
|
|
Description: command.Description,
|
|
}
|
|
}))
|
|
}
|
|
|
|
func (a *APIImpl) AIChatStream(ctx context.Context, model common.Model, conversations []common.Conversation, options common.ChatOptions, callback common.ChatStreamFunc) error {
|
|
//check if plugin has the feature permission
|
|
if !a.pluginInstance.Metadata.IsSupportFeature(MetadataFeatureAI) {
|
|
return fmt.Errorf("plugin has no access to ai feature")
|
|
}
|
|
|
|
provider, providerErr := GetPluginManager().GetAIProvider(ctx, model.Provider)
|
|
if providerErr != nil {
|
|
return providerErr
|
|
}
|
|
|
|
// // resize images in the conversation
|
|
// for i, conversation := range conversations {
|
|
// for j, image := range conversation.Images {
|
|
// image.Resize(600)
|
|
// resizeImage(ctx, image, 600)
|
|
|
|
// // resize image if it's too large
|
|
// maxWidth := 600
|
|
// if image.Bounds().Dx() > maxWidth {
|
|
// start := util.GetSystemTimestamp()
|
|
// conversations[i].Images[j] = imaging.Resize(image, maxWidth, 0, imaging.Lanczos)
|
|
// a.Log(ctx, LogLevelDebug, fmt.Sprintf("resizing image (%d -> %d) in ai chat, cost %d ms", image.Bounds().Dx(), maxWidth, util.GetSystemTimestamp()-start))
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
stream, err := provider.ChatStream(ctx, model, conversations, options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if callback != nil {
|
|
util.Go(ctx, "ai chat stream", func() {
|
|
for {
|
|
streamResult, streamErr := stream.Receive(ctx)
|
|
if streamErr != nil {
|
|
// may be for loop too fast
|
|
if streamErr == ai.ChatStreamNoContentErr {
|
|
time.Sleep(time.Millisecond * 200)
|
|
continue
|
|
}
|
|
|
|
util.GetLogger().Info(ctx, fmt.Sprintf("AI: failed to read stream from ai provider: %s", streamErr.Error()))
|
|
callback(common.ChatStreamData{
|
|
Status: common.ChatStreamStatusError,
|
|
Data: streamErr.Error(),
|
|
ToolCalls: []common.ToolCallInfo{},
|
|
})
|
|
return
|
|
}
|
|
|
|
util.GetLogger().Debug(ctx, fmt.Sprintf("AI: Received stream from ai provider: status=%s, data=%s, tool calls=%d", streamResult.Status, streamResult.Data, len(streamResult.ToolCalls)))
|
|
|
|
a.applyStartTimeIfAbsent(&streamResult)
|
|
|
|
if streamResult.Status == common.ChatStreamStatusStreaming {
|
|
callback(streamResult)
|
|
continue
|
|
}
|
|
|
|
if streamResult.Status == common.ChatStreamStatusStreamed {
|
|
// execute tool calls
|
|
// we execute tool calls asynchronously, but wait for all tool calls to finish before sending the final result
|
|
var sw = sync.WaitGroup{}
|
|
|
|
for toolCallIndex, toolCall := range streamResult.ToolCalls {
|
|
util.GetLogger().Info(ctx, fmt.Sprintf("AI: Tool call is pending to execute, name: %s, args: %v", toolCall.Name, toolCall.Arguments))
|
|
|
|
for _, tool := range options.Tools {
|
|
if tool.Name == toolCall.Name {
|
|
sw.Add(1)
|
|
|
|
util.GetLogger().Info(ctx, fmt.Sprintf("AI: Executing tool: %s with args: %v, toolcall id: %s, toolcall status: %s", tool.Name, toolCall.Arguments, toolCall.Id, toolCall.Status))
|
|
|
|
// update tool call status to running and sync to caller
|
|
streamResult.Status = common.ChatStreamStatusRunningToolCall
|
|
streamResult.ToolCalls[toolCallIndex].Status = common.ToolCallStatusRunning
|
|
|
|
util.Go(ctx, "ai tool call execution", func() {
|
|
toolResponse, toolErr := tool.Callback(ctx, toolCall.Arguments)
|
|
if toolErr != nil {
|
|
util.GetLogger().Error(ctx, fmt.Sprintf("AI: tool execution failed: %s", toolErr.Error()))
|
|
streamResult.ToolCalls[toolCallIndex].Status = common.ToolCallStatusFailed
|
|
streamResult.ToolCalls[toolCallIndex].Response = toolErr.Error()
|
|
} else {
|
|
streamResult.ToolCalls[toolCallIndex].Status = common.ToolCallStatusSucceeded
|
|
streamResult.ToolCalls[toolCallIndex].Response = toolResponse.Text
|
|
streamResult.ToolCalls[toolCallIndex].EndTimestamp = util.GetSystemTimestamp()
|
|
}
|
|
|
|
callback(streamResult)
|
|
sw.Done()
|
|
}, func() {
|
|
util.GetLogger().Error(ctx, fmt.Sprintf("AI: tool execution failed with panic, name: %s", tool.Name))
|
|
streamResult.ToolCalls[toolCallIndex].Status = common.ToolCallStatusFailed
|
|
streamResult.ToolCalls[toolCallIndex].Response = "tool execution failed with panic"
|
|
|
|
callback(streamResult)
|
|
sw.Done()
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
sw.Wait()
|
|
|
|
anyToolCallFailed := lo.SomeBy(streamResult.ToolCalls, func(toolCall common.ToolCallInfo) bool {
|
|
return toolCall.Status == common.ToolCallStatusFailed
|
|
})
|
|
if anyToolCallFailed {
|
|
streamResult.Status = common.ChatStreamStatusError
|
|
callback(streamResult)
|
|
} else {
|
|
streamResult.Status = common.ChatStreamStatusFinished
|
|
callback(streamResult)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *APIImpl) applyStartTimeIfAbsent(streamResult *common.ChatStreamData) {
|
|
for toolCallIndex, toolCall := range streamResult.ToolCalls {
|
|
startTime := util.GetSystemTimestamp()
|
|
if v, ok := a.toolCallStartTimeMap.Load(toolCall.Id); ok {
|
|
startTime = v
|
|
} else {
|
|
a.toolCallStartTimeMap.Store(toolCall.Id, startTime)
|
|
}
|
|
streamResult.ToolCalls[toolCallIndex].StartTimestamp = startTime
|
|
}
|
|
}
|
|
|
|
func (a *APIImpl) OnMRURestore(ctx context.Context, callback func(mruData MRUData) (*QueryResult, error)) {
|
|
if !a.pluginInstance.Metadata.IsSupportFeature(MetadataFeatureMRU) {
|
|
a.Log(ctx, LogLevelError, "plugin has no access to MRU feature")
|
|
return
|
|
}
|
|
|
|
a.pluginInstance.MRURestoreCallbacks = append(a.pluginInstance.MRURestoreCallbacks, callback)
|
|
}
|
|
|
|
func NewAPI(instance *Instance) API {
|
|
apiImpl := &APIImpl{pluginInstance: instance}
|
|
logFolder := path.Join(util.GetLocation().GetLogPluginDirectory(), instance.Metadata.Name)
|
|
apiImpl.logger = util.CreateLogger(logFolder)
|
|
apiImpl.toolCallStartTimeMap = util.NewHashMap[string, int64]()
|
|
return apiImpl
|
|
}
|