detached #1
|
|
@ -21,8 +21,8 @@ yarn-error.log
|
||||||
/.nova
|
/.nova
|
||||||
/.vscode
|
/.vscode
|
||||||
/.zed
|
/.zed
|
||||||
/bin
|
/bot/bin
|
||||||
/include
|
/bot/include
|
||||||
/lib
|
/bot/lib
|
||||||
/lib64
|
/bot/lib64
|
||||||
pyvenv.cfg
|
pyvenv.cfg
|
||||||
|
|
|
||||||
|
|
@ -1,247 +0,0 @@
|
||||||
<#
|
|
||||||
.Synopsis
|
|
||||||
Activate a Python virtual environment for the current PowerShell session.
|
|
||||||
|
|
||||||
.Description
|
|
||||||
Pushes the python executable for a virtual environment to the front of the
|
|
||||||
$Env:PATH environment variable and sets the prompt to signify that you are
|
|
||||||
in a Python virtual environment. Makes use of the command line switches as
|
|
||||||
well as the `pyvenv.cfg` file values present in the virtual environment.
|
|
||||||
|
|
||||||
.Parameter VenvDir
|
|
||||||
Path to the directory that contains the virtual environment to activate. The
|
|
||||||
default value for this is the parent of the directory that the Activate.ps1
|
|
||||||
script is located within.
|
|
||||||
|
|
||||||
.Parameter Prompt
|
|
||||||
The prompt prefix to display when this virtual environment is activated. By
|
|
||||||
default, this prompt is the name of the virtual environment folder (VenvDir)
|
|
||||||
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1
|
|
||||||
Activates the Python virtual environment that contains the Activate.ps1 script.
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1 -Verbose
|
|
||||||
Activates the Python virtual environment that contains the Activate.ps1 script,
|
|
||||||
and shows extra information about the activation as it executes.
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
|
|
||||||
Activates the Python virtual environment located in the specified location.
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1 -Prompt "MyPython"
|
|
||||||
Activates the Python virtual environment that contains the Activate.ps1 script,
|
|
||||||
and prefixes the current prompt with the specified string (surrounded in
|
|
||||||
parentheses) while the virtual environment is active.
|
|
||||||
|
|
||||||
.Notes
|
|
||||||
On Windows, it may be required to enable this Activate.ps1 script by setting the
|
|
||||||
execution policy for the user. You can do this by issuing the following PowerShell
|
|
||||||
command:
|
|
||||||
|
|
||||||
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
|
||||||
|
|
||||||
For more information on Execution Policies:
|
|
||||||
https://go.microsoft.com/fwlink/?LinkID=135170
|
|
||||||
|
|
||||||
#>
|
|
||||||
Param(
|
|
||||||
[Parameter(Mandatory = $false)]
|
|
||||||
[String]
|
|
||||||
$VenvDir,
|
|
||||||
[Parameter(Mandatory = $false)]
|
|
||||||
[String]
|
|
||||||
$Prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
<# Function declarations --------------------------------------------------- #>
|
|
||||||
|
|
||||||
<#
|
|
||||||
.Synopsis
|
|
||||||
Remove all shell session elements added by the Activate script, including the
|
|
||||||
addition of the virtual environment's Python executable from the beginning of
|
|
||||||
the PATH variable.
|
|
||||||
|
|
||||||
.Parameter NonDestructive
|
|
||||||
If present, do not remove this function from the global namespace for the
|
|
||||||
session.
|
|
||||||
|
|
||||||
#>
|
|
||||||
function global:deactivate ([switch]$NonDestructive) {
|
|
||||||
# Revert to original values
|
|
||||||
|
|
||||||
# The prior prompt:
|
|
||||||
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
|
|
||||||
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
|
|
||||||
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
|
|
||||||
}
|
|
||||||
|
|
||||||
# The prior PYTHONHOME:
|
|
||||||
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
|
|
||||||
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
|
|
||||||
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
|
|
||||||
}
|
|
||||||
|
|
||||||
# The prior PATH:
|
|
||||||
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
|
|
||||||
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
|
|
||||||
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
|
|
||||||
}
|
|
||||||
|
|
||||||
# Just remove the VIRTUAL_ENV altogether:
|
|
||||||
if (Test-Path -Path Env:VIRTUAL_ENV) {
|
|
||||||
Remove-Item -Path env:VIRTUAL_ENV
|
|
||||||
}
|
|
||||||
|
|
||||||
# Just remove VIRTUAL_ENV_PROMPT altogether.
|
|
||||||
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
|
|
||||||
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
|
|
||||||
}
|
|
||||||
|
|
||||||
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
|
|
||||||
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
|
|
||||||
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
|
|
||||||
}
|
|
||||||
|
|
||||||
# Leave deactivate function in the global namespace if requested:
|
|
||||||
if (-not $NonDestructive) {
|
|
||||||
Remove-Item -Path function:deactivate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
<#
|
|
||||||
.Description
|
|
||||||
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
|
|
||||||
given folder, and returns them in a map.
|
|
||||||
|
|
||||||
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
|
|
||||||
two strings separated by `=` (with any amount of whitespace surrounding the =)
|
|
||||||
then it is considered a `key = value` line. The left hand string is the key,
|
|
||||||
the right hand is the value.
|
|
||||||
|
|
||||||
If the value starts with a `'` or a `"` then the first and last character is
|
|
||||||
stripped from the value before being captured.
|
|
||||||
|
|
||||||
.Parameter ConfigDir
|
|
||||||
Path to the directory that contains the `pyvenv.cfg` file.
|
|
||||||
#>
|
|
||||||
function Get-PyVenvConfig(
|
|
||||||
[String]
|
|
||||||
$ConfigDir
|
|
||||||
) {
|
|
||||||
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
|
|
||||||
|
|
||||||
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
|
|
||||||
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
|
|
||||||
|
|
||||||
# An empty map will be returned if no config file is found.
|
|
||||||
$pyvenvConfig = @{ }
|
|
||||||
|
|
||||||
if ($pyvenvConfigPath) {
|
|
||||||
|
|
||||||
Write-Verbose "File exists, parse `key = value` lines"
|
|
||||||
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
|
|
||||||
|
|
||||||
$pyvenvConfigContent | ForEach-Object {
|
|
||||||
$keyval = $PSItem -split "\s*=\s*", 2
|
|
||||||
if ($keyval[0] -and $keyval[1]) {
|
|
||||||
$val = $keyval[1]
|
|
||||||
|
|
||||||
# Remove extraneous quotations around a string value.
|
|
||||||
if ("'""".Contains($val.Substring(0, 1))) {
|
|
||||||
$val = $val.Substring(1, $val.Length - 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
$pyvenvConfig[$keyval[0]] = $val
|
|
||||||
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return $pyvenvConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
<# Begin Activate script --------------------------------------------------- #>
|
|
||||||
|
|
||||||
# Determine the containing directory of this script
|
|
||||||
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
|
|
||||||
$VenvExecDir = Get-Item -Path $VenvExecPath
|
|
||||||
|
|
||||||
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
|
|
||||||
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
|
|
||||||
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
|
|
||||||
|
|
||||||
# Set values required in priority: CmdLine, ConfigFile, Default
|
|
||||||
# First, get the location of the virtual environment, it might not be
|
|
||||||
# VenvExecDir if specified on the command line.
|
|
||||||
if ($VenvDir) {
|
|
||||||
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
|
|
||||||
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
|
|
||||||
Write-Verbose "VenvDir=$VenvDir"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Next, read the `pyvenv.cfg` file to determine any required value such
|
|
||||||
# as `prompt`.
|
|
||||||
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
|
|
||||||
|
|
||||||
# Next, set the prompt from the command line, or the config file, or
|
|
||||||
# just use the name of the virtual environment folder.
|
|
||||||
if ($Prompt) {
|
|
||||||
Write-Verbose "Prompt specified as argument, using '$Prompt'"
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
|
|
||||||
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
|
|
||||||
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
|
|
||||||
$Prompt = $pyvenvCfg['prompt'];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
|
|
||||||
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
|
|
||||||
$Prompt = Split-Path -Path $venvDir -Leaf
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Write-Verbose "Prompt = '$Prompt'"
|
|
||||||
Write-Verbose "VenvDir='$VenvDir'"
|
|
||||||
|
|
||||||
# Deactivate any currently active virtual environment, but leave the
|
|
||||||
# deactivate function in place.
|
|
||||||
deactivate -nondestructive
|
|
||||||
|
|
||||||
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
|
|
||||||
# that there is an activated venv.
|
|
||||||
$env:VIRTUAL_ENV = $VenvDir
|
|
||||||
|
|
||||||
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
|
|
||||||
|
|
||||||
Write-Verbose "Setting prompt to '$Prompt'"
|
|
||||||
|
|
||||||
# Set the prompt to include the env name
|
|
||||||
# Make sure _OLD_VIRTUAL_PROMPT is global
|
|
||||||
function global:_OLD_VIRTUAL_PROMPT { "" }
|
|
||||||
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
|
|
||||||
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
|
|
||||||
|
|
||||||
function global:prompt {
|
|
||||||
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
|
|
||||||
_OLD_VIRTUAL_PROMPT
|
|
||||||
}
|
|
||||||
$env:VIRTUAL_ENV_PROMPT = $Prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
# Clear PYTHONHOME
|
|
||||||
if (Test-Path -Path Env:PYTHONHOME) {
|
|
||||||
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
|
|
||||||
Remove-Item -Path Env:PYTHONHOME
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add the venv to the PATH
|
|
||||||
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
|
|
||||||
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
# This file must be used with "source bin/activate" *from bash*
|
|
||||||
# you cannot run it directly
|
|
||||||
|
|
||||||
deactivate () {
|
|
||||||
# reset old environment variables
|
|
||||||
if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
|
|
||||||
PATH="${_OLD_VIRTUAL_PATH:-}"
|
|
||||||
export PATH
|
|
||||||
unset _OLD_VIRTUAL_PATH
|
|
||||||
fi
|
|
||||||
if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
|
|
||||||
PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
|
|
||||||
export PYTHONHOME
|
|
||||||
unset _OLD_VIRTUAL_PYTHONHOME
|
|
||||||
fi
|
|
||||||
|
|
||||||
# This should detect bash and zsh, which have a hash command that must
|
|
||||||
# be called to get it to forget past commands. Without forgetting
|
|
||||||
# past commands the $PATH changes we made may not be respected
|
|
||||||
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
|
|
||||||
hash -r 2> /dev/null
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
|
|
||||||
PS1="${_OLD_VIRTUAL_PS1:-}"
|
|
||||||
export PS1
|
|
||||||
unset _OLD_VIRTUAL_PS1
|
|
||||||
fi
|
|
||||||
|
|
||||||
unset VIRTUAL_ENV
|
|
||||||
unset VIRTUAL_ENV_PROMPT
|
|
||||||
if [ ! "${1:-}" = "nondestructive" ] ; then
|
|
||||||
# Self destruct!
|
|
||||||
unset -f deactivate
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# unset irrelevant variables
|
|
||||||
deactivate nondestructive
|
|
||||||
|
|
||||||
VIRTUAL_ENV="/var/www/html/su-secret-santa/bot"
|
|
||||||
export VIRTUAL_ENV
|
|
||||||
|
|
||||||
_OLD_VIRTUAL_PATH="$PATH"
|
|
||||||
PATH="$VIRTUAL_ENV/bin:$PATH"
|
|
||||||
export PATH
|
|
||||||
|
|
||||||
# unset PYTHONHOME if set
|
|
||||||
# this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
|
|
||||||
# could use `if (set -u; : $PYTHONHOME) ;` in bash
|
|
||||||
if [ -n "${PYTHONHOME:-}" ] ; then
|
|
||||||
_OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
|
|
||||||
unset PYTHONHOME
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
|
|
||||||
_OLD_VIRTUAL_PS1="${PS1:-}"
|
|
||||||
PS1="(bot) ${PS1:-}"
|
|
||||||
export PS1
|
|
||||||
VIRTUAL_ENV_PROMPT="(bot) "
|
|
||||||
export VIRTUAL_ENV_PROMPT
|
|
||||||
fi
|
|
||||||
|
|
||||||
# This should detect bash and zsh, which have a hash command that must
|
|
||||||
# be called to get it to forget past commands. Without forgetting
|
|
||||||
# past commands the $PATH changes we made may not be respected
|
|
||||||
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
|
|
||||||
hash -r 2> /dev/null
|
|
||||||
fi
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
# This file must be used with "source bin/activate.csh" *from csh*.
|
|
||||||
# You cannot run it directly.
|
|
||||||
# Created by Davide Di Blasi <davidedb@gmail.com>.
|
|
||||||
# Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
|
|
||||||
|
|
||||||
alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
|
|
||||||
|
|
||||||
# Unset irrelevant variables.
|
|
||||||
deactivate nondestructive
|
|
||||||
|
|
||||||
setenv VIRTUAL_ENV "/var/www/html/su-secret-santa/bot"
|
|
||||||
|
|
||||||
set _OLD_VIRTUAL_PATH="$PATH"
|
|
||||||
setenv PATH "$VIRTUAL_ENV/bin:$PATH"
|
|
||||||
|
|
||||||
|
|
||||||
set _OLD_VIRTUAL_PROMPT="$prompt"
|
|
||||||
|
|
||||||
if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
|
|
||||||
set prompt = "(bot) $prompt"
|
|
||||||
setenv VIRTUAL_ENV_PROMPT "(bot) "
|
|
||||||
endif
|
|
||||||
|
|
||||||
alias pydoc python -m pydoc
|
|
||||||
|
|
||||||
rehash
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
# This file must be used with "source <venv>/bin/activate.fish" *from fish*
|
|
||||||
# (https://fishshell.com/); you cannot run it directly.
|
|
||||||
|
|
||||||
function deactivate -d "Exit virtual environment and return to normal shell environment"
|
|
||||||
# reset old environment variables
|
|
||||||
if test -n "$_OLD_VIRTUAL_PATH"
|
|
||||||
set -gx PATH $_OLD_VIRTUAL_PATH
|
|
||||||
set -e _OLD_VIRTUAL_PATH
|
|
||||||
end
|
|
||||||
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
|
|
||||||
set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
|
|
||||||
set -e _OLD_VIRTUAL_PYTHONHOME
|
|
||||||
end
|
|
||||||
|
|
||||||
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
|
|
||||||
set -e _OLD_FISH_PROMPT_OVERRIDE
|
|
||||||
# prevents error when using nested fish instances (Issue #93858)
|
|
||||||
if functions -q _old_fish_prompt
|
|
||||||
functions -e fish_prompt
|
|
||||||
functions -c _old_fish_prompt fish_prompt
|
|
||||||
functions -e _old_fish_prompt
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
set -e VIRTUAL_ENV
|
|
||||||
set -e VIRTUAL_ENV_PROMPT
|
|
||||||
if test "$argv[1]" != "nondestructive"
|
|
||||||
# Self-destruct!
|
|
||||||
functions -e deactivate
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# Unset irrelevant variables.
|
|
||||||
deactivate nondestructive
|
|
||||||
|
|
||||||
set -gx VIRTUAL_ENV "/var/www/html/su-secret-santa/bot"
|
|
||||||
|
|
||||||
set -gx _OLD_VIRTUAL_PATH $PATH
|
|
||||||
set -gx PATH "$VIRTUAL_ENV/bin" $PATH
|
|
||||||
|
|
||||||
# Unset PYTHONHOME if set.
|
|
||||||
if set -q PYTHONHOME
|
|
||||||
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
|
|
||||||
set -e PYTHONHOME
|
|
||||||
end
|
|
||||||
|
|
||||||
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
|
|
||||||
# fish uses a function instead of an env var to generate the prompt.
|
|
||||||
|
|
||||||
# Save the current fish_prompt function as the function _old_fish_prompt.
|
|
||||||
functions -c fish_prompt _old_fish_prompt
|
|
||||||
|
|
||||||
# With the original prompt function renamed, we can override with our own.
|
|
||||||
function fish_prompt
|
|
||||||
# Save the return status of the last command.
|
|
||||||
set -l old_status $status
|
|
||||||
|
|
||||||
# Output the venv prompt; color taken from the blue of the Python logo.
|
|
||||||
printf "%s%s%s" (set_color 4B8BBE) "(bot) " (set_color normal)
|
|
||||||
|
|
||||||
# Restore the return status of the previous command.
|
|
||||||
echo "exit $old_status" | .
|
|
||||||
# Output the original/"old" prompt.
|
|
||||||
_old_fish_prompt
|
|
||||||
end
|
|
||||||
|
|
||||||
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
|
|
||||||
set -gx VIRTUAL_ENV_PROMPT "(bot) "
|
|
||||||
end
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
#!/var/www/html/su-secret-santa/bot/bin/python3.11
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from dotenv.__main__ import cli
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(cli())
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
#!/var/www/html/su-secret-santa/bot/bin/python3.11
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pip._internal.cli.main import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
#!/var/www/html/su-secret-santa/bot/bin/python3.11
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pip._internal.cli.main import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
#!/var/www/html/su-secret-santa/bot/bin/python3.11
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pip._internal.cli.main import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
python3.11
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
python3.11
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
/usr/bin/python3.11
|
|
||||||
|
|
@ -1,222 +0,0 @@
|
||||||
# don't import any costly modules
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
is_pypy = '__pypy__' in sys.builtin_module_names
|
|
||||||
|
|
||||||
|
|
||||||
def warn_distutils_present():
|
|
||||||
if 'distutils' not in sys.modules:
|
|
||||||
return
|
|
||||||
if is_pypy and sys.version_info < (3, 7):
|
|
||||||
# PyPy for 3.6 unconditionally imports distutils, so bypass the warning
|
|
||||||
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
|
|
||||||
return
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"Distutils was imported before Setuptools, but importing Setuptools "
|
|
||||||
"also replaces the `distutils` module in `sys.modules`. This may lead "
|
|
||||||
"to undesirable behaviors or errors. To avoid these issues, avoid "
|
|
||||||
"using distutils directly, ensure that setuptools is installed in the "
|
|
||||||
"traditional way (e.g. not an editable install), and/or make sure "
|
|
||||||
"that setuptools is always imported before distutils."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def clear_distutils():
|
|
||||||
if 'distutils' not in sys.modules:
|
|
||||||
return
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn("Setuptools is replacing distutils.")
|
|
||||||
mods = [
|
|
||||||
name
|
|
||||||
for name in sys.modules
|
|
||||||
if name == "distutils" or name.startswith("distutils.")
|
|
||||||
]
|
|
||||||
for name in mods:
|
|
||||||
del sys.modules[name]
|
|
||||||
|
|
||||||
|
|
||||||
def enabled():
|
|
||||||
"""
|
|
||||||
Allow selection of distutils by environment variable.
|
|
||||||
"""
|
|
||||||
which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'local')
|
|
||||||
return which == 'local'
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_local_distutils():
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
clear_distutils()
|
|
||||||
|
|
||||||
# With the DistutilsMetaFinder in place,
|
|
||||||
# perform an import to cause distutils to be
|
|
||||||
# loaded from setuptools._distutils. Ref #2906.
|
|
||||||
with shim():
|
|
||||||
importlib.import_module('distutils')
|
|
||||||
|
|
||||||
# check that submodules load as expected
|
|
||||||
core = importlib.import_module('distutils.core')
|
|
||||||
assert '_distutils' in core.__file__, core.__file__
|
|
||||||
assert 'setuptools._distutils.log' not in sys.modules
|
|
||||||
|
|
||||||
|
|
||||||
def do_override():
|
|
||||||
"""
|
|
||||||
Ensure that the local copy of distutils is preferred over stdlib.
|
|
||||||
|
|
||||||
See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
|
|
||||||
for more motivation.
|
|
||||||
"""
|
|
||||||
if enabled():
|
|
||||||
warn_distutils_present()
|
|
||||||
ensure_local_distutils()
|
|
||||||
|
|
||||||
|
|
||||||
class _TrivialRe:
|
|
||||||
def __init__(self, *patterns):
|
|
||||||
self._patterns = patterns
|
|
||||||
|
|
||||||
def match(self, string):
|
|
||||||
return all(pat in string for pat in self._patterns)
|
|
||||||
|
|
||||||
|
|
||||||
class DistutilsMetaFinder:
|
|
||||||
def find_spec(self, fullname, path, target=None):
|
|
||||||
# optimization: only consider top level modules and those
|
|
||||||
# found in the CPython test suite.
|
|
||||||
if path is not None and not fullname.startswith('test.'):
|
|
||||||
return
|
|
||||||
|
|
||||||
method_name = 'spec_for_{fullname}'.format(**locals())
|
|
||||||
method = getattr(self, method_name, lambda: None)
|
|
||||||
return method()
|
|
||||||
|
|
||||||
def spec_for_distutils(self):
|
|
||||||
if self.is_cpython():
|
|
||||||
return
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import importlib.abc
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
try:
|
|
||||||
mod = importlib.import_module('setuptools._distutils')
|
|
||||||
except Exception:
|
|
||||||
# There are a couple of cases where setuptools._distutils
|
|
||||||
# may not be present:
|
|
||||||
# - An older Setuptools without a local distutils is
|
|
||||||
# taking precedence. Ref #2957.
|
|
||||||
# - Path manipulation during sitecustomize removes
|
|
||||||
# setuptools from the path but only after the hook
|
|
||||||
# has been loaded. Ref #2980.
|
|
||||||
# In either case, fall back to stdlib behavior.
|
|
||||||
return
|
|
||||||
|
|
||||||
class DistutilsLoader(importlib.abc.Loader):
|
|
||||||
def create_module(self, spec):
|
|
||||||
mod.__name__ = 'distutils'
|
|
||||||
return mod
|
|
||||||
|
|
||||||
def exec_module(self, module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return importlib.util.spec_from_loader(
|
|
||||||
'distutils', DistutilsLoader(), origin=mod.__file__
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_cpython():
|
|
||||||
"""
|
|
||||||
Suppress supplying distutils for CPython (build and tests).
|
|
||||||
Ref #2965 and #3007.
|
|
||||||
"""
|
|
||||||
return os.path.isfile('pybuilddir.txt')
|
|
||||||
|
|
||||||
def spec_for_pip(self):
|
|
||||||
"""
|
|
||||||
Ensure stdlib distutils when running under pip.
|
|
||||||
See pypa/pip#8761 for rationale.
|
|
||||||
"""
|
|
||||||
if self.pip_imported_during_build():
|
|
||||||
return
|
|
||||||
clear_distutils()
|
|
||||||
self.spec_for_distutils = lambda: None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pip_imported_during_build(cls):
|
|
||||||
"""
|
|
||||||
Detect if pip is being imported in a build script. Ref #2355.
|
|
||||||
"""
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
return any(
|
|
||||||
cls.frame_file_is_setup(frame) for frame, line in traceback.walk_stack(None)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def frame_file_is_setup(frame):
|
|
||||||
"""
|
|
||||||
Return True if the indicated frame suggests a setup.py file.
|
|
||||||
"""
|
|
||||||
# some frames may not have __file__ (#2940)
|
|
||||||
return frame.f_globals.get('__file__', '').endswith('setup.py')
|
|
||||||
|
|
||||||
def spec_for_sensitive_tests(self):
|
|
||||||
"""
|
|
||||||
Ensure stdlib distutils when running select tests under CPython.
|
|
||||||
|
|
||||||
python/cpython#91169
|
|
||||||
"""
|
|
||||||
clear_distutils()
|
|
||||||
self.spec_for_distutils = lambda: None
|
|
||||||
|
|
||||||
sensitive_tests = (
|
|
||||||
[
|
|
||||||
'test.test_distutils',
|
|
||||||
'test.test_peg_generator',
|
|
||||||
'test.test_importlib',
|
|
||||||
]
|
|
||||||
if sys.version_info < (3, 10)
|
|
||||||
else [
|
|
||||||
'test.test_distutils',
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
for name in DistutilsMetaFinder.sensitive_tests:
|
|
||||||
setattr(
|
|
||||||
DistutilsMetaFinder,
|
|
||||||
f'spec_for_{name}',
|
|
||||||
DistutilsMetaFinder.spec_for_sensitive_tests,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DISTUTILS_FINDER = DistutilsMetaFinder()
|
|
||||||
|
|
||||||
|
|
||||||
def add_shim():
|
|
||||||
DISTUTILS_FINDER in sys.meta_path or insert_shim()
|
|
||||||
|
|
||||||
|
|
||||||
class shim:
|
|
||||||
def __enter__(self):
|
|
||||||
insert_shim()
|
|
||||||
|
|
||||||
def __exit__(self, exc, value, tb):
|
|
||||||
remove_shim()
|
|
||||||
|
|
||||||
|
|
||||||
def insert_shim():
|
|
||||||
sys.meta_path.insert(0, DISTUTILS_FINDER)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_shim():
|
|
||||||
try:
|
|
||||||
sys.meta_path.remove(DISTUTILS_FINDER)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
__import__('_distutils_hack').do_override()
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
pip
|
|
||||||
|
|
@ -1,279 +0,0 @@
|
||||||
A. HISTORY OF THE SOFTWARE
|
|
||||||
==========================
|
|
||||||
|
|
||||||
Python was created in the early 1990s by Guido van Rossum at Stichting
|
|
||||||
Mathematisch Centrum (CWI, see https://www.cwi.nl) in the Netherlands
|
|
||||||
as a successor of a language called ABC. Guido remains Python's
|
|
||||||
principal author, although it includes many contributions from others.
|
|
||||||
|
|
||||||
In 1995, Guido continued his work on Python at the Corporation for
|
|
||||||
National Research Initiatives (CNRI, see https://www.cnri.reston.va.us)
|
|
||||||
in Reston, Virginia where he released several versions of the
|
|
||||||
software.
|
|
||||||
|
|
||||||
In May 2000, Guido and the Python core development team moved to
|
|
||||||
BeOpen.com to form the BeOpen PythonLabs team. In October of the same
|
|
||||||
year, the PythonLabs team moved to Digital Creations, which became
|
|
||||||
Zope Corporation. In 2001, the Python Software Foundation (PSF, see
|
|
||||||
https://www.python.org/psf/) was formed, a non-profit organization
|
|
||||||
created specifically to own Python-related Intellectual Property.
|
|
||||||
Zope Corporation was a sponsoring member of the PSF.
|
|
||||||
|
|
||||||
All Python releases are Open Source (see https://opensource.org for
|
|
||||||
the Open Source Definition). Historically, most, but not all, Python
|
|
||||||
releases have also been GPL-compatible; the table below summarizes
|
|
||||||
the various releases.
|
|
||||||
|
|
||||||
Release Derived Year Owner GPL-
|
|
||||||
from compatible? (1)
|
|
||||||
|
|
||||||
0.9.0 thru 1.2 1991-1995 CWI yes
|
|
||||||
1.3 thru 1.5.2 1.2 1995-1999 CNRI yes
|
|
||||||
1.6 1.5.2 2000 CNRI no
|
|
||||||
2.0 1.6 2000 BeOpen.com no
|
|
||||||
1.6.1 1.6 2001 CNRI yes (2)
|
|
||||||
2.1 2.0+1.6.1 2001 PSF no
|
|
||||||
2.0.1 2.0+1.6.1 2001 PSF yes
|
|
||||||
2.1.1 2.1+2.0.1 2001 PSF yes
|
|
||||||
2.1.2 2.1.1 2002 PSF yes
|
|
||||||
2.1.3 2.1.2 2002 PSF yes
|
|
||||||
2.2 and above 2.1.1 2001-now PSF yes
|
|
||||||
|
|
||||||
Footnotes:
|
|
||||||
|
|
||||||
(1) GPL-compatible doesn't mean that we're distributing Python under
|
|
||||||
the GPL. All Python licenses, unlike the GPL, let you distribute
|
|
||||||
a modified version without making your changes open source. The
|
|
||||||
GPL-compatible licenses make it possible to combine Python with
|
|
||||||
other software that is released under the GPL; the others don't.
|
|
||||||
|
|
||||||
(2) According to Richard Stallman, 1.6.1 is not GPL-compatible,
|
|
||||||
because its license has a choice of law clause. According to
|
|
||||||
CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1
|
|
||||||
is "not incompatible" with the GPL.
|
|
||||||
|
|
||||||
Thanks to the many outside volunteers who have worked under Guido's
|
|
||||||
direction to make these releases possible.
|
|
||||||
|
|
||||||
|
|
||||||
B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON
|
|
||||||
===============================================================
|
|
||||||
|
|
||||||
Python software and documentation are licensed under the
|
|
||||||
Python Software Foundation License Version 2.
|
|
||||||
|
|
||||||
Starting with Python 3.8.6, examples, recipes, and other code in
|
|
||||||
the documentation are dual licensed under the PSF License Version 2
|
|
||||||
and the Zero-Clause BSD license.
|
|
||||||
|
|
||||||
Some software incorporated into Python is under different licenses.
|
|
||||||
The licenses are listed with code falling under that license.
|
|
||||||
|
|
||||||
|
|
||||||
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
|
|
||||||
--------------------------------------------
|
|
||||||
|
|
||||||
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
|
||||||
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
|
||||||
otherwise using this software ("Python") in source or binary form and
|
|
||||||
its associated documentation.
|
|
||||||
|
|
||||||
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
|
||||||
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
|
||||||
analyze, test, perform and/or display publicly, prepare derivative works,
|
|
||||||
distribute, and otherwise use Python alone or in any derivative version,
|
|
||||||
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
|
||||||
i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
|
|
||||||
2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023 Python Software Foundation;
|
|
||||||
All Rights Reserved" are retained in Python alone or in any derivative version
|
|
||||||
prepared by Licensee.
|
|
||||||
|
|
||||||
3. In the event Licensee prepares a derivative work that is based on
|
|
||||||
or incorporates Python or any part thereof, and wants to make
|
|
||||||
the derivative work available to others as provided herein, then
|
|
||||||
Licensee hereby agrees to include in any such work a brief summary of
|
|
||||||
the changes made to Python.
|
|
||||||
|
|
||||||
4. PSF is making Python available to Licensee on an "AS IS"
|
|
||||||
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
|
||||||
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
|
||||||
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
|
||||||
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
|
||||||
INFRINGE ANY THIRD PARTY RIGHTS.
|
|
||||||
|
|
||||||
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
|
||||||
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
|
||||||
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
|
||||||
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
|
||||||
|
|
||||||
6. This License Agreement will automatically terminate upon a material
|
|
||||||
breach of its terms and conditions.
|
|
||||||
|
|
||||||
7. Nothing in this License Agreement shall be deemed to create any
|
|
||||||
relationship of agency, partnership, or joint venture between PSF and
|
|
||||||
Licensee. This License Agreement does not grant permission to use PSF
|
|
||||||
trademarks or trade name in a trademark sense to endorse or promote
|
|
||||||
products or services of Licensee, or any third party.
|
|
||||||
|
|
||||||
8. By copying, installing or otherwise using Python, Licensee
|
|
||||||
agrees to be bound by the terms and conditions of this License
|
|
||||||
Agreement.
|
|
||||||
|
|
||||||
|
|
||||||
BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0
|
|
||||||
-------------------------------------------
|
|
||||||
|
|
||||||
BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1
|
|
||||||
|
|
||||||
1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an
|
|
||||||
office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the
|
|
||||||
Individual or Organization ("Licensee") accessing and otherwise using
|
|
||||||
this software in source or binary form and its associated
|
|
||||||
documentation ("the Software").
|
|
||||||
|
|
||||||
2. Subject to the terms and conditions of this BeOpen Python License
|
|
||||||
Agreement, BeOpen hereby grants Licensee a non-exclusive,
|
|
||||||
royalty-free, world-wide license to reproduce, analyze, test, perform
|
|
||||||
and/or display publicly, prepare derivative works, distribute, and
|
|
||||||
otherwise use the Software alone or in any derivative version,
|
|
||||||
provided, however, that the BeOpen Python License is retained in the
|
|
||||||
Software, alone or in any derivative version prepared by Licensee.
|
|
||||||
|
|
||||||
3. BeOpen is making the Software available to Licensee on an "AS IS"
|
|
||||||
basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
|
||||||
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND
|
|
||||||
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
|
||||||
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT
|
|
||||||
INFRINGE ANY THIRD PARTY RIGHTS.
|
|
||||||
|
|
||||||
4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE
|
|
||||||
SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS
|
|
||||||
AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY
|
|
||||||
DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
|
||||||
|
|
||||||
5. This License Agreement will automatically terminate upon a material
|
|
||||||
breach of its terms and conditions.
|
|
||||||
|
|
||||||
6. This License Agreement shall be governed by and interpreted in all
|
|
||||||
respects by the law of the State of California, excluding conflict of
|
|
||||||
law provisions. Nothing in this License Agreement shall be deemed to
|
|
||||||
create any relationship of agency, partnership, or joint venture
|
|
||||||
between BeOpen and Licensee. This License Agreement does not grant
|
|
||||||
permission to use BeOpen trademarks or trade names in a trademark
|
|
||||||
sense to endorse or promote products or services of Licensee, or any
|
|
||||||
third party. As an exception, the "BeOpen Python" logos available at
|
|
||||||
http://www.pythonlabs.com/logos.html may be used according to the
|
|
||||||
permissions granted on that web page.
|
|
||||||
|
|
||||||
7. By copying, installing or otherwise using the software, Licensee
|
|
||||||
agrees to be bound by the terms and conditions of this License
|
|
||||||
Agreement.
|
|
||||||
|
|
||||||
|
|
||||||
CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1
|
|
||||||
---------------------------------------
|
|
||||||
|
|
||||||
1. This LICENSE AGREEMENT is between the Corporation for National
|
|
||||||
Research Initiatives, having an office at 1895 Preston White Drive,
|
|
||||||
Reston, VA 20191 ("CNRI"), and the Individual or Organization
|
|
||||||
("Licensee") accessing and otherwise using Python 1.6.1 software in
|
|
||||||
source or binary form and its associated documentation.
|
|
||||||
|
|
||||||
2. Subject to the terms and conditions of this License Agreement, CNRI
|
|
||||||
hereby grants Licensee a nonexclusive, royalty-free, world-wide
|
|
||||||
license to reproduce, analyze, test, perform and/or display publicly,
|
|
||||||
prepare derivative works, distribute, and otherwise use Python 1.6.1
|
|
||||||
alone or in any derivative version, provided, however, that CNRI's
|
|
||||||
License Agreement and CNRI's notice of copyright, i.e., "Copyright (c)
|
|
||||||
1995-2001 Corporation for National Research Initiatives; All Rights
|
|
||||||
Reserved" are retained in Python 1.6.1 alone or in any derivative
|
|
||||||
version prepared by Licensee. Alternately, in lieu of CNRI's License
|
|
||||||
Agreement, Licensee may substitute the following text (omitting the
|
|
||||||
quotes): "Python 1.6.1 is made available subject to the terms and
|
|
||||||
conditions in CNRI's License Agreement. This Agreement together with
|
|
||||||
Python 1.6.1 may be located on the internet using the following
|
|
||||||
unique, persistent identifier (known as a handle): 1895.22/1013. This
|
|
||||||
Agreement may also be obtained from a proxy server on the internet
|
|
||||||
using the following URL: http://hdl.handle.net/1895.22/1013".
|
|
||||||
|
|
||||||
3. In the event Licensee prepares a derivative work that is based on
|
|
||||||
or incorporates Python 1.6.1 or any part thereof, and wants to make
|
|
||||||
the derivative work available to others as provided herein, then
|
|
||||||
Licensee hereby agrees to include in any such work a brief summary of
|
|
||||||
the changes made to Python 1.6.1.
|
|
||||||
|
|
||||||
4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS"
|
|
||||||
basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
|
||||||
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND
|
|
||||||
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
|
||||||
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT
|
|
||||||
INFRINGE ANY THIRD PARTY RIGHTS.
|
|
||||||
|
|
||||||
5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
|
||||||
1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
|
||||||
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1,
|
|
||||||
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
|
||||||
|
|
||||||
6. This License Agreement will automatically terminate upon a material
|
|
||||||
breach of its terms and conditions.
|
|
||||||
|
|
||||||
7. This License Agreement shall be governed by the federal
|
|
||||||
intellectual property law of the United States, including without
|
|
||||||
limitation the federal copyright law, and, to the extent such
|
|
||||||
U.S. federal law does not apply, by the law of the Commonwealth of
|
|
||||||
Virginia, excluding Virginia's conflict of law provisions.
|
|
||||||
Notwithstanding the foregoing, with regard to derivative works based
|
|
||||||
on Python 1.6.1 that incorporate non-separable material that was
|
|
||||||
previously distributed under the GNU General Public License (GPL), the
|
|
||||||
law of the Commonwealth of Virginia shall govern this License
|
|
||||||
Agreement only as to issues arising under or with respect to
|
|
||||||
Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this
|
|
||||||
License Agreement shall be deemed to create any relationship of
|
|
||||||
agency, partnership, or joint venture between CNRI and Licensee. This
|
|
||||||
License Agreement does not grant permission to use CNRI trademarks or
|
|
||||||
trade name in a trademark sense to endorse or promote products or
|
|
||||||
services of Licensee, or any third party.
|
|
||||||
|
|
||||||
8. By clicking on the "ACCEPT" button where indicated, or by copying,
|
|
||||||
installing or otherwise using Python 1.6.1, Licensee agrees to be
|
|
||||||
bound by the terms and conditions of this License Agreement.
|
|
||||||
|
|
||||||
ACCEPT
|
|
||||||
|
|
||||||
|
|
||||||
CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2
|
|
||||||
--------------------------------------------------
|
|
||||||
|
|
||||||
Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam,
|
|
||||||
The Netherlands. All rights reserved.
|
|
||||||
|
|
||||||
Permission to use, copy, modify, and distribute this software and its
|
|
||||||
documentation for any purpose and without fee is hereby granted,
|
|
||||||
provided that the above copyright notice appear in all copies and that
|
|
||||||
both that copyright notice and this permission notice appear in
|
|
||||||
supporting documentation, and that the name of Stichting Mathematisch
|
|
||||||
Centrum or CWI not be used in advertising or publicity pertaining to
|
|
||||||
distribution of the software without specific, written prior
|
|
||||||
permission.
|
|
||||||
|
|
||||||
STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO
|
|
||||||
THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
||||||
FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE
|
|
||||||
FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
||||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
|
||||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
|
||||||
OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
||||||
|
|
||||||
ZERO-CLAUSE BSD LICENSE FOR CODE IN THE PYTHON DOCUMENTATION
|
|
||||||
----------------------------------------------------------------------
|
|
||||||
|
|
||||||
Permission to use, copy, modify, and/or distribute this software for any
|
|
||||||
purpose with or without fee is hereby granted.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
|
|
||||||
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
|
|
||||||
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
|
|
||||||
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
|
||||||
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
|
|
||||||
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
|
||||||
PERFORMANCE OF THIS SOFTWARE.
|
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
Metadata-Version: 2.3
|
|
||||||
Name: aiohappyeyeballs
|
|
||||||
Version: 2.6.1
|
|
||||||
Summary: Happy Eyeballs for asyncio
|
|
||||||
License: PSF-2.0
|
|
||||||
Author: J. Nick Koston
|
|
||||||
Author-email: nick@koston.org
|
|
||||||
Requires-Python: >=3.9
|
|
||||||
Classifier: Development Status :: 5 - Production/Stable
|
|
||||||
Classifier: Intended Audience :: Developers
|
|
||||||
Classifier: Natural Language :: English
|
|
||||||
Classifier: Operating System :: OS Independent
|
|
||||||
Classifier: Topic :: Software Development :: Libraries
|
|
||||||
Classifier: Programming Language :: Python :: 3
|
|
||||||
Classifier: Programming Language :: Python :: 3.9
|
|
||||||
Classifier: Programming Language :: Python :: 3.10
|
|
||||||
Classifier: Programming Language :: Python :: 3.11
|
|
||||||
Classifier: Programming Language :: Python :: 3.12
|
|
||||||
Classifier: Programming Language :: Python :: 3.13
|
|
||||||
Classifier: License :: OSI Approved :: Python Software Foundation License
|
|
||||||
Project-URL: Bug Tracker, https://github.com/aio-libs/aiohappyeyeballs/issues
|
|
||||||
Project-URL: Changelog, https://github.com/aio-libs/aiohappyeyeballs/blob/main/CHANGELOG.md
|
|
||||||
Project-URL: Documentation, https://aiohappyeyeballs.readthedocs.io
|
|
||||||
Project-URL: Repository, https://github.com/aio-libs/aiohappyeyeballs
|
|
||||||
Description-Content-Type: text/markdown
|
|
||||||
|
|
||||||
# aiohappyeyeballs
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="https://github.com/aio-libs/aiohappyeyeballs/actions/workflows/ci.yml?query=branch%3Amain">
|
|
||||||
<img src="https://img.shields.io/github/actions/workflow/status/aio-libs/aiohappyeyeballs/ci-cd.yml?branch=main&label=CI&logo=github&style=flat-square" alt="CI Status" >
|
|
||||||
</a>
|
|
||||||
<a href="https://aiohappyeyeballs.readthedocs.io">
|
|
||||||
<img src="https://img.shields.io/readthedocs/aiohappyeyeballs.svg?logo=read-the-docs&logoColor=fff&style=flat-square" alt="Documentation Status">
|
|
||||||
</a>
|
|
||||||
<a href="https://codecov.io/gh/aio-libs/aiohappyeyeballs">
|
|
||||||
<img src="https://img.shields.io/codecov/c/github/aio-libs/aiohappyeyeballs.svg?logo=codecov&logoColor=fff&style=flat-square" alt="Test coverage percentage">
|
|
||||||
</a>
|
|
||||||
</p>
|
|
||||||
<p align="center">
|
|
||||||
<a href="https://python-poetry.org/">
|
|
||||||
<img src="https://img.shields.io/badge/packaging-poetry-299bd7?style=flat-square&logo=" alt="Poetry">
|
|
||||||
</a>
|
|
||||||
<a href="https://github.com/astral-sh/ruff">
|
|
||||||
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff">
|
|
||||||
</a>
|
|
||||||
<a href="https://github.com/pre-commit/pre-commit">
|
|
||||||
<img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white&style=flat-square" alt="pre-commit">
|
|
||||||
</a>
|
|
||||||
</p>
|
|
||||||
<p align="center">
|
|
||||||
<a href="https://pypi.org/project/aiohappyeyeballs/">
|
|
||||||
<img src="https://img.shields.io/pypi/v/aiohappyeyeballs.svg?logo=python&logoColor=fff&style=flat-square" alt="PyPI Version">
|
|
||||||
</a>
|
|
||||||
<img src="https://img.shields.io/pypi/pyversions/aiohappyeyeballs.svg?style=flat-square&logo=python&logoColor=fff" alt="Supported Python versions">
|
|
||||||
<img src="https://img.shields.io/pypi/l/aiohappyeyeballs.svg?style=flat-square" alt="License">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Documentation**: <a href="https://aiohappyeyeballs.readthedocs.io" target="_blank">https://aiohappyeyeballs.readthedocs.io </a>
|
|
||||||
|
|
||||||
**Source Code**: <a href="https://github.com/aio-libs/aiohappyeyeballs" target="_blank">https://github.com/aio-libs/aiohappyeyeballs </a>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
[Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
|
|
||||||
([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
|
|
||||||
|
|
||||||
## Use case
|
|
||||||
|
|
||||||
This library exists to allow connecting with
|
|
||||||
[Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs)
|
|
||||||
([RFC 8305](https://www.rfc-editor.org/rfc/rfc8305.html))
|
|
||||||
when you
|
|
||||||
already have a list of addrinfo and not a DNS name.
|
|
||||||
|
|
||||||
The stdlib version of `loop.create_connection()`
|
|
||||||
will only work when you pass in an unresolved name which
|
|
||||||
is not a good fit when using DNS caching or resolving
|
|
||||||
names via another method such as `zeroconf`.
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
Install this via pip (or your favourite package manager):
|
|
||||||
|
|
||||||
`pip install aiohappyeyeballs`
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
[aiohappyeyeballs is licensed under the same terms as cpython itself.](https://github.com/python/cpython/blob/main/LICENSE)
|
|
||||||
|
|
||||||
## Example usage
|
|
||||||
|
|
||||||
```python
|
|
||||||
|
|
||||||
addr_infos = await loop.getaddrinfo("example.org", 80)
|
|
||||||
|
|
||||||
socket = await start_connection(addr_infos)
|
|
||||||
socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, happy_eyeballs_delay=0.2)
|
|
||||||
|
|
||||||
transport, protocol = await loop.create_connection(
|
|
||||||
MyProtocol, sock=socket, ...)
|
|
||||||
|
|
||||||
# Remove the first address for each family from addr_info
|
|
||||||
pop_addr_infos_interleave(addr_info, 1)
|
|
||||||
|
|
||||||
# Remove all matching address from addr_info
|
|
||||||
remove_addr_infos(addr_info, "dead::beef::")
|
|
||||||
|
|
||||||
# Convert a local_addr to local_addr_infos
|
|
||||||
local_addr_infos = addr_to_addr_infos(("127.0.0.1",0))
|
|
||||||
```
|
|
||||||
|
|
||||||
## Credits
|
|
||||||
|
|
||||||
This package contains code from cpython and is licensed under the same terms as cpython itself.
|
|
||||||
|
|
||||||
This package was created with
|
|
||||||
[Copier](https://copier.readthedocs.io/) and the
|
|
||||||
[browniebroke/pypackage-template](https://github.com/browniebroke/pypackage-template)
|
|
||||||
project template.
|
|
||||||
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
aiohappyeyeballs-2.6.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
|
||||||
aiohappyeyeballs-2.6.1.dist-info/LICENSE,sha256=Oy-B_iHRgcSZxZolbI4ZaEVdZonSaaqFNzv7avQdo78,13936
|
|
||||||
aiohappyeyeballs-2.6.1.dist-info/METADATA,sha256=NSXlhJwAfi380eEjAo7BQ4P_TVal9xi0qkyZWibMsVM,5915
|
|
||||||
aiohappyeyeballs-2.6.1.dist-info/RECORD,,
|
|
||||||
aiohappyeyeballs-2.6.1.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
|
||||||
aiohappyeyeballs/__init__.py,sha256=x7kktHEtaD9quBcWDJPuLeKyjuVAI-Jj14S9B_5hcTs,361
|
|
||||||
aiohappyeyeballs/__pycache__/__init__.cpython-311.pyc,,
|
|
||||||
aiohappyeyeballs/__pycache__/_staggered.cpython-311.pyc,,
|
|
||||||
aiohappyeyeballs/__pycache__/impl.cpython-311.pyc,,
|
|
||||||
aiohappyeyeballs/__pycache__/types.cpython-311.pyc,,
|
|
||||||
aiohappyeyeballs/__pycache__/utils.cpython-311.pyc,,
|
|
||||||
aiohappyeyeballs/_staggered.py,sha256=edfVowFx-P-ywJjIEF3MdPtEMVODujV6CeMYr65otac,6900
|
|
||||||
aiohappyeyeballs/impl.py,sha256=Dlcm2mTJ28ucrGnxkb_fo9CZzLAkOOBizOt7dreBbXE,9681
|
|
||||||
aiohappyeyeballs/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
||||||
aiohappyeyeballs/types.py,sha256=YZJIAnyoV4Dz0WFtlaf_OyE4EW7Xus1z7aIfNI6tDDQ,425
|
|
||||||
aiohappyeyeballs/utils.py,sha256=on9GxIR0LhEfZu8P6Twi9hepX9zDanuZM20MWsb3xlQ,3028
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
Wheel-Version: 1.0
|
|
||||||
Generator: poetry-core 2.1.1
|
|
||||||
Root-Is-Purelib: true
|
|
||||||
Tag: py3-none-any
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
__version__ = "2.6.1"
|
|
||||||
|
|
||||||
from .impl import start_connection
|
|
||||||
from .types import AddrInfoType, SocketFactoryType
|
|
||||||
from .utils import addr_to_addr_infos, pop_addr_infos_interleave, remove_addr_infos
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"AddrInfoType",
|
|
||||||
"SocketFactoryType",
|
|
||||||
"addr_to_addr_infos",
|
|
||||||
"pop_addr_infos_interleave",
|
|
||||||
"remove_addr_infos",
|
|
||||||
"start_connection",
|
|
||||||
)
|
|
||||||
|
|
@ -1,207 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
|
|
||||||
# PY3.9: Import Callable from typing until we drop Python 3.9 support
|
|
||||||
# https://github.com/python/cpython/issues/87131
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
RE_RAISE_EXCEPTIONS = (SystemExit, KeyboardInterrupt)
|
|
||||||
|
|
||||||
|
|
||||||
def _set_result(wait_next: "asyncio.Future[None]") -> None:
|
|
||||||
"""Set the result of a future if it is not already done."""
|
|
||||||
if not wait_next.done():
|
|
||||||
wait_next.set_result(None)
|
|
||||||
|
|
||||||
|
|
||||||
async def _wait_one(
|
|
||||||
futures: "Iterable[asyncio.Future[Any]]",
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
) -> _T:
|
|
||||||
"""Wait for the first future to complete."""
|
|
||||||
wait_next = loop.create_future()
|
|
||||||
|
|
||||||
def _on_completion(fut: "asyncio.Future[Any]") -> None:
|
|
||||||
if not wait_next.done():
|
|
||||||
wait_next.set_result(fut)
|
|
||||||
|
|
||||||
for f in futures:
|
|
||||||
f.add_done_callback(_on_completion)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await wait_next
|
|
||||||
finally:
|
|
||||||
for f in futures:
|
|
||||||
f.remove_done_callback(_on_completion)
|
|
||||||
|
|
||||||
|
|
||||||
async def staggered_race(
|
|
||||||
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
|
|
||||||
delay: Optional[float],
|
|
||||||
*,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
|
|
||||||
"""
|
|
||||||
Run coroutines with staggered start times and take the first to finish.
|
|
||||||
|
|
||||||
This method takes an iterable of coroutine functions. The first one is
|
|
||||||
started immediately. From then on, whenever the immediately preceding one
|
|
||||||
fails (raises an exception), or when *delay* seconds has passed, the next
|
|
||||||
coroutine is started. This continues until one of the coroutines complete
|
|
||||||
successfully, in which case all others are cancelled, or until all
|
|
||||||
coroutines fail.
|
|
||||||
|
|
||||||
The coroutines provided should be well-behaved in the following way:
|
|
||||||
|
|
||||||
* They should only ``return`` if completed successfully.
|
|
||||||
|
|
||||||
* They should always raise an exception if they did not complete
|
|
||||||
successfully. In particular, if they handle cancellation, they should
|
|
||||||
probably reraise, like this::
|
|
||||||
|
|
||||||
try:
|
|
||||||
# do work
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# undo partially completed work
|
|
||||||
raise
|
|
||||||
|
|
||||||
Args:
|
|
||||||
----
|
|
||||||
coro_fns: an iterable of coroutine functions, i.e. callables that
|
|
||||||
return a coroutine object when called. Use ``functools.partial`` or
|
|
||||||
lambdas to pass arguments.
|
|
||||||
|
|
||||||
delay: amount of time, in seconds, between starting coroutines. If
|
|
||||||
``None``, the coroutines will run sequentially.
|
|
||||||
|
|
||||||
loop: the event loop to use. If ``None``, the running loop is used.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
-------
|
|
||||||
tuple *(winner_result, winner_index, exceptions)* where
|
|
||||||
|
|
||||||
- *winner_result*: the result of the winning coroutine, or ``None``
|
|
||||||
if no coroutines won.
|
|
||||||
|
|
||||||
- *winner_index*: the index of the winning coroutine in
|
|
||||||
``coro_fns``, or ``None`` if no coroutines won. If the winning
|
|
||||||
coroutine may return None on success, *winner_index* can be used
|
|
||||||
to definitively determine whether any coroutine won.
|
|
||||||
|
|
||||||
- *exceptions*: list of exceptions returned by the coroutines.
|
|
||||||
``len(exceptions)`` is equal to the number of coroutines actually
|
|
||||||
started, and the order is the same as in ``coro_fns``. The winning
|
|
||||||
coroutine's entry is ``None``.
|
|
||||||
|
|
||||||
"""
|
|
||||||
loop = loop or asyncio.get_running_loop()
|
|
||||||
exceptions: List[Optional[BaseException]] = []
|
|
||||||
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
|
|
||||||
|
|
||||||
async def run_one_coro(
|
|
||||||
coro_fn: Callable[[], Awaitable[_T]],
|
|
||||||
this_index: int,
|
|
||||||
start_next: "asyncio.Future[None]",
|
|
||||||
) -> Optional[Tuple[_T, int]]:
|
|
||||||
"""
|
|
||||||
Run a single coroutine.
|
|
||||||
|
|
||||||
If the coroutine fails, set the exception in the exceptions list and
|
|
||||||
start the next coroutine by setting the result of the start_next.
|
|
||||||
|
|
||||||
If the coroutine succeeds, return the result and the index of the
|
|
||||||
coroutine in the coro_fns list.
|
|
||||||
|
|
||||||
If SystemExit or KeyboardInterrupt is raised, re-raise it.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await coro_fn()
|
|
||||||
except RE_RAISE_EXCEPTIONS:
|
|
||||||
raise
|
|
||||||
except BaseException as e:
|
|
||||||
exceptions[this_index] = e
|
|
||||||
_set_result(start_next) # Kickstart the next coroutine
|
|
||||||
return None
|
|
||||||
|
|
||||||
return result, this_index
|
|
||||||
|
|
||||||
start_next_timer: Optional[asyncio.TimerHandle] = None
|
|
||||||
start_next: Optional[asyncio.Future[None]]
|
|
||||||
task: asyncio.Task[Optional[Tuple[_T, int]]]
|
|
||||||
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
|
|
||||||
coro_iter = iter(coro_fns)
|
|
||||||
this_index = -1
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
if coro_fn := next(coro_iter, None):
|
|
||||||
this_index += 1
|
|
||||||
exceptions.append(None)
|
|
||||||
start_next = loop.create_future()
|
|
||||||
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
|
|
||||||
tasks.add(task)
|
|
||||||
start_next_timer = (
|
|
||||||
loop.call_later(delay, _set_result, start_next) if delay else None
|
|
||||||
)
|
|
||||||
elif not tasks:
|
|
||||||
# We exhausted the coro_fns list and no tasks are running
|
|
||||||
# so we have no winner and all coroutines failed.
|
|
||||||
break
|
|
||||||
|
|
||||||
while tasks or start_next:
|
|
||||||
done = await _wait_one(
|
|
||||||
(*tasks, start_next) if start_next else tasks, loop
|
|
||||||
)
|
|
||||||
if done is start_next:
|
|
||||||
# The current task has failed or the timer has expired
|
|
||||||
# so we need to start the next task.
|
|
||||||
start_next = None
|
|
||||||
if start_next_timer:
|
|
||||||
start_next_timer.cancel()
|
|
||||||
start_next_timer = None
|
|
||||||
|
|
||||||
# Break out of the task waiting loop to start the next
|
|
||||||
# task.
|
|
||||||
break
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert isinstance(done, asyncio.Task)
|
|
||||||
|
|
||||||
tasks.remove(done)
|
|
||||||
if winner := done.result():
|
|
||||||
return *winner, exceptions
|
|
||||||
finally:
|
|
||||||
# We either have:
|
|
||||||
# - a winner
|
|
||||||
# - all tasks failed
|
|
||||||
# - a KeyboardInterrupt or SystemExit.
|
|
||||||
|
|
||||||
#
|
|
||||||
# If the timer is still running, cancel it.
|
|
||||||
#
|
|
||||||
if start_next_timer:
|
|
||||||
start_next_timer.cancel()
|
|
||||||
|
|
||||||
#
|
|
||||||
# If there are any tasks left, cancel them and than
|
|
||||||
# wait them so they fill the exceptions list.
|
|
||||||
#
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
|
||||||
await task
|
|
||||||
|
|
||||||
return None, None, exceptions
|
|
||||||
|
|
@ -1,259 +0,0 @@
|
||||||
"""Base implementation."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import collections
|
|
||||||
import contextlib
|
|
||||||
import functools
|
|
||||||
import itertools
|
|
||||||
import socket
|
|
||||||
from typing import List, Optional, Sequence, Set, Union
|
|
||||||
|
|
||||||
from . import _staggered
|
|
||||||
from .types import AddrInfoType, SocketFactoryType
|
|
||||||
|
|
||||||
|
|
||||||
async def start_connection(
|
|
||||||
addr_infos: Sequence[AddrInfoType],
|
|
||||||
*,
|
|
||||||
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
|
|
||||||
happy_eyeballs_delay: Optional[float] = None,
|
|
||||||
interleave: Optional[int] = None,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
socket_factory: Optional[SocketFactoryType] = None,
|
|
||||||
) -> socket.socket:
|
|
||||||
"""
|
|
||||||
Connect to a TCP server.
|
|
||||||
|
|
||||||
Create a socket connection to a specified destination. The
|
|
||||||
destination is specified as a list of AddrInfoType tuples as
|
|
||||||
returned from getaddrinfo().
|
|
||||||
|
|
||||||
The arguments are, in order:
|
|
||||||
|
|
||||||
* ``family``: the address family, e.g. ``socket.AF_INET`` or
|
|
||||||
``socket.AF_INET6``.
|
|
||||||
* ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
|
|
||||||
``socket.SOCK_DGRAM``.
|
|
||||||
* ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
|
|
||||||
``socket.IPPROTO_UDP``.
|
|
||||||
* ``canonname``: the canonical name of the address, e.g.
|
|
||||||
``"www.python.org"``.
|
|
||||||
* ``sockaddr``: the socket address
|
|
||||||
|
|
||||||
This method is a coroutine which will try to establish the connection
|
|
||||||
in the background. When successful, the coroutine returns a
|
|
||||||
socket.
|
|
||||||
|
|
||||||
The expected use case is to use this method in conjunction with
|
|
||||||
loop.create_connection() to establish a connection to a server::
|
|
||||||
|
|
||||||
socket = await start_connection(addr_infos)
|
|
||||||
transport, protocol = await loop.create_connection(
|
|
||||||
MyProtocol, sock=socket, ...)
|
|
||||||
"""
|
|
||||||
if not (current_loop := loop):
|
|
||||||
current_loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
single_addr_info = len(addr_infos) == 1
|
|
||||||
|
|
||||||
if happy_eyeballs_delay is not None and interleave is None:
|
|
||||||
# If using happy eyeballs, default to interleave addresses by family
|
|
||||||
interleave = 1
|
|
||||||
|
|
||||||
if interleave and not single_addr_info:
|
|
||||||
addr_infos = _interleave_addrinfos(addr_infos, interleave)
|
|
||||||
|
|
||||||
sock: Optional[socket.socket] = None
|
|
||||||
# uvloop can raise RuntimeError instead of OSError
|
|
||||||
exceptions: List[List[Union[OSError, RuntimeError]]] = []
|
|
||||||
if happy_eyeballs_delay is None or single_addr_info:
|
|
||||||
# not using happy eyeballs
|
|
||||||
for addrinfo in addr_infos:
|
|
||||||
try:
|
|
||||||
sock = await _connect_sock(
|
|
||||||
current_loop,
|
|
||||||
exceptions,
|
|
||||||
addrinfo,
|
|
||||||
local_addr_infos,
|
|
||||||
None,
|
|
||||||
socket_factory,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except (RuntimeError, OSError):
|
|
||||||
continue
|
|
||||||
else: # using happy eyeballs
|
|
||||||
open_sockets: Set[socket.socket] = set()
|
|
||||||
try:
|
|
||||||
sock, _, _ = await _staggered.staggered_race(
|
|
||||||
(
|
|
||||||
functools.partial(
|
|
||||||
_connect_sock,
|
|
||||||
current_loop,
|
|
||||||
exceptions,
|
|
||||||
addrinfo,
|
|
||||||
local_addr_infos,
|
|
||||||
open_sockets,
|
|
||||||
socket_factory,
|
|
||||||
)
|
|
||||||
for addrinfo in addr_infos
|
|
||||||
),
|
|
||||||
happy_eyeballs_delay,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
# If we have a winner, staggered_race will
|
|
||||||
# cancel the other tasks, however there is a
|
|
||||||
# small race window where any of the other tasks
|
|
||||||
# can be done before they are cancelled which
|
|
||||||
# will leave the socket open. To avoid this problem
|
|
||||||
# we pass a set to _connect_sock to keep track of
|
|
||||||
# the open sockets and close them here if there
|
|
||||||
# are any "runner up" sockets.
|
|
||||||
for s in open_sockets:
|
|
||||||
if s is not sock:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
s.close()
|
|
||||||
open_sockets = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
if sock is None:
|
|
||||||
all_exceptions = [exc for sub in exceptions for exc in sub]
|
|
||||||
try:
|
|
||||||
first_exception = all_exceptions[0]
|
|
||||||
if len(all_exceptions) == 1:
|
|
||||||
raise first_exception
|
|
||||||
else:
|
|
||||||
# If they all have the same str(), raise one.
|
|
||||||
model = str(first_exception)
|
|
||||||
if all(str(exc) == model for exc in all_exceptions):
|
|
||||||
raise first_exception
|
|
||||||
# Raise a combined exception so the user can see all
|
|
||||||
# the various error messages.
|
|
||||||
msg = "Multiple exceptions: {}".format(
|
|
||||||
", ".join(str(exc) for exc in all_exceptions)
|
|
||||||
)
|
|
||||||
# If the errno is the same for all exceptions, raise
|
|
||||||
# an OSError with that errno.
|
|
||||||
if isinstance(first_exception, OSError):
|
|
||||||
first_errno = first_exception.errno
|
|
||||||
if all(
|
|
||||||
isinstance(exc, OSError) and exc.errno == first_errno
|
|
||||||
for exc in all_exceptions
|
|
||||||
):
|
|
||||||
raise OSError(first_errno, msg)
|
|
||||||
elif isinstance(first_exception, RuntimeError) and all(
|
|
||||||
isinstance(exc, RuntimeError) for exc in all_exceptions
|
|
||||||
):
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
# We have a mix of OSError and RuntimeError
|
|
||||||
# so we have to pick which one to raise.
|
|
||||||
# and we raise OSError for compatibility
|
|
||||||
raise OSError(msg)
|
|
||||||
finally:
|
|
||||||
all_exceptions = None # type: ignore[assignment]
|
|
||||||
exceptions = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
return sock
|
|
||||||
|
|
||||||
|
|
||||||
async def _connect_sock(
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
exceptions: List[List[Union[OSError, RuntimeError]]],
|
|
||||||
addr_info: AddrInfoType,
|
|
||||||
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
|
|
||||||
open_sockets: Optional[Set[socket.socket]] = None,
|
|
||||||
socket_factory: Optional[SocketFactoryType] = None,
|
|
||||||
) -> socket.socket:
|
|
||||||
"""
|
|
||||||
Create, bind and connect one socket.
|
|
||||||
|
|
||||||
If open_sockets is passed, add the socket to the set of open sockets.
|
|
||||||
Any failure caught here will remove the socket from the set and close it.
|
|
||||||
|
|
||||||
Callers can use this set to close any sockets that are not the winner
|
|
||||||
of all staggered tasks in the result there are runner up sockets aka
|
|
||||||
multiple winners.
|
|
||||||
"""
|
|
||||||
my_exceptions: List[Union[OSError, RuntimeError]] = []
|
|
||||||
exceptions.append(my_exceptions)
|
|
||||||
family, type_, proto, _, address = addr_info
|
|
||||||
sock = None
|
|
||||||
try:
|
|
||||||
if socket_factory is not None:
|
|
||||||
sock = socket_factory(addr_info)
|
|
||||||
else:
|
|
||||||
sock = socket.socket(family=family, type=type_, proto=proto)
|
|
||||||
if open_sockets is not None:
|
|
||||||
open_sockets.add(sock)
|
|
||||||
sock.setblocking(False)
|
|
||||||
if local_addr_infos is not None:
|
|
||||||
for lfamily, _, _, _, laddr in local_addr_infos:
|
|
||||||
# skip local addresses of different family
|
|
||||||
if lfamily != family:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
sock.bind(laddr)
|
|
||||||
break
|
|
||||||
except OSError as exc:
|
|
||||||
msg = (
|
|
||||||
f"error while attempting to bind on "
|
|
||||||
f"address {laddr!r}: "
|
|
||||||
f"{(exc.strerror or '').lower()}"
|
|
||||||
)
|
|
||||||
exc = OSError(exc.errno, msg)
|
|
||||||
my_exceptions.append(exc)
|
|
||||||
else: # all bind attempts failed
|
|
||||||
if my_exceptions:
|
|
||||||
raise my_exceptions.pop()
|
|
||||||
else:
|
|
||||||
raise OSError(f"no matching local address with {family=} found")
|
|
||||||
await loop.sock_connect(sock, address)
|
|
||||||
return sock
|
|
||||||
except (RuntimeError, OSError) as exc:
|
|
||||||
my_exceptions.append(exc)
|
|
||||||
if sock is not None:
|
|
||||||
if open_sockets is not None:
|
|
||||||
open_sockets.remove(sock)
|
|
||||||
try:
|
|
||||||
sock.close()
|
|
||||||
except OSError as e:
|
|
||||||
my_exceptions.append(e)
|
|
||||||
raise
|
|
||||||
raise
|
|
||||||
except:
|
|
||||||
if sock is not None:
|
|
||||||
if open_sockets is not None:
|
|
||||||
open_sockets.remove(sock)
|
|
||||||
try:
|
|
||||||
sock.close()
|
|
||||||
except OSError as e:
|
|
||||||
my_exceptions.append(e)
|
|
||||||
raise
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
exceptions = my_exceptions = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
|
|
||||||
def _interleave_addrinfos(
|
|
||||||
addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
|
|
||||||
) -> List[AddrInfoType]:
|
|
||||||
"""Interleave list of addrinfo tuples by family."""
|
|
||||||
# Group addresses by family
|
|
||||||
addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = (
|
|
||||||
collections.OrderedDict()
|
|
||||||
)
|
|
||||||
for addr in addrinfos:
|
|
||||||
family = addr[0]
|
|
||||||
if family not in addrinfos_by_family:
|
|
||||||
addrinfos_by_family[family] = []
|
|
||||||
addrinfos_by_family[family].append(addr)
|
|
||||||
addrinfos_lists = list(addrinfos_by_family.values())
|
|
||||||
|
|
||||||
reordered: List[AddrInfoType] = []
|
|
||||||
if first_address_family_count > 1:
|
|
||||||
reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
|
|
||||||
del addrinfos_lists[0][: first_address_family_count - 1]
|
|
||||||
reordered.extend(
|
|
||||||
a
|
|
||||||
for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
|
|
||||||
if a is not None
|
|
||||||
)
|
|
||||||
return reordered
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
"""Types for aiohappyeyeballs."""
|
|
||||||
|
|
||||||
import socket
|
|
||||||
|
|
||||||
# PY3.9: Import Callable from typing until we drop Python 3.9 support
|
|
||||||
# https://github.com/python/cpython/issues/87131
|
|
||||||
from typing import Callable, Tuple, Union
|
|
||||||
|
|
||||||
AddrInfoType = Tuple[
|
|
||||||
Union[int, socket.AddressFamily],
|
|
||||||
Union[int, socket.SocketKind],
|
|
||||||
int,
|
|
||||||
str,
|
|
||||||
Tuple, # type: ignore[type-arg]
|
|
||||||
]
|
|
||||||
|
|
||||||
SocketFactoryType = Callable[[AddrInfoType], socket.socket]
|
|
||||||
|
|
@ -1,97 +0,0 @@
|
||||||
"""Utility functions for aiohappyeyeballs."""
|
|
||||||
|
|
||||||
import ipaddress
|
|
||||||
import socket
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from .types import AddrInfoType
|
|
||||||
|
|
||||||
|
|
||||||
def addr_to_addr_infos(
|
|
||||||
addr: Optional[
|
|
||||||
Union[Tuple[str, int, int, int], Tuple[str, int, int], Tuple[str, int]]
|
|
||||||
],
|
|
||||||
) -> Optional[List[AddrInfoType]]:
|
|
||||||
"""Convert an address tuple to a list of addr_info tuples."""
|
|
||||||
if addr is None:
|
|
||||||
return None
|
|
||||||
host = addr[0]
|
|
||||||
port = addr[1]
|
|
||||||
is_ipv6 = ":" in host
|
|
||||||
if is_ipv6:
|
|
||||||
flowinfo = 0
|
|
||||||
scopeid = 0
|
|
||||||
addr_len = len(addr)
|
|
||||||
if addr_len >= 4:
|
|
||||||
scopeid = addr[3] # type: ignore[misc]
|
|
||||||
if addr_len >= 3:
|
|
||||||
flowinfo = addr[2] # type: ignore[misc]
|
|
||||||
addr = (host, port, flowinfo, scopeid)
|
|
||||||
family = socket.AF_INET6
|
|
||||||
else:
|
|
||||||
addr = (host, port)
|
|
||||||
family = socket.AF_INET
|
|
||||||
return [(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)]
|
|
||||||
|
|
||||||
|
|
||||||
def pop_addr_infos_interleave(
|
|
||||||
addr_infos: List[AddrInfoType], interleave: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Pop addr_info from the list of addr_infos by family up to interleave times.
|
|
||||||
|
|
||||||
The interleave parameter is used to know how many addr_infos for
|
|
||||||
each family should be popped of the top of the list.
|
|
||||||
"""
|
|
||||||
seen: Dict[int, int] = {}
|
|
||||||
if interleave is None:
|
|
||||||
interleave = 1
|
|
||||||
to_remove: List[AddrInfoType] = []
|
|
||||||
for addr_info in addr_infos:
|
|
||||||
family = addr_info[0]
|
|
||||||
if family not in seen:
|
|
||||||
seen[family] = 0
|
|
||||||
if seen[family] < interleave:
|
|
||||||
to_remove.append(addr_info)
|
|
||||||
seen[family] += 1
|
|
||||||
for addr_info in to_remove:
|
|
||||||
addr_infos.remove(addr_info)
|
|
||||||
|
|
||||||
|
|
||||||
def _addr_tuple_to_ip_address(
|
|
||||||
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
|
|
||||||
) -> Union[
|
|
||||||
Tuple[ipaddress.IPv4Address, int], Tuple[ipaddress.IPv6Address, int, int, int]
|
|
||||||
]:
|
|
||||||
"""Convert an address tuple to an IPv4Address."""
|
|
||||||
return (ipaddress.ip_address(addr[0]), *addr[1:])
|
|
||||||
|
|
||||||
|
|
||||||
def remove_addr_infos(
|
|
||||||
addr_infos: List[AddrInfoType],
|
|
||||||
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Remove an address from the list of addr_infos.
|
|
||||||
|
|
||||||
The addr value is typically the return value of
|
|
||||||
sock.getpeername().
|
|
||||||
"""
|
|
||||||
bad_addrs_infos: List[AddrInfoType] = []
|
|
||||||
for addr_info in addr_infos:
|
|
||||||
if addr_info[-1] == addr:
|
|
||||||
bad_addrs_infos.append(addr_info)
|
|
||||||
if bad_addrs_infos:
|
|
||||||
for bad_addr_info in bad_addrs_infos:
|
|
||||||
addr_infos.remove(bad_addr_info)
|
|
||||||
return
|
|
||||||
# Slow path in case addr is formatted differently
|
|
||||||
match_addr = _addr_tuple_to_ip_address(addr)
|
|
||||||
for addr_info in addr_infos:
|
|
||||||
if match_addr == _addr_tuple_to_ip_address(addr_info[-1]):
|
|
||||||
bad_addrs_infos.append(addr_info)
|
|
||||||
if bad_addrs_infos:
|
|
||||||
for bad_addr_info in bad_addrs_infos:
|
|
||||||
addr_infos.remove(bad_addr_info)
|
|
||||||
return
|
|
||||||
raise ValueError(f"Address {addr} not found in addr_infos")
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
pip
|
|
||||||
|
|
@ -1,262 +0,0 @@
|
||||||
Metadata-Version: 2.4
|
|
||||||
Name: aiohttp
|
|
||||||
Version: 3.13.2
|
|
||||||
Summary: Async http client/server framework (asyncio)
|
|
||||||
Maintainer-email: aiohttp team <team@aiohttp.org>
|
|
||||||
License: Apache-2.0 AND MIT
|
|
||||||
Project-URL: Homepage, https://github.com/aio-libs/aiohttp
|
|
||||||
Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org
|
|
||||||
Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org
|
|
||||||
Project-URL: CI: GitHub Actions, https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI
|
|
||||||
Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/aiohttp
|
|
||||||
Project-URL: Docs: Changelog, https://docs.aiohttp.org/en/stable/changes.html
|
|
||||||
Project-URL: Docs: RTD, https://docs.aiohttp.org
|
|
||||||
Project-URL: GitHub: issues, https://github.com/aio-libs/aiohttp/issues
|
|
||||||
Project-URL: GitHub: repo, https://github.com/aio-libs/aiohttp
|
|
||||||
Classifier: Development Status :: 5 - Production/Stable
|
|
||||||
Classifier: Framework :: AsyncIO
|
|
||||||
Classifier: Intended Audience :: Developers
|
|
||||||
Classifier: Operating System :: POSIX
|
|
||||||
Classifier: Operating System :: MacOS :: MacOS X
|
|
||||||
Classifier: Operating System :: Microsoft :: Windows
|
|
||||||
Classifier: Programming Language :: Python
|
|
||||||
Classifier: Programming Language :: Python :: 3
|
|
||||||
Classifier: Programming Language :: Python :: 3.9
|
|
||||||
Classifier: Programming Language :: Python :: 3.10
|
|
||||||
Classifier: Programming Language :: Python :: 3.11
|
|
||||||
Classifier: Programming Language :: Python :: 3.12
|
|
||||||
Classifier: Programming Language :: Python :: 3.13
|
|
||||||
Classifier: Programming Language :: Python :: 3.14
|
|
||||||
Classifier: Topic :: Internet :: WWW/HTTP
|
|
||||||
Requires-Python: >=3.9
|
|
||||||
Description-Content-Type: text/x-rst
|
|
||||||
License-File: LICENSE.txt
|
|
||||||
License-File: vendor/llhttp/LICENSE
|
|
||||||
Requires-Dist: aiohappyeyeballs>=2.5.0
|
|
||||||
Requires-Dist: aiosignal>=1.4.0
|
|
||||||
Requires-Dist: async-timeout<6.0,>=4.0; python_version < "3.11"
|
|
||||||
Requires-Dist: attrs>=17.3.0
|
|
||||||
Requires-Dist: frozenlist>=1.1.1
|
|
||||||
Requires-Dist: multidict<7.0,>=4.5
|
|
||||||
Requires-Dist: propcache>=0.2.0
|
|
||||||
Requires-Dist: yarl<2.0,>=1.17.0
|
|
||||||
Provides-Extra: speedups
|
|
||||||
Requires-Dist: aiodns>=3.3.0; extra == "speedups"
|
|
||||||
Requires-Dist: Brotli; platform_python_implementation == "CPython" and extra == "speedups"
|
|
||||||
Requires-Dist: brotlicffi; platform_python_implementation != "CPython" and extra == "speedups"
|
|
||||||
Requires-Dist: backports.zstd; (platform_python_implementation == "CPython" and python_version < "3.14") and extra == "speedups"
|
|
||||||
Dynamic: license-file
|
|
||||||
|
|
||||||
==================================
|
|
||||||
Async http client/server framework
|
|
||||||
==================================
|
|
||||||
|
|
||||||
.. image:: https://raw.githubusercontent.com/aio-libs/aiohttp/master/docs/aiohttp-plain.svg
|
|
||||||
:height: 64px
|
|
||||||
:width: 64px
|
|
||||||
:alt: aiohttp logo
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
||||||
.. image:: https://github.com/aio-libs/aiohttp/workflows/CI/badge.svg
|
|
||||||
:target: https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI
|
|
||||||
:alt: GitHub Actions status for master branch
|
|
||||||
|
|
||||||
.. image:: https://codecov.io/gh/aio-libs/aiohttp/branch/master/graph/badge.svg
|
|
||||||
:target: https://codecov.io/gh/aio-libs/aiohttp
|
|
||||||
:alt: codecov.io status for master branch
|
|
||||||
|
|
||||||
.. image:: https://badge.fury.io/py/aiohttp.svg
|
|
||||||
:target: https://pypi.org/project/aiohttp
|
|
||||||
:alt: Latest PyPI package version
|
|
||||||
|
|
||||||
.. image:: https://img.shields.io/pypi/dm/aiohttp
|
|
||||||
:target: https://pypistats.org/packages/aiohttp
|
|
||||||
:alt: Downloads count
|
|
||||||
|
|
||||||
.. image:: https://readthedocs.org/projects/aiohttp/badge/?version=latest
|
|
||||||
:target: https://docs.aiohttp.org/
|
|
||||||
:alt: Latest Read The Docs
|
|
||||||
|
|
||||||
.. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json
|
|
||||||
:target: https://codspeed.io/aio-libs/aiohttp
|
|
||||||
:alt: Codspeed.io status for aiohttp
|
|
||||||
|
|
||||||
|
|
||||||
Key Features
|
|
||||||
============
|
|
||||||
|
|
||||||
- Supports both client and server side of HTTP protocol.
|
|
||||||
- Supports both client and server Web-Sockets out-of-the-box and avoids
|
|
||||||
Callback Hell.
|
|
||||||
- Provides Web-server with middleware and pluggable routing.
|
|
||||||
|
|
||||||
|
|
||||||
Getting started
|
|
||||||
===============
|
|
||||||
|
|
||||||
Client
|
|
||||||
------
|
|
||||||
|
|
||||||
To get something from the web:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get('http://python.org') as response:
|
|
||||||
|
|
||||||
print("Status:", response.status)
|
|
||||||
print("Content-type:", response.headers['content-type'])
|
|
||||||
|
|
||||||
html = await response.text()
|
|
||||||
print("Body:", html[:15], "...")
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
||||||
This prints:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
Status: 200
|
|
||||||
Content-type: text/html; charset=utf-8
|
|
||||||
Body: <!doctype html> ...
|
|
||||||
|
|
||||||
Coming from `requests <https://requests.readthedocs.io/>`_ ? Read `why we need so many lines <https://aiohttp.readthedocs.io/en/latest/http_request_lifecycle.html>`_.
|
|
||||||
|
|
||||||
Server
|
|
||||||
------
|
|
||||||
|
|
||||||
An example using a simple server:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
# examples/server_simple.py
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
async def handle(request):
|
|
||||||
name = request.match_info.get('name', "Anonymous")
|
|
||||||
text = "Hello, " + name
|
|
||||||
return web.Response(text=text)
|
|
||||||
|
|
||||||
async def wshandle(request):
|
|
||||||
ws = web.WebSocketResponse()
|
|
||||||
await ws.prepare(request)
|
|
||||||
|
|
||||||
async for msg in ws:
|
|
||||||
if msg.type == web.WSMsgType.text:
|
|
||||||
await ws.send_str("Hello, {}".format(msg.data))
|
|
||||||
elif msg.type == web.WSMsgType.binary:
|
|
||||||
await ws.send_bytes(msg.data)
|
|
||||||
elif msg.type == web.WSMsgType.close:
|
|
||||||
break
|
|
||||||
|
|
||||||
return ws
|
|
||||||
|
|
||||||
|
|
||||||
app = web.Application()
|
|
||||||
app.add_routes([web.get('/', handle),
|
|
||||||
web.get('/echo', wshandle),
|
|
||||||
web.get('/{name}', handle)])
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
web.run_app(app)
|
|
||||||
|
|
||||||
|
|
||||||
Documentation
|
|
||||||
=============
|
|
||||||
|
|
||||||
https://aiohttp.readthedocs.io/
|
|
||||||
|
|
||||||
|
|
||||||
Demos
|
|
||||||
=====
|
|
||||||
|
|
||||||
https://github.com/aio-libs/aiohttp-demos
|
|
||||||
|
|
||||||
|
|
||||||
External links
|
|
||||||
==============
|
|
||||||
|
|
||||||
* `Third party libraries
|
|
||||||
<http://aiohttp.readthedocs.io/en/latest/third_party.html>`_
|
|
||||||
* `Built with aiohttp
|
|
||||||
<http://aiohttp.readthedocs.io/en/latest/built_with.html>`_
|
|
||||||
* `Powered by aiohttp
|
|
||||||
<http://aiohttp.readthedocs.io/en/latest/powered_by.html>`_
|
|
||||||
|
|
||||||
Feel free to make a Pull Request for adding your link to these pages!
|
|
||||||
|
|
||||||
|
|
||||||
Communication channels
|
|
||||||
======================
|
|
||||||
|
|
||||||
*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions
|
|
||||||
|
|
||||||
*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_
|
|
||||||
|
|
||||||
We support `Stack Overflow
|
|
||||||
<https://stackoverflow.com/questions/tagged/aiohttp>`_.
|
|
||||||
Please add *aiohttp* tag to your question there.
|
|
||||||
|
|
||||||
Requirements
|
|
||||||
============
|
|
||||||
|
|
||||||
- attrs_
|
|
||||||
- multidict_
|
|
||||||
- yarl_
|
|
||||||
- frozenlist_
|
|
||||||
|
|
||||||
Optionally you may install the aiodns_ library (highly recommended for sake of speed).
|
|
||||||
|
|
||||||
.. _aiodns: https://pypi.python.org/pypi/aiodns
|
|
||||||
.. _attrs: https://github.com/python-attrs/attrs
|
|
||||||
.. _multidict: https://pypi.python.org/pypi/multidict
|
|
||||||
.. _frozenlist: https://pypi.org/project/frozenlist/
|
|
||||||
.. _yarl: https://pypi.python.org/pypi/yarl
|
|
||||||
.. _async-timeout: https://pypi.python.org/pypi/async_timeout
|
|
||||||
|
|
||||||
License
|
|
||||||
=======
|
|
||||||
|
|
||||||
``aiohttp`` is offered under the Apache 2 license.
|
|
||||||
|
|
||||||
|
|
||||||
Keepsafe
|
|
||||||
========
|
|
||||||
|
|
||||||
The aiohttp community would like to thank Keepsafe
|
|
||||||
(https://www.getkeepsafe.com) for its support in the early days of
|
|
||||||
the project.
|
|
||||||
|
|
||||||
|
|
||||||
Source code
|
|
||||||
===========
|
|
||||||
|
|
||||||
The latest developer version is available in a GitHub repository:
|
|
||||||
https://github.com/aio-libs/aiohttp
|
|
||||||
|
|
||||||
Benchmarks
|
|
||||||
==========
|
|
||||||
|
|
||||||
If you are interested in efficiency, the AsyncIO community maintains a
|
|
||||||
list of benchmarks on the official wiki:
|
|
||||||
https://github.com/python/asyncio/wiki/Benchmarks
|
|
||||||
|
|
||||||
--------
|
|
||||||
|
|
||||||
.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
|
|
||||||
:target: https://matrix.to/#/%23aio-libs:matrix.org
|
|
||||||
:alt: Matrix Room — #aio-libs:matrix.org
|
|
||||||
|
|
||||||
.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat
|
|
||||||
:target: https://matrix.to/#/%23aio-libs-space:matrix.org
|
|
||||||
:alt: Matrix Space — #aio-libs-space:matrix.org
|
|
||||||
|
|
||||||
.. image:: https://insights.linuxfoundation.org/api/badge/health-score?project=aiohttp
|
|
||||||
:target: https://insights.linuxfoundation.org/project/aiohttp
|
|
||||||
:alt: LFX Health Score
|
|
||||||
|
|
@ -1,139 +0,0 @@
|
||||||
aiohttp-3.13.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
|
||||||
aiohttp-3.13.2.dist-info/METADATA,sha256=3xr8ZyYTInh909TqCdZhKIC37g5nTgyP-Nj_yglCs5A,8135
|
|
||||||
aiohttp-3.13.2.dist-info/RECORD,,
|
|
||||||
aiohttp-3.13.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
||||||
aiohttp-3.13.2.dist-info/WHEEL,sha256=nENvFvUt2sTxh7qTwFTbrHft1Jd6WkcTog-2x3-pWGY,193
|
|
||||||
aiohttp-3.13.2.dist-info/licenses/LICENSE.txt,sha256=n4DQ2311WpQdtFchcsJw7L2PCCuiFd3QlZhZQu2Uqes,588
|
|
||||||
aiohttp-3.13.2.dist-info/licenses/vendor/llhttp/LICENSE,sha256=68qFTgE0zSVtZzYnwgSZ9CV363S6zwi58ltianPJEnc,1105
|
|
||||||
aiohttp-3.13.2.dist-info/top_level.txt,sha256=iv-JIaacmTl-hSho3QmphcKnbRRYx1st47yjz_178Ro,8
|
|
||||||
aiohttp/.hash/_cparser.pxd.hash,sha256=pjs-sEXNw_eijXGAedwG-BHnlFp8B7sOCgUagIWaU2A,121
|
|
||||||
aiohttp/.hash/_find_header.pxd.hash,sha256=_mbpD6vM-CVCKq3ulUvsOAz5Wdo88wrDzfpOsMQaMNA,125
|
|
||||||
aiohttp/.hash/_http_parser.pyx.hash,sha256=ju4DG_uNv8rTD6pu3IunE1ysx3ZbH4OjiQHUb_URSoA,125
|
|
||||||
aiohttp/.hash/_http_writer.pyx.hash,sha256=9txOh7t7c3y-vLmiuEY5dltmXvEo0CYyU4U853yyv9E,125
|
|
||||||
aiohttp/.hash/hdrs.py.hash,sha256=v6IaKbsxjsdQxBzhb5AjP0x_9G3rUe84D7avf7AI4cs,116
|
|
||||||
aiohttp/__init__.py,sha256=YJ2jOOSU0hSTbloGbi5-jtcDfSmpBp2RTQEQAt0ccOA,8302
|
|
||||||
aiohttp/__pycache__/__init__.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/_cookie_helpers.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/abc.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/base_protocol.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/client.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/client_exceptions.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/client_middleware_digest_auth.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/client_middlewares.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/client_proto.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/client_reqrep.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/client_ws.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/compression_utils.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/connector.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/cookiejar.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/formdata.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/hdrs.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/helpers.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/http.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/http_exceptions.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/http_parser.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/http_websocket.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/http_writer.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/log.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/multipart.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/payload.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/payload_streamer.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/pytest_plugin.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/resolver.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/streams.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/tcp_helpers.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/test_utils.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/tracing.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/typedefs.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_app.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_exceptions.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_fileresponse.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_log.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_middlewares.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_protocol.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_request.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_response.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_routedef.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_runner.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_server.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_urldispatcher.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/web_ws.cpython-311.pyc,,
|
|
||||||
aiohttp/__pycache__/worker.cpython-311.pyc,,
|
|
||||||
aiohttp/_cookie_helpers.py,sha256=INC-1MTQU7yJqBVmV48Fw30kzZH47KdVUrP_bbfpGvs,13647
|
|
||||||
aiohttp/_cparser.pxd,sha256=UnbUYCHg4NdXfgyRVYAMv2KTLWClB4P-xCrvtj_r7ew,4295
|
|
||||||
aiohttp/_find_header.pxd,sha256=0GfwFCPN2zxEKTO1_MA5sYq2UfzsG8kcV3aTqvwlz3g,68
|
|
||||||
aiohttp/_headers.pxi,sha256=n701k28dVPjwRnx5j6LpJhLTfj7dqu2vJt7f0O60Oyg,2007
|
|
||||||
aiohttp/_http_parser.cpython-311-aarch64-linux-gnu.so,sha256=QmiJqs9g4Hu-wOIfGNAoXpkMjbxoSmvSO40tQZ5v-i0,2965864
|
|
||||||
aiohttp/_http_parser.pyx,sha256=tmA1PaJn7H8U1nyXtoHJV44pxYVzqXAf1UgJaYPaw28,28219
|
|
||||||
aiohttp/_http_writer.cpython-311-aarch64-linux-gnu.so,sha256=VKki78U6dDv7i02lTWx0HSzhb_hCThwQoNGWhdTjZ6g,664344
|
|
||||||
aiohttp/_http_writer.pyx,sha256=VlFEBM6HoVv8a0AAJtc6JwFlsv2-cDE8-gB94p3dfhQ,4664
|
|
||||||
aiohttp/_websocket/.hash/mask.pxd.hash,sha256=Y0zBddk_ck3pi9-BFzMcpkcvCKvwvZ4GTtZFb9u1nxQ,128
|
|
||||||
aiohttp/_websocket/.hash/mask.pyx.hash,sha256=90owpXYM8_kIma4KUcOxhWSk-Uv4NVMBoCYeFM1B3d0,128
|
|
||||||
aiohttp/_websocket/.hash/reader_c.pxd.hash,sha256=5xf3oobk6vx4xbJm-xtZ1_QufB8fYFtLQV2MNdqUc1w,132
|
|
||||||
aiohttp/_websocket/__init__.py,sha256=Mar3R9_vBN_Ea4lsW7iTAVXD7OKswKPGqF5xgSyt77k,44
|
|
||||||
aiohttp/_websocket/__pycache__/__init__.cpython-311.pyc,,
|
|
||||||
aiohttp/_websocket/__pycache__/helpers.cpython-311.pyc,,
|
|
||||||
aiohttp/_websocket/__pycache__/models.cpython-311.pyc,,
|
|
||||||
aiohttp/_websocket/__pycache__/reader.cpython-311.pyc,,
|
|
||||||
aiohttp/_websocket/__pycache__/reader_c.cpython-311.pyc,,
|
|
||||||
aiohttp/_websocket/__pycache__/reader_py.cpython-311.pyc,,
|
|
||||||
aiohttp/_websocket/__pycache__/writer.cpython-311.pyc,,
|
|
||||||
aiohttp/_websocket/helpers.py,sha256=P-XLv8IUaihKzDenVUqfKU5DJbWE5HvG8uhvUZK8Ic4,5038
|
|
||||||
aiohttp/_websocket/mask.cpython-311-aarch64-linux-gnu.so,sha256=93nWiDT7VdSHaZlaqBZQFe5mURm0B-yVTvAkUHGazxU,406752
|
|
||||||
aiohttp/_websocket/mask.pxd,sha256=sBmZ1Amym9kW4Ge8lj1fLZ7mPPya4LzLdpkQExQXv5M,112
|
|
||||||
aiohttp/_websocket/mask.pyx,sha256=BHjOtV0O0w7xp9p0LNADRJvGmgfPn9sGeJvSs0fL__4,1397
|
|
||||||
aiohttp/_websocket/models.py,sha256=XAzjs_8JYszWXIgZ6R3ZRrF-tX9Q_6LiD49WRYojopM,2121
|
|
||||||
aiohttp/_websocket/reader.py,sha256=eC4qS0c5sOeQ2ebAHLaBpIaTVFaSKX79pY2xvh3Pqyw,1030
|
|
||||||
aiohttp/_websocket/reader_c.cpython-311-aarch64-linux-gnu.so,sha256=irbrCCCSS_hN6L3hw9okjPhF9MCAHpafU53rC_pipA4,2064488
|
|
||||||
aiohttp/_websocket/reader_c.pxd,sha256=nl_njtDrzlQU0rjgGGjZDB-swguE0tX_bCPobkShVa4,2625
|
|
||||||
aiohttp/_websocket/reader_c.py,sha256=gSsE_iSBr7-ORvOmgkCT7Jpj4_j3854i_Cp88Se1_6E,18791
|
|
||||||
aiohttp/_websocket/reader_py.py,sha256=gSsE_iSBr7-ORvOmgkCT7Jpj4_j3854i_Cp88Se1_6E,18791
|
|
||||||
aiohttp/_websocket/writer.py,sha256=2OvSktPmNh_g20h1cXJt2Xu8u6IvswnPjdur7OwBbJk,11261
|
|
||||||
aiohttp/abc.py,sha256=M66F4S6m00bIEn7y4ha_XLTMDmVQ9dPihfOVB0pGfOo,7149
|
|
||||||
aiohttp/base_protocol.py,sha256=Tp8cxUPQvv9kUPk3w6lAzk6d2MAzV3scwI_3Go3C47c,3025
|
|
||||||
aiohttp/client.py,sha256=fOQfwcIUL1NGAVRV4DDj6-wipBzeD8KZpmzhO-LLKp4,58357
|
|
||||||
aiohttp/client_exceptions.py,sha256=uyKbxI2peZhKl7lELBMx3UeusNkfpemPWpGFq0r6JeM,11367
|
|
||||||
aiohttp/client_middleware_digest_auth.py,sha256=BIoQJ5eWL5NNkPOmezTGrceWIho8ETDvS8NKvX-3Xdw,17088
|
|
||||||
aiohttp/client_middlewares.py,sha256=kP5N9CMzQPMGPIEydeVUiLUTLsw8Vl8Gr4qAWYdu3vM,1918
|
|
||||||
aiohttp/client_proto.py,sha256=56_WtLStZGBFPYKzgEgY6v24JkhV1y6JEmmuxeJT2So,12110
|
|
||||||
aiohttp/client_reqrep.py,sha256=eEREDrZ0M8ZFTt1wjHduR-P8_sm40K65gNz-iMGYask,53391
|
|
||||||
aiohttp/client_ws.py,sha256=1CIjIXwyzOMIYw6AjUES4-qUwbyVHW1seJKQfg_Rta8,15109
|
|
||||||
aiohttp/compression_utils.py,sha256=Cmn4bim6iDYUST1Fp66EBRDzIz_3gUQBLg4HkbEljrc,10408
|
|
||||||
aiohttp/connector.py,sha256=WQetKoSW7XnHA9r4o9OWwO3-n7ymOwBd2Tg_xHNw0Bs,68456
|
|
||||||
aiohttp/cookiejar.py,sha256=e28ZMQwJ5P0vbPX1OX4Se7-k3zeGvocFEqzGhwpG53k,18922
|
|
||||||
aiohttp/formdata.py,sha256=xqYMbUo1qoLYPuzY92XeR4pyEe-w-DNcToARDF3GUhA,6384
|
|
||||||
aiohttp/hdrs.py,sha256=2rj5MyA-6yRdYPhW5UKkW4iNWhEAlGIOSBH5D4FmKNE,5111
|
|
||||||
aiohttp/helpers.py,sha256=Q1307PCEnWz4RP8crUw8dk58c0YF2Ei3JywkKfRxz5E,30629
|
|
||||||
aiohttp/http.py,sha256=8o8j8xH70OWjnfTWA9V44NR785QPxEPrUtzMXiAVpwc,1842
|
|
||||||
aiohttp/http_exceptions.py,sha256=AZafFHgtAkAgrKZf8zYPU8VX2dq32-VAoP-UZxBLU0c,2960
|
|
||||||
aiohttp/http_parser.py,sha256=fACBNI47n9hnVPWfm5AJufuezsoYOF_VLp4bptYjvQI,37377
|
|
||||||
aiohttp/http_websocket.py,sha256=8VXFKw6KQUEmPg48GtRMB37v0gTK7A0inoxXuDxMZEc,842
|
|
||||||
aiohttp/http_writer.py,sha256=fbRtKPYSqRbtAdr_gqpjF2-4sI1ESL8dPDF-xY_mAMY,12446
|
|
||||||
aiohttp/log.py,sha256=BbNKx9e3VMIm0xYjZI0IcBBoS7wjdeIeSaiJE7-qK2g,325
|
|
||||||
aiohttp/multipart.py,sha256=6q6QRjKFVqaWzTbc7bkuBtXsTaQq5b2BhHxLBvAElac,40040
|
|
||||||
aiohttp/payload.py,sha256=O6nsYNULL7AeM2cyJ6TYX73ncVnL5xJwt5AegxwMKqw,40874
|
|
||||||
aiohttp/payload_streamer.py,sha256=ZzEYyfzcjGWkVkK3XR2pBthSCSIykYvY3Wr5cGQ2eTc,2211
|
|
||||||
aiohttp/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7
|
|
||||||
aiohttp/pytest_plugin.py,sha256=z4XwqmsKdyJCKxbGiA5kFf90zcedvomqk4RqjZbhKNk,12901
|
|
||||||
aiohttp/resolver.py,sha256=gsrfUpFf8iHlcHfJvY-1fiBHW3PRvRVNb5lNZBg3zlY,10031
|
|
||||||
aiohttp/streams.py,sha256=cQxo6Fyu_HDWDpbezGRVPIVYtVtTbSLRF7g511DNmSs,22601
|
|
||||||
aiohttp/tcp_helpers.py,sha256=BSadqVWaBpMFDRWnhaaR941N9MiDZ7bdTrxgCb0CW-M,961
|
|
||||||
aiohttp/test_utils.py,sha256=ZJSzZWjC76KSbtwddTKcP6vHpUl_ozfAf3F93ewmHRU,23016
|
|
||||||
aiohttp/tracing.py,sha256=-6aaW6l0J9uJD45LzR4cijYH0j62pt0U_nn_aVzFku4,14558
|
|
||||||
aiohttp/typedefs.py,sha256=wUlqwe9Mw9W8jT3HsYJcYk00qP3EMPz3nTkYXmeNN48,1657
|
|
||||||
aiohttp/web.py,sha256=JzSNmejg5G6YeFAnkIgZfytqbU86sNu844yYKmoUpqs,17852
|
|
||||||
aiohttp/web_app.py,sha256=lGU_aAMN-h3wy-LTTHi6SeKH8ydt1G51BXcCspgD5ZA,19452
|
|
||||||
aiohttp/web_exceptions.py,sha256=7nIuiwhZ39vJJ9KrWqArA5QcWbUdqkz2CLwEpJapeN8,10360
|
|
||||||
aiohttp/web_fileresponse.py,sha256=Xzau8EMrWNrFg3u46h4UEteg93G4zYq94CU6vy0HiqE,16362
|
|
||||||
aiohttp/web_log.py,sha256=rX5D7xLOX2B6BMdiZ-chme_KfJfW5IXEoFwLfkfkajs,7865
|
|
||||||
aiohttp/web_middlewares.py,sha256=sFI0AgeNjdyAjuz92QtMIpngmJSOxrqe2Jfbs4BNUu0,4165
|
|
||||||
aiohttp/web_protocol.py,sha256=c8a0PKGqfhIAiq2RboMsy1NRza4dnj6gnXIWvJUeCF0,27015
|
|
||||||
aiohttp/web_request.py,sha256=zN96OlMRlrCFOMRpdh7y9rvHP0Hm8zavC0OFCj0wlSg,29833
|
|
||||||
aiohttp/web_response.py,sha256=PKcziNU4LmftXqKVvoRMrAbOeVClpSN-iznHsiWezmU,29341
|
|
||||||
aiohttp/web_routedef.py,sha256=VT1GAx6BrawoDh5RwBwBu5wSABSqgWwAe74AUCyZAEo,6110
|
|
||||||
aiohttp/web_runner.py,sha256=v1G1nKiOOQgFnTSR4IMc6I9ReEFDMaHtMLvO_roDM-A,11786
|
|
||||||
aiohttp/web_server.py,sha256=-9WDKUAiR9ll-rSdwXSqG6YjaoW79d1R4y0BGSqgUMA,2888
|
|
||||||
aiohttp/web_urldispatcher.py,sha256=3ryu1ZOpcq79IYNMd6EjYWmQ_i6JbsJzS_IaV0yoYBg,44203
|
|
||||||
aiohttp/web_ws.py,sha256=lItgmyatkXh0M6EY7JoZnSZkUl6R0wv8B88X4ILqQbU,22739
|
|
||||||
aiohttp/worker.py,sha256=zT0iWN5Xze194bO6_VjHou0x7lR_k0MviN6Kadnk22g,8152
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
Wheel-Version: 1.0
|
|
||||||
Generator: setuptools (80.9.0)
|
|
||||||
Root-Is-Purelib: false
|
|
||||||
Tag: cp311-cp311-manylinux_2_17_aarch64
|
|
||||||
Tag: cp311-cp311-manylinux2014_aarch64
|
|
||||||
Tag: cp311-cp311-manylinux_2_28_aarch64
|
|
||||||
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
Copyright aio-libs contributors.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
This software is licensed under the MIT License.
|
|
||||||
|
|
||||||
Copyright Fedor Indutny, 2018.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a
|
|
||||||
copy of this software and associated documentation files (the
|
|
||||||
"Software"), to deal in the Software without restriction, including
|
|
||||||
without limitation the rights to use, copy, modify, merge, publish,
|
|
||||||
distribute, sublicense, and/or sell copies of the Software, and to permit
|
|
||||||
persons to whom the Software is furnished to do so, subject to the
|
|
||||||
following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included
|
|
||||||
in all copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
||||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
|
|
||||||
NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
|
||||||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
|
||||||
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
|
|
||||||
USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
aiohttp
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
5276d46021e0e0d7577e0c9155800cbf62932d60a50783fec42aefb63febedec /home/runner/work/aiohttp/aiohttp/aiohttp/_cparser.pxd
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
d067f01423cddb3c442933b5fcc039b18ab651fcec1bc91c577693aafc25cf78 /home/runner/work/aiohttp/aiohttp/aiohttp/_find_header.pxd
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
b660353da267ec7f14d67c97b681c9578e29c58573a9701fd548096983dac36f /home/runner/work/aiohttp/aiohttp/aiohttp/_http_parser.pyx
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
56514404ce87a15bfc6b400026d73a270165b2fdbe70313cfa007de29ddd7e14 /home/runner/work/aiohttp/aiohttp/aiohttp/_http_writer.pyx
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
dab8f933203eeb245d60f856e542a45b888d5a110094620e4811f90f816628d1 /home/runner/work/aiohttp/aiohttp/aiohttp/hdrs.py
|
|
||||||
|
|
@ -1,278 +0,0 @@
|
||||||
__version__ = "3.13.2"
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Tuple
|
|
||||||
|
|
||||||
from . import hdrs as hdrs
|
|
||||||
from .client import (
|
|
||||||
BaseConnector,
|
|
||||||
ClientConnectionError,
|
|
||||||
ClientConnectionResetError,
|
|
||||||
ClientConnectorCertificateError,
|
|
||||||
ClientConnectorDNSError,
|
|
||||||
ClientConnectorError,
|
|
||||||
ClientConnectorSSLError,
|
|
||||||
ClientError,
|
|
||||||
ClientHttpProxyError,
|
|
||||||
ClientOSError,
|
|
||||||
ClientPayloadError,
|
|
||||||
ClientProxyConnectionError,
|
|
||||||
ClientRequest,
|
|
||||||
ClientResponse,
|
|
||||||
ClientResponseError,
|
|
||||||
ClientSession,
|
|
||||||
ClientSSLError,
|
|
||||||
ClientTimeout,
|
|
||||||
ClientWebSocketResponse,
|
|
||||||
ClientWSTimeout,
|
|
||||||
ConnectionTimeoutError,
|
|
||||||
ContentTypeError,
|
|
||||||
Fingerprint,
|
|
||||||
InvalidURL,
|
|
||||||
InvalidUrlClientError,
|
|
||||||
InvalidUrlRedirectClientError,
|
|
||||||
NamedPipeConnector,
|
|
||||||
NonHttpUrlClientError,
|
|
||||||
NonHttpUrlRedirectClientError,
|
|
||||||
RedirectClientError,
|
|
||||||
RequestInfo,
|
|
||||||
ServerConnectionError,
|
|
||||||
ServerDisconnectedError,
|
|
||||||
ServerFingerprintMismatch,
|
|
||||||
ServerTimeoutError,
|
|
||||||
SocketTimeoutError,
|
|
||||||
TCPConnector,
|
|
||||||
TooManyRedirects,
|
|
||||||
UnixConnector,
|
|
||||||
WSMessageTypeError,
|
|
||||||
WSServerHandshakeError,
|
|
||||||
request,
|
|
||||||
)
|
|
||||||
from .client_middleware_digest_auth import DigestAuthMiddleware
|
|
||||||
from .client_middlewares import ClientHandlerType, ClientMiddlewareType
|
|
||||||
from .compression_utils import set_zlib_backend
|
|
||||||
from .connector import (
|
|
||||||
AddrInfoType as AddrInfoType,
|
|
||||||
SocketFactoryType as SocketFactoryType,
|
|
||||||
)
|
|
||||||
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
|
|
||||||
from .formdata import FormData as FormData
|
|
||||||
from .helpers import BasicAuth, ChainMapProxy, ETag
|
|
||||||
from .http import (
|
|
||||||
HttpVersion as HttpVersion,
|
|
||||||
HttpVersion10 as HttpVersion10,
|
|
||||||
HttpVersion11 as HttpVersion11,
|
|
||||||
WebSocketError as WebSocketError,
|
|
||||||
WSCloseCode as WSCloseCode,
|
|
||||||
WSMessage as WSMessage,
|
|
||||||
WSMsgType as WSMsgType,
|
|
||||||
)
|
|
||||||
from .multipart import (
|
|
||||||
BadContentDispositionHeader as BadContentDispositionHeader,
|
|
||||||
BadContentDispositionParam as BadContentDispositionParam,
|
|
||||||
BodyPartReader as BodyPartReader,
|
|
||||||
MultipartReader as MultipartReader,
|
|
||||||
MultipartWriter as MultipartWriter,
|
|
||||||
content_disposition_filename as content_disposition_filename,
|
|
||||||
parse_content_disposition as parse_content_disposition,
|
|
||||||
)
|
|
||||||
from .payload import (
|
|
||||||
PAYLOAD_REGISTRY as PAYLOAD_REGISTRY,
|
|
||||||
AsyncIterablePayload as AsyncIterablePayload,
|
|
||||||
BufferedReaderPayload as BufferedReaderPayload,
|
|
||||||
BytesIOPayload as BytesIOPayload,
|
|
||||||
BytesPayload as BytesPayload,
|
|
||||||
IOBasePayload as IOBasePayload,
|
|
||||||
JsonPayload as JsonPayload,
|
|
||||||
Payload as Payload,
|
|
||||||
StringIOPayload as StringIOPayload,
|
|
||||||
StringPayload as StringPayload,
|
|
||||||
TextIOPayload as TextIOPayload,
|
|
||||||
get_payload as get_payload,
|
|
||||||
payload_type as payload_type,
|
|
||||||
)
|
|
||||||
from .payload_streamer import streamer as streamer
|
|
||||||
from .resolver import (
|
|
||||||
AsyncResolver as AsyncResolver,
|
|
||||||
DefaultResolver as DefaultResolver,
|
|
||||||
ThreadedResolver as ThreadedResolver,
|
|
||||||
)
|
|
||||||
from .streams import (
|
|
||||||
EMPTY_PAYLOAD as EMPTY_PAYLOAD,
|
|
||||||
DataQueue as DataQueue,
|
|
||||||
EofStream as EofStream,
|
|
||||||
FlowControlDataQueue as FlowControlDataQueue,
|
|
||||||
StreamReader as StreamReader,
|
|
||||||
)
|
|
||||||
from .tracing import (
|
|
||||||
TraceConfig as TraceConfig,
|
|
||||||
TraceConnectionCreateEndParams as TraceConnectionCreateEndParams,
|
|
||||||
TraceConnectionCreateStartParams as TraceConnectionCreateStartParams,
|
|
||||||
TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams,
|
|
||||||
TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams,
|
|
||||||
TraceConnectionReuseconnParams as TraceConnectionReuseconnParams,
|
|
||||||
TraceDnsCacheHitParams as TraceDnsCacheHitParams,
|
|
||||||
TraceDnsCacheMissParams as TraceDnsCacheMissParams,
|
|
||||||
TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams,
|
|
||||||
TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams,
|
|
||||||
TraceRequestChunkSentParams as TraceRequestChunkSentParams,
|
|
||||||
TraceRequestEndParams as TraceRequestEndParams,
|
|
||||||
TraceRequestExceptionParams as TraceRequestExceptionParams,
|
|
||||||
TraceRequestHeadersSentParams as TraceRequestHeadersSentParams,
|
|
||||||
TraceRequestRedirectParams as TraceRequestRedirectParams,
|
|
||||||
TraceRequestStartParams as TraceRequestStartParams,
|
|
||||||
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
# At runtime these are lazy-loaded at the bottom of the file.
|
|
||||||
from .worker import (
|
|
||||||
GunicornUVLoopWebWorker as GunicornUVLoopWebWorker,
|
|
||||||
GunicornWebWorker as GunicornWebWorker,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__: Tuple[str, ...] = (
|
|
||||||
"hdrs",
|
|
||||||
# client
|
|
||||||
"AddrInfoType",
|
|
||||||
"BaseConnector",
|
|
||||||
"ClientConnectionError",
|
|
||||||
"ClientConnectionResetError",
|
|
||||||
"ClientConnectorCertificateError",
|
|
||||||
"ClientConnectorDNSError",
|
|
||||||
"ClientConnectorError",
|
|
||||||
"ClientConnectorSSLError",
|
|
||||||
"ClientError",
|
|
||||||
"ClientHttpProxyError",
|
|
||||||
"ClientOSError",
|
|
||||||
"ClientPayloadError",
|
|
||||||
"ClientProxyConnectionError",
|
|
||||||
"ClientResponse",
|
|
||||||
"ClientRequest",
|
|
||||||
"ClientResponseError",
|
|
||||||
"ClientSSLError",
|
|
||||||
"ClientSession",
|
|
||||||
"ClientTimeout",
|
|
||||||
"ClientWebSocketResponse",
|
|
||||||
"ClientWSTimeout",
|
|
||||||
"ConnectionTimeoutError",
|
|
||||||
"ContentTypeError",
|
|
||||||
"Fingerprint",
|
|
||||||
"FlowControlDataQueue",
|
|
||||||
"InvalidURL",
|
|
||||||
"InvalidUrlClientError",
|
|
||||||
"InvalidUrlRedirectClientError",
|
|
||||||
"NonHttpUrlClientError",
|
|
||||||
"NonHttpUrlRedirectClientError",
|
|
||||||
"RedirectClientError",
|
|
||||||
"RequestInfo",
|
|
||||||
"ServerConnectionError",
|
|
||||||
"ServerDisconnectedError",
|
|
||||||
"ServerFingerprintMismatch",
|
|
||||||
"ServerTimeoutError",
|
|
||||||
"SocketFactoryType",
|
|
||||||
"SocketTimeoutError",
|
|
||||||
"TCPConnector",
|
|
||||||
"TooManyRedirects",
|
|
||||||
"UnixConnector",
|
|
||||||
"NamedPipeConnector",
|
|
||||||
"WSServerHandshakeError",
|
|
||||||
"request",
|
|
||||||
# client_middleware
|
|
||||||
"ClientMiddlewareType",
|
|
||||||
"ClientHandlerType",
|
|
||||||
# cookiejar
|
|
||||||
"CookieJar",
|
|
||||||
"DummyCookieJar",
|
|
||||||
# formdata
|
|
||||||
"FormData",
|
|
||||||
# helpers
|
|
||||||
"BasicAuth",
|
|
||||||
"ChainMapProxy",
|
|
||||||
"DigestAuthMiddleware",
|
|
||||||
"ETag",
|
|
||||||
"set_zlib_backend",
|
|
||||||
# http
|
|
||||||
"HttpVersion",
|
|
||||||
"HttpVersion10",
|
|
||||||
"HttpVersion11",
|
|
||||||
"WSMsgType",
|
|
||||||
"WSCloseCode",
|
|
||||||
"WSMessage",
|
|
||||||
"WebSocketError",
|
|
||||||
# multipart
|
|
||||||
"BadContentDispositionHeader",
|
|
||||||
"BadContentDispositionParam",
|
|
||||||
"BodyPartReader",
|
|
||||||
"MultipartReader",
|
|
||||||
"MultipartWriter",
|
|
||||||
"content_disposition_filename",
|
|
||||||
"parse_content_disposition",
|
|
||||||
# payload
|
|
||||||
"AsyncIterablePayload",
|
|
||||||
"BufferedReaderPayload",
|
|
||||||
"BytesIOPayload",
|
|
||||||
"BytesPayload",
|
|
||||||
"IOBasePayload",
|
|
||||||
"JsonPayload",
|
|
||||||
"PAYLOAD_REGISTRY",
|
|
||||||
"Payload",
|
|
||||||
"StringIOPayload",
|
|
||||||
"StringPayload",
|
|
||||||
"TextIOPayload",
|
|
||||||
"get_payload",
|
|
||||||
"payload_type",
|
|
||||||
# payload_streamer
|
|
||||||
"streamer",
|
|
||||||
# resolver
|
|
||||||
"AsyncResolver",
|
|
||||||
"DefaultResolver",
|
|
||||||
"ThreadedResolver",
|
|
||||||
# streams
|
|
||||||
"DataQueue",
|
|
||||||
"EMPTY_PAYLOAD",
|
|
||||||
"EofStream",
|
|
||||||
"StreamReader",
|
|
||||||
# tracing
|
|
||||||
"TraceConfig",
|
|
||||||
"TraceConnectionCreateEndParams",
|
|
||||||
"TraceConnectionCreateStartParams",
|
|
||||||
"TraceConnectionQueuedEndParams",
|
|
||||||
"TraceConnectionQueuedStartParams",
|
|
||||||
"TraceConnectionReuseconnParams",
|
|
||||||
"TraceDnsCacheHitParams",
|
|
||||||
"TraceDnsCacheMissParams",
|
|
||||||
"TraceDnsResolveHostEndParams",
|
|
||||||
"TraceDnsResolveHostStartParams",
|
|
||||||
"TraceRequestChunkSentParams",
|
|
||||||
"TraceRequestEndParams",
|
|
||||||
"TraceRequestExceptionParams",
|
|
||||||
"TraceRequestHeadersSentParams",
|
|
||||||
"TraceRequestRedirectParams",
|
|
||||||
"TraceRequestStartParams",
|
|
||||||
"TraceResponseChunkReceivedParams",
|
|
||||||
# workers (imported lazily with __getattr__)
|
|
||||||
"GunicornUVLoopWebWorker",
|
|
||||||
"GunicornWebWorker",
|
|
||||||
"WSMessageTypeError",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def __dir__() -> Tuple[str, ...]:
|
|
||||||
return __all__ + ("__doc__",)
|
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> object:
|
|
||||||
global GunicornUVLoopWebWorker, GunicornWebWorker
|
|
||||||
|
|
||||||
# Importing gunicorn takes a long time (>100ms), so only import if actually needed.
|
|
||||||
if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"):
|
|
||||||
try:
|
|
||||||
from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw
|
|
||||||
except ImportError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
GunicornUVLoopWebWorker = guv # type: ignore[misc]
|
|
||||||
GunicornWebWorker = gw # type: ignore[misc]
|
|
||||||
return guv if name == "GunicornUVLoopWebWorker" else gw
|
|
||||||
|
|
||||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
|
||||||
|
|
@ -1,334 +0,0 @@
|
||||||
"""
|
|
||||||
Internal cookie handling helpers.
|
|
||||||
|
|
||||||
This module contains internal utilities for cookie parsing and manipulation.
|
|
||||||
These are not part of the public API and may change without notice.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from http.cookies import Morsel
|
|
||||||
from typing import List, Optional, Sequence, Tuple, cast
|
|
||||||
|
|
||||||
from .log import internal_logger
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"parse_set_cookie_headers",
|
|
||||||
"parse_cookie_header",
|
|
||||||
"preserve_morsel_with_coded_value",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cookie parsing constants
|
|
||||||
# Allow more characters in cookie names to handle real-world cookies
|
|
||||||
# that don't strictly follow RFC standards (fixes #2683)
|
|
||||||
# RFC 6265 defines cookie-name token as per RFC 2616 Section 2.2,
|
|
||||||
# but many servers send cookies with characters like {} [] () etc.
|
|
||||||
# This makes the cookie parser more tolerant of real-world cookies
|
|
||||||
# while still providing some validation to catch obviously malformed names.
|
|
||||||
_COOKIE_NAME_RE = re.compile(r"^[!#$%&\'()*+\-./0-9:<=>?@A-Z\[\]^_`a-z{|}~]+$")
|
|
||||||
_COOKIE_KNOWN_ATTRS = frozenset( # AKA Morsel._reserved
|
|
||||||
(
|
|
||||||
"path",
|
|
||||||
"domain",
|
|
||||||
"max-age",
|
|
||||||
"expires",
|
|
||||||
"secure",
|
|
||||||
"httponly",
|
|
||||||
"samesite",
|
|
||||||
"partitioned",
|
|
||||||
"version",
|
|
||||||
"comment",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_COOKIE_BOOL_ATTRS = frozenset( # AKA Morsel._flags
|
|
||||||
("secure", "httponly", "partitioned")
|
|
||||||
)
|
|
||||||
|
|
||||||
# SimpleCookie's pattern for parsing cookies with relaxed validation
|
|
||||||
# Based on http.cookies pattern but extended to allow more characters in cookie names
|
|
||||||
# to handle real-world cookies (fixes #2683)
|
|
||||||
_COOKIE_PATTERN = re.compile(
|
|
||||||
r"""
|
|
||||||
\s* # Optional whitespace at start of cookie
|
|
||||||
(?P<key> # Start of group 'key'
|
|
||||||
# aiohttp has extended to include [] for compatibility with real-world cookies
|
|
||||||
[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=\[\]]+? # Any word of at least one letter
|
|
||||||
) # End of group 'key'
|
|
||||||
( # Optional group: there may not be a value.
|
|
||||||
\s*=\s* # Equal Sign
|
|
||||||
(?P<val> # Start of group 'val'
|
|
||||||
"(?:[^\\"]|\\.)*" # Any double-quoted string (properly closed)
|
|
||||||
| # or
|
|
||||||
"[^";]* # Unmatched opening quote (differs from SimpleCookie - issue #7993)
|
|
||||||
| # or
|
|
||||||
# Special case for "expires" attr - RFC 822, RFC 850, RFC 1036, RFC 1123
|
|
||||||
(\w{3,6}day|\w{3}),\s # Day of the week or abbreviated day (with comma)
|
|
||||||
[\w\d\s-]{9,11}\s[\d:]{8}\s # Date and time in specific format
|
|
||||||
(GMT|[+-]\d{4}) # Timezone: GMT or RFC 2822 offset like -0000, +0100
|
|
||||||
# NOTE: RFC 2822 timezone support is an aiohttp extension
|
|
||||||
# for issue #4493 - SimpleCookie does NOT support this
|
|
||||||
| # or
|
|
||||||
# ANSI C asctime() format: "Wed Jun 9 10:18:14 2021"
|
|
||||||
# NOTE: This is an aiohttp extension for issue #4327 - SimpleCookie does NOT support this format
|
|
||||||
\w{3}\s+\w{3}\s+[\s\d]\d\s+\d{2}:\d{2}:\d{2}\s+\d{4}
|
|
||||||
| # or
|
|
||||||
[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=\[\]]* # Any word or empty string
|
|
||||||
) # End of group 'val'
|
|
||||||
)? # End of optional value group
|
|
||||||
\s* # Any number of spaces.
|
|
||||||
(\s+|;|$) # Ending either at space, semicolon, or EOS.
|
|
||||||
""",
|
|
||||||
re.VERBOSE | re.ASCII,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def preserve_morsel_with_coded_value(cookie: Morsel[str]) -> Morsel[str]:
|
|
||||||
"""
|
|
||||||
Preserve a Morsel's coded_value exactly as received from the server.
|
|
||||||
|
|
||||||
This function ensures that cookie encoding is preserved exactly as sent by
|
|
||||||
the server, which is critical for compatibility with old servers that have
|
|
||||||
strict requirements about cookie formats.
|
|
||||||
|
|
||||||
This addresses the issue described in https://github.com/aio-libs/aiohttp/pull/1453
|
|
||||||
where Python's SimpleCookie would re-encode cookies, breaking authentication
|
|
||||||
with certain servers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cookie: A Morsel object from SimpleCookie
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Morsel object with preserved coded_value
|
|
||||||
|
|
||||||
"""
|
|
||||||
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
|
|
||||||
# We use __setstate__ instead of the public set() API because it allows us to
|
|
||||||
# bypass validation and set already validated state. This is more stable than
|
|
||||||
# setting protected attributes directly and unlikely to change since it would
|
|
||||||
# break pickling.
|
|
||||||
mrsl_val.__setstate__( # type: ignore[attr-defined]
|
|
||||||
{"key": cookie.key, "value": cookie.value, "coded_value": cookie.coded_value}
|
|
||||||
)
|
|
||||||
return mrsl_val
|
|
||||||
|
|
||||||
|
|
||||||
_unquote_sub = re.compile(r"\\(?:([0-3][0-7][0-7])|(.))").sub
|
|
||||||
|
|
||||||
|
|
||||||
def _unquote_replace(m: re.Match[str]) -> str:
|
|
||||||
"""
|
|
||||||
Replace function for _unquote_sub regex substitution.
|
|
||||||
|
|
||||||
Handles escaped characters in cookie values:
|
|
||||||
- Octal sequences are converted to their character representation
|
|
||||||
- Other escaped characters are unescaped by removing the backslash
|
|
||||||
"""
|
|
||||||
if m[1]:
|
|
||||||
return chr(int(m[1], 8))
|
|
||||||
return m[2]
|
|
||||||
|
|
||||||
|
|
||||||
def _unquote(value: str) -> str:
|
|
||||||
"""
|
|
||||||
Unquote a cookie value.
|
|
||||||
|
|
||||||
Vendored from http.cookies._unquote to ensure compatibility.
|
|
||||||
|
|
||||||
Note: The original implementation checked for None, but we've removed
|
|
||||||
that check since all callers already ensure the value is not None.
|
|
||||||
"""
|
|
||||||
# If there aren't any doublequotes,
|
|
||||||
# then there can't be any special characters. See RFC 2109.
|
|
||||||
if len(value) < 2:
|
|
||||||
return value
|
|
||||||
if value[0] != '"' or value[-1] != '"':
|
|
||||||
return value
|
|
||||||
|
|
||||||
# We have to assume that we must decode this string.
|
|
||||||
# Down to work.
|
|
||||||
|
|
||||||
# Remove the "s
|
|
||||||
value = value[1:-1]
|
|
||||||
|
|
||||||
# Check for special sequences. Examples:
|
|
||||||
# \012 --> \n
|
|
||||||
# \" --> "
|
|
||||||
#
|
|
||||||
return _unquote_sub(_unquote_replace, value)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_cookie_header(header: str) -> List[Tuple[str, Morsel[str]]]:
|
|
||||||
"""
|
|
||||||
Parse a Cookie header according to RFC 6265 Section 5.4.
|
|
||||||
|
|
||||||
Cookie headers contain only name-value pairs separated by semicolons.
|
|
||||||
There are no attributes in Cookie headers - even names that match
|
|
||||||
attribute names (like 'path' or 'secure') should be treated as cookies.
|
|
||||||
|
|
||||||
This parser uses the same regex-based approach as parse_set_cookie_headers
|
|
||||||
to properly handle quoted values that may contain semicolons. When the
|
|
||||||
regex fails to match a malformed cookie, it falls back to simple parsing
|
|
||||||
to ensure subsequent cookies are not lost
|
|
||||||
https://github.com/aio-libs/aiohttp/issues/11632
|
|
||||||
|
|
||||||
Args:
|
|
||||||
header: The Cookie header value to parse
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of (name, Morsel) tuples for compatibility with SimpleCookie.update()
|
|
||||||
"""
|
|
||||||
if not header:
|
|
||||||
return []
|
|
||||||
|
|
||||||
cookies: List[Tuple[str, Morsel[str]]] = []
|
|
||||||
morsel: Morsel[str]
|
|
||||||
i = 0
|
|
||||||
n = len(header)
|
|
||||||
|
|
||||||
while i < n:
|
|
||||||
# Use the same pattern as parse_set_cookie_headers to find cookies
|
|
||||||
match = _COOKIE_PATTERN.match(header, i)
|
|
||||||
if not match:
|
|
||||||
# Fallback for malformed cookies https://github.com/aio-libs/aiohttp/issues/11632
|
|
||||||
# Find next semicolon to skip or attempt simple key=value parsing
|
|
||||||
next_semi = header.find(";", i)
|
|
||||||
eq_pos = header.find("=", i)
|
|
||||||
|
|
||||||
# Try to extract key=value if '=' comes before ';'
|
|
||||||
if eq_pos != -1 and (next_semi == -1 or eq_pos < next_semi):
|
|
||||||
end_pos = next_semi if next_semi != -1 else n
|
|
||||||
key = header[i:eq_pos].strip()
|
|
||||||
value = header[eq_pos + 1 : end_pos].strip()
|
|
||||||
|
|
||||||
# Validate the name (same as regex path)
|
|
||||||
if not _COOKIE_NAME_RE.match(key):
|
|
||||||
internal_logger.warning(
|
|
||||||
"Can not load cookie: Illegal cookie name %r", key
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
morsel = Morsel()
|
|
||||||
morsel.__setstate__( # type: ignore[attr-defined]
|
|
||||||
{"key": key, "value": _unquote(value), "coded_value": value}
|
|
||||||
)
|
|
||||||
cookies.append((key, morsel))
|
|
||||||
|
|
||||||
# Move to next cookie or end
|
|
||||||
i = next_semi + 1 if next_semi != -1 else n
|
|
||||||
continue
|
|
||||||
|
|
||||||
key = match.group("key")
|
|
||||||
value = match.group("val") or ""
|
|
||||||
i = match.end(0)
|
|
||||||
|
|
||||||
# Validate the name
|
|
||||||
if not key or not _COOKIE_NAME_RE.match(key):
|
|
||||||
internal_logger.warning("Can not load cookie: Illegal cookie name %r", key)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create new morsel
|
|
||||||
morsel = Morsel()
|
|
||||||
# Preserve the original value as coded_value (with quotes if present)
|
|
||||||
# We use __setstate__ instead of the public set() API because it allows us to
|
|
||||||
# bypass validation and set already validated state. This is more stable than
|
|
||||||
# setting protected attributes directly and unlikely to change since it would
|
|
||||||
# break pickling.
|
|
||||||
morsel.__setstate__( # type: ignore[attr-defined]
|
|
||||||
{"key": key, "value": _unquote(value), "coded_value": value}
|
|
||||||
)
|
|
||||||
|
|
||||||
cookies.append((key, morsel))
|
|
||||||
|
|
||||||
return cookies
|
|
||||||
|
|
||||||
|
|
||||||
def parse_set_cookie_headers(headers: Sequence[str]) -> List[Tuple[str, Morsel[str]]]:
|
|
||||||
"""
|
|
||||||
Parse cookie headers using a vendored version of SimpleCookie parsing.
|
|
||||||
|
|
||||||
This implementation is based on SimpleCookie.__parse_string to ensure
|
|
||||||
compatibility with how SimpleCookie parses cookies, including handling
|
|
||||||
of malformed cookies with missing semicolons.
|
|
||||||
|
|
||||||
This function is used for both Cookie and Set-Cookie headers in order to be
|
|
||||||
forgiving. Ideally we would have followed RFC 6265 Section 5.2 (for Cookie
|
|
||||||
headers) and RFC 6265 Section 4.2.1 (for Set-Cookie headers), but the
|
|
||||||
real world data makes it impossible since we need to be a bit more forgiving.
|
|
||||||
|
|
||||||
NOTE: This implementation differs from SimpleCookie in handling unmatched quotes.
|
|
||||||
SimpleCookie will stop parsing when it encounters a cookie value with an unmatched
|
|
||||||
quote (e.g., 'cookie="value'), causing subsequent cookies to be silently dropped.
|
|
||||||
This implementation handles unmatched quotes more gracefully to prevent cookie loss.
|
|
||||||
See https://github.com/aio-libs/aiohttp/issues/7993
|
|
||||||
"""
|
|
||||||
parsed_cookies: List[Tuple[str, Morsel[str]]] = []
|
|
||||||
|
|
||||||
for header in headers:
|
|
||||||
if not header:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Parse cookie string using SimpleCookie's algorithm
|
|
||||||
i = 0
|
|
||||||
n = len(header)
|
|
||||||
current_morsel: Optional[Morsel[str]] = None
|
|
||||||
morsel_seen = False
|
|
||||||
|
|
||||||
while 0 <= i < n:
|
|
||||||
# Start looking for a cookie
|
|
||||||
match = _COOKIE_PATTERN.match(header, i)
|
|
||||||
if not match:
|
|
||||||
# No more cookies
|
|
||||||
break
|
|
||||||
|
|
||||||
key, value = match.group("key"), match.group("val")
|
|
||||||
i = match.end(0)
|
|
||||||
lower_key = key.lower()
|
|
||||||
|
|
||||||
if key[0] == "$":
|
|
||||||
if not morsel_seen:
|
|
||||||
# We ignore attributes which pertain to the cookie
|
|
||||||
# mechanism as a whole, such as "$Version".
|
|
||||||
continue
|
|
||||||
# Process as attribute
|
|
||||||
if current_morsel is not None:
|
|
||||||
attr_lower_key = lower_key[1:]
|
|
||||||
if attr_lower_key in _COOKIE_KNOWN_ATTRS:
|
|
||||||
current_morsel[attr_lower_key] = value or ""
|
|
||||||
elif lower_key in _COOKIE_KNOWN_ATTRS:
|
|
||||||
if not morsel_seen:
|
|
||||||
# Invalid cookie string - attribute before cookie
|
|
||||||
break
|
|
||||||
if lower_key in _COOKIE_BOOL_ATTRS:
|
|
||||||
# Boolean attribute with any value should be True
|
|
||||||
if current_morsel is not None and current_morsel.isReservedKey(key):
|
|
||||||
current_morsel[lower_key] = True
|
|
||||||
elif value is None:
|
|
||||||
# Invalid cookie string - non-boolean attribute without value
|
|
||||||
break
|
|
||||||
elif current_morsel is not None:
|
|
||||||
# Regular attribute with value
|
|
||||||
current_morsel[lower_key] = _unquote(value)
|
|
||||||
elif value is not None:
|
|
||||||
# This is a cookie name=value pair
|
|
||||||
# Validate the name
|
|
||||||
if key in _COOKIE_KNOWN_ATTRS or not _COOKIE_NAME_RE.match(key):
|
|
||||||
internal_logger.warning(
|
|
||||||
"Can not load cookies: Illegal cookie name %r", key
|
|
||||||
)
|
|
||||||
current_morsel = None
|
|
||||||
else:
|
|
||||||
# Create new morsel
|
|
||||||
current_morsel = Morsel()
|
|
||||||
# Preserve the original value as coded_value (with quotes if present)
|
|
||||||
# We use __setstate__ instead of the public set() API because it allows us to
|
|
||||||
# bypass validation and set already validated state. This is more stable than
|
|
||||||
# setting protected attributes directly and unlikely to change since it would
|
|
||||||
# break pickling.
|
|
||||||
current_morsel.__setstate__( # type: ignore[attr-defined]
|
|
||||||
{"key": key, "value": _unquote(value), "coded_value": value}
|
|
||||||
)
|
|
||||||
parsed_cookies.append((key, current_morsel))
|
|
||||||
morsel_seen = True
|
|
||||||
else:
|
|
||||||
# Invalid cookie string - no value for non-attribute
|
|
||||||
break
|
|
||||||
|
|
||||||
return parsed_cookies
|
|
||||||
|
|
@ -1,158 +0,0 @@
|
||||||
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t
|
|
||||||
|
|
||||||
|
|
||||||
cdef extern from "llhttp.h":
|
|
||||||
|
|
||||||
struct llhttp__internal_s:
|
|
||||||
int32_t _index
|
|
||||||
void* _span_pos0
|
|
||||||
void* _span_cb0
|
|
||||||
int32_t error
|
|
||||||
const char* reason
|
|
||||||
const char* error_pos
|
|
||||||
void* data
|
|
||||||
void* _current
|
|
||||||
uint64_t content_length
|
|
||||||
uint8_t type
|
|
||||||
uint8_t method
|
|
||||||
uint8_t http_major
|
|
||||||
uint8_t http_minor
|
|
||||||
uint8_t header_state
|
|
||||||
uint8_t lenient_flags
|
|
||||||
uint8_t upgrade
|
|
||||||
uint8_t finish
|
|
||||||
uint16_t flags
|
|
||||||
uint16_t status_code
|
|
||||||
void* settings
|
|
||||||
|
|
||||||
ctypedef llhttp__internal_s llhttp__internal_t
|
|
||||||
ctypedef llhttp__internal_t llhttp_t
|
|
||||||
|
|
||||||
ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1
|
|
||||||
ctypedef int (*llhttp_cb)(llhttp_t*) except -1
|
|
||||||
|
|
||||||
struct llhttp_settings_s:
|
|
||||||
llhttp_cb on_message_begin
|
|
||||||
llhttp_data_cb on_url
|
|
||||||
llhttp_data_cb on_status
|
|
||||||
llhttp_data_cb on_header_field
|
|
||||||
llhttp_data_cb on_header_value
|
|
||||||
llhttp_cb on_headers_complete
|
|
||||||
llhttp_data_cb on_body
|
|
||||||
llhttp_cb on_message_complete
|
|
||||||
llhttp_cb on_chunk_header
|
|
||||||
llhttp_cb on_chunk_complete
|
|
||||||
|
|
||||||
llhttp_cb on_url_complete
|
|
||||||
llhttp_cb on_status_complete
|
|
||||||
llhttp_cb on_header_field_complete
|
|
||||||
llhttp_cb on_header_value_complete
|
|
||||||
|
|
||||||
ctypedef llhttp_settings_s llhttp_settings_t
|
|
||||||
|
|
||||||
enum llhttp_errno:
|
|
||||||
HPE_OK,
|
|
||||||
HPE_INTERNAL,
|
|
||||||
HPE_STRICT,
|
|
||||||
HPE_LF_EXPECTED,
|
|
||||||
HPE_UNEXPECTED_CONTENT_LENGTH,
|
|
||||||
HPE_CLOSED_CONNECTION,
|
|
||||||
HPE_INVALID_METHOD,
|
|
||||||
HPE_INVALID_URL,
|
|
||||||
HPE_INVALID_CONSTANT,
|
|
||||||
HPE_INVALID_VERSION,
|
|
||||||
HPE_INVALID_HEADER_TOKEN,
|
|
||||||
HPE_INVALID_CONTENT_LENGTH,
|
|
||||||
HPE_INVALID_CHUNK_SIZE,
|
|
||||||
HPE_INVALID_STATUS,
|
|
||||||
HPE_INVALID_EOF_STATE,
|
|
||||||
HPE_INVALID_TRANSFER_ENCODING,
|
|
||||||
HPE_CB_MESSAGE_BEGIN,
|
|
||||||
HPE_CB_HEADERS_COMPLETE,
|
|
||||||
HPE_CB_MESSAGE_COMPLETE,
|
|
||||||
HPE_CB_CHUNK_HEADER,
|
|
||||||
HPE_CB_CHUNK_COMPLETE,
|
|
||||||
HPE_PAUSED,
|
|
||||||
HPE_PAUSED_UPGRADE,
|
|
||||||
HPE_USER
|
|
||||||
|
|
||||||
ctypedef llhttp_errno llhttp_errno_t
|
|
||||||
|
|
||||||
enum llhttp_flags:
|
|
||||||
F_CHUNKED,
|
|
||||||
F_CONTENT_LENGTH
|
|
||||||
|
|
||||||
enum llhttp_type:
|
|
||||||
HTTP_REQUEST,
|
|
||||||
HTTP_RESPONSE,
|
|
||||||
HTTP_BOTH
|
|
||||||
|
|
||||||
enum llhttp_method:
|
|
||||||
HTTP_DELETE,
|
|
||||||
HTTP_GET,
|
|
||||||
HTTP_HEAD,
|
|
||||||
HTTP_POST,
|
|
||||||
HTTP_PUT,
|
|
||||||
HTTP_CONNECT,
|
|
||||||
HTTP_OPTIONS,
|
|
||||||
HTTP_TRACE,
|
|
||||||
HTTP_COPY,
|
|
||||||
HTTP_LOCK,
|
|
||||||
HTTP_MKCOL,
|
|
||||||
HTTP_MOVE,
|
|
||||||
HTTP_PROPFIND,
|
|
||||||
HTTP_PROPPATCH,
|
|
||||||
HTTP_SEARCH,
|
|
||||||
HTTP_UNLOCK,
|
|
||||||
HTTP_BIND,
|
|
||||||
HTTP_REBIND,
|
|
||||||
HTTP_UNBIND,
|
|
||||||
HTTP_ACL,
|
|
||||||
HTTP_REPORT,
|
|
||||||
HTTP_MKACTIVITY,
|
|
||||||
HTTP_CHECKOUT,
|
|
||||||
HTTP_MERGE,
|
|
||||||
HTTP_MSEARCH,
|
|
||||||
HTTP_NOTIFY,
|
|
||||||
HTTP_SUBSCRIBE,
|
|
||||||
HTTP_UNSUBSCRIBE,
|
|
||||||
HTTP_PATCH,
|
|
||||||
HTTP_PURGE,
|
|
||||||
HTTP_MKCALENDAR,
|
|
||||||
HTTP_LINK,
|
|
||||||
HTTP_UNLINK,
|
|
||||||
HTTP_SOURCE,
|
|
||||||
HTTP_PRI,
|
|
||||||
HTTP_DESCRIBE,
|
|
||||||
HTTP_ANNOUNCE,
|
|
||||||
HTTP_SETUP,
|
|
||||||
HTTP_PLAY,
|
|
||||||
HTTP_PAUSE,
|
|
||||||
HTTP_TEARDOWN,
|
|
||||||
HTTP_GET_PARAMETER,
|
|
||||||
HTTP_SET_PARAMETER,
|
|
||||||
HTTP_REDIRECT,
|
|
||||||
HTTP_RECORD,
|
|
||||||
HTTP_FLUSH
|
|
||||||
|
|
||||||
ctypedef llhttp_method llhttp_method_t;
|
|
||||||
|
|
||||||
void llhttp_settings_init(llhttp_settings_t* settings)
|
|
||||||
void llhttp_init(llhttp_t* parser, llhttp_type type,
|
|
||||||
const llhttp_settings_t* settings)
|
|
||||||
|
|
||||||
llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len)
|
|
||||||
|
|
||||||
int llhttp_should_keep_alive(const llhttp_t* parser)
|
|
||||||
|
|
||||||
void llhttp_resume_after_upgrade(llhttp_t* parser)
|
|
||||||
|
|
||||||
llhttp_errno_t llhttp_get_errno(const llhttp_t* parser)
|
|
||||||
const char* llhttp_get_error_reason(const llhttp_t* parser)
|
|
||||||
const char* llhttp_get_error_pos(const llhttp_t* parser)
|
|
||||||
|
|
||||||
const char* llhttp_method_name(llhttp_method_t method)
|
|
||||||
|
|
||||||
void llhttp_set_lenient_headers(llhttp_t* parser, int enabled)
|
|
||||||
void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled)
|
|
||||||
void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled)
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
cdef extern from "_find_header.h":
|
|
||||||
int find_header(char *, int)
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
# The file is autogenerated from aiohttp/hdrs.py
|
|
||||||
# Run ./tools/gen.py to update it after the origin changing.
|
|
||||||
|
|
||||||
from . import hdrs
|
|
||||||
cdef tuple headers = (
|
|
||||||
hdrs.ACCEPT,
|
|
||||||
hdrs.ACCEPT_CHARSET,
|
|
||||||
hdrs.ACCEPT_ENCODING,
|
|
||||||
hdrs.ACCEPT_LANGUAGE,
|
|
||||||
hdrs.ACCEPT_RANGES,
|
|
||||||
hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
|
|
||||||
hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
|
|
||||||
hdrs.ACCESS_CONTROL_ALLOW_METHODS,
|
|
||||||
hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
|
|
||||||
hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
|
|
||||||
hdrs.ACCESS_CONTROL_MAX_AGE,
|
|
||||||
hdrs.ACCESS_CONTROL_REQUEST_HEADERS,
|
|
||||||
hdrs.ACCESS_CONTROL_REQUEST_METHOD,
|
|
||||||
hdrs.AGE,
|
|
||||||
hdrs.ALLOW,
|
|
||||||
hdrs.AUTHORIZATION,
|
|
||||||
hdrs.CACHE_CONTROL,
|
|
||||||
hdrs.CONNECTION,
|
|
||||||
hdrs.CONTENT_DISPOSITION,
|
|
||||||
hdrs.CONTENT_ENCODING,
|
|
||||||
hdrs.CONTENT_LANGUAGE,
|
|
||||||
hdrs.CONTENT_LENGTH,
|
|
||||||
hdrs.CONTENT_LOCATION,
|
|
||||||
hdrs.CONTENT_MD5,
|
|
||||||
hdrs.CONTENT_RANGE,
|
|
||||||
hdrs.CONTENT_TRANSFER_ENCODING,
|
|
||||||
hdrs.CONTENT_TYPE,
|
|
||||||
hdrs.COOKIE,
|
|
||||||
hdrs.DATE,
|
|
||||||
hdrs.DESTINATION,
|
|
||||||
hdrs.DIGEST,
|
|
||||||
hdrs.ETAG,
|
|
||||||
hdrs.EXPECT,
|
|
||||||
hdrs.EXPIRES,
|
|
||||||
hdrs.FORWARDED,
|
|
||||||
hdrs.FROM,
|
|
||||||
hdrs.HOST,
|
|
||||||
hdrs.IF_MATCH,
|
|
||||||
hdrs.IF_MODIFIED_SINCE,
|
|
||||||
hdrs.IF_NONE_MATCH,
|
|
||||||
hdrs.IF_RANGE,
|
|
||||||
hdrs.IF_UNMODIFIED_SINCE,
|
|
||||||
hdrs.KEEP_ALIVE,
|
|
||||||
hdrs.LAST_EVENT_ID,
|
|
||||||
hdrs.LAST_MODIFIED,
|
|
||||||
hdrs.LINK,
|
|
||||||
hdrs.LOCATION,
|
|
||||||
hdrs.MAX_FORWARDS,
|
|
||||||
hdrs.ORIGIN,
|
|
||||||
hdrs.PRAGMA,
|
|
||||||
hdrs.PROXY_AUTHENTICATE,
|
|
||||||
hdrs.PROXY_AUTHORIZATION,
|
|
||||||
hdrs.RANGE,
|
|
||||||
hdrs.REFERER,
|
|
||||||
hdrs.RETRY_AFTER,
|
|
||||||
hdrs.SEC_WEBSOCKET_ACCEPT,
|
|
||||||
hdrs.SEC_WEBSOCKET_EXTENSIONS,
|
|
||||||
hdrs.SEC_WEBSOCKET_KEY,
|
|
||||||
hdrs.SEC_WEBSOCKET_KEY1,
|
|
||||||
hdrs.SEC_WEBSOCKET_PROTOCOL,
|
|
||||||
hdrs.SEC_WEBSOCKET_VERSION,
|
|
||||||
hdrs.SERVER,
|
|
||||||
hdrs.SET_COOKIE,
|
|
||||||
hdrs.TE,
|
|
||||||
hdrs.TRAILER,
|
|
||||||
hdrs.TRANSFER_ENCODING,
|
|
||||||
hdrs.URI,
|
|
||||||
hdrs.UPGRADE,
|
|
||||||
hdrs.USER_AGENT,
|
|
||||||
hdrs.VARY,
|
|
||||||
hdrs.VIA,
|
|
||||||
hdrs.WWW_AUTHENTICATE,
|
|
||||||
hdrs.WANT_DIGEST,
|
|
||||||
hdrs.WARNING,
|
|
||||||
hdrs.X_FORWARDED_FOR,
|
|
||||||
hdrs.X_FORWARDED_HOST,
|
|
||||||
hdrs.X_FORWARDED_PROTO,
|
|
||||||
)
|
|
||||||
Binary file not shown.
|
|
@ -1,835 +0,0 @@
|
||||||
# Based on https://github.com/MagicStack/httptools
|
|
||||||
#
|
|
||||||
|
|
||||||
from cpython cimport (
|
|
||||||
Py_buffer,
|
|
||||||
PyBUF_SIMPLE,
|
|
||||||
PyBuffer_Release,
|
|
||||||
PyBytes_AsString,
|
|
||||||
PyBytes_AsStringAndSize,
|
|
||||||
PyObject_GetBuffer,
|
|
||||||
)
|
|
||||||
from cpython.mem cimport PyMem_Free, PyMem_Malloc
|
|
||||||
from libc.limits cimport ULLONG_MAX
|
|
||||||
from libc.string cimport memcpy
|
|
||||||
|
|
||||||
from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy
|
|
||||||
from yarl import URL as _URL
|
|
||||||
|
|
||||||
from aiohttp import hdrs
|
|
||||||
from aiohttp.helpers import DEBUG, set_exception
|
|
||||||
|
|
||||||
from .http_exceptions import (
|
|
||||||
BadHttpMessage,
|
|
||||||
BadHttpMethod,
|
|
||||||
BadStatusLine,
|
|
||||||
ContentLengthError,
|
|
||||||
InvalidHeader,
|
|
||||||
InvalidURLError,
|
|
||||||
LineTooLong,
|
|
||||||
PayloadEncodingError,
|
|
||||||
TransferEncodingError,
|
|
||||||
)
|
|
||||||
from .http_parser import DeflateBuffer as _DeflateBuffer
|
|
||||||
from .http_writer import (
|
|
||||||
HttpVersion as _HttpVersion,
|
|
||||||
HttpVersion10 as _HttpVersion10,
|
|
||||||
HttpVersion11 as _HttpVersion11,
|
|
||||||
)
|
|
||||||
from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader
|
|
||||||
|
|
||||||
cimport cython
|
|
||||||
|
|
||||||
from aiohttp cimport _cparser as cparser
|
|
||||||
|
|
||||||
include "_headers.pxi"
|
|
||||||
|
|
||||||
from aiohttp cimport _find_header
|
|
||||||
|
|
||||||
ALLOWED_UPGRADES = frozenset({"websocket"})
|
|
||||||
DEF DEFAULT_FREELIST_SIZE = 250
|
|
||||||
|
|
||||||
cdef extern from "Python.h":
|
|
||||||
int PyByteArray_Resize(object, Py_ssize_t) except -1
|
|
||||||
Py_ssize_t PyByteArray_Size(object) except -1
|
|
||||||
char* PyByteArray_AsString(object)
|
|
||||||
|
|
||||||
__all__ = ('HttpRequestParser', 'HttpResponseParser',
|
|
||||||
'RawRequestMessage', 'RawResponseMessage')
|
|
||||||
|
|
||||||
cdef object URL = _URL
|
|
||||||
cdef object URL_build = URL.build
|
|
||||||
cdef object CIMultiDict = _CIMultiDict
|
|
||||||
cdef object CIMultiDictProxy = _CIMultiDictProxy
|
|
||||||
cdef object HttpVersion = _HttpVersion
|
|
||||||
cdef object HttpVersion10 = _HttpVersion10
|
|
||||||
cdef object HttpVersion11 = _HttpVersion11
|
|
||||||
cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1
|
|
||||||
cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING
|
|
||||||
cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD
|
|
||||||
cdef object StreamReader = _StreamReader
|
|
||||||
cdef object DeflateBuffer = _DeflateBuffer
|
|
||||||
cdef bytes EMPTY_BYTES = b""
|
|
||||||
|
|
||||||
cdef inline object extend(object buf, const char* at, size_t length):
|
|
||||||
cdef Py_ssize_t s
|
|
||||||
cdef char* ptr
|
|
||||||
s = PyByteArray_Size(buf)
|
|
||||||
PyByteArray_Resize(buf, s + length)
|
|
||||||
ptr = PyByteArray_AsString(buf)
|
|
||||||
memcpy(ptr + s, at, length)
|
|
||||||
|
|
||||||
|
|
||||||
DEF METHODS_COUNT = 46;
|
|
||||||
|
|
||||||
cdef list _http_method = []
|
|
||||||
|
|
||||||
for i in range(METHODS_COUNT):
|
|
||||||
_http_method.append(
|
|
||||||
cparser.llhttp_method_name(<cparser.llhttp_method_t> i).decode('ascii'))
|
|
||||||
|
|
||||||
|
|
||||||
cdef inline str http_method_str(int i):
|
|
||||||
if i < METHODS_COUNT:
|
|
||||||
return <str>_http_method[i]
|
|
||||||
else:
|
|
||||||
return "<unknown>"
|
|
||||||
|
|
||||||
cdef inline object find_header(bytes raw_header):
|
|
||||||
cdef Py_ssize_t size
|
|
||||||
cdef char *buf
|
|
||||||
cdef int idx
|
|
||||||
PyBytes_AsStringAndSize(raw_header, &buf, &size)
|
|
||||||
idx = _find_header.find_header(buf, size)
|
|
||||||
if idx == -1:
|
|
||||||
return raw_header.decode('utf-8', 'surrogateescape')
|
|
||||||
return headers[idx]
|
|
||||||
|
|
||||||
|
|
||||||
@cython.freelist(DEFAULT_FREELIST_SIZE)
|
|
||||||
cdef class RawRequestMessage:
|
|
||||||
cdef readonly str method
|
|
||||||
cdef readonly str path
|
|
||||||
cdef readonly object version # HttpVersion
|
|
||||||
cdef readonly object headers # CIMultiDict
|
|
||||||
cdef readonly object raw_headers # tuple
|
|
||||||
cdef readonly object should_close
|
|
||||||
cdef readonly object compression
|
|
||||||
cdef readonly object upgrade
|
|
||||||
cdef readonly object chunked
|
|
||||||
cdef readonly object url # yarl.URL
|
|
||||||
|
|
||||||
def __init__(self, method, path, version, headers, raw_headers,
|
|
||||||
should_close, compression, upgrade, chunked, url):
|
|
||||||
self.method = method
|
|
||||||
self.path = path
|
|
||||||
self.version = version
|
|
||||||
self.headers = headers
|
|
||||||
self.raw_headers = raw_headers
|
|
||||||
self.should_close = should_close
|
|
||||||
self.compression = compression
|
|
||||||
self.upgrade = upgrade
|
|
||||||
self.chunked = chunked
|
|
||||||
self.url = url
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
info = []
|
|
||||||
info.append(("method", self.method))
|
|
||||||
info.append(("path", self.path))
|
|
||||||
info.append(("version", self.version))
|
|
||||||
info.append(("headers", self.headers))
|
|
||||||
info.append(("raw_headers", self.raw_headers))
|
|
||||||
info.append(("should_close", self.should_close))
|
|
||||||
info.append(("compression", self.compression))
|
|
||||||
info.append(("upgrade", self.upgrade))
|
|
||||||
info.append(("chunked", self.chunked))
|
|
||||||
info.append(("url", self.url))
|
|
||||||
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
|
|
||||||
return '<RawRequestMessage(' + sinfo + ')>'
|
|
||||||
|
|
||||||
def _replace(self, **dct):
|
|
||||||
cdef RawRequestMessage ret
|
|
||||||
ret = _new_request_message(self.method,
|
|
||||||
self.path,
|
|
||||||
self.version,
|
|
||||||
self.headers,
|
|
||||||
self.raw_headers,
|
|
||||||
self.should_close,
|
|
||||||
self.compression,
|
|
||||||
self.upgrade,
|
|
||||||
self.chunked,
|
|
||||||
self.url)
|
|
||||||
if "method" in dct:
|
|
||||||
ret.method = dct["method"]
|
|
||||||
if "path" in dct:
|
|
||||||
ret.path = dct["path"]
|
|
||||||
if "version" in dct:
|
|
||||||
ret.version = dct["version"]
|
|
||||||
if "headers" in dct:
|
|
||||||
ret.headers = dct["headers"]
|
|
||||||
if "raw_headers" in dct:
|
|
||||||
ret.raw_headers = dct["raw_headers"]
|
|
||||||
if "should_close" in dct:
|
|
||||||
ret.should_close = dct["should_close"]
|
|
||||||
if "compression" in dct:
|
|
||||||
ret.compression = dct["compression"]
|
|
||||||
if "upgrade" in dct:
|
|
||||||
ret.upgrade = dct["upgrade"]
|
|
||||||
if "chunked" in dct:
|
|
||||||
ret.chunked = dct["chunked"]
|
|
||||||
if "url" in dct:
|
|
||||||
ret.url = dct["url"]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
cdef _new_request_message(str method,
|
|
||||||
str path,
|
|
||||||
object version,
|
|
||||||
object headers,
|
|
||||||
object raw_headers,
|
|
||||||
bint should_close,
|
|
||||||
object compression,
|
|
||||||
bint upgrade,
|
|
||||||
bint chunked,
|
|
||||||
object url):
|
|
||||||
cdef RawRequestMessage ret
|
|
||||||
ret = RawRequestMessage.__new__(RawRequestMessage)
|
|
||||||
ret.method = method
|
|
||||||
ret.path = path
|
|
||||||
ret.version = version
|
|
||||||
ret.headers = headers
|
|
||||||
ret.raw_headers = raw_headers
|
|
||||||
ret.should_close = should_close
|
|
||||||
ret.compression = compression
|
|
||||||
ret.upgrade = upgrade
|
|
||||||
ret.chunked = chunked
|
|
||||||
ret.url = url
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
@cython.freelist(DEFAULT_FREELIST_SIZE)
|
|
||||||
cdef class RawResponseMessage:
|
|
||||||
cdef readonly object version # HttpVersion
|
|
||||||
cdef readonly int code
|
|
||||||
cdef readonly str reason
|
|
||||||
cdef readonly object headers # CIMultiDict
|
|
||||||
cdef readonly object raw_headers # tuple
|
|
||||||
cdef readonly object should_close
|
|
||||||
cdef readonly object compression
|
|
||||||
cdef readonly object upgrade
|
|
||||||
cdef readonly object chunked
|
|
||||||
|
|
||||||
def __init__(self, version, code, reason, headers, raw_headers,
|
|
||||||
should_close, compression, upgrade, chunked):
|
|
||||||
self.version = version
|
|
||||||
self.code = code
|
|
||||||
self.reason = reason
|
|
||||||
self.headers = headers
|
|
||||||
self.raw_headers = raw_headers
|
|
||||||
self.should_close = should_close
|
|
||||||
self.compression = compression
|
|
||||||
self.upgrade = upgrade
|
|
||||||
self.chunked = chunked
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
info = []
|
|
||||||
info.append(("version", self.version))
|
|
||||||
info.append(("code", self.code))
|
|
||||||
info.append(("reason", self.reason))
|
|
||||||
info.append(("headers", self.headers))
|
|
||||||
info.append(("raw_headers", self.raw_headers))
|
|
||||||
info.append(("should_close", self.should_close))
|
|
||||||
info.append(("compression", self.compression))
|
|
||||||
info.append(("upgrade", self.upgrade))
|
|
||||||
info.append(("chunked", self.chunked))
|
|
||||||
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
|
|
||||||
return '<RawResponseMessage(' + sinfo + ')>'
|
|
||||||
|
|
||||||
|
|
||||||
cdef _new_response_message(object version,
|
|
||||||
int code,
|
|
||||||
str reason,
|
|
||||||
object headers,
|
|
||||||
object raw_headers,
|
|
||||||
bint should_close,
|
|
||||||
object compression,
|
|
||||||
bint upgrade,
|
|
||||||
bint chunked):
|
|
||||||
cdef RawResponseMessage ret
|
|
||||||
ret = RawResponseMessage.__new__(RawResponseMessage)
|
|
||||||
ret.version = version
|
|
||||||
ret.code = code
|
|
||||||
ret.reason = reason
|
|
||||||
ret.headers = headers
|
|
||||||
ret.raw_headers = raw_headers
|
|
||||||
ret.should_close = should_close
|
|
||||||
ret.compression = compression
|
|
||||||
ret.upgrade = upgrade
|
|
||||||
ret.chunked = chunked
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
@cython.internal
|
|
||||||
cdef class HttpParser:
|
|
||||||
|
|
||||||
cdef:
|
|
||||||
cparser.llhttp_t* _cparser
|
|
||||||
cparser.llhttp_settings_t* _csettings
|
|
||||||
|
|
||||||
bytes _raw_name
|
|
||||||
object _name
|
|
||||||
bytes _raw_value
|
|
||||||
bint _has_value
|
|
||||||
|
|
||||||
object _protocol
|
|
||||||
object _loop
|
|
||||||
object _timer
|
|
||||||
|
|
||||||
size_t _max_line_size
|
|
||||||
size_t _max_field_size
|
|
||||||
size_t _max_headers
|
|
||||||
bint _response_with_body
|
|
||||||
bint _read_until_eof
|
|
||||||
|
|
||||||
bint _started
|
|
||||||
object _url
|
|
||||||
bytearray _buf
|
|
||||||
str _path
|
|
||||||
str _reason
|
|
||||||
list _headers
|
|
||||||
list _raw_headers
|
|
||||||
bint _upgraded
|
|
||||||
list _messages
|
|
||||||
object _payload
|
|
||||||
bint _payload_error
|
|
||||||
object _payload_exception
|
|
||||||
object _last_error
|
|
||||||
bint _auto_decompress
|
|
||||||
int _limit
|
|
||||||
|
|
||||||
str _content_encoding
|
|
||||||
|
|
||||||
Py_buffer py_buf
|
|
||||||
|
|
||||||
def __cinit__(self):
|
|
||||||
self._cparser = <cparser.llhttp_t*> \
|
|
||||||
PyMem_Malloc(sizeof(cparser.llhttp_t))
|
|
||||||
if self._cparser is NULL:
|
|
||||||
raise MemoryError()
|
|
||||||
|
|
||||||
self._csettings = <cparser.llhttp_settings_t*> \
|
|
||||||
PyMem_Malloc(sizeof(cparser.llhttp_settings_t))
|
|
||||||
if self._csettings is NULL:
|
|
||||||
raise MemoryError()
|
|
||||||
|
|
||||||
def __dealloc__(self):
|
|
||||||
PyMem_Free(self._cparser)
|
|
||||||
PyMem_Free(self._csettings)
|
|
||||||
|
|
||||||
cdef _init(
|
|
||||||
self, cparser.llhttp_type mode,
|
|
||||||
object protocol, object loop, int limit,
|
|
||||||
object timer=None,
|
|
||||||
size_t max_line_size=8190, size_t max_headers=32768,
|
|
||||||
size_t max_field_size=8190, payload_exception=None,
|
|
||||||
bint response_with_body=True, bint read_until_eof=False,
|
|
||||||
bint auto_decompress=True,
|
|
||||||
):
|
|
||||||
cparser.llhttp_settings_init(self._csettings)
|
|
||||||
cparser.llhttp_init(self._cparser, mode, self._csettings)
|
|
||||||
self._cparser.data = <void*>self
|
|
||||||
self._cparser.content_length = 0
|
|
||||||
|
|
||||||
self._protocol = protocol
|
|
||||||
self._loop = loop
|
|
||||||
self._timer = timer
|
|
||||||
|
|
||||||
self._buf = bytearray()
|
|
||||||
self._payload = None
|
|
||||||
self._payload_error = 0
|
|
||||||
self._payload_exception = payload_exception
|
|
||||||
self._messages = []
|
|
||||||
|
|
||||||
self._raw_name = EMPTY_BYTES
|
|
||||||
self._raw_value = EMPTY_BYTES
|
|
||||||
self._has_value = False
|
|
||||||
|
|
||||||
self._max_line_size = max_line_size
|
|
||||||
self._max_headers = max_headers
|
|
||||||
self._max_field_size = max_field_size
|
|
||||||
self._response_with_body = response_with_body
|
|
||||||
self._read_until_eof = read_until_eof
|
|
||||||
self._upgraded = False
|
|
||||||
self._auto_decompress = auto_decompress
|
|
||||||
self._content_encoding = None
|
|
||||||
|
|
||||||
self._csettings.on_url = cb_on_url
|
|
||||||
self._csettings.on_status = cb_on_status
|
|
||||||
self._csettings.on_header_field = cb_on_header_field
|
|
||||||
self._csettings.on_header_value = cb_on_header_value
|
|
||||||
self._csettings.on_headers_complete = cb_on_headers_complete
|
|
||||||
self._csettings.on_body = cb_on_body
|
|
||||||
self._csettings.on_message_begin = cb_on_message_begin
|
|
||||||
self._csettings.on_message_complete = cb_on_message_complete
|
|
||||||
self._csettings.on_chunk_header = cb_on_chunk_header
|
|
||||||
self._csettings.on_chunk_complete = cb_on_chunk_complete
|
|
||||||
|
|
||||||
self._last_error = None
|
|
||||||
self._limit = limit
|
|
||||||
|
|
||||||
cdef _process_header(self):
|
|
||||||
cdef str value
|
|
||||||
if self._raw_name is not EMPTY_BYTES:
|
|
||||||
name = find_header(self._raw_name)
|
|
||||||
value = self._raw_value.decode('utf-8', 'surrogateescape')
|
|
||||||
|
|
||||||
self._headers.append((name, value))
|
|
||||||
|
|
||||||
if name is CONTENT_ENCODING:
|
|
||||||
self._content_encoding = value
|
|
||||||
|
|
||||||
self._has_value = False
|
|
||||||
self._raw_headers.append((self._raw_name, self._raw_value))
|
|
||||||
self._raw_name = EMPTY_BYTES
|
|
||||||
self._raw_value = EMPTY_BYTES
|
|
||||||
|
|
||||||
cdef _on_header_field(self, char* at, size_t length):
|
|
||||||
if self._has_value:
|
|
||||||
self._process_header()
|
|
||||||
|
|
||||||
if self._raw_name is EMPTY_BYTES:
|
|
||||||
self._raw_name = at[:length]
|
|
||||||
else:
|
|
||||||
self._raw_name += at[:length]
|
|
||||||
|
|
||||||
cdef _on_header_value(self, char* at, size_t length):
|
|
||||||
if self._raw_value is EMPTY_BYTES:
|
|
||||||
self._raw_value = at[:length]
|
|
||||||
else:
|
|
||||||
self._raw_value += at[:length]
|
|
||||||
self._has_value = True
|
|
||||||
|
|
||||||
cdef _on_headers_complete(self):
|
|
||||||
self._process_header()
|
|
||||||
|
|
||||||
should_close = not cparser.llhttp_should_keep_alive(self._cparser)
|
|
||||||
upgrade = self._cparser.upgrade
|
|
||||||
chunked = self._cparser.flags & cparser.F_CHUNKED
|
|
||||||
|
|
||||||
raw_headers = tuple(self._raw_headers)
|
|
||||||
headers = CIMultiDictProxy(CIMultiDict(self._headers))
|
|
||||||
|
|
||||||
if self._cparser.type == cparser.HTTP_REQUEST:
|
|
||||||
allowed = upgrade and headers.get("upgrade", "").lower() in ALLOWED_UPGRADES
|
|
||||||
if allowed or self._cparser.method == cparser.HTTP_CONNECT:
|
|
||||||
self._upgraded = True
|
|
||||||
else:
|
|
||||||
if upgrade and self._cparser.status_code == 101:
|
|
||||||
self._upgraded = True
|
|
||||||
|
|
||||||
# do not support old websocket spec
|
|
||||||
if SEC_WEBSOCKET_KEY1 in headers:
|
|
||||||
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
|
|
||||||
|
|
||||||
encoding = None
|
|
||||||
enc = self._content_encoding
|
|
||||||
if enc is not None:
|
|
||||||
self._content_encoding = None
|
|
||||||
enc = enc.lower()
|
|
||||||
if enc in ('gzip', 'deflate', 'br', 'zstd'):
|
|
||||||
encoding = enc
|
|
||||||
|
|
||||||
if self._cparser.type == cparser.HTTP_REQUEST:
|
|
||||||
method = http_method_str(self._cparser.method)
|
|
||||||
msg = _new_request_message(
|
|
||||||
method, self._path,
|
|
||||||
self.http_version(), headers, raw_headers,
|
|
||||||
should_close, encoding, upgrade, chunked, self._url)
|
|
||||||
else:
|
|
||||||
msg = _new_response_message(
|
|
||||||
self.http_version(), self._cparser.status_code, self._reason,
|
|
||||||
headers, raw_headers, should_close, encoding,
|
|
||||||
upgrade, chunked)
|
|
||||||
|
|
||||||
if (
|
|
||||||
ULLONG_MAX > self._cparser.content_length > 0 or chunked or
|
|
||||||
self._cparser.method == cparser.HTTP_CONNECT or
|
|
||||||
(self._cparser.status_code >= 199 and
|
|
||||||
self._cparser.content_length == 0 and
|
|
||||||
self._read_until_eof)
|
|
||||||
):
|
|
||||||
payload = StreamReader(
|
|
||||||
self._protocol, timer=self._timer, loop=self._loop,
|
|
||||||
limit=self._limit)
|
|
||||||
else:
|
|
||||||
payload = EMPTY_PAYLOAD
|
|
||||||
|
|
||||||
self._payload = payload
|
|
||||||
if encoding is not None and self._auto_decompress:
|
|
||||||
self._payload = DeflateBuffer(payload, encoding)
|
|
||||||
|
|
||||||
if not self._response_with_body:
|
|
||||||
payload = EMPTY_PAYLOAD
|
|
||||||
|
|
||||||
self._messages.append((msg, payload))
|
|
||||||
|
|
||||||
cdef _on_message_complete(self):
|
|
||||||
self._payload.feed_eof()
|
|
||||||
self._payload = None
|
|
||||||
|
|
||||||
cdef _on_chunk_header(self):
|
|
||||||
self._payload.begin_http_chunk_receiving()
|
|
||||||
|
|
||||||
cdef _on_chunk_complete(self):
|
|
||||||
self._payload.end_http_chunk_receiving()
|
|
||||||
|
|
||||||
cdef object _on_status_complete(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
cdef inline http_version(self):
|
|
||||||
cdef cparser.llhttp_t* parser = self._cparser
|
|
||||||
|
|
||||||
if parser.http_major == 1:
|
|
||||||
if parser.http_minor == 0:
|
|
||||||
return HttpVersion10
|
|
||||||
elif parser.http_minor == 1:
|
|
||||||
return HttpVersion11
|
|
||||||
|
|
||||||
return HttpVersion(parser.http_major, parser.http_minor)
|
|
||||||
|
|
||||||
### Public API ###
|
|
||||||
|
|
||||||
def feed_eof(self):
|
|
||||||
cdef bytes desc
|
|
||||||
|
|
||||||
if self._payload is not None:
|
|
||||||
if self._cparser.flags & cparser.F_CHUNKED:
|
|
||||||
raise TransferEncodingError(
|
|
||||||
"Not enough data to satisfy transfer length header.")
|
|
||||||
elif self._cparser.flags & cparser.F_CONTENT_LENGTH:
|
|
||||||
raise ContentLengthError(
|
|
||||||
"Not enough data to satisfy content length header.")
|
|
||||||
elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK:
|
|
||||||
desc = cparser.llhttp_get_error_reason(self._cparser)
|
|
||||||
raise PayloadEncodingError(desc.decode('latin-1'))
|
|
||||||
else:
|
|
||||||
self._payload.feed_eof()
|
|
||||||
elif self._started:
|
|
||||||
self._on_headers_complete()
|
|
||||||
if self._messages:
|
|
||||||
return self._messages[-1][0]
|
|
||||||
|
|
||||||
def feed_data(self, data):
|
|
||||||
cdef:
|
|
||||||
size_t data_len
|
|
||||||
size_t nb
|
|
||||||
cdef cparser.llhttp_errno_t errno
|
|
||||||
|
|
||||||
PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE)
|
|
||||||
data_len = <size_t>self.py_buf.len
|
|
||||||
|
|
||||||
errno = cparser.llhttp_execute(
|
|
||||||
self._cparser,
|
|
||||||
<char*>self.py_buf.buf,
|
|
||||||
data_len)
|
|
||||||
|
|
||||||
if errno is cparser.HPE_PAUSED_UPGRADE:
|
|
||||||
cparser.llhttp_resume_after_upgrade(self._cparser)
|
|
||||||
|
|
||||||
nb = cparser.llhttp_get_error_pos(self._cparser) - <char*>self.py_buf.buf
|
|
||||||
|
|
||||||
PyBuffer_Release(&self.py_buf)
|
|
||||||
|
|
||||||
if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE):
|
|
||||||
if self._payload_error == 0:
|
|
||||||
if self._last_error is not None:
|
|
||||||
ex = self._last_error
|
|
||||||
self._last_error = None
|
|
||||||
else:
|
|
||||||
after = cparser.llhttp_get_error_pos(self._cparser)
|
|
||||||
before = data[:after - <char*>self.py_buf.buf]
|
|
||||||
after_b = after.split(b"\r\n", 1)[0]
|
|
||||||
before = before.rsplit(b"\r\n", 1)[-1]
|
|
||||||
data = before + after_b
|
|
||||||
pointer = " " * (len(repr(before))-1) + "^"
|
|
||||||
ex = parser_error_from_errno(self._cparser, data, pointer)
|
|
||||||
self._payload = None
|
|
||||||
raise ex
|
|
||||||
|
|
||||||
if self._messages:
|
|
||||||
messages = self._messages
|
|
||||||
self._messages = []
|
|
||||||
else:
|
|
||||||
messages = ()
|
|
||||||
|
|
||||||
if self._upgraded:
|
|
||||||
return messages, True, data[nb:]
|
|
||||||
else:
|
|
||||||
return messages, False, b""
|
|
||||||
|
|
||||||
def set_upgraded(self, val):
|
|
||||||
self._upgraded = val
|
|
||||||
|
|
||||||
|
|
||||||
cdef class HttpRequestParser(HttpParser):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, protocol, loop, int limit, timer=None,
|
|
||||||
size_t max_line_size=8190, size_t max_headers=32768,
|
|
||||||
size_t max_field_size=8190, payload_exception=None,
|
|
||||||
bint response_with_body=True, bint read_until_eof=False,
|
|
||||||
bint auto_decompress=True,
|
|
||||||
):
|
|
||||||
self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer,
|
|
||||||
max_line_size, max_headers, max_field_size,
|
|
||||||
payload_exception, response_with_body, read_until_eof,
|
|
||||||
auto_decompress)
|
|
||||||
|
|
||||||
cdef object _on_status_complete(self):
|
|
||||||
cdef int idx1, idx2
|
|
||||||
if not self._buf:
|
|
||||||
return
|
|
||||||
self._path = self._buf.decode('utf-8', 'surrogateescape')
|
|
||||||
try:
|
|
||||||
idx3 = len(self._path)
|
|
||||||
if self._cparser.method == cparser.HTTP_CONNECT:
|
|
||||||
# authority-form,
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3
|
|
||||||
self._url = URL.build(authority=self._path, encoded=True)
|
|
||||||
elif idx3 > 1 and self._path[0] == '/':
|
|
||||||
# origin-form,
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1
|
|
||||||
idx1 = self._path.find("?")
|
|
||||||
if idx1 == -1:
|
|
||||||
query = ""
|
|
||||||
idx2 = self._path.find("#")
|
|
||||||
if idx2 == -1:
|
|
||||||
path = self._path
|
|
||||||
fragment = ""
|
|
||||||
else:
|
|
||||||
path = self._path[0: idx2]
|
|
||||||
fragment = self._path[idx2+1:]
|
|
||||||
|
|
||||||
else:
|
|
||||||
path = self._path[0:idx1]
|
|
||||||
idx1 += 1
|
|
||||||
idx2 = self._path.find("#", idx1+1)
|
|
||||||
if idx2 == -1:
|
|
||||||
query = self._path[idx1:]
|
|
||||||
fragment = ""
|
|
||||||
else:
|
|
||||||
query = self._path[idx1: idx2]
|
|
||||||
fragment = self._path[idx2+1:]
|
|
||||||
|
|
||||||
self._url = URL.build(
|
|
||||||
path=path,
|
|
||||||
query_string=query,
|
|
||||||
fragment=fragment,
|
|
||||||
encoded=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# absolute-form for proxy maybe,
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2
|
|
||||||
self._url = URL(self._path, encoded=True)
|
|
||||||
finally:
|
|
||||||
PyByteArray_Resize(self._buf, 0)
|
|
||||||
|
|
||||||
|
|
||||||
cdef class HttpResponseParser(HttpParser):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, protocol, loop, int limit, timer=None,
|
|
||||||
size_t max_line_size=8190, size_t max_headers=32768,
|
|
||||||
size_t max_field_size=8190, payload_exception=None,
|
|
||||||
bint response_with_body=True, bint read_until_eof=False,
|
|
||||||
bint auto_decompress=True
|
|
||||||
):
|
|
||||||
self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer,
|
|
||||||
max_line_size, max_headers, max_field_size,
|
|
||||||
payload_exception, response_with_body, read_until_eof,
|
|
||||||
auto_decompress)
|
|
||||||
# Use strict parsing on dev mode, so users are warned about broken servers.
|
|
||||||
if not DEBUG:
|
|
||||||
cparser.llhttp_set_lenient_headers(self._cparser, 1)
|
|
||||||
cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1)
|
|
||||||
cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1)
|
|
||||||
|
|
||||||
cdef object _on_status_complete(self):
|
|
||||||
if self._buf:
|
|
||||||
self._reason = self._buf.decode('utf-8', 'surrogateescape')
|
|
||||||
PyByteArray_Resize(self._buf, 0)
|
|
||||||
else:
|
|
||||||
self._reason = self._reason or ''
|
|
||||||
|
|
||||||
cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
|
|
||||||
pyparser._started = True
|
|
||||||
pyparser._headers = []
|
|
||||||
pyparser._raw_headers = []
|
|
||||||
PyByteArray_Resize(pyparser._buf, 0)
|
|
||||||
pyparser._path = None
|
|
||||||
pyparser._reason = None
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_url(cparser.llhttp_t* parser,
|
|
||||||
const char *at, size_t length) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
try:
|
|
||||||
if length > pyparser._max_line_size:
|
|
||||||
raise LineTooLong(
|
|
||||||
'Status line is too long', pyparser._max_line_size, length)
|
|
||||||
extend(pyparser._buf, at, length)
|
|
||||||
except BaseException as ex:
|
|
||||||
pyparser._last_error = ex
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_status(cparser.llhttp_t* parser,
|
|
||||||
const char *at, size_t length) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
cdef str reason
|
|
||||||
try:
|
|
||||||
if length > pyparser._max_line_size:
|
|
||||||
raise LineTooLong(
|
|
||||||
'Status line is too long', pyparser._max_line_size, length)
|
|
||||||
extend(pyparser._buf, at, length)
|
|
||||||
except BaseException as ex:
|
|
||||||
pyparser._last_error = ex
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_header_field(cparser.llhttp_t* parser,
|
|
||||||
const char *at, size_t length) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
cdef Py_ssize_t size
|
|
||||||
try:
|
|
||||||
pyparser._on_status_complete()
|
|
||||||
size = len(pyparser._raw_name) + length
|
|
||||||
if size > pyparser._max_field_size:
|
|
||||||
raise LineTooLong(
|
|
||||||
'Header name is too long', pyparser._max_field_size, size)
|
|
||||||
pyparser._on_header_field(at, length)
|
|
||||||
except BaseException as ex:
|
|
||||||
pyparser._last_error = ex
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_header_value(cparser.llhttp_t* parser,
|
|
||||||
const char *at, size_t length) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
cdef Py_ssize_t size
|
|
||||||
try:
|
|
||||||
size = len(pyparser._raw_value) + length
|
|
||||||
if size > pyparser._max_field_size:
|
|
||||||
raise LineTooLong(
|
|
||||||
'Header value is too long', pyparser._max_field_size, size)
|
|
||||||
pyparser._on_header_value(at, length)
|
|
||||||
except BaseException as ex:
|
|
||||||
pyparser._last_error = ex
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
try:
|
|
||||||
pyparser._on_status_complete()
|
|
||||||
pyparser._on_headers_complete()
|
|
||||||
except BaseException as exc:
|
|
||||||
pyparser._last_error = exc
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
|
|
||||||
return 2
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_body(cparser.llhttp_t* parser,
|
|
||||||
const char *at, size_t length) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
cdef bytes body = at[:length]
|
|
||||||
try:
|
|
||||||
pyparser._payload.feed_data(body, length)
|
|
||||||
except BaseException as underlying_exc:
|
|
||||||
reraised_exc = underlying_exc
|
|
||||||
if pyparser._payload_exception is not None:
|
|
||||||
reraised_exc = pyparser._payload_exception(str(underlying_exc))
|
|
||||||
|
|
||||||
set_exception(pyparser._payload, reraised_exc, underlying_exc)
|
|
||||||
|
|
||||||
pyparser._payload_error = 1
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
try:
|
|
||||||
pyparser._started = False
|
|
||||||
pyparser._on_message_complete()
|
|
||||||
except BaseException as exc:
|
|
||||||
pyparser._last_error = exc
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
try:
|
|
||||||
pyparser._on_chunk_header()
|
|
||||||
except BaseException as exc:
|
|
||||||
pyparser._last_error = exc
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1:
|
|
||||||
cdef HttpParser pyparser = <HttpParser>parser.data
|
|
||||||
try:
|
|
||||||
pyparser._on_chunk_complete()
|
|
||||||
except BaseException as exc:
|
|
||||||
pyparser._last_error = exc
|
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer):
|
|
||||||
cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser)
|
|
||||||
cdef bytes desc = cparser.llhttp_get_error_reason(parser)
|
|
||||||
|
|
||||||
err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer)
|
|
||||||
|
|
||||||
if errno in {cparser.HPE_CB_MESSAGE_BEGIN,
|
|
||||||
cparser.HPE_CB_HEADERS_COMPLETE,
|
|
||||||
cparser.HPE_CB_MESSAGE_COMPLETE,
|
|
||||||
cparser.HPE_CB_CHUNK_HEADER,
|
|
||||||
cparser.HPE_CB_CHUNK_COMPLETE,
|
|
||||||
cparser.HPE_INVALID_CONSTANT,
|
|
||||||
cparser.HPE_INVALID_HEADER_TOKEN,
|
|
||||||
cparser.HPE_INVALID_CONTENT_LENGTH,
|
|
||||||
cparser.HPE_INVALID_CHUNK_SIZE,
|
|
||||||
cparser.HPE_INVALID_EOF_STATE,
|
|
||||||
cparser.HPE_INVALID_TRANSFER_ENCODING}:
|
|
||||||
return BadHttpMessage(err_msg)
|
|
||||||
elif errno == cparser.HPE_INVALID_METHOD:
|
|
||||||
return BadHttpMethod(error=err_msg)
|
|
||||||
elif errno in {cparser.HPE_INVALID_STATUS,
|
|
||||||
cparser.HPE_INVALID_VERSION}:
|
|
||||||
return BadStatusLine(error=err_msg)
|
|
||||||
elif errno == cparser.HPE_INVALID_URL:
|
|
||||||
return InvalidURLError(err_msg)
|
|
||||||
|
|
||||||
return BadHttpMessage(err_msg)
|
|
||||||
Binary file not shown.
|
|
@ -1,162 +0,0 @@
|
||||||
from cpython.bytes cimport PyBytes_FromStringAndSize
|
|
||||||
from cpython.exc cimport PyErr_NoMemory
|
|
||||||
from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc
|
|
||||||
from cpython.object cimport PyObject_Str
|
|
||||||
from libc.stdint cimport uint8_t, uint64_t
|
|
||||||
from libc.string cimport memcpy
|
|
||||||
|
|
||||||
from multidict import istr
|
|
||||||
|
|
||||||
DEF BUF_SIZE = 16 * 1024 # 16KiB
|
|
||||||
|
|
||||||
cdef object _istr = istr
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------- writer ---------------------------
|
|
||||||
|
|
||||||
cdef struct Writer:
|
|
||||||
char *buf
|
|
||||||
Py_ssize_t size
|
|
||||||
Py_ssize_t pos
|
|
||||||
bint heap_allocated
|
|
||||||
|
|
||||||
cdef inline void _init_writer(Writer* writer, char *buf):
|
|
||||||
writer.buf = buf
|
|
||||||
writer.size = BUF_SIZE
|
|
||||||
writer.pos = 0
|
|
||||||
writer.heap_allocated = 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef inline void _release_writer(Writer* writer):
|
|
||||||
if writer.heap_allocated:
|
|
||||||
PyMem_Free(writer.buf)
|
|
||||||
|
|
||||||
|
|
||||||
cdef inline int _write_byte(Writer* writer, uint8_t ch):
|
|
||||||
cdef char * buf
|
|
||||||
cdef Py_ssize_t size
|
|
||||||
|
|
||||||
if writer.pos == writer.size:
|
|
||||||
# reallocate
|
|
||||||
size = writer.size + BUF_SIZE
|
|
||||||
if not writer.heap_allocated:
|
|
||||||
buf = <char*>PyMem_Malloc(size)
|
|
||||||
if buf == NULL:
|
|
||||||
PyErr_NoMemory()
|
|
||||||
return -1
|
|
||||||
memcpy(buf, writer.buf, writer.size)
|
|
||||||
else:
|
|
||||||
buf = <char*>PyMem_Realloc(writer.buf, size)
|
|
||||||
if buf == NULL:
|
|
||||||
PyErr_NoMemory()
|
|
||||||
return -1
|
|
||||||
writer.buf = buf
|
|
||||||
writer.size = size
|
|
||||||
writer.heap_allocated = 1
|
|
||||||
writer.buf[writer.pos] = <char>ch
|
|
||||||
writer.pos += 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol):
|
|
||||||
cdef uint64_t utf = <uint64_t> symbol
|
|
||||||
|
|
||||||
if utf < 0x80:
|
|
||||||
return _write_byte(writer, <uint8_t>utf)
|
|
||||||
elif utf < 0x800:
|
|
||||||
if _write_byte(writer, <uint8_t>(0xc0 | (utf >> 6))) < 0:
|
|
||||||
return -1
|
|
||||||
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
|
|
||||||
elif 0xD800 <= utf <= 0xDFFF:
|
|
||||||
# surogate pair, ignored
|
|
||||||
return 0
|
|
||||||
elif utf < 0x10000:
|
|
||||||
if _write_byte(writer, <uint8_t>(0xe0 | (utf >> 12))) < 0:
|
|
||||||
return -1
|
|
||||||
if _write_byte(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
|
|
||||||
return -1
|
|
||||||
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
|
|
||||||
elif utf > 0x10FFFF:
|
|
||||||
# symbol is too large
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
if _write_byte(writer, <uint8_t>(0xf0 | (utf >> 18))) < 0:
|
|
||||||
return -1
|
|
||||||
if _write_byte(writer,
|
|
||||||
<uint8_t>(0x80 | ((utf >> 12) & 0x3f))) < 0:
|
|
||||||
return -1
|
|
||||||
if _write_byte(writer,
|
|
||||||
<uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
|
|
||||||
return -1
|
|
||||||
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
|
|
||||||
|
|
||||||
|
|
||||||
cdef inline int _write_str(Writer* writer, str s):
|
|
||||||
cdef Py_UCS4 ch
|
|
||||||
for ch in s:
|
|
||||||
if _write_utf8(writer, ch) < 0:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s):
|
|
||||||
cdef Py_UCS4 ch
|
|
||||||
cdef str out_str
|
|
||||||
if type(s) is str:
|
|
||||||
out_str = <str>s
|
|
||||||
elif type(s) is _istr:
|
|
||||||
out_str = PyObject_Str(s)
|
|
||||||
elif not isinstance(s, str):
|
|
||||||
raise TypeError("Cannot serialize non-str key {!r}".format(s))
|
|
||||||
else:
|
|
||||||
out_str = str(s)
|
|
||||||
|
|
||||||
for ch in out_str:
|
|
||||||
if ch == 0x0D or ch == 0x0A:
|
|
||||||
raise ValueError(
|
|
||||||
"Newline or carriage return detected in headers. "
|
|
||||||
"Potential header injection attack."
|
|
||||||
)
|
|
||||||
if _write_utf8(writer, ch) < 0:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
# --------------- _serialize_headers ----------------------
|
|
||||||
|
|
||||||
def _serialize_headers(str status_line, headers):
|
|
||||||
cdef Writer writer
|
|
||||||
cdef object key
|
|
||||||
cdef object val
|
|
||||||
cdef char buf[BUF_SIZE]
|
|
||||||
|
|
||||||
_init_writer(&writer, buf)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if _write_str(&writer, status_line) < 0:
|
|
||||||
raise
|
|
||||||
if _write_byte(&writer, b'\r') < 0:
|
|
||||||
raise
|
|
||||||
if _write_byte(&writer, b'\n') < 0:
|
|
||||||
raise
|
|
||||||
|
|
||||||
for key, val in headers.items():
|
|
||||||
if _write_str_raise_on_nlcr(&writer, key) < 0:
|
|
||||||
raise
|
|
||||||
if _write_byte(&writer, b':') < 0:
|
|
||||||
raise
|
|
||||||
if _write_byte(&writer, b' ') < 0:
|
|
||||||
raise
|
|
||||||
if _write_str_raise_on_nlcr(&writer, val) < 0:
|
|
||||||
raise
|
|
||||||
if _write_byte(&writer, b'\r') < 0:
|
|
||||||
raise
|
|
||||||
if _write_byte(&writer, b'\n') < 0:
|
|
||||||
raise
|
|
||||||
|
|
||||||
if _write_byte(&writer, b'\r') < 0:
|
|
||||||
raise
|
|
||||||
if _write_byte(&writer, b'\n') < 0:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return PyBytes_FromStringAndSize(writer.buf, writer.pos)
|
|
||||||
finally:
|
|
||||||
_release_writer(&writer)
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93 /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
9e5fe78ed0ebce5414d2b8e01868d90c1facc20b84d2d5ff6c23e86e44a155ae /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
"""WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
@ -1,147 +0,0 @@
|
||||||
"""Helpers for WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import re
|
|
||||||
from struct import Struct
|
|
||||||
from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
|
|
||||||
|
|
||||||
from ..helpers import NO_EXTENSIONS
|
|
||||||
from .models import WSHandshakeError
|
|
||||||
|
|
||||||
UNPACK_LEN3 = Struct("!Q").unpack_from
|
|
||||||
UNPACK_CLOSE_CODE = Struct("!H").unpack
|
|
||||||
PACK_LEN1 = Struct("!BB").pack
|
|
||||||
PACK_LEN2 = Struct("!BBH").pack
|
|
||||||
PACK_LEN3 = Struct("!BBQ").pack
|
|
||||||
PACK_CLOSE_CODE = Struct("!H").pack
|
|
||||||
PACK_RANDBITS = Struct("!L").pack
|
|
||||||
MSG_SIZE: Final[int] = 2**14
|
|
||||||
MASK_LEN: Final[int] = 4
|
|
||||||
|
|
||||||
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
||||||
|
|
||||||
|
|
||||||
# Used by _websocket_mask_python
|
|
||||||
@functools.lru_cache
|
|
||||||
def _xor_table() -> List[bytes]:
|
|
||||||
return [bytes(a ^ b for a in range(256)) for b in range(256)]
|
|
||||||
|
|
||||||
|
|
||||||
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
|
|
||||||
"""Websocket masking function.
|
|
||||||
|
|
||||||
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
|
|
||||||
object of any length. The contents of `data` are masked with `mask`,
|
|
||||||
as specified in section 5.3 of RFC 6455.
|
|
||||||
|
|
||||||
Note that this function mutates the `data` argument.
|
|
||||||
|
|
||||||
This pure-python implementation may be replaced by an optimized
|
|
||||||
version when available.
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert isinstance(data, bytearray), data
|
|
||||||
assert len(mask) == 4, mask
|
|
||||||
|
|
||||||
if data:
|
|
||||||
_XOR_TABLE = _xor_table()
|
|
||||||
a, b, c, d = (_XOR_TABLE[n] for n in mask)
|
|
||||||
data[::4] = data[::4].translate(a)
|
|
||||||
data[1::4] = data[1::4].translate(b)
|
|
||||||
data[2::4] = data[2::4].translate(c)
|
|
||||||
data[3::4] = data[3::4].translate(d)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
|
||||||
websocket_mask = _websocket_mask_python
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
websocket_mask = _websocket_mask_cython
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
websocket_mask = _websocket_mask_python
|
|
||||||
|
|
||||||
|
|
||||||
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
|
|
||||||
r"^(?:;\s*(?:"
|
|
||||||
r"(server_no_context_takeover)|"
|
|
||||||
r"(client_no_context_takeover)|"
|
|
||||||
r"(server_max_window_bits(?:=(\d+))?)|"
|
|
||||||
r"(client_max_window_bits(?:=(\d+))?)))*$"
|
|
||||||
)
|
|
||||||
|
|
||||||
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
|
|
||||||
|
|
||||||
|
|
||||||
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
|
|
||||||
if not extstr:
|
|
||||||
return 0, False
|
|
||||||
|
|
||||||
compress = 0
|
|
||||||
notakeover = False
|
|
||||||
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
|
|
||||||
defext = ext.group(1)
|
|
||||||
# Return compress = 15 when get `permessage-deflate`
|
|
||||||
if not defext:
|
|
||||||
compress = 15
|
|
||||||
break
|
|
||||||
match = _WS_EXT_RE.match(defext)
|
|
||||||
if match:
|
|
||||||
compress = 15
|
|
||||||
if isserver:
|
|
||||||
# Server never fail to detect compress handshake.
|
|
||||||
# Server does not need to send max wbit to client
|
|
||||||
if match.group(4):
|
|
||||||
compress = int(match.group(4))
|
|
||||||
# Group3 must match if group4 matches
|
|
||||||
# Compress wbit 8 does not support in zlib
|
|
||||||
# If compress level not support,
|
|
||||||
# CONTINUE to next extension
|
|
||||||
if compress > 15 or compress < 9:
|
|
||||||
compress = 0
|
|
||||||
continue
|
|
||||||
if match.group(1):
|
|
||||||
notakeover = True
|
|
||||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if match.group(6):
|
|
||||||
compress = int(match.group(6))
|
|
||||||
# Group5 must match if group6 matches
|
|
||||||
# Compress wbit 8 does not support in zlib
|
|
||||||
# If compress level not support,
|
|
||||||
# FAIL the parse progress
|
|
||||||
if compress > 15 or compress < 9:
|
|
||||||
raise WSHandshakeError("Invalid window size")
|
|
||||||
if match.group(2):
|
|
||||||
notakeover = True
|
|
||||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
|
||||||
break
|
|
||||||
# Return Fail if client side and not match
|
|
||||||
elif not isserver:
|
|
||||||
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
|
|
||||||
|
|
||||||
return compress, notakeover
|
|
||||||
|
|
||||||
|
|
||||||
def ws_ext_gen(
|
|
||||||
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
|
|
||||||
) -> str:
|
|
||||||
# client_notakeover=False not used for server
|
|
||||||
# compress wbit 8 does not support in zlib
|
|
||||||
if compress < 9 or compress > 15:
|
|
||||||
raise ValueError(
|
|
||||||
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
|
|
||||||
)
|
|
||||||
enabledext = ["permessage-deflate"]
|
|
||||||
if not isserver:
|
|
||||||
enabledext.append("client_max_window_bits")
|
|
||||||
|
|
||||||
if compress < 15:
|
|
||||||
enabledext.append("server_max_window_bits=" + str(compress))
|
|
||||||
if server_notakeover:
|
|
||||||
enabledext.append("server_no_context_takeover")
|
|
||||||
# if client_notakeover:
|
|
||||||
# enabledext.append('client_no_context_takeover')
|
|
||||||
return "; ".join(enabledext)
|
|
||||||
Binary file not shown.
|
|
@ -1,3 +0,0 @@
|
||||||
"""Cython declarations for websocket masking."""
|
|
||||||
|
|
||||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data)
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
from cpython cimport PyBytes_AsString
|
|
||||||
|
|
||||||
|
|
||||||
#from cpython cimport PyByteArray_AsString # cython still not exports that
|
|
||||||
cdef extern from "Python.h":
|
|
||||||
char* PyByteArray_AsString(bytearray ba) except NULL
|
|
||||||
|
|
||||||
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
|
|
||||||
|
|
||||||
|
|
||||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
|
|
||||||
"""Note, this function mutates its `data` argument
|
|
||||||
"""
|
|
||||||
cdef:
|
|
||||||
Py_ssize_t data_len, i
|
|
||||||
# bit operations on signed integers are implementation-specific
|
|
||||||
unsigned char * in_buf
|
|
||||||
const unsigned char * mask_buf
|
|
||||||
uint32_t uint32_msk
|
|
||||||
uint64_t uint64_msk
|
|
||||||
|
|
||||||
assert len(mask) == 4
|
|
||||||
|
|
||||||
data_len = len(data)
|
|
||||||
in_buf = <unsigned char*>PyByteArray_AsString(data)
|
|
||||||
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
|
|
||||||
uint32_msk = (<uint32_t*>mask_buf)[0]
|
|
||||||
|
|
||||||
# TODO: align in_data ptr to achieve even faster speeds
|
|
||||||
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
|
|
||||||
|
|
||||||
if sizeof(size_t) >= 8:
|
|
||||||
uint64_msk = uint32_msk
|
|
||||||
uint64_msk = (uint64_msk << 32) | uint32_msk
|
|
||||||
|
|
||||||
while data_len >= 8:
|
|
||||||
(<uint64_t*>in_buf)[0] ^= uint64_msk
|
|
||||||
in_buf += 8
|
|
||||||
data_len -= 8
|
|
||||||
|
|
||||||
|
|
||||||
while data_len >= 4:
|
|
||||||
(<uint32_t*>in_buf)[0] ^= uint32_msk
|
|
||||||
in_buf += 4
|
|
||||||
data_len -= 4
|
|
||||||
|
|
||||||
for i in range(0, data_len):
|
|
||||||
in_buf[i] ^= mask_buf[i]
|
|
||||||
|
|
@ -1,84 +0,0 @@
|
||||||
"""Models for WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from enum import IntEnum
|
|
||||||
from typing import Any, Callable, Final, NamedTuple, Optional, cast
|
|
||||||
|
|
||||||
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
|
|
||||||
|
|
||||||
|
|
||||||
class WSCloseCode(IntEnum):
|
|
||||||
OK = 1000
|
|
||||||
GOING_AWAY = 1001
|
|
||||||
PROTOCOL_ERROR = 1002
|
|
||||||
UNSUPPORTED_DATA = 1003
|
|
||||||
ABNORMAL_CLOSURE = 1006
|
|
||||||
INVALID_TEXT = 1007
|
|
||||||
POLICY_VIOLATION = 1008
|
|
||||||
MESSAGE_TOO_BIG = 1009
|
|
||||||
MANDATORY_EXTENSION = 1010
|
|
||||||
INTERNAL_ERROR = 1011
|
|
||||||
SERVICE_RESTART = 1012
|
|
||||||
TRY_AGAIN_LATER = 1013
|
|
||||||
BAD_GATEWAY = 1014
|
|
||||||
|
|
||||||
|
|
||||||
class WSMsgType(IntEnum):
|
|
||||||
# websocket spec types
|
|
||||||
CONTINUATION = 0x0
|
|
||||||
TEXT = 0x1
|
|
||||||
BINARY = 0x2
|
|
||||||
PING = 0x9
|
|
||||||
PONG = 0xA
|
|
||||||
CLOSE = 0x8
|
|
||||||
|
|
||||||
# aiohttp specific types
|
|
||||||
CLOSING = 0x100
|
|
||||||
CLOSED = 0x101
|
|
||||||
ERROR = 0x102
|
|
||||||
|
|
||||||
text = TEXT
|
|
||||||
binary = BINARY
|
|
||||||
ping = PING
|
|
||||||
pong = PONG
|
|
||||||
close = CLOSE
|
|
||||||
closing = CLOSING
|
|
||||||
closed = CLOSED
|
|
||||||
error = ERROR
|
|
||||||
|
|
||||||
|
|
||||||
class WSMessage(NamedTuple):
|
|
||||||
type: WSMsgType
|
|
||||||
# To type correctly, this would need some kind of tagged union for each type.
|
|
||||||
data: Any
|
|
||||||
extra: Optional[str]
|
|
||||||
|
|
||||||
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
|
|
||||||
"""Return parsed JSON data.
|
|
||||||
|
|
||||||
.. versionadded:: 0.22
|
|
||||||
"""
|
|
||||||
return loads(self.data)
|
|
||||||
|
|
||||||
|
|
||||||
# Constructing the tuple directly to avoid the overhead of
|
|
||||||
# the lambda and arg processing since NamedTuples are constructed
|
|
||||||
# with a run time built lambda
|
|
||||||
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
|
|
||||||
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
|
|
||||||
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketError(Exception):
|
|
||||||
"""WebSocket protocol parser error."""
|
|
||||||
|
|
||||||
def __init__(self, code: int, message: str) -> None:
|
|
||||||
self.code = code
|
|
||||||
super().__init__(code, message)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return cast(str, self.args[1])
|
|
||||||
|
|
||||||
|
|
||||||
class WSHandshakeError(Exception):
|
|
||||||
"""WebSocket protocol handshake error."""
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from ..helpers import NO_EXTENSIONS
|
|
||||||
|
|
||||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
|
||||||
from .reader_py import (
|
|
||||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
|
||||||
WebSocketReader as WebSocketReaderPython,
|
|
||||||
)
|
|
||||||
|
|
||||||
WebSocketReader = WebSocketReaderPython
|
|
||||||
WebSocketDataQueue = WebSocketDataQueuePython
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from .reader_c import ( # type: ignore[import-not-found]
|
|
||||||
WebSocketDataQueue as WebSocketDataQueueCython,
|
|
||||||
WebSocketReader as WebSocketReaderCython,
|
|
||||||
)
|
|
||||||
|
|
||||||
WebSocketReader = WebSocketReaderCython
|
|
||||||
WebSocketDataQueue = WebSocketDataQueueCython
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
from .reader_py import (
|
|
||||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
|
||||||
WebSocketReader as WebSocketReaderPython,
|
|
||||||
)
|
|
||||||
|
|
||||||
WebSocketReader = WebSocketReaderPython
|
|
||||||
WebSocketDataQueue = WebSocketDataQueuePython
|
|
||||||
Binary file not shown.
|
|
@ -1,110 +0,0 @@
|
||||||
import cython
|
|
||||||
|
|
||||||
from .mask cimport _websocket_mask_cython as websocket_mask
|
|
||||||
|
|
||||||
|
|
||||||
cdef unsigned int READ_HEADER
|
|
||||||
cdef unsigned int READ_PAYLOAD_LENGTH
|
|
||||||
cdef unsigned int READ_PAYLOAD_MASK
|
|
||||||
cdef unsigned int READ_PAYLOAD
|
|
||||||
|
|
||||||
cdef int OP_CODE_NOT_SET
|
|
||||||
cdef int OP_CODE_CONTINUATION
|
|
||||||
cdef int OP_CODE_TEXT
|
|
||||||
cdef int OP_CODE_BINARY
|
|
||||||
cdef int OP_CODE_CLOSE
|
|
||||||
cdef int OP_CODE_PING
|
|
||||||
cdef int OP_CODE_PONG
|
|
||||||
|
|
||||||
cdef int COMPRESSED_NOT_SET
|
|
||||||
cdef int COMPRESSED_FALSE
|
|
||||||
cdef int COMPRESSED_TRUE
|
|
||||||
|
|
||||||
cdef object UNPACK_LEN3
|
|
||||||
cdef object UNPACK_CLOSE_CODE
|
|
||||||
cdef object TUPLE_NEW
|
|
||||||
|
|
||||||
cdef object WSMsgType
|
|
||||||
cdef object WSMessage
|
|
||||||
|
|
||||||
cdef object WS_MSG_TYPE_TEXT
|
|
||||||
cdef object WS_MSG_TYPE_BINARY
|
|
||||||
|
|
||||||
cdef set ALLOWED_CLOSE_CODES
|
|
||||||
cdef set MESSAGE_TYPES_WITH_CONTENT
|
|
||||||
|
|
||||||
cdef tuple EMPTY_FRAME
|
|
||||||
cdef tuple EMPTY_FRAME_ERROR
|
|
||||||
|
|
||||||
cdef class WebSocketDataQueue:
|
|
||||||
|
|
||||||
cdef unsigned int _size
|
|
||||||
cdef public object _protocol
|
|
||||||
cdef unsigned int _limit
|
|
||||||
cdef object _loop
|
|
||||||
cdef bint _eof
|
|
||||||
cdef object _waiter
|
|
||||||
cdef object _exception
|
|
||||||
cdef public object _buffer
|
|
||||||
cdef object _get_buffer
|
|
||||||
cdef object _put_buffer
|
|
||||||
|
|
||||||
cdef void _release_waiter(self)
|
|
||||||
|
|
||||||
cpdef void feed_data(self, object data, unsigned int size)
|
|
||||||
|
|
||||||
@cython.locals(size="unsigned int")
|
|
||||||
cdef _read_from_buffer(self)
|
|
||||||
|
|
||||||
cdef class WebSocketReader:
|
|
||||||
|
|
||||||
cdef WebSocketDataQueue queue
|
|
||||||
cdef unsigned int _max_msg_size
|
|
||||||
|
|
||||||
cdef Exception _exc
|
|
||||||
cdef bytearray _partial
|
|
||||||
cdef unsigned int _state
|
|
||||||
|
|
||||||
cdef int _opcode
|
|
||||||
cdef bint _frame_fin
|
|
||||||
cdef int _frame_opcode
|
|
||||||
cdef list _payload_fragments
|
|
||||||
cdef Py_ssize_t _frame_payload_len
|
|
||||||
|
|
||||||
cdef bytes _tail
|
|
||||||
cdef bint _has_mask
|
|
||||||
cdef bytes _frame_mask
|
|
||||||
cdef Py_ssize_t _payload_bytes_to_read
|
|
||||||
cdef unsigned int _payload_len_flag
|
|
||||||
cdef int _compressed
|
|
||||||
cdef object _decompressobj
|
|
||||||
cdef bint _compress
|
|
||||||
|
|
||||||
cpdef tuple feed_data(self, object data)
|
|
||||||
|
|
||||||
@cython.locals(
|
|
||||||
is_continuation=bint,
|
|
||||||
fin=bint,
|
|
||||||
has_partial=bint,
|
|
||||||
payload_merged=bytes,
|
|
||||||
)
|
|
||||||
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
|
|
||||||
|
|
||||||
@cython.locals(
|
|
||||||
start_pos=Py_ssize_t,
|
|
||||||
data_len=Py_ssize_t,
|
|
||||||
length=Py_ssize_t,
|
|
||||||
chunk_size=Py_ssize_t,
|
|
||||||
chunk_len=Py_ssize_t,
|
|
||||||
data_len=Py_ssize_t,
|
|
||||||
data_cstr="const unsigned char *",
|
|
||||||
first_byte="unsigned char",
|
|
||||||
second_byte="unsigned char",
|
|
||||||
f_start_pos=Py_ssize_t,
|
|
||||||
f_end_pos=Py_ssize_t,
|
|
||||||
has_mask=bint,
|
|
||||||
fin=bint,
|
|
||||||
had_fragments=Py_ssize_t,
|
|
||||||
payload_bytearray=bytearray,
|
|
||||||
)
|
|
||||||
cpdef void _feed_data(self, bytes data) except *
|
|
||||||
|
|
@ -1,476 +0,0 @@
|
||||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import builtins
|
|
||||||
from collections import deque
|
|
||||||
from typing import Deque, Final, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
from ..base_protocol import BaseProtocol
|
|
||||||
from ..compression_utils import ZLibDecompressor
|
|
||||||
from ..helpers import _EXC_SENTINEL, set_exception
|
|
||||||
from ..streams import EofStream
|
|
||||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
|
||||||
from .models import (
|
|
||||||
WS_DEFLATE_TRAILING,
|
|
||||||
WebSocketError,
|
|
||||||
WSCloseCode,
|
|
||||||
WSMessage,
|
|
||||||
WSMsgType,
|
|
||||||
)
|
|
||||||
|
|
||||||
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
|
|
||||||
|
|
||||||
# States for the reader, used to parse the WebSocket frame
|
|
||||||
# integer values are used so they can be cythonized
|
|
||||||
READ_HEADER = 1
|
|
||||||
READ_PAYLOAD_LENGTH = 2
|
|
||||||
READ_PAYLOAD_MASK = 3
|
|
||||||
READ_PAYLOAD = 4
|
|
||||||
|
|
||||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
|
||||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
|
||||||
|
|
||||||
# WSMsgType values unpacked so they can by cythonized to ints
|
|
||||||
OP_CODE_NOT_SET = -1
|
|
||||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
|
||||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
|
||||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
|
||||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
|
||||||
OP_CODE_PING = WSMsgType.PING.value
|
|
||||||
OP_CODE_PONG = WSMsgType.PONG.value
|
|
||||||
|
|
||||||
EMPTY_FRAME_ERROR = (True, b"")
|
|
||||||
EMPTY_FRAME = (False, b"")
|
|
||||||
|
|
||||||
COMPRESSED_NOT_SET = -1
|
|
||||||
COMPRESSED_FALSE = 0
|
|
||||||
COMPRESSED_TRUE = 1
|
|
||||||
|
|
||||||
TUPLE_NEW = tuple.__new__
|
|
||||||
|
|
||||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketDataQueue:
|
|
||||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
|
||||||
|
|
||||||
It is a destination for WebSocket data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
|
||||||
) -> None:
|
|
||||||
self._size = 0
|
|
||||||
self._protocol = protocol
|
|
||||||
self._limit = limit * 2
|
|
||||||
self._loop = loop
|
|
||||||
self._eof = False
|
|
||||||
self._waiter: Optional[asyncio.Future[None]] = None
|
|
||||||
self._exception: Union[BaseException, None] = None
|
|
||||||
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
|
|
||||||
self._get_buffer = self._buffer.popleft
|
|
||||||
self._put_buffer = self._buffer.append
|
|
||||||
|
|
||||||
def is_eof(self) -> bool:
|
|
||||||
return self._eof
|
|
||||||
|
|
||||||
def exception(self) -> Optional[BaseException]:
|
|
||||||
return self._exception
|
|
||||||
|
|
||||||
def set_exception(
|
|
||||||
self,
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
|
||||||
) -> None:
|
|
||||||
self._eof = True
|
|
||||||
self._exception = exc
|
|
||||||
if (waiter := self._waiter) is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_exception(waiter, exc, exc_cause)
|
|
||||||
|
|
||||||
def _release_waiter(self) -> None:
|
|
||||||
if (waiter := self._waiter) is None:
|
|
||||||
return
|
|
||||||
self._waiter = None
|
|
||||||
if not waiter.done():
|
|
||||||
waiter.set_result(None)
|
|
||||||
|
|
||||||
def feed_eof(self) -> None:
|
|
||||||
self._eof = True
|
|
||||||
self._release_waiter()
|
|
||||||
self._exception = None # Break cyclic references
|
|
||||||
|
|
||||||
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
|
|
||||||
self._size += size
|
|
||||||
self._put_buffer((data, size))
|
|
||||||
self._release_waiter()
|
|
||||||
if self._size > self._limit and not self._protocol._reading_paused:
|
|
||||||
self._protocol.pause_reading()
|
|
||||||
|
|
||||||
async def read(self) -> WSMessage:
|
|
||||||
if not self._buffer and not self._eof:
|
|
||||||
assert not self._waiter
|
|
||||||
self._waiter = self._loop.create_future()
|
|
||||||
try:
|
|
||||||
await self._waiter
|
|
||||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
||||||
self._waiter = None
|
|
||||||
raise
|
|
||||||
return self._read_from_buffer()
|
|
||||||
|
|
||||||
def _read_from_buffer(self) -> WSMessage:
|
|
||||||
if self._buffer:
|
|
||||||
data, size = self._get_buffer()
|
|
||||||
self._size -= size
|
|
||||||
if self._size < self._limit and self._protocol._reading_paused:
|
|
||||||
self._protocol.resume_reading()
|
|
||||||
return data
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
raise EofStream
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketReader:
|
|
||||||
def __init__(
|
|
||||||
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
|
|
||||||
) -> None:
|
|
||||||
self.queue = queue
|
|
||||||
self._max_msg_size = max_msg_size
|
|
||||||
|
|
||||||
self._exc: Optional[Exception] = None
|
|
||||||
self._partial = bytearray()
|
|
||||||
self._state = READ_HEADER
|
|
||||||
|
|
||||||
self._opcode: int = OP_CODE_NOT_SET
|
|
||||||
self._frame_fin = False
|
|
||||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
|
||||||
self._payload_fragments: list[bytes] = []
|
|
||||||
self._frame_payload_len = 0
|
|
||||||
|
|
||||||
self._tail: bytes = b""
|
|
||||||
self._has_mask = False
|
|
||||||
self._frame_mask: Optional[bytes] = None
|
|
||||||
self._payload_bytes_to_read = 0
|
|
||||||
self._payload_len_flag = 0
|
|
||||||
self._compressed: int = COMPRESSED_NOT_SET
|
|
||||||
self._decompressobj: Optional[ZLibDecompressor] = None
|
|
||||||
self._compress = compress
|
|
||||||
|
|
||||||
def feed_eof(self) -> None:
|
|
||||||
self.queue.feed_eof()
|
|
||||||
|
|
||||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
|
||||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
|
||||||
# coerce data to bytes if it is not
|
|
||||||
def feed_data(
|
|
||||||
self, data: Union[bytes, bytearray, memoryview]
|
|
||||||
) -> Tuple[bool, bytes]:
|
|
||||||
if type(data) is not bytes:
|
|
||||||
data = bytes(data)
|
|
||||||
|
|
||||||
if self._exc is not None:
|
|
||||||
return True, data
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._feed_data(data)
|
|
||||||
except Exception as exc:
|
|
||||||
self._exc = exc
|
|
||||||
set_exception(self.queue, exc)
|
|
||||||
return EMPTY_FRAME_ERROR
|
|
||||||
|
|
||||||
return EMPTY_FRAME
|
|
||||||
|
|
||||||
def _handle_frame(
|
|
||||||
self,
|
|
||||||
fin: bool,
|
|
||||||
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
|
||||||
payload: Union[bytes, bytearray],
|
|
||||||
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
|
||||||
) -> None:
|
|
||||||
msg: WSMessage
|
|
||||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
|
||||||
# load text/binary
|
|
||||||
if not fin:
|
|
||||||
# got partial frame payload
|
|
||||||
if opcode != OP_CODE_CONTINUATION:
|
|
||||||
self._opcode = opcode
|
|
||||||
self._partial += payload
|
|
||||||
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.MESSAGE_TOO_BIG,
|
|
||||||
f"Message size {len(self._partial)} "
|
|
||||||
f"exceeds limit {self._max_msg_size}",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
has_partial = bool(self._partial)
|
|
||||||
if opcode == OP_CODE_CONTINUATION:
|
|
||||||
if self._opcode == OP_CODE_NOT_SET:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Continuation frame for non started message",
|
|
||||||
)
|
|
||||||
opcode = self._opcode
|
|
||||||
self._opcode = OP_CODE_NOT_SET
|
|
||||||
# previous frame was non finished
|
|
||||||
# we should get continuation opcode
|
|
||||||
elif has_partial:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"The opcode in non-fin frame is expected "
|
|
||||||
f"to be zero, got {opcode!r}",
|
|
||||||
)
|
|
||||||
|
|
||||||
assembled_payload: Union[bytes, bytearray]
|
|
||||||
if has_partial:
|
|
||||||
assembled_payload = self._partial + payload
|
|
||||||
self._partial.clear()
|
|
||||||
else:
|
|
||||||
assembled_payload = payload
|
|
||||||
|
|
||||||
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.MESSAGE_TOO_BIG,
|
|
||||||
f"Message size {len(assembled_payload)} "
|
|
||||||
f"exceeds limit {self._max_msg_size}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decompress process must to be done after all packets
|
|
||||||
# received.
|
|
||||||
if compressed:
|
|
||||||
if not self._decompressobj:
|
|
||||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
|
||||||
# XXX: It's possible that the zlib backend (isal is known to
|
|
||||||
# do this, maybe others too?) will return max_length bytes,
|
|
||||||
# but internally buffer more data such that the payload is
|
|
||||||
# >max_length, so we return one extra byte and if we're able
|
|
||||||
# to do that, then the message is too big.
|
|
||||||
payload_merged = self._decompressobj.decompress_sync(
|
|
||||||
assembled_payload + WS_DEFLATE_TRAILING,
|
|
||||||
(
|
|
||||||
self._max_msg_size + 1
|
|
||||||
if self._max_msg_size
|
|
||||||
else self._max_msg_size
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.MESSAGE_TOO_BIG,
|
|
||||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
|
||||||
)
|
|
||||||
elif type(assembled_payload) is bytes:
|
|
||||||
payload_merged = assembled_payload
|
|
||||||
else:
|
|
||||||
payload_merged = bytes(assembled_payload)
|
|
||||||
|
|
||||||
if opcode == OP_CODE_TEXT:
|
|
||||||
try:
|
|
||||||
text = payload_merged.decode("utf-8")
|
|
||||||
except UnicodeDecodeError as exc:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# XXX: The Text and Binary messages here can be a performance
|
|
||||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
|
||||||
# This is not type safe, but many tests should fail in
|
|
||||||
# test_client_ws_functional.py if this is wrong.
|
|
||||||
self.queue.feed_data(
|
|
||||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
|
||||||
len(payload_merged),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.queue.feed_data(
|
|
||||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
|
||||||
len(payload_merged),
|
|
||||||
)
|
|
||||||
elif opcode == OP_CODE_CLOSE:
|
|
||||||
if len(payload) >= 2:
|
|
||||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
|
||||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
f"Invalid close code: {close_code}",
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
close_message = payload[2:].decode("utf-8")
|
|
||||||
except UnicodeDecodeError as exc:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
|
||||||
) from exc
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
|
||||||
elif payload:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
|
||||||
|
|
||||||
self.queue.feed_data(msg, 0)
|
|
||||||
elif opcode == OP_CODE_PING:
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
|
||||||
self.queue.feed_data(msg, len(payload))
|
|
||||||
elif opcode == OP_CODE_PONG:
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
|
||||||
self.queue.feed_data(msg, len(payload))
|
|
||||||
else:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _feed_data(self, data: bytes) -> None:
|
|
||||||
"""Return the next frame from the socket."""
|
|
||||||
if self._tail:
|
|
||||||
data, self._tail = self._tail + data, b""
|
|
||||||
|
|
||||||
start_pos: int = 0
|
|
||||||
data_len = len(data)
|
|
||||||
data_cstr = data
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# read header
|
|
||||||
if self._state == READ_HEADER:
|
|
||||||
if data_len - start_pos < 2:
|
|
||||||
break
|
|
||||||
first_byte = data_cstr[start_pos]
|
|
||||||
second_byte = data_cstr[start_pos + 1]
|
|
||||||
start_pos += 2
|
|
||||||
|
|
||||||
fin = (first_byte >> 7) & 1
|
|
||||||
rsv1 = (first_byte >> 6) & 1
|
|
||||||
rsv2 = (first_byte >> 5) & 1
|
|
||||||
rsv3 = (first_byte >> 4) & 1
|
|
||||||
opcode = first_byte & 0xF
|
|
||||||
|
|
||||||
# frame-fin = %x0 ; more frames of this message follow
|
|
||||||
# / %x1 ; final frame of this message
|
|
||||||
# frame-rsv1 = %x0 ;
|
|
||||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
|
||||||
# frame-rsv2 = %x0 ;
|
|
||||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
|
||||||
# frame-rsv3 = %x0 ;
|
|
||||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
|
||||||
#
|
|
||||||
# Remove rsv1 from this test for deflate development
|
|
||||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Received frame with non-zero reserved bits",
|
|
||||||
)
|
|
||||||
|
|
||||||
if opcode > 0x7 and fin == 0:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Received fragmented control frame",
|
|
||||||
)
|
|
||||||
|
|
||||||
has_mask = (second_byte >> 7) & 1
|
|
||||||
length = second_byte & 0x7F
|
|
||||||
|
|
||||||
# Control frames MUST have a payload
|
|
||||||
# length of 125 bytes or less
|
|
||||||
if opcode > 0x7 and length > 125:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Control frame payload cannot be larger than 125 bytes",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set compress status if last package is FIN
|
|
||||||
# OR set compress status if this is first fragment
|
|
||||||
# Raise error if not first fragment with rsv1 = 0x1
|
|
||||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
|
||||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
|
||||||
elif rsv1:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Received frame with non-zero reserved bits",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._frame_fin = bool(fin)
|
|
||||||
self._frame_opcode = opcode
|
|
||||||
self._has_mask = bool(has_mask)
|
|
||||||
self._payload_len_flag = length
|
|
||||||
self._state = READ_PAYLOAD_LENGTH
|
|
||||||
|
|
||||||
# read payload length
|
|
||||||
if self._state == READ_PAYLOAD_LENGTH:
|
|
||||||
len_flag = self._payload_len_flag
|
|
||||||
if len_flag == 126:
|
|
||||||
if data_len - start_pos < 2:
|
|
||||||
break
|
|
||||||
first_byte = data_cstr[start_pos]
|
|
||||||
second_byte = data_cstr[start_pos + 1]
|
|
||||||
start_pos += 2
|
|
||||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
|
||||||
elif len_flag > 126:
|
|
||||||
if data_len - start_pos < 8:
|
|
||||||
break
|
|
||||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
|
||||||
start_pos += 8
|
|
||||||
else:
|
|
||||||
self._payload_bytes_to_read = len_flag
|
|
||||||
|
|
||||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
|
||||||
|
|
||||||
# read payload mask
|
|
||||||
if self._state == READ_PAYLOAD_MASK:
|
|
||||||
if data_len - start_pos < 4:
|
|
||||||
break
|
|
||||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
|
||||||
start_pos += 4
|
|
||||||
self._state = READ_PAYLOAD
|
|
||||||
|
|
||||||
if self._state == READ_PAYLOAD:
|
|
||||||
chunk_len = data_len - start_pos
|
|
||||||
if self._payload_bytes_to_read >= chunk_len:
|
|
||||||
f_end_pos = data_len
|
|
||||||
self._payload_bytes_to_read -= chunk_len
|
|
||||||
else:
|
|
||||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
|
||||||
self._payload_bytes_to_read = 0
|
|
||||||
|
|
||||||
had_fragments = self._frame_payload_len
|
|
||||||
self._frame_payload_len += f_end_pos - start_pos
|
|
||||||
f_start_pos = start_pos
|
|
||||||
start_pos = f_end_pos
|
|
||||||
|
|
||||||
if self._payload_bytes_to_read != 0:
|
|
||||||
# If we don't have a complete frame, we need to save the
|
|
||||||
# data for the next call to feed_data.
|
|
||||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
|
||||||
break
|
|
||||||
|
|
||||||
payload: Union[bytes, bytearray]
|
|
||||||
if had_fragments:
|
|
||||||
# We have to join the payload fragments get the payload
|
|
||||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
|
||||||
if self._has_mask:
|
|
||||||
assert self._frame_mask is not None
|
|
||||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
|
||||||
websocket_mask(self._frame_mask, payload_bytearray)
|
|
||||||
payload = payload_bytearray
|
|
||||||
else:
|
|
||||||
payload = b"".join(self._payload_fragments)
|
|
||||||
self._payload_fragments.clear()
|
|
||||||
elif self._has_mask:
|
|
||||||
assert self._frame_mask is not None
|
|
||||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
|
||||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
|
||||||
# Cython will do the conversion for us
|
|
||||||
# but we need to do it for Python and we
|
|
||||||
# will always get here in Python
|
|
||||||
payload_bytearray = bytearray(payload_bytearray)
|
|
||||||
websocket_mask(self._frame_mask, payload_bytearray)
|
|
||||||
payload = payload_bytearray
|
|
||||||
else:
|
|
||||||
payload = data_cstr[f_start_pos:f_end_pos]
|
|
||||||
|
|
||||||
self._handle_frame(
|
|
||||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
|
||||||
)
|
|
||||||
self._frame_payload_len = 0
|
|
||||||
self._state = READ_HEADER
|
|
||||||
|
|
||||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
|
||||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
|
||||||
|
|
@ -1,476 +0,0 @@
|
||||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import builtins
|
|
||||||
from collections import deque
|
|
||||||
from typing import Deque, Final, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
from ..base_protocol import BaseProtocol
|
|
||||||
from ..compression_utils import ZLibDecompressor
|
|
||||||
from ..helpers import _EXC_SENTINEL, set_exception
|
|
||||||
from ..streams import EofStream
|
|
||||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
|
||||||
from .models import (
|
|
||||||
WS_DEFLATE_TRAILING,
|
|
||||||
WebSocketError,
|
|
||||||
WSCloseCode,
|
|
||||||
WSMessage,
|
|
||||||
WSMsgType,
|
|
||||||
)
|
|
||||||
|
|
||||||
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
|
|
||||||
|
|
||||||
# States for the reader, used to parse the WebSocket frame
|
|
||||||
# integer values are used so they can be cythonized
|
|
||||||
READ_HEADER = 1
|
|
||||||
READ_PAYLOAD_LENGTH = 2
|
|
||||||
READ_PAYLOAD_MASK = 3
|
|
||||||
READ_PAYLOAD = 4
|
|
||||||
|
|
||||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
|
||||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
|
||||||
|
|
||||||
# WSMsgType values unpacked so they can by cythonized to ints
|
|
||||||
OP_CODE_NOT_SET = -1
|
|
||||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
|
||||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
|
||||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
|
||||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
|
||||||
OP_CODE_PING = WSMsgType.PING.value
|
|
||||||
OP_CODE_PONG = WSMsgType.PONG.value
|
|
||||||
|
|
||||||
EMPTY_FRAME_ERROR = (True, b"")
|
|
||||||
EMPTY_FRAME = (False, b"")
|
|
||||||
|
|
||||||
COMPRESSED_NOT_SET = -1
|
|
||||||
COMPRESSED_FALSE = 0
|
|
||||||
COMPRESSED_TRUE = 1
|
|
||||||
|
|
||||||
TUPLE_NEW = tuple.__new__
|
|
||||||
|
|
||||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketDataQueue:
|
|
||||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
|
||||||
|
|
||||||
It is a destination for WebSocket data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
|
||||||
) -> None:
|
|
||||||
self._size = 0
|
|
||||||
self._protocol = protocol
|
|
||||||
self._limit = limit * 2
|
|
||||||
self._loop = loop
|
|
||||||
self._eof = False
|
|
||||||
self._waiter: Optional[asyncio.Future[None]] = None
|
|
||||||
self._exception: Union[BaseException, None] = None
|
|
||||||
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
|
|
||||||
self._get_buffer = self._buffer.popleft
|
|
||||||
self._put_buffer = self._buffer.append
|
|
||||||
|
|
||||||
def is_eof(self) -> bool:
|
|
||||||
return self._eof
|
|
||||||
|
|
||||||
def exception(self) -> Optional[BaseException]:
|
|
||||||
return self._exception
|
|
||||||
|
|
||||||
def set_exception(
|
|
||||||
self,
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
|
||||||
) -> None:
|
|
||||||
self._eof = True
|
|
||||||
self._exception = exc
|
|
||||||
if (waiter := self._waiter) is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_exception(waiter, exc, exc_cause)
|
|
||||||
|
|
||||||
def _release_waiter(self) -> None:
|
|
||||||
if (waiter := self._waiter) is None:
|
|
||||||
return
|
|
||||||
self._waiter = None
|
|
||||||
if not waiter.done():
|
|
||||||
waiter.set_result(None)
|
|
||||||
|
|
||||||
def feed_eof(self) -> None:
|
|
||||||
self._eof = True
|
|
||||||
self._release_waiter()
|
|
||||||
self._exception = None # Break cyclic references
|
|
||||||
|
|
||||||
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
|
|
||||||
self._size += size
|
|
||||||
self._put_buffer((data, size))
|
|
||||||
self._release_waiter()
|
|
||||||
if self._size > self._limit and not self._protocol._reading_paused:
|
|
||||||
self._protocol.pause_reading()
|
|
||||||
|
|
||||||
async def read(self) -> WSMessage:
|
|
||||||
if not self._buffer and not self._eof:
|
|
||||||
assert not self._waiter
|
|
||||||
self._waiter = self._loop.create_future()
|
|
||||||
try:
|
|
||||||
await self._waiter
|
|
||||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
||||||
self._waiter = None
|
|
||||||
raise
|
|
||||||
return self._read_from_buffer()
|
|
||||||
|
|
||||||
def _read_from_buffer(self) -> WSMessage:
|
|
||||||
if self._buffer:
|
|
||||||
data, size = self._get_buffer()
|
|
||||||
self._size -= size
|
|
||||||
if self._size < self._limit and self._protocol._reading_paused:
|
|
||||||
self._protocol.resume_reading()
|
|
||||||
return data
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
raise EofStream
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketReader:
|
|
||||||
def __init__(
|
|
||||||
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
|
|
||||||
) -> None:
|
|
||||||
self.queue = queue
|
|
||||||
self._max_msg_size = max_msg_size
|
|
||||||
|
|
||||||
self._exc: Optional[Exception] = None
|
|
||||||
self._partial = bytearray()
|
|
||||||
self._state = READ_HEADER
|
|
||||||
|
|
||||||
self._opcode: int = OP_CODE_NOT_SET
|
|
||||||
self._frame_fin = False
|
|
||||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
|
||||||
self._payload_fragments: list[bytes] = []
|
|
||||||
self._frame_payload_len = 0
|
|
||||||
|
|
||||||
self._tail: bytes = b""
|
|
||||||
self._has_mask = False
|
|
||||||
self._frame_mask: Optional[bytes] = None
|
|
||||||
self._payload_bytes_to_read = 0
|
|
||||||
self._payload_len_flag = 0
|
|
||||||
self._compressed: int = COMPRESSED_NOT_SET
|
|
||||||
self._decompressobj: Optional[ZLibDecompressor] = None
|
|
||||||
self._compress = compress
|
|
||||||
|
|
||||||
def feed_eof(self) -> None:
|
|
||||||
self.queue.feed_eof()
|
|
||||||
|
|
||||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
|
||||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
|
||||||
# coerce data to bytes if it is not
|
|
||||||
def feed_data(
|
|
||||||
self, data: Union[bytes, bytearray, memoryview]
|
|
||||||
) -> Tuple[bool, bytes]:
|
|
||||||
if type(data) is not bytes:
|
|
||||||
data = bytes(data)
|
|
||||||
|
|
||||||
if self._exc is not None:
|
|
||||||
return True, data
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._feed_data(data)
|
|
||||||
except Exception as exc:
|
|
||||||
self._exc = exc
|
|
||||||
set_exception(self.queue, exc)
|
|
||||||
return EMPTY_FRAME_ERROR
|
|
||||||
|
|
||||||
return EMPTY_FRAME
|
|
||||||
|
|
||||||
def _handle_frame(
|
|
||||||
self,
|
|
||||||
fin: bool,
|
|
||||||
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
|
||||||
payload: Union[bytes, bytearray],
|
|
||||||
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
|
||||||
) -> None:
|
|
||||||
msg: WSMessage
|
|
||||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
|
||||||
# load text/binary
|
|
||||||
if not fin:
|
|
||||||
# got partial frame payload
|
|
||||||
if opcode != OP_CODE_CONTINUATION:
|
|
||||||
self._opcode = opcode
|
|
||||||
self._partial += payload
|
|
||||||
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.MESSAGE_TOO_BIG,
|
|
||||||
f"Message size {len(self._partial)} "
|
|
||||||
f"exceeds limit {self._max_msg_size}",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
has_partial = bool(self._partial)
|
|
||||||
if opcode == OP_CODE_CONTINUATION:
|
|
||||||
if self._opcode == OP_CODE_NOT_SET:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Continuation frame for non started message",
|
|
||||||
)
|
|
||||||
opcode = self._opcode
|
|
||||||
self._opcode = OP_CODE_NOT_SET
|
|
||||||
# previous frame was non finished
|
|
||||||
# we should get continuation opcode
|
|
||||||
elif has_partial:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"The opcode in non-fin frame is expected "
|
|
||||||
f"to be zero, got {opcode!r}",
|
|
||||||
)
|
|
||||||
|
|
||||||
assembled_payload: Union[bytes, bytearray]
|
|
||||||
if has_partial:
|
|
||||||
assembled_payload = self._partial + payload
|
|
||||||
self._partial.clear()
|
|
||||||
else:
|
|
||||||
assembled_payload = payload
|
|
||||||
|
|
||||||
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.MESSAGE_TOO_BIG,
|
|
||||||
f"Message size {len(assembled_payload)} "
|
|
||||||
f"exceeds limit {self._max_msg_size}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decompress process must to be done after all packets
|
|
||||||
# received.
|
|
||||||
if compressed:
|
|
||||||
if not self._decompressobj:
|
|
||||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
|
||||||
# XXX: It's possible that the zlib backend (isal is known to
|
|
||||||
# do this, maybe others too?) will return max_length bytes,
|
|
||||||
# but internally buffer more data such that the payload is
|
|
||||||
# >max_length, so we return one extra byte and if we're able
|
|
||||||
# to do that, then the message is too big.
|
|
||||||
payload_merged = self._decompressobj.decompress_sync(
|
|
||||||
assembled_payload + WS_DEFLATE_TRAILING,
|
|
||||||
(
|
|
||||||
self._max_msg_size + 1
|
|
||||||
if self._max_msg_size
|
|
||||||
else self._max_msg_size
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.MESSAGE_TOO_BIG,
|
|
||||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
|
||||||
)
|
|
||||||
elif type(assembled_payload) is bytes:
|
|
||||||
payload_merged = assembled_payload
|
|
||||||
else:
|
|
||||||
payload_merged = bytes(assembled_payload)
|
|
||||||
|
|
||||||
if opcode == OP_CODE_TEXT:
|
|
||||||
try:
|
|
||||||
text = payload_merged.decode("utf-8")
|
|
||||||
except UnicodeDecodeError as exc:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# XXX: The Text and Binary messages here can be a performance
|
|
||||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
|
||||||
# This is not type safe, but many tests should fail in
|
|
||||||
# test_client_ws_functional.py if this is wrong.
|
|
||||||
self.queue.feed_data(
|
|
||||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
|
||||||
len(payload_merged),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.queue.feed_data(
|
|
||||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
|
||||||
len(payload_merged),
|
|
||||||
)
|
|
||||||
elif opcode == OP_CODE_CLOSE:
|
|
||||||
if len(payload) >= 2:
|
|
||||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
|
||||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
f"Invalid close code: {close_code}",
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
close_message = payload[2:].decode("utf-8")
|
|
||||||
except UnicodeDecodeError as exc:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
|
||||||
) from exc
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
|
||||||
elif payload:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
|
||||||
|
|
||||||
self.queue.feed_data(msg, 0)
|
|
||||||
elif opcode == OP_CODE_PING:
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
|
||||||
self.queue.feed_data(msg, len(payload))
|
|
||||||
elif opcode == OP_CODE_PONG:
|
|
||||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
|
||||||
self.queue.feed_data(msg, len(payload))
|
|
||||||
else:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _feed_data(self, data: bytes) -> None:
|
|
||||||
"""Return the next frame from the socket."""
|
|
||||||
if self._tail:
|
|
||||||
data, self._tail = self._tail + data, b""
|
|
||||||
|
|
||||||
start_pos: int = 0
|
|
||||||
data_len = len(data)
|
|
||||||
data_cstr = data
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# read header
|
|
||||||
if self._state == READ_HEADER:
|
|
||||||
if data_len - start_pos < 2:
|
|
||||||
break
|
|
||||||
first_byte = data_cstr[start_pos]
|
|
||||||
second_byte = data_cstr[start_pos + 1]
|
|
||||||
start_pos += 2
|
|
||||||
|
|
||||||
fin = (first_byte >> 7) & 1
|
|
||||||
rsv1 = (first_byte >> 6) & 1
|
|
||||||
rsv2 = (first_byte >> 5) & 1
|
|
||||||
rsv3 = (first_byte >> 4) & 1
|
|
||||||
opcode = first_byte & 0xF
|
|
||||||
|
|
||||||
# frame-fin = %x0 ; more frames of this message follow
|
|
||||||
# / %x1 ; final frame of this message
|
|
||||||
# frame-rsv1 = %x0 ;
|
|
||||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
|
||||||
# frame-rsv2 = %x0 ;
|
|
||||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
|
||||||
# frame-rsv3 = %x0 ;
|
|
||||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
|
||||||
#
|
|
||||||
# Remove rsv1 from this test for deflate development
|
|
||||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Received frame with non-zero reserved bits",
|
|
||||||
)
|
|
||||||
|
|
||||||
if opcode > 0x7 and fin == 0:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Received fragmented control frame",
|
|
||||||
)
|
|
||||||
|
|
||||||
has_mask = (second_byte >> 7) & 1
|
|
||||||
length = second_byte & 0x7F
|
|
||||||
|
|
||||||
# Control frames MUST have a payload
|
|
||||||
# length of 125 bytes or less
|
|
||||||
if opcode > 0x7 and length > 125:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Control frame payload cannot be larger than 125 bytes",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set compress status if last package is FIN
|
|
||||||
# OR set compress status if this is first fragment
|
|
||||||
# Raise error if not first fragment with rsv1 = 0x1
|
|
||||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
|
||||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
|
||||||
elif rsv1:
|
|
||||||
raise WebSocketError(
|
|
||||||
WSCloseCode.PROTOCOL_ERROR,
|
|
||||||
"Received frame with non-zero reserved bits",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._frame_fin = bool(fin)
|
|
||||||
self._frame_opcode = opcode
|
|
||||||
self._has_mask = bool(has_mask)
|
|
||||||
self._payload_len_flag = length
|
|
||||||
self._state = READ_PAYLOAD_LENGTH
|
|
||||||
|
|
||||||
# read payload length
|
|
||||||
if self._state == READ_PAYLOAD_LENGTH:
|
|
||||||
len_flag = self._payload_len_flag
|
|
||||||
if len_flag == 126:
|
|
||||||
if data_len - start_pos < 2:
|
|
||||||
break
|
|
||||||
first_byte = data_cstr[start_pos]
|
|
||||||
second_byte = data_cstr[start_pos + 1]
|
|
||||||
start_pos += 2
|
|
||||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
|
||||||
elif len_flag > 126:
|
|
||||||
if data_len - start_pos < 8:
|
|
||||||
break
|
|
||||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
|
||||||
start_pos += 8
|
|
||||||
else:
|
|
||||||
self._payload_bytes_to_read = len_flag
|
|
||||||
|
|
||||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
|
||||||
|
|
||||||
# read payload mask
|
|
||||||
if self._state == READ_PAYLOAD_MASK:
|
|
||||||
if data_len - start_pos < 4:
|
|
||||||
break
|
|
||||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
|
||||||
start_pos += 4
|
|
||||||
self._state = READ_PAYLOAD
|
|
||||||
|
|
||||||
if self._state == READ_PAYLOAD:
|
|
||||||
chunk_len = data_len - start_pos
|
|
||||||
if self._payload_bytes_to_read >= chunk_len:
|
|
||||||
f_end_pos = data_len
|
|
||||||
self._payload_bytes_to_read -= chunk_len
|
|
||||||
else:
|
|
||||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
|
||||||
self._payload_bytes_to_read = 0
|
|
||||||
|
|
||||||
had_fragments = self._frame_payload_len
|
|
||||||
self._frame_payload_len += f_end_pos - start_pos
|
|
||||||
f_start_pos = start_pos
|
|
||||||
start_pos = f_end_pos
|
|
||||||
|
|
||||||
if self._payload_bytes_to_read != 0:
|
|
||||||
# If we don't have a complete frame, we need to save the
|
|
||||||
# data for the next call to feed_data.
|
|
||||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
|
||||||
break
|
|
||||||
|
|
||||||
payload: Union[bytes, bytearray]
|
|
||||||
if had_fragments:
|
|
||||||
# We have to join the payload fragments get the payload
|
|
||||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
|
||||||
if self._has_mask:
|
|
||||||
assert self._frame_mask is not None
|
|
||||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
|
||||||
websocket_mask(self._frame_mask, payload_bytearray)
|
|
||||||
payload = payload_bytearray
|
|
||||||
else:
|
|
||||||
payload = b"".join(self._payload_fragments)
|
|
||||||
self._payload_fragments.clear()
|
|
||||||
elif self._has_mask:
|
|
||||||
assert self._frame_mask is not None
|
|
||||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
|
||||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
|
||||||
# Cython will do the conversion for us
|
|
||||||
# but we need to do it for Python and we
|
|
||||||
# will always get here in Python
|
|
||||||
payload_bytearray = bytearray(payload_bytearray)
|
|
||||||
websocket_mask(self._frame_mask, payload_bytearray)
|
|
||||||
payload = payload_bytearray
|
|
||||||
else:
|
|
||||||
payload = data_cstr[f_start_pos:f_end_pos]
|
|
||||||
|
|
||||||
self._handle_frame(
|
|
||||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
|
||||||
)
|
|
||||||
self._frame_payload_len = 0
|
|
||||||
self._state = READ_HEADER
|
|
||||||
|
|
||||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
|
||||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
|
||||||
|
|
@ -1,262 +0,0 @@
|
||||||
"""WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import random
|
|
||||||
import sys
|
|
||||||
from functools import partial
|
|
||||||
from typing import Final, Optional, Set, Union
|
|
||||||
|
|
||||||
from ..base_protocol import BaseProtocol
|
|
||||||
from ..client_exceptions import ClientConnectionResetError
|
|
||||||
from ..compression_utils import ZLibBackend, ZLibCompressor
|
|
||||||
from .helpers import (
|
|
||||||
MASK_LEN,
|
|
||||||
MSG_SIZE,
|
|
||||||
PACK_CLOSE_CODE,
|
|
||||||
PACK_LEN1,
|
|
||||||
PACK_LEN2,
|
|
||||||
PACK_LEN3,
|
|
||||||
PACK_RANDBITS,
|
|
||||||
websocket_mask,
|
|
||||||
)
|
|
||||||
from .models import WS_DEFLATE_TRAILING, WSMsgType
|
|
||||||
|
|
||||||
DEFAULT_LIMIT: Final[int] = 2**16
|
|
||||||
|
|
||||||
# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
|
|
||||||
# Control frames (ping, pong, close) are never compressed
|
|
||||||
WS_CONTROL_FRAME_OPCODE: Final[int] = 8
|
|
||||||
|
|
||||||
# For websockets, keeping latency low is extremely important as implementations
|
|
||||||
# generally expect to be able to send and receive messages quickly. We use a
|
|
||||||
# larger chunk size to reduce the number of executor calls and avoid task
|
|
||||||
# creation overhead, since both are significant sources of latency when chunks
|
|
||||||
# are small. A size of 16KiB was chosen as a balance between avoiding task
|
|
||||||
# overhead and not blocking the event loop too long with synchronous compression.
|
|
||||||
|
|
||||||
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketWriter:
|
|
||||||
"""WebSocket writer.
|
|
||||||
|
|
||||||
The writer is responsible for sending messages to the client. It is
|
|
||||||
created by the protocol when a connection is established. The writer
|
|
||||||
should avoid implementing any application logic and should only be
|
|
||||||
concerned with the low-level details of the WebSocket protocol.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
protocol: BaseProtocol,
|
|
||||||
transport: asyncio.Transport,
|
|
||||||
*,
|
|
||||||
use_mask: bool = False,
|
|
||||||
limit: int = DEFAULT_LIMIT,
|
|
||||||
random: random.Random = random.Random(),
|
|
||||||
compress: int = 0,
|
|
||||||
notakeover: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize a WebSocket writer."""
|
|
||||||
self.protocol = protocol
|
|
||||||
self.transport = transport
|
|
||||||
self.use_mask = use_mask
|
|
||||||
self.get_random_bits = partial(random.getrandbits, 32)
|
|
||||||
self.compress = compress
|
|
||||||
self.notakeover = notakeover
|
|
||||||
self._closing = False
|
|
||||||
self._limit = limit
|
|
||||||
self._output_size = 0
|
|
||||||
self._compressobj: Optional[ZLibCompressor] = None
|
|
||||||
self._send_lock = asyncio.Lock()
|
|
||||||
self._background_tasks: Set[asyncio.Task[None]] = set()
|
|
||||||
|
|
||||||
async def send_frame(
|
|
||||||
self, message: bytes, opcode: int, compress: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
"""Send a frame over the websocket with message as its payload."""
|
|
||||||
if self._closing and not (opcode & WSMsgType.CLOSE):
|
|
||||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
|
||||||
|
|
||||||
if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
|
|
||||||
# Non-compressed frames don't need lock or shield
|
|
||||||
self._write_websocket_frame(message, opcode, 0)
|
|
||||||
elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
|
|
||||||
# Small compressed payloads - compress synchronously in event loop
|
|
||||||
# We need the lock even though sync compression has no await points.
|
|
||||||
# This prevents small frames from interleaving with large frames that
|
|
||||||
# compress in the executor, avoiding compressor state corruption.
|
|
||||||
async with self._send_lock:
|
|
||||||
self._send_compressed_frame_sync(message, opcode, compress)
|
|
||||||
else:
|
|
||||||
# Large compressed frames need shield to prevent corruption
|
|
||||||
# For large compressed frames, the entire compress+send
|
|
||||||
# operation must be atomic. If cancelled after compression but
|
|
||||||
# before send, the compressor state would be advanced but data
|
|
||||||
# not sent, corrupting subsequent frames.
|
|
||||||
# Create a task to shield from cancellation
|
|
||||||
# The lock is acquired inside the shielded task so the entire
|
|
||||||
# operation (lock + compress + send) completes atomically.
|
|
||||||
# Use eager_start on Python 3.12+ to avoid scheduling overhead
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
coro = self._send_compressed_frame_async_locked(message, opcode, compress)
|
|
||||||
if sys.version_info >= (3, 12):
|
|
||||||
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
|
|
||||||
else:
|
|
||||||
send_task = loop.create_task(coro)
|
|
||||||
# Keep a strong reference to prevent garbage collection
|
|
||||||
self._background_tasks.add(send_task)
|
|
||||||
send_task.add_done_callback(self._background_tasks.discard)
|
|
||||||
await asyncio.shield(send_task)
|
|
||||||
|
|
||||||
# It is safe to return control to the event loop when using compression
|
|
||||||
# after this point as we have already sent or buffered all the data.
|
|
||||||
# Once we have written output_size up to the limit, we call the
|
|
||||||
# drain helper which waits for the transport to be ready to accept
|
|
||||||
# more data. This is a flow control mechanism to prevent the buffer
|
|
||||||
# from growing too large. The drain helper will return right away
|
|
||||||
# if the writer is not paused.
|
|
||||||
if self._output_size > self._limit:
|
|
||||||
self._output_size = 0
|
|
||||||
if self.protocol._paused:
|
|
||||||
await self.protocol._drain_helper()
|
|
||||||
|
|
||||||
def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
|
|
||||||
"""
|
|
||||||
Write a websocket frame to the transport.
|
|
||||||
|
|
||||||
This method handles frame header construction, masking, and writing to transport.
|
|
||||||
It does not handle compression or flow control - those are the responsibility
|
|
||||||
of the caller.
|
|
||||||
"""
|
|
||||||
msg_length = len(message)
|
|
||||||
|
|
||||||
use_mask = self.use_mask
|
|
||||||
mask_bit = 0x80 if use_mask else 0
|
|
||||||
|
|
||||||
# Depending on the message length, the header is assembled differently.
|
|
||||||
# The first byte is reserved for the opcode and the RSV bits.
|
|
||||||
first_byte = 0x80 | rsv | opcode
|
|
||||||
if msg_length < 126:
|
|
||||||
header = PACK_LEN1(first_byte, msg_length | mask_bit)
|
|
||||||
header_len = 2
|
|
||||||
elif msg_length < 65536:
|
|
||||||
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
|
|
||||||
header_len = 4
|
|
||||||
else:
|
|
||||||
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
|
|
||||||
header_len = 10
|
|
||||||
|
|
||||||
if self.transport.is_closing():
|
|
||||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
|
||||||
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
|
|
||||||
# If we are using a mask, we need to generate it randomly
|
|
||||||
# and apply it to the message before sending it. A mask is
|
|
||||||
# a 32-bit value that is applied to the message using a
|
|
||||||
# bitwise XOR operation. It is used to prevent certain types
|
|
||||||
# of attacks on the websocket protocol. The mask is only used
|
|
||||||
# when aiohttp is acting as a client. Servers do not use a mask.
|
|
||||||
if use_mask:
|
|
||||||
mask = PACK_RANDBITS(self.get_random_bits())
|
|
||||||
message = bytearray(message)
|
|
||||||
websocket_mask(mask, message)
|
|
||||||
self.transport.write(header + mask + message)
|
|
||||||
self._output_size += MASK_LEN
|
|
||||||
elif msg_length > MSG_SIZE:
|
|
||||||
self.transport.write(header)
|
|
||||||
self.transport.write(message)
|
|
||||||
else:
|
|
||||||
self.transport.write(header + message)
|
|
||||||
|
|
||||||
self._output_size += header_len + msg_length
|
|
||||||
|
|
||||||
def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor:
|
|
||||||
"""Get or create a compressor object for the given compression level."""
|
|
||||||
if compress:
|
|
||||||
# Do not set self._compress if compressing is for this frame
|
|
||||||
return ZLibCompressor(
|
|
||||||
level=ZLibBackend.Z_BEST_SPEED,
|
|
||||||
wbits=-compress,
|
|
||||||
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
|
||||||
)
|
|
||||||
if not self._compressobj:
|
|
||||||
self._compressobj = ZLibCompressor(
|
|
||||||
level=ZLibBackend.Z_BEST_SPEED,
|
|
||||||
wbits=-self.compress,
|
|
||||||
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
|
||||||
)
|
|
||||||
return self._compressobj
|
|
||||||
|
|
||||||
def _send_compressed_frame_sync(
|
|
||||||
self, message: bytes, opcode: int, compress: Optional[int]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Synchronous send for small compressed frames.
|
|
||||||
|
|
||||||
This is used for small compressed payloads that compress synchronously in the event loop.
|
|
||||||
Since there are no await points, this is inherently cancellation-safe.
|
|
||||||
"""
|
|
||||||
# RSV are the reserved bits in the frame header. They are used to
|
|
||||||
# indicate that the frame is using an extension.
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
|
||||||
compressobj = self._get_compressor(compress)
|
|
||||||
# (0x40) RSV1 is set for compressed frames
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
|
|
||||||
self._write_websocket_frame(
|
|
||||||
(
|
|
||||||
compressobj.compress_sync(message)
|
|
||||||
+ compressobj.flush(
|
|
||||||
ZLibBackend.Z_FULL_FLUSH
|
|
||||||
if self.notakeover
|
|
||||||
else ZLibBackend.Z_SYNC_FLUSH
|
|
||||||
)
|
|
||||||
).removesuffix(WS_DEFLATE_TRAILING),
|
|
||||||
opcode,
|
|
||||||
0x40,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _send_compressed_frame_async_locked(
|
|
||||||
self, message: bytes, opcode: int, compress: Optional[int]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Async send for large compressed frames with lock.
|
|
||||||
|
|
||||||
Acquires the lock and compresses large payloads asynchronously in
|
|
||||||
the executor. The lock is held for the entire operation to ensure
|
|
||||||
the compressor state is not corrupted by concurrent sends.
|
|
||||||
|
|
||||||
MUST be run shielded from cancellation. If cancelled after
|
|
||||||
compression but before sending, the compressor state would be
|
|
||||||
advanced but data not sent, corrupting subsequent frames.
|
|
||||||
"""
|
|
||||||
async with self._send_lock:
|
|
||||||
# RSV are the reserved bits in the frame header. They are used to
|
|
||||||
# indicate that the frame is using an extension.
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
|
||||||
compressobj = self._get_compressor(compress)
|
|
||||||
# (0x40) RSV1 is set for compressed frames
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
|
|
||||||
self._write_websocket_frame(
|
|
||||||
(
|
|
||||||
await compressobj.compress(message)
|
|
||||||
+ compressobj.flush(
|
|
||||||
ZLibBackend.Z_FULL_FLUSH
|
|
||||||
if self.notakeover
|
|
||||||
else ZLibBackend.Z_SYNC_FLUSH
|
|
||||||
)
|
|
||||||
).removesuffix(WS_DEFLATE_TRAILING),
|
|
||||||
opcode,
|
|
||||||
0x40,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
|
|
||||||
"""Close the websocket, sending the specified code and message."""
|
|
||||||
if isinstance(message, str):
|
|
||||||
message = message.encode("utf-8")
|
|
||||||
try:
|
|
||||||
await self.send_frame(
|
|
||||||
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self._closing = True
|
|
||||||
|
|
@ -1,268 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import socket
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections.abc import Sized
|
|
||||||
from http.cookies import BaseCookie, Morsel
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TypedDict,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from multidict import CIMultiDict
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from ._cookie_helpers import parse_set_cookie_headers
|
|
||||||
from .typedefs import LooseCookies
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .web_app import Application
|
|
||||||
from .web_exceptions import HTTPException
|
|
||||||
from .web_request import BaseRequest, Request
|
|
||||||
from .web_response import StreamResponse
|
|
||||||
else:
|
|
||||||
BaseRequest = Request = Application = StreamResponse = None
|
|
||||||
HTTPException = None
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractRouter(ABC):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._frozen = False
|
|
||||||
|
|
||||||
def post_init(self, app: Application) -> None:
|
|
||||||
"""Post init stage.
|
|
||||||
|
|
||||||
Not an abstract method for sake of backward compatibility,
|
|
||||||
but if the router wants to be aware of the application
|
|
||||||
it can override this.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def frozen(self) -> bool:
|
|
||||||
return self._frozen
|
|
||||||
|
|
||||||
def freeze(self) -> None:
|
|
||||||
"""Freeze router."""
|
|
||||||
self._frozen = True
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def resolve(self, request: Request) -> "AbstractMatchInfo":
|
|
||||||
"""Return MATCH_INFO for given request"""
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractMatchInfo(ABC):
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
@property # pragma: no branch
|
|
||||||
@abstractmethod
|
|
||||||
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
|
|
||||||
"""Execute matched request handler"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def expect_handler(
|
|
||||||
self,
|
|
||||||
) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]:
|
|
||||||
"""Expect handler for 100-continue processing"""
|
|
||||||
|
|
||||||
@property # pragma: no branch
|
|
||||||
@abstractmethod
|
|
||||||
def http_exception(self) -> Optional[HTTPException]:
|
|
||||||
"""HTTPException instance raised on router's resolving, or None"""
|
|
||||||
|
|
||||||
@abstractmethod # pragma: no branch
|
|
||||||
def get_info(self) -> Dict[str, Any]:
|
|
||||||
"""Return a dict with additional info useful for introspection"""
|
|
||||||
|
|
||||||
@property # pragma: no branch
|
|
||||||
@abstractmethod
|
|
||||||
def apps(self) -> Tuple[Application, ...]:
|
|
||||||
"""Stack of nested applications.
|
|
||||||
|
|
||||||
Top level application is left-most element.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_app(self, app: Application) -> None:
|
|
||||||
"""Add application to the nested apps stack."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def freeze(self) -> None:
|
|
||||||
"""Freeze the match info.
|
|
||||||
|
|
||||||
The method is called after route resolution.
|
|
||||||
|
|
||||||
After the call .add_app() is forbidden.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractView(ABC):
|
|
||||||
"""Abstract class based view."""
|
|
||||||
|
|
||||||
def __init__(self, request: Request) -> None:
|
|
||||||
self._request = request
|
|
||||||
|
|
||||||
@property
|
|
||||||
def request(self) -> Request:
|
|
||||||
"""Request instance."""
|
|
||||||
return self._request
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __await__(self) -> Generator[None, None, StreamResponse]:
|
|
||||||
"""Execute the view handler."""
|
|
||||||
|
|
||||||
|
|
||||||
class ResolveResult(TypedDict):
|
|
||||||
"""Resolve result.
|
|
||||||
|
|
||||||
This is the result returned from an AbstractResolver's
|
|
||||||
resolve method.
|
|
||||||
|
|
||||||
:param hostname: The hostname that was provided.
|
|
||||||
:param host: The IP address that was resolved.
|
|
||||||
:param port: The port that was resolved.
|
|
||||||
:param family: The address family that was resolved.
|
|
||||||
:param proto: The protocol that was resolved.
|
|
||||||
:param flags: The flags that were resolved.
|
|
||||||
"""
|
|
||||||
|
|
||||||
hostname: str
|
|
||||||
host: str
|
|
||||||
port: int
|
|
||||||
family: int
|
|
||||||
proto: int
|
|
||||||
flags: int
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractResolver(ABC):
|
|
||||||
"""Abstract DNS resolver."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def resolve(
|
|
||||||
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
|
|
||||||
) -> List[ResolveResult]:
|
|
||||||
"""Return IP address for given hostname"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Release resolver"""
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
IterableBase = Iterable[Morsel[str]]
|
|
||||||
else:
|
|
||||||
IterableBase = Iterable
|
|
||||||
|
|
||||||
|
|
||||||
ClearCookiePredicate = Callable[["Morsel[str]"], bool]
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractCookieJar(Sized, IterableBase):
|
|
||||||
"""Abstract Cookie Jar."""
|
|
||||||
|
|
||||||
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
|
||||||
self._loop = loop or asyncio.get_running_loop()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def quote_cookie(self) -> bool:
|
|
||||||
"""Return True if cookies should be quoted."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
|
|
||||||
"""Clear all cookies if no predicate is passed."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def clear_domain(self, domain: str) -> None:
|
|
||||||
"""Clear all cookies for domain and all subdomains."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
|
|
||||||
"""Update cookies."""
|
|
||||||
|
|
||||||
def update_cookies_from_headers(
|
|
||||||
self, headers: Sequence[str], response_url: URL
|
|
||||||
) -> None:
|
|
||||||
"""Update cookies from raw Set-Cookie headers."""
|
|
||||||
if headers and (cookies_to_update := parse_set_cookie_headers(headers)):
|
|
||||||
self.update_cookies(cookies_to_update, response_url)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
|
|
||||||
"""Return the jar's cookies filtered by their attributes."""
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractStreamWriter(ABC):
|
|
||||||
"""Abstract stream writer."""
|
|
||||||
|
|
||||||
buffer_size: int = 0
|
|
||||||
output_size: int = 0
|
|
||||||
length: Optional[int] = 0
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
|
|
||||||
"""Write chunk into stream."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def write_eof(self, chunk: bytes = b"") -> None:
|
|
||||||
"""Write last chunk."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def drain(self) -> None:
|
|
||||||
"""Flush the write buffer."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def enable_compression(
|
|
||||||
self, encoding: str = "deflate", strategy: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
"""Enable HTTP body compression"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def enable_chunking(self) -> None:
|
|
||||||
"""Enable HTTP chunked mode"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def write_headers(
|
|
||||||
self, status_line: str, headers: "CIMultiDict[str]"
|
|
||||||
) -> None:
|
|
||||||
"""Write HTTP headers"""
|
|
||||||
|
|
||||||
def send_headers(self) -> None:
|
|
||||||
"""Force sending buffered headers if not already sent.
|
|
||||||
|
|
||||||
Required only if write_headers() buffers headers instead of sending immediately.
|
|
||||||
For backwards compatibility, this method does nothing by default.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractAccessLogger(ABC):
|
|
||||||
"""Abstract writer to access log."""
|
|
||||||
|
|
||||||
__slots__ = ("logger", "log_format")
|
|
||||||
|
|
||||||
def __init__(self, logger: logging.Logger, log_format: str) -> None:
|
|
||||||
self.logger = logger
|
|
||||||
self.log_format = log_format
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
|
|
||||||
"""Emit log to logger."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def enabled(self) -> bool:
|
|
||||||
"""Check if logger is enabled."""
|
|
||||||
return True
|
|
||||||
|
|
@ -1,100 +0,0 @@
|
||||||
import asyncio
|
|
||||||
from typing import Optional, cast
|
|
||||||
|
|
||||||
from .client_exceptions import ClientConnectionResetError
|
|
||||||
from .helpers import set_exception
|
|
||||||
from .tcp_helpers import tcp_nodelay
|
|
||||||
|
|
||||||
|
|
||||||
class BaseProtocol(asyncio.Protocol):
|
|
||||||
__slots__ = (
|
|
||||||
"_loop",
|
|
||||||
"_paused",
|
|
||||||
"_drain_waiter",
|
|
||||||
"_connection_lost",
|
|
||||||
"_reading_paused",
|
|
||||||
"transport",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
|
||||||
self._loop: asyncio.AbstractEventLoop = loop
|
|
||||||
self._paused = False
|
|
||||||
self._drain_waiter: Optional[asyncio.Future[None]] = None
|
|
||||||
self._reading_paused = False
|
|
||||||
|
|
||||||
self.transport: Optional[asyncio.Transport] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def connected(self) -> bool:
|
|
||||||
"""Return True if the connection is open."""
|
|
||||||
return self.transport is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def writing_paused(self) -> bool:
|
|
||||||
return self._paused
|
|
||||||
|
|
||||||
def pause_writing(self) -> None:
|
|
||||||
assert not self._paused
|
|
||||||
self._paused = True
|
|
||||||
|
|
||||||
def resume_writing(self) -> None:
|
|
||||||
assert self._paused
|
|
||||||
self._paused = False
|
|
||||||
|
|
||||||
waiter = self._drain_waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._drain_waiter = None
|
|
||||||
if not waiter.done():
|
|
||||||
waiter.set_result(None)
|
|
||||||
|
|
||||||
def pause_reading(self) -> None:
|
|
||||||
if not self._reading_paused and self.transport is not None:
|
|
||||||
try:
|
|
||||||
self.transport.pause_reading()
|
|
||||||
except (AttributeError, NotImplementedError, RuntimeError):
|
|
||||||
pass
|
|
||||||
self._reading_paused = True
|
|
||||||
|
|
||||||
def resume_reading(self) -> None:
|
|
||||||
if self._reading_paused and self.transport is not None:
|
|
||||||
try:
|
|
||||||
self.transport.resume_reading()
|
|
||||||
except (AttributeError, NotImplementedError, RuntimeError):
|
|
||||||
pass
|
|
||||||
self._reading_paused = False
|
|
||||||
|
|
||||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
||||||
tr = cast(asyncio.Transport, transport)
|
|
||||||
tcp_nodelay(tr, True)
|
|
||||||
self.transport = tr
|
|
||||||
|
|
||||||
def connection_lost(self, exc: Optional[BaseException]) -> None:
|
|
||||||
# Wake up the writer if currently paused.
|
|
||||||
self.transport = None
|
|
||||||
if not self._paused:
|
|
||||||
return
|
|
||||||
waiter = self._drain_waiter
|
|
||||||
if waiter is None:
|
|
||||||
return
|
|
||||||
self._drain_waiter = None
|
|
||||||
if waiter.done():
|
|
||||||
return
|
|
||||||
if exc is None:
|
|
||||||
waiter.set_result(None)
|
|
||||||
else:
|
|
||||||
set_exception(
|
|
||||||
waiter,
|
|
||||||
ConnectionError("Connection lost"),
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _drain_helper(self) -> None:
|
|
||||||
if self.transport is None:
|
|
||||||
raise ClientConnectionResetError("Connection lost")
|
|
||||||
if not self._paused:
|
|
||||||
return
|
|
||||||
waiter = self._drain_waiter
|
|
||||||
if waiter is None:
|
|
||||||
waiter = self._loop.create_future()
|
|
||||||
self._drain_waiter = waiter
|
|
||||||
await asyncio.shield(waiter)
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,421 +0,0 @@
|
||||||
"""HTTP related errors."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import warnings
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from multidict import MultiMapping
|
|
||||||
|
|
||||||
from .typedefs import StrOrURL
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import ssl
|
|
||||||
|
|
||||||
SSLContext = ssl.SSLContext
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import ssl
|
|
||||||
|
|
||||||
SSLContext = ssl.SSLContext
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
ssl = SSLContext = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
|
|
||||||
from .http_parser import RawResponseMessage
|
|
||||||
else:
|
|
||||||
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"ClientError",
|
|
||||||
"ClientConnectionError",
|
|
||||||
"ClientConnectionResetError",
|
|
||||||
"ClientOSError",
|
|
||||||
"ClientConnectorError",
|
|
||||||
"ClientProxyConnectionError",
|
|
||||||
"ClientSSLError",
|
|
||||||
"ClientConnectorDNSError",
|
|
||||||
"ClientConnectorSSLError",
|
|
||||||
"ClientConnectorCertificateError",
|
|
||||||
"ConnectionTimeoutError",
|
|
||||||
"SocketTimeoutError",
|
|
||||||
"ServerConnectionError",
|
|
||||||
"ServerTimeoutError",
|
|
||||||
"ServerDisconnectedError",
|
|
||||||
"ServerFingerprintMismatch",
|
|
||||||
"ClientResponseError",
|
|
||||||
"ClientHttpProxyError",
|
|
||||||
"WSServerHandshakeError",
|
|
||||||
"ContentTypeError",
|
|
||||||
"ClientPayloadError",
|
|
||||||
"InvalidURL",
|
|
||||||
"InvalidUrlClientError",
|
|
||||||
"RedirectClientError",
|
|
||||||
"NonHttpUrlClientError",
|
|
||||||
"InvalidUrlRedirectClientError",
|
|
||||||
"NonHttpUrlRedirectClientError",
|
|
||||||
"WSMessageTypeError",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientError(Exception):
|
|
||||||
"""Base class for client connection errors."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientResponseError(ClientError):
|
|
||||||
"""Base class for exceptions that occur after getting a response.
|
|
||||||
|
|
||||||
request_info: An instance of RequestInfo.
|
|
||||||
history: A sequence of responses, if redirects occurred.
|
|
||||||
status: HTTP status code.
|
|
||||||
message: Error message.
|
|
||||||
headers: Response headers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
request_info: RequestInfo,
|
|
||||||
history: Tuple[ClientResponse, ...],
|
|
||||||
*,
|
|
||||||
code: Optional[int] = None,
|
|
||||||
status: Optional[int] = None,
|
|
||||||
message: str = "",
|
|
||||||
headers: Optional[MultiMapping[str]] = None,
|
|
||||||
) -> None:
|
|
||||||
self.request_info = request_info
|
|
||||||
if code is not None:
|
|
||||||
if status is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Both code and status arguments are provided; "
|
|
||||||
"code is deprecated, use status instead"
|
|
||||||
)
|
|
||||||
warnings.warn(
|
|
||||||
"code argument is deprecated, use status instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
if status is not None:
|
|
||||||
self.status = status
|
|
||||||
elif code is not None:
|
|
||||||
self.status = code
|
|
||||||
else:
|
|
||||||
self.status = 0
|
|
||||||
self.message = message
|
|
||||||
self.headers = headers
|
|
||||||
self.history = history
|
|
||||||
self.args = (request_info, history)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "{}, message={!r}, url={!r}".format(
|
|
||||||
self.status,
|
|
||||||
self.message,
|
|
||||||
str(self.request_info.real_url),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
args = f"{self.request_info!r}, {self.history!r}"
|
|
||||||
if self.status != 0:
|
|
||||||
args += f", status={self.status!r}"
|
|
||||||
if self.message != "":
|
|
||||||
args += f", message={self.message!r}"
|
|
||||||
if self.headers is not None:
|
|
||||||
args += f", headers={self.headers!r}"
|
|
||||||
return f"{type(self).__name__}({args})"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def code(self) -> int:
|
|
||||||
warnings.warn(
|
|
||||||
"code property is deprecated, use status instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return self.status
|
|
||||||
|
|
||||||
@code.setter
|
|
||||||
def code(self, value: int) -> None:
|
|
||||||
warnings.warn(
|
|
||||||
"code property is deprecated, use status instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
self.status = value
|
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeError(ClientResponseError):
|
|
||||||
"""ContentType found is not valid."""
|
|
||||||
|
|
||||||
|
|
||||||
class WSServerHandshakeError(ClientResponseError):
|
|
||||||
"""websocket server handshake error."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientHttpProxyError(ClientResponseError):
|
|
||||||
"""HTTP proxy error.
|
|
||||||
|
|
||||||
Raised in :class:`aiohttp.connector.TCPConnector` if
|
|
||||||
proxy responds with status other than ``200 OK``
|
|
||||||
on ``CONNECT`` request.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class TooManyRedirects(ClientResponseError):
|
|
||||||
"""Client was redirected too many times."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientConnectionError(ClientError):
|
|
||||||
"""Base class for client socket errors."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
|
|
||||||
"""ConnectionResetError"""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientOSError(ClientConnectionError, OSError):
|
|
||||||
"""OSError error."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientConnectorError(ClientOSError):
|
|
||||||
"""Client connector error.
|
|
||||||
|
|
||||||
Raised in :class:`aiohttp.connector.TCPConnector` if
|
|
||||||
a connection can not be established.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None:
|
|
||||||
self._conn_key = connection_key
|
|
||||||
self._os_error = os_error
|
|
||||||
super().__init__(os_error.errno, os_error.strerror)
|
|
||||||
self.args = (connection_key, os_error)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def os_error(self) -> OSError:
|
|
||||||
return self._os_error
|
|
||||||
|
|
||||||
@property
|
|
||||||
def host(self) -> str:
|
|
||||||
return self._conn_key.host
|
|
||||||
|
|
||||||
@property
|
|
||||||
def port(self) -> Optional[int]:
|
|
||||||
return self._conn_key.port
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
|
|
||||||
return self._conn_key.ssl
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
|
|
||||||
self, "default" if self.ssl is True else self.ssl, self.strerror
|
|
||||||
)
|
|
||||||
|
|
||||||
# OSError.__reduce__ does too much black magick
|
|
||||||
__reduce__ = BaseException.__reduce__
|
|
||||||
|
|
||||||
|
|
||||||
class ClientConnectorDNSError(ClientConnectorError):
|
|
||||||
"""DNS resolution failed during client connection.
|
|
||||||
|
|
||||||
Raised in :class:`aiohttp.connector.TCPConnector` if
|
|
||||||
DNS resolution fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientProxyConnectionError(ClientConnectorError):
|
|
||||||
"""Proxy connection error.
|
|
||||||
|
|
||||||
Raised in :class:`aiohttp.connector.TCPConnector` if
|
|
||||||
connection to proxy can not be established.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class UnixClientConnectorError(ClientConnectorError):
|
|
||||||
"""Unix connector error.
|
|
||||||
|
|
||||||
Raised in :py:class:`aiohttp.connector.UnixConnector`
|
|
||||||
if connection to unix socket can not be established.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, path: str, connection_key: ConnectionKey, os_error: OSError
|
|
||||||
) -> None:
|
|
||||||
self._path = path
|
|
||||||
super().__init__(connection_key, os_error)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def path(self) -> str:
|
|
||||||
return self._path
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
|
|
||||||
self, "default" if self.ssl is True else self.ssl, self.strerror
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ServerConnectionError(ClientConnectionError):
|
|
||||||
"""Server connection errors."""
|
|
||||||
|
|
||||||
|
|
||||||
class ServerDisconnectedError(ServerConnectionError):
|
|
||||||
"""Server disconnected."""
|
|
||||||
|
|
||||||
def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None:
|
|
||||||
if message is None:
|
|
||||||
message = "Server disconnected"
|
|
||||||
|
|
||||||
self.args = (message,)
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
|
|
||||||
class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
|
|
||||||
"""Server timeout error."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTimeoutError(ServerTimeoutError):
|
|
||||||
"""Connection timeout error."""
|
|
||||||
|
|
||||||
|
|
||||||
class SocketTimeoutError(ServerTimeoutError):
|
|
||||||
"""Socket timeout error."""
|
|
||||||
|
|
||||||
|
|
||||||
class ServerFingerprintMismatch(ServerConnectionError):
|
|
||||||
"""SSL certificate does not match expected fingerprint."""
|
|
||||||
|
|
||||||
def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None:
|
|
||||||
self.expected = expected
|
|
||||||
self.got = got
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.args = (expected, got, host, port)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "<{} expected={!r} got={!r} host={!r} port={!r}>".format(
|
|
||||||
self.__class__.__name__, self.expected, self.got, self.host, self.port
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientPayloadError(ClientError):
|
|
||||||
"""Response payload error."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidURL(ClientError, ValueError):
|
|
||||||
"""Invalid URL.
|
|
||||||
|
|
||||||
URL used for fetching is malformed, e.g. it doesn't contains host
|
|
||||||
part.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Derive from ValueError for backward compatibility
|
|
||||||
|
|
||||||
def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None:
|
|
||||||
# The type of url is not yarl.URL because the exception can be raised
|
|
||||||
# on URL(url) call
|
|
||||||
self._url = url
|
|
||||||
self._description = description
|
|
||||||
|
|
||||||
if description:
|
|
||||||
super().__init__(url, description)
|
|
||||||
else:
|
|
||||||
super().__init__(url)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def url(self) -> StrOrURL:
|
|
||||||
return self._url
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> "str | None":
|
|
||||||
return self._description
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<{self.__class__.__name__} {self}>"
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
if self._description:
|
|
||||||
return f"{self._url} - {self._description}"
|
|
||||||
return str(self._url)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidUrlClientError(InvalidURL):
|
|
||||||
"""Invalid URL client error."""
|
|
||||||
|
|
||||||
|
|
||||||
class RedirectClientError(ClientError):
|
|
||||||
"""Client redirect error."""
|
|
||||||
|
|
||||||
|
|
||||||
class NonHttpUrlClientError(ClientError):
|
|
||||||
"""Non http URL client error."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError):
|
|
||||||
"""Invalid URL redirect client error."""
|
|
||||||
|
|
||||||
|
|
||||||
class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError):
|
|
||||||
"""Non http URL redirect client error."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientSSLError(ClientConnectorError):
|
|
||||||
"""Base error for ssl.*Errors."""
|
|
||||||
|
|
||||||
|
|
||||||
if ssl is not None:
|
|
||||||
cert_errors = (ssl.CertificateError,)
|
|
||||||
cert_errors_bases = (
|
|
||||||
ClientSSLError,
|
|
||||||
ssl.CertificateError,
|
|
||||||
)
|
|
||||||
|
|
||||||
ssl_errors = (ssl.SSLError,)
|
|
||||||
ssl_error_bases = (ClientSSLError, ssl.SSLError)
|
|
||||||
else: # pragma: no cover
|
|
||||||
cert_errors = tuple()
|
|
||||||
cert_errors_bases = (
|
|
||||||
ClientSSLError,
|
|
||||||
ValueError,
|
|
||||||
)
|
|
||||||
|
|
||||||
ssl_errors = tuple()
|
|
||||||
ssl_error_bases = (ClientSSLError,)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc]
|
|
||||||
"""Response ssl error."""
|
|
||||||
|
|
||||||
|
|
||||||
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc]
|
|
||||||
"""Response certificate error."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, connection_key: ConnectionKey, certificate_error: Exception
|
|
||||||
) -> None:
|
|
||||||
self._conn_key = connection_key
|
|
||||||
self._certificate_error = certificate_error
|
|
||||||
self.args = (connection_key, certificate_error)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def certificate_error(self) -> Exception:
|
|
||||||
return self._certificate_error
|
|
||||||
|
|
||||||
@property
|
|
||||||
def host(self) -> str:
|
|
||||||
return self._conn_key.host
|
|
||||||
|
|
||||||
@property
|
|
||||||
def port(self) -> Optional[int]:
|
|
||||||
return self._conn_key.port
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ssl(self) -> bool:
|
|
||||||
return self._conn_key.is_ssl
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return (
|
|
||||||
"Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} "
|
|
||||||
"[{0.certificate_error.__class__.__name__}: "
|
|
||||||
"{0.certificate_error.args}]".format(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WSMessageTypeError(TypeError):
|
|
||||||
"""WebSocket message type is not valid."""
|
|
||||||
|
|
@ -1,476 +0,0 @@
|
||||||
"""
|
|
||||||
Digest authentication middleware for aiohttp client.
|
|
||||||
|
|
||||||
This middleware implements HTTP Digest Authentication according to RFC 7616,
|
|
||||||
providing a more secure alternative to Basic Authentication. It supports all
|
|
||||||
standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session
|
|
||||||
variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from typing import (
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Final,
|
|
||||||
FrozenSet,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Tuple,
|
|
||||||
TypedDict,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from . import hdrs
|
|
||||||
from .client_exceptions import ClientError
|
|
||||||
from .client_middlewares import ClientHandlerType
|
|
||||||
from .client_reqrep import ClientRequest, ClientResponse
|
|
||||||
from .payload import Payload
|
|
||||||
|
|
||||||
|
|
||||||
class DigestAuthChallenge(TypedDict, total=False):
|
|
||||||
realm: str
|
|
||||||
nonce: str
|
|
||||||
qop: str
|
|
||||||
algorithm: str
|
|
||||||
opaque: str
|
|
||||||
domain: str
|
|
||||||
stale: str
|
|
||||||
|
|
||||||
|
|
||||||
DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = {
|
|
||||||
"MD5": hashlib.md5,
|
|
||||||
"MD5-SESS": hashlib.md5,
|
|
||||||
"SHA": hashlib.sha1,
|
|
||||||
"SHA-SESS": hashlib.sha1,
|
|
||||||
"SHA256": hashlib.sha256,
|
|
||||||
"SHA256-SESS": hashlib.sha256,
|
|
||||||
"SHA-256": hashlib.sha256,
|
|
||||||
"SHA-256-SESS": hashlib.sha256,
|
|
||||||
"SHA512": hashlib.sha512,
|
|
||||||
"SHA512-SESS": hashlib.sha512,
|
|
||||||
"SHA-512": hashlib.sha512,
|
|
||||||
"SHA-512-SESS": hashlib.sha512,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Compile the regex pattern once at module level for performance
|
|
||||||
_HEADER_PAIRS_PATTERN = re.compile(
|
|
||||||
r'(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
|
|
||||||
# | | | | | | | | | || |
|
|
||||||
# +----|--|-|-|--|----|------|----|--||-----|--> alphanumeric key
|
|
||||||
# +--|-|-|--|----|------|----|--||-----|--> maybe whitespace
|
|
||||||
# | | | | | | | || |
|
|
||||||
# +-|-|--|----|------|----|--||-----|--> = (delimiter)
|
|
||||||
# +-|--|----|------|----|--||-----|--> maybe whitespace
|
|
||||||
# | | | | | || |
|
|
||||||
# +--|----|------|----|--||-----|--> group quoted or unquoted
|
|
||||||
# | | | | || |
|
|
||||||
# +----|------|----|--||-----|--> if quoted...
|
|
||||||
# +------|----|--||-----|--> anything but " or \
|
|
||||||
# +----|--||-----|--> escaped characters allowed
|
|
||||||
# +--||-----|--> or can be empty string
|
|
||||||
# || |
|
|
||||||
# +|-----|--> if unquoted...
|
|
||||||
# +-----|--> anything but , or <space>
|
|
||||||
# +--> at least one char req'd
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# RFC 7616: Challenge parameters to extract
|
|
||||||
CHALLENGE_FIELDS: Final[
|
|
||||||
Tuple[
|
|
||||||
Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ...
|
|
||||||
]
|
|
||||||
] = (
|
|
||||||
"realm",
|
|
||||||
"nonce",
|
|
||||||
"qop",
|
|
||||||
"algorithm",
|
|
||||||
"opaque",
|
|
||||||
"domain",
|
|
||||||
"stale",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Supported digest authentication algorithms
|
|
||||||
# Use a tuple of sorted keys for predictable documentation and error messages
|
|
||||||
SUPPORTED_ALGORITHMS: Final[Tuple[str, ...]] = tuple(sorted(DigestFunctions.keys()))
|
|
||||||
|
|
||||||
# RFC 7616: Fields that require quoting in the Digest auth header
|
|
||||||
# These fields must be enclosed in double quotes in the Authorization header.
|
|
||||||
# Algorithm, qop, and nc are never quoted per RFC specifications.
|
|
||||||
# This frozen set is used by the template-based header construction to
|
|
||||||
# automatically determine which fields need quotes.
|
|
||||||
QUOTED_AUTH_FIELDS: Final[FrozenSet[str]] = frozenset(
|
|
||||||
{"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def escape_quotes(value: str) -> str:
|
|
||||||
"""Escape double quotes for HTTP header values."""
|
|
||||||
return value.replace('"', '\\"')
|
|
||||||
|
|
||||||
|
|
||||||
def unescape_quotes(value: str) -> str:
|
|
||||||
"""Unescape double quotes in HTTP header values."""
|
|
||||||
return value.replace('\\"', '"')
|
|
||||||
|
|
||||||
|
|
||||||
def parse_header_pairs(header: str) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Parse key-value pairs from WWW-Authenticate or similar HTTP headers.
|
|
||||||
|
|
||||||
This function handles the complex format of WWW-Authenticate header values,
|
|
||||||
supporting both quoted and unquoted values, proper handling of commas in
|
|
||||||
quoted values, and whitespace variations per RFC 7616.
|
|
||||||
|
|
||||||
Examples of supported formats:
|
|
||||||
- key1="value1", key2=value2
|
|
||||||
- key1 = "value1" , key2="value, with, commas"
|
|
||||||
- key1=value1,key2="value2"
|
|
||||||
- realm="example.com", nonce="12345", qop="auth"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
header: The header value string to parse
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping parameter names to their values
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val
|
|
||||||
for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header)
|
|
||||||
if (stripped_key := key.strip())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DigestAuthMiddleware:
|
|
||||||
"""
|
|
||||||
HTTP digest authentication middleware for aiohttp client.
|
|
||||||
|
|
||||||
This middleware intercepts 401 Unauthorized responses containing a Digest
|
|
||||||
authentication challenge, calculates the appropriate digest credentials,
|
|
||||||
and automatically retries the request with the proper Authorization header.
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Handles all aspects of Digest authentication handshake automatically
|
|
||||||
- Supports all standard hash algorithms:
|
|
||||||
- MD5, MD5-SESS
|
|
||||||
- SHA, SHA-SESS
|
|
||||||
- SHA256, SHA256-SESS, SHA-256, SHA-256-SESS
|
|
||||||
- SHA512, SHA512-SESS, SHA-512, SHA-512-SESS
|
|
||||||
- Supports 'auth' and 'auth-int' quality of protection modes
|
|
||||||
- Properly handles quoted strings and parameter parsing
|
|
||||||
- Includes replay attack protection with client nonce count tracking
|
|
||||||
- Supports preemptive authentication per RFC 7616 Section 3.6
|
|
||||||
|
|
||||||
Standards compliance:
|
|
||||||
- RFC 7616: HTTP Digest Access Authentication (primary reference)
|
|
||||||
- RFC 2617: HTTP Authentication (deprecated by RFC 7616)
|
|
||||||
- RFC 1945: Section 11.1 (username restrictions)
|
|
||||||
|
|
||||||
Implementation notes:
|
|
||||||
The core digest calculation is inspired by the implementation in
|
|
||||||
https://github.com/requests/requests/blob/v2.18.4/requests/auth.py
|
|
||||||
with added support for modern digest auth features and error handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
login: str,
|
|
||||||
password: str,
|
|
||||||
preemptive: bool = True,
|
|
||||||
) -> None:
|
|
||||||
if login is None:
|
|
||||||
raise ValueError("None is not allowed as login value")
|
|
||||||
|
|
||||||
if password is None:
|
|
||||||
raise ValueError("None is not allowed as password value")
|
|
||||||
|
|
||||||
if ":" in login:
|
|
||||||
raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)')
|
|
||||||
|
|
||||||
self._login_str: Final[str] = login
|
|
||||||
self._login_bytes: Final[bytes] = login.encode("utf-8")
|
|
||||||
self._password_bytes: Final[bytes] = password.encode("utf-8")
|
|
||||||
|
|
||||||
self._last_nonce_bytes = b""
|
|
||||||
self._nonce_count = 0
|
|
||||||
self._challenge: DigestAuthChallenge = {}
|
|
||||||
self._preemptive: bool = preemptive
|
|
||||||
# Set of URLs defining the protection space
|
|
||||||
self._protection_space: List[str] = []
|
|
||||||
|
|
||||||
async def _encode(
|
|
||||||
self, method: str, url: URL, body: Union[Payload, Literal[b""]]
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Build digest authorization header for the current challenge.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method: The HTTP method (GET, POST, etc.)
|
|
||||||
url: The request URL
|
|
||||||
body: The request body (used for qop=auth-int)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A fully formatted Digest authorization header string
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ClientError: If the challenge is missing required parameters or
|
|
||||||
contains unsupported values
|
|
||||||
|
|
||||||
"""
|
|
||||||
challenge = self._challenge
|
|
||||||
if "realm" not in challenge:
|
|
||||||
raise ClientError(
|
|
||||||
"Malformed Digest auth challenge: Missing 'realm' parameter"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "nonce" not in challenge:
|
|
||||||
raise ClientError(
|
|
||||||
"Malformed Digest auth challenge: Missing 'nonce' parameter"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name)
|
|
||||||
realm = challenge["realm"]
|
|
||||||
nonce = challenge["nonce"]
|
|
||||||
|
|
||||||
# Empty nonce values are not allowed as they are security-critical for replay protection
|
|
||||||
if not nonce:
|
|
||||||
raise ClientError(
|
|
||||||
"Security issue: Digest auth challenge contains empty 'nonce' value"
|
|
||||||
)
|
|
||||||
|
|
||||||
qop_raw = challenge.get("qop", "")
|
|
||||||
# Preserve original algorithm case for response while using uppercase for processing
|
|
||||||
algorithm_original = challenge.get("algorithm", "MD5")
|
|
||||||
algorithm = algorithm_original.upper()
|
|
||||||
opaque = challenge.get("opaque", "")
|
|
||||||
|
|
||||||
# Convert string values to bytes once
|
|
||||||
nonce_bytes = nonce.encode("utf-8")
|
|
||||||
realm_bytes = realm.encode("utf-8")
|
|
||||||
path = URL(url).path_qs
|
|
||||||
|
|
||||||
# Process QoP
|
|
||||||
qop = ""
|
|
||||||
qop_bytes = b""
|
|
||||||
if qop_raw:
|
|
||||||
valid_qops = {"auth", "auth-int"}.intersection(
|
|
||||||
{q.strip() for q in qop_raw.split(",") if q.strip()}
|
|
||||||
)
|
|
||||||
if not valid_qops:
|
|
||||||
raise ClientError(
|
|
||||||
f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}"
|
|
||||||
)
|
|
||||||
|
|
||||||
qop = "auth-int" if "auth-int" in valid_qops else "auth"
|
|
||||||
qop_bytes = qop.encode("utf-8")
|
|
||||||
|
|
||||||
if algorithm not in DigestFunctions:
|
|
||||||
raise ClientError(
|
|
||||||
f"Digest auth error: Unsupported hash algorithm: {algorithm}. "
|
|
||||||
f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}"
|
|
||||||
)
|
|
||||||
hash_fn: Final = DigestFunctions[algorithm]
|
|
||||||
|
|
||||||
def H(x: bytes) -> bytes:
|
|
||||||
"""RFC 7616 Section 3: Hash function H(data) = hex(hash(data))."""
|
|
||||||
return hash_fn(x).hexdigest().encode()
|
|
||||||
|
|
||||||
def KD(s: bytes, d: bytes) -> bytes:
|
|
||||||
"""RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data))."""
|
|
||||||
return H(b":".join((s, d)))
|
|
||||||
|
|
||||||
# Calculate A1 and A2
|
|
||||||
A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes))
|
|
||||||
A2 = f"{method.upper()}:{path}".encode()
|
|
||||||
if qop == "auth-int":
|
|
||||||
if isinstance(body, Payload): # will always be empty bytes unless Payload
|
|
||||||
entity_bytes = await body.as_bytes() # Get bytes from Payload
|
|
||||||
else:
|
|
||||||
entity_bytes = body
|
|
||||||
entity_hash = H(entity_bytes)
|
|
||||||
A2 = b":".join((A2, entity_hash))
|
|
||||||
|
|
||||||
HA1 = H(A1)
|
|
||||||
HA2 = H(A2)
|
|
||||||
|
|
||||||
# Nonce count handling
|
|
||||||
if nonce_bytes == self._last_nonce_bytes:
|
|
||||||
self._nonce_count += 1
|
|
||||||
else:
|
|
||||||
self._nonce_count = 1
|
|
||||||
|
|
||||||
self._last_nonce_bytes = nonce_bytes
|
|
||||||
ncvalue = f"{self._nonce_count:08x}"
|
|
||||||
ncvalue_bytes = ncvalue.encode("utf-8")
|
|
||||||
|
|
||||||
# Generate client nonce
|
|
||||||
cnonce = hashlib.sha1(
|
|
||||||
b"".join(
|
|
||||||
[
|
|
||||||
str(self._nonce_count).encode("utf-8"),
|
|
||||||
nonce_bytes,
|
|
||||||
time.ctime().encode("utf-8"),
|
|
||||||
os.urandom(8),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
).hexdigest()[:16]
|
|
||||||
cnonce_bytes = cnonce.encode("utf-8")
|
|
||||||
|
|
||||||
# Special handling for session-based algorithms
|
|
||||||
if algorithm.upper().endswith("-SESS"):
|
|
||||||
HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes)))
|
|
||||||
|
|
||||||
# Calculate the response digest
|
|
||||||
if qop:
|
|
||||||
noncebit = b":".join(
|
|
||||||
(nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2)
|
|
||||||
)
|
|
||||||
response_digest = KD(HA1, noncebit)
|
|
||||||
else:
|
|
||||||
response_digest = KD(HA1, b":".join((nonce_bytes, HA2)))
|
|
||||||
|
|
||||||
# Define a dict mapping of header fields to their values
|
|
||||||
# Group fields into always-present, optional, and qop-dependent
|
|
||||||
header_fields = {
|
|
||||||
# Always present fields
|
|
||||||
"username": escape_quotes(self._login_str),
|
|
||||||
"realm": escape_quotes(realm),
|
|
||||||
"nonce": escape_quotes(nonce),
|
|
||||||
"uri": path,
|
|
||||||
"response": response_digest.decode(),
|
|
||||||
"algorithm": algorithm_original,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Optional fields
|
|
||||||
if opaque:
|
|
||||||
header_fields["opaque"] = escape_quotes(opaque)
|
|
||||||
|
|
||||||
# QoP-dependent fields
|
|
||||||
if qop:
|
|
||||||
header_fields["qop"] = qop
|
|
||||||
header_fields["nc"] = ncvalue
|
|
||||||
header_fields["cnonce"] = cnonce
|
|
||||||
|
|
||||||
# Build header using templates for each field type
|
|
||||||
pairs: List[str] = []
|
|
||||||
for field, value in header_fields.items():
|
|
||||||
if field in QUOTED_AUTH_FIELDS:
|
|
||||||
pairs.append(f'{field}="{value}"')
|
|
||||||
else:
|
|
||||||
pairs.append(f"{field}={value}")
|
|
||||||
|
|
||||||
return f"Digest {', '.join(pairs)}"
|
|
||||||
|
|
||||||
def _in_protection_space(self, url: URL) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the given URL is within the current protection space.
|
|
||||||
|
|
||||||
According to RFC 7616, a URI is in the protection space if any URI
|
|
||||||
in the protection space is a prefix of it (after both have been made absolute).
|
|
||||||
"""
|
|
||||||
request_str = str(url)
|
|
||||||
for space_str in self._protection_space:
|
|
||||||
# Check if request starts with space URL
|
|
||||||
if not request_str.startswith(space_str):
|
|
||||||
continue
|
|
||||||
# Exact match or space ends with / (proper directory prefix)
|
|
||||||
if len(request_str) == len(space_str) or space_str[-1] == "/":
|
|
||||||
return True
|
|
||||||
# Check next char is / to ensure proper path boundary
|
|
||||||
if request_str[len(space_str)] == "/":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _authenticate(self, response: ClientResponse) -> bool:
|
|
||||||
"""
|
|
||||||
Takes the given response and tries digest-auth, if needed.
|
|
||||||
|
|
||||||
Returns true if the original request must be resent.
|
|
||||||
"""
|
|
||||||
if response.status != 401:
|
|
||||||
return False
|
|
||||||
|
|
||||||
auth_header = response.headers.get("www-authenticate", "")
|
|
||||||
if not auth_header:
|
|
||||||
return False # No authentication header present
|
|
||||||
|
|
||||||
method, sep, headers = auth_header.partition(" ")
|
|
||||||
if not sep:
|
|
||||||
# No space found in www-authenticate header
|
|
||||||
return False # Malformed auth header, missing scheme separator
|
|
||||||
|
|
||||||
if method.lower() != "digest":
|
|
||||||
# Not a digest auth challenge (could be Basic, Bearer, etc.)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not headers:
|
|
||||||
# We have a digest scheme but no parameters
|
|
||||||
return False # Malformed digest header, missing parameters
|
|
||||||
|
|
||||||
# We have a digest auth header with content
|
|
||||||
if not (header_pairs := parse_header_pairs(headers)):
|
|
||||||
# Failed to parse any key-value pairs
|
|
||||||
return False # Malformed digest header, no valid parameters
|
|
||||||
|
|
||||||
# Extract challenge parameters
|
|
||||||
self._challenge = {}
|
|
||||||
for field in CHALLENGE_FIELDS:
|
|
||||||
if value := header_pairs.get(field):
|
|
||||||
self._challenge[field] = value
|
|
||||||
|
|
||||||
# Update protection space based on domain parameter or default to origin
|
|
||||||
origin = response.url.origin()
|
|
||||||
|
|
||||||
if domain := self._challenge.get("domain"):
|
|
||||||
# Parse space-separated list of URIs
|
|
||||||
self._protection_space = []
|
|
||||||
for uri in domain.split():
|
|
||||||
# Remove quotes if present
|
|
||||||
uri = uri.strip('"')
|
|
||||||
if uri.startswith("/"):
|
|
||||||
# Path-absolute, relative to origin
|
|
||||||
self._protection_space.append(str(origin.join(URL(uri))))
|
|
||||||
else:
|
|
||||||
# Absolute URI
|
|
||||||
self._protection_space.append(str(URL(uri)))
|
|
||||||
else:
|
|
||||||
# No domain specified, protection space is entire origin
|
|
||||||
self._protection_space = [str(origin)]
|
|
||||||
|
|
||||||
# Return True only if we found at least one challenge parameter
|
|
||||||
return bool(self._challenge)
|
|
||||||
|
|
||||||
async def __call__(
|
|
||||||
self, request: ClientRequest, handler: ClientHandlerType
|
|
||||||
) -> ClientResponse:
|
|
||||||
"""Run the digest auth middleware."""
|
|
||||||
response = None
|
|
||||||
for retry_count in range(2):
|
|
||||||
# Apply authorization header if:
|
|
||||||
# 1. This is a retry after 401 (retry_count > 0), OR
|
|
||||||
# 2. Preemptive auth is enabled AND we have a challenge AND the URL is in protection space
|
|
||||||
if retry_count > 0 or (
|
|
||||||
self._preemptive
|
|
||||||
and self._challenge
|
|
||||||
and self._in_protection_space(request.url)
|
|
||||||
):
|
|
||||||
request.headers[hdrs.AUTHORIZATION] = await self._encode(
|
|
||||||
request.method, request.url, request.body
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send the request
|
|
||||||
response = await handler(request)
|
|
||||||
|
|
||||||
# Check if we need to authenticate
|
|
||||||
if not self._authenticate(response):
|
|
||||||
break
|
|
||||||
|
|
||||||
# At this point, response is guaranteed to be defined
|
|
||||||
assert response is not None
|
|
||||||
return response
|
|
||||||
|
|
@ -1,55 +0,0 @@
|
||||||
"""Client middleware support."""
|
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable, Sequence
|
|
||||||
|
|
||||||
from .client_reqrep import ClientRequest, ClientResponse
|
|
||||||
|
|
||||||
__all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares")
|
|
||||||
|
|
||||||
# Type alias for client request handlers - functions that process requests and return responses
|
|
||||||
ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]]
|
|
||||||
|
|
||||||
# Type for client middleware - similar to server but uses ClientRequest/ClientResponse
|
|
||||||
ClientMiddlewareType = Callable[
|
|
||||||
[ClientRequest, ClientHandlerType], Awaitable[ClientResponse]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_client_middlewares(
|
|
||||||
handler: ClientHandlerType,
|
|
||||||
middlewares: Sequence[ClientMiddlewareType],
|
|
||||||
) -> ClientHandlerType:
|
|
||||||
"""
|
|
||||||
Apply middlewares to request handler.
|
|
||||||
|
|
||||||
The middlewares are applied in reverse order, so the first middleware
|
|
||||||
in the list wraps all subsequent middlewares and the handler.
|
|
||||||
|
|
||||||
This implementation avoids using partial/update_wrapper to minimize overhead
|
|
||||||
and doesn't cache to avoid holding references to stateful middleware.
|
|
||||||
"""
|
|
||||||
# Optimize for single middleware case
|
|
||||||
if len(middlewares) == 1:
|
|
||||||
middleware = middlewares[0]
|
|
||||||
|
|
||||||
async def single_middleware_handler(req: ClientRequest) -> ClientResponse:
|
|
||||||
return await middleware(req, handler)
|
|
||||||
|
|
||||||
return single_middleware_handler
|
|
||||||
|
|
||||||
# Build the chain for multiple middlewares
|
|
||||||
current_handler = handler
|
|
||||||
|
|
||||||
for middleware in reversed(middlewares):
|
|
||||||
# Create a new closure that captures the current state
|
|
||||||
def make_wrapper(
|
|
||||||
mw: ClientMiddlewareType, next_h: ClientHandlerType
|
|
||||||
) -> ClientHandlerType:
|
|
||||||
async def wrapped(req: ClientRequest) -> ClientResponse:
|
|
||||||
return await mw(req, next_h)
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
current_handler = make_wrapper(middleware, current_handler)
|
|
||||||
|
|
||||||
return current_handler
|
|
||||||
|
|
@ -1,359 +0,0 @@
|
||||||
import asyncio
|
|
||||||
from contextlib import suppress
|
|
||||||
from typing import Any, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from .base_protocol import BaseProtocol
|
|
||||||
from .client_exceptions import (
|
|
||||||
ClientConnectionError,
|
|
||||||
ClientOSError,
|
|
||||||
ClientPayloadError,
|
|
||||||
ServerDisconnectedError,
|
|
||||||
SocketTimeoutError,
|
|
||||||
)
|
|
||||||
from .helpers import (
|
|
||||||
_EXC_SENTINEL,
|
|
||||||
EMPTY_BODY_STATUS_CODES,
|
|
||||||
BaseTimerContext,
|
|
||||||
set_exception,
|
|
||||||
set_result,
|
|
||||||
)
|
|
||||||
from .http import HttpResponseParser, RawResponseMessage
|
|
||||||
from .http_exceptions import HttpProcessingError
|
|
||||||
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]):
|
|
||||||
"""Helper class to adapt between Protocol and StreamReader."""
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
|
||||||
BaseProtocol.__init__(self, loop=loop)
|
|
||||||
DataQueue.__init__(self, loop)
|
|
||||||
|
|
||||||
self._should_close = False
|
|
||||||
|
|
||||||
self._payload: Optional[StreamReader] = None
|
|
||||||
self._skip_payload = False
|
|
||||||
self._payload_parser = None
|
|
||||||
|
|
||||||
self._timer = None
|
|
||||||
|
|
||||||
self._tail = b""
|
|
||||||
self._upgraded = False
|
|
||||||
self._parser: Optional[HttpResponseParser] = None
|
|
||||||
|
|
||||||
self._read_timeout: Optional[float] = None
|
|
||||||
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
|
|
||||||
|
|
||||||
self._timeout_ceil_threshold: Optional[float] = 5
|
|
||||||
|
|
||||||
self._closed: Union[None, asyncio.Future[None]] = None
|
|
||||||
self._connection_lost_called = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def closed(self) -> Union[None, asyncio.Future[None]]:
|
|
||||||
"""Future that is set when the connection is closed.
|
|
||||||
|
|
||||||
This property returns a Future that will be completed when the connection
|
|
||||||
is closed. The Future is created lazily on first access to avoid creating
|
|
||||||
futures that will never be awaited.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- A Future[None] if the connection is still open or was closed after
|
|
||||||
this property was accessed
|
|
||||||
- None if connection_lost() was already called before this property
|
|
||||||
was ever accessed (indicating no one is waiting for the closure)
|
|
||||||
"""
|
|
||||||
if self._closed is None and not self._connection_lost_called:
|
|
||||||
self._closed = self._loop.create_future()
|
|
||||||
return self._closed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def upgraded(self) -> bool:
|
|
||||||
return self._upgraded
|
|
||||||
|
|
||||||
@property
|
|
||||||
def should_close(self) -> bool:
|
|
||||||
return bool(
|
|
||||||
self._should_close
|
|
||||||
or (self._payload is not None and not self._payload.is_eof())
|
|
||||||
or self._upgraded
|
|
||||||
or self._exception is not None
|
|
||||||
or self._payload_parser is not None
|
|
||||||
or self._buffer
|
|
||||||
or self._tail
|
|
||||||
)
|
|
||||||
|
|
||||||
def force_close(self) -> None:
|
|
||||||
self._should_close = True
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._exception = None # Break cyclic references
|
|
||||||
transport = self.transport
|
|
||||||
if transport is not None:
|
|
||||||
transport.close()
|
|
||||||
self.transport = None
|
|
||||||
self._payload = None
|
|
||||||
self._drop_timeout()
|
|
||||||
|
|
||||||
def abort(self) -> None:
|
|
||||||
self._exception = None # Break cyclic references
|
|
||||||
transport = self.transport
|
|
||||||
if transport is not None:
|
|
||||||
transport.abort()
|
|
||||||
self.transport = None
|
|
||||||
self._payload = None
|
|
||||||
self._drop_timeout()
|
|
||||||
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.transport is not None and not self.transport.is_closing()
|
|
||||||
|
|
||||||
def connection_lost(self, exc: Optional[BaseException]) -> None:
|
|
||||||
self._connection_lost_called = True
|
|
||||||
self._drop_timeout()
|
|
||||||
|
|
||||||
original_connection_error = exc
|
|
||||||
reraised_exc = original_connection_error
|
|
||||||
|
|
||||||
connection_closed_cleanly = original_connection_error is None
|
|
||||||
|
|
||||||
if self._closed is not None:
|
|
||||||
# If someone is waiting for the closed future,
|
|
||||||
# we should set it to None or an exception. If
|
|
||||||
# self._closed is None, it means that
|
|
||||||
# connection_lost() was called already
|
|
||||||
# or nobody is waiting for it.
|
|
||||||
if connection_closed_cleanly:
|
|
||||||
set_result(self._closed, None)
|
|
||||||
else:
|
|
||||||
assert original_connection_error is not None
|
|
||||||
set_exception(
|
|
||||||
self._closed,
|
|
||||||
ClientConnectionError(
|
|
||||||
f"Connection lost: {original_connection_error !s}",
|
|
||||||
),
|
|
||||||
original_connection_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._payload_parser is not None:
|
|
||||||
with suppress(Exception): # FIXME: log this somehow?
|
|
||||||
self._payload_parser.feed_eof()
|
|
||||||
|
|
||||||
uncompleted = None
|
|
||||||
if self._parser is not None:
|
|
||||||
try:
|
|
||||||
uncompleted = self._parser.feed_eof()
|
|
||||||
except Exception as underlying_exc:
|
|
||||||
if self._payload is not None:
|
|
||||||
client_payload_exc_msg = (
|
|
||||||
f"Response payload is not completed: {underlying_exc !r}"
|
|
||||||
)
|
|
||||||
if not connection_closed_cleanly:
|
|
||||||
client_payload_exc_msg = (
|
|
||||||
f"{client_payload_exc_msg !s}. "
|
|
||||||
f"{original_connection_error !r}"
|
|
||||||
)
|
|
||||||
set_exception(
|
|
||||||
self._payload,
|
|
||||||
ClientPayloadError(client_payload_exc_msg),
|
|
||||||
underlying_exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.is_eof():
|
|
||||||
if isinstance(original_connection_error, OSError):
|
|
||||||
reraised_exc = ClientOSError(*original_connection_error.args)
|
|
||||||
if connection_closed_cleanly:
|
|
||||||
reraised_exc = ServerDisconnectedError(uncompleted)
|
|
||||||
# assigns self._should_close to True as side effect,
|
|
||||||
# we do it anyway below
|
|
||||||
underlying_non_eof_exc = (
|
|
||||||
_EXC_SENTINEL
|
|
||||||
if connection_closed_cleanly
|
|
||||||
else original_connection_error
|
|
||||||
)
|
|
||||||
assert underlying_non_eof_exc is not None
|
|
||||||
assert reraised_exc is not None
|
|
||||||
self.set_exception(reraised_exc, underlying_non_eof_exc)
|
|
||||||
|
|
||||||
self._should_close = True
|
|
||||||
self._parser = None
|
|
||||||
self._payload = None
|
|
||||||
self._payload_parser = None
|
|
||||||
self._reading_paused = False
|
|
||||||
|
|
||||||
super().connection_lost(reraised_exc)
|
|
||||||
|
|
||||||
def eof_received(self) -> None:
|
|
||||||
# should call parser.feed_eof() most likely
|
|
||||||
self._drop_timeout()
|
|
||||||
|
|
||||||
def pause_reading(self) -> None:
|
|
||||||
super().pause_reading()
|
|
||||||
self._drop_timeout()
|
|
||||||
|
|
||||||
def resume_reading(self) -> None:
|
|
||||||
super().resume_reading()
|
|
||||||
self._reschedule_timeout()
|
|
||||||
|
|
||||||
def set_exception(
|
|
||||||
self,
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: BaseException = _EXC_SENTINEL,
|
|
||||||
) -> None:
|
|
||||||
self._should_close = True
|
|
||||||
self._drop_timeout()
|
|
||||||
super().set_exception(exc, exc_cause)
|
|
||||||
|
|
||||||
def set_parser(self, parser: Any, payload: Any) -> None:
|
|
||||||
# TODO: actual types are:
|
|
||||||
# parser: WebSocketReader
|
|
||||||
# payload: WebSocketDataQueue
|
|
||||||
# but they are not generi enough
|
|
||||||
# Need an ABC for both types
|
|
||||||
self._payload = payload
|
|
||||||
self._payload_parser = parser
|
|
||||||
|
|
||||||
self._drop_timeout()
|
|
||||||
|
|
||||||
if self._tail:
|
|
||||||
data, self._tail = self._tail, b""
|
|
||||||
self.data_received(data)
|
|
||||||
|
|
||||||
def set_response_params(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
timer: Optional[BaseTimerContext] = None,
|
|
||||||
skip_payload: bool = False,
|
|
||||||
read_until_eof: bool = False,
|
|
||||||
auto_decompress: bool = True,
|
|
||||||
read_timeout: Optional[float] = None,
|
|
||||||
read_bufsize: int = 2**16,
|
|
||||||
timeout_ceil_threshold: float = 5,
|
|
||||||
max_line_size: int = 8190,
|
|
||||||
max_field_size: int = 8190,
|
|
||||||
) -> None:
|
|
||||||
self._skip_payload = skip_payload
|
|
||||||
|
|
||||||
self._read_timeout = read_timeout
|
|
||||||
|
|
||||||
self._timeout_ceil_threshold = timeout_ceil_threshold
|
|
||||||
|
|
||||||
self._parser = HttpResponseParser(
|
|
||||||
self,
|
|
||||||
self._loop,
|
|
||||||
read_bufsize,
|
|
||||||
timer=timer,
|
|
||||||
payload_exception=ClientPayloadError,
|
|
||||||
response_with_body=not skip_payload,
|
|
||||||
read_until_eof=read_until_eof,
|
|
||||||
auto_decompress=auto_decompress,
|
|
||||||
max_line_size=max_line_size,
|
|
||||||
max_field_size=max_field_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._tail:
|
|
||||||
data, self._tail = self._tail, b""
|
|
||||||
self.data_received(data)
|
|
||||||
|
|
||||||
def _drop_timeout(self) -> None:
|
|
||||||
if self._read_timeout_handle is not None:
|
|
||||||
self._read_timeout_handle.cancel()
|
|
||||||
self._read_timeout_handle = None
|
|
||||||
|
|
||||||
def _reschedule_timeout(self) -> None:
|
|
||||||
timeout = self._read_timeout
|
|
||||||
if self._read_timeout_handle is not None:
|
|
||||||
self._read_timeout_handle.cancel()
|
|
||||||
|
|
||||||
if timeout:
|
|
||||||
self._read_timeout_handle = self._loop.call_later(
|
|
||||||
timeout, self._on_read_timeout
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._read_timeout_handle = None
|
|
||||||
|
|
||||||
def start_timeout(self) -> None:
|
|
||||||
self._reschedule_timeout()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def read_timeout(self) -> Optional[float]:
|
|
||||||
return self._read_timeout
|
|
||||||
|
|
||||||
@read_timeout.setter
|
|
||||||
def read_timeout(self, read_timeout: Optional[float]) -> None:
|
|
||||||
self._read_timeout = read_timeout
|
|
||||||
|
|
||||||
def _on_read_timeout(self) -> None:
|
|
||||||
exc = SocketTimeoutError("Timeout on reading data from socket")
|
|
||||||
self.set_exception(exc)
|
|
||||||
if self._payload is not None:
|
|
||||||
set_exception(self._payload, exc)
|
|
||||||
|
|
||||||
def data_received(self, data: bytes) -> None:
|
|
||||||
self._reschedule_timeout()
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
|
|
||||||
# custom payload parser - currently always WebSocketReader
|
|
||||||
if self._payload_parser is not None:
|
|
||||||
eof, tail = self._payload_parser.feed_data(data)
|
|
||||||
if eof:
|
|
||||||
self._payload = None
|
|
||||||
self._payload_parser = None
|
|
||||||
|
|
||||||
if tail:
|
|
||||||
self.data_received(tail)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._upgraded or self._parser is None:
|
|
||||||
# i.e. websocket connection, websocket parser is not set yet
|
|
||||||
self._tail += data
|
|
||||||
return
|
|
||||||
|
|
||||||
# parse http messages
|
|
||||||
try:
|
|
||||||
messages, upgraded, tail = self._parser.feed_data(data)
|
|
||||||
except BaseException as underlying_exc:
|
|
||||||
if self.transport is not None:
|
|
||||||
# connection.release() could be called BEFORE
|
|
||||||
# data_received(), the transport is already
|
|
||||||
# closed in this case
|
|
||||||
self.transport.close()
|
|
||||||
# should_close is True after the call
|
|
||||||
if isinstance(underlying_exc, HttpProcessingError):
|
|
||||||
exc = HttpProcessingError(
|
|
||||||
code=underlying_exc.code,
|
|
||||||
message=underlying_exc.message,
|
|
||||||
headers=underlying_exc.headers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
exc = HttpProcessingError()
|
|
||||||
self.set_exception(exc, underlying_exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._upgraded = upgraded
|
|
||||||
|
|
||||||
payload: Optional[StreamReader] = None
|
|
||||||
for message, payload in messages:
|
|
||||||
if message.should_close:
|
|
||||||
self._should_close = True
|
|
||||||
|
|
||||||
self._payload = payload
|
|
||||||
|
|
||||||
if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
|
|
||||||
self.feed_data((message, EMPTY_PAYLOAD), 0)
|
|
||||||
else:
|
|
||||||
self.feed_data((message, payload), 0)
|
|
||||||
|
|
||||||
if payload is not None:
|
|
||||||
# new message(s) was processed
|
|
||||||
# register timeout handler unsubscribing
|
|
||||||
# either on end-of-stream or immediately for
|
|
||||||
# EMPTY_PAYLOAD
|
|
||||||
if payload is not EMPTY_PAYLOAD:
|
|
||||||
payload.on_eof(self._drop_timeout)
|
|
||||||
else:
|
|
||||||
self._drop_timeout()
|
|
||||||
|
|
||||||
if upgraded and tail:
|
|
||||||
self.data_received(tail)
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,428 +0,0 @@
|
||||||
"""WebSocket client for asyncio."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
from types import TracebackType
|
|
||||||
from typing import Any, Optional, Type, cast
|
|
||||||
|
|
||||||
import attr
|
|
||||||
|
|
||||||
from ._websocket.reader import WebSocketDataQueue
|
|
||||||
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
|
|
||||||
from .client_reqrep import ClientResponse
|
|
||||||
from .helpers import calculate_timeout_when, set_result
|
|
||||||
from .http import (
|
|
||||||
WS_CLOSED_MESSAGE,
|
|
||||||
WS_CLOSING_MESSAGE,
|
|
||||||
WebSocketError,
|
|
||||||
WSCloseCode,
|
|
||||||
WSMessage,
|
|
||||||
WSMsgType,
|
|
||||||
)
|
|
||||||
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
|
|
||||||
from .streams import EofStream
|
|
||||||
from .typedefs import (
|
|
||||||
DEFAULT_JSON_DECODER,
|
|
||||||
DEFAULT_JSON_ENCODER,
|
|
||||||
JSONDecoder,
|
|
||||||
JSONEncoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
|
||||||
import asyncio as async_timeout
|
|
||||||
else:
|
|
||||||
import async_timeout
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True)
|
|
||||||
class ClientWSTimeout:
|
|
||||||
ws_receive = attr.ib(type=Optional[float], default=None)
|
|
||||||
ws_close = attr.ib(type=Optional[float], default=None)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientWebSocketResponse:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
reader: WebSocketDataQueue,
|
|
||||||
writer: WebSocketWriter,
|
|
||||||
protocol: Optional[str],
|
|
||||||
response: ClientResponse,
|
|
||||||
timeout: ClientWSTimeout,
|
|
||||||
autoclose: bool,
|
|
||||||
autoping: bool,
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
*,
|
|
||||||
heartbeat: Optional[float] = None,
|
|
||||||
compress: int = 0,
|
|
||||||
client_notakeover: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self._response = response
|
|
||||||
self._conn = response.connection
|
|
||||||
|
|
||||||
self._writer = writer
|
|
||||||
self._reader = reader
|
|
||||||
self._protocol = protocol
|
|
||||||
self._closed = False
|
|
||||||
self._closing = False
|
|
||||||
self._close_code: Optional[int] = None
|
|
||||||
self._timeout = timeout
|
|
||||||
self._autoclose = autoclose
|
|
||||||
self._autoping = autoping
|
|
||||||
self._heartbeat = heartbeat
|
|
||||||
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
|
|
||||||
self._heartbeat_when: float = 0.0
|
|
||||||
if heartbeat is not None:
|
|
||||||
self._pong_heartbeat = heartbeat / 2.0
|
|
||||||
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
|
|
||||||
self._loop = loop
|
|
||||||
self._waiting: bool = False
|
|
||||||
self._close_wait: Optional[asyncio.Future[None]] = None
|
|
||||||
self._exception: Optional[BaseException] = None
|
|
||||||
self._compress = compress
|
|
||||||
self._client_notakeover = client_notakeover
|
|
||||||
self._ping_task: Optional[asyncio.Task[None]] = None
|
|
||||||
|
|
||||||
self._reset_heartbeat()
|
|
||||||
|
|
||||||
def _cancel_heartbeat(self) -> None:
|
|
||||||
self._cancel_pong_response_cb()
|
|
||||||
if self._heartbeat_cb is not None:
|
|
||||||
self._heartbeat_cb.cancel()
|
|
||||||
self._heartbeat_cb = None
|
|
||||||
if self._ping_task is not None:
|
|
||||||
self._ping_task.cancel()
|
|
||||||
self._ping_task = None
|
|
||||||
|
|
||||||
def _cancel_pong_response_cb(self) -> None:
|
|
||||||
if self._pong_response_cb is not None:
|
|
||||||
self._pong_response_cb.cancel()
|
|
||||||
self._pong_response_cb = None
|
|
||||||
|
|
||||||
def _reset_heartbeat(self) -> None:
|
|
||||||
if self._heartbeat is None:
|
|
||||||
return
|
|
||||||
self._cancel_pong_response_cb()
|
|
||||||
loop = self._loop
|
|
||||||
assert loop is not None
|
|
||||||
conn = self._conn
|
|
||||||
timeout_ceil_threshold = (
|
|
||||||
conn._connector._timeout_ceil_threshold if conn is not None else 5
|
|
||||||
)
|
|
||||||
now = loop.time()
|
|
||||||
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
|
|
||||||
self._heartbeat_when = when
|
|
||||||
if self._heartbeat_cb is None:
|
|
||||||
# We do not cancel the previous heartbeat_cb here because
|
|
||||||
# it generates a significant amount of TimerHandle churn
|
|
||||||
# which causes asyncio to rebuild the heap frequently.
|
|
||||||
# Instead _send_heartbeat() will reschedule the next
|
|
||||||
# heartbeat if it fires too early.
|
|
||||||
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
|
|
||||||
|
|
||||||
def _send_heartbeat(self) -> None:
|
|
||||||
self._heartbeat_cb = None
|
|
||||||
loop = self._loop
|
|
||||||
now = loop.time()
|
|
||||||
if now < self._heartbeat_when:
|
|
||||||
# Heartbeat fired too early, reschedule
|
|
||||||
self._heartbeat_cb = loop.call_at(
|
|
||||||
self._heartbeat_when, self._send_heartbeat
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
conn = self._conn
|
|
||||||
timeout_ceil_threshold = (
|
|
||||||
conn._connector._timeout_ceil_threshold if conn is not None else 5
|
|
||||||
)
|
|
||||||
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
|
|
||||||
self._cancel_pong_response_cb()
|
|
||||||
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
|
|
||||||
|
|
||||||
coro = self._writer.send_frame(b"", WSMsgType.PING)
|
|
||||||
if sys.version_info >= (3, 12):
|
|
||||||
# Optimization for Python 3.12, try to send the ping
|
|
||||||
# immediately to avoid having to schedule
|
|
||||||
# the task on the event loop.
|
|
||||||
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
|
|
||||||
else:
|
|
||||||
ping_task = loop.create_task(coro)
|
|
||||||
|
|
||||||
if not ping_task.done():
|
|
||||||
self._ping_task = ping_task
|
|
||||||
ping_task.add_done_callback(self._ping_task_done)
|
|
||||||
else:
|
|
||||||
self._ping_task_done(ping_task)
|
|
||||||
|
|
||||||
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
|
|
||||||
"""Callback for when the ping task completes."""
|
|
||||||
if not task.cancelled() and (exc := task.exception()):
|
|
||||||
self._handle_ping_pong_exception(exc)
|
|
||||||
self._ping_task = None
|
|
||||||
|
|
||||||
def _pong_not_received(self) -> None:
|
|
||||||
self._handle_ping_pong_exception(
|
|
||||||
ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
|
|
||||||
"""Handle exceptions raised during ping/pong processing."""
|
|
||||||
if self._closed:
|
|
||||||
return
|
|
||||||
self._set_closed()
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
self._exception = exc
|
|
||||||
self._response.close()
|
|
||||||
if self._waiting and not self._closing:
|
|
||||||
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
|
|
||||||
|
|
||||||
def _set_closed(self) -> None:
|
|
||||||
"""Set the connection to closed.
|
|
||||||
|
|
||||||
Cancel any heartbeat timers and set the closed flag.
|
|
||||||
"""
|
|
||||||
self._closed = True
|
|
||||||
self._cancel_heartbeat()
|
|
||||||
|
|
||||||
def _set_closing(self) -> None:
|
|
||||||
"""Set the connection to closing.
|
|
||||||
|
|
||||||
Cancel any heartbeat timers and set the closing flag.
|
|
||||||
"""
|
|
||||||
self._closing = True
|
|
||||||
self._cancel_heartbeat()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def closed(self) -> bool:
|
|
||||||
return self._closed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def close_code(self) -> Optional[int]:
|
|
||||||
return self._close_code
|
|
||||||
|
|
||||||
@property
|
|
||||||
def protocol(self) -> Optional[str]:
|
|
||||||
return self._protocol
|
|
||||||
|
|
||||||
@property
|
|
||||||
def compress(self) -> int:
|
|
||||||
return self._compress
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client_notakeover(self) -> bool:
|
|
||||||
return self._client_notakeover
|
|
||||||
|
|
||||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
|
||||||
"""extra info from connection transport"""
|
|
||||||
conn = self._response.connection
|
|
||||||
if conn is None:
|
|
||||||
return default
|
|
||||||
transport = conn.transport
|
|
||||||
if transport is None:
|
|
||||||
return default
|
|
||||||
return transport.get_extra_info(name, default)
|
|
||||||
|
|
||||||
def exception(self) -> Optional[BaseException]:
|
|
||||||
return self._exception
|
|
||||||
|
|
||||||
async def ping(self, message: bytes = b"") -> None:
|
|
||||||
await self._writer.send_frame(message, WSMsgType.PING)
|
|
||||||
|
|
||||||
async def pong(self, message: bytes = b"") -> None:
|
|
||||||
await self._writer.send_frame(message, WSMsgType.PONG)
|
|
||||||
|
|
||||||
async def send_frame(
|
|
||||||
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
"""Send a frame over the websocket."""
|
|
||||||
await self._writer.send_frame(message, opcode, compress)
|
|
||||||
|
|
||||||
async def send_str(self, data: str, compress: Optional[int] = None) -> None:
|
|
||||||
if not isinstance(data, str):
|
|
||||||
raise TypeError("data argument must be str (%r)" % type(data))
|
|
||||||
await self._writer.send_frame(
|
|
||||||
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
|
|
||||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
|
||||||
raise TypeError("data argument must be byte-ish (%r)" % type(data))
|
|
||||||
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
|
|
||||||
|
|
||||||
async def send_json(
|
|
||||||
self,
|
|
||||||
data: Any,
|
|
||||||
compress: Optional[int] = None,
|
|
||||||
*,
|
|
||||||
dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
|
|
||||||
) -> None:
|
|
||||||
await self.send_str(dumps(data), compress=compress)
|
|
||||||
|
|
||||||
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
|
|
||||||
# we need to break `receive()` cycle first,
|
|
||||||
# `close()` may be called from different task
|
|
||||||
if self._waiting and not self._closing:
|
|
||||||
assert self._loop is not None
|
|
||||||
self._close_wait = self._loop.create_future()
|
|
||||||
self._set_closing()
|
|
||||||
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
|
|
||||||
await self._close_wait
|
|
||||||
|
|
||||||
if self._closed:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._set_closed()
|
|
||||||
try:
|
|
||||||
await self._writer.close(code, message)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
self._response.close()
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
self._exception = exc
|
|
||||||
self._response.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
if self._close_code:
|
|
||||||
self._response.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
async with async_timeout.timeout(self._timeout.ws_close):
|
|
||||||
msg = await self._reader.read()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
self._response.close()
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
self._exception = exc
|
|
||||||
self._response.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
if msg.type is WSMsgType.CLOSE:
|
|
||||||
self._close_code = msg.data
|
|
||||||
self._response.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
|
|
||||||
receive_timeout = timeout or self._timeout.ws_receive
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if self._waiting:
|
|
||||||
raise RuntimeError("Concurrent call to receive() is not allowed")
|
|
||||||
|
|
||||||
if self._closed:
|
|
||||||
return WS_CLOSED_MESSAGE
|
|
||||||
elif self._closing:
|
|
||||||
await self.close()
|
|
||||||
return WS_CLOSED_MESSAGE
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._waiting = True
|
|
||||||
try:
|
|
||||||
if receive_timeout:
|
|
||||||
# Entering the context manager and creating
|
|
||||||
# Timeout() object can take almost 50% of the
|
|
||||||
# run time in this loop so we avoid it if
|
|
||||||
# there is no read timeout.
|
|
||||||
async with async_timeout.timeout(receive_timeout):
|
|
||||||
msg = await self._reader.read()
|
|
||||||
else:
|
|
||||||
msg = await self._reader.read()
|
|
||||||
self._reset_heartbeat()
|
|
||||||
finally:
|
|
||||||
self._waiting = False
|
|
||||||
if self._close_wait:
|
|
||||||
set_result(self._close_wait, None)
|
|
||||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
raise
|
|
||||||
except EofStream:
|
|
||||||
self._close_code = WSCloseCode.OK
|
|
||||||
await self.close()
|
|
||||||
return WSMessage(WSMsgType.CLOSED, None, None)
|
|
||||||
except ClientError:
|
|
||||||
# Likely ServerDisconnectedError when connection is lost
|
|
||||||
self._set_closed()
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
return WS_CLOSED_MESSAGE
|
|
||||||
except WebSocketError as exc:
|
|
||||||
self._close_code = exc.code
|
|
||||||
await self.close(code=exc.code)
|
|
||||||
return WSMessage(WSMsgType.ERROR, exc, None)
|
|
||||||
except Exception as exc:
|
|
||||||
self._exception = exc
|
|
||||||
self._set_closing()
|
|
||||||
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
|
|
||||||
await self.close()
|
|
||||||
return WSMessage(WSMsgType.ERROR, exc, None)
|
|
||||||
|
|
||||||
if msg.type not in _INTERNAL_RECEIVE_TYPES:
|
|
||||||
# If its not a close/closing/ping/pong message
|
|
||||||
# we can return it immediately
|
|
||||||
return msg
|
|
||||||
|
|
||||||
if msg.type is WSMsgType.CLOSE:
|
|
||||||
self._set_closing()
|
|
||||||
self._close_code = msg.data
|
|
||||||
if not self._closed and self._autoclose:
|
|
||||||
await self.close()
|
|
||||||
elif msg.type is WSMsgType.CLOSING:
|
|
||||||
self._set_closing()
|
|
||||||
elif msg.type is WSMsgType.PING and self._autoping:
|
|
||||||
await self.pong(msg.data)
|
|
||||||
continue
|
|
||||||
elif msg.type is WSMsgType.PONG and self._autoping:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return msg
|
|
||||||
|
|
||||||
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
|
|
||||||
msg = await self.receive(timeout)
|
|
||||||
if msg.type is not WSMsgType.TEXT:
|
|
||||||
raise WSMessageTypeError(
|
|
||||||
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
|
|
||||||
)
|
|
||||||
return cast(str, msg.data)
|
|
||||||
|
|
||||||
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
|
|
||||||
msg = await self.receive(timeout)
|
|
||||||
if msg.type is not WSMsgType.BINARY:
|
|
||||||
raise WSMessageTypeError(
|
|
||||||
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
|
|
||||||
)
|
|
||||||
return cast(bytes, msg.data)
|
|
||||||
|
|
||||||
async def receive_json(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
loads: JSONDecoder = DEFAULT_JSON_DECODER,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
) -> Any:
|
|
||||||
data = await self.receive_str(timeout=timeout)
|
|
||||||
return loads(data)
|
|
||||||
|
|
||||||
def __aiter__(self) -> "ClientWebSocketResponse":
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self) -> WSMessage:
|
|
||||||
msg = await self.receive()
|
|
||||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
|
|
||||||
raise StopAsyncIteration
|
|
||||||
return msg
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "ClientWebSocketResponse":
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc_val: Optional[BaseException],
|
|
||||||
exc_tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
await self.close()
|
|
||||||
|
|
@ -1,319 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
import zlib
|
|
||||||
from concurrent.futures import Executor
|
|
||||||
from typing import Any, Final, Optional, Protocol, TypedDict, cast
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
|
||||||
from collections.abc import Buffer
|
|
||||||
else:
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
import brotlicffi as brotli
|
|
||||||
except ImportError:
|
|
||||||
import brotli
|
|
||||||
|
|
||||||
HAS_BROTLI = True
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
HAS_BROTLI = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if sys.version_info >= (3, 14):
|
|
||||||
from compression.zstd import ZstdDecompressor # noqa: I900
|
|
||||||
else: # TODO(PY314): Remove mentions of backports.zstd across codebase
|
|
||||||
from backports.zstd import ZstdDecompressor
|
|
||||||
|
|
||||||
HAS_ZSTD = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_ZSTD = False
|
|
||||||
|
|
||||||
|
|
||||||
MAX_SYNC_CHUNK_SIZE = 1024
|
|
||||||
|
|
||||||
|
|
||||||
class ZLibCompressObjProtocol(Protocol):
|
|
||||||
def compress(self, data: Buffer) -> bytes: ...
|
|
||||||
def flush(self, mode: int = ..., /) -> bytes: ...
|
|
||||||
|
|
||||||
|
|
||||||
class ZLibDecompressObjProtocol(Protocol):
|
|
||||||
def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
|
|
||||||
def flush(self, length: int = ..., /) -> bytes: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eof(self) -> bool: ...
|
|
||||||
|
|
||||||
|
|
||||||
class ZLibBackendProtocol(Protocol):
|
|
||||||
MAX_WBITS: int
|
|
||||||
Z_FULL_FLUSH: int
|
|
||||||
Z_SYNC_FLUSH: int
|
|
||||||
Z_BEST_SPEED: int
|
|
||||||
Z_FINISH: int
|
|
||||||
|
|
||||||
def compressobj(
|
|
||||||
self,
|
|
||||||
level: int = ...,
|
|
||||||
method: int = ...,
|
|
||||||
wbits: int = ...,
|
|
||||||
memLevel: int = ...,
|
|
||||||
strategy: int = ...,
|
|
||||||
zdict: Optional[Buffer] = ...,
|
|
||||||
) -> ZLibCompressObjProtocol: ...
|
|
||||||
def decompressobj(
|
|
||||||
self, wbits: int = ..., zdict: Buffer = ...
|
|
||||||
) -> ZLibDecompressObjProtocol: ...
|
|
||||||
|
|
||||||
def compress(
|
|
||||||
self, data: Buffer, /, level: int = ..., wbits: int = ...
|
|
||||||
) -> bytes: ...
|
|
||||||
def decompress(
|
|
||||||
self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
|
|
||||||
) -> bytes: ...
|
|
||||||
|
|
||||||
|
|
||||||
class CompressObjArgs(TypedDict, total=False):
|
|
||||||
wbits: int
|
|
||||||
strategy: int
|
|
||||||
level: int
|
|
||||||
|
|
||||||
|
|
||||||
class ZLibBackendWrapper:
|
|
||||||
def __init__(self, _zlib_backend: ZLibBackendProtocol):
|
|
||||||
self._zlib_backend: ZLibBackendProtocol = _zlib_backend
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return getattr(self._zlib_backend, "__name__", "undefined")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def MAX_WBITS(self) -> int:
|
|
||||||
return self._zlib_backend.MAX_WBITS
|
|
||||||
|
|
||||||
@property
|
|
||||||
def Z_FULL_FLUSH(self) -> int:
|
|
||||||
return self._zlib_backend.Z_FULL_FLUSH
|
|
||||||
|
|
||||||
@property
|
|
||||||
def Z_SYNC_FLUSH(self) -> int:
|
|
||||||
return self._zlib_backend.Z_SYNC_FLUSH
|
|
||||||
|
|
||||||
@property
|
|
||||||
def Z_BEST_SPEED(self) -> int:
|
|
||||||
return self._zlib_backend.Z_BEST_SPEED
|
|
||||||
|
|
||||||
@property
|
|
||||||
def Z_FINISH(self) -> int:
|
|
||||||
return self._zlib_backend.Z_FINISH
|
|
||||||
|
|
||||||
def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
|
|
||||||
return self._zlib_backend.compressobj(*args, **kwargs)
|
|
||||||
|
|
||||||
def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
|
|
||||||
return self._zlib_backend.decompressobj(*args, **kwargs)
|
|
||||||
|
|
||||||
def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
|
|
||||||
return self._zlib_backend.compress(data, *args, **kwargs)
|
|
||||||
|
|
||||||
def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
|
|
||||||
return self._zlib_backend.decompress(data, *args, **kwargs)
|
|
||||||
|
|
||||||
# Everything not explicitly listed in the Protocol we just pass through
|
|
||||||
def __getattr__(self, attrname: str) -> Any:
|
|
||||||
return getattr(self._zlib_backend, attrname)
|
|
||||||
|
|
||||||
|
|
||||||
ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
|
|
||||||
|
|
||||||
|
|
||||||
def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
|
|
||||||
ZLibBackend._zlib_backend = new_zlib_backend
|
|
||||||
|
|
||||||
|
|
||||||
def encoding_to_mode(
|
|
||||||
encoding: Optional[str] = None,
|
|
||||||
suppress_deflate_header: bool = False,
|
|
||||||
) -> int:
|
|
||||||
if encoding == "gzip":
|
|
||||||
return 16 + ZLibBackend.MAX_WBITS
|
|
||||||
|
|
||||||
return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
|
|
||||||
|
|
||||||
|
|
||||||
class ZlibBaseHandler:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mode: int,
|
|
||||||
executor: Optional[Executor] = None,
|
|
||||||
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
|
|
||||||
):
|
|
||||||
self._mode = mode
|
|
||||||
self._executor = executor
|
|
||||||
self._max_sync_chunk_size = max_sync_chunk_size
|
|
||||||
|
|
||||||
|
|
||||||
class ZLibCompressor(ZlibBaseHandler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
encoding: Optional[str] = None,
|
|
||||||
suppress_deflate_header: bool = False,
|
|
||||||
level: Optional[int] = None,
|
|
||||||
wbits: Optional[int] = None,
|
|
||||||
strategy: Optional[int] = None,
|
|
||||||
executor: Optional[Executor] = None,
|
|
||||||
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
mode=(
|
|
||||||
encoding_to_mode(encoding, suppress_deflate_header)
|
|
||||||
if wbits is None
|
|
||||||
else wbits
|
|
||||||
),
|
|
||||||
executor=executor,
|
|
||||||
max_sync_chunk_size=max_sync_chunk_size,
|
|
||||||
)
|
|
||||||
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
|
|
||||||
|
|
||||||
kwargs: CompressObjArgs = {}
|
|
||||||
kwargs["wbits"] = self._mode
|
|
||||||
if strategy is not None:
|
|
||||||
kwargs["strategy"] = strategy
|
|
||||||
if level is not None:
|
|
||||||
kwargs["level"] = level
|
|
||||||
self._compressor = self._zlib_backend.compressobj(**kwargs)
|
|
||||||
|
|
||||||
def compress_sync(self, data: bytes) -> bytes:
|
|
||||||
return self._compressor.compress(data)
|
|
||||||
|
|
||||||
async def compress(self, data: bytes) -> bytes:
|
|
||||||
"""Compress the data and returned the compressed bytes.
|
|
||||||
|
|
||||||
Note that flush() must be called after the last call to compress()
|
|
||||||
|
|
||||||
If the data size is large than the max_sync_chunk_size, the compression
|
|
||||||
will be done in the executor. Otherwise, the compression will be done
|
|
||||||
in the event loop.
|
|
||||||
|
|
||||||
**WARNING: This method is NOT cancellation-safe when used with flush().**
|
|
||||||
If this operation is cancelled, the compressor state may be corrupted.
|
|
||||||
The connection MUST be closed after cancellation to avoid data corruption
|
|
||||||
in subsequent compress operations.
|
|
||||||
|
|
||||||
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
|
|
||||||
compress() + flush() + send operations in a shield and lock to ensure atomicity.
|
|
||||||
"""
|
|
||||||
# For large payloads, offload compression to executor to avoid blocking event loop
|
|
||||||
should_use_executor = (
|
|
||||||
self._max_sync_chunk_size is not None
|
|
||||||
and len(data) > self._max_sync_chunk_size
|
|
||||||
)
|
|
||||||
if should_use_executor:
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
self._executor, self._compressor.compress, data
|
|
||||||
)
|
|
||||||
return self.compress_sync(data)
|
|
||||||
|
|
||||||
def flush(self, mode: Optional[int] = None) -> bytes:
|
|
||||||
"""Flush the compressor synchronously.
|
|
||||||
|
|
||||||
**WARNING: This method is NOT cancellation-safe when called after compress().**
|
|
||||||
The flush() operation accesses shared compressor state. If compress() was
|
|
||||||
cancelled, calling flush() may result in corrupted data. The connection MUST
|
|
||||||
be closed after compress() cancellation.
|
|
||||||
|
|
||||||
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
|
|
||||||
compress() + flush() + send operations in a shield and lock to ensure atomicity.
|
|
||||||
"""
|
|
||||||
return self._compressor.flush(
|
|
||||||
mode if mode is not None else self._zlib_backend.Z_FINISH
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ZLibDecompressor(ZlibBaseHandler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
encoding: Optional[str] = None,
|
|
||||||
suppress_deflate_header: bool = False,
|
|
||||||
executor: Optional[Executor] = None,
|
|
||||||
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
mode=encoding_to_mode(encoding, suppress_deflate_header),
|
|
||||||
executor=executor,
|
|
||||||
max_sync_chunk_size=max_sync_chunk_size,
|
|
||||||
)
|
|
||||||
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
|
|
||||||
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
|
|
||||||
|
|
||||||
def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
|
|
||||||
return self._decompressor.decompress(data, max_length)
|
|
||||||
|
|
||||||
async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
|
|
||||||
"""Decompress the data and return the decompressed bytes.
|
|
||||||
|
|
||||||
If the data size is large than the max_sync_chunk_size, the decompression
|
|
||||||
will be done in the executor. Otherwise, the decompression will be done
|
|
||||||
in the event loop.
|
|
||||||
"""
|
|
||||||
if (
|
|
||||||
self._max_sync_chunk_size is not None
|
|
||||||
and len(data) > self._max_sync_chunk_size
|
|
||||||
):
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
self._executor, self._decompressor.decompress, data, max_length
|
|
||||||
)
|
|
||||||
return self.decompress_sync(data, max_length)
|
|
||||||
|
|
||||||
def flush(self, length: int = 0) -> bytes:
|
|
||||||
return (
|
|
||||||
self._decompressor.flush(length)
|
|
||||||
if length > 0
|
|
||||||
else self._decompressor.flush()
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eof(self) -> bool:
|
|
||||||
return self._decompressor.eof
|
|
||||||
|
|
||||||
|
|
||||||
class BrotliDecompressor:
|
|
||||||
# Supports both 'brotlipy' and 'Brotli' packages
|
|
||||||
# since they share an import name. The top branches
|
|
||||||
# are for 'brotlipy' and bottom branches for 'Brotli'
|
|
||||||
def __init__(self) -> None:
|
|
||||||
if not HAS_BROTLI:
|
|
||||||
raise RuntimeError(
|
|
||||||
"The brotli decompression is not available. "
|
|
||||||
"Please install `Brotli` module"
|
|
||||||
)
|
|
||||||
self._obj = brotli.Decompressor()
|
|
||||||
|
|
||||||
def decompress_sync(self, data: bytes) -> bytes:
|
|
||||||
if hasattr(self._obj, "decompress"):
|
|
||||||
return cast(bytes, self._obj.decompress(data))
|
|
||||||
return cast(bytes, self._obj.process(data))
|
|
||||||
|
|
||||||
def flush(self) -> bytes:
|
|
||||||
if hasattr(self._obj, "flush"):
|
|
||||||
return cast(bytes, self._obj.flush())
|
|
||||||
return b""
|
|
||||||
|
|
||||||
|
|
||||||
class ZSTDDecompressor:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
if not HAS_ZSTD:
|
|
||||||
raise RuntimeError(
|
|
||||||
"The zstd decompression is not available. "
|
|
||||||
"Please install `backports.zstd` module"
|
|
||||||
)
|
|
||||||
self._obj = ZstdDecompressor()
|
|
||||||
|
|
||||||
def decompress_sync(self, data: bytes) -> bytes:
|
|
||||||
return self._obj.decompress(data)
|
|
||||||
|
|
||||||
def flush(self) -> bytes:
|
|
||||||
return b""
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,522 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import calendar
|
|
||||||
import contextlib
|
|
||||||
import datetime
|
|
||||||
import heapq
|
|
||||||
import itertools
|
|
||||||
import os # noqa
|
|
||||||
import pathlib
|
|
||||||
import pickle
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from http.cookies import BaseCookie, Morsel, SimpleCookie
|
|
||||||
from typing import (
|
|
||||||
DefaultDict,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from ._cookie_helpers import preserve_morsel_with_coded_value
|
|
||||||
from .abc import AbstractCookieJar, ClearCookiePredicate
|
|
||||||
from .helpers import is_ip_address
|
|
||||||
from .typedefs import LooseCookies, PathLike, StrOrURL
|
|
||||||
|
|
||||||
__all__ = ("CookieJar", "DummyCookieJar")
|
|
||||||
|
|
||||||
|
|
||||||
CookieItem = Union[str, "Morsel[str]"]
|
|
||||||
|
|
||||||
# We cache these string methods here as their use is in performance critical code.
|
|
||||||
_FORMAT_PATH = "{}/{}".format
|
|
||||||
_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
|
|
||||||
|
|
||||||
# The minimum number of scheduled cookie expirations before we start cleaning up
|
|
||||||
# the expiration heap. This is a performance optimization to avoid cleaning up the
|
|
||||||
# heap too often when there are only a few scheduled expirations.
|
|
||||||
_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
|
|
||||||
_SIMPLE_COOKIE = SimpleCookie()
|
|
||||||
|
|
||||||
|
|
||||||
class CookieJar(AbstractCookieJar):
|
|
||||||
"""Implements cookie storage adhering to RFC 6265."""
|
|
||||||
|
|
||||||
DATE_TOKENS_RE = re.compile(
|
|
||||||
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
|
|
||||||
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
|
|
||||||
)
|
|
||||||
|
|
||||||
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
|
|
||||||
|
|
||||||
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
|
|
||||||
|
|
||||||
DATE_MONTH_RE = re.compile(
|
|
||||||
"(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
|
|
||||||
re.I,
|
|
||||||
)
|
|
||||||
|
|
||||||
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
|
|
||||||
|
|
||||||
# calendar.timegm() fails for timestamps after datetime.datetime.max
|
|
||||||
# Minus one as a loss of precision occurs when timestamp() is called.
|
|
||||||
MAX_TIME = (
|
|
||||||
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
calendar.timegm(time.gmtime(MAX_TIME))
|
|
||||||
except (OSError, ValueError):
|
|
||||||
# Hit the maximum representable time on Windows
|
|
||||||
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
|
|
||||||
# Throws ValueError on PyPy 3.9, OSError elsewhere
|
|
||||||
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
|
|
||||||
except OverflowError:
|
|
||||||
# #4515: datetime.max may not be representable on 32-bit platforms
|
|
||||||
MAX_TIME = 2**31 - 1
|
|
||||||
# Avoid minuses in the future, 3x faster
|
|
||||||
SUB_MAX_TIME = MAX_TIME - 1
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
unsafe: bool = False,
|
|
||||||
quote_cookie: bool = True,
|
|
||||||
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(loop=loop)
|
|
||||||
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
|
|
||||||
SimpleCookie
|
|
||||||
)
|
|
||||||
self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
|
|
||||||
defaultdict(dict)
|
|
||||||
)
|
|
||||||
self._host_only_cookies: Set[Tuple[str, str]] = set()
|
|
||||||
self._unsafe = unsafe
|
|
||||||
self._quote_cookie = quote_cookie
|
|
||||||
if treat_as_secure_origin is None:
|
|
||||||
treat_as_secure_origin = []
|
|
||||||
elif isinstance(treat_as_secure_origin, URL):
|
|
||||||
treat_as_secure_origin = [treat_as_secure_origin.origin()]
|
|
||||||
elif isinstance(treat_as_secure_origin, str):
|
|
||||||
treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
|
|
||||||
else:
|
|
||||||
treat_as_secure_origin = [
|
|
||||||
URL(url).origin() if isinstance(url, str) else url.origin()
|
|
||||||
for url in treat_as_secure_origin
|
|
||||||
]
|
|
||||||
self._treat_as_secure_origin = treat_as_secure_origin
|
|
||||||
self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
|
|
||||||
self._expirations: Dict[Tuple[str, str, str], float] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def quote_cookie(self) -> bool:
|
|
||||||
return self._quote_cookie
|
|
||||||
|
|
||||||
def save(self, file_path: PathLike) -> None:
|
|
||||||
file_path = pathlib.Path(file_path)
|
|
||||||
with file_path.open(mode="wb") as f:
|
|
||||||
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
|
|
||||||
|
|
||||||
def load(self, file_path: PathLike) -> None:
|
|
||||||
file_path = pathlib.Path(file_path)
|
|
||||||
with file_path.open(mode="rb") as f:
|
|
||||||
self._cookies = pickle.load(f)
|
|
||||||
|
|
||||||
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
|
|
||||||
if predicate is None:
|
|
||||||
self._expire_heap.clear()
|
|
||||||
self._cookies.clear()
|
|
||||||
self._morsel_cache.clear()
|
|
||||||
self._host_only_cookies.clear()
|
|
||||||
self._expirations.clear()
|
|
||||||
return
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
to_del = [
|
|
||||||
key
|
|
||||||
for (domain, path), cookie in self._cookies.items()
|
|
||||||
for name, morsel in cookie.items()
|
|
||||||
if (
|
|
||||||
(key := (domain, path, name)) in self._expirations
|
|
||||||
and self._expirations[key] <= now
|
|
||||||
)
|
|
||||||
or predicate(morsel)
|
|
||||||
]
|
|
||||||
if to_del:
|
|
||||||
self._delete_cookies(to_del)
|
|
||||||
|
|
||||||
def clear_domain(self, domain: str) -> None:
|
|
||||||
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
|
|
||||||
|
|
||||||
def __iter__(self) -> "Iterator[Morsel[str]]":
|
|
||||||
self._do_expiration()
|
|
||||||
for val in self._cookies.values():
|
|
||||||
yield from val.values()
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Return number of cookies.
|
|
||||||
|
|
||||||
This function does not iterate self to avoid unnecessary expiration
|
|
||||||
checks.
|
|
||||||
"""
|
|
||||||
return sum(len(cookie.values()) for cookie in self._cookies.values())
|
|
||||||
|
|
||||||
def _do_expiration(self) -> None:
|
|
||||||
"""Remove expired cookies."""
|
|
||||||
if not (expire_heap_len := len(self._expire_heap)):
|
|
||||||
return
|
|
||||||
|
|
||||||
# If the expiration heap grows larger than the number expirations
|
|
||||||
# times two, we clean it up to avoid keeping expired entries in
|
|
||||||
# the heap and consuming memory. We guard this with a minimum
|
|
||||||
# threshold to avoid cleaning up the heap too often when there are
|
|
||||||
# only a few scheduled expirations.
|
|
||||||
if (
|
|
||||||
expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
|
|
||||||
and expire_heap_len > len(self._expirations) * 2
|
|
||||||
):
|
|
||||||
# Remove any expired entries from the expiration heap
|
|
||||||
# that do not match the expiration time in the expirations
|
|
||||||
# as it means the cookie has been re-added to the heap
|
|
||||||
# with a different expiration time.
|
|
||||||
self._expire_heap = [
|
|
||||||
entry
|
|
||||||
for entry in self._expire_heap
|
|
||||||
if self._expirations.get(entry[1]) == entry[0]
|
|
||||||
]
|
|
||||||
heapq.heapify(self._expire_heap)
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
to_del: List[Tuple[str, str, str]] = []
|
|
||||||
# Find any expired cookies and add them to the to-delete list
|
|
||||||
while self._expire_heap:
|
|
||||||
when, cookie_key = self._expire_heap[0]
|
|
||||||
if when > now:
|
|
||||||
break
|
|
||||||
heapq.heappop(self._expire_heap)
|
|
||||||
# Check if the cookie hasn't been re-added to the heap
|
|
||||||
# with a different expiration time as it will be removed
|
|
||||||
# later when it reaches the top of the heap and its
|
|
||||||
# expiration time is met.
|
|
||||||
if self._expirations.get(cookie_key) == when:
|
|
||||||
to_del.append(cookie_key)
|
|
||||||
|
|
||||||
if to_del:
|
|
||||||
self._delete_cookies(to_del)
|
|
||||||
|
|
||||||
def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
|
|
||||||
for domain, path, name in to_del:
|
|
||||||
self._host_only_cookies.discard((domain, name))
|
|
||||||
self._cookies[(domain, path)].pop(name, None)
|
|
||||||
self._morsel_cache[(domain, path)].pop(name, None)
|
|
||||||
self._expirations.pop((domain, path, name), None)
|
|
||||||
|
|
||||||
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
|
|
||||||
cookie_key = (domain, path, name)
|
|
||||||
if self._expirations.get(cookie_key) == when:
|
|
||||||
# Avoid adding duplicates to the heap
|
|
||||||
return
|
|
||||||
heapq.heappush(self._expire_heap, (when, cookie_key))
|
|
||||||
self._expirations[cookie_key] = when
|
|
||||||
|
|
||||||
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
|
|
||||||
"""Update cookies."""
|
|
||||||
hostname = response_url.raw_host
|
|
||||||
|
|
||||||
if not self._unsafe and is_ip_address(hostname):
|
|
||||||
# Don't accept cookies from IPs
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(cookies, Mapping):
|
|
||||||
cookies = cookies.items()
|
|
||||||
|
|
||||||
for name, cookie in cookies:
|
|
||||||
if not isinstance(cookie, Morsel):
|
|
||||||
tmp = SimpleCookie()
|
|
||||||
tmp[name] = cookie # type: ignore[assignment]
|
|
||||||
cookie = tmp[name]
|
|
||||||
|
|
||||||
domain = cookie["domain"]
|
|
||||||
|
|
||||||
# ignore domains with trailing dots
|
|
||||||
if domain and domain[-1] == ".":
|
|
||||||
domain = ""
|
|
||||||
del cookie["domain"]
|
|
||||||
|
|
||||||
if not domain and hostname is not None:
|
|
||||||
# Set the cookie's domain to the response hostname
|
|
||||||
# and set its host-only-flag
|
|
||||||
self._host_only_cookies.add((hostname, name))
|
|
||||||
domain = cookie["domain"] = hostname
|
|
||||||
|
|
||||||
if domain and domain[0] == ".":
|
|
||||||
# Remove leading dot
|
|
||||||
domain = domain[1:]
|
|
||||||
cookie["domain"] = domain
|
|
||||||
|
|
||||||
if hostname and not self._is_domain_match(domain, hostname):
|
|
||||||
# Setting cookies for different domains is not allowed
|
|
||||||
continue
|
|
||||||
|
|
||||||
path = cookie["path"]
|
|
||||||
if not path or path[0] != "/":
|
|
||||||
# Set the cookie's path to the response path
|
|
||||||
path = response_url.path
|
|
||||||
if not path.startswith("/"):
|
|
||||||
path = "/"
|
|
||||||
else:
|
|
||||||
# Cut everything from the last slash to the end
|
|
||||||
path = "/" + path[1 : path.rfind("/")]
|
|
||||||
cookie["path"] = path
|
|
||||||
path = path.rstrip("/")
|
|
||||||
|
|
||||||
if max_age := cookie["max-age"]:
|
|
||||||
try:
|
|
||||||
delta_seconds = int(max_age)
|
|
||||||
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
|
|
||||||
self._expire_cookie(max_age_expiration, domain, path, name)
|
|
||||||
except ValueError:
|
|
||||||
cookie["max-age"] = ""
|
|
||||||
|
|
||||||
elif expires := cookie["expires"]:
|
|
||||||
if expire_time := self._parse_date(expires):
|
|
||||||
self._expire_cookie(expire_time, domain, path, name)
|
|
||||||
else:
|
|
||||||
cookie["expires"] = ""
|
|
||||||
|
|
||||||
key = (domain, path)
|
|
||||||
if self._cookies[key].get(name) != cookie:
|
|
||||||
# Don't blow away the cache if the same
|
|
||||||
# cookie gets set again
|
|
||||||
self._cookies[key][name] = cookie
|
|
||||||
self._morsel_cache[key].pop(name, None)
|
|
||||||
|
|
||||||
self._do_expiration()
|
|
||||||
|
|
||||||
def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
|
|
||||||
"""Returns this jar's cookies filtered by their attributes."""
|
|
||||||
# We always use BaseCookie now since all
|
|
||||||
# cookies set on on filtered are fully constructed
|
|
||||||
# Morsels, not just names and values.
|
|
||||||
filtered: BaseCookie[str] = BaseCookie()
|
|
||||||
if not self._cookies:
|
|
||||||
# Skip do_expiration() if there are no cookies.
|
|
||||||
return filtered
|
|
||||||
self._do_expiration()
|
|
||||||
if not self._cookies:
|
|
||||||
# Skip rest of function if no non-expired cookies.
|
|
||||||
return filtered
|
|
||||||
if type(request_url) is not URL:
|
|
||||||
warnings.warn(
|
|
||||||
"filter_cookies expects yarl.URL instances only,"
|
|
||||||
f"and will stop working in 4.x, got {type(request_url)}",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
request_url = URL(request_url)
|
|
||||||
hostname = request_url.raw_host or ""
|
|
||||||
|
|
||||||
is_not_secure = request_url.scheme not in ("https", "wss")
|
|
||||||
if is_not_secure and self._treat_as_secure_origin:
|
|
||||||
request_origin = URL()
|
|
||||||
with contextlib.suppress(ValueError):
|
|
||||||
request_origin = request_url.origin()
|
|
||||||
is_not_secure = request_origin not in self._treat_as_secure_origin
|
|
||||||
|
|
||||||
# Send shared cookie
|
|
||||||
key = ("", "")
|
|
||||||
for c in self._cookies[key].values():
|
|
||||||
# Check cache first
|
|
||||||
if c.key in self._morsel_cache[key]:
|
|
||||||
filtered[c.key] = self._morsel_cache[key][c.key]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Build and cache the morsel
|
|
||||||
mrsl_val = self._build_morsel(c)
|
|
||||||
self._morsel_cache[key][c.key] = mrsl_val
|
|
||||||
filtered[c.key] = mrsl_val
|
|
||||||
|
|
||||||
if is_ip_address(hostname):
|
|
||||||
if not self._unsafe:
|
|
||||||
return filtered
|
|
||||||
domains: Iterable[str] = (hostname,)
|
|
||||||
else:
|
|
||||||
# Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
|
|
||||||
domains = itertools.accumulate(
|
|
||||||
reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
|
|
||||||
paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
|
|
||||||
# Create every combination of (domain, path) pairs.
|
|
||||||
pairs = itertools.product(domains, paths)
|
|
||||||
|
|
||||||
path_len = len(request_url.path)
|
|
||||||
# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
|
|
||||||
for p in pairs:
|
|
||||||
if p not in self._cookies:
|
|
||||||
continue
|
|
||||||
for name, cookie in self._cookies[p].items():
|
|
||||||
domain = cookie["domain"]
|
|
||||||
|
|
||||||
if (domain, name) in self._host_only_cookies and domain != hostname:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip edge case when the cookie has a trailing slash but request doesn't.
|
|
||||||
if len(cookie["path"]) > path_len:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_not_secure and cookie["secure"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# We already built the Morsel so reuse it here
|
|
||||||
if name in self._morsel_cache[p]:
|
|
||||||
filtered[name] = self._morsel_cache[p][name]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Build and cache the morsel
|
|
||||||
mrsl_val = self._build_morsel(cookie)
|
|
||||||
self._morsel_cache[p][name] = mrsl_val
|
|
||||||
filtered[name] = mrsl_val
|
|
||||||
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
def _build_morsel(self, cookie: Morsel[str]) -> Morsel[str]:
|
|
||||||
"""Build a morsel for sending, respecting quote_cookie setting."""
|
|
||||||
if self._quote_cookie and cookie.coded_value and cookie.coded_value[0] == '"':
|
|
||||||
return preserve_morsel_with_coded_value(cookie)
|
|
||||||
morsel: Morsel[str] = Morsel()
|
|
||||||
if self._quote_cookie:
|
|
||||||
value, coded_value = _SIMPLE_COOKIE.value_encode(cookie.value)
|
|
||||||
else:
|
|
||||||
coded_value = value = cookie.value
|
|
||||||
# We use __setstate__ instead of the public set() API because it allows us to
|
|
||||||
# bypass validation and set already validated state. This is more stable than
|
|
||||||
# setting protected attributes directly and unlikely to change since it would
|
|
||||||
# break pickling.
|
|
||||||
morsel.__setstate__({"key": cookie.key, "value": value, "coded_value": coded_value}) # type: ignore[attr-defined]
|
|
||||||
return morsel
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_domain_match(domain: str, hostname: str) -> bool:
|
|
||||||
"""Implements domain matching adhering to RFC 6265."""
|
|
||||||
if hostname == domain:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not hostname.endswith(domain):
|
|
||||||
return False
|
|
||||||
|
|
||||||
non_matching = hostname[: -len(domain)]
|
|
||||||
|
|
||||||
if not non_matching.endswith("."):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return not is_ip_address(hostname)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _parse_date(cls, date_str: str) -> Optional[int]:
|
|
||||||
"""Implements date string parsing adhering to RFC 6265."""
|
|
||||||
if not date_str:
|
|
||||||
return None
|
|
||||||
|
|
||||||
found_time = False
|
|
||||||
found_day = False
|
|
||||||
found_month = False
|
|
||||||
found_year = False
|
|
||||||
|
|
||||||
hour = minute = second = 0
|
|
||||||
day = 0
|
|
||||||
month = 0
|
|
||||||
year = 0
|
|
||||||
|
|
||||||
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
|
|
||||||
|
|
||||||
token = token_match.group("token")
|
|
||||||
|
|
||||||
if not found_time:
|
|
||||||
time_match = cls.DATE_HMS_TIME_RE.match(token)
|
|
||||||
if time_match:
|
|
||||||
found_time = True
|
|
||||||
hour, minute, second = (int(s) for s in time_match.groups())
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not found_day:
|
|
||||||
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
|
|
||||||
if day_match:
|
|
||||||
found_day = True
|
|
||||||
day = int(day_match.group())
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not found_month:
|
|
||||||
month_match = cls.DATE_MONTH_RE.match(token)
|
|
||||||
if month_match:
|
|
||||||
found_month = True
|
|
||||||
assert month_match.lastindex is not None
|
|
||||||
month = month_match.lastindex
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not found_year:
|
|
||||||
year_match = cls.DATE_YEAR_RE.match(token)
|
|
||||||
if year_match:
|
|
||||||
found_year = True
|
|
||||||
year = int(year_match.group())
|
|
||||||
|
|
||||||
if 70 <= year <= 99:
|
|
||||||
year += 1900
|
|
||||||
elif 0 <= year <= 69:
|
|
||||||
year += 2000
|
|
||||||
|
|
||||||
if False in (found_day, found_month, found_year, found_time):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not 1 <= day <= 31:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if year < 1601 or hour > 23 or minute > 59 or second > 59:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
|
|
||||||
|
|
||||||
|
|
||||||
class DummyCookieJar(AbstractCookieJar):
|
|
||||||
"""Implements a dummy cookie storage.
|
|
||||||
|
|
||||||
It can be used with the ClientSession when no cookie processing is needed.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
|
||||||
super().__init__(loop=loop)
|
|
||||||
|
|
||||||
def __iter__(self) -> "Iterator[Morsel[str]]":
|
|
||||||
while False:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def quote_cookie(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear_domain(self, domain: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
|
|
||||||
return SimpleCookie()
|
|
||||||
|
|
@ -1,179 +0,0 @@
|
||||||
import io
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Iterable, List, Optional
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
|
|
||||||
from multidict import MultiDict, MultiDictProxy
|
|
||||||
|
|
||||||
from . import hdrs, multipart, payload
|
|
||||||
from .helpers import guess_filename
|
|
||||||
from .payload import Payload
|
|
||||||
|
|
||||||
__all__ = ("FormData",)
|
|
||||||
|
|
||||||
|
|
||||||
class FormData:
|
|
||||||
"""Helper class for form body generation.
|
|
||||||
|
|
||||||
Supports multipart/form-data and application/x-www-form-urlencoded.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
fields: Iterable[Any] = (),
|
|
||||||
quote_fields: bool = True,
|
|
||||||
charset: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
default_to_multipart: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self._writer = multipart.MultipartWriter("form-data")
|
|
||||||
self._fields: List[Any] = []
|
|
||||||
self._is_multipart = default_to_multipart
|
|
||||||
self._quote_fields = quote_fields
|
|
||||||
self._charset = charset
|
|
||||||
|
|
||||||
if isinstance(fields, dict):
|
|
||||||
fields = list(fields.items())
|
|
||||||
elif not isinstance(fields, (list, tuple)):
|
|
||||||
fields = (fields,)
|
|
||||||
self.add_fields(*fields)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_multipart(self) -> bool:
|
|
||||||
return self._is_multipart
|
|
||||||
|
|
||||||
def add_field(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
value: Any,
|
|
||||||
*,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
filename: Optional[str] = None,
|
|
||||||
content_transfer_encoding: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
if isinstance(value, io.IOBase):
|
|
||||||
self._is_multipart = True
|
|
||||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
|
||||||
msg = (
|
|
||||||
"In v4, passing bytes will no longer create a file field. "
|
|
||||||
"Please explicitly use the filename parameter or pass a BytesIO object."
|
|
||||||
)
|
|
||||||
if filename is None and content_transfer_encoding is None:
|
|
||||||
warnings.warn(msg, DeprecationWarning)
|
|
||||||
filename = name
|
|
||||||
|
|
||||||
type_options: MultiDict[str] = MultiDict({"name": name})
|
|
||||||
if filename is not None and not isinstance(filename, str):
|
|
||||||
raise TypeError("filename must be an instance of str. Got: %s" % filename)
|
|
||||||
if filename is None and isinstance(value, io.IOBase):
|
|
||||||
filename = guess_filename(value, name)
|
|
||||||
if filename is not None:
|
|
||||||
type_options["filename"] = filename
|
|
||||||
self._is_multipart = True
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
if content_type is not None:
|
|
||||||
if not isinstance(content_type, str):
|
|
||||||
raise TypeError(
|
|
||||||
"content_type must be an instance of str. Got: %s" % content_type
|
|
||||||
)
|
|
||||||
headers[hdrs.CONTENT_TYPE] = content_type
|
|
||||||
self._is_multipart = True
|
|
||||||
if content_transfer_encoding is not None:
|
|
||||||
if not isinstance(content_transfer_encoding, str):
|
|
||||||
raise TypeError(
|
|
||||||
"content_transfer_encoding must be an instance"
|
|
||||||
" of str. Got: %s" % content_transfer_encoding
|
|
||||||
)
|
|
||||||
msg = (
|
|
||||||
"content_transfer_encoding is deprecated. "
|
|
||||||
"To maintain compatibility with v4 please pass a BytesPayload."
|
|
||||||
)
|
|
||||||
warnings.warn(msg, DeprecationWarning)
|
|
||||||
self._is_multipart = True
|
|
||||||
|
|
||||||
self._fields.append((type_options, headers, value))
|
|
||||||
|
|
||||||
def add_fields(self, *fields: Any) -> None:
|
|
||||||
to_add = list(fields)
|
|
||||||
|
|
||||||
while to_add:
|
|
||||||
rec = to_add.pop(0)
|
|
||||||
|
|
||||||
if isinstance(rec, io.IOBase):
|
|
||||||
k = guess_filename(rec, "unknown")
|
|
||||||
self.add_field(k, rec) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
elif isinstance(rec, (MultiDictProxy, MultiDict)):
|
|
||||||
to_add.extend(rec.items())
|
|
||||||
|
|
||||||
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
|
|
||||||
k, fp = rec
|
|
||||||
self.add_field(k, fp)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
"Only io.IOBase, multidict and (name, file) "
|
|
||||||
"pairs allowed, use .add_field() for passing "
|
|
||||||
"more complex parameters, got {!r}".format(rec)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _gen_form_urlencoded(self) -> payload.BytesPayload:
|
|
||||||
# form data (x-www-form-urlencoded)
|
|
||||||
data = []
|
|
||||||
for type_options, _, value in self._fields:
|
|
||||||
data.append((type_options["name"], value))
|
|
||||||
|
|
||||||
charset = self._charset if self._charset is not None else "utf-8"
|
|
||||||
|
|
||||||
if charset == "utf-8":
|
|
||||||
content_type = "application/x-www-form-urlencoded"
|
|
||||||
else:
|
|
||||||
content_type = "application/x-www-form-urlencoded; charset=%s" % charset
|
|
||||||
|
|
||||||
return payload.BytesPayload(
|
|
||||||
urlencode(data, doseq=True, encoding=charset).encode(),
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _gen_form_data(self) -> multipart.MultipartWriter:
|
|
||||||
"""Encode a list of fields using the multipart/form-data MIME format"""
|
|
||||||
for dispparams, headers, value in self._fields:
|
|
||||||
try:
|
|
||||||
if hdrs.CONTENT_TYPE in headers:
|
|
||||||
part = payload.get_payload(
|
|
||||||
value,
|
|
||||||
content_type=headers[hdrs.CONTENT_TYPE],
|
|
||||||
headers=headers,
|
|
||||||
encoding=self._charset,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
part = payload.get_payload(
|
|
||||||
value, headers=headers, encoding=self._charset
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
raise TypeError(
|
|
||||||
"Can not serialize value type: %r\n "
|
|
||||||
"headers: %r\n value: %r" % (type(value), headers, value)
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
if dispparams:
|
|
||||||
part.set_content_disposition(
|
|
||||||
"form-data", quote_fields=self._quote_fields, **dispparams
|
|
||||||
)
|
|
||||||
# FIXME cgi.FieldStorage doesn't likes body parts with
|
|
||||||
# Content-Length which were sent via chunked transfer encoding
|
|
||||||
assert part.headers is not None
|
|
||||||
part.headers.popall(hdrs.CONTENT_LENGTH, None)
|
|
||||||
|
|
||||||
self._writer.append_payload(part)
|
|
||||||
|
|
||||||
self._fields.clear()
|
|
||||||
return self._writer
|
|
||||||
|
|
||||||
def __call__(self) -> Payload:
|
|
||||||
if self._is_multipart:
|
|
||||||
return self._gen_form_data()
|
|
||||||
else:
|
|
||||||
return self._gen_form_urlencoded()
|
|
||||||
|
|
@ -1,121 +0,0 @@
|
||||||
"""HTTP Headers constants."""
|
|
||||||
|
|
||||||
# After changing the file content call ./tools/gen.py
|
|
||||||
# to regenerate the headers parser
|
|
||||||
import itertools
|
|
||||||
from typing import Final, Set
|
|
||||||
|
|
||||||
from multidict import istr
|
|
||||||
|
|
||||||
METH_ANY: Final[str] = "*"
|
|
||||||
METH_CONNECT: Final[str] = "CONNECT"
|
|
||||||
METH_HEAD: Final[str] = "HEAD"
|
|
||||||
METH_GET: Final[str] = "GET"
|
|
||||||
METH_DELETE: Final[str] = "DELETE"
|
|
||||||
METH_OPTIONS: Final[str] = "OPTIONS"
|
|
||||||
METH_PATCH: Final[str] = "PATCH"
|
|
||||||
METH_POST: Final[str] = "POST"
|
|
||||||
METH_PUT: Final[str] = "PUT"
|
|
||||||
METH_TRACE: Final[str] = "TRACE"
|
|
||||||
|
|
||||||
METH_ALL: Final[Set[str]] = {
|
|
||||||
METH_CONNECT,
|
|
||||||
METH_HEAD,
|
|
||||||
METH_GET,
|
|
||||||
METH_DELETE,
|
|
||||||
METH_OPTIONS,
|
|
||||||
METH_PATCH,
|
|
||||||
METH_POST,
|
|
||||||
METH_PUT,
|
|
||||||
METH_TRACE,
|
|
||||||
}
|
|
||||||
|
|
||||||
ACCEPT: Final[istr] = istr("Accept")
|
|
||||||
ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset")
|
|
||||||
ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding")
|
|
||||||
ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language")
|
|
||||||
ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges")
|
|
||||||
ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age")
|
|
||||||
ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials")
|
|
||||||
ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers")
|
|
||||||
ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods")
|
|
||||||
ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin")
|
|
||||||
ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers")
|
|
||||||
ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers")
|
|
||||||
ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method")
|
|
||||||
AGE: Final[istr] = istr("Age")
|
|
||||||
ALLOW: Final[istr] = istr("Allow")
|
|
||||||
AUTHORIZATION: Final[istr] = istr("Authorization")
|
|
||||||
CACHE_CONTROL: Final[istr] = istr("Cache-Control")
|
|
||||||
CONNECTION: Final[istr] = istr("Connection")
|
|
||||||
CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition")
|
|
||||||
CONTENT_ENCODING: Final[istr] = istr("Content-Encoding")
|
|
||||||
CONTENT_LANGUAGE: Final[istr] = istr("Content-Language")
|
|
||||||
CONTENT_LENGTH: Final[istr] = istr("Content-Length")
|
|
||||||
CONTENT_LOCATION: Final[istr] = istr("Content-Location")
|
|
||||||
CONTENT_MD5: Final[istr] = istr("Content-MD5")
|
|
||||||
CONTENT_RANGE: Final[istr] = istr("Content-Range")
|
|
||||||
CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding")
|
|
||||||
CONTENT_TYPE: Final[istr] = istr("Content-Type")
|
|
||||||
COOKIE: Final[istr] = istr("Cookie")
|
|
||||||
DATE: Final[istr] = istr("Date")
|
|
||||||
DESTINATION: Final[istr] = istr("Destination")
|
|
||||||
DIGEST: Final[istr] = istr("Digest")
|
|
||||||
ETAG: Final[istr] = istr("Etag")
|
|
||||||
EXPECT: Final[istr] = istr("Expect")
|
|
||||||
EXPIRES: Final[istr] = istr("Expires")
|
|
||||||
FORWARDED: Final[istr] = istr("Forwarded")
|
|
||||||
FROM: Final[istr] = istr("From")
|
|
||||||
HOST: Final[istr] = istr("Host")
|
|
||||||
IF_MATCH: Final[istr] = istr("If-Match")
|
|
||||||
IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since")
|
|
||||||
IF_NONE_MATCH: Final[istr] = istr("If-None-Match")
|
|
||||||
IF_RANGE: Final[istr] = istr("If-Range")
|
|
||||||
IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since")
|
|
||||||
KEEP_ALIVE: Final[istr] = istr("Keep-Alive")
|
|
||||||
LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID")
|
|
||||||
LAST_MODIFIED: Final[istr] = istr("Last-Modified")
|
|
||||||
LINK: Final[istr] = istr("Link")
|
|
||||||
LOCATION: Final[istr] = istr("Location")
|
|
||||||
MAX_FORWARDS: Final[istr] = istr("Max-Forwards")
|
|
||||||
ORIGIN: Final[istr] = istr("Origin")
|
|
||||||
PRAGMA: Final[istr] = istr("Pragma")
|
|
||||||
PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate")
|
|
||||||
PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization")
|
|
||||||
RANGE: Final[istr] = istr("Range")
|
|
||||||
REFERER: Final[istr] = istr("Referer")
|
|
||||||
RETRY_AFTER: Final[istr] = istr("Retry-After")
|
|
||||||
SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept")
|
|
||||||
SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version")
|
|
||||||
SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol")
|
|
||||||
SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions")
|
|
||||||
SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key")
|
|
||||||
SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1")
|
|
||||||
SERVER: Final[istr] = istr("Server")
|
|
||||||
SET_COOKIE: Final[istr] = istr("Set-Cookie")
|
|
||||||
TE: Final[istr] = istr("TE")
|
|
||||||
TRAILER: Final[istr] = istr("Trailer")
|
|
||||||
TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding")
|
|
||||||
UPGRADE: Final[istr] = istr("Upgrade")
|
|
||||||
URI: Final[istr] = istr("URI")
|
|
||||||
USER_AGENT: Final[istr] = istr("User-Agent")
|
|
||||||
VARY: Final[istr] = istr("Vary")
|
|
||||||
VIA: Final[istr] = istr("Via")
|
|
||||||
WANT_DIGEST: Final[istr] = istr("Want-Digest")
|
|
||||||
WARNING: Final[istr] = istr("Warning")
|
|
||||||
WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate")
|
|
||||||
X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For")
|
|
||||||
X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host")
|
|
||||||
X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto")
|
|
||||||
|
|
||||||
# These are the upper/lower case variants of the headers/methods
|
|
||||||
# Example: {'hOst', 'host', 'HoST', 'HOSt', 'hOsT', 'HosT', 'hoSt', ...}
|
|
||||||
METH_HEAD_ALL: Final = frozenset(
|
|
||||||
map("".join, itertools.product(*zip(METH_HEAD.upper(), METH_HEAD.lower())))
|
|
||||||
)
|
|
||||||
METH_CONNECT_ALL: Final = frozenset(
|
|
||||||
map("".join, itertools.product(*zip(METH_CONNECT.upper(), METH_CONNECT.lower())))
|
|
||||||
)
|
|
||||||
HOST_ALL: Final = frozenset(
|
|
||||||
map("".join, itertools.product(*zip(HOST.upper(), HOST.lower())))
|
|
||||||
)
|
|
||||||
|
|
@ -1,986 +0,0 @@
|
||||||
"""Various helper functions"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import binascii
|
|
||||||
import contextlib
|
|
||||||
import datetime
|
|
||||||
import enum
|
|
||||||
import functools
|
|
||||||
import inspect
|
|
||||||
import netrc
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import weakref
|
|
||||||
from collections import namedtuple
|
|
||||||
from contextlib import suppress
|
|
||||||
from email.message import EmailMessage
|
|
||||||
from email.parser import HeaderParser
|
|
||||||
from email.policy import HTTP
|
|
||||||
from email.utils import parsedate
|
|
||||||
from math import ceil
|
|
||||||
from pathlib import Path
|
|
||||||
from types import MappingProxyType, TracebackType
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
ContextManager,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Generic,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
get_args,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
from urllib.parse import quote
|
|
||||||
from urllib.request import getproxies, proxy_bypass
|
|
||||||
|
|
||||||
import attr
|
|
||||||
from multidict import MultiDict, MultiDictProxy, MultiMapping
|
|
||||||
from propcache.api import under_cached_property as reify
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from . import hdrs
|
|
||||||
from .log import client_logger
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
|
||||||
import asyncio as async_timeout
|
|
||||||
else:
|
|
||||||
import async_timeout
|
|
||||||
|
|
||||||
__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify")
|
|
||||||
|
|
||||||
IS_MACOS = platform.system() == "Darwin"
|
|
||||||
IS_WINDOWS = platform.system() == "Windows"
|
|
||||||
|
|
||||||
PY_310 = sys.version_info >= (3, 10)
|
|
||||||
PY_311 = sys.version_info >= (3, 11)
|
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_S = TypeVar("_S")
|
|
||||||
|
|
||||||
_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
|
|
||||||
sentinel = _SENTINEL.sentinel
|
|
||||||
|
|
||||||
NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
|
|
||||||
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
|
|
||||||
EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
|
|
||||||
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
|
|
||||||
EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
|
|
||||||
|
|
||||||
DEBUG = sys.flags.dev_mode or (
|
|
||||||
not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CHAR = {chr(i) for i in range(0, 128)}
|
|
||||||
CTL = {chr(i) for i in range(0, 32)} | {
|
|
||||||
chr(127),
|
|
||||||
}
|
|
||||||
SEPARATORS = {
|
|
||||||
"(",
|
|
||||||
")",
|
|
||||||
"<",
|
|
||||||
">",
|
|
||||||
"@",
|
|
||||||
",",
|
|
||||||
";",
|
|
||||||
":",
|
|
||||||
"\\",
|
|
||||||
'"',
|
|
||||||
"/",
|
|
||||||
"[",
|
|
||||||
"]",
|
|
||||||
"?",
|
|
||||||
"=",
|
|
||||||
"{",
|
|
||||||
"}",
|
|
||||||
" ",
|
|
||||||
chr(9),
|
|
||||||
}
|
|
||||||
TOKEN = CHAR ^ CTL ^ SEPARATORS
|
|
||||||
|
|
||||||
|
|
||||||
class noop:
|
|
||||||
def __await__(self) -> Generator[None, None, None]:
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
|
|
||||||
"""Http basic authentication helper."""
|
|
||||||
|
|
||||||
def __new__(
|
|
||||||
cls, login: str, password: str = "", encoding: str = "latin1"
|
|
||||||
) -> "BasicAuth":
|
|
||||||
if login is None:
|
|
||||||
raise ValueError("None is not allowed as login value")
|
|
||||||
|
|
||||||
if password is None:
|
|
||||||
raise ValueError("None is not allowed as password value")
|
|
||||||
|
|
||||||
if ":" in login:
|
|
||||||
raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
|
|
||||||
|
|
||||||
return super().__new__(cls, login, password, encoding)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
|
|
||||||
"""Create a BasicAuth object from an Authorization HTTP header."""
|
|
||||||
try:
|
|
||||||
auth_type, encoded_credentials = auth_header.split(" ", 1)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("Could not parse authorization header.")
|
|
||||||
|
|
||||||
if auth_type.lower() != "basic":
|
|
||||||
raise ValueError("Unknown authorization method %s" % auth_type)
|
|
||||||
|
|
||||||
try:
|
|
||||||
decoded = base64.b64decode(
|
|
||||||
encoded_credentials.encode("ascii"), validate=True
|
|
||||||
).decode(encoding)
|
|
||||||
except binascii.Error:
|
|
||||||
raise ValueError("Invalid base64 encoding.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# RFC 2617 HTTP Authentication
|
|
||||||
# https://www.ietf.org/rfc/rfc2617.txt
|
|
||||||
# the colon must be present, but the username and password may be
|
|
||||||
# otherwise blank.
|
|
||||||
username, password = decoded.split(":", 1)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("Invalid credentials.")
|
|
||||||
|
|
||||||
return cls(username, password, encoding=encoding)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
|
|
||||||
"""Create BasicAuth from url."""
|
|
||||||
if not isinstance(url, URL):
|
|
||||||
raise TypeError("url should be yarl.URL instance")
|
|
||||||
# Check raw_user and raw_password first as yarl is likely
|
|
||||||
# to already have these values parsed from the netloc in the cache.
|
|
||||||
if url.raw_user is None and url.raw_password is None:
|
|
||||||
return None
|
|
||||||
return cls(url.user or "", url.password or "", encoding=encoding)
|
|
||||||
|
|
||||||
def encode(self) -> str:
|
|
||||||
"""Encode credentials."""
|
|
||||||
creds = (f"{self.login}:{self.password}").encode(self.encoding)
|
|
||||||
return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
|
|
||||||
|
|
||||||
|
|
||||||
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
|
|
||||||
"""Remove user and password from URL if present and return BasicAuth object."""
|
|
||||||
# Check raw_user and raw_password first as yarl is likely
|
|
||||||
# to already have these values parsed from the netloc in the cache.
|
|
||||||
if url.raw_user is None and url.raw_password is None:
|
|
||||||
return url, None
|
|
||||||
return url.with_user(None), BasicAuth(url.user or "", url.password or "")
|
|
||||||
|
|
||||||
|
|
||||||
def netrc_from_env() -> Optional[netrc.netrc]:
|
|
||||||
"""Load netrc from file.
|
|
||||||
|
|
||||||
Attempt to load it from the path specified by the env-var
|
|
||||||
NETRC or in the default location in the user's home directory.
|
|
||||||
|
|
||||||
Returns None if it couldn't be found or fails to parse.
|
|
||||||
"""
|
|
||||||
netrc_env = os.environ.get("NETRC")
|
|
||||||
|
|
||||||
if netrc_env is not None:
|
|
||||||
netrc_path = Path(netrc_env)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
home_dir = Path.home()
|
|
||||||
except RuntimeError as e: # pragma: no cover
|
|
||||||
# if pathlib can't resolve home, it may raise a RuntimeError
|
|
||||||
client_logger.debug(
|
|
||||||
"Could not resolve home directory when "
|
|
||||||
"trying to look for .netrc file: %s",
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
|
|
||||||
|
|
||||||
try:
|
|
||||||
return netrc.netrc(str(netrc_path))
|
|
||||||
except netrc.NetrcParseError as e:
|
|
||||||
client_logger.warning("Could not parse .netrc file: %s", e)
|
|
||||||
except OSError as e:
|
|
||||||
netrc_exists = False
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
netrc_exists = netrc_path.is_file()
|
|
||||||
# we couldn't read the file (doesn't exist, permissions, etc.)
|
|
||||||
if netrc_env or netrc_exists:
|
|
||||||
# only warn if the environment wanted us to load it,
|
|
||||||
# or it appears like the default file does actually exist
|
|
||||||
client_logger.warning("Could not read .netrc file: %s", e)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class ProxyInfo:
|
|
||||||
proxy: URL
|
|
||||||
proxy_auth: Optional[BasicAuth]
|
|
||||||
|
|
||||||
|
|
||||||
def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
|
|
||||||
"""
|
|
||||||
Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
|
|
||||||
|
|
||||||
:raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
|
|
||||||
entry is found for the ``host``.
|
|
||||||
"""
|
|
||||||
if netrc_obj is None:
|
|
||||||
raise LookupError("No .netrc file found")
|
|
||||||
auth_from_netrc = netrc_obj.authenticators(host)
|
|
||||||
|
|
||||||
if auth_from_netrc is None:
|
|
||||||
raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
|
|
||||||
login, account, password = auth_from_netrc
|
|
||||||
|
|
||||||
# TODO(PY311): username = login or account
|
|
||||||
# Up to python 3.10, account could be None if not specified,
|
|
||||||
# and login will be empty string if not specified. From 3.11,
|
|
||||||
# login and account will be empty string if not specified.
|
|
||||||
username = login if (login or account is None) else account
|
|
||||||
|
|
||||||
# TODO(PY311): Remove this, as password will be empty string
|
|
||||||
# if not specified
|
|
||||||
if password is None:
|
|
||||||
password = ""
|
|
||||||
|
|
||||||
return BasicAuth(username, password)
|
|
||||||
|
|
||||||
|
|
||||||
def proxies_from_env() -> Dict[str, ProxyInfo]:
|
|
||||||
proxy_urls = {
|
|
||||||
k: URL(v)
|
|
||||||
for k, v in getproxies().items()
|
|
||||||
if k in ("http", "https", "ws", "wss")
|
|
||||||
}
|
|
||||||
netrc_obj = netrc_from_env()
|
|
||||||
stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
|
|
||||||
ret = {}
|
|
||||||
for proto, val in stripped.items():
|
|
||||||
proxy, auth = val
|
|
||||||
if proxy.scheme in ("https", "wss"):
|
|
||||||
client_logger.warning(
|
|
||||||
"%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
if netrc_obj and auth is None:
|
|
||||||
if proxy.host is not None:
|
|
||||||
try:
|
|
||||||
auth = basicauth_from_netrc(netrc_obj, proxy.host)
|
|
||||||
except LookupError:
|
|
||||||
auth = None
|
|
||||||
ret[proto] = ProxyInfo(proxy, auth)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
|
|
||||||
"""Get a permitted proxy for the given URL from the env."""
|
|
||||||
if url.host is not None and proxy_bypass(url.host):
|
|
||||||
raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
|
|
||||||
|
|
||||||
proxies_in_env = proxies_from_env()
|
|
||||||
try:
|
|
||||||
proxy_info = proxies_in_env[url.scheme]
|
|
||||||
except KeyError:
|
|
||||||
raise LookupError(f"No proxies found for `{url!s}` in the env")
|
|
||||||
else:
|
|
||||||
return proxy_info.proxy, proxy_info.proxy_auth
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class MimeType:
|
|
||||||
type: str
|
|
||||||
subtype: str
|
|
||||||
suffix: str
|
|
||||||
parameters: "MultiDictProxy[str]"
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=56)
|
|
||||||
def parse_mimetype(mimetype: str) -> MimeType:
|
|
||||||
"""Parses a MIME type into its components.
|
|
||||||
|
|
||||||
mimetype is a MIME type string.
|
|
||||||
|
|
||||||
Returns a MimeType object.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
>>> parse_mimetype('text/html; charset=utf-8')
|
|
||||||
MimeType(type='text', subtype='html', suffix='',
|
|
||||||
parameters={'charset': 'utf-8'})
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not mimetype:
|
|
||||||
return MimeType(
|
|
||||||
type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
|
|
||||||
)
|
|
||||||
|
|
||||||
parts = mimetype.split(";")
|
|
||||||
params: MultiDict[str] = MultiDict()
|
|
||||||
for item in parts[1:]:
|
|
||||||
if not item:
|
|
||||||
continue
|
|
||||||
key, _, value = item.partition("=")
|
|
||||||
params.add(key.lower().strip(), value.strip(' "'))
|
|
||||||
|
|
||||||
fulltype = parts[0].strip().lower()
|
|
||||||
if fulltype == "*":
|
|
||||||
fulltype = "*/*"
|
|
||||||
|
|
||||||
mtype, _, stype = fulltype.partition("/")
|
|
||||||
stype, _, suffix = stype.partition("+")
|
|
||||||
|
|
||||||
return MimeType(
|
|
||||||
type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EnsureOctetStream(EmailMessage):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
# https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
|
|
||||||
self.set_default_type("application/octet-stream")
|
|
||||||
|
|
||||||
def get_content_type(self) -> str:
|
|
||||||
"""Re-implementation from Message
|
|
||||||
|
|
||||||
Returns application/octet-stream in place of plain/text when
|
|
||||||
value is wrong.
|
|
||||||
|
|
||||||
The way this class is used guarantees that content-type will
|
|
||||||
be present so simplify the checks wrt to the base implementation.
|
|
||||||
"""
|
|
||||||
value = self.get("content-type", "").lower()
|
|
||||||
|
|
||||||
# Based on the implementation of _splitparam in the standard library
|
|
||||||
ctype, _, _ = value.partition(";")
|
|
||||||
ctype = ctype.strip()
|
|
||||||
if ctype.count("/") != 1:
|
|
||||||
return self.get_default_type()
|
|
||||||
return ctype
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=56)
|
|
||||||
def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]:
|
|
||||||
"""Parse Content-Type header.
|
|
||||||
|
|
||||||
Returns a tuple of the parsed content type and a
|
|
||||||
MappingProxyType of parameters. The default returned value
|
|
||||||
is `application/octet-stream`
|
|
||||||
"""
|
|
||||||
msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}")
|
|
||||||
content_type = msg.get_content_type()
|
|
||||||
params = msg.get_params(())
|
|
||||||
content_dict = dict(params[1:]) # First element is content type again
|
|
||||||
return content_type, MappingProxyType(content_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
|
|
||||||
name = getattr(obj, "name", None)
|
|
||||||
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
|
|
||||||
return Path(name).name
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
|
|
||||||
QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
|
|
||||||
|
|
||||||
|
|
||||||
def quoted_string(content: str) -> str:
|
|
||||||
"""Return 7-bit content as quoted-string.
|
|
||||||
|
|
||||||
Format content into a quoted-string as defined in RFC5322 for
|
|
||||||
Internet Message Format. Notice that this is not the 8-bit HTTP
|
|
||||||
format, but the 7-bit email format. Content must be in usascii or
|
|
||||||
a ValueError is raised.
|
|
||||||
"""
|
|
||||||
if not (QCONTENT > set(content)):
|
|
||||||
raise ValueError(f"bad content for quoted-string {content!r}")
|
|
||||||
return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
|
|
||||||
|
|
||||||
|
|
||||||
def content_disposition_header(
|
|
||||||
disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
|
|
||||||
) -> str:
|
|
||||||
"""Sets ``Content-Disposition`` header for MIME.
|
|
||||||
|
|
||||||
This is the MIME payload Content-Disposition header from RFC 2183
|
|
||||||
and RFC 7579 section 4.2, not the HTTP Content-Disposition from
|
|
||||||
RFC 6266.
|
|
||||||
|
|
||||||
disptype is a disposition type: inline, attachment, form-data.
|
|
||||||
Should be valid extension token (see RFC 2183)
|
|
||||||
|
|
||||||
quote_fields performs value quoting to 7-bit MIME headers
|
|
||||||
according to RFC 7578. Set to quote_fields to False if recipient
|
|
||||||
can take 8-bit file names and field values.
|
|
||||||
|
|
||||||
_charset specifies the charset to use when quote_fields is True.
|
|
||||||
|
|
||||||
params is a dict with disposition params.
|
|
||||||
"""
|
|
||||||
if not disptype or not (TOKEN > set(disptype)):
|
|
||||||
raise ValueError(f"bad content disposition type {disptype!r}")
|
|
||||||
|
|
||||||
value = disptype
|
|
||||||
if params:
|
|
||||||
lparams = []
|
|
||||||
for key, val in params.items():
|
|
||||||
if not key or not (TOKEN > set(key)):
|
|
||||||
raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
|
|
||||||
if quote_fields:
|
|
||||||
if key.lower() == "filename":
|
|
||||||
qval = quote(val, "", encoding=_charset)
|
|
||||||
lparams.append((key, '"%s"' % qval))
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
qval = quoted_string(val)
|
|
||||||
except ValueError:
|
|
||||||
qval = "".join(
|
|
||||||
(_charset, "''", quote(val, "", encoding=_charset))
|
|
||||||
)
|
|
||||||
lparams.append((key + "*", qval))
|
|
||||||
else:
|
|
||||||
lparams.append((key, '"%s"' % qval))
|
|
||||||
else:
|
|
||||||
qval = val.replace("\\", "\\\\").replace('"', '\\"')
|
|
||||||
lparams.append((key, '"%s"' % qval))
|
|
||||||
sparams = "; ".join("=".join(pair) for pair in lparams)
|
|
||||||
value = "; ".join((value, sparams))
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def is_ip_address(host: Optional[str]) -> bool:
|
|
||||||
"""Check if host looks like an IP Address.
|
|
||||||
|
|
||||||
This check is only meant as a heuristic to ensure that
|
|
||||||
a host is not a domain name.
|
|
||||||
"""
|
|
||||||
if not host:
|
|
||||||
return False
|
|
||||||
# For a host to be an ipv4 address, it must be all numeric.
|
|
||||||
# The host must contain a colon to be an IPv6 address.
|
|
||||||
return ":" in host or host.replace(".", "").isdigit()
|
|
||||||
|
|
||||||
|
|
||||||
_cached_current_datetime: Optional[int] = None
|
|
||||||
_cached_formatted_datetime = ""
|
|
||||||
|
|
||||||
|
|
||||||
def rfc822_formatted_time() -> str:
|
|
||||||
global _cached_current_datetime
|
|
||||||
global _cached_formatted_datetime
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
if now != _cached_current_datetime:
|
|
||||||
# Weekday and month names for HTTP date/time formatting;
|
|
||||||
# always English!
|
|
||||||
# Tuples are constants stored in codeobject!
|
|
||||||
_weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
|
|
||||||
_monthname = (
|
|
||||||
"", # Dummy so we can use 1-based month numbers
|
|
||||||
"Jan",
|
|
||||||
"Feb",
|
|
||||||
"Mar",
|
|
||||||
"Apr",
|
|
||||||
"May",
|
|
||||||
"Jun",
|
|
||||||
"Jul",
|
|
||||||
"Aug",
|
|
||||||
"Sep",
|
|
||||||
"Oct",
|
|
||||||
"Nov",
|
|
||||||
"Dec",
|
|
||||||
)
|
|
||||||
|
|
||||||
year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
|
|
||||||
_cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
|
|
||||||
_weekdayname[wd],
|
|
||||||
day,
|
|
||||||
_monthname[month],
|
|
||||||
year,
|
|
||||||
hh,
|
|
||||||
mm,
|
|
||||||
ss,
|
|
||||||
)
|
|
||||||
_cached_current_datetime = now
|
|
||||||
return _cached_formatted_datetime
|
|
||||||
|
|
||||||
|
|
||||||
def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
|
|
||||||
ref, name = info
|
|
||||||
ob = ref()
|
|
||||||
if ob is not None:
|
|
||||||
with suppress(Exception):
|
|
||||||
getattr(ob, name)()
|
|
||||||
|
|
||||||
|
|
||||||
def weakref_handle(
|
|
||||||
ob: object,
|
|
||||||
name: str,
|
|
||||||
timeout: float,
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
timeout_ceil_threshold: float = 5,
|
|
||||||
) -> Optional[asyncio.TimerHandle]:
|
|
||||||
if timeout is not None and timeout > 0:
|
|
||||||
when = loop.time() + timeout
|
|
||||||
if timeout >= timeout_ceil_threshold:
|
|
||||||
when = ceil(when)
|
|
||||||
|
|
||||||
return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def call_later(
|
|
||||||
cb: Callable[[], Any],
|
|
||||||
timeout: float,
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
timeout_ceil_threshold: float = 5,
|
|
||||||
) -> Optional[asyncio.TimerHandle]:
|
|
||||||
if timeout is None or timeout <= 0:
|
|
||||||
return None
|
|
||||||
now = loop.time()
|
|
||||||
when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
|
|
||||||
return loop.call_at(when, cb)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_timeout_when(
|
|
||||||
loop_time: float,
|
|
||||||
timeout: float,
|
|
||||||
timeout_ceiling_threshold: float,
|
|
||||||
) -> float:
|
|
||||||
"""Calculate when to execute a timeout."""
|
|
||||||
when = loop_time + timeout
|
|
||||||
if timeout > timeout_ceiling_threshold:
|
|
||||||
return ceil(when)
|
|
||||||
return when
|
|
||||||
|
|
||||||
|
|
||||||
class TimeoutHandle:
|
|
||||||
"""Timeout handle"""
|
|
||||||
|
|
||||||
__slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
timeout: Optional[float],
|
|
||||||
ceil_threshold: float = 5,
|
|
||||||
) -> None:
|
|
||||||
self._timeout = timeout
|
|
||||||
self._loop = loop
|
|
||||||
self._ceil_threshold = ceil_threshold
|
|
||||||
self._callbacks: List[
|
|
||||||
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
|
|
||||||
] = []
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self, callback: Callable[..., None], *args: Any, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
self._callbacks.append((callback, args, kwargs))
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._callbacks.clear()
|
|
||||||
|
|
||||||
def start(self) -> Optional[asyncio.TimerHandle]:
|
|
||||||
timeout = self._timeout
|
|
||||||
if timeout is not None and timeout > 0:
|
|
||||||
when = self._loop.time() + timeout
|
|
||||||
if timeout >= self._ceil_threshold:
|
|
||||||
when = ceil(when)
|
|
||||||
return self._loop.call_at(when, self.__call__)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def timer(self) -> "BaseTimerContext":
|
|
||||||
if self._timeout is not None and self._timeout > 0:
|
|
||||||
timer = TimerContext(self._loop)
|
|
||||||
self.register(timer.timeout)
|
|
||||||
return timer
|
|
||||||
else:
|
|
||||||
return TimerNoop()
|
|
||||||
|
|
||||||
def __call__(self) -> None:
|
|
||||||
for cb, args, kwargs in self._callbacks:
|
|
||||||
with suppress(Exception):
|
|
||||||
cb(*args, **kwargs)
|
|
||||||
|
|
||||||
self._callbacks.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTimerContext(ContextManager["BaseTimerContext"]):
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def assert_timeout(self) -> None:
|
|
||||||
"""Raise TimeoutError if timeout has been exceeded."""
|
|
||||||
|
|
||||||
|
|
||||||
class TimerNoop(BaseTimerContext):
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __enter__(self) -> BaseTimerContext:
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc_val: Optional[BaseException],
|
|
||||||
exc_tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class TimerContext(BaseTimerContext):
|
|
||||||
"""Low resolution timeout context manager"""
|
|
||||||
|
|
||||||
__slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
|
||||||
self._loop = loop
|
|
||||||
self._tasks: List[asyncio.Task[Any]] = []
|
|
||||||
self._cancelled = False
|
|
||||||
self._cancelling = 0
|
|
||||||
|
|
||||||
def assert_timeout(self) -> None:
|
|
||||||
"""Raise TimeoutError if timer has already been cancelled."""
|
|
||||||
if self._cancelled:
|
|
||||||
raise asyncio.TimeoutError from None
|
|
||||||
|
|
||||||
def __enter__(self) -> BaseTimerContext:
|
|
||||||
task = asyncio.current_task(loop=self._loop)
|
|
||||||
if task is None:
|
|
||||||
raise RuntimeError("Timeout context manager should be used inside a task")
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
|
||||||
# Remember if the task was already cancelling
|
|
||||||
# so when we __exit__ we can decide if we should
|
|
||||||
# raise asyncio.TimeoutError or let the cancellation propagate
|
|
||||||
self._cancelling = task.cancelling()
|
|
||||||
|
|
||||||
if self._cancelled:
|
|
||||||
raise asyncio.TimeoutError from None
|
|
||||||
|
|
||||||
self._tasks.append(task)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc_val: Optional[BaseException],
|
|
||||||
exc_tb: Optional[TracebackType],
|
|
||||||
) -> Optional[bool]:
|
|
||||||
enter_task: Optional[asyncio.Task[Any]] = None
|
|
||||||
if self._tasks:
|
|
||||||
enter_task = self._tasks.pop()
|
|
||||||
|
|
||||||
if exc_type is asyncio.CancelledError and self._cancelled:
|
|
||||||
assert enter_task is not None
|
|
||||||
# The timeout was hit, and the task was cancelled
|
|
||||||
# so we need to uncancel the last task that entered the context manager
|
|
||||||
# since the cancellation should not leak out of the context manager
|
|
||||||
if sys.version_info >= (3, 11):
|
|
||||||
# If the task was already cancelling don't raise
|
|
||||||
# asyncio.TimeoutError and instead return None
|
|
||||||
# to allow the cancellation to propagate
|
|
||||||
if enter_task.uncancel() > self._cancelling:
|
|
||||||
return None
|
|
||||||
raise asyncio.TimeoutError from exc_val
|
|
||||||
return None
|
|
||||||
|
|
||||||
def timeout(self) -> None:
|
|
||||||
if not self._cancelled:
|
|
||||||
for task in set(self._tasks):
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
self._cancelled = True
|
|
||||||
|
|
||||||
|
|
||||||
def ceil_timeout(
|
|
||||||
delay: Optional[float], ceil_threshold: float = 5
|
|
||||||
) -> async_timeout.Timeout:
|
|
||||||
if delay is None or delay <= 0:
|
|
||||||
return async_timeout.timeout(None)
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
now = loop.time()
|
|
||||||
when = now + delay
|
|
||||||
if delay > ceil_threshold:
|
|
||||||
when = ceil(when)
|
|
||||||
return async_timeout.timeout_at(when)
|
|
||||||
|
|
||||||
|
|
||||||
class HeadersMixin:
|
|
||||||
"""Mixin for handling headers."""
|
|
||||||
|
|
||||||
ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
|
|
||||||
|
|
||||||
_headers: MultiMapping[str]
|
|
||||||
_content_type: Optional[str] = None
|
|
||||||
_content_dict: Optional[Dict[str, str]] = None
|
|
||||||
_stored_content_type: Union[str, None, _SENTINEL] = sentinel
|
|
||||||
|
|
||||||
def _parse_content_type(self, raw: Optional[str]) -> None:
|
|
||||||
self._stored_content_type = raw
|
|
||||||
if raw is None:
|
|
||||||
# default value according to RFC 2616
|
|
||||||
self._content_type = "application/octet-stream"
|
|
||||||
self._content_dict = {}
|
|
||||||
else:
|
|
||||||
content_type, content_mapping_proxy = parse_content_type(raw)
|
|
||||||
self._content_type = content_type
|
|
||||||
# _content_dict needs to be mutable so we can update it
|
|
||||||
self._content_dict = content_mapping_proxy.copy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def content_type(self) -> str:
|
|
||||||
"""The value of content part for Content-Type HTTP header."""
|
|
||||||
raw = self._headers.get(hdrs.CONTENT_TYPE)
|
|
||||||
if self._stored_content_type != raw:
|
|
||||||
self._parse_content_type(raw)
|
|
||||||
assert self._content_type is not None
|
|
||||||
return self._content_type
|
|
||||||
|
|
||||||
@property
|
|
||||||
def charset(self) -> Optional[str]:
|
|
||||||
"""The value of charset part for Content-Type HTTP header."""
|
|
||||||
raw = self._headers.get(hdrs.CONTENT_TYPE)
|
|
||||||
if self._stored_content_type != raw:
|
|
||||||
self._parse_content_type(raw)
|
|
||||||
assert self._content_dict is not None
|
|
||||||
return self._content_dict.get("charset")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def content_length(self) -> Optional[int]:
|
|
||||||
"""The value of Content-Length HTTP header."""
|
|
||||||
content_length = self._headers.get(hdrs.CONTENT_LENGTH)
|
|
||||||
return None if content_length is None else int(content_length)
|
|
||||||
|
|
||||||
|
|
||||||
def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
|
|
||||||
if not fut.done():
|
|
||||||
fut.set_result(result)
|
|
||||||
|
|
||||||
|
|
||||||
_EXC_SENTINEL = BaseException()
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorableProtocol(Protocol):
|
|
||||||
def set_exception(
|
|
||||||
self,
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: BaseException = ...,
|
|
||||||
) -> None: ... # pragma: no cover
|
|
||||||
|
|
||||||
|
|
||||||
def set_exception(
|
|
||||||
fut: "asyncio.Future[_T] | ErrorableProtocol",
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: BaseException = _EXC_SENTINEL,
|
|
||||||
) -> None:
|
|
||||||
"""Set future exception.
|
|
||||||
|
|
||||||
If the future is marked as complete, this function is a no-op.
|
|
||||||
|
|
||||||
:param exc_cause: An exception that is a direct cause of ``exc``.
|
|
||||||
Only set if provided.
|
|
||||||
"""
|
|
||||||
if asyncio.isfuture(fut) and fut.done():
|
|
||||||
return
|
|
||||||
|
|
||||||
exc_is_sentinel = exc_cause is _EXC_SENTINEL
|
|
||||||
exc_causes_itself = exc is exc_cause
|
|
||||||
if not exc_is_sentinel and not exc_causes_itself:
|
|
||||||
exc.__cause__ = exc_cause
|
|
||||||
|
|
||||||
fut.set_exception(exc)
|
|
||||||
|
|
||||||
|
|
||||||
@functools.total_ordering
|
|
||||||
class AppKey(Generic[_T]):
|
|
||||||
"""Keys for static typing support in Application."""
|
|
||||||
|
|
||||||
__slots__ = ("_name", "_t", "__orig_class__")
|
|
||||||
|
|
||||||
# This may be set by Python when instantiating with a generic type. We need to
|
|
||||||
# support this, in order to support types that are not concrete classes,
|
|
||||||
# like Iterable, which can't be passed as the second parameter to __init__.
|
|
||||||
__orig_class__: Type[object]
|
|
||||||
|
|
||||||
def __init__(self, name: str, t: Optional[Type[_T]] = None):
|
|
||||||
# Prefix with module name to help deduplicate key names.
|
|
||||||
frame = inspect.currentframe()
|
|
||||||
while frame:
|
|
||||||
if frame.f_code.co_name == "<module>":
|
|
||||||
module: str = frame.f_globals["__name__"]
|
|
||||||
break
|
|
||||||
frame = frame.f_back
|
|
||||||
|
|
||||||
self._name = module + "." + name
|
|
||||||
self._t = t
|
|
||||||
|
|
||||||
def __lt__(self, other: object) -> bool:
|
|
||||||
if isinstance(other, AppKey):
|
|
||||||
return self._name < other._name
|
|
||||||
return True # Order AppKey above other types.
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
t = self._t
|
|
||||||
if t is None:
|
|
||||||
with suppress(AttributeError):
|
|
||||||
# Set to type arg.
|
|
||||||
t = get_args(self.__orig_class__)[0]
|
|
||||||
|
|
||||||
if t is None:
|
|
||||||
t_repr = "<<Unknown>>"
|
|
||||||
elif isinstance(t, type):
|
|
||||||
if t.__module__ == "builtins":
|
|
||||||
t_repr = t.__qualname__
|
|
||||||
else:
|
|
||||||
t_repr = f"{t.__module__}.{t.__qualname__}"
|
|
||||||
else:
|
|
||||||
t_repr = repr(t)
|
|
||||||
return f"<AppKey({self._name}, type={t_repr})>"
|
|
||||||
|
|
||||||
|
|
||||||
class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
|
|
||||||
__slots__ = ("_maps",)
|
|
||||||
|
|
||||||
def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
|
|
||||||
self._maps = tuple(maps)
|
|
||||||
|
|
||||||
def __init_subclass__(cls) -> None:
|
|
||||||
raise TypeError(
|
|
||||||
"Inheritance class {} from ChainMapProxy "
|
|
||||||
"is forbidden".format(cls.__name__)
|
|
||||||
)
|
|
||||||
|
|
||||||
@overload # type: ignore[override]
|
|
||||||
def __getitem__(self, key: AppKey[_T]) -> _T: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __getitem__(self, key: str) -> Any: ...
|
|
||||||
|
|
||||||
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
|
|
||||||
for mapping in self._maps:
|
|
||||||
try:
|
|
||||||
return mapping[key]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
raise KeyError(key)
|
|
||||||
|
|
||||||
@overload # type: ignore[override]
|
|
||||||
def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: str, default: Any = ...) -> Any: ...
|
|
||||||
|
|
||||||
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
|
|
||||||
try:
|
|
||||||
return self[key]
|
|
||||||
except KeyError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
# reuses stored hash values if possible
|
|
||||||
return len(set().union(*self._maps))
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
|
|
||||||
d: Dict[Union[str, AppKey[Any]], Any] = {}
|
|
||||||
for mapping in reversed(self._maps):
|
|
||||||
# reuses stored hash values if possible
|
|
||||||
d.update(mapping)
|
|
||||||
return iter(d)
|
|
||||||
|
|
||||||
def __contains__(self, key: object) -> bool:
|
|
||||||
return any(key in m for m in self._maps)
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return any(self._maps)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
content = ", ".join(map(repr, self._maps))
|
|
||||||
return f"ChainMapProxy({content})"
|
|
||||||
|
|
||||||
|
|
||||||
# https://tools.ietf.org/html/rfc7232#section-2.3
|
|
||||||
_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
|
|
||||||
_ETAGC_RE = re.compile(_ETAGC)
|
|
||||||
_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
|
|
||||||
QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
|
|
||||||
LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
|
|
||||||
|
|
||||||
ETAG_ANY = "*"
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class ETag:
|
|
||||||
value: str
|
|
||||||
is_weak: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def validate_etag_value(value: str) -> None:
|
|
||||||
if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
|
|
||||||
raise ValueError(
|
|
||||||
f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
|
|
||||||
"""Process a date string, return a datetime object"""
|
|
||||||
if date_str is not None:
|
|
||||||
timetuple = parsedate(date_str)
|
|
||||||
if timetuple is not None:
|
|
||||||
with suppress(ValueError):
|
|
||||||
return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
|
||||||
def must_be_empty_body(method: str, code: int) -> bool:
|
|
||||||
"""Check if a request must return an empty body."""
|
|
||||||
return (
|
|
||||||
code in EMPTY_BODY_STATUS_CODES
|
|
||||||
or method in EMPTY_BODY_METHODS
|
|
||||||
or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def should_remove_content_length(method: str, code: int) -> bool:
|
|
||||||
"""Check if a Content-Length header should be removed.
|
|
||||||
|
|
||||||
This should always be a subset of must_be_empty_body
|
|
||||||
"""
|
|
||||||
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
|
|
||||||
# https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
|
|
||||||
return code in EMPTY_BODY_STATUS_CODES or (
|
|
||||||
200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
|
|
||||||
)
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
import sys
|
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import Mapping, Tuple
|
|
||||||
|
|
||||||
from . import __version__
|
|
||||||
from .http_exceptions import HttpProcessingError as HttpProcessingError
|
|
||||||
from .http_parser import (
|
|
||||||
HeadersParser as HeadersParser,
|
|
||||||
HttpParser as HttpParser,
|
|
||||||
HttpRequestParser as HttpRequestParser,
|
|
||||||
HttpResponseParser as HttpResponseParser,
|
|
||||||
RawRequestMessage as RawRequestMessage,
|
|
||||||
RawResponseMessage as RawResponseMessage,
|
|
||||||
)
|
|
||||||
from .http_websocket import (
|
|
||||||
WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE,
|
|
||||||
WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE,
|
|
||||||
WS_KEY as WS_KEY,
|
|
||||||
WebSocketError as WebSocketError,
|
|
||||||
WebSocketReader as WebSocketReader,
|
|
||||||
WebSocketWriter as WebSocketWriter,
|
|
||||||
WSCloseCode as WSCloseCode,
|
|
||||||
WSMessage as WSMessage,
|
|
||||||
WSMsgType as WSMsgType,
|
|
||||||
ws_ext_gen as ws_ext_gen,
|
|
||||||
ws_ext_parse as ws_ext_parse,
|
|
||||||
)
|
|
||||||
from .http_writer import (
|
|
||||||
HttpVersion as HttpVersion,
|
|
||||||
HttpVersion10 as HttpVersion10,
|
|
||||||
HttpVersion11 as HttpVersion11,
|
|
||||||
StreamWriter as StreamWriter,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"HttpProcessingError",
|
|
||||||
"RESPONSES",
|
|
||||||
"SERVER_SOFTWARE",
|
|
||||||
# .http_writer
|
|
||||||
"StreamWriter",
|
|
||||||
"HttpVersion",
|
|
||||||
"HttpVersion10",
|
|
||||||
"HttpVersion11",
|
|
||||||
# .http_parser
|
|
||||||
"HeadersParser",
|
|
||||||
"HttpParser",
|
|
||||||
"HttpRequestParser",
|
|
||||||
"HttpResponseParser",
|
|
||||||
"RawRequestMessage",
|
|
||||||
"RawResponseMessage",
|
|
||||||
# .http_websocket
|
|
||||||
"WS_CLOSED_MESSAGE",
|
|
||||||
"WS_CLOSING_MESSAGE",
|
|
||||||
"WS_KEY",
|
|
||||||
"WebSocketReader",
|
|
||||||
"WebSocketWriter",
|
|
||||||
"ws_ext_gen",
|
|
||||||
"ws_ext_parse",
|
|
||||||
"WSMessage",
|
|
||||||
"WebSocketError",
|
|
||||||
"WSMsgType",
|
|
||||||
"WSCloseCode",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format(
|
|
||||||
sys.version_info, __version__
|
|
||||||
)
|
|
||||||
|
|
||||||
RESPONSES: Mapping[int, Tuple[str, str]] = {
|
|
||||||
v: (v.phrase, v.description) for v in HTTPStatus.__members__.values()
|
|
||||||
}
|
|
||||||
|
|
@ -1,112 +0,0 @@
|
||||||
"""Low-level http related exceptions."""
|
|
||||||
|
|
||||||
from textwrap import indent
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from .typedefs import _CIMultiDict
|
|
||||||
|
|
||||||
__all__ = ("HttpProcessingError",)
|
|
||||||
|
|
||||||
|
|
||||||
class HttpProcessingError(Exception):
|
|
||||||
"""HTTP error.
|
|
||||||
|
|
||||||
Shortcut for raising HTTP errors with custom code, message and headers.
|
|
||||||
|
|
||||||
code: HTTP Error code.
|
|
||||||
message: (optional) Error message.
|
|
||||||
headers: (optional) Headers to be sent in response, a list of pairs
|
|
||||||
"""
|
|
||||||
|
|
||||||
code = 0
|
|
||||||
message = ""
|
|
||||||
headers = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
code: Optional[int] = None,
|
|
||||||
message: str = "",
|
|
||||||
headers: Optional[_CIMultiDict] = None,
|
|
||||||
) -> None:
|
|
||||||
if code is not None:
|
|
||||||
self.code = code
|
|
||||||
self.headers = headers
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
msg = indent(self.message, " ")
|
|
||||||
return f"{self.code}, message:\n{msg}"
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<{self.__class__.__name__}: {self.code}, message={self.message!r}>"
|
|
||||||
|
|
||||||
|
|
||||||
class BadHttpMessage(HttpProcessingError):
|
|
||||||
|
|
||||||
code = 400
|
|
||||||
message = "Bad Request"
|
|
||||||
|
|
||||||
def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> None:
|
|
||||||
super().__init__(message=message, headers=headers)
|
|
||||||
self.args = (message,)
|
|
||||||
|
|
||||||
|
|
||||||
class HttpBadRequest(BadHttpMessage):
|
|
||||||
|
|
||||||
code = 400
|
|
||||||
message = "Bad Request"
|
|
||||||
|
|
||||||
|
|
||||||
class PayloadEncodingError(BadHttpMessage):
|
|
||||||
"""Base class for payload errors"""
|
|
||||||
|
|
||||||
|
|
||||||
class ContentEncodingError(PayloadEncodingError):
|
|
||||||
"""Content encoding error."""
|
|
||||||
|
|
||||||
|
|
||||||
class TransferEncodingError(PayloadEncodingError):
|
|
||||||
"""transfer encoding error."""
|
|
||||||
|
|
||||||
|
|
||||||
class ContentLengthError(PayloadEncodingError):
|
|
||||||
"""Not enough data to satisfy content length header."""
|
|
||||||
|
|
||||||
|
|
||||||
class LineTooLong(BadHttpMessage):
|
|
||||||
def __init__(
|
|
||||||
self, line: str, limit: str = "Unknown", actual_size: str = "Unknown"
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
f"Got more than {limit} bytes ({actual_size}) when reading {line}."
|
|
||||||
)
|
|
||||||
self.args = (line, limit, actual_size)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidHeader(BadHttpMessage):
|
|
||||||
def __init__(self, hdr: Union[bytes, str]) -> None:
|
|
||||||
hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr
|
|
||||||
super().__init__(f"Invalid HTTP header: {hdr!r}")
|
|
||||||
self.hdr = hdr_s
|
|
||||||
self.args = (hdr,)
|
|
||||||
|
|
||||||
|
|
||||||
class BadStatusLine(BadHttpMessage):
|
|
||||||
def __init__(self, line: str = "", error: Optional[str] = None) -> None:
|
|
||||||
if not isinstance(line, str):
|
|
||||||
line = repr(line)
|
|
||||||
super().__init__(error or f"Bad status line {line!r}")
|
|
||||||
self.args = (line,)
|
|
||||||
self.line = line
|
|
||||||
|
|
||||||
|
|
||||||
class BadHttpMethod(BadStatusLine):
|
|
||||||
"""Invalid HTTP method in status line."""
|
|
||||||
|
|
||||||
def __init__(self, line: str = "", error: Optional[str] = None) -> None:
|
|
||||||
super().__init__(line, error or f"Bad HTTP method in status line {line!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidURLError(BadHttpMessage):
|
|
||||||
pass
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,36 +0,0 @@
|
||||||
"""WebSocket protocol versions 13 and 8."""
|
|
||||||
|
|
||||||
from ._websocket.helpers import WS_KEY, ws_ext_gen, ws_ext_parse
|
|
||||||
from ._websocket.models import (
|
|
||||||
WS_CLOSED_MESSAGE,
|
|
||||||
WS_CLOSING_MESSAGE,
|
|
||||||
WebSocketError,
|
|
||||||
WSCloseCode,
|
|
||||||
WSHandshakeError,
|
|
||||||
WSMessage,
|
|
||||||
WSMsgType,
|
|
||||||
)
|
|
||||||
from ._websocket.reader import WebSocketReader
|
|
||||||
from ._websocket.writer import WebSocketWriter
|
|
||||||
|
|
||||||
# Messages that the WebSocketResponse.receive needs to handle internally
|
|
||||||
_INTERNAL_RECEIVE_TYPES = frozenset(
|
|
||||||
(WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.PING, WSMsgType.PONG)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"WS_CLOSED_MESSAGE",
|
|
||||||
"WS_CLOSING_MESSAGE",
|
|
||||||
"WS_KEY",
|
|
||||||
"WebSocketReader",
|
|
||||||
"WebSocketWriter",
|
|
||||||
"WSMessage",
|
|
||||||
"WebSocketError",
|
|
||||||
"WSMsgType",
|
|
||||||
"WSCloseCode",
|
|
||||||
"ws_ext_gen",
|
|
||||||
"ws_ext_parse",
|
|
||||||
"WSHandshakeError",
|
|
||||||
"WSMessage",
|
|
||||||
)
|
|
||||||
|
|
@ -1,378 +0,0 @@
|
||||||
"""Http related parsers and protocol."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
from typing import ( # noqa
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from multidict import CIMultiDict
|
|
||||||
|
|
||||||
from .abc import AbstractStreamWriter
|
|
||||||
from .base_protocol import BaseProtocol
|
|
||||||
from .client_exceptions import ClientConnectionResetError
|
|
||||||
from .compression_utils import ZLibCompressor
|
|
||||||
from .helpers import NO_EXTENSIONS
|
|
||||||
|
|
||||||
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
|
|
||||||
|
|
||||||
|
|
||||||
MIN_PAYLOAD_FOR_WRITELINES = 2048
|
|
||||||
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
|
|
||||||
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
|
|
||||||
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
|
|
||||||
# writelines is not safe for use
|
|
||||||
# on Python 3.12+ until 3.12.9
|
|
||||||
# on Python 3.13+ until 3.13.2
|
|
||||||
# and on older versions it not any faster than write
|
|
||||||
# CVE-2024-12254: https://github.com/python/cpython/pull/127656
|
|
||||||
|
|
||||||
|
|
||||||
class HttpVersion(NamedTuple):
|
|
||||||
major: int
|
|
||||||
minor: int
|
|
||||||
|
|
||||||
|
|
||||||
HttpVersion10 = HttpVersion(1, 0)
|
|
||||||
HttpVersion11 = HttpVersion(1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
|
|
||||||
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
|
|
||||||
|
|
||||||
|
|
||||||
class StreamWriter(AbstractStreamWriter):
|
|
||||||
|
|
||||||
length: Optional[int] = None
|
|
||||||
chunked: bool = False
|
|
||||||
_eof: bool = False
|
|
||||||
_compress: Optional[ZLibCompressor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
protocol: BaseProtocol,
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
on_chunk_sent: _T_OnChunkSent = None,
|
|
||||||
on_headers_sent: _T_OnHeadersSent = None,
|
|
||||||
) -> None:
|
|
||||||
self._protocol = protocol
|
|
||||||
self.loop = loop
|
|
||||||
self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
|
|
||||||
self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
|
|
||||||
self._headers_buf: Optional[bytes] = None
|
|
||||||
self._headers_written: bool = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def transport(self) -> Optional[asyncio.Transport]:
|
|
||||||
return self._protocol.transport
|
|
||||||
|
|
||||||
@property
|
|
||||||
def protocol(self) -> BaseProtocol:
|
|
||||||
return self._protocol
|
|
||||||
|
|
||||||
def enable_chunking(self) -> None:
|
|
||||||
self.chunked = True
|
|
||||||
|
|
||||||
def enable_compression(
|
|
||||||
self, encoding: str = "deflate", strategy: Optional[int] = None
|
|
||||||
) -> None:
|
|
||||||
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
|
|
||||||
|
|
||||||
def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
|
|
||||||
size = len(chunk)
|
|
||||||
self.buffer_size += size
|
|
||||||
self.output_size += size
|
|
||||||
transport = self._protocol.transport
|
|
||||||
if transport is None or transport.is_closing():
|
|
||||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
|
||||||
transport.write(chunk)
|
|
||||||
|
|
||||||
def _writelines(self, chunks: Iterable[bytes]) -> None:
|
|
||||||
size = 0
|
|
||||||
for chunk in chunks:
|
|
||||||
size += len(chunk)
|
|
||||||
self.buffer_size += size
|
|
||||||
self.output_size += size
|
|
||||||
transport = self._protocol.transport
|
|
||||||
if transport is None or transport.is_closing():
|
|
||||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
|
||||||
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
|
|
||||||
transport.write(b"".join(chunks))
|
|
||||||
else:
|
|
||||||
transport.writelines(chunks)
|
|
||||||
|
|
||||||
def _write_chunked_payload(
|
|
||||||
self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
|
|
||||||
) -> None:
|
|
||||||
"""Write a chunk with proper chunked encoding."""
|
|
||||||
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
|
|
||||||
self._writelines((chunk_len_pre, chunk, b"\r\n"))
|
|
||||||
|
|
||||||
def _send_headers_with_payload(
|
|
||||||
self,
|
|
||||||
chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"],
|
|
||||||
is_eof: bool,
|
|
||||||
) -> None:
|
|
||||||
"""Send buffered headers with payload, coalescing into single write."""
|
|
||||||
# Mark headers as written
|
|
||||||
self._headers_written = True
|
|
||||||
headers_buf = self._headers_buf
|
|
||||||
self._headers_buf = None
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
# Safe because callers (write() and write_eof()) only invoke this method
|
|
||||||
# after checking that self._headers_buf is truthy
|
|
||||||
assert headers_buf is not None
|
|
||||||
|
|
||||||
if not self.chunked:
|
|
||||||
# Non-chunked: coalesce headers with body
|
|
||||||
if chunk:
|
|
||||||
self._writelines((headers_buf, chunk))
|
|
||||||
else:
|
|
||||||
self._write(headers_buf)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Coalesce headers with chunked data
|
|
||||||
if chunk:
|
|
||||||
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
|
|
||||||
if is_eof:
|
|
||||||
self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
|
|
||||||
else:
|
|
||||||
self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n"))
|
|
||||||
elif is_eof:
|
|
||||||
self._writelines((headers_buf, b"0\r\n\r\n"))
|
|
||||||
else:
|
|
||||||
self._write(headers_buf)
|
|
||||||
|
|
||||||
async def write(
|
|
||||||
self,
|
|
||||||
chunk: Union[bytes, bytearray, memoryview],
|
|
||||||
*,
|
|
||||||
drain: bool = True,
|
|
||||||
LIMIT: int = 0x10000,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Writes chunk of data to a stream.
|
|
||||||
|
|
||||||
write_eof() indicates end of stream.
|
|
||||||
writer can't be used after write_eof() method being called.
|
|
||||||
write() return drain future.
|
|
||||||
"""
|
|
||||||
if self._on_chunk_sent is not None:
|
|
||||||
await self._on_chunk_sent(chunk)
|
|
||||||
|
|
||||||
if isinstance(chunk, memoryview):
|
|
||||||
if chunk.nbytes != len(chunk):
|
|
||||||
# just reshape it
|
|
||||||
chunk = chunk.cast("c")
|
|
||||||
|
|
||||||
if self._compress is not None:
|
|
||||||
chunk = await self._compress.compress(chunk)
|
|
||||||
if not chunk:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.length is not None:
|
|
||||||
chunk_len = len(chunk)
|
|
||||||
if self.length >= chunk_len:
|
|
||||||
self.length = self.length - chunk_len
|
|
||||||
else:
|
|
||||||
chunk = chunk[: self.length]
|
|
||||||
self.length = 0
|
|
||||||
if not chunk:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle buffered headers for small payload optimization
|
|
||||||
if self._headers_buf and not self._headers_written:
|
|
||||||
self._send_headers_with_payload(chunk, False)
|
|
||||||
if drain and self.buffer_size > LIMIT:
|
|
||||||
self.buffer_size = 0
|
|
||||||
await self.drain()
|
|
||||||
return
|
|
||||||
|
|
||||||
if chunk:
|
|
||||||
if self.chunked:
|
|
||||||
self._write_chunked_payload(chunk)
|
|
||||||
else:
|
|
||||||
self._write(chunk)
|
|
||||||
|
|
||||||
if drain and self.buffer_size > LIMIT:
|
|
||||||
self.buffer_size = 0
|
|
||||||
await self.drain()
|
|
||||||
|
|
||||||
async def write_headers(
|
|
||||||
self, status_line: str, headers: "CIMultiDict[str]"
|
|
||||||
) -> None:
|
|
||||||
"""Write headers to the stream."""
|
|
||||||
if self._on_headers_sent is not None:
|
|
||||||
await self._on_headers_sent(headers)
|
|
||||||
# status + headers
|
|
||||||
buf = _serialize_headers(status_line, headers)
|
|
||||||
self._headers_written = False
|
|
||||||
self._headers_buf = buf
|
|
||||||
|
|
||||||
def send_headers(self) -> None:
|
|
||||||
"""Force sending buffered headers if not already sent."""
|
|
||||||
if not self._headers_buf or self._headers_written:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._headers_written = True
|
|
||||||
headers_buf = self._headers_buf
|
|
||||||
self._headers_buf = None
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
# Safe because we only enter this block when self._headers_buf is truthy
|
|
||||||
assert headers_buf is not None
|
|
||||||
|
|
||||||
self._write(headers_buf)
|
|
||||||
|
|
||||||
def set_eof(self) -> None:
|
|
||||||
"""Indicate that the message is complete."""
|
|
||||||
if self._eof:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If headers haven't been sent yet, send them now
|
|
||||||
# This handles the case where there's no body at all
|
|
||||||
if self._headers_buf and not self._headers_written:
|
|
||||||
self._headers_written = True
|
|
||||||
headers_buf = self._headers_buf
|
|
||||||
self._headers_buf = None
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
# Safe because we only enter this block when self._headers_buf is truthy
|
|
||||||
assert headers_buf is not None
|
|
||||||
|
|
||||||
# Combine headers and chunked EOF marker in a single write
|
|
||||||
if self.chunked:
|
|
||||||
self._writelines((headers_buf, b"0\r\n\r\n"))
|
|
||||||
else:
|
|
||||||
self._write(headers_buf)
|
|
||||||
elif self.chunked and self._headers_written:
|
|
||||||
# Headers already sent, just send the final chunk marker
|
|
||||||
self._write(b"0\r\n\r\n")
|
|
||||||
|
|
||||||
self._eof = True
|
|
||||||
|
|
||||||
async def write_eof(self, chunk: bytes = b"") -> None:
|
|
||||||
if self._eof:
|
|
||||||
return
|
|
||||||
|
|
||||||
if chunk and self._on_chunk_sent is not None:
|
|
||||||
await self._on_chunk_sent(chunk)
|
|
||||||
|
|
||||||
# Handle body/compression
|
|
||||||
if self._compress:
|
|
||||||
chunks: List[bytes] = []
|
|
||||||
chunks_len = 0
|
|
||||||
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
|
|
||||||
chunks_len = len(compressed_chunk)
|
|
||||||
chunks.append(compressed_chunk)
|
|
||||||
|
|
||||||
flush_chunk = self._compress.flush()
|
|
||||||
chunks_len += len(flush_chunk)
|
|
||||||
chunks.append(flush_chunk)
|
|
||||||
assert chunks_len
|
|
||||||
|
|
||||||
# Send buffered headers with compressed data if not yet sent
|
|
||||||
if self._headers_buf and not self._headers_written:
|
|
||||||
self._headers_written = True
|
|
||||||
headers_buf = self._headers_buf
|
|
||||||
self._headers_buf = None
|
|
||||||
|
|
||||||
if self.chunked:
|
|
||||||
# Coalesce headers with compressed chunked data
|
|
||||||
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
|
|
||||||
self._writelines(
|
|
||||||
(headers_buf, chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Coalesce headers with compressed data
|
|
||||||
self._writelines((headers_buf, *chunks))
|
|
||||||
await self.drain()
|
|
||||||
self._eof = True
|
|
||||||
return
|
|
||||||
|
|
||||||
# Headers already sent, just write compressed data
|
|
||||||
if self.chunked:
|
|
||||||
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
|
|
||||||
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
|
|
||||||
elif len(chunks) > 1:
|
|
||||||
self._writelines(chunks)
|
|
||||||
else:
|
|
||||||
self._write(chunks[0])
|
|
||||||
await self.drain()
|
|
||||||
self._eof = True
|
|
||||||
return
|
|
||||||
|
|
||||||
# No compression - send buffered headers if not yet sent
|
|
||||||
if self._headers_buf and not self._headers_written:
|
|
||||||
# Use helper to send headers with payload
|
|
||||||
self._send_headers_with_payload(chunk, True)
|
|
||||||
await self.drain()
|
|
||||||
self._eof = True
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle remaining body
|
|
||||||
if self.chunked:
|
|
||||||
if chunk:
|
|
||||||
# Write final chunk with EOF marker
|
|
||||||
self._writelines(
|
|
||||||
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n0\r\n\r\n")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._write(b"0\r\n\r\n")
|
|
||||||
await self.drain()
|
|
||||||
self._eof = True
|
|
||||||
return
|
|
||||||
|
|
||||||
if chunk:
|
|
||||||
self._write(chunk)
|
|
||||||
await self.drain()
|
|
||||||
|
|
||||||
self._eof = True
|
|
||||||
|
|
||||||
async def drain(self) -> None:
|
|
||||||
"""Flush the write buffer.
|
|
||||||
|
|
||||||
The intended use is to write
|
|
||||||
|
|
||||||
await w.write(data)
|
|
||||||
await w.drain()
|
|
||||||
"""
|
|
||||||
protocol = self._protocol
|
|
||||||
if protocol.transport is not None and protocol._paused:
|
|
||||||
await protocol._drain_helper()
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_header(string: str) -> str:
|
|
||||||
if "\r" in string or "\n" in string:
|
|
||||||
raise ValueError(
|
|
||||||
"Newline or carriage return detected in headers. "
|
|
||||||
"Potential header injection attack."
|
|
||||||
)
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
|
|
||||||
headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
|
|
||||||
line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
|
|
||||||
return line.encode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
_serialize_headers = _py_serialize_headers
|
|
||||||
|
|
||||||
try:
|
|
||||||
import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
_c_serialize_headers = _http_writer._serialize_headers
|
|
||||||
if not NO_EXTENSIONS:
|
|
||||||
_serialize_headers = _c_serialize_headers
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
import logging
|
|
||||||
|
|
||||||
access_logger = logging.getLogger("aiohttp.access")
|
|
||||||
client_logger = logging.getLogger("aiohttp.client")
|
|
||||||
internal_logger = logging.getLogger("aiohttp.internal")
|
|
||||||
server_logger = logging.getLogger("aiohttp.server")
|
|
||||||
web_logger = logging.getLogger("aiohttp.web")
|
|
||||||
ws_logger = logging.getLogger("aiohttp.websocket")
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,78 +0,0 @@
|
||||||
"""
|
|
||||||
Payload implementation for coroutines as data provider.
|
|
||||||
|
|
||||||
As a simple case, you can upload data from file::
|
|
||||||
|
|
||||||
@aiohttp.streamer
|
|
||||||
async def file_sender(writer, file_name=None):
|
|
||||||
with open(file_name, 'rb') as f:
|
|
||||||
chunk = f.read(2**16)
|
|
||||||
while chunk:
|
|
||||||
await writer.write(chunk)
|
|
||||||
|
|
||||||
chunk = f.read(2**16)
|
|
||||||
|
|
||||||
Then you can use `file_sender` like this:
|
|
||||||
|
|
||||||
async with session.post('http://httpbin.org/post',
|
|
||||||
data=file_sender(file_name='huge_file')) as resp:
|
|
||||||
print(await resp.text())
|
|
||||||
|
|
||||||
..note:: Coroutine must accept `writer` as first argument
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import types
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Awaitable, Callable, Dict, Tuple
|
|
||||||
|
|
||||||
from .abc import AbstractStreamWriter
|
|
||||||
from .payload import Payload, payload_type
|
|
||||||
|
|
||||||
__all__ = ("streamer",)
|
|
||||||
|
|
||||||
|
|
||||||
class _stream_wrapper:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
coro: Callable[..., Awaitable[None]],
|
|
||||||
args: Tuple[Any, ...],
|
|
||||||
kwargs: Dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
self.coro = types.coroutine(coro)
|
|
||||||
self.args = args
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
async def __call__(self, writer: AbstractStreamWriter) -> None:
|
|
||||||
await self.coro(writer, *self.args, **self.kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class streamer:
|
|
||||||
def __init__(self, coro: Callable[..., Awaitable[None]]) -> None:
|
|
||||||
warnings.warn(
|
|
||||||
"@streamer is deprecated, use async generators instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
self.coro = coro
|
|
||||||
|
|
||||||
def __call__(self, *args: Any, **kwargs: Any) -> _stream_wrapper:
|
|
||||||
return _stream_wrapper(self.coro, args, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@payload_type(_stream_wrapper)
|
|
||||||
class StreamWrapperPayload(Payload):
|
|
||||||
async def write(self, writer: AbstractStreamWriter) -> None:
|
|
||||||
await self._value(writer)
|
|
||||||
|
|
||||||
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
|
|
||||||
raise TypeError("Unable to decode.")
|
|
||||||
|
|
||||||
|
|
||||||
@payload_type(streamer)
|
|
||||||
class StreamPayload(StreamWrapperPayload):
|
|
||||||
def __init__(self, value: Any, *args: Any, **kwargs: Any) -> None:
|
|
||||||
super().__init__(value(), *args, **kwargs)
|
|
||||||
|
|
||||||
async def write(self, writer: AbstractStreamWriter) -> None:
|
|
||||||
await self._value(writer)
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
Marker
|
|
||||||
|
|
@ -1,444 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
import inspect
|
|
||||||
import warnings
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Union,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from .test_utils import (
|
|
||||||
BaseTestServer,
|
|
||||||
RawTestServer,
|
|
||||||
TestClient,
|
|
||||||
TestServer,
|
|
||||||
loop_context,
|
|
||||||
setup_test_loop,
|
|
||||||
teardown_test_loop,
|
|
||||||
unused_port as _unused_port,
|
|
||||||
)
|
|
||||||
from .web import Application, BaseRequest, Request
|
|
||||||
from .web_protocol import _RequestHandler
|
|
||||||
|
|
||||||
try:
|
|
||||||
import uvloop
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
uvloop = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
|
|
||||||
class AiohttpClient(Protocol):
|
|
||||||
@overload
|
|
||||||
async def __call__(
|
|
||||||
self,
|
|
||||||
__param: Application,
|
|
||||||
*,
|
|
||||||
server_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> TestClient[Request, Application]: ...
|
|
||||||
@overload
|
|
||||||
async def __call__(
|
|
||||||
self,
|
|
||||||
__param: BaseTestServer,
|
|
||||||
*,
|
|
||||||
server_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> TestClient[BaseRequest, None]: ...
|
|
||||||
|
|
||||||
|
|
||||||
class AiohttpServer(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self, app: Application, *, port: Optional[int] = None, **kwargs: Any
|
|
||||||
) -> Awaitable[TestServer]: ...
|
|
||||||
|
|
||||||
|
|
||||||
class AiohttpRawServer(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
|
|
||||||
) -> Awaitable[RawTestServer]: ...
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser): # type: ignore[no-untyped-def]
|
|
||||||
parser.addoption(
|
|
||||||
"--aiohttp-fast",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="run tests faster by disabling extra checks",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
|
||||||
"--aiohttp-loop",
|
|
||||||
action="store",
|
|
||||||
default="pyloop",
|
|
||||||
help="run tests with specific loop: pyloop, uvloop or all",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
|
||||||
"--aiohttp-enable-loop-debug",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="enable event loop debug mode",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def]
|
|
||||||
"""Set up pytest fixture.
|
|
||||||
|
|
||||||
Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
|
|
||||||
"""
|
|
||||||
func = fixturedef.func
|
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(func):
|
|
||||||
# async generator fixture
|
|
||||||
is_async_gen = True
|
|
||||||
elif inspect.iscoroutinefunction(func):
|
|
||||||
# regular async fixture
|
|
||||||
is_async_gen = False
|
|
||||||
else:
|
|
||||||
# not an async fixture, nothing to do
|
|
||||||
return
|
|
||||||
|
|
||||||
strip_request = False
|
|
||||||
if "request" not in fixturedef.argnames:
|
|
||||||
fixturedef.argnames += ("request",)
|
|
||||||
strip_request = True
|
|
||||||
|
|
||||||
def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
|
|
||||||
request = kwargs["request"]
|
|
||||||
if strip_request:
|
|
||||||
del kwargs["request"]
|
|
||||||
|
|
||||||
# if neither the fixture nor the test use the 'loop' fixture,
|
|
||||||
# 'getfixturevalue' will fail because the test is not parameterized
|
|
||||||
# (this can be removed someday if 'loop' is no longer parameterized)
|
|
||||||
if "loop" not in request.fixturenames:
|
|
||||||
raise Exception(
|
|
||||||
"Asynchronous fixtures must depend on the 'loop' fixture or "
|
|
||||||
"be used in tests depending from it."
|
|
||||||
)
|
|
||||||
|
|
||||||
_loop = request.getfixturevalue("loop")
|
|
||||||
|
|
||||||
if is_async_gen:
|
|
||||||
# for async generators, we need to advance the generator once,
|
|
||||||
# then advance it again in a finalizer
|
|
||||||
gen = func(*args, **kwargs)
|
|
||||||
|
|
||||||
def finalizer(): # type: ignore[no-untyped-def]
|
|
||||||
try:
|
|
||||||
return _loop.run_until_complete(gen.__anext__())
|
|
||||||
except StopAsyncIteration:
|
|
||||||
pass
|
|
||||||
|
|
||||||
request.addfinalizer(finalizer)
|
|
||||||
return _loop.run_until_complete(gen.__anext__())
|
|
||||||
else:
|
|
||||||
return _loop.run_until_complete(func(*args, **kwargs))
|
|
||||||
|
|
||||||
fixturedef.func = wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def fast(request): # type: ignore[no-untyped-def]
|
|
||||||
"""--fast config option"""
|
|
||||||
return request.config.getoption("--aiohttp-fast")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def loop_debug(request): # type: ignore[no-untyped-def]
|
|
||||||
"""--enable-loop-debug config option"""
|
|
||||||
return request.config.getoption("--aiohttp-enable-loop-debug")
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _runtime_warning_context(): # type: ignore[no-untyped-def]
|
|
||||||
"""Context manager which checks for RuntimeWarnings.
|
|
||||||
|
|
||||||
This exists specifically to
|
|
||||||
avoid "coroutine 'X' was never awaited" warnings being missed.
|
|
||||||
|
|
||||||
If RuntimeWarnings occur in the context a RuntimeError is raised.
|
|
||||||
"""
|
|
||||||
with warnings.catch_warnings(record=True) as _warnings:
|
|
||||||
yield
|
|
||||||
rw = [
|
|
||||||
"{w.filename}:{w.lineno}:{w.message}".format(w=w)
|
|
||||||
for w in _warnings
|
|
||||||
if w.category == RuntimeWarning
|
|
||||||
]
|
|
||||||
if rw:
|
|
||||||
raise RuntimeError(
|
|
||||||
"{} Runtime Warning{},\n{}".format(
|
|
||||||
len(rw), "" if len(rw) == 1 else "s", "\n".join(rw)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _passthrough_loop_context(loop, fast=False): # type: ignore[no-untyped-def]
|
|
||||||
"""Passthrough loop context.
|
|
||||||
|
|
||||||
Sets up and tears down a loop unless one is passed in via the loop
|
|
||||||
argument when it's passed straight through.
|
|
||||||
"""
|
|
||||||
if loop:
|
|
||||||
# loop already exists, pass it straight through
|
|
||||||
yield loop
|
|
||||||
else:
|
|
||||||
# this shadows loop_context's standard behavior
|
|
||||||
loop = setup_test_loop()
|
|
||||||
yield loop
|
|
||||||
teardown_test_loop(loop, fast=fast)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def]
|
|
||||||
"""Fix pytest collecting for coroutines."""
|
|
||||||
if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj):
|
|
||||||
return list(collector._genfunctions(name, obj))
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def]
|
|
||||||
"""Run coroutines in an event loop instead of a normal function call."""
|
|
||||||
fast = pyfuncitem.config.getoption("--aiohttp-fast")
|
|
||||||
if inspect.iscoroutinefunction(pyfuncitem.function):
|
|
||||||
existing_loop = (
|
|
||||||
pyfuncitem.funcargs.get("proactor_loop")
|
|
||||||
or pyfuncitem.funcargs.get("selector_loop")
|
|
||||||
or pyfuncitem.funcargs.get("uvloop_loop")
|
|
||||||
or pyfuncitem.funcargs.get("loop", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
with _runtime_warning_context():
|
|
||||||
with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
|
|
||||||
testargs = {
|
|
||||||
arg: pyfuncitem.funcargs[arg]
|
|
||||||
for arg in pyfuncitem._fixtureinfo.argnames
|
|
||||||
}
|
|
||||||
_loop.run_until_complete(pyfuncitem.obj(**testargs))
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def]
|
|
||||||
if "loop_factory" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
|
|
||||||
loops = metafunc.config.option.aiohttp_loop
|
|
||||||
avail_factories: dict[str, Callable[[], asyncio.AbstractEventLoop]]
|
|
||||||
avail_factories = {"pyloop": asyncio.new_event_loop}
|
|
||||||
|
|
||||||
if uvloop is not None: # pragma: no cover
|
|
||||||
avail_factories["uvloop"] = uvloop.new_event_loop
|
|
||||||
|
|
||||||
if loops == "all":
|
|
||||||
loops = "pyloop,uvloop?"
|
|
||||||
|
|
||||||
factories = {} # type: ignore[var-annotated]
|
|
||||||
for name in loops.split(","):
|
|
||||||
required = not name.endswith("?")
|
|
||||||
name = name.strip(" ?")
|
|
||||||
if name not in avail_factories: # pragma: no cover
|
|
||||||
if required:
|
|
||||||
raise ValueError(
|
|
||||||
"Unknown loop '%s', available loops: %s"
|
|
||||||
% (name, list(factories.keys()))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
factories[name] = avail_factories[name]
|
|
||||||
metafunc.parametrize(
|
|
||||||
"loop_factory", list(factories.values()), ids=list(factories.keys())
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def loop(
|
|
||||||
loop_factory: Callable[[], asyncio.AbstractEventLoop],
|
|
||||||
fast: bool,
|
|
||||||
loop_debug: bool,
|
|
||||||
) -> Iterator[asyncio.AbstractEventLoop]:
|
|
||||||
"""Return an instance of the event loop."""
|
|
||||||
with loop_context(loop_factory, fast=fast) as _loop:
|
|
||||||
if loop_debug:
|
|
||||||
_loop.set_debug(True) # pragma: no cover
|
|
||||||
asyncio.set_event_loop(_loop)
|
|
||||||
yield _loop
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def proactor_loop() -> Iterator[asyncio.AbstractEventLoop]:
|
|
||||||
factory = asyncio.ProactorEventLoop # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
with loop_context(factory) as _loop:
|
|
||||||
asyncio.set_event_loop(_loop)
|
|
||||||
yield _loop
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]:
|
|
||||||
warnings.warn(
|
|
||||||
"Deprecated, use aiohttp_unused_port fixture instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return aiohttp_unused_port
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def aiohttp_unused_port() -> Callable[[], int]:
|
|
||||||
"""Return a port that is unused on the current host."""
|
|
||||||
return _unused_port
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
|
|
||||||
"""Factory to create a TestServer instance, given an app.
|
|
||||||
|
|
||||||
aiohttp_server(app, **kwargs)
|
|
||||||
"""
|
|
||||||
servers = []
|
|
||||||
|
|
||||||
async def go(
|
|
||||||
app: Application,
|
|
||||||
*,
|
|
||||||
host: str = "127.0.0.1",
|
|
||||||
port: Optional[int] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> TestServer:
|
|
||||||
server = TestServer(app, host=host, port=port)
|
|
||||||
await server.start_server(loop=loop, **kwargs)
|
|
||||||
servers.append(server)
|
|
||||||
return server
|
|
||||||
|
|
||||||
yield go
|
|
||||||
|
|
||||||
async def finalize() -> None:
|
|
||||||
while servers:
|
|
||||||
await servers.pop().close()
|
|
||||||
|
|
||||||
loop.run_until_complete(finalize())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_server(aiohttp_server): # type: ignore[no-untyped-def] # pragma: no cover
|
|
||||||
warnings.warn(
|
|
||||||
"Deprecated, use aiohttp_server fixture instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return aiohttp_server
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]:
|
|
||||||
"""Factory to create a RawTestServer instance, given a web handler.
|
|
||||||
|
|
||||||
aiohttp_raw_server(handler, **kwargs)
|
|
||||||
"""
|
|
||||||
servers = []
|
|
||||||
|
|
||||||
async def go(
|
|
||||||
handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
|
|
||||||
) -> RawTestServer:
|
|
||||||
server = RawTestServer(handler, port=port)
|
|
||||||
await server.start_server(loop=loop, **kwargs)
|
|
||||||
servers.append(server)
|
|
||||||
return server
|
|
||||||
|
|
||||||
yield go
|
|
||||||
|
|
||||||
async def finalize() -> None:
|
|
||||||
while servers:
|
|
||||||
await servers.pop().close()
|
|
||||||
|
|
||||||
loop.run_until_complete(finalize())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def raw_test_server( # type: ignore[no-untyped-def] # pragma: no cover
|
|
||||||
aiohttp_raw_server,
|
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
"Deprecated, use aiohttp_raw_server fixture instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return aiohttp_raw_server
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def aiohttp_client(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpClient]:
|
|
||||||
"""Factory to create a TestClient instance.
|
|
||||||
|
|
||||||
aiohttp_client(app, **kwargs)
|
|
||||||
aiohttp_client(server, **kwargs)
|
|
||||||
aiohttp_client(raw_server, **kwargs)
|
|
||||||
"""
|
|
||||||
clients = []
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def go(
|
|
||||||
__param: Application,
|
|
||||||
*,
|
|
||||||
server_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> TestClient[Request, Application]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def go(
|
|
||||||
__param: BaseTestServer,
|
|
||||||
*,
|
|
||||||
server_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> TestClient[BaseRequest, None]: ...
|
|
||||||
|
|
||||||
async def go(
|
|
||||||
__param: Union[Application, BaseTestServer],
|
|
||||||
*args: Any,
|
|
||||||
server_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> TestClient[Any, Any]:
|
|
||||||
if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
|
|
||||||
__param, (Application, BaseTestServer)
|
|
||||||
):
|
|
||||||
__param = __param(loop, *args, **kwargs)
|
|
||||||
kwargs = {}
|
|
||||||
else:
|
|
||||||
assert not args, "args should be empty"
|
|
||||||
|
|
||||||
if isinstance(__param, Application):
|
|
||||||
server_kwargs = server_kwargs or {}
|
|
||||||
server = TestServer(__param, loop=loop, **server_kwargs)
|
|
||||||
client = TestClient(server, loop=loop, **kwargs)
|
|
||||||
elif isinstance(__param, BaseTestServer):
|
|
||||||
client = TestClient(__param, loop=loop, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown argument type: %r" % type(__param))
|
|
||||||
|
|
||||||
await client.start_server()
|
|
||||||
clients.append(client)
|
|
||||||
return client
|
|
||||||
|
|
||||||
yield go
|
|
||||||
|
|
||||||
async def finalize() -> None:
|
|
||||||
while clients:
|
|
||||||
await clients.pop().close()
|
|
||||||
|
|
||||||
loop.run_until_complete(finalize())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_client(aiohttp_client): # type: ignore[no-untyped-def] # pragma: no cover
|
|
||||||
warnings.warn(
|
|
||||||
"Deprecated, use aiohttp_client fixture instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return aiohttp_client
|
|
||||||
|
|
@ -1,274 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import socket
|
|
||||||
import weakref
|
|
||||||
from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union
|
|
||||||
|
|
||||||
from .abc import AbstractResolver, ResolveResult
|
|
||||||
|
|
||||||
__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import aiodns
|
|
||||||
|
|
||||||
aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
aiodns = None # type: ignore[assignment]
|
|
||||||
aiodns_default = False
|
|
||||||
|
|
||||||
|
|
||||||
_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
|
|
||||||
_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
|
|
||||||
_AI_ADDRCONFIG = socket.AI_ADDRCONFIG
|
|
||||||
if hasattr(socket, "AI_MASK"):
|
|
||||||
_AI_ADDRCONFIG &= socket.AI_MASK
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadedResolver(AbstractResolver):
|
|
||||||
"""Threaded resolver.
|
|
||||||
|
|
||||||
Uses an Executor for synchronous getaddrinfo() calls.
|
|
||||||
concurrent.futures.ThreadPoolExecutor is used by default.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
|
||||||
self._loop = loop or asyncio.get_running_loop()
|
|
||||||
|
|
||||||
async def resolve(
|
|
||||||
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
|
|
||||||
) -> List[ResolveResult]:
|
|
||||||
infos = await self._loop.getaddrinfo(
|
|
||||||
host,
|
|
||||||
port,
|
|
||||||
type=socket.SOCK_STREAM,
|
|
||||||
family=family,
|
|
||||||
flags=_AI_ADDRCONFIG,
|
|
||||||
)
|
|
||||||
|
|
||||||
hosts: List[ResolveResult] = []
|
|
||||||
for family, _, proto, _, address in infos:
|
|
||||||
if family == socket.AF_INET6:
|
|
||||||
if len(address) < 3:
|
|
||||||
# IPv6 is not supported by Python build,
|
|
||||||
# or IPv6 is not enabled in the host
|
|
||||||
continue
|
|
||||||
if address[3]:
|
|
||||||
# This is essential for link-local IPv6 addresses.
|
|
||||||
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
|
|
||||||
# getnameinfo() unconditionally, but performance makes sense.
|
|
||||||
resolved_host, _port = await self._loop.getnameinfo(
|
|
||||||
address, _NAME_SOCKET_FLAGS
|
|
||||||
)
|
|
||||||
port = int(_port)
|
|
||||||
else:
|
|
||||||
resolved_host, port = address[:2]
|
|
||||||
else: # IPv4
|
|
||||||
assert family == socket.AF_INET
|
|
||||||
resolved_host, port = address # type: ignore[misc]
|
|
||||||
hosts.append(
|
|
||||||
ResolveResult(
|
|
||||||
hostname=host,
|
|
||||||
host=resolved_host,
|
|
||||||
port=port,
|
|
||||||
family=family,
|
|
||||||
proto=proto,
|
|
||||||
flags=_NUMERIC_SOCKET_FLAGS,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return hosts
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncResolver(AbstractResolver):
|
|
||||||
"""Use the `aiodns` package to make asynchronous DNS lookups"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
if aiodns is None:
|
|
||||||
raise RuntimeError("Resolver requires aiodns library")
|
|
||||||
|
|
||||||
self._loop = loop or asyncio.get_running_loop()
|
|
||||||
self._manager: Optional[_DNSResolverManager] = None
|
|
||||||
# If custom args are provided, create a dedicated resolver instance
|
|
||||||
# This means each AsyncResolver with custom args gets its own
|
|
||||||
# aiodns.DNSResolver instance
|
|
||||||
if args or kwargs:
|
|
||||||
self._resolver = aiodns.DNSResolver(*args, **kwargs)
|
|
||||||
return
|
|
||||||
# Use the shared resolver from the manager for default arguments
|
|
||||||
self._manager = _DNSResolverManager()
|
|
||||||
self._resolver = self._manager.get_resolver(self, self._loop)
|
|
||||||
|
|
||||||
if not hasattr(self._resolver, "gethostbyname"):
|
|
||||||
# aiodns 1.1 is not available, fallback to DNSResolver.query
|
|
||||||
self.resolve = self._resolve_with_query # type: ignore
|
|
||||||
|
|
||||||
async def resolve(
|
|
||||||
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
|
|
||||||
) -> List[ResolveResult]:
|
|
||||||
try:
|
|
||||||
resp = await self._resolver.getaddrinfo(
|
|
||||||
host,
|
|
||||||
port=port,
|
|
||||||
type=socket.SOCK_STREAM,
|
|
||||||
family=family,
|
|
||||||
flags=_AI_ADDRCONFIG,
|
|
||||||
)
|
|
||||||
except aiodns.error.DNSError as exc:
|
|
||||||
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
|
|
||||||
raise OSError(None, msg) from exc
|
|
||||||
hosts: List[ResolveResult] = []
|
|
||||||
for node in resp.nodes:
|
|
||||||
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
|
|
||||||
family = node.family
|
|
||||||
if family == socket.AF_INET6:
|
|
||||||
if len(address) > 3 and address[3]:
|
|
||||||
# This is essential for link-local IPv6 addresses.
|
|
||||||
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
|
|
||||||
# getnameinfo() unconditionally, but performance makes sense.
|
|
||||||
result = await self._resolver.getnameinfo(
|
|
||||||
(address[0].decode("ascii"), *address[1:]),
|
|
||||||
_NAME_SOCKET_FLAGS,
|
|
||||||
)
|
|
||||||
resolved_host = result.node
|
|
||||||
else:
|
|
||||||
resolved_host = address[0].decode("ascii")
|
|
||||||
port = address[1]
|
|
||||||
else: # IPv4
|
|
||||||
assert family == socket.AF_INET
|
|
||||||
resolved_host = address[0].decode("ascii")
|
|
||||||
port = address[1]
|
|
||||||
hosts.append(
|
|
||||||
ResolveResult(
|
|
||||||
hostname=host,
|
|
||||||
host=resolved_host,
|
|
||||||
port=port,
|
|
||||||
family=family,
|
|
||||||
proto=0,
|
|
||||||
flags=_NUMERIC_SOCKET_FLAGS,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hosts:
|
|
||||||
raise OSError(None, "DNS lookup failed")
|
|
||||||
|
|
||||||
return hosts
|
|
||||||
|
|
||||||
async def _resolve_with_query(
|
|
||||||
self, host: str, port: int = 0, family: int = socket.AF_INET
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
qtype: Final = "AAAA" if family == socket.AF_INET6 else "A"
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp = await self._resolver.query(host, qtype)
|
|
||||||
except aiodns.error.DNSError as exc:
|
|
||||||
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
|
|
||||||
raise OSError(None, msg) from exc
|
|
||||||
|
|
||||||
hosts = []
|
|
||||||
for rr in resp:
|
|
||||||
hosts.append(
|
|
||||||
{
|
|
||||||
"hostname": host,
|
|
||||||
"host": rr.host,
|
|
||||||
"port": port,
|
|
||||||
"family": family,
|
|
||||||
"proto": 0,
|
|
||||||
"flags": socket.AI_NUMERICHOST,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hosts:
|
|
||||||
raise OSError(None, "DNS lookup failed")
|
|
||||||
|
|
||||||
return hosts
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
if self._manager:
|
|
||||||
# Release the resolver from the manager if using the shared resolver
|
|
||||||
self._manager.release_resolver(self, self._loop)
|
|
||||||
self._manager = None # Clear reference to manager
|
|
||||||
self._resolver = None # type: ignore[assignment] # Clear reference to resolver
|
|
||||||
return
|
|
||||||
# Otherwise cancel our dedicated resolver
|
|
||||||
if self._resolver is not None:
|
|
||||||
self._resolver.cancel()
|
|
||||||
self._resolver = None # type: ignore[assignment] # Clear reference
|
|
||||||
|
|
||||||
|
|
||||||
class _DNSResolverManager:
|
|
||||||
"""Manager for aiodns.DNSResolver objects.
|
|
||||||
|
|
||||||
This class manages shared aiodns.DNSResolver instances
|
|
||||||
with no custom arguments across different event loops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_instance: Optional["_DNSResolverManager"] = None
|
|
||||||
|
|
||||||
def __new__(cls) -> "_DNSResolverManager":
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
cls._instance._init()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def _init(self) -> None:
|
|
||||||
# Use WeakKeyDictionary to allow event loops to be garbage collected
|
|
||||||
self._loop_data: weakref.WeakKeyDictionary[
|
|
||||||
asyncio.AbstractEventLoop,
|
|
||||||
tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]],
|
|
||||||
] = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
def get_resolver(
|
|
||||||
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
|
|
||||||
) -> "aiodns.DNSResolver":
|
|
||||||
"""Get or create the shared aiodns.DNSResolver instance for a specific event loop.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client: The AsyncResolver instance requesting the resolver.
|
|
||||||
This is required to track resolver usage.
|
|
||||||
loop: The event loop to use for the resolver.
|
|
||||||
"""
|
|
||||||
# Create a new resolver and client set for this loop if it doesn't exist
|
|
||||||
if loop not in self._loop_data:
|
|
||||||
resolver = aiodns.DNSResolver(loop=loop)
|
|
||||||
client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet()
|
|
||||||
self._loop_data[loop] = (resolver, client_set)
|
|
||||||
else:
|
|
||||||
# Get the existing resolver and client set
|
|
||||||
resolver, client_set = self._loop_data[loop]
|
|
||||||
|
|
||||||
# Register this client with the loop
|
|
||||||
client_set.add(client)
|
|
||||||
return resolver
|
|
||||||
|
|
||||||
def release_resolver(
|
|
||||||
self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
|
|
||||||
) -> None:
|
|
||||||
"""Release the resolver for an AsyncResolver client when it's closed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client: The AsyncResolver instance to release.
|
|
||||||
loop: The event loop the resolver was using.
|
|
||||||
"""
|
|
||||||
# Remove client from its loop's tracking
|
|
||||||
current_loop_data = self._loop_data.get(loop)
|
|
||||||
if current_loop_data is None:
|
|
||||||
return
|
|
||||||
resolver, client_set = current_loop_data
|
|
||||||
client_set.discard(client)
|
|
||||||
# If no more clients for this loop, cancel and remove its resolver
|
|
||||||
if not client_set:
|
|
||||||
if resolver is not None:
|
|
||||||
resolver.cancel()
|
|
||||||
del self._loop_data[loop]
|
|
||||||
|
|
||||||
|
|
||||||
_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
|
|
||||||
DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver
|
|
||||||
|
|
@ -1,735 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import collections
|
|
||||||
import warnings
|
|
||||||
from typing import (
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Deque,
|
|
||||||
Final,
|
|
||||||
Generic,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base_protocol import BaseProtocol
|
|
||||||
from .helpers import (
|
|
||||||
_EXC_SENTINEL,
|
|
||||||
BaseTimerContext,
|
|
||||||
TimerNoop,
|
|
||||||
set_exception,
|
|
||||||
set_result,
|
|
||||||
)
|
|
||||||
from .log import internal_logger
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"EMPTY_PAYLOAD",
|
|
||||||
"EofStream",
|
|
||||||
"StreamReader",
|
|
||||||
"DataQueue",
|
|
||||||
)
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
|
|
||||||
class EofStream(Exception):
|
|
||||||
"""eof stream indication."""
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncStreamIterator(Generic[_T]):
|
|
||||||
|
|
||||||
__slots__ = ("read_func",)
|
|
||||||
|
|
||||||
def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None:
|
|
||||||
self.read_func = read_func
|
|
||||||
|
|
||||||
def __aiter__(self) -> "AsyncStreamIterator[_T]":
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self) -> _T:
|
|
||||||
try:
|
|
||||||
rv = await self.read_func()
|
|
||||||
except EofStream:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
if rv == b"":
|
|
||||||
raise StopAsyncIteration
|
|
||||||
return rv
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkTupleAsyncStreamIterator:
|
|
||||||
|
|
||||||
__slots__ = ("_stream",)
|
|
||||||
|
|
||||||
def __init__(self, stream: "StreamReader") -> None:
|
|
||||||
self._stream = stream
|
|
||||||
|
|
||||||
def __aiter__(self) -> "ChunkTupleAsyncStreamIterator":
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self) -> Tuple[bytes, bool]:
|
|
||||||
rv = await self._stream.readchunk()
|
|
||||||
if rv == (b"", False):
|
|
||||||
raise StopAsyncIteration
|
|
||||||
return rv
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncStreamReaderMixin:
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __aiter__(self) -> AsyncStreamIterator[bytes]:
|
|
||||||
return AsyncStreamIterator(self.readline) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
|
|
||||||
"""Returns an asynchronous iterator that yields chunks of size n."""
|
|
||||||
return AsyncStreamIterator(lambda: self.read(n)) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def iter_any(self) -> AsyncStreamIterator[bytes]:
|
|
||||||
"""Yield all available data as soon as it is received."""
|
|
||||||
return AsyncStreamIterator(self.readany) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
|
|
||||||
"""Yield chunks of data as they are received by the server.
|
|
||||||
|
|
||||||
The yielded objects are tuples
|
|
||||||
of (bytes, bool) as returned by the StreamReader.readchunk method.
|
|
||||||
"""
|
|
||||||
return ChunkTupleAsyncStreamIterator(self) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
|
|
||||||
class StreamReader(AsyncStreamReaderMixin):
|
|
||||||
"""An enhancement of asyncio.StreamReader.
|
|
||||||
|
|
||||||
Supports asynchronous iteration by line, chunk or as available::
|
|
||||||
|
|
||||||
async for line in reader:
|
|
||||||
...
|
|
||||||
async for chunk in reader.iter_chunked(1024):
|
|
||||||
...
|
|
||||||
async for slice in reader.iter_any():
|
|
||||||
...
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = (
|
|
||||||
"_protocol",
|
|
||||||
"_low_water",
|
|
||||||
"_high_water",
|
|
||||||
"_loop",
|
|
||||||
"_size",
|
|
||||||
"_cursor",
|
|
||||||
"_http_chunk_splits",
|
|
||||||
"_buffer",
|
|
||||||
"_buffer_offset",
|
|
||||||
"_eof",
|
|
||||||
"_waiter",
|
|
||||||
"_eof_waiter",
|
|
||||||
"_exception",
|
|
||||||
"_timer",
|
|
||||||
"_eof_callbacks",
|
|
||||||
"_eof_counter",
|
|
||||||
"total_bytes",
|
|
||||||
"total_compressed_bytes",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
protocol: BaseProtocol,
|
|
||||||
limit: int,
|
|
||||||
*,
|
|
||||||
timer: Optional[BaseTimerContext] = None,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
) -> None:
|
|
||||||
self._protocol = protocol
|
|
||||||
self._low_water = limit
|
|
||||||
self._high_water = limit * 2
|
|
||||||
if loop is None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
self._loop = loop
|
|
||||||
self._size = 0
|
|
||||||
self._cursor = 0
|
|
||||||
self._http_chunk_splits: Optional[List[int]] = None
|
|
||||||
self._buffer: Deque[bytes] = collections.deque()
|
|
||||||
self._buffer_offset = 0
|
|
||||||
self._eof = False
|
|
||||||
self._waiter: Optional[asyncio.Future[None]] = None
|
|
||||||
self._eof_waiter: Optional[asyncio.Future[None]] = None
|
|
||||||
self._exception: Optional[BaseException] = None
|
|
||||||
self._timer = TimerNoop() if timer is None else timer
|
|
||||||
self._eof_callbacks: List[Callable[[], None]] = []
|
|
||||||
self._eof_counter = 0
|
|
||||||
self.total_bytes = 0
|
|
||||||
self.total_compressed_bytes: Optional[int] = None
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
info = [self.__class__.__name__]
|
|
||||||
if self._size:
|
|
||||||
info.append("%d bytes" % self._size)
|
|
||||||
if self._eof:
|
|
||||||
info.append("eof")
|
|
||||||
if self._low_water != 2**16: # default limit
|
|
||||||
info.append("low=%d high=%d" % (self._low_water, self._high_water))
|
|
||||||
if self._waiter:
|
|
||||||
info.append("w=%r" % self._waiter)
|
|
||||||
if self._exception:
|
|
||||||
info.append("e=%r" % self._exception)
|
|
||||||
return "<%s>" % " ".join(info)
|
|
||||||
|
|
||||||
def get_read_buffer_limits(self) -> Tuple[int, int]:
|
|
||||||
return (self._low_water, self._high_water)
|
|
||||||
|
|
||||||
def exception(self) -> Optional[BaseException]:
|
|
||||||
return self._exception
|
|
||||||
|
|
||||||
def set_exception(
|
|
||||||
self,
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: BaseException = _EXC_SENTINEL,
|
|
||||||
) -> None:
|
|
||||||
self._exception = exc
|
|
||||||
self._eof_callbacks.clear()
|
|
||||||
|
|
||||||
waiter = self._waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_exception(waiter, exc, exc_cause)
|
|
||||||
|
|
||||||
waiter = self._eof_waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._eof_waiter = None
|
|
||||||
set_exception(waiter, exc, exc_cause)
|
|
||||||
|
|
||||||
def on_eof(self, callback: Callable[[], None]) -> None:
|
|
||||||
if self._eof:
|
|
||||||
try:
|
|
||||||
callback()
|
|
||||||
except Exception:
|
|
||||||
internal_logger.exception("Exception in eof callback")
|
|
||||||
else:
|
|
||||||
self._eof_callbacks.append(callback)
|
|
||||||
|
|
||||||
def feed_eof(self) -> None:
|
|
||||||
self._eof = True
|
|
||||||
|
|
||||||
waiter = self._waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_result(waiter, None)
|
|
||||||
|
|
||||||
waiter = self._eof_waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._eof_waiter = None
|
|
||||||
set_result(waiter, None)
|
|
||||||
|
|
||||||
if self._protocol._reading_paused:
|
|
||||||
self._protocol.resume_reading()
|
|
||||||
|
|
||||||
for cb in self._eof_callbacks:
|
|
||||||
try:
|
|
||||||
cb()
|
|
||||||
except Exception:
|
|
||||||
internal_logger.exception("Exception in eof callback")
|
|
||||||
|
|
||||||
self._eof_callbacks.clear()
|
|
||||||
|
|
||||||
def is_eof(self) -> bool:
|
|
||||||
"""Return True if 'feed_eof' was called."""
|
|
||||||
return self._eof
|
|
||||||
|
|
||||||
def at_eof(self) -> bool:
|
|
||||||
"""Return True if the buffer is empty and 'feed_eof' was called."""
|
|
||||||
return self._eof and not self._buffer
|
|
||||||
|
|
||||||
async def wait_eof(self) -> None:
|
|
||||||
if self._eof:
|
|
||||||
return
|
|
||||||
|
|
||||||
assert self._eof_waiter is None
|
|
||||||
self._eof_waiter = self._loop.create_future()
|
|
||||||
try:
|
|
||||||
await self._eof_waiter
|
|
||||||
finally:
|
|
||||||
self._eof_waiter = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def total_raw_bytes(self) -> int:
|
|
||||||
if self.total_compressed_bytes is None:
|
|
||||||
return self.total_bytes
|
|
||||||
return self.total_compressed_bytes
|
|
||||||
|
|
||||||
def unread_data(self, data: bytes) -> None:
|
|
||||||
"""rollback reading some data from stream, inserting it to buffer head."""
|
|
||||||
warnings.warn(
|
|
||||||
"unread_data() is deprecated "
|
|
||||||
"and will be removed in future releases (#3260)",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._buffer_offset:
|
|
||||||
self._buffer[0] = self._buffer[0][self._buffer_offset :]
|
|
||||||
self._buffer_offset = 0
|
|
||||||
self._size += len(data)
|
|
||||||
self._cursor -= len(data)
|
|
||||||
self._buffer.appendleft(data)
|
|
||||||
self._eof_counter = 0
|
|
||||||
|
|
||||||
# TODO: size is ignored, remove the param later
|
|
||||||
def feed_data(self, data: bytes, size: int = 0) -> None:
|
|
||||||
assert not self._eof, "feed_data after feed_eof"
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
|
|
||||||
data_len = len(data)
|
|
||||||
self._size += data_len
|
|
||||||
self._buffer.append(data)
|
|
||||||
self.total_bytes += data_len
|
|
||||||
|
|
||||||
waiter = self._waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_result(waiter, None)
|
|
||||||
|
|
||||||
if self._size > self._high_water and not self._protocol._reading_paused:
|
|
||||||
self._protocol.pause_reading()
|
|
||||||
|
|
||||||
def begin_http_chunk_receiving(self) -> None:
|
|
||||||
if self._http_chunk_splits is None:
|
|
||||||
if self.total_bytes:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Called begin_http_chunk_receiving when some data was already fed"
|
|
||||||
)
|
|
||||||
self._http_chunk_splits = []
|
|
||||||
|
|
||||||
def end_http_chunk_receiving(self) -> None:
|
|
||||||
if self._http_chunk_splits is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Called end_chunk_receiving without calling "
|
|
||||||
"begin_chunk_receiving first"
|
|
||||||
)
|
|
||||||
|
|
||||||
# self._http_chunk_splits contains logical byte offsets from start of
|
|
||||||
# the body transfer. Each offset is the offset of the end of a chunk.
|
|
||||||
# "Logical" means bytes, accessible for a user.
|
|
||||||
# If no chunks containing logical data were received, current position
|
|
||||||
# is difinitely zero.
|
|
||||||
pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
|
|
||||||
|
|
||||||
if self.total_bytes == pos:
|
|
||||||
# We should not add empty chunks here. So we check for that.
|
|
||||||
# Note, when chunked + gzip is used, we can receive a chunk
|
|
||||||
# of compressed data, but that data may not be enough for gzip FSM
|
|
||||||
# to yield any uncompressed data. That's why current position may
|
|
||||||
# not change after receiving a chunk.
|
|
||||||
return
|
|
||||||
|
|
||||||
self._http_chunk_splits.append(self.total_bytes)
|
|
||||||
|
|
||||||
# wake up readchunk when end of http chunk received
|
|
||||||
waiter = self._waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_result(waiter, None)
|
|
||||||
|
|
||||||
async def _wait(self, func_name: str) -> None:
|
|
||||||
if not self._protocol.connected:
|
|
||||||
raise RuntimeError("Connection closed.")
|
|
||||||
|
|
||||||
# StreamReader uses a future to link the protocol feed_data() method
|
|
||||||
# to a read coroutine. Running two read coroutines at the same time
|
|
||||||
# would have an unexpected behaviour. It would not possible to know
|
|
||||||
# which coroutine would get the next data.
|
|
||||||
if self._waiter is not None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"%s() called while another coroutine is "
|
|
||||||
"already waiting for incoming data" % func_name
|
|
||||||
)
|
|
||||||
|
|
||||||
waiter = self._waiter = self._loop.create_future()
|
|
||||||
try:
|
|
||||||
with self._timer:
|
|
||||||
await waiter
|
|
||||||
finally:
|
|
||||||
self._waiter = None
|
|
||||||
|
|
||||||
async def readline(self) -> bytes:
|
|
||||||
return await self.readuntil()
|
|
||||||
|
|
||||||
async def readuntil(self, separator: bytes = b"\n") -> bytes:
|
|
||||||
seplen = len(separator)
|
|
||||||
if seplen == 0:
|
|
||||||
raise ValueError("Separator should be at least one-byte string")
|
|
||||||
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
|
|
||||||
chunk = b""
|
|
||||||
chunk_size = 0
|
|
||||||
not_enough = True
|
|
||||||
|
|
||||||
while not_enough:
|
|
||||||
while self._buffer and not_enough:
|
|
||||||
offset = self._buffer_offset
|
|
||||||
ichar = self._buffer[0].find(separator, offset) + 1
|
|
||||||
# Read from current offset to found separator or to the end.
|
|
||||||
data = self._read_nowait_chunk(
|
|
||||||
ichar - offset + seplen - 1 if ichar else -1
|
|
||||||
)
|
|
||||||
chunk += data
|
|
||||||
chunk_size += len(data)
|
|
||||||
if ichar:
|
|
||||||
not_enough = False
|
|
||||||
|
|
||||||
if chunk_size > self._high_water:
|
|
||||||
raise ValueError("Chunk too big")
|
|
||||||
|
|
||||||
if self._eof:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not_enough:
|
|
||||||
await self._wait("readuntil")
|
|
||||||
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
|
|
||||||
# migration problem; with DataQueue you have to catch
|
|
||||||
# EofStream exception, so common way is to run payload.read() inside
|
|
||||||
# infinite loop. what can cause real infinite loop with StreamReader
|
|
||||||
# lets keep this code one major release.
|
|
||||||
if __debug__:
|
|
||||||
if self._eof and not self._buffer:
|
|
||||||
self._eof_counter = getattr(self, "_eof_counter", 0) + 1
|
|
||||||
if self._eof_counter > 5:
|
|
||||||
internal_logger.warning(
|
|
||||||
"Multiple access to StreamReader in eof state, "
|
|
||||||
"might be infinite loop.",
|
|
||||||
stack_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not n:
|
|
||||||
return b""
|
|
||||||
|
|
||||||
if n < 0:
|
|
||||||
# This used to just loop creating a new waiter hoping to
|
|
||||||
# collect everything in self._buffer, but that would
|
|
||||||
# deadlock if the subprocess sends more than self.limit
|
|
||||||
# bytes. So just call self.readany() until EOF.
|
|
||||||
blocks = []
|
|
||||||
while True:
|
|
||||||
block = await self.readany()
|
|
||||||
if not block:
|
|
||||||
break
|
|
||||||
blocks.append(block)
|
|
||||||
return b"".join(blocks)
|
|
||||||
|
|
||||||
# TODO: should be `if` instead of `while`
|
|
||||||
# because waiter maybe triggered on chunk end,
|
|
||||||
# without feeding any data
|
|
||||||
while not self._buffer and not self._eof:
|
|
||||||
await self._wait("read")
|
|
||||||
|
|
||||||
return self._read_nowait(n)
|
|
||||||
|
|
||||||
async def readany(self) -> bytes:
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
|
|
||||||
# TODO: should be `if` instead of `while`
|
|
||||||
# because waiter maybe triggered on chunk end,
|
|
||||||
# without feeding any data
|
|
||||||
while not self._buffer and not self._eof:
|
|
||||||
await self._wait("readany")
|
|
||||||
|
|
||||||
return self._read_nowait(-1)
|
|
||||||
|
|
||||||
async def readchunk(self) -> Tuple[bytes, bool]:
|
|
||||||
"""Returns a tuple of (data, end_of_http_chunk).
|
|
||||||
|
|
||||||
When chunked transfer
|
|
||||||
encoding is used, end_of_http_chunk is a boolean indicating if the end
|
|
||||||
of the data corresponds to the end of a HTTP chunk , otherwise it is
|
|
||||||
always False.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
|
|
||||||
while self._http_chunk_splits:
|
|
||||||
pos = self._http_chunk_splits.pop(0)
|
|
||||||
if pos == self._cursor:
|
|
||||||
return (b"", True)
|
|
||||||
if pos > self._cursor:
|
|
||||||
return (self._read_nowait(pos - self._cursor), True)
|
|
||||||
internal_logger.warning(
|
|
||||||
"Skipping HTTP chunk end due to data "
|
|
||||||
"consumption beyond chunk boundary"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._buffer:
|
|
||||||
return (self._read_nowait_chunk(-1), False)
|
|
||||||
# return (self._read_nowait(-1), False)
|
|
||||||
|
|
||||||
if self._eof:
|
|
||||||
# Special case for signifying EOF.
|
|
||||||
# (b'', True) is not a final return value actually.
|
|
||||||
return (b"", False)
|
|
||||||
|
|
||||||
await self._wait("readchunk")
|
|
||||||
|
|
||||||
async def readexactly(self, n: int) -> bytes:
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
|
|
||||||
blocks: List[bytes] = []
|
|
||||||
while n > 0:
|
|
||||||
block = await self.read(n)
|
|
||||||
if not block:
|
|
||||||
partial = b"".join(blocks)
|
|
||||||
raise asyncio.IncompleteReadError(partial, len(partial) + n)
|
|
||||||
blocks.append(block)
|
|
||||||
n -= len(block)
|
|
||||||
|
|
||||||
return b"".join(blocks)
|
|
||||||
|
|
||||||
def read_nowait(self, n: int = -1) -> bytes:
|
|
||||||
# default was changed to be consistent with .read(-1)
|
|
||||||
#
|
|
||||||
# I believe the most users don't know about the method and
|
|
||||||
# they are not affected.
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
|
|
||||||
if self._waiter and not self._waiter.done():
|
|
||||||
raise RuntimeError(
|
|
||||||
"Called while some coroutine is waiting for incoming data."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._read_nowait(n)
|
|
||||||
|
|
||||||
def _read_nowait_chunk(self, n: int) -> bytes:
|
|
||||||
first_buffer = self._buffer[0]
|
|
||||||
offset = self._buffer_offset
|
|
||||||
if n != -1 and len(first_buffer) - offset > n:
|
|
||||||
data = first_buffer[offset : offset + n]
|
|
||||||
self._buffer_offset += n
|
|
||||||
|
|
||||||
elif offset:
|
|
||||||
self._buffer.popleft()
|
|
||||||
data = first_buffer[offset:]
|
|
||||||
self._buffer_offset = 0
|
|
||||||
|
|
||||||
else:
|
|
||||||
data = self._buffer.popleft()
|
|
||||||
|
|
||||||
data_len = len(data)
|
|
||||||
self._size -= data_len
|
|
||||||
self._cursor += data_len
|
|
||||||
|
|
||||||
chunk_splits = self._http_chunk_splits
|
|
||||||
# Prevent memory leak: drop useless chunk splits
|
|
||||||
while chunk_splits and chunk_splits[0] < self._cursor:
|
|
||||||
chunk_splits.pop(0)
|
|
||||||
|
|
||||||
if self._size < self._low_water and self._protocol._reading_paused:
|
|
||||||
self._protocol.resume_reading()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _read_nowait(self, n: int) -> bytes:
|
|
||||||
"""Read not more than n bytes, or whole buffer if n == -1"""
|
|
||||||
self._timer.assert_timeout()
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
while self._buffer:
|
|
||||||
chunk = self._read_nowait_chunk(n)
|
|
||||||
chunks.append(chunk)
|
|
||||||
if n != -1:
|
|
||||||
n -= len(chunk)
|
|
||||||
if n == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
return b"".join(chunks) if chunks else b""
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
|
|
||||||
|
|
||||||
__slots__ = ("_read_eof_chunk",)
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._read_eof_chunk = False
|
|
||||||
self.total_bytes = 0
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "<%s>" % self.__class__.__name__
|
|
||||||
|
|
||||||
def exception(self) -> Optional[BaseException]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def set_exception(
|
|
||||||
self,
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: BaseException = _EXC_SENTINEL,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_eof(self, callback: Callable[[], None]) -> None:
|
|
||||||
try:
|
|
||||||
callback()
|
|
||||||
except Exception:
|
|
||||||
internal_logger.exception("Exception in eof callback")
|
|
||||||
|
|
||||||
def feed_eof(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def is_eof(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def at_eof(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def wait_eof(self) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def feed_data(self, data: bytes, n: int = 0) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def readline(self) -> bytes:
|
|
||||||
return b""
|
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
|
||||||
return b""
|
|
||||||
|
|
||||||
# TODO add async def readuntil
|
|
||||||
|
|
||||||
async def readany(self) -> bytes:
|
|
||||||
return b""
|
|
||||||
|
|
||||||
async def readchunk(self) -> Tuple[bytes, bool]:
|
|
||||||
if not self._read_eof_chunk:
|
|
||||||
self._read_eof_chunk = True
|
|
||||||
return (b"", False)
|
|
||||||
|
|
||||||
return (b"", True)
|
|
||||||
|
|
||||||
async def readexactly(self, n: int) -> bytes:
|
|
||||||
raise asyncio.IncompleteReadError(b"", n)
|
|
||||||
|
|
||||||
def read_nowait(self, n: int = -1) -> bytes:
|
|
||||||
return b""
|
|
||||||
|
|
||||||
|
|
||||||
EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()
|
|
||||||
|
|
||||||
|
|
||||||
class DataQueue(Generic[_T]):
|
|
||||||
"""DataQueue is a general-purpose blocking queue with one reader."""
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
|
||||||
self._loop = loop
|
|
||||||
self._eof = False
|
|
||||||
self._waiter: Optional[asyncio.Future[None]] = None
|
|
||||||
self._exception: Optional[BaseException] = None
|
|
||||||
self._buffer: Deque[Tuple[_T, int]] = collections.deque()
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self._buffer)
|
|
||||||
|
|
||||||
def is_eof(self) -> bool:
|
|
||||||
return self._eof
|
|
||||||
|
|
||||||
def at_eof(self) -> bool:
|
|
||||||
return self._eof and not self._buffer
|
|
||||||
|
|
||||||
def exception(self) -> Optional[BaseException]:
|
|
||||||
return self._exception
|
|
||||||
|
|
||||||
def set_exception(
|
|
||||||
self,
|
|
||||||
exc: BaseException,
|
|
||||||
exc_cause: BaseException = _EXC_SENTINEL,
|
|
||||||
) -> None:
|
|
||||||
self._eof = True
|
|
||||||
self._exception = exc
|
|
||||||
if (waiter := self._waiter) is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_exception(waiter, exc, exc_cause)
|
|
||||||
|
|
||||||
def feed_data(self, data: _T, size: int = 0) -> None:
|
|
||||||
self._buffer.append((data, size))
|
|
||||||
if (waiter := self._waiter) is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_result(waiter, None)
|
|
||||||
|
|
||||||
def feed_eof(self) -> None:
|
|
||||||
self._eof = True
|
|
||||||
if (waiter := self._waiter) is not None:
|
|
||||||
self._waiter = None
|
|
||||||
set_result(waiter, None)
|
|
||||||
|
|
||||||
async def read(self) -> _T:
|
|
||||||
if not self._buffer and not self._eof:
|
|
||||||
assert not self._waiter
|
|
||||||
self._waiter = self._loop.create_future()
|
|
||||||
try:
|
|
||||||
await self._waiter
|
|
||||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
||||||
self._waiter = None
|
|
||||||
raise
|
|
||||||
if self._buffer:
|
|
||||||
data, _ = self._buffer.popleft()
|
|
||||||
return data
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
raise EofStream
|
|
||||||
|
|
||||||
def __aiter__(self) -> AsyncStreamIterator[_T]:
|
|
||||||
return AsyncStreamIterator(self.read)
|
|
||||||
|
|
||||||
|
|
||||||
class FlowControlDataQueue(DataQueue[_T]):
|
|
||||||
"""FlowControlDataQueue resumes and pauses an underlying stream.
|
|
||||||
|
|
||||||
It is a destination for parsed data.
|
|
||||||
|
|
||||||
This class is deprecated and will be removed in version 4.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
|
||||||
) -> None:
|
|
||||||
super().__init__(loop=loop)
|
|
||||||
self._size = 0
|
|
||||||
self._protocol = protocol
|
|
||||||
self._limit = limit * 2
|
|
||||||
|
|
||||||
def feed_data(self, data: _T, size: int = 0) -> None:
|
|
||||||
super().feed_data(data, size)
|
|
||||||
self._size += size
|
|
||||||
|
|
||||||
if self._size > self._limit and not self._protocol._reading_paused:
|
|
||||||
self._protocol.pause_reading()
|
|
||||||
|
|
||||||
async def read(self) -> _T:
|
|
||||||
if not self._buffer and not self._eof:
|
|
||||||
assert not self._waiter
|
|
||||||
self._waiter = self._loop.create_future()
|
|
||||||
try:
|
|
||||||
await self._waiter
|
|
||||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
||||||
self._waiter = None
|
|
||||||
raise
|
|
||||||
if self._buffer:
|
|
||||||
data, size = self._buffer.popleft()
|
|
||||||
self._size -= size
|
|
||||||
if self._size < self._limit and self._protocol._reading_paused:
|
|
||||||
self._protocol.resume_reading()
|
|
||||||
return data
|
|
||||||
if self._exception is not None:
|
|
||||||
raise self._exception
|
|
||||||
raise EofStream
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
"""Helper methods to tune a TCP connection"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import socket
|
|
||||||
from contextlib import suppress
|
|
||||||
from typing import Optional # noqa
|
|
||||||
|
|
||||||
__all__ = ("tcp_keepalive", "tcp_nodelay")
|
|
||||||
|
|
||||||
|
|
||||||
if hasattr(socket, "SO_KEEPALIVE"):
|
|
||||||
|
|
||||||
def tcp_keepalive(transport: asyncio.Transport) -> None:
|
|
||||||
sock = transport.get_extra_info("socket")
|
|
||||||
if sock is not None:
|
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def tcp_keepalive(transport: asyncio.Transport) -> None: # pragma: no cover
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None:
|
|
||||||
sock = transport.get_extra_info("socket")
|
|
||||||
|
|
||||||
if sock is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if sock.family not in (socket.AF_INET, socket.AF_INET6):
|
|
||||||
return
|
|
||||||
|
|
||||||
value = bool(value)
|
|
||||||
|
|
||||||
# socket may be closed already, on windows OSError get raised
|
|
||||||
with suppress(OSError):
|
|
||||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
|
|
||||||
|
|
@ -1,774 +0,0 @@
|
||||||
"""Utilities shared by tests."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
import gc
|
|
||||||
import inspect
|
|
||||||
import ipaddress
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from types import TracebackType
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Generic,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
cast,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
from unittest import IsolatedAsyncioTestCase, mock
|
|
||||||
|
|
||||||
from aiosignal import Signal
|
|
||||||
from multidict import CIMultiDict, CIMultiDictProxy
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from aiohttp.client import (
|
|
||||||
_RequestContextManager,
|
|
||||||
_RequestOptions,
|
|
||||||
_WSRequestContextManager,
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import ClientSession, hdrs
|
|
||||||
from .abc import AbstractCookieJar
|
|
||||||
from .client_reqrep import ClientResponse
|
|
||||||
from .client_ws import ClientWebSocketResponse
|
|
||||||
from .helpers import sentinel
|
|
||||||
from .http import HttpVersion, RawRequestMessage
|
|
||||||
from .streams import EMPTY_PAYLOAD, StreamReader
|
|
||||||
from .typedefs import StrOrURL
|
|
||||||
from .web import (
|
|
||||||
Application,
|
|
||||||
AppRunner,
|
|
||||||
BaseRequest,
|
|
||||||
BaseRunner,
|
|
||||||
Request,
|
|
||||||
Server,
|
|
||||||
ServerRunner,
|
|
||||||
SockSite,
|
|
||||||
UrlMappingMatchInfo,
|
|
||||||
)
|
|
||||||
from .web_protocol import _RequestHandler
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ssl import SSLContext
|
|
||||||
else:
|
|
||||||
SSLContext = None
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11) and TYPE_CHECKING:
|
|
||||||
from typing import Unpack
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
|
||||||
from typing import Self
|
|
||||||
else:
|
|
||||||
Self = Any
|
|
||||||
|
|
||||||
_ApplicationNone = TypeVar("_ApplicationNone", Application, None)
|
|
||||||
_Request = TypeVar("_Request", bound=BaseRequest)
|
|
||||||
|
|
||||||
REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
|
|
||||||
|
|
||||||
|
|
||||||
def get_unused_port_socket(
|
|
||||||
host: str, family: socket.AddressFamily = socket.AF_INET
|
|
||||||
) -> socket.socket:
|
|
||||||
return get_port_socket(host, 0, family)
|
|
||||||
|
|
||||||
|
|
||||||
def get_port_socket(
|
|
||||||
host: str, port: int, family: socket.AddressFamily
|
|
||||||
) -> socket.socket:
|
|
||||||
s = socket.socket(family, socket.SOCK_STREAM)
|
|
||||||
if REUSE_ADDRESS:
|
|
||||||
# Windows has different semantics for SO_REUSEADDR,
|
|
||||||
# so don't set it. Ref:
|
|
||||||
# https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
|
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
s.bind((host, port))
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
def unused_port() -> int:
|
|
||||||
"""Return a port that is unused on the current host."""
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("127.0.0.1", 0))
|
|
||||||
return cast(int, s.getsockname()[1])
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTestServer(ABC):
|
|
||||||
__test__ = False
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
scheme: str = "",
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
host: str = "127.0.0.1",
|
|
||||||
port: Optional[int] = None,
|
|
||||||
skip_url_asserts: bool = False,
|
|
||||||
socket_factory: Callable[
|
|
||||||
[str, int, socket.AddressFamily], socket.socket
|
|
||||||
] = get_port_socket,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self._loop = loop
|
|
||||||
self.runner: Optional[BaseRunner] = None
|
|
||||||
self._root: Optional[URL] = None
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self._closed = False
|
|
||||||
self.scheme = scheme
|
|
||||||
self.skip_url_asserts = skip_url_asserts
|
|
||||||
self.socket_factory = socket_factory
|
|
||||||
|
|
||||||
async def start_server(
|
|
||||||
self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
if self.runner:
|
|
||||||
return
|
|
||||||
self._loop = loop
|
|
||||||
self._ssl = kwargs.pop("ssl", None)
|
|
||||||
self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
|
|
||||||
await self.runner.setup()
|
|
||||||
if not self.port:
|
|
||||||
self.port = 0
|
|
||||||
absolute_host = self.host
|
|
||||||
try:
|
|
||||||
version = ipaddress.ip_address(self.host).version
|
|
||||||
except ValueError:
|
|
||||||
version = 4
|
|
||||||
if version == 6:
|
|
||||||
absolute_host = f"[{self.host}]"
|
|
||||||
family = socket.AF_INET6 if version == 6 else socket.AF_INET
|
|
||||||
_sock = self.socket_factory(self.host, self.port, family)
|
|
||||||
self.host, self.port = _sock.getsockname()[:2]
|
|
||||||
site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
|
|
||||||
await site.start()
|
|
||||||
server = site._server
|
|
||||||
assert server is not None
|
|
||||||
sockets = server.sockets # type: ignore[attr-defined]
|
|
||||||
assert sockets is not None
|
|
||||||
self.port = sockets[0].getsockname()[1]
|
|
||||||
if not self.scheme:
|
|
||||||
self.scheme = "https" if self._ssl else "http"
|
|
||||||
self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
|
|
||||||
|
|
||||||
@abstractmethod # pragma: no cover
|
|
||||||
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def make_url(self, path: StrOrURL) -> URL:
|
|
||||||
assert self._root is not None
|
|
||||||
url = URL(path)
|
|
||||||
if not self.skip_url_asserts:
|
|
||||||
assert not url.absolute
|
|
||||||
return self._root.join(url)
|
|
||||||
else:
|
|
||||||
return URL(str(self._root) + str(path))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def started(self) -> bool:
|
|
||||||
return self.runner is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def closed(self) -> bool:
|
|
||||||
return self._closed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def handler(self) -> Server:
|
|
||||||
# for backward compatibility
|
|
||||||
# web.Server instance
|
|
||||||
runner = self.runner
|
|
||||||
assert runner is not None
|
|
||||||
assert runner.server is not None
|
|
||||||
return runner.server
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close all fixtures created by the test client.
|
|
||||||
|
|
||||||
After that point, the TestClient is no longer usable.
|
|
||||||
|
|
||||||
This is an idempotent function: running close multiple times
|
|
||||||
will not have any additional effects.
|
|
||||||
|
|
||||||
close is also run when the object is garbage collected, and on
|
|
||||||
exit when used as a context manager.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.started and not self.closed:
|
|
||||||
assert self.runner is not None
|
|
||||||
await self.runner.cleanup()
|
|
||||||
self._root = None
|
|
||||||
self.port = None
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
|
||||||
raise TypeError("Use async with instead")
|
|
||||||
|
|
||||||
def __exit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc_value: Optional[BaseException],
|
|
||||||
traceback: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
# __exit__ should exist in pair with __enter__ but never executed
|
|
||||||
pass # pragma: no cover
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "BaseTestServer":
|
|
||||||
await self.start_server(loop=self._loop)
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc_value: Optional[BaseException],
|
|
||||||
traceback: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
|
|
||||||
class TestServer(BaseTestServer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
app: Application,
|
|
||||||
*,
|
|
||||||
scheme: str = "",
|
|
||||||
host: str = "127.0.0.1",
|
|
||||||
port: Optional[int] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
self.app = app
|
|
||||||
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
|
|
||||||
|
|
||||||
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
|
|
||||||
return AppRunner(self.app, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class RawTestServer(BaseTestServer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
handler: _RequestHandler,
|
|
||||||
*,
|
|
||||||
scheme: str = "",
|
|
||||||
host: str = "127.0.0.1",
|
|
||||||
port: Optional[int] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self._handler = handler
|
|
||||||
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
|
|
||||||
|
|
||||||
async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner:
|
|
||||||
srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs)
|
|
||||||
return ServerRunner(srv, debug=debug, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class TestClient(Generic[_Request, _ApplicationNone]):
|
|
||||||
"""
|
|
||||||
A test client implementation.
|
|
||||||
|
|
||||||
To write functional tests for aiohttp based servers.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
__test__ = False
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __init__(
|
|
||||||
self: "TestClient[Request, Application]",
|
|
||||||
server: TestServer,
|
|
||||||
*,
|
|
||||||
cookie_jar: Optional[AbstractCookieJar] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None: ...
|
|
||||||
@overload
|
|
||||||
def __init__(
|
|
||||||
self: "TestClient[_Request, None]",
|
|
||||||
server: BaseTestServer,
|
|
||||||
*,
|
|
||||||
cookie_jar: Optional[AbstractCookieJar] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None: ...
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server: BaseTestServer,
|
|
||||||
*,
|
|
||||||
cookie_jar: Optional[AbstractCookieJar] = None,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
if not isinstance(server, BaseTestServer):
|
|
||||||
raise TypeError(
|
|
||||||
"server must be TestServer instance, found type: %r" % type(server)
|
|
||||||
)
|
|
||||||
self._server = server
|
|
||||||
self._loop = loop
|
|
||||||
if cookie_jar is None:
|
|
||||||
cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop)
|
|
||||||
self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs)
|
|
||||||
self._session._retry_connection = False
|
|
||||||
self._closed = False
|
|
||||||
self._responses: List[ClientResponse] = []
|
|
||||||
self._websockets: List[ClientWebSocketResponse] = []
|
|
||||||
|
|
||||||
async def start_server(self) -> None:
|
|
||||||
await self._server.start_server(loop=self._loop)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def host(self) -> str:
|
|
||||||
return self._server.host
|
|
||||||
|
|
||||||
@property
|
|
||||||
def port(self) -> Optional[int]:
|
|
||||||
return self._server.port
|
|
||||||
|
|
||||||
@property
|
|
||||||
def server(self) -> BaseTestServer:
|
|
||||||
return self._server
|
|
||||||
|
|
||||||
@property
|
|
||||||
def app(self) -> _ApplicationNone:
|
|
||||||
return getattr(self._server, "app", None) # type: ignore[return-value]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def session(self) -> ClientSession:
|
|
||||||
"""An internal aiohttp.ClientSession.
|
|
||||||
|
|
||||||
Unlike the methods on the TestClient, client session requests
|
|
||||||
do not automatically include the host in the url queried, and
|
|
||||||
will require an absolute path to the resource.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
def make_url(self, path: StrOrURL) -> URL:
|
|
||||||
return self._server.make_url(path)
|
|
||||||
|
|
||||||
async def _request(
|
|
||||||
self, method: str, path: StrOrURL, **kwargs: Any
|
|
||||||
) -> ClientResponse:
|
|
||||||
resp = await self._session.request(method, self.make_url(path), **kwargs)
|
|
||||||
# save it to close later
|
|
||||||
self._responses.append(resp)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11) and TYPE_CHECKING:
|
|
||||||
|
|
||||||
def request(
|
|
||||||
self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
path: StrOrURL,
|
|
||||||
**kwargs: Unpack[_RequestOptions],
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
def options(
|
|
||||||
self,
|
|
||||||
path: StrOrURL,
|
|
||||||
**kwargs: Unpack[_RequestOptions],
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
def head(
|
|
||||||
self,
|
|
||||||
path: StrOrURL,
|
|
||||||
**kwargs: Unpack[_RequestOptions],
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
def post(
|
|
||||||
self,
|
|
||||||
path: StrOrURL,
|
|
||||||
**kwargs: Unpack[_RequestOptions],
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
def put(
|
|
||||||
self,
|
|
||||||
path: StrOrURL,
|
|
||||||
**kwargs: Unpack[_RequestOptions],
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
def patch(
|
|
||||||
self,
|
|
||||||
path: StrOrURL,
|
|
||||||
**kwargs: Unpack[_RequestOptions],
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
path: StrOrURL,
|
|
||||||
**kwargs: Unpack[_RequestOptions],
|
|
||||||
) -> _RequestContextManager: ...
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def request(
|
|
||||||
self, method: str, path: StrOrURL, **kwargs: Any
|
|
||||||
) -> _RequestContextManager:
|
|
||||||
"""Routes a request to tested http server.
|
|
||||||
|
|
||||||
The interface is identical to aiohttp.ClientSession.request,
|
|
||||||
except the loop kwarg is overridden by the instance used by the
|
|
||||||
test server.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return _RequestContextManager(self._request(method, path, **kwargs))
|
|
||||||
|
|
||||||
def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
|
|
||||||
"""Perform an HTTP GET request."""
|
|
||||||
return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
|
|
||||||
|
|
||||||
def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
|
|
||||||
"""Perform an HTTP POST request."""
|
|
||||||
return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
|
|
||||||
|
|
||||||
def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
|
|
||||||
"""Perform an HTTP OPTIONS request."""
|
|
||||||
return _RequestContextManager(
|
|
||||||
self._request(hdrs.METH_OPTIONS, path, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
|
|
||||||
"""Perform an HTTP HEAD request."""
|
|
||||||
return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
|
|
||||||
|
|
||||||
def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
|
|
||||||
"""Perform an HTTP PUT request."""
|
|
||||||
return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
|
|
||||||
|
|
||||||
def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
|
|
||||||
"""Perform an HTTP PATCH request."""
|
|
||||||
return _RequestContextManager(
|
|
||||||
self._request(hdrs.METH_PATCH, path, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
|
|
||||||
"""Perform an HTTP PATCH request."""
|
|
||||||
return _RequestContextManager(
|
|
||||||
self._request(hdrs.METH_DELETE, path, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
|
|
||||||
"""Initiate websocket connection.
|
|
||||||
|
|
||||||
The api corresponds to aiohttp.ClientSession.ws_connect.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return _WSRequestContextManager(self._ws_connect(path, **kwargs))
|
|
||||||
|
|
||||||
async def _ws_connect(
|
|
||||||
self, path: StrOrURL, **kwargs: Any
|
|
||||||
) -> ClientWebSocketResponse:
|
|
||||||
ws = await self._session.ws_connect(self.make_url(path), **kwargs)
|
|
||||||
self._websockets.append(ws)
|
|
||||||
return ws
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close all fixtures created by the test client.
|
|
||||||
|
|
||||||
After that point, the TestClient is no longer usable.
|
|
||||||
|
|
||||||
This is an idempotent function: running close multiple times
|
|
||||||
will not have any additional effects.
|
|
||||||
|
|
||||||
close is also run on exit when used as a(n) (asynchronous)
|
|
||||||
context manager.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not self._closed:
|
|
||||||
for resp in self._responses:
|
|
||||||
resp.close()
|
|
||||||
for ws in self._websockets:
|
|
||||||
await ws.close()
|
|
||||||
await self._session.close()
|
|
||||||
await self._server.close()
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
|
||||||
raise TypeError("Use async with instead")
|
|
||||||
|
|
||||||
def __exit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc: Optional[BaseException],
|
|
||||||
tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
# __exit__ should exist in pair with __enter__ but never executed
|
|
||||||
pass # pragma: no cover
|
|
||||||
|
|
||||||
async def __aenter__(self) -> Self:
|
|
||||||
await self.start_server()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[Type[BaseException]],
|
|
||||||
exc: Optional[BaseException],
|
|
||||||
tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
|
|
||||||
class AioHTTPTestCase(IsolatedAsyncioTestCase):
|
|
||||||
"""A base class to allow for unittest web applications using aiohttp.
|
|
||||||
|
|
||||||
Provides the following:
|
|
||||||
|
|
||||||
* self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
|
|
||||||
* self.loop (asyncio.BaseEventLoop): the event loop in which the
|
|
||||||
application and server are running.
|
|
||||||
* self.app (aiohttp.web.Application): the application returned by
|
|
||||||
self.get_application()
|
|
||||||
|
|
||||||
Note that the TestClient's methods are asynchronous: you have to
|
|
||||||
execute function on the test client using asynchronous methods.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def get_application(self) -> Application:
|
|
||||||
"""Get application.
|
|
||||||
|
|
||||||
This method should be overridden
|
|
||||||
to return the aiohttp.web.Application
|
|
||||||
object to test.
|
|
||||||
"""
|
|
||||||
return self.get_app()
|
|
||||||
|
|
||||||
def get_app(self) -> Application:
|
|
||||||
"""Obsolete method used to constructing web application.
|
|
||||||
|
|
||||||
Use .get_application() coroutine instead.
|
|
||||||
"""
|
|
||||||
raise RuntimeError("Did you forget to define get_application()?")
|
|
||||||
|
|
||||||
async def asyncSetUp(self) -> None:
|
|
||||||
self.loop = asyncio.get_running_loop()
|
|
||||||
return await self.setUpAsync()
|
|
||||||
|
|
||||||
async def setUpAsync(self) -> None:
|
|
||||||
self.app = await self.get_application()
|
|
||||||
self.server = await self.get_server(self.app)
|
|
||||||
self.client = await self.get_client(self.server)
|
|
||||||
|
|
||||||
await self.client.start_server()
|
|
||||||
|
|
||||||
async def asyncTearDown(self) -> None:
|
|
||||||
return await self.tearDownAsync()
|
|
||||||
|
|
||||||
async def tearDownAsync(self) -> None:
|
|
||||||
await self.client.close()
|
|
||||||
|
|
||||||
async def get_server(self, app: Application) -> TestServer:
|
|
||||||
"""Return a TestServer instance."""
|
|
||||||
return TestServer(app, loop=self.loop)
|
|
||||||
|
|
||||||
async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
|
|
||||||
"""Return a TestClient instance."""
|
|
||||||
return TestClient(server, loop=self.loop)
|
|
||||||
|
|
||||||
|
|
||||||
def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
"""
|
|
||||||
A decorator dedicated to use with asynchronous AioHTTPTestCase test methods.
|
|
||||||
|
|
||||||
In 3.8+, this does nothing.
|
|
||||||
"""
|
|
||||||
warnings.warn(
|
|
||||||
"Decorator `@unittest_run_loop` is no longer needed in aiohttp 3.8+",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def loop_context(
|
|
||||||
loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
|
|
||||||
) -> Iterator[asyncio.AbstractEventLoop]:
|
|
||||||
"""A contextmanager that creates an event_loop, for test purposes.
|
|
||||||
|
|
||||||
Handles the creation and cleanup of a test loop.
|
|
||||||
"""
|
|
||||||
loop = setup_test_loop(loop_factory)
|
|
||||||
yield loop
|
|
||||||
teardown_test_loop(loop, fast=fast)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_test_loop(
|
|
||||||
loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
|
|
||||||
) -> asyncio.AbstractEventLoop:
|
|
||||||
"""Create and return an asyncio.BaseEventLoop instance.
|
|
||||||
|
|
||||||
The caller should also call teardown_test_loop,
|
|
||||||
once they are done with the loop.
|
|
||||||
"""
|
|
||||||
loop = loop_factory()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
return loop
|
|
||||||
|
|
||||||
|
|
||||||
def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
|
|
||||||
"""Teardown and cleanup an event_loop created by setup_test_loop."""
|
|
||||||
closed = loop.is_closed()
|
|
||||||
if not closed:
|
|
||||||
loop.call_soon(loop.stop)
|
|
||||||
loop.run_forever()
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
if not fast:
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
asyncio.set_event_loop(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_app_mock() -> mock.MagicMock:
|
|
||||||
def get_dict(app: Any, key: str) -> Any:
|
|
||||||
return app.__app_dict[key]
|
|
||||||
|
|
||||||
def set_dict(app: Any, key: str, value: Any) -> None:
|
|
||||||
app.__app_dict[key] = value
|
|
||||||
|
|
||||||
app = mock.MagicMock(spec=Application)
|
|
||||||
app.__app_dict = {}
|
|
||||||
app.__getitem__ = get_dict
|
|
||||||
app.__setitem__ = set_dict
|
|
||||||
|
|
||||||
app._debug = False
|
|
||||||
app.on_response_prepare = Signal(app)
|
|
||||||
app.on_response_prepare.freeze()
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
|
|
||||||
transport = mock.Mock()
|
|
||||||
|
|
||||||
def get_extra_info(key: str) -> Optional[SSLContext]:
|
|
||||||
if key == "sslcontext":
|
|
||||||
return sslcontext
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
transport.get_extra_info.side_effect = get_extra_info
|
|
||||||
return transport
|
|
||||||
|
|
||||||
|
|
||||||
def make_mocked_request(
|
|
||||||
method: str,
|
|
||||||
path: str,
|
|
||||||
headers: Any = None,
|
|
||||||
*,
|
|
||||||
match_info: Any = sentinel,
|
|
||||||
version: HttpVersion = HttpVersion(1, 1),
|
|
||||||
closing: bool = False,
|
|
||||||
app: Any = None,
|
|
||||||
writer: Any = sentinel,
|
|
||||||
protocol: Any = sentinel,
|
|
||||||
transport: Any = sentinel,
|
|
||||||
payload: StreamReader = EMPTY_PAYLOAD,
|
|
||||||
sslcontext: Optional[SSLContext] = None,
|
|
||||||
client_max_size: int = 1024**2,
|
|
||||||
loop: Any = ...,
|
|
||||||
) -> Request:
|
|
||||||
"""Creates mocked web.Request testing purposes.
|
|
||||||
|
|
||||||
Useful in unit tests, when spinning full web server is overkill or
|
|
||||||
specific conditions and errors are hard to trigger.
|
|
||||||
"""
|
|
||||||
task = mock.Mock()
|
|
||||||
if loop is ...:
|
|
||||||
# no loop passed, try to get the current one if
|
|
||||||
# its is running as we need a real loop to create
|
|
||||||
# executor jobs to be able to do testing
|
|
||||||
# with a real executor
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = mock.Mock()
|
|
||||||
loop.create_future.return_value = ()
|
|
||||||
|
|
||||||
if version < HttpVersion(1, 1):
|
|
||||||
closing = True
|
|
||||||
|
|
||||||
if headers:
|
|
||||||
headers = CIMultiDictProxy(CIMultiDict(headers))
|
|
||||||
raw_hdrs = tuple(
|
|
||||||
(k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
headers = CIMultiDictProxy(CIMultiDict())
|
|
||||||
raw_hdrs = ()
|
|
||||||
|
|
||||||
chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
|
|
||||||
|
|
||||||
message = RawRequestMessage(
|
|
||||||
method,
|
|
||||||
path,
|
|
||||||
version,
|
|
||||||
headers,
|
|
||||||
raw_hdrs,
|
|
||||||
closing,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
chunked,
|
|
||||||
URL(path),
|
|
||||||
)
|
|
||||||
if app is None:
|
|
||||||
app = _create_app_mock()
|
|
||||||
|
|
||||||
if transport is sentinel:
|
|
||||||
transport = _create_transport(sslcontext)
|
|
||||||
|
|
||||||
if protocol is sentinel:
|
|
||||||
protocol = mock.Mock()
|
|
||||||
protocol.transport = transport
|
|
||||||
type(protocol).peername = mock.PropertyMock(
|
|
||||||
return_value=transport.get_extra_info("peername")
|
|
||||||
)
|
|
||||||
type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext)
|
|
||||||
|
|
||||||
if writer is sentinel:
|
|
||||||
writer = mock.Mock()
|
|
||||||
writer.write_headers = make_mocked_coro(None)
|
|
||||||
writer.write = make_mocked_coro(None)
|
|
||||||
writer.write_eof = make_mocked_coro(None)
|
|
||||||
writer.drain = make_mocked_coro(None)
|
|
||||||
writer.transport = transport
|
|
||||||
|
|
||||||
protocol.transport = transport
|
|
||||||
protocol.writer = writer
|
|
||||||
|
|
||||||
req = Request(
|
|
||||||
message, payload, protocol, writer, task, loop, client_max_size=client_max_size
|
|
||||||
)
|
|
||||||
|
|
||||||
match_info = UrlMappingMatchInfo(
|
|
||||||
{} if match_info is sentinel else match_info, mock.Mock()
|
|
||||||
)
|
|
||||||
match_info.add_app(app)
|
|
||||||
req._match_info = match_info
|
|
||||||
|
|
||||||
return req
|
|
||||||
|
|
||||||
|
|
||||||
def make_mocked_coro(
|
|
||||||
return_value: Any = sentinel, raise_exception: Any = sentinel
|
|
||||||
) -> Any:
|
|
||||||
"""Creates a coroutine mock."""
|
|
||||||
|
|
||||||
async def mock_coro(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
if raise_exception is not sentinel:
|
|
||||||
raise raise_exception
|
|
||||||
if not inspect.isawaitable(return_value):
|
|
||||||
return return_value
|
|
||||||
await return_value
|
|
||||||
|
|
||||||
return mock.Mock(wraps=mock_coro)
|
|
||||||
|
|
@ -1,455 +0,0 @@
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import TYPE_CHECKING, Mapping, Optional, Type, TypeVar
|
|
||||||
|
|
||||||
import attr
|
|
||||||
from aiosignal import Signal
|
|
||||||
from multidict import CIMultiDict
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from .client_reqrep import ClientResponse
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .client import ClientSession
|
|
||||||
|
|
||||||
_ParamT_contra = TypeVar("_ParamT_contra", contravariant=True)
|
|
||||||
_TracingSignal = Signal[ClientSession, SimpleNamespace, _ParamT_contra]
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"TraceConfig",
|
|
||||||
"TraceRequestStartParams",
|
|
||||||
"TraceRequestEndParams",
|
|
||||||
"TraceRequestExceptionParams",
|
|
||||||
"TraceConnectionQueuedStartParams",
|
|
||||||
"TraceConnectionQueuedEndParams",
|
|
||||||
"TraceConnectionCreateStartParams",
|
|
||||||
"TraceConnectionCreateEndParams",
|
|
||||||
"TraceConnectionReuseconnParams",
|
|
||||||
"TraceDnsResolveHostStartParams",
|
|
||||||
"TraceDnsResolveHostEndParams",
|
|
||||||
"TraceDnsCacheHitParams",
|
|
||||||
"TraceDnsCacheMissParams",
|
|
||||||
"TraceRequestRedirectParams",
|
|
||||||
"TraceRequestChunkSentParams",
|
|
||||||
"TraceResponseChunkReceivedParams",
|
|
||||||
"TraceRequestHeadersSentParams",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TraceConfig:
|
|
||||||
"""First-class used to trace requests launched via ClientSession objects."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace
|
|
||||||
) -> None:
|
|
||||||
self._on_request_start: _TracingSignal[TraceRequestStartParams] = Signal(self)
|
|
||||||
self._on_request_chunk_sent: _TracingSignal[TraceRequestChunkSentParams] = (
|
|
||||||
Signal(self)
|
|
||||||
)
|
|
||||||
self._on_response_chunk_received: _TracingSignal[
|
|
||||||
TraceResponseChunkReceivedParams
|
|
||||||
] = Signal(self)
|
|
||||||
self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(self)
|
|
||||||
self._on_request_exception: _TracingSignal[TraceRequestExceptionParams] = (
|
|
||||||
Signal(self)
|
|
||||||
)
|
|
||||||
self._on_request_redirect: _TracingSignal[TraceRequestRedirectParams] = Signal(
|
|
||||||
self
|
|
||||||
)
|
|
||||||
self._on_connection_queued_start: _TracingSignal[
|
|
||||||
TraceConnectionQueuedStartParams
|
|
||||||
] = Signal(self)
|
|
||||||
self._on_connection_queued_end: _TracingSignal[
|
|
||||||
TraceConnectionQueuedEndParams
|
|
||||||
] = Signal(self)
|
|
||||||
self._on_connection_create_start: _TracingSignal[
|
|
||||||
TraceConnectionCreateStartParams
|
|
||||||
] = Signal(self)
|
|
||||||
self._on_connection_create_end: _TracingSignal[
|
|
||||||
TraceConnectionCreateEndParams
|
|
||||||
] = Signal(self)
|
|
||||||
self._on_connection_reuseconn: _TracingSignal[
|
|
||||||
TraceConnectionReuseconnParams
|
|
||||||
] = Signal(self)
|
|
||||||
self._on_dns_resolvehost_start: _TracingSignal[
|
|
||||||
TraceDnsResolveHostStartParams
|
|
||||||
] = Signal(self)
|
|
||||||
self._on_dns_resolvehost_end: _TracingSignal[TraceDnsResolveHostEndParams] = (
|
|
||||||
Signal(self)
|
|
||||||
)
|
|
||||||
self._on_dns_cache_hit: _TracingSignal[TraceDnsCacheHitParams] = Signal(self)
|
|
||||||
self._on_dns_cache_miss: _TracingSignal[TraceDnsCacheMissParams] = Signal(self)
|
|
||||||
self._on_request_headers_sent: _TracingSignal[TraceRequestHeadersSentParams] = (
|
|
||||||
Signal(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._trace_config_ctx_factory = trace_config_ctx_factory
|
|
||||||
|
|
||||||
def trace_config_ctx(
|
|
||||||
self, trace_request_ctx: Optional[Mapping[str, str]] = None
|
|
||||||
) -> SimpleNamespace:
|
|
||||||
"""Return a new trace_config_ctx instance"""
|
|
||||||
return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx)
|
|
||||||
|
|
||||||
def freeze(self) -> None:
|
|
||||||
self._on_request_start.freeze()
|
|
||||||
self._on_request_chunk_sent.freeze()
|
|
||||||
self._on_response_chunk_received.freeze()
|
|
||||||
self._on_request_end.freeze()
|
|
||||||
self._on_request_exception.freeze()
|
|
||||||
self._on_request_redirect.freeze()
|
|
||||||
self._on_connection_queued_start.freeze()
|
|
||||||
self._on_connection_queued_end.freeze()
|
|
||||||
self._on_connection_create_start.freeze()
|
|
||||||
self._on_connection_create_end.freeze()
|
|
||||||
self._on_connection_reuseconn.freeze()
|
|
||||||
self._on_dns_resolvehost_start.freeze()
|
|
||||||
self._on_dns_resolvehost_end.freeze()
|
|
||||||
self._on_dns_cache_hit.freeze()
|
|
||||||
self._on_dns_cache_miss.freeze()
|
|
||||||
self._on_request_headers_sent.freeze()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_request_start(self) -> "_TracingSignal[TraceRequestStartParams]":
|
|
||||||
return self._on_request_start
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_request_chunk_sent(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceRequestChunkSentParams]":
|
|
||||||
return self._on_request_chunk_sent
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_response_chunk_received(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceResponseChunkReceivedParams]":
|
|
||||||
return self._on_response_chunk_received
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_request_end(self) -> "_TracingSignal[TraceRequestEndParams]":
|
|
||||||
return self._on_request_end
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_request_exception(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceRequestExceptionParams]":
|
|
||||||
return self._on_request_exception
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_request_redirect(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceRequestRedirectParams]":
|
|
||||||
return self._on_request_redirect
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_connection_queued_start(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceConnectionQueuedStartParams]":
|
|
||||||
return self._on_connection_queued_start
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_connection_queued_end(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceConnectionQueuedEndParams]":
|
|
||||||
return self._on_connection_queued_end
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_connection_create_start(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceConnectionCreateStartParams]":
|
|
||||||
return self._on_connection_create_start
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_connection_create_end(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceConnectionCreateEndParams]":
|
|
||||||
return self._on_connection_create_end
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_connection_reuseconn(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceConnectionReuseconnParams]":
|
|
||||||
return self._on_connection_reuseconn
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_dns_resolvehost_start(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceDnsResolveHostStartParams]":
|
|
||||||
return self._on_dns_resolvehost_start
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_dns_resolvehost_end(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceDnsResolveHostEndParams]":
|
|
||||||
return self._on_dns_resolvehost_end
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_dns_cache_hit(self) -> "_TracingSignal[TraceDnsCacheHitParams]":
|
|
||||||
return self._on_dns_cache_hit
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_dns_cache_miss(self) -> "_TracingSignal[TraceDnsCacheMissParams]":
|
|
||||||
return self._on_dns_cache_miss
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_request_headers_sent(
|
|
||||||
self,
|
|
||||||
) -> "_TracingSignal[TraceRequestHeadersSentParams]":
|
|
||||||
return self._on_request_headers_sent
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceRequestStartParams:
|
|
||||||
"""Parameters sent by the `on_request_start` signal"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
url: URL
|
|
||||||
headers: "CIMultiDict[str]"
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceRequestChunkSentParams:
|
|
||||||
"""Parameters sent by the `on_request_chunk_sent` signal"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
url: URL
|
|
||||||
chunk: bytes
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceResponseChunkReceivedParams:
|
|
||||||
"""Parameters sent by the `on_response_chunk_received` signal"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
url: URL
|
|
||||||
chunk: bytes
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceRequestEndParams:
|
|
||||||
"""Parameters sent by the `on_request_end` signal"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
url: URL
|
|
||||||
headers: "CIMultiDict[str]"
|
|
||||||
response: ClientResponse
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceRequestExceptionParams:
|
|
||||||
"""Parameters sent by the `on_request_exception` signal"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
url: URL
|
|
||||||
headers: "CIMultiDict[str]"
|
|
||||||
exception: BaseException
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceRequestRedirectParams:
|
|
||||||
"""Parameters sent by the `on_request_redirect` signal"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
url: URL
|
|
||||||
headers: "CIMultiDict[str]"
|
|
||||||
response: ClientResponse
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceConnectionQueuedStartParams:
|
|
||||||
"""Parameters sent by the `on_connection_queued_start` signal"""
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceConnectionQueuedEndParams:
|
|
||||||
"""Parameters sent by the `on_connection_queued_end` signal"""
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceConnectionCreateStartParams:
|
|
||||||
"""Parameters sent by the `on_connection_create_start` signal"""
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceConnectionCreateEndParams:
|
|
||||||
"""Parameters sent by the `on_connection_create_end` signal"""
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceConnectionReuseconnParams:
|
|
||||||
"""Parameters sent by the `on_connection_reuseconn` signal"""
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceDnsResolveHostStartParams:
|
|
||||||
"""Parameters sent by the `on_dns_resolvehost_start` signal"""
|
|
||||||
|
|
||||||
host: str
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceDnsResolveHostEndParams:
|
|
||||||
"""Parameters sent by the `on_dns_resolvehost_end` signal"""
|
|
||||||
|
|
||||||
host: str
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceDnsCacheHitParams:
|
|
||||||
"""Parameters sent by the `on_dns_cache_hit` signal"""
|
|
||||||
|
|
||||||
host: str
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceDnsCacheMissParams:
|
|
||||||
"""Parameters sent by the `on_dns_cache_miss` signal"""
|
|
||||||
|
|
||||||
host: str
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
|
||||||
class TraceRequestHeadersSentParams:
|
|
||||||
"""Parameters sent by the `on_request_headers_sent` signal"""
|
|
||||||
|
|
||||||
method: str
|
|
||||||
url: URL
|
|
||||||
headers: "CIMultiDict[str]"
|
|
||||||
|
|
||||||
|
|
||||||
class Trace:
|
|
||||||
"""Internal dependency holder class.
|
|
||||||
|
|
||||||
Used to keep together the main dependencies used
|
|
||||||
at the moment of send a signal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
session: "ClientSession",
|
|
||||||
trace_config: TraceConfig,
|
|
||||||
trace_config_ctx: SimpleNamespace,
|
|
||||||
) -> None:
|
|
||||||
self._trace_config = trace_config
|
|
||||||
self._trace_config_ctx = trace_config_ctx
|
|
||||||
self._session = session
|
|
||||||
|
|
||||||
async def send_request_start(
|
|
||||||
self, method: str, url: URL, headers: "CIMultiDict[str]"
|
|
||||||
) -> None:
|
|
||||||
return await self._trace_config.on_request_start.send(
|
|
||||||
self._session,
|
|
||||||
self._trace_config_ctx,
|
|
||||||
TraceRequestStartParams(method, url, headers),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_request_chunk_sent(
|
|
||||||
self, method: str, url: URL, chunk: bytes
|
|
||||||
) -> None:
|
|
||||||
return await self._trace_config.on_request_chunk_sent.send(
|
|
||||||
self._session,
|
|
||||||
self._trace_config_ctx,
|
|
||||||
TraceRequestChunkSentParams(method, url, chunk),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_response_chunk_received(
|
|
||||||
self, method: str, url: URL, chunk: bytes
|
|
||||||
) -> None:
|
|
||||||
return await self._trace_config.on_response_chunk_received.send(
|
|
||||||
self._session,
|
|
||||||
self._trace_config_ctx,
|
|
||||||
TraceResponseChunkReceivedParams(method, url, chunk),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_request_end(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
url: URL,
|
|
||||||
headers: "CIMultiDict[str]",
|
|
||||||
response: ClientResponse,
|
|
||||||
) -> None:
|
|
||||||
return await self._trace_config.on_request_end.send(
|
|
||||||
self._session,
|
|
||||||
self._trace_config_ctx,
|
|
||||||
TraceRequestEndParams(method, url, headers, response),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_request_exception(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
url: URL,
|
|
||||||
headers: "CIMultiDict[str]",
|
|
||||||
exception: BaseException,
|
|
||||||
) -> None:
|
|
||||||
return await self._trace_config.on_request_exception.send(
|
|
||||||
self._session,
|
|
||||||
self._trace_config_ctx,
|
|
||||||
TraceRequestExceptionParams(method, url, headers, exception),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_request_redirect(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
url: URL,
|
|
||||||
headers: "CIMultiDict[str]",
|
|
||||||
response: ClientResponse,
|
|
||||||
) -> None:
|
|
||||||
return await self._trace_config._on_request_redirect.send(
|
|
||||||
self._session,
|
|
||||||
self._trace_config_ctx,
|
|
||||||
TraceRequestRedirectParams(method, url, headers, response),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_connection_queued_start(self) -> None:
|
|
||||||
return await self._trace_config.on_connection_queued_start.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceConnectionQueuedStartParams()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_connection_queued_end(self) -> None:
|
|
||||||
return await self._trace_config.on_connection_queued_end.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceConnectionQueuedEndParams()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_connection_create_start(self) -> None:
|
|
||||||
return await self._trace_config.on_connection_create_start.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceConnectionCreateStartParams()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_connection_create_end(self) -> None:
|
|
||||||
return await self._trace_config.on_connection_create_end.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceConnectionCreateEndParams()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_connection_reuseconn(self) -> None:
|
|
||||||
return await self._trace_config.on_connection_reuseconn.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceConnectionReuseconnParams()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_dns_resolvehost_start(self, host: str) -> None:
|
|
||||||
return await self._trace_config.on_dns_resolvehost_start.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceDnsResolveHostStartParams(host)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_dns_resolvehost_end(self, host: str) -> None:
|
|
||||||
return await self._trace_config.on_dns_resolvehost_end.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceDnsResolveHostEndParams(host)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_dns_cache_hit(self, host: str) -> None:
|
|
||||||
return await self._trace_config.on_dns_cache_hit.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceDnsCacheHitParams(host)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_dns_cache_miss(self, host: str) -> None:
|
|
||||||
return await self._trace_config.on_dns_cache_miss.send(
|
|
||||||
self._session, self._trace_config_ctx, TraceDnsCacheMissParams(host)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_request_headers(
|
|
||||||
self, method: str, url: URL, headers: "CIMultiDict[str]"
|
|
||||||
) -> None:
|
|
||||||
return await self._trace_config._on_request_headers_sent.send(
|
|
||||||
self._session,
|
|
||||||
self._trace_config_ctx,
|
|
||||||
TraceRequestHeadersSentParams(method, url, headers),
|
|
||||||
)
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
Mapping,
|
|
||||||
Protocol,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr
|
|
||||||
from yarl import URL, Query as _Query
|
|
||||||
|
|
||||||
Query = _Query
|
|
||||||
|
|
||||||
DEFAULT_JSON_ENCODER = json.dumps
|
|
||||||
DEFAULT_JSON_DECODER = json.loads
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
_CIMultiDict = CIMultiDict[str]
|
|
||||||
_CIMultiDictProxy = CIMultiDictProxy[str]
|
|
||||||
_MultiDict = MultiDict[str]
|
|
||||||
_MultiDictProxy = MultiDictProxy[str]
|
|
||||||
from http.cookies import BaseCookie, Morsel
|
|
||||||
|
|
||||||
from .web import Request, StreamResponse
|
|
||||||
else:
|
|
||||||
_CIMultiDict = CIMultiDict
|
|
||||||
_CIMultiDictProxy = CIMultiDictProxy
|
|
||||||
_MultiDict = MultiDict
|
|
||||||
_MultiDictProxy = MultiDictProxy
|
|
||||||
|
|
||||||
Byteish = Union[bytes, bytearray, memoryview]
|
|
||||||
JSONEncoder = Callable[[Any], str]
|
|
||||||
JSONDecoder = Callable[[str], Any]
|
|
||||||
LooseHeaders = Union[
|
|
||||||
Mapping[str, str],
|
|
||||||
Mapping[istr, str],
|
|
||||||
_CIMultiDict,
|
|
||||||
_CIMultiDictProxy,
|
|
||||||
Iterable[Tuple[Union[str, istr], str]],
|
|
||||||
]
|
|
||||||
RawHeaders = Tuple[Tuple[bytes, bytes], ...]
|
|
||||||
StrOrURL = Union[str, URL]
|
|
||||||
|
|
||||||
LooseCookiesMappings = Mapping[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]]
|
|
||||||
LooseCookiesIterables = Iterable[
|
|
||||||
Tuple[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]]
|
|
||||||
]
|
|
||||||
LooseCookies = Union[
|
|
||||||
LooseCookiesMappings,
|
|
||||||
LooseCookiesIterables,
|
|
||||||
"BaseCookie[str]",
|
|
||||||
]
|
|
||||||
|
|
||||||
Handler = Callable[["Request"], Awaitable["StreamResponse"]]
|
|
||||||
|
|
||||||
|
|
||||||
class Middleware(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self, request: "Request", handler: Handler
|
|
||||||
) -> Awaitable["StreamResponse"]: ...
|
|
||||||
|
|
||||||
|
|
||||||
PathLike = Union[str, "os.PathLike[str]"]
|
|
||||||
|
|
@ -1,592 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from contextlib import suppress
|
|
||||||
from importlib import import_module
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Iterable as TypingIterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .abc import AbstractAccessLogger
|
|
||||||
from .helpers import AppKey as AppKey
|
|
||||||
from .log import access_logger
|
|
||||||
from .typedefs import PathLike
|
|
||||||
from .web_app import Application as Application, CleanupError as CleanupError
|
|
||||||
from .web_exceptions import (
|
|
||||||
HTTPAccepted as HTTPAccepted,
|
|
||||||
HTTPBadGateway as HTTPBadGateway,
|
|
||||||
HTTPBadRequest as HTTPBadRequest,
|
|
||||||
HTTPClientError as HTTPClientError,
|
|
||||||
HTTPConflict as HTTPConflict,
|
|
||||||
HTTPCreated as HTTPCreated,
|
|
||||||
HTTPError as HTTPError,
|
|
||||||
HTTPException as HTTPException,
|
|
||||||
HTTPExpectationFailed as HTTPExpectationFailed,
|
|
||||||
HTTPFailedDependency as HTTPFailedDependency,
|
|
||||||
HTTPForbidden as HTTPForbidden,
|
|
||||||
HTTPFound as HTTPFound,
|
|
||||||
HTTPGatewayTimeout as HTTPGatewayTimeout,
|
|
||||||
HTTPGone as HTTPGone,
|
|
||||||
HTTPInsufficientStorage as HTTPInsufficientStorage,
|
|
||||||
HTTPInternalServerError as HTTPInternalServerError,
|
|
||||||
HTTPLengthRequired as HTTPLengthRequired,
|
|
||||||
HTTPMethodNotAllowed as HTTPMethodNotAllowed,
|
|
||||||
HTTPMisdirectedRequest as HTTPMisdirectedRequest,
|
|
||||||
HTTPMove as HTTPMove,
|
|
||||||
HTTPMovedPermanently as HTTPMovedPermanently,
|
|
||||||
HTTPMultipleChoices as HTTPMultipleChoices,
|
|
||||||
HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired,
|
|
||||||
HTTPNoContent as HTTPNoContent,
|
|
||||||
HTTPNonAuthoritativeInformation as HTTPNonAuthoritativeInformation,
|
|
||||||
HTTPNotAcceptable as HTTPNotAcceptable,
|
|
||||||
HTTPNotExtended as HTTPNotExtended,
|
|
||||||
HTTPNotFound as HTTPNotFound,
|
|
||||||
HTTPNotImplemented as HTTPNotImplemented,
|
|
||||||
HTTPNotModified as HTTPNotModified,
|
|
||||||
HTTPOk as HTTPOk,
|
|
||||||
HTTPPartialContent as HTTPPartialContent,
|
|
||||||
HTTPPaymentRequired as HTTPPaymentRequired,
|
|
||||||
HTTPPermanentRedirect as HTTPPermanentRedirect,
|
|
||||||
HTTPPreconditionFailed as HTTPPreconditionFailed,
|
|
||||||
HTTPPreconditionRequired as HTTPPreconditionRequired,
|
|
||||||
HTTPProxyAuthenticationRequired as HTTPProxyAuthenticationRequired,
|
|
||||||
HTTPRedirection as HTTPRedirection,
|
|
||||||
HTTPRequestEntityTooLarge as HTTPRequestEntityTooLarge,
|
|
||||||
HTTPRequestHeaderFieldsTooLarge as HTTPRequestHeaderFieldsTooLarge,
|
|
||||||
HTTPRequestRangeNotSatisfiable as HTTPRequestRangeNotSatisfiable,
|
|
||||||
HTTPRequestTimeout as HTTPRequestTimeout,
|
|
||||||
HTTPRequestURITooLong as HTTPRequestURITooLong,
|
|
||||||
HTTPResetContent as HTTPResetContent,
|
|
||||||
HTTPSeeOther as HTTPSeeOther,
|
|
||||||
HTTPServerError as HTTPServerError,
|
|
||||||
HTTPServiceUnavailable as HTTPServiceUnavailable,
|
|
||||||
HTTPSuccessful as HTTPSuccessful,
|
|
||||||
HTTPTemporaryRedirect as HTTPTemporaryRedirect,
|
|
||||||
HTTPTooManyRequests as HTTPTooManyRequests,
|
|
||||||
HTTPUnauthorized as HTTPUnauthorized,
|
|
||||||
HTTPUnavailableForLegalReasons as HTTPUnavailableForLegalReasons,
|
|
||||||
HTTPUnprocessableEntity as HTTPUnprocessableEntity,
|
|
||||||
HTTPUnsupportedMediaType as HTTPUnsupportedMediaType,
|
|
||||||
HTTPUpgradeRequired as HTTPUpgradeRequired,
|
|
||||||
HTTPUseProxy as HTTPUseProxy,
|
|
||||||
HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates,
|
|
||||||
HTTPVersionNotSupported as HTTPVersionNotSupported,
|
|
||||||
NotAppKeyWarning as NotAppKeyWarning,
|
|
||||||
)
|
|
||||||
from .web_fileresponse import FileResponse as FileResponse
|
|
||||||
from .web_log import AccessLogger
|
|
||||||
from .web_middlewares import (
|
|
||||||
middleware as middleware,
|
|
||||||
normalize_path_middleware as normalize_path_middleware,
|
|
||||||
)
|
|
||||||
from .web_protocol import (
|
|
||||||
PayloadAccessError as PayloadAccessError,
|
|
||||||
RequestHandler as RequestHandler,
|
|
||||||
RequestPayloadError as RequestPayloadError,
|
|
||||||
)
|
|
||||||
from .web_request import (
|
|
||||||
BaseRequest as BaseRequest,
|
|
||||||
FileField as FileField,
|
|
||||||
Request as Request,
|
|
||||||
)
|
|
||||||
from .web_response import (
|
|
||||||
ContentCoding as ContentCoding,
|
|
||||||
Response as Response,
|
|
||||||
StreamResponse as StreamResponse,
|
|
||||||
json_response as json_response,
|
|
||||||
)
|
|
||||||
from .web_routedef import (
|
|
||||||
AbstractRouteDef as AbstractRouteDef,
|
|
||||||
RouteDef as RouteDef,
|
|
||||||
RouteTableDef as RouteTableDef,
|
|
||||||
StaticDef as StaticDef,
|
|
||||||
delete as delete,
|
|
||||||
get as get,
|
|
||||||
head as head,
|
|
||||||
options as options,
|
|
||||||
patch as patch,
|
|
||||||
post as post,
|
|
||||||
put as put,
|
|
||||||
route as route,
|
|
||||||
static as static,
|
|
||||||
view as view,
|
|
||||||
)
|
|
||||||
from .web_runner import (
|
|
||||||
AppRunner as AppRunner,
|
|
||||||
BaseRunner as BaseRunner,
|
|
||||||
BaseSite as BaseSite,
|
|
||||||
GracefulExit as GracefulExit,
|
|
||||||
NamedPipeSite as NamedPipeSite,
|
|
||||||
ServerRunner as ServerRunner,
|
|
||||||
SockSite as SockSite,
|
|
||||||
TCPSite as TCPSite,
|
|
||||||
UnixSite as UnixSite,
|
|
||||||
)
|
|
||||||
from .web_server import Server as Server
|
|
||||||
from .web_urldispatcher import (
|
|
||||||
AbstractResource as AbstractResource,
|
|
||||||
AbstractRoute as AbstractRoute,
|
|
||||||
DynamicResource as DynamicResource,
|
|
||||||
PlainResource as PlainResource,
|
|
||||||
PrefixedSubAppResource as PrefixedSubAppResource,
|
|
||||||
Resource as Resource,
|
|
||||||
ResourceRoute as ResourceRoute,
|
|
||||||
StaticResource as StaticResource,
|
|
||||||
UrlDispatcher as UrlDispatcher,
|
|
||||||
UrlMappingMatchInfo as UrlMappingMatchInfo,
|
|
||||||
View as View,
|
|
||||||
)
|
|
||||||
from .web_ws import (
|
|
||||||
WebSocketReady as WebSocketReady,
|
|
||||||
WebSocketResponse as WebSocketResponse,
|
|
||||||
WSMsgType as WSMsgType,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
# web_app
|
|
||||||
"AppKey",
|
|
||||||
"Application",
|
|
||||||
"CleanupError",
|
|
||||||
# web_exceptions
|
|
||||||
"NotAppKeyWarning",
|
|
||||||
"HTTPAccepted",
|
|
||||||
"HTTPBadGateway",
|
|
||||||
"HTTPBadRequest",
|
|
||||||
"HTTPClientError",
|
|
||||||
"HTTPConflict",
|
|
||||||
"HTTPCreated",
|
|
||||||
"HTTPError",
|
|
||||||
"HTTPException",
|
|
||||||
"HTTPExpectationFailed",
|
|
||||||
"HTTPFailedDependency",
|
|
||||||
"HTTPForbidden",
|
|
||||||
"HTTPFound",
|
|
||||||
"HTTPGatewayTimeout",
|
|
||||||
"HTTPGone",
|
|
||||||
"HTTPInsufficientStorage",
|
|
||||||
"HTTPInternalServerError",
|
|
||||||
"HTTPLengthRequired",
|
|
||||||
"HTTPMethodNotAllowed",
|
|
||||||
"HTTPMisdirectedRequest",
|
|
||||||
"HTTPMove",
|
|
||||||
"HTTPMovedPermanently",
|
|
||||||
"HTTPMultipleChoices",
|
|
||||||
"HTTPNetworkAuthenticationRequired",
|
|
||||||
"HTTPNoContent",
|
|
||||||
"HTTPNonAuthoritativeInformation",
|
|
||||||
"HTTPNotAcceptable",
|
|
||||||
"HTTPNotExtended",
|
|
||||||
"HTTPNotFound",
|
|
||||||
"HTTPNotImplemented",
|
|
||||||
"HTTPNotModified",
|
|
||||||
"HTTPOk",
|
|
||||||
"HTTPPartialContent",
|
|
||||||
"HTTPPaymentRequired",
|
|
||||||
"HTTPPermanentRedirect",
|
|
||||||
"HTTPPreconditionFailed",
|
|
||||||
"HTTPPreconditionRequired",
|
|
||||||
"HTTPProxyAuthenticationRequired",
|
|
||||||
"HTTPRedirection",
|
|
||||||
"HTTPRequestEntityTooLarge",
|
|
||||||
"HTTPRequestHeaderFieldsTooLarge",
|
|
||||||
"HTTPRequestRangeNotSatisfiable",
|
|
||||||
"HTTPRequestTimeout",
|
|
||||||
"HTTPRequestURITooLong",
|
|
||||||
"HTTPResetContent",
|
|
||||||
"HTTPSeeOther",
|
|
||||||
"HTTPServerError",
|
|
||||||
"HTTPServiceUnavailable",
|
|
||||||
"HTTPSuccessful",
|
|
||||||
"HTTPTemporaryRedirect",
|
|
||||||
"HTTPTooManyRequests",
|
|
||||||
"HTTPUnauthorized",
|
|
||||||
"HTTPUnavailableForLegalReasons",
|
|
||||||
"HTTPUnprocessableEntity",
|
|
||||||
"HTTPUnsupportedMediaType",
|
|
||||||
"HTTPUpgradeRequired",
|
|
||||||
"HTTPUseProxy",
|
|
||||||
"HTTPVariantAlsoNegotiates",
|
|
||||||
"HTTPVersionNotSupported",
|
|
||||||
# web_fileresponse
|
|
||||||
"FileResponse",
|
|
||||||
# web_middlewares
|
|
||||||
"middleware",
|
|
||||||
"normalize_path_middleware",
|
|
||||||
# web_protocol
|
|
||||||
"PayloadAccessError",
|
|
||||||
"RequestHandler",
|
|
||||||
"RequestPayloadError",
|
|
||||||
# web_request
|
|
||||||
"BaseRequest",
|
|
||||||
"FileField",
|
|
||||||
"Request",
|
|
||||||
# web_response
|
|
||||||
"ContentCoding",
|
|
||||||
"Response",
|
|
||||||
"StreamResponse",
|
|
||||||
"json_response",
|
|
||||||
# web_routedef
|
|
||||||
"AbstractRouteDef",
|
|
||||||
"RouteDef",
|
|
||||||
"RouteTableDef",
|
|
||||||
"StaticDef",
|
|
||||||
"delete",
|
|
||||||
"get",
|
|
||||||
"head",
|
|
||||||
"options",
|
|
||||||
"patch",
|
|
||||||
"post",
|
|
||||||
"put",
|
|
||||||
"route",
|
|
||||||
"static",
|
|
||||||
"view",
|
|
||||||
# web_runner
|
|
||||||
"AppRunner",
|
|
||||||
"BaseRunner",
|
|
||||||
"BaseSite",
|
|
||||||
"GracefulExit",
|
|
||||||
"ServerRunner",
|
|
||||||
"SockSite",
|
|
||||||
"TCPSite",
|
|
||||||
"UnixSite",
|
|
||||||
"NamedPipeSite",
|
|
||||||
# web_server
|
|
||||||
"Server",
|
|
||||||
# web_urldispatcher
|
|
||||||
"AbstractResource",
|
|
||||||
"AbstractRoute",
|
|
||||||
"DynamicResource",
|
|
||||||
"PlainResource",
|
|
||||||
"PrefixedSubAppResource",
|
|
||||||
"Resource",
|
|
||||||
"ResourceRoute",
|
|
||||||
"StaticResource",
|
|
||||||
"UrlDispatcher",
|
|
||||||
"UrlMappingMatchInfo",
|
|
||||||
"View",
|
|
||||||
# web_ws
|
|
||||||
"WebSocketReady",
|
|
||||||
"WebSocketResponse",
|
|
||||||
"WSMsgType",
|
|
||||||
# web
|
|
||||||
"run_app",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ssl import SSLContext
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from ssl import SSLContext
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
SSLContext = object # type: ignore[misc,assignment]
|
|
||||||
|
|
||||||
# Only display warning when using -Wdefault, -We, -X dev or similar.
|
|
||||||
warnings.filterwarnings("ignore", category=NotAppKeyWarning, append=True)
|
|
||||||
|
|
||||||
HostSequence = TypingIterable[str]
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_app(
|
|
||||||
app: Union[Application, Awaitable[Application]],
|
|
||||||
*,
|
|
||||||
host: Optional[Union[str, HostSequence]] = None,
|
|
||||||
port: Optional[int] = None,
|
|
||||||
path: Union[PathLike, TypingIterable[PathLike], None] = None,
|
|
||||||
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None,
|
|
||||||
ssl_context: Optional[SSLContext] = None,
|
|
||||||
print: Optional[Callable[..., None]] = print,
|
|
||||||
backlog: int = 128,
|
|
||||||
reuse_address: Optional[bool] = None,
|
|
||||||
reuse_port: Optional[bool] = None,
|
|
||||||
**kwargs: Any, # TODO(PY311): Use Unpack
|
|
||||||
) -> None:
|
|
||||||
# An internal function to actually do all dirty job for application running
|
|
||||||
if asyncio.iscoroutine(app):
|
|
||||||
app = await app
|
|
||||||
|
|
||||||
app = cast(Application, app)
|
|
||||||
|
|
||||||
runner = AppRunner(app, **kwargs)
|
|
||||||
|
|
||||||
await runner.setup()
|
|
||||||
|
|
||||||
sites: List[BaseSite] = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
if host is not None:
|
|
||||||
if isinstance(host, str):
|
|
||||||
sites.append(
|
|
||||||
TCPSite(
|
|
||||||
runner,
|
|
||||||
host,
|
|
||||||
port,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
backlog=backlog,
|
|
||||||
reuse_address=reuse_address,
|
|
||||||
reuse_port=reuse_port,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for h in host:
|
|
||||||
sites.append(
|
|
||||||
TCPSite(
|
|
||||||
runner,
|
|
||||||
h,
|
|
||||||
port,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
backlog=backlog,
|
|
||||||
reuse_address=reuse_address,
|
|
||||||
reuse_port=reuse_port,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif path is None and sock is None or port is not None:
|
|
||||||
sites.append(
|
|
||||||
TCPSite(
|
|
||||||
runner,
|
|
||||||
port=port,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
backlog=backlog,
|
|
||||||
reuse_address=reuse_address,
|
|
||||||
reuse_port=reuse_port,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if path is not None:
|
|
||||||
if isinstance(path, (str, os.PathLike)):
|
|
||||||
sites.append(
|
|
||||||
UnixSite(
|
|
||||||
runner,
|
|
||||||
path,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
backlog=backlog,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for p in path:
|
|
||||||
sites.append(
|
|
||||||
UnixSite(
|
|
||||||
runner,
|
|
||||||
p,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
backlog=backlog,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if sock is not None:
|
|
||||||
if not isinstance(sock, Iterable):
|
|
||||||
sites.append(
|
|
||||||
SockSite(
|
|
||||||
runner,
|
|
||||||
sock,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
backlog=backlog,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for s in sock:
|
|
||||||
sites.append(
|
|
||||||
SockSite(
|
|
||||||
runner,
|
|
||||||
s,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
backlog=backlog,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for site in sites:
|
|
||||||
await site.start()
|
|
||||||
|
|
||||||
if print: # pragma: no branch
|
|
||||||
names = sorted(str(s.name) for s in runner.sites)
|
|
||||||
print(
|
|
||||||
"======== Running on {} ========\n"
|
|
||||||
"(Press CTRL+C to quit)".format(", ".join(names))
|
|
||||||
)
|
|
||||||
|
|
||||||
# sleep forever by 1 hour intervals,
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(3600)
|
|
||||||
finally:
|
|
||||||
await runner.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
def _cancel_tasks(
|
|
||||||
to_cancel: Set["asyncio.Task[Any]"], loop: asyncio.AbstractEventLoop
|
|
||||||
) -> None:
|
|
||||||
if not to_cancel:
|
|
||||||
return
|
|
||||||
|
|
||||||
for task in to_cancel:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
|
||||||
|
|
||||||
for task in to_cancel:
|
|
||||||
if task.cancelled():
|
|
||||||
continue
|
|
||||||
if task.exception() is not None:
|
|
||||||
loop.call_exception_handler(
|
|
||||||
{
|
|
||||||
"message": "unhandled exception during asyncio.run() shutdown",
|
|
||||||
"exception": task.exception(),
|
|
||||||
"task": task,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_app(
|
|
||||||
app: Union[Application, Awaitable[Application]],
|
|
||||||
*,
|
|
||||||
host: Optional[Union[str, HostSequence]] = None,
|
|
||||||
port: Optional[int] = None,
|
|
||||||
path: Union[PathLike, TypingIterable[PathLike], None] = None,
|
|
||||||
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None,
|
|
||||||
shutdown_timeout: float = 60.0,
|
|
||||||
keepalive_timeout: float = 75.0,
|
|
||||||
ssl_context: Optional[SSLContext] = None,
|
|
||||||
print: Optional[Callable[..., None]] = print,
|
|
||||||
backlog: int = 128,
|
|
||||||
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
|
|
||||||
access_log_format: str = AccessLogger.LOG_FORMAT,
|
|
||||||
access_log: Optional[logging.Logger] = access_logger,
|
|
||||||
handle_signals: bool = True,
|
|
||||||
reuse_address: Optional[bool] = None,
|
|
||||||
reuse_port: Optional[bool] = None,
|
|
||||||
handler_cancellation: bool = False,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Run an app locally"""
|
|
||||||
if loop is None:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
|
|
||||||
# Configure if and only if in debugging mode and using the default logger
|
|
||||||
if loop.get_debug() and access_log and access_log.name == "aiohttp.access":
|
|
||||||
if access_log.level == logging.NOTSET:
|
|
||||||
access_log.setLevel(logging.DEBUG)
|
|
||||||
if not access_log.hasHandlers():
|
|
||||||
access_log.addHandler(logging.StreamHandler())
|
|
||||||
|
|
||||||
main_task = loop.create_task(
|
|
||||||
_run_app(
|
|
||||||
app,
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
path=path,
|
|
||||||
sock=sock,
|
|
||||||
shutdown_timeout=shutdown_timeout,
|
|
||||||
keepalive_timeout=keepalive_timeout,
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
print=print,
|
|
||||||
backlog=backlog,
|
|
||||||
access_log_class=access_log_class,
|
|
||||||
access_log_format=access_log_format,
|
|
||||||
access_log=access_log,
|
|
||||||
handle_signals=handle_signals,
|
|
||||||
reuse_address=reuse_address,
|
|
||||||
reuse_port=reuse_port,
|
|
||||||
handler_cancellation=handler_cancellation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(main_task)
|
|
||||||
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
main_task.cancel()
|
|
||||||
with suppress(asyncio.CancelledError):
|
|
||||||
loop.run_until_complete(main_task)
|
|
||||||
finally:
|
|
||||||
_cancel_tasks(asyncio.all_tasks(loop), loop)
|
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv: List[str]) -> None:
|
|
||||||
arg_parser = ArgumentParser(
|
|
||||||
description="aiohttp.web Application server", prog="aiohttp.web"
|
|
||||||
)
|
|
||||||
arg_parser.add_argument(
|
|
||||||
"entry_func",
|
|
||||||
help=(
|
|
||||||
"Callable returning the `aiohttp.web.Application` instance to "
|
|
||||||
"run. Should be specified in the 'module:function' syntax."
|
|
||||||
),
|
|
||||||
metavar="entry-func",
|
|
||||||
)
|
|
||||||
arg_parser.add_argument(
|
|
||||||
"-H",
|
|
||||||
"--hostname",
|
|
||||||
help="TCP/IP hostname to serve on (default: localhost)",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
arg_parser.add_argument(
|
|
||||||
"-P",
|
|
||||||
"--port",
|
|
||||||
help="TCP/IP port to serve on (default: %(default)r)",
|
|
||||||
type=int,
|
|
||||||
default=8080,
|
|
||||||
)
|
|
||||||
arg_parser.add_argument(
|
|
||||||
"-U",
|
|
||||||
"--path",
|
|
||||||
help="Unix file system path to serve on. Can be combined with hostname "
|
|
||||||
"to serve on both Unix and TCP.",
|
|
||||||
)
|
|
||||||
args, extra_argv = arg_parser.parse_known_args(argv)
|
|
||||||
|
|
||||||
# Import logic
|
|
||||||
mod_str, _, func_str = args.entry_func.partition(":")
|
|
||||||
if not func_str or not mod_str:
|
|
||||||
arg_parser.error("'entry-func' not in 'module:function' syntax")
|
|
||||||
if mod_str.startswith("."):
|
|
||||||
arg_parser.error("relative module names not supported")
|
|
||||||
try:
|
|
||||||
module = import_module(mod_str)
|
|
||||||
except ImportError as ex:
|
|
||||||
arg_parser.error(f"unable to import {mod_str}: {ex}")
|
|
||||||
try:
|
|
||||||
func = getattr(module, func_str)
|
|
||||||
except AttributeError:
|
|
||||||
arg_parser.error(f"module {mod_str!r} has no attribute {func_str!r}")
|
|
||||||
|
|
||||||
# Compatibility logic
|
|
||||||
if args.path is not None and not hasattr(socket, "AF_UNIX"):
|
|
||||||
arg_parser.error(
|
|
||||||
"file system paths not supported by your operating environment"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
|
|
||||||
if args.path and args.hostname is None:
|
|
||||||
host = port = None
|
|
||||||
else:
|
|
||||||
host = args.hostname or "localhost"
|
|
||||||
port = args.port
|
|
||||||
|
|
||||||
app = func(extra_argv)
|
|
||||||
run_app(app, host=host, port=port, path=args.path)
|
|
||||||
arg_parser.exit(message="Stopped\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__": # pragma: no branch
|
|
||||||
main(sys.argv[1:]) # pragma: no cover
|
|
||||||
|
|
@ -1,620 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from functools import lru_cache, partial, update_wrapper
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
MutableMapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
from aiosignal import Signal
|
|
||||||
from frozenlist import FrozenList
|
|
||||||
|
|
||||||
from . import hdrs
|
|
||||||
from .abc import (
|
|
||||||
AbstractAccessLogger,
|
|
||||||
AbstractMatchInfo,
|
|
||||||
AbstractRouter,
|
|
||||||
AbstractStreamWriter,
|
|
||||||
)
|
|
||||||
from .helpers import DEBUG, AppKey
|
|
||||||
from .http_parser import RawRequestMessage
|
|
||||||
from .log import web_logger
|
|
||||||
from .streams import StreamReader
|
|
||||||
from .typedefs import Handler, Middleware
|
|
||||||
from .web_exceptions import NotAppKeyWarning
|
|
||||||
from .web_log import AccessLogger
|
|
||||||
from .web_middlewares import _fix_request_current_app
|
|
||||||
from .web_protocol import RequestHandler
|
|
||||||
from .web_request import Request
|
|
||||||
from .web_response import StreamResponse
|
|
||||||
from .web_routedef import AbstractRouteDef
|
|
||||||
from .web_server import Server
|
|
||||||
from .web_urldispatcher import (
|
|
||||||
AbstractResource,
|
|
||||||
AbstractRoute,
|
|
||||||
Domain,
|
|
||||||
MaskDomain,
|
|
||||||
MatchedSubAppResource,
|
|
||||||
PrefixedSubAppResource,
|
|
||||||
SystemRoute,
|
|
||||||
UrlDispatcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ("Application", "CleanupError")
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
_AppSignal = Signal["Application"]
|
|
||||||
_RespPrepareSignal = Signal[Request, StreamResponse]
|
|
||||||
_Middlewares = FrozenList[Middleware]
|
|
||||||
_MiddlewaresHandlers = Optional[Sequence[Tuple[Middleware, bool]]]
|
|
||||||
_Subapps = List["Application"]
|
|
||||||
else:
|
|
||||||
# No type checker mode, skip types
|
|
||||||
_AppSignal = Signal
|
|
||||||
_RespPrepareSignal = Signal
|
|
||||||
_Middlewares = FrozenList
|
|
||||||
_MiddlewaresHandlers = Optional[Sequence]
|
|
||||||
_Subapps = List
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_U = TypeVar("_U")
|
|
||||||
_Resource = TypeVar("_Resource", bound=AbstractResource)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_middlewares(
|
|
||||||
handler: Handler, apps: Tuple["Application", ...]
|
|
||||||
) -> Callable[[Request], Awaitable[StreamResponse]]:
|
|
||||||
"""Apply middlewares to handler."""
|
|
||||||
for app in apps[::-1]:
|
|
||||||
for m, _ in app._middlewares_handlers: # type: ignore[union-attr]
|
|
||||||
handler = update_wrapper(partial(m, handler=handler), handler)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
|
|
||||||
_cached_build_middleware = lru_cache(maxsize=1024)(_build_middlewares)
|
|
||||||
|
|
||||||
|
|
||||||
class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
|
|
||||||
ATTRS = frozenset(
|
|
||||||
[
|
|
||||||
"logger",
|
|
||||||
"_debug",
|
|
||||||
"_router",
|
|
||||||
"_loop",
|
|
||||||
"_handler_args",
|
|
||||||
"_middlewares",
|
|
||||||
"_middlewares_handlers",
|
|
||||||
"_has_legacy_middlewares",
|
|
||||||
"_run_middlewares",
|
|
||||||
"_state",
|
|
||||||
"_frozen",
|
|
||||||
"_pre_frozen",
|
|
||||||
"_subapps",
|
|
||||||
"_on_response_prepare",
|
|
||||||
"_on_startup",
|
|
||||||
"_on_shutdown",
|
|
||||||
"_on_cleanup",
|
|
||||||
"_client_max_size",
|
|
||||||
"_cleanup_ctx",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
logger: logging.Logger = web_logger,
|
|
||||||
router: Optional[UrlDispatcher] = None,
|
|
||||||
middlewares: Iterable[Middleware] = (),
|
|
||||||
handler_args: Optional[Mapping[str, Any]] = None,
|
|
||||||
client_max_size: int = 1024**2,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
debug: Any = ..., # mypy doesn't support ellipsis
|
|
||||||
) -> None:
|
|
||||||
if router is None:
|
|
||||||
router = UrlDispatcher()
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
"router argument is deprecated", DeprecationWarning, stacklevel=2
|
|
||||||
)
|
|
||||||
assert isinstance(router, AbstractRouter), router
|
|
||||||
|
|
||||||
if loop is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"loop argument is deprecated", DeprecationWarning, stacklevel=2
|
|
||||||
)
|
|
||||||
|
|
||||||
if debug is not ...:
|
|
||||||
warnings.warn(
|
|
||||||
"debug argument is deprecated", DeprecationWarning, stacklevel=2
|
|
||||||
)
|
|
||||||
self._debug = debug
|
|
||||||
self._router: UrlDispatcher = router
|
|
||||||
self._loop = loop
|
|
||||||
self._handler_args = handler_args
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
self._middlewares: _Middlewares = FrozenList(middlewares)
|
|
||||||
|
|
||||||
# initialized on freezing
|
|
||||||
self._middlewares_handlers: _MiddlewaresHandlers = None
|
|
||||||
# initialized on freezing
|
|
||||||
self._run_middlewares: Optional[bool] = None
|
|
||||||
self._has_legacy_middlewares: bool = True
|
|
||||||
|
|
||||||
self._state: Dict[Union[AppKey[Any], str], object] = {}
|
|
||||||
self._frozen = False
|
|
||||||
self._pre_frozen = False
|
|
||||||
self._subapps: _Subapps = []
|
|
||||||
|
|
||||||
self._on_response_prepare: _RespPrepareSignal = Signal(self)
|
|
||||||
self._on_startup: _AppSignal = Signal(self)
|
|
||||||
self._on_shutdown: _AppSignal = Signal(self)
|
|
||||||
self._on_cleanup: _AppSignal = Signal(self)
|
|
||||||
self._cleanup_ctx = CleanupContext()
|
|
||||||
self._on_startup.append(self._cleanup_ctx._on_startup)
|
|
||||||
self._on_cleanup.append(self._cleanup_ctx._on_cleanup)
|
|
||||||
self._client_max_size = client_max_size
|
|
||||||
|
|
||||||
def __init_subclass__(cls: Type["Application"]) -> None:
|
|
||||||
warnings.warn(
|
|
||||||
"Inheritance class {} from web.Application "
|
|
||||||
"is discouraged".format(cls.__name__),
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
if DEBUG: # pragma: no cover
|
|
||||||
|
|
||||||
def __setattr__(self, name: str, val: Any) -> None:
|
|
||||||
if name not in self.ATTRS:
|
|
||||||
warnings.warn(
|
|
||||||
"Setting custom web.Application.{} attribute "
|
|
||||||
"is discouraged".format(name),
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
super().__setattr__(name, val)
|
|
||||||
|
|
||||||
# MutableMapping API
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
return self is other
|
|
||||||
|
|
||||||
@overload # type: ignore[override]
|
|
||||||
def __getitem__(self, key: AppKey[_T]) -> _T: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __getitem__(self, key: str) -> Any: ...
|
|
||||||
|
|
||||||
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
|
|
||||||
return self._state[key]
|
|
||||||
|
|
||||||
def _check_frozen(self) -> None:
|
|
||||||
if self._frozen:
|
|
||||||
warnings.warn(
|
|
||||||
"Changing state of started or joined application is deprecated",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
@overload # type: ignore[override]
|
|
||||||
def __setitem__(self, key: AppKey[_T], value: _T) -> None: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __setitem__(self, key: str, value: Any) -> None: ...
|
|
||||||
|
|
||||||
def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None:
|
|
||||||
self._check_frozen()
|
|
||||||
if not isinstance(key, AppKey):
|
|
||||||
warnings.warn(
|
|
||||||
"It is recommended to use web.AppKey instances for keys.\n"
|
|
||||||
+ "https://docs.aiohttp.org/en/stable/web_advanced.html"
|
|
||||||
+ "#application-s-config",
|
|
||||||
category=NotAppKeyWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
self._state[key] = value
|
|
||||||
|
|
||||||
def __delitem__(self, key: Union[str, AppKey[_T]]) -> None:
|
|
||||||
self._check_frozen()
|
|
||||||
del self._state[key]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self._state)
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
|
|
||||||
return iter(self._state)
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return id(self)
|
|
||||||
|
|
||||||
@overload # type: ignore[override]
|
|
||||||
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: str, default: Any = ...) -> Any: ...
|
|
||||||
|
|
||||||
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
|
|
||||||
return self._state.get(key, default)
|
|
||||||
|
|
||||||
########
|
|
||||||
@property
|
|
||||||
def loop(self) -> asyncio.AbstractEventLoop:
|
|
||||||
# Technically the loop can be None
|
|
||||||
# but we mask it by explicit type cast
|
|
||||||
# to provide more convenient type annotation
|
|
||||||
warnings.warn("loop property is deprecated", DeprecationWarning, stacklevel=2)
|
|
||||||
return cast(asyncio.AbstractEventLoop, self._loop)
|
|
||||||
|
|
||||||
def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None:
|
|
||||||
if loop is None:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if self._loop is not None and self._loop is not loop:
|
|
||||||
raise RuntimeError(
|
|
||||||
"web.Application instance initialized with different loop"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._loop = loop
|
|
||||||
|
|
||||||
# set loop debug
|
|
||||||
if self._debug is ...:
|
|
||||||
self._debug = loop.get_debug()
|
|
||||||
|
|
||||||
# set loop to sub applications
|
|
||||||
for subapp in self._subapps:
|
|
||||||
subapp._set_loop(loop)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pre_frozen(self) -> bool:
|
|
||||||
return self._pre_frozen
|
|
||||||
|
|
||||||
def pre_freeze(self) -> None:
|
|
||||||
if self._pre_frozen:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._pre_frozen = True
|
|
||||||
self._middlewares.freeze()
|
|
||||||
self._router.freeze()
|
|
||||||
self._on_response_prepare.freeze()
|
|
||||||
self._cleanup_ctx.freeze()
|
|
||||||
self._on_startup.freeze()
|
|
||||||
self._on_shutdown.freeze()
|
|
||||||
self._on_cleanup.freeze()
|
|
||||||
self._middlewares_handlers = tuple(self._prepare_middleware())
|
|
||||||
self._has_legacy_middlewares = any(
|
|
||||||
not new_style for _, new_style in self._middlewares_handlers
|
|
||||||
)
|
|
||||||
|
|
||||||
# If current app and any subapp do not have middlewares avoid run all
|
|
||||||
# of the code footprint that it implies, which have a middleware
|
|
||||||
# hardcoded per app that sets up the current_app attribute. If no
|
|
||||||
# middlewares are configured the handler will receive the proper
|
|
||||||
# current_app without needing all of this code.
|
|
||||||
self._run_middlewares = True if self.middlewares else False
|
|
||||||
|
|
||||||
for subapp in self._subapps:
|
|
||||||
subapp.pre_freeze()
|
|
||||||
self._run_middlewares = self._run_middlewares or subapp._run_middlewares
|
|
||||||
|
|
||||||
@property
|
|
||||||
def frozen(self) -> bool:
|
|
||||||
return self._frozen
|
|
||||||
|
|
||||||
def freeze(self) -> None:
|
|
||||||
if self._frozen:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.pre_freeze()
|
|
||||||
self._frozen = True
|
|
||||||
for subapp in self._subapps:
|
|
||||||
subapp.freeze()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def debug(self) -> bool:
|
|
||||||
warnings.warn("debug property is deprecated", DeprecationWarning, stacklevel=2)
|
|
||||||
return self._debug # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
def _reg_subapp_signals(self, subapp: "Application") -> None:
|
|
||||||
def reg_handler(signame: str) -> None:
|
|
||||||
subsig = getattr(subapp, signame)
|
|
||||||
|
|
||||||
async def handler(app: "Application") -> None:
|
|
||||||
await subsig.send(subapp)
|
|
||||||
|
|
||||||
appsig = getattr(self, signame)
|
|
||||||
appsig.append(handler)
|
|
||||||
|
|
||||||
reg_handler("on_startup")
|
|
||||||
reg_handler("on_shutdown")
|
|
||||||
reg_handler("on_cleanup")
|
|
||||||
|
|
||||||
def add_subapp(self, prefix: str, subapp: "Application") -> PrefixedSubAppResource:
|
|
||||||
if not isinstance(prefix, str):
|
|
||||||
raise TypeError("Prefix must be str")
|
|
||||||
prefix = prefix.rstrip("/")
|
|
||||||
if not prefix:
|
|
||||||
raise ValueError("Prefix cannot be empty")
|
|
||||||
factory = partial(PrefixedSubAppResource, prefix, subapp)
|
|
||||||
return self._add_subapp(factory, subapp)
|
|
||||||
|
|
||||||
def _add_subapp(
|
|
||||||
self, resource_factory: Callable[[], _Resource], subapp: "Application"
|
|
||||||
) -> _Resource:
|
|
||||||
if self.frozen:
|
|
||||||
raise RuntimeError("Cannot add sub application to frozen application")
|
|
||||||
if subapp.frozen:
|
|
||||||
raise RuntimeError("Cannot add frozen application")
|
|
||||||
resource = resource_factory()
|
|
||||||
self.router.register_resource(resource)
|
|
||||||
self._reg_subapp_signals(subapp)
|
|
||||||
self._subapps.append(subapp)
|
|
||||||
subapp.pre_freeze()
|
|
||||||
if self._loop is not None:
|
|
||||||
subapp._set_loop(self._loop)
|
|
||||||
return resource
|
|
||||||
|
|
||||||
def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResource:
|
|
||||||
if not isinstance(domain, str):
|
|
||||||
raise TypeError("Domain must be str")
|
|
||||||
elif "*" in domain:
|
|
||||||
rule: Domain = MaskDomain(domain)
|
|
||||||
else:
|
|
||||||
rule = Domain(domain)
|
|
||||||
factory = partial(MatchedSubAppResource, rule, subapp)
|
|
||||||
return self._add_subapp(factory, subapp)
|
|
||||||
|
|
||||||
def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]:
|
|
||||||
return self.router.add_routes(routes)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_response_prepare(self) -> _RespPrepareSignal:
|
|
||||||
return self._on_response_prepare
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_startup(self) -> _AppSignal:
|
|
||||||
return self._on_startup
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_shutdown(self) -> _AppSignal:
|
|
||||||
return self._on_shutdown
|
|
||||||
|
|
||||||
@property
|
|
||||||
def on_cleanup(self) -> _AppSignal:
|
|
||||||
return self._on_cleanup
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cleanup_ctx(self) -> "CleanupContext":
|
|
||||||
return self._cleanup_ctx
|
|
||||||
|
|
||||||
@property
|
|
||||||
def router(self) -> UrlDispatcher:
|
|
||||||
return self._router
|
|
||||||
|
|
||||||
@property
|
|
||||||
def middlewares(self) -> _Middlewares:
|
|
||||||
return self._middlewares
|
|
||||||
|
|
||||||
def _make_handler(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Server:
|
|
||||||
|
|
||||||
if not issubclass(access_log_class, AbstractAccessLogger):
|
|
||||||
raise TypeError(
|
|
||||||
"access_log_class must be subclass of "
|
|
||||||
"aiohttp.abc.AbstractAccessLogger, got {}".format(access_log_class)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._set_loop(loop)
|
|
||||||
self.freeze()
|
|
||||||
|
|
||||||
kwargs["debug"] = self._debug
|
|
||||||
kwargs["access_log_class"] = access_log_class
|
|
||||||
if self._handler_args:
|
|
||||||
for k, v in self._handler_args.items():
|
|
||||||
kwargs[k] = v
|
|
||||||
|
|
||||||
return Server(
|
|
||||||
self._handle, # type: ignore[arg-type]
|
|
||||||
request_factory=self._make_request,
|
|
||||||
loop=self._loop,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_handler(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
||||||
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Server:
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"Application.make_handler(...) is deprecated, use AppRunner API instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._make_handler(
|
|
||||||
loop=loop, access_log_class=access_log_class, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def startup(self) -> None:
|
|
||||||
"""Causes on_startup signal
|
|
||||||
|
|
||||||
Should be called in the event loop along with the request handler.
|
|
||||||
"""
|
|
||||||
await self.on_startup.send(self)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
"""Causes on_shutdown signal
|
|
||||||
|
|
||||||
Should be called before cleanup()
|
|
||||||
"""
|
|
||||||
await self.on_shutdown.send(self)
|
|
||||||
|
|
||||||
async def cleanup(self) -> None:
|
|
||||||
"""Causes on_cleanup signal
|
|
||||||
|
|
||||||
Should be called after shutdown()
|
|
||||||
"""
|
|
||||||
if self.on_cleanup.frozen:
|
|
||||||
await self.on_cleanup.send(self)
|
|
||||||
else:
|
|
||||||
# If an exception occurs in startup, ensure cleanup contexts are completed.
|
|
||||||
await self._cleanup_ctx._on_cleanup(self)
|
|
||||||
|
|
||||||
def _make_request(
|
|
||||||
self,
|
|
||||||
message: RawRequestMessage,
|
|
||||||
payload: StreamReader,
|
|
||||||
protocol: RequestHandler,
|
|
||||||
writer: AbstractStreamWriter,
|
|
||||||
task: "asyncio.Task[None]",
|
|
||||||
_cls: Type[Request] = Request,
|
|
||||||
) -> Request:
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert self._loop is not None
|
|
||||||
return _cls(
|
|
||||||
message,
|
|
||||||
payload,
|
|
||||||
protocol,
|
|
||||||
writer,
|
|
||||||
task,
|
|
||||||
self._loop,
|
|
||||||
client_max_size=self._client_max_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_middleware(self) -> Iterator[Tuple[Middleware, bool]]:
|
|
||||||
for m in reversed(self._middlewares):
|
|
||||||
if getattr(m, "__middleware_version__", None) == 1:
|
|
||||||
yield m, True
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
f'old-style middleware "{m!r}" deprecated, see #2252',
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
yield m, False
|
|
||||||
|
|
||||||
yield _fix_request_current_app(self), True
|
|
||||||
|
|
||||||
async def _handle(self, request: Request) -> StreamResponse:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
debug = loop.get_debug()
|
|
||||||
match_info = await self._router.resolve(request)
|
|
||||||
if debug: # pragma: no cover
|
|
||||||
if not isinstance(match_info, AbstractMatchInfo):
|
|
||||||
raise TypeError(
|
|
||||||
"match_info should be AbstractMatchInfo "
|
|
||||||
"instance, not {!r}".format(match_info)
|
|
||||||
)
|
|
||||||
match_info.add_app(self)
|
|
||||||
|
|
||||||
match_info.freeze()
|
|
||||||
|
|
||||||
request._match_info = match_info
|
|
||||||
|
|
||||||
if request.headers.get(hdrs.EXPECT):
|
|
||||||
resp = await match_info.expect_handler(request)
|
|
||||||
await request.writer.drain()
|
|
||||||
if resp is not None:
|
|
||||||
return resp
|
|
||||||
|
|
||||||
handler = match_info.handler
|
|
||||||
|
|
||||||
if self._run_middlewares:
|
|
||||||
# If its a SystemRoute, don't cache building the middlewares since
|
|
||||||
# they are constructed for every MatchInfoError as a new handler
|
|
||||||
# is made each time.
|
|
||||||
if not self._has_legacy_middlewares and not isinstance(
|
|
||||||
match_info.route, SystemRoute
|
|
||||||
):
|
|
||||||
handler = _cached_build_middleware(handler, match_info.apps)
|
|
||||||
else:
|
|
||||||
for app in match_info.apps[::-1]:
|
|
||||||
for m, new_style in app._middlewares_handlers: # type: ignore[union-attr]
|
|
||||||
if new_style:
|
|
||||||
handler = update_wrapper(
|
|
||||||
partial(m, handler=handler), handler
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
handler = await m(app, handler) # type: ignore[arg-type,assignment]
|
|
||||||
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
def __call__(self) -> "Application":
|
|
||||||
"""gunicorn compatibility"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<Application 0x{id(self):x}>"
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class CleanupError(RuntimeError):
|
|
||||||
@property
|
|
||||||
def exceptions(self) -> List[BaseException]:
|
|
||||||
return cast(List[BaseException], self.args[1])
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
_CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]]
|
|
||||||
else:
|
|
||||||
_CleanupContextBase = FrozenList
|
|
||||||
|
|
||||||
|
|
||||||
class CleanupContext(_CleanupContextBase):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._exits: List[AsyncIterator[None]] = []
|
|
||||||
|
|
||||||
async def _on_startup(self, app: Application) -> None:
|
|
||||||
for cb in self:
|
|
||||||
it = cb(app).__aiter__()
|
|
||||||
await it.__anext__()
|
|
||||||
self._exits.append(it)
|
|
||||||
|
|
||||||
async def _on_cleanup(self, app: Application) -> None:
|
|
||||||
errors = []
|
|
||||||
for it in reversed(self._exits):
|
|
||||||
try:
|
|
||||||
await it.__anext__()
|
|
||||||
except StopAsyncIteration:
|
|
||||||
pass
|
|
||||||
except (Exception, asyncio.CancelledError) as exc:
|
|
||||||
errors.append(exc)
|
|
||||||
else:
|
|
||||||
errors.append(RuntimeError(f"{it!r} has more than one 'yield'"))
|
|
||||||
if errors:
|
|
||||||
if len(errors) == 1:
|
|
||||||
raise errors[0]
|
|
||||||
else:
|
|
||||||
raise CleanupError("Multiple errors on cleanup stage", errors)
|
|
||||||
|
|
@ -1,452 +0,0 @@
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set # noqa
|
|
||||||
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from .typedefs import LooseHeaders, StrOrURL
|
|
||||||
from .web_response import Response
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"HTTPException",
|
|
||||||
"HTTPError",
|
|
||||||
"HTTPRedirection",
|
|
||||||
"HTTPSuccessful",
|
|
||||||
"HTTPOk",
|
|
||||||
"HTTPCreated",
|
|
||||||
"HTTPAccepted",
|
|
||||||
"HTTPNonAuthoritativeInformation",
|
|
||||||
"HTTPNoContent",
|
|
||||||
"HTTPResetContent",
|
|
||||||
"HTTPPartialContent",
|
|
||||||
"HTTPMove",
|
|
||||||
"HTTPMultipleChoices",
|
|
||||||
"HTTPMovedPermanently",
|
|
||||||
"HTTPFound",
|
|
||||||
"HTTPSeeOther",
|
|
||||||
"HTTPNotModified",
|
|
||||||
"HTTPUseProxy",
|
|
||||||
"HTTPTemporaryRedirect",
|
|
||||||
"HTTPPermanentRedirect",
|
|
||||||
"HTTPClientError",
|
|
||||||
"HTTPBadRequest",
|
|
||||||
"HTTPUnauthorized",
|
|
||||||
"HTTPPaymentRequired",
|
|
||||||
"HTTPForbidden",
|
|
||||||
"HTTPNotFound",
|
|
||||||
"HTTPMethodNotAllowed",
|
|
||||||
"HTTPNotAcceptable",
|
|
||||||
"HTTPProxyAuthenticationRequired",
|
|
||||||
"HTTPRequestTimeout",
|
|
||||||
"HTTPConflict",
|
|
||||||
"HTTPGone",
|
|
||||||
"HTTPLengthRequired",
|
|
||||||
"HTTPPreconditionFailed",
|
|
||||||
"HTTPRequestEntityTooLarge",
|
|
||||||
"HTTPRequestURITooLong",
|
|
||||||
"HTTPUnsupportedMediaType",
|
|
||||||
"HTTPRequestRangeNotSatisfiable",
|
|
||||||
"HTTPExpectationFailed",
|
|
||||||
"HTTPMisdirectedRequest",
|
|
||||||
"HTTPUnprocessableEntity",
|
|
||||||
"HTTPFailedDependency",
|
|
||||||
"HTTPUpgradeRequired",
|
|
||||||
"HTTPPreconditionRequired",
|
|
||||||
"HTTPTooManyRequests",
|
|
||||||
"HTTPRequestHeaderFieldsTooLarge",
|
|
||||||
"HTTPUnavailableForLegalReasons",
|
|
||||||
"HTTPServerError",
|
|
||||||
"HTTPInternalServerError",
|
|
||||||
"HTTPNotImplemented",
|
|
||||||
"HTTPBadGateway",
|
|
||||||
"HTTPServiceUnavailable",
|
|
||||||
"HTTPGatewayTimeout",
|
|
||||||
"HTTPVersionNotSupported",
|
|
||||||
"HTTPVariantAlsoNegotiates",
|
|
||||||
"HTTPInsufficientStorage",
|
|
||||||
"HTTPNotExtended",
|
|
||||||
"HTTPNetworkAuthenticationRequired",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NotAppKeyWarning(UserWarning):
|
|
||||||
"""Warning when not using AppKey in Application."""
|
|
||||||
|
|
||||||
|
|
||||||
############################################################
|
|
||||||
# HTTP Exceptions
|
|
||||||
############################################################
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPException(Response, Exception):
|
|
||||||
|
|
||||||
# You should set in subclasses:
|
|
||||||
# status = 200
|
|
||||||
|
|
||||||
status_code = -1
|
|
||||||
empty_body = False
|
|
||||||
|
|
||||||
__http_exception__ = True
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
headers: Optional[LooseHeaders] = None,
|
|
||||||
reason: Optional[str] = None,
|
|
||||||
body: Any = None,
|
|
||||||
text: Optional[str] = None,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
if body is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"body argument is deprecated for http web exceptions",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
Response.__init__(
|
|
||||||
self,
|
|
||||||
status=self.status_code,
|
|
||||||
headers=headers,
|
|
||||||
reason=reason,
|
|
||||||
body=body,
|
|
||||||
text=text,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
Exception.__init__(self, self.reason)
|
|
||||||
if self.body is None and not self.empty_body:
|
|
||||||
self.text = f"{self.status}: {self.reason}"
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPError(HTTPException):
|
|
||||||
"""Base class for exceptions with status codes in the 400s and 500s."""
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPRedirection(HTTPException):
|
|
||||||
"""Base class for exceptions with status codes in the 300s."""
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPSuccessful(HTTPException):
|
|
||||||
"""Base class for exceptions with status codes in the 200s."""
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPOk(HTTPSuccessful):
|
|
||||||
status_code = 200
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPCreated(HTTPSuccessful):
|
|
||||||
status_code = 201
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPAccepted(HTTPSuccessful):
|
|
||||||
status_code = 202
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNonAuthoritativeInformation(HTTPSuccessful):
|
|
||||||
status_code = 203
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNoContent(HTTPSuccessful):
|
|
||||||
status_code = 204
|
|
||||||
empty_body = True
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPResetContent(HTTPSuccessful):
|
|
||||||
status_code = 205
|
|
||||||
empty_body = True
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPPartialContent(HTTPSuccessful):
|
|
||||||
status_code = 206
|
|
||||||
|
|
||||||
|
|
||||||
############################################################
|
|
||||||
# 3xx redirection
|
|
||||||
############################################################
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPMove(HTTPRedirection):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
location: StrOrURL,
|
|
||||||
*,
|
|
||||||
headers: Optional[LooseHeaders] = None,
|
|
||||||
reason: Optional[str] = None,
|
|
||||||
body: Any = None,
|
|
||||||
text: Optional[str] = None,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
if not location:
|
|
||||||
raise ValueError("HTTP redirects need a location to redirect to.")
|
|
||||||
super().__init__(
|
|
||||||
headers=headers,
|
|
||||||
reason=reason,
|
|
||||||
body=body,
|
|
||||||
text=text,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
self.headers["Location"] = str(URL(location))
|
|
||||||
self.location = location
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPMultipleChoices(HTTPMove):
|
|
||||||
status_code = 300
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPMovedPermanently(HTTPMove):
|
|
||||||
status_code = 301
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPFound(HTTPMove):
|
|
||||||
status_code = 302
|
|
||||||
|
|
||||||
|
|
||||||
# This one is safe after a POST (the redirected location will be
|
|
||||||
# retrieved with GET):
|
|
||||||
class HTTPSeeOther(HTTPMove):
|
|
||||||
status_code = 303
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNotModified(HTTPRedirection):
|
|
||||||
# FIXME: this should include a date or etag header
|
|
||||||
status_code = 304
|
|
||||||
empty_body = True
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPUseProxy(HTTPMove):
|
|
||||||
# Not a move, but looks a little like one
|
|
||||||
status_code = 305
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPTemporaryRedirect(HTTPMove):
|
|
||||||
status_code = 307
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPPermanentRedirect(HTTPMove):
|
|
||||||
status_code = 308
|
|
||||||
|
|
||||||
|
|
||||||
############################################################
|
|
||||||
# 4xx client error
|
|
||||||
############################################################
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPClientError(HTTPError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPBadRequest(HTTPClientError):
|
|
||||||
status_code = 400
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPUnauthorized(HTTPClientError):
|
|
||||||
status_code = 401
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPPaymentRequired(HTTPClientError):
|
|
||||||
status_code = 402
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPForbidden(HTTPClientError):
|
|
||||||
status_code = 403
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNotFound(HTTPClientError):
|
|
||||||
status_code = 404
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPMethodNotAllowed(HTTPClientError):
|
|
||||||
status_code = 405
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
allowed_methods: Iterable[str],
|
|
||||||
*,
|
|
||||||
headers: Optional[LooseHeaders] = None,
|
|
||||||
reason: Optional[str] = None,
|
|
||||||
body: Any = None,
|
|
||||||
text: Optional[str] = None,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
allow = ",".join(sorted(allowed_methods))
|
|
||||||
super().__init__(
|
|
||||||
headers=headers,
|
|
||||||
reason=reason,
|
|
||||||
body=body,
|
|
||||||
text=text,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
self.headers["Allow"] = allow
|
|
||||||
self.allowed_methods: Set[str] = set(allowed_methods)
|
|
||||||
self.method = method.upper()
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNotAcceptable(HTTPClientError):
|
|
||||||
status_code = 406
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPProxyAuthenticationRequired(HTTPClientError):
|
|
||||||
status_code = 407
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPRequestTimeout(HTTPClientError):
|
|
||||||
status_code = 408
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPConflict(HTTPClientError):
|
|
||||||
status_code = 409
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPGone(HTTPClientError):
|
|
||||||
status_code = 410
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPLengthRequired(HTTPClientError):
|
|
||||||
status_code = 411
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPPreconditionFailed(HTTPClientError):
|
|
||||||
status_code = 412
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPRequestEntityTooLarge(HTTPClientError):
|
|
||||||
status_code = 413
|
|
||||||
|
|
||||||
def __init__(self, max_size: float, actual_size: float, **kwargs: Any) -> None:
|
|
||||||
kwargs.setdefault(
|
|
||||||
"text",
|
|
||||||
"Maximum request body size {} exceeded, "
|
|
||||||
"actual body size {}".format(max_size, actual_size),
|
|
||||||
)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPRequestURITooLong(HTTPClientError):
|
|
||||||
status_code = 414
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPUnsupportedMediaType(HTTPClientError):
|
|
||||||
status_code = 415
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPRequestRangeNotSatisfiable(HTTPClientError):
|
|
||||||
status_code = 416
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPExpectationFailed(HTTPClientError):
|
|
||||||
status_code = 417
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPMisdirectedRequest(HTTPClientError):
|
|
||||||
status_code = 421
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPUnprocessableEntity(HTTPClientError):
|
|
||||||
status_code = 422
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPFailedDependency(HTTPClientError):
|
|
||||||
status_code = 424
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPUpgradeRequired(HTTPClientError):
|
|
||||||
status_code = 426
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPPreconditionRequired(HTTPClientError):
|
|
||||||
status_code = 428
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPTooManyRequests(HTTPClientError):
|
|
||||||
status_code = 429
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPRequestHeaderFieldsTooLarge(HTTPClientError):
|
|
||||||
status_code = 431
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPUnavailableForLegalReasons(HTTPClientError):
|
|
||||||
status_code = 451
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
link: Optional[StrOrURL],
|
|
||||||
*,
|
|
||||||
headers: Optional[LooseHeaders] = None,
|
|
||||||
reason: Optional[str] = None,
|
|
||||||
body: Any = None,
|
|
||||||
text: Optional[str] = None,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
headers=headers,
|
|
||||||
reason=reason,
|
|
||||||
body=body,
|
|
||||||
text=text,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
self._link = None
|
|
||||||
if link:
|
|
||||||
self._link = URL(link)
|
|
||||||
self.headers["Link"] = f'<{str(self._link)}>; rel="blocked-by"'
|
|
||||||
|
|
||||||
@property
|
|
||||||
def link(self) -> Optional[URL]:
|
|
||||||
return self._link
|
|
||||||
|
|
||||||
|
|
||||||
############################################################
|
|
||||||
# 5xx Server Error
|
|
||||||
############################################################
|
|
||||||
# Response status codes beginning with the digit "5" indicate cases in
|
|
||||||
# which the server is aware that it has erred or is incapable of
|
|
||||||
# performing the request. Except when responding to a HEAD request, the
|
|
||||||
# server SHOULD include an entity containing an explanation of the error
|
|
||||||
# situation, and whether it is a temporary or permanent condition. User
|
|
||||||
# agents SHOULD display any included entity to the user. These response
|
|
||||||
# codes are applicable to any request method.
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPServerError(HTTPError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPInternalServerError(HTTPServerError):
|
|
||||||
status_code = 500
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNotImplemented(HTTPServerError):
|
|
||||||
status_code = 501
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPBadGateway(HTTPServerError):
|
|
||||||
status_code = 502
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPServiceUnavailable(HTTPServerError):
|
|
||||||
status_code = 503
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPGatewayTimeout(HTTPServerError):
|
|
||||||
status_code = 504
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPVersionNotSupported(HTTPServerError):
|
|
||||||
status_code = 505
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPVariantAlsoNegotiates(HTTPServerError):
|
|
||||||
status_code = 506
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPInsufficientStorage(HTTPServerError):
|
|
||||||
status_code = 507
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNotExtended(HTTPServerError):
|
|
||||||
status_code = 510
|
|
||||||
|
|
||||||
|
|
||||||
class HTTPNetworkAuthenticationRequired(HTTPServerError):
|
|
||||||
status_code = 511
|
|
||||||
|
|
@ -1,418 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
import sys
|
|
||||||
from contextlib import suppress
|
|
||||||
from enum import Enum, auto
|
|
||||||
from mimetypes import MimeTypes
|
|
||||||
from stat import S_ISREG
|
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import ( # noqa
|
|
||||||
IO,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Final,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import hdrs
|
|
||||||
from .abc import AbstractStreamWriter
|
|
||||||
from .helpers import ETAG_ANY, ETag, must_be_empty_body
|
|
||||||
from .typedefs import LooseHeaders, PathLike
|
|
||||||
from .web_exceptions import (
|
|
||||||
HTTPForbidden,
|
|
||||||
HTTPNotFound,
|
|
||||||
HTTPNotModified,
|
|
||||||
HTTPPartialContent,
|
|
||||||
HTTPPreconditionFailed,
|
|
||||||
HTTPRequestRangeNotSatisfiable,
|
|
||||||
)
|
|
||||||
from .web_response import StreamResponse
|
|
||||||
|
|
||||||
__all__ = ("FileResponse",)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .web_request import BaseRequest
|
|
||||||
|
|
||||||
|
|
||||||
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
|
|
||||||
|
|
||||||
|
|
||||||
NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE"))
|
|
||||||
|
|
||||||
CONTENT_TYPES: Final[MimeTypes] = MimeTypes()
|
|
||||||
|
|
||||||
# File extension to IANA encodings map that will be checked in the order defined.
|
|
||||||
ENCODING_EXTENSIONS = MappingProxyType(
|
|
||||||
{ext: CONTENT_TYPES.encodings_map[ext] for ext in (".br", ".gz")}
|
|
||||||
)
|
|
||||||
|
|
||||||
FALLBACK_CONTENT_TYPE = "application/octet-stream"
|
|
||||||
|
|
||||||
# Provide additional MIME type/extension pairs to be recognized.
|
|
||||||
# https://en.wikipedia.org/wiki/List_of_archive_formats#Compression_only
|
|
||||||
ADDITIONAL_CONTENT_TYPES = MappingProxyType(
|
|
||||||
{
|
|
||||||
"application/gzip": ".gz",
|
|
||||||
"application/x-brotli": ".br",
|
|
||||||
"application/x-bzip2": ".bz2",
|
|
||||||
"application/x-compress": ".Z",
|
|
||||||
"application/x-xz": ".xz",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _FileResponseResult(Enum):
|
|
||||||
"""The result of the file response."""
|
|
||||||
|
|
||||||
SEND_FILE = auto() # Ie a regular file to send
|
|
||||||
NOT_ACCEPTABLE = auto() # Ie a socket, or non-regular file
|
|
||||||
PRE_CONDITION_FAILED = auto() # Ie If-Match or If-None-Match failed
|
|
||||||
NOT_MODIFIED = auto() # 304 Not Modified
|
|
||||||
|
|
||||||
|
|
||||||
# Add custom pairs and clear the encodings map so guess_type ignores them.
|
|
||||||
CONTENT_TYPES.encodings_map.clear()
|
|
||||||
for content_type, extension in ADDITIONAL_CONTENT_TYPES.items():
|
|
||||||
CONTENT_TYPES.add_type(content_type, extension)
|
|
||||||
|
|
||||||
|
|
||||||
_CLOSE_FUTURES: Set[asyncio.Future[None]] = set()
|
|
||||||
|
|
||||||
|
|
||||||
class FileResponse(StreamResponse):
|
|
||||||
"""A response object can be used to send files."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
path: PathLike,
|
|
||||||
chunk_size: int = 256 * 1024,
|
|
||||||
status: int = 200,
|
|
||||||
reason: Optional[str] = None,
|
|
||||||
headers: Optional[LooseHeaders] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(status=status, reason=reason, headers=headers)
|
|
||||||
|
|
||||||
self._path = pathlib.Path(path)
|
|
||||||
self._chunk_size = chunk_size
|
|
||||||
|
|
||||||
def _seek_and_read(self, fobj: IO[Any], offset: int, chunk_size: int) -> bytes:
|
|
||||||
fobj.seek(offset)
|
|
||||||
return fobj.read(chunk_size) # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
async def _sendfile_fallback(
|
|
||||||
self, writer: AbstractStreamWriter, fobj: IO[Any], offset: int, count: int
|
|
||||||
) -> AbstractStreamWriter:
|
|
||||||
# To keep memory usage low,fobj is transferred in chunks
|
|
||||||
# controlled by the constructor's chunk_size argument.
|
|
||||||
|
|
||||||
chunk_size = self._chunk_size
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
chunk = await loop.run_in_executor(
|
|
||||||
None, self._seek_and_read, fobj, offset, chunk_size
|
|
||||||
)
|
|
||||||
while chunk:
|
|
||||||
await writer.write(chunk)
|
|
||||||
count = count - chunk_size
|
|
||||||
if count <= 0:
|
|
||||||
break
|
|
||||||
chunk = await loop.run_in_executor(None, fobj.read, min(chunk_size, count))
|
|
||||||
|
|
||||||
await writer.drain()
|
|
||||||
return writer
|
|
||||||
|
|
||||||
async def _sendfile(
|
|
||||||
self, request: "BaseRequest", fobj: IO[Any], offset: int, count: int
|
|
||||||
) -> AbstractStreamWriter:
|
|
||||||
writer = await super().prepare(request)
|
|
||||||
assert writer is not None
|
|
||||||
|
|
||||||
if NOSENDFILE or self.compression:
|
|
||||||
return await self._sendfile_fallback(writer, fobj, offset, count)
|
|
||||||
|
|
||||||
loop = request._loop
|
|
||||||
transport = request.transport
|
|
||||||
assert transport is not None
|
|
||||||
|
|
||||||
try:
|
|
||||||
await loop.sendfile(transport, fobj, offset, count)
|
|
||||||
except NotImplementedError:
|
|
||||||
return await self._sendfile_fallback(writer, fobj, offset, count)
|
|
||||||
|
|
||||||
await super().write_eof()
|
|
||||||
return writer
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool:
|
|
||||||
if len(etags) == 1 and etags[0].value == ETAG_ANY:
|
|
||||||
return True
|
|
||||||
return any(
|
|
||||||
etag.value == etag_value for etag in etags if weak or not etag.is_weak
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _not_modified(
|
|
||||||
self, request: "BaseRequest", etag_value: str, last_modified: float
|
|
||||||
) -> Optional[AbstractStreamWriter]:
|
|
||||||
self.set_status(HTTPNotModified.status_code)
|
|
||||||
self._length_check = False
|
|
||||||
self.etag = etag_value
|
|
||||||
self.last_modified = last_modified
|
|
||||||
# Delete any Content-Length headers provided by user. HTTP 304
|
|
||||||
# should always have empty response body
|
|
||||||
return await super().prepare(request)
|
|
||||||
|
|
||||||
async def _precondition_failed(
|
|
||||||
self, request: "BaseRequest"
|
|
||||||
) -> Optional[AbstractStreamWriter]:
|
|
||||||
self.set_status(HTTPPreconditionFailed.status_code)
|
|
||||||
self.content_length = 0
|
|
||||||
return await super().prepare(request)
|
|
||||||
|
|
||||||
def _make_response(
|
|
||||||
self, request: "BaseRequest", accept_encoding: str
|
|
||||||
) -> Tuple[
|
|
||||||
_FileResponseResult, Optional[io.BufferedReader], os.stat_result, Optional[str]
|
|
||||||
]:
|
|
||||||
"""Return the response result, io object, stat result, and encoding.
|
|
||||||
|
|
||||||
If an uncompressed file is returned, the encoding is set to
|
|
||||||
:py:data:`None`.
|
|
||||||
|
|
||||||
This method should be called from a thread executor
|
|
||||||
since it calls os.stat which may block.
|
|
||||||
"""
|
|
||||||
file_path, st, file_encoding = self._get_file_path_stat_encoding(
|
|
||||||
accept_encoding
|
|
||||||
)
|
|
||||||
if not file_path:
|
|
||||||
return _FileResponseResult.NOT_ACCEPTABLE, None, st, None
|
|
||||||
|
|
||||||
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
|
|
||||||
|
|
||||||
# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
|
|
||||||
if (ifmatch := request.if_match) is not None and not self._etag_match(
|
|
||||||
etag_value, ifmatch, weak=False
|
|
||||||
):
|
|
||||||
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding
|
|
||||||
|
|
||||||
if (
|
|
||||||
(unmodsince := request.if_unmodified_since) is not None
|
|
||||||
and ifmatch is None
|
|
||||||
and st.st_mtime > unmodsince.timestamp()
|
|
||||||
):
|
|
||||||
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding
|
|
||||||
|
|
||||||
# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
|
|
||||||
if (ifnonematch := request.if_none_match) is not None and self._etag_match(
|
|
||||||
etag_value, ifnonematch, weak=True
|
|
||||||
):
|
|
||||||
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding
|
|
||||||
|
|
||||||
if (
|
|
||||||
(modsince := request.if_modified_since) is not None
|
|
||||||
and ifnonematch is None
|
|
||||||
and st.st_mtime <= modsince.timestamp()
|
|
||||||
):
|
|
||||||
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding
|
|
||||||
|
|
||||||
fobj = file_path.open("rb")
|
|
||||||
with suppress(OSError):
|
|
||||||
# fstat() may not be available on all platforms
|
|
||||||
# Once we open the file, we want the fstat() to ensure
|
|
||||||
# the file has not changed between the first stat()
|
|
||||||
# and the open().
|
|
||||||
st = os.stat(fobj.fileno())
|
|
||||||
return _FileResponseResult.SEND_FILE, fobj, st, file_encoding
|
|
||||||
|
|
||||||
def _get_file_path_stat_encoding(
|
|
||||||
self, accept_encoding: str
|
|
||||||
) -> Tuple[Optional[pathlib.Path], os.stat_result, Optional[str]]:
|
|
||||||
file_path = self._path
|
|
||||||
for file_extension, file_encoding in ENCODING_EXTENSIONS.items():
|
|
||||||
if file_encoding not in accept_encoding:
|
|
||||||
continue
|
|
||||||
|
|
||||||
compressed_path = file_path.with_suffix(file_path.suffix + file_extension)
|
|
||||||
with suppress(OSError):
|
|
||||||
# Do not follow symlinks and ignore any non-regular files.
|
|
||||||
st = compressed_path.lstat()
|
|
||||||
if S_ISREG(st.st_mode):
|
|
||||||
return compressed_path, st, file_encoding
|
|
||||||
|
|
||||||
# Fallback to the uncompressed file
|
|
||||||
st = file_path.stat()
|
|
||||||
return file_path if S_ISREG(st.st_mode) else None, st, None
|
|
||||||
|
|
||||||
async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
# Encoding comparisons should be case-insensitive
|
|
||||||
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
|
|
||||||
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
|
|
||||||
try:
|
|
||||||
response_result, fobj, st, file_encoding = await loop.run_in_executor(
|
|
||||||
None, self._make_response, request, accept_encoding
|
|
||||||
)
|
|
||||||
except PermissionError:
|
|
||||||
self.set_status(HTTPForbidden.status_code)
|
|
||||||
return await super().prepare(request)
|
|
||||||
except OSError:
|
|
||||||
# Most likely to be FileNotFoundError or OSError for circular
|
|
||||||
# symlinks in python >= 3.13, so respond with 404.
|
|
||||||
self.set_status(HTTPNotFound.status_code)
|
|
||||||
return await super().prepare(request)
|
|
||||||
|
|
||||||
# Forbid special files like sockets, pipes, devices, etc.
|
|
||||||
if response_result is _FileResponseResult.NOT_ACCEPTABLE:
|
|
||||||
self.set_status(HTTPForbidden.status_code)
|
|
||||||
return await super().prepare(request)
|
|
||||||
|
|
||||||
if response_result is _FileResponseResult.PRE_CONDITION_FAILED:
|
|
||||||
return await self._precondition_failed(request)
|
|
||||||
|
|
||||||
if response_result is _FileResponseResult.NOT_MODIFIED:
|
|
||||||
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
|
|
||||||
last_modified = st.st_mtime
|
|
||||||
return await self._not_modified(request, etag_value, last_modified)
|
|
||||||
|
|
||||||
assert fobj is not None
|
|
||||||
try:
|
|
||||||
return await self._prepare_open_file(request, fobj, st, file_encoding)
|
|
||||||
finally:
|
|
||||||
# We do not await here because we do not want to wait
|
|
||||||
# for the executor to finish before returning the response
|
|
||||||
# so the connection can begin servicing another request
|
|
||||||
# as soon as possible.
|
|
||||||
close_future = loop.run_in_executor(None, fobj.close)
|
|
||||||
# Hold a strong reference to the future to prevent it from being
|
|
||||||
# garbage collected before it completes.
|
|
||||||
_CLOSE_FUTURES.add(close_future)
|
|
||||||
close_future.add_done_callback(_CLOSE_FUTURES.remove)
|
|
||||||
|
|
||||||
async def _prepare_open_file(
|
|
||||||
self,
|
|
||||||
request: "BaseRequest",
|
|
||||||
fobj: io.BufferedReader,
|
|
||||||
st: os.stat_result,
|
|
||||||
file_encoding: Optional[str],
|
|
||||||
) -> Optional[AbstractStreamWriter]:
|
|
||||||
status = self._status
|
|
||||||
file_size: int = st.st_size
|
|
||||||
file_mtime: float = st.st_mtime
|
|
||||||
count: int = file_size
|
|
||||||
start: Optional[int] = None
|
|
||||||
|
|
||||||
if (ifrange := request.if_range) is None or file_mtime <= ifrange.timestamp():
|
|
||||||
# If-Range header check:
|
|
||||||
# condition = cached date >= last modification date
|
|
||||||
# return 206 if True else 200.
|
|
||||||
# if False:
|
|
||||||
# Range header would not be processed, return 200
|
|
||||||
# if True but Range header missing
|
|
||||||
# return 200
|
|
||||||
try:
|
|
||||||
rng = request.http_range
|
|
||||||
start = rng.start
|
|
||||||
end: Optional[int] = rng.stop
|
|
||||||
except ValueError:
|
|
||||||
# https://tools.ietf.org/html/rfc7233:
|
|
||||||
# A server generating a 416 (Range Not Satisfiable) response to
|
|
||||||
# a byte-range request SHOULD send a Content-Range header field
|
|
||||||
# with an unsatisfied-range value.
|
|
||||||
# The complete-length in a 416 response indicates the current
|
|
||||||
# length of the selected representation.
|
|
||||||
#
|
|
||||||
# Will do the same below. Many servers ignore this and do not
|
|
||||||
# send a Content-Range header with HTTP 416
|
|
||||||
self._headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}"
|
|
||||||
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
|
|
||||||
return await super().prepare(request)
|
|
||||||
|
|
||||||
# If a range request has been made, convert start, end slice
|
|
||||||
# notation into file pointer offset and count
|
|
||||||
if start is not None:
|
|
||||||
if start < 0 and end is None: # return tail of file
|
|
||||||
start += file_size
|
|
||||||
if start < 0:
|
|
||||||
# if Range:bytes=-1000 in request header but file size
|
|
||||||
# is only 200, there would be trouble without this
|
|
||||||
start = 0
|
|
||||||
count = file_size - start
|
|
||||||
else:
|
|
||||||
# rfc7233:If the last-byte-pos value is
|
|
||||||
# absent, or if the value is greater than or equal to
|
|
||||||
# the current length of the representation data,
|
|
||||||
# the byte range is interpreted as the remainder
|
|
||||||
# of the representation (i.e., the server replaces the
|
|
||||||
# value of last-byte-pos with a value that is one less than
|
|
||||||
# the current length of the selected representation).
|
|
||||||
count = (
|
|
||||||
min(end if end is not None else file_size, file_size) - start
|
|
||||||
)
|
|
||||||
|
|
||||||
if start >= file_size:
|
|
||||||
# HTTP 416 should be returned in this case.
|
|
||||||
#
|
|
||||||
# According to https://tools.ietf.org/html/rfc7233:
|
|
||||||
# If a valid byte-range-set includes at least one
|
|
||||||
# byte-range-spec with a first-byte-pos that is less than
|
|
||||||
# the current length of the representation, or at least one
|
|
||||||
# suffix-byte-range-spec with a non-zero suffix-length,
|
|
||||||
# then the byte-range-set is satisfiable. Otherwise, the
|
|
||||||
# byte-range-set is unsatisfiable.
|
|
||||||
self._headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}"
|
|
||||||
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
|
|
||||||
return await super().prepare(request)
|
|
||||||
|
|
||||||
status = HTTPPartialContent.status_code
|
|
||||||
# Even though you are sending the whole file, you should still
|
|
||||||
# return a HTTP 206 for a Range request.
|
|
||||||
self.set_status(status)
|
|
||||||
|
|
||||||
# If the Content-Type header is not already set, guess it based on the
|
|
||||||
# extension of the request path. The encoding returned by guess_type
|
|
||||||
# can be ignored since the map was cleared above.
|
|
||||||
if hdrs.CONTENT_TYPE not in self._headers:
|
|
||||||
if sys.version_info >= (3, 13):
|
|
||||||
guesser = CONTENT_TYPES.guess_file_type
|
|
||||||
else:
|
|
||||||
guesser = CONTENT_TYPES.guess_type
|
|
||||||
self.content_type = guesser(self._path)[0] or FALLBACK_CONTENT_TYPE
|
|
||||||
|
|
||||||
if file_encoding:
|
|
||||||
self._headers[hdrs.CONTENT_ENCODING] = file_encoding
|
|
||||||
self._headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
|
|
||||||
# Disable compression if we are already sending
|
|
||||||
# a compressed file since we don't want to double
|
|
||||||
# compress.
|
|
||||||
self._compression = False
|
|
||||||
|
|
||||||
self.etag = f"{st.st_mtime_ns:x}-{st.st_size:x}"
|
|
||||||
self.last_modified = file_mtime
|
|
||||||
self.content_length = count
|
|
||||||
|
|
||||||
self._headers[hdrs.ACCEPT_RANGES] = "bytes"
|
|
||||||
|
|
||||||
if status == HTTPPartialContent.status_code:
|
|
||||||
real_start = start
|
|
||||||
assert real_start is not None
|
|
||||||
self._headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format(
|
|
||||||
real_start, real_start + count - 1, file_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we are sending 0 bytes calling sendfile() will throw a ValueError
|
|
||||||
if count == 0 or must_be_empty_body(request.method, status):
|
|
||||||
return await super().prepare(request)
|
|
||||||
|
|
||||||
# be aware that start could be None or int=0 here.
|
|
||||||
offset = start or 0
|
|
||||||
|
|
||||||
return await self._sendfile(request, fobj, offset, count)
|
|
||||||
|
|
@ -1,216 +0,0 @@
|
||||||
import datetime
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import time as time_mod
|
|
||||||
from collections import namedtuple
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa
|
|
||||||
|
|
||||||
from .abc import AbstractAccessLogger
|
|
||||||
from .web_request import BaseRequest
|
|
||||||
from .web_response import StreamResponse
|
|
||||||
|
|
||||||
KeyMethod = namedtuple("KeyMethod", "key method")
|
|
||||||
|
|
||||||
|
|
||||||
class AccessLogger(AbstractAccessLogger):
|
|
||||||
"""Helper object to log access.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
log = logging.getLogger("spam")
|
|
||||||
log_format = "%a %{User-Agent}i"
|
|
||||||
access_logger = AccessLogger(log, log_format)
|
|
||||||
access_logger.log(request, response, time)
|
|
||||||
|
|
||||||
Format:
|
|
||||||
%% The percent sign
|
|
||||||
%a Remote IP-address (IP-address of proxy if using reverse proxy)
|
|
||||||
%t Time when the request was started to process
|
|
||||||
%P The process ID of the child that serviced the request
|
|
||||||
%r First line of request
|
|
||||||
%s Response status code
|
|
||||||
%b Size of response in bytes, including HTTP headers
|
|
||||||
%T Time taken to serve the request, in seconds
|
|
||||||
%Tf Time taken to serve the request, in seconds with floating fraction
|
|
||||||
in .06f format
|
|
||||||
%D Time taken to serve the request, in microseconds
|
|
||||||
%{FOO}i request.headers['FOO']
|
|
||||||
%{FOO}o response.headers['FOO']
|
|
||||||
%{FOO}e os.environ['FOO']
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
LOG_FORMAT_MAP = {
|
|
||||||
"a": "remote_address",
|
|
||||||
"t": "request_start_time",
|
|
||||||
"P": "process_id",
|
|
||||||
"r": "first_request_line",
|
|
||||||
"s": "response_status",
|
|
||||||
"b": "response_size",
|
|
||||||
"T": "request_time",
|
|
||||||
"Tf": "request_time_frac",
|
|
||||||
"D": "request_time_micro",
|
|
||||||
"i": "request_header",
|
|
||||||
"o": "response_header",
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"'
|
|
||||||
FORMAT_RE = re.compile(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)")
|
|
||||||
CLEANUP_RE = re.compile(r"(%[^s])")
|
|
||||||
_FORMAT_CACHE: Dict[str, Tuple[str, List[KeyMethod]]] = {}
|
|
||||||
|
|
||||||
def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None:
|
|
||||||
"""Initialise the logger.
|
|
||||||
|
|
||||||
logger is a logger object to be used for logging.
|
|
||||||
log_format is a string with apache compatible log format description.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__(logger, log_format=log_format)
|
|
||||||
|
|
||||||
_compiled_format = AccessLogger._FORMAT_CACHE.get(log_format)
|
|
||||||
if not _compiled_format:
|
|
||||||
_compiled_format = self.compile_format(log_format)
|
|
||||||
AccessLogger._FORMAT_CACHE[log_format] = _compiled_format
|
|
||||||
|
|
||||||
self._log_format, self._methods = _compiled_format
|
|
||||||
|
|
||||||
def compile_format(self, log_format: str) -> Tuple[str, List[KeyMethod]]:
|
|
||||||
"""Translate log_format into form usable by modulo formatting
|
|
||||||
|
|
||||||
All known atoms will be replaced with %s
|
|
||||||
Also methods for formatting of those atoms will be added to
|
|
||||||
_methods in appropriate order
|
|
||||||
|
|
||||||
For example we have log_format = "%a %t"
|
|
||||||
This format will be translated to "%s %s"
|
|
||||||
Also contents of _methods will be
|
|
||||||
[self._format_a, self._format_t]
|
|
||||||
These method will be called and results will be passed
|
|
||||||
to translated string format.
|
|
||||||
|
|
||||||
Each _format_* method receive 'args' which is list of arguments
|
|
||||||
given to self.log
|
|
||||||
|
|
||||||
Exceptions are _format_e, _format_i and _format_o methods which
|
|
||||||
also receive key name (by functools.partial)
|
|
||||||
|
|
||||||
"""
|
|
||||||
# list of (key, method) tuples, we don't use an OrderedDict as users
|
|
||||||
# can repeat the same key more than once
|
|
||||||
methods = list()
|
|
||||||
|
|
||||||
for atom in self.FORMAT_RE.findall(log_format):
|
|
||||||
if atom[1] == "":
|
|
||||||
format_key1 = self.LOG_FORMAT_MAP[atom[0]]
|
|
||||||
m = getattr(AccessLogger, "_format_%s" % atom[0])
|
|
||||||
key_method = KeyMethod(format_key1, m)
|
|
||||||
else:
|
|
||||||
format_key2 = (self.LOG_FORMAT_MAP[atom[2]], atom[1])
|
|
||||||
m = getattr(AccessLogger, "_format_%s" % atom[2])
|
|
||||||
key_method = KeyMethod(format_key2, functools.partial(m, atom[1]))
|
|
||||||
|
|
||||||
methods.append(key_method)
|
|
||||||
|
|
||||||
log_format = self.FORMAT_RE.sub(r"%s", log_format)
|
|
||||||
log_format = self.CLEANUP_RE.sub(r"%\1", log_format)
|
|
||||||
return log_format, methods
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_i(
|
|
||||||
key: str, request: BaseRequest, response: StreamResponse, time: float
|
|
||||||
) -> str:
|
|
||||||
if request is None:
|
|
||||||
return "(no headers)"
|
|
||||||
|
|
||||||
# suboptimal, make istr(key) once
|
|
||||||
return request.headers.get(key, "-")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_o(
|
|
||||||
key: str, request: BaseRequest, response: StreamResponse, time: float
|
|
||||||
) -> str:
|
|
||||||
# suboptimal, make istr(key) once
|
|
||||||
return response.headers.get(key, "-")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_a(request: BaseRequest, response: StreamResponse, time: float) -> str:
|
|
||||||
if request is None:
|
|
||||||
return "-"
|
|
||||||
ip = request.remote
|
|
||||||
return ip if ip is not None else "-"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_t(request: BaseRequest, response: StreamResponse, time: float) -> str:
|
|
||||||
tz = datetime.timezone(datetime.timedelta(seconds=-time_mod.timezone))
|
|
||||||
now = datetime.datetime.now(tz)
|
|
||||||
start_time = now - datetime.timedelta(seconds=time)
|
|
||||||
return start_time.strftime("[%d/%b/%Y:%H:%M:%S %z]")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_P(request: BaseRequest, response: StreamResponse, time: float) -> str:
|
|
||||||
return "<%s>" % os.getpid()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_r(request: BaseRequest, response: StreamResponse, time: float) -> str:
|
|
||||||
if request is None:
|
|
||||||
return "-"
|
|
||||||
return "{} {} HTTP/{}.{}".format(
|
|
||||||
request.method,
|
|
||||||
request.path_qs,
|
|
||||||
request.version.major,
|
|
||||||
request.version.minor,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_s(request: BaseRequest, response: StreamResponse, time: float) -> int:
|
|
||||||
return response.status
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_b(request: BaseRequest, response: StreamResponse, time: float) -> int:
|
|
||||||
return response.body_length
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_T(request: BaseRequest, response: StreamResponse, time: float) -> str:
|
|
||||||
return str(round(time))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_Tf(request: BaseRequest, response: StreamResponse, time: float) -> str:
|
|
||||||
return "%06f" % time
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_D(request: BaseRequest, response: StreamResponse, time: float) -> str:
|
|
||||||
return str(round(time * 1000000))
|
|
||||||
|
|
||||||
def _format_line(
|
|
||||||
self, request: BaseRequest, response: StreamResponse, time: float
|
|
||||||
) -> Iterable[Tuple[str, Callable[[BaseRequest, StreamResponse, float], str]]]:
|
|
||||||
return [(key, method(request, response, time)) for key, method in self._methods]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def enabled(self) -> bool:
|
|
||||||
"""Check if logger is enabled."""
|
|
||||||
# Avoid formatting the log line if it will not be emitted.
|
|
||||||
return self.logger.isEnabledFor(logging.INFO)
|
|
||||||
|
|
||||||
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
|
|
||||||
try:
|
|
||||||
fmt_info = self._format_line(request, response, time)
|
|
||||||
|
|
||||||
values = list()
|
|
||||||
extra = dict()
|
|
||||||
for key, value in fmt_info:
|
|
||||||
values.append(value)
|
|
||||||
|
|
||||||
if key.__class__ is str:
|
|
||||||
extra[key] = value
|
|
||||||
else:
|
|
||||||
k1, k2 = key # type: ignore[misc]
|
|
||||||
dct = extra.get(k1, {}) # type: ignore[var-annotated,has-type]
|
|
||||||
dct[k2] = value # type: ignore[index,has-type]
|
|
||||||
extra[k1] = dct # type: ignore[has-type,assignment]
|
|
||||||
|
|
||||||
self.logger.info(self._log_format % tuple(values), extra=extra)
|
|
||||||
except Exception:
|
|
||||||
self.logger.exception("Error in logging")
|
|
||||||
|
|
@ -1,121 +0,0 @@
|
||||||
import re
|
|
||||||
from typing import TYPE_CHECKING, Tuple, Type, TypeVar
|
|
||||||
|
|
||||||
from .typedefs import Handler, Middleware
|
|
||||||
from .web_exceptions import HTTPMove, HTTPPermanentRedirect
|
|
||||||
from .web_request import Request
|
|
||||||
from .web_response import StreamResponse
|
|
||||||
from .web_urldispatcher import SystemRoute
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"middleware",
|
|
||||||
"normalize_path_middleware",
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .web_app import Application
|
|
||||||
|
|
||||||
_Func = TypeVar("_Func")
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_request_resolves(request: Request, path: str) -> Tuple[bool, Request]:
|
|
||||||
alt_request = request.clone(rel_url=path)
|
|
||||||
|
|
||||||
match_info = await request.app.router.resolve(alt_request)
|
|
||||||
alt_request._match_info = match_info
|
|
||||||
|
|
||||||
if match_info.http_exception is None:
|
|
||||||
return True, alt_request
|
|
||||||
|
|
||||||
return False, request
|
|
||||||
|
|
||||||
|
|
||||||
def middleware(f: _Func) -> _Func:
|
|
||||||
f.__middleware_version__ = 1 # type: ignore[attr-defined]
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_path_middleware(
|
|
||||||
*,
|
|
||||||
append_slash: bool = True,
|
|
||||||
remove_slash: bool = False,
|
|
||||||
merge_slashes: bool = True,
|
|
||||||
redirect_class: Type[HTTPMove] = HTTPPermanentRedirect,
|
|
||||||
) -> Middleware:
|
|
||||||
"""Factory for producing a middleware that normalizes the path of a request.
|
|
||||||
|
|
||||||
Normalizing means:
|
|
||||||
- Add or remove a trailing slash to the path.
|
|
||||||
- Double slashes are replaced by one.
|
|
||||||
|
|
||||||
The middleware returns as soon as it finds a path that resolves
|
|
||||||
correctly. The order if both merge and append/remove are enabled is
|
|
||||||
1) merge slashes
|
|
||||||
2) append/remove slash
|
|
||||||
3) both merge slashes and append/remove slash.
|
|
||||||
If the path resolves with at least one of those conditions, it will
|
|
||||||
redirect to the new path.
|
|
||||||
|
|
||||||
Only one of `append_slash` and `remove_slash` can be enabled. If both
|
|
||||||
are `True` the factory will raise an assertion error
|
|
||||||
|
|
||||||
If `append_slash` is `True` the middleware will append a slash when
|
|
||||||
needed. If a resource is defined with trailing slash and the request
|
|
||||||
comes without it, it will append it automatically.
|
|
||||||
|
|
||||||
If `remove_slash` is `True`, `append_slash` must be `False`. When enabled
|
|
||||||
the middleware will remove trailing slashes and redirect if the resource
|
|
||||||
is defined
|
|
||||||
|
|
||||||
If merge_slashes is True, merge multiple consecutive slashes in the
|
|
||||||
path into one.
|
|
||||||
"""
|
|
||||||
correct_configuration = not (append_slash and remove_slash)
|
|
||||||
assert correct_configuration, "Cannot both remove and append slash"
|
|
||||||
|
|
||||||
@middleware
|
|
||||||
async def impl(request: Request, handler: Handler) -> StreamResponse:
|
|
||||||
if isinstance(request.match_info.route, SystemRoute):
|
|
||||||
paths_to_check = []
|
|
||||||
if "?" in request.raw_path:
|
|
||||||
path, query = request.raw_path.split("?", 1)
|
|
||||||
query = "?" + query
|
|
||||||
else:
|
|
||||||
query = ""
|
|
||||||
path = request.raw_path
|
|
||||||
|
|
||||||
if merge_slashes:
|
|
||||||
paths_to_check.append(re.sub("//+", "/", path))
|
|
||||||
if append_slash and not request.path.endswith("/"):
|
|
||||||
paths_to_check.append(path + "/")
|
|
||||||
if remove_slash and request.path.endswith("/"):
|
|
||||||
paths_to_check.append(path[:-1])
|
|
||||||
if merge_slashes and append_slash:
|
|
||||||
paths_to_check.append(re.sub("//+", "/", path + "/"))
|
|
||||||
if merge_slashes and remove_slash:
|
|
||||||
merged_slashes = re.sub("//+", "/", path)
|
|
||||||
paths_to_check.append(merged_slashes[:-1])
|
|
||||||
|
|
||||||
for path in paths_to_check:
|
|
||||||
path = re.sub("^//+", "/", path) # SECURITY: GHSA-v6wp-4m6f-gcjg
|
|
||||||
resolves, request = await _check_request_resolves(request, path)
|
|
||||||
if resolves:
|
|
||||||
raise redirect_class(request.raw_path + query)
|
|
||||||
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
def _fix_request_current_app(app: "Application") -> Middleware:
|
|
||||||
@middleware
|
|
||||||
async def impl(request: Request, handler: Handler) -> StreamResponse:
|
|
||||||
match_info = request.match_info
|
|
||||||
prev = match_info.current_app
|
|
||||||
match_info.current_app = app
|
|
||||||
try:
|
|
||||||
return await handler(request)
|
|
||||||
finally:
|
|
||||||
match_info.current_app = prev
|
|
||||||
|
|
||||||
return impl
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue