Federated Learning-FedSGD, FedAvg
๐ฅ Key Takeaways
- 1๏ธโฃ FL ๊ฐ๋ ์ ์ฒ์์ผ๋ก ์ ์ํ์๋ค.
- 2๏ธโฃ IID, Non-IID ์๊ด์์ด ๊ฐ ๋ชจ๋ธ์ intialization point๊ฐ ๊ฐ์์ผ updated local weight๋ค์ averageํ์ฌ global weight์ ๋ฐ์ํ๋ FedAVGํ์ต ๋ฐฉ์์ด ์ ์๋ํ๋ค.
Federated Learning
- Background: ๊ฐ์ธ์ ๋ณด ๋ณดํธ๊ฐ ์ค์ํ๊ฑฐ๋, ๋ฐ์ดํฐ ์์ด ๋ง์ ๊ฒฝ์ฐ traditionalํ data center์์์ ํ์ต์ด ์ด๋ ต๋ค.
- Federated Learning: ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ฐ์ผ ๊ธฐ๊ธฐ์ ๋ถ์ฐ์ํจ ์ฑ๋ก, local์์ ๊ณ์ฐ๋ updated๋ฅผ ๋ชจ์์ ๊ณต์ ๋ชจ๋ธ์ ํ์ตํ๋ ๋ฐฉ๋ฒ
์ฅ์
- raw training data์ ์ง์ ์ ๊ทผํ ํ์ ์์ด ํ์ต์ ์งํํ ์ ์๋ค.
- ์๋ฒ์ ๋ํ ์ ๋ขฐ๋ ์ฌ์ ํ ํ์ํ์ง๋ง, ํ์ต ๋ชฉํ๊ฐ ๊ฐ ํด๋ผ์ด์ธํธ์ ๋ฐ์ดํฐ๋ก ์ ์๋ ์ ์๋ค๋ฉด, Federated Learning์ ๊ฐ์ธ์ ๋ณด ๋ณดํธ์ ๋ณด์ risk๋ฅผ ํฌ๊ฒ ์ค์ผ ์ ์์.
- ๊ณต๊ฒฉ ํ๋ฉด์ด ๊ธฐ๊ธฐ๋ง ๊ตญํ๋๋ฏ๋ก, ํด๋ผ์ฐ๋์์ ์ฐ๊ฒฐ์ ์ต์ํํ ์ ์์.
Federated Optimization
- Federated Optimization: Federated learning์์ ์๋ฌต์ ์ผ๋ก ๋ฐ์ํ๋ ์ต์ ํ ๋ฌธ์
- ๋ถ์ฐ ์ต์ ํ์์ ์ฐจ์ด:
- 1๏ธโฃ Non-IID: ๊ฐ ํด๋ผ์ด์ธํธ์ ํ์ต ๋ฐ์ดํฐ๋ ํน์ ์ฌ์ฉ์์ ๊ธฐ๊ธฐ ์ฌ์ฉ์ ๋ฐ๋ผ ๋ฌ๋ผ์ ธ ํด๋น ํด๋ผ์ด์ธํธ์ ๋ฐ์ดํฐ์ ์ ์ ์ฒด ๋ถํฌ๋ฅผ ๋ํํ์ง ์์.
- 2๏ธโฃ Unbalanced: ์ผ๋ถ ์ฌ์ฉ์๋ ๋ค๋ฅธ ์ฌ์ฉ์๋ณด๋ค ์๋น์ค๋ฅผ ํจ์ฌ ๋ ๋ง์ด ์ฌ์ฉํ์ฌ, ๋ก์ปฌ ๋ฐ์ดํฐ ์์ด ๋ค๋ฆ.
- 3๏ธโฃ Massively distributed: ์ฐธ์ฌํ๋ ํด๋ผ์ด์ธํธ์ ์๊ฐ ๊ฐ ํด๋ผ์ด์ธํธ์ ์์ ์๋ณด๋ค ํจ์ฌ ๋ง์.
- 4๏ธโฃ Limited communication: ๋ชจ๋ฐ์ผ ๊ธฐ๊ธฐ๋ ์ข ์ข ์คํ๋ผ์ธ ์ํ๊ฑฐ๋ ๋๋ฆฌ๊ฑฐ๋ ๋น์ผ ์ฐ๊ฒฐ์ ์ฌ์ฉ.
- ์ด ๋ ผ๋ฌธ์์๋ non-IID์ unbalanced ํน์ฑ์ ์ค์ ์ ๋๊ณ , communication cost์ ์ค์์ฑ์ ์ดํด ๋ณธ๋ค.
Communication cost
- communication cost: client์ server ๊ฐ์ ์๋ฃ ์ก์์ (communication)์ ๋๋ ๋น์ฉ
- parameter ํฌ๊ธฐ, # device, ๊ทธ๋ฆฌ๊ณ server์ client๊ฐ์ ๊ฑฐ๋ฆฌ์ ์ํฅ์ ๋ฐ๋๋ค.
- Data center optimization
- communication costs๊ฐ ์๋์ ์ผ๋ก ์๊ณ , computation costs๊ฐ ํฌ๋ค.
- ์ต๊ทผ GPU์ ๋ฐ๋ฌ๋ก computation costs๊ฐ ์ค์ด๋ค์๋ค.
- federated optimization
- communication costs๊ฐ ์ง๋ฐฐ์ ์ด๋ค. ๋ณดํต upload bandwidth๋ 1MB/s ์ดํ๋ก ์ ํ๋๋ค
- client๋ ์ฃผ๋ก ์ถฉ์ ์ค์ด๊ณ , ์ ๋ฃ๊ฐ ์๋ Wi-Fi ์ฐ๊ฒฐ์ ์์ ๋๋ง ์ต์ ํ์ ์ฐธ์ฌํ๋ค.
- ๊ฐ client๋ ํ๋ฃจ์ ๋ช ๋ฒ๋ง update rounds์๋ง ์ฐธ์ฌํ ์ ์๋ค.
- ๋ฐ๋ฉด, ๊ฐ ํด๋ผ์ด์ธํธ์ ๋ฐ์ดํฐ์ ์ ์ ์ฒด ๋ฐ์ดํฐ์ ์ ๋นํด ์๊ธฐ ๋๋ฌธ์, ํ๋ smartphone์ processor(GPU ํฌํจ)๋ ๊ณ์ฐ ๋น์ฉ์ ๊ฑฐ์ ๋ฌด์ํ ์ ์์ ์ ๋๋ก ๋น ๋ฅด๋ค.
- ๋ฐ๋ผ์ ๋ชจ๋ธ ํ์ต์ ์ํ communication rounds ์๋ฅผ ์ค์ด๊ธฐ ์ํด ์ถ๊ฐ์ ์ธ ๊ณ์ฐ์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ชฉํ์ด๋ค. ์ด๋ฅผ ์ํ ๋ ๊ฐ์ง ์ฃผ์ ๋ฐฉ๋ฒ์ด ์๋ค:
- 1๏ธโฃ Increased parallelism: ๊ฐ communication round ๊ฐ์ ๋ ๋ง์ ํด๋ผ์ด์ธํธ๋ฅผ ๋ ๋ฆฝ์ ์ผ๋ก ์์ ํ๊ฒ ํ๋ค.
- 2๏ธโฃ Increased computation per client: ๊ฐ ํด๋ผ์ด์ธํธ๊ฐ ๊ฐ๋จํ ๊ณ์ฐ(์: gradient ๊ณ์ฐ)์ ์ํํ๋ ๋์ , ๋ ๋ณต์กํ ๊ณ์ฐ์ ์ํํ๋ค.
- ์คํ์์๋ ๋ ๊ฐ์ง ์ ๊ทผ ๋ฐฉ์์ ๋ชจ๋ ์กฐ์ฌํ์ง๋ง, ํด๋ผ์ด์ธํธ ๊ฐ ์ต์ํ์ ๋ณ๋ ฌ์ฑ์ ์ฌ์ฉํ ๊ฒฝ์ฐ ์ป๋ ์๋ ํฅ์์ ์ฃผ๋ก ๊ฐ ํด๋ผ์ด์ธํธ์์ ๋ ๋ง์ ๊ณ์ฐ์ ์ถ๊ฐํ๋ ๋ฐ(2๏ธโฃ)์ ๋์จ๋ค.
1. FedSGD(Stochastic Gradient Descent)
- FL ๊ฐ๋ ์ค๋ช ์ ์ํด ๋์จ baseline ๋ฐฉ๋ฒ(Naรฏve algorithm)์ด๋ฉฐ, ์ค์ ๋ก ์ ์ฐ์ด์ง ์๋๋ค.
- Synchronous update ๋ฐฉ์์ ๊ฐ์ ํ์ฌ comminucation rounds๋ก ์งํ๋๋ค.
- ํ์ต ๊ณผ์
- 1๏ธโฃ Global weight initialization
- 2๏ธโฃ Client sampling with client fraction hyper-parameter $C$
- ex. $C=0.75$: client 4๊ฐ ์ค 3๊ฐ sampling
- 3๏ธโฃ Local learning
- Server: global weight์ client๋ก ๋ณด๋ธ๋ค.
- Client: ๊ฐ์ ๊ฐ์ง local data๋ก ์๋ก์ด gradient๋ฅผ ๊ณ์ฐํ๋ค.
- 4๏ธโฃ Update parameter
- Client: ๊ณ์ฐ๋ gradient๋ฅผ server์ ๋ณด๋ธ๋ค.
- Server: ๋ฐ์ gradients๋ฅผ ๊ฐ์คํ๊ท ๋ด์ด ์๋กญ๊ฒ global weight๋ฅผ updateํ๋ค.
- 5๏ธโฃ 2~4๋ฒ ๋ฐ๋ณต
- Hyper-parameter
- client fraction hyper-parameter $C$
2. FedAvg
- FedAVG = FedSGD + mini-batch ๊ฐ๋
- FedSGD = $B=\infty \land E=1$ ์ธ FedAVG
- ์ฃผ์ Hyper-parameter
- client fraction hyper-parameter $C$
- mini-batch size $B$
- epoch $E$
- Synchronous update ๋ฐฉ์์ ๊ฐ์ ํ์ฌ comminucation rounds๋ก ์งํ๋๋ค.
- ํ์ต ๊ณผ์
- 1๏ธโฃ Global weight initialization
- 2๏ธโฃ Client sampling with client fraction hyper-parameter $C$
- ex. $C=0.75$: client 4๊ฐ ์ค 3๊ฐ sampling
- 3๏ธโฃ Local learning
- Server: global weight์ client๋ก ๋ณด๋ธ๋ค.
- Client: ๊ฐ์ ๊ฐ์ง local data๋ก ์๋ก์ด gradient๋ฅผ mini-batch ๋จ์($B$)๋ก ๊ณ์ฐํ๋ค. ๋ชจ๋ mini-batch๋ฅผ ๋ค ๋๋ฉด ํ๋์ epoch๊ฐ ์งํ๋ ๊ฒ์ด๋ค. ํ๋์ client๋น ํ๋์ weight๊ฐ ๋์ค๊ฒ ๋๋ค.
- 4๏ธโฃ Update parameter
- Client: ๊ณ์ฐ๋ gradient๋ฅผ server์ ๋ณด๋ธ๋ค.
- Server: ๋ฐ์ gradients๋ฅผ ๊ฐ์คํ๊ท ๋ด์ด ์๋กญ๊ฒ global weight๋ฅผ updateํ๋ค.
- 5๏ธโฃ 2~4๋ฒ ๋ฐ๋ณต
- Algorithm์ pseudo-code๋ ๋ค์๊ณผ ๊ฐ๋ค.
FedAVG๋ communication round๋ฅผ ํฌ๊ฒ ์ค์ฌ, ๋ถ์ฐ ๋ฐ์ดํฐ์์ ๋ชจ๋ธ์ ํ์ตํ๋ ๋ฐ ๋๋ ์๊ฐ์ ๋ํญ ๋จ์ถ์์ผฐ๋ค. ๋ํ ์ ์๋ค์ averaging process๊ฐ ๋ง์น dropout๊ณผ ๊ฐ์ regularization์ฒ๋ผ ์์ฉํ๋ ๊ฒ์ผ๋ก ์ถ์ธกํ๋ค.
Robust to imbalanced, Non-IID data distribution
์ผ๋ฐ์ ์ธ Non-convex objectives์์๋ ํ๋ผ๋ฏธํฐ ๊ณต๊ฐ์์ ๋ชจ๋ธ์ ํ๊ท ํํ๋ ๊ฒ ๋ชจ๋ธ์ ์ข์ง ์๋ค. ํ์ง๋ง ์ด ๋ ผ๋ฌธ์ FedAVG๋ฅผ ์ ์ฉํ๋ฉด ๋ชจ๋ธ์ ํฌ๊ฒ ๊ฐ์ ์ํฌ ์ ์๋ค๊ณ ์ฃผ์ฅํ์๋ค. ์ ์๋ค์ MNIST ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์๋ค.
๋ MNIST ์ซ์ ์ธ์ ๋ชจ๋ธ $w$์ $wโ$๋ ๊ฐ๊ฐ MNIST ํ๋ จ ์ธํธ์์ 600๊ฐ์ ์๋ก ๋ค๋ฅธ IID ์ํ๋ก ํ๋ จ๋์๋ค. ํ๋ จ์ SGD ๋ฐฉ์์ผ๋ก, fixed learning rate 0.1๋ก 240๋ฒ์ ์ ๋ฐ์ดํธ๋ฅผ ๊ฑฐ์ณ ์ด๋ฃจ์ด์ก๋ค. ๊ฐ mini-batch ํฌ๊ธฐ๋ 50์ด๋ฉฐ, mini-dataset ํฌ๊ธฐ 600์ ๋ํด 20๋ฒ์ ํจ์ค๋ฅผ ์ํํ๋ค. ์ด ์ ๋ ํ๋ จ์ด ์งํ๋๋ฉด, ๋ชจ๋ธ์ ๊ฐ์์ ๋ก์ปฌ ๋ฐ์ดํฐ์ ์ ๊ณผ์ ํฉ๋๊ธฐ ์์ํ๋ค.
- ์ผ์ชฝ Figure: ๋ ๋ชจ๋ธ์ ์๋ก ๋ค๋ฅธ random initialization์์๋ถํฐ ํ๋ จํ ํ ํ๊ท ํํ ๋ชจ์ต. ์ข์ง ์์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋ค.
- ์ค๋ฅธ์ชฝ Figure(FedAVG): ๋ ๋ชจ๋ธ์ same random initialization์์ ์์ํ๊ณ , ๊ฐ ๋ชจ๋ธ์ ๋ฐ์ดํฐ๋ฅผ ์๋ก ๋ค๋ฅธ ๋ถ๋ถ ์งํฉ์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ํ๋ จ์ํจ ํ ๋ชจ๋ธ์ ํ๊ท ํํ ๋ชจ์ต
์ฆ, ๊ฐ ๋ชจ๋ธ์ intialization point๊ฐ ๊ฐ์์ผ ๋์ค์ ์ ๋ฐ์ดํธ๋์ด ๋์จ weight๋ค์ ํ๊ท ๋ด์์๋ ์ด ํ์ต ๋ฐฉ์์ด ์ ์๋ํ๋ค!(IID, Non-IID ์๊ด์์ด) ์ด ์คํ์ ํตํด ๋ถ๊ท ํ์ ์ด๊ณ ๋น๋ ๋ฆฝ์ ์ด์ง ์์ ๋ฐ์ดํฐ ๋ถํฌ์๋ robustํจ์ ์ ์ฆํ์๋ค.
Hyper-parameter tuning ์์
- $C \rightarrow B \rightarrow E$
- $E$๋ฅผ ๋๋ฆฌ๋ ๊ฒ์ $B$๋ฅผ ์ค์ด๋ ๊ฒ๋ณด๋ค ๋ ๋ง์ ์๊ฐ์ ํ์๋ก ํจ.