Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# This workflow installs Python dependencies, lints, and runs the test suite.
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python application

on:
push:
branches: [ main ]
branches: [ main, rewrite ]
pull_request:
branches: [ main ]
branches: [ main, rewrite ]

permissions:
contents: read
Expand All @@ -18,19 +18,26 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v3
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest poetry
python -m pip install --upgrade pip poetry
# `poetry install` picks up dev-dependencies (pytest, flake8) from
# the [tool.poetry.dev-dependencies] table, so we don't need to
# pip-install them separately and risk version skew.
poetry install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=.venv
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude=.venv
- name: Run tests
# Default pytest config skips the `integration` marker (CLIP download).
# Fast tests only here — they're enough to catch dead imports and
# regressions in the loader / CLI wiring.
run: poetry run pytest -v
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ tags
*.ann
*.pt

# local-only workspace state
.claude/
.venv/
.pytest_cache/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
Binary file removed docs/testimage.png
Binary file not shown.
Binary file removed graphs/embed_d.jpg
Binary file not shown.
Binary file removed graphs/embed_n.jpg
Binary file not shown.
Binary file removed graphs/mde_d.gif
Binary file not shown.
Binary file removed graphs/mde_n.gif
Binary file not shown.
Binary file removed graphs/normalized-d.jpg
Binary file not shown.
Binary file removed graphs/normalized-lg.jpg
Binary file not shown.
Binary file removed graphs/normalized.jpg
Binary file not shown.
Binary file removed graphs/plotted_ims.jpg
Binary file not shown.
46 changes: 31 additions & 15 deletions memery/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import sys
from pathlib import Path
from typing import Optional

import typer
from memery.core import Memery

import memery
import streamlit.cli
from typing import Optional
from memery.core import Memery
# Sometimes you just want to be able to pipe information through the terminal. This is that command

app = typer.Typer()
Expand All @@ -13,28 +16,41 @@ def main():
@app.command()
def recall(
root: str = typer.Argument('.', help="Image folder to search"),
text: str = typer.Option(None, *("-t", "--text"), help="Text query"),
image: str = typer.Option(None, *("-i", "--image"), help="Filepath to image query") ,
number: int = typer.Option(10, *("-n", "--number"), help="Number of results to return")
) -> list[str]:
"""Search recursively over a folder from the command line"""
text: str = typer.Option(None, "-t", "--text", help="Text query"),
negative: str = typer.Option(None, "-nt", "--negative-text", help="Negative text query"),
image: str = typer.Option(None, "-i", "--image", help="Filepath to image query"),
number: int = typer.Option(10, "-n", "--number", help="Number of results to return"),
) -> None:
"""Search recursively over a folder from the command line."""
memery = Memery()
ranked = memery.query_flow(root, query=text, image_query=image)
ranked = memery.query_flow(root, query=text, negative_query=negative, image_query=image)
print(ranked[:number])

@app.command()
def serve(root: Optional[str] = typer.Argument(None)):
"""Runs the streamlit GUI in your browser"""
app_path = memery.__file__.replace('__init__.py','streamlit_app.py')
if root is None:
streamlit.cli.main(['run', app_path, './images'])
else:
streamlit.cli.main(['run', app_path, f'{root}'])
# Importing here so `memery --help` doesn't pay the streamlit import cost
from streamlit.web import cli as stcli

app_path = str(Path(memery.__file__).parent / "streamlit_app.py")
target_root = root if root is not None else "./images"
sys.argv = ["streamlit", "run", app_path, "--", target_root]
sys.exit(stcli.main())

@app.command()
def build(
root: str = typer.Argument('.'),
workers: int = typer.Option(default=0)
workers: int = typer.Option(
2,
"--workers", "-w",
help=(
"DataLoader workers for image preprocessing. Measured on macOS/MPS: "
"2 wins ~15% on realistic corpora (4000+ images). Higher counts hurt "
"because the GPU is a serial bottleneck and IPC overhead piles up. "
"Pass 0 to disable workers entirely (slightly faster on tiny folders, "
"or as a fallback if multiprocessing misbehaves in your environment)."
),
),
):
'''
Indexes the directory and all subdirectories
Expand Down
12 changes: 11 additions & 1 deletion memery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,17 @@ def index_flow(self, root: str, num_workers=0) -> tuple[str, str]:
# Crafting and encoding
crafted_files = crafter.crafter(new_files, device, num_workers=num_workers)
model = self.get_model()
new_embeddings = encoder.image_encoder(crafted_files, device, model)
new_embeddings, surviving = encoder.image_encoder(crafted_files, device, model)

# `surviving` maps embedding rows back to indices into new_files;
# it can be shorter than new_files if any image failed to decode
# after passing verify_image (truncated files, weird color spaces).
# Filter new_files down to the same set so join_all can correlate
# by position without misalignment.
if len(surviving) != len(new_files):
dropped = len(new_files) - len(surviving)
print(f"Skipped {dropped} image(s) that failed to decode")
new_files = [new_files[int(i)] for i in surviving.tolist()]

# Reindexing
db = indexer.join_all(archive_db, new_files, new_embeddings)
Expand Down
126 changes: 88 additions & 38 deletions memery/crafter.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,121 @@
from typing import Optional

import torch
from torch import Tensor, device
from torch.utils.data import DataLoader, default_collate
from torchvision.datasets import VisionDataset
from PIL import Image, ImageFile
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch.utils.data import DataLoader


def make_dataset(new_files: list[str]) -> tuple[list[str], list[str]]:
'''Returns a list of samples of a form (path_to_sample, class) and in
this case the class is just the filename'''
samples = []
slugs = []
for i, f in enumerate(new_files):
path, slug = f
samples.append((str(path), i))
slugs.append((slug, i))
return(samples, slugs)

def pil_loader(path: str) -> Image.Image:
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow truncated images
from PIL import Image, ImageFile


def make_dataset(new_files: list[tuple[str, str]]) -> list[tuple[str, int]]:
'''Returns a list of (path, index) pairs.

The previous implementation also built a parallel `slugs` list that
nothing downstream ever read — pure dead weight on every build.
'''
return [(str(path), i) for i, (path, _slug) in enumerate(new_files)]


def pil_loader(path: str) -> Optional[Image.Image]:
"""Open `path` and return an RGB PIL image, or None on failure.

Uses PIL's JPEG `draft` mode to decode at the smallest IDCT scale that's
still ≥ 256px. For large JPEGs this skips most of the inverse-DCT work
and saves a substantial amount of file I/O — measured ~1.6x speedup on
a realistic Downloads sample (~165KB median image size). For non-JPEG
formats `draft` is a no-op, so it's safe to leave on unconditionally.
"""
ImageFile.LOAD_TRUNCATED_IMAGES = True # tolerate partially-truncated files
try:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
# open path as file to avoid ResourceWarning
# https://github.com/python-pillow/Pillow/issues/835
with open(path, 'rb') as f:
img = Image.open(f)
# Hint the JPEG decoder for partial-scale decode. 256 is chosen
# to stay safely above the 224 CLIP input size after CenterCrop.
try:
img.draft('RGB', (256, 256))
except (AttributeError, OSError):
pass
return img.convert('RGB')
except Exception as e:
print(f"Skipping image {path}: {e}")
return None


class DatasetImagePaths(VisionDataset):

def __init__(self, new_files, transforms = None):
super(DatasetImagePaths, self).__init__(new_files, transforms=transforms)
samples, slugs = make_dataset(new_files)
self.samples = samples
self.slugs = slugs

def __init__(self, new_files, transforms=None):
super().__init__(new_files, transforms=transforms)
self.samples = make_dataset(new_files)
self.loader = pil_loader
self.root = 'file dataset'

def __len__(self):
return(len(self.samples))
return len(self.samples)

def __getitem__(self, index):
path, target = self.samples[index]
try:
sample = self.loader(path)
if sample is not None:
if self.transforms is not None:
sample = self.transforms(sample)
return sample, target
if sample is None:
return None
if self.transforms is not None:
sample = self.transforms(sample)
return sample, target
except Exception as e:
print(f"Skipping file {path} due to error: {e}")
return None


def safe_collate(batch):
"""Drop items that failed to decode, then run the default collate.

Without this, a single `None` from `__getitem__` (e.g. a file that
passed `verify_image` but failed at decode time) would crash the entire
DataLoader with an unhelpful traceback. Returning `None` for an empty
batch lets the encoder loop skip it cleanly.
"""
batch = [b for b in batch if b is not None]
if not batch:
return None
return default_collate(batch)


def clip_transform(n_px: int) -> Compose:
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])

def crafter(new_files: list[str], device: device, batch_size: int=128, num_workers: int=4):

def crafter(new_files: list[tuple[str, str]], device: device,
batch_size: int = 128, num_workers: int = 0) -> DataLoader:
"""Build the DataLoader used to feed CLIP.

`num_workers=0` is the macOS default because DataLoader workers use
`spawn` (fork is unsafe with PyTorch), each costs ~2s of startup, and
on typical-sized images (~150KB) the parallelism gain is smaller than
the spawn overhead. For folders dominated by very large images
(multi-megabyte phone photos / scans), `--workers 4-8` can win
significantly — measured at scale, not at small N.
"""
with torch.no_grad():
imagefiles=DatasetImagePaths(new_files, clip_transform(224))
img_loader=DataLoader(imagefiles, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return(img_loader)
imagefiles = DatasetImagePaths(new_files, clip_transform(224))
img_loader = DataLoader(
imagefiles,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=safe_collate,
)
return img_loader


def preproc(img: Tensor) -> Compose:
transformed = clip_transform(224)(img)
return(transformed)
def preproc(img: Image.Image) -> Tensor:
"""Apply the CLIP preprocessing pipeline to a single PIL image."""
return clip_transform(224)(img)
52 changes: 46 additions & 6 deletions memery/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,57 @@
def load_model(device: device) -> CLIP:
model, _ = clip.load("ViT-B/32", device, jit=False)
model = model.float()
return(model)
# Inference-only — disables any train-time behavior in submodules.
# CLIP ViT-B/32 has no batchnorm or dropout that fires here, so this is
# mostly defensive correctness, but it's free.
model.eval()
return model

def image_encoder(img_loader: DataLoader, device: device, model: CLIP):
image_embeddings = torch.tensor(()).to(device)
"""Encode a DataLoader's worth of images to L2-normalized CLIP features.

Returns ``(embeddings, labels)``:
* ``embeddings``: ``(K, 512)`` CPU tensor of unit-norm features, where
``K`` is the number of images that *successfully* survived decoding
and collation. May be smaller than ``len(dataset)`` if any items
were dropped by ``crafter.safe_collate``.
* ``labels``: ``(K,)`` CPU long tensor mapping each row of
``embeddings`` back to its original index in ``new_files``. The
caller uses this to keep file paths and embeddings aligned even
when some files failed to decode mid-batch.

Implementation notes:
* Features are accumulated in a Python list and concatenated once at
the end. The previous in-loop ``torch.cat`` was O(n²) in the number
of batches and on MPS paid a Metal kernel-launch cost per concat —
~15s wasted on a 702-batch run.
* The final tensor is moved to CPU here. Everything downstream (annoy,
torch.save) is CPU-only; leaving embeddings on MPS caused
``annoy.add_item`` to force a Metal ``waitUntilCompleted`` for every
single float, i.e. ~46M GPU syncs on an 89k-image library.
"""
feature_chunks = []
label_chunks = []
with torch.no_grad():
for images, labels in tqdm(img_loader):
batch_features = model.encode_image(images.to(device))
image_embeddings = torch.cat((image_embeddings, batch_features)).to(device)
for batch in tqdm(img_loader):
if batch is None:
# safe_collate returns None when every image in a batch
# failed to decode — skip rather than crash.
continue
images, labels = batch
feature_chunks.append(model.encode_image(images.to(device)))
label_chunks.append(labels)

if not feature_chunks:
return (
torch.empty((0, 512)),
torch.empty((0,), dtype=torch.long),
)

image_embeddings = torch.cat(feature_chunks, dim=0)
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
surviving_labels = torch.cat(label_chunks, dim=0)
return image_embeddings.cpu(), surviving_labels.cpu()

def text_encoder(text: str, device: device, model: CLIP):
with torch.no_grad():
Expand Down
Loading
Loading