Рекомендация для читателей:
Прежде чем погрузиться в детали, советую ознакомиться с двумя отличными статьями инженера из Яндекса (статья 1, статья 2). В них отлично объясняются принципы дистилляции, её применение в промышленных задачах и ключевые практические аспекты. Это идеальный старт для тех, кто только начинает знакомиться с темой.
Однако, если вы, как и я, стремитесь к глубокому пониманию — этого может оказаться недостаточно. В данном обзоре мы пойдём дальше:
Математическая формализация: Разберём более глубако уравнения, лежащие в основе дистилляции, включая функцию потерь с температурным параметром, оптимизацию распределений и законы масштабирования из работы Apple.
Примеры кода: Покажем, как реализовать дистилляцию на практике — от простых моделей на PyTorch до тонкой настройки гиперпараметров.
Нюансы исследований: Ответим на вопросы, оставшиеся за рамками вводных материалов. Например, почему «слишком умный учитель» вредит ученику и как математически обосновать оптимальное соотношение их размеров.
Для кого это?
Если вы хотите не просто использовать дистилляцию «из коробки», а понимать, как и почему она работает — этот разбор для вас. Мы заглянем «под капот» методов, чтобы вы могли осознанно применять их в своих проектах.
Knowledge Distillation (Дистилляция знаний) — это метод обучения моделей-студентов (обычно меньшего размера и менее сложных) путем передачи "знаний" от предварительно обученной модели-учителя (обычно большей и более сложной). Основная идея заключается в том, что модель-учитель, обладающая большей емкостью и обученная на большом объеме данных, может передать не только свои "жесткие" предсказания (например, класс объекта), но и более богатую информацию о распределении вероятностей классов, которую модель-студент может использовать для более эффективного обучения.
В парадигме Knowledge Distillation участвуют две основные модели:
Teacher (Учитель): Это большая, предварительно обученная модель, которая считается "экспертом" в решении определенной задачи. Учитель уже достиг высокой точности и обладает "знаниями", которые мы хотим передать студенту. Математически учитель представляется как функция, которая для входных данных xx выдает распределение вероятностей
по классам
.
Student (Студент): Это меньшая, более простая модель, которую мы хотим обучить. Цель студента — научиться имитировать поведение учителя, чтобы достичь сравнимой производительности, но при этом быть более эффективной с точки зрения вычислительных ресурсов, памяти или времени инференса. Студент представляется как функция , где
— параметры модели, которые мы оптимизируем в процессе обучения.
Функция потерь (Loss Function) в Knowledge Distillation:
Общая цель Knowledge Distillation — минимизировать разницу между предсказаниями учителя и студента. Это формализуется через функцию потерь, которая зависит от предсказаний учителя
и студента
. Процесс обучения заключается в поиске оптимальных параметров $\theta$ для студента, которые минимизируют эту функцию потерь:
Это общее выражение, и конкретный вид функции потерь и способ дистилляции определяют различные подходы. Рассмотрим два основных подхода: hard-label и soft-label дистилляцию.
Это общее выражение, и конкретный вид функции потерь и способ дистилляции определяют различные подходы. Рассмотрим два основных подхода: hard-label и soft-label дистилляцию.
Hard-label Distillation для GPT моделей: объяснение на пальцах
Представьте, что у нас есть две модели:
Учитель (Teacher): Большая, мощная GPT модель, например, GPT-3 или что-то подобное. Она обладает огромным количеством знаний о языке и мире, и способна генерировать очень качественный и связный текст.
Студент (Student): Маленькая, более компактная GPT модель, например, уменьшенная версия GPT или Transformer меньшего размера. Она менее ресурсоемкая, но изначально уступает учителю в качестве генерации текста.
Наша цель - "научить" маленькую модель-студента генерировать текст так же хорошо, как и большая модель-учитель, используя метод Hard-label Distillation.
Шаги Hard-label Distillation в этом контексте:
Генерация "жестких" меток учителем (Большой GPT):
Мы берем большой набор текстовых данных (например, обучающую выборку, на которой изначально обучался учитель, или просто большой корпус текстов).
Для каждого фрагмента текста (или запроса) из этого набора, мы просим большую модель-учителя сгенерировать текст. В контексте GPT, это означает, что мы подаем учителю входной текст (например, начало предложения или запрос) и просим его сгенерировать продолжение.
Учитель генерирует последовательность токенов, которые он считает наиболее вероятными для продолжения данного текста. Эти сгенерированные последовательности токенов и являются нашими "жесткими" метками.
Пример:
Входной текст (запрос): "Столица Франции - это"
Учитель (Большая GPT) генерирует: "Париж." (токены: "Па", "ри", "ж", ".")
"Жесткая" метка: Последовательность токенов: ("Па", "ри", "ж", ".")
Мы повторяем этот процесс для большого количества различных входных текстов, получая набор пар: (исходный входной текст, "жесткая" метка - последовательность токенов, сгенерированная учителем).
Обучение студента (Маленький GPT) на "жестких" метках:
Теперь у нас есть синтетический датасет, состоящий из пар (исходный входной текст, "жесткая" метка). Мы будем использовать этот датасет для обучения маленькой модели-студента.
Мы обучаем студента предсказывать "жесткие" метки, сгенерированные учителем, используя стандартную задачу языкового моделирования. Это означает, что для каждого входного текста мы хотим, чтобы студент генерировал последовательность токенов, максимально похожую на "жесткую" метку, сгенерированную учителем.
В процессе обучения мы используем функцию потерь кросс-энтропии. Мы сравниваем распределение вероятностей токенов, предсказанное студентом, с "жесткой" меткой (которая по сути является распределением, где вероятность "правильного" токена равна 1, а всех остальных - 0). Мы стремимся минимизировать эту кросс-энтропию, заставляя студента "подражать" учителю в предсказании токенов.
В нашем примере, если студент на вход "Столица Франции - это" предсказывает, например, "Лондон", то функция потерь будет высокой, так как "жесткая" метка учителя была "Париж". В процессе обучения студент будет корректировать свои параметры, чтобы в будущем для аналогичных запросов предсказывать "Париж" или что-то очень похожее на предсказание учителя.
Почему маленькая модель может предсказывать те же токены, что и большая?
Передача знаний через "жесткие" метки: Хотя Hard-label Distillation и теряет часть информации из распределения вероятностей учителя, она все равно эффективно передает ключевые знания о том, какие токены являются наиболее вероятными в определенных контекстах. Большая модель, будучи хорошо обученной, "знает", какие продолжения текста являются грамматически правильными, семантически уместными и стилистически подходящими. Генерируя "жесткие" метки, она как бы "подсказывает" маленькой модели, какие именно токены нужно предсказывать.
Фокус на наиболее важной информации: "Жесткие" метки концентрируются на наиболее вероятных токенах. В языковом моделировании часто бывает так, что для многих контекстов есть один или несколько доминирующих "правильных" продолжений. Hard-label Distillation помогает маленькой модели быстро освоить эти наиболее важные закономерности, игнорируя менее значимые детали, которые могут быть избыточными для достижения хорошего качества генерации.
Упрощение задачи обучения: Обучение на "жестких" метках превращает дистилляцию в стандартную задачу обучения с учителем. Это упрощает процесс обучения и позволяет использовать хорошо известные методы и оптимизаторы. Маленькой модели не нужно пытаться воспроизвести все тонкости распределения вероятностей учителя, ей достаточно научиться предсказывать наиболее вероятные токены, что является более простой задачей.
Важно отметить ограничения Hard-label Distillation:
Потеря "мягкой" информации: Как и указано в тексте, Hard-label Distillation теряет информацию о вероятностях других классов и "мягких" отношениях между классами. В контексте языковых моделей это означает, что студент может не улавливать все нюансы стиля, семантики и разнообразия, которые присутствуют в распределении вероятностей учителя. Например, учитель может знать, что "Париж" является самым вероятным ответом на "Столица Франции - это", но также понимать, что "Рим" или "Берлин" являются менее вероятными, но все же допустимыми ответами в определенных контекстах. Hard-label Distillation фокусируется только на "Париже", игнорируя эту "мягкую" информацию.
Потенциальное ухудшение разнообразия: Из-за фокусировки на "жестких" метках, студент может стать менее разнообразным в своих генерациях, чем учитель. Он может слишком точно копировать наиболее вероятные ответы учителя, упуская возможность генерировать альтернативные, но все еще качественные варианты.
Математическая формализация:
1. Генерация "жестких" меток учителем: Для каждого примераиз обучающей выборки, учитель
предсказывает распределение вероятностей классов. "Жесткая" метка
выбирается как класс с максимальной вероятностью, предсказанной учителем. В контексте языков моделей, где
представляет собой последовательность токенов, учитель генерирует последовательность "жестких" меток
для
примеров. Здесь
представляет собой последовательность токенов длиной
.
В более простом варианте, для классификации, . В случае последовательностей, учитель может генерировать целые последовательности наиболее вероятных токенов.
2. Обучение студента на "жестких" метках: Студентобучается максимизировать логарифмическую вероятность "жестких" меток, сгенерированных учителем. Это стандартная задача обучения с учителем, где целевыми метками являются
. Функция потерь, которую мы минимизируем (или эквивалентно, максимизируем отрицательную потерю), представляет собой ожидание логарифмической вероятности "жестких" меток под распределением
учителя.
В практической реализации, это ожидание аппроксимируется эмпирическим средним по обучающей выборке. Для последовательностей текста, функция потерь выглядит следующим образом:
Здесь:
* — количество примеров в обучающей выборке.
* — длина последовательности для $n$-го примера.
* —
-й токен в последовательности "жестких" меток для
-го примера, сгенерированных учителем.
* — префикс последовательности до
-го токена.
* — вероятность предсказания студентом
-го токена
при условии предыдущих токенов
, параметризованная
.
Эта функция потерь представляет собой кросс-энтропию между распределением "жестких" меток, сгенерированных учителем, и предсказаниями студента. Мы стремимся максимизировать эту величину, что эквивалентно минимизации отрицательной логарифмической правдоподобности или кросс-энтропии.
Преимущества и недостатки Hard-label Distillation:
Преимущества: Простота реализации и понимания. Можно использовать стандартные методы обучения с учителем.
Недостатки: Потеря информации, содержащейся в распределении вероятностей учителя. "Жесткие" метки содержат только информацию о наиболее вероятном классе, игнорируя вероятности других классов и "мягкие" отношения между классами, которые учитель "знает". Это может ограничить эффективность передачи знаний.
Ниже представлена реализация Hard-label Distillation с использованием подхода, применяемого в проекте Open R1. Процесс разделен на два этапа: генерация данных учителем и обучение ученика.
@misc{openr1,
title = {Open R1: A fully open reproduction of DeepSeek-R1},
url = {https://github.com/huggingface/open-r1},
author = {Hugging Face},
month = {January},
year = {2025}
}
import argparse
from datasets import load_dataset
from typing import Optional, Dict, Any
from distilabel.pipeline import Pipeline
from distilabel.models import vLLM
from distilabel.steps.tasks import TextGeneration
def build_hard_label_pipeline(
teacher_model: str,
base_url: str = "http://localhost:8000/v1",
prompt_column: Optional[str] = None,
prompt_template: str = "{{ instruction }}",
temperature: float = 0.0,
max_new_tokens: int = 4096,
input_batch_size: int = 32,
) -> Pipeline:
"""
Description:
---------------
Создает конвейер для генерации "жестких" меток с использованием модели-учителя.
Args:
---------------
teacher_model: Идентификатор модели-учителя
base_url: URL сервера vLLM
prompt_column: Имя колонки в датасете, содержащей входные тексты
prompt_template: Шаблон для форматирования промптов
temperature: Температура для генерации (0.0 для "жестких" меток)
max_new_tokens: Максимальное количество генерируемых токенов
input_batch_size: Размер батча для входных данных
Returns:
---------------
Настроенный конвейер Distilabel
Raises:
---------------
Exception: В случае ошибки настройки конвейера
Examples:
---------------
>>> pipeline = build_hard_label_pipeline("deepseek-ai/DeepSeek-R1")
>>> pipeline.run(dataset)
"""
# Настраиваем параметры генерации с temperature=0 для получения детерминированных ответов
generation_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": 1.0,
"do_sample": False, # Отключаем семплирование для получения "жестких" меток
}
with Pipeline(
name="hard-label-distillation",
description="Конвейер для генерации 'жестких' меток с использованием модели-учителя",
) as pipeline:
# Настраиваем модель-учителя через vLLM
teacher = vLLM(
model=teacher_model,
tokenizer=teacher_model,
extra_kwargs={
"tensor_parallel_size": 1, # Можно увеличить для больших моделей
"max_model_len": max_new_tokens + 2048, # Добавляем запас для контекста
},
generation_kwargs=generation_kwargs,
)
# Настраиваем шаг генерации текста
text_generation = TextGeneration(
llm=teacher,
template=prompt_template,
num_generations=1, # Для "жестких" меток нам нужна только одна генерация
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
input_batch_size=input_batch_size,
)
return pipeline
def generate_hard_labels(
dataset_name: str,
dataset_split: str = "train",
teacher_model: str = "deepseek-ai/DeepSeek-R1",
output_dataset: str = "my-username/hard-label-distill-dataset",
prompt_column: str = "problem",
prompt_template: str = "You will be given a problem. Please reason step by step, and put your final answer within \\boxed{}: {{ instruction }}",
max_examples: Optional[int] = None,
private: bool = False,
) -> Any:
"""
Description:
---------------
Генерирует "жесткие" метки с использованием модели-учителя и сохраняет результаты как набор данных на HuggingFace Hub.
Args:
---------------
dataset_name: Имя исходного датасета
dataset_split: Имя сплита датасета
teacher_model: Модель-учитель для генерации "жестких" меток
output_dataset: Имя выходного датасета на HuggingFace Hub
prompt_column: Имя колонки, содержащей входные данные
prompt_template: Шаблон для форматирования промптов
max_examples: Максимальное количество примеров для обработки
private: Приватный ли выходной датасет
Returns:
---------------
Датасет с "жесткими" метками
Raises:
---------------
Exception: В случае ошибки генерации меток
Examples:
---------------
>>> hard_label_dataset = generate_hard_labels("my-dataset", "train")
>>> hard_label_dataset.push_to_hub("my-username/hard-label-dataset")
"""
# Загружаем исходный датасет
print(f"Загрузка датасета '{dataset_name}' (сплит: {dataset_split})...")
dataset = load_dataset(dataset_name, split=dataset_split)
# Ограничиваем количество примеров, если указано
if max_examples is not None and max_examples < len(dataset):
dataset = dataset.select(range(max_examples))
print(f"Создание конвейера для генерации 'жестких' меток с использованием {teacher_model}...")
pipeline = build_hard_label_pipeline(
teacher_model=teacher_model,
prompt_column=prompt_column,
prompt_template=prompt_template,
)
print(f"Запуск конвейера для генерации 'жестких' меток на {len(dataset)} примерах...")
# Генерируем "жесткие" метки
hard_label_dataset = pipeline.run(dataset=dataset)
# Сохраняем результаты на HuggingFace Hub
if output_dataset:
print(f"Сохранение результатов в '{output_dataset}'...")
hard_label_dataset.push_to_hub(output_dataset, private=private)
print(f"Датасет с 'жесткими' метками успешно сохранен в '{output_dataset}'.")
return hard_label_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Генерация 'жестких' меток с использованием модели-учителя")
parser.add_argument("--dataset", type=str, required=True, help="Имя исходного датасета")
parser.add_argument("--split", type=str, default="train", help="Сплит датасета")
parser.add_argument("--teacher-model", type=str, default="deepseek-ai/DeepSeek-R1", help="Модель-учитель")
parser.add_argument("--output-dataset", type=str, required=True, help="Имя выходного датасета")
parser.add_argument("--prompt-column", type=str, default="problem", help="Колонка с входными данными")
parser.add_argument("--prompt-template", type=str,
default="You will be given a problem. Please reason step by step, and put your final answer within \\boxed{}: {{ instruction }}",
help="Шаблон для форматирования промптов")
parser.add_argument("--max-examples", type=int, default=None, help="Максимальное количество примеров")
parser.add_argument("--private", action="store_true", help="Сделать выходной датасет приватным")
args = parser.parse_args()
generate_hard_labels(
dataset_name=args.dataset,
dataset_split=args.split,
teacher_model=args.teacher_model,
output_dataset=args.output_dataset,
prompt_column=args.prompt_column,
prompt_template=args.prompt_template,
max_examples=args.max_examples,
private=args.private,
)
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import AutoTokenizer, set_seed
from transformers.trainer_utils import get_last_checkpoint
from trl import SFTTrainer, ModelConfig, TrlParser, get_peft_config
from open_r1.configs import SFTConfig
from open_r1.utils.wandb_logging import init_wandb_training
logger = logging.getLogger(__name__)
@dataclass
class HardLabelDistillConfig(SFTConfig):
"""Конфигурация для обучения ученика с использованием Hard-label Distillation."""
dataset_name: str = field(
default=None, metadata={"help": "Датасет с 'жесткими' метками, сгенерированными учителем"}
)
input_column: str = field(
default="problem", metadata={"help": "Колонка с входными данными"}
)
target_column: str = field(
default="generation_0", metadata={"help": "Колонка с выходными данными (жесткими метками) учителя"}
)
max_seq_length: int = field(
default=2048, metadata={"help": "Максимальная длина последовательности"}
)
def train_student_model(config: HardLabelDistillConfig, model_args: ModelConfig) -> None:
"""
Description:
---------------
Обучает модель-ученика на 'жестких' метках, сгенерированных учителем.
Args:
---------------
config: Конфигурация обучения
model_args: Конфигурация модели
Returns:
---------------
None
Raises:
---------------
Exception: В случае ошибки обучения модели
Examples:
---------------
>>> train_student_model(config, model_args)
"""
# Настраиваем логирование
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = config.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
# Устанавливаем сид для воспроизводимости
set_seed(config.seed)
# Проверяем наличие последнего чекпоинта
last_checkpoint: Optional[str] = None
if os.path.isdir(config.output_dir):
last_checkpoint = get_last_checkpoint(config.output_dir)
if last_checkpoint is not None:
logger.info(f"Найден чекпоинт, продолжаем обучение с {last_checkpoint}")
# Инициализируем Weights & Biases, если нужно
if "wandb" in config.report_to:
init_wandb_training(config)
# Загружаем датасет с 'жесткими' метками
logger.info(f"Загрузка датасета с 'жесткими' метками: {config.dataset_name}")
dataset = load_dataset(config.dataset_name)
# Подготавливаем входные данные и метки для обучения
def prepare_dataset(examples: Dict[str, Any]) -> Dict[str, Any]:
"""Форматирует данные для обучения с учителем."""
return {
"input_ids": examples[config.input_column],
"labels": examples[config.target_column],
}
# Трансформируем датасет
dataset = dataset.map(prepare_dataset, batched=True)
# Загружаем токенизатор
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
# Настраиваем chat_template, если указан
if config.chat_template is not None:
tokenizer.chat_template = config.chat_template
# Настраиваем параметры модели
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs: Dict[str, Any] = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch_dtype,
use_cache=False if config.gradient_checkpointing else True,
)
config.model_init_kwargs = model_kwargs
# Создаем SFT тренер
trainer = SFTTrainer(
model=model_args.model_name_or_path,
args=config,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"] if "validation" in dataset and config.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
# Запускаем обучение
logger.info("Начало обучения модели-ученика...")
checkpoint: Optional[str] = None
if config.resume_from_checkpoint is not None:
checkpoint = config.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Сохраняем модель
logger.info(f"Сохранение модели в {config.output_dir}")
trainer.save_model(config.output_dir)
# Создаем карточку модели и загружаем на HuggingFace Hub, если нужно
kwargs: Dict[str, Any] = {
"dataset_name": config.dataset_name,
"tags": ["hard-label-distillation", "open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Восстанавливаем кэш для быстрого инференса
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(config.output_dir)
# Оцениваем модель, если нужно
if config.do_eval and "validation" in dataset:
logger.info("Оценка модели...")
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Загружаем модель на HuggingFace Hub, если нужно
if config.push_to_hub:
logger.info("Загрузка модели на HuggingFace Hub...")
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
# Создаем парсер аргументов
parser = TrlParser((HardLabelDistillConfig, ModelConfig))
config, model_args = parser.parse_args_and_config()
# Запускаем обучение
train_student_model(config, model_args)
# Этап 1: Генерация "жестких" меток с использованием модели-учителя
python hard_label_distill.py \
--dataset AI-MO/NuminaMath-TIR \
--teacher-model deepseek-ai/DeepSeek-R1 \
--output-dataset username/hard-label-math-dataset \
--prompt-column problem
# Этап 2: Обучение модели-ученика на сгенерированных "жестких" метках
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml train_student.py \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name username/hard-label-math-dataset \
--input_column problem \
--target_column generation_0 \
--learning_rate 1.0e-5 \
--num_train_epochs 2 \
--packing \
--max_seq_length 4096 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 4 \
--gradient_checkpointing \
--bf16 \
--output_dir models/Qwen2.5-1.5B-Hard-Label-Distill
Концепция:
Soft-label distillation, предложенная Хинтоном и соавторами в их знаменитой статье "Distilling the Knowledge in a Neural Network" (2015), является более совершенным методом дистилляции знаний. В отличие от Hard-label distillation, этот подход использует не только "жесткие" метки, но и полное распределение вероятностей, предсказанное учителем, в качестве "мягких" меток (soft labels).
"Мягкие" метки содержат значительно больше информации, чем "жесткие", поскольку они отражают уверенность учителя в различных классах и отношения между ними. Например, учитель может предсказать для изображения собаки вероятности [0.8 для "собака", 0.15 для "волк", 0.03 для "лиса", 0.02 для других классов]. Эта информация гораздо богаче, чем просто метка "собака".
Ключевым компонентом метода является "temperature scaling" (масштабирование температуры), который делает распределение вероятностей более "мягким" и информативным путем деления логитов модели на параметр температуры T > 1.
Soft-label Distillation для GPT моделей: объяснение на пальцах
Представьте, что у нас есть две модели:
Учитель (Teacher): Большая, мощная GPT модель с 175 миллиардами параметров. Она обладает глубоким пониманием языка и мира.
Студент (Student): Компактная GPT модель с 1.5 миллиардами параметров. Намного быстрее и экономичнее, но изначально уступает учителю в качестве.
Наша цель - научить студента генерировать текст так же хорошо, как учитель, используя Soft-label Distillation.
Шаги Soft-label Distillation:
Генерация "мягких" меток учителем:
Для запроса "Столица Франции - это" большая модель-учитель не просто выдает "Париж", но вычисляет вероятности для всех возможных следующих токенов:
"Париж": 0.92
"город": 0.03
"Рим": 0.01
... (и тысячи других токенов с малыми вероятностями)
Проблема: это распределение слишком "острое" - один токен имеет почти всю вероятность. Чтобы извлечь больше полезных знаний, применяем temperature scaling:
Делим логиты на температуру T (например, T = 2.0) перед применением softmax:
"Париж": 0.70 (уменьшилось с 0.92)
"город": 0.08 (увеличилось с 0.03)
"Рим": 0.05 (увеличилось с 0.01)
... (другие токены тоже получают больше вероятности)
Эти "смягченные" распределения сохраняют намного больше информации о том, что модель-учитель "знает".
Обучение модели-студента:
Студент обучается не только предсказывать правильный токен, но и воспроизводить всё распределение вероятностей учителя.
Для этого используется КЛ-дивергенция (или кросс-энтропия) между распределениями учителя и студента.
Важно: распределение студента также "смягчается" с той же температурой T для сопоставимости.
Функция потерь умножается на T² для компенсации уменьшения градиентов.
Комбинированное обучение:
Обычно используется комбинация двух функций потерь:
α · (Потери от "мягких" меток) + (1-α) · (Стандартные потери от "жестких" меток)
Где α - коэффициент, обычно от 0.5 до 0.9
Почему это работает лучше Hard-label Distillation?
"Темные знания" (Dark Knowledge): Как назвал Хинтон, относительные вероятности "неправильных" ответов содержат ценную информацию. Например, если модель путает "собаку" с "волком", но не с "самолетом", это важная информация.
Передача неопределенности: Студент учится не только правильным ответам, но и тому, в каких случаях стоит сомневаться.
Более богатый сигнал: Вместо одного бита информации на каждый пример (правильный/неправильный класс), студент получает информацию о всем распределении вероятностей.
Математическая формализация:
1. "Мягкие" метки учителя с температурой T:
Если- логит для класса (токена)
от учителя, то "мягкая" метка с температурой T:
Разберем каждый элемент формулы:
* : Это "мягкая" вероятность для
-го токена, с учетом температуры
. Именно это распределение вероятностей, сгенерированное учителем, мы будем использовать как "мягкую метку".
* : Это логит (logit) для
-го токена, выданный моделью-учителем. Логиты - это значения, которые модель выдает перед применением функции softmax. Они представляют собой "сырые" оценки того, насколько модель уверена в каждом токене. Чем больше логит, тем больше уверенность модели в этом токене.
* : Это параметр температуры (temperature). Как мы разбирали уже выше, температура используется для "смягчения" распределения вероятностей.
* : Это экспоненциальная функция
.
* : Это сумма экспоненциальных значений логитов, деленных на температуру, для всех возможных токенов
. Эта сумма используется для нормализации, чтобы вероятности в итоге суммировались к 1.
Пошаговое объяснение:
1. Деление логитов на температуру: Когда мы делим логиты на температуру
, мы уменьшаем абсолютные значения логитов.
2. Экспоненцирование: Экспоненциальная функция преобразует логиты в положительные значения.
3. Нормализация: Деление на сумму экспоненциальных значений всех логитов гарантирует, что полученные значения
будут представлять собой вероятностное распределение, то есть будут неотрицательными и в сумме дадут 1. Это стандартная операция softmax, но с применением температуры.
Интуиция и эффект температуры:
* При высокой температуре (например, T = 2.0), распределение вероятностей становится более "мягким" или "ровным". Вероятности для менее вероятных токенов увеличиваются, а вероятность наиболее вероятного токена уменьшается. Это позволяет "вытащить" больше информации из распределения, включая "темные знания" о менее вероятных, но все же релевантных вариантах.
* При низкой температуре (приближающейся к T = 1.0, или даже меньше), распределение становится более "острым". Вероятность наиболее вероятного токена приближается к 1, а вероятности остальных токенов стремятся к 0. При T=1 это стандартный softmax. При распределение становится дельта-функцией, выбирая только токен с наибольшим логитом.
2. Аналогично для студента:
где - логит студента для класса
.
Аналогия с формулой учителя: Эта формула абсолютно идентична формуле для учителя, за исключением того, что здесь используются логиты, выданные моделью-студентом.
* : "Мягкая" вероятность для
-го токена, сгенерированная студентом с температурой
.
* : Логит для
-го токена, выданный моделью-студентом.
* Цель: Мы применяем ту же температуру к распределению студента, чтобы сделать его сопоставимым с "мягкими" метками учителя. Это необходимо для корректного расчета функции потерь дистилляции.
3. Функция потерь для Soft-label Distillation:
Множитель компенсирует уменьшение градиентов из-за temperature scaling.
Разберем компоненты:
* : Функция потерь Soft-label Distillation. Это значение, которое мы хотим минимизировать в процессе обучения студента.
* : Квадрат температуры. Этот множитель используется для масштабирования функции потерь и компенсации уменьшения градиентов, вызванного температурой.
* : KL-дивергенция ( Kullback-Leibler divergence) между распределением учителя $p^T$ и распределением студента
.
* : Это развернутая формула KL-дивергенции для дискретных распределений.
Пошаговое объяснение KL-дивергенции:
1. : Отношение вероятности учителя к вероятности студента для каждого токена
. Если студент предсказывает вероятность
близкую к вероятности учителя
, это отношение будет близко к 1.
2. : Логарифм этого отношения. Если отношение близко к 1, логарифм будет близок к 0. Если
сильно отличается от
, логарифм будет иметь большее абсолютное значение (отрицательное, если
, и положительное, если
).
3. : Умножение на
взвешивает вклад каждого токена в общую дивергенцию. Токены, которые учитель считает более вероятными (высокое
), вносят больший вклад в функцию потерь.
4. : Суммирование по всем токенам
дает общую KL-дивергенцию. KL-дивергенция измеряет "расстояние" между двумя распределениями вероятностей. В контексте дистилляции, она измеряет, насколько распределение студента
отличается от распределения учителя
.
Роль:
* Применение температуры "смягчает" распределения, что может привести к уменьшению величины градиентов при обучении. Умножение на
масштабирует функцию потерь, чтобы компенсировать это уменьшение и сделать градиенты более значимыми, особенно на ранних этапах обучения. Это эмпирическая коррекция, которая помогает стабилизировать и ускорить обучение.
* Цель: Минимизируя
, мы заставляем распределение вероятностей студента
максимально приблизиться к распределению вероятностей учителя
. Студент учится не только предсказывать "правильный" токен, но и имитировать всю "манеру мышления" учителя, выраженную в распределении вероятностей.
4. Комбинированная функция потерь:
где - стандартная кросс-энтропия с истинными метками,
- коэффициент баланса.
Разберем компоненты:
* : Общая функция потерь, используемая для обучения студента.
* : Коэффициент баланса (обычно от 0.5 до 0.9). Он определяет, насколько сильно мы полагаемся на "мягкие" метки учителя по сравнению со стандартными "жесткими" метками.
* : Функция потерь Soft-label Distillation, которую мы разобрали выше.
* : Стандартная функция потерь "жестких" меток, обычно кросс-энтропия между предсказаниями студента и истинными (one-hot) метками.
(Стандартные потери "жестких" меток):
* В обычной задаче обучения языковой модели, мы имеем "жесткие" метки - это истинные следующие токены в обучающих данных. Например, для фразы "Столица Франции - это Париж", "Париж" является "жесткой" меткой.
* вычисляется как кросс-энтропия между распределением вероятностей, предсказанным студентом (обычно с
, то есть стандартный softmax), и one-hot вектором, представляющим истинный токен. Эта функция потерь заставляет студента предсказывать именно "правильный" токен.
Комбинирование и
:
* Комбинирование "мягких" и "жестких" потерь позволяет студенту учиться как у учителя (через), так и из исходных данных (через
).
* Коэффициент позволяет настроить баланс.
* Высокое (например, 0.9) означает, что мы больше полагаемся на знания учителя, переданные через "мягкие" метки. Это может быть полезно, когда учитель обладает значительно лучшими знаниями, чем можно извлечь только из "жестких" меток.
* Низкое (например, 0.5) означает, что мы в равной степени учитываем как знания учителя, так и "жесткие" метки. Это может быть полезно, когда мы хотим, чтобы студент сохранил способность хорошо работать и на исходных данных, а не только имитировал учителя.
Практическая реализация Soft-label Distillation для GPT моделей
Программный код был заимствован из репозитория: https://github.com/arcee-ai/DistillKit
1. Конфигурация дистилляции
Первым шагом необходимо настроить параметры дистилляции, включая температуру и коэффициент баланса между мягкими и жесткими метками:
"""
Здесь temperature: 2.0 соответствует параметру T в формулах, который "смягчает" распределение вероятностей, а alpha: 0.5 - это коэффициент α, который определяет соотношение между потерями от мягких и жестких меток.
"""
config = {
"project_name": "distil-multilayer", # Название проекта
"dataset": {
"name": "mlabonne/FineTome-100k", # Название датасета
"split": "train", # Раздел датасета для тренировки
"num_samples": 1000, # Количество образцов для тренировки (можно ограничить)
"seed": 42 # Значение для инициализации генератора случайных чисел
},
"models": {
"teacher": "arcee-ai/Arcee-Spark", # Модель учителя
"student": "Qwen/Qwen2-1.5B" # Модель студента
},
"tokenizer": {
"max_length": 4096, # Максимальная длина токенов
"chat_template": (
"{% for message in messages %}"
"{% if loop.first and messages[0]['role'] != 'system' %}"
"{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}"
"{% endif %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
) # Шаблон для форматирования сообщений в чате
},
"training": {
"output_dir": "./results", # Директория для сохранения результатов
"num_train_epochs": 3, # Количество эпох для тренировки
"per_device_train_batch_size": 1, # Размер батча для тренировки на одном устройстве
"gradient_accumulation_steps": 8, # Количество шагов для накопления градиентов
"save_steps": 1000, # Шаги для сохранения модели
"logging_steps": 2, # Шаги для логирования
"save_total_limit": 2, # Лимит на количество сохраняемых моделей
"learning_rate": 2e-5, # Скорость обучения
"weight_decay": 0.01, # Коэффициент регуляризации
"warmup_ratio": 0.2, # Доля шагов для разгона скорости обучения
"lr_scheduler_type": "linear", # Тип планировщика скорости обучения
"resume_from_checkpoint": None, # Путь к чекпоинту для возобновления тренировки (если есть)
"fp16": False, # Использовать ли 16-битное число с плавающей точкой
"bf16": True, # Использовать ли BFloat16
"max_grad_norm": 1.0, # Максимальная норма градиента
"group_by_length": False # Группировать ли батчи по длине
},
"distillation": {
"temperature": 2.0, # Температура для дистилляции
"alpha": 0.5 # Коэффициент альфа для дистилляции
},
"model_config": {
"use_flash_attention": True # Использовать ли Flash Attention
}
}
2. Подготовка моделей учителя и студента
Для дистилляции необходимо загрузить как модель-учитель (более крупную), так и модель-студент (более компактную):
import torch
from typing import Dict, Any
from transformers import AutoModelForCausalLM
def load_models_with_flash_attention(config: Dict[str, Any]) -> Dict[str, AutoModelForCausalLM]:
"""
Description:
---------------
Загружает модели с настройкой флеш-внимания для ускорения.
Args:
---------------
config: Конфигурация моделей и параметров
Returns:
---------------
Словарь с загруженными моделями
Raises:
---------------
KeyError: Если в конфигурации отсутствуют необходимые ключи
Examples:
---------------
>>> config = {
... "model_config": {"use_flash_attention": True},
... "models": {"teacher": "teacher_model_path", "student": "student_model_path"}
... }
>>> load_models_with_flash_attention(config)
{'teacher_model': <transformers.models.model_name.model.ModelName object>,
'student_model': <transformers.models.model_name.model.ModelName object>}
"""
# Настройки для загрузки моделей
model_kwargs: Dict[str, Any] = {"torch_dtype": torch.bfloat16}
# Проверка на использование flash attention
if config["model_config"]["use_flash_attention"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
# Загрузка моделей
teacher_model = AutoModelForCausalLM.from_pretrained(config["models"]["teacher"], **model_kwargs)
student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], **model_kwargs)
return {"teacher_model": teacher_model, "student_model": student_model}
# Вызов функции
models = load_models_with_flash_attention(config)
# Теперь models содержит загруженные модели
teacher_model = models["teacher_model"]
student_model = models["student_model"]
3. Реализация функции потерь с мягкими метками
Ключевым компонентом является функция потерь Soft-label Distillation. Рассмотрим её реализацию из файла distil_logits.py:
"""
Это прямая реализация формулы KL-дивергенции. Обратите внимание на следующие ключевые моменты:
1. Логиты масштабируются температурой T перед применением функций softmax/log_softmax.
2. Потери умножаются на T² для компенсации уменьшения градиентов, как описано в теории.
3. Финальная функция потерь комбинирует мягкие метки (KL-дивергенция) и жесткие метки (original_loss) с коэффициентом α.
"""
from typing import Any
import torch
import torch.nn.functional as F
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
inputs: Any,
original_loss: torch.Tensor,
config: Dict[str, Any]
) -> torch.Tensor:
"""
Description:
---------------
Вычисляет потери дистилляции между логитами студента и учителя.
Args:
---------------
student_logits: Логиты студента.
teacher_logits: Логиты учителя.
inputs: Входные данные.
original_loss: Исходные потери.
config: Конфигурация моделей и параметров.
Returns:
---------------
Общие потери, включающие дистилляционные потери и исходные потери.
Raises:
---------------
KeyError: Если в конфигурации отсутствуют необходимые ключи.
Examples:
---------------
>>> config = {
... "distillation": {"temperature": 2.0, "alpha": 0.5},
... "tokenizer": {"max_length": 512}
... }
>>> student_logits = torch.randn(3, 512)
>>> teacher_logits = torch.randn(3, 512)
>>> inputs = ...
>>> original_loss = torch.tensor(0.5)
>>> distillation_loss(self, student_logits, teacher_logits, inputs, original_loss, config)
tensor(0.25)
"""
# Приведение размерностей логитов учителя и студента к одинаковому размеру
student_logits, teacher_logits = pad_logits(
student_logits.to(self.model.device),
teacher_logits.to(self.model.device)
)
# Масштабирование логитов с помощью температуры T
temperature = config["distillation"]["temperature"]
student_logits_scaled = student_logits / temperature
teacher_logits_scaled = teacher_logits / temperature
# Расчёт KL-дивергенции между распределениями учителя и студента
loss_kd = F.kl_div(
F.log_softmax(student_logits_scaled, dim=-1), # log(q_i^T)
F.softmax(teacher_logits_scaled, dim=-1), # p_i^T
reduction='batchmean'
) * (temperature ** 2) / config["tokenizer"]["max_length"]
# Комбинирование потерь от мягких и жестких меток
alpha = config["distillation"]["alpha"]
total_loss = alpha * loss_kd + (1 - alpha) * original_loss
return total_loss
4. Обработка различных размеров словарей
Поскольку модели учителя и студента могут иметь разный размер словаря токенов, необходима дополнительная функция для согласования размерности их логитов:
"""
Эта функция добавляет нулевые логиты к меньшему распределению, чтобы обеспечить одинаковую размерность для сравнения распределений.
"""
from typing import Tuple
import torch
def pad_logits(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Description:
---------------
Приводит размерности логитов студента и учителя к одинаковому размеру.
Args:
---------------
student_logits: Логиты студента.
teacher_logits: Логиты учителя.
Returns:
---------------
Кортеж из логитов студента и учителя с одинаковыми размерностями.
Raises:
---------------
ValueError: Если размерности логитов не совпадают и не могут быть приведены к одинаковому размеру.
Examples:
---------------
>>> student_logits = torch.randn(3, 512)
>>> teacher_logits = torch.randn(3, 510)
>>> pad_logits(student_logits, teacher_logits)
(tensor([...]), tensor([...]))
"""
# Определение размеров логитов
student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)
# Если размеры не совпадают, добавляем паддинг
if student_size != teacher_size:
pad_size = abs(student_size - teacher_size)
pad_tensor = torch.zeros(
(*teacher_logits.shape[:-1], pad_size),
dtype=teacher_logits.dtype,
device=teacher_logits.device
)
# Возвращаем логиты с добавленным паддингом
if student_size < teacher_size:
return torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits
else:
return student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1)
# Возвращаем логиты без изменений, если размеры совпадают
return student_logits, teacher_logits
5. Кастомный тренер для дистилляции
Для интеграции процесса дистилляции в процесс обучения создаётся специальный класс тренера, который переопределяет функцию вычисления потерь:
"""
Этот класс:
1. Получает выходы (логиты) как от студента, так и от учителя
2. Замораживает веса учителя с помощью `torch.no_grad()`
3. Вычисляет комбинированную функцию потерь с использованием потерь от мягких и жестких меток
"""
from typing import Dict, Any, Union, Tuple
import torch
import torch.nn.functional as F
from transformers import SFTTrainer
class LogitsTrainer(SFTTrainer):
"""
Description:
---------------
Класс для обучения модели с использованием дистилляции логитов.
"""
def compute_loss(
self,
model: torch.nn.Module,
inputs: Dict[str, Any],
return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
"""
Description:
---------------
Вычисляет комбинированную функцию потерь для модели студента и учителя.
Args:
---------------
model: Модель студента.
inputs: Входные данные.
return_outputs: Флаг для возврата выходов модели.
Returns:
---------------
Комбинированная функция потерь и, если указано, выходы модели.
Raises:
---------------
ValueError: Если входные данные не соответствуют ожидаемым.
Examples:
---------------
>>> model = ...
>>> inputs = ...
>>> trainer = LogitsTrainer()
>>> trainer.compute_loss(model, inputs, return_outputs=True)
(tensor(0.5), ...)
"""
# Перемещение входных данных на устройство модели
inputs = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
# Перемещение модели учителя на устройство модели
self.teacher_model = self.teacher_model.to(model.device)
# Получение модулей моделей, если они существуют
student_model = model.module if hasattr(model, 'module') else model
teacher_model = self.teacher_model.module if hasattr(self.teacher_model, 'module') else self.teacher_model
# Получение выходов моделей
student_outputs = student_model(**inputs)
with torch.no_grad(): # Учитель не обучается
teacher_outputs = teacher_model(**inputs)
# Вычисление комбинированной функции потерь
custom_loss = self.distillation_loss(
student_outputs.logits,
teacher_outputs.logits,
inputs,
student_outputs.loss
)
# Возврат потерь и выходов модели, если указано
if return_outputs:
return custom_loss, student_outputs
return custom_loss
def pad_logits(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Description:
---------------
Приводит размерности логитов студента и учителя к одинаковому размеру.
Args:
---------------
student_logits: Логиты студента.
teacher_logits: Логиты учителя.
Returns:
---------------
Кортеж из логитов студента и учителя с одинаковыми размерностями.
Raises:
---------------
ValueError: Если размерности логитов не совпадают и не могут быть приведены к одинаковому размеру.
Examples:
---------------
>>> student_logits = torch.randn(3, 512)
>>> teacher_logits = torch.randn(3, 510)
>>> trainer = LogitsTrainer()
>>> trainer.pad_logits(student_logits, teacher_logits)
(tensor([...]), tensor([...]))
"""
# Определение размеров логитов
student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)
# Если размеры не совпадают, добавляем паддинг
if student_size != teacher_size:
pad_size = abs(student_size - teacher_size)
pad_tensor = torch.zeros(
(*teacher_logits.shape[:-1], pad_size),
dtype=teacher_logits.dtype,
device=teacher_logits.device
)
# Возвращаем логиты с добавленным паддингом
if student_size < teacher_size:
return torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits
else:
return student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1)
# Возвращаем логиты без изменений, если размеры совпадают
return student_logits, teacher_logits
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
inputs: Any,
original_loss: torch.Tensor
) -> torch.Tensor:
"""
Description:
---------------
Вычисляет потери дистилляции между логитами студента и учителя.
Args:
---------------
student_logits: Логиты студента.
teacher_logits: Логиты учителя.
inputs: Входные данные.
original_loss: Исходные потери.
Returns:
---------------
Общие потери, включающие дистилляционные потери и исходные потери.
Raises:
---------------
KeyError: Если в конфигурации отсутствуют необходимые ключи.
Examples:
---------------
>>> config = {
... "distillation": {"temperature": 2.0, "alpha": 0.5},
... "tokenizer": {"max_length": 512}
... }
>>> student_logits = torch.randn(3, 512)
>>> teacher_logits = torch.randn(3, 512)
>>> inputs = ...
>>> original_loss = torch.tensor(0.5)
>>> trainer = LogitsTrainer()
>>> trainer.distillation_loss(student_logits, teacher_logits, inputs, original_loss)
tensor(0.25)
"""
# Приведение размерностей логитов учителя и студента к одинаковому размеру
student_logits, teacher_logits = self.pad_logits(
student_logits.to(self.model.device),
teacher_logits.to(self.model.device)
)
# Масштабирование логитов с помощью температуры T
temperature = config["distillation"]["temperature"]
student_logits_scaled = student_logits / temperature
teacher_logits_scaled = teacher_logits / temperature
# Расчёт KL-дивергенции между распределениями учителя и студента
loss_kd = F.kl_div(
F.log_softmax(student_logits_scaled, dim=-1), # log(q_i^T)
F.softmax(teacher_logits_scaled, dim=-1), # p_i^T
reduction='batchmean'
) * (temperature ** 2) / config["tokenizer"]["max_length"]
# Комбинирование потерь от мягких и жестких меток
alpha = config["distillation"]["alpha"]
total_loss = alpha * loss_kd + (1 - alpha) * original_loss
return total_loss
6. Подготовка тренера и запуск обучения
После определения всех компонентов можно инициализировать тренер и запустить процесс дистилляции:
"""
Обратите внимание, что модель-учитель добавляется к тренеру как атрибут, чтобы она была доступна внутри функции `compute_loss`.
"""
# Импорт необходимых библиотек
from transformers import TrainingArguments
from accelerate import Accelerator
# Инициализация accelerator
accelerator = Accelerator()
# Аргументы обучения
training_arguments = TrainingArguments(**config["training"])
# Проверка наличия предобработанного датасета
if 'tokenized_dataset' not in locals():
# Если датасет не предобработан, выполняем необходимую предобработку
# Код предобработки датасета должен быть здесь...
print("Необходимо сначала выполнить предобработку датасета!")
# Создание кастомного SFT тренера
trainer = LogitsTrainer(
model=student_model,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=student_tokenizer,
args=training_arguments,
max_seq_length=config["tokenizer"]["max_length"],
dataset_text_field="text",
)
# Добавление модели-учителя к тренеру
trainer.teacher_model = teacher_model
# Подготовка к распределенному обучению
trainer = accelerator.prepare(trainer)
# Запуск обучения
trainer.train(resume_from_checkpoint=config["training"]["resume_from_checkpoint"])
# Сохранение финальной модели
trainer.save_model(config["training"]["output_dir"])
print(f"Обучение завершено. Модель сохранена в {config['training']['output_dir']}")
Преимущества Soft-label Distillation:
Более полная передача знаний: Студент получает доступ к "темным знаниям" учителя — информации о сложных случаях, тонких различиях между классами и степени неопределенности.
Лучшие результаты: Студенты, обученные этим методом, обычно демонстрируют производительность ближе к учителю по сравнению с Hard-label Distillation.
Улучшенная генерализация: Модели лучше работают на новых данных, так как учатся не только "что" предсказывать, но и "с какой уверенностью".
Контроль через температуру: Параметр T позволяет настраивать степень "мягкости" дистилляции. Более высокие значения T делают распределение более равномерным, помогая передать больше информации о маловероятных классах.
Совместимость с другими методами: Легко комбинируется с другими техниками улучшения моделей.
Недостатки Soft-label Distillation:
Вычислительные затраты: Для языковых моделей с большими словарями (50,000+ токенов) хранение и передача полных распределений вероятностей требует значительных ресурсов.
Сложность реализации: Требует доступа к логитам/вероятностям учителя, а не только к финальным предсказаниям.
Настройка гиперпараметров: Необходимо тщательно подбирать температуру T и коэффициент α для оптимальных результатов.
Зависимость от качества учителя: Если учитель имеет систематические ошибки, они могут передаться студенту.
Сравнение Hard-label и Soft-label Distillation:
Аспект | Hard-label Distillation | Soft-label Distillation |
---|---|---|
Передаваемая информация | Только итоговые классы/токены | Полные распределения вероятностей |
Температура | Не используется | Используется для "смягчения" распределений |
Сложность реализации | Простая | Средняя |
Вычислительные требования | Низкие | Средние-высокие |
Объем хранимых данных | Малый | Большой (особенно для языковых моделей) |
Качество получаемой модели | Хорошее | Лучшее |
Способность передавать неопределенность | Низкая | Высокая |
Эффективность для языковых моделей | Средняя | Высокая |
В заключение, Soft-label Distillation предлагает более мощный метод передачи знаний от учителя к ученику, особенно для сложных задач, где важны тонкие различия между классами и понимание неопределенности. Ключевое отличие от Hard-label Distillation заключается в использовании полных распределений вероятностей и temperature scaling, что позволяет извлечь "темные знания" и научить студента не только выдавать правильные ответы, но и воспроизводить тонкие нюансы рассуждений учителя.
После того, как DeepSeek представил в open source свой метод дистилляции знаний для R1, исследователи из Apple и Оксфордского университета быстро предложили закон масштабирования дистилляции и уже 28 февраля завершили все эксперименты и загрузили 67-страничную статью на arXiv.
Рассмотрим мотивацию исследования, которая сводится к следующим пунктам:
Текущее состояние исследований законов масштабирования моделей: В последние годы исследования выявили взаимосвязь между производительностью языковых моделей, их размером и объемом данных для обучения. Однако систематических исследований законов масштабирования в контексте дистилляции пока не проводилось.
Проблема стоимости вывода модели: С увеличением размера языковых моделей значительно возрастает стоимость вывода. Исследование того, как снизить стоимость вывода без потери производительности, становится важной задачей.
Эффективность и производительность дистилляции: Теоретически, дистилляция может снизить стоимость вывода, однако в академических кругах нет единого мнения относительно методов дистилляции, особенно в том, как рационально распределить вычислительные ресурсы для создания наиболее мощных моделей, что остается большой неопределенностью.
Экстраполяции закона масштабирования дистилляции. Закон масштабирования дистилляции (Уравнение 8) аппроксимирован на слабых учениках
для ряда учителей с потерями
. Сплошные линии представляют прогнозируемое поведение модели для невидимых учителей при заданной конфигурации ученика (интерполяция), а пунктирные линии представляют прогнозируемое поведение модели за пределами видимых учителей и для области сильных учеников
.
Традиционный закон масштабирования (Scaling Laws) для больших моделей демонстрирует, что производительность языковой модели (LM) может улучшаться с увеличением вычислительных ресурсов, если модель следует оптимальной вычислительной парадигме обучения. Однако постоянный рост затрат на инференс делает этот подход все менее практичным, что заставляет исследователей искать альтернативные методы, включая переобучение и дистилляцию, для создания небольших, но мощных моделей.
Исследователи провели обширные эксперименты, используя модели-студенты и модели-учителя с параметрами от 143 миллионов до 12,6 миллиардов и объемом данных до 512 миллиардов токенов. Целью было изучить взаимосвязь между производительностью модели и вычислительными ресурсами в процессе дистилляции, а также найти способы оптимизации распределения этих ресурсов.
В следующей таблице показано значение символов, используемых в этой статье:
Таблица. Выражения, связанные с законами масштабирования, используемые в данной работе. В каждом случае всегда относится к ученику, а не к обучению с учителем.
Выражение | Значение |
---|---|
Количество параметров модели/ученика/учителя, не связанных с эмбеддингом. В тексте, когда мы упоминаем параметры, мы всегда имеем в виду параметры, не связанные с эмбеддингом, если не указано иное. Подробности см. в Приложении H.2. | |
Количество токенов, на которых предобучена модель/учитель. | |
Количество токенов, на которых дистиллирован ученик. | |
Соотношение токенов на параметр, или MM-соотношение. В работе Hoffmann et al. (2022), M принимает оптимальное значение | |
Кросс-энтропия модели, которая представляет собой валидационную кросс-энтропию модели на данных, оцениваемую по закону масштабирования с учителем для модели с N параметрами, обученной на D токенах. (Уравнение 1). | |
Кросс-энтропия учителя, которая представляет собой валидационную кросс-энтропию учителя на данных, оцениваемую по закону масштабирования с учителем для учителя с | |
Кросс-энтропия ученика, которая представляет собой валидационную кросс-энтропию ученика на данных, оцениваемую по нашему закону масштабирования дистилляции для ученика с | |
Кросс-энтропия ученика с учителем, которая представляет собой валидационную кросс-энтропию ученика на данных, если бы ученик был обучен с учителем, оцениваемую по закону масштабирования с учителем для ученика с |
Пояснение: Кросс-энтропия — это метрика, измеряющая расхождение между предсказанным распределением вероятностей модели и истинным распределением. Чем ниже кросс-энтропия, тем лучше модель предсказывает правильные токены. Это основной показатель качества языковой модели.
Пояснение к правилу Чинчиллы: Исследование Hoffmann et al. (2022) установило эмпирическое правило оптимального соотношения между количеством параметров модели и количеством токенов для обучения — примерно 20 токенов на каждый параметр. Это правило позволяет эффективно распределять вычислительные ресурсы при обучении крупных языковых моделей.
Центральным вкладом исследования является формулировка закона масштабирования дистилляции:
Объяснение переменных:
— кросс-энтропия студента (мера ошибки предсказания; чем ниже, тем лучше модель).
— кросс-энтропия учителя (мера ошибки предсказания большой модели).
— количество неэмбеддинговых параметров студента (основные обучаемые параметры модели).
— количество токенов, использованных для обучения студента при дистилляции.
— потенциальная кросс-энтропия студента при обычном обучении без дистилляции, определяемая классическим законом масштабирования:
— коэффициенты, определяемые эмпирически.
и
— положительные коэффициенты, зависящие от архитектуры модели и характеристик набора данных.
Физический смысл формулы:
1. Базовая часть: — студент не может быть лучше учителя.
2. Модифицирующая часть: Остальная часть формулы описывает, насколько эффективно студент может приблизиться к учителю в зависимости от своего размера, количества данных и качества учителя.
Ключевые выводы:
1. Студент не может превзойти учителя (всегда). Кросс-энтропия (L) - это мера ошибки модели. Чем ниже значение L, тем лучше модель предсказывает данные.
2. Чем ближе потенциальная производительность студента к производительности учителя, тем эффективнее дистилляция.
3. При фиксированном учителе закон масштабирования дистилляции не превосходит обычный закон масштабирования.
Практическое применение:
Этот закон позволяет оптимально распределить вычислительные ресурсы между учителем и студентом и прогнозировать эффективность дистилляции.
То есть: Этот закон описывает, как качество маленькой модели зависит от трех факторов: размера самой модели, количества данных для обучения и качества большой модели-учителя. Ключевой вывод: студент никогда не может быть лучше учителя, но насколько близко он подойдет к учителю, зависит от его собственных возможностей и объема тренировки.
Рассмотрев общий закон масштабирования дистилляции, важно также понять практические аспекты реализации этого процесса, в частности, как управлять балансом между имитацией учителя и самостоятельным обучением модели-ученика.
Основная идея дистилляции знаний заключается в переносе информации от большой модели-учителя к компактной модели-ученику. В этом процессе прогнозируемое распределение вероятностей модели-учителя используется в качестве целевого распределения для модели-ученика. Обучение происходит путем минимизации расхождения Кульбака-Лейблера (KL-дивергенции) между распределениями ученика и учителя:
где:
- и
— выходные логиты моделей учителя и ученика соответственно
- — температура дистилляции, контролирующая "сглаженность" распределения вероятностей учителя
- — функция softmax, преобразующая логиты в вероятности
- — размер словаря
Комбинированная функция потерь для модели-ученика объединяет несколько компонентов:
где:
- — потеря при предсказании следующего токена (стандартная кросс-энтропия)
-— потеря при дистилляции знаний (KL-дивергенция)
-— регуляризационная Z-потеря, стабилизирующая обучение путем нормализации логитов
-— коэффициент смешивания, определяющий баланс между обучением на "чистых" данных и имитацией учителя
-— весовой коэффициент для Z-потери
Для определения влияния параметров дистилляции на эффективность закона масштабирования, исследователи провели серию экспериментов. Чтобы исключить влияние данных и сосредоточиться именно на роли модели-учителя, эксперименты проводились в режиме "чистой дистилляции" с . Результаты показали, что такой выбор
даёт результаты, статистически сопоставимые с использованием оптимальных значений
.
Во всех экспериментах использовалась фиксированная температура дистилляции , которая эмпирически показала наилучшую эффективность для обучения модели-ученика.
Коэффициенты смешивания.
(a) Модели-ученики шести размеров, обученные с соотношением
, дистиллируются от моделей-учителей размеров
, обученных с соотношение
, с различными значениями коэффициента смешивания
. Значения
и
соответствуют стандартному обучению и чистой дистилляции соответственно.
(b) Оптимальные коэффициенты смешивания, дающие наименьшую потерю на валидационном наборе для каждой пары учитель-ученик.
Эти эксперименты подтверждают, что параметры дистилляции оказывают существенное влияние на итоговую производительность модели-ученика, и их оптимальный выбор напрямую связан с размерами моделей учителя и ученика, что согласуется с общим законом масштабирования дистилляции.
Дистилляция знаний — это метод, позволяющий передать способности большой нейронной модели (учителя) меньшей и вычислительно эффективной модели (ученику). Процесс основан на обучении модели-ученика имитировать распределение вероятностей модели-учителя путём минимизации расхождения Кульбака-Лейблера между их предсказаниями.
Эффективность дистилляции определяется балансом нескольких компонентов в функции потерь:
Стандартной кросс-энтропии при предсказании следующего токена
KL-дивергенции при имитации учителя
Регуляризационной Z-потери для стабилизации обучения
Два ключевых параметра контролируют этот процесс:
Коэффициент смешивания λ, регулирующий баланс между самостоятельным обучением и имитацией учителя
Температура дистилляции τ, влияющая на "сглаженность" распределения вероятностей
Экспериментальные исследования демонстрируют, что режим "чистой дистилляции" (λ = 1) при температуре τ = 1 часто даёт результаты, сопоставимые с оптимально подобранными параметрами. Однако наиболее важным открытием является то, что идеальные значения этих параметров системно зависят от соотношения размеров конкретной пары моделей учитель-ученик.
Это открытие соответствует общему закону масштабирования дистилляции и имеет прямое практическое применение: для достижения максимальной эффективности при практической реализации дистилляции необходим индивидуальный подбор параметров с учётом размеров используемых моделей, что позволяет существенно улучшить итоговую производительность компактной модели при сохранении её вычислительной эффективности.
Размер модели учителя и объем обучающих данных на которых обучался учитель, фиксированы, а размер модели ученика и объем дистилляционных данных варьируются. Цель состоит в том, чтобы изучить, как производительность модели ученика меняется в зависимости от ее размера и объема обработанных дистилляционных данных в условиях фиксированной модели учителя. Таким образом, можно определить оптимальную производительность модели студента при различных масштабах и объемах данных.
Из результатов эксперимента можно заметить, что:
При высокой вычислительной мощности, чем больше масштаб параметров модели ученика, тем меньше его функция потерь, и чем больше масштаб модели учителя, тем очевиднее эта тенденция.
Когда размер моделей ученика и учителя определен, становится понятно, что чем больше вычислительная мощность, тем лучше будет работать модель ученика.
При низкой вычислительной мощности производительность модели сначала улучшится, а затем ослабнет с размером модели. Здесь легко понять, что более крупные модели не полностью обучаются при меньшей вычислительной мощности.
В особых случаях модель ученика может превзойти модель учителя и показать способность к обобщению. Я лично предполагаю, что модель учителя может быть недообучена в таких сценариях.
Размер модели ученика и объем данных дистилляции фиксированы, а размер модели учителя и объем обучающих данных варьируются. Цель состоит в том, чтобы изучить, как эффективность модели учителя влияет на конечную эффективность модели ученика. Таким образом, можно определить оптимальный размер модели учителя и объем обучающих данных для максимизации производительности модели ученика.
Как видно из результатов, чем больше параметры у модели учителя, тем ниже перекрестная энтропия модели ученика. Это показывает, что для достижения наилучшего эффекта дистилляции производительность модели учителя должна соответствовать возможностям модели ученика.
Чтобы понять, когда дистилляция приносит пользу, на следующем рисунке сравнивается производительность дистилляции и контролируемого обучения при фиксированных вычислительных ресурсах. Результаты показывают, что контролируемое обучение всегда превосходит дистилляцию при наличии достаточного количества вычислений или данных у учащихся. При умеренном бюджете данных дистилляция имеет преимущества, однако при наличии больших объемов данных контролируемое обучение превосходит дистилляцию.
Подводя итог, можно сказать, что при ограниченных вычислительных ресурсах дистилляция обычно более эффективна, чем контролируемое обучение. Это связано с тем, что дистилляция может быстрее усваивать эффективные представления признаков под руководством модели учителя, тем самым достигая более высокой производительности при меньших вычислительных ресурсах.
Сила обучающего сигнала: Модели учителей разных размеров могут обеспечивать разную силу обучающего сигнала, которая обычно измеряется с помощью потери перекрестной энтропии. Более крупная модель учителя может обеспечить более сильный сигнал обучения (более низкая перекрестная энтропия), тем самым помогая модели ученика лучше учиться.
Увеличение затрат: использование более крупной модели учителя повлечет за собой более высокие затраты из-за необходимости вычисления логитов модели учителя. Это означает, что более крупная модель учителя не только более затратна в обучении, но и потребляет больше вычислительных ресурсов при использовании для дистилляции.
На рисунке ниже показано изменение потери перекрестной энтропии модели студента при различных бюджетах данных дистилляции. Результаты показывают, что оптимальная потеря учителя (представленная красной линией) уменьшается по степенному закону с увеличением численности учащихся до тех пор, пока потеря ученика не сравняется с оптимальной потерей учителя.
Как видно на другом рисунке ниже, по мере увеличения объема данных дистилляции, перекрестная энтропия оптимальной модели учителя постепенно уменьшается. Таким образом, можно сделать вывод, что: когда вычислительные ресурсы ограничены, выбор меньшей модели учителя может снизить затраты на вывод, при этом обеспечивая эффективные сигналы обучения для модели ученика.
Целью вычислительно оптимальной дистилляции является определение способа создания модели студента желаемого размера с наименьшей перекрестной энтропией при заданном вычислительном бюджете. В частности, необходимо найти оптимальный объем данных для обучения учащихся, размер модели учителя, данные для обучения, чтобы минимизировать перекрестную энтропию студента и при этом удовлетворить ограничениям вычислительного бюджета.
На рисунке ниже мы видим:
Контролируемое обучение всегда соответствует наилучшему варианту настройки дистилляции при достаточном вычислительном бюджете: Контролируемое обучение всегда соответствует наилучшему варианту настройки дистилляции при определенном общем вычислительном бюджете. Это означает, что контролируемое обучение может достичь той же производительности, что и дистилляция, если вычислительный бюджет достаточно велик.
Если в вычисления включено обучение учителя, перекрестная энтропия учащихся всегда выше, чем в контролируемой обстановке: это означает, что если вашей единственной целью является создание наилучшей модели с целевым размером и у вас нет доступа к учителю, вам следует выбрать контролируемое обучение вместо обучения учителя и последующей дистилляции. Напротив, если цель состоит в том, чтобы выделить семейство моделей или использовать учителя в качестве обслуживающей модели, то выделение может оказаться более выгодным с вычислительной точки зрения, чем контролируемое обучение.
Меньшие модели с большей вероятностью получат выгоду от контролируемого предварительного обучения, в то время как более крупные модели с большей вероятностью получат выгоду от дистилляции: Меньшие модели с большей вероятностью получат выгоду от контролируемого обучения при больших вычислительных бюджетах, в то время как более крупные модели с большей вероятностью получат выгоду от дистилляции при больших вычислительных бюджетах.
На рисунке ниже показаны тенденции изменения оптимального размера учителя и объема обучающих данных по мере изменения вычислительного бюджета. Токены моделей студентов и преподавателей масштабируются по степенному закону, причем токены студентов растут быстрее. Размер лучшей модели учителя сначала увеличивается, пока не станет немного больше ученика, а затем стабилизируется. Это связано с тем, что использование большой модели учителя для вывода обходится дорого, и по мере увеличения количества токенов учеников более эффективным становится переобучение модели учителя.
В результате исследований авторы пришли к следующим выводам:
Предсказуемость производительности через закон масштабирования: Производительность модели-студента размером , полученной путем дистилляции из модели-учителя размером
с использованием
токенов, может быть предсказана с помощью разработанного закона масштабирования дистилляции.
Практическое значение: Это позволяет заранее оценить, какой результат можно получить от процесса дистилляции, не проводя дорогостоящих экспериментов. Компания может спланировать свои ресурсы и решить, стоит ли вкладываться в дистилляцию, или лучше выбрать другой подход к созданию эффективной модели.
Влияние параметров учителя на студента: Размер модели-учителя NTNT и количество токенов для её обучения определяют кросс-энтропию модели-учителя
, которая, в свою очередь, влияет на кросс-энтропию модели-студента.
Наглядный пример: Представьте учителя как источник знаний для студента. Если учитель сам недостаточно образован (высокая кросс-энтропия), он не сможет хорошо обучить студента, какими бы способностями студент ни обладал.
Феномен "разрыва в способностях": Исследование выявило интересный эффект - более сильный учитель может привести к худшему студенту, что объясняется "разрывом в способностях" (capacity gap). Влияние кросс-энтропии модели-учителя на потери модели-студента следует степенному закону, который переключается между двумя режимами в зависимости от относительной способности к обучению студента и учителя. Исследование показало, что важен именно разрыв в способности к обучению (гипотезное пространство и оптимизационная способность) между учителем и студентом, а не просто их относительный размер.
Аналогия для понимания: Представьте, что профессор квантовой физики пытается обучить первоклассника. Несмотря на высокую квалификацию профессора, первоклассник не сможет усвоить сложный материал из-за разрыва в способностях к обучению. Аналогично, если модель-учитель слишком сложна и "мыслит" на уровне, недоступном модели-студенту, эффективность обучения снижается.
U-образная зависимость ошибки студента: Эмпирически подтверждается U-образная зависимость ошибки студента от размера учителя при фиксированном размере студента, что теоретически обосновывается разрывом в емкости между ними.
Визуальное представление: Если изобразить ошибку студента на графике, где по горизонтальной оси отложен размер учителя, мы увидим U-образную кривую. Это означает, что существует оптимальный размер учителя для данного студента — не слишком маленький (недостаточно знаний) и не слишком большой (слишком сложное представление знаний).
Результаты исследования показывают, что дистилляция становится более эффективной, чем обучение с учителем, при соблюдении следующих условий:
Общее количество вычислений или токенов для студента не превышает пороговое значение, связанное с размером студента, согласно новому закону масштабирования.
Практический сценарий: Для компании с ограниченным бюджетом на вычисления, которая хочет создать модель размером 1 миллиард параметров, дистилляция может быть оптимальным выбором, если доступно менее 20 миллиардов токенов для обучения (согласно правилу Чинчиллы).
Модель-учитель уже существует, или обучение модели-учителя имеет применение за пределами одной дистилляции.
Бизнес-кейс: Если компания уже обучила крупную модель для своих основных задач, имеет смысл использовать её для дистилляции меньших, специализированных моделей для развертывания на мобильных устройствах или в средах с ограниченными ресурсами.
🔥Не пропустите важные обновления и углубленные материалы!🔥
Хотите быть в курсе самых свежих обзоров и исследований в мире ML и AI? Переходите по ссылкам ниже, чтобы получить доступ к эксклюзивному контенту:
📌 Все обзоры также доступны в нашем Telegram канале TheWeeklyBrief📢
📌 Более подробный обзор с математической формализацией и программным кодом ждет вас в нашем репозитории Weekly-arXiv-ML-AI-Research-Review 👩💻📂✨
Не упустите шанс глубже погрузиться в мир технологий! 🚀