diff --git a/internal/postgres/custom.go b/internal/postgres/custom.go index 919184f..c360669 100644 --- a/internal/postgres/custom.go +++ b/internal/postgres/custom.go @@ -7,7 +7,24 @@ import ( "gorm.io/gorm" ) -func ExecuteRawStatement(ctx context.Context, db *gorm.DB, query string, args ...any) (*sql.Rows, error) { +func ExecuteRawStatement(ctx context.Context, db *gorm.DB, query string, args ...any) error { + if query == "" { + return errors.New("query can not be empty") + } + + if args == nil { + return errors.New("arguments can not be nil") + } + + result := db.WithContext(ctx).Exec(query, args...) + if result.Error != nil { + return result.Error + } + + return nil +} + +func QueryRawStatement(ctx context.Context, db *gorm.DB, query string, args ...any) (*sql.Rows, error) { if query == "" { return nil, errors.New("query can not be empty") } @@ -16,7 +33,7 @@ func ExecuteRawStatement(ctx context.Context, db *gorm.DB, query string, args .. return nil, errors.New("arguments can not be nil") } - result, err := db.WithContext(ctx).Exec(query, args...).Rows() + result, err := db.WithContext(ctx).Raw(query, args...).Rows() if err != nil { return nil, err } diff --git a/internal/postgres/custom_test.go b/internal/postgres/custom_test.go index e763e2a..22c176b 100644 --- a/internal/postgres/custom_test.go +++ b/internal/postgres/custom_test.go @@ -56,13 +56,69 @@ func TestExecuteRawStatement(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ExecuteRawStatement(tt.args.ctx, tt.args.db, tt.args.query, tt.args.args...) + err := ExecuteRawStatement(tt.args.ctx, tt.args.db, tt.args.query, tt.args.args...) if (err != nil) != tt.wantErr { t.Errorf("ExecuteRawStatement() error = %v, wantErr %v", err, tt.wantErr) return } + }) + } +} + +func TestQueryRawStatement(t *testing.T) { + // Setup trow away container + ctx := context.Background() + container, gormDB, err := test.StartPostgresContainer(ctx) + if err != nil { + t.Fatalf("Could not start PostgreSQL container: %v", err) + } + defer container.Terminate(ctx) + + // Test + type args struct { + ctx context.Context + db *gorm.DB + query string + args []any + } + tests := []struct { + name string + args args + want *sql.Rows + wantErr bool + }{ + { + name: "Test 01: Empty Query", + args: args{ + ctx: ctx, + db: gormDB, + query: "", + args: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 02: Nil Query", + args: args{ + ctx: ctx, + db: gormDB, + query: "aasd", + args: nil, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := QueryRawStatement(tt.args.ctx, tt.args.db, tt.args.query, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("QueryRawStatement() error = %v, wantErr %v", err, tt.wantErr) + return + } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("ExecuteRawStatement() got = %v, want %v", got, tt.want) + t.Errorf("QueryRawStatement() got = %v, want %v", got, tt.want) } }) } diff --git a/pkg/database/database.go b/pkg/database/database.go index 4364884..3a6beeb 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -3,7 +3,6 @@ package database import ( "context" "database/sql" - "git.anthrove.art/Anthrove/otter-space-sdk/v2/pkg/models" ) @@ -30,5 +29,8 @@ type OtterSpace interface { TagGroup // ExecuteRawStatement run a custom query. - ExecuteRawStatement(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecuteRawStatement(ctx context.Context, query string, args ...any) error + + // QueryRawStatement runs a custom query and returns the table + QueryRawStatement(ctx context.Context, query string, args ...any) (*sql.Rows, error) } diff --git a/pkg/database/postgres.go b/pkg/database/postgres.go index d5e9a4b..001d7b0 100644 --- a/pkg/database/postgres.go +++ b/pkg/database/postgres.go @@ -230,10 +230,14 @@ func (p *postgresqlConnection) CreateTagGroupInBatch(ctx context.Context, tagGro } -func (p *postgresqlConnection) ExecuteRawStatement(ctx context.Context, query string, args ...any) (*sql.Rows, error) { +func (p *postgresqlConnection) ExecuteRawStatement(ctx context.Context, query string, args ...any) error { return postgres.ExecuteRawStatement(ctx, p.db, query, args...) } +func (p *postgresqlConnection) QueryRawStatement(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + return postgres.QueryRawStatement(ctx, p.db, query, args...) +} + // HELPER func (p *postgresqlConnection) migrateDatabase(dbPool *gorm.DB) error { diff --git a/pkg/database/postgres_test.go b/pkg/database/postgres_test.go index 85a4eed..221b882 100644 --- a/pkg/database/postgres_test.go +++ b/pkg/database/postgres_test.go @@ -3721,13 +3721,70 @@ func Test_postgresqlConnection_ExecuteRawStatement(t *testing.T) { db: gormDB, debug: true, } - got, err := p.ExecuteRawStatement(tt.args.ctx, tt.args.query, tt.args.args...) + err := p.ExecuteRawStatement(tt.args.ctx, tt.args.query, tt.args.args...) if (err != nil) != tt.wantErr { t.Errorf("ExecuteRawStatement() error = %v, wantErr %v", err, tt.wantErr) return } + }) + } +} + +func Test_postgresqlConnection_QueryRawStatement(t *testing.T) { + // Setup trow away container + ctx := context.Background() + container, gormDB, err := test.StartPostgresContainer(ctx) + if err != nil { + t.Fatalf("Could not start PostgreSQL container: %v", err) + } + defer container.Terminate(ctx) + + // Test + type args struct { + ctx context.Context + query string + args []any + } + tests := []struct { + name string + args args + want *sql.Rows + wantErr bool + }{ + { + name: "Test 01: Empty Query", + args: args{ + ctx: ctx, + query: "", + args: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 02: Nil Query", + args: args{ + ctx: ctx, + query: "aasd", + args: nil, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &postgresqlConnection{ + db: gormDB, + debug: true, + } + got, err := p.QueryRawStatement(tt.args.ctx, tt.args.query, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("QueryRawStatement() error = %v, wantErr %v", err, tt.wantErr) + return + } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("ExecuteRawStatement() got = %v, want %v", got, tt.want) + t.Errorf("QueryRawStatement() got = %v, want %v", got, tt.want) } }) }