Commit 44c831c3 authored by Michael Buehlmann's avatar Michael Buehlmann
Browse files

allow reading single source rank in py module

parent 36a6fac5
......@@ -17,14 +17,18 @@ namespace py = pybind11;
class PyGenericIO : public gio::GenericIO {
public:
PyGenericIO(const std::string& filename, gio::GenericIO::FileIO method=gio::GenericIO::FileIOPOSIX, gio::GenericIO::MismatchBehavior redistribute=gio::GenericIO::MismatchRedistribute)
PyGenericIO(
const std::string& filename,
gio::GenericIO::FileIO method=gio::GenericIO::FileIOPOSIX,
gio::GenericIO::MismatchBehavior redistribute=gio::GenericIO::MismatchRedistribute,
int eff_rank = -1)
#ifdef GENERICIO_NO_MPI
: gio::GenericIO(filename, method), num_ranks(0) {
#else
: gio::GenericIO(MPI_COMM_WORLD, filename, method), num_ranks(0) {
#endif
// open headers and rank info
openAndReadHeader(redistribute);
openAndReadHeader(redistribute, eff_rank);
num_ranks = readNRanks();
// read variable info
getVariableInfo(variables);
......@@ -67,8 +71,15 @@ public:
bool collective_stats=true,
int eff_rank=-1
) {
int rank;
#ifdef GENERICIO_NO_MPI
rank = 0;
#else
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
#endif
// read number of elements
int64_t num_elem = readNumElems();
int64_t num_elem = readNumElems(eff_rank);
// if no argument, read all
if(!var_names.has_value()) {
......@@ -108,8 +119,10 @@ public:
}
}
}
readData(eff_rank, print_stats, collective_stats);
clearVariables();
#ifndef GENERICIO_NO_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
......@@ -136,6 +149,12 @@ public:
return scale;
}
std::vector<int> get_source_ranks() {
std::vector<int> sr;
getSourceRanks(sr);
return sr;
}
private:
int num_ranks;
std::vector<gio::GenericIO::VariableInfo> variables;
......@@ -151,7 +170,7 @@ std::map<std::string, py::array> read_genericio(
bool rebalance_source_ranks=false,
int eff_rank=-1
) {
PyGenericIO reader(filename, method, redistribute);
PyGenericIO reader(filename, method, redistribute, eff_rank);
#ifndef GENERICIO_NO_MPI
if(rebalance_source_ranks)
reader.rebalanceSourceRanks();
......@@ -259,12 +278,12 @@ PYBIND11_MODULE(pygio, m) {
.def("read_phys_origin", &PyGenericIO::read_phys_origin)
.def("read_phys_scale", &PyGenericIO::read_phys_scale)
.def("read", &PyGenericIO::read,
py::arg("variables")=nullptr,
py::kw_only(),
py::arg("variables")=nullptr,
py::arg("print_stats")=true,
py::arg("collective_stats")=true,
py::arg("eff_rank")=-1)
.def("get_source_ranks", &PyGenericIO::getSourceRanks)
.def("get_source_ranks", &PyGenericIO::get_source_ranks)
#ifndef GENERICIO_NO_MPI
.def("rebalance_source_ranks", &PyGenericIO::rebalanceSourceRanks)
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment