-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
49 lines (39 loc) · 1.47 KB
/
main.cpp
File metadata and controls
49 lines (39 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <lambdamart/lambdamart.h>
using namespace std;
using namespace LambdaMART;
void demo(Config* config) {
const char* train = config->train_data.c_str();
const char* train_query = config->train_query.c_str();
Log::Info("Loading training dataset %s and query boundaries %s", train, train_query);
auto* X_train = new Dataset(config);
X_train->load_dataset(train, train_query);
RawDataset* X_valid = nullptr;
if (!config->valid_data.empty()) {
const char* vali = config->valid_data.c_str();
const char* vali_query = config->valid_query.c_str();
Log::Info("Loading validation dataset %s and query boundaries %s", vali, vali_query);
X_valid = new RawDataset();
X_valid->load_dataset(vali, vali_query);
} else {
Log::Info("No validation dataset");
}
Log::Info("Start training...");
Model* model = (new Booster(X_train, X_valid, config))->train();
Log::Info("Training finished.");
Log::Info("Predicting with validation dataset and saving output to %s", config->output_result.c_str());
vector<double> predictions = model->predict(X_valid, config->output_result);
}
int main(int argc, char** argv) {
cout << version() << endl;
if (argc <= 1) {
cout << help() << endl;
exit(0);
}
Log::Info("Using configuration file %s", argv[1]);
auto* config = new Config(argv[1]);
demo(config);
return 0;
}