Добрый день! Меня зовут Александр, я занимаюсь с компанией ITFB Group проектами, связанными с искусственным интеллектом, в частности, с обучением и использованием нейронных сетей. Думаю, ни для кого не секрет, что многие из успехов в глубинном обучении достигнуты отчасти благодаря тому, что разработчики, говоря простым языком, взяли модели побольше и натренировали их на огромных объёмах данных. Однако, чтобы прогнать эти самые огромные объёмы данных через модель, нужно либо очень много времени, либо каким-то образом распределить работу на много вычислительных узлов — сделать обучение параллельным. Я видел на Хабре пару статей на эту тему, но дерзну попробовать написать ещё одну. Добавить кое-каких деталей, а что-то, что уже было, надеюсь, получится объяснить попроще. Поехали!
Самый простой подход к параллельном обучению нейронных сетей — это параллелизм по данным. Допустим, у нас есть в распоряжении дюжина GPU или ещё каких-то устройств. Давайте назовём их узлами. На каждый узел мы помещаем идентичную копию нашей модели. Термин «модель» тут означает то же самое, что и нейронная сеть. Каждая такая копия получает собственную уникальную порцию обучающих данных (mini-batch), а дальше всё, как если бы мы обучали модель на единственном узле: данные проходят через модель, предсказание модели сравнивается с эталонными значениями, вычисляется функция ошибки и градиенты для весов модели. Тут мы добавляем буквально один дополнительный шаг: перед тем, как мы должны были бы обновить веса модели, мы сначала усредняем градиенты между всеми нашими копиями. В самом прямом смысле — складываем, делим на количество копий и рассылаем всем обратно что получилось. Математически это никак не отличается от того, как если бы мы на одном-единственном устройстве просто обрабатывали выборку побольше за один шаг оптимизации.
Всё хорошо, но есть две проблемы. Первая — нам нужно это усреднение сделать быстро, за время, которые было бы сравнительно небольшим по сравнению с тем, что мы тратим на вычисления, иначе теряется смысл всей затеи. Вторая проблема — при оптимизации на большей выборке наша модель внезапно начинает обучаться хуже, т. е. медленнее в пересчёте на количество обработанных данных, а то и совсем не достигает желаемой точности.
К этой проблеме нужно подступиться с инженерной стороны. Во-первых, неплохо бы иметь высокопроизводительную сеть и использовать подходящий алгоритм усреднения. Кому интересно почитать поглубже — делайте поиск в сторону ring-AllReduce. Дальше можно придумывать разные оптимизации. Вот, например, обратное распространение ошибки, известное также как наш любимый бэкпроп или backward pass, — оно же происходит от слоя к слою. Посчитали самый последний слой — и можно уже усреднять его градиенты, тем временем рассчитывая следующий слой. Можно пересылать данные по сети в сжатом виде, можно пробовать делать это всё асинхронно, т. е. не ждать, если какой-то из узлов вдруг не успел посчитать свою порцию градиентов, а в следующей итерации использовать устаревшие (stale) градиенты, часто с ограничением на то, насколько они могут быть… эм... несвежими (bound staleness). Впрочем, этот подход последнее время не очень часто случается видеть.
Наконец, можно обработать больше данных на каждом из устройств перед тем как обмениваться градиентами. Логика такая: время на выполнение всех вычислений пропорционально количеству сэмплов, а то и увеличивается даже медленней, чем объём работы, т. е. показывает, как это называют в англоязычной литературе, sub-linear scaling. Так происходит, потому что при большем количестве данных устройства типа GPU работают более эффективно. Если совсем на пальцах — всем ядрам хватает работы, и доступ к памяти тоже загружается по полной.
А вот время на обмен градиентов пропорционально только размеру самой модели. Получается, что чем больше считаем сэмплов за одно обновление весов, тем меньше относительный вклад времени на сетевое взаимодействие. Если не хватает памяти посчитать больше сэмплов в одном мини-батче, можно посчитать один, сохранить градиенты в отдельном буфере или даже в какой-то внешней памяти, и считать следующий мини-батч, не делая обновления. Такой подход называется Gradient Accumulation. Всё хорошо, но при этом размер выборки, относительно которой делается шаг оптимизации (effective mini-batch size), становится еще больше, и это еще больше усугубляет нашу вторую проблему.
Что же не так с большим батчем? Во-первых, если батч (ну т. е. выборка данных, на которой мы считаем градиенты модели) большой, то количество обновлений модели за один проход по фиксированному объёму данных (всей нашей обучающей выборке) будет меньше. Тут простая арифметика. Если у нас всего сто сэмплов, и мы за один раз считаем десять, то получается десять обновлений. Если мы за один раз обрабатываем двадцать сэмплов, то, соответственно, шагов оптимизации мы сделаем только пять. На всякий случай: обновление или «шаг» оптимизатора — это когда мы к каждому из весов прибавили его градиент, умноженный на константу, которая называется learning late, или lr. Если lr у нас фиксированный, например, равный 0,0001, то за меньшее количество обновлений веса моделей смогут сместиться на меньшее расстояние.
В этом наблюдении уже есть подсказка про то, что можно сделать — можно увеличить learning rate. Возможное объяснение тут такое. Скажем, мы обрабатываем по одному сэмплу за раз, как в классическом stochastic gradient descent — стохастическом градиентном спуске. Представьте, что мы стоим на многомерной оптимизационной гиперповерхности: наши координаты — все веса модели, наша высота — функция ошибки. Градиент как бы говорит нам: «Иди налево». Посчитали еще один сэмпл — «Иди направо». И мы так прыгаем туда-сюда, но есть, к счастью, теоремы, которые обещают, что если так попрыгать некоторое время, то мы, скорее всего, придём куда надо. Если мы посчитали больше сэмплов за один раз, то наша идея о хорошем направлении более точная, что-ли. Мы можем пройти большее расстояние, прежде чем проверять еще раз, куда нам идти дальше.
Тут хотелось бы сделать одну оговорку. Теория обучения нейронных сетей сильно отстаёт от практики. Статьи, в которых что-то доказывается строго в математическом смысле, делают это часто для сферических коней в вакууме типа нейронных сетей с одним слоем и двумя нейронами. Не шучу. Статьи, которые что-то анализируют, часто скорее «теоретизируют», нежели анализируют. В те же далёкие времена, когда я познакомился с теорией оптимизации в университете, там вообще было в основном про расход топлива в ракете и вот это всё.
А сейчас большая часть статей — и вообще, и про большой батч в частности — просто забивает на теорию и занимается исключительно практикой. Объяснения дальше я выбрал те, которые кажутся мне более-менее адекватными и коррелируют с практическим опытом. Но читайте со здоровым скептицизмом.
Продолжаем. В благословенные времена, когда гнаться за точностью систем машинного обучения на датасете под названием Imagenet (задача object recognition) стало уже не очень модно, а языковые модели еще не захватили сцену, наметилось ещё одно небольшое соревнование. Кто натренирует ResNet50 (популярную свёрточную нейронную сеть для задач компьютерного зрения) до заданной точности быстрее? На том же самом датасете Imagenet.
Помимо разных ухищрений в оптимизации, задача закономерно вылилась в вопрос: «Каким хитрым способом нам выдержать батч побольше?»
Дело в том, что бесконечно увеличивать learning rate увы, не выходит. Обучение оказывается нестабильным. Хотя наши градиенты и становятся менее резкими, небольшое отклонение при очень большом шаге может выбросить нашу модель в такие тёмные углы векторного пространства, из которых она уже не сможет выбраться. Или один NaN (Not a Number — например, ошибка переполнения в арифметике с плавающей точкой) в активациях — и всё пропало.
Следующий шаг для решения этой проблемы — более «умные» оптимизаторы, которые накапливают некоторую статистику о том, насколько сильно изменялись параметры, и сами подкручивают learning rate. Входящие тогда в моду LARS и LAMB рекламировались как раз как «средство от большого батча».
Еще один нюанс: на ранних этапах обучения модель знает о наших обучающих данных примерно столько, сколько Джон Сноу, т. е. ничего. Ошибки получаются большие, градиенты, соответственно, тоже большие. И шаги, стало быть, нужно делать поменьше, чтобы не улететь в кювет, так сказать. На сцену выходит идея прогрева (preheat) — начинать обучение с маленьким learning rate, постепенно увеличивать его до какого-то заданного значения, а затем уменьшать, как мы обычно делали и без прогрева, чтобы к концу тренировки как бы «пришлифовать» модель аккуратными маленькими шагами.
Прогрев, если что, не обязательно только про learning rate. Можно, например, начинать прогрев с одним оптимизатором, а потом плавно переходить на другой: например, с rmsprop — в momentum SGD. Есть случаи, когда это применялось тоже как раз для проблемы большого батча.
А в 2019 году вышла статья под названием «Don't Decay the Learning Rate, Increase the Batch Size». Заголовок говорит сам за себя. Получается, что в идеале стоило бы начать обучение на небольшом количестве узлов, чтобы зря не считать градиенты на большой выборке, когда грубых шагов достаточно для движения в нужном направлении. А потом постепенно добавлять новые узлы в «пул», чтобы «заполировать» модель батчем побольше.
Но вернёмся к большому батчу. Вся эта идея о том, что можно делать более смелые шаги, если есть более точный градиент, работает, только если у нас поверхность оптимизации более-менее гладкая. А она, к сожалению, — не совсем. Тут слегка могут помочь методы оптимизации второго порядка, такие как Natural Gradient Descent. Возвращаясь к моей великолепной метафоре с путешествием по оптимизационной поверхности, мы видим перед собой холмик и понимаем, что идти надо по дуге. Другое дело, что простыми холмиками дело не ограничивается, и, к сожалению, то, на сколько методы второго порядка позволяют нам увеличить размер батча, тоже имеет свой предел. Еще один момент: в пересчёте на количество шагов модель сходитcя быстрее. Но считать матрицу производных второго порядка для весов модели — очень дорого в смысле вычислительных затрат. На практике применяются приближённые методы, такие как Kronecker-factored Approximate Curvature (K-FAC), но всё равно считать их не так быстро.
Но ладно бы модель просто сходилась медленней. Часто оказывается так, что модель вообще не достигает точности, которая получается с батчем поменьше. Вы уже чувствуете себя как дома на оптимизационной поверхности? Одна из теорий, которые пытаются объяснить, что происходит, предполагает, что на ней есть пологие минимумы, а есть впадины, как бы это сказать, более острой формы. Пологие минимумы — это хорошо. Острые минимумы — не очень. Попробуйте тут остановиться и подумать: «Почему?» Логика тут такая. У нас же, кроме тренировочной выборки, есть еще и тестовая, которая по замыслу соответствует всем тем данным в реальном мире, которые наша модель не имела счастья видеть. И оптимизационная поверхность — она же не просто так, она зависит от данных. Но поскольку тестовая и обучающая выборки у нас по замыслу похожи, то и поверхности ошибки выходят довольно похожими. Но всё-таки немного отличающимися. И вот если мы находимся на дне пологого минимума в обучающей выборке, то, скорее всего, в тестовой выборке, да и вообще в «настоящих» данных, — примерно там тоже большой пологий минимум, и мы в нём тоже где-то внизу. А если мы в остром минимуме — шаг в сторону, и мы уже на поверхности. Обидно!
Вот и картинка с одномерной иллюстрацией этого эффекта.
Так, а при чем тут всё это, спросите вы? Если модель забредёт в такую узкую расщелину, то гладкие устойчивые градиенты не позволяют ей легко оттуда выбраться. А резкие шаги в случайных направлениях, которых у нас в достатке при маленьком батче, позволяют как бы «выпрыгнуть» и дальше уже, в идеале, скатиться в пологую равнинку, где, прыгай или не прыгай, останешься всё равно там, чего нам, собственно говоря, и хотелось. В общем, стохастичность, которую даёт нам маленький батч, очень полезна.
Тут кто-то может резонно предложить просто добавить случайного шума к нашим градиентам — и действительно, так можно сделать! Другое дело, что это посылает коту под хвост нашу идею делать шаги побольше благодаря более точным градиентам.
Статья, в которой предложили это объяснение с острыми и пологими минимумами, называется «On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima». Однако в скором времени аж сам Бенджио пишет в «Sharp Minima Can Generalize For Deep Nets», что, может, и не всё так однозначно. Но так или иначе, на моей практике модель на большом батче действительно может застрять далеко от заданной точности. Еще пара практических трюков, которые могут помочь, — это «рестарт» оптимизатора, т. е. сброс накопленной статистики по градиентам, или более затейливые режимы изменения learning rate, такие как cosine annealing.
Немного конкретики: насколько большой этот самый большой батч? Зависит от данных, от модели, и много от чего еще. В языковых моделях вообще путаница — то считали строки, то отдельные токены. Но раз уж заговорили про гонку «Imagenet за 60 секунд» — на одном GPU можно вместить сотню-другую сэмплов, в зависимости от объёма памяти. Весь ImageNet, если что, — это 1,2 миллиона сэмплов. Не так уж много по современным меркам, но на одном устройстве тогда прогнать всё это через ResNet50 около ста раз (за примерно столько достигалась максимальная точность) занимало дни, если не недели. Раздув размер батча до 8К Goyal, с коллегами смогли это сделать чуть меньше чем за час. Это как раз статья, в которой описали, что надо увеличивать lr и делать прогрев — «Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour». Потом Akiba et al. (Extremely Large Minibatch SGD: Training ResNet-50 on ImageNet in 15 Minutes) со сменой оптимизатора на переправе дожали батч до 32К, а время, соответственно, до 15 минут. Ну и наконец, размер батча в 64К — например, можно посмотреть «Highly Scalable Deep Learning Training System with Mixed-Precision: Training ImageNet in Four Minutes».
В довесок ко всем этим трюкам про которые я рассказывал выше, оказалось что искать подходящие гиперпараметры становится гораздо увлекательней. Например делать weight decay с разным множителем для разных слоёв.
Что можно сказать в заключение? В целом всё не так уж плохо. До разумного предела упомянутые выше техники вполне работают. Увеличиваем learning rate пропорционально размеру батча, делаем прогрев, и оптимизаторы типа AdamW работают вполне себе стабильно. Доходим до момента, где всё практически без сбоев, и если иногда модель «коллапсит», откатываемся на чекпоинт и пробуем на других данных. Часто прокатывает.
А вообще, если данных настолько уж много, то, наверное, имеет смысл взять и модель побольше. А для модели побольше уже имеет смысл применять Model Parallel Training — вместо того, чтобы поделить данные на, скажем, десять частей и тренировать на них копии модели, мы саму модель поделим на десять частей, и проблема с большим батчем уходит сама собой. Насколько можно дробить модель — тоже имеет, конечно, свой предел. На практике имеет смысл гибридный подход: делим модель на какое-то количество частей, и сколько-то таких групп участвуют в data-parallel, как если бы каждая из них была одним узлом. А про то, как делать этот самый data-parallel, можем поговорить в следующий раз, если вам понравился такой подход к повествованию и тема сама по себе. Пока-пока.