debian-mirror-gitlab/app/services/ml/experiment_tracking/candidate_repository.rb

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

100 lines
2.6 KiB
Ruby
Raw Normal View History

2022-11-25 23:54:43 +05:30
# frozen_string_literal: true
module Ml
module ExperimentTracking
class CandidateRepository
attr_accessor :project, :user, :experiment, :candidate
def initialize(project, user)
@project = project
@user = user
end
2023-06-20 00:43:36 +05:30
def by_eid(eid)
::Ml::Candidate.with_project_id_and_eid(project.id, eid)
2022-11-25 23:54:43 +05:30
end
2023-03-17 16:20:25 +05:30
def create!(experiment, start_time, tags = nil, name = nil)
2023-03-04 22:38:38 +05:30
candidate = experiment.candidates.create!(
2022-11-25 23:54:43 +05:30
user: user,
2023-03-17 16:20:25 +05:30
name: candidate_name(name, tags),
2023-06-20 00:43:36 +05:30
project: project,
2022-11-25 23:54:43 +05:30
start_time: start_time || 0
)
2023-03-04 22:38:38 +05:30
add_tags(candidate, tags)
candidate
2022-11-25 23:54:43 +05:30
end
def update(candidate, status, end_time)
candidate.status = status.downcase if status
candidate.end_time = end_time if end_time
candidate.save
end
def add_metric!(candidate, name, value, tracked_at, step)
candidate.metrics.create!(
name: name,
value: value,
tracked_at: tracked_at,
step: step
)
end
def add_param!(candidate, name, value)
candidate.params.create!(name: name, value: value)
end
2023-03-04 22:38:38 +05:30
def add_tag!(candidate, name, value)
candidate.metadata.create!(name: name, value: value)
end
2022-11-25 23:54:43 +05:30
2023-03-04 22:38:38 +05:30
def add_metrics(candidate, metric_definitions)
extra_keys = { tracked_at: :timestamp, step: :step }
insert_many(candidate, metric_definitions, ::Ml::CandidateMetric, extra_keys)
2022-11-25 23:54:43 +05:30
end
def add_params(candidate, param_definitions)
2023-03-04 22:38:38 +05:30
insert_many(candidate, param_definitions, ::Ml::CandidateParam)
end
2022-11-25 23:54:43 +05:30
2023-03-04 22:38:38 +05:30
def add_tags(candidate, tag_definitions)
insert_many(candidate, tag_definitions, ::Ml::CandidateMetadata)
2022-11-25 23:54:43 +05:30
end
private
def timestamps
current_time = Time.zone.now
{ created_at: current_time, updated_at: current_time }
end
2023-03-04 22:38:38 +05:30
def insert_many(candidate, definitions, entity_class, extra_keys = {})
return unless candidate.present? && definitions.present?
entities = definitions.map do |d|
{
candidate_id: candidate.id,
name: d[:key],
value: d[:value],
**extra_keys.transform_values { |old_key| d[old_key] },
**timestamps
}
end
entity_class.insert_all(entities, returning: false) unless entities.empty?
end
2023-03-17 16:20:25 +05:30
def candidate_name(name, tags)
return name if name.present?
return unless tags.present?
tags.detect { |t| t[:key] == 'mlflow.runName' }&.dig(:value)
end
2022-11-25 23:54:43 +05:30
end
end
end