Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions autotest/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
from pathlib import Path
from pprint import pformat

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -38,6 +39,7 @@
)
from flopy.modflow import Modflow, ModflowDis
from flopy.modpath import Modpath6, Modpath6Bas
from flopy.plot.plotutil import to_prt_pathlines
from flopy.utils import (
CellBudgetFile,
HeadFile,
Expand Down Expand Up @@ -1508,8 +1510,6 @@ def test_vtk_unstructured(function_tmpdir, unstructured_grid):

@requires_pkg("vtk", "pyvista")
def test_vtk_to_pyvista(function_tmpdir):
from pprint import pformat

from autotest.test_mp7_cases import Mp7Cases

case_mf6 = Mp7Cases.mp7_mf6(function_tmpdir)
Expand All @@ -1529,13 +1529,29 @@ def test_vtk_to_pyvista(function_tmpdir):
assert grid.n_cells == gwf.modelgrid.nnodes

vtk.add_pathline_points(pls)
grid, pathlines = vtk.to_pyvista()
grid, mp7_pls = vtk.to_pyvista()
n_pts = sum(pl.shape[0] for pl in pls)
assert mp7_pls.n_points == n_pts
assert mp7_pls.n_cells == n_pts + len(pls)
assert "particleid" in mp7_pls.point_data
assert "time" in mp7_pls.point_data
assert "k" in mp7_pls.point_data

vtk = Vtk(model=gwf, binary=True, smooth=False)
assert not any(vtk.to_pyvista())

prt_pathlines = to_prt_pathlines(np.hstack(pls).view(np.recarray))

vtk.add_model(gwf)
vtk.add_pathline_points(prt_pathlines)
grid, prt_pls = vtk.to_pyvista()
n_pts = sum(pl.shape[0] for pl in pls)
assert pathlines.n_points == n_pts
assert pathlines.n_cells == n_pts + len(pls)
assert "particleid" in pathlines.point_data
assert "time" in pathlines.point_data
assert "k" in pathlines.point_data
assert prt_pls.n_points == n_pts
assert prt_pls.n_cells == n_pts + len(pls)
assert "imdl" in prt_pls.point_data
assert "iprp" in prt_pls.point_data
assert "irpt" in prt_pls.point_data
assert "trelease" in prt_pls.point_data

# uncomment to debug
# grid.plot()
Expand Down
22 changes: 11 additions & 11 deletions flopy/export/vtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,17 +1109,17 @@ def add_pathline_points(self, pathlines, timeseries=False):
pids = np.unique(pathlines.particleid)
pathlines = [pathlines[pathlines.particleid == pid] for pid in pids]
elif all(k in pathlines.dtype.names for k in prt_fields):
pls = []
for imdl in np.unique(pathlines.imdl):
for iprp in np.unique(pathlines.iprp):
for irpt in np.unique(pathlines.irpt):
pl = pathlines[
(pathlines.imdl == imdl)
& (pathlines.iprp == iprp)
& (pathlines.irpt == irpt)
]
pls.extend([pl[pl.trelease == t] for t in np.unique(pl.t)])
pathlines = pls
# particle composite key
keys = np.column_stack(
[
pathlines["imdl"],
pathlines["iprp"],
pathlines["irpt"],
pathlines["trelease"],
]
)
_, inv = np.unique(keys, axis=0, return_inverse=True)
pathlines = [pathlines[inv == i] for i in range(inv.max() + 1)]
else:
raise ValueError("Unrecognized pathline dtype")
else:
Expand Down
Loading