Skip to content
2 changes: 1 addition & 1 deletion src/pecanpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def parse_args():
parser.add_argument(
"--random_state",
type=int,
default=None,
default=42,
help="Random seed for generating random walks.",
)

Expand Down
29 changes: 16 additions & 13 deletions src/pecanpy/pecanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from gensim.models import Word2Vec
from numba import njit
from numba import prange
from numba import set_num_threads
from numba_progress import ProgressBar

from .graph import BaseGraph
Expand All @@ -18,11 +19,6 @@
from .typing import Uint64Array
from .wrappers import Timer

try:
from numba.np.ufunc.parallel import get_thread_id
except ImportError: # numba<0.56
from numba.np.ufunc.parallel import _get_thread_id as get_thread_id


class Base(BaseGraph):
"""Base node2vec object.
Expand Down Expand Up @@ -88,12 +84,13 @@ def __init__(
verbose: bool = False,
extend: bool = False,
gamma: float = 0,
random_state: Optional[int] = None,
random_state: int = None,
):
super().__init__()
self.p = p
self.q = q
self.workers = workers # TODO: not doing anything, need to fix.
self.workers = workers
set_num_threads(workers)
self.verbose = verbose
self.extend = extend
self.gamma = gamma
Expand Down Expand Up @@ -144,12 +141,15 @@ def simulate_walks(
has_nbrs = self.get_has_nbrs()
verbose = self.verbose

# Create list of seeds
random_states = self._get_random_seeds(random_state, tot_num_jobs)

# Acquire numba progress proxy for displaying the progress bar
with ProgressBar(total=tot_num_jobs, disable=not verbose) as progress:
walk_idx_mat = self._random_walks(
tot_num_jobs,
walk_length,
random_state,
random_states,
start_node_idx_ary,
has_nbrs,
move_forward,
Expand All @@ -161,29 +161,32 @@ def simulate_walks(

return walks

@staticmethod
def _get_random_seeds(base_seed: int, num_jobs: int) -> np.ndarray:
"""Get random number generators for each thread."""
rng = np.random.default_rng(base_seed)
return rng.integers(0, 2**31 - 1, size=num_jobs, dtype=np.int32)

@staticmethod
@njit(parallel=True, nogil=True)
def _random_walks(
tot_num_jobs: int,
walk_length: int,
random_state: Optional[int],
random_states: Optional[np.ndarray],
start_node_idx_ary: Uint32Array,
has_nbrs: HasNbrs,
move_forward: MoveForward,
progress_proxy: ProgressBar,
) -> Uint32Array:
"""Simulate a random walk starting from start node."""
# Seed the random number generator
if random_state is not None:
np.random.seed(random_state + get_thread_id())

# use the last entry of each walk index array to keep track of the
# effective walk length
walk_idx_mat = np.zeros((tot_num_jobs, walk_length + 2), dtype=np.uint32)
walk_idx_mat[:, 0] = start_node_idx_ary # initialize seeds
walk_idx_mat[:, -1] = walk_length + 1 # set to full walk length by default

for i in prange(tot_num_jobs):
np.random.seed(random_states[i])
# initialize first step as normal random walk
start_node_idx = walk_idx_mat[i, 0]
if has_nbrs(start_node_idx):
Expand Down
Loading