From 8d8c827309f4f89a42b13a18b94c24eb4d56ab68 Mon Sep 17 00:00:00 2001 From: SoXX Date: Thu, 15 Aug 2024 09:56:40 +0200 Subject: [PATCH] feat(source): Include `source` parameter in gRPC server methods and initialization --- pkg/plug/grpc.go | 10 ++++++---- pkg/plug/server.go | 21 ++++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/pkg/plug/grpc.go b/pkg/plug/grpc.go index 583139d..4e23af4 100644 --- a/pkg/plug/grpc.go +++ b/pkg/plug/grpc.go @@ -17,14 +17,16 @@ type server struct { taskExecutionFunction TaskExecution sendMessageExecution SendMessageExecution getMessageExecution GetMessageExecution + source models.Source } -func NewGrpcServer(taskExecutionFunction TaskExecution, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer { +func NewGrpcServer(source models.Source, taskExecutionFunction TaskExecution, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer { return &server{ ctx: make(map[string]context.CancelFunc), taskExecutionFunction: taskExecutionFunction, sendMessageExecution: sendMessageExecution, getMessageExecution: getMessageExecution, + source: source, } } @@ -51,7 +53,7 @@ func (s *server) TaskStart(ctx context.Context, creation *gRPC.PlugTaskCreation) go func() { // FIXME: better implement this methode, works for now but needs refactoring - err := s.taskExecutionFunction(ctx, creation.UserSourceName, user, creation.DeepScrape, creation.ApiKey, func() { + err := s.taskExecutionFunction(ctx, s.source, creation.UserSourceName, user, creation.DeepScrape, creation.ApiKey, func() { s.removeTask(id) }) if err != nil { @@ -107,7 +109,7 @@ func (s *server) Ping(_ context.Context, request *gRPC.PingRequest) (*gRPC.PongR func (s *server) SendMessage(ctx context.Context, request *gRPC.SendMessageRequest) (*gRPC.SendMessageResponse, error) { messageResponse := gRPC.SendMessageResponse{Success: true} - err := s.sendMessageExecution(ctx, request.UserSourceId, request.UserSourceName, request.Message) + err := s.sendMessageExecution(ctx, s.source, request.UserSourceId, request.UserSourceName, request.Message) if err != nil { return nil, err } @@ -118,7 +120,7 @@ func (s *server) SendMessage(ctx context.Context, request *gRPC.SendMessageReque func (s *server) GetUserMessages(ctx context.Context, request *gRPC.GetMessagesRequest) (*gRPC.GetMessagesResponse, error) { messageResponse := gRPC.GetMessagesResponse{} - messages, err := s.getMessageExecution(ctx, request.UserSourceId, request.UserSourceName) + messages, err := s.getMessageExecution(ctx, s.source, request.UserSourceId, request.UserSourceName) if err != nil { return nil, err } diff --git a/pkg/plug/server.go b/pkg/plug/server.go index ac1292d..1831328 100644 --- a/pkg/plug/server.go +++ b/pkg/plug/server.go @@ -22,9 +22,9 @@ type Message struct { CreatedAt *timestamp.Timestamp } -type TaskExecution func(ctx context.Context, userSourceUsername string, anthroveUser models.User, deepScrape bool, apiKey string, cancelFunction func()) error -type SendMessageExecution func(ctx context.Context, userSourceID string, userSourceUsername string, message string) error -type GetMessageExecution func(ctx context.Context, userSourceID string, userSourceUsername string) ([]Message, error) +type TaskExecution func(ctx context.Context, source models.Source, userSourceUsername string, anthroveUser models.User, deepScrape bool, apiKey string, cancelFunction func()) error +type SendMessageExecution func(ctx context.Context, source models.Source, userSourceID string, userSourceUsername string, message string) error +type GetMessageExecution func(ctx context.Context, source models.Source, userSourceID string, userSourceUsername string) ([]Message, error) type Plug struct { address string @@ -47,13 +47,14 @@ func NewPlug(ctx context.Context, address string, port string, source models.Sou func (p *Plug) Listen() error { var err error + var source models.Source log.Print("Check if source exists") - _, err = database.GetSourceByDomain(p.ctx, p.source.Domain) + source, err = database.GetSourceByDomain(p.ctx, p.source.Domain) if err != nil { if err.Error() == otterError.NoDataFound { - log.Printf("iIitializing source!") - if _, err = database.CreateSource(p.ctx, p.source); err != nil { + log.Printf("Initalizing source!") + if source, err = database.CreateSource(p.ctx, p.source); err != nil { panic(err) } } else { @@ -61,6 +62,8 @@ func (p *Plug) Listen() error { } } + p.source = source + log.Print("Source exists") lis, err := net.Listen("tcp", fmt.Sprintf("%s:%s", p.address, p.port)) @@ -70,7 +73,7 @@ func (p *Plug) Listen() error { grpcServer := grpc.NewServer() - pb.RegisterPlugConnectorServer(grpcServer, NewGrpcServer(p.taskExecutionFunction, p.sendMessageExecution, p.getMessageExecution)) + pb.RegisterPlugConnectorServer(grpcServer, NewGrpcServer(p.source, p.taskExecutionFunction, p.sendMessageExecution, p.getMessageExecution)) err = grpcServer.Serve(lis) if err != nil { @@ -80,6 +83,10 @@ func (p *Plug) Listen() error { return nil } +func (p *Plug) GetSource() models.Source { + return p.source +} + func (p *Plug) TaskExecutionFunction(function TaskExecution) { p.taskExecutionFunction = function }