oauth2.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package controllers
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/astaxie/beego/logs"
  7. "github.com/astaxie/beego/orm"
  8. "io"
  9. "nginx-ui/server/config"
  10. "nginx-ui/server/models"
  11. "nginx-ui/server/utils"
  12. )
  13. type Oauth2Controller struct {
  14. BaseController
  15. }
  16. type Oauth2SSOReq struct {
  17. Code string `json:"code"`
  18. Scope string `json:"scope"`
  19. State string `json:"state"`
  20. }
  21. // Get 获取oauth2.0的登录url
  22. func (c *Oauth2Controller) Get() {
  23. state, err := utils.RandPassword(6)
  24. if err != nil {
  25. c.ErrorJson(err)
  26. return
  27. }
  28. url := config.OauthConfig.AuthCodeURL(state)
  29. c.addRespData("redirect_url", url).addRespData("state", state).json()
  30. }
  31. // Callback 用户注册
  32. func (c *Oauth2Controller) Callback() {
  33. var ssoReq Oauth2SSOReq
  34. err := json.Unmarshal(c.Ctx.Input.RequestBody, &ssoReq)
  35. if err != nil {
  36. logs.Error(err, string(c.Ctx.Input.RequestBody))
  37. c.ErrorJson(err)
  38. return
  39. }
  40. oauth := config.OauthConfig
  41. if len(ssoReq.Code) == 0 {
  42. c.setCode(-1).setMsg("登录失败(Code):code is empty").json()
  43. return
  44. }
  45. token, err := oauth.Exchange(context.Background(), ssoReq.Code)
  46. if err != nil {
  47. logs.Error("ExchangeToken", err)
  48. c.setCode(-1).setMsg("登录失败(Exchange):" + err.Error()).json()
  49. return
  50. }
  51. client := oauth.Client(context.Background(), token)
  52. resp, err := client.Get(oauth.Userinfo)
  53. if err != nil {
  54. logs.Error("GetUserinfo", err)
  55. c.setCode(-1).setMsg(fmt.Sprintf("登录失败(Userinfo):%s", err.Error())).json()
  56. return
  57. }
  58. defer resp.Body.Close()
  59. content, err := io.ReadAll(resp.Body)
  60. if err != nil {
  61. logs.Error("GetUserinfo Read Body", err)
  62. c.setCode(-1).setMsg(fmt.Sprintf("登录失败(Userinfo):%s", err.Error())).json()
  63. return
  64. }
  65. user := models.User{}
  66. err = json.Unmarshal(content, &user)
  67. if err != nil {
  68. logs.Error("GetUserinfo Unmarshal", err)
  69. }
  70. if len(user.Account) == 0 {
  71. c.setCode(-1).setMsg("登录失败,请确认userinfo接口返回了account字段").json()
  72. return
  73. }
  74. if len(user.Nickname) == 0 {
  75. user.Nickname = user.Account
  76. }
  77. o := orm.NewOrm()
  78. err = o.Read(&user, "Account")
  79. if err != nil {
  80. _, err = o.Insert(&user)
  81. }
  82. user.Password = ""
  83. if err != nil {
  84. c.setCode(-1).setMsg(fmt.Sprintf("保存用户失败:%s", err.Error())).json()
  85. return
  86. }
  87. c.SetSession("user", user)
  88. c.setData(user).json()
  89. }