oauth2.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package controllers
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/astaxie/beego/logs"
  7. "github.com/astaxie/beego/orm"
  8. "io"
  9. "nginx-ui/server/config"
  10. "nginx-ui/server/models"
  11. "nginx-ui/server/utils"
  12. )
  13. type Oauth2Controller struct {
  14. BaseController
  15. }
  16. type Oauth2SSOReq struct {
  17. Code string `json:"code"`
  18. Scope string `json:"scope"`
  19. State string `json:"state"`
  20. }
  21. // Get 获取oauth2.0的登录url
  22. func (c *Oauth2Controller) Get() {
  23. state, err := utils.RandPassword(6)
  24. if err != nil {
  25. c.ErrorJson(err)
  26. return
  27. }
  28. url := config.OauthConfig.AuthCodeURL(state)
  29. c.addRespData("redirect_url", url).addRespData("state", state).json()
  30. }
  31. // Callback 用户注册
  32. func (c *Oauth2Controller) Callback() {
  33. var ssoReq Oauth2SSOReq
  34. err := json.Unmarshal(c.Ctx.Input.RequestBody, &ssoReq)
  35. if err != nil {
  36. logs.Error(err, string(c.Ctx.Input.RequestBody))
  37. c.ErrorJson(err)
  38. return
  39. }
  40. oauth := config.OauthConfig
  41. if len(ssoReq.Code) == 0 {
  42. c.setCode(-1).setMsg("登录失败(Code):code is empty").json()
  43. return
  44. }
  45. token, err := oauth.Exchange(context.Background(), ssoReq.Code)
  46. if err != nil {
  47. logs.Error("ExchangeToken", err)
  48. c.setCode(-1).setMsg("登录失败(Exchange):" + err.Error()).json()
  49. return
  50. }
  51. client := oauth.Client(context.Background(), token)
  52. resp, err := client.Get(oauth.Userinfo)
  53. if err != nil {
  54. logs.Error("GetUserinfo", err)
  55. c.setCode(-1).setMsg(fmt.Sprintf("登录失败(Userinfo):%s", err.Error())).json()
  56. return
  57. }
  58. defer resp.Body.Close()
  59. content, err := io.ReadAll(resp.Body)
  60. if err != nil {
  61. logs.Error("GetUserinfo Read Body", err)
  62. c.setCode(-1).setMsg(fmt.Sprintf("登录失败(Userinfo):%s", err.Error())).json()
  63. return
  64. }
  65. user := models.User{}
  66. err = json.Unmarshal(content, &user)
  67. if err != nil {
  68. logs.Error("GetUserinfo Unmarshal", err)
  69. c.setCode(-1).setMsg(fmt.Sprintf("登录失败(Unmarshal):%s", err.Error())).json()
  70. return
  71. }
  72. if len(user.Account) == 0 {
  73. c.setCode(-1).setMsg("登录失败,请确认userinfo接口返回了account字段").json()
  74. return
  75. }
  76. o := orm.NewOrm()
  77. err = o.Read(&user, "Account")
  78. if err != nil {
  79. _, err = o.Insert(&user)
  80. }
  81. user.Password = ""
  82. if err != nil {
  83. c.setCode(-1).setMsg(fmt.Sprintf("保存用户失败:%s", err.Error())).json()
  84. return
  85. }
  86. c.SetSession("user", user)
  87. c.setData(user).json()
  88. }