diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 45469ce..d423385 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,13 +1,13 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python +# This workflow installs Python dependencies, lints, and runs the test suite. # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions name: Python application on: push: - branches: [ main ] + branches: [ main, rewrite ] pull_request: - branches: [ main ] + branches: [ main, rewrite ] permissions: contents: read @@ -18,19 +18,26 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.12" - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install flake8 pytest poetry + python -m pip install --upgrade pip poetry + # `poetry install` picks up dev-dependencies (pytest, flake8) from + # the [tool.poetry.dev-dependencies] table, so we don't need to + # pip-install them separately and risk version skew. poetry install - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=.venv # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude=.venv + - name: Run tests + # Default pytest config skips the `integration` marker (CLIP download). + # Fast tests only here — they're enough to catch dead imports and + # regressions in the loader / CLI wiring. + run: poetry run pytest -v diff --git a/.gitignore b/.gitignore index 00d1fa1..8ec9ef4 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,11 @@ tags *.ann *.pt +# local-only workspace state +.claude/ +.venv/ +.pytest_cache/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/docs/testimage.png b/docs/testimage.png deleted file mode 100644 index c0351ce..0000000 Binary files a/docs/testimage.png and /dev/null differ diff --git a/graphs/embed_d.jpg b/graphs/embed_d.jpg deleted file mode 100644 index 068cdcf..0000000 Binary files a/graphs/embed_d.jpg and /dev/null differ diff --git a/graphs/embed_n.jpg b/graphs/embed_n.jpg deleted file mode 100644 index 206031d..0000000 Binary files a/graphs/embed_n.jpg and /dev/null differ diff --git a/graphs/mde_d.gif b/graphs/mde_d.gif deleted file mode 100644 index 66439f8..0000000 Binary files a/graphs/mde_d.gif and /dev/null differ diff --git a/graphs/mde_n.gif b/graphs/mde_n.gif deleted file mode 100644 index 6a79d9e..0000000 Binary files a/graphs/mde_n.gif and /dev/null differ diff --git a/graphs/normalized-d.jpg b/graphs/normalized-d.jpg deleted file mode 100644 index 3846c6a..0000000 Binary files a/graphs/normalized-d.jpg and /dev/null differ diff --git a/graphs/normalized-lg.jpg b/graphs/normalized-lg.jpg deleted file mode 100644 index 519b1b8..0000000 Binary files a/graphs/normalized-lg.jpg and /dev/null differ diff --git a/graphs/normalized.jpg b/graphs/normalized.jpg deleted file mode 100644 index a2d34db..0000000 Binary files a/graphs/normalized.jpg and /dev/null differ diff --git a/graphs/plotted_ims.jpg b/graphs/plotted_ims.jpg deleted file mode 100644 index 0cd6d53..0000000 Binary files a/graphs/plotted_ims.jpg and /dev/null differ diff --git a/memery/cli.py b/memery/cli.py index 2d6ed4e..caae7eb 100644 --- a/memery/cli.py +++ b/memery/cli.py @@ -1,8 +1,11 @@ +import sys +from pathlib import Path +from typing import Optional + import typer -from memery.core import Memery + import memery -import streamlit.cli -from typing import Optional +from memery.core import Memery # Sometimes you just want to be able to pipe information through the terminal. This is that command app = typer.Typer() @@ -13,28 +16,41 @@ def main(): @app.command() def recall( root: str = typer.Argument('.', help="Image folder to search"), - text: str = typer.Option(None, *("-t", "--text"), help="Text query"), - image: str = typer.Option(None, *("-i", "--image"), help="Filepath to image query") , - number: int = typer.Option(10, *("-n", "--number"), help="Number of results to return") - ) -> list[str]: - """Search recursively over a folder from the command line""" + text: str = typer.Option(None, "-t", "--text", help="Text query"), + negative: str = typer.Option(None, "-nt", "--negative-text", help="Negative text query"), + image: str = typer.Option(None, "-i", "--image", help="Filepath to image query"), + number: int = typer.Option(10, "-n", "--number", help="Number of results to return"), + ) -> None: + """Search recursively over a folder from the command line.""" memery = Memery() - ranked = memery.query_flow(root, query=text, image_query=image) + ranked = memery.query_flow(root, query=text, negative_query=negative, image_query=image) print(ranked[:number]) @app.command() def serve(root: Optional[str] = typer.Argument(None)): """Runs the streamlit GUI in your browser""" - app_path = memery.__file__.replace('__init__.py','streamlit_app.py') - if root is None: - streamlit.cli.main(['run', app_path, './images']) - else: - streamlit.cli.main(['run', app_path, f'{root}']) + # Importing here so `memery --help` doesn't pay the streamlit import cost + from streamlit.web import cli as stcli + + app_path = str(Path(memery.__file__).parent / "streamlit_app.py") + target_root = root if root is not None else "./images" + sys.argv = ["streamlit", "run", app_path, "--", target_root] + sys.exit(stcli.main()) @app.command() def build( root: str = typer.Argument('.'), - workers: int = typer.Option(default=0) + workers: int = typer.Option( + 2, + "--workers", "-w", + help=( + "DataLoader workers for image preprocessing. Measured on macOS/MPS: " + "2 wins ~15% on realistic corpora (4000+ images). Higher counts hurt " + "because the GPU is a serial bottleneck and IPC overhead piles up. " + "Pass 0 to disable workers entirely (slightly faster on tiny folders, " + "or as a fallback if multiprocessing misbehaves in your environment)." + ), + ), ): ''' Indexes the directory and all subdirectories diff --git a/memery/core.py b/memery/core.py index 07e1b38..5304d50 100644 --- a/memery/core.py +++ b/memery/core.py @@ -59,7 +59,17 @@ def index_flow(self, root: str, num_workers=0) -> tuple[str, str]: # Crafting and encoding crafted_files = crafter.crafter(new_files, device, num_workers=num_workers) model = self.get_model() - new_embeddings = encoder.image_encoder(crafted_files, device, model) + new_embeddings, surviving = encoder.image_encoder(crafted_files, device, model) + + # `surviving` maps embedding rows back to indices into new_files; + # it can be shorter than new_files if any image failed to decode + # after passing verify_image (truncated files, weird color spaces). + # Filter new_files down to the same set so join_all can correlate + # by position without misalignment. + if len(surviving) != len(new_files): + dropped = len(new_files) - len(surviving) + print(f"Skipped {dropped} image(s) that failed to decode") + new_files = [new_files[int(i)] for i in surviving.tolist()] # Reindexing db = indexer.join_all(archive_db, new_files, new_embeddings) diff --git a/memery/crafter.py b/memery/crafter.py index fa55275..0071af1 100644 --- a/memery/crafter.py +++ b/memery/crafter.py @@ -1,71 +1,121 @@ +from typing import Optional + import torch from torch import Tensor, device +from torch.utils.data import DataLoader, default_collate from torchvision.datasets import VisionDataset -from PIL import Image, ImageFile from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from torch.utils.data import DataLoader - - -def make_dataset(new_files: list[str]) -> tuple[list[str], list[str]]: - '''Returns a list of samples of a form (path_to_sample, class) and in - this case the class is just the filename''' - samples = [] - slugs = [] - for i, f in enumerate(new_files): - path, slug = f - samples.append((str(path), i)) - slugs.append((slug, i)) - return(samples, slugs) - -def pil_loader(path: str) -> Image.Image: - ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow truncated images +from PIL import Image, ImageFile + + +def make_dataset(new_files: list[tuple[str, str]]) -> list[tuple[str, int]]: + '''Returns a list of (path, index) pairs. + + The previous implementation also built a parallel `slugs` list that + nothing downstream ever read — pure dead weight on every build. + ''' + return [(str(path), i) for i, (path, _slug) in enumerate(new_files)] + + +def pil_loader(path: str) -> Optional[Image.Image]: + """Open `path` and return an RGB PIL image, or None on failure. + + Uses PIL's JPEG `draft` mode to decode at the smallest IDCT scale that's + still ≥ 256px. For large JPEGs this skips most of the inverse-DCT work + and saves a substantial amount of file I/O — measured ~1.6x speedup on + a realistic Downloads sample (~165KB median image size). For non-JPEG + formats `draft` is a no-op, so it's safe to leave on unconditionally. + """ + ImageFile.LOAD_TRUNCATED_IMAGES = True # tolerate partially-truncated files try: - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - with open(path, 'rb') as f: + # open path as file to avoid ResourceWarning + # https://github.com/python-pillow/Pillow/issues/835 + with open(path, 'rb') as f: img = Image.open(f) + # Hint the JPEG decoder for partial-scale decode. 256 is chosen + # to stay safely above the 224 CLIP input size after CenterCrop. + try: + img.draft('RGB', (256, 256)) + except (AttributeError, OSError): + pass return img.convert('RGB') except Exception as e: print(f"Skipping image {path}: {e}") return None + class DatasetImagePaths(VisionDataset): - - def __init__(self, new_files, transforms = None): - super(DatasetImagePaths, self).__init__(new_files, transforms=transforms) - samples, slugs = make_dataset(new_files) - self.samples = samples - self.slugs = slugs + + def __init__(self, new_files, transforms=None): + super().__init__(new_files, transforms=transforms) + self.samples = make_dataset(new_files) self.loader = pil_loader self.root = 'file dataset' + def __len__(self): - return(len(self.samples)) + return len(self.samples) def __getitem__(self, index): path, target = self.samples[index] try: sample = self.loader(path) - if sample is not None: - if self.transforms is not None: - sample = self.transforms(sample) - return sample, target + if sample is None: + return None + if self.transforms is not None: + sample = self.transforms(sample) + return sample, target except Exception as e: print(f"Skipping file {path} due to error: {e}") return None + +def safe_collate(batch): + """Drop items that failed to decode, then run the default collate. + + Without this, a single `None` from `__getitem__` (e.g. a file that + passed `verify_image` but failed at decode time) would crash the entire + DataLoader with an unhelpful traceback. Returning `None` for an empty + batch lets the encoder loop skip it cleanly. + """ + batch = [b for b in batch if b is not None] + if not batch: + return None + return default_collate(batch) + + def clip_transform(n_px: int) -> Compose: return Compose([ Resize(n_px, interpolation=Image.BICUBIC), CenterCrop(n_px), ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), ]) -def crafter(new_files: list[str], device: device, batch_size: int=128, num_workers: int=4): + +def crafter(new_files: list[tuple[str, str]], device: device, + batch_size: int = 128, num_workers: int = 0) -> DataLoader: + """Build the DataLoader used to feed CLIP. + + `num_workers=0` is the macOS default because DataLoader workers use + `spawn` (fork is unsafe with PyTorch), each costs ~2s of startup, and + on typical-sized images (~150KB) the parallelism gain is smaller than + the spawn overhead. For folders dominated by very large images + (multi-megabyte phone photos / scans), `--workers 4-8` can win + significantly — measured at scale, not at small N. + """ with torch.no_grad(): - imagefiles=DatasetImagePaths(new_files, clip_transform(224)) - img_loader=DataLoader(imagefiles, batch_size=batch_size, shuffle=False, num_workers=num_workers) - return(img_loader) + imagefiles = DatasetImagePaths(new_files, clip_transform(224)) + img_loader = DataLoader( + imagefiles, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + collate_fn=safe_collate, + ) + return img_loader + -def preproc(img: Tensor) -> Compose: - transformed = clip_transform(224)(img) - return(transformed) \ No newline at end of file +def preproc(img: Image.Image) -> Tensor: + """Apply the CLIP preprocessing pipeline to a single PIL image.""" + return clip_transform(224)(img) diff --git a/memery/encoder.py b/memery/encoder.py index 0989ac5..ac5a1ba 100644 --- a/memery/encoder.py +++ b/memery/encoder.py @@ -9,17 +9,57 @@ def load_model(device: device) -> CLIP: model, _ = clip.load("ViT-B/32", device, jit=False) model = model.float() - return(model) + # Inference-only — disables any train-time behavior in submodules. + # CLIP ViT-B/32 has no batchnorm or dropout that fires here, so this is + # mostly defensive correctness, but it's free. + model.eval() + return model def image_encoder(img_loader: DataLoader, device: device, model: CLIP): - image_embeddings = torch.tensor(()).to(device) + """Encode a DataLoader's worth of images to L2-normalized CLIP features. + + Returns ``(embeddings, labels)``: + * ``embeddings``: ``(K, 512)`` CPU tensor of unit-norm features, where + ``K`` is the number of images that *successfully* survived decoding + and collation. May be smaller than ``len(dataset)`` if any items + were dropped by ``crafter.safe_collate``. + * ``labels``: ``(K,)`` CPU long tensor mapping each row of + ``embeddings`` back to its original index in ``new_files``. The + caller uses this to keep file paths and embeddings aligned even + when some files failed to decode mid-batch. + + Implementation notes: + * Features are accumulated in a Python list and concatenated once at + the end. The previous in-loop ``torch.cat`` was O(n²) in the number + of batches and on MPS paid a Metal kernel-launch cost per concat — + ~15s wasted on a 702-batch run. + * The final tensor is moved to CPU here. Everything downstream (annoy, + torch.save) is CPU-only; leaving embeddings on MPS caused + ``annoy.add_item`` to force a Metal ``waitUntilCompleted`` for every + single float, i.e. ~46M GPU syncs on an 89k-image library. + """ + feature_chunks = [] + label_chunks = [] with torch.no_grad(): - for images, labels in tqdm(img_loader): - batch_features = model.encode_image(images.to(device)) - image_embeddings = torch.cat((image_embeddings, batch_features)).to(device) + for batch in tqdm(img_loader): + if batch is None: + # safe_collate returns None when every image in a batch + # failed to decode — skip rather than crash. + continue + images, labels = batch + feature_chunks.append(model.encode_image(images.to(device))) + label_chunks.append(labels) + + if not feature_chunks: + return ( + torch.empty((0, 512)), + torch.empty((0,), dtype=torch.long), + ) + image_embeddings = torch.cat(feature_chunks, dim=0) image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) - return(image_embeddings) + surviving_labels = torch.cat(label_chunks, dim=0) + return image_embeddings.cpu(), surviving_labels.cpu() def text_encoder(text: str, device: device, model: CLIP): with torch.no_grad(): diff --git a/memery/gui.py b/memery/gui.py deleted file mode 100644 index fd98433..0000000 --- a/memery/gui.py +++ /dev/null @@ -1,82 +0,0 @@ -# import ipywidgets as widgets - -# from .core import query_flow -# from pathlib import Path -# from IPython.display import clear_output - - - -# def get_image(file_loc): -# filepath = Path(file_loc) -# file = open(filepath, 'rb') -# image = widgets.Image(value=file.read(),width=200) - -# return(image) - -# def get_grid(filepaths, n=4): -# imgs = [get_image(f) for f in filepaths[:n] if Path(f).exists()] -# grid = widgets.GridBox(imgs, layout=widgets.Layout(grid_template_columns="repeat(auto-fit, 200px)")) -# return(grid) - -# from PIL import Image -# from io import BytesIO - -# def update_tabs(path, query, n_images, searches, tabs, logbox, im_display_zone, image_query=None): -# stem = Path(path.value).stem -# slug = f"{stem}:{str(query.value)}" -# if slug not in searches.keys(): -# with logbox: -# print(slug) -# if image_query: -# im_queries = [name for name, data in image_query.items()] - -# img = [Image.open(BytesIO(file_info['content'])).convert('RGB') for name, file_info in image_query.items()] -# ranked = query_flow(path.value, query.value, image_query=img[-1]) -# slug = slug + f'/{im_queries}' - -# if len(im_queries) > 0: -# with im_display_zone: -# clear_output() -# display(img[-1]) -# else: -# ranked = query_flow(path.value, query.value) -# searches[f'{slug}'] = ranked - -# tabs.children = [get_grid(v, n=n_images.value) for v in searches.values()] -# for i, k in enumerate(searches.keys()): -# tabs.set_title(i, k) -# tabs.selected_index = len(searches)-1 - - -# # return(True) - -# class appPage(): - -# def __init__(self): -# self.inputs_layout = widgets.Layout(max_width='80%') - -# self.path = widgets.Text(placeholder='path/to/image/folder', value='images/', layout=self.inputs_layout) -# self.query = widgets.Text(placeholder='a funny dog meme', value='a funny dog meme', layout=self.inputs_layout) - -# self.image_query = widgets.FileUpload() -# self.im_display_zone = widgets.Output(max_height='5rem') - -# self.n_images = widgets.IntSlider(description='#', value=4, layout=self.inputs_layout) -# self.go = widgets.Button(description="Search", layout=self.inputs_layout) -# self.logbox = widgets.Output(layout=widgets.Layout(max_width='80%', height="3rem", overflow="none")) -# self.all_inputs_layout = widgets.Layout(max_width='80vw', min_height='40vh', flex_flow='row wrap', align_content='flex-start') - -# self.inputs = widgets.Box([self.path, self.query, self.image_query, self.n_images, self.go, self.im_display_zone, self.logbox], layout=self.all_inputs_layout) -# self.tabs = widgets.Tab() -# self.page = widgets.AppLayout(left_sidebar=self.inputs, center=self.tabs) - -# self.searches = {} -# self.go.on_click(self.page_update) - -# display(self.page) - -# def page_update(self, b): - -# update_tabs(self.path, self.query, self.n_images, self.searches, self.tabs, self.logbox, self.im_display_zone, self.image_query.value) - - diff --git a/memery/indexer.py b/memery/indexer.py index 2013f19..f4d9e80 100644 --- a/memery/indexer.py +++ b/memery/indexer.py @@ -1,5 +1,11 @@ from annoy import AnnoyIndex +import numpy as np import torch +from tqdm import tqdm + + +EMBED_DIM = 512 + def join_all(db, new_files, new_embeddings) -> dict: start = len(db) @@ -11,30 +17,55 @@ def join_all(db, new_files, new_embeddings) -> dict: 'fpath': path, 'embed': new_embeddings[i], } - return(db) + return db + + +def _to_cpu_array(emb): + """Coerce one stored embedding to a fast plain-Python list of floats. + + Annoy iterates the vector with PyFloat_AsDouble. If `emb` is a torch + Tensor still on a GPU device, that path fires `Tensor.item()` once per + element, and on MPS each call costs a full `waitUntilCompleted`. So + we make sure annoy gets a CPU-resident plain sequence. + """ + if isinstance(emb, torch.Tensor): + # `.detach().cpu()` is a single bulk device-to-host copy. + return emb.detach().cpu().numpy() + if isinstance(emb, np.ndarray): + return emb + return np.asarray(emb, dtype=np.float32) + def build_treemap(db) -> AnnoyIndex: - treemap = AnnoyIndex(512, 'angular') - for k, v in db.items(): - treemap.add_item(k, v['embed']) + """Build an angular Annoy index over the database's embeddings. - # Build the treemap, with 5 trees rn - treemap.build(5) + The previous implementation was the dominant bottleneck on macOS/MPS at + scale: handing GPU tensors to annoy caused element-wise GPU syncs, so a + library of ~90k images spent more time in this single function than in + the entire CLIP encoding pass. This version converts the embeddings to + CPU in a single bulk step per item and shows progress. + """ + treemap = AnnoyIndex(EMBED_DIM, 'angular') + if not db: + treemap.build(5) + return treemap - return(treemap) + for k, v in tqdm(db.items(), desc="Indexing", total=len(db)): + treemap.add_item(k, _to_cpu_array(v['embed'])) + + treemap.build(5) + return treemap def save_archives(root, treemap, db) -> tuple[str, str]: - dbpath = root/'memery.pt' + dbpath = root / 'memery.pt' if dbpath.exists(): -# dbpath.rename(root/'memery-bak.pt') dbpath.unlink() torch.save(db, dbpath) - treepath = root/'memery.ann' + treepath = root / 'memery.ann' if treepath.exists(): -# treepath.rename(root/'memery-bak.ann') treepath.unlink() treemap.save(str(treepath)) - return(str(dbpath), str(treepath)) \ No newline at end of file + return str(dbpath), str(treepath) diff --git a/memery/loader.py b/memery/loader.py index 078ef3c..50cbe85 100644 --- a/memery/loader.py +++ b/memery/loader.py @@ -30,14 +30,26 @@ def verify_image(f: str): logging.exception('Skipping bad file: %s\ndue to %s', f, e) pass -def archive_loader(filepaths: list[str], db: Any) -> tuple[ set[str], list[str] ]: # Just guessing on the return type - - current_hashes = [hash for path, hash in filepaths] - archive_db = {i:db[item[0]] for i, item in enumerate(db.items()) if item[1]['hash'] in current_hashes} - archive_hashes = [v['hash'] for v in archive_db.values()] - new_files = [(str(path), hash) for path, hash in filepaths if hash not in archive_hashes and verify_image(path)] - - return(archive_db, new_files) +def archive_loader( + filepaths: list[tuple[Path, str]], + db: dict[int, dict[str, Any]], +) -> tuple[dict[int, dict[str, Any]], list[tuple[str, str]]]: + # `filepaths` already only contains files that passed verify_image upstream + # in get_valid_images, so re-verifying here is wasted work. Use a set for + # O(1) hash lookups instead of an O(n) `in` against a list. + current_hashes = {h for _, h in filepaths} + archive_db = { + i: db[item[0]] + for i, item in enumerate(db.items()) + if item[1]['hash'] in current_hashes + } + archive_hashes = {v['hash'] for v in archive_db.values()} + new_files = [ + (str(path), h) + for path, h in filepaths + if h not in archive_hashes + ] + return archive_db, new_files def db_loader(dbpath: str, device: device) -> Any: ''' diff --git a/memery/streamlit_app.py b/memery/streamlit_app.py index 9e4fbc2..22c01b8 100644 --- a/memery/streamlit_app.py +++ b/memery/streamlit_app.py @@ -1,27 +1,24 @@ # Builtins +import argparse +import sys from pathlib import Path + +# Dependencies +import streamlit as st from PIL import Image -from io import StringIO -import sys -import argparse -from threading import current_thread -from contextlib import contextmanager # Local from memery.core import Memery -# Dependencies -import streamlit as st -from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME - # Parses the args from the command line def parse_args(args: list[str]): parser = argparse.ArgumentParser() - parser.add_argument('root', help='starting directory to search') + parser.add_argument('root', nargs='?', default='./images', + help='starting directory to search') return parser.parse_args(args) -# Initalize session state +# Initialize session state args = parse_args(sys.argv[1:]) if 'memery' not in st.session_state: st.session_state['memery'] = Memery() @@ -38,14 +35,14 @@ def parse_args(args: list[str]): do_clear_cache = st.button(label="Clear Cache") num_workers = st.slider(label="Number of workers", max_value=8) -dir_l, dir_r = st.sidebar.columns([3,1]) +dir_l, dir_r = st.sidebar.columns([3, 1]) with dir_l: path = st.text_input(label='Directory', value=args.root) with dir_r: st.title("") do_index = st.button(label="Index", key='do_index') -search_l, search_r = st.sidebar.columns([3,1]) +search_l, search_r = st.sidebar.columns([3, 1]) with search_l: text_query = st.text_input(label='Text query', value='') negative_text_query = st.text_input(label='Negative Text query', value='') @@ -56,7 +53,7 @@ def parse_args(args: list[str]): image_query = st.sidebar.file_uploader(label='Image query') image_query_display = st.sidebar.container() -if image_query: # Display the image query if there is one +if image_query: # Display the image query if there is one img = Image.open(image_query).convert('RGB') with image_query_display: st.image(img) @@ -64,94 +61,68 @@ def parse_args(args: list[str]): skipped_files_box = st.sidebar.expander(label='Skipped files', expanded=False) # Draw the main page -sizes = {'small': 115, 'medium':230, 'large':332, 'xlarge':600} -l, m, r = st.columns([4,1,1]) +sizes = {'small': 115, 'medium': 230, 'large': 332, 'xlarge': 600} +l, m, r = st.columns([4, 1, 1]) with l: - num_images = st.slider(label='Number of images',value=12) + num_images = st.slider(label='Number of images', value=12) with m: - size_choice = st.selectbox(label='Image width', options=[k for k in sizes.keys()], index=1) + size_choice = st.selectbox(label='Image width', options=list(sizes.keys()), index=1) with r: captions_on = st.checkbox(label="Caption filenames", value=False) image_display_zone = st.container() + # Index the directory def index(logbox, path, num_workers): - if Path(path).exists(): - with logbox: - with st_stdout('info'): - memery.index_flow(path, num_workers) - else: - with logbox: - with st_stdout('warning'): - print(f'{path} does not exist!') + if not Path(path).exists(): + logbox.warning(f'{path} does not exist!') + return + with logbox, st.spinner(f'Indexing {path}...'): + memery.index_flow(path, num_workers) + logbox.success('Done indexing') + # Clears out the database and treemap files def clear_cache(root, logbox): memery.clean(root) - with logbox: - with st_stdout('info'): - print("Cleaned database and index files") + logbox.info("Cleaned database and index files") + # Runs a search -def search(root, text_query, negative_text_query, image_query, image_display_zone, skipped_files_box, num_images, captions_on, sizes, size_choice): - if not Path(path).exists(): - with logbox: - with st_stdout('warning'): - print(f'{path} does not exist!') - return - with logbox: - with st_stdout('info'): - ranked = memery.query_flow(root, text_query, negative_text_query, image_query) # Modified line +def search(root, text_query, negative_text_query, image_query, + image_display_zone, skipped_files_box, + num_images, captions_on, sizes, size_choice): + if not Path(root).exists(): + logbox.warning(f'{root} does not exist!') + return + with logbox, st.spinner('Searching...'): + ranked = memery.query_flow( + root, text_query, negative_text_query, image_query + ) + if not ranked: + logbox.info('No results.') + return + ims_to_display = {} size = sizes[size_choice] for o in ranked[:num_images]: - name = o.replace(path, '') + name = o.replace(root, '') try: ims_to_display[name] = Image.open(o).convert('RGB') except Exception as e: with skipped_files_box: st.warning(f'Skipping bad file: {name}\ndue to {type(e)}') - pass with image_display_zone: if captions_on: - st.image([o for o in ims_to_display.values()], width=size, channels='RGB', caption=[o for o in ims_to_display.keys()]) + st.image(list(ims_to_display.values()), + width=size, channels='RGB', + caption=list(ims_to_display.keys())) else: - st.image([o for o in ims_to_display.values()], width=sizes[size_choice], channels='RGB') - + st.image(list(ims_to_display.values()), + width=size, channels='RGB') + logbox.success(f'Found {len(ranked)} matches.') -@contextmanager -def st_redirect(src, dst): - placeholder = st.empty() - output_func = getattr(placeholder, dst) - - with StringIO() as buffer: - old_write = src.write - - def new_write(b): - if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None): - buffer.write(b + '') - output_func(buffer.getvalue() + '') - else: - old_write(b) - - try: - src.write = new_write - yield - finally: - src.write = old_write - - -@contextmanager -def st_stdout(dst): - with st_redirect(sys.stdout, dst): - yield - - -@contextmanager -def st_stderr(dst): - with st_redirect(sys.stderr, dst): - yield # Decide which actions to take if do_clear_cache: @@ -159,5 +130,6 @@ def st_stderr(dst): elif do_index: index(logbox, path, num_workers) elif search_button or text_query or image_query: - search(path, text_query, negative_text_query, image_query, image_display_zone, skipped_files_box, num_images, captions_on, sizes, size_choice) # Modified line - + search(path, text_query, negative_text_query, image_query, + image_display_zone, skipped_files_box, + num_images, captions_on, sizes, size_choice) diff --git a/notebooks/00_core.ipynb b/notebooks/00_core.ipynb deleted file mode 100644 index 4c2ba01..0000000 --- a/notebooks/00_core.ipynb +++ /dev/null @@ -1,318 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp core" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Core\n", - "\n", - "> Index, query and save embeddings of images by folder" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Rationale\n", - "\n", - "**Memery takes a folder of images, and a search query, and returns a list of ranked images.**\n", - "\n", - "The images and query are both projected into a high-dimensional semantic space, courtesy of OpenAI's [https://github.com/openai/CLIP](https://openai.com/blog/clip/). These embeddings are indexed and treemapped using the [Annoy](https://github.com/spotify/annoy) library, which provides nearest-neighbor results for the search query. These results are then transmitted to the user interface (currently as a list of file locations).\n", - "\n", - "We provide various interfaces for the end user, which all call upon the function `query_flow` and `index_flow` below.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Modular flow system" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Memery uses the Neural Search design pattern as described by Han Xiao in e.g. [General Neural Elastic Search and Go Way Beyond](https://hanxiao.io/2019/07/29/Generic-Neural-Elastic-Search-From-bert-as-service-and-Go-Way-Beyond)&c.\n", - "\n", - "This is a system designed to be scalable and distributed if necessary. Even for a single-machine scenario, I like the functional style of it: grab data, transform it and pass it downstream, all the way from the folder to the output widget.\n", - "\n", - "There are two main types of operater in this pattern: **flows** and **executors**.\n", - "\n", - "**Flows** are specific patterns of data manipulation and storage. **Executors** are the operators that transform the data within the flow. \n", - "\n", - "There are two core flows to any search system: indexing, and querying. The plan here is to make executors that can be composed into flows and then compose the flows into a UI that supports querying and, to some extent, indexing as well.\n", - "\n", - "The core executors for this use case are:\n", - " - Loader\n", - " - Crafter\n", - " - Encoder\n", - " - Indexer\n", - " - Ranker\n", - " - Gateway\n", - " \n", - "\n", - "**NB: The executors are currently implemented as functions. A future upgrade will change the names to verbs to match, or change their implementation to classes if they're going to act as nouns.**\n", - "\n", - "These executors are being implemented ad hoc in the flow functions, but should probably be given single entry points and have their specific logic happen within their own files. Deeper abstractions with less coupling." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Flows" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "import time\n", - "import torch\n", - "\n", - "from pathlib import Path\n", - "from memery.loader import get_image_files, get_valid_images, archive_loader, db_loader, treemap_loader \n", - "from memery.crafter import crafter, preproc\n", - "from memery.encoder import image_encoder, text_encoder, image_query_encoder\n", - "from memery.indexer import join_all, build_treemap, save_archives\n", - "from memery.ranker import ranker, nns_to_files" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Indexing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def index_flow(path):\n", - " '''Indexes images in path, returns the location of save files'''\n", - " root = Path(path)\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " \n", - " # Loading\n", - " filepaths = loader.get_image_files(root)\n", - " archive_db = {}\n", - " \n", - " archive_db, new_files = loader.archive_loader(filepaths, root, device)\n", - " print(f\"Loaded {len(archive_db)} encodings\")\n", - " print(f\"Encoding {len(new_files)} new images\")\n", - "\n", - " # Crafting and encoding\n", - " crafted_files = crafter.crafter(new_files, device)\n", - " new_embeddings = encoder.image_encoder(crafted_files, device)\n", - " \n", - " # Reindexing\n", - " db = indexer.join_all(archive_db, new_files, new_embeddings)\n", - " print(\"Building treemap\")\n", - " t = indexer.build_treemap(db)\n", - " \n", - " print(f\"Saving {len(db)} encodings\")\n", - " save_paths = indexer.save_archives(root, t, db)\n", - "\n", - " return(save_paths)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "show_doc(index_flow)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can index the local `images` folder to test" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# delete the current savefile for testing purposes\n", - "Path('images/memery.pt').unlink()\n", - "Path('images/memery.ann').unlink()\n", - "\n", - "# run the index flow. returns the path\n", - "save_paths = index_flow('./images')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert save_paths # Returns True if the path exists\n", - "save_paths" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Querying" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def query_flow(path, query=None, image_query=None):\n", - " '''\n", - " Indexes a folder and returns file paths ranked by query.\n", - " \n", - " Parameters:\n", - " path (str): Folder to search\n", - " query (str): Search query text\n", - " image_query (Tensor): Search query image(s)\n", - "\n", - " Returns:\n", - " list of file paths ranked by query\n", - " '''\n", - " start_time = time.time()\n", - " root = Path(path)\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " \n", - " # Check if we should re-index the files\n", - " print(\"Checking files\")\n", - " dbpath = root/'memery.pt'\n", - " db = loader.db_loader(dbpath, device)\n", - " treepath = root/'memery.ann'\n", - " treemap = treemap_loader(treepath)\n", - " filepaths = get_valid_images(root)\n", - "\n", - " # # Rebuild the tree if it doesn't \n", - " # if treemap == None or len(db) != len(filepaths):\n", - " # print('Indexing')\n", - " # dbpath, treepath = index_flow(root)\n", - " # treemap = loader.treemap_loader(Path(treepath))\n", - " # db = loader.db_loader(dbpath, device)\n", - " \n", - " # Convert queries to vector\n", - " print('Converting query')\n", - " if image_query:\n", - " img = crafter.preproc(image_query)\n", - " if query and image_query:\n", - " text_vec = encoder.text_encoder(query, device)\n", - " image_vec = encoder.image_query_encoder(img, device)\n", - " query_vec = text_vec + image_vec\n", - " elif query:\n", - " query_vec = encoder.text_encoder(query, device)\n", - " elif image_query:\n", - " query_vec = encoder.image_query_encoder(img, device)\n", - " else:\n", - " print('No query!')\n", - "\n", - " # Rank db by query \n", - " print(f\"Searching {len(db)} images\")\n", - " indexes = ranker.ranker(query_vec, treemap)\n", - " ranked_files = ranker.nns_to_files(db, indexes)\n", - " \n", - " print(f\"Done in {time.time() - start_time} seconds\")\n", - " \n", - " return(ranked_files)\n", - "\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "show_doc(query_flow)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ranked = query_flow('./images', 'dog')\n", - "\n", - "print(ranked[0])\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert ranked[0] == \"images/memes/Wholesome-Meme-8.jpg\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![](images/memes/Wholesome-Meme-8.jpg)\n", - "\n", - "*Then what?! What are the limitations of this system? What are its options? What configuration can i do if i'm a power user? Why did you organize things this way instead of a different way?*\n", - "\n", - "*This, and probably each of the following notebooks, would benefit from a small recording session where I try to explain it to an imaginary audience. So that I can get the narrative of how it works, and then arrange the code around that.*\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.7.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/01_loader.ipynb b/notebooks/01_loader.ipynb deleted file mode 100644 index bf84bc3..0000000 --- a/notebooks/01_loader.ipynb +++ /dev/null @@ -1,338 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp loader" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Loader\n", - "> Functions for finding and loading image files and saved embeddings\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## File manipulation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "from pathlib import Path\n", - "from PIL import Image\n", - "from tqdm import tqdm" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**NB: A lot of this implementation is too specific, especially the slugified filenames being used for dictionary IDs. Should be replaced with a better database implementation.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def slugify(filepath):\n", - " return f'{filepath.stem}_{str(filepath.stat().st_mtime).split(\".\")[0]}'\n", - "\n", - "def get_image_files(path):\n", - " img_extensions = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}\n", - " return [(f, slugify(f)) for f in tqdm(path.rglob('*')) if f.suffix in img_extensions]\n", - "\n", - "def get_valid_images(path):\n", - " filepaths = get_image_files(path)\n", - " return [f for f in filepaths if verify_image(f[0])]\n", - "\n", - "# This returns boolean and should be called is_valid_image or something like that\n", - "def verify_image(f):\n", - " try:\n", - " img = Image.open(f)\n", - " img.verify() \n", - " return(True)\n", - " except Exception as e:\n", - " print(f'Skipping bad file: {f}\\ndue to {type(e)}')\n", - " pass\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Demonstrating the usage here, not a great test though:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "root = Path('./images')\n", - "\n", - "\n", - "filepaths = get_image_files(root)\n", - "\n", - "len(filepaths)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "filepaths[:3]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Loaders\n", - "\n", - "So we have a list of paths and slugified filenames from the folder. We want to see if there's an archive, so that we don't have to recalculate tensors for images we've seen before. Then we want to pass that directly to the indexer, but send the new images through the crafter and encoder first.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "import torch\n", - "import torchvision" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We want to use the GPU, if possible, for all the pyTorch functions. But if we can't get access to it we need to fallback to CPU. Either way we call it `device` and pass it to each function in the executors that use torch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `archive_loader` is only called in `indexFlow`. It takes the list of image files and the folder they're in (and the torch device), opens an archive if there is one" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def archive_loader(filepaths, root, device):\n", - " dbpath = root/'memery.pt'\n", - "# dbpath_backup = root/'memery.pt'\n", - " db = db_loader(dbpath, device)\n", - " \n", - " current_slugs = [slug for path, slug in filepaths] \n", - " archive_db = {i:db[item[0]] for i, item in enumerate(db.items()) if item[1]['slug'] in current_slugs} \n", - " archive_slugs = [v['slug'] for v in archive_db.values()]\n", - " new_files = [(str(path), slug) for path, slug in filepaths if slug not in archive_slugs and verify_image(path)]\n", - " \n", - " return(archive_db, new_files)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `db_loader` takes a location and returns either the archive dictionary or an empty dictionary. Decomposed to its own function so it can be called separately from `archive_loader` or `queryFlow`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def db_loader(dbpath, device):\n", - "\n", - " # check for savefile or backup and extract\n", - " if Path(dbpath).exists():\n", - " db = torch.load(dbpath, device)\n", - "# elif dbpath_backup.exists():\n", - "# db = torch.load(dbpath_backup)\n", - " else:\n", - " db = {}\n", - " return(db)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The library `annoy`, [Approximate Nearest Neighbors Oh Yeah!](https://github.com/spotify/annoy) allows us to search through vector space for approximate matches instead of exact best-similarity matches. We sacrifice accuracy for speed, so we can search through tens of thousands of images in less than a thousand times the time it would take to search through tens of images. There's got to be a better way to put that." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "from annoy import AnnoyIndex" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def treemap_loader(treepath):\n", - " treemap = AnnoyIndex(512, 'angular')\n", - "\n", - " if treepath.exists():\n", - " treemap.load(str(treepath))\n", - " else:\n", - " treemap = None\n", - " return(treemap)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "treepath = Path('images/memery.ann')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "treemap = AnnoyIndex(512, 'angular')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if treepath.exists():\n", - " treemap.load(str(treepath))\n", - "else:\n", - " treemap = None" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we just test on the local image folder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "archive_db, new_files = archive_loader(get_image_files(root), root, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(archive_db), len(new_files), treemap.get_n_items()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dbpath = root/'memery.pt'\n", - "# dbpath_backup = root/'memery.pt'\n", - "db = db_loader(dbpath, device)\n", - "\n", - "current_slugs = [slug for path, slug in filepaths] " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "archive_db = {i:db[item[0]] for i, item in enumerate(db.items()) if item[1]['slug'] in current_slugs} " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(archive_db)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/02_crafter.ipynb b/notebooks/02_crafter.ipynb deleted file mode 100644 index 5e6aa17..0000000 --- a/notebooks/02_crafter.ipynb +++ /dev/null @@ -1,280 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp crafter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Crafter\n", - "\n", - "Takes a list of image filenames and transforms them to batches of the correct dimensions for CLIP. \n", - "\n", - "This executor subclasses PyTorch's VisionDataset (for its file-loading expertise) and DataLoaders. The `DatasetImagePaths` takes a list of image paths and a transfom, returns the transformed tensors when called. DataLoader does batching internally so we pass it along to the encoder in that format.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "import torch\n", - "from torchvision.datasets import VisionDataset\n", - "from PIL import Image\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def make_dataset(new_files):\n", - " '''Returns a list of samples of a form (path_to_sample, class) and in \n", - " this case the class is just the filename'''\n", - " samples = []\n", - " slugs = []\n", - " for i, f in enumerate(new_files):\n", - " path, slug = f\n", - " samples.append((str(path), i))\n", - " slugs.append((slug, i))\n", - " return(samples, slugs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def pil_loader(path: str) -> Image.Image:\n", - " # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n", - " with open(path, 'rb') as f:\n", - " img = Image.open(f)\n", - " return img.convert('RGB')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "class DatasetImagePaths(VisionDataset):\n", - " def __init__(self, new_files, transforms = None):\n", - " super(DatasetImagePaths, self).__init__(new_files, transforms=transforms)\n", - " samples, slugs = make_dataset(new_files)\n", - " self.samples = samples\n", - " self.slugs = slugs\n", - " self.loader = pil_loader\n", - " self.root = 'file dataset'\n", - " def __len__(self):\n", - " return(len(self.samples))\n", - " \n", - " def __getitem__(self, index):\n", - " path, target = self.samples[index]\n", - " sample = self.loader(path)\n", - " if sample is not None:\n", - " if self.transforms is not None:\n", - " sample = self.transforms(sample)\n", - " return sample, target" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "new_files = [('images/memes/Wholesome-Meme-8.jpg', 'Wholesome-Meme-8'), ('images/memes/Wholesome-Meme-1.jpg', 'Wholesome-Meme-1')]#, ('images/corrupted-file.jpeg', 'corrupted-file.jpeg')]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "crafted = DatasetImagePaths(new_files)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "crafted[0][0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Okay, that seems to work decently. Test with transforms, which I will just find in CLIP source code and copy over, to prevent having to import CLIP in this executor." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def clip_transform(n_px):\n", - " return Compose([\n", - " Resize(n_px, interpolation=Image.BICUBIC),\n", - " CenterCrop(n_px),\n", - " ToTensor(),\n", - " Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n", - " ])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Put that all together, and wrap in a DataLoader for batching. In future, need to figure out how to pick batch size and number of workers programmatically bsed on device capabilities." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def crafter(new_files, device, batch_size=128, num_workers=4): \n", - " with torch.no_grad():\n", - " imagefiles=DatasetImagePaths(new_files, clip_transform(224))\n", - " img_loader=torch.utils.data.DataLoader(imagefiles, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n", - " return(img_loader)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "crafted_files = crafter(new_files, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "crafted_files.batch_size, crafted_files.num_workers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "file = new_files[1][0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def preproc(img):\n", - " transformed = clip_transform(224)(img)\n", - " return(transformed)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "im = preproc([Image.open(file)][0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# %matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# show_image(im)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/03_encoder.ipynb b/notebooks/03_encoder.ipynb deleted file mode 100644 index 159281e..0000000 --- a/notebooks/03_encoder.ipynb +++ /dev/null @@ -1,185 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp encoder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Encoder\n", - "\n", - "\n", - "This is just a wrapper around CLIP functions. Cool thing here is we can use the one model for both image and text!\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "import torch\n", - "import clip\n", - "from tqdm import tqdm\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model, _ = clip.load(\"ViT-B/32\", device, jit=False) \n", - "model = model.float()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def image_encoder(img_loader, device):\n", - " image_embeddings = torch.tensor(()).to(device)\n", - " with torch.no_grad():\n", - " for images, labels in tqdm(img_loader):\n", - " batch_features = model.encode_image(images.to(device))\n", - " image_embeddings = torch.cat((image_embeddings, batch_features)).to(device)\n", - " \n", - " image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)\n", - " return(image_embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "new_files = [('images/memes/Wholesome-Meme-8.jpg', 'Wholesome-Meme-8'), ('images/memes/Wholesome-Meme-1.jpg', 'Wholesome-Meme-1')]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from memery.crafter import crafter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "img_loader = crafter(new_files, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for images, labels in img_loader:\n", - " print(images)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "image_embeddings = image_encoder(img_loader, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "image_embeddings.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The text encoder returns a 512d vector just like the image encoder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def text_encoder(text, device):\n", - " with torch.no_grad():\n", - " text = clip.tokenize(text).to(device)\n", - " text_features = model.encode_text(text)\n", - " text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n", - " return(text_features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "text_embedding = text_encoder('a funny dog meme', device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "text_embedding.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def image_query_encoder(image, device):\n", - " with torch.no_grad():\n", - " image_embed = model.encode_image(image.unsqueeze(0).to(device))\n", - " image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)\n", - " return(image_embed)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/04_indexer.ipynb b/notebooks/04_indexer.ipynb deleted file mode 100644 index 60c7ed3..0000000 --- a/notebooks/04_indexer.ipynb +++ /dev/null @@ -1,244 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp indexer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Indexer\n", - "\n", - "Given a dataset of tensors, returns a dictionary archive and a treemap structure (and saves them to disk)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Joiner\n", - "\n", - "This executor `needs` both Encoder and Loader to send it the new and old vectors, respectively. So it needs to be preceded by the **join_all** component to make sure we're not missing new data before handing it over to the indexer -- or indexing old data that no longer exists!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def join_all(db, new_files, new_embeddings):\n", - " start = len(db)\n", - " for i, file in enumerate(new_files):\n", - " path, slug = file\n", - " index = i + start\n", - " db[index] = {\n", - " 'slug': slug,\n", - " 'fpath': path,\n", - " 'embed': new_embeddings[i],\n", - " }\n", - " return(db)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from pathlib import Path\n", - "from memery.loader import get_image_files, db_loader, archive_loader\n", - "from memery.crafter import crafter\n", - "from memery.encoder import image_encoder\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "root = Path('images/')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "filepaths = get_image_files(root)\n", - "archive_db = {}\n", - "\n", - "\n", - "archive_db, new_files = archive_loader(filepaths, root, device)\n", - "print(f\"Loaded {len(archive_db)} encodings\")\n", - "print(f\"Encoding {len(new_files)} new images\")\n", - "\n", - "crafted_files = crafter(new_files, device)\n", - "new_embeddings = image_encoder(crafted_files, device)\n", - "\n", - "db = join_all(archive_db, new_files, new_embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# db = db_loader(root/'memery.pt',device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "[o[0] for o in db.items()][:5]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(db)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Building treemap takes a long time. I don't think `annoy` uses the GPU at all?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "from annoy import AnnoyIndex" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def build_treemap(db):\n", - " treemap = AnnoyIndex(512, 'angular')\n", - " for k, v in db.items():\n", - " treemap.add_item(k, v['embed'])\n", - "\n", - " # Build the treemap, with 5 trees rn\n", - " treemap.build(5)\n", - "\n", - " return(treemap)\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = build_treemap(db)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.get_n_items(), t.get_n_trees()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def save_archives(root, treemap, db):\n", - " dbpath = root/'memery.pt'\n", - " if dbpath.exists():\n", - "# dbpath.rename(root/'memery-bak.pt')\n", - " dbpath.unlink()\n", - " torch.save(db, dbpath)\n", - " \n", - " treepath = root/'memery.ann'\n", - " if treepath.exists():\n", - "# treepath.rename(root/'memery-bak.ann')\n", - " treepath.unlink()\n", - " treemap.save(str(treepath))\n", - " \n", - " return(str(dbpath), str(treepath))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "save_archives(root, t, db)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/05_ranker.ipynb b/notebooks/05_ranker.ipynb deleted file mode 100644 index 77ea415..0000000 --- a/notebooks/05_ranker.ipynb +++ /dev/null @@ -1,152 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp ranker" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Ranker\n", - "\n", - "Takes a query and an index and finds the nearest neighbors or most similar scores. Ideally this is just a simple Annoy `get_nns_by_vector`, or in the simple case a similarity score across all the vectors." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "\n", - "from pathlib import Path\n", - "\n", - "from memery.loader import treemap_loader, db_loader\n", - "from memery.encoder import text_encoder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "treemap = treemap_loader(Path('./images/memery.ann'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if treemap:\n", - " treemap.get_n_items()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def ranker(query_vec, treemap):\n", - " nn_indexes = treemap.get_nns_by_vector(query_vec[0], treemap.get_n_items())\n", - " return(nn_indexes)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def nns_to_files(db, indexes):\n", - "# return([[v['fpath'] for k,v in db.items() if v['index'] == ind][0] for ind in indexes])\n", - " return([db[ind]['fpath'] for ind in indexes])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "db = db_loader(Path('images/memery.pt'), device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "query = 'dog'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "query_vec = text_encoder(query, device)\n", - "indexes = ranker(query_vec, treemap)\n", - "ranked_files = nns_to_files(db, indexes)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ranked_files[:5]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/07_cli.ipynb b/notebooks/07_cli.ipynb deleted file mode 100644 index 8b29e4e..0000000 --- a/notebooks/07_cli.ipynb +++ /dev/null @@ -1,147 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp cli" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# CLI" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "import typer\n", - "import memery.core\n", - "import streamlit.cli" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "app = typer.Typer()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Sometimes you just want to be able to pipe information through the terminal. This is that command" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "@app.command()\n", - "def recall(path: str, query: str, n: int = 10):\n", - " \"\"\"Search recursively over a folder from the command line\"\"\"\n", - " ranked = memery.core.query_flow(path, query=query)\n", - " print(ranked[:n])\n", - "# return(ranked)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "recall('./images', 'a funny dog meme')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "More often, though, you probably want to sift through image visually. The `memery serve` command will open a browser app on your local device, using Streamlit library." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "@app.command()\n", - "def serve():\n", - " \"\"\"Runs the streamlit GUI in your browser\"\"\"\n", - " path = memery.__file__.replace('__init__.py','streamlit_app.py')\n", - " streamlit.cli.main(['run',path])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# serve()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export \n", - "def __main__():\n", - " app()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/08_jupyter_gui.ipynb b/notebooks/08_jupyter_gui.ipynb deleted file mode 100644 index 5b7e7bc..0000000 --- a/notebooks/08_jupyter_gui.ipynb +++ /dev/null @@ -1,266 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp gui" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# GUI" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "import ipywidgets as widgets\n", - "\n", - "from memery.core import query_flow\n", - "from pathlib import Path\n", - "from IPython.display import clear_output\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## App design\n", - "\n", - "So what zones do we need for a proper image search app? Two examples come to mind: https://same.energy and https://images.google.com. One is minimalist and brutalist while the other is maximalist in features and refined in design.\n", - "\n", - "Same.energy proves that all you need for image search is a text box, a button, and images. (At least, that's how it started off, and sometimes how it is today. They're A/B testing heavily right now, and we'll see what it evolves into.) If you click on an image result, you are now searching for that image. If you add text, it asks if you want to search for the image with text or just the image. This can lead in any hill-climbing direction the user wants, I suppose. \n", - "\n", - "Google Images has up to six toolbars overhanging the images, and a complicated lightbox selection window that shows the individual image with a subset of similar images below it. Nested and stacked, providing lots of specific search and filtering capabilities. Not as likely to induce a wikiwalk. They've introduced \"collections\" now, which are presumably meant to replace the \"download to random image folder\" functionality of current browsers.\n", - "\n", - "There's also Pinterest, of course, though their engineering is geared more toward gaming Google results than finding the right image by search. Thye have a great browse mode though, and save features. Best of all, they have a goodreads-style user tagging function that allows for a whole different way of sorting images than availableon the other sites.\n", - "\n", - "The functions available from these sites include:\n", - "\n", - "- Text query\n", - "- Image query\n", - "- Text and image query (totally doable with CLIP vectors)\n", - "- Browse visually similar images\n", - "- Save images (to cloud mostly)\n", - "- Filter images by:\n", - " - Size\n", - " - Color\n", - " - Type\n", - " - Time\n", - " - Usage rights\n", - "- Visit homepage for image\n", - "- Tagging images\n", - "- Searching by tags additively\n", - "- Filtering out by tags\n", - "\n", - "Tags and filter categories can both be simulated with CLIP vectors of text tokens like \"green\" or \"noisy\" or \"illustration\" or \"menswear\". Size of image can be inferred directly from filesize or recorded from bitmap data in the `crafter`. Images as search queries and visually similar image browser are the same function but in different user interaction modes. And image links can be to local files, rather than homepages. Saving images not as relevant in this context, though easily sending them somewhere else is. \n", - "\n", - "Thus there are really three projects here:\n", - "- Basic app functionality with search and grid\n", - "- Visually simillar image browsing and search\n", - "- Tagging and filtering, auto and manual\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Basic app functionality\n", - "\n", - "We want a unified search bar (variable inputs and a button) and an image grid. And each search should remain accessible after it's run, so we can navigate between and compare. It would be nice to use browser-native navigation but for now, with the plan to run a notebook in Voila and serve locally, better to use `ipywidgets` Tabs mode. Eventually it would also be good to replace or upgrade `ipyplot` or better navigation, but first we should sketch out the new-tab functionality.\n", - "\n", - "Need a tabs output, an event loop, a dictionary of searches run, each search returning a list of filenames to be printed in a sub-output within the tab. All wrapped in a VBox with the inputs.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "filepaths = ['images/memes/Wholesome-Meme-8.jpg', 'images/memes/Wholesome-Meme-1.jpg']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def get_image(file_loc):\n", - " filepath = Path(file_loc)\n", - " file = open(filepath, 'rb')\n", - " image = widgets.Image(value=file.read(),width=200)\n", - " \n", - " return(image)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "display(get_image(filepaths[0]))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "imgs = [get_image(f) for f in filepaths]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def get_grid(filepaths, n=4):\n", - " imgs = [get_image(f) for f in filepaths[:n] if Path(f).exists()]\n", - " grid = widgets.GridBox(imgs, layout=widgets.Layout(grid_template_columns=\"repeat(auto-fit, 200px)\"))\n", - " return(grid)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "get_grid(filepaths)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "from PIL import Image\n", - "from io import BytesIO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def update_tabs(path, query, n_images, searches, tabs, logbox, im_display_zone, image_query=None):\n", - " stem = Path(path.value).stem\n", - " slug = f\"{stem}:{str(query.value)}\"\n", - " if slug not in searches.keys():\n", - " with logbox:\n", - " print(slug)\n", - " if image_query:\n", - " im_queries = [name for name, data in image_query.items()]\n", - " \n", - " img = [Image.open(BytesIO(file_info['content'])).convert('RGB') for name, file_info in image_query.items()]\n", - " ranked = query_flow(path.value, query.value, image_query=img[-1])\n", - " slug = slug + f'/{im_queries}'\n", - " \n", - " if len(im_queries) > 0:\n", - " with im_display_zone:\n", - " clear_output()\n", - " display(img[-1])\n", - " else:\n", - " ranked = query_flow(path.value, query.value)\n", - " searches[f'{slug}'] = ranked\n", - " \n", - " tabs.children = [get_grid(v, n=n_images.value) for v in searches.values()]\n", - " for i, k in enumerate(searches.keys()):\n", - " tabs.set_title(i, k)\n", - " tabs.selected_index = len(searches)-1\n", - "\n", - " \n", - "# return(True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "class appPage():\n", - " \n", - " def __init__(self):\n", - " self.inputs_layout = widgets.Layout(max_width='80%')\n", - "\n", - " self.path = widgets.Text(placeholder='path/to/image/folder', value='images/', layout=self.inputs_layout)\n", - " self.query = widgets.Text(placeholder='a funny dog meme', value='a funny dog meme', layout=self.inputs_layout)\n", - " \n", - " self.image_query = widgets.FileUpload()\n", - " self.im_display_zone = widgets.Output(max_height='5rem')\n", - "\n", - " self.n_images = widgets.IntSlider(description='#', value=4, layout=self.inputs_layout)\n", - " self.go = widgets.Button(description=\"Search\", layout=self.inputs_layout)\n", - " self.logbox = widgets.Output(layout=widgets.Layout(max_width='80%', height=\"3rem\", overflow=\"none\"))\n", - " self.all_inputs_layout = widgets.Layout(max_width='80vw', min_height='40vh', flex_flow='row wrap', align_content='flex-start')\n", - "\n", - " self.inputs = widgets.Box([self.path, self.query, self.image_query, self.n_images, self.go, self.im_display_zone, self.logbox], layout=self.all_inputs_layout)\n", - " self.tabs = widgets.Tab()\n", - " self.page = widgets.AppLayout(left_sidebar=self.inputs, center=self.tabs)\n", - "\n", - " self.searches = {}\n", - " self.go.on_click(self.page_update)\n", - " \n", - " display(self.page)\n", - "\n", - " def page_update(self, b):\n", - " \n", - " update_tabs(self.path, self.query, self.n_images, self.searches, self.tabs, self.logbox, self.im_display_zone, self.image_query.value)\n", - "\n", - " \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "app = appPage()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/09_streamlit_app.ipynb b/notebooks/09_streamlit_app.ipynb deleted file mode 100644 index f8e4352..0000000 --- a/notebooks/09_streamlit_app.ipynb +++ /dev/null @@ -1,224 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp streamlit_app" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Streamlit app" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Streamlit is a more convenient way to activate a quick user-facing GUI than Voila was, especially because of Voila having conflicting dependencies with nbdev.\n", - "\n", - "However, Streamlit wants a `.py` file instead of a notebook for development. This is kind of annoying, because to get the hot-reload effect from Streamlit we have to develop outside the notebook, but to maintain documentation (and compile with everything else) we have to keep the main source of truth right here. Perhaps a solution will present itself later; meanwhile, I have been using a scratch file `streamlit-app.py` for development and then copied it back here." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This is a workaround for the query_flow printing to stdout. Maybe it should be handled natively in Streamlit? " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export \n", - "import streamlit as st\n", - "from memery import core\n", - "\n", - "from pathlib import Path\n", - "from PIL import Image\n", - "\n", - "from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME\n", - "from threading import current_thread\n", - "from contextlib import contextmanager\n", - "from io import StringIO\n", - "import sys" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export \n", - "@contextmanager\n", - "def st_redirect(src, dst):\n", - " placeholder = st.empty()\n", - " output_func = getattr(placeholder, dst)\n", - "\n", - " with StringIO() as buffer:\n", - " old_write = src.write\n", - "\n", - " def new_write(b):\n", - " if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None):\n", - " buffer.write(b + '')\n", - " output_func(buffer.getvalue() + '')\n", - " else:\n", - " old_write(b)\n", - "\n", - " try:\n", - " src.write = new_write\n", - " yield\n", - " finally:\n", - " src.write = old_write\n", - "\n", - "\n", - "@contextmanager\n", - "def st_stdout(dst):\n", - " with st_redirect(sys.stdout, dst):\n", - " yield\n", - "\n", - "\n", - "@contextmanager\n", - "def st_stderr(dst):\n", - " with st_redirect(sys.stderr, dst):\n", - " yield" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Trying to make good use of streamlit's caching service here; if the search query and folder are the same as a previous search, it will serve the cached version. Might present some breakage points though, yet to see." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "@st.cache\n", - "def send_image_query(path, text_query, image_query):\n", - " ranked = core.query_flow(path, text_query, image_query=img)\n", - " return(ranked)\n", - "\n", - "@st.cache\n", - "def send_text_query(path, text_query):\n", - " ranked = core.query_flow(path, text_query)\n", - " return(ranked)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This is the sidebar content" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "st.sidebar.title(\"Memery\")\n", - "\n", - "path = st.sidebar.text_input(label='Directory', value='./images')\n", - "text_query = st.sidebar.text_input(label='Text query', value='')\n", - "image_query = st.sidebar.file_uploader(label='Image query')\n", - "im_display_zone = st.sidebar.beta_container()\n", - "logbox = st.sidebar.beta_container()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The image grid parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "sizes = {'small': 115, 'medium':230, 'large':332, 'xlarge':600}\n", - "\n", - "l, m, r = st.beta_columns([4,1,1])\n", - "with l:\n", - " num_images = st.slider(label='Number of images',value=12)\n", - "with m:\n", - " size_choice = st.selectbox(label='Image width', options=[k for k in sizes.keys()], index=1)\n", - "with r:\n", - " captions_on = st.checkbox(label=\"Caption filenames\", value=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And the main event loop, triggered every time the query parameters change.\n", - "\n", - "This doesn't really work in Jupyter at all. Hope it does once it's compiled." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "if text_query or image_query:\n", - " with logbox:\n", - " with st_stdout('info'):\n", - " if image_query is not None:\n", - " img = Image.open(image_query).convert('RGB')\n", - " with im_display_zone:\n", - " st.image(img)\n", - " ranked = send_image_query(path, text_query, image_query)\n", - " else:\n", - " ranked = send_text_query(path, text_query)\n", - " ims = [Image.open(o).convert('RGB') for o in ranked[:num_images]]\n", - " names = [o.replace(path, '') for o in ranked[:num_images]]\n", - "\n", - " if captions_on:\n", - " images = st.image(ims, width=sizes[size_choice], channels='RGB', caption=names)\n", - " else:\n", - " images = st.image(ims, width=sizes[size_choice], channels='RGB')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/_visualize.ipynb b/notebooks/_visualize.ipynb deleted file mode 100644 index 40bf13e..0000000 --- a/notebooks/_visualize.ipynb +++ /dev/null @@ -1,552 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#default_exp visualizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#hide\n", - "from nbdev.showdoc import *\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Visualize" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "## Dimensionality Reduction\n", - "\n", - "One use-case for `memery` is to explore large image datasets, for cleaning and curation purposes. Sifting images by hand takes a long time, and it's near impossible to keep all the images in your mind at noce.\n", - "\n", - "Even with semantic search capabilities, it's hard to get an overview of all the images. CLIP sees things in many more dimensions than humans do, so no matter how many searches you run you can't be sure if you're missing some outliers you don't even know to search for.\n", - "\n", - "The ideal overview would be a map of all the images along all the dimensions, but we don't know how to visualize or parse 512-dimensional spaces for human brains. So we have to do dimensional reduction: find a function in some space with ≤ 3 dimensions that best emulates the 512-dim embeddings we have, and map that instead.\n", - "\n", - "The recent advance in dimensional reduction is Minimum Distortion Embedding, an abstraction over all types of embeddings like PCA, t-SNE, or k-means clustering. We can use the `pymde` library to embed them and `matplotlib` to draw the images as their own markers on the graph. We'll also need `torch` to process the tensors, and `memery` functions to process the database" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pymde\n", - "import torch\n", - "from pathlib import Path\n", - "from memery.loader import db_loader\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's get a database of embeddings from the local folder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db = db_loader('images/memery.pt', device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db[0].keys()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "embeds = torch.stack([v['embed'] for v in db.values()], 0)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are two methods to invoke with `pymde`: `preserve_neighbors` and `preserve_distances`. They create different textures in the final product. Let's see what each looks like on our sample dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mde_n = pymde.preserve_neighbors(embeds, verbose=False, device='cuda')\n", - "mde_d = pymde.preserve_distances(embeds, verbose=False, device='cuda')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "embed_n = mde_n.embed(verbose=False, snapshot_every=1)\n", - "embed_d = mde_d.embed(verbose=False, snapshot_every=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pymde.plot(embed_n)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pymde.plot(embed_d)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mde_n.play(savepath='./graphs/mde_n.gif')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mde_d.play(savepath='./graphs/mde_d.gif')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert embed_n.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "Now I want to plot images as markers, instead of little dots. Haven't figured out yet how to merge this with `pymde.plot` functions, so I'm doing it right in matplotlib. \n", - "\n", - "If we just plot the images at their coordinates, they will overlap (especially on the `preserve_neighbors` plot) so eventually maybe I can normalize the x and y axes and plot things on a grid? at least a little bit" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n", - "from tqdm import tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_images_from_tensors(coords, image_paths, dpi=600, savefile = 'default.jpg', zoom=0.03):\n", - " fig, ax = plt.subplots()\n", - " fig.dpi = dpi\n", - " fig.set_size_inches(8,8)\n", - " \n", - " ax.xaxis.set_visible(False) \n", - " ax.yaxis.set_visible(False)\n", - " \n", - " cc = coords.cpu()\n", - " x_max, y_max = cc.argmax(0)\n", - " x_min, y_min = cc.argmin(0)\n", - " \n", - " low = min(cc[x_min][0], cc[y_min][1])\n", - " high = max(cc[x_max][0], cc[y_max][1])\n", - " sq_lim = max(abs(low), abs(high))\n", - " \n", - " plt.xlim(low, high)\n", - " plt.ylim(low, high)\n", - " \n", - "# plt.xlim(-sq_lim, sq_lim)\n", - "# plt.ylim(-sq_lim, sq_lim)\n", - "\n", - " for i, coord in tqdm(enumerate(coords)):\n", - " try:\n", - " x, y = coord\n", - "\n", - " path = str(image_paths[i])\n", - " with open(path, 'rb') as image_file:\n", - " image = plt.imread(image_file)\n", - "\n", - " im = OffsetImage(image, zoom=zoom, resample=False)\n", - " im.image.axes = ax\n", - " ab = AnnotationBbox(im, (x,y), frameon=False, pad=0.0,)\n", - " ax.add_artist(ab)\n", - " except SyntaxError:\n", - " pass\n", - " print(\"Drawing images as markers...\")\n", - " plt.savefig(savefile)\n", - " print(f'Saved image to {savefile}')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "filenames = [v['fpath'] for v in db.values()]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "savefile = 'graphs/embed_n.jpg'\n", - "\n", - "plot_images_from_tensors(embed_n, filenames, savefile=savefile)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "savefile = 'graphs/embed_d.jpg'\n", - "\n", - "plot_images_from_tensors(embed_d, filenames, savefile=savefile)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "I suppose it makes sense that the `preserve_neighbors` function clumps things together and the `preserve_distances` spreads them out. It's nice to see the actual distances and texture of the data, for sure. But I'd also like to be able to see them bigger, with only relative data about where they are to each other. Let's see if we can implement a normalization function and plot them again.\n", - "\n", - "Currently the embedding tensor is basically a list pairs of floats. Can I convert those to a set of integers that's the length of the amount of images? I don't know how to do this in matrix math so I'll try it more simply first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(embed_n)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "embed_list = [(float(x),float(y)) for x,y in embed_n]\n", - "embed_dict = {k: v for k, v in zip(filenames, embed_list)}\n", - "len(embed_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def normalize_embeds(embed_dict):\n", - " sort_x = {k: v[0] for k, v in sorted(embed_dict.items(), key=lambda item: item[1][0])}\n", - " norm_x = {item[0]: i for i, item in enumerate(sort_x.items())}\n", - " \n", - " sort_y = {k: v[1] for k, v in sorted(embed_dict.items(), key=lambda item: item[1][1])}\n", - " norm_y = {item[0]: i for i, item in enumerate(sort_y.items())}\n", - "\n", - " normalized_dict = {k: (norm_x[k], norm_y[k]) for k in embed_dict.keys()}\n", - " return(normalized_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "norm_dict = normalize_embeds(embed_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(norm_dict)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "I probably could do that all in torch but right now I'm just going to pipe it back into tensors and put it through my plotting function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "norms = torch.stack([torch.tensor([x, y]) for x, y in norm_dict.values()])\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_images_from_tensors(norms, filenames, savefile='graphs/normalized.jpg')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It worked!! The clusters still exist but their distances are relaxed so they can be displayed better on the graph. It's removing some information, for sure. but unclear if that is information a human needs.\n", - "\n", - "I wonder if it works on the `preserve_distances` method..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "embed_list = [(float(x),float(y)) for x,y in embed_d]\n", - "embed_dict = {k: v for k, v in zip(filenames, embed_list)}\n", - "norm_dict = normalize_embeds(embed_dict)\n", - "norms = torch.stack([torch.tensor([x, y]) for x, y in norm_dict.values()])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_images_from_tensors(norms, filenames, savefile='graphs/normalized-d.jpg')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This looks okay. It reduces overall distances but keeps relative distances? Still not sure what the actionalbe difference between these two methods is. \n", - "\n", - "Well, it works okay for now. The next question is, how to incorporate it into a working GUI?\n", - "\n", - "I wonder how matplotlib does natively, for a much larger dataset. Let's see:\n", - "\n", - "# Large dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def normalize_tensors(embdgs, names):\n", - " embed_list = [(float(x),float(y)) for x,y in embdgs]\n", - " embed_dict = {k: v for k, v in zip(names, embed_list)}\n", - " norm_dict = normalize_embeds(embed_dict)\n", - " norms = torch.stack([torch.tensor([x, y]) for x, y in norm_dict.values()])\n", - " return(norms)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db = db_loader('/home/mage/Pictures/memes/memery.pt', device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "filenames = [v['fpath'] for v in db.values()]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "clips = torch.stack([v['embed'] for v in db.values()])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "filenames[:5]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mde_lg = pymde.preserve_neighbors(clips, verbose=False, device='cuda')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "embed_lg = mde_lg.embed(verbose=False, snapshot_every=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "norms_lg = normalize_tensors(embed_lg,filenames)\n", - "len(norms_lg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_images_from_tensors(embed_lg, filenames, savefile='graphs/normalized-lg.jpg')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "\n", - "### Be careful here\n", - "\n", - "It is possible to use embeddings as target coordinates to delete sections of the data:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "to_delete = []\n", - "for coord, img in zip(#embedding, filenames):\n", - " x, y = coord\n", - " if x < -2 or y < -1:\n", - " to_delete.append(img)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(to_delete)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for img in to_delete:\n", - " imgpath = Path(img)\n", - " imgpath.unlink()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It worked! A better distribution and fewer of the wrong things" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/_working_pipeline.ipynb b/notebooks/_working_pipeline.ipynb deleted file mode 100644 index 22800a4..0000000 --- a/notebooks/_working_pipeline.ipynb +++ /dev/null @@ -1,1320 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Modular flow system\n", - "\n", - "I have decided to adapt the design system from Jina into this repo, at least for prototyping purposes. Their distributed systems approach seems quite good but is too muc complexity for me to add right away. Insetead I'm going to replicate the essential design pattern, that of Flows and Executors.\n", - "\n", - "**Flows** are specific patterns of data manipulation and storage. **Executors** are the operators that transform the data within the flow. \n", - "\n", - "There are two core flows to any search system: indexing, and querying. The plan here is to make executors that can be composed into flows and then compose the flows into a UI that supports querying and, to some extent, indexing as well.\n", - "\n", - "The core executors for this use case are:\n", - " - Loader\n", - " - Crafter\n", - " - Encoder\n", - " - Indexer\n", - " - Ranker\n", - " - Gateway\n", - " \n", - "In this file I try to build these so that the Jupyter notebook itself can be run as a Flow for indexing and then querying. From there it should be easy to abstract the functions and classes and messaging or whatever is necessary for microservices etc." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.215480Z", - "iopub.status.busy": "2021-05-17T23:26:12.215153Z", - "iopub.status.idle": "2021-05-17T23:26:12.218180Z", - "shell.execute_reply": "2021-05-17T23:26:12.217661Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.215435Z" - } - }, - "outputs": [], - "source": [ - "# move these to main function eventually but for now we're going in notebook order\n", - "args = {\n", - " \"path\": \"/home/mage/Pictures/memes/\",\n", - " \"query\": \"scary cat\",\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Loader\n", - "\n", - "The loader takes a directory or list of image files and checks them against database or checkpoint. If there is a saved checkpoint and the files haven't changed, it loads the checkpoint and sends the data directly to Ranker. If not, it sends them to Crafter. Ideally it could send new images to Crafter and load dictionary of old images at the same time, without re-encoding old images.\n", - "\n", - "The process of indexing could actually happen in the background while querying happens on the old index! This means putting the logic in the Flow rather than the Loader, I suppose.\n", - "\n", - "Maybe build dictionary `{filename_timestamp : vector}` to databse as a simple version control mechanism. Then, if any filenames exist but with a different timestamp, we load those under their own key. And we can throw out any filename_timestamp that doesn't exist, before indexing. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.219219Z", - "iopub.status.busy": "2021-05-17T23:26:12.218988Z", - "iopub.status.idle": "2021-05-17T23:26:12.221615Z", - "shell.execute_reply": "2021-05-17T23:26:12.221131Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.219201Z" - } - }, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "root = Path(args['path'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.222765Z", - "iopub.status.busy": "2021-05-17T23:26:12.222612Z", - "iopub.status.idle": "2021-05-17T23:26:12.225218Z", - "shell.execute_reply": "2021-05-17T23:26:12.224478Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.222747Z" - } - }, - "outputs": [], - "source": [ - "def slugify(filepath):\n", - " return f'{filepath.stem}_{str(filepath.stat().st_mtime).split(\".\")[0]}'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.226412Z", - "iopub.status.busy": "2021-05-17T23:26:12.226182Z", - "iopub.status.idle": "2021-05-17T23:26:12.229182Z", - "shell.execute_reply": "2021-05-17T23:26:12.228619Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.226386Z" - } - }, - "outputs": [], - "source": [ - "# filenames = path.iterdir()\n", - "def get_image_files(path):\n", - " return [(f, slugify(f)) for f in path.rglob('*') if f.suffix in ['.jpg', '.png', '.jpeg']]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.230161Z", - "iopub.status.busy": "2021-05-17T23:26:12.229994Z", - "iopub.status.idle": "2021-05-17T23:26:12.271774Z", - "shell.execute_reply": "2021-05-17T23:26:12.271310Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.230139Z" - } - }, - "outputs": [], - "source": [ - "filepaths = get_image_files(root)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.272506Z", - "iopub.status.busy": "2021-05-17T23:26:12.272342Z", - "iopub.status.idle": "2021-05-17T23:26:12.279507Z", - "shell.execute_reply": "2021-05-17T23:26:12.279149Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.272456Z" - } - }, - "outputs": [], - "source": [ - "len(filepaths)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.280416Z", - "iopub.status.busy": "2021-05-17T23:26:12.280299Z", - "iopub.status.idle": "2021-05-17T23:26:12.283291Z", - "shell.execute_reply": "2021-05-17T23:26:12.282872Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.280400Z" - } - }, - "outputs": [], - "source": [ - "filepaths[:5]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "So we have a list of paths and slugified filenames from the folder. We want to see if there's an archive, so that we don't have to recalculate tensors for images we've seen before. Then we want to pass that directly to the indexer, but send the new images through the crafter and encoder first.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-12T23:37:14.029490Z", - "iopub.status.busy": "2021-05-12T23:37:14.028825Z", - "iopub.status.idle": "2021-05-12T23:37:14.080913Z", - "shell.execute_reply": "2021-05-12T23:37:14.080380Z", - "shell.execute_reply.started": "2021-05-12T23:37:14.029406Z" - } - }, - "source": [ - "But I need to separate out the logic for the crafter and encoder from the simple loading of archives and pictures. This component should only provide the dictionary of archived CLIP embeddings, the treemap (eventually) and the locations of the new images to review, and let the downstream components deal with them." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.283975Z", - "iopub.status.busy": "2021-05-17T23:26:12.283864Z", - "iopub.status.idle": "2021-05-17T23:26:12.768725Z", - "shell.execute_reply": "2021-05-17T23:26:12.768101Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.283960Z" - } - }, - "outputs": [], - "source": [ - "import torch\n", - "import torchvision\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.770648Z", - "iopub.status.busy": "2021-05-17T23:26:12.770436Z", - "iopub.status.idle": "2021-05-17T23:26:12.775477Z", - "shell.execute_reply": "2021-05-17T23:26:12.774783Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.770627Z" - } - }, - "outputs": [], - "source": [ - "def files_archive_loader(filepaths, root, device):\n", - " dbpath = root/'memery.pt'\n", - "# dbpath_backup = root/'memery.pt'\n", - " db = db_loader(dbpath)\n", - " \n", - " current_slugs = [slug for path, slug in filepaths] \n", - " archive_db = {k:db[k] for k in db if k in current_slugs} \n", - " archive_slugs = [v['slug'] for v in archive_db.values()]\n", - " new_files = [(str(path), slug) for path, slug in filepaths if slug not in archive_slugs]\n", - " \n", - " return(archive_db, new_files)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.778681Z", - "iopub.status.busy": "2021-05-17T23:26:12.778540Z", - "iopub.status.idle": "2021-05-17T23:26:12.781397Z", - "shell.execute_reply": "2021-05-17T23:26:12.780882Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.778662Z" - } - }, - "outputs": [], - "source": [ - "def db_loader(dbpath):\n", - " # check for savefile or backup and extract\n", - " if dbpath.exists():\n", - " db = torch.load(dbpath)\n", - "# elif dbpath_backup.exists():\n", - "# db = torch.load(dbpath_backup)\n", - " else:\n", - " db = {}\n", - " return(db)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.782479Z", - "iopub.status.busy": "2021-05-17T23:26:12.782283Z", - "iopub.status.idle": "2021-05-17T23:26:12.785981Z", - "shell.execute_reply": "2021-05-17T23:26:12.785150Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.782459Z" - } - }, - "outputs": [], - "source": [ - "def treemap_loader(treepath):\n", - " treemap = AnnoyIndex(512, 'angular')\n", - "\n", - " if treepath.exists():\n", - " treemap.load(str(treepath))\n", - " else:\n", - " treemap = None\n", - " return(treemap)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:12.787044Z", - "iopub.status.busy": "2021-05-17T23:26:12.786880Z", - "iopub.status.idle": "2021-05-17T23:26:14.981021Z", - "shell.execute_reply": "2021-05-17T23:26:14.980482Z", - "shell.execute_reply.started": "2021-05-17T23:26:12.787026Z" - } - }, - "outputs": [], - "source": [ - "archive_db, new_files = files_archive_loader(get_image_files(Path(args['path'])), root, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:14.982030Z", - "iopub.status.busy": "2021-05-17T23:26:14.981859Z", - "iopub.status.idle": "2021-05-17T23:26:14.986891Z", - "shell.execute_reply": "2021-05-17T23:26:14.986449Z", - "shell.execute_reply.started": "2021-05-17T23:26:14.982010Z" - } - }, - "outputs": [], - "source": [ - "len(archive_db)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:14.987764Z", - "iopub.status.busy": "2021-05-17T23:26:14.987610Z", - "iopub.status.idle": "2021-05-17T23:26:14.991252Z", - "shell.execute_reply": "2021-05-17T23:26:14.990625Z", - "shell.execute_reply.started": "2021-05-17T23:26:14.987747Z" - } - }, - "outputs": [], - "source": [ - "len(new_files)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:14.992366Z", - "iopub.status.busy": "2021-05-17T23:26:14.992157Z", - "iopub.status.idle": "2021-05-17T23:26:14.996332Z", - "shell.execute_reply": "2021-05-17T23:26:14.995383Z", - "shell.execute_reply.started": "2021-05-17T23:26:14.992343Z" - } - }, - "outputs": [], - "source": [ - "\n", - "len(new_files),len(archive_db)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Crafter\n", - "\n", - "Takes a list of image filenames and transforms them to batches of the correct dimensions for CLIP. Need to figure out a way around torchvision's loader idiosyncrasies here: currently it just loads images from subfolders, needs to operate okay if pointed at a single folder of images, or recursively, or an arbitrary list of files.\n", - "\n", - "Then, too, it would be nice to eventually putthis work on the client computer using torchscript or something. So that it only sends 224x224x3 images over the wire. And we only have to compute those once per image, since we're storing a database of finished vectors which should be even smaller\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:14.997972Z", - "iopub.status.busy": "2021-05-17T23:26:14.997652Z", - "iopub.status.idle": "2021-05-17T23:26:15.000660Z", - "shell.execute_reply": "2021-05-17T23:26:15.000138Z", - "shell.execute_reply.started": "2021-05-17T23:26:14.997945Z" - } - }, - "outputs": [], - "source": [ - "from torchvision.datasets import VisionDataset\n", - "from PIL import Image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:15.005865Z", - "iopub.status.busy": "2021-05-17T23:26:15.005459Z", - "iopub.status.idle": "2021-05-17T23:26:15.010256Z", - "shell.execute_reply": "2021-05-17T23:26:15.009011Z", - "shell.execute_reply.started": "2021-05-17T23:26:15.005837Z" - } - }, - "outputs": [], - "source": [ - "def make_dataset(new_files):\n", - " '''Returns a list of samples of a form (path_to_sample, class) and in \n", - " this case the class is just the filename'''\n", - " samples = []\n", - " slugs = []\n", - " for i, f in enumerate(new_files):\n", - " path, slug = f\n", - " samples.append((str(path), i))\n", - " slugs.append((slug, i))\n", - " return(samples, slugs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:15.011746Z", - "iopub.status.busy": "2021-05-17T23:26:15.011203Z", - "iopub.status.idle": "2021-05-17T23:26:15.015031Z", - "shell.execute_reply": "2021-05-17T23:26:15.014629Z", - "shell.execute_reply.started": "2021-05-17T23:26:15.011681Z" - } - }, - "outputs": [], - "source": [ - "def pil_loader(path: str) -> Image.Image:\n", - " # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n", - " with open(path, 'rb') as f:\n", - " img = Image.open(f)\n", - " return img.convert('RGB')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:15.015970Z", - "iopub.status.busy": "2021-05-17T23:26:15.015807Z", - "iopub.status.idle": "2021-05-17T23:26:15.020168Z", - "shell.execute_reply": "2021-05-17T23:26:15.019612Z", - "shell.execute_reply.started": "2021-05-17T23:26:15.015951Z" - } - }, - "outputs": [], - "source": [ - "class DatasetImagePaths(VisionDataset):\n", - " def __init__(self, new_files, transforms = None):\n", - " super(DatasetImagePaths, self).__init__(new_files, transforms=transforms)\n", - " samples, slugs = make_dataset(new_files)\n", - " self.samples = samples\n", - " self.slugs = slugs\n", - " self.loader = pil_loader\n", - " self.root = 'file dataset'\n", - " def __len__(self):\n", - " return(len(self.samples))\n", - " \n", - " def __getitem__(self, index):\n", - " path, target = self.samples[index]\n", - " sample = self.loader(path)\n", - " if self.transforms is not None:\n", - " sample = self.transforms(sample)\n", - " return sample, target" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:15.021137Z", - "iopub.status.busy": "2021-05-17T23:26:15.020930Z", - "iopub.status.idle": "2021-05-17T23:26:15.024133Z", - "shell.execute_reply": "2021-05-17T23:26:15.023597Z", - "shell.execute_reply.started": "2021-05-17T23:26:15.021117Z" - } - }, - "outputs": [], - "source": [ - "crafted = DatasetImagePaths(new_files)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:10.327359Z", - "iopub.status.busy": "2021-05-17T23:27:10.327061Z", - "iopub.status.idle": "2021-05-17T23:27:10.331376Z", - "shell.execute_reply": "2021-05-17T23:27:10.330348Z", - "shell.execute_reply.started": "2021-05-17T23:27:10.327324Z" - } - }, - "outputs": [], - "source": [ - "if len(crafted) > 0:\n", - " crafted[0][0].show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Okay, that seems to work decently. Test with transforms, which I will just find in CLIP source code and copy over, to prevent having to import CLIP in this executor." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:10.532077Z", - "iopub.status.busy": "2021-05-17T23:27:10.531910Z", - "iopub.status.idle": "2021-05-17T23:27:10.535139Z", - "shell.execute_reply": "2021-05-17T23:27:10.534199Z", - "shell.execute_reply.started": "2021-05-17T23:27:10.532056Z" - } - }, - "outputs": [], - "source": [ - "from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:10.672197Z", - "iopub.status.busy": "2021-05-17T23:27:10.672025Z", - "iopub.status.idle": "2021-05-17T23:27:10.675311Z", - "shell.execute_reply": "2021-05-17T23:27:10.674703Z", - "shell.execute_reply.started": "2021-05-17T23:27:10.672178Z" - } - }, - "outputs": [], - "source": [ - "def clip_transform(n_px):\n", - " return Compose([\n", - " Resize(n_px, interpolation=Image.BICUBIC),\n", - " CenterCrop(n_px),\n", - " ToTensor(),\n", - " Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n", - " ])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:10.783218Z", - "iopub.status.busy": "2021-05-17T23:27:10.783066Z", - "iopub.status.idle": "2021-05-17T23:27:10.785719Z", - "shell.execute_reply": "2021-05-17T23:27:10.785213Z", - "shell.execute_reply.started": "2021-05-17T23:27:10.783202Z" - } - }, - "outputs": [], - "source": [ - "crafted_transformed = DatasetImagePaths(new_files, clip_transform(224))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:10.914257Z", - "iopub.status.busy": "2021-05-17T23:27:10.914086Z", - "iopub.status.idle": "2021-05-17T23:27:10.916361Z", - "shell.execute_reply": "2021-05-17T23:27:10.915817Z", - "shell.execute_reply.started": "2021-05-17T23:27:10.914238Z" - } - }, - "outputs": [], - "source": [ - "# crafted_transformed[0][0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:11.049878Z", - "iopub.status.busy": "2021-05-17T23:27:11.049661Z", - "iopub.status.idle": "2021-05-17T23:27:11.052757Z", - "shell.execute_reply": "2021-05-17T23:27:11.052083Z", - "shell.execute_reply.started": "2021-05-17T23:27:11.049856Z" - } - }, - "outputs": [], - "source": [ - "# to_pil = torchvision.transforms.ToPILImage()\n", - "# img = to_pil(crafted_transformed[0][0])\n", - "# img.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Put that all together, and wrap in a DataLoader for batching. In future, need to figure out how to pick batch size and number of workers programmatically bsed on device capabilities." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:11.510935Z", - "iopub.status.busy": "2021-05-17T23:27:11.510756Z", - "iopub.status.idle": "2021-05-17T23:27:11.514316Z", - "shell.execute_reply": "2021-05-17T23:27:11.513561Z", - "shell.execute_reply.started": "2021-05-17T23:27:11.510917Z" - } - }, - "outputs": [], - "source": [ - "def crafter(new_files, device, batch_size=128, num_workers=4): \n", - " with torch.no_grad():\n", - " imagefiles=DatasetImagePaths(new_files, clip_transform(224))\n", - " img_loader=torch.utils.data.DataLoader(imagefiles, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n", - " return(img_loader)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:11.876682Z", - "iopub.status.busy": "2021-05-17T23:27:11.876512Z", - "iopub.status.idle": "2021-05-17T23:27:11.880238Z", - "shell.execute_reply": "2021-05-17T23:27:11.879080Z", - "shell.execute_reply.started": "2021-05-17T23:27:11.876665Z" - } - }, - "outputs": [], - "source": [ - "img_loader = crafter(new_files, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:12.182811Z", - "iopub.status.busy": "2021-05-17T23:27:12.182645Z", - "iopub.status.idle": "2021-05-17T23:27:12.186305Z", - "shell.execute_reply": "2021-05-17T23:27:12.185653Z", - "shell.execute_reply.started": "2021-05-17T23:27:12.182794Z" - } - }, - "outputs": [], - "source": [ - "img_loader" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Encoder\n", - "\n", - "CLIP wrapper takes batched tensors or text queries and returns batched 512-dim vectors. size of batch depends on GPU, but if we're putting all that on a server anyway it's a matter of accounting. Does batching go here though? Or in the crafter?\n", - "\n", - "cool thing here is we can use one encoder for both image and text, just check type on the way in. but first probably keep it simple and make two functions.\n", - "\n", - "could index previous queries as vectors in a different map and use for predictive/history -- keep a little database of previous queries already in vector format and their ranked NNs, so that the user can see history offline?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:13.353998Z", - "iopub.status.busy": "2021-05-17T23:27:13.353324Z", - "iopub.status.idle": "2021-05-17T23:27:14.660546Z", - "shell.execute_reply": "2021-05-17T23:27:14.659900Z", - "shell.execute_reply.started": "2021-05-17T23:27:13.353916Z" - } - }, - "outputs": [], - "source": [ - "import clip\n", - "from tqdm import tqdm\n", - "model, _ = clip.load(\"ViT-B/32\", device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:14.661653Z", - "iopub.status.busy": "2021-05-17T23:27:14.661465Z", - "iopub.status.idle": "2021-05-17T23:27:14.665037Z", - "shell.execute_reply": "2021-05-17T23:27:14.664655Z", - "shell.execute_reply.started": "2021-05-17T23:27:14.661618Z" - } - }, - "outputs": [], - "source": [ - "def image_encoder(img_loader, device):\n", - " image_embeddings = torch.tensor(()).to(device)\n", - " with torch.no_grad():\n", - " for images, labels in tqdm(img_loader):\n", - " batch_features = model.encode_image(images)\n", - " image_embeddings = torch.cat((image_embeddings, batch_features)).to(device)\n", - " \n", - " image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)\n", - " return(image_embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:15.166413Z", - "iopub.status.busy": "2021-05-17T23:27:15.166108Z", - "iopub.status.idle": "2021-05-17T23:27:15.300321Z", - "shell.execute_reply": "2021-05-17T23:27:15.299842Z", - "shell.execute_reply.started": "2021-05-17T23:27:15.166374Z" - } - }, - "outputs": [], - "source": [ - "new_embeddings = image_encoder(img_loader, device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:16.366074Z", - "iopub.status.busy": "2021-05-17T23:27:16.365827Z", - "iopub.status.idle": "2021-05-17T23:27:16.370513Z", - "shell.execute_reply": "2021-05-17T23:27:16.369708Z", - "shell.execute_reply.started": "2021-05-17T23:27:16.366042Z" - } - }, - "outputs": [], - "source": [ - "def text_encoder(text, device):\n", - " with torch.no_grad():\n", - " text = clip.tokenize(text).to(device)\n", - " text_features = model.encode_text(text)\n", - " text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n", - " return(text_features)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Indexer\n", - "\n", - "Annoy treemap or FAISS or other solutions. Given a dataset of tensors, returns a dictionary or database or treemap structure, something that is searchable for later. It would be nice to be able to diff this somehow, or make sure that it's up-to-date. Maybe keeping two copies is okay? One for backup and quick-searching, one for main search once it's indexed any new images. \n", - "\n", - "This executor `needs` both Encoder and Loader to send it the new and old vectors, respectively. So it needs to be preceded by some kind of **join_all** component that can makesure we're not missing new data before handing it over to the indexer. Hm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:22.830650Z", - "iopub.status.busy": "2021-05-17T23:27:22.829915Z", - "iopub.status.idle": "2021-05-17T23:27:22.835413Z", - "shell.execute_reply": "2021-05-17T23:27:22.834944Z", - "shell.execute_reply.started": "2021-05-17T23:27:22.830565Z" - } - }, - "outputs": [], - "source": [ - "root = Path(args['path'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:23.548128Z", - "iopub.status.busy": "2021-05-17T23:27:23.547975Z", - "iopub.status.idle": "2021-05-17T23:27:23.551213Z", - "shell.execute_reply": "2021-05-17T23:27:23.550679Z", - "shell.execute_reply.started": "2021-05-17T23:27:23.548112Z" - } - }, - "outputs": [], - "source": [ - "def join_all(db, new_files, new_embeddings):\n", - " for i, file in enumerate(new_files):\n", - " path, slug = file\n", - " start = len(db)\n", - " index = i + start\n", - " archive_db[slug] = {\n", - " 'slug': slug,\n", - " 'fpath': path,\n", - " 'embed': new_embeddings[i],\n", - " 'index': index\n", - " }\n", - " return(db)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:26.689841Z", - "iopub.status.busy": "2021-05-17T23:27:26.689681Z", - "iopub.status.idle": "2021-05-17T23:27:26.692632Z", - "shell.execute_reply": "2021-05-17T23:27:26.691974Z", - "shell.execute_reply.started": "2021-05-17T23:27:26.689825Z" - } - }, - "outputs": [], - "source": [ - "db = join_all(archive_db,\n", - " new_files,\n", - " new_embeddings\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:27.321954Z", - "iopub.status.busy": "2021-05-17T23:27:27.321741Z", - "iopub.status.idle": "2021-05-17T23:27:27.326550Z", - "shell.execute_reply": "2021-05-17T23:27:27.325029Z", - "shell.execute_reply.started": "2021-05-17T23:27:27.321935Z" - } - }, - "outputs": [], - "source": [ - "len(db)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And build treemap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:28.602453Z", - "iopub.status.busy": "2021-05-17T23:27:28.601731Z", - "iopub.status.idle": "2021-05-17T23:27:28.613957Z", - "shell.execute_reply": "2021-05-17T23:27:28.611655Z", - "shell.execute_reply.started": "2021-05-17T23:27:28.602368Z" - } - }, - "outputs": [], - "source": [ - "from annoy import AnnoyIndex" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:29.075028Z", - "iopub.status.busy": "2021-05-17T23:27:29.074813Z", - "iopub.status.idle": "2021-05-17T23:27:29.078199Z", - "shell.execute_reply": "2021-05-17T23:27:29.077644Z", - "shell.execute_reply.started": "2021-05-17T23:27:29.075010Z" - } - }, - "outputs": [], - "source": [ - "def build_treemap(db):\n", - " treemap = AnnoyIndex(512, 'angular')\n", - " for v in db.values():\n", - " treemap.add_item(v['index'], v['embed'])\n", - "\n", - " # Build the treemap, with 5 trees rn\n", - " treemap.build(5)\n", - "\n", - " return(treemap)\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:29.615962Z", - "iopub.status.busy": "2021-05-17T23:27:29.615800Z", - "iopub.status.idle": "2021-05-17T23:27:47.259986Z", - "shell.execute_reply": "2021-05-17T23:27:47.259488Z", - "shell.execute_reply.started": "2021-05-17T23:27:29.615943Z" - } - }, - "outputs": [], - "source": [ - "t = build_treemap(db)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:47.261342Z", - "iopub.status.busy": "2021-05-17T23:27:47.261093Z", - "iopub.status.idle": "2021-05-17T23:27:47.265327Z", - "shell.execute_reply": "2021-05-17T23:27:47.264924Z", - "shell.execute_reply.started": "2021-05-17T23:27:47.261322Z" - } - }, - "outputs": [], - "source": [ - "t.get_n_items(), t.get_n_trees()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:47.266399Z", - "iopub.status.busy": "2021-05-17T23:27:47.266168Z", - "iopub.status.idle": "2021-05-17T23:27:47.269406Z", - "shell.execute_reply": "2021-05-17T23:27:47.269053Z", - "shell.execute_reply.started": "2021-05-17T23:27:47.266382Z" - } - }, - "outputs": [], - "source": [ - "def save_archives(root, treemap, db):\n", - " dbpath = root/'memery.pt'\n", - " if dbpath.exists():\n", - "# dbpath.rename(root/'memery-bak.pt')\n", - " dbpath.unlink()\n", - " torch.save(db, dbpath)\n", - " \n", - " treepath = root/'memery.ann'\n", - " if treepath.exists():\n", - "# treepath.rename(root/'memery-bak.ann')\n", - " treepath.unlink()\n", - " treemap.save(str(treepath))\n", - " \n", - " return(str(dbpath), str(treepath))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:47.270195Z", - "iopub.status.busy": "2021-05-17T23:27:47.270078Z", - "iopub.status.idle": "2021-05-17T23:27:47.361769Z", - "shell.execute_reply": "2021-05-17T23:27:47.361432Z", - "shell.execute_reply.started": "2021-05-17T23:27:47.270180Z" - } - }, - "outputs": [], - "source": [ - "save_archives(root, t, db)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Ranker\n", - "\n", - "Takes a query and an index and finds the nearest neighbors or most similar scores. Ideally this is just a simple Annoy `get_nns_by_vector`, or in the simple case a similarity score across all the vectors." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:55.387079Z", - "iopub.status.busy": "2021-05-17T23:27:55.386363Z", - "iopub.status.idle": "2021-05-17T23:27:55.397260Z", - "shell.execute_reply": "2021-05-17T23:27:55.394454Z", - "shell.execute_reply.started": "2021-05-17T23:27:55.386997Z" - } - }, - "outputs": [], - "source": [ - "def ranker(query_vec, treemap):\n", - " nn_indexes = treemap.get_nns_by_vector(query_vec[0], treemap.get_n_items())\n", - " return(nn_indexes)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:26:15.001671Z", - "iopub.status.busy": "2021-05-17T23:26:15.001534Z", - "iopub.status.idle": "2021-05-17T23:26:15.004383Z", - "shell.execute_reply": "2021-05-17T23:26:15.003536Z", - "shell.execute_reply.started": "2021-05-17T23:26:15.001654Z" - } - }, - "outputs": [], - "source": [ - "from IPython.display import Image as IMG" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:56.008469Z", - "iopub.status.busy": "2021-05-17T23:27:56.008293Z", - "iopub.status.idle": "2021-05-17T23:27:56.012267Z", - "shell.execute_reply": "2021-05-17T23:27:56.011056Z", - "shell.execute_reply.started": "2021-05-17T23:27:56.008450Z" - } - }, - "outputs": [], - "source": [ - "def printi(filenames, n=5):\n", - " for im in filenames[:n]:\n", - " display(IMG(filename=im[0], width=200))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:55.520152Z", - "iopub.status.busy": "2021-05-17T23:27:55.519884Z", - "iopub.status.idle": "2021-05-17T23:27:55.524543Z", - "shell.execute_reply": "2021-05-17T23:27:55.523632Z", - "shell.execute_reply.started": "2021-05-17T23:27:55.520126Z" - } - }, - "outputs": [], - "source": [ - "def rank_5(text):\n", - " query_vec = text_encoder(text, device)\n", - " indexes = ranker(query_vec, t)\n", - " filenames =[[v['fpath'] for k,v in db.items() if v['index'] == ind] for ind in indexes]\n", - " return(filenames)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:27:56.551897Z", - "iopub.status.busy": "2021-05-17T23:27:56.551621Z", - "iopub.status.idle": "2021-05-17T23:27:57.496956Z", - "shell.execute_reply": "2021-05-17T23:27:57.496325Z", - "shell.execute_reply.started": "2021-05-17T23:27:56.551836Z" - } - }, - "outputs": [], - "source": [ - "printi(rank_5(args['query']))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "I think we have to call that a success!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Gateway\n", - "\n", - "Takes a query and processes it through either Indexing Flow or Querying Flow, passing along arguments. The main entrypoint for each iteration of the index/query process.\n", - "\n", - "Querying Flow can technically process either text or image search, becuase the CLIP encoder will put them into the same embedding space. So we might as well build in a method for either, and make it available to the user, since it's impressive and useful and relatively easy to build.\n", - "\n", - "Eventually the Gateway process probably needs to be quite complicated, for serving all the different users and for delivering REST APIs to different clients. For now we will run this locally, in a notebook. Then build out a GUI from there using `mediapy` or `widgets`. That should reveal the basic necessities of the UI, and then we can separate out the GUI client from the server." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:28:07.161111Z", - "iopub.status.busy": "2021-05-17T23:28:07.160442Z", - "iopub.status.idle": "2021-05-17T23:28:07.173215Z", - "shell.execute_reply": "2021-05-17T23:28:07.171845Z", - "shell.execute_reply.started": "2021-05-17T23:28:07.161028Z" - } - }, - "outputs": [], - "source": [ - "def indexFlow(path):\n", - " root = Path(path)\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " \n", - " filepaths = get_image_files(root)\n", - " archive_db, new_files = files_archive_loader(filepaths, root, device)\n", - " print(f\"Loaded {len(archive_db)} encodings\")\n", - " print(f\"Encoding {len(new_files)} new images\")\n", - " crafted_files = crafter(new_files, device)\n", - " new_embeddings = image_encoder(crafted_files, device)\n", - " \n", - " db = join_all(archive_db, new_files, new_embeddings)\n", - " print(\"Building treemap\")\n", - " t = build_treemap(db)\n", - " \n", - " print(f\"Saving {len(db)}images\")\n", - " save_paths = save_archives(root, t, db)\n", - " print(\"Done\")\n", - " return(save_paths)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:28:10.313351Z", - "iopub.status.busy": "2021-05-17T23:28:10.313166Z", - "iopub.status.idle": "2021-05-17T23:28:28.543108Z", - "shell.execute_reply": "2021-05-17T23:28:28.542515Z", - "shell.execute_reply.started": "2021-05-17T23:28:10.313334Z" - } - }, - "outputs": [], - "source": [ - "save_paths = indexFlow(args['path'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:28:28.544063Z", - "iopub.status.busy": "2021-05-17T23:28:28.543945Z", - "iopub.status.idle": "2021-05-17T23:28:28.547992Z", - "shell.execute_reply": "2021-05-17T23:28:28.547123Z", - "shell.execute_reply.started": "2021-05-17T23:28:28.544047Z" - } - }, - "outputs": [], - "source": [ - "save_paths" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To search:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:08:36.054349Z", - "iopub.status.busy": "2021-05-17T23:08:36.054132Z", - "iopub.status.idle": "2021-05-17T23:08:36.059544Z", - "shell.execute_reply": "2021-05-17T23:08:36.058990Z", - "shell.execute_reply.started": "2021-05-17T23:08:36.054324Z" - } - }, - "outputs": [], - "source": [ - "def queryFlow(path, query): \n", - " root = Path(path)\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " \n", - " dbpath = root/'memery.pt'\n", - " db = db_loader(dbpath)\n", - " treepath = root/'memery.ann'\n", - " treemap = treemap_loader(treepath)\n", - " \n", - " if treemap == None or db == {}:\n", - " dbpath, treepath = indexFlow(root)\n", - " treemap = treemap_loader(treepath)\n", - " db = file\n", - " \n", - " print(f\"Searching {len(db)} images\")\n", - " query_vec = text_encoder(query, device)\n", - " indexes = ranker(query_vec, treemap)\n", - " ranked_files = [[v['fpath'] for k,v in db.items() if v['index'] == ind] for ind in indexes]\n", - " return(ranked_files)\n", - "\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:12:15.974818Z", - "iopub.status.busy": "2021-05-17T23:12:15.974655Z", - "iopub.status.idle": "2021-05-17T23:12:16.791693Z", - "shell.execute_reply": "2021-05-17T23:12:16.791335Z", - "shell.execute_reply.started": "2021-05-17T23:12:15.974800Z" - } - }, - "outputs": [], - "source": [ - "ranked = queryFlow(args['path'], 'dog')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2021-05-17T23:12:16.792617Z", - "iopub.status.busy": "2021-05-17T23:12:16.792501Z", - "iopub.status.idle": "2021-05-17T23:12:16.808254Z", - "shell.execute_reply": "2021-05-17T23:12:16.807905Z", - "shell.execute_reply.started": "2021-05-17T23:12:16.792601Z" - } - }, - "outputs": [], - "source": [ - "printi(ranked)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Interactive process\n", - "Currently the objective is to take the following inputs:\n", - "- a location with images\n", - "- a text or image query,\n", - "\n", - "and return the following outputs:\n", - "- a list of image files within that location ranked by similarity to that query,\n", - "\n", - "with a minimum of duplicated effort, and a general ease-of-use for both the programmer and the casual API user." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## TODO:\n", - "\n", - "- Cleanup repo\n", - "- Rough interactive GUI\n", - "\n", - "- Optimize the image loader and number of trees based on memory and db size\n", - "- Type annotations\n", - "\n", - "## DONE:\n", - "- _Code for joining archived data to new data_\n", - "- _Code for saving indexes to archive_\n", - "- _Flows_\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/memery b/notebooks/memery deleted file mode 120000 index 1659db9..0000000 --- a/notebooks/memery +++ /dev/null @@ -1 +0,0 @@ -../memery/ \ No newline at end of file diff --git a/notebooks/memery.ipynb b/notebooks/memery.ipynb deleted file mode 100644 index 5854a20..0000000 --- a/notebooks/memery.ipynb +++ /dev/null @@ -1,91 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from memery.core import Memery" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "memery = Memery()\n", - "ranked = memery.query_flow('../images', 'dad joke')\n", - "\n", - "print(ranked[:5])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "memery = Memery()\n", - "root = '../images/'\n", - "db = memery.get_db(root + 'memery.pt')\n", - "index = memery.get_index(root + 'memery.ann')\n", - "model = memery.get_model()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "memery.index_flow(root)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "memery.reset_state()\n", - "memery.model = None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "memery.query_flow(root, 'Wow its already working')" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "deeee0b52e76b5e3a563dfd39c9570f6111f9f254cd04b55dab6af9643751b0b" - }, - "kernelspec": { - "display_name": "Python 3.9.12 ('memery-OXFjyqC6-py3.9')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.6" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index ed57053..6ec5e53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,20 +5,19 @@ description = "" authors = ["deepfates ", "wkrettek "] [tool.poetry.dependencies] -python = "^3.9" +python = ">=3.10,<3.13" torch = "^2.2.0" annoy = "^1.17.0" torchvision = "^0.17.0" tqdm = "^4.64.0" -Pillow = "^9.1.0" -typer = "^0.4.1" -streamlit = "1.3.1" +Pillow = "^10.0.0" +typer = ">=0.12,<1.0" +streamlit = "^1.30.0" clip = {git = "https://github.com/openai/CLIP", rev = "main"} ftfy = "^6.1.1" regex = "^2022.4.24" -altair = "^4.0.0" +altair = "^5.0.0" numpy = "^1.24.0" -protobuf = "^3.20.0" [tool.poetry.scripts] memery = "memery.cli:main" @@ -26,7 +25,16 @@ memery = "memery.cli:main" [tool.poetry.dev-dependencies] ipywidgets = "^7.7.0" ipython = "^8.3.0" +pytest = "^8.0.0" +flake8 = "^7.0.0" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +testpaths = ["tests"] +markers = [ + "integration: end-to-end tests that load the CLIP model (slow, ~338MB download on first run)", +] +addopts = "-m 'not integration'" diff --git a/notebooks/__init__.py b/tests/__init__.py similarity index 100% rename from notebooks/__init__.py rename to tests/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..16bad55 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,56 @@ +"""Shared fixtures for the memery test suite. + +The bundled ``images/`` folder is the canonical fixture corpus, but tests +shouldn't touch it directly — running ``index_flow`` writes ``memery.ann`` +and ``memery.pt`` into whatever folder you point it at, and we don't want +those artifacts ending up in the working tree. So most fixtures copy a +subset into a tmp dir and let pytest clean it up. +""" +from __future__ import annotations + +import shutil +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +BUNDLED_IMAGES = REPO_ROOT / "images" + + +@pytest.fixture(scope="session") +def bundled_images_dir() -> Path: + """Read-only path to the bundled images/ folder.""" + assert BUNDLED_IMAGES.is_dir(), f"missing fixture corpus: {BUNDLED_IMAGES}" + return BUNDLED_IMAGES + + +@pytest.fixture +def tiny_image_dir(tmp_path: Path, bundled_images_dir: Path) -> Path: + """A tmp dir with ~3 valid images plus the corrupt fixture. + + Fast enough to encode end-to-end with real CLIP under the integration + marker without the test suite taking forever. + """ + dst = tmp_path / "tiny" + dst.mkdir() + valid = [ + "memes/Wholesome-Meme-68.jpg", # the "dad joke" winner from earlier + "memes/Wholesome-Meme-1.jpg", + "memes/cute-dog-with-cupcake-P9E2YL5-min.jpg", + ] + for rel in valid: + src = bundled_images_dir / rel + assert src.exists(), f"fixture moved or missing: {src}" + shutil.copy(src, dst / src.name) + # include the deliberately corrupt file so loader tests can exercise it + corrupt = bundled_images_dir / "memes" / "corrupted-file.jpeg" + if corrupt.exists(): + shutil.copy(corrupt, dst / corrupt.name) + return dst + + +@pytest.fixture +def empty_image_dir(tmp_path: Path) -> Path: + d = tmp_path / "empty" + d.mkdir() + return d diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..bec942a --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,28 @@ +"""CLI smoke tests — exercises typer wiring without touching CLIP.""" +from __future__ import annotations + +from typer.testing import CliRunner + +from memery.cli import app + +runner = CliRunner() + + +def test_cli_help_lists_all_commands(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + for cmd in ("recall", "build", "serve", "purge"): + assert cmd in result.stdout + + +def test_recall_help_exposes_negative_flag(): + result = runner.invoke(app, ["recall", "--help"]) + assert result.exit_code == 0 + assert "--negative-text" in result.stdout or "-nt" in result.stdout + + +def test_purge_is_idempotent_on_empty_dir(empty_image_dir): + """purge on a folder with no index files should succeed quietly.""" + result = runner.invoke(app, ["purge", str(empty_image_dir)]) + assert result.exit_code == 0 + assert "Purged" in result.stdout diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..2d27ee6 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,89 @@ +"""End-to-end tests for the index/query pipeline. + +Marked ``integration`` because they download CLIP weights on first run +(~338MB) and run the model. Skipped in default test runs; opt in with +``pytest -m integration``. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from memery.core import Memery + +pytestmark = pytest.mark.integration + + +def test_index_flow_writes_archives(tiny_image_dir: Path): + m = Memery(root=str(tiny_image_dir)) + db_path, tree_path = m.index_flow(str(tiny_image_dir)) + assert Path(db_path).exists() + assert Path(tree_path).exists() + + +def test_query_flow_returns_ranked_paths(tiny_image_dir: Path): + m = Memery(root=str(tiny_image_dir)) + ranked = m.query_flow(str(tiny_image_dir), query="dog") + assert isinstance(ranked, list) + assert len(ranked) >= 1 + # the cupcake-dog fixture should rank above the wholesome-meme texts + assert "cute-dog-with-cupcake-P9E2YL5-min.jpg" in ranked[0] + + +def test_query_flow_no_query_returns_empty(tiny_image_dir: Path): + m = Memery(root=str(tiny_image_dir)) + out = m.query_flow(str(tiny_image_dir)) + # current behaviour: returns empty string when there's nothing to search for + assert out == "" or out == [] + + +def test_clean_removes_index_files(tiny_image_dir: Path): + m = Memery(root=str(tiny_image_dir)) + m.index_flow(str(tiny_image_dir)) + assert (tiny_image_dir / "memery.ann").exists() + m.clean(str(tiny_image_dir)) + assert not (tiny_image_dir / "memery.ann").exists() + assert not (tiny_image_dir / "memery.pt").exists() + + +def test_index_flow_keeps_paths_aligned_when_decode_fails_mid_batch( + tiny_image_dir: Path, monkeypatch +): + """Files that pass verify_image but fail at decode time used to corrupt the + db: safe_collate would drop the failed item from a batch, but the encoder + didn't tell the indexer which item was dropped, so subsequent files got + tagged with the wrong embedding (and the last file would IndexError). + + Simulate that case by monkey-patching pil_loader to return None for a + specific file. The resulting db must contain only successfully-decoded + files, with paths correctly correlated to their embeddings. + """ + import torch + + from memery import crafter + + real_loader = crafter.pil_loader + + def flaky_loader(path: str): + if "Wholesome-Meme-1.jpg" in path: + return None + return real_loader(path) + + monkeypatch.setattr(crafter, "pil_loader", flaky_loader) + + m = Memery(root=str(tiny_image_dir)) + db_path, _ = m.index_flow(str(tiny_image_dir)) + + db = torch.load(db_path, map_location="cpu", weights_only=False) + paths = [entry["fpath"] for entry in db.values()] + + # The deliberately-flaky file must NOT appear in the db + assert not any("Wholesome-Meme-1.jpg" in p for p in paths) + # The other fixtures must still be there + assert any("Wholesome-Meme-68.jpg" in p for p in paths) + assert any("cute-dog-with-cupcake" in p for p in paths) + # And every entry must have a 512-dim embedding (i.e. nothing got + # IndexError'd into a half-built state) + for entry in db.values(): + assert entry["embed"].shape == (512,) diff --git a/tests/test_crafter.py b/tests/test_crafter.py new file mode 100644 index 0000000..fd5d399 --- /dev/null +++ b/tests/test_crafter.py @@ -0,0 +1,40 @@ +"""Tests for crafter — pil_loader fast path and safe_collate.""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import torch + +from memery import crafter + + +def test_pil_loader_returns_rgb_image(bundled_images_dir: Path): + f = bundled_images_dir / "memes" / "Wholesome-Meme-1.jpg" + img = crafter.pil_loader(str(f)) + assert img is not None + assert img.mode == "RGB" + + +def test_pil_loader_returns_none_on_corrupt(bundled_images_dir: Path): + f = bundled_images_dir / "memes" / "corrupted-file.jpeg" + if not f.exists(): + pytest.skip("corrupt fixture missing") + out = crafter.pil_loader(str(f)) + assert out is None + + +def test_safe_collate_drops_none_items(): + """A batch with mixed None and real items collates only the real ones.""" + a = (torch.zeros(3, 8, 8), 0) + b = (torch.zeros(3, 8, 8), 1) + out = crafter.safe_collate([a, None, b]) + images, labels = out + assert images.shape == (2, 3, 8, 8) + assert labels.tolist() == [0, 1] + + +def test_safe_collate_returns_none_on_all_failed(): + """A batch where every item failed yields None — encoder skips it.""" + assert crafter.safe_collate([None, None]) is None + assert crafter.safe_collate([]) is None diff --git a/tests/test_loader.py b/tests/test_loader.py new file mode 100644 index 0000000..d727e33 --- /dev/null +++ b/tests/test_loader.py @@ -0,0 +1,54 @@ +"""Unit tests for memery.loader — no CLIP, no torch model loading.""" +from __future__ import annotations + +from pathlib import Path + +from memery import loader + + +def test_hash_path_combines_stem_and_mtime(tmp_path: Path): + f = tmp_path / "hello.jpg" + f.write_bytes(b"\x00") + h = loader.hash_path(f) + assert h.startswith("hello_") + # mtime portion is everything after the underscore and is an integer-ish string + _, mtime = h.split("_", 1) + assert mtime.isdigit() + + +def test_get_image_files_filters_extensions(tmp_path: Path): + (tmp_path / "a.jpg").write_bytes(b"\x00") + (tmp_path / "b.PNG").write_bytes(b"\x00") # uppercase suffix is NOT matched + (tmp_path / "notes.txt").write_text("nope") + sub = tmp_path / "nested" + sub.mkdir() + (sub / "c.png").write_bytes(b"\x00") + + paths = loader.get_image_files(tmp_path) + names = sorted(p.name for p, _ in paths) + # current behaviour: extension matching is case-sensitive on the suffix set + assert "a.jpg" in names + assert "c.png" in names + assert "notes.txt" not in names + + +def test_get_valid_images_skips_corrupt(tiny_image_dir: Path): + valid = loader.get_valid_images(tiny_image_dir) + names = {Path(p).name for p, _ in valid} + # corrupt fixture must be excluded + assert "corrupted-file.jpeg" not in names + # the three real fixtures must all be there + assert "Wholesome-Meme-68.jpg" in names + assert "Wholesome-Meme-1.jpg" in names + assert "cute-dog-with-cupcake-P9E2YL5-min.jpg" in names + + +def test_treemap_loader_returns_none_on_missing(tmp_path: Path): + assert loader.treemap_loader(tmp_path / "nope.ann") is None + + +def test_db_loader_returns_empty_dict_on_missing(tmp_path: Path): + import torch + + out = loader.db_loader(tmp_path / "nope.pt", torch.device("cpu")) + assert out == {}