์ž๋ฐ”์Šคํฌ๋ฆฝํŠธ๋ฅผ ํ™œ์„ฑํ™” ํ•ด์ฃผ์„ธ์š”

Backward Pass, Optimizer, Batch Normalization

[TIL] ์˜์นด X ๋ฉ‹์Ÿ์ด์‚ฌ์ž์ฒ˜๋Ÿผ (AI ์—”์ง€๋‹ˆ์–ด ์œก์„ฑ ๋ถ€ํŠธ ์บ ํ”„ 2๊ธฐ) 5์ฃผ์ฐจ

 ·  โ˜• 4 min read

๋“ค์–ด๊ฐ€๋ฉฐ


ย ย ย 5์ฃผ์ฐจ ๊ฐ•์˜์— ์ ‘์–ด๋“ค์—ˆ๋‹ค. ์ด๋ฒˆ์žฅ์—์„œ๋Š” MLP์˜ ์ „๋ฐ˜์ ์ธ ๊ณผ์ •์— ๋Œ€ํ•œ ๋‚ด์šฉ์ด์—ˆ๊ณ , ๊ทธ์ค‘์—์„œ Forward Pass, Backward Pass์— ๋Œ€ํ•œ ๋‚ด์šฉ์ด ์ฃผ์š” ๋‚ด์šฉ์ด์—ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ตœ์ ์˜ ๋ชจ๋ธ ์„ฑ๋Šฅ์„ ์œ„ํ•œ Optimizer ์™€, ํ•™์Šต์†๋„์™€ ์•ˆ์ •๋„๋ฅผ ์œ„ํ•ด ์‚ฌ์šฉํ•˜๋Š” Batch Normalization ๋‚ด์šฉ๋„ ํฌํ•จ ๋˜์—ˆ๋‹ค. ๊ฐ€์žฅ ๋ฉ˜๋ถ•์˜ค๋Š” ๊ฐ•์˜์˜€๋˜ ๊ฒƒ ๊ฐ™๋‹ค. ๊ฐ™์€ ํ”ผ์–ด๊ทธ๋ฃน ๋ถ„๋“ค๋„ ๋ชจ๋‘ ๋™์˜ํ•˜์˜€๋‹ค. ๊ทธ๋ž˜๋„ ๊ณต์œ ๋œ ์œ ํŠœ๋ธŒ ๊ฐ•์˜๋‚˜ ๋ธ”๋กœ๊ทธ๋“ฑ์„ ๋ณด๋ฉด์„œ ์กฐ๊ธˆ์ด๋‚˜๋งˆ ์ดํ•ด๋„๋ฅผ ๋†’์ด๋Š”๋ฐ ๋„์›€์ด ๋œ ๊ฒƒ ๊ฐ™๋‹ค.

5์ฃผ์ฐจ


  1. Backward Pass
  2. Optimizer
  3. ์™ธ์šฐ์ง€ ์•Š๊ณ  ๋ฐฐ์šฐ๋Š” ๋ชจ๋ธ
  4. Batch Normalization
  5. (์‹ค์Šต) MLP MNIST classification (2)

Backward Pass


Back Propagation ์ด๋ž€?

  • loss๋ฅผ ์ค„์ด๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ๊ฐ parameter๋ฅผ ์กฐ์ ˆํ•˜๊ธฐ์œ„ํ•ด ๊ฐ parameter์— ๋Œ€ํ•œ loss์˜ ํŽธ๋ฏธ๋ถ„ ๊ฐ’์„ ๊ณ„์‚ฐํ•˜์—ฌ ์ด๋ฅผ ์ด์šฉํ•ด parameter๋ฅผ update
  • ํŽธ๋ฏธ๋ถ„(Partial Derivative): MLP์™€ ๊ฐ™์€ ๋‹ค๋ณ€์ˆ˜ ํ•จ์ˆ˜๋Š” ๊ฐ ๋ณ€์ˆ˜๋“ค์ด ๋ณตํ•ฉ์ ์œผ๋กœ ํ•จ์ˆ˜์— ์˜ํ–ฅ์„ ์ฃผ๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค๋ฅธ ๋ณ€์ˆ˜๋“ค์˜ ๊ฐ’์„ ์ƒ์ˆ˜๋กœ ๋‘” ์ƒํƒœ์—์„œ ํŠน์ • ๋ณ€์ˆ˜์— ๋Œ€ํ•œ ๋„ํ•จ์ˆ˜ ๊ฐ’์„ ๊ณ ๋ ค
  • Chain Rule: ์—ฐ์‡„ ๋ฒ•์น™์„ ํ™œ์šฉํ•˜์—ฌ, Forward Pass ์‹œ ์ €์žฅํ•˜์˜€๋˜ Local Gradient์™€ Back Propagation์œผ๋กœ ์—ญ์œผ๋กœ ๋‚ด๋ ค์˜จ Global Gradient ๊ฐ’์„ ์ด์šฉํ•˜์—ฌ ์ „์ฒด Loss ์— ํŽธ๋ฏธ๋ถ„ ๊ฐ’์„ ๊ตฌํ• ์ˆ˜ ์žˆ๋‹ค.

Forward Pass - Backward Pass

  • MLP์˜ ์—ฐ์‚ฐ๋“ค์€ matrix multiplication๊ณผ nonlinear activation function ์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์•„๋ž˜์™€ ๊ฐ™์ด ๊ตฌ๋ถ„๊ฐ€๋Šฅ

  • Forward Pass ๋“ค์˜ basic operation 4๊ฐ€์ง€๋ฅผ ๋ณด๊ณ , Backward Pass๋ฅผ ํ• ๋•Œ ์—๋Ÿฌ ์‹œ๊ทธ๋„์ด ์–ด๋–ป๊ฒŒ ์ „๋‹ฌ๋˜๋Š”์ง€ ์‚ดํŽด๋ณด์ž.


    Addition Operation
    1: Addition Operation

    Multiplication Operation
    2: Multiplication Operation

    Common Variable Operation
    3: Common Variable Operation

    Nonlinear Active Fucnction
    4: Nonlinear Active Fucnction

  • Nonlinear Activation Function

    • Activation function์˜ backward pass๋ฅผ ์œ„ํ•ด์„œ๋Š” ํ•ด๋‹น node์—์„œ ๊ฐ ํ•จ์ˆ˜์˜ ๋ฏธ๋ถ„๊ฐ’์ด ํ•„์š”.

      • Sigmoid, ReLu …

Optimizer


  • gradient descent๋ฅผ ํ†ตํ•ด์„œ loss์— ๋Œ€ํ•ด์„œ ์ตœ์  parameter ๊ฐ’์„ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด์„œ ์—ฌ๋Ÿฌ๋ฒˆ ์—…๋ฐ์ดํŠธ๋ฅผ ์ง„ํ–‰ํ•œ๋‹ค.
  • Deep Learning ์—์„œ๋Š” parameter space๋Š” ์ฐจ์›์ด ๊ต‰์žฅํžˆ ํฌ๋‹ค. -> global optimal point ์ฐพ๋Š” ๊ฒƒ์ด ๋ถˆ๊ฐ€๋Šฅ.
  • saddle point(๋ง ์•ˆ์žฅ์˜ ๋ชจ์–‘)๋ฅผ ํ”ผํ•˜๊ณ  local minima๋ฅผ ์ฐพ๋Š”๋ฐ ๋ชฉํ‘œ๋กœ ํ•œ๋‹ค.

Gradient-based Methods

  • First-order Optimization Methods
  • Parameter๋ฅผ loss function gradient์˜ ๋ฐ˜๋Œ€๋ฐฉํ–ฅ์œผ๋กœ update ํ•˜์—ฌ loss function์ด ๋” ์ž‘์€ paramter๋ฅผ ์–ป์Œ
  • mini-batch๋ฅผ ์ด์šฉ
  1. Stocahstic Gradient Descent SGD

    • Parameter๋ฅผ gradient ๋ฐ˜๋Œ€ ๋ฐฉํ–ฅ์œผ๋กœ update
    • ๊ฐ€์žฅ๋น ๋ฅด๊ณ  ์‰ฝ๊ฒŒ ์ ์šฉ๊ฐ€๋Šฅํ•˜๋‚˜ saddle points์— ๋น ์ง€๊ธฐ ์‰ฌ์›€
    • Gradeint์— noise๊ฐ€ ๋งŽ์ด ๋ฐœ์ƒ, update์˜ ๋ฐฉํ–ฅ์ด ์ง„๋™ํ•˜๊ธฐ ์‰ฌ์›€
  2. Momentum

    • gradeint๊ฐ€ ๋น ๋ฅด๊ฒŒ ๋ณ€ํ•˜๋Š” ๊ฒƒ์„ ๋ง‰์œผ๋ฉฐ ์ผ๊ด€๋œ ๋ฐฉํ–ฅ์œผ๋กœ update ์œ ๋„
    • Hyper-parameter momentum factor๊ฐ€ ์ถ”๊ฐ€ ๋จ
  3. AdaGrad

    • Update ๋ฐฉํ–ฅ์ด ๊ณผํ•˜๊ฒŒ ์ง„๋™ํ•˜๋Š” ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด prameter-wise update
    • history๋ฅผ ํ†ตํ•ด parameter-wise learning rate ์ ์šฉ
    • Update ์–‘์ด ๋งŽ์€ parameter์˜ update๋ฅผ ์ค„์ด๊ณ , update๊ฐ€ ๋งŽ์ด ์ง„ํ–‰๋˜์ง€ ์•Š์€ parameter์˜ update๋ฅผ ๋Š˜๋ฆผ
    • learning rate๊ฐ€ ๊ณ„์† ๊ฐ์†Œํ•˜์—ฌ, Deap Learning์—์„œ ์‚ฌ์šฉํ•˜๊ธฐ ์–ด๋ ต๋‹ค.
  4. RMSprop

    • AdaGrad์˜ gradient accumulation S์˜ momentum์„ ์ ์šฉ
    • ๋„ˆ๋ฌด ๋จผ ๊ณผ๊ฑฐ์˜ gradient์˜ ํšจ๊ณผ๋ฅผ ์ค„์ž„
  5. Adam

    • ๊ฐ€์žฅ ๋งŽ์ด ์‚ฌ์šฉ๋จ
    • RMSprop๊ณผ momentum์˜ ์กฐํ•ฉ
    • Bias correction์ด๋ผ๋Š” ๊ธฐ๋ฒ•์„ ํ†ตํ•ด ๊ฐ momentum์ด ์ดˆ๋ฐ˜์— ๋ถˆ์•ˆ์ •ํ•˜๊ฒŒ ์ž‘๋™ํ•˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€


  • Learning Rate Scheduling

    • ํ•™์Šต์ด ์งํ–‰๋ ์ˆ˜๋ก parameter๊ฐ€ ์ตœ์  ๊ฐ’์œผ๋กœ ๋‹ค๊ฐ€๊ฐ€๊ธฐ ๋•Œ๋ฌธ์— learning rate๋ฅผ ์ค„์—ฌ ๋” ์ •ํ™•ํ•œ ์ˆ˜๋ ด์„ ์‹œ๋„
    • Linear decay, step decay, exponential decay
  • Parameter Initialization

    • ์ดˆ๊ธฐ parameter ์„ค์ •๋„ ์ค‘์š”ํ•˜๋‹ค.

์™ธ์šฐ์ง€ ์•Š๊ณ  ๋ฐฐ์šฐ๋Š” ๋ชจ๋ธ


  • Regularization

    • Overfitting์„ ๋ง‰๊ธฐ ์œ„ํ•œ ๊ธฐ๋ฒ•
    • Overfitting: ๋ฐ์ดํ„ฐ์˜ ์กด์žฌํ•˜๋Š” noise๊นŒ์ง€ ํ•™์Šตํ•จ์— ๋”ฐ๋ผ ํ•™์Šต ๋ฐ์ดํ„ฐ๊ฐ€ ์•„๋‹Œ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ์ •ํ™•ํ•œ ์ถ”๋ก ์„ ํ•˜์ง€ ๋ชปํ•˜๋Š” ๊ฒฝ์šฐ
    1. Norm Regularizations
    2. Early Stopping
      • Validation set์˜ ์„ฑ๋Šฅํ–ฅ์ƒ์ด ๋” ์ด์ƒ ๋‚˜ํƒ€๋‚˜์ง€ ์•Š์„๋•Œ ํ•™์Šต์„ ๋ฉˆ์ถ”๋Š” ๊ธฐ๋ฒ•
      • ํ•˜์ง€๋งŒ ์‹ค์ œ ํ•™์Šต์‹œ validation ์„ฑ๋Šฅ์ด ํ•œ์ฐธ ์˜ค๋ฅด์ง€ ์•Š์„๋•Œ๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ฃผ์˜๊ฐ€ ํ•„์š”
    3. Ensemble Methods
      • ๋‹ค์–‘ํ•œ hyper-parmeter ์กฐ์ ˆ + randomness
    4. Dropout
      • ๋งค๋ฒˆ forward pass๋ฅผ ํ•  ๋•Œ๋งˆ๋‹ค ์ „์ฒด parameter ์ค‘ ์ผ๋ถ€๋ฅผ masking
      • ๋ชจ๋ธ ์ „์ฒด parameter ์ค‘ ์ผ๋ถ€๋ฅผ ์ด์šฉํ•ด์„œ๋„ ์ข‹์€ ์„ฑ๋Šฅ์„ ์–ป์„์ˆ˜ ์žˆ๋„๋ก ์œ ๋„
      • batch normalizaion? ์ด ๋น„์Šทํ•œ ํšจ๊ณผ๋ฅผ ๋‚ธ๋‹ค

Batch Normalization


  • ํ•™์Šต์•ˆ์ •๋„, ํ•™์Šต์†๋„์— ๋งŽ์€ ๊ฐœ์„ ์„ ์ค€ ์•Œ๊ณ ๋ฆฌ์ฆ˜

  • Activation Distribution Assumption

    • ๋ชจ๋ธ ์ž์ฒด์— ๋Œ€ํ•œ ๋ถ„์„๊ณผ ์—ฌ๋Ÿฌ ์œ ์šฉํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋Œ€๋ถ€๋ถ„ activation๊ณผ parameter ๋ถ„ํฌ์— ๋Œ€ํ•ด Gaussian์„ ๊ฐ€์ •
    • ์‹ค์ œ๋กœ๋Š” ์ด Gaussian ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅด์ง€ ์•Š๋Š”๋‹ค. -> ํ•™์Šต์†๋„๊ฐ€ ๋Š๋ ค์ง€๊ณ  ํ•™์Šต๋ฐฉํ–ฅ์ด ์ผ์ •ํ•˜์ง€ ์•Š๋‹ค.
    • ์ด๋ฅผ ์œ„ํ•ด mini-batch ๋‹จ์œ„๋กœ activation์€ normalize ํ•˜์—ฌ ์›ํ•˜๋Š” ๋ถ„ํฌ๋กœ ๋งŒ๋“ค์–ด์ค€๋‹ค.
  • Batch Normalizaition

    • ๊ฐ layer์˜ activation์„ batch ๋‹จ์œ„๋กœ normalize๋ฅผ ํ•˜์—ฌ ์›ํ•˜๋Š” ๋ถ„ํฌ๋กœ ๋งŒ๋“ค์–ด ์คŒ
    • RNA๋‚˜ lstm์˜ ๊ฒฝ์šฐ ๋ชจ๋ธ์˜ ํŠน์„ฑ์ƒ ์‚ฌ์šฉ์ด ์–ด๋ ต๋‹ค.
    • ํ•™์Šต ๊ณผ์ •์—์„œ๋Š” mini-batch ์ „์ฒด์˜ ์ •๋ณด๋ฅผ ์ด์šฉํ•ด batch-statics๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ normalize์— ์‚ฌ์šฉ

shin alli
๊ธ€์“ด์ด
shin alli
Backend ๊ฐœ๋ฐœ์ž (Python, Django, AWS)