Commit 8fea9a24 authored by Michael Buehlmann's avatar Michael Buehlmann
Browse files

add more class methods to py interface, whitespace

parent 44c831c3
......@@ -18,8 +18,8 @@ namespace py = pybind11;
class PyGenericIO : public gio::GenericIO {
public:
PyGenericIO(
const std::string& filename,
gio::GenericIO::FileIO method=gio::GenericIO::FileIOPOSIX,
const std::string& filename,
gio::GenericIO::FileIO method=gio::GenericIO::FileIOPOSIX,
gio::GenericIO::MismatchBehavior redistribute=gio::GenericIO::MismatchRedistribute,
int eff_rank = -1)
#ifdef GENERICIO_NO_MPI
......@@ -66,8 +66,8 @@ public:
}
std::map<std::string, py::array> read(
std::optional<std::vector<std::string>> var_names,
bool print_stats=true,
std::optional<std::vector<std::string>> var_names,
bool print_stats=true,
bool collective_stats=true,
int eff_rank=-1
) {
......@@ -80,7 +80,7 @@ public:
// read number of elements
int64_t num_elem = readNumElems(eff_rank);
// if no argument, read all
if(!var_names.has_value()) {
var_names.emplace(std::vector<std::string>());
......@@ -94,8 +94,8 @@ public:
for(const std::string& var_name: *var_names) {
auto varp = std::find_if(
variables.begin(),
variables.end(),
variables.begin(),
variables.end(),
[&var_name](const auto& v){ return v.Name == var_name; }
);
if (varp != variables.end()) {
......@@ -122,7 +122,7 @@ public:
readData(eff_rank, print_stats, collective_stats);
clearVariables();
#ifndef GENERICIO_NO_MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
......@@ -161,10 +161,10 @@ private:
};
std::map<std::string, py::array> read_genericio(
std::string filename,
std::optional<std::vector<std::string>> var_names,
PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX,
PyGenericIO::MismatchBehavior redistribute=PyGenericIO::MismatchBehavior::MismatchRedistribute,
std::string filename,
std::optional<std::vector<std::string>> var_names,
PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX,
PyGenericIO::MismatchBehavior redistribute=PyGenericIO::MismatchBehavior::MismatchRedistribute,
bool print_stats=true,
bool collective_stats=true,
bool rebalance_source_ranks=false,
......@@ -179,8 +179,8 @@ std::map<std::string, py::array> read_genericio(
}
void inspect_genericio(
std::string filename,
PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX,
std::string filename,
PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX,
PyGenericIO::MismatchBehavior redistribute=PyGenericIO::MismatchBehavior::MismatchRedistribute
) {
PyGenericIO reader(filename, method, redistribute);
......@@ -189,9 +189,9 @@ void inspect_genericio(
#ifndef GENERICIO_NO_MPI
void write_genericio(
std::string filename,
std::map<std::string, py::array> variables,
std::array<double, 3> phys_scale, std::array<double, 3> phys_origin,
std::string filename,
std::map<std::string, py::array> variables,
std::array<double, 3> phys_scale, std::array<double, 3> phys_origin,
PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX
) {
// check data integrity, find particle count
......@@ -218,15 +218,15 @@ void write_genericio(
}
for(auto& [name, data]: variables) {
if(py::isinstance<py::array_t<float>>(data))
if(py::isinstance<py::array_t<float>>(data))
writer.addVariable(name.c_str(), reinterpret_cast<float*>(data.mutable_data()));
else if(py::isinstance<py::array_t<double>>(data))
else if(py::isinstance<py::array_t<double>>(data))
writer.addVariable(name.c_str(), reinterpret_cast<double*>(data.mutable_data()));
else if(py::isinstance<py::array_t<int32_t>>(data))
else if(py::isinstance<py::array_t<int32_t>>(data))
writer.addVariable(name.c_str(), reinterpret_cast<int32_t*>(data.mutable_data()));
else if(py::isinstance<py::array_t<int64_t>>(data))
else if(py::isinstance<py::array_t<int64_t>>(data))
writer.addVariable(name.c_str(), reinterpret_cast<int64_t*>(data.mutable_data()));
else if(py::isinstance<py::array_t<uint16_t>>(data))
else if(py::isinstance<py::array_t<uint16_t>>(data))
writer.addVariable(name.c_str(), reinterpret_cast<uint16_t*>(data.mutable_data()));
else
throw std::runtime_error("array dtype not supported for " + name);
......@@ -249,7 +249,7 @@ PYBIND11_MODULE(pygio, m) {
MPI_Initialized(&initialized);
if(!initialized) {
int level_provided;
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_SINGLE, &level_provided);
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_SINGLE, &level_provided);
}
});
#endif
......@@ -267,9 +267,9 @@ PYBIND11_MODULE(pygio, m) {
.value("MismatchRedistribute", PyGenericIO::MismatchBehavior::MismatchRedistribute);
pyGenericIO.def(
py::init<std::string, PyGenericIO::FileIO, PyGenericIO::MismatchBehavior>(),
py::arg("filename"),
py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX,
py::init<std::string, PyGenericIO::FileIO, PyGenericIO::MismatchBehavior>(),
py::arg("filename"),
py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX,
py::arg("redistribute")=PyGenericIO::MismatchBehavior::MismatchRedistribute)
.def("inspect", &PyGenericIO::inspect, "Print variable infos and size of GenericIO file")
.def("get_variables", &PyGenericIO::get_variables, "Get a list of VariableInformations defined in the GenericIO file")
......@@ -280,10 +280,12 @@ PYBIND11_MODULE(pygio, m) {
.def("read", &PyGenericIO::read,
py::arg("variables")=nullptr,
py::kw_only(),
py::arg("print_stats")=true,
py::arg("print_stats")=true,
py::arg("collective_stats")=true,
py::arg("eff_rank")=-1)
.def("get_source_ranks", &PyGenericIO::get_source_ranks)
.def("readGlobalRankNumber", &PyGenericIO::readGlobalRankNumber)
.def("readNRanks", &PyGenericIO::readNRanks)
#ifndef GENERICIO_NO_MPI
.def("rebalance_source_ranks", &PyGenericIO::rebalanceSourceRanks)
#endif
......@@ -299,30 +301,30 @@ PYBIND11_MODULE(pygio, m) {
(vi.IsFloat ? "float" : "int") + " name='" + vi.Name + "'>";
});
m.def("read_genericio", &read_genericio,
py::arg("filename"),
py::arg("variables")=nullptr,
m.def("read_genericio", &read_genericio,
py::arg("filename"),
py::arg("variables")=nullptr,
py::kw_only(),
py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX,
py::arg("redistribute")=PyGenericIO::MismatchBehavior::MismatchRedistribute,
py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX,
py::arg("redistribute")=PyGenericIO::MismatchBehavior::MismatchRedistribute,
py::arg("print_stats")=true,
py::arg("collective_stats")=true,
py::arg("rebalance_sourceranks")=false,
py::arg("eff_rank")=-1,
py::return_value_policy::move);
m.def("inspect_genericio", &inspect_genericio,
py::arg("filename"),
m.def("inspect_genericio", &inspect_genericio,
py::arg("filename"),
py::kw_only(),
py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX,
py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX,
py::arg("redistribute")=PyGenericIO::MismatchBehavior::MismatchRedistribute);
#ifndef GENERICIO_NO_MPI
m.def("write_genericio", &write_genericio,
py::arg("filename"),
py::arg("variables"),
py::arg("phys_scale"),
py::arg("phys_origin") = std::array<double, 3>({0., 0., 0.}),
m.def("write_genericio", &write_genericio,
py::arg("filename"),
py::arg("variables"),
py::arg("phys_scale"),
py::arg("phys_origin") = std::array<double, 3>({0., 0., 0.}),
py::kw_only(),
py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment