@@ -1013,19 +1013,38 @@ void TPCFastSpaceChargeCorrectionHelper::initInverse(std::vector<o2::gpu::TPCFas
10131013 LOGP (info, " Inverse tooks: {}s" , duration);
10141014}
10151015
1016- void TPCFastSpaceChargeCorrectionHelper::mergeCorrections (
1017- o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, float mainScale,
1018- const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, float >>& additionalCorrections, bool /* prn */ )
1016+ void TPCFastSpaceChargeCorrectionHelper::addCorrections (
1017+ o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, double mainScale,
1018+ const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, double >>& additionalCorrections)
10191019{
1020- // / merge several corrections
1020+ // / weighted add of several corrections
1021+ SectorScales mainSectorScale;
1022+ mainSectorScale.fill (mainScale);
1023+ std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>> additionalSectorScales;
1024+ for (const auto & corr : additionalCorrections) {
1025+ SectorScales sectorScale;
1026+ sectorScale.fill (corr.second );
1027+ additionalSectorScales.emplace_back (corr.first , sectorScale);
1028+ }
1029+
1030+ addCorrections (mainCorrection, mainSectorScale, additionalSectorScales);
1031+ }
1032+
1033+ void TPCFastSpaceChargeCorrectionHelper::addCorrections (
1034+ o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, SectorScales mainScale,
1035+ const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>>& additionalCorrections)
1036+ {
1037+ // / weighted add of several corrections
10211038
10221039 TStopwatch watch;
1023- LOG (info) << " fast space charge correction helper: Merge corrections" ;
1040+ LOG (info) << " fast space charge correction helper: Add corrections" ;
10241041
10251042 const auto & geo = mainCorrection.getGeometry ();
10261043
10271044 for (int sector = 0 ; sector < geo.getNumberOfSectors (); sector++) {
10281045
1046+ float secMainScale = mainScale[sector];
1047+
10291048 auto myThread = [&](int iThread) {
10301049 for (int row = iThread; row < geo.getNumberOfRows (); row += mNthreads ) {
10311050 auto & rowInfo = mainCorrection.getRowInfo (row);
@@ -1040,8 +1059,11 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10401059 constexpr int nKnotPar3d = nKnotPar1d * 3 ;
10411060
10421061 { // scale the main correction
1043-
1044- double parscale[4 ] = {mainScale, mainScale, mainScale, mainScale * mainScale};
1062+ for (int i = 0 ; i < 3 ; i++) {
1063+ secRowInfo.maxCorr [i] *= secMainScale;
1064+ secRowInfo.minCorr [i] *= secMainScale;
1065+ }
1066+ double parscale[4 ] = {secMainScale, secMainScale, secMainScale, secMainScale * secMainScale};
10451067 for (int iknot = 0 , ind = 0 ; iknot < spline.getNumberOfKnots (); iknot++) {
10461068 for (int ipar = 0 ; ipar < nKnotPar1d; ++ipar) {
10471069 for (int idim = 0 ; idim < 3 ; idim++, ind++) {
@@ -1074,6 +1096,10 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10741096 const auto & corr = *(additionalCorrections[icorr].first );
10751097 double scale = additionalCorrections[icorr].second ;
10761098 auto & linfo = corr.getRowInfo (row);
1099+ // double scale = additionalCorrections[icorr].second[sector];
1100+ // auto& linfo = corr.getSectorRowInfo(sector, row);
1101+ // secRowInfo.updateMaxValues(linfo.getMaxValues(), scale);
1102+ // secRowInfo.updateMaxValues(linfo.getMinValues(), scale);
10771103
10781104 double scaleU = rowInfo.gridMeasured .getYscale () / linfo.gridMeasured .getYscale ();
10791105 double scaleV = rowInfo.gridMeasured .getZscale () / linfo.gridMeasured .getZscale ();
@@ -1150,7 +1176,93 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
11501176 }
11511177
11521178 } // sector
1153- float duration = watch.RealTime ();
1179+ double duration = watch.RealTime ();
1180+ LOGP (info, " Merge of corrections tooks: {}s" , duration);
1181+ }
1182+
1183+ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections (o2::gpu::TPCFastSpaceChargeCorrection& destinationCorrection,
1184+ const o2::gpu::TPCFastSpaceChargeCorrection& sourceCorrection,
1185+ const std::vector<int >& sectors)
1186+ {
1187+ // / merge of two corrections sector-wise
1188+ TStopwatch watch;
1189+ LOG (info) << " fast space charge correction helper: Merge corrections" ;
1190+
1191+ const auto & geo = destinationCorrection.getGeometry ();
1192+
1193+ for (int sector : sectors) {
1194+ if (sector < 0 || sector >= geo.getNumberOfSectors ()) {
1195+ LOGP (fatal, " Invalid sector number {}. Valid range is [0, {})" , sector, geo.getNumberOfSectors ());
1196+ continue ;
1197+ }
1198+ auto myThread = [&](int iThread) {
1199+ for (int row = iThread; row < geo.getNumberOfRows (); row += mNthreads ) {
1200+
1201+ { // replace the direct correction
1202+ const auto & destSpline = destinationCorrection.getSpline (sector, row);
1203+ float * destSplineParameters = destinationCorrection.getCorrectionData (sector, row);
1204+ const auto & sourceSpline = sourceCorrection.getSpline (sector, row);
1205+ const float * sourceSplineParameters = sourceCorrection.getCorrectionData (sector, row);
1206+
1207+ // ensure the splines are compatible
1208+ if (destSpline.getGridX1 ().getNumberOfKnots () != sourceSpline.getGridX1 ().getNumberOfKnots () ||
1209+ destSpline.getGridX2 ().getNumberOfKnots () != sourceSpline.getGridX2 ().getNumberOfKnots ()) {
1210+ LOGP (error, " Splines for sector {} row {} are not compatible: number of knots in U or V direction do not match" , sector, row);
1211+ continue ;
1212+ }
1213+ // replace the destination correction with the source correction for this sector and row
1214+ memcpy (destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters () * sizeof (float ));
1215+ }
1216+
1217+ { // replace the inverse correction X
1218+ const auto & destSpline = destinationCorrection.getSplineInvX (sector, row);
1219+ float * destSplineParameters = destinationCorrection.getCorrectionDataInvX (sector, row);
1220+ const auto & sourceSpline = sourceCorrection.getSplineInvX (sector, row);
1221+ const float * sourceSplineParameters = sourceCorrection.getCorrectionDataInvX (sector, row);
1222+ // ensure the splines are compatible
1223+ if (destSpline.getGridX1 ().getNumberOfKnots () != sourceSpline.getGridX1 ().getNumberOfKnots () ||
1224+ destSpline.getGridX2 ().getNumberOfKnots () != sourceSpline.getGridX2 ().getNumberOfKnots ()) {
1225+ LOGP (error, " Inverse X splines for sector {} row {} are not compatible: number of knots in U or V direction do not match" , sector, row);
1226+ continue ;
1227+ }
1228+ memcpy (destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters () * sizeof (float ));
1229+ }
1230+
1231+ { // replace the inverse correction YZ
1232+ const auto & destSpline = destinationCorrection.getSplineInvYZ (sector, row);
1233+ float * destSplineParameters = destinationCorrection.getCorrectionDataInvYZ (sector, row);
1234+ const auto & sourceSpline = sourceCorrection.getSplineInvYZ (sector, row);
1235+ const float * sourceSplineParameters = sourceCorrection.getCorrectionDataInvYZ (sector, row);
1236+ // ensure the splines are compatible
1237+ if (destSpline.getGridX1 ().getNumberOfKnots () != sourceSpline.getGridX1 ().getNumberOfKnots () ||
1238+ destSpline.getGridX2 ().getNumberOfKnots () != sourceSpline.getGridX2 ().getNumberOfKnots ()) {
1239+ LOGP (error, " Inverse YZ splines for sector {} row {} are not compatible: number of knots in U or V direction do not match" , sector, row);
1240+ continue ;
1241+ }
1242+ memcpy (destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters () * sizeof (float ));
1243+ }
1244+
1245+ // replace the sector row info
1246+ auto & destSecRowInfo = destinationCorrection.getSectorRowInfo (sector, row);
1247+ const auto & sourceSecRowInfo = sourceCorrection.getSectorRowInfo (sector, row);
1248+ destSecRowInfo = sourceSecRowInfo;
1249+ } // row
1250+ }; // thread
1251+
1252+ std::vector<std::thread> threads (mNthreads );
1253+
1254+ // run n threads
1255+ for (int i = 0 ; i < mNthreads ; i++) {
1256+ threads[i] = std::thread (myThread, i);
1257+ }
1258+
1259+ // wait for the threads to finish
1260+ for (auto & th : threads) {
1261+ th.join ();
1262+ }
1263+
1264+ } // sector
1265+ double duration = watch.RealTime ();
11541266 LOGP (info, " Merge of corrections tooks: {}s" , duration);
11551267}
11561268
0 commit comments