Всем привет. На связи Игорь Буянов, старший разработчик в MTS AI. Этот пост — текстовый вариант моего доклада, с которым я выступал в прошлую пятницу на Pycon 2024. Расскажу о том, как мы оптимизировали параметры аугментаций для текстовых данных и что из этого получилось. Текст рассчитан на широкий круг читателей, поэтому если вы слышите про аугментации впервые — не пугайтесь, разберемся.
Работа каждого ML-инженера — сделать свою модель лучше. Чтобы этого достичь, нужно либо работать над моделью, либо повышать качество и количество данных. Мы рассмотрим второй путь. Какие здесь могут быть проблемы? Иногда данные бывают настолько специфичными и редкими, что просто так достать новую порцию нельзя, или данные находятся на разметке, а улучшить модель нужно «вот прям щас». А может мы маленький стартап и денег на разметку просто нет. Что делать в таких случаях?
Мы можем взять существующий пример из датасет и немного изменить его, совсем чуть-чуть, чтобы его связь с классом не потерялась. Пусть изменения и небольшие, но по факту мы получаем уже другой пример, на котором можем обучаться. Вот это и называется аугментацией. Кто-то может спросить, а почему они вообще работают? Давайте посмотрим вот на эту картинку.
Готов спорить, что слева — типичная картинка из известных CV датасетов типа ImageNet. Она правильная, красивая, качественная, ее точно снимал фотограф в студии. Проблема появляется, когда модель, обученная на таких правильных картинках, начинает работать с фотографиями обычных пользователей. Завал горизонта, обрезанные фотографии, расфокусировка и прочие дефекты — вот с чем придется работать модели. Но если модель никогда не видела, что котик может быть представлен только «в половину», то она может сильно просесть в уверенности своего предсказания, а то и вовсе предсказать что-нибудь не то.
Мы можем это исправить, если сами, с помощью трансформации в одну строчку кода, наделаем таких «неправильных» фоток, как на картинке выше справа. Таким образом мы только за счёт компьюта сможем снизить влияние этого эффекта.
Если говорить в общем, то с помощью аугментаций, мы пытаемся расширить разнообразие нашего датасета и таким образом помогаем модели обобщить свои знания о природе данных, чтобы лучше справляться с задачей.
Выше мы посмотрели на пример CV, поскольку он супер наглядный. Давайте теперь посмотрим, какие аугментации есть для текста. Я их разделил на две группы: алгоритмические и генеративные.
В эту группу попадают те аугментации, что пишем мы руками. Часто они совершают одно очень простое действие. К примерам таких аугментаций можно отнести так называемые простые дата аугментации (easy data augmentation, eda), про них речь пойдет далее в посте.
Еще примером могут быть опечатки. Идея чем-то похожа на пример с котиков. Если вы работаете с текстами, которые печатают люди, особенно на смартфонах, то логично ожидать, что они будут содержать опечатки. Для моделей слово с опечаткой — это совсем другое слово, поэтому не стоит удивляться, что она будет вести себя на них странно. Но мы можем сами сделать опечатки. Способов много: случайная замена символов, эмуляция qwerty-клавиатуры или даже имитация ошибок самих пользователей. Для последнего есть хорошая библиотека Sage.
Ко второй группе относятся всё, где используется любая генеративная модель. Обратный перевод, как ходовой пример. Его идея — перевести исходный текст на какой-нибудь другой язык, а затем с того языка перевести обратно. Смотрите как это работает на тексте классика:
«Я помню чудное мгновенье: передо мной явилась ты» → “I remember a wonderful moment: you appeared before me” → «Я помню замечательный момент: вы появились передо мной»
Как видите, смысл сохранен, но вот лексически тексты различаются. В эпоху до трансформеров была такая забава: прогнать текст цепочкой через 10 разных языков и вернуть обратно на русский. Иногда смысл терялся полностью, иногда частично, и порой это порождало смешных текстов. Сейчас у меня уже не получилось повторить также, переводчики всё-таки стали куда лучше, чем 10 лет назад. С другой стороны, я не долго пытался.
Конечно, в эту же категорию попадают вообще любые способы аугментации датасета с помощью ChatGPT. Особенность языковых моделей в том, что вы можете напрямую синтезировать целые датасеты, но нужно помнить, что это всё еще не решение всех проблем, но ведется много работы в эту сторону.
Одна из моих задач — строить и улучшать классификаторы интентов. Именно эти модели понимают (ну должны, во всяком случае), что хочет пользователь. Интентов у нас сотни и в них наблюдается естественный дисбаланс данных — есть супер популярные интенты, а есть такие, что за месяц обращений будет, как кот наплакал. Мы называем такие интенты малыми или тонкими, для них граница — 100 примеров в обучающей выборке.
С ними вот какая проблема. Как правило, это специфичные просьбы более общих интентов, которые имеют большое количество данных. Грубо говоря, их различие может быть в какой-то фразе или даже слове. Модель же, просто потому что более общий интент видит значительно чаще, начинает все эти специфичные интенты также запихивать в общий. Подумалось, что если мы увеличить объём мелких интентов, то они смогут потеснить крупные интенты и не раствориться в градиентах.
Когда я обдумывал задачу, у меня в голове держались три мысли:
Аугментации обычно имеют два параметра: вероятность срабатывания и/или количество срабатываний.
Изменяя параметры аугментации, мы изменяем итоговые данные, а значит и качество модели.
Почему бы не выстраивать аугментации в последовательности, чтобы получить большее разнообразие?
Из них появился общий вопрос:
КАКИЕ НАСТРОЙКИ И ПОСЛЕДОВАТЕЛЬНОСТЬ АУГМЕНТАЦИЙ ДАДУТ МАКСИМАЛЬНО КАЧЕСТВЕННУЮ МОДЕЛЬ?
Тут рядом пробежала мысль, что у моделей есть свои параметры, которые мы тюним, чтобы добиться максимально возможного выхлопа. Они не обучаются во время тренировки модели, мы их задаем сами, поэтому мы называем их гиперпараметрами. Проведя параллели, в голову пришел ответ:
ПРЕДСТАВИМ АУГМЕНТАЦИЮ КАК ГИПЕРПАРАМЕТР
Далее встал вопрос, как будем эти гиперпараметры искать. На уме три способа:
поиск по сетке (grid search) — полный перебор всего пространства поиска. Надежно, но непозволительно долго;
случайный поиск (random search) — просто произвольно выбираем параметры и смотрим, что будет. Быстрее, чем по поиск по сетке, но непредсказуемо;
байесовским поиск (bayesian search) — поиск с использованием фреймворка байесовской оптимизации. Доверившись мат. аппарату, можно сказать, что это оптимальный вариант. Но, честно сказать, мне просто хотелось его пощупать.
Хорошо, с поиском решили, но у нас есть одна маленькая проблема. Наша тренировочная выборка настолько большая, что обучение одной модели занимает 6 часов на четырех GPU, а оптимизация требует десятки итераций, чтобы получить какой-то внятный результат. Мы бы дождались сильного ИИ, если бы использовали нашу продовую модель. Мы решили это обойти с помощью прокси-модели в виде логистической регрессии.
Итак, общий сетап у нас такой:
Прокси-модель для оценки качества — логистическая регрессия.
В качестве векторизатора используем корпоративный BERT.
Целевая функция для оптимизации —. Такой странный вид нужен для фреймворка оптимизации, он работает только на поиск минимума. Макро используем, потому что оно проседает сильнее, чем micro и weighted.
Максимальное количество шагов — 100. Просто потому что 100.
Обучение проводится лишь на малых классах, а не на всём датасете. Тоже такой прокси-подход в угоду времени обученияю
Теперь давайте поговорим про набор аугментаций, который мы использовали. Почти все они составляют набор EDA, кроме одной кастомной.
Проходимся по токенам и крутим рандом-машину. Если рандом выдал число больше порога, то удаляем токен. Можем сделать так несколько раз. Пример:
«Как мне пополнить счет сим-карты» → «Как ___ пополнить счет сим-карты»
Случайно выбираем два токена и меняем местами. Можем сделать так несколько раз. Пример:
«Как мне пополнить счет сим-карты» → «Счет мне пополнить как сим-карты»
Проходимся по словам в тексте и заменяем их на синоним. Можем сделать несколько раз. Для этой аугментации нужен внешний ресурс — словарь с наборами синонимов для каждого слова. Такие наборы называются тезаурусы. Для русского языка есть RuWordNet.
«Как мне пополнить счет сим-карты» → «Как мне обогатить счет сим-карты»
А вот это как раз наша кастомная аугментация. Я работаю над голосовым чат-ботом, а значит мои тексты — распознанная речь. Речевые паузы — это всякие «ээ», «аа», «мм» и прочие звуки, которые люди вставляют, когда думают по ходу речи. Мы как-то заметили, что наличие такой паузы может сильно смутить модель, поэтому решили сделать это аугментацией. Работает так. Прыгаем по промежуткам между токенами и крутим рандом машину. Если число больше порога, то случайно выбираем паузу из набора «ээ», «аа», «мм» — они самые частые в данных — и вставляем. Можем сделать так несколько раз.
«Как мне пополнить счет сим-карты» → «ээ как мне пополнить счет мм сим-карты»
Такая же механика с рандом-машиной, но в этот раз прыгаем по токена и, если повезет, то просто удваиваем токен. Тоже можем сделать так несколько раз.
«Как мне пополнить счет сим-карты» → «как мне пополнить счет счет сим-карты»
К сожалению, у меня не хватило времени закодить перебор последовательностей нормально, поэтому пришлось вручную выбрать три последовательности, которые мне показались разумными для перебора. Вот они на слайде ниже:
Немного расскажу про два инструмента, которые помогли всё это реализовать.
Это открытая библиотека перебора гиперпараметров. Ее я выбрал исключительно потому что встретил в одной из статей, когда изучал тему. Вы можете выбрать любой другой фреймворк, который вам знаком, например, Optuna. Вот снипет кода, который показывает основные компоненты для работы :
Мы должны определить функцию, внутри которой делаем всё необходимое для подсчёта функции оптимизации. Функция (питонячья) обязательно должна возвращать значение функции оптимизации. Далее идет словарь, в котором мы определяем пространство поиска гиперпараметров, и, наконец, сам запуск.
Это моя скромная открытая библиотечка, в которой я поначалу собирал всякие методы препроцессинга текста, потому что мне надоело однажды каждый раз писать эти функции заново. Кроме того, хотелось бы иметь способ логирования препроцессинга, поэтому все фабрики — объекты, которые выполняют препроцессинг — могут быть сериализованы в виде yaml файла, а дальше вы можете их сохранить в dvc, wandb, clearml и т. д. Позже я добавил еще и аугментации. Вот снипет кода для понимания:
На слайде ниже показан весь алгоритм в сборке
А вот снипет кода, который этот алгоритм реализует:
Сверху объявляется банк аугментаций, затем в permutitions идут последовательности. Далее все просто: собирается фабрика, через нее прогоняется весь тренировочный датасет, который потом полностью соединяется с исходным. Наконец, обучается модель и результат подсчитывается на валидационной выборке.
На выходе от NePS нас ждет табличка, где в каждой строчке указаны параметры, которые были выбраны, время работы и функция потерь (сортировку выполнил я).
Итоговая конфигурация получилась такой:
Примечательно, что это последовательность, в которой нет удаления токенов. Чтобы как-то глазами оценить, как выглядят тексты после этой аугментации, вот вам несколько примеров:
Теперь самое интересное — результаты самой модели.
Мы сравнили по F1 macro варианты прокси-модели (логистическая регрессия и BERT как энкодер) в трех случаях обучения:
На чистых данных малых классов. Ррезульта — 0.902;
На худшем варианте аугментации малых классов. Результат — 0.887 (-0.01);
На лучшем варианте аугментации малых классов. Результат — 0.933 (+0.03).
О чем нам это говорит? Во-первых, у нас есть вероятность навредить аугментациями, это очень важно. Во-вторых, нам удалось найти аугментацию, которая дала прирост аж на 3 пункта. Тут ремарку сделаю, что вообще говоря, чем ближе мы к единице по качеству, тем сложнее улучшить модель. На таких значениях F1 три пункта — это очень много. И всё это за счёт компьюта и относительно небольшого количества времени. Да, здесь речь идет о прокси-модели, но если представить, что эта модель у нас целевая, то прям вау.
На продовой же модели после прогона мелких классов через лучшую аугментацию мы получили прирост в два пункта, что тоже немало.
Давайте посмотрим на статистику результатов оптимизации:
Что мы можем сказать:
Чтобы получить отрицательный результат, вам должно очень не повезти.
Если бы вы рандомно выбрали аугментации, то скорее всего получили бы прирост в 2 пункта (здесь речь про прокси-модель, напоминаю).
Кажется, что поиск гиперпараметров себя оправдал, потому что лучший результат тоже на везунчика.
Можно было бы сказать, что заморачиваться так ради одного пункта не стоило. Но по факту заморочиться надо один раз — написать логику перебора. Дальше можно ставить на ночь считаться эксперименты, а затем забрать лучший результат. Как я писал выше, на таких уровнях качества разница в 1 пункт — это много.
В какую сторону можно двигаться:
Самое простое, что бросается в глаза — перебор последовательностей. Это не долго, но мне времени не хватило.
Провести верификацию переноса результатов с прокси-модели на реальную — то, что мы получили прирост в 2 пункта на продовой модели еще не говорит о том, что нам не просто повезло. В идеале надо прогнать десяток экспериментов с разными аугментациями и посмотреть на корреляцию между качеством продовой модели и прокси-модели.
Или вовсе включить в работу реальную модель, но в упрощенном варианте. Например, можно заморозить энкодер или использовать дистилированную версию Берта. Да, это всё еще не сама продовая модели, но приближение явно лучше, чем логистическая регрессия.
Можно также расширять аугментации по вкусу. Мы использовали алгоритмические из-за их простоты, но никто не мешает использовать генеративные.
Для нашего домена RuWordNet не очень подходящий тезаурус, поэтому можно попытаться собрать свой, пусть он даже будет не очень чистый.
Еще можно поработать с ключевыми словами, чтобы их случайно не удалить или не исказить слишком сильно, иначе модель не сможет ни на что опереться.
Я могу предсказать пару вопросов, которые могли возникнуть у вас:
«Так че там с малыми классами?» — этот вопрос я получал во время моих тестовых прогонов, но к сожалению, детальные метрики потерялись.
«А сколько можно добавлять аугментаций?» — действительно, я писал, что я просто прогнал весь набор текстов с тонкими классами, но оптимально ли это? Вообще, на этот вопрос нельзя так просто ответить, есть целые работы, которые этому посвящены. Но вам никто не мешает представить этот вопрос как поиск гиперпараметров.
На этом у меня всё. Хочу сказать большое спасибо коллегам из MTS AI, которые помогали мне с подготовкой к выступлению, и организаторам Pycon за крутую конфу.
Да-да, теперь ссылка на мой тг-канал, в котором я пишу почаще, чем здесь.