feat(source): Include source parameter in gRPC server methods and initialization

This commit is contained in:
SoXX 2024-08-15 09:56:40 +02:00
parent c01290847d
commit b6974c4a9f
2 changed files with 20 additions and 11 deletions

View File

@ -17,14 +17,16 @@ type server struct {
taskExecutionFunction TaskExecution taskExecutionFunction TaskExecution
sendMessageExecution SendMessageExecution sendMessageExecution SendMessageExecution
getMessageExecution GetMessageExecution getMessageExecution GetMessageExecution
source models.Source
} }
func NewGrpcServer(taskExecutionFunction TaskExecution, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer { func NewGrpcServer(source models.Source, taskExecutionFunction TaskExecution, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer {
return &server{ return &server{
ctx: make(map[string]context.CancelFunc), ctx: make(map[string]context.CancelFunc),
taskExecutionFunction: taskExecutionFunction, taskExecutionFunction: taskExecutionFunction,
sendMessageExecution: sendMessageExecution, sendMessageExecution: sendMessageExecution,
getMessageExecution: getMessageExecution, getMessageExecution: getMessageExecution,
source: source,
} }
} }
@ -51,7 +53,7 @@ func (s *server) TaskStart(ctx context.Context, creation *gRPC.PlugTaskCreation)
go func() { go func() {
// FIXME: better implement this methode, works for now but needs refactoring // FIXME: better implement this methode, works for now but needs refactoring
err := s.taskExecutionFunction(ctx, creation.UserSourceName, user, creation.DeepScrape, creation.ApiKey, func() { err := s.taskExecutionFunction(ctx, s.source, creation.UserSourceName, user, creation.DeepScrape, creation.ApiKey, func() {
s.removeTask(id) s.removeTask(id)
}) })
if err != nil { if err != nil {
@ -107,7 +109,7 @@ func (s *server) Ping(_ context.Context, request *gRPC.PingRequest) (*gRPC.PongR
func (s *server) SendMessage(ctx context.Context, request *gRPC.SendMessageRequest) (*gRPC.SendMessageResponse, error) { func (s *server) SendMessage(ctx context.Context, request *gRPC.SendMessageRequest) (*gRPC.SendMessageResponse, error) {
messageResponse := gRPC.SendMessageResponse{Success: true} messageResponse := gRPC.SendMessageResponse{Success: true}
err := s.sendMessageExecution(ctx, request.UserSourceId, request.UserSourceName, request.Message) err := s.sendMessageExecution(ctx, s.source, request.UserSourceId, request.UserSourceName, request.Message)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,7 +120,7 @@ func (s *server) SendMessage(ctx context.Context, request *gRPC.SendMessageReque
func (s *server) GetUserMessages(ctx context.Context, request *gRPC.GetMessagesRequest) (*gRPC.GetMessagesResponse, error) { func (s *server) GetUserMessages(ctx context.Context, request *gRPC.GetMessagesRequest) (*gRPC.GetMessagesResponse, error) {
messageResponse := gRPC.GetMessagesResponse{} messageResponse := gRPC.GetMessagesResponse{}
messages, err := s.getMessageExecution(ctx, request.UserSourceId, request.UserSourceName) messages, err := s.getMessageExecution(ctx, s.source, request.UserSourceId, request.UserSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,9 +22,9 @@ type Message struct {
CreatedAt *timestamp.Timestamp CreatedAt *timestamp.Timestamp
} }
type TaskExecution func(ctx context.Context, userSourceUsername string, anthroveUser models.User, deepScrape bool, apiKey string, cancelFunction func()) error type TaskExecution func(ctx context.Context, source models.Source, userSourceUsername string, anthroveUser models.User, deepScrape bool, apiKey string, cancelFunction func()) error
type SendMessageExecution func(ctx context.Context, userSourceID string, userSourceUsername string, message string) error type SendMessageExecution func(ctx context.Context, source models.Source, userSourceID string, userSourceUsername string, message string) error
type GetMessageExecution func(ctx context.Context, userSourceID string, userSourceUsername string) ([]Message, error) type GetMessageExecution func(ctx context.Context, source models.Source, userSourceID string, userSourceUsername string) ([]Message, error)
type Plug struct { type Plug struct {
address string address string
@ -47,13 +47,14 @@ func NewPlug(ctx context.Context, address string, port string, source models.Sou
func (p *Plug) Listen() error { func (p *Plug) Listen() error {
var err error var err error
var source models.Source
log.Print("Check if source exists") log.Print("Check if source exists")
_, err = database.GetSourceByDomain(p.ctx, p.source.Domain) source, err = database.GetSourceByDomain(p.ctx, p.source.Domain)
if err != nil { if err != nil {
if err.Error() == otterError.NoDataFound { if err.Error() == otterError.NoDataFound {
log.Printf("iIitializing source!") log.Printf("Initalizing source!")
if _, err = database.CreateSource(p.ctx, p.source); err != nil { if source, err = database.CreateSource(p.ctx, p.source); err != nil {
panic(err) panic(err)
} }
} else { } else {
@ -61,6 +62,8 @@ func (p *Plug) Listen() error {
} }
} }
p.source = source
log.Print("Source exists") log.Print("Source exists")
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%s", p.address, p.port)) lis, err := net.Listen("tcp", fmt.Sprintf("%s:%s", p.address, p.port))
@ -70,7 +73,7 @@ func (p *Plug) Listen() error {
grpcServer := grpc.NewServer() grpcServer := grpc.NewServer()
pb.RegisterPlugConnectorServer(grpcServer, NewGrpcServer(p.taskExecutionFunction, p.sendMessageExecution, p.getMessageExecution)) pb.RegisterPlugConnectorServer(grpcServer, NewGrpcServer(p.source, p.taskExecutionFunction, p.sendMessageExecution, p.getMessageExecution))
err = grpcServer.Serve(lis) err = grpcServer.Serve(lis)
if err != nil { if err != nil {
@ -80,6 +83,10 @@ func (p *Plug) Listen() error {
return nil return nil
} }
func (p *Plug) GetSource() models.Source {
return p.source
}
func (p *Plug) TaskExecutionFunction(function TaskExecution) { func (p *Plug) TaskExecutionFunction(function TaskExecution) {
p.taskExecutionFunction = function p.taskExecutionFunction = function
} }