diff --git a/aux/src/SamplingHelpers.cxx b/aux/src/SamplingHelpers.cxx index 361782a5a..16f0a31a8 100644 --- a/aux/src/SamplingHelpers.cxx +++ b/aux/src/SamplingHelpers.cxx @@ -127,7 +127,7 @@ void Aux::fill_scalar_center(PointCloud::Dataset& scalar, const PointCloud::Data void Aux::fill_scalar_aux(PointCloud::Dataset& scalar, const PointCloud::Dataset& aux) { if (aux.empty()) { - raise("empty 'aux' PC. you probably fell victim to issue #426"); + return; // Skip scalar aux fill if aux point cloud is empty } const std::vector auxnames = { "max_wire_interval", "min_wire_interval", "max_wire_type", "min_wire_type", diff --git a/clus/src/Facade_Cluster.cxx b/clus/src/Facade_Cluster.cxx index 0b8f2a668..4bcc022a5 100644 --- a/clus/src/Facade_Cluster.cxx +++ b/clus/src/Facade_Cluster.cxx @@ -199,18 +199,48 @@ std::vector Cluster::add_corrected_points( std::vector blob_passed; blob_passed.resize(children().size(), 0); // not passed by default - if (correction_name == "T0Correction") { - const auto& pct = pcts->pc_transform("T0Correction"); + if (correction_name == "T0Correction" || correction_name == "SCECorrection" || correction_name == "T0SCECorrection") { + + const auto& pct = pcts->pc_transform(correction_name); + + std::vector out_coords; + std::vector scope_coords; + + if (correction_name == "T0Correction") { + // Preserve historical behavior: only x is replaced by the T0-corrected coordinate. + out_coords = {"x_t0cor", "y_t0cor", "z_t0cor"}; + scope_coords = {"x_t0cor", "y", "z"}; + } + else if (correction_name == "SCECorrection") { + // SCE is a full 3D spatial correction. + out_coords = {"x_scecor", "y_scecor", "z_scecor"}; + scope_coords = out_coords; + } + else { // T0SCECorrection + // Combined correction: T0 first, then full 3D SCE correction. + out_coords = {"x_t0scecor", "y_t0scecor", "z_t0scecor"}; + scope_coords = out_coords; + } + for (size_t iblob = 0; iblob < this->children().size(); ++iblob) { Blob* blob = this->children().at(iblob); auto &lpc_3d = blob->local_pcs().at("3d"); + auto corrected_points = pct->forward(lpc_3d, {"x", "y", "z"}, - {"x_t0cor","y_t0cor","z_t0cor"}, t0, + out_coords, t0, blob->wpid().face(), blob->wpid().apa()); - lpc_3d.add("x_t0cor", *corrected_points.get("x_t0cor")); // only add x_t0cor + + // Add only the arrays that will actually be used by the new scope. + for (const auto& coord : scope_coords) { + if (coord != "x" && coord != "y" && coord != "z") { + lpc_3d.add(coord, *corrected_points.get(coord)); + } + } + auto filter_result = pct->filter(corrected_points, - {"x_t0cor", "y_t0cor", "z_t0cor"}, + out_coords, t0, blob->wpid().face(), blob->wpid().apa()); + auto arr_filter = filter_result.get("filter")->elements(); for (size_t ipt = 0; ipt < arr_filter.size(); ++ipt) { if (arr_filter[ipt] == 1) { @@ -219,8 +249,9 @@ std::vector Cluster::add_corrected_points( } } } - // the new scope should have the same name as the correction name. This is how the code can find corrections in the code ... - m_scopes["T0Correction"] = {"3d", {"x_t0cor", "y", "z"}}; // add the new scope + + // The new scope should have the same name as the correction name. + m_scopes[correction_name] = {"3d", scope_coords}; } else { raise("Cluster::add_corrected_points: no such correction: %s", correction_name); } diff --git a/clus/src/PCTransforms.cxx b/clus/src/PCTransforms.cxx index bd2d995df..99efd78b1 100644 --- a/clus/src/PCTransforms.cxx +++ b/clus/src/PCTransforms.cxx @@ -4,6 +4,12 @@ // // detector_volumes which defaults to "DetectorVolumes". // +#include +#include + +#include +#include +#include #include "WireCellClus/IPCTransform.h" #include "WireCellIface/IDetectorVolumes.h" @@ -150,6 +156,275 @@ class T0Correction : public WireCell::Clus::IPCTransform std::map> m_drift_speeds; }; +class SCECorrection : public WireCell::Clus::IPCTransform +{ +public: + virtual ~SCECorrection() = default; + + SCECorrection(IDetectorVolumes::pointer dv, + const std::string& sce_file, + double cathode_eps = 2.5) + : m_dv(dv) + , m_cathode_eps(cathode_eps) + { + std::unique_ptr infile(TFile::Open(sce_file.c_str(), "READ")); + if (!infile || infile->IsZombie()) { + throw std::runtime_error("SCECorrection: failed to open SCE file: " + sce_file); + } + + m_bkwd_e[0] = must_clone_hist(infile.get(), "TrueBkwd_Displacement_X_E"); + m_bkwd_e[1] = must_clone_hist(infile.get(), "TrueBkwd_Displacement_Y_E"); + m_bkwd_e[2] = must_clone_hist(infile.get(), "TrueBkwd_Displacement_Z_E"); + + m_bkwd_w[0] = must_clone_hist(infile.get(), "TrueBkwd_Displacement_X_W"); + m_bkwd_w[1] = must_clone_hist(infile.get(), "TrueBkwd_Displacement_Y_W"); + m_bkwd_w[2] = must_clone_hist(infile.get(), "TrueBkwd_Displacement_Z_W"); + } + + virtual Point forward(const Point& pos_in, double cluster_t0, int face, int apa) const override + { + auto dpos = cal_offset(pos_in, face, apa); + Point pos_out(pos_in); + pos_out[0] += dpos[0]; + pos_out[1] += dpos[1]; + pos_out[2] += dpos[2]; + return pos_out; + } + + virtual Point backward(const Point& pos_in, double cluster_t0, int face, int apa) const override + { + // First-pass approximate inverse. + auto dpos = cal_offset(pos_in, face, apa); + Point pos_out(pos_in); + pos_out[0] -= dpos[0]; + pos_out[1] -= dpos[1]; + pos_out[2] -= dpos[2]; + return pos_out; + } + + virtual bool filter(const Point& pos_corr, double cluster_t0, int face, int apa) const override + { + auto wpid = m_dv->contained_by(pos_corr); + if (!wpid.valid()) return false; + return (wpid.apa() == apa && wpid.face() == face); + } + + virtual Dataset forward(const Dataset& pc_in, + const std::vector& arr_in_names, + const std::vector& arr_out_names, + double cluster_t0, int face, int apa) const override + { + const auto& arr_x = pc_in.get(arr_in_names[0])->elements(); + const auto& arr_y = pc_in.get(arr_in_names[1])->elements(); + const auto& arr_z = pc_in.get(arr_in_names[2])->elements(); + + std::vector arr_x_out(arr_x.size()); + std::vector arr_y_out(arr_y.size()); + std::vector arr_z_out(arr_z.size()); + + for (size_t i = 0; i < arr_x.size(); ++i) { + Point pin(arr_x[i], arr_y[i], arr_z[i]); + auto dpos = cal_offset(pin, face, apa); + arr_x_out[i] = arr_x[i] + dpos[0]; + arr_y_out[i] = arr_y[i] + dpos[1]; + arr_z_out[i] = arr_z[i] + dpos[2]; + } + + Dataset ds_out; + ds_out.add(arr_out_names[0], Array(arr_x_out)); + ds_out.add(arr_out_names[1], Array(arr_y_out)); + ds_out.add(arr_out_names[2], Array(arr_z_out)); + return ds_out; + } + + virtual Dataset backward(const Dataset& pc_in, + const std::vector& arr_in_names, + const std::vector& arr_out_names, + double cluster_t0, int face, int apa) const override + { + const auto& arr_x = pc_in.get(arr_in_names[0])->elements(); + const auto& arr_y = pc_in.get(arr_in_names[1])->elements(); + const auto& arr_z = pc_in.get(arr_in_names[2])->elements(); + + std::vector arr_x_out(arr_x.size()); + std::vector arr_y_out(arr_y.size()); + std::vector arr_z_out(arr_z.size()); + + for (size_t i = 0; i < arr_x.size(); ++i) { + Point pin(arr_x[i], arr_y[i], arr_z[i]); + auto dpos = cal_offset(pin, face, apa); + arr_x_out[i] = arr_x[i] - dpos[0]; + arr_y_out[i] = arr_y[i] - dpos[1]; + arr_z_out[i] = arr_z[i] - dpos[2]; + } + + Dataset ds_out; + ds_out.add(arr_out_names[0], Array(arr_x_out)); + ds_out.add(arr_out_names[1], Array(arr_y_out)); + ds_out.add(arr_out_names[2], Array(arr_z_out)); + return ds_out; + } + + virtual Dataset filter(const Dataset& pc_in, + const std::vector& arr_names, + double cluster_t0, int face, int apa) const override + { + const auto& arr_x = pc_in.get(arr_names[0])->elements(); + const auto& arr_y = pc_in.get(arr_names[1])->elements(); + const auto& arr_z = pc_in.get(arr_names[2])->elements(); + + std::vector arr_filter(arr_x.size(), 0); + for (size_t i = 0; i < arr_x.size(); ++i) { + auto wpid = m_dv->contained_by(Point(arr_x[i], arr_y[i], arr_z[i])); + if (wpid.valid() && wpid.apa() == apa && wpid.face() == face) { + arr_filter[i] = 1; + } + } + + Dataset ds; + ds.add("filter", Array(arr_filter)); + return ds; + } + +private: + IDetectorVolumes::pointer m_dv; + double m_cathode_eps{2.5}; + + TH3F* m_bkwd_e[3] = {nullptr, nullptr, nullptr}; + TH3F* m_bkwd_w[3] = {nullptr, nullptr, nullptr}; + + static TH3F* must_clone_hist(TFile* tf, const char* name) + { + auto* h = dynamic_cast(tf->Get(name)); + if (!h) { + throw std::runtime_error(std::string("SCECorrection: missing histogram: ") + name); + } + auto* hc = dynamic_cast(h->Clone()); + hc->SetDirectory(nullptr); + return hc; + } + + Point cal_offset(const Point& pin, int face, int apa) const + { + // SCE map axes are in cm; WCT internal length is mm. Convert. + const double mm_per_cm = 10.0; + + // === apa1 SCE fix: transform local-frame pin[0] back to global X === + // Use bounding-box check to detect local-frame points and reflect them + // around the anode plane. + WirePlaneId wpid_q(kAllLayers, face, apa); + const auto bb_q = m_dv->inner_bounds(wpid_q); + const auto bbmin = bb_q.bounds().first; + const auto bbmax = bb_q.bounds().second; + const double bb_x_min = bbmin.x(); + const double bb_x_max = bbmax.x(); + + double pin_x_mm = pin[0]; + if (pin_x_mm < bb_x_min - 50.0 || pin_x_mm > bb_x_max + 50.0) { + // pin[0] is outside the apa's sensitive volume X range (with 5cm tolerance). + // It is in a mirrored local frame; reflect around the appropriate anode plane. + const auto dirx_q = m_dv->face_dirx(wpid_q); + const double anode_x_mm = (dirx_q < 0) ? bb_x_max : bb_x_min; + pin_x_mm = 2.0 * anode_x_mm - pin_x_mm; + } + + double xx = pin_x_mm / mm_per_cm; + double yy = pin[1] / mm_per_cm; + double zz = pin[2] / mm_per_cm; + + if (xx < -199.999) xx = -199.999; + else if (xx > 199.999) xx = 199.999; + + if (yy < -199.999) yy = -199.999; + else if (yy > 199.999) yy = 199.999; + + if (zz < 0.001) zz = 0.001; + else if (zz > 499.999) zz = 499.999; + + if (std::abs(xx) < m_cathode_eps) { + const auto dirx = m_dv->face_dirx(WirePlaneId(kAllLayers, face, apa)); + xx = (dirx < 0 ? -m_cathode_eps : m_cathode_eps); + } + + TH3F** hs = (xx < 0.0) ? const_cast(m_bkwd_e) + : const_cast(m_bkwd_w); + + double dx_mm = hs[0]->Interpolate(xx, yy, zz) * mm_per_cm; + const double dy_mm = hs[1]->Interpolate(xx, yy, zz) * mm_per_cm; + const double dz_mm = hs[2]->Interpolate(xx, yy, zz) * mm_per_cm; + + if (pin_x_mm != pin[0]) { + dx_mm = -dx_mm; + } + + return Point(dx_mm, dy_mm, dz_mm); + } +}; + +class T0SCECorrection : public WireCell::Clus::IPCTransform +{ +public: + virtual ~T0SCECorrection() = default; + + T0SCECorrection(IDetectorVolumes::pointer dv, + const std::string& sce_file, + double cathode_eps = 2.5) + : m_t0(dv) + , m_sce(dv, sce_file, cathode_eps) + {} + + virtual Point forward(const Point& pos_in, double cluster_t0, int face, int apa) const override + { + auto p1 = m_t0.forward(pos_in, cluster_t0, face, apa); + return m_sce.forward(p1, cluster_t0, face, apa); + } + + virtual Point backward(const Point& pos_in, double cluster_t0, int face, int apa) const override + { + auto p1 = m_sce.backward(pos_in, cluster_t0, face, apa); + return m_t0.backward(p1, cluster_t0, face, apa); + } + + virtual bool filter(const Point& pos_corr, double cluster_t0, int face, int apa) const override + { + return m_sce.filter(pos_corr, cluster_t0, face, apa); + } + + virtual Dataset forward(const Dataset& pc_in, + const std::vector& arr_in_names, + const std::vector& arr_out_names, + double cluster_t0, int face, int apa) const override + { + auto tmp = m_t0.forward(pc_in, arr_in_names, arr_out_names, cluster_t0, face, apa); + return m_sce.forward(tmp, arr_out_names, arr_out_names, cluster_t0, face, apa); + } + + virtual Dataset backward(const Dataset& pc_in, + const std::vector& arr_in_names, + const std::vector& arr_out_names, + double cluster_t0, int face, int apa) const override + { + auto tmp = m_sce.backward(pc_in, arr_in_names, arr_out_names, cluster_t0, face, apa); + return m_t0.backward(tmp, arr_out_names, arr_out_names, cluster_t0, face, apa); + } + + virtual Dataset filter(const Dataset& pc_in, + const std::vector& arr_names, + double cluster_t0, int face, int apa) const override + { + return m_sce.filter(pc_in, arr_names, cluster_t0, face, apa); + } + +private: + T0Correction m_t0; + SCECorrection m_sce; +}; + + + + + + class PCTransformSet : public WireCell::Clus::IPCTransformSet, public WireCell::IConfigurable { @@ -161,13 +436,26 @@ class PCTransformSet : public WireCell::Clus::IPCTransformSet, virtual Configuration default_configuration() const { Configuration cfg; cfg["detector_volumes"] = "DetectorVolumes"; + cfg["enable_sce_correction"] = false; + cfg["sce_file"] = ""; + cfg["sce_cathode_eps"] = 2.5; return cfg; } virtual void configure(const Configuration& cfg) { std::string dvtn = get(cfg, "detector_volumes", "DetectorVolumes"); auto dv = Factory::find_tn(dvtn); m_pcts["T0Correction"] = std::make_shared(dv); - // ... + const bool enable_sce = get(cfg, "enable_sce_correction", false); + const std::string sce_file = get(cfg, "sce_file", ""); + const double sce_cathode_eps = get(cfg, "sce_cathode_eps", 2.5); + + if (enable_sce) { + if (sce_file.empty()) { + throw std::runtime_error("PCTransformSet: enable_sce_correction=true but sce_file is empty"); + } + m_pcts["SCECorrection"] = std::make_shared(dv, sce_file, sce_cathode_eps); + m_pcts["T0SCECorrection"] = std::make_shared(dv, sce_file, sce_cathode_eps); + } } virtual IPCTransform::pointer pc_transform(const std::string &name) const { diff --git a/clus/src/clustering_switch_scope.cxx b/clus/src/clustering_switch_scope.cxx index d64042bfc..4fe9f4136 100644 --- a/clus/src/clustering_switch_scope.cxx +++ b/clus/src/clustering_switch_scope.cxx @@ -74,7 +74,7 @@ static void clustering_switch_scope( for (size_t iclus = 0; iclus < live_clusters.size(); ++iclus) { Cluster* cluster = live_clusters.at(iclus); - if (correction_name == "T0Correction") { + if (correction_name == "T0Correction" || correction_name == "SCECorrection" || correction_name == "T0SCECorrection") { // Get original bounds before correction // info("Cluster {} original bounds:", iclus); // const auto [earliest_orig, latest_orig] = cluster->get_earliest_latest_points(); diff --git a/clus/wscript_build b/clus/wscript_build index 782176c83..60330713a 100644 --- a/clus/wscript_build +++ b/clus/wscript_build @@ -1,2 +1,2 @@ -bld.smplpkg('WireCellClus', use='WireCellAux WCPQuickhull') +bld.smplpkg('WireCellClus', use='WireCellAux WCPQuickhull ROOTSYS')