diff --git a/src/pecanpy/cli.py b/src/pecanpy/cli.py index 577e59a2..33a1cba2 100755 --- a/src/pecanpy/cli.py +++ b/src/pecanpy/cli.py @@ -156,7 +156,7 @@ def parse_args(): parser.add_argument( "--random_state", type=int, - default=None, + default=42, help="Random seed for generating random walks.", ) diff --git a/src/pecanpy/pecanpy.py b/src/pecanpy/pecanpy.py index 923562d0..98a419e5 100755 --- a/src/pecanpy/pecanpy.py +++ b/src/pecanpy/pecanpy.py @@ -3,6 +3,7 @@ from gensim.models import Word2Vec from numba import njit from numba import prange +from numba import set_num_threads from numba_progress import ProgressBar from .graph import BaseGraph @@ -18,11 +19,6 @@ from .typing import Uint64Array from .wrappers import Timer -try: - from numba.np.ufunc.parallel import get_thread_id -except ImportError: # numba<0.56 - from numba.np.ufunc.parallel import _get_thread_id as get_thread_id - class Base(BaseGraph): """Base node2vec object. @@ -88,12 +84,13 @@ def __init__( verbose: bool = False, extend: bool = False, gamma: float = 0, - random_state: Optional[int] = None, + random_state: int = None, ): super().__init__() self.p = p self.q = q - self.workers = workers # TODO: not doing anything, need to fix. + self.workers = workers + set_num_threads(workers) self.verbose = verbose self.extend = extend self.gamma = gamma @@ -144,12 +141,15 @@ def simulate_walks( has_nbrs = self.get_has_nbrs() verbose = self.verbose + # Create list of seeds + random_states = self._get_random_seeds(random_state, tot_num_jobs) + # Acquire numba progress proxy for displaying the progress bar with ProgressBar(total=tot_num_jobs, disable=not verbose) as progress: walk_idx_mat = self._random_walks( tot_num_jobs, walk_length, - random_state, + random_states, start_node_idx_ary, has_nbrs, move_forward, @@ -161,22 +161,24 @@ def simulate_walks( return walks + @staticmethod + def _get_random_seeds(base_seed: int, num_jobs: int) -> np.ndarray: + """Get random number generators for each thread.""" + rng = np.random.default_rng(base_seed) + return rng.integers(0, 2**31 - 1, size=num_jobs, dtype=np.int32) + @staticmethod @njit(parallel=True, nogil=True) def _random_walks( tot_num_jobs: int, walk_length: int, - random_state: Optional[int], + random_states: Optional[np.ndarray], start_node_idx_ary: Uint32Array, has_nbrs: HasNbrs, move_forward: MoveForward, progress_proxy: ProgressBar, ) -> Uint32Array: """Simulate a random walk starting from start node.""" - # Seed the random number generator - if random_state is not None: - np.random.seed(random_state + get_thread_id()) - # use the last entry of each walk index array to keep track of the # effective walk length walk_idx_mat = np.zeros((tot_num_jobs, walk_length + 2), dtype=np.uint32) @@ -184,6 +186,7 @@ def _random_walks( walk_idx_mat[:, -1] = walk_length + 1 # set to full walk length by default for i in prange(tot_num_jobs): + np.random.seed(random_states[i]) # initialize first step as normal random walk start_node_idx = walk_idx_mat[i, 0] if has_nbrs(start_node_idx):