109 lines
3.9 KiB
Ruby
109 lines
3.9 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module Database
|
|
module PreventCrossDatabaseModification
|
|
CrossDatabaseModificationAcrossUnsupportedTablesError = Class.new(StandardError)
|
|
|
|
module GitlabDatabaseMixin
|
|
def allow_cross_database_modification_within_transaction(url:)
|
|
cross_database_context = Database::PreventCrossDatabaseModification.cross_database_context
|
|
return yield unless cross_database_context && cross_database_context[:enabled]
|
|
|
|
transaction_tracker_enabled_was = cross_database_context[:enabled]
|
|
cross_database_context[:enabled] = false
|
|
|
|
yield
|
|
ensure
|
|
cross_database_context[:enabled] = transaction_tracker_enabled_was if cross_database_context
|
|
end
|
|
end
|
|
|
|
module SpecHelpers
|
|
def with_cross_database_modification_prevented
|
|
subscriber = ActiveSupport::Notifications.subscribe('sql.active_record') do |name, start, finish, id, payload|
|
|
PreventCrossDatabaseModification.prevent_cross_database_modification!(payload[:connection], payload[:sql])
|
|
end
|
|
|
|
PreventCrossDatabaseModification.reset_cross_database_context!
|
|
PreventCrossDatabaseModification.cross_database_context.merge!(enabled: true, subscriber: subscriber)
|
|
|
|
yield if block_given?
|
|
ensure
|
|
cleanup_with_cross_database_modification_prevented if block_given?
|
|
end
|
|
|
|
def cleanup_with_cross_database_modification_prevented
|
|
ActiveSupport::Notifications.unsubscribe(PreventCrossDatabaseModification.cross_database_context[:subscriber])
|
|
PreventCrossDatabaseModification.cross_database_context[:enabled] = false
|
|
end
|
|
end
|
|
|
|
def self.cross_database_context
|
|
Thread.current[:transaction_tracker]
|
|
end
|
|
|
|
def self.reset_cross_database_context!
|
|
Thread.current[:transaction_tracker] = initial_data
|
|
end
|
|
|
|
def self.initial_data
|
|
{
|
|
enabled: false,
|
|
transaction_depth_by_db: Hash.new { |h, k| h[k] = 0 },
|
|
modified_tables_by_db: Hash.new { |h, k| h[k] = Set.new }
|
|
}
|
|
end
|
|
|
|
def self.prevent_cross_database_modification!(connection, sql)
|
|
return unless cross_database_context[:enabled]
|
|
|
|
database = connection.pool.db_config.name
|
|
|
|
if sql.start_with?('SAVEPOINT')
|
|
cross_database_context[:transaction_depth_by_db][database] += 1
|
|
|
|
return
|
|
elsif sql.start_with?('RELEASE SAVEPOINT', 'ROLLBACK TO SAVEPOINT')
|
|
cross_database_context[:transaction_depth_by_db][database] -= 1
|
|
if cross_database_context[:transaction_depth_by_db][database] <= 0
|
|
cross_database_context[:modified_tables_by_db][database].clear
|
|
end
|
|
|
|
return
|
|
end
|
|
|
|
return if cross_database_context[:transaction_depth_by_db].values.all?(&:zero?)
|
|
|
|
tables = PgQuery.parse(sql).dml_tables
|
|
|
|
return if tables.empty?
|
|
|
|
cross_database_context[:modified_tables_by_db][database].merge(tables)
|
|
|
|
all_tables = cross_database_context[:modified_tables_by_db].values.map(&:to_a).flatten
|
|
|
|
unless PreventCrossJoins.only_ci_or_only_main?(all_tables)
|
|
raise Database::PreventCrossDatabaseModification::CrossDatabaseModificationAcrossUnsupportedTablesError,
|
|
"Cross-database data modification queries (CI and Main) were detected within " \
|
|
"a transaction '#{all_tables.join(", ")}' discovered"
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
Gitlab::Database.singleton_class.prepend(
|
|
Database::PreventCrossDatabaseModification::GitlabDatabaseMixin)
|
|
|
|
RSpec.configure do |config|
|
|
config.include(::Database::PreventCrossDatabaseModification::SpecHelpers)
|
|
|
|
# Using before and after blocks because the around block causes problems with the let_it_be
|
|
# record creations. It makes an extra savepoint which breaks the transaction count logic.
|
|
config.before(:each, :prevent_cross_database_modification) do
|
|
with_cross_database_modification_prevented
|
|
end
|
|
|
|
config.after(:each, :prevent_cross_database_modification) do
|
|
cleanup_with_cross_database_modification_prevented
|
|
end
|
|
end
|