plug-sdk/pkg/plug/grpc.go
SoXX 38cecba455
All checks were successful
Gitea Build Check / Build (push) Successful in 48s
Gitea Build Check / Build (pull_request) Successful in 50s
fix(grpc): spelling
2024-09-04 14:56:27 +02:00

251 lines
7.7 KiB
Go

package plug
import (
"context"
"errors"
"git.anthrove.art/Anthrove/otter-space-sdk/v4/pkg/database"
"git.anthrove.art/Anthrove/otter-space-sdk/v4/pkg/models"
gRPC "git.anthrove.art/Anthrove/plug-sdk/v3/pkg/grpc"
gonanoid "github.com/matoous/go-nanoid/v2"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
type server struct {
gRPC.UnimplementedPlugConnectorServer
ctx map[string]context.CancelFunc
taskExecutionFunction TaskExecution
sendMessageExecution SendMessageExecution
getMessageExecution GetMessageExecution
source models.Source
}
func NewGrpcServer(source models.Source, taskExecutionFunction TaskExecution, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer {
return &server{
ctx: make(map[string]context.CancelFunc),
taskExecutionFunction: taskExecutionFunction,
sendMessageExecution: sendMessageExecution,
getMessageExecution: getMessageExecution,
source: source,
}
}
func (s *server) TaskStart(ctx context.Context, creation *gRPC.PlugTaskCreation) (*gRPC.PlugTaskStatus, error) {
ctx, span := tracer.Start(ctx, "TaskStart")
defer span.End()
var plugTaskState gRPC.PlugTaskStatus
id, err := gonanoid.New()
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
span.AddEvent("Generated task ID", trace.WithAttributes(attribute.String("task_id", id)))
plugTaskState.TaskId = id
plugTaskState.TaskState = gRPC.PlugTaskState_RUNNING
userSource, err := database.GetUserSourceByID(ctx, models.UserSourceID(creation.UserSourceId))
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
span.AddEvent("Retrieved user source", trace.WithAttributes(attribute.String("user_source_id", creation.UserSourceId)))
if !userSource.AccountValidate {
err = errors.New("user is not validated")
log.WithContext(ctx).WithError(err).WithField("task_id", id).Error("Task execution failed")
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
plugTaskState.TaskState = gRPC.PlugTaskState_STOPPED
return &plugTaskState, err
}
// gRPC closes the context after the call ended. So the whole scrapping stopped without waiting
// by using this method we assign a new context to each new request we get.
// This can be used for example to close the context with the given id
ctx = trace.ContextWithSpanContext(context.Background(), trace.NewSpanContext(trace.SpanContextConfig{TraceID: span.SpanContext().TraceID()}))
taskCtx, cancel := context.WithCancel(ctx)
s.ctx[id] = cancel
span.AddEvent("Created new context for task", trace.WithAttributes(attribute.String("task_id", id)))
log.WithContext(taskCtx).WithFields(log.Fields{
"task_id": id,
"user_source_id": creation.UserSourceId,
}).Debug("Starting task")
go func() {
err := s.taskExecutionFunction(taskCtx, userSource, creation.DeepScrape, creation.ApiKey, func() {
s.removeTask(id)
})
if err != nil {
log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed")
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
} else {
log.WithContext(taskCtx).WithField("task_id", id).Debug("Task completed successfully")
span.AddEvent("Task completed successfully", trace.WithAttributes(attribute.String("task_id", id)))
}
}()
span.SetAttributes(attribute.String("task_id", id))
return &plugTaskState, nil
}
func (s *server) TaskStatus(ctx context.Context, task *gRPC.PlugTask) (*gRPC.PlugTaskStatus, error) {
ctx, span := tracer.Start(ctx, "TaskStatus")
defer span.End()
var plugTaskState gRPC.PlugTaskStatus
_, found := s.ctx[task.TaskId]
plugTaskState.TaskId = task.TaskId
plugTaskState.TaskState = gRPC.PlugTaskState_RUNNING
if !found {
plugTaskState.TaskState = gRPC.PlugTaskState_UNKNOWN
}
span.AddEvent("Determined task state", trace.WithAttributes(attribute.String("task_id", task.TaskId), attribute.String("task_state", plugTaskState.TaskState.String())))
log.WithContext(ctx).WithFields(log.Fields{
"task_id": task.TaskId,
"task_state": plugTaskState.TaskState,
}).Debug("Task status requested")
span.SetAttributes(attribute.String("task_id", task.TaskId))
return &plugTaskState, nil
}
func (s *server) TaskCancel(ctx context.Context, task *gRPC.PlugTask) (*gRPC.PlugTaskStatus, error) {
ctx, span := tracer.Start(ctx, "TaskCancel")
defer span.End()
var plugTaskState gRPC.PlugTaskStatus
plugTaskState.TaskState = gRPC.PlugTaskState_STOPPED
plugTaskState.TaskId = task.TaskId
s.removeTask(task.TaskId)
span.AddEvent("Removed task", trace.WithAttributes(attribute.String("task_id", task.TaskId)))
log.WithContext(ctx).WithFields(log.Fields{
"task_id": task.TaskId,
"task_state": plugTaskState.TaskState,
}).Debug("Task cancellation requested")
span.SetAttributes(attribute.String("task_id", task.TaskId))
return &plugTaskState, nil
}
func (s *server) removeTask(taskID string) {
fn, exists := s.ctx[taskID]
if !exists {
return
}
fn()
delete(s.ctx, taskID)
}
func (s *server) GetUserMessages(ctx context.Context, message *gRPC.GetMessagesRequest) (*gRPC.GetMessagesResponse, error) {
ctx, span := tracer.Start(ctx, "GetUserMessages")
defer span.End()
userSourceID := models.UserSourceID(message.UserSourceId)
userSource, err := database.GetUserSourceByID(ctx, userSourceID)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
log.WithContext(ctx).WithError(err).Error("Getting userSource")
return nil, err
}
messages, err := s.getMessageExecution(ctx, userSource)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
log.WithContext(ctx).WithError(err).Error("Execution function")
return nil, err
}
var response gRPC.GetMessagesResponse
for _, message := range messages {
response.Messages = append(response.Messages, &gRPC.Message{
FromUserSourceId: string(userSource.ID),
CreatedAt: message.CreatedAt,
Body: message.Body,
Title: message.Title,
})
}
span.SetAttributes(
attribute.String("user_source_id", string(userSource.ID)),
attribute.String("user_id", string(userSource.UserID)),
attribute.String("source_id", string(userSource.SourceID)),
)
fields := log.Fields{
"user_source_id": userSource.ID,
"user_id": userSource.UserID,
"source_id": userSource.SourceID,
"len_messages": len(messages),
}
log.WithContext(ctx).WithFields(fields).Debug("Got User messages")
return &response, err
}
func (s *server) SendMessage(ctx context.Context, message *gRPC.SendMessageRequest) (*gRPC.SendMessageResponse, error) {
ctx, span := tracer.Start(ctx, "SendMessage")
defer span.End()
response := &gRPC.SendMessageResponse{
Success: false,
}
sourceID := models.UserSourceID(message.UserSourceId)
userSource := models.UserSource{BaseModel: models.BaseModel[models.UserSourceID]{ID: sourceID}}
err := s.sendMessageExecution(ctx, userSource, message.Message)
if err != nil {
log.WithContext(ctx).WithError(err).Error("Sending message execution")
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return response, err
}
response.Success = true
return response, err
}
func (s *server) Ping(ctx context.Context, ping *gRPC.PingRequest) (*gRPC.PongResponse, error) {
ctx, span := tracer.Start(ctx, "Ping")
defer span.End()
var pong gRPC.PongResponse
pong.Message = ping.Message
pong.Timestamp = ping.Timestamp
fields := log.Fields{
"messsage": ping.Message,
"timestamp": ping.Timestamp,
}
log.WithContext(ctx).WithFields(fields).Trace("Got pinged")
return &pong, nil
}