diff --git a/internal/migration_acceptance_tests/function_cases_test.go b/internal/migration_acceptance_tests/function_cases_test.go index 69dc081..9f8b785 100644 --- a/internal/migration_acceptance_tests/function_cases_test.go +++ b/internal/migration_acceptance_tests/function_cases_test.go @@ -138,7 +138,7 @@ var functionAcceptanceTestCases = []acceptanceTestCase{ LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT - RETURN CONCAT(a, b); + RETURN CONCAT(a, b); `}, expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, }, @@ -246,7 +246,7 @@ var functionAcceptanceTestCases = []acceptanceTestCase{ LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT - RETURN CONCAT(a, b); + RETURN CONCAT(a, b); `, }, expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, @@ -571,6 +571,90 @@ var functionAcceptanceTestCases = []acceptanceTestCase{ }, expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, }, + + { + name: "Non-sql function used a default for a column", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add() RETURNS text + LANGUAGE plpgsql + AS $$ + declare + begin + return 'hi'; + end; + $$; + + CREATE TABLE foobar ( + foo text DEFAULT add() NOT NULL + ); + + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Deletion of non-sql function used a default for a column", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add() RETURNS text + LANGUAGE plpgsql + AS $$ + declare + begin + return 'hi'; + end; + $$; + + CREATE TABLE foobar ( + foo text DEFAULT add() NOT NULL + ); + + `, + }, + newSchemaDDL: nil, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies, diff.MigrationHazardTypeDeletesData}, + }, + + { + name: "Sql function used a default for a column", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add() RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN 'hi'; + + CREATE TABLE foobar ( + foo text DEFAULT add() NOT NULL + ); + + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{}, + }, + { + name: "Deletion of sql function used a default for a column", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add() RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN 'hi'; + + CREATE TABLE foobar ( + foo text DEFAULT add() NOT NULL + ); + + `, + }, + newSchemaDDL: nil, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeDeletesData}, + }, } func (suite *acceptanceTestSuite) TestFunctionTestCases() { diff --git a/internal/queries/dml.sql.go b/internal/queries/dml.sql.go index 7c7e70e..51c445c 100644 --- a/internal/queries/dml.sql.go +++ b/internal/queries/dml.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.18.0 +// sqlc v1.28.0 package queries diff --git a/internal/queries/models.sql.go b/internal/queries/models.sql.go index 5da8a9e..81c5718 100644 --- a/internal/queries/models.sql.go +++ b/internal/queries/models.sql.go @@ -1,7 +1,5 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.18.0 +// sqlc v1.28.0 package queries - -import () diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index 56a2419..8ea282c 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -103,7 +103,8 @@ SELECT identity_col_seq.seqmin AS min_value, identity_col_seq.seqcache AS cache_size, identity_col_seq.seqcycle AS is_cycle, - pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type + pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type, + default_value.oid AS default_value_oid FROM pg_catalog.pg_attribute AS a LEFT JOIN pg_catalog.pg_attrdef AS d @@ -117,6 +118,11 @@ LEFT JOIN ON a.attrelid = identity_col_seq.owner_relid AND a.attnum = identity_col_seq.owner_attnum +LEFT JOIN + pg_catalog.pg_attrdef AS default_value + ON + a.attrelid = default_value.adrelid + AND a.attnum = default_value.adnum WHERE a.attrelid = $1 AND a.attnum > 0 diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 8e02988..bca145b 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -129,7 +129,8 @@ SELECT identity_col_seq.seqmin AS min_value, identity_col_seq.seqcache AS cache_size, identity_col_seq.seqcycle AS is_cycle, - pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type + pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type, + default_value.oid AS default_value_oid FROM pg_catalog.pg_attribute AS a LEFT JOIN pg_catalog.pg_attrdef AS d @@ -143,6 +144,11 @@ LEFT JOIN ON a.attrelid = identity_col_seq.owner_relid AND a.attnum = identity_col_seq.owner_attnum +LEFT JOIN + pg_catalog.pg_attrdef AS default_value + ON + a.attrelid = default_value.adrelid + AND a.attnum = default_value.adnum WHERE a.attrelid = $1 AND a.attnum > 0 @@ -165,6 +171,7 @@ type GetColumnsForTableRow struct { CacheSize sql.NullInt64 IsCycle sql.NullBool ColumnType string + DefaultValueOid interface{} } func (q *Queries) GetColumnsForTable(ctx context.Context, attrelid interface{}) ([]GetColumnsForTableRow, error) { @@ -191,6 +198,7 @@ func (q *Queries) GetColumnsForTable(ctx context.Context, attrelid interface{}) &i.CacheSize, &i.IsCycle, &i.ColumnType, + &i.DefaultValueOid, ); err != nil { return nil, err } diff --git a/internal/schema/schema.go b/internal/schema/schema.go index e8c50e1..af00a9a 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -234,6 +234,8 @@ type ( // It is used for data-packing purposes Size int Identity *ColumnIdentity + + DependsOnFunctions []SchemaQualifiedName } ) @@ -878,6 +880,11 @@ func (s *schemaFetcher) buildTable( } } + dependsOnFunctions, _ := s.fetchDependsOnFunctions(ctx, "pg_attrdef", column.DefaultValueOid) + if err != nil { + return Table{}, fmt.Errorf("fetchDependsOnFunctions(%s): %w", column.DefaultValueOid, err) + } + columns = append(columns, Column{ Name: column.ColumnName, Type: column.ColumnType, @@ -888,9 +895,10 @@ func (s *schemaFetcher) buildTable( // ''::text // CURRENT_TIMESTAMP // If empty, indicates that there is no default value. - Default: column.DefaultValue, - Size: int(column.ColumnSize), - Identity: identity, + Default: column.DefaultValue, + Size: int(column.ColumnSize), + Identity: identity, + DependsOnFunctions: dependsOnFunctions, }) } @@ -905,6 +913,7 @@ func (s *schemaFetcher) buildTable( SchemaName: table.TableSchemaName, EscapedName: EscapeIdentifier(table.TableName), } + return Table{ SchemaQualifiedName: schemaQualifiedName, Columns: columns, diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 94f67ab..7892b60 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -1074,6 +1074,13 @@ func (t *tableSQLVertexGenerator) GetAddAlterDependencies(table, _ schema.Table) mustRun(t.GetSQLVertexId(table, diffTypeAddAlter)).after(buildTableVertexId(*table.ParentTable, diffTypeAddAlter)), ) } + + for _, col := range table.Columns { + for _, depFunction := range col.DependsOnFunctions { + deps = append(deps, mustRun(t.GetSQLVertexId(table, diffTypeDelete)).after(buildFunctionVertexId(depFunction, diffTypeDelete))) + } + } + return deps, nil } @@ -1142,6 +1149,13 @@ func (t *tableSQLVertexGenerator) GetDeleteDependencies(table schema.Table) ([]d mustRun(t.GetSQLVertexId(table, diffTypeDelete)).after(buildTableVertexId(*table.ParentTable, diffTypeDelete)), ) } + + for _, col := range table.Columns { + for _, depFunction := range col.DependsOnFunctions { + deps = append(deps, mustRun(t.GetSQLVertexId(table, diffTypeDelete)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) + } + } + return deps, nil } @@ -1407,13 +1421,26 @@ func buildColumnVertexId(columnName string, diffType diffType) sqlVertexId { } func (csg *columnSQLVertexGenerator) GetAddAlterDependencies(col, _ schema.Column) ([]dependency, error) { - return []dependency{ + + var deps []dependency = []dependency{ mustRun(csg.GetSQLVertexId(col, diffTypeDelete)).before(csg.GetSQLVertexId(col, diffTypeAddAlter)), - }, nil + } + + for _, depFunction := range col.DependsOnFunctions { + deps = append(deps, mustRun(csg.GetSQLVertexId(col, diffTypeDelete)).after(buildFunctionVertexId(depFunction, diffTypeDelete))) + } + + return deps, nil + } -func (csg *columnSQLVertexGenerator) GetDeleteDependencies(_ schema.Column) ([]dependency, error) { - return nil, nil +func (csg *columnSQLVertexGenerator) GetDeleteDependencies(col schema.Column) ([]dependency, error) { + + var deps []dependency + for _, depFunction := range col.DependsOnFunctions { + deps = append(deps, mustRun(csg.GetSQLVertexId(col, diffTypeDelete)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) + } + return deps, nil } type renameConflictingIndexSQLVertexGenerator struct {