Skip to content

Commit ed21c45

Browse files
Fixes for cases when target data is absent (on GPU and in eval_set).
ref:d6afb0c6ece9cb3940272eac9e4ccf6a8ee40cd8
1 parent 002d849 commit ed21c45

26 files changed

Lines changed: 1369 additions & 124 deletions

File tree

catboost/cuda/data/binarizations_manager.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,13 @@ namespace NCatboostCuda {
344344
void TBinarizedFeaturesManager::CreateCtrConfigsFromDescription(const NCatboostOptions::TCtrDescription& ctrDescription,
345345
TMap<ECtrType, TSet<TCtrConfig>>* grouppedConfigs) const{
346346
for (const auto& prior : ctrDescription.GetPriors()) {
347-
CB_ENSURE(!TargetBorders.empty(), "Enable ctr description should be done after target borders are set");
347+
ECtrType type = ctrDescription.Type;
348+
if ((type != ECtrType::Counter) && !HasTargetBinarization()) {
349+
continue;
350+
}
351+
348352
CB_ENSURE(ctrDescription.GetPriors().size(), "Set priors first");
349353

350-
ECtrType type = ctrDescription.Type;
351354
TCtrConfig defaultConfig;
352355

353356
defaultConfig.Prior = prior;

catboost/cuda/gpu_data/dataset_helpers.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,15 @@ TVector<ui32> NCatboostCuda::GetLearnFeatureIds(NCatboostCuda::TBinarizedFeature
104104

105105
namespace NCatboostCuda {
106106
TMirrorBuffer<ui8> BuildBinarizedTarget(const TBinarizedFeaturesManager& featuresManager, const TVector<float>& targets) {
107-
CB_ENSURE(featuresManager.HasTargetBinarization(),
108-
"Error: No target binarization found. Can't make binarized target. Probably input labels columns was constant") ;
109-
auto& borders = featuresManager.GetTargetBorders();
110-
111-
auto binarizedTarget = NCB::BinarizeLine<ui8>(targets,
112-
ENanMode::Forbidden,
113-
borders);
107+
TVector<ui8> binarizedTarget;
108+
if (featuresManager.HasTargetBinarization()) {
109+
auto& borders = featuresManager.GetTargetBorders();
110+
binarizedTarget = NCB::BinarizeLine<ui8>(targets,
111+
ENanMode::Forbidden,
112+
borders);
113+
} else {
114+
binarizedTarget.resize(targets.size(), 0);
115+
}
114116

115117
TMirrorBuffer<ui8> binarizedTargetGpu = TMirrorBuffer<ui8>::Create(NCudaLib::TMirrorMapping(binarizedTarget.size()));
116118
binarizedTargetGpu.Write(binarizedTarget);

catboost/cuda/methods/boosting_progress_tracker.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace NCatboostCuda {
2121
TBoostingProgressTracker::TBoostingProgressTracker(const NCatboostOptions::TCatBoostOptions& catBoostOptions,
2222
const NCatboostOptions::TOutputFilesOptions& outputFilesOptions,
2323
bool hasTest,
24+
bool testHasTarget,
2425
ui32 cpuApproxDim,
2526
const TMaybe<std::function<bool(const TMetricsAndTimeLeftHistory&)>>& onEndIterationCallback)
2627
: CatboostOptions(catBoostOptions)
@@ -37,6 +38,7 @@ namespace NCatboostCuda {
3738
, ProfileInfo(catBoostOptions.BoostingOptions->IterationCount)
3839
, MetricDescriptions(GetMetricsDescription(GetCpuMetrics(Metrics)))
3940
, IsSkipOnTrainFlags(GetSkipMetricOnTrain(GetCpuMetrics(Metrics)))
41+
, IsSkipOnTestFlags(GetSkipMetricOnTest(testHasTarget, GetCpuMetrics(Metrics)))
4042
{
4143
if (OutputOptions.AllowWriteFiles()) {
4244
CreateMetaFile(OutputFiles,
@@ -58,6 +60,11 @@ namespace NCatboostCuda {
5860
CatboostOptions.Save(&options);
5961
CatBoostOptionsStr = ToString<NJson::TJsonValue>(options);
6062
}
63+
64+
if (HasTest && IsSkipOnTestFlags[0]) {
65+
CATBOOST_WARNING_LOG << "Warning: Eval metric " << Metrics[0]->GetMetricDescription() <<
66+
" needs Target data, but test dataset does not have it so it won't be calculated" << Endl;
67+
}
6168
}
6269

6370
void TBoostingProgressTracker::OnFirstCall() {
@@ -82,12 +89,13 @@ namespace NCatboostCuda {
8289
History.TimeHistory.push_back({ProfileInfo.GetProfileResults().PassedTime,
8390
ProfileInfo.GetProfileResults().RemainingTime});
8491

92+
constexpr size_t evalMetricIdx = 0;
8593
Log((int)Iteration,
8694
MetricDescriptions,
8795
History.LearnMetricsHistory,
8896
History.TestMetricsHistory,
89-
ErrorTracker.GetBestError(),
90-
ErrorTracker.GetBestIteration(),
97+
!IsSkipOnTestFlags[evalMetricIdx] ? TMaybe<double>(ErrorTracker.GetBestError()) : Nothing(),
98+
!IsSkipOnTestFlags[evalMetricIdx] ? TMaybe<int>(ErrorTracker.GetBestIteration()) : Nothing(),
9199
ProfileInfo.GetProfileResults(),
92100
LearnToken,
93101
TestTokens,
@@ -126,15 +134,20 @@ namespace NCatboostCuda {
126134
// In case of changing the order it should be changed in CPU mode also.
127135
const int errorTrackerMetricIdx = calcErrorTrackerMetric ? 0 : -1;
128136
for (int i = 0; i < Metrics.ysize(); ++i) {
129-
if (calcAllMetrics || i == errorTrackerMetricIdx) {
130-
auto metricValue = Metrics[i]->GetCpuMetric().GetFinalError(metricCalcer.Compute(Metrics[i].Get()));
131-
History.AddTestError(0 /*testIdx*/, Metrics[i]->GetCpuMetric(), metricValue, i == errorTrackerMetricIdx);
132-
133-
if (i == errorTrackerMetricIdx) {
134-
ErrorTracker.AddError(metricValue, static_cast<int>(GetCurrentIteration()));
135-
if (OutputOptions.UseBestModel && static_cast<int>(GetCurrentIteration() + 1) >= OutputOptions.BestModelMinTrees) {
136-
BestModelMinTreesTracker.AddError(metricValue, static_cast<int>(GetCurrentIteration()));
137-
}
137+
if (!calcAllMetrics && (i != errorTrackerMetricIdx)) {
138+
continue;
139+
}
140+
if (IsSkipOnTestFlags[i]) {
141+
continue;
142+
}
143+
144+
auto metricValue = Metrics[i]->GetCpuMetric().GetFinalError(metricCalcer.Compute(Metrics[i].Get()));
145+
History.AddTestError(0 /*testIdx*/, Metrics[i]->GetCpuMetric(), metricValue, i == errorTrackerMetricIdx);
146+
147+
if (i == errorTrackerMetricIdx) {
148+
ErrorTracker.AddError(metricValue, static_cast<int>(GetCurrentIteration()));
149+
if (OutputOptions.UseBestModel && static_cast<int>(GetCurrentIteration() + 1) >= OutputOptions.BestModelMinTrees) {
150+
BestModelMinTreesTracker.AddError(metricValue, static_cast<int>(GetCurrentIteration()));
138151
}
139152
}
140153
}
@@ -171,8 +184,12 @@ namespace NCatboostCuda {
171184

172185
// WriteHistory & update ErrorTracker
173186
for (ui64 iteration = 0; iteration < Iteration; ++iteration) {
187+
const int testIdxToLog = 0;
174188
if (ShouldCalcMetricOnIteration(iteration) && iteration < History.TestMetricsHistory.size()) {
175-
const int testIdxToLog = 0;
189+
if (IsSkipOnTestFlags[testIdxToLog]) {
190+
continue;
191+
}
192+
176193
const int metricIdxToLog = 0;
177194
const TString& metricDescription = Metrics[metricIdxToLog]->GetCpuMetric().GetDescription();
178195
const double error = History.TestMetricsHistory[iteration][testIdxToLog].at(metricDescription);
@@ -187,8 +204,8 @@ namespace NCatboostCuda {
187204
MetricDescriptions,
188205
History.LearnMetricsHistory,
189206
History.TestMetricsHistory,
190-
ErrorTracker.GetBestError(),
191-
ErrorTracker.GetBestIteration(),
207+
!IsSkipOnTestFlags[testIdxToLog] ? TMaybe<double>(ErrorTracker.GetBestError()) : Nothing(),
208+
!IsSkipOnTestFlags[testIdxToLog] ? TMaybe<int>(ErrorTracker.GetBestIteration()) : Nothing(),
192209
TProfileResults(History.TimeHistory[iteration].PassedTime, History.TimeHistory[iteration].RemainingTime),
193210
LearnToken,
194211
TestTokens,

catboost/cuda/methods/boosting_progress_tracker.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace NCatboostCuda {
2828
TBoostingProgressTracker(const NCatboostOptions::TCatBoostOptions& catBoostOptions,
2929
const NCatboostOptions::TOutputFilesOptions& outputFilesOptions,
3030
bool hasTest,
31+
bool testHasTarget,
3132
ui32 cpuApproxDim,
3233
const TMaybe<std::function<bool(const TMetricsAndTimeLeftHistory&)>>& onEndIterationCallback);
3334

@@ -73,6 +74,10 @@ namespace NCatboostCuda {
7374
return this->History;
7475
}
7576

77+
bool EvalMetricWasCalculated() const {
78+
return HasTest && !IsSkipOnTestFlags[0];
79+
}
80+
7681
private:
7782
void OnFirstCall();
7883

@@ -128,6 +133,7 @@ namespace NCatboostCuda {
128133

129134
TVector<TString> MetricDescriptions;
130135
TVector<bool> IsSkipOnTrainFlags;
136+
TVector<bool> IsSkipOnTestFlags;
131137
TVector<TVector<double>> BestTestCursor;
132138

133139
size_t Iteration = 0;

catboost/cuda/train_lib/train_template.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ namespace NCatboostCuda {
4444
TBoostingProgressTracker progressTracker(catBoostOptions,
4545
outputOptions,
4646
test != nullptr,
47+
/*testHasTarget*/ (test != nullptr) && test->MetaInfo.HasTarget,
4748
approxDimension,
4849
onEndIterationCallback);
4950

5051
boosting.SetBoostingProgressTracker(&progressTracker);
5152

5253
auto model = boosting.Run();
5354

54-
if (test) {
55+
if (progressTracker.EvalMetricWasCalculated()) {
5556
const auto& errorTracker = progressTracker.GetErrorTracker();
5657
CATBOOST_NOTICE_LOG << "bestTest = " << errorTracker.GetBestError() << Endl;
5758
CATBOOST_NOTICE_LOG << "bestIteration = " << errorTracker.GetBestIteration() << Endl;
@@ -62,6 +63,9 @@ namespace NCatboostCuda {
6263
if (outputOptions.ShrinkModelToBestIteration()) {
6364
if (test == nullptr) {
6465
CATBOOST_INFO_LOG << "Warning: can't use-best-model without test set. Will skip model shrinking";
66+
} else if (!progressTracker.EvalMetricWasCalculated()) {
67+
CATBOOST_INFO_LOG << "Warning: can't use-best-model because eval metric was not calculated "
68+
"due to the absence of target data in test set. Will skip model shrinking";
6569
} else {
6670
const auto& errorTracker = progressTracker.GetErrorTracker();
6771
const auto& bestModelTracker = progressTracker.GetBestModelMinTreesTracker();

catboost/libs/algo/helpers.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,33 @@ void CalcErrors(
132132
}
133133
const auto& targetData = testDataPtr->TargetData;
134134

135-
auto target = GetMaybeTarget(targetData).GetOrElse(TConstArrayRef<float>());
135+
auto maybeTarget = GetMaybeTarget(targetData);
136+
auto target = maybeTarget.GetOrElse(TConstArrayRef<float>());
136137
auto weights = GetWeights(targetData);
137138
auto queryInfo = GetGroupInfo(targetData);
138139

139140
const auto& testApprox = ctx->LearnProgress.TestApprox[testIdx];
140141
for (int i = 0; i < errors.ysize(); ++i) {
141-
if (calcAllMetrics || i == errorTrackerMetricIdx) {
142-
const auto& additiveStats = EvalErrors(
143-
testApprox,
144-
target,
145-
weights,
146-
queryInfo,
147-
errors[i],
148-
ctx->LocalExecutor
149-
);
150-
bool updateBestIteration = (i == 0) && (testIdx == trainingDataProviders.Test.size() - 1);
151-
ctx->LearnProgress.MetricsAndTimeHistory.AddTestError(testIdx,
152-
*errors[i].Get(),
153-
errors[i]->GetFinalError(additiveStats),
154-
updateBestIteration);
142+
if (!calcAllMetrics && (i != errorTrackerMetricIdx)) {
143+
continue;
155144
}
145+
if (!maybeTarget && errors[i]->NeedTarget()) {
146+
continue;
147+
}
148+
149+
const auto& additiveStats = EvalErrors(
150+
testApprox,
151+
target,
152+
weights,
153+
queryInfo,
154+
errors[i],
155+
ctx->LocalExecutor
156+
);
157+
bool updateBestIteration = (i == 0) && (testIdx == trainingDataProviders.Test.size() - 1);
158+
ctx->LearnProgress.MetricsAndTimeHistory.AddTestError(testIdx,
159+
*errors[i].Get(),
160+
errors[i]->GetFinalError(additiveStats),
161+
updateBestIteration);
156162
}
157163
}
158164
}

catboost/libs/data_new/target.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,9 @@ TBinClassTarget::TBinClassTarget(
768768
)
769769
{
770770
if (!skipCheck) {
771-
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
771+
if (target) {
772+
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
773+
}
772774
CheckDataSize(weights->GetSize(), GetObjectCount(), "weights");
773775
CheckMaybeEmptyBaseline(baseline, GetObjectCount());
774776
}
@@ -779,7 +781,9 @@ TBinClassTarget::TBinClassTarget(
779781

780782

781783
void TBinClassTarget::GetSourceDataForSubsetCreation(TSubsetTargetDataCache* subsetTargetDataCache) const {
782-
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
784+
if (Target) {
785+
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
786+
}
783787
subsetTargetDataCache->Weights.emplace(Weights, TSharedWeights<float>());
784788
if (Baseline) {
785789
subsetTargetDataCache->Baselines.emplace(Baseline, TSharedVector<float>());
@@ -793,7 +797,7 @@ TTargetDataProviderPtr TBinClassTarget::GetSubset(
793797
return MakeIntrusive<TBinClassTarget>(
794798
GetSpecification().Description,
795799
std::move(objectsGrouping),
796-
subsetTargetDataCache.Targets.at(Target),
800+
Target ? subsetTargetDataCache.Targets.at(Target) : Target,
797801
subsetTargetDataCache.Weights.at(Weights),
798802

799803
// reuse empty vector
@@ -875,7 +879,9 @@ TMultiClassTarget::TMultiClassTarget(
875879
classCount >= 2,
876880
"MultiClass target data must have at least two classes (got " << classCount <<")"
877881
);
878-
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
882+
if (target) {
883+
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
884+
}
879885
CheckDataSize(weights->GetSize(), GetObjectCount(), "weights");
880886
CheckBaseline(baseline, GetObjectCount(), classCount);
881887
}
@@ -892,7 +898,9 @@ TMultiClassTarget::TMultiClassTarget(
892898

893899

894900
void TMultiClassTarget::GetSourceDataForSubsetCreation(TSubsetTargetDataCache* subsetTargetDataCache) const {
895-
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
901+
if (Target) {
902+
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
903+
}
896904
subsetTargetDataCache->Weights.emplace(Weights, TSharedWeights<float>());
897905
for (const auto& oneBaseline : Baseline) {
898906
subsetTargetDataCache->Baselines.emplace(oneBaseline, TSharedVector<float>());
@@ -913,7 +921,7 @@ TTargetDataProviderPtr TMultiClassTarget::GetSubset(
913921
GetSpecification().Description,
914922
std::move(objectsGrouping),
915923
ClassCount,
916-
subsetTargetDataCache.Targets.at(Target),
924+
Target ? subsetTargetDataCache.Targets.at(Target) : Target,
917925
subsetTargetDataCache.Weights.at(Weights),
918926
std::move(subsetBaseline),
919927
true
@@ -981,7 +989,9 @@ TRegressionTarget::TRegressionTarget(
981989
)
982990
{
983991
if (!skipCheck) {
984-
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
992+
if (target) {
993+
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
994+
}
985995
CheckDataSize(weights->GetSize(), GetObjectCount(), "weights");
986996
CheckMaybeEmptyBaseline(baseline, GetObjectCount());
987997
}
@@ -991,7 +1001,9 @@ TRegressionTarget::TRegressionTarget(
9911001
}
9921002

9931003
void TRegressionTarget::GetSourceDataForSubsetCreation(TSubsetTargetDataCache* subsetTargetDataCache) const {
994-
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
1004+
if (Target) {
1005+
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
1006+
}
9951007
subsetTargetDataCache->Weights.emplace(Weights, TSharedWeights<float>());
9961008
if (Baseline) {
9971009
subsetTargetDataCache->Baselines.emplace(Baseline, TSharedVector<float>());
@@ -1005,7 +1017,7 @@ TTargetDataProviderPtr TRegressionTarget::GetSubset(
10051017
return MakeIntrusive<TRegressionTarget>(
10061018
GetSpecification().Description,
10071019
std::move(objectsGrouping),
1008-
subsetTargetDataCache.Targets.at(Target),
1020+
Target ? subsetTargetDataCache.Targets.at(Target) : Target,
10091021
subsetTargetDataCache.Weights.at(Weights),
10101022

10111023
// reuse empty vector
@@ -1054,7 +1066,9 @@ TGroupwiseRankingTarget::TGroupwiseRankingTarget(
10541066
)
10551067
{
10561068
if (!skipCheck) {
1057-
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
1069+
if (target) {
1070+
CheckDataSize(target->size(), (size_t)GetObjectCount(), "target");
1071+
}
10581072
CheckDataSize(weights->GetSize(), GetObjectCount(), "weights");
10591073
CheckMaybeEmptyBaseline(baseline, GetObjectCount());
10601074
CheckGroupInfo(*groupInfo, *ObjectsGrouping, false);
@@ -1066,7 +1080,9 @@ TGroupwiseRankingTarget::TGroupwiseRankingTarget(
10661080
}
10671081

10681082
void TGroupwiseRankingTarget::GetSourceDataForSubsetCreation(TSubsetTargetDataCache* subsetTargetDataCache) const {
1069-
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
1083+
if (Target) {
1084+
subsetTargetDataCache->Targets.emplace(Target, TSharedVector<float>());
1085+
}
10701086
subsetTargetDataCache->Weights.emplace(Weights, TSharedWeights<float>());
10711087
if (Baseline) {
10721088
subsetTargetDataCache->Baselines.emplace(Baseline, TSharedVector<float>());
@@ -1081,7 +1097,7 @@ TTargetDataProviderPtr TGroupwiseRankingTarget::GetSubset(
10811097
return MakeIntrusive<TGroupwiseRankingTarget>(
10821098
GetSpecification().Description,
10831099
std::move(objectsGrouping),
1084-
subsetTargetDataCache.Targets.at(Target),
1100+
Target ? subsetTargetDataCache.Targets.at(Target) : Target,
10851101
subsetTargetDataCache.Weights.at(Weights),
10861102

10871103
// reuse empty vector

0 commit comments

Comments
 (0)