Skip to content

Commit e24b99a

Browse files
cbmswmatthias-kleiner
authored andcommitted
TPC splines: add a possibility to merge specific sectors
1 parent 3874e31 commit e24b99a

2 files changed

Lines changed: 144 additions & 12 deletions

File tree

Detectors/TPC/calibration/include/TPCCalibration/TPCFastSpaceChargeCorrectionHelper.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ using namespace o2::gpu;
4141

4242
class TPCFastSpaceChargeCorrectionHelper
4343
{
44+
public:
45+
using SectorScales = std::array<double, TPCFastTransformGeo::getNumberOfSectors()>;
46+
4447
public:
4548
/// _____________ Constructors / destructors __________________________
4649

@@ -115,15 +118,32 @@ class TPCFastSpaceChargeCorrectionHelper
115118
/// initialise inverse transformation from linear combination of several input corrections
116119
void initInverse(std::vector<o2::gpu::TPCFastSpaceChargeCorrection*>& corrections, const std::vector<float>& scaling, bool prn);
117120

118-
/// merge several corrections
121+
/// weighted add of several corrections
119122
/// \param mainCorrection main correction
120123
/// \param scale scaling factor for the main correction
121124
/// \param additionalCorrections vector of pairs of additional corrections and their scaling factors
122-
/// \param prn printout flag
123125
/// \return main correction merged with additional corrections
126+
void addCorrections(
127+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, double scale,
128+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, double>>& additionalCorrections);
129+
130+
/// weighted add of several corrections with sector-dependent scaling factors
131+
/// \param mainCorrection main correction
132+
/// \param scale scaling factor for the main correction
133+
/// \param additionalCorrections vector of pairs of additional corrections and their sector-dependent scaling factors
134+
/// \return main correction merged with additional corrections
135+
void addCorrections(
136+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, SectorScales scale,
137+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>>& additionalCorrections);
138+
139+
/// merge of two corrections sector-wise
140+
/// \param destinationCorrection main correction to which the source correction will be added
141+
/// \param sourceCorrection correction to be added to the main correction
142+
/// \param sectors vector of sector indices for which the correction will be added
143+
/// \return main correction merged with the source correction
124144
void mergeCorrections(
125-
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, float scale,
126-
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, float>>& additionalCorrections, bool prn);
145+
o2::gpu::TPCFastSpaceChargeCorrection& destinationCorrection, const o2::gpu::TPCFastSpaceChargeCorrection& sourceCorrection,
146+
const std::vector<int>& sectors);
127147

128148
/// how far the voxel mean is allowed to be outside of the voxel (1.1 means 10%)
129149
void setVoxelMeanValidityRange(double range)

Detectors/TPC/calibration/src/TPCFastSpaceChargeCorrectionHelper.cxx

Lines changed: 120 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,19 +1013,38 @@ void TPCFastSpaceChargeCorrectionHelper::initInverse(std::vector<o2::gpu::TPCFas
10131013
LOGP(info, "Inverse tooks: {}s", duration);
10141014
}
10151015

1016-
void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
1017-
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, float mainScale,
1018-
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, float>>& additionalCorrections, bool /*prn*/)
1016+
void TPCFastSpaceChargeCorrectionHelper::addCorrections(
1017+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, double mainScale,
1018+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, double>>& additionalCorrections)
10191019
{
1020-
/// merge several corrections
1020+
/// weighted add of several corrections
1021+
SectorScales mainSectorScale;
1022+
mainSectorScale.fill(mainScale);
1023+
std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>> additionalSectorScales;
1024+
for (const auto& corr : additionalCorrections) {
1025+
SectorScales sectorScale;
1026+
sectorScale.fill(corr.second);
1027+
additionalSectorScales.emplace_back(corr.first, sectorScale);
1028+
}
1029+
1030+
addCorrections(mainCorrection, mainSectorScale, additionalSectorScales);
1031+
}
1032+
1033+
void TPCFastSpaceChargeCorrectionHelper::addCorrections(
1034+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, SectorScales mainScale,
1035+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>>& additionalCorrections)
1036+
{
1037+
/// weighted add of several corrections
10211038

10221039
TStopwatch watch;
1023-
LOG(info) << "fast space charge correction helper: Merge corrections";
1040+
LOG(info) << "fast space charge correction helper: Add corrections";
10241041

10251042
const auto& geo = mainCorrection.getGeometry();
10261043

10271044
for (int sector = 0; sector < geo.getNumberOfSectors(); sector++) {
10281045

1046+
float secMainScale = mainScale[sector];
1047+
10291048
auto myThread = [&](int iThread) {
10301049
for (int row = iThread; row < geo.getNumberOfRows(); row += mNthreads) {
10311050
auto& rowInfo = mainCorrection.getRowInfo(row);
@@ -1040,8 +1059,11 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10401059
constexpr int nKnotPar3d = nKnotPar1d * 3;
10411060

10421061
{ // scale the main correction
1043-
1044-
double parscale[4] = {mainScale, mainScale, mainScale, mainScale * mainScale};
1062+
for (int i = 0; i < 3; i++) {
1063+
secRowInfo.maxCorr[i] *= secMainScale;
1064+
secRowInfo.minCorr[i] *= secMainScale;
1065+
}
1066+
double parscale[4] = {secMainScale, secMainScale, secMainScale, secMainScale * secMainScale};
10451067
for (int iknot = 0, ind = 0; iknot < spline.getNumberOfKnots(); iknot++) {
10461068
for (int ipar = 0; ipar < nKnotPar1d; ++ipar) {
10471069
for (int idim = 0; idim < 3; idim++, ind++) {
@@ -1074,6 +1096,10 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10741096
const auto& corr = *(additionalCorrections[icorr].first);
10751097
double scale = additionalCorrections[icorr].second;
10761098
auto& linfo = corr.getRowInfo(row);
1099+
// double scale = additionalCorrections[icorr].second[sector];
1100+
// auto& linfo = corr.getSectorRowInfo(sector, row);
1101+
// secRowInfo.updateMaxValues(linfo.getMaxValues(), scale);
1102+
// secRowInfo.updateMaxValues(linfo.getMinValues(), scale);
10771103

10781104
double scaleU = rowInfo.gridMeasured.getYscale() / linfo.gridMeasured.getYscale();
10791105
double scaleV = rowInfo.gridMeasured.getZscale() / linfo.gridMeasured.getZscale();
@@ -1150,7 +1176,93 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
11501176
}
11511177

11521178
} // sector
1153-
float duration = watch.RealTime();
1179+
double duration = watch.RealTime();
1180+
LOGP(info, "Merge of corrections tooks: {}s", duration);
1181+
}
1182+
1183+
void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(o2::gpu::TPCFastSpaceChargeCorrection& destinationCorrection,
1184+
const o2::gpu::TPCFastSpaceChargeCorrection& sourceCorrection,
1185+
const std::vector<int>& sectors)
1186+
{
1187+
/// merge of two corrections sector-wise
1188+
TStopwatch watch;
1189+
LOG(info) << "fast space charge correction helper: Merge corrections";
1190+
1191+
const auto& geo = destinationCorrection.getGeometry();
1192+
1193+
for (int sector : sectors) {
1194+
if (sector < 0 || sector >= geo.getNumberOfSectors()) {
1195+
LOGP(fatal, "Invalid sector number {}. Valid range is [0, {})", sector, geo.getNumberOfSectors());
1196+
continue;
1197+
}
1198+
auto myThread = [&](int iThread) {
1199+
for (int row = iThread; row < geo.getNumberOfRows(); row += mNthreads) {
1200+
1201+
{ // replace the direct correction
1202+
const auto& destSpline = destinationCorrection.getSpline(sector, row);
1203+
float* destSplineParameters = destinationCorrection.getCorrectionData(sector, row);
1204+
const auto& sourceSpline = sourceCorrection.getSpline(sector, row);
1205+
const float* sourceSplineParameters = sourceCorrection.getCorrectionData(sector, row);
1206+
1207+
// ensure the splines are compatible
1208+
if (destSpline.getGridX1().getNumberOfKnots() != sourceSpline.getGridX1().getNumberOfKnots() ||
1209+
destSpline.getGridX2().getNumberOfKnots() != sourceSpline.getGridX2().getNumberOfKnots()) {
1210+
LOGP(error, "Splines for sector {} row {} are not compatible: number of knots in U or V direction do not match", sector, row);
1211+
continue;
1212+
}
1213+
// replace the destination correction with the source correction for this sector and row
1214+
memcpy(destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters() * sizeof(float));
1215+
}
1216+
1217+
{ // replace the inverse correction X
1218+
const auto& destSpline = destinationCorrection.getSplineInvX(sector, row);
1219+
float* destSplineParameters = destinationCorrection.getCorrectionDataInvX(sector, row);
1220+
const auto& sourceSpline = sourceCorrection.getSplineInvX(sector, row);
1221+
const float* sourceSplineParameters = sourceCorrection.getCorrectionDataInvX(sector, row);
1222+
// ensure the splines are compatible
1223+
if (destSpline.getGridX1().getNumberOfKnots() != sourceSpline.getGridX1().getNumberOfKnots() ||
1224+
destSpline.getGridX2().getNumberOfKnots() != sourceSpline.getGridX2().getNumberOfKnots()) {
1225+
LOGP(error, "Inverse X splines for sector {} row {} are not compatible: number of knots in U or V direction do not match", sector, row);
1226+
continue;
1227+
}
1228+
memcpy(destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters() * sizeof(float));
1229+
}
1230+
1231+
{ // replace the inverse correction YZ
1232+
const auto& destSpline = destinationCorrection.getSplineInvYZ(sector, row);
1233+
float* destSplineParameters = destinationCorrection.getCorrectionDataInvYZ(sector, row);
1234+
const auto& sourceSpline = sourceCorrection.getSplineInvYZ(sector, row);
1235+
const float* sourceSplineParameters = sourceCorrection.getCorrectionDataInvYZ(sector, row);
1236+
// ensure the splines are compatible
1237+
if (destSpline.getGridX1().getNumberOfKnots() != sourceSpline.getGridX1().getNumberOfKnots() ||
1238+
destSpline.getGridX2().getNumberOfKnots() != sourceSpline.getGridX2().getNumberOfKnots()) {
1239+
LOGP(error, "Inverse YZ splines for sector {} row {} are not compatible: number of knots in U or V direction do not match", sector, row);
1240+
continue;
1241+
}
1242+
memcpy(destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters() * sizeof(float));
1243+
}
1244+
1245+
// replace the sector row info
1246+
auto& destSecRowInfo = destinationCorrection.getSectorRowInfo(sector, row);
1247+
const auto& sourceSecRowInfo = sourceCorrection.getSectorRowInfo(sector, row);
1248+
destSecRowInfo = sourceSecRowInfo;
1249+
} // row
1250+
}; // thread
1251+
1252+
std::vector<std::thread> threads(mNthreads);
1253+
1254+
// run n threads
1255+
for (int i = 0; i < mNthreads; i++) {
1256+
threads[i] = std::thread(myThread, i);
1257+
}
1258+
1259+
// wait for the threads to finish
1260+
for (auto& th : threads) {
1261+
th.join();
1262+
}
1263+
1264+
} // sector
1265+
double duration = watch.RealTime();
11541266
LOGP(info, "Merge of corrections tooks: {}s", duration);
11551267
}
11561268

0 commit comments

Comments
 (0)