Skip to content

Datasets API

Dataset loading and management utilities for text and classification datasets.

mi_crow.datasets

BaseDataset

BaseDataset(ds, store, loading_strategy=LoadingStrategy.MEMORY)

Bases: ABC

Abstract base class for datasets with support for multiple sources, loading strategies, and Store integration.

Loading Strategies: - MEMORY: Load entire dataset into memory (fastest random access, highest memory usage) - DISK: Save to disk, read dynamically via memory-mapped Arrow files (supports len/getitem, lower memory usage) - STREAMING: True streaming mode using IterableDataset (lowest memory, no len/getitem support, no stratification and limit support)

Initialize dataset.

Parameters:

Name Type Description Default
ds Dataset | IterableDataset

HuggingFace Dataset or IterableDataset

required
store Store

Store instance for caching/persistence

required
loading_strategy LoadingStrategy

How to load data (MEMORY, DISK, or STREAMING)

MEMORY

Raises:

Type Description
ValueError

If store is None, loading_strategy is invalid, or dataset operations fail

OSError

If file system operations fail

Source code in src/mi_crow/datasets/base_dataset.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    ds: Dataset | IterableDataset,
    store: Store,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
):
    """
    Initialize dataset.

    Args:
        ds: HuggingFace Dataset or IterableDataset
        store: Store instance for caching/persistence
        loading_strategy: How to load data (MEMORY, DISK, or STREAMING)

    Raises:
        ValueError: If store is None, loading_strategy is invalid, or dataset operations fail
        OSError: If file system operations fail
    """
    self._validate_initialization_params(store, loading_strategy)

    self._store = store
    self._loading_strategy = loading_strategy
    self._dataset_dir: Path = Path(store.base_path) / store.dataset_prefix

    is_iterable_input = isinstance(ds, IterableDataset)

    if loading_strategy == LoadingStrategy.MEMORY:
        # MEMORY: Convert to Dataset if needed, save to disk, load fully into memory
        self._is_iterable = False
        if is_iterable_input:
            ds = Dataset.from_generator(lambda: iter(ds))
        self._ds = self._save_and_load_dataset(ds, use_memory_mapping=False)
    elif loading_strategy == LoadingStrategy.DISK:
        # DISK: Save to disk, use memory-mapped Arrow files (supports len/getitem)
        self._is_iterable = False
        if is_iterable_input:
            ds = Dataset.from_generator(lambda: iter(ds))
        self._ds = self._save_and_load_dataset(ds, use_memory_mapping=True)
    elif loading_strategy == LoadingStrategy.STREAMING:
        # STREAMING: Convert to IterableDataset, don't save to disk (no len/getitem)
        if not is_iterable_input:
            ds = ds.to_iterable_dataset()
        self._is_iterable = True
        self._ds = ds
        # Don't save to disk for iterable-only mode
    else:
        raise ValueError(
            f"Unknown loading strategy: {loading_strategy}. Must be one of: {[s.value for s in LoadingStrategy]}"
        )

is_streaming property

is_streaming

Whether this dataset is streaming (DISK or STREAMING).

__getitem__ abstractmethod

__getitem__(idx)

Get item(s) by index.

Source code in src/mi_crow/datasets/base_dataset.py
405
406
407
408
@abstractmethod
def __getitem__(self, idx: IndexLike) -> Any:
    """Get item(s) by index."""
    pass

__len__ abstractmethod

__len__()

Return the number of items in the dataset.

Source code in src/mi_crow/datasets/base_dataset.py
400
401
402
403
@abstractmethod
def __len__(self) -> int:
    """Return the number of items in the dataset."""
    pass

extract_texts_from_batch abstractmethod

extract_texts_from_batch(batch)

Extract text strings from a batch.

Parameters:

Name Type Description Default
batch List[Any]

A batch as returned by iter_batches()

required

Returns:

Type Description
List[str]

List of text strings ready for model inference

Source code in src/mi_crow/datasets/base_dataset.py
420
421
422
423
424
425
426
427
428
429
430
@abstractmethod
def extract_texts_from_batch(self, batch: List[Any]) -> List[str]:
    """Extract text strings from a batch.

    Args:
        batch: A batch as returned by iter_batches()

    Returns:
        List of text strings ready for model inference
    """
    pass

from_csv classmethod

from_csv(source, store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', delimiter=',', stratify_by=None, stratify_seed=None, drop_na_columns=None, **kwargs)

Load dataset from CSV file.

Parameters:

Name Type Description Default
source Union[str, Path]

Path to CSV file

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the column containing text

'text'
delimiter str

CSV delimiter (default: comma)

','
stratify_by Optional[str]

Optional column used for stratified sampling (non-streaming only)

None
stratify_seed Optional[int]

Optional RNG seed for stratified sampling

None
drop_na_columns Optional[List[str]]

Optional list of columns to check for None/empty values

None
**kwargs Any

Additional arguments passed to load_dataset

{}

Returns:

Type Description
'BaseDataset'

BaseDataset instance

Raises:

Type Description
FileNotFoundError

If CSV file doesn't exist

ValueError

If store is None or source is invalid

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/base_dataset.py
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
@classmethod
def from_csv(
    cls,
    source: Union[str, Path],
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    delimiter: str = ",",
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    drop_na_columns: Optional[List[str]] = None,
    **kwargs: Any,
) -> "BaseDataset":
    """
    Load dataset from CSV file.

    Args:
        source: Path to CSV file
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the column containing text
        delimiter: CSV delimiter (default: comma)
        stratify_by: Optional column used for stratified sampling (non-streaming only)
        stratify_seed: Optional RNG seed for stratified sampling
        drop_na_columns: Optional list of columns to check for None/empty values
        **kwargs: Additional arguments passed to load_dataset

    Returns:
        BaseDataset instance

    Raises:
        FileNotFoundError: If CSV file doesn't exist
        ValueError: If store is None or source is invalid
        RuntimeError: If dataset loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    use_streaming = loading_strategy == LoadingStrategy.STREAMING
    if (stratify_by or drop_na_columns) and use_streaming:
        raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")

    ds = cls._load_csv_source(
        source,
        delimiter=delimiter,
        streaming=use_streaming,
        **kwargs,
    )

    if not use_streaming and (stratify_by or drop_na_columns):
        ds = cls._postprocess_non_streaming_dataset(
            ds,
            stratify_by=stratify_by,
            stratify_seed=stratify_seed,
            drop_na_columns=drop_na_columns,
        )

    return cls(ds, store=store, loading_strategy=loading_strategy)

from_disk classmethod

from_disk(store, *, loading_strategy=LoadingStrategy.MEMORY, **kwargs)

Load dataset from already-saved Arrow files on disk.

Use this when you've previously saved a dataset and want to reload it without re-downloading from HuggingFace or re-applying transformations.

Parameters:

Name Type Description Default
store Store

Store instance pointing to where the dataset was saved (dataset will be loaded from store.base_path/store.dataset_prefix/)

required
loading_strategy LoadingStrategy

Loading strategy (MEMORY or DISK only, not STREAMING)

MEMORY
**kwargs Any

Additional arguments (for subclass compatibility)

{}

Returns:

Type Description
'BaseDataset'

BaseDataset instance loaded from disk

Raises:

Type Description
ValueError

If store is None or loading_strategy is STREAMING

FileNotFoundError

If dataset directory doesn't exist

RuntimeError

If dataset loading fails

Example
First: save dataset

dataset_store = LocalStore("store/my_dataset") dataset = ClassificationDataset.from_huggingface(..., store=dataset_store)

Dataset saved to: store/my_dataset/datasets/*.arrow
Later: reload from disk

dataset_store = LocalStore("store/my_dataset") dataset = ClassificationDataset.from_disk(store=dataset_store)

Source code in src/mi_crow/datasets/base_dataset.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
@classmethod
def from_disk(
    cls,
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    **kwargs: Any,
) -> "BaseDataset":
    """
    Load dataset from already-saved Arrow files on disk.

    Use this when you've previously saved a dataset and want to reload it
    without re-downloading from HuggingFace or re-applying transformations.

    Args:
        store: Store instance pointing to where the dataset was saved
               (dataset will be loaded from store.base_path/store.dataset_prefix/)
        loading_strategy: Loading strategy (MEMORY or DISK only, not STREAMING)
        **kwargs: Additional arguments (for subclass compatibility)

    Returns:
        BaseDataset instance loaded from disk

    Raises:
        ValueError: If store is None or loading_strategy is STREAMING
        FileNotFoundError: If dataset directory doesn't exist
        RuntimeError: If dataset loading fails

    Example:
        # First: save dataset
        dataset_store = LocalStore("store/my_dataset")
        dataset = ClassificationDataset.from_huggingface(..., store=dataset_store)
        # Dataset saved to: store/my_dataset/datasets/*.arrow

        # Later: reload from disk
        dataset_store = LocalStore("store/my_dataset")
        dataset = ClassificationDataset.from_disk(store=dataset_store)
    """
    if store is None:
        raise ValueError("store cannot be None")

    if loading_strategy == LoadingStrategy.STREAMING:
        raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")

    dataset_dir = Path(store.base_path) / store.dataset_prefix

    if not dataset_dir.exists():
        raise FileNotFoundError(
            f"Dataset directory not found: {dataset_dir}. "
            f"Make sure you've previously saved a dataset to this store location."
        )

    # Verify it's a valid Arrow dataset directory
    arrow_files = list(dataset_dir.glob("*.arrow"))
    if not arrow_files:
        raise FileNotFoundError(
            f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
        )

    try:
        use_memory_mapping = loading_strategy == LoadingStrategy.DISK
        ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
    except Exception as e:
        raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e

    return cls(ds, store=store, loading_strategy=loading_strategy)

from_huggingface classmethod

from_huggingface(repo_id, store, *, split='train', loading_strategy=LoadingStrategy.MEMORY, revision=None, streaming=None, filters=None, limit=None, stratify_by=None, stratify_seed=None, **kwargs)

Load dataset from HuggingFace Hub.

Parameters:

Name Type Description Default
repo_id str

HuggingFace dataset repository ID

required
store Store

Store instance

required
split str

Dataset split (e.g., "train", "validation")

'train'
loading_strategy LoadingStrategy

Loading strategy (MEMORY, DISK, or STREAMING)

MEMORY
revision Optional[str]

Optional git revision/branch/tag

None
streaming Optional[bool]

Optional override for streaming (if None, uses loading_strategy)

None
filters Optional[Dict[str, Any]]

Optional dict of column->value pairs used for exact-match filtering

None
limit Optional[int]

Optional maximum number of rows to keep (applied after filtering/stratification)

None
stratify_by Optional[str]

Optional column to use for stratified sampling (non-streaming only)

None
stratify_seed Optional[int]

Optional RNG seed for deterministic stratification

None
**kwargs Any

Additional arguments passed to load_dataset

{}

Returns:

Type Description
'BaseDataset'

BaseDataset instance

Raises:

Type Description
ValueError

If repo_id is empty or store is None

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/base_dataset.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
@classmethod
def from_huggingface(
    cls,
    repo_id: str,
    store: Store,
    *,
    split: str = "train",
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    revision: Optional[str] = None,
    streaming: Optional[bool] = None,
    filters: Optional[Dict[str, Any]] = None,
    limit: Optional[int] = None,
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    **kwargs: Any,
) -> "BaseDataset":
    """
    Load dataset from HuggingFace Hub.

    Args:
        repo_id: HuggingFace dataset repository ID
        store: Store instance
        split: Dataset split (e.g., "train", "validation")
        loading_strategy: Loading strategy (MEMORY, DISK, or STREAMING)
        revision: Optional git revision/branch/tag
        streaming: Optional override for streaming (if None, uses loading_strategy)
        filters: Optional dict of column->value pairs used for exact-match filtering
        limit: Optional maximum number of rows to keep (applied after filtering/stratification)
        stratify_by: Optional column to use for stratified sampling (non-streaming only)
        stratify_seed: Optional RNG seed for deterministic stratification
        **kwargs: Additional arguments passed to load_dataset

    Returns:
        BaseDataset instance

    Raises:
        ValueError: If repo_id is empty or store is None
        RuntimeError: If dataset loading fails
    """
    if not repo_id or not isinstance(repo_id, str) or not repo_id.strip():
        raise ValueError(f"repo_id must be a non-empty string, got: {repo_id!r}")

    if store is None:
        raise ValueError("store cannot be None")

    # Determine if we should use streaming for HuggingFace load_dataset
    use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)

    if stratify_by and loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError("Stratification is not supported for STREAMING datasets.")

    try:
        ds = load_dataset(
            path=repo_id,
            split=split,
            revision=revision,
            streaming=use_streaming,
            **kwargs,
        )
    except Exception as e:
        raise RuntimeError(
            f"Failed to load dataset from HuggingFace Hub: repo_id={repo_id!r}, "
            f"split={split!r}, revision={revision!r}. Error: {e}"
        ) from e

    if use_streaming:
        if filters or limit or stratify_by:
            raise NotImplementedError(
                "filters, limit, and stratification are not supported when streaming datasets. "
                "Choose MEMORY or DISK loading strategy instead."
            )
    else:
        ds = cls._postprocess_non_streaming_dataset(
            ds,
            filters=filters,
            limit=limit,
            stratify_by=stratify_by,
            stratify_seed=stratify_seed,
        )

    return cls(ds, store=store, loading_strategy=loading_strategy)

from_json classmethod

from_json(source, store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', stratify_by=None, stratify_seed=None, drop_na_columns=None, **kwargs)

Load dataset from JSON or JSONL file.

Parameters:

Name Type Description Default
source Union[str, Path]

Path to JSON or JSONL file

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the field containing text (for JSON objects)

'text'
stratify_by Optional[str]

Optional column used for stratified sampling (non-streaming only)

None
stratify_seed Optional[int]

Optional RNG seed for stratified sampling

None
drop_na_columns Optional[List[str]]

Optional list of columns to check for None/empty values

None
**kwargs Any

Additional arguments passed to load_dataset

{}

Returns:

Type Description
'BaseDataset'

BaseDataset instance

Raises:

Type Description
FileNotFoundError

If JSON file doesn't exist

ValueError

If store is None or source is invalid

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/base_dataset.py
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
@classmethod
def from_json(
    cls,
    source: Union[str, Path],
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    drop_na_columns: Optional[List[str]] = None,
    **kwargs: Any,
) -> "BaseDataset":
    """
    Load dataset from JSON or JSONL file.

    Args:
        source: Path to JSON or JSONL file
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the field containing text (for JSON objects)
        stratify_by: Optional column used for stratified sampling (non-streaming only)
        stratify_seed: Optional RNG seed for stratified sampling
        drop_na_columns: Optional list of columns to check for None/empty values
        **kwargs: Additional arguments passed to load_dataset

    Returns:
        BaseDataset instance

    Raises:
        FileNotFoundError: If JSON file doesn't exist
        ValueError: If store is None or source is invalid
        RuntimeError: If dataset loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    use_streaming = loading_strategy == LoadingStrategy.STREAMING
    if (stratify_by or drop_na_columns) and use_streaming:
        raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")

    ds = cls._load_json_source(
        source,
        streaming=use_streaming,
        **kwargs,
    )

    if not use_streaming and (stratify_by or drop_na_columns):
        ds = cls._postprocess_non_streaming_dataset(
            ds,
            stratify_by=stratify_by,
            stratify_seed=stratify_seed,
            drop_na_columns=drop_na_columns,
        )

    return cls(ds, store=store, loading_strategy=loading_strategy)

get_all_texts abstractmethod

get_all_texts()

Get all texts from the dataset.

Returns:

Type Description
List[str]

List of all text strings in the dataset

Raises:

Type Description
NotImplementedError

If not supported for streaming datasets

Source code in src/mi_crow/datasets/base_dataset.py
432
433
434
435
436
437
438
439
440
441
442
@abstractmethod
def get_all_texts(self) -> List[str]:
    """Get all texts from the dataset.

    Returns:
        List of all text strings in the dataset

    Raises:
        NotImplementedError: If not supported for streaming datasets
    """
    pass

get_batch

get_batch(start, batch_size)

Get a contiguous batch of items.

Parameters:

Name Type Description Default
start int

Starting index

required
batch_size int

Number of items to retrieve

required

Returns:

Type Description
List[Any]

List of items

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

Source code in src/mi_crow/datasets/base_dataset.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def get_batch(self, start: int, batch_size: int) -> List[Any]:
    """
    Get a contiguous batch of items.

    Args:
        start: Starting index
        batch_size: Number of items to retrieve

    Returns:
        List of items

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError("get_batch not supported for STREAMING datasets. Use iter_batches instead.")
    if batch_size <= 0:
        return []
    end = min(start + batch_size, len(self))
    if start >= end:
        return []
    return self[start:end]

head

head(n=5)

Get first n items.

Works for all loading strategies.

Parameters:

Name Type Description Default
n int

Number of items to retrieve (default: 5)

5

Returns:

Type Description
List[Any]

List of first n items

Source code in src/mi_crow/datasets/base_dataset.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def head(self, n: int = 5) -> List[Any]:
    """
    Get first n items.

    Works for all loading strategies.

    Args:
        n: Number of items to retrieve (default: 5)

    Returns:
        List of first n items
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        items = []
        for i, item in enumerate(self.iter_items()):
            if i >= n:
                break
            items.append(item)
        return items
    return self[:n]

iter_batches abstractmethod

iter_batches(batch_size)

Iterate over items in batches.

Source code in src/mi_crow/datasets/base_dataset.py
415
416
417
418
@abstractmethod
def iter_batches(self, batch_size: int) -> Iterator[List[Any]]:
    """Iterate over items in batches."""
    pass

iter_items abstractmethod

iter_items()

Iterate over items one by one.

Source code in src/mi_crow/datasets/base_dataset.py
410
411
412
413
@abstractmethod
def iter_items(self) -> Iterator[Any]:
    """Iterate over items one by one."""
    pass

sample

sample(n=5)

Get n random items from the dataset.

Works for MEMORY and DISK strategies only.

Parameters:

Name Type Description Default
n int

Number of items to sample

5

Returns:

Type Description
List[Any]

List of n randomly sampled items

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

Source code in src/mi_crow/datasets/base_dataset.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def sample(self, n: int = 5) -> List[Any]:
    """
    Get n random items from the dataset.

    Works for MEMORY and DISK strategies only.

    Args:
        n: Number of items to sample

    Returns:
        List of n randomly sampled items

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError(
            "sample() not supported for STREAMING datasets. Use iter_items() and sample manually."
        )

    dataset_len = len(self)
    if n <= 0:
        return []
    if n >= dataset_len:
        # Return all items in random order
        indices = list(range(dataset_len))
        random.shuffle(indices)
        return [self[i] for i in indices]

    # Sample n random indices
    indices = random.sample(range(dataset_len), n)
    # Use __getitem__ with list of indices
    return self[indices]

ClassificationDataset

ClassificationDataset(ds, store, loading_strategy=LoadingStrategy.MEMORY, text_field='text', category_field='category')

Bases: BaseDataset

Classification dataset with text and category/label columns. Each item is a dict with 'text' and label column(s) as keys. Supports single or multiple label columns.

Initialize classification dataset.

Parameters:

Name Type Description Default
ds Dataset | IterableDataset

HuggingFace Dataset or IterableDataset

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the column containing text

'text'
category_field Union[str, List[str]]

Name(s) of the column(s) containing category/label. Can be a single string or a list of strings for multiple labels.

'category'

Raises:

Type Description
ValueError

If text_field or category_field is empty, or fields not found in dataset

Source code in src/mi_crow/datasets/classification_dataset.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    ds: Dataset | IterableDataset,
    store: Store,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    category_field: Union[str, List[str]] = "category",
):
    """
    Initialize classification dataset.

    Args:
        ds: HuggingFace Dataset or IterableDataset
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the column containing text
        category_field: Name(s) of the column(s) containing category/label.
                      Can be a single string or a list of strings for multiple labels.

    Raises:
        ValueError: If text_field or category_field is empty, or fields not found in dataset
    """
    self._validate_text_field(text_field)

    # Normalize category_field to list
    if isinstance(category_field, str):
        self._category_fields = [category_field]
    else:
        self._category_fields = list(category_field)

    self._validate_category_fields(self._category_fields)

    # Validate dataset
    is_iterable = isinstance(ds, IterableDataset)
    if not is_iterable:
        if text_field not in ds.column_names:
            raise ValueError(f"Dataset must have a '{text_field}' column; got columns: {ds.column_names}")
        for cat_field in self._category_fields:
            if cat_field not in ds.column_names:
                raise ValueError(f"Dataset must have a '{cat_field}' column; got columns: {ds.column_names}")
        # Set format with all required columns
        format_columns = [text_field] + self._category_fields
        ds.set_format("python", columns=format_columns)

    self._text_field = text_field
    self._category_field = category_field  # Keep original for backward compatibility
    super().__init__(ds, store=store, loading_strategy=loading_strategy)

__getitem__

__getitem__(idx)

Get item(s) by index. Returns dict with 'text' and label column(s) as keys.

For single label: {"text": "...", "category": "..."} For multiple labels: {"text": "...", "label1": "...", "label2": "..."}

Parameters:

Name Type Description Default
idx IndexLike

Index (int), slice, or sequence of indices

required

Returns:

Type Description
Union[Dict[str, Any], List[Dict[str, Any]]]

Single item dict or list of item dicts

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

IndexError

If index is out of bounds

ValueError

If dataset is empty

Source code in src/mi_crow/datasets/classification_dataset.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def __getitem__(self, idx: IndexLike) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
    """
    Get item(s) by index. Returns dict with 'text' and label column(s) as keys.

    For single label: {"text": "...", "category": "..."}
    For multiple labels: {"text": "...", "label1": "...", "label2": "..."}

    Args:
        idx: Index (int), slice, or sequence of indices

    Returns:
        Single item dict or list of item dicts

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
        IndexError: If index is out of bounds
        ValueError: If dataset is empty
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError("Indexing not supported for STREAMING datasets. Use iter_items or iter_batches.")

    dataset_len = len(self)
    if dataset_len == 0:
        raise ValueError("Cannot index into empty dataset")

    if isinstance(idx, int):
        if idx < 0:
            idx = dataset_len + idx
        if idx < 0 or idx >= dataset_len:
            raise IndexError(f"Index {idx} out of bounds for dataset of length {dataset_len}")
        row = self._ds[idx]
        return self._extract_item_from_row(row)

    if isinstance(idx, slice):
        start, stop, step = idx.indices(dataset_len)
        if step != 1:
            indices = list(range(start, stop, step))
            selected = self._ds.select(indices)
        else:
            selected = self._ds.select(range(start, stop))
        return [self._extract_item_from_row(row) for row in selected]

    if isinstance(idx, Sequence):
        # Validate all indices are in bounds
        invalid_indices = [i for i in idx if not (0 <= i < dataset_len)]
        if invalid_indices:
            raise IndexError(f"Indices out of bounds: {invalid_indices} (dataset length: {dataset_len})")
        selected = self._ds.select(list(idx))
        return [self._extract_item_from_row(row) for row in selected]

    raise TypeError(f"Invalid index type: {type(idx)}")

__len__

__len__()

Return the number of items in the dataset.

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

Source code in src/mi_crow/datasets/classification_dataset.py
129
130
131
132
133
134
135
136
137
138
def __len__(self) -> int:
    """
    Return the number of items in the dataset.

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError("len() not supported for STREAMING datasets")
    return self._ds.num_rows

extract_texts_from_batch

extract_texts_from_batch(batch)

Extract text strings from a batch of classification items.

Parameters:

Name Type Description Default
batch List[Dict[str, Any]]

List of dicts with 'text' and category fields

required

Returns:

Type Description
List[Optional[str]]

List of text strings from the batch

Raises:

Type Description
ValueError

If 'text' key is not found in any batch item

Source code in src/mi_crow/datasets/classification_dataset.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def extract_texts_from_batch(self, batch: List[Dict[str, Any]]) -> List[Optional[str]]:
    """Extract text strings from a batch of classification items.

    Args:
        batch: List of dicts with 'text' and category fields

    Returns:
        List of text strings from the batch

    Raises:
        ValueError: If 'text' key is not found in any batch item
    """
    texts = []
    for item in batch:
        if "text" not in item:
            raise ValueError(f"'text' key not found in batch item. Available keys: {list(item.keys())}")
        texts.append(item["text"])
    return texts

from_csv classmethod

from_csv(source, store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', category_field='category', delimiter=',', stratify_by=None, stratify_seed=None, drop_na=False, **kwargs)

Load classification dataset from CSV file.

Parameters:

Name Type Description Default
source Union[str, Path]

Path to CSV file

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the column containing text

'text'
category_field Union[str, List[str]]

Name(s) of the column(s) containing category/label

'category'
delimiter str

CSV delimiter (default: comma)

','
stratify_by Optional[str]

Optional column used for stratified sampling

None
stratify_seed Optional[int]

Optional RNG seed for stratified sampling

None
drop_na bool

Whether to drop rows with None/empty text or categories

False
**kwargs Any

Additional arguments for load_dataset

{}

Returns:

Type Description
'ClassificationDataset'

ClassificationDataset instance

Raises:

Type Description
FileNotFoundError

If CSV file doesn't exist

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/classification_dataset.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
@classmethod
def from_csv(
    cls,
    source: Union[str, Path],
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    category_field: Union[str, List[str]] = "category",
    delimiter: str = ",",
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    drop_na: bool = False,
    **kwargs: Any,
) -> "ClassificationDataset":
    """
    Load classification dataset from CSV file.

    Args:
        source: Path to CSV file
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the column containing text
        category_field: Name(s) of the column(s) containing category/label
        delimiter: CSV delimiter (default: comma)
        stratify_by: Optional column used for stratified sampling
        stratify_seed: Optional RNG seed for stratified sampling
        drop_na: Whether to drop rows with None/empty text or categories
        **kwargs: Additional arguments for load_dataset

    Returns:
        ClassificationDataset instance

    Raises:
        FileNotFoundError: If CSV file doesn't exist
        RuntimeError: If dataset loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    use_streaming = loading_strategy == LoadingStrategy.STREAMING
    if (stratify_by or drop_na) and use_streaming:
        raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")

    # Load CSV using parent's static method
    ds = cls._load_csv_source(
        source,
        delimiter=delimiter,
        streaming=use_streaming,
        **kwargs,
    )

    # Apply postprocessing if not streaming
    if not use_streaming and (stratify_by or drop_na):
        drop_na_columns = None
        if drop_na:
            cat_fields = [category_field] if isinstance(category_field, str) else category_field
            drop_na_columns = [text_field] + list(cat_fields)

        ds = cls._postprocess_non_streaming_dataset(
            ds,
            stratify_by=stratify_by,
            stratify_seed=stratify_seed,
            drop_na_columns=drop_na_columns,
        )

    return cls(
        ds,
        store=store,
        loading_strategy=loading_strategy,
        text_field=text_field,
        category_field=category_field,
    )

from_disk classmethod

from_disk(store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', category_field='category')

Load classification dataset from already-saved Arrow files on disk.

Use this when you've previously saved a dataset and want to reload it without re-downloading from HuggingFace or re-applying transformations.

Parameters:

Name Type Description Default
store Store

Store instance pointing to where the dataset was saved

required
loading_strategy LoadingStrategy

Loading strategy (MEMORY or DISK only)

MEMORY
text_field str

Name of the column containing text

'text'
category_field Union[str, List[str]]

Name(s) of the column(s) containing category/label

'category'

Returns:

Type Description
'ClassificationDataset'

ClassificationDataset instance loaded from disk

Raises:

Type Description
FileNotFoundError

If dataset directory doesn't exist or contains no Arrow files

ValueError

If required fields are not in the loaded dataset

Example
First: save dataset

dataset_store = LocalStore("store/wgmix_test") dataset = ClassificationDataset.from_huggingface( "allenai/wildguardmix", store=dataset_store, limit=100 )

Dataset saved to: store/wgmix_test/datasets/*.arrow
Later: reload from disk

dataset_store = LocalStore("store/wgmix_test") dataset = ClassificationDataset.from_disk( store=dataset_store, text_field="prompt", category_field="prompt_harm_label" )

Source code in src/mi_crow/datasets/classification_dataset.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
@classmethod
def from_disk(
    cls,
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    category_field: Union[str, List[str]] = "category",
) -> "ClassificationDataset":
    """
    Load classification dataset from already-saved Arrow files on disk.

    Use this when you've previously saved a dataset and want to reload it
    without re-downloading from HuggingFace or re-applying transformations.

    Args:
        store: Store instance pointing to where the dataset was saved
        loading_strategy: Loading strategy (MEMORY or DISK only)
        text_field: Name of the column containing text
        category_field: Name(s) of the column(s) containing category/label

    Returns:
        ClassificationDataset instance loaded from disk

    Raises:
        FileNotFoundError: If dataset directory doesn't exist or contains no Arrow files
        ValueError: If required fields are not in the loaded dataset

    Example:
        # First: save dataset
        dataset_store = LocalStore("store/wgmix_test")
        dataset = ClassificationDataset.from_huggingface(
            "allenai/wildguardmix",
            store=dataset_store,
            limit=100
        )
        # Dataset saved to: store/wgmix_test/datasets/*.arrow

        # Later: reload from disk
        dataset_store = LocalStore("store/wgmix_test")
        dataset = ClassificationDataset.from_disk(
            store=dataset_store,
            text_field="prompt",
            category_field="prompt_harm_label"
        )
    """

    if store is None:
        raise ValueError("store cannot be None")

    if loading_strategy == LoadingStrategy.STREAMING:
        raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")

    dataset_dir = Path(store.base_path) / store.dataset_prefix

    if not dataset_dir.exists():
        raise FileNotFoundError(
            f"Dataset directory not found: {dataset_dir}. "
            f"Make sure you've previously saved a dataset to this store location."
        )

    # Verify it's a valid Arrow dataset directory
    arrow_files = list(dataset_dir.glob("*.arrow"))
    if not arrow_files:
        raise FileNotFoundError(
            f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
        )

    try:
        use_memory_mapping = loading_strategy == LoadingStrategy.DISK
        ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
    except Exception as e:
        raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e

    # Create ClassificationDataset with the loaded dataset and field names
    return cls(
        ds,
        store=store,
        loading_strategy=loading_strategy,
        text_field=text_field,
        category_field=category_field,
    )

from_huggingface classmethod

from_huggingface(repo_id, store, *, split='train', loading_strategy=LoadingStrategy.MEMORY, revision=None, text_field='text', category_field='category', filters=None, limit=None, stratify_by=None, stratify_seed=None, streaming=None, drop_na=False, **kwargs)

Load classification dataset from HuggingFace Hub.

Parameters:

Name Type Description Default
repo_id str

HuggingFace dataset repository ID

required
store Store

Store instance

required
split str

Dataset split

'train'
loading_strategy LoadingStrategy

Loading strategy

MEMORY
revision Optional[str]

Optional git revision

None
text_field str

Name of the column containing text

'text'
category_field Union[str, List[str]]

Name(s) of the column(s) containing category/label

'category'
filters Optional[Dict[str, Any]]

Optional filters to apply (dict of column: value)

None
limit Optional[int]

Optional limit on number of rows

None
stratify_by Optional[str]

Optional column used for stratified sampling (non-streaming only)

None
stratify_seed Optional[int]

Optional RNG seed for stratified sampling

None
streaming Optional[bool]

Optional override for streaming

None
drop_na bool

Whether to drop rows with None/empty text or categories

False
**kwargs Any

Additional arguments for load_dataset

{}

Returns:

Type Description
'ClassificationDataset'

ClassificationDataset instance

Raises:

Type Description
ValueError

If parameters are invalid

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/classification_dataset.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
@classmethod
def from_huggingface(
    cls,
    repo_id: str,
    store: Store,
    *,
    split: str = "train",
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    revision: Optional[str] = None,
    text_field: str = "text",
    category_field: Union[str, List[str]] = "category",
    filters: Optional[Dict[str, Any]] = None,
    limit: Optional[int] = None,
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    streaming: Optional[bool] = None,
    drop_na: bool = False,
    **kwargs: Any,
) -> "ClassificationDataset":
    """
    Load classification dataset from HuggingFace Hub.

    Args:
        repo_id: HuggingFace dataset repository ID
        store: Store instance
        split: Dataset split
        loading_strategy: Loading strategy
        revision: Optional git revision
        text_field: Name of the column containing text
        category_field: Name(s) of the column(s) containing category/label
        filters: Optional filters to apply (dict of column: value)
        limit: Optional limit on number of rows
        stratify_by: Optional column used for stratified sampling (non-streaming only)
        stratify_seed: Optional RNG seed for stratified sampling
        streaming: Optional override for streaming
        drop_na: Whether to drop rows with None/empty text or categories
        **kwargs: Additional arguments for load_dataset

    Returns:
        ClassificationDataset instance

    Raises:
        ValueError: If parameters are invalid
        RuntimeError: If dataset loading fails
    """
    use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)

    if (stratify_by or drop_na) and use_streaming:
        raise NotImplementedError(
            "Stratification and drop_na are not supported for streaming datasets. Use MEMORY or DISK."
        )

    try:
        ds = load_dataset(
            path=repo_id,
            split=split,
            revision=revision,
            streaming=use_streaming,
            **kwargs,
        )

        if use_streaming:
            if filters or limit:
                raise NotImplementedError(
                    "filters and limit are not supported when streaming datasets. Choose MEMORY or DISK."
                )
        else:
            drop_na_columns = None
            if drop_na:
                cat_fields = [category_field] if isinstance(category_field, str) else category_field
                drop_na_columns = [text_field] + list(cat_fields)

            ds = cls._postprocess_non_streaming_dataset(
                ds,
                filters=filters,
                limit=limit,
                stratify_by=stratify_by,
                stratify_seed=stratify_seed,
                drop_na_columns=drop_na_columns,
            )
    except Exception as e:
        raise RuntimeError(
            f"Failed to load classification dataset from HuggingFace Hub: "
            f"repo_id={repo_id!r}, split={split!r}, text_field={text_field!r}, "
            f"category_field={category_field!r}. Error: {e}"
        ) from e

    return cls(
        ds,
        store=store,
        loading_strategy=loading_strategy,
        text_field=text_field,
        category_field=category_field,
    )

from_json classmethod

from_json(source, store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', category_field='category', stratify_by=None, stratify_seed=None, drop_na=False, **kwargs)

Load classification dataset from JSON/JSONL file.

Parameters:

Name Type Description Default
source Union[str, Path]

Path to JSON or JSONL file

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the field containing text

'text'
category_field Union[str, List[str]]

Name(s) of the field(s) containing category/label

'category'
stratify_by Optional[str]

Optional column used for stratified sampling

None
stratify_seed Optional[int]

Optional RNG seed for stratified sampling

None
drop_na bool

Whether to drop rows with None/empty text or categories

False
**kwargs Any

Additional arguments for load_dataset

{}

Returns:

Type Description
'ClassificationDataset'

ClassificationDataset instance

Raises:

Type Description
FileNotFoundError

If JSON file doesn't exist

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/classification_dataset.py
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
@classmethod
def from_json(
    cls,
    source: Union[str, Path],
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    category_field: Union[str, List[str]] = "category",
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    drop_na: bool = False,
    **kwargs: Any,
) -> "ClassificationDataset":
    """
    Load classification dataset from JSON/JSONL file.

    Args:
        source: Path to JSON or JSONL file
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the field containing text
        category_field: Name(s) of the field(s) containing category/label
        stratify_by: Optional column used for stratified sampling
        stratify_seed: Optional RNG seed for stratified sampling
        drop_na: Whether to drop rows with None/empty text or categories
        **kwargs: Additional arguments for load_dataset

    Returns:
        ClassificationDataset instance

    Raises:
        FileNotFoundError: If JSON file doesn't exist
        RuntimeError: If dataset loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    use_streaming = loading_strategy == LoadingStrategy.STREAMING
    if (stratify_by or drop_na) and use_streaming:
        raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")

    # Load JSON using parent's static method
    ds = cls._load_json_source(
        source,
        streaming=use_streaming,
        **kwargs,
    )

    # Apply postprocessing if not streaming
    if not use_streaming and (stratify_by or drop_na):
        drop_na_columns = None
        if drop_na:
            cat_fields = [category_field] if isinstance(category_field, str) else category_field
            drop_na_columns = [text_field] + list(cat_fields)

        ds = cls._postprocess_non_streaming_dataset(
            ds,
            stratify_by=stratify_by,
            stratify_seed=stratify_seed,
            drop_na_columns=drop_na_columns,
        )

    return cls(
        ds,
        store=store,
        loading_strategy=loading_strategy,
        text_field=text_field,
        category_field=category_field,
    )

get_all_texts

get_all_texts()

Get all texts from the dataset.

Returns:

Type Description
List[Optional[str]]

List of all text strings

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING and dataset is very large

Source code in src/mi_crow/datasets/classification_dataset.py
305
306
307
308
309
310
311
312
313
314
315
316
def get_all_texts(self) -> List[Optional[str]]:
    """Get all texts from the dataset.

    Returns:
        List of all text strings

    Raises:
        NotImplementedError: If loading_strategy is STREAMING and dataset is very large
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        return [item["text"] for item in self.iter_items()]
    return list(self._ds[self._text_field])

get_categories

get_categories()

Get unique categories in the dataset, excluding None values.

Returns:

Type Description
Union[List[Any], Dict[str, List[Any]]]
  • For single label column: List of unique category values
Union[List[Any], Dict[str, List[Any]]]
  • For multiple label columns: Dict mapping column name to list of unique categories

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING and dataset is large

Source code in src/mi_crow/datasets/classification_dataset.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def get_categories(self) -> Union[List[Any], Dict[str, List[Any]]]:  # noqa: C901
    """
    Get unique categories in the dataset, excluding None values.

    Returns:
        - For single label column: List of unique category values
        - For multiple label columns: Dict mapping column name to list of unique categories

    Raises:
        NotImplementedError: If loading_strategy is STREAMING and dataset is large
    """
    if len(self._category_fields) == 1:
        # Single label: return list for backward compatibility
        cat_field = self._category_fields[0]
        if self._loading_strategy == LoadingStrategy.STREAMING:
            categories = set()
            for item in self.iter_items():
                cat = item[cat_field]
                if cat is not None:
                    categories.add(cat)
            return sorted(list(categories))  # noqa: C414
        categories = [cat for cat in set(self._ds[cat_field]) if cat is not None]
        return sorted(categories)
    else:
        # Multiple labels: return dict
        result = {}
        if self._loading_strategy == LoadingStrategy.STREAMING:
            # Collect categories from all items
            category_sets = {field: set() for field in self._category_fields}
            for item in self.iter_items():
                for field in self._category_fields:
                    cat = item[field]
                    if cat is not None:
                        category_sets[field].add(cat)
            for field in self._category_fields:
                result[field] = sorted(list(category_sets[field]))  # noqa: C414
        else:
            # Use direct column access
            for field in self._category_fields:
                categories = [cat for cat in set(self._ds[field]) if cat is not None]
                result[field] = sorted(categories)
        return result

get_categories_for_texts

get_categories_for_texts(texts)

Get categories for given texts (if texts match dataset texts).

Parameters:

Name Type Description Default
texts List[Optional[str]]

List of text strings to look up

required

Returns:

Type Description
Union[List[Any], List[Dict[str, Any]]]
  • For single label column: List of category values (one per text)
Union[List[Any], List[Dict[str, Any]]]
  • For multiple label columns: List of dicts with label columns as keys

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

ValueError

If texts list is empty

Source code in src/mi_crow/datasets/classification_dataset.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def get_categories_for_texts(self, texts: List[Optional[str]]) -> Union[List[Any], List[Dict[str, Any]]]:
    """
    Get categories for given texts (if texts match dataset texts).

    Args:
        texts: List of text strings to look up

    Returns:
        - For single label column: List of category values (one per text)
        - For multiple label columns: List of dicts with label columns as keys

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
        ValueError: If texts list is empty
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError("get_categories_for_texts not supported for STREAMING datasets")

    if not texts:
        raise ValueError("texts list cannot be empty")

    if len(self._category_fields) == 1:
        # Single label: return list for backward compatibility
        cat_field = self._category_fields[0]
        text_to_category = {row[self._text_field]: row[cat_field] for row in self._ds}
        return [text_to_category.get(text) for text in texts]
    else:
        # Multiple labels: return list of dicts
        text_to_categories = {
            row[self._text_field]: {field: row[field] for field in self._category_fields} for row in self._ds
        }
        return [text_to_categories.get(text) for text in texts]

iter_batches

iter_batches(batch_size)

Iterate over items in batches. Each batch is a list of dicts with 'text' and label column(s) as keys.

For single label: [{"text": "...", "category_column_1": "..."}, ...] For multiple labels: [{"text": "...", "category_column_1": "...", "category_column_2": "..."}, ...]

Parameters:

Name Type Description Default
batch_size int

Number of items per batch

required

Yields:

Type Description
List[Dict[str, Any]]

Lists of item dictionaries (batches)

Raises:

Type Description
ValueError

If batch_size <= 0 or required fields are not found in any row

Source code in src/mi_crow/datasets/classification_dataset.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def iter_batches(self, batch_size: int) -> Iterator[List[Dict[str, Any]]]:
    """
    Iterate over items in batches. Each batch is a list of dicts with 'text' and label column(s) as keys.

    For single label: [{"text": "...", "category_column_1": "..."}, ...]
    For multiple labels: [{"text": "...", "category_column_1": "...", "category_column_2": "..."}, ...]

    Args:
        batch_size: Number of items per batch

    Yields:
        Lists of item dictionaries (batches)

    Raises:
        ValueError: If batch_size <= 0 or required fields are not found in any row
    """
    if batch_size <= 0:
        raise ValueError(f"batch_size must be > 0, got: {batch_size}")

    if self._loading_strategy == LoadingStrategy.STREAMING:
        batch = []
        for row in self._ds:
            batch.append(self._extract_item_from_row(row))
            if len(batch) >= batch_size:
                yield batch
                batch = []
        if batch:
            yield batch
    else:
        # Use select to get batches with proper format
        for i in range(0, len(self), batch_size):
            end = min(i + batch_size, len(self))
            batch_list = self[i:end]
            yield batch_list

iter_items

iter_items()

Iterate over items one by one. Yields dict with 'text' and label column(s) as keys.

For single label: {"text": "...", "category_column_1": "..."} For multiple labels: {"text": "...", "category_column_1": "...", "category_column_2": "..."}

Yields:

Type Description
Dict[str, Any]

Item dictionaries with text and category fields

Raises:

Type Description
ValueError

If required fields are not found in any row

Source code in src/mi_crow/datasets/classification_dataset.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def iter_items(self) -> Iterator[Dict[str, Any]]:
    """
    Iterate over items one by one. Yields dict with 'text' and label column(s) as keys.

    For single label: {"text": "...", "category_column_1": "..."}
    For multiple labels: {"text": "...", "category_column_1": "...", "category_column_2": "..."}

    Yields:
        Item dictionaries with text and category fields

    Raises:
        ValueError: If required fields are not found in any row
    """
    for row in self._ds:
        yield self._extract_item_from_row(row)

LoadingStrategy

Bases: Enum

Strategy for loading dataset data.

Choose the best strategy for your use case:

  • MEMORY: Load entire dataset into memory (fastest random access, highest memory usage) Best for: Small datasets that fit in memory, when you need fast random access

  • DISK: Save to disk, read dynamically via memory-mapped Arrow files (supports len/getitem, lower memory usage) Best for: Large datasets that don't fit in memory, when you need random access

  • STREAMING: True streaming mode using IterableDataset (lowest memory, no len/getitem support) Best for: Very large datasets, when you only need sequential iteration

TextDataset

TextDataset(ds, store, loading_strategy=LoadingStrategy.DISK, text_field='text')

Bases: BaseDataset

Text-only dataset with support for multiple sources and loading strategies. Each item is a string (text snippet).

Initialize text dataset.

Parameters:

Name Type Description Default
ds Dataset | IterableDataset

HuggingFace Dataset or IterableDataset

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

DISK
text_field str

Name of the column containing text

'text'

Raises:

Type Description
ValueError

If text_field is empty or not found in dataset

Source code in src/mi_crow/datasets/text_dataset.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    ds: Dataset | IterableDataset,
    store: Store,
    loading_strategy: LoadingStrategy = LoadingStrategy.DISK,
    text_field: str = "text",
):
    """
    Initialize text dataset.

    Args:
        ds: HuggingFace Dataset or IterableDataset
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the column containing text

    Raises:
        ValueError: If text_field is empty or not found in dataset
    """
    self._validate_text_field(text_field)

    # Validate and prepare dataset
    is_iterable = isinstance(ds, IterableDataset)
    if not is_iterable:
        if text_field not in ds.column_names:
            raise ValueError(f"Dataset must have a '{text_field}' column; got columns: {ds.column_names}")
        # Keep only text column for memory efficiency
        columns_to_remove = [c for c in ds.column_names if c != text_field]
        if columns_to_remove:
            ds = ds.remove_columns(columns_to_remove)
        if text_field != "text":
            ds = ds.rename_column(text_field, "text")
        ds.set_format("python", columns=["text"])

    self._text_field = text_field
    super().__init__(ds, store=store, loading_strategy=loading_strategy)

__getitem__

__getitem__(idx)

Get text item(s) by index.

Parameters:

Name Type Description Default
idx IndexLike

Index (int), slice, or sequence of indices

required

Returns:

Type Description
Union[Optional[str], List[Optional[str]]]

Single text string or list of text strings

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

IndexError

If index is out of bounds

ValueError

If dataset is empty

Source code in src/mi_crow/datasets/text_dataset.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def __getitem__(self, idx: IndexLike) -> Union[Optional[str], List[Optional[str]]]:
    """
    Get text item(s) by index.

    Args:
        idx: Index (int), slice, or sequence of indices

    Returns:
        Single text string or list of text strings

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
        IndexError: If index is out of bounds
        ValueError: If dataset is empty
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError("Indexing not supported for STREAMING datasets. Use iter_items or iter_batches.")

    dataset_len = len(self)
    if dataset_len == 0:
        raise ValueError("Cannot index into empty dataset")

    if isinstance(idx, int):
        if idx < 0:
            idx = dataset_len + idx
        if idx < 0 or idx >= dataset_len:
            raise IndexError(f"Index {idx} out of bounds for dataset of length {dataset_len}")
        return self._ds[idx]["text"]

    if isinstance(idx, slice):
        start, stop, step = idx.indices(dataset_len)
        if step != 1:
            indices = list(range(start, stop, step))
            out = self._ds.select(indices)["text"]
        else:
            out = self._ds.select(range(start, stop))["text"]
        return list(out)

    if isinstance(idx, Sequence):
        # Validate all indices are in bounds
        invalid_indices = [i for i in idx if not (0 <= i < dataset_len)]
        if invalid_indices:
            raise IndexError(f"Indices out of bounds: {invalid_indices} (dataset length: {dataset_len})")
        out = self._ds.select(list(idx))["text"]
        return list(out)

    raise TypeError(f"Invalid index type: {type(idx)}")

__len__

__len__()

Return the number of items in the dataset.

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

Source code in src/mi_crow/datasets/text_dataset.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __len__(self) -> int:
    """
    Return the number of items in the dataset.

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError("len() not supported for STREAMING datasets")
    return self._ds.num_rows

extract_texts_from_batch

extract_texts_from_batch(batch)

Extract text strings from a batch.

For TextDataset, batch items are already strings, so return as-is.

Parameters:

Name Type Description Default
batch List[Optional[str]]

List of text strings

required

Returns:

Type Description
List[Optional[str]]

List of text strings (same as input)

Source code in src/mi_crow/datasets/text_dataset.py
193
194
195
196
197
198
199
200
201
202
203
204
def extract_texts_from_batch(self, batch: List[Optional[str]]) -> List[Optional[str]]:
    """Extract text strings from a batch.

    For TextDataset, batch items are already strings, so return as-is.

    Args:
        batch: List of text strings

    Returns:
        List of text strings (same as input)
    """
    return batch

from_csv classmethod

from_csv(source, store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', delimiter=',', stratify_by=None, stratify_seed=None, drop_na=False, **kwargs)

Load text dataset from CSV file.

Parameters:

Name Type Description Default
source Union[str, Path]

Path to CSV file

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the column containing text

'text'
delimiter str

CSV delimiter (default: comma)

','
stratify_by Optional[str]

Optional column to use for stratified sampling

None
stratify_seed Optional[int]

Optional RNG seed for stratified sampling

None
drop_na bool

Whether to drop rows with None/empty text

False
**kwargs Any

Additional arguments for load_dataset

{}

Returns:

Type Description
'TextDataset'

TextDataset instance

Raises:

Type Description
FileNotFoundError

If CSV file doesn't exist

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/text_dataset.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
@classmethod
def from_csv(
    cls,
    source: Union[str, Path],
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    delimiter: str = ",",
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    drop_na: bool = False,
    **kwargs: Any,
) -> "TextDataset":
    """
    Load text dataset from CSV file.

    Args:
        source: Path to CSV file
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the column containing text
        delimiter: CSV delimiter (default: comma)
        stratify_by: Optional column to use for stratified sampling
        stratify_seed: Optional RNG seed for stratified sampling
        drop_na: Whether to drop rows with None/empty text
        **kwargs: Additional arguments for load_dataset

    Returns:
        TextDataset instance

    Raises:
        FileNotFoundError: If CSV file doesn't exist
        RuntimeError: If dataset loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    use_streaming = loading_strategy == LoadingStrategy.STREAMING
    if (stratify_by or drop_na) and use_streaming:
        raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")

    # Load CSV using parent's static method
    ds = cls._load_csv_source(
        source,
        delimiter=delimiter,
        streaming=use_streaming,
        **kwargs,
    )

    # Apply postprocessing if not streaming
    if not use_streaming and (stratify_by or drop_na):
        drop_na_columns = [text_field] if drop_na else None
        ds = cls._postprocess_non_streaming_dataset(
            ds,
            stratify_by=stratify_by,
            stratify_seed=stratify_seed,
            drop_na_columns=drop_na_columns,
        )

    return cls(
        ds,
        store=store,
        loading_strategy=loading_strategy,
        text_field=text_field,
    )

from_disk classmethod

from_disk(store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text')

Load text dataset from already-saved Arrow files on disk.

Use this when you've previously saved a dataset and want to reload it without re-downloading from HuggingFace or re-applying transformations.

Parameters:

Name Type Description Default
store Store

Store instance pointing to where the dataset was saved

required
loading_strategy LoadingStrategy

Loading strategy (MEMORY or DISK only)

MEMORY
text_field str

Name of the column containing text

'text'

Returns:

Type Description
'TextDataset'

TextDataset instance loaded from disk

Raises:

Type Description
FileNotFoundError

If dataset directory doesn't exist or contains no Arrow files

Example
First: save dataset

dataset_store = LocalStore("store/my_texts") dataset = TextDataset.from_huggingface( "wikipedia", store=dataset_store, limit=1000 )

Dataset saved to: store/my_texts/datasets/*.arrow
Later: reload from disk

dataset_store = LocalStore("store/my_texts") dataset = TextDataset.from_disk(store=dataset_store)

Source code in src/mi_crow/datasets/text_dataset.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
@classmethod
def from_disk(
    cls,
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
) -> "TextDataset":
    """
    Load text dataset from already-saved Arrow files on disk.

    Use this when you've previously saved a dataset and want to reload it
    without re-downloading from HuggingFace or re-applying transformations.

    Args:
        store: Store instance pointing to where the dataset was saved
        loading_strategy: Loading strategy (MEMORY or DISK only)
        text_field: Name of the column containing text

    Returns:
        TextDataset instance loaded from disk

    Raises:
        FileNotFoundError: If dataset directory doesn't exist or contains no Arrow files

    Example:
        # First: save dataset
        dataset_store = LocalStore("store/my_texts")
        dataset = TextDataset.from_huggingface(
            "wikipedia",
            store=dataset_store,
            limit=1000
        )
        # Dataset saved to: store/my_texts/datasets/*.arrow

        # Later: reload from disk
        dataset_store = LocalStore("store/my_texts")
        dataset = TextDataset.from_disk(store=dataset_store)
    """

    if store is None:
        raise ValueError("store cannot be None")

    if loading_strategy == LoadingStrategy.STREAMING:
        raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")

    dataset_dir = Path(store.base_path) / store.dataset_prefix

    if not dataset_dir.exists():
        raise FileNotFoundError(
            f"Dataset directory not found: {dataset_dir}. "
            f"Make sure you've previously saved a dataset to this store location."
        )

    # Verify it's a valid Arrow dataset directory
    arrow_files = list(dataset_dir.glob("*.arrow"))
    if not arrow_files:
        raise FileNotFoundError(
            f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
        )

    try:
        use_memory_mapping = loading_strategy == LoadingStrategy.DISK
        ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
    except Exception as e:
        raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e

    # Create TextDataset with the loaded dataset and field name
    return cls(
        ds,
        store=store,
        loading_strategy=loading_strategy,
        text_field=text_field,
    )

from_huggingface classmethod

from_huggingface(repo_id, store, *, split='train', loading_strategy=LoadingStrategy.MEMORY, revision=None, text_field='text', filters=None, limit=None, stratify_by=None, stratify_seed=None, streaming=None, drop_na=False, **kwargs)

Load text dataset from HuggingFace Hub.

Parameters:

Name Type Description Default
repo_id str

HuggingFace dataset repository ID

required
store Store

Store instance

required
split str

Dataset split

'train'
loading_strategy LoadingStrategy

Loading strategy

MEMORY
revision Optional[str]

Optional git revision

None
text_field str

Name of the column containing text

'text'
filters Optional[Dict[str, Any]]

Optional filters to apply (dict of column: value)

None
limit Optional[int]

Optional limit on number of rows

None
stratify_by Optional[str]

Optional column used for stratified sampling (non-streaming only)

None
stratify_seed Optional[int]

Optional RNG seed for deterministic stratification

None
streaming Optional[bool]

Optional override for streaming

None
drop_na bool

Whether to drop rows with None/empty text

False
**kwargs Any

Additional arguments for load_dataset

{}

Returns:

Type Description
'TextDataset'

TextDataset instance

Raises:

Type Description
ValueError

If parameters are invalid

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/text_dataset.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
@classmethod
def from_huggingface(
    cls,
    repo_id: str,
    store: Store,
    *,
    split: str = "train",
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    revision: Optional[str] = None,
    text_field: str = "text",
    filters: Optional[Dict[str, Any]] = None,
    limit: Optional[int] = None,
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    streaming: Optional[bool] = None,
    drop_na: bool = False,
    **kwargs: Any,
) -> "TextDataset":
    """
    Load text dataset from HuggingFace Hub.

    Args:
        repo_id: HuggingFace dataset repository ID
        store: Store instance
        split: Dataset split
        loading_strategy: Loading strategy
        revision: Optional git revision
        text_field: Name of the column containing text
        filters: Optional filters to apply (dict of column: value)
        limit: Optional limit on number of rows
        stratify_by: Optional column used for stratified sampling (non-streaming only)
        stratify_seed: Optional RNG seed for deterministic stratification
        streaming: Optional override for streaming
        drop_na: Whether to drop rows with None/empty text
        **kwargs: Additional arguments for load_dataset

    Returns:
        TextDataset instance

    Raises:
        ValueError: If parameters are invalid
        RuntimeError: If dataset loading fails
    """
    use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)

    if (stratify_by or drop_na) and use_streaming:
        raise NotImplementedError(
            "Stratification and drop_na are not supported for streaming datasets. Use MEMORY or DISK."
        )

    try:
        ds = load_dataset(
            path=repo_id,
            split=split,
            revision=revision,
            streaming=use_streaming,
            **kwargs,
        )

        if use_streaming:
            if filters or limit:
                raise NotImplementedError(
                    "filters and limit are not supported when streaming datasets. Choose MEMORY or DISK."
                )
        else:
            drop_na_columns = [text_field] if drop_na else None
            ds = cls._postprocess_non_streaming_dataset(
                ds,
                filters=filters,
                limit=limit,
                stratify_by=stratify_by,
                stratify_seed=stratify_seed,
                drop_na_columns=drop_na_columns,
            )
    except Exception as e:
        raise RuntimeError(
            f"Failed to load text dataset from HuggingFace Hub: "
            f"repo_id={repo_id!r}, split={split!r}, text_field={text_field!r}. "
            f"Error: {e}"
        ) from e

    return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)

from_json classmethod

from_json(source, store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', stratify_by=None, stratify_seed=None, drop_na=False, **kwargs)

Load text dataset from JSON/JSONL file.

Parameters:

Name Type Description Default
source Union[str, Path]

Path to JSON or JSONL file

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the field containing text

'text'
stratify_by Optional[str]

Optional column to use for stratified sampling

None
stratify_seed Optional[int]

Optional RNG seed for stratified sampling

None
drop_na bool

Whether to drop rows with None/empty text

False
**kwargs Any

Additional arguments for load_dataset

{}

Returns:

Type Description
'TextDataset'

TextDataset instance

Raises:

Type Description
FileNotFoundError

If JSON file doesn't exist

RuntimeError

If dataset loading fails

Source code in src/mi_crow/datasets/text_dataset.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
@classmethod
def from_json(
    cls,
    source: Union[str, Path],
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    stratify_by: Optional[str] = None,
    stratify_seed: Optional[int] = None,
    drop_na: bool = False,
    **kwargs: Any,
) -> "TextDataset":
    """
    Load text dataset from JSON/JSONL file.

    Args:
        source: Path to JSON or JSONL file
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the field containing text
        stratify_by: Optional column to use for stratified sampling
        stratify_seed: Optional RNG seed for stratified sampling
        drop_na: Whether to drop rows with None/empty text
        **kwargs: Additional arguments for load_dataset

    Returns:
        TextDataset instance

    Raises:
        FileNotFoundError: If JSON file doesn't exist
        RuntimeError: If dataset loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    use_streaming = loading_strategy == LoadingStrategy.STREAMING
    if (stratify_by or drop_na) and use_streaming:
        raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")

    # Load JSON using parent's static method
    ds = cls._load_json_source(
        source,
        streaming=use_streaming,
        **kwargs,
    )

    # Apply postprocessing if not streaming
    if not use_streaming and (stratify_by or drop_na):
        drop_na_columns = [text_field] if drop_na else None
        ds = cls._postprocess_non_streaming_dataset(
            ds,
            stratify_by=stratify_by,
            stratify_seed=stratify_seed,
            drop_na_columns=drop_na_columns,
        )

    return cls(
        ds,
        store=store,
        loading_strategy=loading_strategy,
        text_field=text_field,
    )

from_local classmethod

from_local(source, store, *, loading_strategy=LoadingStrategy.MEMORY, text_field='text', recursive=True)

Load from a local directory or file(s).

Supported
  • Directory of .txt files (each file becomes one example)
  • JSONL/JSON/CSV/TSV files with a text column

Parameters:

Name Type Description Default
source Union[str, Path]

Path to directory or file

required
store Store

Store instance

required
loading_strategy LoadingStrategy

Loading strategy

MEMORY
text_field str

Name of the column/field containing text

'text'
recursive bool

Whether to recursively search directories for .txt files

True

Returns:

Type Description
'TextDataset'

TextDataset instance

Raises:

Type Description
FileNotFoundError

If source path doesn't exist

ValueError

If source is invalid or unsupported file type

RuntimeError

If file operations fail

Source code in src/mi_crow/datasets/text_dataset.py
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@classmethod
def from_local(
    cls,
    source: Union[str, Path],
    store: Store,
    *,
    loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
    text_field: str = "text",
    recursive: bool = True,
) -> "TextDataset":
    """
    Load from a local directory or file(s).

    Supported:
      - Directory of .txt files (each file becomes one example)
      - JSONL/JSON/CSV/TSV files with a text column

    Args:
        source: Path to directory or file
        store: Store instance
        loading_strategy: Loading strategy
        text_field: Name of the column/field containing text
        recursive: Whether to recursively search directories for .txt files

    Returns:
        TextDataset instance

    Raises:
        FileNotFoundError: If source path doesn't exist
        ValueError: If source is invalid or unsupported file type
        RuntimeError: If file operations fail
    """
    p = Path(source)
    if not p.exists():
        raise FileNotFoundError(f"Source path does not exist: {source}")

    if p.is_dir():
        txts: List[str] = []
        pattern = "**/*.txt" if recursive else "*.txt"
        try:
            for fp in sorted(p.glob(pattern)):
                txts.append(fp.read_text(encoding="utf-8", errors="ignore"))
        except OSError as e:
            raise RuntimeError(f"Failed to read text files from directory {source}. Error: {e}") from e

        if not txts:
            raise ValueError(f"No .txt files found in directory: {source} (recursive={recursive})")

        ds = Dataset.from_dict({"text": txts})
    else:
        suffix = p.suffix.lower()
        if suffix in {".jsonl", ".json"}:
            return cls.from_json(
                source,
                store=store,
                loading_strategy=loading_strategy,
                text_field=text_field,
            )
        elif suffix in {".csv"}:
            return cls.from_csv(
                source,
                store=store,
                loading_strategy=loading_strategy,
                text_field=text_field,
            )
        elif suffix in {".tsv"}:
            return cls.from_csv(
                source,
                store=store,
                loading_strategy=loading_strategy,
                text_field=text_field,
                delimiter="\t",
            )
        else:
            raise ValueError(
                f"Unsupported file type: {suffix} for source: {source}. "
                f"Use directory of .txt, or JSON/JSONL/CSV/TSV."
            )

    return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)

get_all_texts

get_all_texts()

Get all texts from the dataset.

Returns:

Type Description
List[Optional[str]]

List of all text strings

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

Source code in src/mi_crow/datasets/text_dataset.py
206
207
208
209
210
211
212
213
214
215
216
217
def get_all_texts(self) -> List[Optional[str]]:
    """Get all texts from the dataset.

    Returns:
        List of all text strings

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        return list(self.iter_items())
    return list(self._ds["text"])

iter_batches

iter_batches(batch_size)

Iterate over text items in batches.

Parameters:

Name Type Description Default
batch_size int

Number of items per batch

required

Yields:

Type Description
List[Optional[str]]

Lists of text strings (batches)

Raises:

Type Description
ValueError

If batch_size <= 0 or text field is not found in any row

Source code in src/mi_crow/datasets/text_dataset.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def iter_batches(self, batch_size: int) -> Iterator[List[Optional[str]]]:
    """
    Iterate over text items in batches.

    Args:
        batch_size: Number of items per batch

    Yields:
        Lists of text strings (batches)

    Raises:
        ValueError: If batch_size <= 0 or text field is not found in any row
    """
    if batch_size <= 0:
        raise ValueError(f"batch_size must be > 0, got: {batch_size}")

    if self._loading_strategy == LoadingStrategy.STREAMING:
        batch = []
        for row in self._ds:
            batch.append(self._extract_text_from_row(row))
            if len(batch) >= batch_size:
                yield batch
                batch = []
        if batch:
            yield batch
    else:
        for batch in self._ds.iter(batch_size=batch_size):
            yield list(batch["text"])

iter_items

iter_items()

Iterate over text items one by one.

Yields:

Type Description
Optional[str]

Text strings from the dataset

Raises:

Type Description
ValueError

If text field is not found in any row

Source code in src/mi_crow/datasets/text_dataset.py
151
152
153
154
155
156
157
158
159
160
161
162
def iter_items(self) -> Iterator[Optional[str]]:
    """
    Iterate over text items one by one.

    Yields:
        Text strings from the dataset

    Raises:
        ValueError: If text field is not found in any row
    """
    for row in self._ds:
        yield self._extract_text_from_row(row)

random_sample

random_sample(n, seed=None)

Create a new TextDataset with n randomly sampled items.

Parameters:

Name Type Description Default
n int

Number of items to sample

required
seed Optional[int]

Optional random seed for reproducibility

None

Returns:

Type Description
'TextDataset'

New TextDataset instance with sampled items

Raises:

Type Description
NotImplementedError

If loading_strategy is STREAMING

ValueError

If n <= 0

Source code in src/mi_crow/datasets/text_dataset.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def random_sample(self, n: int, seed: Optional[int] = None) -> "TextDataset":
    """Create a new TextDataset with n randomly sampled items.

    Args:
        n: Number of items to sample
        seed: Optional random seed for reproducibility

    Returns:
        New TextDataset instance with sampled items

    Raises:
        NotImplementedError: If loading_strategy is STREAMING
        ValueError: If n <= 0
    """
    if self._loading_strategy == LoadingStrategy.STREAMING:
        raise NotImplementedError(
            "random_sample() not supported for STREAMING datasets. Use iter_items() and sample manually."
        )

    if n <= 0:
        raise ValueError(f"n must be > 0, got: {n}")

    dataset_len = len(self)
    if n >= dataset_len:
        if seed is not None:
            random.seed(seed)
        indices = list(range(dataset_len))
        random.shuffle(indices)
        sampled_ds = self._ds.select(indices)
    else:
        if seed is not None:
            random.seed(seed)
        indices = random.sample(range(dataset_len), n)
        sampled_ds = self._ds.select(indices)

    return TextDataset(
        sampled_ds,
        store=self._store,
        loading_strategy=self._loading_strategy,
        text_field=self._text_field,
    )

TextDataset.random_sample()

The TextDataset.random_sample() method creates a new TextDataset instance with randomly sampled items from the original dataset. This is useful for creating smaller subsets of large datasets for testing or training.

Parameters

  • n (int): Number of items to sample. Must be greater than 0.
  • seed (Optional[int]): Optional random seed for reproducibility. If provided, ensures the same random sample is generated across runs.

Returns

A new TextDataset instance containing the randomly sampled items.

Example

from mi_crow.datasets import TextDataset
from mi_crow.store import LocalStore

store = LocalStore(base_path="./store")

# Load a large dataset
dataset = TextDataset.from_huggingface(
    "roneneldan/TinyStories",
    split="train",
    store=store,
    text_field="text"
)

print(f"Original dataset size: {len(dataset)}")  # e.g., 2119719

# Sample 1000 random items
sampled_dataset = dataset.random_sample(1000, seed=42)
print(f"Sampled dataset size: {len(sampled_dataset)}")  # 1000

# Use the sampled dataset for activation saving or training
run_id = lm.activations.save_activations_dataset(
    dataset=sampled_dataset,
    layer_signature="layer_0",
    batch_size=4
)

Notes

  • Works with MEMORY and DISK loading strategies only. Not supported for STREAMING datasets.
  • If n >= len(dataset), returns all items in random order.
  • The method preserves the original dataset's loading strategy, store, and text field configuration.
  • For reproducible results, always specify a seed parameter.