Source code for cube_solver.solver.utils

"""Solver utils module."""
from __future__ import annotations

import numpy as np
from pathlib import Path
from collections import deque
from dataclasses import asdict
from typing_extensions import TYPE_CHECKING
from typing import Union, Tuple, Sequence, Dict, Callable

from ..logger import logger
from ..defs import CoordsType, NEXT_MOVES
from ..cube.enums import Move
from ..cube.cube import Cube, apply_move
from .defs import NONE, FlattenCoords, TransitionDef, PruningDef, TableDef

if TYPE_CHECKING:
    from .solver import BaseSolver


[docs] def flatten(coords: CoordsType) -> FlattenCoords: """ Get the flattened cube coordinates. Parameters ---------- coords : tuple of (int or tuple of int) Cube coordinates. Returns ------- flatten_coords : tuple of int Flattened cube coordinates. Examples -------- >>> from cube_solver import Cube >>> from cube_solver.solver import utils >>> cube = Cube() >>> coords = cube.get_coords(partial_corner_perm=True, partial_edge_perm=True) >>> coords (0, 0, (0, 1656), (0, 11856, 1656)) >>> utils.flatten(coords) (0, 0, 0, 1656, 0, 11856, 1656) """ flatten_coords = () for coord in coords: flatten_coords += (coord,) if isinstance(coord, int) else coord return flatten_coords
[docs] def select(coords: FlattenCoords, indexes: Union[int, Tuple[int, ...], None]) -> FlattenCoords: """ Select coordinates. Parameters ---------- coords : tuple of int Coordinates. indexes : int or tuple of int or None Index or indexes of the coordinates to select. If ``None``, all coordinates are selected. Returns ------- selected_coords : tuple of int Selected coordinates. Examples -------- >>> from cube_solver import Cube >>> from cube_solver.solver import utils >>> cube = Cube() >>> coords = cube.get_coords(partial_corner_perm=True, partial_edge_perm=True) >>> flatten_coords = utils.flatten(coords) >>> utils.select(flatten_coords, None) (0, 0, 0, 1656, 0, 11856, 1656) >>> utils.select(flatten_coords, 5) (11856,) >>> utils.select(flatten_coords, (3, 5, 6)) (1656, 11856, 1656) """ if indexes is None: return coords if isinstance(indexes, int): return (coords[indexes],) return tuple(coords[index] for index in indexes)
[docs] def load_tables(path: Union[str, Path]) -> Dict[str, np.ndarray]: """ Load the tables from a file. Parameters ---------- path : str or Path Path of the file. Returns ------- tables : dict Dictionary containig the tables. """ if not isinstance(path, (str, Path)): raise TypeError(f"path must be str or Path, not {type(path).__name__}") if isinstance(path, str): path = Path(path) with np.load(path, allow_pickle=False) as data: tables = dict(data) return tables
[docs] def save_tables(path: Union[str, Path], tables: Dict[str, np.ndarray]): """ Save the tables into a single file. Parameters ---------- path : str or Path Path of the file. tables : dict Dictionary containig the tables. """ if not isinstance(path, (str, Path)): raise TypeError(f"path must be str or Path, not {type(path).__name__}") if not isinstance(tables, dict): raise TypeError(f"tables must be dict, not {type(tables).__name__}") if isinstance(path, str): path = Path(path) path.parent.mkdir(exist_ok=True) with path.open("wb") as file: np.savez(file, **tables) # type: ignore
[docs] def get_tables(filename: str, tables_defs: Sequence[TableDef], generate_table_fn: Callable[..., np.ndarray], accumulate: bool = False) -> Dict[str, np.ndarray]: """ Create or load tables from the ``tables/`` directory according to the ``tables_defs``. If the file does not exist, or if it exists but is missing some tables, create the missing tables from ``tables_defs`` and update the file. Parameters ---------- filename : str Name of the file in the ``tables/`` directory. tables_defs : list of TableDef Table definitions. generate_table_fn : Callable Function to generate the table. It must accept the TableDef keyword arguments. accumulate : bool, optional Whether to keep the tables not included in ``tables_defs``. Default is ``False``. Returns ------- tables : dict Dictionary containig the tables. The keys represent the name of the table from :attr:`TableDef.name`. """ if not isinstance(filename, str): raise TypeError(f"filename must be str, not {type(filename).__name__}") if not isinstance(tables_defs, list): raise TypeError(f"tables_defs must be list, not {type(tables_defs).__name__}") if not isinstance(generate_table_fn, Callable): raise TypeError(f"generate_table_fn must be Callable, not {type(generate_table_fn).__name__}") if not isinstance(accumulate, bool): raise TypeError(f"accumulate must be bool, not {type(accumulate).__name__}") for kwargs in tables_defs: if not isinstance(kwargs, (TransitionDef, PruningDef)): raise TypeError(f"tables_defs elements must be TableDef, not {type(kwargs).__name__}") path = Path(f"tables/{filename}") try: tables = load_tables(path) save = False for kwargs in tables_defs: if kwargs.name not in tables: logger.info(f"Updating {path}") tables[kwargs.name] = generate_table_fn(**asdict(kwargs)) save = True if not accumulate: names = {kwargs.name for kwargs in tables_defs} for name in tables.keys() - names: logger.info(f"Updating {path}") del tables[name] save = True if save: save_tables(path, tables) except FileNotFoundError: logger.info(f"Creating {path}") tables = {kwargs.name: generate_table_fn(**asdict(kwargs)) for kwargs in tables_defs} save_tables(path, tables) return tables
[docs] def generate_transition_table(coord_name: str, coord_size: int) -> np.ndarray: """ Generate the cube coordinate transition table. Parameters ---------- coord_name : str Cube coordinate name. coord_size : int Cube coordinate size. Returns ------- transition_table : ndarray Cube coordinate transition table. """ if not isinstance(coord_name, str): raise TypeError(f"coord_name must be str, not {type(coord_name).__name__}") if not isinstance(coord_size, int): raise TypeError(f"coord_size must be int, not {type(coord_size).__name__}") if coord_size <= 0 or coord_size - 1 > np.iinfo(np.uint16).max: raise ValueError(f"coord_size must be > 0 and <= {np.iinfo(np.uint16).max + 1} (got {coord_size})") transition_table = np.zeros((coord_size, len(NEXT_MOVES[Move.NONE])), dtype=np.uint16) cube = Cube() for coord in range(coord_size): cube.set_coord(coord_name, coord) transition_table[coord] = [apply_move(cube, move).get_coord(coord_name) for move in NEXT_MOVES[Move.NONE]] return transition_table
[docs] def generate_pruning_table(solver: BaseSolver, phase: int, shape: Union[int, Tuple[int, ...]], indexes: Union[int, Tuple[int, ...], None], **kwargs) -> np.ndarray: """ Generate the phase coordinates pruning table. Parameters ---------- solver : BaseSolver Solver object. phase : int Solver phase (0-indexed). shape : int or tuple of int Shape of the pruning table. indexes : int or tuple of int or None Index or indexes of the phase coordinates to use for the pruning table. If ``None``, use all the phase coordinates. Returns ------- pruning_table : ndarray Phase coordinates pruning table. """ if not isinstance(phase, int): raise TypeError(f"phase must be int, not {type(phase).__name__}") if not isinstance(shape, (int, tuple)): raise TypeError(f"shape must be int or tuple, not {type(shape).__name__}") if indexes is not None and not isinstance(indexes, (int, tuple)): raise TypeError(f"indexes must be int or tuple or None, not {type(indexes).__name__}") if phase < 0 or phase >= solver.num_phases: raise ValueError(f"phase must be >= 0 and < {solver.num_phases} (got {phase})") if isinstance(shape, tuple): for size in shape: if not isinstance(size, int): raise TypeError(f"shape elements must be int, not {type(size).__name__}") if isinstance(indexes, tuple): for index in indexes: if not isinstance(index, int): raise TypeError(f"indexes elements must be int, not {type(index).__name__}") pruning_table = np.full(shape, NONE, dtype=np.int8) init_coords = solver.get_coords(Cube()) phase_coords = solver.phase_coords(flatten(init_coords), phase) prune_coords = select(phase_coords, indexes) pruning_table[prune_coords] = 0 queue = deque([(init_coords, 0)]) while queue: coords, depth = queue.popleft() for move in solver.phase_moves[phase]: next_coords = solver.next_position(coords, move) phase_coords = solver.phase_coords(flatten(next_coords), phase) prune_coords = select(phase_coords, indexes) if pruning_table[prune_coords] == NONE: pruning_table[prune_coords] = depth + 1 queue.append((next_coords, depth + 1)) return pruning_table