Post

Federated Learning-FedSGD, FedAvg

Federated Learning-FedSGD, FedAvg

Paper Link

๐Ÿฅ‘ Key Takeaways

  • 1๏ธโƒฃ FL ๊ฐœ๋…์„ ์ฒ˜์Œ์œผ๋กœ ์ œ์‹œํ•˜์˜€๋‹ค.
  • 2๏ธโƒฃ IID, Non-IID ์ƒ๊ด€์—†์ด ๊ฐ ๋ชจ๋ธ์˜ intialization point๊ฐ€ ๊ฐ™์•„์•ผ updated local weight๋“ค์„ averageํ•˜์—ฌ global weight์— ๋ฐ˜์˜ํ•˜๋Š” FedAVGํ•™์Šต ๋ฐฉ์‹์ด ์ž˜ ์ž‘๋™ํ•œ๋‹ค.

Federated Learning

  • Background: ๊ฐœ์ธ์ •๋ณด ๋ณดํ˜ธ๊ฐ€ ์ค‘์š”ํ•˜๊ฑฐ๋‚˜, ๋ฐ์ดํ„ฐ ์–‘์ด ๋งŽ์€ ๊ฒฝ์šฐ traditionalํ•œ data center์—์„œ์˜ ํ•™์Šต์ด ์–ด๋ ต๋‹ค.
  • Federated Learning: ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ฐ”์ผ ๊ธฐ๊ธฐ์— ๋ถ„์‚ฐ์‹œํ‚จ ์ฑ„๋กœ, local์—์„œ ๊ณ„์‚ฐ๋œ updated๋ฅผ ๋ชจ์•„์„œ ๊ณต์œ  ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•

์žฅ์ 

  • raw training data์— ์ง์ ‘ ์ ‘๊ทผํ•  ํ•„์š” ์—†์ด ํ•™์Šต์„ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ์„œ๋ฒ„์— ๋Œ€ํ•œ ์‹ ๋ขฐ๋Š” ์—ฌ์ „ํžˆ ํ•„์š”ํ•˜์ง€๋งŒ, ํ•™์Šต ๋ชฉํ‘œ๊ฐ€ ๊ฐ ํด๋ผ์ด์–ธํŠธ์˜ ๋ฐ์ดํ„ฐ๋กœ ์ •์˜๋  ์ˆ˜ ์žˆ๋‹ค๋ฉด, Federated Learning์€ ๊ฐœ์ธ์ •๋ณด ๋ณดํ˜ธ์™€ ๋ณด์•ˆ risk๋ฅผ ํฌ๊ฒŒ ์ค„์ผ ์ˆ˜ ์žˆ์Œ.
  • ๊ณต๊ฒฉ ํ‘œ๋ฉด์ด ๊ธฐ๊ธฐ๋งŒ ๊ตญํ•œ๋˜๋ฏ€๋กœ, ํด๋ผ์šฐ๋“œ์™€์˜ ์—ฐ๊ฒฐ์„ ์ตœ์†Œํ™”ํ•  ์ˆ˜ ์žˆ์Œ.

Federated Optimization

  • Federated Optimization: Federated learning์—์„œ ์•”๋ฌต์ ์œผ๋กœ ๋ฐœ์ƒํ•˜๋Š” ์ตœ์ ํ™” ๋ฌธ์ œ
  • ๋ถ„์‚ฐ ์ตœ์ ํ™”์™€์˜ ์ฐจ์ด:
    • 1๏ธโƒฃ Non-IID: ๊ฐ ํด๋ผ์ด์–ธํŠธ์˜ ํ•™์Šต ๋ฐ์ดํ„ฐ๋Š” ํŠน์ • ์‚ฌ์šฉ์ž์˜ ๊ธฐ๊ธฐ ์‚ฌ์šฉ์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ ธ ํ•ด๋‹น ํด๋ผ์ด์–ธํŠธ์˜ ๋ฐ์ดํ„ฐ์…‹์€ ์ „์ฒด ๋ถ„ํฌ๋ฅผ ๋Œ€ํ‘œํ•˜์ง€ ์•Š์Œ.
    • 2๏ธโƒฃ Unbalanced: ์ผ๋ถ€ ์‚ฌ์šฉ์ž๋Š” ๋‹ค๋ฅธ ์‚ฌ์šฉ์ž๋ณด๋‹ค ์„œ๋น„์Šค๋ฅผ ํ›จ์”ฌ ๋” ๋งŽ์ด ์‚ฌ์šฉํ•˜์—ฌ, ๋กœ์ปฌ ๋ฐ์ดํ„ฐ ์–‘์ด ๋‹ค๋ฆ„.
    • 3๏ธโƒฃ Massively distributed: ์ฐธ์—ฌํ•˜๋Š” ํด๋ผ์ด์–ธํŠธ์˜ ์ˆ˜๊ฐ€ ๊ฐ ํด๋ผ์ด์–ธํŠธ์˜ ์˜ˆ์‹œ ์ˆ˜๋ณด๋‹ค ํ›จ์”ฌ ๋งŽ์Œ.
    • 4๏ธโƒฃ Limited communication: ๋ชจ๋ฐ”์ผ ๊ธฐ๊ธฐ๋Š” ์ข…์ข… ์˜คํ”„๋ผ์ธ ์ƒํƒœ๊ฑฐ๋‚˜ ๋А๋ฆฌ๊ฑฐ๋‚˜ ๋น„์‹ผ ์—ฐ๊ฒฐ์„ ์‚ฌ์šฉ.
  • ์ด ๋…ผ๋ฌธ์—์„œ๋Š” non-IID์™€ unbalanced ํŠน์„ฑ์— ์ค‘์ ์„ ๋‘๊ณ , communication cost์˜ ์ค‘์š”์„ฑ์„ ์‚ดํŽด ๋ณธ๋‹ค.

Communication cost

  • communication cost: client์™€ server ๊ฐ„์˜ ์ž๋ฃŒ ์†ก์ˆ˜์‹ (communication)์— ๋“œ๋Š” ๋น„์šฉ
  • parameter ํฌ๊ธฐ, # device, ๊ทธ๋ฆฌ๊ณ  server์™€ client๊ฐ„์˜ ๊ฑฐ๋ฆฌ์˜ ์˜ํ–ฅ์„ ๋ฐ›๋Š”๋‹ค.
  • Data center optimization
    • communication costs๊ฐ€ ์ƒ๋Œ€์ ์œผ๋กœ ์ž‘๊ณ , computation costs๊ฐ€ ํฌ๋‹ค.
    • ์ตœ๊ทผ GPU์˜ ๋ฐœ๋‹ฌ๋กœ computation costs๊ฐ€ ์ค„์–ด๋“ค์—ˆ๋‹ค.
  • federated optimization
    • communication costs๊ฐ€ ์ง€๋ฐฐ์ ์ด๋‹ค. ๋ณดํ†ต upload bandwidth๋Š” 1MB/s ์ดํ•˜๋กœ ์ œํ•œ๋œ๋‹ค
    • client๋Š” ์ฃผ๋กœ ์ถฉ์ „ ์ค‘์ด๊ณ , ์œ ๋ฃŒ๊ฐ€ ์•„๋‹Œ Wi-Fi ์—ฐ๊ฒฐ์— ์žˆ์„ ๋•Œ๋งŒ ์ตœ์ ํ™”์— ์ฐธ์—ฌํ•œ๋‹ค.
    • ๊ฐ client๋Š” ํ•˜๋ฃจ์— ๋ช‡ ๋ฒˆ๋งŒ update rounds์—๋งŒ ์ฐธ์—ฌํ•  ์ˆ˜ ์žˆ๋‹ค.
    • ๋ฐ˜๋ฉด, ๊ฐ ํด๋ผ์ด์–ธํŠธ์˜ ๋ฐ์ดํ„ฐ์…‹์€ ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์— ๋น„ํ•ด ์ž‘๊ธฐ ๋•Œ๋ฌธ์—, ํ˜„๋Œ€ smartphone์˜ processor(GPU ํฌํ•จ)๋Š” ๊ณ„์‚ฐ ๋น„์šฉ์„ ๊ฑฐ์˜ ๋ฌด์‹œํ•  ์ˆ˜ ์žˆ์„ ์ •๋„๋กœ ๋น ๋ฅด๋‹ค.
  • ๋”ฐ๋ผ์„œ ๋ชจ๋ธ ํ•™์Šต์„ ์œ„ํ•œ communication rounds ์ˆ˜๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•ด ์ถ”๊ฐ€์ ์ธ ๊ณ„์‚ฐ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ์ด๋‹ค. ์ด๋ฅผ ์œ„ํ•œ ๋‘ ๊ฐ€์ง€ ์ฃผ์š” ๋ฐฉ๋ฒ•์ด ์žˆ๋‹ค:
    • 1๏ธโƒฃ Increased parallelism: ๊ฐ communication round ๊ฐ„์— ๋” ๋งŽ์€ ํด๋ผ์ด์–ธํŠธ๋ฅผ ๋…๋ฆฝ์ ์œผ๋กœ ์ž‘์—…ํ•˜๊ฒŒ ํ•œ๋‹ค.
    • 2๏ธโƒฃ Increased computation per client: ๊ฐ ํด๋ผ์ด์–ธํŠธ๊ฐ€ ๊ฐ„๋‹จํ•œ ๊ณ„์‚ฐ(์˜ˆ: gradient ๊ณ„์‚ฐ)์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋Œ€์‹ , ๋” ๋ณต์žกํ•œ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.
    • ์‹คํ—˜์—์„œ๋Š” ๋‘ ๊ฐ€์ง€ ์ ‘๊ทผ ๋ฐฉ์‹์„ ๋ชจ๋‘ ์กฐ์‚ฌํ•˜์ง€๋งŒ, ํด๋ผ์ด์–ธํŠธ ๊ฐ„ ์ตœ์†Œํ•œ์˜ ๋ณ‘๋ ฌ์„ฑ์„ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ ์–ป๋Š” ์†๋„ ํ–ฅ์ƒ์€ ์ฃผ๋กœ ๊ฐ ํด๋ผ์ด์–ธํŠธ์—์„œ ๋” ๋งŽ์€ ๊ณ„์‚ฐ์„ ์ถ”๊ฐ€ํ•˜๋Š” ๋ฐ(2๏ธโƒฃ)์„œ ๋‚˜์˜จ๋‹ค.

1. FedSGD(Stochastic Gradient Descent)

  • FL ๊ฐœ๋… ์„ค๋ช…์„ ์œ„ํ•ด ๋‚˜์˜จ baseline ๋ฐฉ๋ฒ•(Naรฏve algorithm)์ด๋ฉฐ, ์‹ค์ œ๋กœ ์ž˜ ์“ฐ์ด์ง€ ์•Š๋Š”๋‹ค.
  • Synchronous update ๋ฐฉ์‹์„ ๊ฐ€์ •ํ•˜์—ฌ comminucation rounds๋กœ ์ง„ํ–‰๋œ๋‹ค.

image

  • ํ•™์Šต ๊ณผ์ •
    • 1๏ธโƒฃ Global weight initialization
    • 2๏ธโƒฃ Client sampling with client fraction hyper-parameter $C$
      • ex. $C=0.75$: client 4๊ฐœ ์ค‘ 3๊ฐœ sampling
    • 3๏ธโƒฃ Local learning
      • Server: global weight์„ client๋กœ ๋ณด๋‚ธ๋‹ค.
      • Client: ๊ฐ์ž ๊ฐ€์ง„ local data๋กœ ์ƒˆ๋กœ์šด gradient๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
    • 4๏ธโƒฃ Update parameter
      • Client: ๊ณ„์‚ฐ๋œ gradient๋ฅผ server์— ๋ณด๋‚ธ๋‹ค.
      • Server: ๋ฐ›์€ gradients๋ฅผ ๊ฐ€์ค‘ํ‰๊ท  ๋‚ด์–ด ์ƒˆ๋กญ๊ฒŒ global weight๋ฅผ updateํ•œ๋‹ค.
    • 5๏ธโƒฃ 2~4๋ฒˆ ๋ฐ˜๋ณต
  • Hyper-parameter
    • client fraction hyper-parameter $C$

2. FedAvg

  • FedAVG = FedSGD + mini-batch ๊ฐœ๋…
  • FedSGD = $B=\infty \land E=1$ ์ธ FedAVG
  • ์ฃผ์š” Hyper-parameter
    • client fraction hyper-parameter $C$
    • mini-batch size $B$
    • epoch $E$
  • Synchronous update ๋ฐฉ์‹์„ ๊ฐ€์ •ํ•˜์—ฌ comminucation rounds๋กœ ์ง„ํ–‰๋œ๋‹ค.
  • ํ•™์Šต ๊ณผ์ •
    • 1๏ธโƒฃ Global weight initialization
    • 2๏ธโƒฃ Client sampling with client fraction hyper-parameter $C$
      • ex. $C=0.75$: client 4๊ฐœ ์ค‘ 3๊ฐœ sampling
    • 3๏ธโƒฃ Local learning
      • Server: global weight์„ client๋กœ ๋ณด๋‚ธ๋‹ค.
      • Client: ๊ฐ์ž ๊ฐ€์ง„ local data๋กœ ์ƒˆ๋กœ์šด gradient๋ฅผ mini-batch ๋‹จ์œ„($B$)๋กœ ๊ณ„์‚ฐํ•œ๋‹ค. ๋ชจ๋“  mini-batch๋ฅผ ๋‹ค ๋Œ๋ฉด ํ•˜๋‚˜์˜ epoch๊ฐ€ ์ง„ํ–‰๋œ ๊ฒƒ์ด๋‹ค. ํ•˜๋‚˜์˜ client๋‹น ํ•˜๋‚˜์˜ weight๊ฐ€ ๋‚˜์˜ค๊ฒŒ ๋œ๋‹ค.
    • 4๏ธโƒฃ Update parameter
      • Client: ๊ณ„์‚ฐ๋œ gradient๋ฅผ server์— ๋ณด๋‚ธ๋‹ค.
      • Server: ๋ฐ›์€ gradients๋ฅผ ๊ฐ€์ค‘ํ‰๊ท  ๋‚ด์–ด ์ƒˆ๋กญ๊ฒŒ global weight๋ฅผ updateํ•œ๋‹ค.
    • 5๏ธโƒฃ 2~4๋ฒˆ ๋ฐ˜๋ณต
  • Algorithm์˜ pseudo-code๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. image

FedAVG๋Š” communication round๋ฅผ ํฌ๊ฒŒ ์ค„์—ฌ, ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ์—์„œ ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ๋ฐ ๋“œ๋Š” ์‹œ๊ฐ„์„ ๋Œ€ํญ ๋‹จ์ถ•์‹œ์ผฐ๋‹ค. ๋˜ํ•œ ์ €์ž๋“ค์€ averaging process๊ฐ€ ๋งˆ์น˜ dropout๊ณผ ๊ฐ™์€ regularization์ฒ˜๋Ÿผ ์ž‘์šฉํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ถ”์ธกํ•œ๋‹ค.

Robust to imbalanced, Non-IID data distribution

์ผ๋ฐ˜์ ์ธ Non-convex objectives์—์„œ๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณต๊ฐ„์—์„œ ๋ชจ๋ธ์„ ํ‰๊ท ํ™”ํ•˜๋Š” ๊ฒŒ ๋ชจ๋ธ์— ์ข‹์ง€ ์•Š๋‹ค. ํ•˜์ง€๋งŒ ์ด ๋…ผ๋ฌธ์€ FedAVG๋ฅผ ์ ์šฉํ•˜๋ฉด ๋ชจ๋ธ์„ ํฌ๊ฒŒ ๊ฐœ์„ ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค๊ณ  ์ฃผ์žฅํ•˜์˜€๋‹ค. ์ €์ž๋“ค์€ MNIST ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜์˜€๋‹ค.

image

๋‘ MNIST ์ˆซ์ž ์ธ์‹ ๋ชจ๋ธ $w$์™€ $wโ€™$๋Š” ๊ฐ๊ฐ MNIST ํ›ˆ๋ จ ์„ธํŠธ์—์„œ 600๊ฐœ์˜ ์„œ๋กœ ๋‹ค๋ฅธ IID ์ƒ˜ํ”Œ๋กœ ํ›ˆ๋ จ๋˜์—ˆ๋‹ค. ํ›ˆ๋ จ์€ SGD ๋ฐฉ์‹์œผ๋กœ, fixed learning rate 0.1๋กœ 240๋ฒˆ์˜ ์—…๋ฐ์ดํŠธ๋ฅผ ๊ฑฐ์ณ ์ด๋ฃจ์–ด์กŒ๋‹ค. ๊ฐ mini-batch ํฌ๊ธฐ๋Š” 50์ด๋ฉฐ, mini-dataset ํฌ๊ธฐ 600์— ๋Œ€ํ•ด 20๋ฒˆ์˜ ํŒจ์Šค๋ฅผ ์ˆ˜ํ–‰ํ–ˆ๋‹ค. ์ด ์ •๋„ ํ›ˆ๋ จ์ด ์ง„ํ–‰๋˜๋ฉด, ๋ชจ๋ธ์€ ๊ฐ์ž์˜ ๋กœ์ปฌ ๋ฐ์ดํ„ฐ์…‹์— ๊ณผ์ ํ•ฉ๋˜๊ธฐ ์‹œ์ž‘ํ•œ๋‹ค.

  • ์™ผ์ชฝ Figure: ๋‘ ๋ชจ๋ธ์„ ์„œ๋กœ ๋‹ค๋ฅธ random initialization์—์„œ๋ถ€ํ„ฐ ํ›ˆ๋ จํ•œ ํ›„ ํ‰๊ท ํ™”ํ•œ ๋ชจ์Šต. ์ข‹์ง€ ์•Š์€ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋‹ค.
  • ์˜ค๋ฅธ์ชฝ Figure(FedAVG): ๋‘ ๋ชจ๋ธ์„ same random initialization์—์„œ ์‹œ์ž‘ํ•˜๊ณ , ๊ฐ ๋ชจ๋ธ์„ ๋ฐ์ดํ„ฐ๋ฅผ ์„œ๋กœ ๋‹ค๋ฅธ ๋ถ€๋ถ„ ์ง‘ํ•ฉ์— ๋Œ€ํ•ด ๋…๋ฆฝ์ ์œผ๋กœ ํ›ˆ๋ จ์‹œํ‚จ ํ›„ ๋ชจ๋ธ์„ ํ‰๊ท ํ™”ํ•œ ๋ชจ์Šต

์ฆ‰, ๊ฐ ๋ชจ๋ธ์˜ intialization point๊ฐ€ ๊ฐ™์•„์•ผ ๋‚˜์ค‘์— ์—…๋ฐ์ดํŠธ๋˜์–ด ๋‚˜์˜จ weight๋“ค์„ ํ‰๊ท ๋‚ด์—ˆ์„๋•Œ ์ด ํ•™์Šต ๋ฐฉ์‹์ด ์ž˜ ์ž‘๋™ํ•œ๋‹ค!(IID, Non-IID ์ƒ๊ด€์—†์ด) ์ด ์‹คํ—˜์„ ํ†ตํ•ด ๋ถˆ๊ท ํ˜•์ ์ด๊ณ  ๋น„๋…๋ฆฝ์ ์ด์ง€ ์•Š์€ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ์—๋„ robustํ•จ์„ ์ž…์ฆํ•˜์˜€๋‹ค.

Hyper-parameter tuning ์ˆœ์„œ

  • $C \rightarrow B \rightarrow E$
  • $E$๋ฅผ ๋Š˜๋ฆฌ๋Š” ๊ฒƒ์€ $B$๋ฅผ ์ค„์ด๋Š” ๊ฒƒ๋ณด๋‹ค ๋” ๋งŽ์€ ์‹œ๊ฐ„์„ ํ•„์š”๋กœ ํ•จ.

Reference

This post is licensed under CC BY 4.0 by the author.

ยฉ Su. Some rights reserved.

Using the Chirpy theme for Jekyll.