feat(source): Include source parameter in gRPC server methods and initialization

This commit is contained in:
SoXX 2024-08-15 09:56:40 +02:00
parent 87ae16ff1b
commit 8d8c827309
2 changed files with 20 additions and 11 deletions

View File

@ -17,14 +17,16 @@ type server struct {
taskExecutionFunction TaskExecution
sendMessageExecution SendMessageExecution
getMessageExecution GetMessageExecution
source models.Source
}
func NewGrpcServer(taskExecutionFunction TaskExecution, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer {
func NewGrpcServer(source models.Source, taskExecutionFunction TaskExecution, sendMessageExecution SendMessageExecution, getMessageExecution GetMessageExecution) gRPC.PlugConnectorServer {
return &server{
ctx: make(map[string]context.CancelFunc),
taskExecutionFunction: taskExecutionFunction,
sendMessageExecution: sendMessageExecution,
getMessageExecution: getMessageExecution,
source: source,
}
}
@ -51,7 +53,7 @@ func (s *server) TaskStart(ctx context.Context, creation *gRPC.PlugTaskCreation)
go func() {
// FIXME: better implement this methode, works for now but needs refactoring
err := s.taskExecutionFunction(ctx, creation.UserSourceName, user, creation.DeepScrape, creation.ApiKey, func() {
err := s.taskExecutionFunction(ctx, s.source, creation.UserSourceName, user, creation.DeepScrape, creation.ApiKey, func() {
s.removeTask(id)
})
if err != nil {
@ -107,7 +109,7 @@ func (s *server) Ping(_ context.Context, request *gRPC.PingRequest) (*gRPC.PongR
func (s *server) SendMessage(ctx context.Context, request *gRPC.SendMessageRequest) (*gRPC.SendMessageResponse, error) {
messageResponse := gRPC.SendMessageResponse{Success: true}
err := s.sendMessageExecution(ctx, request.UserSourceId, request.UserSourceName, request.Message)
err := s.sendMessageExecution(ctx, s.source, request.UserSourceId, request.UserSourceName, request.Message)
if err != nil {
return nil, err
}
@ -118,7 +120,7 @@ func (s *server) SendMessage(ctx context.Context, request *gRPC.SendMessageReque
func (s *server) GetUserMessages(ctx context.Context, request *gRPC.GetMessagesRequest) (*gRPC.GetMessagesResponse, error) {
messageResponse := gRPC.GetMessagesResponse{}
messages, err := s.getMessageExecution(ctx, request.UserSourceId, request.UserSourceName)
messages, err := s.getMessageExecution(ctx, s.source, request.UserSourceId, request.UserSourceName)
if err != nil {
return nil, err
}

View File

@ -22,9 +22,9 @@ type Message struct {
CreatedAt *timestamp.Timestamp
}
type TaskExecution func(ctx context.Context, userSourceUsername string, anthroveUser models.User, deepScrape bool, apiKey string, cancelFunction func()) error
type SendMessageExecution func(ctx context.Context, userSourceID string, userSourceUsername string, message string) error
type GetMessageExecution func(ctx context.Context, userSourceID string, userSourceUsername string) ([]Message, error)
type TaskExecution func(ctx context.Context, source models.Source, userSourceUsername string, anthroveUser models.User, deepScrape bool, apiKey string, cancelFunction func()) error
type SendMessageExecution func(ctx context.Context, source models.Source, userSourceID string, userSourceUsername string, message string) error
type GetMessageExecution func(ctx context.Context, source models.Source, userSourceID string, userSourceUsername string) ([]Message, error)
type Plug struct {
address string
@ -47,13 +47,14 @@ func NewPlug(ctx context.Context, address string, port string, source models.Sou
func (p *Plug) Listen() error {
var err error
var source models.Source
log.Print("Check if source exists")
_, err = database.GetSourceByDomain(p.ctx, p.source.Domain)
source, err = database.GetSourceByDomain(p.ctx, p.source.Domain)
if err != nil {
if err.Error() == otterError.NoDataFound {
log.Printf("iIitializing source!")
if _, err = database.CreateSource(p.ctx, p.source); err != nil {
log.Printf("Initalizing source!")
if source, err = database.CreateSource(p.ctx, p.source); err != nil {
panic(err)
}
} else {
@ -61,6 +62,8 @@ func (p *Plug) Listen() error {
}
}
p.source = source
log.Print("Source exists")
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%s", p.address, p.port))
@ -70,7 +73,7 @@ func (p *Plug) Listen() error {
grpcServer := grpc.NewServer()
pb.RegisterPlugConnectorServer(grpcServer, NewGrpcServer(p.taskExecutionFunction, p.sendMessageExecution, p.getMessageExecution))
pb.RegisterPlugConnectorServer(grpcServer, NewGrpcServer(p.source, p.taskExecutionFunction, p.sendMessageExecution, p.getMessageExecution))
err = grpcServer.Serve(lis)
if err != nil {
@ -80,6 +83,10 @@ func (p *Plug) Listen() error {
return nil
}
func (p *Plug) GetSource() models.Source {
return p.source
}
func (p *Plug) TaskExecutionFunction(function TaskExecution) {
p.taskExecutionFunction = function
}