diff --git a/internal/postgres/custom.go b/internal/postgres/custom.go new file mode 100644 index 0000000..c360669 --- /dev/null +++ b/internal/postgres/custom.go @@ -0,0 +1,42 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "gorm.io/gorm" +) + +func ExecuteRawStatement(ctx context.Context, db *gorm.DB, query string, args ...any) error { + if query == "" { + return errors.New("query can not be empty") + } + + if args == nil { + return errors.New("arguments can not be nil") + } + + result := db.WithContext(ctx).Exec(query, args...) + if result.Error != nil { + return result.Error + } + + return nil +} + +func QueryRawStatement(ctx context.Context, db *gorm.DB, query string, args ...any) (*sql.Rows, error) { + if query == "" { + return nil, errors.New("query can not be empty") + } + + if args == nil { + return nil, errors.New("arguments can not be nil") + } + + result, err := db.WithContext(ctx).Raw(query, args...).Rows() + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/internal/postgres/custom_test.go b/internal/postgres/custom_test.go new file mode 100644 index 0000000..22c176b --- /dev/null +++ b/internal/postgres/custom_test.go @@ -0,0 +1,125 @@ +package postgres + +import ( + "context" + "database/sql" + "git.anthrove.art/Anthrove/otter-space-sdk/v2/test" + "gorm.io/gorm" + "reflect" + "testing" +) + +func TestExecuteRawStatement(t *testing.T) { + // Setup trow away container + ctx := context.Background() + container, gormDB, err := test.StartPostgresContainer(ctx) + if err != nil { + t.Fatalf("Could not start PostgreSQL container: %v", err) + } + defer container.Terminate(ctx) + + // Test + type args struct { + ctx context.Context + db *gorm.DB + query string + args []any + } + tests := []struct { + name string + args args + want *sql.Rows + wantErr bool + }{ + { + name: "Test 01: Empty Query", + args: args{ + ctx: ctx, + db: gormDB, + query: "", + args: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 02: Nil Query", + args: args{ + ctx: ctx, + db: gormDB, + query: "aasd", + args: nil, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ExecuteRawStatement(tt.args.ctx, tt.args.db, tt.args.query, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("ExecuteRawStatement() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestQueryRawStatement(t *testing.T) { + // Setup trow away container + ctx := context.Background() + container, gormDB, err := test.StartPostgresContainer(ctx) + if err != nil { + t.Fatalf("Could not start PostgreSQL container: %v", err) + } + defer container.Terminate(ctx) + + // Test + type args struct { + ctx context.Context + db *gorm.DB + query string + args []any + } + tests := []struct { + name string + args args + want *sql.Rows + wantErr bool + }{ + { + name: "Test 01: Empty Query", + args: args{ + ctx: ctx, + db: gormDB, + query: "", + args: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 02: Nil Query", + args: args{ + ctx: ctx, + db: gormDB, + query: "aasd", + args: nil, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := QueryRawStatement(tt.args.ctx, tt.args.db, tt.args.query, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("QueryRawStatement() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("QueryRawStatement() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/database/database.go b/pkg/database/database.go index 9ae09b2..3a6beeb 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -2,7 +2,7 @@ package database import ( "context" - + "database/sql" "git.anthrove.art/Anthrove/otter-space-sdk/v2/pkg/models" ) @@ -27,4 +27,10 @@ type OtterSpace interface { // TagGroup contains all function that are needed to manage the TagGroup TagGroup + + // ExecuteRawStatement run a custom query. + ExecuteRawStatement(ctx context.Context, query string, args ...any) error + + // QueryRawStatement runs a custom query and returns the table + QueryRawStatement(ctx context.Context, query string, args ...any) (*sql.Rows, error) } diff --git a/pkg/database/postgres.go b/pkg/database/postgres.go index 57a87dd..001d7b0 100644 --- a/pkg/database/postgres.go +++ b/pkg/database/postgres.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "embed" "fmt" log2 "log" @@ -160,8 +161,6 @@ func (p *postgresqlConnection) GetSourceByDomain(ctx context.Context, sourceDoma return postgres.GetSourceByDomain(ctx, p.db, sourceDomain) } -// NEW FUNCTIONS - func (p *postgresqlConnection) UpdateUserSourceScrapeTimeInterval(ctx context.Context, anthroveUserID models.AnthroveUserID, sourceID models.AnthroveSourceID, scrapeTime models.AnthroveScrapeTimeInterval) error { return postgres.UpdateUserSourceScrapeTimeInterval(ctx, p.db, anthroveUserID, sourceID, scrapeTime) } @@ -231,6 +230,14 @@ func (p *postgresqlConnection) CreateTagGroupInBatch(ctx context.Context, tagGro } +func (p *postgresqlConnection) ExecuteRawStatement(ctx context.Context, query string, args ...any) error { + return postgres.ExecuteRawStatement(ctx, p.db, query, args...) +} + +func (p *postgresqlConnection) QueryRawStatement(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + return postgres.QueryRawStatement(ctx, p.db, query, args...) +} + // HELPER func (p *postgresqlConnection) migrateDatabase(dbPool *gorm.DB) error { diff --git a/pkg/database/postgres_test.go b/pkg/database/postgres_test.go index 46efe94..221b882 100644 --- a/pkg/database/postgres_test.go +++ b/pkg/database/postgres_test.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "fmt" "reflect" "testing" @@ -3671,3 +3672,120 @@ func checkFavoritePosts(got *models.FavoriteList, want *models.FavoriteList) boo return true } + +func Test_postgresqlConnection_ExecuteRawStatement(t *testing.T) { + // Setup trow away container + ctx := context.Background() + container, gormDB, err := test.StartPostgresContainer(ctx) + if err != nil { + t.Fatalf("Could not start PostgreSQL container: %v", err) + } + defer container.Terminate(ctx) + + // Test + type args struct { + ctx context.Context + query string + args []any + } + tests := []struct { + name string + args args + want *sql.Rows + wantErr bool + }{ + { + name: "Test 01: Empty Query", + args: args{ + ctx: ctx, + query: "", + args: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 02: Nil Query", + args: args{ + ctx: ctx, + query: "aasd", + args: nil, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &postgresqlConnection{ + db: gormDB, + debug: true, + } + err := p.ExecuteRawStatement(tt.args.ctx, tt.args.query, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("ExecuteRawStatement() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func Test_postgresqlConnection_QueryRawStatement(t *testing.T) { + // Setup trow away container + ctx := context.Background() + container, gormDB, err := test.StartPostgresContainer(ctx) + if err != nil { + t.Fatalf("Could not start PostgreSQL container: %v", err) + } + defer container.Terminate(ctx) + + // Test + type args struct { + ctx context.Context + query string + args []any + } + tests := []struct { + name string + args args + want *sql.Rows + wantErr bool + }{ + { + name: "Test 01: Empty Query", + args: args{ + ctx: ctx, + query: "", + args: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 02: Nil Query", + args: args{ + ctx: ctx, + query: "aasd", + args: nil, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &postgresqlConnection{ + db: gormDB, + debug: true, + } + got, err := p.QueryRawStatement(tt.args.ctx, tt.args.query, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("QueryRawStatement() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("QueryRawStatement() got = %v, want %v", got, tt.want) + } + }) + } +}