348 lines
11 KiB
Go
348 lines
11 KiB
Go
package plug
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"gorm.io/gorm"
|
|
"time"
|
|
|
|
"git.anthrove.art/Anthrove/otter-space-sdk/v4/pkg/database"
|
|
"git.anthrove.art/Anthrove/otter-space-sdk/v4/pkg/models"
|
|
gRPC "git.anthrove.art/Anthrove/plug-sdk/v4/pkg/grpc"
|
|
gonanoid "github.com/matoous/go-nanoid/v2"
|
|
log "github.com/sirupsen/logrus"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
type server struct {
|
|
gRPC.UnimplementedPlugConnectorServer
|
|
ctx map[string]context.CancelFunc
|
|
plugInterface Plug
|
|
sendMessageExecution SendMessageExecution
|
|
getMessageExecution GetMessageExecution
|
|
source models.Source
|
|
}
|
|
|
|
func NewGrpcServer(source models.Source, plugAPIInterface Plug, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer {
|
|
return &server{
|
|
ctx: make(map[string]context.CancelFunc),
|
|
plugInterface: plugAPIInterface,
|
|
sendMessageExecution: sendMessageExecution,
|
|
getMessageExecution: getMessageExecution,
|
|
source: source,
|
|
}
|
|
}
|
|
|
|
func (s *server) TaskStart(ctx context.Context, creation *gRPC.PlugTaskCreation) (*gRPC.PlugTaskStatus, error) {
|
|
ctx, span := tracer.Start(ctx, "TaskStart")
|
|
defer span.End()
|
|
|
|
var plugTaskState gRPC.PlugTaskStatus
|
|
|
|
id, err := gonanoid.New(25)
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
return nil, err
|
|
}
|
|
span.AddEvent("Generated task ID", trace.WithAttributes(attribute.String("task_id", id)))
|
|
|
|
scrapeTaskHistory := models.ScrapeHistory{
|
|
ScrapeTaskID: models.ScrapeTaskID(id),
|
|
UserSourceID: models.UserSourceID(creation.UserSourceId),
|
|
}
|
|
scrapeTaskHistory, err = database.CreateScrapeHistory(ctx, scrapeTaskHistory)
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
return nil, err
|
|
}
|
|
span.AddEvent("Creates ScrapeTaskHistory", trace.WithAttributes(
|
|
attribute.String("user_source_id", creation.UserSourceId),
|
|
attribute.String("scrape_task_id", id),
|
|
))
|
|
|
|
plugTaskState.TaskId = id
|
|
plugTaskState.TaskState = gRPC.PlugTaskState_RUNNING
|
|
|
|
userSource, err := database.GetUserSourceByID(ctx, models.UserSourceID(creation.UserSourceId))
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
|
|
dberr := database.UpdateScrapeHistory(ctx, models.ScrapeHistory{
|
|
ScrapeTaskID: models.ScrapeTaskID(id),
|
|
UserSourceID: userSource.ID,
|
|
FinishedAt: time.Now(),
|
|
Error: err.Error(),
|
|
})
|
|
|
|
return nil, errors.Join(err, dberr)
|
|
}
|
|
span.AddEvent("Retrieved user source", trace.WithAttributes(attribute.String("user_source_id", creation.UserSourceId)))
|
|
|
|
if !userSource.AccountValidate {
|
|
err = errors.New("user is not validated")
|
|
|
|
log.WithContext(ctx).WithError(err).WithField("task_id", id).Error("Task execution failed")
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
|
|
plugTaskState.TaskState = gRPC.PlugTaskState_STOPPED
|
|
|
|
dberr := database.UpdateScrapeHistory(ctx, models.ScrapeHistory{
|
|
ScrapeTaskID: models.ScrapeTaskID(id),
|
|
UserSourceID: userSource.ID,
|
|
FinishedAt: time.Now(),
|
|
Error: err.Error(),
|
|
})
|
|
|
|
return &plugTaskState, errors.Join(err, dberr)
|
|
}
|
|
|
|
// gRPC closes the context after the call ended. So the whole scrapping stopped without waiting
|
|
// by using this method we assign a new context to each new request we get.
|
|
// This can be used for example to close the context with the given id
|
|
ctx = trace.ContextWithSpanContext(context.Background(), trace.NewSpanContext(trace.SpanContextConfig{TraceID: span.SpanContext().TraceID()}))
|
|
taskCtx, cancel := context.WithCancel(ctx)
|
|
s.ctx[id] = cancel
|
|
span.AddEvent("Created new context for task", trace.WithAttributes(attribute.String("task_id", id)))
|
|
|
|
log.WithContext(taskCtx).WithFields(log.Fields{
|
|
"task_id": id,
|
|
"user_source_id": creation.UserSourceId,
|
|
}).Debug("Starting task")
|
|
|
|
db, err := database.GetGorm(taskCtx)
|
|
if err != nil {
|
|
log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed")
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
dberr := database.UpdateScrapeHistory(taskCtx, models.ScrapeHistory{
|
|
ScrapeTaskID: models.ScrapeTaskID(id),
|
|
UserSourceID: userSource.ID,
|
|
FinishedAt: time.Now(),
|
|
Error: errorString(err),
|
|
})
|
|
return &plugTaskState, errors.Join(err, dberr)
|
|
}
|
|
|
|
anthroveUserFavCount, err := getUserFavoriteCountFromDatabase(taskCtx, db, userSource.UserID, userSource.ID)
|
|
if err != nil {
|
|
log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed")
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
dberr := database.UpdateScrapeHistory(taskCtx, models.ScrapeHistory{
|
|
ScrapeTaskID: models.ScrapeTaskID(id),
|
|
UserSourceID: userSource.ID,
|
|
FinishedAt: time.Now(),
|
|
Error: errorString(err),
|
|
})
|
|
return &plugTaskState, errors.Join(err, dberr)
|
|
}
|
|
|
|
go func() {
|
|
log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Failed to get Gorm client")
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
|
|
taskSummery, err := algorithm(taskCtx, s.plugInterface, userSource, anthroveUserFavCount, creation.DeepScrape, creation.ApiKey)
|
|
if err != nil {
|
|
log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed")
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
} else {
|
|
log.WithContext(taskCtx).WithField("task_id", id).Debug("Task completed successfully")
|
|
span.AddEvent("Task completed successfully", trace.WithAttributes(attribute.String("task_id", id)))
|
|
}
|
|
|
|
err = database.UpdateScrapeHistory(taskCtx, models.ScrapeHistory{
|
|
ScrapeTaskID: models.ScrapeTaskID(id),
|
|
UserSourceID: userSource.ID,
|
|
FinishedAt: time.Now(),
|
|
Error: errorString(err),
|
|
AddedPosts: taskSummery.AddedPosts,
|
|
DeletedPosts: taskSummery.DeletedPosts,
|
|
})
|
|
|
|
if err != nil {
|
|
log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed")
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}
|
|
s.removeTask(id)
|
|
}()
|
|
|
|
span.SetAttributes(attribute.String("task_id", id))
|
|
return &plugTaskState, nil
|
|
}
|
|
|
|
func (s *server) TaskStatus(ctx context.Context, task *gRPC.PlugTask) (*gRPC.PlugTaskStatus, error) {
|
|
ctx, span := tracer.Start(ctx, "TaskStatus")
|
|
defer span.End()
|
|
|
|
var plugTaskState gRPC.PlugTaskStatus
|
|
|
|
_, found := s.ctx[task.TaskId]
|
|
plugTaskState.TaskId = task.TaskId
|
|
|
|
plugTaskState.TaskState = gRPC.PlugTaskState_RUNNING
|
|
|
|
if !found {
|
|
plugTaskState.TaskState = gRPC.PlugTaskState_UNKNOWN
|
|
}
|
|
span.AddEvent("Determined task state", trace.WithAttributes(attribute.String("task_id", task.TaskId), attribute.String("task_state", plugTaskState.TaskState.String())))
|
|
|
|
log.WithContext(ctx).WithFields(log.Fields{
|
|
"task_id": task.TaskId,
|
|
"task_state": plugTaskState.TaskState,
|
|
}).Debug("Task status requested")
|
|
|
|
span.SetAttributes(attribute.String("task_id", task.TaskId))
|
|
return &plugTaskState, nil
|
|
}
|
|
|
|
func (s *server) TaskCancel(ctx context.Context, task *gRPC.PlugTask) (*gRPC.PlugTaskStatus, error) {
|
|
ctx, span := tracer.Start(ctx, "TaskCancel")
|
|
defer span.End()
|
|
|
|
var plugTaskState gRPC.PlugTaskStatus
|
|
|
|
plugTaskState.TaskState = gRPC.PlugTaskState_STOPPED
|
|
plugTaskState.TaskId = task.TaskId
|
|
|
|
s.removeTask(task.TaskId)
|
|
span.AddEvent("Removed task", trace.WithAttributes(attribute.String("task_id", task.TaskId)))
|
|
|
|
log.WithContext(ctx).WithFields(log.Fields{
|
|
"task_id": task.TaskId,
|
|
"task_state": plugTaskState.TaskState,
|
|
}).Debug("Task cancellation requested")
|
|
|
|
span.SetAttributes(attribute.String("task_id", task.TaskId))
|
|
return &plugTaskState, nil
|
|
}
|
|
|
|
func (s *server) GetUserMessages(ctx context.Context, message *gRPC.GetMessagesRequest) (*gRPC.GetMessagesResponse, error) {
|
|
ctx, span := tracer.Start(ctx, "GetUserMessages")
|
|
defer span.End()
|
|
|
|
userSourceID := models.UserSourceID(message.UserSourceId)
|
|
|
|
userSource, err := database.GetUserSourceByID(ctx, userSourceID)
|
|
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
log.WithContext(ctx).WithError(err).Error("Getting userSource")
|
|
return nil, err
|
|
}
|
|
|
|
messages, err := s.getMessageExecution(ctx, userSource)
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
log.WithContext(ctx).WithError(err).Error("Execution function")
|
|
return nil, err
|
|
}
|
|
|
|
var response gRPC.GetMessagesResponse
|
|
for _, message := range messages {
|
|
response.Messages = append(response.Messages, &gRPC.Message{
|
|
FromUserSourceId: string(userSource.ID),
|
|
CreatedAt: message.CreatedAt,
|
|
Body: message.Body,
|
|
Title: message.Title,
|
|
})
|
|
}
|
|
|
|
span.SetAttributes(
|
|
attribute.String("user_source_id", string(userSource.ID)),
|
|
attribute.String("user_id", string(userSource.UserID)),
|
|
attribute.String("source_id", string(userSource.SourceID)),
|
|
)
|
|
|
|
fields := log.Fields{
|
|
"user_source_id": userSource.ID,
|
|
"user_id": userSource.UserID,
|
|
"source_id": userSource.SourceID,
|
|
"len_messages": len(messages),
|
|
}
|
|
|
|
log.WithContext(ctx).WithFields(fields).Debug("Got User messages")
|
|
|
|
return &response, err
|
|
|
|
}
|
|
|
|
func (s *server) SendMessage(ctx context.Context, message *gRPC.SendMessageRequest) (*gRPC.SendMessageResponse, error) {
|
|
ctx, span := tracer.Start(ctx, "SendMessage")
|
|
defer span.End()
|
|
|
|
response := &gRPC.SendMessageResponse{
|
|
Success: false,
|
|
}
|
|
|
|
sourceID := models.UserSourceID(message.UserSourceId)
|
|
userSource := models.UserSource{BaseModel: models.BaseModel[models.UserSourceID]{ID: sourceID}}
|
|
|
|
err := s.sendMessageExecution(ctx, userSource, message.Message)
|
|
if err != nil {
|
|
log.WithContext(ctx).WithError(err).Error("Sending message execution")
|
|
span.RecordError(err)
|
|
span.SetStatus(codes.Error, err.Error())
|
|
|
|
return response, err
|
|
}
|
|
|
|
response.Success = true
|
|
return response, err
|
|
}
|
|
|
|
func (s *server) Ping(ctx context.Context, ping *gRPC.PingRequest) (*gRPC.PongResponse, error) {
|
|
ctx, span := tracer.Start(ctx, "Ping")
|
|
defer span.End()
|
|
|
|
var pong gRPC.PongResponse
|
|
pong.Message = ping.Message
|
|
pong.Timestamp = ping.Timestamp
|
|
|
|
fields := log.Fields{
|
|
"messsage": ping.Message,
|
|
"timestamp": ping.Timestamp,
|
|
}
|
|
log.WithContext(ctx).WithFields(fields).Trace("Got pinged")
|
|
|
|
return &pong, nil
|
|
}
|
|
|
|
func (s *server) removeTask(taskID string) {
|
|
fn, exists := s.ctx[taskID]
|
|
if !exists {
|
|
return
|
|
}
|
|
fn()
|
|
delete(s.ctx, taskID)
|
|
}
|
|
|
|
func errorString(err error) string {
|
|
if err != nil {
|
|
return err.Error()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// getUserFavoriteCountFromDatabase
|
|
func getUserFavoriteCountFromDatabase(ctx context.Context, gorm *gorm.DB, userID models.UserID, userSourceID models.UserSourceID) (int64, error) {
|
|
var count int64
|
|
|
|
err := gorm.WithContext(ctx).Model(&models.UserFavorite{}).Where("user_id = ? AND user_source_id = ?", userID, userSourceID).Count(&count).Error
|
|
if err != nil {
|
|
return count, err
|
|
}
|
|
|
|
return count, nil
|
|
}
|