Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llmfoundry/command_utils/data_prep/convert_dataset_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def convert_dataset_hf(
)
loader = build_dataloader(
dataset=hf_dataset,
batch_size=512,
batch_size=1,
num_workers=num_workers,
)
samples = generate_samples(
Expand Down Expand Up @@ -405,6 +405,7 @@ def convert_dataset_hf(
columns=columns,
out=os.path.join(out_root, folder_split),
compression=compression,
size_limit="128mb",
) as out:
if denominator is not None:
for sample in tqdm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def convert_finetuning_dataset(
out=out,
compression=compression,
keep_local=keep_local,
size_limit="128mb",
) as out:
examples_removed = 0
for sample in tqdm(samples, desc=split_name):
Expand Down
12 changes: 5 additions & 7 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,11 @@ def __iter__(self) -> Iterable[dict[str, NDArray]]:
)
iids = encoded['input_ids']
buffer = buffer + self.bos_tokens + iids + self.eos_tokens
while len(buffer) >= self.max_length:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.max_length:] if self.should_wrap else []
yield {
# convert to ndarray to store in MDS format
'tokens': np.asarray(concat_sample, dtype=np.int32),
}
yield {
# convert to ndarray to store in MDS format
'tokens': np.asarray(buffer, dtype=np.int32),
}
buffer = []


def stream_remote_local_validate(
Expand Down
Empty file.
92 changes: 92 additions & 0 deletions scripts/data_prep/data_lib/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#% utils
def _banner(msg):
print("#"*len(msg))
print(msg)
print("#"*len(msg))

def str_rows_features(ds):
return f"rows: {len(ds)} features: {ds['train'].features}"

def get_datasets(): # target_repo):
ds_config = {
"tulu": {
"original": "allenai/tulu-3-sft-olmo-2-mixture",
"decontaminated": "LocalResearchGroup/split-tulu-3-sft-olmo-2-mixture-decontaminated",
"kind": "instruct",
"template": "tulu-with-template",
},
"numina": {
"original": "AI-MO/NuminaMath-CoT",
"decontaminated": "LocalResearchGroup/split-NuminaMath-CoT-decontaminated",
"kind": "instruct",
"template": "numina-with-template",
},
"glaive": {
"original": "glaiveai/glaive-code-assistant-v3",
"decontaminated": "LocalResearchGroup/split-glaive-code-assistant-v3-decontaminated",
"kind": "instruct",
"template": "glaive-with-template",
},
"finemath": {
"original": "HuggingFaceTB/finemath",
"decontaminated": "LocalResearchGroup/split-finemath-decontaminated",
"kind": "pretrain",
"ds_name": "finemath-4plus",
},
"pythonedu": {
"original": "Avelina/python-edu",
"decontaminated": "LocalResearchGroup/split-avelina-python-edu-decontaminated",
"kind": "pretrain",
},
}
return ds_config


def rel_path(name, decontaminated):
return f"{name}" \
f"{'-with-template' if get_datasets()[name]['kind'] == 'instruct' else ''}" \
f"{'-decontaminated' if decontaminated else ''}"

#% Allow to add extra datasets to CONSTS

def add_dataset_config(name, splits):
from llmfoundry.command_utils.data_prep.convert_dataset_hf import CONSTS
CONSTS[name] = splits


def generate_constants(total_rows, chars_per_sample, chars_per_token):
from llmfoundry.command_utils.data_prep.convert_dataset_hf import CONSTS, DataSplitConstants, DatasetConstants

ds_const = DatasetConstants(
chars_per_sample=chars_per_sample,
chars_per_token=chars_per_token,
)
ds_const.splits["train"] = DataSplitConstants(
hf_split="train",
folder_split="train",
raw_samples=total_rows,
truncated_samples=None,
)

ds_const.splits["test"] = DataSplitConstants(
hf_split="test",
folder_split="test",
raw_samples=total_rows,
truncated_samples=None,
)
return ds_const


def register_new_datasets(target = "LocalResearchGroup"):
constants = {
"finemath": generate_constants(6_700_000, 6212, 4),
"tulu": generate_constants(939_000, 6212, 4),
"numina": generate_constants(859_00, 6212, 4),
"pythonedu": generate_constants(7_680_000, 6212, 4),
"glaive": generate_constants(950_000, 6212, 4),
}
ds = get_datasets()
for name in ds.keys():
add_dataset_config(ds[name]["original"], constants[name])
add_dataset_config(ds[name]["decontaminated"], constants[name])

62 changes: 62 additions & 0 deletions scripts/data_prep/download_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from argparse import ArgumentParser, Namespace, BooleanOptionalAction
from huggingface_hub import HfApi, login
import os

from data_lib.utils import get_datasets, rel_path


def main(args):
api = HfApi()

for ds in args.datasets:
ld = f"{args.out}/{ds}-tokens"
datadown = f"{args.user_org}/{rel_path(ds, args.decontaminated)}-tokenized"
print(f"downloading {datadown=} to {ld=}\n")
local_dir = api.snapshot_download(
repo_id=datadown,
repo_type="dataset",
local_dir=ld,
)

def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
description=
"Downloads tokenized versions of train/test 1M, 100k, 10k, 1k",
)
datasets = get_datasets().keys()
parser.add_argument(
"--datasets",
nargs="+",
choices=datasets,
default=datasets,
)

parser.add_argument(
"--user_org",
default="LocalResearchGroup",
help="user/org containing tokenizations",
)

parser.add_argument(
"--out",
default=".",
help="local download folder",
)

parser.add_argument(
"--decontaminated",
action=BooleanOptionalAction,
default=False,
help="use decontaminated dataset instead of original one",
)
parsed = parser.parse_args()
return parsed


if __name__ == "__main__":
args = parse_args()
if not os.environ.get("HUGGING_FACE_HUB_TOKEN"):
print("No Hugging Face token found. Please login.")
login()
main(args)
161 changes: 161 additions & 0 deletions scripts/data_prep/text_dataset_preproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@

from argparse import ArgumentParser, Namespace, BooleanOptionalAction
from datasets import load_dataset, load_from_disk, DatasetDict
from llmfoundry.data.finetuning.tasks import dataset_constructor
from data_lib.utils import get_datasets, register_new_datasets, _banner, str_rows_features


def create_refactor(dataset, decontaminated):
ds_name = dataset["ds_name"] if "ds_name" in dataset else None
process = dataset["after_pull"] if "after_pull" in dataset else None
ds = dataset["decontaminated"] if decontaminated else dataset["original"]
original = pull_orifinal_ds(ds, decontaminated, ds_name, process)
return original

def pull_orifinal_ds(
hf_ds_src,
decontaminated,
ds_name=None,
after_pull=None,
):
_banner(f"Loading dataset {hf_ds_src}/{'default' if ds_name is None else ds_name}")
if ds_name: _banner(ds_name)
from llmfoundry.command_utils.data_prep.convert_dataset_hf import CONSTS
register_new_datasets()
dataset = load_dataset(path=hf_ds_src, name=ds_name)
if after_pull is not None:
dataset = after_pull(dataset, decontaminated)
return dataset


#% main loop
def _main_loop(args):
ds_config = get_datasets()
# Add after pull call to process instruct datasets with template
ds_config["tulu"]["after_pull"] = filter_tulu
ds_config["numina"]["after_pull"] = process_numina
ds_config["glaive"]["after_pull"] = process_glaive
for ds in args.datasets:
dataset = create_refactor(ds_config[ds], args.decontaminated)
private=False
hf_repo = f"{args.user_org}/{ds}-with-template{'-decontaminated' if args.decontaminated else ''}"
label="default"
shard_size = "128MB"
dataset.push_to_hub(hf_repo, config_name=label, private=private, max_shard_size=shard_size)


#% chat ml template and filtering of original datasets
def apply_chatml_template(inp: dict, k_prompt: str, k_response: str):
"""Format dataset into ChatML template."""
prompt = (
"<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Local Research Group<|im_end|>\n"
f"<|im_start|>user\n{inp[k_prompt]}\n<|im_end|>\n"
)
response = (
f"<|im_start|>assistant\n{inp[k_response]}<|im_end|>\n"
"<|endoftext|>"
)
return {"prompt": prompt, "response": response}


def template_to_tulu(inp: dict):
return apply_chatml_template(inp, "prompt", "response")


def template_to_numina(inp: dict):
return apply_chatml_template(inp, "problem", "solution")


def template_to_glaive(inp: dict):
return apply_chatml_template(inp, "question", "answer")


def filter_tulu(dataset, decontaminated):
print(f"\n\ntulu {str_rows_features(dataset)}\n\n")
if not decontaminated:
dataset = dataset.filter(lambda r: r["source"] is not None and "aya" not in r["source"] and len(r["messages"]) == 2)
dataset = dataset.remove_columns(["source", "dataset"])
dataset = dataset.remove_columns(["id"])

def extract_qa(messages):
user_question = next((msg["content"] for msg in messages if msg["role"] == "user"), None)
assistant_response = next((msg["content"] for msg in messages if msg["role"] == "assistant"), None)
return {"prompt": user_question, "response": assistant_response}

# Apply function to dataset
dataset = dataset.map(lambda example: extract_qa(example["messages"])) if not decontaminated else dataset
dataset = dataset.remove_columns(["messages"]) if not decontaminated else dataset
dataset = dataset.map(lambda example: template_to_tulu(example)) if not decontaminated else dataset
print(f"tulu after {str_rows_features(dataset)}")
return dataset


def process_numina(dataset, decontaminated):
print(f"numina {str_rows_features(dataset)}")
# remove conflictlict that breaks pytorch collate with 2 row per batch!
dataset = dataset.map(lambda example: template_to_numina(example))
colums = ["source", "problem", "solution"]
if not decontaminated: colums.append("messages")
dataset = dataset.remove_columns(colums)
print(f"numina processed: {str_rows_features(dataset)}")
return dataset


def process_glaive(dataset, decontaminated):
print(f"glaive {str_rows_features(dataset)}")

def extract_qa(messages):
return template_to_glaive(messages)

dataset = dataset.map(lambda example: extract_qa(example))
dataset = dataset.remove_columns(["question", "answer"])
print(f"glaive processed: {str_rows_features(dataset)}")

return dataset

#% argument parsing section
def main(args):
if args.datasets:
_main_loop(args)


def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
description="""Refactor instruct datasets with and witout decontamination
""",
)
ds = [k for k in get_datasets() if get_datasets()[k]["kind"] == "instruct"]

parser.add_argument(
"--datasets",
nargs="+",
choices=ds,
default=ds,
)

parser.add_argument(
"--user_org",
default="LocalResearchGroup",
help="user/org base namespace default is `LocalResearchGroup`",
)

parser.add_argument(
"--decontaminated",
action=BooleanOptionalAction,
default=False,
help="use decontaminated dataset instead of original one",
)

parsed = parser.parse_args()
return parsed


if __name__ == "__main__":
args = parse_args()
import os
if not os.environ.get("HUGGING_FACE_HUB_TOKEN"):
print("No Hugging Face token found. Please login.")
login()
main(args)

Loading