client.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from urllib.parse import urlencode
  2. import requests
  3. from django.urls import reverse
  4. from django.utils.crypto import get_random_string
  5. from requests.exceptions import RequestException
  6. from . import exceptions
  7. SESSION_STATE = "oauth2_state"
  8. STATE_LENGTH = 40
  9. REQUESTS_TIMEOUT = 30
  10. def create_login_url(request):
  11. state = get_random_string(STATE_LENGTH)
  12. request.session[SESSION_STATE] = state
  13. quote = {
  14. "response_type": "code",
  15. "client_id": request.settings.oauth2_client_id,
  16. "redirect_uri": get_redirect_uri(request),
  17. "scope": request.settings.oauth2_scopes,
  18. "state": state,
  19. }
  20. return "%s?%s" % (request.settings.oauth2_login_url, urlencode(quote))
  21. def get_code_grant(request):
  22. session_state = request.session.pop(SESSION_STATE, None)
  23. if request.GET.get("error") == "access_denied":
  24. raise exceptions.OAuth2AccessDeniedError()
  25. if not session_state:
  26. raise exceptions.OAuth2StateNotSetError()
  27. provider_state = request.GET.get("state")
  28. if not provider_state:
  29. raise exceptions.OAuth2StateNotProvidedError()
  30. if provider_state != session_state:
  31. raise exceptions.OAuth2StateMismatchError()
  32. code_grant = request.GET.get("code")
  33. if not code_grant:
  34. raise exceptions.OAuth2CodeNotProvidedError()
  35. return code_grant
  36. def get_access_token(request, code_grant):
  37. token_url = request.settings.oauth2_token_url
  38. data = {
  39. "grant_type": "authorization_code",
  40. "client_id": request.settings.oauth2_client_id,
  41. "client_secret": request.settings.oauth2_client_secret,
  42. "redirect_uri": get_redirect_uri(request),
  43. "code": code_grant,
  44. }
  45. headers = get_headers_dict(request.settings.oauth2_token_extra_headers)
  46. try:
  47. if request.settings.oauth2_token_method == "GET":
  48. token_url += "&" if "?" in token_url else "?"
  49. token_url += urlencode(data)
  50. r = requests.get(
  51. token_url,
  52. headers=headers,
  53. timeout=REQUESTS_TIMEOUT,
  54. )
  55. else:
  56. r = requests.post(
  57. token_url,
  58. data=data,
  59. headers=headers,
  60. timeout=REQUESTS_TIMEOUT,
  61. )
  62. except RequestException:
  63. raise exceptions.OAuth2AccessTokenRequestError()
  64. if r.status_code != 200:
  65. raise exceptions.OAuth2AccessTokenResponseError()
  66. try:
  67. response_json = r.json()
  68. if not isinstance(response_json, dict):
  69. raise TypeError()
  70. except (ValueError, TypeError):
  71. raise exceptions.OAuth2AccessTokenJSONError()
  72. access_token = get_value_from_json(
  73. request.settings.oauth2_json_token_path,
  74. response_json,
  75. )
  76. if not access_token:
  77. raise exceptions.OAuth2AccessTokenNotProvidedError()
  78. return access_token
  79. JSON_MAPPING = {
  80. "id": "oauth2_json_id_path",
  81. "name": "oauth2_json_name_path",
  82. "email": "oauth2_json_email_path",
  83. "avatar": "oauth2_json_avatar_path",
  84. }
  85. def get_user_data(request, access_token):
  86. headers = get_headers_dict(request.settings.oauth2_user_extra_headers)
  87. user_url = request.settings.oauth2_user_url
  88. if request.settings.oauth2_user_token_location == "QUERY":
  89. user_url += "&" if "?" in user_url else "?"
  90. user_url += urlencode({request.settings.oauth2_user_token_name: access_token})
  91. elif request.settings.oauth2_user_token_location == "HEADER_BEARER":
  92. headers[request.settings.oauth2_user_token_name] = f"Bearer {access_token}"
  93. else:
  94. headers[request.settings.oauth2_user_token_name] = access_token
  95. try:
  96. if request.settings.oauth2_user_method == "GET":
  97. r = requests.get(user_url, headers=headers, timeout=REQUESTS_TIMEOUT)
  98. else:
  99. r = requests.post(user_url, headers=headers, timeout=REQUESTS_TIMEOUT)
  100. except RequestException:
  101. raise exceptions.OAuth2UserDataRequestError()
  102. if r.status_code != 200:
  103. raise exceptions.OAuth2UserDataResponseError()
  104. try:
  105. response_json = r.json()
  106. if not isinstance(response_json, dict):
  107. raise TypeError()
  108. except (ValueError, TypeError):
  109. raise exceptions.OAuth2UserDataJSONError()
  110. return {
  111. key: get_value_from_json(getattr(request.settings, setting), response_json)
  112. for key, setting in JSON_MAPPING.items()
  113. }
  114. def get_redirect_uri(request):
  115. return request.build_absolute_uri(reverse("misago:oauth2-complete"))
  116. def get_headers_dict(headers_str):
  117. headers = {}
  118. if not headers_str:
  119. return headers
  120. for header in headers_str.splitlines():
  121. header = header.strip()
  122. if ":" not in header:
  123. continue
  124. header_name, header_value = [part.strip() for part in header.split(":", 1)]
  125. if header_name and header_value:
  126. headers[header_name] = header_value
  127. return headers
  128. def get_value_from_json(path, json):
  129. if not path:
  130. return None
  131. if "." not in path:
  132. return str(json.get(path, "")).strip() or None
  133. data = json
  134. for path_part in path.split("."):
  135. data = data.get(path_part)
  136. if not data:
  137. return None
  138. return data