From 408c97743256db480e86b52b9743c14e3b5febff Mon Sep 17 00:00:00 2001 From: SoXX Date: Sat, 26 Oct 2024 20:30:12 +0200 Subject: [PATCH] fix: nil pointer for DB --- pkg/plug/algorithm.go | 23 +--------------------- pkg/plug/grpc.go | 45 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/pkg/plug/algorithm.go b/pkg/plug/algorithm.go index 1950c54..fc29956 100644 --- a/pkg/plug/algorithm.go +++ b/pkg/plug/algorithm.go @@ -6,7 +6,6 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" - "gorm.io/gorm" ) type User struct { @@ -31,7 +30,7 @@ type Plug interface { GetUserProfile(ctx context.Context, apiKey string, userSource models.UserSource) (User, error) } -func Algorithm(ctx context.Context, plugInterface Plug, db *gorm.DB, userSource models.UserSource, deepScrape bool, apiKey string) (TaskSummery, error) { +func algorithm(ctx context.Context, plugInterface Plug, userSource models.UserSource, anthroveUserFavCount int64, deepScrape bool, apiKey string) (TaskSummery, error) { ctx, span := tracer.Start(ctx, "mainScrapeAlgorithm") defer span.End() @@ -54,14 +53,6 @@ func Algorithm(ctx context.Context, plugInterface Plug, db *gorm.DB, userSource DeletedPosts: 0, } - anthroveUserFavCount, err := getUserFavoriteCountFromDatabase(ctx, db, userSource.UserID, userSource.ID) - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - log.WithContext(ctx).WithFields(basicLoggingInfo).WithError(err).Error("Failed to get user favorite count from db") - return taskSummery, err - } - profile, err := plugInterface.GetUserProfile(ctx, apiKey, userSource) if err != nil { return taskSummery, err @@ -115,15 +106,3 @@ outer: log.WithContext(ctx).WithFields(basicLoggingInfo).Info("Completed scraping algorithm") return taskSummery, nil } - -// getUserFavoriteCountFromDatabase -func getUserFavoriteCountFromDatabase(ctx context.Context, gorm *gorm.DB, userID models.UserID, userSourceID models.UserSourceID) (int64, error) { - var count int64 - - err := gorm.WithContext(ctx).Model(&models.UserFavorite{}).Where("user_id = ? AND user_source_id = ?", userID, userSourceID).Count(&count).Error - if err != nil { - return count, err - } - - return count, nil -} diff --git a/pkg/plug/grpc.go b/pkg/plug/grpc.go index 379671a..fd45029 100644 --- a/pkg/plug/grpc.go +++ b/pkg/plug/grpc.go @@ -3,6 +3,7 @@ package plug import ( "context" "errors" + "gorm.io/gorm" "time" "git.anthrove.art/Anthrove/otter-space-sdk/v4/pkg/database" @@ -114,14 +115,40 @@ func (s *server) TaskStart(ctx context.Context, creation *gRPC.PlugTaskCreation) "user_source_id": creation.UserSourceId, }).Debug("Starting task") + db, err := database.GetGorm(taskCtx) + if err != nil { + log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + dberr := database.UpdateScrapeHistory(taskCtx, models.ScrapeHistory{ + ScrapeTaskID: models.ScrapeTaskID(id), + UserSourceID: userSource.ID, + FinishedAt: time.Now(), + Error: errorString(err), + }) + return &plugTaskState, errors.Join(err, dberr) + } + + anthroveUserFavCount, err := getUserFavoriteCountFromDatabase(taskCtx, db, userSource.UserID, userSource.ID) + if err != nil { + log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + dberr := database.UpdateScrapeHistory(taskCtx, models.ScrapeHistory{ + ScrapeTaskID: models.ScrapeTaskID(id), + UserSourceID: userSource.ID, + FinishedAt: time.Now(), + Error: errorString(err), + }) + return &plugTaskState, errors.Join(err, dberr) + } + go func() { - var err error - gorm, err := database.GetGorm(taskCtx) log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Failed to get Gorm client") span.RecordError(err) span.SetStatus(codes.Error, err.Error()) - taskSummery, err := Algorithm(taskCtx, s.plugInterface, gorm, userSource, creation.DeepScrape, creation.ApiKey) + taskSummery, err := algorithm(taskCtx, s.plugInterface, userSource, anthroveUserFavCount, creation.DeepScrape, creation.ApiKey) if err != nil { log.WithContext(taskCtx).WithError(err).WithField("task_id", id).Error("Task execution failed") span.RecordError(err) @@ -306,3 +333,15 @@ func errorString(err error) string { } return "" } + +// getUserFavoriteCountFromDatabase +func getUserFavoriteCountFromDatabase(ctx context.Context, gorm *gorm.DB, userID models.UserID, userSourceID models.UserSourceID) (int64, error) { + var count int64 + + err := gorm.WithContext(ctx).Model(&models.UserFavorite{}).Where("user_id = ? AND user_source_id = ?", userID, userSourceID).Count(&count).Error + if err != nil { + return count, err + } + + return count, nil +}