| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from functools import partial |
| from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple |
|
|
| from .processors.feedback import preprocess_feedback_dataset |
| from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example |
| from .processors.pretrain import preprocess_pretrain_dataset |
| from .processors.supervised import ( |
| preprocess_packed_supervised_dataset, |
| preprocess_supervised_dataset, |
| print_supervised_dataset_example, |
| ) |
| from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments |
|
|
| from ..hparams import DataArguments |
| from .template import Template |
|
|
|
|
| def get_preprocess_and_print_func( |
| data_args: "DataArguments", |
| training_args: "Seq2SeqTrainingArguments", |
| stage: Literal["pt", "sft", "rm", "ppo", "kto"], |
| template: "Template", |
| tokenizer: "PreTrainedTokenizer", |
| processor: Optional["ProcessorMixin"], |
| ) -> Tuple[Callable, Callable]: |
| if stage == "pt": |
| preprocess_func = partial( |
| preprocess_pretrain_dataset, |
| tokenizer=tokenizer, |
| data_args=data_args, |
| ) |
| print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) |
| elif stage == "sft" and not training_args.predict_with_generate: |
| if data_args.packing: |
| if data_args.neat_packing: |
| from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence |
|
|
| def __init__(self, data, **kwargs): |
| return TypedSequence.__init__( |
| self, |
| data, |
| type=kwargs.pop("type", None), |
| try_type=kwargs.pop("try_type", None), |
| optimized_int_type=kwargs.pop("optimized_int_type", None), |
| ) |
|
|
| OptimizedTypedSequence.__init__ = __init__ |
| preprocess_func = partial( |
| preprocess_packed_supervised_dataset, |
| template=template, |
| tokenizer=tokenizer, |
| data_args=data_args, |
| ) |
| else: |
| preprocess_func = partial( |
| preprocess_supervised_dataset, |
| template=template, |
| tokenizer=tokenizer, |
| processor=processor, |
| data_args=data_args, |
| ) |
|
|
| print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) |
| elif stage == "rm": |
| preprocess_func = partial( |
| preprocess_pairwise_dataset, |
| template=template, |
| tokenizer=tokenizer, |
| processor=processor, |
| data_args=data_args, |
| ) |
| print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) |
| elif stage == "kto": |
| preprocess_func = partial( |
| preprocess_feedback_dataset, |
| template=template, |
| tokenizer=tokenizer, |
| processor=processor, |
| data_args=data_args, |
| ) |
| print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) |
| else: |
| preprocess_func = partial( |
| preprocess_unsupervised_dataset, |
| template=template, |
| tokenizer=tokenizer, |
| processor=processor, |
| data_args=data_args, |
| ) |
| print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) |
|
|
| return preprocess_func, print_function |
|
|