client.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from urllib.parse import urlencode
  2. import requests
  3. from django.urls import reverse
  4. from django.utils.crypto import get_random_string
  5. from requests.exceptions import RequestException
  6. from . import exceptions
  7. SESSION_STATE = "oauth2_state"
  8. STATE_LENGTH = 40
  9. REQUESTS_TIMEOUT = 30
  10. def create_login_url(request):
  11. state = get_random_string(STATE_LENGTH)
  12. request.session[SESSION_STATE] = state
  13. quote = {
  14. "response_type": "code",
  15. "client_id": request.settings.oauth2_client_id,
  16. "redirect_uri": get_redirect_uri(request),
  17. "scope": request.settings.oauth2_scopes,
  18. "state": state,
  19. }
  20. return "%s?%s" % (request.settings.oauth2_login_url, urlencode(quote))
  21. def get_code_grant(request):
  22. session_state = request.session.pop(SESSION_STATE, None)
  23. if request.GET.get("error") == "access_denied":
  24. raise exceptions.OAuth2AccessDeniedError()
  25. if not session_state:
  26. raise exceptions.OAuth2StateNotSetError()
  27. provider_state = request.GET.get("state")
  28. if not provider_state:
  29. raise exceptions.OAuth2StateNotProvidedError()
  30. if provider_state != session_state:
  31. raise exceptions.OAuth2StateMismatchError()
  32. code_grant = request.GET.get("code")
  33. if not code_grant:
  34. raise exceptions.OAuth2CodeNotProvidedError()
  35. return code_grant
  36. def get_access_token(request, code_grant):
  37. token_url = request.settings.oauth2_token_url
  38. data = {
  39. "grant_type": "authorization_code",
  40. "client_id": request.settings.oauth2_client_id,
  41. "client_secret": request.settings.oauth2_client_secret,
  42. "redirect_uri": get_redirect_uri(request),
  43. "code": code_grant,
  44. }
  45. headers = get_headers_dict(request.settings.oauth2_token_extra_headers)
  46. try:
  47. r = requests.post(
  48. token_url,
  49. data=data,
  50. headers=headers,
  51. timeout=REQUESTS_TIMEOUT,
  52. )
  53. except RequestException:
  54. raise exceptions.OAuth2AccessTokenRequestError()
  55. if r.status_code != 200:
  56. raise exceptions.OAuth2AccessTokenResponseError()
  57. try:
  58. response_json = r.json()
  59. if not isinstance(response_json, dict):
  60. raise TypeError()
  61. except (ValueError, TypeError):
  62. raise exceptions.OAuth2AccessTokenJSONError()
  63. access_token = get_value_from_json(
  64. request.settings.oauth2_json_token_path,
  65. response_json,
  66. )
  67. if not access_token:
  68. raise exceptions.OAuth2AccessTokenNotProvidedError()
  69. return access_token
  70. JSON_MAPPING = {
  71. "id": "oauth2_json_id_path",
  72. "name": "oauth2_json_name_path",
  73. "email": "oauth2_json_email_path",
  74. "avatar": "oauth2_json_avatar_path",
  75. }
  76. def get_user_data(request, access_token):
  77. headers = get_headers_dict(request.settings.oauth2_user_extra_headers)
  78. user_url = request.settings.oauth2_user_url
  79. if request.settings.oauth2_user_token_location == "QUERY":
  80. user_url += "&" if "?" in user_url else "?"
  81. user_url += urlencode({request.settings.oauth2_user_token_name: access_token})
  82. elif request.settings.oauth2_user_token_location == "HEADER_BEARER":
  83. headers[request.settings.oauth2_user_token_name] = f"Bearer {access_token}"
  84. else:
  85. headers[request.settings.oauth2_user_token_name] = access_token
  86. try:
  87. if request.settings.oauth2_user_method == "GET":
  88. r = requests.get(user_url, headers=headers, timeout=REQUESTS_TIMEOUT)
  89. else:
  90. r = requests.post(user_url, headers=headers, timeout=REQUESTS_TIMEOUT)
  91. except RequestException:
  92. raise exceptions.OAuth2UserDataRequestError()
  93. if r.status_code != 200:
  94. raise exceptions.OAuth2UserDataResponseError()
  95. try:
  96. response_json = r.json()
  97. if not isinstance(response_json, dict):
  98. raise TypeError()
  99. except (ValueError, TypeError):
  100. raise exceptions.OAuth2UserDataJSONError()
  101. clean_data = {
  102. key: get_value_from_json(getattr(request.settings, setting), response_json)
  103. for key, setting in JSON_MAPPING.items()
  104. }
  105. return clean_data, response_json
  106. def get_redirect_uri(request):
  107. return request.build_absolute_uri(reverse("misago:oauth2-complete"))
  108. def get_headers_dict(headers_str):
  109. headers = {}
  110. if not headers_str:
  111. return headers
  112. for header in headers_str.splitlines():
  113. header = header.strip()
  114. if ":" not in header:
  115. continue
  116. header_name, header_value = [part.strip() for part in header.split(":", 1)]
  117. if header_name and header_value:
  118. headers[header_name] = header_value
  119. return headers
  120. def get_value_from_json(path, json):
  121. if not path:
  122. return None
  123. if "." not in path:
  124. return str(json.get(path, "")).strip() or None
  125. data = json
  126. for path_part in path.split("."):
  127. data = data.get(path_part)
  128. if not data:
  129. return None
  130. return data