chore: migrate to gitea
This commit is contained in:
7
vendor/github.com/quic-go/qpack/.codecov.yml
generated
vendored
Normal file
7
vendor/github.com/quic-go/qpack/.codecov.yml
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
coverage:
|
||||
round: nearest
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
threshold: 1
|
||||
patch: false
|
||||
6
vendor/github.com/quic-go/qpack/.gitignore
generated
vendored
Normal file
6
vendor/github.com/quic-go/qpack/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
fuzzing/*.zip
|
||||
fuzzing/coverprofile
|
||||
fuzzing/crashers
|
||||
fuzzing/sonarprofile
|
||||
fuzzing/suppressions
|
||||
fuzzing/corpus/
|
||||
3
vendor/github.com/quic-go/qpack/.gitmodules
generated
vendored
Normal file
3
vendor/github.com/quic-go/qpack/.gitmodules
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "interop/qifs"]
|
||||
path = interop/qifs
|
||||
url = https://github.com/qpackers/qifs.git
|
||||
22
vendor/github.com/quic-go/qpack/.golangci.yml
generated
vendored
Normal file
22
vendor/github.com/quic-go/qpack/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- asciicheck
|
||||
- copyloopvar
|
||||
- exhaustive
|
||||
- govet
|
||||
- ineffassign
|
||||
- misspell
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usetesting
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- goimports
|
||||
7
vendor/github.com/quic-go/qpack/LICENSE.md
generated
vendored
Normal file
7
vendor/github.com/quic-go/qpack/LICENSE.md
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
Copyright 2019 Marten Seemann
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
21
vendor/github.com/quic-go/qpack/README.md
generated
vendored
Normal file
21
vendor/github.com/quic-go/qpack/README.md
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# QPACK
|
||||
|
||||
[](https://pkg.go.dev/github.com/quic-go/qpack)
|
||||
[](https://codecov.io/gh/quic-go/qpack)
|
||||
[](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
|
||||
|
||||
This is a minimal QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) implementation in Go. It reuses the Huffman encoder / decoder code from the [HPACK implementation in the Go standard library](https://github.com/golang/net/tree/master/http2/hpack).
|
||||
|
||||
It is fully interoperable with other QPACK implementations (both encoders and decoders). However, it does not support the dynamic table and relies solely on the static table and string literals (including Huffman encoding), which limits compression efficiency. If you're interested in dynamic table support, please comment on [issue #33](https://github.com/quic-go/qpack/issues/33).
|
||||
|
||||
## Running the Interop Tests
|
||||
|
||||
Install the [QPACK interop files](https://github.com/qpackers/qifs/) by running
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
```bash
|
||||
go test -v ./interop
|
||||
```
|
||||
183
vendor/github.com/quic-go/qpack/decoder.go
generated
vendored
Normal file
183
vendor/github.com/quic-go/qpack/decoder.go
generated
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
package qpack
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
// An invalidIndexError is returned when decoding encounters an invalid index
|
||||
// (e.g., an index that is out of bounds for the static table).
|
||||
type invalidIndexError int
|
||||
|
||||
func (e invalidIndexError) Error() string {
|
||||
return fmt.Sprintf("invalid indexed representation index %d", int(e))
|
||||
}
|
||||
|
||||
var errNoDynamicTable = errors.New("no dynamic table")
|
||||
|
||||
// A Decoder decodes QPACK header blocks.
|
||||
// A Decoder can be reused to decode multiple header blocks on different streams
|
||||
// on the same connection (e.g., headers then trailers).
|
||||
// This will be useful when dynamic table support is added.
|
||||
type Decoder struct{}
|
||||
|
||||
// DecodeFunc is a function that decodes the next header field from a header block.
|
||||
// It should be called repeatedly until it returns io.EOF.
|
||||
// It returns io.EOF when all header fields have been decoded.
|
||||
// Any error other than io.EOF indicates a decoding error.
|
||||
type DecodeFunc func() (HeaderField, error)
|
||||
|
||||
// NewDecoder returns a new Decoder.
|
||||
func NewDecoder() *Decoder {
|
||||
return &Decoder{}
|
||||
}
|
||||
|
||||
// Decode returns a function that decodes header fields from the given header block.
|
||||
// It does not copy the slice; the caller must ensure it remains valid during decoding.
|
||||
func (d *Decoder) Decode(p []byte) DecodeFunc {
|
||||
var readRequiredInsertCount bool
|
||||
var readDeltaBase bool
|
||||
|
||||
return func() (HeaderField, error) {
|
||||
if !readRequiredInsertCount {
|
||||
requiredInsertCount, rest, err := readVarInt(8, p)
|
||||
if err != nil {
|
||||
return HeaderField{}, err
|
||||
}
|
||||
p = rest
|
||||
readRequiredInsertCount = true
|
||||
if requiredInsertCount != 0 {
|
||||
return HeaderField{}, errors.New("expected Required Insert Count to be zero")
|
||||
}
|
||||
}
|
||||
|
||||
if !readDeltaBase {
|
||||
base, rest, err := readVarInt(7, p)
|
||||
if err != nil {
|
||||
return HeaderField{}, err
|
||||
}
|
||||
p = rest
|
||||
readDeltaBase = true
|
||||
if base != 0 {
|
||||
return HeaderField{}, errors.New("expected Base to be zero")
|
||||
}
|
||||
}
|
||||
|
||||
if len(p) == 0 {
|
||||
return HeaderField{}, io.EOF
|
||||
}
|
||||
|
||||
b := p[0]
|
||||
var hf HeaderField
|
||||
var rest []byte
|
||||
var err error
|
||||
switch {
|
||||
case (b & 0x80) > 0: // 1xxxxxxx
|
||||
hf, rest, err = d.parseIndexedHeaderField(p)
|
||||
case (b & 0xc0) == 0x40: // 01xxxxxx
|
||||
hf, rest, err = d.parseLiteralHeaderField(p)
|
||||
case (b & 0xe0) == 0x20: // 001xxxxx
|
||||
hf, rest, err = d.parseLiteralHeaderFieldWithoutNameReference(p)
|
||||
default:
|
||||
err = fmt.Errorf("unexpected type byte: %#x", b)
|
||||
}
|
||||
p = rest
|
||||
if err != nil {
|
||||
return HeaderField{}, err
|
||||
}
|
||||
return hf, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) parseIndexedHeaderField(buf []byte) (_ HeaderField, rest []byte, _ error) {
|
||||
if buf[0]&0x40 == 0 {
|
||||
return HeaderField{}, buf, errNoDynamicTable
|
||||
}
|
||||
index, rest, err := readVarInt(6, buf)
|
||||
if err != nil {
|
||||
return HeaderField{}, buf, err
|
||||
}
|
||||
hf, ok := d.at(index)
|
||||
if !ok {
|
||||
return HeaderField{}, buf, invalidIndexError(index)
|
||||
}
|
||||
return hf, rest, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) parseLiteralHeaderField(buf []byte) (_ HeaderField, rest []byte, _ error) {
|
||||
if buf[0]&0x10 == 0 {
|
||||
return HeaderField{}, buf, errNoDynamicTable
|
||||
}
|
||||
// We don't need to check the value of the N-bit here.
|
||||
// It's only relevant when re-encoding header fields,
|
||||
// and determines whether the header field can be added to the dynamic table.
|
||||
// Since we don't support the dynamic table, we can ignore it.
|
||||
index, rest, err := readVarInt(4, buf)
|
||||
if err != nil {
|
||||
return HeaderField{}, buf, err
|
||||
}
|
||||
hf, ok := d.at(index)
|
||||
if !ok {
|
||||
return HeaderField{}, buf, invalidIndexError(index)
|
||||
}
|
||||
buf = rest
|
||||
if len(buf) == 0 {
|
||||
return HeaderField{}, buf, io.ErrUnexpectedEOF
|
||||
}
|
||||
usesHuffman := buf[0]&0x80 > 0
|
||||
val, rest, err := d.readString(rest, 7, usesHuffman)
|
||||
if err != nil {
|
||||
return HeaderField{}, rest, err
|
||||
}
|
||||
hf.Value = val
|
||||
return hf, rest, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) parseLiteralHeaderFieldWithoutNameReference(buf []byte) (_ HeaderField, rest []byte, _ error) {
|
||||
usesHuffmanForName := buf[0]&0x8 > 0
|
||||
name, rest, err := d.readString(buf, 3, usesHuffmanForName)
|
||||
if err != nil {
|
||||
return HeaderField{}, rest, err
|
||||
}
|
||||
buf = rest
|
||||
if len(buf) == 0 {
|
||||
return HeaderField{}, rest, io.ErrUnexpectedEOF
|
||||
}
|
||||
usesHuffmanForVal := buf[0]&0x80 > 0
|
||||
val, rest, err := d.readString(buf, 7, usesHuffmanForVal)
|
||||
if err != nil {
|
||||
return HeaderField{}, rest, err
|
||||
}
|
||||
return HeaderField{Name: name, Value: val}, rest, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) readString(buf []byte, n uint8, usesHuffman bool) (string, []byte, error) {
|
||||
l, buf, err := readVarInt(n, buf)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if uint64(len(buf)) < l {
|
||||
return "", nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
var val string
|
||||
if usesHuffman {
|
||||
val, err = hpack.HuffmanDecodeToString(buf[:l])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else {
|
||||
val = string(buf[:l])
|
||||
}
|
||||
buf = buf[l:]
|
||||
return val, buf, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
|
||||
if i >= uint64(len(staticTableEntries)) {
|
||||
return
|
||||
}
|
||||
return staticTableEntries[i], true
|
||||
}
|
||||
95
vendor/github.com/quic-go/qpack/encoder.go
generated
vendored
Normal file
95
vendor/github.com/quic-go/qpack/encoder.go
generated
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
package qpack
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
// An Encoder performs QPACK encoding.
|
||||
type Encoder struct {
|
||||
wrotePrefix bool
|
||||
|
||||
w io.Writer
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// NewEncoder returns a new Encoder which performs QPACK encoding. An
|
||||
// encoded data is written to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{w: w}
|
||||
}
|
||||
|
||||
// WriteField encodes f into a single Write to e's underlying Writer.
|
||||
// This function may also produce bytes for the Header Block Prefix
|
||||
// if necessary. If produced, it is done before encoding f.
|
||||
func (e *Encoder) WriteField(f HeaderField) error {
|
||||
// write the Header Block Prefix
|
||||
if !e.wrotePrefix {
|
||||
e.buf = appendVarInt(e.buf, 8, 0)
|
||||
e.buf = appendVarInt(e.buf, 7, 0)
|
||||
e.wrotePrefix = true
|
||||
}
|
||||
|
||||
idxAndVals, nameFound := encoderMap[f.Name]
|
||||
if nameFound {
|
||||
if idxAndVals.values == nil {
|
||||
if len(f.Value) == 0 {
|
||||
e.writeIndexedField(idxAndVals.idx)
|
||||
} else {
|
||||
e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx)
|
||||
}
|
||||
} else {
|
||||
valIdx, valueFound := idxAndVals.values[f.Value]
|
||||
if valueFound {
|
||||
e.writeIndexedField(valIdx)
|
||||
} else {
|
||||
e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
e.writeLiteralFieldWithoutNameReference(f)
|
||||
}
|
||||
|
||||
_, err := e.w.Write(e.buf)
|
||||
e.buf = e.buf[:0]
|
||||
return err
|
||||
}
|
||||
|
||||
// Close declares that the encoding is complete and resets the Encoder
|
||||
// to be reused again for a new header block.
|
||||
func (e *Encoder) Close() error {
|
||||
e.wrotePrefix = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Encoder) writeLiteralFieldWithoutNameReference(f HeaderField) {
|
||||
offset := len(e.buf)
|
||||
e.buf = appendVarInt(e.buf, 3, hpack.HuffmanEncodeLength(f.Name))
|
||||
e.buf[offset] ^= 0x20 ^ 0x8
|
||||
e.buf = hpack.AppendHuffmanString(e.buf, f.Name)
|
||||
offset = len(e.buf)
|
||||
e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value))
|
||||
e.buf[offset] ^= 0x80
|
||||
e.buf = hpack.AppendHuffmanString(e.buf, f.Value)
|
||||
}
|
||||
|
||||
// Encodes a header field whose name is present in one of the tables.
|
||||
func (e *Encoder) writeLiteralFieldWithNameReference(f *HeaderField, id uint8) {
|
||||
offset := len(e.buf)
|
||||
e.buf = appendVarInt(e.buf, 4, uint64(id))
|
||||
// Set the 01NTxxxx pattern, forcing N to 0 and T to 1
|
||||
e.buf[offset] ^= 0x50
|
||||
offset = len(e.buf)
|
||||
e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value))
|
||||
e.buf[offset] ^= 0x80
|
||||
e.buf = hpack.AppendHuffmanString(e.buf, f.Value)
|
||||
}
|
||||
|
||||
// Encodes an indexed field, meaning it's entirely defined in one of the tables.
|
||||
func (e *Encoder) writeIndexedField(id uint8) {
|
||||
offset := len(e.buf)
|
||||
e.buf = appendVarInt(e.buf, 6, uint64(id))
|
||||
// Set the 1Txxxxxx pattern, forcing T to 1
|
||||
e.buf[offset] ^= 0xc0
|
||||
}
|
||||
16
vendor/github.com/quic-go/qpack/header_field.go
generated
vendored
Normal file
16
vendor/github.com/quic-go/qpack/header_field.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
package qpack
|
||||
|
||||
// A HeaderField is a name-value pair. Both the name and value are
|
||||
// treated as opaque sequences of octets.
|
||||
type HeaderField struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// IsPseudo reports whether the header field is an HTTP3 pseudo header.
|
||||
// That is, it reports whether it starts with a colon.
|
||||
// It is not otherwise guaranteed to be a valid pseudo header field,
|
||||
// though.
|
||||
func (hf HeaderField) IsPseudo() bool {
|
||||
return len(hf.Name) != 0 && hf.Name[0] == ':'
|
||||
}
|
||||
255
vendor/github.com/quic-go/qpack/static_table.go
generated
vendored
Normal file
255
vendor/github.com/quic-go/qpack/static_table.go
generated
vendored
Normal file
@@ -0,0 +1,255 @@
|
||||
package qpack
|
||||
|
||||
var staticTableEntries = [...]HeaderField{
|
||||
{Name: ":authority"},
|
||||
{Name: ":path", Value: "/"},
|
||||
{Name: "age", Value: "0"},
|
||||
{Name: "content-disposition"},
|
||||
{Name: "content-length", Value: "0"},
|
||||
{Name: "cookie"},
|
||||
{Name: "date"},
|
||||
{Name: "etag"},
|
||||
{Name: "if-modified-since"},
|
||||
{Name: "if-none-match"},
|
||||
{Name: "last-modified"},
|
||||
{Name: "link"},
|
||||
{Name: "location"},
|
||||
{Name: "referer"},
|
||||
{Name: "set-cookie"},
|
||||
{Name: ":method", Value: "CONNECT"},
|
||||
{Name: ":method", Value: "DELETE"},
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":method", Value: "HEAD"},
|
||||
{Name: ":method", Value: "OPTIONS"},
|
||||
{Name: ":method", Value: "POST"},
|
||||
{Name: ":method", Value: "PUT"},
|
||||
{Name: ":scheme", Value: "http"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":status", Value: "103"},
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: ":status", Value: "304"},
|
||||
{Name: ":status", Value: "404"},
|
||||
{Name: ":status", Value: "503"},
|
||||
{Name: "accept", Value: "*/*"},
|
||||
{Name: "accept", Value: "application/dns-message"},
|
||||
{Name: "accept-encoding", Value: "gzip, deflate, br"},
|
||||
{Name: "accept-ranges", Value: "bytes"},
|
||||
{Name: "access-control-allow-headers", Value: "cache-control"},
|
||||
{Name: "access-control-allow-headers", Value: "content-type"},
|
||||
{Name: "access-control-allow-origin", Value: "*"},
|
||||
{Name: "cache-control", Value: "max-age=0"},
|
||||
{Name: "cache-control", Value: "max-age=2592000"},
|
||||
{Name: "cache-control", Value: "max-age=604800"},
|
||||
{Name: "cache-control", Value: "no-cache"},
|
||||
{Name: "cache-control", Value: "no-store"},
|
||||
{Name: "cache-control", Value: "public, max-age=31536000"},
|
||||
{Name: "content-encoding", Value: "br"},
|
||||
{Name: "content-encoding", Value: "gzip"},
|
||||
{Name: "content-type", Value: "application/dns-message"},
|
||||
{Name: "content-type", Value: "application/javascript"},
|
||||
{Name: "content-type", Value: "application/json"},
|
||||
{Name: "content-type", Value: "application/x-www-form-urlencoded"},
|
||||
{Name: "content-type", Value: "image/gif"},
|
||||
{Name: "content-type", Value: "image/jpeg"},
|
||||
{Name: "content-type", Value: "image/png"},
|
||||
{Name: "content-type", Value: "text/css"},
|
||||
{Name: "content-type", Value: "text/html; charset=utf-8"},
|
||||
{Name: "content-type", Value: "text/plain"},
|
||||
{Name: "content-type", Value: "text/plain;charset=utf-8"},
|
||||
{Name: "range", Value: "bytes=0-"},
|
||||
{Name: "strict-transport-security", Value: "max-age=31536000"},
|
||||
{Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains"},
|
||||
{Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains; preload"},
|
||||
{Name: "vary", Value: "accept-encoding"},
|
||||
{Name: "vary", Value: "origin"},
|
||||
{Name: "x-content-type-options", Value: "nosniff"},
|
||||
{Name: "x-xss-protection", Value: "1; mode=block"},
|
||||
{Name: ":status", Value: "100"},
|
||||
{Name: ":status", Value: "204"},
|
||||
{Name: ":status", Value: "206"},
|
||||
{Name: ":status", Value: "302"},
|
||||
{Name: ":status", Value: "400"},
|
||||
{Name: ":status", Value: "403"},
|
||||
{Name: ":status", Value: "421"},
|
||||
{Name: ":status", Value: "425"},
|
||||
{Name: ":status", Value: "500"},
|
||||
{Name: "accept-language"},
|
||||
{Name: "access-control-allow-credentials", Value: "FALSE"},
|
||||
{Name: "access-control-allow-credentials", Value: "TRUE"},
|
||||
{Name: "access-control-allow-headers", Value: "*"},
|
||||
{Name: "access-control-allow-methods", Value: "get"},
|
||||
{Name: "access-control-allow-methods", Value: "get, post, options"},
|
||||
{Name: "access-control-allow-methods", Value: "options"},
|
||||
{Name: "access-control-expose-headers", Value: "content-length"},
|
||||
{Name: "access-control-request-headers", Value: "content-type"},
|
||||
{Name: "access-control-request-method", Value: "get"},
|
||||
{Name: "access-control-request-method", Value: "post"},
|
||||
{Name: "alt-svc", Value: "clear"},
|
||||
{Name: "authorization"},
|
||||
{Name: "content-security-policy", Value: "script-src 'none'; object-src 'none'; base-uri 'none'"},
|
||||
{Name: "early-data", Value: "1"},
|
||||
{Name: "expect-ct"},
|
||||
{Name: "forwarded"},
|
||||
{Name: "if-range"},
|
||||
{Name: "origin"},
|
||||
{Name: "purpose", Value: "prefetch"},
|
||||
{Name: "server"},
|
||||
{Name: "timing-allow-origin", Value: "*"},
|
||||
{Name: "upgrade-insecure-requests", Value: "1"},
|
||||
{Name: "user-agent"},
|
||||
{Name: "x-forwarded-for"},
|
||||
{Name: "x-frame-options", Value: "deny"},
|
||||
{Name: "x-frame-options", Value: "sameorigin"},
|
||||
}
|
||||
|
||||
// Only needed for tests.
|
||||
// use go:linkname to retrieve the static table.
|
||||
//
|
||||
//nolint:unused
|
||||
func getStaticTable() []HeaderField {
|
||||
return staticTableEntries[:]
|
||||
}
|
||||
|
||||
type indexAndValues struct {
|
||||
idx uint8
|
||||
values map[string]uint8
|
||||
}
|
||||
|
||||
// A map of the header names from the static table to their index in the table.
|
||||
// This is used by the encoder to quickly find if a header is in the static table
|
||||
// and what value should be used to encode it.
|
||||
// There's a second level of mapping for the headers that have some predefined
|
||||
// values in the static table.
|
||||
var encoderMap = map[string]indexAndValues{
|
||||
":authority": {0, nil},
|
||||
":path": {1, map[string]uint8{"/": 1}},
|
||||
"age": {2, map[string]uint8{"0": 2}},
|
||||
"content-disposition": {3, nil},
|
||||
"content-length": {4, map[string]uint8{"0": 4}},
|
||||
"cookie": {5, nil},
|
||||
"date": {6, nil},
|
||||
"etag": {7, nil},
|
||||
"if-modified-since": {8, nil},
|
||||
"if-none-match": {9, nil},
|
||||
"last-modified": {10, nil},
|
||||
"link": {11, nil},
|
||||
"location": {12, nil},
|
||||
"referer": {13, nil},
|
||||
"set-cookie": {14, nil},
|
||||
":method": {15, map[string]uint8{
|
||||
"CONNECT": 15,
|
||||
"DELETE": 16,
|
||||
"GET": 17,
|
||||
"HEAD": 18,
|
||||
"OPTIONS": 19,
|
||||
"POST": 20,
|
||||
"PUT": 21,
|
||||
}},
|
||||
":scheme": {22, map[string]uint8{
|
||||
"http": 22,
|
||||
"https": 23,
|
||||
}},
|
||||
":status": {24, map[string]uint8{
|
||||
"103": 24,
|
||||
"200": 25,
|
||||
"304": 26,
|
||||
"404": 27,
|
||||
"503": 28,
|
||||
"100": 63,
|
||||
"204": 64,
|
||||
"206": 65,
|
||||
"302": 66,
|
||||
"400": 67,
|
||||
"403": 68,
|
||||
"421": 69,
|
||||
"425": 70,
|
||||
"500": 71,
|
||||
}},
|
||||
"accept": {29, map[string]uint8{
|
||||
"*/*": 29,
|
||||
"application/dns-message": 30,
|
||||
}},
|
||||
"accept-encoding": {31, map[string]uint8{"gzip, deflate, br": 31}},
|
||||
"accept-ranges": {32, map[string]uint8{"bytes": 32}},
|
||||
"access-control-allow-headers": {33, map[string]uint8{
|
||||
"cache-control": 33,
|
||||
"content-type": 34,
|
||||
"*": 75,
|
||||
}},
|
||||
"access-control-allow-origin": {35, map[string]uint8{"*": 35}},
|
||||
"cache-control": {36, map[string]uint8{
|
||||
"max-age=0": 36,
|
||||
"max-age=2592000": 37,
|
||||
"max-age=604800": 38,
|
||||
"no-cache": 39,
|
||||
"no-store": 40,
|
||||
"public, max-age=31536000": 41,
|
||||
}},
|
||||
"content-encoding": {42, map[string]uint8{
|
||||
"br": 42,
|
||||
"gzip": 43,
|
||||
}},
|
||||
"content-type": {44, map[string]uint8{
|
||||
"application/dns-message": 44,
|
||||
"application/javascript": 45,
|
||||
"application/json": 46,
|
||||
"application/x-www-form-urlencoded": 47,
|
||||
"image/gif": 48,
|
||||
"image/jpeg": 49,
|
||||
"image/png": 50,
|
||||
"text/css": 51,
|
||||
"text/html; charset=utf-8": 52,
|
||||
"text/plain": 53,
|
||||
"text/plain;charset=utf-8": 54,
|
||||
}},
|
||||
"range": {55, map[string]uint8{"bytes=0-": 55}},
|
||||
"strict-transport-security": {56, map[string]uint8{
|
||||
"max-age=31536000": 56,
|
||||
"max-age=31536000; includesubdomains": 57,
|
||||
"max-age=31536000; includesubdomains; preload": 58,
|
||||
}},
|
||||
"vary": {59, map[string]uint8{
|
||||
"accept-encoding": 59,
|
||||
"origin": 60,
|
||||
}},
|
||||
"x-content-type-options": {61, map[string]uint8{"nosniff": 61}},
|
||||
"x-xss-protection": {62, map[string]uint8{"1; mode=block": 62}},
|
||||
// ":status" is duplicated and takes index 63 to 71
|
||||
"accept-language": {72, nil},
|
||||
"access-control-allow-credentials": {73, map[string]uint8{
|
||||
"FALSE": 73,
|
||||
"TRUE": 74,
|
||||
}},
|
||||
// "access-control-allow-headers" is duplicated and takes index 75
|
||||
"access-control-allow-methods": {76, map[string]uint8{
|
||||
"get": 76,
|
||||
"get, post, options": 77,
|
||||
"options": 78,
|
||||
}},
|
||||
"access-control-expose-headers": {79, map[string]uint8{"content-length": 79}},
|
||||
"access-control-request-headers": {80, map[string]uint8{"content-type": 80}},
|
||||
"access-control-request-method": {81, map[string]uint8{
|
||||
"get": 81,
|
||||
"post": 82,
|
||||
}},
|
||||
"alt-svc": {83, map[string]uint8{"clear": 83}},
|
||||
"authorization": {84, nil},
|
||||
"content-security-policy": {85, map[string]uint8{
|
||||
"script-src 'none'; object-src 'none'; base-uri 'none'": 85,
|
||||
}},
|
||||
"early-data": {86, map[string]uint8{"1": 86}},
|
||||
"expect-ct": {87, nil},
|
||||
"forwarded": {88, nil},
|
||||
"if-range": {89, nil},
|
||||
"origin": {90, nil},
|
||||
"purpose": {91, map[string]uint8{"prefetch": 91}},
|
||||
"server": {92, nil},
|
||||
"timing-allow-origin": {93, map[string]uint8{"*": 93}},
|
||||
"upgrade-insecure-requests": {94, map[string]uint8{"1": 94}},
|
||||
"user-agent": {95, nil},
|
||||
"x-forwarded-for": {96, nil},
|
||||
"x-frame-options": {97, map[string]uint8{
|
||||
"deny": 97,
|
||||
"sameorigin": 98,
|
||||
}},
|
||||
}
|
||||
69
vendor/github.com/quic-go/qpack/varint.go
generated
vendored
Normal file
69
vendor/github.com/quic-go/qpack/varint.go
generated
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
package qpack
|
||||
|
||||
// copied from the Go standard library HPACK implementation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
var errVarintOverflow = errors.New("varint integer overflow")
|
||||
|
||||
// appendVarInt appends i, as encoded in variable integer form using n
|
||||
// bit prefix, to dst and returns the extended buffer.
|
||||
//
|
||||
// See
|
||||
// http://http2.github.io/http2-spec/compression.html#integer.representation
|
||||
func appendVarInt(dst []byte, n byte, i uint64) []byte {
|
||||
k := uint64((1 << n) - 1)
|
||||
if i < k {
|
||||
return append(dst, byte(i))
|
||||
}
|
||||
dst = append(dst, byte(k))
|
||||
i -= k
|
||||
for ; i >= 128; i >>= 7 {
|
||||
dst = append(dst, byte(0x80|(i&0x7f)))
|
||||
}
|
||||
return append(dst, byte(i))
|
||||
}
|
||||
|
||||
// readVarInt reads an unsigned variable length integer off the
|
||||
// beginning of p. n is the parameter as described in
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1.
|
||||
//
|
||||
// n must always be between 1 and 8.
|
||||
//
|
||||
// The returned remain buffer is either a smaller suffix of p, or err != nil.
|
||||
// The error is io.ErrUnexpectedEOF if p doesn't contain a complete integer.
|
||||
func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
|
||||
if n < 1 || n > 8 {
|
||||
panic("bad n")
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, p, io.ErrUnexpectedEOF
|
||||
}
|
||||
i = uint64(p[0])
|
||||
if n < 8 {
|
||||
i &= (1 << uint64(n)) - 1
|
||||
}
|
||||
if i < (1<<uint64(n))-1 {
|
||||
return i, p[1:], nil
|
||||
}
|
||||
|
||||
origP := p
|
||||
p = p[1:]
|
||||
var m uint64
|
||||
for len(p) > 0 {
|
||||
b := p[0]
|
||||
p = p[1:]
|
||||
i += uint64(b&127) << m
|
||||
if b&128 == 0 {
|
||||
return i, p, nil
|
||||
}
|
||||
m += 7
|
||||
if m >= 63 { // TODO: proper overflow check. making this up.
|
||||
return 0, origP, errVarintOverflow
|
||||
}
|
||||
}
|
||||
return 0, origP, io.ErrUnexpectedEOF
|
||||
}
|
||||
18
vendor/github.com/quic-go/quic-go/.gitignore
generated
vendored
Normal file
18
vendor/github.com/quic-go/quic-go/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
debug
|
||||
debug.test
|
||||
main
|
||||
mockgen_tmp.go
|
||||
*.qtr
|
||||
*.qlog
|
||||
*.sqlog
|
||||
*.txt
|
||||
race.[0-9]*
|
||||
|
||||
fuzzing/*/*.zip
|
||||
fuzzing/*/coverprofile
|
||||
fuzzing/*/crashers
|
||||
fuzzing/*/sonarprofile
|
||||
fuzzing/*/suppressions
|
||||
fuzzing/*/corpus/
|
||||
|
||||
gomock_reflect_*/
|
||||
109
vendor/github.com/quic-go/quic-go/.golangci.yml
generated
vendored
Normal file
109
vendor/github.com/quic-go/quic-go/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- asciicheck
|
||||
- copyloopvar
|
||||
- depguard
|
||||
- exhaustive
|
||||
- govet
|
||||
- ineffassign
|
||||
- misspell
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usetesting
|
||||
settings:
|
||||
depguard:
|
||||
rules:
|
||||
random:
|
||||
deny:
|
||||
- pkg: "math/rand$"
|
||||
desc: use math/rand/v2
|
||||
- pkg: "golang.org/x/exp/rand"
|
||||
desc: use math/rand/v2
|
||||
quicvarint:
|
||||
list-mode: strict
|
||||
files:
|
||||
- '**/github.com/quic-go/quic-go/quicvarint/*'
|
||||
- '!$test'
|
||||
allow:
|
||||
- $gostd
|
||||
rsa:
|
||||
list-mode: original
|
||||
deny:
|
||||
- pkg: crypto/rsa
|
||||
desc: "use crypto/ed25519 instead"
|
||||
ginkgo:
|
||||
list-mode: original
|
||||
deny:
|
||||
- pkg: github.com/onsi/ginkgo
|
||||
desc: "use standard Go tests"
|
||||
- pkg: github.com/onsi/ginkgo/v2
|
||||
desc: "use standard Go tests"
|
||||
- pkg: github.com/onsi/gomega
|
||||
desc: "use standard Go tests"
|
||||
http3-internal:
|
||||
list-mode: lax
|
||||
files:
|
||||
- '**/http3/**'
|
||||
deny:
|
||||
- pkg: 'github.com/quic-go/quic-go/internal'
|
||||
desc: 'no dependency on quic-go/internal'
|
||||
allow:
|
||||
- 'github.com/quic-go/quic-go/internal/synctest'
|
||||
misspell:
|
||||
ignore-rules:
|
||||
- ect
|
||||
# see https://github.com/ldez/usetesting/issues/10
|
||||
usetesting:
|
||||
context-background: false
|
||||
context-todo: false
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
rules:
|
||||
- linters:
|
||||
- depguard
|
||||
path: internal/qtls
|
||||
- linters:
|
||||
- exhaustive
|
||||
- prealloc
|
||||
- unparam
|
||||
path: _test\.go
|
||||
- linters:
|
||||
- staticcheck
|
||||
path: _test\.go
|
||||
text: 'SA1029:' # inappropriate key in call to context.WithValue
|
||||
# WebTransport still relies on the ConnectionTracingID and ConnectionTracingKey.
|
||||
# See https://github.com/quic-go/quic-go/issues/4405 for more details.
|
||||
- linters:
|
||||
- staticcheck
|
||||
paths:
|
||||
- http3/
|
||||
- integrationtests/self/http_test.go
|
||||
text: 'SA1019:.+quic\.ConnectionTracing(ID|Key)'
|
||||
paths:
|
||||
- internal/handshake/cipher_suite.go
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- goimports
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- internal/handshake/cipher_suite.go
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
21
vendor/github.com/quic-go/quic-go/LICENSE
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2016 the quic-go authors & Google, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
61
vendor/github.com/quic-go/quic-go/README.md
generated
vendored
Normal file
61
vendor/github.com/quic-go/quic-go/README.md
generated
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
<div align="center" style="margin-bottom: 15px;">
|
||||
<img src="./assets/quic-go-logo.png" width="700" height="auto">
|
||||
</div>
|
||||
|
||||
# A QUIC implementation in pure Go
|
||||
|
||||
|
||||
[](https://quic-go.net/docs/)
|
||||
[](https://pkg.go.dev/github.com/quic-go/quic-go)
|
||||
[](https://codecov.io/gh/quic-go/quic-go/)
|
||||
[](https://issues.oss-fuzz.com/issues?q=quic-go)
|
||||
|
||||
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)).
|
||||
|
||||
In addition to these base RFCs, it also implements the following RFCs:
|
||||
|
||||
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
|
||||
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
|
||||
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
|
||||
* QUIC Event Logging using qlog ([draft-ietf-quic-qlog-main-schema](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/) and [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/))
|
||||
* QUIC Stream Resets with Partial Delivery ([draft-ietf-quic-reliable-stream-reset](https://datatracker.ietf.org/doc/html/draft-ietf-quic-reliable-stream-reset-07))
|
||||
|
||||
Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go).
|
||||
|
||||
Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/).
|
||||
|
||||
## Projects using quic-go
|
||||
|
||||
| Project | Description | Stars |
|
||||
| ---------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
|
||||
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. |  |
|
||||
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support |  |
|
||||
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS |  |
|
||||
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins |  |
|
||||
| [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet |  |
|
||||
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others |  |
|
||||
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go |  |
|
||||
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy |  |
|
||||
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications |  |
|
||||
| [nodepass](https://github.com/yosebyte/nodepass) | A secure, efficient TCP/UDP tunneling solution that delivers fast, reliable access across network restrictions using pre-established TCP/QUIC connections |  |
|
||||
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. |  |
|
||||
| [reverst](https://github.com/flipt-io/reverst) | Reverse Tunnels in Go over HTTP/3 and QUIC |  |
|
||||
| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins |  |
|
||||
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization |  |
|
||||
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy |  |
|
||||
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions |  |
|
||||
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System |  |
|
||||
|
||||
If you'd like to see your project added to this list, please send us a PR.
|
||||
|
||||
## Release Policy
|
||||
|
||||
quic-go always aims to support the latest two Go releases.
|
||||
|
||||
## Contributing
|
||||
|
||||
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
|
||||
|
||||
## License
|
||||
|
||||
The code is licensed under the MIT license. The logo and brand assets are excluded from the MIT license. See [assets/LICENSE.md](https://github.com/quic-go/quic-go/tree/master/assets/LICENSE.md) for the full usage policy and details.
|
||||
14
vendor/github.com/quic-go/quic-go/SECURITY.md
generated
vendored
Normal file
14
vendor/github.com/quic-go/quic-go/SECURITY.md
generated
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# Security Policy
|
||||
|
||||
quic-go is an implementation of the QUIC protocol and related standards. No software is perfect, and we take reports of potential security issues very seriously.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you discover a vulnerability that could affect production deployments (e.g., a remotely exploitable issue), please report it [**privately**](https://github.com/quic-go/quic-go/security/advisories/new).
|
||||
Please **DO NOT file a public issue** for exploitable vulnerabilities.
|
||||
|
||||
If the issue is theoretical, non-exploitable, or related to an experimental feature, you may discuss it openly by filing a regular issue.
|
||||
|
||||
## Reporting a non-security bug
|
||||
|
||||
For bugs, feature requests, or other non-security concerns, please open a GitHub [issue](https://github.com/quic-go/quic-go/issues/new).
|
||||
92
vendor/github.com/quic-go/quic-go/buffer_pool.go
generated
vendored
Normal file
92
vendor/github.com/quic-go/quic-go/buffer_pool.go
generated
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type packetBuffer struct {
|
||||
Data []byte
|
||||
|
||||
// refCount counts how many packets Data is used in.
|
||||
// It doesn't support concurrent use.
|
||||
// It is > 1 when used for coalesced packet.
|
||||
refCount int
|
||||
}
|
||||
|
||||
// Split increases the refCount.
|
||||
// It must be called when a packet buffer is used for more than one packet,
|
||||
// e.g. when splitting coalesced packets.
|
||||
func (b *packetBuffer) Split() {
|
||||
b.refCount++
|
||||
}
|
||||
|
||||
// Decrement decrements the reference counter.
|
||||
// It doesn't put the buffer back into the pool.
|
||||
func (b *packetBuffer) Decrement() {
|
||||
b.refCount--
|
||||
if b.refCount < 0 {
|
||||
panic("negative packetBuffer refCount")
|
||||
}
|
||||
}
|
||||
|
||||
// MaybeRelease puts the packet buffer back into the pool,
|
||||
// if the reference counter already reached 0.
|
||||
func (b *packetBuffer) MaybeRelease() {
|
||||
// only put the packetBuffer back if it's not used any more
|
||||
if b.refCount == 0 {
|
||||
b.putBack()
|
||||
}
|
||||
}
|
||||
|
||||
// Release puts back the packet buffer into the pool.
|
||||
// It should be called when processing is definitely finished.
|
||||
func (b *packetBuffer) Release() {
|
||||
b.Decrement()
|
||||
if b.refCount != 0 {
|
||||
panic("packetBuffer refCount not zero")
|
||||
}
|
||||
b.putBack()
|
||||
}
|
||||
|
||||
// Len returns the length of Data
|
||||
func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
|
||||
func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
|
||||
|
||||
func (b *packetBuffer) putBack() {
|
||||
if cap(b.Data) == protocol.MaxPacketBufferSize {
|
||||
bufferPool.Put(b)
|
||||
return
|
||||
}
|
||||
if cap(b.Data) == protocol.MaxLargePacketBufferSize {
|
||||
largeBufferPool.Put(b)
|
||||
return
|
||||
}
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
|
||||
var bufferPool, largeBufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() *packetBuffer {
|
||||
buf := bufferPool.Get().(*packetBuffer)
|
||||
buf.refCount = 1
|
||||
buf.Data = buf.Data[:0]
|
||||
return buf
|
||||
}
|
||||
|
||||
func getLargePacketBuffer() *packetBuffer {
|
||||
buf := largeBufferPool.Get().(*packetBuffer)
|
||||
buf.refCount = 1
|
||||
buf.Data = buf.Data[:0]
|
||||
return buf
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() any {
|
||||
return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
|
||||
}
|
||||
largeBufferPool.New = func() any {
|
||||
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
|
||||
}
|
||||
}
|
||||
109
vendor/github.com/quic-go/quic-go/client.go
generated
vendored
Normal file
109
vendor/github.com/quic-go/quic-go/client.go
generated
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// make it possible to mock connection ID for initial generation in the tests
|
||||
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
|
||||
// When the QUIC connection is closed, this UDP connection is closed.
|
||||
// See [Dial] for more details.
|
||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (*Conn, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
|
||||
if err != nil {
|
||||
tr.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// See [DialAddr] for more details.
|
||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (*Conn, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
|
||||
if err != nil {
|
||||
tr.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
|
||||
// See [Dial] for more details.
|
||||
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// If the PacketConn satisfies the [OOBCapablePacketConn] interface (as a [net.UDPConn] does),
|
||||
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
|
||||
// will be used instead of ReadFrom and WriteTo to read/write packets.
|
||||
// The [tls.Config] must define an application protocol (using tls.Config.NextProtos).
|
||||
//
|
||||
// This is a convenience function. More advanced use cases should instantiate a [Transport],
|
||||
// which offers configuration options for a more fine-grained control of the connection establishment,
|
||||
// including reusing the underlying UDP socket for multiple QUIC connections.
|
||||
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
return &Transport{
|
||||
Conn: c,
|
||||
createdConn: createdPacketConn,
|
||||
isSingleUse: true,
|
||||
}, nil
|
||||
}
|
||||
58
vendor/github.com/quic-go/quic-go/closed_conn.go
generated
vendored
Normal file
58
vendor/github.com/quic-go/quic-go/closed_conn.go
generated
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A closedLocalConn is a connection that we closed locally.
|
||||
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
||||
// with an exponential backoff.
|
||||
type closedLocalConn struct {
|
||||
counter atomic.Uint32
|
||||
logger utils.Logger
|
||||
|
||||
sendPacket func(net.Addr, packetInfo)
|
||||
}
|
||||
|
||||
var _ packetHandler = &closedLocalConn{}
|
||||
|
||||
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
||||
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
|
||||
return &closedLocalConn{
|
||||
sendPacket: sendPacket,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) handlePacket(p receivedPacket) {
|
||||
n := c.counter.Add(1)
|
||||
// exponential backoff
|
||||
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
|
||||
if bits.OnesCount32(n) != 1 {
|
||||
return
|
||||
}
|
||||
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", n)
|
||||
c.sendPacket(p.remoteAddr, p.info)
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) destroy(error) {}
|
||||
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
|
||||
|
||||
// A closedRemoteConn is a connection that was closed remotely.
|
||||
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
||||
// We can just ignore those packets.
|
||||
type closedRemoteConn struct{}
|
||||
|
||||
var _ packetHandler = &closedRemoteConn{}
|
||||
|
||||
func newClosedRemoteConn() packetHandler {
|
||||
return &closedRemoteConn{}
|
||||
}
|
||||
|
||||
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||
func (c *closedRemoteConn) destroy(error) {}
|
||||
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
|
||||
19
vendor/github.com/quic-go/quic-go/codecov.yml
generated
vendored
Normal file
19
vendor/github.com/quic-go/quic-go/codecov.yml
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
coverage:
|
||||
round: nearest
|
||||
ignore:
|
||||
- http3/gzip_reader.go
|
||||
- example/
|
||||
- interop/
|
||||
- internal/handshake/cipher_suite.go
|
||||
- internal/mocks/
|
||||
- internal/utils/linkedlist/linkedlist.go
|
||||
- internal/testdata
|
||||
- internal/synctest
|
||||
- testutils/
|
||||
- fuzzing/
|
||||
- metrics/
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
threshold: 0.5
|
||||
patch: false
|
||||
129
vendor/github.com/quic-go/quic-go/config.go
generated
vendored
Normal file
129
vendor/github.com/quic-go/quic-go/config.go
generated
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// Clone clones a Config.
|
||||
func (c *Config) Clone() *Config {
|
||||
copy := *c
|
||||
return ©
|
||||
}
|
||||
|
||||
func (c *Config) handshakeTimeout() time.Duration {
|
||||
return 2 * c.HandshakeIdleTimeout
|
||||
}
|
||||
|
||||
func (c *Config) maxRetryTokenAge() time.Duration {
|
||||
return c.handshakeTimeout()
|
||||
}
|
||||
|
||||
func validateConfig(config *Config) error {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
const maxStreams = 1 << 60
|
||||
if config.MaxIncomingStreams > maxStreams {
|
||||
config.MaxIncomingStreams = maxStreams
|
||||
}
|
||||
if config.MaxIncomingUniStreams > maxStreams {
|
||||
config.MaxIncomingUniStreams = maxStreams
|
||||
}
|
||||
if config.MaxStreamReceiveWindow > quicvarint.Max {
|
||||
config.MaxStreamReceiveWindow = quicvarint.Max
|
||||
}
|
||||
if config.MaxConnectionReceiveWindow > quicvarint.Max {
|
||||
config.MaxConnectionReceiveWindow = quicvarint.Max
|
||||
}
|
||||
if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize {
|
||||
config.InitialPacketSize = protocol.MinInitialPacketSize
|
||||
}
|
||||
if config.InitialPacketSize > protocol.MaxPacketBufferSize {
|
||||
config.InitialPacketSize = protocol.MaxPacketBufferSize
|
||||
}
|
||||
// check that all QUIC versions are actually supported
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return fmt.Errorf("invalid QUIC version: %s", v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateConfig(config *Config) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
versions := config.Versions
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
|
||||
if config.HandshakeIdleTimeout != 0 {
|
||||
handshakeIdleTimeout = config.HandshakeIdleTimeout
|
||||
}
|
||||
idleTimeout := protocol.DefaultIdleTimeout
|
||||
if config.MaxIdleTimeout != 0 {
|
||||
idleTimeout = config.MaxIdleTimeout
|
||||
}
|
||||
initialStreamReceiveWindow := config.InitialStreamReceiveWindow
|
||||
if initialStreamReceiveWindow == 0 {
|
||||
initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData
|
||||
}
|
||||
maxStreamReceiveWindow := config.MaxStreamReceiveWindow
|
||||
if maxStreamReceiveWindow == 0 {
|
||||
maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow
|
||||
if initialConnectionReceiveWindow == 0 {
|
||||
initialConnectionReceiveWindow = protocol.DefaultInitialMaxData
|
||||
}
|
||||
maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow
|
||||
if maxConnectionReceiveWindow == 0 {
|
||||
maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
maxIncomingStreams := config.MaxIncomingStreams
|
||||
if maxIncomingStreams == 0 {
|
||||
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
||||
} else if maxIncomingStreams < 0 {
|
||||
maxIncomingStreams = 0
|
||||
}
|
||||
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
||||
if maxIncomingUniStreams == 0 {
|
||||
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
initialPacketSize := config.InitialPacketSize
|
||||
if initialPacketSize == 0 {
|
||||
initialPacketSize = protocol.InitialPacketSize
|
||||
}
|
||||
|
||||
return &Config{
|
||||
GetConfigForClient: config.GetConfigForClient,
|
||||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
|
||||
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
|
||||
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
InitialPacketSize: initialPacketSize,
|
||||
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
|
||||
EnableStreamResetPartialDelivery: config.EnableStreamResetPartialDelivery,
|
||||
Allow0RTT: config.Allow0RTT,
|
||||
Tracer: config.Tracer,
|
||||
}
|
||||
}
|
||||
212
vendor/github.com/quic-go/quic-go/conn_id_generator.go
generated
vendored
Normal file
212
vendor/github.com/quic-go/quic-go/conn_id_generator.go
generated
vendored
Normal file
@@ -0,0 +1,212 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type connRunnerCallbacks struct {
|
||||
AddConnectionID func(protocol.ConnectionID)
|
||||
RemoveConnectionID func(protocol.ConnectionID)
|
||||
ReplaceWithClosed func([]protocol.ConnectionID, []byte, time.Duration)
|
||||
}
|
||||
|
||||
// The memory address of the Transport is used as the key.
|
||||
type connRunners map[connRunner]connRunnerCallbacks
|
||||
|
||||
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
c.AddConnectionID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) {
|
||||
for _, c := range cr {
|
||||
c.RemoveConnectionID(id)
|
||||
}
|
||||
}
|
||||
|
||||
func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte, expiry time.Duration) {
|
||||
for _, c := range cr {
|
||||
c.ReplaceWithClosed(ids, b, expiry)
|
||||
}
|
||||
}
|
||||
|
||||
type connIDToRetire struct {
|
||||
t monotime.Time
|
||||
connID protocol.ConnectionID
|
||||
}
|
||||
|
||||
type connIDGenerator struct {
|
||||
generator ConnectionIDGenerator
|
||||
highestSeq uint64
|
||||
connRunners connRunners
|
||||
|
||||
activeSrcConnIDs map[uint64]protocol.ConnectionID
|
||||
connIDsToRetire []connIDToRetire // sorted by t
|
||||
initialClientDestConnID *protocol.ConnectionID // nil for the client
|
||||
|
||||
statelessResetter *statelessResetter
|
||||
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
func newConnIDGenerator(
|
||||
runner connRunner,
|
||||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
statelessResetter *statelessResetter,
|
||||
callbacks connRunnerCallbacks,
|
||||
queueControlFrame func(wire.Frame),
|
||||
generator ConnectionIDGenerator,
|
||||
) *connIDGenerator {
|
||||
m := &connIDGenerator{
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
statelessResetter: statelessResetter,
|
||||
connRunners: map[connRunner]connRunnerCallbacks{runner: callbacks},
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
m.initialClientDestConnID = initialClientDestConnID
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
|
||||
if m.generator.ConnectionIDLen() == 0 {
|
||||
return nil
|
||||
}
|
||||
// The active_connection_id_limit transport parameter is the number of
|
||||
// connection IDs the peer will store. This limit includes the connection ID
|
||||
// used during the handshake, and the one sent in the preferred_address
|
||||
// transport parameter.
|
||||
// We currently don't send the preferred_address transport parameter,
|
||||
// so we can issue (limit - 1) connection IDs.
|
||||
for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ {
|
||||
if err := m.issueNewConnID(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID, expiry monotime.Time) error {
|
||||
if seq > m.highestSeq {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
|
||||
}
|
||||
}
|
||||
connID, ok := m.activeSrcConnIDs[seq]
|
||||
// We might already have deleted this connection ID, if this is a duplicate frame.
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if connID == sentWithDestConnID {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
|
||||
}
|
||||
}
|
||||
m.queueConnIDForRetiring(connID, expiry)
|
||||
|
||||
delete(m.activeSrcConnIDs, seq)
|
||||
// Don't issue a replacement for the initial connection ID.
|
||||
if seq == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.issueNewConnID()
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) queueConnIDForRetiring(connID protocol.ConnectionID, expiry monotime.Time) {
|
||||
idx := slices.IndexFunc(m.connIDsToRetire, func(c connIDToRetire) bool {
|
||||
return c.t.After(expiry)
|
||||
})
|
||||
if idx == -1 {
|
||||
idx = len(m.connIDsToRetire)
|
||||
}
|
||||
m.connIDsToRetire = slices.Insert(m.connIDsToRetire, idx, connIDToRetire{t: expiry, connID: connID})
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) issueNewConnID() error {
|
||||
connID, err := m.generator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.activeSrcConnIDs[m.highestSeq+1] = connID
|
||||
m.connRunners.AddConnectionID(connID)
|
||||
m.queueControlFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: m.highestSeq + 1,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
|
||||
})
|
||||
m.highestSeq++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) SetHandshakeComplete(connIDExpiry monotime.Time) {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.queueConnIDForRetiring(*m.initialClientDestConnID, connIDExpiry)
|
||||
m.initialClientDestConnID = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveRetiredConnIDs(now monotime.Time) {
|
||||
if len(m.connIDsToRetire) == 0 {
|
||||
return
|
||||
}
|
||||
for _, c := range m.connIDsToRetire {
|
||||
if c.t.After(now) {
|
||||
break
|
||||
}
|
||||
m.connRunners.RemoveConnectionID(c.connID)
|
||||
m.connIDsToRetire = m.connIDsToRetire[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveAll() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.connRunners.RemoveConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
m.connRunners.RemoveConnectionID(connID)
|
||||
}
|
||||
for _, c := range m.connIDsToRetire {
|
||||
m.connRunners.RemoveConnectionID(c.connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte, expiry time.Duration) {
|
||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+len(m.connIDsToRetire)+1)
|
||||
if m.initialClientDestConnID != nil {
|
||||
connIDs = append(connIDs, *m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
for _, c := range m.connIDsToRetire {
|
||||
connIDs = append(connIDs, c.connID)
|
||||
}
|
||||
m.connRunners.ReplaceWithClosed(connIDs, connClose, expiry)
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) AddConnRunner(runner connRunner, r connRunnerCallbacks) {
|
||||
// The transport might have already been added earlier.
|
||||
// This happens if the application migrates back to and old path.
|
||||
if _, ok := m.connRunners[runner]; ok {
|
||||
return
|
||||
}
|
||||
m.connRunners[runner] = r
|
||||
if m.initialClientDestConnID != nil {
|
||||
r.AddConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
r.AddConnectionID(connID)
|
||||
}
|
||||
}
|
||||
321
vendor/github.com/quic-go/quic-go/conn_id_manager.go
generated
vendored
Normal file
321
vendor/github.com/quic-go/quic-go/conn_id_manager.go
generated
vendored
Normal file
@@ -0,0 +1,321 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type newConnID struct {
|
||||
SequenceNumber uint64
|
||||
ConnectionID protocol.ConnectionID
|
||||
StatelessResetToken protocol.StatelessResetToken
|
||||
}
|
||||
|
||||
type connIDManager struct {
|
||||
queue []newConnID
|
||||
|
||||
highestProbingID uint64
|
||||
pathProbing map[pathID]newConnID // initialized lazily
|
||||
|
||||
handshakeComplete bool
|
||||
activeSequenceNumber uint64
|
||||
highestRetired uint64
|
||||
activeConnectionID protocol.ConnectionID
|
||||
activeStatelessResetToken *protocol.StatelessResetToken
|
||||
|
||||
// We change the connection ID after sending on average
|
||||
// protocol.PacketsPerConnectionID packets. The actual value is randomized
|
||||
// hide the packet loss rate from on-path observers.
|
||||
rand utils.Rand
|
||||
packetsSinceLastChange uint32
|
||||
packetsPerConnectionID uint32
|
||||
|
||||
addStatelessResetToken func(protocol.StatelessResetToken)
|
||||
removeStatelessResetToken func(protocol.StatelessResetToken)
|
||||
queueControlFrame func(wire.Frame)
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newConnIDManager(
|
||||
initialDestConnID protocol.ConnectionID,
|
||||
addStatelessResetToken func(protocol.StatelessResetToken),
|
||||
removeStatelessResetToken func(protocol.StatelessResetToken),
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *connIDManager {
|
||||
return &connIDManager{
|
||||
activeConnectionID: initialDestConnID,
|
||||
addStatelessResetToken: addStatelessResetToken,
|
||||
removeStatelessResetToken: removeStatelessResetToken,
|
||||
queueControlFrame: queueControlFrame,
|
||||
queue: make([]newConnID, 0, protocol.MaxActiveConnectionIDs),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
|
||||
return h.addConnectionID(1, connID, resetToken)
|
||||
}
|
||||
|
||||
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
|
||||
if err := h.add(f); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(h.queue) >= protocol.MaxActiveConnectionIDs {
|
||||
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
|
||||
if h.activeConnectionID.Len() == 0 {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use",
|
||||
}
|
||||
}
|
||||
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
|
||||
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
|
||||
if f.SequenceNumber < max(h.activeSequenceNumber, h.highestProbingID) || f.SequenceNumber < h.highestRetired {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: f.SequenceNumber,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.RetirePriorTo != 0 && h.pathProbing != nil {
|
||||
for id, entry := range h.pathProbing {
|
||||
if entry.SequenceNumber < f.RetirePriorTo {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: entry.SequenceNumber,
|
||||
})
|
||||
h.removeStatelessResetToken(entry.StatelessResetToken)
|
||||
delete(h.pathProbing, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Retire elements in the queue.
|
||||
// Doesn't retire the active connection ID.
|
||||
if f.RetirePriorTo > h.highestRetired {
|
||||
var newQueue []newConnID
|
||||
for _, entry := range h.queue {
|
||||
if entry.SequenceNumber >= f.RetirePriorTo {
|
||||
newQueue = append(newQueue, entry)
|
||||
} else {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: entry.SequenceNumber})
|
||||
}
|
||||
}
|
||||
h.queue = newQueue
|
||||
h.highestRetired = f.RetirePriorTo
|
||||
}
|
||||
|
||||
if f.SequenceNumber == h.activeSequenceNumber {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Retire the active connection ID, if necessary.
|
||||
if h.activeSequenceNumber < f.RetirePriorTo {
|
||||
// The queue is guaranteed to have at least one element at this point.
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
|
||||
// fast path: add to the end of the queue
|
||||
if len(h.queue) == 0 || h.queue[len(h.queue)-1].SequenceNumber < seq {
|
||||
h.queue = append(h.queue, newConnID{
|
||||
SequenceNumber: seq,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: resetToken,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// slow path: insert in the middle
|
||||
for i, entry := range h.queue {
|
||||
if entry.SequenceNumber == seq {
|
||||
if entry.ConnectionID != connID {
|
||||
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
|
||||
}
|
||||
if entry.StatelessResetToken != resetToken {
|
||||
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// insert at the correct position to maintain sorted order
|
||||
if entry.SequenceNumber > seq {
|
||||
h.queue = slices.Insert(h.queue, i, newConnID{
|
||||
SequenceNumber: seq,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: resetToken,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil // unreachable
|
||||
}
|
||||
|
||||
func (h *connIDManager) updateConnectionID() {
|
||||
h.assertNotClosed()
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: h.activeSequenceNumber,
|
||||
})
|
||||
h.highestRetired = max(h.highestRetired, h.activeSequenceNumber)
|
||||
if h.activeStatelessResetToken != nil {
|
||||
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
|
||||
front := h.queue[0]
|
||||
h.queue = h.queue[1:]
|
||||
h.activeSequenceNumber = front.SequenceNumber
|
||||
h.activeConnectionID = front.ConnectionID
|
||||
h.activeStatelessResetToken = &front.StatelessResetToken
|
||||
h.packetsSinceLastChange = 0
|
||||
h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
|
||||
h.addStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
|
||||
func (h *connIDManager) Close() {
|
||||
h.closed = true
|
||||
if h.activeStatelessResetToken != nil {
|
||||
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
if h.pathProbing != nil {
|
||||
for _, entry := range h.pathProbing {
|
||||
h.removeStatelessResetToken(entry.StatelessResetToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// is called when the server performs a Retry
|
||||
// and when the server changes the connection ID in the first Initial sent
|
||||
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
|
||||
if h.activeSequenceNumber != 0 {
|
||||
panic("expected first connection ID to have sequence number 0")
|
||||
}
|
||||
h.activeConnectionID = newConnID
|
||||
}
|
||||
|
||||
// is called when the server provides a stateless reset token in the transport parameters
|
||||
func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
|
||||
h.assertNotClosed()
|
||||
if h.activeSequenceNumber != 0 {
|
||||
panic("expected first connection ID to have sequence number 0")
|
||||
}
|
||||
h.activeStatelessResetToken = &token
|
||||
h.addStatelessResetToken(token)
|
||||
}
|
||||
|
||||
func (h *connIDManager) SentPacket() {
|
||||
h.packetsSinceLastChange++
|
||||
}
|
||||
|
||||
func (h *connIDManager) shouldUpdateConnID() bool {
|
||||
if !h.handshakeComplete {
|
||||
return false
|
||||
}
|
||||
// initiate the first change as early as possible (after handshake completion)
|
||||
if len(h.queue) > 0 && h.activeSequenceNumber == 0 {
|
||||
return true
|
||||
}
|
||||
// For later changes, only change if
|
||||
// 1. The queue of connection IDs is filled more than 50%.
|
||||
// 2. We sent at least PacketsPerConnectionID packets
|
||||
return 2*len(h.queue) >= protocol.MaxActiveConnectionIDs &&
|
||||
h.packetsSinceLastChange >= h.packetsPerConnectionID
|
||||
}
|
||||
|
||||
func (h *connIDManager) Get() protocol.ConnectionID {
|
||||
h.assertNotClosed()
|
||||
if h.shouldUpdateConnID() {
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return h.activeConnectionID
|
||||
}
|
||||
|
||||
func (h *connIDManager) SetHandshakeComplete() {
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
|
||||
// GetConnIDForPath retrieves a connection ID for a new path (i.e. not the active one).
|
||||
// Once a connection ID is allocated for a path, it cannot be used for a different path.
|
||||
// When called with the same pathID, it will return the same connection ID,
|
||||
// unless the peer requested that this connection ID be retired.
|
||||
func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool) {
|
||||
h.assertNotClosed()
|
||||
// if we're using zero-length connection IDs, we don't need to change the connection ID
|
||||
if h.activeConnectionID.Len() == 0 {
|
||||
return protocol.ConnectionID{}, true
|
||||
}
|
||||
|
||||
if h.pathProbing == nil {
|
||||
h.pathProbing = make(map[pathID]newConnID)
|
||||
}
|
||||
entry, ok := h.pathProbing[id]
|
||||
if ok {
|
||||
return entry.ConnectionID, true
|
||||
}
|
||||
if len(h.queue) == 0 {
|
||||
return protocol.ConnectionID{}, false
|
||||
}
|
||||
front := h.queue[0]
|
||||
h.queue = h.queue[1:]
|
||||
h.pathProbing[id] = front
|
||||
h.highestProbingID = front.SequenceNumber
|
||||
h.addStatelessResetToken(front.StatelessResetToken)
|
||||
return front.ConnectionID, true
|
||||
}
|
||||
|
||||
func (h *connIDManager) RetireConnIDForPath(pathID pathID) {
|
||||
h.assertNotClosed()
|
||||
// if we're using zero-length connection IDs, we don't need to change the connection ID
|
||||
if h.activeConnectionID.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
entry, ok := h.pathProbing[pathID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: entry.SequenceNumber,
|
||||
})
|
||||
h.removeStatelessResetToken(entry.StatelessResetToken)
|
||||
delete(h.pathProbing, pathID)
|
||||
}
|
||||
|
||||
func (h *connIDManager) IsActiveStatelessResetToken(token protocol.StatelessResetToken) bool {
|
||||
if h.activeStatelessResetToken != nil {
|
||||
if *h.activeStatelessResetToken == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if h.pathProbing != nil {
|
||||
for _, entry := range h.pathProbing {
|
||||
if entry.StatelessResetToken == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Using the connIDManager after it has been closed can have disastrous effects:
|
||||
// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
|
||||
// leading to a memory leak of the connection struct.
|
||||
// See https://github.com/quic-go/quic-go/pull/4852 for more details.
|
||||
func (h *connIDManager) assertNotClosed() {
|
||||
if h.closed {
|
||||
panic("connection ID manager is closed")
|
||||
}
|
||||
}
|
||||
3147
vendor/github.com/quic-go/quic-go/connection.go
generated
vendored
Normal file
3147
vendor/github.com/quic-go/quic-go/connection.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
315
vendor/github.com/quic-go/quic-go/connection_logging.go
generated
vendored
Normal file
315
vendor/github.com/quic-go/quic-go/connection_logging.go
generated
vendored
Normal file
@@ -0,0 +1,315 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
// ConvertFrame converts a wire.Frame into a logging.Frame.
|
||||
// This makes it possible for external packages to access the frames.
|
||||
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
|
||||
func toQlogFrame(frame wire.Frame) qlog.Frame {
|
||||
switch f := frame.(type) {
|
||||
case *wire.AckFrame:
|
||||
// We use a pool for ACK frames.
|
||||
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
|
||||
return qlog.Frame{Frame: toQlogAckFrame(f)}
|
||||
case *wire.CryptoFrame:
|
||||
return qlog.Frame{
|
||||
Frame: &qlog.CryptoFrame{
|
||||
Offset: int64(f.Offset),
|
||||
Length: int64(len(f.Data)),
|
||||
},
|
||||
}
|
||||
case *wire.StreamFrame:
|
||||
return qlog.Frame{
|
||||
Frame: &qlog.StreamFrame{
|
||||
StreamID: f.StreamID,
|
||||
Offset: int64(f.Offset),
|
||||
Length: int64(f.DataLen()),
|
||||
Fin: f.Fin,
|
||||
},
|
||||
}
|
||||
case *wire.DatagramFrame:
|
||||
return qlog.Frame{
|
||||
Frame: &qlog.DatagramFrame{
|
||||
Length: int64(len(f.Data)),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return qlog.Frame{Frame: frame}
|
||||
}
|
||||
}
|
||||
|
||||
func toQlogAckFrame(f *wire.AckFrame) *qlog.AckFrame {
|
||||
ack := &qlog.AckFrame{
|
||||
AckRanges: slices.Clone(f.AckRanges),
|
||||
DelayTime: f.DelayTime,
|
||||
ECNCE: f.ECNCE,
|
||||
ECT0: f.ECT0,
|
||||
ECT1: f.ECT1,
|
||||
}
|
||||
return ack
|
||||
}
|
||||
|
||||
func (c *Conn) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN, datagramID qlog.DatagramID) {
|
||||
// quic-go logging
|
||||
if c.logger.Debug() {
|
||||
p.header.Log(c.logger)
|
||||
if p.ack != nil {
|
||||
wire.LogFrame(c.logger, p.ack, true)
|
||||
}
|
||||
for _, frame := range p.frames {
|
||||
wire.LogFrame(c.logger, frame.Frame, true)
|
||||
}
|
||||
for _, frame := range p.streamFrames {
|
||||
wire.LogFrame(c.logger, frame.Frame, true)
|
||||
}
|
||||
}
|
||||
|
||||
// tracing
|
||||
if c.qlogger != nil {
|
||||
numFrames := len(p.frames) + len(p.streamFrames)
|
||||
if p.ack != nil {
|
||||
numFrames++
|
||||
}
|
||||
frames := make([]qlog.Frame, 0, numFrames)
|
||||
if p.ack != nil {
|
||||
frames = append(frames, toQlogFrame(p.ack))
|
||||
}
|
||||
for _, f := range p.frames {
|
||||
frames = append(frames, toQlogFrame(f.Frame))
|
||||
}
|
||||
for _, f := range p.streamFrames {
|
||||
frames = append(frames, toQlogFrame(f.Frame))
|
||||
}
|
||||
c.qlogger.RecordEvent(qlog.PacketSent{
|
||||
Header: qlog.PacketHeader{
|
||||
PacketType: toQlogPacketType(p.header.Type),
|
||||
KeyPhaseBit: p.header.KeyPhase,
|
||||
PacketNumber: p.header.PacketNumber,
|
||||
Version: p.header.Version,
|
||||
SrcConnectionID: p.header.SrcConnectionID,
|
||||
DestConnectionID: p.header.DestConnectionID,
|
||||
},
|
||||
Raw: qlog.RawInfo{
|
||||
Length: int(p.length),
|
||||
PayloadLength: int(p.header.Length),
|
||||
},
|
||||
DatagramID: datagramID,
|
||||
Frames: frames,
|
||||
ECN: toQlogECN(ecn),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) logShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, size protocol.ByteCount) {
|
||||
c.logShortHeaderPacketWithDatagramID(p, ecn, size, false, 0)
|
||||
}
|
||||
|
||||
func (c *Conn) logShortHeaderPacketWithDatagramID(p shortHeaderPacket, ecn protocol.ECN, size protocol.ByteCount, isCoalesced bool, datagramID qlog.DatagramID) {
|
||||
if c.logger.Debug() && !isCoalesced {
|
||||
c.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", p.PacketNumber, size, c.logID, ecn)
|
||||
}
|
||||
// quic-go logging
|
||||
if c.logger.Debug() {
|
||||
wire.LogShortHeader(c.logger, p.DestConnID, p.PacketNumber, p.PacketNumberLen, p.KeyPhase)
|
||||
if p.Ack != nil {
|
||||
wire.LogFrame(c.logger, p.Ack, true)
|
||||
}
|
||||
for _, f := range p.Frames {
|
||||
wire.LogFrame(c.logger, f.Frame, true)
|
||||
}
|
||||
for _, f := range p.StreamFrames {
|
||||
wire.LogFrame(c.logger, f.Frame, true)
|
||||
}
|
||||
}
|
||||
|
||||
// tracing
|
||||
if c.qlogger != nil {
|
||||
numFrames := len(p.Frames) + len(p.StreamFrames)
|
||||
if p.Ack != nil {
|
||||
numFrames++
|
||||
}
|
||||
fs := make([]qlog.Frame, 0, numFrames)
|
||||
if p.Ack != nil {
|
||||
fs = append(fs, toQlogFrame(p.Ack))
|
||||
}
|
||||
for _, f := range p.Frames {
|
||||
fs = append(fs, toQlogFrame(f.Frame))
|
||||
}
|
||||
for _, f := range p.StreamFrames {
|
||||
fs = append(fs, toQlogFrame(f.Frame))
|
||||
}
|
||||
c.qlogger.RecordEvent(qlog.PacketSent{
|
||||
Header: qlog.PacketHeader{
|
||||
PacketType: qlog.PacketType1RTT,
|
||||
KeyPhaseBit: p.KeyPhase,
|
||||
PacketNumber: p.PacketNumber,
|
||||
Version: c.version,
|
||||
DestConnectionID: p.DestConnID,
|
||||
},
|
||||
Raw: qlog.RawInfo{
|
||||
Length: int(size),
|
||||
PayloadLength: int(size - wire.ShortHeaderLen(p.DestConnID, p.PacketNumberLen)),
|
||||
},
|
||||
DatagramID: datagramID,
|
||||
Frames: fs,
|
||||
ECN: toQlogECN(ecn),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
|
||||
var datagramID qlog.DatagramID
|
||||
if c.qlogger != nil {
|
||||
datagramID = qlog.CalculateDatagramID(packet.buffer.Data)
|
||||
}
|
||||
if c.logger.Debug() {
|
||||
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
|
||||
// during which we might call PackCoalescedPacket but just pack a short header packet.
|
||||
if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
|
||||
c.logShortHeaderPacketWithDatagramID(
|
||||
*packet.shortHdrPacket,
|
||||
ecn,
|
||||
packet.shortHdrPacket.Length,
|
||||
false,
|
||||
datagramID,
|
||||
)
|
||||
return
|
||||
}
|
||||
if len(packet.longHdrPackets) > 1 {
|
||||
c.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), c.logID)
|
||||
} else {
|
||||
c.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), c.logID, packet.longHdrPackets[0].EncryptionLevel())
|
||||
}
|
||||
}
|
||||
for _, p := range packet.longHdrPackets {
|
||||
c.logLongHeaderPacket(p, ecn, datagramID)
|
||||
}
|
||||
if p := packet.shortHdrPacket; p != nil {
|
||||
c.logShortHeaderPacketWithDatagramID(*p, ecn, p.Length, true, datagramID)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) qlogTransportParameters(tp *wire.TransportParameters, sentBy protocol.Perspective, restore bool) {
|
||||
ev := qlog.ParametersSet{
|
||||
Restore: restore,
|
||||
OriginalDestinationConnectionID: tp.OriginalDestinationConnectionID,
|
||||
InitialSourceConnectionID: tp.InitialSourceConnectionID,
|
||||
RetrySourceConnectionID: tp.RetrySourceConnectionID,
|
||||
StatelessResetToken: tp.StatelessResetToken,
|
||||
DisableActiveMigration: tp.DisableActiveMigration,
|
||||
MaxIdleTimeout: tp.MaxIdleTimeout,
|
||||
MaxUDPPayloadSize: tp.MaxUDPPayloadSize,
|
||||
AckDelayExponent: tp.AckDelayExponent,
|
||||
MaxAckDelay: tp.MaxAckDelay,
|
||||
ActiveConnectionIDLimit: tp.ActiveConnectionIDLimit,
|
||||
InitialMaxData: tp.InitialMaxData,
|
||||
InitialMaxStreamDataBidiLocal: tp.InitialMaxStreamDataBidiLocal,
|
||||
InitialMaxStreamDataBidiRemote: tp.InitialMaxStreamDataBidiRemote,
|
||||
InitialMaxStreamDataUni: tp.InitialMaxStreamDataUni,
|
||||
InitialMaxStreamsBidi: int64(tp.MaxBidiStreamNum),
|
||||
InitialMaxStreamsUni: int64(tp.MaxUniStreamNum),
|
||||
MaxDatagramFrameSize: tp.MaxDatagramFrameSize,
|
||||
EnableResetStreamAt: tp.EnableResetStreamAt,
|
||||
}
|
||||
if sentBy == c.perspective {
|
||||
ev.Initiator = qlog.InitiatorLocal
|
||||
} else {
|
||||
ev.Initiator = qlog.InitiatorRemote
|
||||
}
|
||||
if tp.PreferredAddress != nil {
|
||||
ev.PreferredAddress = &qlog.PreferredAddress{
|
||||
IPv4: tp.PreferredAddress.IPv4,
|
||||
IPv6: tp.PreferredAddress.IPv6,
|
||||
ConnectionID: tp.PreferredAddress.ConnectionID,
|
||||
StatelessResetToken: tp.PreferredAddress.StatelessResetToken,
|
||||
}
|
||||
}
|
||||
c.qlogger.RecordEvent(ev)
|
||||
}
|
||||
|
||||
func toQlogECN(ecn protocol.ECN) qlog.ECN {
|
||||
//nolint:exhaustive // only need to handle the 3 valid values
|
||||
switch ecn {
|
||||
case protocol.ECT0:
|
||||
return qlog.ECT0
|
||||
case protocol.ECT1:
|
||||
return qlog.ECT1
|
||||
case protocol.ECNCE:
|
||||
return qlog.ECNCE
|
||||
default:
|
||||
return qlog.ECNUnsupported
|
||||
}
|
||||
}
|
||||
|
||||
func toQlogPacketType(pt protocol.PacketType) qlog.PacketType {
|
||||
var qpt qlog.PacketType
|
||||
switch pt {
|
||||
case protocol.PacketTypeInitial:
|
||||
qpt = qlog.PacketTypeInitial
|
||||
case protocol.PacketTypeHandshake:
|
||||
qpt = qlog.PacketTypeHandshake
|
||||
case protocol.PacketType0RTT:
|
||||
qpt = qlog.PacketType0RTT
|
||||
case protocol.PacketTypeRetry:
|
||||
qpt = qlog.PacketTypeRetry
|
||||
}
|
||||
return qpt
|
||||
}
|
||||
|
||||
func toPathEndpointInfo(addr *net.UDPAddr) qlog.PathEndpointInfo {
|
||||
if addr == nil {
|
||||
return qlog.PathEndpointInfo{}
|
||||
}
|
||||
|
||||
var info qlog.PathEndpointInfo
|
||||
if addr.IP == nil || addr.IP.To4() != nil {
|
||||
addrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte(addr.IP.To4())), uint16(addr.Port))
|
||||
if addrPort.IsValid() {
|
||||
info.IPv4 = addrPort
|
||||
}
|
||||
} else {
|
||||
addrPort := netip.AddrPortFrom(netip.AddrFrom16([16]byte(addr.IP.To16())), uint16(addr.Port))
|
||||
if addrPort.IsValid() {
|
||||
info.IPv6 = addrPort
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// startedConnectionEvent builds a StartedConnection event using consistent logic
|
||||
// for both endpoints. If the local address is unspecified (e.g., dual-stack
|
||||
// listener), it selects the family based on the remote address and uses the
|
||||
// unspecified address of that family with the local port.
|
||||
func startedConnectionEvent(local, remote *net.UDPAddr) qlog.StartedConnection {
|
||||
var localInfo, remoteInfo qlog.PathEndpointInfo
|
||||
if remote != nil {
|
||||
remoteInfo = toPathEndpointInfo(remote)
|
||||
}
|
||||
if local != nil {
|
||||
if local.IP == nil || local.IP.IsUnspecified() {
|
||||
// Choose local family based on the remote address family.
|
||||
if remote != nil && remote.IP.To4() != nil {
|
||||
ap := netip.AddrPortFrom(netip.AddrFrom4([4]byte{}), uint16(local.Port))
|
||||
if ap.IsValid() {
|
||||
localInfo.IPv4 = ap
|
||||
}
|
||||
} else if remote != nil && remote.IP.To16() != nil && remote.IP.To4() == nil {
|
||||
ap := netip.AddrPortFrom(netip.AddrFrom16([16]byte{}), uint16(local.Port))
|
||||
if ap.IsValid() {
|
||||
localInfo.IPv6 = ap
|
||||
}
|
||||
}
|
||||
} else {
|
||||
localInfo = toPathEndpointInfo(local)
|
||||
}
|
||||
}
|
||||
return qlog.StartedConnection{Local: localInfo, Remote: remoteInfo}
|
||||
}
|
||||
249
vendor/github.com/quic-go/quic-go/crypto_stream.go
generated
vendored
Normal file
249
vendor/github.com/quic-go/quic-go/crypto_stream.go
generated
vendored
Normal file
@@ -0,0 +1,249 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
const disableClientHelloScramblingEnv = "QUIC_GO_DISABLE_CLIENTHELLO_SCRAMBLING"
|
||||
|
||||
// The baseCryptoStream is used by the cryptoStream and the initialCryptoStream.
|
||||
// This allows us to implement different logic for PopCryptoFrame for the two streams.
|
||||
type baseCryptoStream struct {
|
||||
queue frameSorter
|
||||
|
||||
highestOffset protocol.ByteCount
|
||||
finished bool
|
||||
|
||||
writeOffset protocol.ByteCount
|
||||
writeBuf []byte
|
||||
}
|
||||
|
||||
func newCryptoStream() *cryptoStream {
|
||||
return &cryptoStream{baseCryptoStream{queue: *newFrameSorter()}}
|
||||
}
|
||||
|
||||
func (s *baseCryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
|
||||
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.CryptoBufferExceeded,
|
||||
ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset),
|
||||
}
|
||||
}
|
||||
if s.finished {
|
||||
if highestOffset > s.highestOffset {
|
||||
// reject crypto data received after this stream was already finished
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "received crypto data after change of encryption level",
|
||||
}
|
||||
}
|
||||
// ignore data with a smaller offset than the highest received
|
||||
// could e.g. be a retransmission
|
||||
return nil
|
||||
}
|
||||
s.highestOffset = max(s.highestOffset, highestOffset)
|
||||
return s.queue.Push(f.Data, f.Offset, nil)
|
||||
}
|
||||
|
||||
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||
func (s *baseCryptoStream) GetCryptoData() []byte {
|
||||
_, data, _ := s.queue.Pop()
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *baseCryptoStream) Finish() error {
|
||||
if s.queue.HasMoreData() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "encryption level changed, but crypto stream has more data to read",
|
||||
}
|
||||
}
|
||||
s.finished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Writes writes data that should be sent out in CRYPTO frames
|
||||
func (s *baseCryptoStream) Write(p []byte) (int, error) {
|
||||
s.writeBuf = append(s.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *baseCryptoStream) HasData() bool {
|
||||
return len(s.writeBuf) > 0
|
||||
}
|
||||
|
||||
func (s *baseCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||
n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
f.Data = s.writeBuf[:n]
|
||||
s.writeBuf = s.writeBuf[n:]
|
||||
s.writeOffset += n
|
||||
return f
|
||||
}
|
||||
|
||||
type cryptoStream struct {
|
||||
baseCryptoStream
|
||||
}
|
||||
|
||||
type clientHelloCut struct {
|
||||
start protocol.ByteCount
|
||||
end protocol.ByteCount
|
||||
}
|
||||
|
||||
type initialCryptoStream struct {
|
||||
baseCryptoStream
|
||||
|
||||
scramble bool
|
||||
end protocol.ByteCount
|
||||
cuts [2]clientHelloCut
|
||||
}
|
||||
|
||||
func newInitialCryptoStream(isClient bool) *initialCryptoStream {
|
||||
var scramble bool
|
||||
if isClient {
|
||||
disabled, err := strconv.ParseBool(os.Getenv(disableClientHelloScramblingEnv))
|
||||
scramble = err != nil || !disabled
|
||||
}
|
||||
s := &initialCryptoStream{
|
||||
baseCryptoStream: baseCryptoStream{queue: *newFrameSorter()},
|
||||
scramble: scramble,
|
||||
}
|
||||
for i := range len(s.cuts) {
|
||||
s.cuts[i].start = protocol.InvalidByteCount
|
||||
s.cuts[i].end = protocol.InvalidByteCount
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *initialCryptoStream) HasData() bool {
|
||||
// The ClientHello might be written in multiple parts.
|
||||
// In order to correctly split the ClientHello, we need the entire ClientHello has been queued.
|
||||
if s.scramble && s.writeOffset == 0 && s.cuts[0].start == protocol.InvalidByteCount {
|
||||
return false
|
||||
}
|
||||
return s.baseCryptoStream.HasData()
|
||||
}
|
||||
|
||||
func (s *initialCryptoStream) Write(p []byte) (int, error) {
|
||||
s.writeBuf = append(s.writeBuf, p...)
|
||||
if !s.scramble {
|
||||
return len(p), nil
|
||||
}
|
||||
if s.cuts[0].start == protocol.InvalidByteCount {
|
||||
sniPos, sniLen, echPos, err := findSNIAndECH(s.writeBuf)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return len(p), nil
|
||||
}
|
||||
if err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if sniPos == -1 && echPos == -1 {
|
||||
// Neither SNI nor ECH found.
|
||||
// There's nothing to scramble.
|
||||
s.scramble = false
|
||||
return len(p), nil
|
||||
}
|
||||
s.end = protocol.ByteCount(len(s.writeBuf))
|
||||
s.cuts[0].start = protocol.ByteCount(sniPos + sniLen/2) // right in the middle
|
||||
s.cuts[0].end = protocol.ByteCount(sniPos + sniLen)
|
||||
if echPos > 0 {
|
||||
// ECH extension found, cut the ECH extension type value (a uint16) in half
|
||||
start := protocol.ByteCount(echPos + 1)
|
||||
s.cuts[1].start = start
|
||||
// cut somewhere (16 bytes), most likely in the ECH extension value
|
||||
s.cuts[1].end = min(start+16, s.end)
|
||||
}
|
||||
slices.SortFunc(s.cuts[:], func(a, b clientHelloCut) int {
|
||||
if a.start == protocol.InvalidByteCount {
|
||||
return 1
|
||||
}
|
||||
if a.start > b.start {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
})
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *initialCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
if !s.scramble {
|
||||
return s.baseCryptoStream.PopCryptoFrame(maxLen)
|
||||
}
|
||||
|
||||
// send out the skipped parts
|
||||
if s.writeOffset == s.end {
|
||||
var foundCuts bool
|
||||
var f *wire.CryptoFrame
|
||||
for i, c := range s.cuts {
|
||||
if c.start == protocol.InvalidByteCount {
|
||||
continue
|
||||
}
|
||||
foundCuts = true
|
||||
if f != nil {
|
||||
break
|
||||
}
|
||||
f = &wire.CryptoFrame{Offset: c.start}
|
||||
n := min(f.MaxDataLen(maxLen), c.end-c.start)
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
f.Data = s.writeBuf[c.start : c.start+n]
|
||||
s.cuts[i].start += n
|
||||
if s.cuts[i].start == c.end {
|
||||
s.cuts[i].start = protocol.InvalidByteCount
|
||||
s.cuts[i].end = protocol.InvalidByteCount
|
||||
foundCuts = false
|
||||
}
|
||||
}
|
||||
if !foundCuts {
|
||||
// no more cuts found, we're done sending out everything up until s.end
|
||||
s.writeBuf = s.writeBuf[s.end:]
|
||||
s.end = protocol.InvalidByteCount
|
||||
s.scramble = false
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
nextCut := clientHelloCut{start: protocol.InvalidByteCount, end: protocol.InvalidByteCount}
|
||||
for _, c := range s.cuts {
|
||||
if c.start == protocol.InvalidByteCount {
|
||||
continue
|
||||
}
|
||||
if c.start > s.writeOffset {
|
||||
nextCut = c
|
||||
break
|
||||
}
|
||||
}
|
||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||
maxOffset := nextCut.start
|
||||
if maxOffset == protocol.InvalidByteCount {
|
||||
maxOffset = s.end
|
||||
}
|
||||
n := min(f.MaxDataLen(maxLen), maxOffset-s.writeOffset)
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
f.Data = s.writeBuf[s.writeOffset : s.writeOffset+n]
|
||||
// Don't reslice the writeBuf yet.
|
||||
// This is done once all parts have been sent out.
|
||||
s.writeOffset += n
|
||||
if s.writeOffset == nextCut.start {
|
||||
s.writeOffset = nextCut.end
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
73
vendor/github.com/quic-go/quic-go/crypto_stream_manager.go
generated
vendored
Normal file
73
vendor/github.com/quic-go/quic-go/crypto_stream_manager.go
generated
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStreamManager struct {
|
||||
initialStream *initialCryptoStream
|
||||
handshakeStream *cryptoStream
|
||||
oneRTTStream *cryptoStream
|
||||
}
|
||||
|
||||
func newCryptoStreamManager(
|
||||
initialStream *initialCryptoStream,
|
||||
handshakeStream *cryptoStream,
|
||||
oneRTTStream *cryptoStream,
|
||||
) *cryptoStreamManager {
|
||||
return &cryptoStreamManager{
|
||||
initialStream: initialStream,
|
||||
handshakeStream: handshakeStream,
|
||||
oneRTTStream: oneRTTStream,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return m.initialStream.HandleCryptoFrame(frame)
|
||||
case protocol.EncryptionHandshake:
|
||||
return m.handshakeStream.HandleCryptoFrame(frame)
|
||||
case protocol.Encryption1RTT:
|
||||
return m.oneRTTStream.HandleCryptoFrame(frame)
|
||||
default:
|
||||
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte {
|
||||
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return m.initialStream.GetCryptoData()
|
||||
case protocol.EncryptionHandshake:
|
||||
return m.handshakeStream.GetCryptoData()
|
||||
case protocol.Encryption1RTT:
|
||||
return m.oneRTTStream.GetCryptoData()
|
||||
default:
|
||||
panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
|
||||
if !m.oneRTTStream.HasData() {
|
||||
return nil
|
||||
}
|
||||
return m.oneRTTStream.PopCryptoFrame(maxSize)
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
|
||||
//nolint:exhaustive // 1-RTT keys should never get dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return m.initialStream.Finish()
|
||||
case protocol.EncryptionHandshake:
|
||||
return m.handshakeStream.Finish()
|
||||
default:
|
||||
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
137
vendor/github.com/quic-go/quic-go/datagram_queue.go
generated
vendored
Normal file
137
vendor/github.com/quic-go/quic-go/datagram_queue.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
const (
|
||||
maxDatagramSendQueueLen = 32
|
||||
maxDatagramRcvQueueLen = 128
|
||||
)
|
||||
|
||||
type datagramQueue struct {
|
||||
sendMx sync.Mutex
|
||||
sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame]
|
||||
sent chan struct{} // used to notify Add that a datagram was dequeued
|
||||
|
||||
rcvMx sync.Mutex
|
||||
rcvQueue [][]byte
|
||||
rcvd chan struct{} // used to notify Receive that a new datagram was received
|
||||
|
||||
closeErr error
|
||||
closed chan struct{}
|
||||
|
||||
hasData func()
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
|
||||
return &datagramQueue{
|
||||
hasData: hasData,
|
||||
rcvd: make(chan struct{}, 1),
|
||||
sent: make(chan struct{}, 1),
|
||||
closed: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Add queues a new DATAGRAM frame for sending.
|
||||
// Up to 32 DATAGRAM frames will be queued.
|
||||
// Once that limit is reached, Add blocks until the queue size has reduced.
|
||||
func (h *datagramQueue) Add(f *wire.DatagramFrame) error {
|
||||
h.sendMx.Lock()
|
||||
|
||||
for {
|
||||
if h.sendQueue.Len() < maxDatagramSendQueueLen {
|
||||
h.sendQueue.PushBack(f)
|
||||
h.sendMx.Unlock()
|
||||
h.hasData()
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-h.sent: // drain the queue so we don't loop immediately
|
||||
default:
|
||||
}
|
||||
h.sendMx.Unlock()
|
||||
select {
|
||||
case <-h.closed:
|
||||
return h.closeErr
|
||||
case <-h.sent:
|
||||
}
|
||||
h.sendMx.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
// Peek gets the next DATAGRAM frame for sending.
|
||||
// If actually sent out, Pop needs to be called before the next call to Peek.
|
||||
func (h *datagramQueue) Peek() *wire.DatagramFrame {
|
||||
h.sendMx.Lock()
|
||||
defer h.sendMx.Unlock()
|
||||
if h.sendQueue.Empty() {
|
||||
return nil
|
||||
}
|
||||
return h.sendQueue.PeekFront()
|
||||
}
|
||||
|
||||
func (h *datagramQueue) Pop() {
|
||||
h.sendMx.Lock()
|
||||
defer h.sendMx.Unlock()
|
||||
_ = h.sendQueue.PopFront()
|
||||
select {
|
||||
case h.sent <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDatagramFrame handles a received DATAGRAM frame.
|
||||
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
|
||||
data := make([]byte, len(f.Data))
|
||||
copy(data, f.Data)
|
||||
var queued bool
|
||||
h.rcvMx.Lock()
|
||||
if len(h.rcvQueue) < maxDatagramRcvQueueLen {
|
||||
h.rcvQueue = append(h.rcvQueue, data)
|
||||
queued = true
|
||||
select {
|
||||
case h.rcvd <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
h.rcvMx.Unlock()
|
||||
if !queued && h.logger.Debug() {
|
||||
h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data))
|
||||
}
|
||||
}
|
||||
|
||||
// Receive gets a received DATAGRAM frame.
|
||||
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
|
||||
for {
|
||||
h.rcvMx.Lock()
|
||||
if len(h.rcvQueue) > 0 {
|
||||
data := h.rcvQueue[0]
|
||||
h.rcvQueue = h.rcvQueue[1:]
|
||||
h.rcvMx.Unlock()
|
||||
return data, nil
|
||||
}
|
||||
h.rcvMx.Unlock()
|
||||
select {
|
||||
case <-h.rcvd:
|
||||
continue
|
||||
case <-h.closed:
|
||||
return nil, h.closeErr
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *datagramQueue) CloseWithError(e error) {
|
||||
h.closeErr = e
|
||||
close(h.closed)
|
||||
}
|
||||
105
vendor/github.com/quic-go/quic-go/errors.go
generated
vendored
Normal file
105
vendor/github.com/quic-go/quic-go/errors.go
generated
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
)
|
||||
|
||||
type (
|
||||
// TransportError indicates an error that occurred on the QUIC transport layer.
|
||||
// Every transport error other than CONNECTION_REFUSED and APPLICATION_ERROR is
|
||||
// likely a bug in the implementation.
|
||||
TransportError = qerr.TransportError
|
||||
// ApplicationError is an application-defined error.
|
||||
ApplicationError = qerr.ApplicationError
|
||||
// VersionNegotiationError indicates a failure to negotiate a QUIC version.
|
||||
VersionNegotiationError = qerr.VersionNegotiationError
|
||||
// StatelessResetError indicates a stateless reset was received.
|
||||
// This can happen when the peer reboots, or when packets are misrouted.
|
||||
// See section 10.3 of RFC 9000 for details.
|
||||
StatelessResetError = qerr.StatelessResetError
|
||||
// IdleTimeoutError indicates that the connection timed out because it was inactive for too long.
|
||||
IdleTimeoutError = qerr.IdleTimeoutError
|
||||
// HandshakeTimeoutError indicates that the connection timed out before completing the handshake.
|
||||
HandshakeTimeoutError = qerr.HandshakeTimeoutError
|
||||
)
|
||||
|
||||
type (
|
||||
// TransportErrorCode is a QUIC transport error code, see section 20 of RFC 9000.
|
||||
TransportErrorCode = qerr.TransportErrorCode
|
||||
// ApplicationErrorCode is an QUIC application error code.
|
||||
ApplicationErrorCode = qerr.ApplicationErrorCode
|
||||
// StreamErrorCode is a QUIC stream error code. The meaning of the value is defined by the application.
|
||||
StreamErrorCode = qerr.StreamErrorCode
|
||||
)
|
||||
|
||||
const (
|
||||
// NoError is the NO_ERROR transport error code.
|
||||
NoError = qerr.NoError
|
||||
// InternalError is the INTERNAL_ERROR transport error code.
|
||||
InternalError = qerr.InternalError
|
||||
// ConnectionRefused is the CONNECTION_REFUSED transport error code.
|
||||
ConnectionRefused = qerr.ConnectionRefused
|
||||
// FlowControlError is the FLOW_CONTROL_ERROR transport error code.
|
||||
FlowControlError = qerr.FlowControlError
|
||||
// StreamLimitError is the STREAM_LIMIT_ERROR transport error code.
|
||||
StreamLimitError = qerr.StreamLimitError
|
||||
// StreamStateError is the STREAM_STATE_ERROR transport error code.
|
||||
StreamStateError = qerr.StreamStateError
|
||||
// FinalSizeError is the FINAL_SIZE_ERROR transport error code.
|
||||
FinalSizeError = qerr.FinalSizeError
|
||||
// FrameEncodingError is the FRAME_ENCODING_ERROR transport error code.
|
||||
FrameEncodingError = qerr.FrameEncodingError
|
||||
// TransportParameterError is the TRANSPORT_PARAMETER_ERROR transport error code.
|
||||
TransportParameterError = qerr.TransportParameterError
|
||||
// ConnectionIDLimitError is the CONNECTION_ID_LIMIT_ERROR transport error code.
|
||||
ConnectionIDLimitError = qerr.ConnectionIDLimitError
|
||||
// ProtocolViolation is the PROTOCOL_VIOLATION transport error code.
|
||||
ProtocolViolation = qerr.ProtocolViolation
|
||||
// InvalidToken is the INVALID_TOKEN transport error code.
|
||||
InvalidToken = qerr.InvalidToken
|
||||
// ApplicationErrorErrorCode is the APPLICATION_ERROR transport error code.
|
||||
ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode
|
||||
// CryptoBufferExceeded is the CRYPTO_BUFFER_EXCEEDED transport error code.
|
||||
CryptoBufferExceeded = qerr.CryptoBufferExceeded
|
||||
// KeyUpdateError is the KEY_UPDATE_ERROR transport error code.
|
||||
KeyUpdateError = qerr.KeyUpdateError
|
||||
// AEADLimitReached is the AEAD_LIMIT_REACHED transport error code.
|
||||
AEADLimitReached = qerr.AEADLimitReached
|
||||
// NoViablePathError is the NO_VIABLE_PATH_ERROR transport error code.
|
||||
NoViablePathError = qerr.NoViablePathError
|
||||
)
|
||||
|
||||
// A StreamError is used to signal stream cancellations.
|
||||
// It is returned from the Read and Write methods of the [ReceiveStream], [SendStream] and [Stream].
|
||||
type StreamError struct {
|
||||
StreamID StreamID
|
||||
ErrorCode StreamErrorCode
|
||||
Remote bool
|
||||
}
|
||||
|
||||
func (e *StreamError) Is(target error) bool {
|
||||
t, ok := target.(*StreamError)
|
||||
return ok && e.StreamID == t.StreamID && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
|
||||
}
|
||||
|
||||
func (e *StreamError) Error() string {
|
||||
pers := "local"
|
||||
if e.Remote {
|
||||
pers = "remote"
|
||||
}
|
||||
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
|
||||
}
|
||||
|
||||
// DatagramTooLargeError is returned from Conn.SendDatagram if the payload is too large to be sent.
|
||||
type DatagramTooLargeError struct {
|
||||
MaxDatagramPayloadSize int64
|
||||
}
|
||||
|
||||
func (e *DatagramTooLargeError) Is(target error) bool {
|
||||
t, ok := target.(*DatagramTooLargeError)
|
||||
return ok && e.MaxDatagramPayloadSize == t.MaxDatagramPayloadSize
|
||||
}
|
||||
|
||||
func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" }
|
||||
274
vendor/github.com/quic-go/quic-go/frame_sorter.go
generated
vendored
Normal file
274
vendor/github.com/quic-go/quic-go/frame_sorter.go
generated
vendored
Normal file
@@ -0,0 +1,274 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
)
|
||||
|
||||
// byteInterval is an interval from one ByteCount to the other
|
||||
type byteInterval struct {
|
||||
Start protocol.ByteCount
|
||||
End protocol.ByteCount
|
||||
}
|
||||
|
||||
var byteIntervalElementPool sync.Pool
|
||||
|
||||
func init() {
|
||||
byteIntervalElementPool = *list.NewPool[byteInterval]()
|
||||
}
|
||||
|
||||
type frameSorterEntry struct {
|
||||
Data []byte
|
||||
DoneCb func()
|
||||
}
|
||||
|
||||
type frameSorter struct {
|
||||
queue map[protocol.ByteCount]frameSorterEntry
|
||||
readPos protocol.ByteCount
|
||||
gaps *list.List[byteInterval]
|
||||
}
|
||||
|
||||
var errDuplicateStreamData = errors.New("duplicate stream data")
|
||||
|
||||
func newFrameSorter() *frameSorter {
|
||||
s := frameSorter{
|
||||
gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool),
|
||||
queue: make(map[protocol.ByteCount]frameSorterEntry),
|
||||
}
|
||||
s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount})
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error {
|
||||
err := s.push(data, offset, doneCb)
|
||||
if err == errDuplicateStreamData {
|
||||
if doneCb != nil {
|
||||
doneCb()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error {
|
||||
if len(data) == 0 {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
start := offset
|
||||
end := offset + protocol.ByteCount(len(data))
|
||||
|
||||
if end <= s.gaps.Front().Value.Start {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
startGap, startsInGap := s.findStartGap(start)
|
||||
endGap, endsInGap := s.findEndGap(startGap, end)
|
||||
|
||||
startGapEqualsEndGap := startGap == endGap
|
||||
|
||||
if (startGapEqualsEndGap && end <= startGap.Value.Start) ||
|
||||
(!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
startGapNext := startGap.Next()
|
||||
startGapEnd := startGap.Value.End // save it, in case startGap is modified
|
||||
endGapStart := endGap.Value.Start // save it, in case endGap is modified
|
||||
endGapEnd := endGap.Value.End // save it, in case endGap is modified
|
||||
var adjustedStartGapEnd bool
|
||||
var wasCut bool
|
||||
|
||||
pos := start
|
||||
var hasReplacedAtLeastOne bool
|
||||
for {
|
||||
oldEntry, ok := s.queue[pos]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
|
||||
if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) {
|
||||
// The existing frame is shorter than the new frame. Replace it.
|
||||
delete(s.queue, pos)
|
||||
pos += oldEntryLen
|
||||
hasReplacedAtLeastOne = true
|
||||
if oldEntry.DoneCb != nil {
|
||||
oldEntry.DoneCb()
|
||||
}
|
||||
} else {
|
||||
if !hasReplacedAtLeastOne {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
// The existing frame is longer than the new frame.
|
||||
// Cut the new frame such that the end aligns with the start of the existing frame.
|
||||
data = data[:pos-start]
|
||||
end = pos
|
||||
wasCut = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !startsInGap && !hasReplacedAtLeastOne {
|
||||
// cut the frame, such that it starts at the start of the gap
|
||||
data = data[startGap.Value.Start-start:]
|
||||
start = startGap.Value.Start
|
||||
wasCut = true
|
||||
}
|
||||
if start <= startGap.Value.Start {
|
||||
if end >= startGap.Value.End {
|
||||
// The frame covers the whole startGap. Delete the gap.
|
||||
s.gaps.Remove(startGap)
|
||||
} else {
|
||||
startGap.Value.Start = end
|
||||
}
|
||||
} else if !hasReplacedAtLeastOne {
|
||||
startGap.Value.End = start
|
||||
adjustedStartGapEnd = true
|
||||
}
|
||||
|
||||
if !startGapEqualsEndGap {
|
||||
s.deleteConsecutive(startGapEnd)
|
||||
var nextGap *list.Element[byteInterval]
|
||||
for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap {
|
||||
nextGap = gap.Next()
|
||||
s.deleteConsecutive(gap.Value.End)
|
||||
s.gaps.Remove(gap)
|
||||
}
|
||||
}
|
||||
|
||||
if !endsInGap && start != endGapEnd && end > endGapEnd {
|
||||
// cut the frame, such that it ends at the end of the gap
|
||||
data = data[:endGapEnd-start]
|
||||
end = endGapEnd
|
||||
wasCut = true
|
||||
}
|
||||
if end == endGapEnd {
|
||||
if !startGapEqualsEndGap {
|
||||
// The frame covers the whole endGap. Delete the gap.
|
||||
s.gaps.Remove(endGap)
|
||||
}
|
||||
} else {
|
||||
if startGapEqualsEndGap && adjustedStartGapEnd {
|
||||
// The frame split the existing gap into two.
|
||||
s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap)
|
||||
} else if !startGapEqualsEndGap {
|
||||
endGap.Value.Start = end
|
||||
}
|
||||
}
|
||||
|
||||
if wasCut && len(data) < protocol.MinStreamFrameBufferSize {
|
||||
newData := make([]byte, len(data))
|
||||
copy(newData, data)
|
||||
data = newData
|
||||
if doneCb != nil {
|
||||
doneCb()
|
||||
doneCb = nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
|
||||
return errors.New("too many gaps in received data")
|
||||
}
|
||||
|
||||
s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
|
||||
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
|
||||
if offset >= gap.Value.Start && offset <= gap.Value.End {
|
||||
return gap, true
|
||||
}
|
||||
if offset < gap.Value.Start {
|
||||
return gap, false
|
||||
}
|
||||
}
|
||||
panic("no gap found")
|
||||
}
|
||||
|
||||
func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
|
||||
for gap := startGap; gap != nil; gap = gap.Next() {
|
||||
if offset >= gap.Value.Start && offset < gap.Value.End {
|
||||
return gap, true
|
||||
}
|
||||
if offset < gap.Value.Start {
|
||||
return gap.Prev(), false
|
||||
}
|
||||
}
|
||||
panic("no gap found")
|
||||
}
|
||||
|
||||
// deleteConsecutive deletes consecutive frames from the queue, starting at pos
|
||||
func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) {
|
||||
for {
|
||||
oldEntry, ok := s.queue[pos]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
|
||||
delete(s.queue, pos)
|
||||
if oldEntry.DoneCb != nil {
|
||||
oldEntry.DoneCb()
|
||||
}
|
||||
pos += oldEntryLen
|
||||
}
|
||||
}
|
||||
|
||||
func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) {
|
||||
entry, ok := s.queue[s.readPos]
|
||||
if !ok {
|
||||
return s.readPos, nil, nil
|
||||
}
|
||||
delete(s.queue, s.readPos)
|
||||
offset := s.readPos
|
||||
s.readPos += protocol.ByteCount(len(entry.Data))
|
||||
if s.gaps.Front().Value.End <= s.readPos {
|
||||
panic("frame sorter BUG: read position higher than a gap")
|
||||
}
|
||||
return offset, entry.Data, entry.DoneCb
|
||||
}
|
||||
|
||||
// HasMoreData says if there is any more data queued at *any* offset.
|
||||
func (s *frameSorter) HasMoreData() bool {
|
||||
return len(s.queue) > 0
|
||||
}
|
||||
|
||||
var errTooLittleData = errors.New("too little data")
|
||||
|
||||
// Peek copies len(p) consecutive bytes starting at offset into p, without removing them.
|
||||
// It is only possible to peek from an offset where a frame starts.
|
||||
//
|
||||
// If there isn't enough consecutive data available, errTooLittleData is returned.
|
||||
func (s *frameSorter) Peek(offset protocol.ByteCount, p []byte) error {
|
||||
if len(p) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// first, check if we have enough consecutive data available
|
||||
pos := offset
|
||||
remaining := len(p)
|
||||
for remaining > 0 {
|
||||
entry, ok := s.queue[pos]
|
||||
if !ok {
|
||||
return errTooLittleData
|
||||
}
|
||||
entryLen := len(entry.Data)
|
||||
if remaining <= entryLen {
|
||||
break // enough data available
|
||||
}
|
||||
remaining -= entryLen
|
||||
pos += protocol.ByteCount(entryLen)
|
||||
}
|
||||
|
||||
pos = offset
|
||||
var copied int
|
||||
for copied < len(p) {
|
||||
entry := s.queue[pos] // the entry is guaranteed to exist from the check above
|
||||
copied += copy(p[copied:], entry.Data)
|
||||
pos += protocol.ByteCount(len(entry.Data))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
295
vendor/github.com/quic-go/quic-go/framer.go
generated
vendored
Normal file
295
vendor/github.com/quic-go/quic-go/framer.go
generated
vendored
Normal file
@@ -0,0 +1,295 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/ackhandler"
|
||||
"github.com/quic-go/quic-go/internal/flowcontrol"
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
const (
|
||||
maxPathResponses = 256
|
||||
maxControlFrames = 16 << 10
|
||||
)
|
||||
|
||||
// This is the largest possible size of a stream-related control frame
|
||||
// (which is the RESET_STREAM frame).
|
||||
const maxStreamControlFrameSize = 25
|
||||
|
||||
type streamFrameGetter interface {
|
||||
popStreamFrame(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)
|
||||
}
|
||||
|
||||
type streamControlFrameGetter interface {
|
||||
getControlFrame(monotime.Time) (_ ackhandler.Frame, ok, hasMore bool)
|
||||
}
|
||||
|
||||
type framer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
activeStreams map[protocol.StreamID]streamFrameGetter
|
||||
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
|
||||
streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter
|
||||
|
||||
controlFrameMutex sync.Mutex
|
||||
controlFrames []wire.Frame
|
||||
pathResponses []*wire.PathResponseFrame
|
||||
connFlowController flowcontrol.ConnectionFlowController
|
||||
queuedTooManyControlFrames bool
|
||||
}
|
||||
|
||||
func newFramer(connFlowController flowcontrol.ConnectionFlowController) *framer {
|
||||
return &framer{
|
||||
activeStreams: make(map[protocol.StreamID]streamFrameGetter),
|
||||
streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter),
|
||||
connFlowController: connFlowController,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *framer) HasData() bool {
|
||||
f.mutex.Lock()
|
||||
hasData := !f.streamQueue.Empty()
|
||||
f.mutex.Unlock()
|
||||
if hasData {
|
||||
return true
|
||||
}
|
||||
f.controlFrameMutex.Lock()
|
||||
defer f.controlFrameMutex.Unlock()
|
||||
return len(f.streamsWithControlFrames) > 0 || len(f.controlFrames) > 0 || len(f.pathResponses) > 0
|
||||
}
|
||||
|
||||
func (f *framer) QueueControlFrame(frame wire.Frame) {
|
||||
f.controlFrameMutex.Lock()
|
||||
defer f.controlFrameMutex.Unlock()
|
||||
|
||||
if pr, ok := frame.(*wire.PathResponseFrame); ok {
|
||||
// Only queue up to maxPathResponses PATH_RESPONSE frames.
|
||||
// This limit should be high enough to never be hit in practice,
|
||||
// unless the peer is doing something malicious.
|
||||
if len(f.pathResponses) >= maxPathResponses {
|
||||
return
|
||||
}
|
||||
f.pathResponses = append(f.pathResponses, pr)
|
||||
return
|
||||
}
|
||||
// This is a hack.
|
||||
if len(f.controlFrames) >= maxControlFrames {
|
||||
f.queuedTooManyControlFrames = true
|
||||
return
|
||||
}
|
||||
f.controlFrames = append(f.controlFrames, frame)
|
||||
}
|
||||
|
||||
func (f *framer) Append(
|
||||
frames []ackhandler.Frame,
|
||||
streamFrames []ackhandler.StreamFrame,
|
||||
maxLen protocol.ByteCount,
|
||||
now monotime.Time,
|
||||
v protocol.Version,
|
||||
) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) {
|
||||
f.controlFrameMutex.Lock()
|
||||
frames, controlFrameLen := f.appendControlFrames(frames, maxLen, now, v)
|
||||
maxLen -= controlFrameLen
|
||||
|
||||
var lastFrame ackhandler.StreamFrame
|
||||
var streamFrameLen protocol.ByteCount
|
||||
f.mutex.Lock()
|
||||
// pop STREAM frames, until less than 128 bytes are left in the packet
|
||||
numActiveStreams := f.streamQueue.Len()
|
||||
for i := 0; i < numActiveStreams; i++ {
|
||||
if protocol.MinStreamFrameSize > maxLen {
|
||||
break
|
||||
}
|
||||
sf, blocked := f.getNextStreamFrame(maxLen, v)
|
||||
if sf.Frame != nil {
|
||||
streamFrames = append(streamFrames, sf)
|
||||
maxLen -= sf.Frame.Length(v)
|
||||
lastFrame = sf
|
||||
streamFrameLen += sf.Frame.Length(v)
|
||||
}
|
||||
// If the stream just became blocked on stream flow control, attempt to pack the
|
||||
// STREAM_DATA_BLOCKED into the same packet.
|
||||
if blocked != nil {
|
||||
l := blocked.Length(v)
|
||||
// In case it doesn't fit, queue it for the next packet.
|
||||
if maxLen < l {
|
||||
f.controlFrames = append(f.controlFrames, blocked)
|
||||
break
|
||||
}
|
||||
frames = append(frames, ackhandler.Frame{Frame: blocked})
|
||||
maxLen -= l
|
||||
controlFrameLen += l
|
||||
}
|
||||
}
|
||||
|
||||
// The only way to become blocked on connection-level flow control is by sending STREAM frames.
|
||||
if isBlocked, offset := f.connFlowController.IsNewlyBlocked(); isBlocked {
|
||||
blocked := &wire.DataBlockedFrame{MaximumData: offset}
|
||||
l := blocked.Length(v)
|
||||
// In case it doesn't fit, queue it for the next packet.
|
||||
if maxLen >= l {
|
||||
frames = append(frames, ackhandler.Frame{Frame: blocked})
|
||||
controlFrameLen += l
|
||||
} else {
|
||||
f.controlFrames = append(f.controlFrames, blocked)
|
||||
}
|
||||
}
|
||||
|
||||
f.mutex.Unlock()
|
||||
f.controlFrameMutex.Unlock()
|
||||
|
||||
if lastFrame.Frame != nil {
|
||||
// account for the smaller size of the last STREAM frame
|
||||
streamFrameLen -= lastFrame.Frame.Length(v)
|
||||
lastFrame.Frame.DataLenPresent = false
|
||||
streamFrameLen += lastFrame.Frame.Length(v)
|
||||
}
|
||||
|
||||
return frames, streamFrames, controlFrameLen + streamFrameLen
|
||||
}
|
||||
|
||||
func (f *framer) appendControlFrames(
|
||||
frames []ackhandler.Frame,
|
||||
maxLen protocol.ByteCount,
|
||||
now monotime.Time,
|
||||
v protocol.Version,
|
||||
) ([]ackhandler.Frame, protocol.ByteCount) {
|
||||
var length protocol.ByteCount
|
||||
// add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet
|
||||
if len(f.pathResponses) > 0 {
|
||||
frame := f.pathResponses[0]
|
||||
frameLen := frame.Length(v)
|
||||
if frameLen <= maxLen {
|
||||
frames = append(frames, ackhandler.Frame{Frame: frame})
|
||||
length += frameLen
|
||||
f.pathResponses = f.pathResponses[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// add stream-related control frames
|
||||
for id, str := range f.streamsWithControlFrames {
|
||||
start:
|
||||
remainingLen := maxLen - length
|
||||
if remainingLen <= maxStreamControlFrameSize {
|
||||
break
|
||||
}
|
||||
fr, ok, hasMore := str.getControlFrame(now)
|
||||
if !hasMore {
|
||||
delete(f.streamsWithControlFrames, id)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
frames = append(frames, fr)
|
||||
length += fr.Frame.Length(v)
|
||||
if hasMore {
|
||||
// It is rare that a stream has more than one control frame to queue.
|
||||
// We don't want to spawn another loop for just to cover that case.
|
||||
goto start
|
||||
}
|
||||
}
|
||||
|
||||
for len(f.controlFrames) > 0 {
|
||||
frame := f.controlFrames[len(f.controlFrames)-1]
|
||||
frameLen := frame.Length(v)
|
||||
if length+frameLen > maxLen {
|
||||
break
|
||||
}
|
||||
frames = append(frames, ackhandler.Frame{Frame: frame})
|
||||
length += frameLen
|
||||
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
|
||||
}
|
||||
|
||||
return frames, length
|
||||
}
|
||||
|
||||
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
|
||||
// This is a hack.
|
||||
// It is easier to implement than propagating an error return value in QueueControlFrame.
|
||||
// The correct solution would be to queue frames with their respective structs.
|
||||
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
|
||||
func (f *framer) QueuedTooManyControlFrames() bool {
|
||||
return f.queuedTooManyControlFrames
|
||||
}
|
||||
|
||||
func (f *framer) AddActiveStream(id protocol.StreamID, str streamFrameGetter) {
|
||||
f.mutex.Lock()
|
||||
if _, ok := f.activeStreams[id]; !ok {
|
||||
f.streamQueue.PushBack(id)
|
||||
f.activeStreams[id] = str
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framer) AddStreamWithControlFrames(id protocol.StreamID, str streamControlFrameGetter) {
|
||||
f.controlFrameMutex.Lock()
|
||||
if _, ok := f.streamsWithControlFrames[id]; !ok {
|
||||
f.streamsWithControlFrames[id] = str
|
||||
}
|
||||
f.controlFrameMutex.Unlock()
|
||||
}
|
||||
|
||||
// RemoveActiveStream is called when a stream completes.
|
||||
func (f *framer) RemoveActiveStream(id protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
delete(f.activeStreams, id)
|
||||
// We don't delete the stream from the streamQueue,
|
||||
// since we'd have to iterate over the ringbuffer.
|
||||
// Instead, we check if the stream is still in activeStreams when appending STREAM frames.
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framer) getNextStreamFrame(maxLen protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame) {
|
||||
id := f.streamQueue.PopFront()
|
||||
// This should never return an error. Better check it anyway.
|
||||
// The stream will only be in the streamQueue, if it enqueued itself there.
|
||||
str, ok := f.activeStreams[id]
|
||||
// The stream might have been removed after being enqueued.
|
||||
if !ok {
|
||||
return ackhandler.StreamFrame{}, nil
|
||||
}
|
||||
// For the last STREAM frame, we'll remove the DataLen field later.
|
||||
// Therefore, we can pretend to have more bytes available when popping
|
||||
// the STREAM frame (which will always have the DataLen set).
|
||||
maxLen += protocol.ByteCount(quicvarint.Len(uint64(maxLen)))
|
||||
frame, blocked, hasMoreData := str.popStreamFrame(maxLen, v)
|
||||
if hasMoreData { // put the stream back in the queue (at the end)
|
||||
f.streamQueue.PushBack(id)
|
||||
} else { // no more data to send. Stream is not active
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
// Note that the frame.Frame can be nil:
|
||||
// * if the stream was canceled after it said it had data
|
||||
// * the remaining size doesn't allow us to add another STREAM frame
|
||||
return frame, blocked
|
||||
}
|
||||
|
||||
func (f *framer) Handle0RTTRejection() {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
f.controlFrameMutex.Lock()
|
||||
defer f.controlFrameMutex.Unlock()
|
||||
|
||||
f.streamQueue.Clear()
|
||||
for id := range f.activeStreams {
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
var j int
|
||||
for i, frame := range f.controlFrames {
|
||||
switch frame.(type) {
|
||||
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame,
|
||||
*wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
|
||||
continue
|
||||
default:
|
||||
f.controlFrames[j] = f.controlFrames[i]
|
||||
j++
|
||||
}
|
||||
}
|
||||
f.controlFrames = slices.Delete(f.controlFrames, j, len(f.controlFrames))
|
||||
}
|
||||
9
vendor/github.com/quic-go/quic-go/http3/README.md
generated
vendored
Normal file
9
vendor/github.com/quic-go/quic-go/http3/README.md
generated
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# HTTP/3
|
||||
|
||||
[](https://quic-go.net/docs/)
|
||||
[](https://pkg.go.dev/github.com/quic-go/quic-go/http3)
|
||||
|
||||
This package implements HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)).
|
||||
It aims to provide feature parity with the standard library's HTTP/1.1 and HTTP/2 implementation.
|
||||
|
||||
Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/).
|
||||
137
vendor/github.com/quic-go/quic-go/http3/body.go
generated
vendored
Normal file
137
vendor/github.com/quic-go/quic-go/http3/body.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// Settingser allows waiting for and retrieving the peer's HTTP/3 settings.
|
||||
type Settingser interface {
|
||||
// ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received.
|
||||
// Settings can be obtained from the Settings method after the channel was closed.
|
||||
ReceivedSettings() <-chan struct{}
|
||||
// Settings returns the settings received on this connection.
|
||||
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
|
||||
Settings() *Settings
|
||||
}
|
||||
|
||||
var errTooMuchData = errors.New("peer sent too much data")
|
||||
|
||||
// The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response).
|
||||
type body struct {
|
||||
str *Stream
|
||||
|
||||
remainingContentLength int64
|
||||
violatedContentLength bool
|
||||
hasContentLength bool
|
||||
}
|
||||
|
||||
func newBody(str *Stream, contentLength int64) *body {
|
||||
b := &body{str: str}
|
||||
if contentLength >= 0 {
|
||||
b.hasContentLength = true
|
||||
b.remainingContentLength = contentLength
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *body) StreamID() quic.StreamID { return r.str.StreamID() }
|
||||
|
||||
func (r *body) checkContentLengthViolation() error {
|
||||
if !r.hasContentLength {
|
||||
return nil
|
||||
}
|
||||
if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() {
|
||||
if !r.violatedContentLength {
|
||||
r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
r.violatedContentLength = true
|
||||
}
|
||||
return errTooMuchData
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *body) Read(b []byte) (int, error) {
|
||||
if err := r.checkContentLengthViolation(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if r.hasContentLength {
|
||||
b = b[:min(int64(len(b)), r.remainingContentLength)]
|
||||
}
|
||||
n, err := r.str.Read(b)
|
||||
r.remainingContentLength -= int64(n)
|
||||
if err := r.checkContentLengthViolation(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
func (r *body) Close() error {
|
||||
r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return nil
|
||||
}
|
||||
|
||||
type requestBody struct {
|
||||
body
|
||||
connCtx context.Context
|
||||
rcvdSettings <-chan struct{}
|
||||
getSettings func() *Settings
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = &requestBody{}
|
||||
|
||||
func newRequestBody(str *Stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody {
|
||||
return &requestBody{
|
||||
body: *newBody(str, contentLength),
|
||||
connCtx: connCtx,
|
||||
rcvdSettings: rcvdSettings,
|
||||
getSettings: getSettings,
|
||||
}
|
||||
}
|
||||
|
||||
type hijackableBody struct {
|
||||
body body
|
||||
|
||||
// only set for the http.Response
|
||||
// The channel is closed when the user is done with this response:
|
||||
// either when Read() errors, or when Close() is called.
|
||||
reqDone chan<- struct{}
|
||||
reqDoneOnce sync.Once
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = &hijackableBody{}
|
||||
|
||||
func newResponseBody(str *Stream, contentLength int64, done chan<- struct{}) *hijackableBody {
|
||||
return &hijackableBody{
|
||||
body: *newBody(str, contentLength),
|
||||
reqDone: done,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *hijackableBody) Read(b []byte) (int, error) {
|
||||
n, err := r.body.Read(b)
|
||||
if err != nil {
|
||||
r.requestDone()
|
||||
}
|
||||
return n, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
func (r *hijackableBody) requestDone() {
|
||||
if r.reqDone != nil {
|
||||
r.reqDoneOnce.Do(func() {
|
||||
close(r.reqDone)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *hijackableBody) Close() error {
|
||||
r.requestDone()
|
||||
// If the EOF was read, CancelRead() is a no-op.
|
||||
r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return nil
|
||||
}
|
||||
61
vendor/github.com/quic-go/quic-go/http3/capsule.go
generated
vendored
Normal file
61
vendor/github.com/quic-go/quic-go/http3/capsule.go
generated
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// CapsuleType is the type of the capsule
|
||||
type CapsuleType uint64
|
||||
|
||||
// CapsuleProtocolHeader is the header value used to advertise support for the capsule protocol
|
||||
const CapsuleProtocolHeader = "Capsule-Protocol"
|
||||
|
||||
type exactReader struct {
|
||||
R io.LimitedReader
|
||||
}
|
||||
|
||||
func (r *exactReader) Read(b []byte) (int, error) {
|
||||
n, err := r.R.Read(b)
|
||||
if err == io.EOF && r.R.N > 0 {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ParseCapsule parses the header of a Capsule.
|
||||
// It returns an io.Reader that can be used to read the Capsule value.
|
||||
// The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
|
||||
func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
|
||||
cbr := countingByteReader{Reader: r}
|
||||
ct, err := quicvarint.Read(&cbr)
|
||||
if err != nil {
|
||||
// If an io.EOF is returned without consuming any bytes, return it unmodified.
|
||||
// Otherwise, return an io.ErrUnexpectedEOF.
|
||||
if err == io.EOF && cbr.NumRead > 0 {
|
||||
return 0, nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
l, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return 0, nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
return CapsuleType(ct), &exactReader{R: io.LimitedReader{R: r, N: int64(l)}}, nil
|
||||
}
|
||||
|
||||
// WriteCapsule writes a capsule
|
||||
func WriteCapsule(w quicvarint.Writer, ct CapsuleType, value []byte) error {
|
||||
b := make([]byte, 0, 16)
|
||||
b = quicvarint.Append(b, uint64(ct))
|
||||
b = quicvarint.Append(b, uint64(len(value)))
|
||||
if _, err := w.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write(value)
|
||||
return err
|
||||
}
|
||||
505
vendor/github.com/quic-go/quic-go/http3/client.go
generated
vendored
Normal file
505
vendor/github.com/quic-go/quic-go/http3/client.go
generated
vendored
Normal file
@@ -0,0 +1,505 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
const (
|
||||
// MethodGet0RTT allows a GET request to be sent using 0-RTT.
|
||||
// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
|
||||
MethodGet0RTT = "GET_0RTT"
|
||||
// MethodHead0RTT allows a HEAD request to be sent using 0-RTT.
|
||||
// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
|
||||
MethodHead0RTT = "HEAD_0RTT"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUserAgent = "quic-go HTTP/3"
|
||||
defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
|
||||
)
|
||||
|
||||
var errGoAway = errors.New("connection in graceful shutdown")
|
||||
|
||||
type errConnUnusable struct{ e error }
|
||||
|
||||
func (e *errConnUnusable) Unwrap() error { return e.e }
|
||||
func (e *errConnUnusable) Error() string { return fmt.Sprintf("http3: conn unusable: %s", e.e.Error()) }
|
||||
|
||||
const max1xxResponses = 5 // arbitrary bound on number of informational responses
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
|
||||
KeepAlivePeriod: 10 * time.Second,
|
||||
}
|
||||
|
||||
// ClientConn is an HTTP/3 client doing requests to a single remote server.
|
||||
type ClientConn struct {
|
||||
conn *quic.Conn
|
||||
rawConn *rawConn
|
||||
|
||||
decoder *qpack.Decoder
|
||||
|
||||
// Additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
|
||||
additionalSettings map[uint64]uint64
|
||||
|
||||
// maxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||
// allowed in the server's response header.
|
||||
maxResponseHeaderBytes int
|
||||
|
||||
// disableCompression, if true, prevents the Transport from requesting compression with an
|
||||
// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
|
||||
// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body.
|
||||
// However, if the user explicitly requested gzip it is not automatically uncompressed.
|
||||
disableCompression bool
|
||||
|
||||
streamMx sync.Mutex
|
||||
maxStreamID quic.StreamID // set once a GOAWAY frame is received
|
||||
lastStreamID quic.StreamID // the highest stream ID that was opened
|
||||
|
||||
qlogger qlogwriter.Recorder
|
||||
logger *slog.Logger
|
||||
|
||||
requestWriter *requestWriter
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = &ClientConn{}
|
||||
|
||||
func newClientConn(
|
||||
conn *quic.Conn,
|
||||
enableDatagrams bool,
|
||||
additionalSettings map[uint64]uint64,
|
||||
maxResponseHeaderBytes int,
|
||||
disableCompression bool,
|
||||
logger *slog.Logger,
|
||||
) *ClientConn {
|
||||
var qlogger qlogwriter.Recorder
|
||||
if qlogTrace := conn.QlogTrace(); qlogTrace != nil && qlogTrace.SupportsSchemas(qlog.EventSchema) {
|
||||
qlogger = qlogTrace.AddProducer()
|
||||
}
|
||||
c := &ClientConn{
|
||||
conn: conn,
|
||||
additionalSettings: additionalSettings,
|
||||
disableCompression: disableCompression,
|
||||
maxStreamID: invalidStreamID,
|
||||
lastStreamID: invalidStreamID,
|
||||
logger: logger,
|
||||
qlogger: qlogger,
|
||||
decoder: qpack.NewDecoder(),
|
||||
}
|
||||
if maxResponseHeaderBytes <= 0 {
|
||||
c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes
|
||||
} else {
|
||||
c.maxResponseHeaderBytes = maxResponseHeaderBytes
|
||||
}
|
||||
c.requestWriter = newRequestWriter()
|
||||
c.rawConn = newRawConn(
|
||||
conn,
|
||||
enableDatagrams,
|
||||
c.onStreamsEmpty,
|
||||
c.handleControlStream,
|
||||
qlogger,
|
||||
c.logger,
|
||||
)
|
||||
// send the SETTINGs frame, using 0-RTT data, if possible
|
||||
go func() {
|
||||
_, err := c.rawConn.openControlStream(&settingsFrame{
|
||||
Datagram: enableDatagrams,
|
||||
Other: additionalSettings,
|
||||
MaxFieldSectionSize: int64(c.maxResponseHeaderBytes),
|
||||
})
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("setting up connection failed", "error", err)
|
||||
}
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
|
||||
return
|
||||
}
|
||||
}()
|
||||
return c
|
||||
}
|
||||
|
||||
// OpenRequestStream opens a new request stream on the HTTP/3 connection.
|
||||
func (c *ClientConn) OpenRequestStream(ctx context.Context) (*RequestStream, error) {
|
||||
return c.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes)
|
||||
}
|
||||
|
||||
func (c *ClientConn) openRequestStream(
|
||||
ctx context.Context,
|
||||
requestWriter *requestWriter,
|
||||
reqDone chan<- struct{},
|
||||
disableCompression bool,
|
||||
maxHeaderBytes int,
|
||||
) (*RequestStream, error) {
|
||||
c.streamMx.Lock()
|
||||
maxStreamID := c.maxStreamID
|
||||
var nextStreamID quic.StreamID
|
||||
if c.lastStreamID == invalidStreamID {
|
||||
nextStreamID = 0
|
||||
} else {
|
||||
nextStreamID = c.lastStreamID + 4
|
||||
}
|
||||
c.streamMx.Unlock()
|
||||
// Streams with stream ID equal to or greater than the stream ID carried in the GOAWAY frame
|
||||
// will be rejected, see section 5.2 of RFC 9114.
|
||||
if maxStreamID != invalidStreamID && nextStreamID >= maxStreamID {
|
||||
return nil, errGoAway
|
||||
}
|
||||
|
||||
str, err := c.conn.OpenStreamSync(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.streamMx.Lock()
|
||||
// take the maximum here, as multiple OpenStreamSync calls might have returned concurrently
|
||||
if c.lastStreamID == invalidStreamID {
|
||||
c.lastStreamID = str.StreamID()
|
||||
} else {
|
||||
c.lastStreamID = max(c.lastStreamID, str.StreamID())
|
||||
}
|
||||
// check again, in case a (or another) GOAWAY frame was received
|
||||
maxStreamID = c.maxStreamID
|
||||
c.streamMx.Unlock()
|
||||
|
||||
if maxStreamID != invalidStreamID && str.StreamID() >= maxStreamID {
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return nil, errGoAway
|
||||
}
|
||||
|
||||
hstr := c.rawConn.TrackStream(str)
|
||||
rsp := &http.Response{}
|
||||
trace := httptrace.ContextClientTrace(ctx)
|
||||
return newRequestStream(
|
||||
newStream(hstr, c.rawConn, trace, func(r io.Reader, hf *headersFrame) error {
|
||||
hdr, err := decodeTrailers(r, hf, maxHeaderBytes, c.decoder, c.qlogger, str.StreamID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rsp.Trailer = hdr
|
||||
return nil
|
||||
}, c.qlogger),
|
||||
requestWriter,
|
||||
reqDone,
|
||||
c.decoder,
|
||||
disableCompression,
|
||||
maxHeaderBytes,
|
||||
rsp,
|
||||
), nil
|
||||
}
|
||||
|
||||
func (c *ClientConn) handleUnidirectionalStream(str *quic.ReceiveStream) {
|
||||
c.rawConn.handleUnidirectionalStream(str, false)
|
||||
}
|
||||
|
||||
func (c *ClientConn) handleControlStream(str *quic.ReceiveStream, fp *frameParser) {
|
||||
for {
|
||||
f, err := fp.ParseNext(c.qlogger)
|
||||
if err != nil {
|
||||
var serr *quic.StreamError
|
||||
if err == io.EOF || errors.As(err, &serr) {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "")
|
||||
return
|
||||
}
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
}
|
||||
// GOAWAY is the only frame allowed at this point:
|
||||
// * unexpected frames are ignored by the frame parser
|
||||
// * we don't support any extension that might add support for more frames
|
||||
goaway, ok := f.(*goAwayFrame)
|
||||
if !ok {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
|
||||
return
|
||||
}
|
||||
if goaway.StreamID%4 != 0 { // client-initiated, bidirectional streams
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
|
||||
return
|
||||
}
|
||||
c.streamMx.Lock()
|
||||
// the server is not allowed to increase the Stream ID in subsequent GOAWAY frames
|
||||
if c.maxStreamID != invalidStreamID && goaway.StreamID > c.maxStreamID {
|
||||
c.streamMx.Unlock()
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
|
||||
return
|
||||
}
|
||||
c.maxStreamID = goaway.StreamID
|
||||
c.streamMx.Unlock()
|
||||
|
||||
hasActiveStreams := c.rawConn.hasActiveStreams()
|
||||
// immediately close the connection if there are currently no active requests
|
||||
if !hasActiveStreams {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) onStreamsEmpty() {
|
||||
c.streamMx.Lock()
|
||||
defer c.streamMx.Unlock()
|
||||
|
||||
// The server is performing a graceful shutdown.
|
||||
if c.maxStreamID != invalidStreamID {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip executes a request and returns a response
|
||||
func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rsp, err := c.roundTrip(req)
|
||||
if err != nil && req.Context().Err() != nil {
|
||||
// if the context was canceled, return the context cancellation error
|
||||
err = req.Context().Err()
|
||||
}
|
||||
return rsp, err
|
||||
}
|
||||
|
||||
func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Immediately send out this request, if this is a 0-RTT request.
|
||||
switch req.Method {
|
||||
case MethodGet0RTT:
|
||||
// don't modify the original request
|
||||
reqCopy := *req
|
||||
req = &reqCopy
|
||||
req.Method = http.MethodGet
|
||||
case MethodHead0RTT:
|
||||
// don't modify the original request
|
||||
reqCopy := *req
|
||||
req = &reqCopy
|
||||
req.Method = http.MethodHead
|
||||
default:
|
||||
// wait for the handshake to complete
|
||||
select {
|
||||
case <-c.conn.HandshakeComplete():
|
||||
case <-req.Context().Done():
|
||||
return nil, req.Context().Err()
|
||||
}
|
||||
}
|
||||
|
||||
// It is only possible to send an Extended CONNECT request once the SETTINGS were received.
|
||||
// See section 3 of RFC 8441.
|
||||
if isExtendedConnectRequest(req) {
|
||||
connCtx := c.conn.Context()
|
||||
// wait for the server's SETTINGS frame to arrive
|
||||
select {
|
||||
case <-c.rawConn.ReceivedSettings():
|
||||
case <-connCtx.Done():
|
||||
return nil, context.Cause(connCtx)
|
||||
}
|
||||
if !c.rawConn.Settings().EnableExtendedConnect {
|
||||
return nil, errors.New("http3: server didn't enable Extended CONNECT")
|
||||
}
|
||||
}
|
||||
|
||||
reqDone := make(chan struct{})
|
||||
str, err := c.openRequestStream(
|
||||
req.Context(),
|
||||
c.requestWriter,
|
||||
reqDone,
|
||||
c.disableCompression,
|
||||
c.maxResponseHeaderBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, &errConnUnusable{e: err}
|
||||
}
|
||||
|
||||
// Request Cancellation:
|
||||
// This go routine keeps running even after RoundTripOpt() returns.
|
||||
// It is shut down when the application is done processing the body.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
case <-reqDone:
|
||||
}
|
||||
}()
|
||||
|
||||
rsp, err := c.doRequest(req, str)
|
||||
if err != nil { // if any error occurred
|
||||
close(reqDone)
|
||||
<-done
|
||||
return nil, maybeReplaceError(err)
|
||||
}
|
||||
return rsp, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
// ReceivedSettings returns a channel that is closed once the server's HTTP/3 settings were received.
|
||||
// Settings can be obtained from the Settings method after the channel was closed.
|
||||
func (c *ClientConn) ReceivedSettings() <-chan struct{} {
|
||||
return c.rawConn.ReceivedSettings()
|
||||
}
|
||||
|
||||
// Settings returns the HTTP/3 settings for this connection.
|
||||
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
|
||||
func (c *ClientConn) Settings() *Settings {
|
||||
return c.rawConn.Settings()
|
||||
}
|
||||
|
||||
// CloseWithError closes the connection with the given error code and message.
|
||||
// It is invalid to call this function after the connection was closed.
|
||||
func (c *ClientConn) CloseWithError(code quic.ApplicationErrorCode, msg string) error {
|
||||
return c.conn.CloseWithError(code, msg)
|
||||
}
|
||||
|
||||
// Context returns a context that is cancelled when the connection is closed.
|
||||
func (c *ClientConn) Context() context.Context {
|
||||
return c.conn.Context()
|
||||
}
|
||||
|
||||
// cancelingReader reads from the io.Reader.
|
||||
// It cancels writing on the stream if any error other than io.EOF occurs.
|
||||
type cancelingReader struct {
|
||||
r io.Reader
|
||||
str *RequestStream
|
||||
}
|
||||
|
||||
func (r *cancelingReader) Read(b []byte) (int, error) {
|
||||
n, err := r.r.Read(b)
|
||||
if err != nil && err != io.EOF {
|
||||
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *ClientConn) sendRequestBody(str *RequestStream, body io.ReadCloser, contentLength int64) error {
|
||||
defer body.Close()
|
||||
buf := make([]byte, bodyCopyBufferSize)
|
||||
sr := &cancelingReader{str: str, r: body}
|
||||
if contentLength == -1 {
|
||||
_, err := io.CopyBuffer(str, sr, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// make sure we don't send more bytes than the content length
|
||||
n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var extra int64
|
||||
extra, err = io.CopyBuffer(io.Discard, sr, buf)
|
||||
n += extra
|
||||
if n > contentLength {
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientConn) doRequest(req *http.Request, str *RequestStream) (*http.Response, error) {
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
var sendingReqFailed bool
|
||||
if err := str.sendRequestHeader(req); err != nil {
|
||||
traceWroteRequest(trace, err)
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("error writing request", "error", err)
|
||||
}
|
||||
sendingReqFailed = true
|
||||
}
|
||||
if !sendingReqFailed {
|
||||
if req.Body == nil {
|
||||
traceWroteRequest(trace, nil)
|
||||
str.Close()
|
||||
} else {
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
defer str.Close()
|
||||
contentLength := int64(-1)
|
||||
// According to the documentation for http.Request.ContentLength,
|
||||
// a value of 0 with a non-nil Body is also treated as unknown content length.
|
||||
if req.ContentLength > 0 {
|
||||
contentLength = req.ContentLength
|
||||
}
|
||||
err := c.sendRequestBody(str, req.Body, contentLength)
|
||||
traceWroteRequest(trace, err)
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("error writing request", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Trailer) > 0 {
|
||||
if err := str.sendRequestTrailer(req); err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("error writing trailers", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// copy from net/http: support 1xx responses
|
||||
var num1xx int // number of informational 1xx headers received
|
||||
var res *http.Response
|
||||
for {
|
||||
var err error
|
||||
res, err = str.ReadResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resCode := res.StatusCode
|
||||
is1xx := 100 <= resCode && resCode <= 199
|
||||
// treat 101 as a terminal status, see https://github.com/golang/go/issues/26161
|
||||
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
||||
if is1xxNonTerminal {
|
||||
num1xx++
|
||||
if num1xx > max1xxResponses {
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeExcessiveLoad))
|
||||
return nil, errors.New("http3: too many 1xx informational responses")
|
||||
}
|
||||
traceGot1xxResponse(trace, resCode, textproto.MIMEHeader(res.Header))
|
||||
if resCode == http.StatusContinue {
|
||||
traceGot100Continue(trace)
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
connState := c.conn.ConnectionState().TLS
|
||||
res.TLS = &connState
|
||||
res.Request = req
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// RawClientConn is a low-level HTTP/3 client connection.
|
||||
// It allows the application to take control of the stream accept loops,
|
||||
// giving the application the ability to handle streams originating from the server.
|
||||
type RawClientConn struct {
|
||||
*ClientConn
|
||||
}
|
||||
|
||||
// HandleUnidirectionalStream handles an incoming unidirectional stream.
|
||||
func (c *RawClientConn) HandleUnidirectionalStream(str *quic.ReceiveStream) {
|
||||
c.rawConn.handleUnidirectionalStream(str, false)
|
||||
}
|
||||
|
||||
// HandleBidirectionalStream handles an incoming bidirectional stream.
|
||||
func (c *ClientConn) HandleBidirectionalStream(str *quic.Stream) {
|
||||
// According to RFC 9114, the server is not allowed to open bidirectional streams.
|
||||
c.rawConn.CloseWithError(
|
||||
quic.ApplicationErrorCode(ErrCodeStreamCreationError),
|
||||
fmt.Sprintf("server opened bidirectional stream %d", str.StreamID()),
|
||||
)
|
||||
}
|
||||
321
vendor/github.com/quic-go/quic-go/http3/conn.go
generated
vendored
Normal file
321
vendor/github.com/quic-go/quic-go/http3/conn.go
generated
vendored
Normal file
@@ -0,0 +1,321 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
const maxQuarterStreamID = 1<<60 - 1
|
||||
|
||||
// invalidStreamID is a stream ID that is invalid. The first valid stream ID in QUIC is 0.
|
||||
const invalidStreamID = quic.StreamID(-1)
|
||||
|
||||
// rawConn is an HTTP/3 connection.
|
||||
// It provides HTTP/3 specific functionality by wrapping a quic.Conn,
|
||||
// in particular handling of unidirectional HTTP/3 streams, SETTINGS and datagrams.
|
||||
type rawConn struct {
|
||||
conn *quic.Conn
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
enableDatagrams bool
|
||||
|
||||
streamMx sync.Mutex
|
||||
streams map[quic.StreamID]*stateTrackingStream
|
||||
|
||||
rcvdControlStr atomic.Bool
|
||||
rcvdQPACKEncoderStr atomic.Bool
|
||||
rcvdQPACKDecoderStr atomic.Bool
|
||||
controlStrHandler func(*quic.ReceiveStream, *frameParser) // is called *after* the SETTINGS frame was parsed
|
||||
|
||||
onStreamsEmpty func()
|
||||
|
||||
settings *Settings
|
||||
receivedSettings chan struct{}
|
||||
|
||||
qlogger qlogwriter.Recorder
|
||||
qloggerWG sync.WaitGroup // tracks goroutines that may produce qlog events
|
||||
}
|
||||
|
||||
func newRawConn(
|
||||
quicConn *quic.Conn,
|
||||
enableDatagrams bool,
|
||||
onStreamsEmpty func(),
|
||||
controlStrHandler func(*quic.ReceiveStream, *frameParser),
|
||||
qlogger qlogwriter.Recorder,
|
||||
logger *slog.Logger,
|
||||
) *rawConn {
|
||||
c := &rawConn{
|
||||
conn: quicConn,
|
||||
logger: logger,
|
||||
enableDatagrams: enableDatagrams,
|
||||
receivedSettings: make(chan struct{}),
|
||||
streams: make(map[quic.StreamID]*stateTrackingStream),
|
||||
qlogger: qlogger,
|
||||
onStreamsEmpty: onStreamsEmpty,
|
||||
controlStrHandler: controlStrHandler,
|
||||
}
|
||||
if qlogger != nil {
|
||||
context.AfterFunc(quicConn.Context(), c.closeQlogger)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *rawConn) OpenUniStream() (*quic.SendStream, error) {
|
||||
return c.conn.OpenUniStream()
|
||||
}
|
||||
|
||||
// openControlStream opens the control stream and sends the SETTINGS frame.
|
||||
// It returns the control stream (needed by the server for sending GOAWAY later).
|
||||
func (c *rawConn) openControlStream(settings *settingsFrame) (*quic.SendStream, error) {
|
||||
c.qloggerWG.Add(1)
|
||||
defer c.qloggerWG.Done()
|
||||
|
||||
str, err := c.conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := make([]byte, 0, 64)
|
||||
b = quicvarint.Append(b, streamTypeControlStream)
|
||||
b = settings.Append(b)
|
||||
if c.qlogger != nil {
|
||||
sf := qlog.SettingsFrame{
|
||||
MaxFieldSectionSize: settings.MaxFieldSectionSize,
|
||||
Other: maps.Clone(settings.Other),
|
||||
}
|
||||
if settings.Datagram {
|
||||
sf.Datagram = pointer(true)
|
||||
}
|
||||
if settings.ExtendedConnect {
|
||||
sf.ExtendedConnect = pointer(true)
|
||||
}
|
||||
c.qlogger.RecordEvent(qlog.FrameCreated{
|
||||
StreamID: str.StreamID(),
|
||||
Raw: qlog.RawInfo{Length: len(b)},
|
||||
Frame: qlog.Frame{Frame: sf},
|
||||
})
|
||||
}
|
||||
if _, err := str.Write(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (c *rawConn) TrackStream(str *quic.Stream) *stateTrackingStream {
|
||||
hstr := newStateTrackingStream(str, c, func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
|
||||
|
||||
c.streamMx.Lock()
|
||||
c.streams[str.StreamID()] = hstr
|
||||
c.qloggerWG.Add(1)
|
||||
c.streamMx.Unlock()
|
||||
return hstr
|
||||
}
|
||||
|
||||
func (c *rawConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *rawConn) ConnectionState() quic.ConnectionState {
|
||||
return c.conn.ConnectionState()
|
||||
}
|
||||
|
||||
func (c *rawConn) clearStream(id quic.StreamID) {
|
||||
c.streamMx.Lock()
|
||||
defer c.streamMx.Unlock()
|
||||
|
||||
if _, ok := c.streams[id]; ok {
|
||||
delete(c.streams, id)
|
||||
c.qloggerWG.Done()
|
||||
}
|
||||
if len(c.streams) == 0 {
|
||||
c.onStreamsEmpty()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *rawConn) hasActiveStreams() bool {
|
||||
c.streamMx.Lock()
|
||||
defer c.streamMx.Unlock()
|
||||
|
||||
return len(c.streams) > 0
|
||||
}
|
||||
|
||||
func (c *rawConn) CloseWithError(code quic.ApplicationErrorCode, msg string) error {
|
||||
return c.conn.CloseWithError(code, msg)
|
||||
}
|
||||
|
||||
func (c *rawConn) handleUnidirectionalStream(str *quic.ReceiveStream, isServer bool) {
|
||||
c.qloggerWG.Add(1)
|
||||
defer c.qloggerWG.Done()
|
||||
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
// We're only interested in the control stream here.
|
||||
switch streamType {
|
||||
case streamTypeControlStream:
|
||||
case streamTypeQPACKEncoderStream:
|
||||
if isFirst := c.rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream")
|
||||
}
|
||||
// Our QPACK implementation doesn't use the dynamic table yet.
|
||||
return
|
||||
case streamTypeQPACKDecoderStream:
|
||||
if isFirst := c.rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream")
|
||||
}
|
||||
// Our QPACK implementation doesn't use the dynamic table yet.
|
||||
return
|
||||
case streamTypePushStream:
|
||||
if isServer {
|
||||
// only the server can push
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
|
||||
} else {
|
||||
// we never increased the Push ID, so we don't expect any push streams
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
|
||||
}
|
||||
return
|
||||
default:
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
|
||||
return
|
||||
}
|
||||
// Only a single control stream is allowed.
|
||||
if isFirstControlStr := c.rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
|
||||
return
|
||||
}
|
||||
c.handleControlStream(str)
|
||||
}
|
||||
|
||||
func (c *rawConn) handleControlStream(str *quic.ReceiveStream) {
|
||||
fp := &frameParser{closeConn: c.conn.CloseWithError, r: str, streamID: str.StreamID()}
|
||||
f, err := fp.ParseNext(c.qlogger)
|
||||
if err != nil {
|
||||
var serr *quic.StreamError
|
||||
if err == io.EOF || errors.As(err, &serr) {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "")
|
||||
return
|
||||
}
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
|
||||
return
|
||||
}
|
||||
c.settings = &Settings{
|
||||
EnableDatagrams: sf.Datagram,
|
||||
EnableExtendedConnect: sf.ExtendedConnect,
|
||||
Other: sf.Other,
|
||||
}
|
||||
close(c.receivedSettings)
|
||||
if sf.Datagram {
|
||||
// If datagram support was enabled on our side as well as on the server side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if c.enableDatagrams && !c.ConnectionState().SupportsDatagrams.Remote {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
|
||||
return
|
||||
}
|
||||
c.qloggerWG.Add(1)
|
||||
go func() {
|
||||
defer c.qloggerWG.Done()
|
||||
if err := c.receiveDatagrams(); err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("receiving datagrams failed", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if c.controlStrHandler != nil {
|
||||
c.controlStrHandler(str, fp)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *rawConn) sendDatagram(streamID quic.StreamID, b []byte) error {
|
||||
// TODO: this creates a lot of garbage and an additional copy
|
||||
data := make([]byte, 0, len(b)+8)
|
||||
quarterStreamID := uint64(streamID / 4)
|
||||
data = quicvarint.Append(data, uint64(streamID/4))
|
||||
data = append(data, b...)
|
||||
if c.qlogger != nil {
|
||||
c.qlogger.RecordEvent(qlog.DatagramCreated{
|
||||
QuaterStreamID: quarterStreamID,
|
||||
Raw: qlog.RawInfo{
|
||||
Length: len(data),
|
||||
PayloadLength: len(b),
|
||||
},
|
||||
})
|
||||
}
|
||||
return c.conn.SendDatagram(data)
|
||||
}
|
||||
|
||||
func (c *rawConn) receiveDatagrams() error {
|
||||
for {
|
||||
b, err := c.conn.ReceiveDatagram(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
quarterStreamID, n, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
|
||||
return fmt.Errorf("could not read quarter stream id: %w", err)
|
||||
}
|
||||
if c.qlogger != nil {
|
||||
c.qlogger.RecordEvent(qlog.DatagramParsed{
|
||||
QuaterStreamID: quarterStreamID,
|
||||
Raw: qlog.RawInfo{
|
||||
Length: len(b),
|
||||
PayloadLength: len(b) - n,
|
||||
},
|
||||
})
|
||||
}
|
||||
if quarterStreamID > maxQuarterStreamID {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
|
||||
return fmt.Errorf("invalid quarter stream id: %w", err)
|
||||
}
|
||||
streamID := quic.StreamID(4 * quarterStreamID)
|
||||
c.streamMx.Lock()
|
||||
dg, ok := c.streams[streamID]
|
||||
c.streamMx.Unlock()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
dg.enqueueDatagram(b[n:])
|
||||
}
|
||||
}
|
||||
|
||||
// ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received.
|
||||
// Settings can be optained from the Settings method after the channel was closed.
|
||||
func (c *rawConn) ReceivedSettings() <-chan struct{} { return c.receivedSettings }
|
||||
|
||||
// Settings returns the settings received on this connection.
|
||||
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
|
||||
func (c *rawConn) Settings() *Settings { return c.settings }
|
||||
|
||||
// closeQlogger waits for all goroutines that may produce qlog events to finish,
|
||||
// then closes the qlogger.
|
||||
func (c *rawConn) closeQlogger() {
|
||||
if c.qlogger == nil {
|
||||
return
|
||||
}
|
||||
c.qloggerWG.Wait()
|
||||
c.qlogger.Close()
|
||||
}
|
||||
63
vendor/github.com/quic-go/quic-go/http3/error.go
generated
vendored
Normal file
63
vendor/github.com/quic-go/quic-go/http3/error.go
generated
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// Error is returned from the round tripper (for HTTP clients)
|
||||
// and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs.
|
||||
// See section 8 of RFC 9114.
|
||||
type Error struct {
|
||||
Remote bool
|
||||
ErrorCode ErrCode
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
var _ error = &Error{}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
s := e.ErrorCode.string()
|
||||
if s == "" {
|
||||
s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode))
|
||||
}
|
||||
// Usually errors are remote. Only make it explicit for local errors.
|
||||
if !e.Remote {
|
||||
s += " (local)"
|
||||
}
|
||||
if e.ErrorMessage != "" {
|
||||
s += ": " + e.ErrorMessage
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (e *Error) Is(target error) bool {
|
||||
t, ok := target.(*Error)
|
||||
return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
|
||||
}
|
||||
|
||||
func maybeReplaceError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
e Error
|
||||
strErr *quic.StreamError
|
||||
appErr *quic.ApplicationError
|
||||
)
|
||||
switch {
|
||||
default:
|
||||
return err
|
||||
case errors.As(err, &strErr):
|
||||
e.Remote = strErr.Remote
|
||||
e.ErrorCode = ErrCode(strErr.ErrorCode)
|
||||
case errors.As(err, &appErr):
|
||||
e.Remote = appErr.Remote
|
||||
e.ErrorCode = ErrCode(appErr.ErrorCode)
|
||||
e.ErrorMessage = appErr.ErrorMessage
|
||||
}
|
||||
return &e
|
||||
}
|
||||
84
vendor/github.com/quic-go/quic-go/http3/error_codes.go
generated
vendored
Normal file
84
vendor/github.com/quic-go/quic-go/http3/error_codes.go
generated
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type ErrCode quic.ApplicationErrorCode
|
||||
|
||||
const (
|
||||
ErrCodeNoError ErrCode = 0x100
|
||||
ErrCodeGeneralProtocolError ErrCode = 0x101
|
||||
ErrCodeInternalError ErrCode = 0x102
|
||||
ErrCodeStreamCreationError ErrCode = 0x103
|
||||
ErrCodeClosedCriticalStream ErrCode = 0x104
|
||||
ErrCodeFrameUnexpected ErrCode = 0x105
|
||||
ErrCodeFrameError ErrCode = 0x106
|
||||
ErrCodeExcessiveLoad ErrCode = 0x107
|
||||
ErrCodeIDError ErrCode = 0x108
|
||||
ErrCodeSettingsError ErrCode = 0x109
|
||||
ErrCodeMissingSettings ErrCode = 0x10a
|
||||
ErrCodeRequestRejected ErrCode = 0x10b
|
||||
ErrCodeRequestCanceled ErrCode = 0x10c
|
||||
ErrCodeRequestIncomplete ErrCode = 0x10d
|
||||
ErrCodeMessageError ErrCode = 0x10e
|
||||
ErrCodeConnectError ErrCode = 0x10f
|
||||
ErrCodeVersionFallback ErrCode = 0x110
|
||||
ErrCodeDatagramError ErrCode = 0x33
|
||||
ErrCodeQPACKDecompressionFailed ErrCode = 0x200
|
||||
)
|
||||
|
||||
func (e ErrCode) String() string {
|
||||
s := e.string()
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("unknown error code: %#x", uint16(e))
|
||||
}
|
||||
|
||||
func (e ErrCode) string() string {
|
||||
switch e {
|
||||
case ErrCodeNoError:
|
||||
return "H3_NO_ERROR"
|
||||
case ErrCodeGeneralProtocolError:
|
||||
return "H3_GENERAL_PROTOCOL_ERROR"
|
||||
case ErrCodeInternalError:
|
||||
return "H3_INTERNAL_ERROR"
|
||||
case ErrCodeStreamCreationError:
|
||||
return "H3_STREAM_CREATION_ERROR"
|
||||
case ErrCodeClosedCriticalStream:
|
||||
return "H3_CLOSED_CRITICAL_STREAM"
|
||||
case ErrCodeFrameUnexpected:
|
||||
return "H3_FRAME_UNEXPECTED"
|
||||
case ErrCodeFrameError:
|
||||
return "H3_FRAME_ERROR"
|
||||
case ErrCodeExcessiveLoad:
|
||||
return "H3_EXCESSIVE_LOAD"
|
||||
case ErrCodeIDError:
|
||||
return "H3_ID_ERROR"
|
||||
case ErrCodeSettingsError:
|
||||
return "H3_SETTINGS_ERROR"
|
||||
case ErrCodeMissingSettings:
|
||||
return "H3_MISSING_SETTINGS"
|
||||
case ErrCodeRequestRejected:
|
||||
return "H3_REQUEST_REJECTED"
|
||||
case ErrCodeRequestCanceled:
|
||||
return "H3_REQUEST_CANCELLED"
|
||||
case ErrCodeRequestIncomplete:
|
||||
return "H3_INCOMPLETE_REQUEST"
|
||||
case ErrCodeMessageError:
|
||||
return "H3_MESSAGE_ERROR"
|
||||
case ErrCodeConnectError:
|
||||
return "H3_CONNECT_ERROR"
|
||||
case ErrCodeVersionFallback:
|
||||
return "H3_VERSION_FALLBACK"
|
||||
case ErrCodeDatagramError:
|
||||
return "H3_DATAGRAM_ERROR"
|
||||
case ErrCodeQPACKDecompressionFailed:
|
||||
return "QPACK_DECOMPRESSION_FAILED"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
327
vendor/github.com/quic-go/quic-go/http3/frames.go
generated
vendored
Normal file
327
vendor/github.com/quic-go/quic-go/http3/frames.go
generated
vendored
Normal file
@@ -0,0 +1,327 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// FrameType is the frame type of a HTTP/3 frame
|
||||
type FrameType uint64
|
||||
|
||||
type frame any
|
||||
|
||||
// The maximum length of an encoded HTTP/3 frame header is 16:
|
||||
// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
|
||||
const frameHeaderLen = 16
|
||||
|
||||
type countingByteReader struct {
|
||||
quicvarint.Reader
|
||||
NumRead int
|
||||
}
|
||||
|
||||
func (r *countingByteReader) ReadByte() (byte, error) {
|
||||
b, err := r.Reader.ReadByte()
|
||||
if err == nil {
|
||||
r.NumRead++
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
func (r *countingByteReader) Read(b []byte) (int, error) {
|
||||
n, err := r.Reader.Read(b)
|
||||
r.NumRead += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *countingByteReader) Reset() {
|
||||
r.NumRead = 0
|
||||
}
|
||||
|
||||
type frameParser struct {
|
||||
r io.Reader
|
||||
streamID quic.StreamID
|
||||
closeConn func(quic.ApplicationErrorCode, string) error
|
||||
}
|
||||
|
||||
func (p *frameParser) ParseNext(qlogger qlogwriter.Recorder) (frame, error) {
|
||||
r := &countingByteReader{Reader: quicvarint.NewReader(p.r)}
|
||||
for {
|
||||
t, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case 0x0: // DATA
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: p.streamID,
|
||||
Raw: qlog.RawInfo{
|
||||
Length: int(l) + r.NumRead,
|
||||
PayloadLength: int(l),
|
||||
},
|
||||
Frame: qlog.Frame{Frame: qlog.DataFrame{}},
|
||||
})
|
||||
}
|
||||
return &dataFrame{Length: l}, nil
|
||||
case 0x1: // HEADERS
|
||||
return &headersFrame{
|
||||
Length: l,
|
||||
headerLen: r.NumRead,
|
||||
}, nil
|
||||
case 0x4: // SETTINGS
|
||||
return parseSettingsFrame(r, l, p.streamID, qlogger)
|
||||
case 0x3: // unsupported: CANCEL_PUSH
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: p.streamID,
|
||||
Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.CancelPushFrame{}},
|
||||
})
|
||||
}
|
||||
case 0x5: // unsupported: PUSH_PROMISE
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: p.streamID,
|
||||
Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.PushPromiseFrame{}},
|
||||
})
|
||||
}
|
||||
case 0x7: // GOAWAY
|
||||
return parseGoAwayFrame(r, l, p.streamID, qlogger)
|
||||
case 0xd: // unsupported: MAX_PUSH_ID
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: p.streamID,
|
||||
Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.MaxPushIDFrame{}},
|
||||
})
|
||||
}
|
||||
case 0x2, 0x6, 0x8, 0x9: // reserved frame types
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: p.streamID,
|
||||
Raw: qlog.RawInfo{Length: r.NumRead + int(l), PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.ReservedFrame{Type: t}},
|
||||
})
|
||||
}
|
||||
p.closeConn(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
|
||||
return nil, fmt.Errorf("http3: reserved frame type: %d", t)
|
||||
default:
|
||||
// unknown frame types
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: p.streamID,
|
||||
Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.UnknownFrame{Type: t}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// skip over the payload
|
||||
if _, err := io.CopyN(io.Discard, r, int64(l)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
type dataFrame struct {
|
||||
Length uint64
|
||||
}
|
||||
|
||||
func (f *dataFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x0)
|
||||
return quicvarint.Append(b, f.Length)
|
||||
}
|
||||
|
||||
type headersFrame struct {
|
||||
Length uint64
|
||||
headerLen int // number of bytes read for type and length field
|
||||
}
|
||||
|
||||
func (f *headersFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x1)
|
||||
return quicvarint.Append(b, f.Length)
|
||||
}
|
||||
|
||||
const (
|
||||
// SETTINGS_MAX_FIELD_SECTION_SIZE
|
||||
settingMaxFieldSectionSize = 0x6
|
||||
// Extended CONNECT, RFC 9220
|
||||
settingExtendedConnect = 0x8
|
||||
// HTTP Datagrams, RFC 9297
|
||||
settingDatagram = 0x33
|
||||
)
|
||||
|
||||
type settingsFrame struct {
|
||||
MaxFieldSectionSize int64 // SETTINGS_MAX_FIELD_SECTION_SIZE, -1 if not set
|
||||
|
||||
Datagram bool // HTTP Datagrams, RFC 9297
|
||||
ExtendedConnect bool // Extended CONNECT, RFC 9220
|
||||
Other map[uint64]uint64 // all settings that we don't explicitly recognize
|
||||
}
|
||||
|
||||
func pointer[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func parseSettingsFrame(r *countingByteReader, l uint64, streamID quic.StreamID, qlogger qlogwriter.Recorder) (*settingsFrame, error) {
|
||||
if l > 8*(1<<10) {
|
||||
return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l)
|
||||
}
|
||||
buf := make([]byte, l)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
frame := &settingsFrame{MaxFieldSectionSize: -1}
|
||||
b := bytes.NewReader(buf)
|
||||
settingsFrame := qlog.SettingsFrame{MaxFieldSectionSize: -1}
|
||||
var readMaxFieldSectionSize, readDatagram, readExtendedConnect bool
|
||||
for b.Len() > 0 {
|
||||
id, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return nil, err
|
||||
}
|
||||
val, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch id {
|
||||
case settingMaxFieldSectionSize:
|
||||
if readMaxFieldSectionSize {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
readMaxFieldSectionSize = true
|
||||
frame.MaxFieldSectionSize = int64(val)
|
||||
settingsFrame.MaxFieldSectionSize = int64(val)
|
||||
case settingExtendedConnect:
|
||||
if readExtendedConnect {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
readExtendedConnect = true
|
||||
if val != 0 && val != 1 {
|
||||
return nil, fmt.Errorf("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: %d", val)
|
||||
}
|
||||
frame.ExtendedConnect = val == 1
|
||||
if qlogger != nil {
|
||||
settingsFrame.ExtendedConnect = pointer(frame.ExtendedConnect)
|
||||
}
|
||||
case settingDatagram:
|
||||
if readDatagram {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
readDatagram = true
|
||||
if val != 0 && val != 1 {
|
||||
return nil, fmt.Errorf("invalid value for SETTINGS_H3_DATAGRAM: %d", val)
|
||||
}
|
||||
frame.Datagram = val == 1
|
||||
if qlogger != nil {
|
||||
settingsFrame.Datagram = pointer(frame.Datagram)
|
||||
}
|
||||
default:
|
||||
if _, ok := frame.Other[id]; ok {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
if frame.Other == nil {
|
||||
frame.Other = make(map[uint64]uint64)
|
||||
}
|
||||
frame.Other[id] = val
|
||||
}
|
||||
}
|
||||
if qlogger != nil {
|
||||
settingsFrame.Other = maps.Clone(frame.Other)
|
||||
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: streamID,
|
||||
Raw: qlog.RawInfo{
|
||||
Length: r.NumRead,
|
||||
PayloadLength: int(l),
|
||||
},
|
||||
Frame: qlog.Frame{Frame: settingsFrame},
|
||||
})
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *settingsFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x4)
|
||||
var l int
|
||||
if f.MaxFieldSectionSize >= 0 {
|
||||
l += quicvarint.Len(settingMaxFieldSectionSize) + quicvarint.Len(uint64(f.MaxFieldSectionSize))
|
||||
}
|
||||
for id, val := range f.Other {
|
||||
l += quicvarint.Len(id) + quicvarint.Len(val)
|
||||
}
|
||||
if f.Datagram {
|
||||
l += quicvarint.Len(settingDatagram) + quicvarint.Len(1)
|
||||
}
|
||||
if f.ExtendedConnect {
|
||||
l += quicvarint.Len(settingExtendedConnect) + quicvarint.Len(1)
|
||||
}
|
||||
b = quicvarint.Append(b, uint64(l))
|
||||
if f.MaxFieldSectionSize >= 0 {
|
||||
b = quicvarint.Append(b, settingMaxFieldSectionSize)
|
||||
b = quicvarint.Append(b, uint64(f.MaxFieldSectionSize))
|
||||
}
|
||||
if f.Datagram {
|
||||
b = quicvarint.Append(b, settingDatagram)
|
||||
b = quicvarint.Append(b, 1)
|
||||
}
|
||||
if f.ExtendedConnect {
|
||||
b = quicvarint.Append(b, settingExtendedConnect)
|
||||
b = quicvarint.Append(b, 1)
|
||||
}
|
||||
for id, val := range f.Other {
|
||||
b = quicvarint.Append(b, id)
|
||||
b = quicvarint.Append(b, val)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
type goAwayFrame struct {
|
||||
StreamID quic.StreamID
|
||||
}
|
||||
|
||||
func parseGoAwayFrame(r *countingByteReader, l uint64, streamID quic.StreamID, qlogger qlogwriter.Recorder) (*goAwayFrame, error) {
|
||||
frame := &goAwayFrame{}
|
||||
startLen := r.NumRead
|
||||
id, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.NumRead-startLen != int(l) {
|
||||
return nil, errors.New("GOAWAY frame: inconsistent length")
|
||||
}
|
||||
frame.StreamID = quic.StreamID(id)
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: streamID,
|
||||
Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.GoAwayFrame{StreamID: frame.StreamID}},
|
||||
})
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *goAwayFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x7)
|
||||
b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(f.StreamID))))
|
||||
return quicvarint.Append(b, uint64(f.StreamID))
|
||||
}
|
||||
39
vendor/github.com/quic-go/quic-go/http3/gzip_reader.go
generated
vendored
Normal file
39
vendor/github.com/quic-go/quic-go/http3/gzip_reader.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package http3
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// gzipReader wraps a response body so it can lazily
|
||||
// call gzip.NewReader on the first call to Read
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
// call gzip.NewReader on the first call to Read
|
||||
type gzipReader struct {
|
||||
body io.ReadCloser // underlying Response.Body
|
||||
zr *gzip.Reader // lazily-initialized gzip reader
|
||||
zerr error // sticky error
|
||||
}
|
||||
|
||||
func newGzipReader(body io.ReadCloser) io.ReadCloser {
|
||||
return &gzipReader{body: body}
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Read(p []byte) (n int, err error) {
|
||||
if gz.zerr != nil {
|
||||
return 0, gz.zerr
|
||||
}
|
||||
if gz.zr == nil {
|
||||
gz.zr, err = gzip.NewReader(gz.body)
|
||||
if err != nil {
|
||||
gz.zerr = err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return gz.zr.Read(p)
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Close() error {
|
||||
return gz.body.Close()
|
||||
}
|
||||
380
vendor/github.com/quic-go/quic-go/http3/headers.go
generated
vendored
Normal file
380
vendor/github.com/quic-go/quic-go/http3/headers.go
generated
vendored
Normal file
@@ -0,0 +1,380 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
type qpackError struct{ err error }
|
||||
|
||||
func (e *qpackError) Error() string { return fmt.Sprintf("qpack: %v", e.err) }
|
||||
func (e *qpackError) Unwrap() error { return e.err }
|
||||
|
||||
var errHeaderTooLarge = errors.New("http3: headers too large")
|
||||
|
||||
type header struct {
|
||||
// Pseudo header fields defined in RFC 9114
|
||||
Path string
|
||||
Method string
|
||||
Authority string
|
||||
Scheme string
|
||||
Status string
|
||||
// for Extended connect
|
||||
Protocol string
|
||||
// parsed and deduplicated. -1 if no Content-Length header is sent
|
||||
ContentLength int64
|
||||
// all non-pseudo headers
|
||||
Headers http.Header
|
||||
}
|
||||
|
||||
// connection-specific header fields must not be sent on HTTP/3
|
||||
var invalidHeaderFields = [...]string{
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"proxy-connection",
|
||||
"transfer-encoding",
|
||||
"upgrade",
|
||||
}
|
||||
|
||||
func parseHeaders(decodeFn qpack.DecodeFunc, isRequest bool, sizeLimit int, headerFields *[]qpack.HeaderField) (header, error) {
|
||||
hdr := header{Headers: make(http.Header)}
|
||||
var readFirstRegularHeader, readContentLength bool
|
||||
var contentLengthStr string
|
||||
for {
|
||||
h, err := decodeFn()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return header{}, &qpackError{err}
|
||||
}
|
||||
if headerFields != nil {
|
||||
*headerFields = append(*headerFields, h)
|
||||
}
|
||||
// RFC 9114, section 4.2.2:
|
||||
// The size of a field list is calculated based on the uncompressed size of fields,
|
||||
// including the length of the name and value in bytes plus an overhead of 32 bytes for each field.
|
||||
sizeLimit -= len(h.Name) + len(h.Value) + 32
|
||||
if sizeLimit < 0 {
|
||||
return header{}, errHeaderTooLarge
|
||||
}
|
||||
// field names need to be lowercase, see section 4.2 of RFC 9114
|
||||
if strings.ToLower(h.Name) != h.Name {
|
||||
return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name)
|
||||
}
|
||||
if !httpguts.ValidHeaderFieldValue(h.Value) {
|
||||
return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value)
|
||||
}
|
||||
if h.IsPseudo() {
|
||||
if readFirstRegularHeader {
|
||||
// all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114
|
||||
return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name)
|
||||
}
|
||||
var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses
|
||||
var isDuplicatePseudoHeader bool // pseudo headers are allowed to appear exactly once
|
||||
switch h.Name {
|
||||
case ":path":
|
||||
isDuplicatePseudoHeader = hdr.Path != ""
|
||||
hdr.Path = h.Value
|
||||
case ":method":
|
||||
isDuplicatePseudoHeader = hdr.Method != ""
|
||||
hdr.Method = h.Value
|
||||
case ":authority":
|
||||
isDuplicatePseudoHeader = hdr.Authority != ""
|
||||
hdr.Authority = h.Value
|
||||
case ":protocol":
|
||||
isDuplicatePseudoHeader = hdr.Protocol != ""
|
||||
hdr.Protocol = h.Value
|
||||
case ":scheme":
|
||||
isDuplicatePseudoHeader = hdr.Scheme != ""
|
||||
hdr.Scheme = h.Value
|
||||
case ":status":
|
||||
isDuplicatePseudoHeader = hdr.Status != ""
|
||||
hdr.Status = h.Value
|
||||
isResponsePseudoHeader = true
|
||||
default:
|
||||
return header{}, fmt.Errorf("unknown pseudo header: %s", h.Name)
|
||||
}
|
||||
if isDuplicatePseudoHeader {
|
||||
return header{}, fmt.Errorf("duplicate pseudo header: %s", h.Name)
|
||||
}
|
||||
if isRequest && isResponsePseudoHeader {
|
||||
return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name)
|
||||
}
|
||||
if !isRequest && !isResponsePseudoHeader {
|
||||
return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name)
|
||||
}
|
||||
} else {
|
||||
if !httpguts.ValidHeaderFieldName(h.Name) {
|
||||
return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
|
||||
}
|
||||
for _, invalidField := range invalidHeaderFields {
|
||||
if h.Name == invalidField {
|
||||
return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
|
||||
}
|
||||
}
|
||||
if h.Name == "te" && h.Value != "trailers" {
|
||||
return header{}, fmt.Errorf("invalid TE header field value: %q", h.Value)
|
||||
}
|
||||
readFirstRegularHeader = true
|
||||
switch h.Name {
|
||||
case "content-length":
|
||||
// Ignore duplicate Content-Length headers.
|
||||
// Fail if the duplicates differ.
|
||||
if !readContentLength {
|
||||
readContentLength = true
|
||||
contentLengthStr = h.Value
|
||||
} else if contentLengthStr != h.Value {
|
||||
return header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value)
|
||||
}
|
||||
default:
|
||||
hdr.Headers.Add(h.Name, h.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
hdr.ContentLength = -1
|
||||
if len(contentLengthStr) > 0 {
|
||||
// use ParseUint instead of ParseInt, so that parsing fails on negative values
|
||||
cl, err := strconv.ParseUint(contentLengthStr, 10, 63)
|
||||
if err != nil {
|
||||
return header{}, fmt.Errorf("invalid content length: %w", err)
|
||||
}
|
||||
hdr.Headers.Set("Content-Length", contentLengthStr)
|
||||
hdr.ContentLength = int64(cl)
|
||||
}
|
||||
return hdr, nil
|
||||
}
|
||||
|
||||
func parseTrailers(decodeFn qpack.DecodeFunc, headerFields *[]qpack.HeaderField) (http.Header, error) {
|
||||
h := make(http.Header)
|
||||
for {
|
||||
hf, err := decodeFn()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, &qpackError{err}
|
||||
}
|
||||
if headerFields != nil {
|
||||
*headerFields = append(*headerFields, hf)
|
||||
}
|
||||
if hf.IsPseudo() {
|
||||
return nil, fmt.Errorf("http3: received pseudo header in trailer: %s", hf.Name)
|
||||
}
|
||||
h.Add(hf.Name, hf.Value)
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func requestFromHeaders(decodeFn qpack.DecodeFunc, sizeLimit int, headerFields *[]qpack.HeaderField) (*http.Request, error) {
|
||||
hdr, err := parseHeaders(decodeFn, true, sizeLimit, headerFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
||||
if len(hdr.Headers["Cookie"]) > 0 {
|
||||
hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; "))
|
||||
}
|
||||
|
||||
isConnect := hdr.Method == http.MethodConnect
|
||||
// Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4
|
||||
isExtendedConnected := isConnect && hdr.Protocol != ""
|
||||
if isExtendedConnected {
|
||||
if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" {
|
||||
return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty")
|
||||
}
|
||||
} else if isConnect {
|
||||
if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT
|
||||
return nil, errors.New(":path must be empty and :authority must not be empty")
|
||||
}
|
||||
} else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 {
|
||||
return nil, errors.New(":path, :authority and :method must not be empty")
|
||||
}
|
||||
|
||||
if !isExtendedConnected && len(hdr.Protocol) > 0 {
|
||||
return nil, errors.New(":protocol must be empty")
|
||||
}
|
||||
|
||||
var u *url.URL
|
||||
var requestURI string
|
||||
|
||||
protocol := "HTTP/3.0"
|
||||
|
||||
if isConnect {
|
||||
u = &url.URL{}
|
||||
if isExtendedConnected {
|
||||
u, err = url.ParseRequestURI(hdr.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
protocol = hdr.Protocol
|
||||
} else {
|
||||
u.Path = hdr.Path
|
||||
}
|
||||
u.Scheme = hdr.Scheme
|
||||
u.Host = hdr.Authority
|
||||
requestURI = hdr.Authority
|
||||
} else {
|
||||
u, err = url.ParseRequestURI(hdr.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid content length: %w", err)
|
||||
}
|
||||
requestURI = hdr.Path
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: hdr.Method,
|
||||
URL: u,
|
||||
Proto: protocol,
|
||||
ProtoMajor: 3,
|
||||
ProtoMinor: 0,
|
||||
Header: hdr.Headers,
|
||||
Body: nil,
|
||||
ContentLength: hdr.ContentLength,
|
||||
Host: hdr.Authority,
|
||||
RequestURI: requestURI,
|
||||
}
|
||||
req.Trailer = extractAnnouncedTrailers(req.Header)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// updateResponseFromHeaders sets up http.Response as an HTTP/3 response,
|
||||
// using the decoded qpack header filed.
|
||||
// It is only called for the HTTP header (and not the HTTP trailer).
|
||||
// It takes an http.Response as an argument to allow the caller to set the trailer later on.
|
||||
func updateResponseFromHeaders(rsp *http.Response, decodeFn qpack.DecodeFunc, sizeLimit int, headerFields *[]qpack.HeaderField) error {
|
||||
hdr, err := parseHeaders(decodeFn, false, sizeLimit, headerFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hdr.Status == "" {
|
||||
return errors.New("missing :status field")
|
||||
}
|
||||
rsp.Proto = "HTTP/3.0"
|
||||
rsp.ProtoMajor = 3
|
||||
rsp.Header = hdr.Headers
|
||||
rsp.Trailer = extractAnnouncedTrailers(rsp.Header)
|
||||
rsp.ContentLength = hdr.ContentLength
|
||||
|
||||
status, err := strconv.Atoi(hdr.Status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid status code: %w", err)
|
||||
}
|
||||
rsp.StatusCode = status
|
||||
rsp.Status = hdr.Status + " " + http.StatusText(status)
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractAnnouncedTrailers extracts trailer keys from the "Trailer" header.
|
||||
// It returns a map with the announced keys set to nil values, and removes the "Trailer" header.
|
||||
// It handles both duplicate as well as comma-separated values for the Trailer header.
|
||||
// For example:
|
||||
//
|
||||
// Trailer: Trailer1, Trailer2
|
||||
// Trailer: Trailer3
|
||||
//
|
||||
// Will result in a map containing the keys "Trailer1", "Trailer2", "Trailer3" with nil values.
|
||||
func extractAnnouncedTrailers(header http.Header) http.Header {
|
||||
rawTrailers, ok := header["Trailer"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
trailers := make(http.Header)
|
||||
for _, rawVal := range rawTrailers {
|
||||
for _, val := range strings.Split(rawVal, ",") {
|
||||
trailers[http.CanonicalHeaderKey(textproto.TrimString(val))] = nil
|
||||
}
|
||||
}
|
||||
delete(header, "Trailer")
|
||||
return trailers
|
||||
}
|
||||
|
||||
// writeTrailers encodes and writes HTTP trailers as a HEADERS frame.
|
||||
// It returns true if trailers were written, false if there were no trailers to write.
|
||||
func writeTrailers(wr io.Writer, trailers http.Header, streamID quic.StreamID, qlogger qlogwriter.Recorder) (bool, error) {
|
||||
var hasValues bool
|
||||
for k, vals := range trailers {
|
||||
if httpguts.ValidTrailerHeader(k) && len(vals) > 0 {
|
||||
hasValues = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasValues {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc := qpack.NewEncoder(&buf)
|
||||
var headerFields []qlog.HeaderField
|
||||
if qlogger != nil {
|
||||
headerFields = make([]qlog.HeaderField, 0, len(trailers))
|
||||
}
|
||||
|
||||
for k, vals := range trailers {
|
||||
if len(vals) == 0 {
|
||||
continue
|
||||
}
|
||||
if !httpguts.ValidTrailerHeader(k) {
|
||||
continue
|
||||
}
|
||||
lowercaseKey := strings.ToLower(k)
|
||||
for _, v := range vals {
|
||||
if err := enc.WriteField(qpack.HeaderField{Name: lowercaseKey, Value: v}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if qlogger != nil {
|
||||
headerFields = append(headerFields, qlog.HeaderField{Name: lowercaseKey, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b := make([]byte, 0, frameHeaderLen+buf.Len())
|
||||
b = (&headersFrame{Length: uint64(buf.Len())}).Append(b)
|
||||
b = append(b, buf.Bytes()...)
|
||||
if qlogger != nil {
|
||||
qlogCreatedHeadersFrame(qlogger, streamID, len(b), buf.Len(), headerFields)
|
||||
}
|
||||
_, err := wr.Write(b)
|
||||
return true, err
|
||||
}
|
||||
|
||||
func decodeTrailers(r io.Reader, hf *headersFrame, maxHeaderBytes int, decoder *qpack.Decoder, qlogger qlogwriter.Recorder, streamID quic.StreamID) (http.Header, error) {
|
||||
if hf.Length > uint64(maxHeaderBytes) {
|
||||
maybeQlogInvalidHeadersFrame(qlogger, streamID, hf.Length)
|
||||
return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, maxHeaderBytes)
|
||||
}
|
||||
|
||||
b := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodeFn := decoder.Decode(b)
|
||||
var fields []qpack.HeaderField
|
||||
if qlogger != nil {
|
||||
fields = make([]qpack.HeaderField, 0, 16)
|
||||
}
|
||||
trailers, err := parseTrailers(decodeFn, &fields)
|
||||
if err != nil {
|
||||
maybeQlogInvalidHeadersFrame(qlogger, streamID, hf.Length)
|
||||
return nil, err
|
||||
}
|
||||
if qlogger != nil {
|
||||
qlogParsedHeadersFrame(qlogger, streamID, hf, fields)
|
||||
}
|
||||
return trailers, nil
|
||||
}
|
||||
48
vendor/github.com/quic-go/quic-go/http3/ip_addr.go
generated
vendored
Normal file
48
vendor/github.com/quic-go/quic-go/http3/ip_addr.go
generated
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// An addrList represents a list of network endpoint addresses.
|
||||
// Copy from [net.addrList] and change type from [net.Addr] to [net.IPAddr]
|
||||
type addrList []net.IPAddr
|
||||
|
||||
// isIPv4 reports whether addr contains an IPv4 address.
|
||||
func isIPv4(addr net.IPAddr) bool {
|
||||
return addr.IP.To4() != nil
|
||||
}
|
||||
|
||||
// isNotIPv4 reports whether addr does not contain an IPv4 address.
|
||||
func isNotIPv4(addr net.IPAddr) bool { return !isIPv4(addr) }
|
||||
|
||||
// forResolve returns the most appropriate address in address for
|
||||
// a call to ResolveTCPAddr, ResolveUDPAddr, or ResolveIPAddr.
|
||||
// IPv4 is preferred, unless addr contains an IPv6 literal.
|
||||
func (addrs addrList) forResolve(network, addr string) net.IPAddr {
|
||||
var want6 bool
|
||||
switch network {
|
||||
case "ip":
|
||||
// IPv6 literal (addr does NOT contain a port)
|
||||
want6 = strings.ContainsRune(addr, ':')
|
||||
case "tcp", "udp":
|
||||
// IPv6 literal. (addr contains a port, so look for '[')
|
||||
want6 = strings.ContainsRune(addr, '[')
|
||||
}
|
||||
if want6 {
|
||||
return addrs.first(isNotIPv4)
|
||||
}
|
||||
return addrs.first(isIPv4)
|
||||
}
|
||||
|
||||
// first returns the first address which satisfies strategy, or if
|
||||
// none do, then the first address of any kind.
|
||||
func (addrs addrList) first(strategy func(net.IPAddr) bool) net.IPAddr {
|
||||
for _, addr := range addrs {
|
||||
if strategy(addr) {
|
||||
return addr
|
||||
}
|
||||
}
|
||||
return addrs[0]
|
||||
}
|
||||
11
vendor/github.com/quic-go/quic-go/http3/mockgen.go
generated
vendored
Normal file
11
vendor/github.com/quic-go/quic-go/http3/mockgen.go
generated
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build gomock || generate
|
||||
|
||||
package http3
|
||||
|
||||
//go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -mock_names=TestClientConnInterface=MockClientConn -package http3 -destination mock_clientconn_test.go github.com/quic-go/quic-go/http3 TestClientConnInterface"
|
||||
type TestClientConnInterface = clientConn
|
||||
|
||||
//go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -mock_names=DatagramStream=MockDatagramStream -package http3 -destination mock_datagram_stream_test.go github.com/quic-go/quic-go/http3 DatagramStream"
|
||||
type DatagramStream = datagramStream
|
||||
|
||||
//go:generate sh -c "go tool mockgen -typed -package http3 -destination mock_quic_listener_test.go github.com/quic-go/quic-go/http3 QUICListener"
|
||||
56
vendor/github.com/quic-go/quic-go/http3/qlog.go
generated
vendored
Normal file
56
vendor/github.com/quic-go/quic-go/http3/qlog.go
generated
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
func maybeQlogInvalidHeadersFrame(qlogger qlogwriter.Recorder, streamID quic.StreamID, l uint64) {
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: streamID,
|
||||
Raw: qlog.RawInfo{PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.HeadersFrame{}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func qlogParsedHeadersFrame(qlogger qlogwriter.Recorder, streamID quic.StreamID, hf *headersFrame, hfs []qpack.HeaderField) {
|
||||
headerFields := make([]qlog.HeaderField, len(hfs))
|
||||
for i, hf := range hfs {
|
||||
headerFields[i] = qlog.HeaderField{
|
||||
Name: hf.Name,
|
||||
Value: hf.Value,
|
||||
}
|
||||
}
|
||||
qlogger.RecordEvent(qlog.FrameParsed{
|
||||
StreamID: streamID,
|
||||
Raw: qlog.RawInfo{
|
||||
Length: int(hf.Length) + hf.headerLen,
|
||||
PayloadLength: int(hf.Length),
|
||||
},
|
||||
Frame: qlog.Frame{Frame: qlog.HeadersFrame{
|
||||
HeaderFields: headerFields,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
func qlogCreatedHeadersFrame(qlogger qlogwriter.Recorder, streamID quic.StreamID, length, payloadLength int, hfs []qlog.HeaderField) {
|
||||
headerFields := make([]qlog.HeaderField, len(hfs))
|
||||
for i, hf := range hfs {
|
||||
headerFields[i] = qlog.HeaderField{
|
||||
Name: hf.Name,
|
||||
Value: hf.Value,
|
||||
}
|
||||
}
|
||||
qlogger.RecordEvent(qlog.FrameCreated{
|
||||
StreamID: streamID,
|
||||
Raw: qlog.RawInfo{Length: length, PayloadLength: payloadLength},
|
||||
Frame: qlog.Frame{Frame: qlog.HeadersFrame{
|
||||
HeaderFields: headerFields,
|
||||
}},
|
||||
})
|
||||
}
|
||||
138
vendor/github.com/quic-go/quic-go/http3/qlog/event.go
generated
vendored
Normal file
138
vendor/github.com/quic-go/quic-go/http3/qlog/event.go
generated
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
package qlog
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/qlogwriter/jsontext"
|
||||
)
|
||||
|
||||
type encoderHelper struct {
|
||||
enc *jsontext.Encoder
|
||||
err error
|
||||
}
|
||||
|
||||
func (h *encoderHelper) WriteToken(t jsontext.Token) {
|
||||
if h.err != nil {
|
||||
return
|
||||
}
|
||||
h.err = h.enc.WriteToken(t)
|
||||
}
|
||||
|
||||
type RawInfo struct {
|
||||
Length int // full packet length, including header and AEAD authentication tag
|
||||
PayloadLength int // length of the packet payload, excluding AEAD tag
|
||||
}
|
||||
|
||||
func (i RawInfo) HasValues() bool {
|
||||
return i.Length != 0 || i.PayloadLength != 0
|
||||
}
|
||||
|
||||
func (i RawInfo) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
if i.Length != 0 {
|
||||
h.WriteToken(jsontext.String("length"))
|
||||
h.WriteToken(jsontext.Uint(uint64(i.Length)))
|
||||
}
|
||||
if i.PayloadLength != 0 {
|
||||
h.WriteToken(jsontext.String("payload_length"))
|
||||
h.WriteToken(jsontext.Uint(uint64(i.PayloadLength)))
|
||||
}
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
type FrameParsed struct {
|
||||
StreamID quic.StreamID
|
||||
Raw RawInfo
|
||||
Frame Frame
|
||||
}
|
||||
|
||||
func (e FrameParsed) Name() string { return "http3:frame_parsed" }
|
||||
|
||||
func (e FrameParsed) Encode(enc *jsontext.Encoder, _ time.Time) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("stream_id"))
|
||||
h.WriteToken(jsontext.Uint(uint64(e.StreamID)))
|
||||
if e.Raw.HasValues() {
|
||||
h.WriteToken(jsontext.String("raw"))
|
||||
if err := e.Raw.encode(enc); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.WriteToken(jsontext.String("frame"))
|
||||
if err := e.Frame.encode(enc); err != nil {
|
||||
return err
|
||||
}
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
type FrameCreated struct {
|
||||
StreamID quic.StreamID
|
||||
Raw RawInfo
|
||||
Frame Frame
|
||||
}
|
||||
|
||||
func (e FrameCreated) Name() string { return "http3:frame_created" }
|
||||
|
||||
func (e FrameCreated) Encode(enc *jsontext.Encoder, _ time.Time) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("stream_id"))
|
||||
h.WriteToken(jsontext.Uint(uint64(e.StreamID)))
|
||||
if e.Raw.HasValues() {
|
||||
h.WriteToken(jsontext.String("raw"))
|
||||
if err := e.Raw.encode(enc); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.WriteToken(jsontext.String("frame"))
|
||||
if err := e.Frame.encode(enc); err != nil {
|
||||
return err
|
||||
}
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
type DatagramCreated struct {
|
||||
QuaterStreamID uint64
|
||||
Raw RawInfo
|
||||
}
|
||||
|
||||
func (e DatagramCreated) Name() string { return "http3:datagram_created" }
|
||||
|
||||
func (e DatagramCreated) Encode(enc *jsontext.Encoder, _ time.Time) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("quater_stream_id"))
|
||||
h.WriteToken(jsontext.Uint(e.QuaterStreamID))
|
||||
h.WriteToken(jsontext.String("raw"))
|
||||
if err := e.Raw.encode(enc); err != nil {
|
||||
return err
|
||||
}
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
type DatagramParsed struct {
|
||||
QuaterStreamID uint64
|
||||
Raw RawInfo
|
||||
}
|
||||
|
||||
func (e DatagramParsed) Name() string { return "http3:datagram_parsed" }
|
||||
|
||||
func (e DatagramParsed) Encode(enc *jsontext.Encoder, _ time.Time) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("quater_stream_id"))
|
||||
h.WriteToken(jsontext.Uint(e.QuaterStreamID))
|
||||
h.WriteToken(jsontext.String("raw"))
|
||||
if err := e.Raw.encode(enc); err != nil {
|
||||
return err
|
||||
}
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
220
vendor/github.com/quic-go/quic-go/http3/qlog/frame.go
generated
vendored
Normal file
220
vendor/github.com/quic-go/quic-go/http3/qlog/frame.go
generated
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
package qlog
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/qlogwriter/jsontext"
|
||||
)
|
||||
|
||||
// Frame represents an HTTP/3 frame.
|
||||
type Frame struct {
|
||||
Frame any
|
||||
}
|
||||
|
||||
func (f Frame) encode(enc *jsontext.Encoder) error {
|
||||
switch frame := f.Frame.(type) {
|
||||
case DataFrame:
|
||||
return frame.encode(enc)
|
||||
case HeadersFrame:
|
||||
return frame.encode(enc)
|
||||
case GoAwayFrame:
|
||||
return frame.encode(enc)
|
||||
case SettingsFrame:
|
||||
return frame.encode(enc)
|
||||
case PushPromiseFrame:
|
||||
return frame.encode(enc)
|
||||
case CancelPushFrame:
|
||||
return frame.encode(enc)
|
||||
case MaxPushIDFrame:
|
||||
return frame.encode(enc)
|
||||
case ReservedFrame:
|
||||
return frame.encode(enc)
|
||||
case UnknownFrame:
|
||||
return frame.encode(enc)
|
||||
}
|
||||
// This shouldn't happen if the code is correctly logging frames.
|
||||
// Write a null token to produce valid JSON.
|
||||
return enc.WriteToken(jsontext.Null)
|
||||
}
|
||||
|
||||
// A DataFrame is a DATA frame
|
||||
type DataFrame struct{}
|
||||
|
||||
func (f *DataFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("data"))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
type HeaderField struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// A HeadersFrame is a HEADERS frame
|
||||
type HeadersFrame struct {
|
||||
HeaderFields []HeaderField
|
||||
}
|
||||
|
||||
func (f *HeadersFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("headers"))
|
||||
if len(f.HeaderFields) > 0 {
|
||||
h.WriteToken(jsontext.String("header_fields"))
|
||||
h.WriteToken(jsontext.BeginArray)
|
||||
for _, f := range f.HeaderFields {
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("name"))
|
||||
h.WriteToken(jsontext.String(f.Name))
|
||||
h.WriteToken(jsontext.String("value"))
|
||||
h.WriteToken(jsontext.String(f.Value))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
}
|
||||
h.WriteToken(jsontext.EndArray)
|
||||
}
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
// A GoAwayFrame is a GOAWAY frame
|
||||
type GoAwayFrame struct {
|
||||
StreamID quic.StreamID
|
||||
}
|
||||
|
||||
func (f *GoAwayFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("goaway"))
|
||||
h.WriteToken(jsontext.String("id"))
|
||||
h.WriteToken(jsontext.Uint(uint64(f.StreamID)))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
type SettingsFrame struct {
|
||||
MaxFieldSectionSize int64
|
||||
Datagram *bool
|
||||
ExtendedConnect *bool
|
||||
Other map[uint64]uint64
|
||||
}
|
||||
|
||||
func (f *SettingsFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("settings"))
|
||||
h.WriteToken(jsontext.String("settings"))
|
||||
h.WriteToken(jsontext.BeginArray)
|
||||
if f.MaxFieldSectionSize >= 0 {
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("name"))
|
||||
h.WriteToken(jsontext.String("settings_max_field_section_size"))
|
||||
h.WriteToken(jsontext.String("value"))
|
||||
h.WriteToken(jsontext.Uint(uint64(f.MaxFieldSectionSize)))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
}
|
||||
if f.Datagram != nil {
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("name"))
|
||||
h.WriteToken(jsontext.String("settings_h3_datagram"))
|
||||
h.WriteToken(jsontext.String("value"))
|
||||
h.WriteToken(jsontext.Bool(*f.Datagram))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
}
|
||||
if f.ExtendedConnect != nil {
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("name"))
|
||||
h.WriteToken(jsontext.String("settings_enable_connect_protocol"))
|
||||
h.WriteToken(jsontext.String("value"))
|
||||
h.WriteToken(jsontext.Bool(*f.ExtendedConnect))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
}
|
||||
if len(f.Other) > 0 {
|
||||
for k, v := range f.Other {
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("name"))
|
||||
h.WriteToken(jsontext.String("unknown"))
|
||||
h.WriteToken(jsontext.String("name_bytes"))
|
||||
h.WriteToken(jsontext.Uint(k))
|
||||
h.WriteToken(jsontext.String("value"))
|
||||
h.WriteToken(jsontext.Uint(v))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
}
|
||||
}
|
||||
h.WriteToken(jsontext.EndArray)
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
// A PushPromiseFrame is a PUSH_PROMISE frame
|
||||
type PushPromiseFrame struct{}
|
||||
|
||||
func (f *PushPromiseFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("push_promise"))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
// A CancelPushFrame is a CANCEL_PUSH frame
|
||||
type CancelPushFrame struct{}
|
||||
|
||||
func (f *CancelPushFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("cancel_push"))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
// A MaxPushIDFrame is a MAX_PUSH_ID frame
|
||||
type MaxPushIDFrame struct{}
|
||||
|
||||
func (f *MaxPushIDFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("max_push_id"))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
// A ReservedFrame is one of the reserved frame types
|
||||
type ReservedFrame struct {
|
||||
Type uint64
|
||||
}
|
||||
|
||||
func (f *ReservedFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("reserved"))
|
||||
h.WriteToken(jsontext.String("frame_type_bytes"))
|
||||
h.WriteToken(jsontext.Uint(f.Type))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
|
||||
// An UnknownFrame is an unknown frame type
|
||||
type UnknownFrame struct {
|
||||
Type uint64
|
||||
}
|
||||
|
||||
func (f *UnknownFrame) encode(enc *jsontext.Encoder) error {
|
||||
h := encoderHelper{enc: enc}
|
||||
h.WriteToken(jsontext.BeginObject)
|
||||
h.WriteToken(jsontext.String("frame_type"))
|
||||
h.WriteToken(jsontext.String("unknown"))
|
||||
h.WriteToken(jsontext.String("frame_type_bytes"))
|
||||
h.WriteToken(jsontext.Uint(f.Type))
|
||||
h.WriteToken(jsontext.EndObject)
|
||||
return h.err
|
||||
}
|
||||
15
vendor/github.com/quic-go/quic-go/http3/qlog/qlog_dir.go
generated
vendored
Normal file
15
vendor/github.com/quic-go/quic-go/http3/qlog/qlog_dir.go
generated
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
package qlog
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
const EventSchema = "urn:ietf:params:qlog:events:http3-12"
|
||||
|
||||
func DefaultConnectionTracer(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace {
|
||||
return qlog.DefaultConnectionTracerWithSchemas(ctx, isClient, connID, []string{qlog.EventSchema, EventSchema})
|
||||
}
|
||||
321
vendor/github.com/quic-go/quic-go/http3/request_writer.go
generated
vendored
Normal file
321
vendor/github.com/quic-go/quic-go/http3/request_writer.go
generated
vendored
Normal file
@@ -0,0 +1,321 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
const bodyCopyBufferSize = 8 * 1024
|
||||
|
||||
type requestWriter struct {
|
||||
mutex sync.Mutex
|
||||
encoder *qpack.Encoder
|
||||
headerBuf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newRequestWriter() *requestWriter {
|
||||
headerBuf := &bytes.Buffer{}
|
||||
encoder := qpack.NewEncoder(headerBuf)
|
||||
return &requestWriter{
|
||||
encoder: encoder,
|
||||
headerBuf: headerBuf,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequestHeader(wr io.Writer, req *http.Request, gzip bool, streamID quic.StreamID, qlogger qlogwriter.Recorder) error {
|
||||
buf := &bytes.Buffer{}
|
||||
if err := w.writeHeaders(buf, req, gzip, streamID, qlogger); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := wr.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
traceWroteHeaders(trace)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, streamID quic.StreamID, qlogger qlogwriter.Recorder) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
defer w.encoder.Close()
|
||||
defer w.headerBuf.Reset()
|
||||
|
||||
var trailers string
|
||||
if len(req.Trailer) > 0 {
|
||||
keys := make([]string, 0, len(req.Trailer))
|
||||
for k := range req.Trailer {
|
||||
if httpguts.ValidTrailerHeader(k) {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
trailers = strings.Join(keys, ", ")
|
||||
}
|
||||
|
||||
headerFields, err := w.encodeHeaders(req, gzip, trailers, actualContentLength(req), qlogger != nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b := make([]byte, 0, 128)
|
||||
b = (&headersFrame{Length: uint64(w.headerBuf.Len())}).Append(b)
|
||||
if qlogger != nil {
|
||||
qlogCreatedHeadersFrame(qlogger, streamID, len(b)+w.headerBuf.Len(), w.headerBuf.Len(), headerFields)
|
||||
}
|
||||
if _, err := wr.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = wr.Write(w.headerBuf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func isExtendedConnectRequest(req *http.Request) bool {
|
||||
return req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
// Modified to support Extended CONNECT:
|
||||
// Contrary to what the godoc for the http.Request says,
|
||||
// we do respect the Proto field if the method is CONNECT.
|
||||
//
|
||||
// The returned header fields are only set if doQlog is true.
|
||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, doQlog bool) ([]qlog.HeaderField, error) {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
host, err := httpguts.PunycodeHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !httpguts.ValidHostHeader(host) {
|
||||
return nil, errors.New("http3: invalid Host header")
|
||||
}
|
||||
|
||||
// http.NewRequest sets this field to HTTP/1.1
|
||||
isExtendedConnect := isExtendedConnectRequest(req)
|
||||
|
||||
var path string
|
||||
if req.Method != http.MethodConnect || isExtendedConnect {
|
||||
path = req.URL.RequestURI()
|
||||
if !validPseudoPath(path) {
|
||||
orig := path
|
||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||||
if !validPseudoPath(path) {
|
||||
if req.URL.Opaque != "" {
|
||||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for any invalid headers and return an error before we
|
||||
// potentially pollute our hpack state. (We want to be able to
|
||||
// continue to reuse the hpack encoder for future requests)
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enumerateHeaders := func(f func(name, value string)) {
|
||||
// 8.1.2.3 Request Pseudo-Header Fields
|
||||
// The :path pseudo-header field includes the path and query parts of the
|
||||
// target URI (the path-absolute production and optionally a '?' character
|
||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||
// [RFC3986]).
|
||||
f(":authority", host)
|
||||
f(":method", req.Method)
|
||||
if req.Method != http.MethodConnect || isExtendedConnect {
|
||||
f(":path", path)
|
||||
f(":scheme", req.URL.Scheme)
|
||||
}
|
||||
if isExtendedConnect {
|
||||
f(":protocol", req.Proto)
|
||||
}
|
||||
if trailers != "" {
|
||||
f("trailer", trailers)
|
||||
}
|
||||
|
||||
var didUA bool
|
||||
for k, vv := range req.Header {
|
||||
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
|
||||
// Host is :authority, already sent.
|
||||
// Content-Length is automatic, set below.
|
||||
continue
|
||||
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
|
||||
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
|
||||
strings.EqualFold(k, "keep-alive") {
|
||||
// Per 8.1.2.2 Connection-Specific Header
|
||||
// Fields, don't send connection-specific
|
||||
// fields. We have already checked if any
|
||||
// are error-worthy so just ignore the rest.
|
||||
continue
|
||||
} else if strings.EqualFold(k, "user-agent") {
|
||||
// Match Go's http1 behavior: at most one
|
||||
// User-Agent. If set to nil or empty string,
|
||||
// then omit it. Otherwise if not mentioned,
|
||||
// include the default (below).
|
||||
didUA = true
|
||||
if len(vv) < 1 {
|
||||
continue
|
||||
}
|
||||
vv = vv[:1]
|
||||
if vv[0] == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, v := range vv {
|
||||
f(k, v)
|
||||
}
|
||||
}
|
||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
||||
f("content-length", strconv.FormatInt(contentLength, 10))
|
||||
}
|
||||
if addGzipHeader {
|
||||
f("accept-encoding", "gzip")
|
||||
}
|
||||
if !didUA {
|
||||
f("user-agent", defaultUserAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// Do a first pass over the headers counting bytes to ensure
|
||||
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
|
||||
// separate pass before encoding the headers to prevent
|
||||
// modifying the hpack state.
|
||||
hlSize := uint64(0)
|
||||
enumerateHeaders(func(name, value string) {
|
||||
hf := hpack.HeaderField{Name: name, Value: value}
|
||||
hlSize += uint64(hf.Size())
|
||||
})
|
||||
|
||||
// TODO: check maximum header list size
|
||||
// if hlSize > cc.peerMaxHeaderListSize {
|
||||
// return errRequestHeaderListSize
|
||||
// }
|
||||
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
traceHeaders := traceHasWroteHeaderField(trace)
|
||||
|
||||
// Header list size is ok. Write the headers.
|
||||
var headerFields []qlog.HeaderField
|
||||
if doQlog {
|
||||
headerFields = make([]qlog.HeaderField, 0, len(req.Header))
|
||||
}
|
||||
enumerateHeaders(func(name, value string) {
|
||||
name = strings.ToLower(name)
|
||||
w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
|
||||
if traceHeaders {
|
||||
traceWroteHeaderField(trace, name, value)
|
||||
}
|
||||
if doQlog {
|
||||
headerFields = append(headerFields, qlog.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
})
|
||||
|
||||
return headerFields, nil
|
||||
}
|
||||
|
||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
// and returns a host:port. The port 443 is added if needed.
|
||||
func authorityAddr(authority string) (addr string) {
|
||||
host, port, err := net.SplitHostPort(authority)
|
||||
if err != nil { // authority didn't have a port
|
||||
port = "443"
|
||||
host = authority
|
||||
}
|
||||
if a, err := idna.ToASCII(host); err == nil {
|
||||
host = a
|
||||
}
|
||||
// IPv6 address literal, without a port:
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host + ":" + port
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// validPseudoPath reports whether v is a valid :path pseudo-header
|
||||
// value. It must be either:
|
||||
//
|
||||
// *) a non-empty string starting with '/'
|
||||
// *) the string '*', for OPTIONS requests.
|
||||
//
|
||||
// For now this is only used a quick check for deciding when to clean
|
||||
// up Opaque URLs before sending requests from the Transport.
|
||||
// See golang.org/issue/16847
|
||||
//
|
||||
// We used to enforce that the path also didn't start with "//", but
|
||||
// Google's GFE accepts such paths and Chrome sends them, so ignore
|
||||
// that part of the spec. See golang.org/issue/19103.
|
||||
func validPseudoPath(v string) bool {
|
||||
return (len(v) > 0 && v[0] == '/') || v == "*"
|
||||
}
|
||||
|
||||
// actualContentLength returns a sanitized version of
|
||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||
// means unknown.
|
||||
func actualContentLength(req *http.Request) int64 {
|
||||
if req.Body == nil {
|
||||
return 0
|
||||
}
|
||||
if req.ContentLength != 0 {
|
||||
return req.ContentLength
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||
// transferWriter.shouldSendContentLength.
|
||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||
// -1 means unknown.
|
||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||||
if contentLength > 0 {
|
||||
return true
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return false
|
||||
}
|
||||
// For zero bodies, whether we send a content-length depends on the method.
|
||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||
switch method {
|
||||
case "POST", "PUT", "PATCH":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// WriteRequestTrailer writes HTTP trailers to the stream.
|
||||
// It should be called after the request body has been fully written.
|
||||
func (w *requestWriter) WriteRequestTrailer(wr io.Writer, req *http.Request, streamID quic.StreamID, qlogger qlogwriter.Recorder) error {
|
||||
_, err := writeTrailers(wr, req.Trailer, streamID, qlogger)
|
||||
return err
|
||||
}
|
||||
368
vendor/github.com/quic-go/quic-go/http3/response_writer.go
generated
vendored
Normal file
368
vendor/github.com/quic-go/quic-go/http3/response_writer.go
generated
vendored
Normal file
@@ -0,0 +1,368 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by the http.ResponseWriter.
|
||||
// When a stream is taken over, it's the caller's responsibility to close the stream.
|
||||
type HTTPStreamer interface {
|
||||
HTTPStream() *Stream
|
||||
}
|
||||
|
||||
const maxSmallResponseSize = 4096
|
||||
|
||||
type responseWriter struct {
|
||||
str *Stream
|
||||
|
||||
conn *rawConn
|
||||
header http.Header
|
||||
trailers map[string]struct{}
|
||||
buf []byte
|
||||
status int // status code passed to WriteHeader
|
||||
|
||||
// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
|
||||
// and automatically add the Content-Length header
|
||||
smallResponseBuf []byte
|
||||
|
||||
contentLen int64 // if handler set valid Content-Length header
|
||||
numWritten int64 // bytes written
|
||||
headerComplete bool // set once WriteHeader is called with a status code >= 200
|
||||
headerWritten bool // set once the response header has been serialized to the stream
|
||||
isHead bool
|
||||
trailerWritten bool // set once the response trailers has been serialized to the stream
|
||||
|
||||
hijacked bool // set on HTTPStream is called
|
||||
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.ResponseWriter = &responseWriter{}
|
||||
_ http.Flusher = &responseWriter{}
|
||||
_ Settingser = &responseWriter{}
|
||||
_ HTTPStreamer = &responseWriter{}
|
||||
// make sure that we implement (some of the) methods used by the http.ResponseController
|
||||
_ interface {
|
||||
SetReadDeadline(time.Time) error
|
||||
SetWriteDeadline(time.Time) error
|
||||
Flush()
|
||||
FlushError() error
|
||||
} = &responseWriter{}
|
||||
)
|
||||
|
||||
func newResponseWriter(str *Stream, conn *rawConn, isHead bool, logger *slog.Logger) *responseWriter {
|
||||
return &responseWriter{
|
||||
str: str,
|
||||
conn: conn,
|
||||
header: http.Header{},
|
||||
buf: make([]byte, frameHeaderLen),
|
||||
isHead: isHead,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(status int) {
|
||||
if w.headerComplete {
|
||||
return
|
||||
}
|
||||
|
||||
// http status must be 3 digits
|
||||
if status < 100 || status > 999 {
|
||||
panic(fmt.Sprintf("invalid WriteHeader code %v", status))
|
||||
}
|
||||
w.status = status
|
||||
|
||||
// immediately write 1xx headers
|
||||
if status < 200 {
|
||||
w.writeHeader(status)
|
||||
return
|
||||
}
|
||||
|
||||
// We're done with headers once we write a status >= 200.
|
||||
w.headerComplete = true
|
||||
// Add Date header.
|
||||
// This is what the standard library does.
|
||||
// Can be disabled by setting the Date header to nil.
|
||||
if _, ok := w.header["Date"]; !ok {
|
||||
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
|
||||
}
|
||||
// Content-Length checking
|
||||
// use ParseUint instead of ParseInt, as negative values are invalid
|
||||
if clen := w.header.Get("Content-Length"); clen != "" {
|
||||
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
|
||||
w.contentLen = int64(cl)
|
||||
} else {
|
||||
// emit a warning for malformed Content-Length and remove it
|
||||
logger := w.logger
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
logger.Error("Malformed Content-Length", "value", clen)
|
||||
w.header.Del("Content-Length")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) sniffContentType(p []byte) {
|
||||
// If no content type, apply sniffing algorithm to body.
|
||||
// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shouldn't do sniffing.
|
||||
_, haveType := w.header["Content-Type"]
|
||||
|
||||
// If the Content-Encoding was set and is non-blank, we shouldn't sniff the body.
|
||||
hasCE := w.header.Get("Content-Encoding") != ""
|
||||
if !hasCE && !haveType && len(p) > 0 {
|
||||
w.header.Set("Content-Type", http.DetectContentType(p))
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
bodyAllowed := bodyAllowedForStatus(w.status)
|
||||
if !w.headerComplete {
|
||||
w.sniffContentType(p)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
bodyAllowed = true
|
||||
}
|
||||
if !bodyAllowed {
|
||||
return 0, http.ErrBodyNotAllowed
|
||||
}
|
||||
|
||||
w.numWritten += int64(len(p))
|
||||
if w.contentLen != 0 && w.numWritten > w.contentLen {
|
||||
return 0, http.ErrContentLength
|
||||
}
|
||||
|
||||
if w.isHead {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
if !w.headerWritten {
|
||||
// Buffer small responses.
|
||||
// This allows us to automatically set the Content-Length field.
|
||||
if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize {
|
||||
w.smallResponseBuf = append(w.smallResponseBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
}
|
||||
return w.doWrite(p)
|
||||
}
|
||||
|
||||
func (w *responseWriter) doWrite(p []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.sniffContentType(w.smallResponseBuf)
|
||||
if err := w.writeHeader(w.status); err != nil {
|
||||
return 0, maybeReplaceError(err)
|
||||
}
|
||||
w.headerWritten = true
|
||||
}
|
||||
|
||||
l := uint64(len(w.smallResponseBuf) + len(p))
|
||||
if l == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
df := &dataFrame{Length: l}
|
||||
w.buf = w.buf[:0]
|
||||
w.buf = df.Append(w.buf)
|
||||
if w.str.qlogger != nil {
|
||||
w.str.qlogger.RecordEvent(qlog.FrameCreated{
|
||||
StreamID: w.str.StreamID(),
|
||||
Raw: qlog.RawInfo{Length: len(w.buf) + int(l), PayloadLength: int(l)},
|
||||
Frame: qlog.Frame{Frame: qlog.DataFrame{}},
|
||||
})
|
||||
}
|
||||
if _, err := w.str.writeUnframed(w.buf); err != nil {
|
||||
return 0, maybeReplaceError(err)
|
||||
}
|
||||
if len(w.smallResponseBuf) > 0 {
|
||||
if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil {
|
||||
return 0, maybeReplaceError(err)
|
||||
}
|
||||
w.smallResponseBuf = nil
|
||||
}
|
||||
var n int
|
||||
if len(p) > 0 {
|
||||
var err error
|
||||
n, err = w.str.writeUnframed(p)
|
||||
if err != nil {
|
||||
return n, maybeReplaceError(err)
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *responseWriter) writeHeader(status int) error {
|
||||
var headerFields []qlog.HeaderField // only used for qlog
|
||||
var headers bytes.Buffer
|
||||
enc := qpack.NewEncoder(&headers)
|
||||
if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.str.qlogger != nil {
|
||||
headerFields = append(headerFields, qlog.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
||||
}
|
||||
|
||||
// Handle trailer fields
|
||||
if vals, ok := w.header["Trailer"]; ok {
|
||||
for _, val := range vals {
|
||||
for _, trailer := range strings.Split(val, ",") {
|
||||
// We need to convert to the canonical header key value here because this will be called when using
|
||||
// headers.Add or headers.Set.
|
||||
trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer))
|
||||
w.declareTrailer(trailer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range w.header {
|
||||
if _, excluded := w.trailers[k]; excluded {
|
||||
continue
|
||||
}
|
||||
// Ignore "Trailer:" prefixed headers
|
||||
if strings.HasPrefix(k, http.TrailerPrefix) {
|
||||
continue
|
||||
}
|
||||
for index := range v {
|
||||
name := strings.ToLower(k)
|
||||
value := v[index]
|
||||
if err := enc.WriteField(qpack.HeaderField{Name: name, Value: value}); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.str.qlogger != nil {
|
||||
headerFields = append(headerFields, qlog.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, frameHeaderLen+headers.Len())
|
||||
buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
|
||||
buf = append(buf, headers.Bytes()...)
|
||||
|
||||
if w.str.qlogger != nil {
|
||||
qlogCreatedHeadersFrame(w.str.qlogger, w.str.StreamID(), len(buf), headers.Len(), headerFields)
|
||||
}
|
||||
|
||||
_, err := w.str.writeUnframed(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *responseWriter) FlushError() error {
|
||||
if !w.headerComplete {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
_, err := w.doWrite(nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *responseWriter) flushTrailers() {
|
||||
if w.trailerWritten {
|
||||
return
|
||||
}
|
||||
if err := w.writeTrailers(); err != nil {
|
||||
w.logger.Debug("could not write trailers", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Flush() {
|
||||
if err := w.FlushError(); err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("could not flush to stream", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a
|
||||
// valid name.
|
||||
func (w *responseWriter) declareTrailer(k string) {
|
||||
if !httpguts.ValidTrailerHeader(k) {
|
||||
// Forbidden by RFC 9110, section 6.5.1.
|
||||
w.logger.Debug("ignoring invalid trailer", slog.String("header", k))
|
||||
return
|
||||
}
|
||||
if w.trailers == nil {
|
||||
w.trailers = make(map[string]struct{})
|
||||
}
|
||||
w.trailers[k] = struct{}{}
|
||||
}
|
||||
|
||||
// writeTrailers will write trailers to the stream if there are any.
|
||||
func (w *responseWriter) writeTrailers() error {
|
||||
// promote headers added via "Trailer:" convention as trailers, these can be added after
|
||||
// streaming the status/headers have been written.
|
||||
for k := range w.header {
|
||||
if strings.HasPrefix(k, http.TrailerPrefix) {
|
||||
w.declareTrailer(k)
|
||||
}
|
||||
}
|
||||
|
||||
if len(w.trailers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
trailers := make(http.Header, len(w.trailers))
|
||||
for trailer := range w.trailers {
|
||||
if vals, ok := w.header[trailer]; ok {
|
||||
trailers[strings.TrimPrefix(trailer, http.TrailerPrefix)] = vals
|
||||
}
|
||||
}
|
||||
|
||||
written, err := writeTrailers(w.str.datagramStream, trailers, w.str.StreamID(), w.str.qlogger)
|
||||
if written {
|
||||
w.trailerWritten = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *responseWriter) HTTPStream() *Stream {
|
||||
w.hijacked = true
|
||||
w.Flush()
|
||||
return w.str
|
||||
}
|
||||
|
||||
func (w *responseWriter) wasStreamHijacked() bool { return w.hijacked }
|
||||
|
||||
func (w *responseWriter) ReceivedSettings() <-chan struct{} {
|
||||
return w.conn.ReceivedSettings()
|
||||
}
|
||||
|
||||
func (w *responseWriter) Settings() *Settings {
|
||||
return w.conn.Settings()
|
||||
}
|
||||
|
||||
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
|
||||
return w.str.SetReadDeadline(deadline)
|
||||
}
|
||||
|
||||
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
|
||||
return w.str.SetWriteDeadline(deadline)
|
||||
}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == http.StatusNoContent:
|
||||
return false
|
||||
case status == http.StatusNotModified:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
745
vendor/github.com/quic-go/quic-go/http3/server.go
generated
vendored
Normal file
745
vendor/github.com/quic-go/quic-go/http3/server.go
generated
vendored
Normal file
@@ -0,0 +1,745 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2.
|
||||
const NextProtoH3 = "h3"
|
||||
|
||||
// StreamType is the stream type of a unidirectional stream.
|
||||
type StreamType uint64
|
||||
|
||||
const (
|
||||
streamTypeControlStream = 0
|
||||
streamTypePushStream = 1
|
||||
streamTypeQPACKEncoderStream = 2
|
||||
streamTypeQPACKDecoderStream = 3
|
||||
)
|
||||
|
||||
// A QUICListener listens for incoming QUIC connections.
|
||||
type QUICListener interface {
|
||||
Accept(context.Context) (*quic.Conn, error)
|
||||
Addr() net.Addr
|
||||
io.Closer
|
||||
}
|
||||
|
||||
var _ QUICListener = &quic.EarlyListener{}
|
||||
|
||||
// ConfigureTLSConfig creates a new tls.Config which can be used
|
||||
// to create a quic.Listener meant for serving HTTP/3.
|
||||
func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
|
||||
// Workaround for https://github.com/golang/go/issues/60506.
|
||||
// This initializes the session tickets _before_ cloning the config.
|
||||
_, _ = tlsConf.DecryptTicket(nil, tls.ConnectionState{})
|
||||
config := tlsConf.Clone()
|
||||
config.NextProtos = []string{NextProtoH3}
|
||||
if gfc := config.GetConfigForClient; gfc != nil {
|
||||
config.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
conf, err := gfc(ch)
|
||||
if conf == nil || err != nil {
|
||||
return conf, err
|
||||
}
|
||||
return ConfigureTLSConfig(conf), nil
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// contextKey is a value for use with context.WithValue. It's used as
|
||||
// a pointer so it fits in an interface{} without allocation.
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name }
|
||||
|
||||
// ServerContextKey is a context key. It can be used in HTTP
|
||||
// handlers with Context.Value to access the server that
|
||||
// started the handler. The associated value will be of
|
||||
// type *http3.Server.
|
||||
var ServerContextKey = &contextKey{"http3-server"}
|
||||
|
||||
// RemoteAddrContextKey is a context key. It can be used in
|
||||
// HTTP handlers with Context.Value to access the remote
|
||||
// address of the connection. The associated value will be of
|
||||
// type net.Addr.
|
||||
//
|
||||
// Use this value instead of [http.Request.RemoteAddr] if you
|
||||
// require access to the remote address of the connection rather
|
||||
// than its string representation.
|
||||
var RemoteAddrContextKey = &contextKey{"remote-addr"}
|
||||
|
||||
// listener contains info about specific listener added with addListener
|
||||
type listener struct {
|
||||
ln *QUICListener
|
||||
port int // 0 means that no info about port is available
|
||||
|
||||
// if this listener was constructed by the application, it won't be closed when the server is closed
|
||||
createdLocally bool
|
||||
}
|
||||
|
||||
// Server is a HTTP/3 server.
|
||||
type Server struct {
|
||||
// Addr optionally specifies the UDP address for the server to listen on,
|
||||
// in the form "host:port".
|
||||
//
|
||||
// When used by ListenAndServe and ListenAndServeTLS methods, if empty,
|
||||
// ":https" (port 443) is used. See net.Dial for details of the address
|
||||
// format.
|
||||
//
|
||||
// Otherwise, if Port is not set and underlying QUIC listeners do not
|
||||
// have valid port numbers, the port part is used in Alt-Svc headers set
|
||||
// with SetQUICHeaders.
|
||||
Addr string
|
||||
|
||||
// Port is used in Alt-Svc response headers set with SetQUICHeaders. If
|
||||
// needed Port can be manually set when the Server is created.
|
||||
//
|
||||
// This is useful when a Layer 4 firewall is redirecting UDP traffic and
|
||||
// clients must use a port different from the port the Server is
|
||||
// listening on.
|
||||
Port int
|
||||
|
||||
// TLSConfig provides a TLS configuration for use by server. It must be
|
||||
// set for ListenAndServe and Serve methods.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// QUICConfig provides the parameters for QUIC connection created with Serve.
|
||||
// If nil, it uses reasonable default values.
|
||||
//
|
||||
// Configured versions are also used in Alt-Svc response header set with SetQUICHeaders.
|
||||
QUICConfig *quic.Config
|
||||
|
||||
// Handler is the HTTP request handler to use. If not set, defaults to
|
||||
// http.NotFound.
|
||||
Handler http.Handler
|
||||
|
||||
// EnableDatagrams enables support for HTTP/3 datagrams (RFC 9297).
|
||||
// If set to true, QUICConfig.EnableDatagrams will be set.
|
||||
EnableDatagrams bool
|
||||
|
||||
// MaxHeaderBytes controls the maximum number of bytes the server will
|
||||
// read parsing the request HEADERS frame. It does not limit the size of
|
||||
// the request body. If zero or negative, http.DefaultMaxHeaderBytes is
|
||||
// used.
|
||||
MaxHeaderBytes int
|
||||
|
||||
// AdditionalSettings specifies additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
|
||||
AdditionalSettings map[uint64]uint64
|
||||
|
||||
// IdleTimeout specifies how long until idle clients connection should be
|
||||
// closed. Idle refers only to the HTTP/3 layer, activity at the QUIC layer
|
||||
// like PING frames are not considered.
|
||||
// If zero or negative, there is no timeout.
|
||||
IdleTimeout time.Duration
|
||||
|
||||
// ConnContext optionally specifies a function that modifies the context used for a new connection c.
|
||||
// The provided ctx has a ServerContextKey value.
|
||||
ConnContext func(ctx context.Context, c *quic.Conn) context.Context
|
||||
|
||||
Logger *slog.Logger
|
||||
|
||||
mutex sync.RWMutex
|
||||
listeners []listener
|
||||
|
||||
closed bool
|
||||
closeCtx context.Context // canceled when the server is closed
|
||||
closeCancel context.CancelFunc // cancels the closeCtx
|
||||
graceCtx context.Context // canceled when the server is closed or gracefully closed
|
||||
graceCancel context.CancelFunc // cancels the graceCtx
|
||||
connCount atomic.Int64
|
||||
connHandlingDone chan struct{}
|
||||
|
||||
altSvcHeader string
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||
//
|
||||
// If s.Addr is blank, ":https" is used.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
ln, err := s.setupListenerForConn(s.TLSConfig, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.removeListener(ln)
|
||||
|
||||
return s.serveListener(*ln)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||
//
|
||||
// If s.Addr is blank, ":https" is used.
|
||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
ln, err := s.setupListenerForConn(&tls.Config{Certificates: certs}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.removeListener(ln)
|
||||
|
||||
return s.serveListener(*ln)
|
||||
}
|
||||
|
||||
// Serve an existing UDP connection.
|
||||
// It is possible to reuse the same connection for outgoing connections.
|
||||
// Closing the server does not close the connection.
|
||||
func (s *Server) Serve(conn net.PacketConn) error {
|
||||
ln, err := s.setupListenerForConn(s.TLSConfig, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.removeListener(ln)
|
||||
|
||||
return s.serveListener(*ln)
|
||||
}
|
||||
|
||||
// init initializes the contexts used for shutting down the server.
|
||||
// It must be called with the mutex held.
|
||||
func (s *Server) init() {
|
||||
if s.closeCtx == nil {
|
||||
s.closeCtx, s.closeCancel = context.WithCancel(context.Background())
|
||||
s.graceCtx, s.graceCancel = context.WithCancel(s.closeCtx)
|
||||
}
|
||||
s.connHandlingDone = make(chan struct{}, 1)
|
||||
}
|
||||
|
||||
func (s *Server) decreaseConnCount() {
|
||||
if s.connCount.Add(-1) == 0 && s.graceCtx.Err() != nil {
|
||||
close(s.connHandlingDone)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeQUICConn serves a single QUIC connection.
|
||||
func (s *Server) ServeQUICConn(conn *quic.Conn) error {
|
||||
s.mutex.Lock()
|
||||
if s.closed {
|
||||
s.mutex.Unlock()
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
|
||||
s.init()
|
||||
s.mutex.Unlock()
|
||||
|
||||
s.connCount.Add(1)
|
||||
defer s.decreaseConnCount()
|
||||
|
||||
return s.handleConn(conn)
|
||||
}
|
||||
|
||||
// ServeListener serves an existing QUIC listener.
|
||||
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
|
||||
// and use it to construct a http3-friendly QUIC listener.
|
||||
// Closing the server does not close the listener. It is the application's responsibility to close them.
|
||||
// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
|
||||
func (s *Server) ServeListener(ln QUICListener) error {
|
||||
s.mutex.Lock()
|
||||
if err := s.addListener(&ln, false); err != nil {
|
||||
s.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
defer s.removeListener(&ln)
|
||||
|
||||
return s.serveListener(ln)
|
||||
}
|
||||
|
||||
func (s *Server) serveListener(ln QUICListener) error {
|
||||
for {
|
||||
conn, err := ln.Accept(s.graceCtx)
|
||||
// server closed
|
||||
if errors.Is(err, quic.ErrServerClosed) || s.graceCtx.Err() != nil {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.connCount.Add(1)
|
||||
go func() {
|
||||
defer s.decreaseConnCount()
|
||||
if err := s.handleConn(conn); err != nil {
|
||||
if s.Logger != nil {
|
||||
s.Logger.Debug("handling connection failed", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
|
||||
|
||||
func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) (*QUICListener, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errServerWithoutTLSConfig
|
||||
}
|
||||
|
||||
baseConf := ConfigureTLSConfig(tlsConf)
|
||||
quicConf := s.QUICConfig
|
||||
if quicConf == nil {
|
||||
quicConf = &quic.Config{Allow0RTT: true}
|
||||
} else {
|
||||
quicConf = s.QUICConfig.Clone()
|
||||
}
|
||||
if s.EnableDatagrams {
|
||||
quicConf.EnableDatagrams = true
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
closed := s.closed
|
||||
if closed {
|
||||
return nil, http.ErrServerClosed
|
||||
}
|
||||
|
||||
var ln QUICListener
|
||||
var err error
|
||||
if conn == nil {
|
||||
addr := s.Addr
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
}
|
||||
ln, err = quic.ListenAddrEarly(addr, baseConf, quicConf)
|
||||
} else {
|
||||
ln, err = quic.ListenEarly(conn, baseConf, quicConf)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.addListener(&ln, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ln, nil
|
||||
}
|
||||
|
||||
func extractPort(addr string) (int, error) {
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
portInt, err := net.LookupPort("tcp", portStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return portInt, nil
|
||||
}
|
||||
|
||||
func (s *Server) generateAltSvcHeader() {
|
||||
if len(s.listeners) == 0 {
|
||||
// Don't announce any ports since no one is listening for connections
|
||||
s.altSvcHeader = ""
|
||||
return
|
||||
}
|
||||
|
||||
// This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed.
|
||||
|
||||
var altSvc []string
|
||||
addPort := func(port int) {
|
||||
altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, NextProtoH3, port))
|
||||
}
|
||||
|
||||
if s.Port != 0 {
|
||||
// if Port is specified, we must use it instead of the
|
||||
// listener addresses since there's a reason it's specified.
|
||||
addPort(s.Port)
|
||||
} else {
|
||||
// if we have some listeners assigned, try to find ports
|
||||
// which we can announce, otherwise nothing should be announced
|
||||
validPortsFound := false
|
||||
for _, info := range s.listeners {
|
||||
if info.port != 0 {
|
||||
addPort(info.port)
|
||||
validPortsFound = true
|
||||
}
|
||||
}
|
||||
if !validPortsFound {
|
||||
if port, err := extractPort(s.Addr); err == nil {
|
||||
addPort(port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.altSvcHeader = strings.Join(altSvc, ",")
|
||||
}
|
||||
|
||||
func (s *Server) addListener(l *QUICListener, createdLocally bool) error {
|
||||
if s.closed {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
s.init()
|
||||
|
||||
laddr := (*l).Addr()
|
||||
if port, err := extractPort(laddr.String()); err == nil {
|
||||
s.listeners = append(s.listeners, listener{ln: l, port: port, createdLocally: createdLocally})
|
||||
} else {
|
||||
logger := s.Logger
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
logger.Error("Unable to extract port from listener, will not be announced using SetQUICHeaders", "local addr", laddr, "error", err)
|
||||
s.listeners = append(s.listeners, listener{ln: l, port: 0, createdLocally: createdLocally})
|
||||
}
|
||||
s.generateAltSvcHeader()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) removeListener(l *QUICListener) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.listeners = slices.DeleteFunc(s.listeners, func(info listener) bool {
|
||||
return info.ln == l
|
||||
})
|
||||
s.generateAltSvcHeader()
|
||||
}
|
||||
|
||||
func (s *Server) NewRawServerConn(conn *quic.Conn) (*RawServerConn, error) {
|
||||
hconn, _, _, err := s.newRawServerConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hconn, nil
|
||||
}
|
||||
|
||||
func (s *Server) newRawServerConn(conn *quic.Conn) (*RawServerConn, *quic.SendStream, qlogwriter.Recorder, error) {
|
||||
var qlogger qlogwriter.Recorder
|
||||
if qlogTrace := conn.QlogTrace(); qlogTrace != nil && qlogTrace.SupportsSchemas(qlog.EventSchema) {
|
||||
qlogger = qlogTrace.AddProducer()
|
||||
}
|
||||
connCtx := conn.Context()
|
||||
connCtx = context.WithValue(connCtx, ServerContextKey, s)
|
||||
connCtx = context.WithValue(connCtx, http.LocalAddrContextKey, conn.LocalAddr())
|
||||
connCtx = context.WithValue(connCtx, RemoteAddrContextKey, conn.RemoteAddr())
|
||||
if s.ConnContext != nil {
|
||||
connCtx = s.ConnContext(connCtx, conn)
|
||||
if connCtx == nil {
|
||||
panic("http3: ConnContext returned nil")
|
||||
}
|
||||
}
|
||||
hconn := newRawServerConn(
|
||||
conn,
|
||||
s.EnableDatagrams,
|
||||
s.IdleTimeout,
|
||||
qlogger,
|
||||
s.Logger,
|
||||
connCtx,
|
||||
s.Handler,
|
||||
s.maxHeaderBytes(),
|
||||
)
|
||||
|
||||
// open the control stream and send a SETTINGS frame, it's also used to send a GOAWAY frame later
|
||||
// when the server is gracefully closed
|
||||
ctrlStr, err := hconn.openControlStream(&settingsFrame{
|
||||
MaxFieldSectionSize: int64(s.maxHeaderBytes()),
|
||||
Datagram: s.EnableDatagrams,
|
||||
ExtendedConnect: true,
|
||||
Other: s.AdditionalSettings,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("opening the control stream failed: %w", err)
|
||||
}
|
||||
return hconn, ctrlStr, qlogger, nil
|
||||
}
|
||||
|
||||
// handleConn handles the HTTP/3 exchange on a QUIC connection.
|
||||
// It blocks until all HTTP handlers for all streams have returned.
|
||||
func (s *Server) handleConn(conn *quic.Conn) error {
|
||||
hconn, ctrlStr, qlogger, err := s.newRawServerConn(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go hconn.HandleUnidirectionalStream(str)
|
||||
}
|
||||
}()
|
||||
|
||||
var nextStreamID quic.StreamID
|
||||
var handleErr error
|
||||
var inGracefulShutdown bool
|
||||
// Process all requests immediately.
|
||||
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
|
||||
ctx := s.graceCtx
|
||||
for {
|
||||
// The context used here is:
|
||||
// * before graceful shutdown: s.graceCtx
|
||||
// * after graceful shutdown: s.closeCtx
|
||||
// This allows us to keep accepting (and resetting) streams after graceful shutdown has started.
|
||||
str, err := conn.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
// the underlying connection was closed (by either side)
|
||||
if conn.Context().Err() != nil {
|
||||
var appErr *quic.ApplicationError
|
||||
if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) {
|
||||
handleErr = fmt.Errorf("accepting stream failed: %w", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
// server (not gracefully) closed, close the connection immediately
|
||||
if s.closeCtx.Err() != nil {
|
||||
hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
|
||||
handleErr = http.ErrServerClosed
|
||||
break
|
||||
}
|
||||
inGracefulShutdown = s.graceCtx.Err() != nil
|
||||
if !inGracefulShutdown {
|
||||
var appErr *quic.ApplicationError
|
||||
if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) {
|
||||
handleErr = fmt.Errorf("accepting stream failed: %w", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// gracefully closed, send GOAWAY frame and wait for requests to complete or grace period to end
|
||||
// new requests will be rejected and shouldn't be sent
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.FrameCreated{
|
||||
StreamID: ctrlStr.StreamID(),
|
||||
Frame: qlog.Frame{Frame: qlog.GoAwayFrame{StreamID: nextStreamID}},
|
||||
})
|
||||
}
|
||||
wg.Add(1)
|
||||
// Send the GOAWAY frame in a separate Goroutine.
|
||||
// Sending might block if the peer didn't grant enough flow control credit.
|
||||
// Write is guaranteed to return once the connection is closed.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = ctrlStr.Write((&goAwayFrame{StreamID: nextStreamID}).Append(nil))
|
||||
}()
|
||||
ctx = s.closeCtx
|
||||
continue
|
||||
}
|
||||
if inGracefulShutdown {
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestRejected))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestRejected))
|
||||
continue
|
||||
}
|
||||
|
||||
nextStreamID = str.StreamID() + 4
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// HandleRequestStream will return once the request has been handled,
|
||||
// or the underlying connection is closed.
|
||||
defer wg.Done()
|
||||
hconn.HandleRequestStream(str)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
return handleErr
|
||||
}
|
||||
|
||||
func (s *Server) maxHeaderBytes() int {
|
||||
if s.MaxHeaderBytes <= 0 {
|
||||
return http.DefaultMaxHeaderBytes
|
||||
}
|
||||
return s.MaxHeaderBytes
|
||||
}
|
||||
|
||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
||||
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
// It is the caller's responsibility to close any connection passed to ServeQUICConn.
|
||||
func (s *Server) Close() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.closed = true
|
||||
// server is never used
|
||||
if s.closeCtx == nil {
|
||||
return nil
|
||||
}
|
||||
s.closeCancel()
|
||||
|
||||
var err error
|
||||
for _, l := range s.listeners {
|
||||
if l.createdLocally {
|
||||
if cerr := (*l.ln).Close(); cerr != nil && err == nil {
|
||||
err = cerr
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.connCount.Load() == 0 {
|
||||
return err
|
||||
}
|
||||
// wait for all connections to be closed
|
||||
<-s.connHandlingDone
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server without interrupting any active connections.
|
||||
// The server sends a GOAWAY frame first, then or for all running requests to complete.
|
||||
// Shutdown in combination with ListenAndServe may race if it is called before a UDP socket is established.
|
||||
// It is recommended to use Serve instead.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
s.mutex.Lock()
|
||||
s.closed = true
|
||||
// server was never used
|
||||
if s.closeCtx == nil {
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.graceCancel()
|
||||
|
||||
// close all listeners
|
||||
var closeErrs []error
|
||||
for _, l := range s.listeners {
|
||||
if l.createdLocally {
|
||||
if err := (*l.ln).Close(); err != nil {
|
||||
closeErrs = append(closeErrs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
if len(closeErrs) > 0 {
|
||||
return errors.Join(closeErrs...)
|
||||
}
|
||||
|
||||
if s.connCount.Load() == 0 {
|
||||
return s.Close()
|
||||
}
|
||||
select {
|
||||
case <-s.connHandlingDone: // all connections were closed
|
||||
// When receiving a GOAWAY frame, HTTP/3 clients are expected to close the connection
|
||||
// once all requests were successfully handled...
|
||||
return s.Close()
|
||||
case <-ctx.Done():
|
||||
// ... however, clients handling long-lived requests (and misbehaving clients),
|
||||
// might not do so before the context is cancelled.
|
||||
// In this case, we close the server, which closes all existing connections
|
||||
// (expect those passed to ServeQUICConn).
|
||||
_ = s.Close()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ErrNoAltSvcPort is the error returned by SetQUICHeaders when no port was found
|
||||
// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port
|
||||
// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr.
|
||||
var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr")
|
||||
|
||||
// SetQUICHeaders can be used to set the proper headers that announce that this server supports HTTP/3.
|
||||
// The values set by default advertise all the ports the server is listening on, but can be
|
||||
// changed to a specific port by setting Server.Port before launching the server.
|
||||
// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used
|
||||
// to extract the port, if specified.
|
||||
// For example, a server launched using ListenAndServe on an address with port 443 would set:
|
||||
//
|
||||
// Alt-Svc: h3=":443"; ma=2592000
|
||||
func (s *Server) SetQUICHeaders(hdr http.Header) error {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
if s.altSvcHeader == "" {
|
||||
return ErrNoAltSvcPort
|
||||
}
|
||||
// use the map directly to avoid constant canonicalization since the key is already canonicalized
|
||||
hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
||||
// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
|
||||
// used when handler is nil.
|
||||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
server := &Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
}
|
||||
return server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the given network address for both TLS/TCP and QUIC
|
||||
// connections in parallel. It returns if one of the two returns an error.
|
||||
// http.DefaultServeMux is used when handler is nil.
|
||||
// The correct Alt-Svc headers for QUIC are set.
|
||||
func ListenAndServeTLS(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
// Load certs
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
}
|
||||
|
||||
// Open the listeners
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
// Start the servers
|
||||
quicServer := &Server{
|
||||
TLSConfig: config,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
hErr := make(chan error, 1)
|
||||
qErr := make(chan error, 1)
|
||||
go func() {
|
||||
hErr <- http.ListenAndServeTLS(addr, certFile, keyFile, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
quicServer.SetQUICHeaders(w.Header())
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
}()
|
||||
go func() {
|
||||
qErr <- quicServer.Serve(udpConn)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-hErr:
|
||||
quicServer.Close()
|
||||
return err
|
||||
case err := <-qErr:
|
||||
// Cannot close the HTTP server or wait for requests to complete properly :/
|
||||
return err
|
||||
}
|
||||
}
|
||||
261
vendor/github.com/quic-go/quic-go/http3/server_conn.go
generated
vendored
Normal file
261
vendor/github.com/quic-go/quic-go/http3/server_conn.go
generated
vendored
Normal file
@@ -0,0 +1,261 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
// RawServerConn is an HTTP/3 server connection.
|
||||
// It can be used for advanced use cases where the application wants to manage the QUIC connection lifecycle.
|
||||
type RawServerConn struct {
|
||||
rawConn rawConn
|
||||
|
||||
idleTimeout time.Duration
|
||||
idleTimer *time.Timer
|
||||
|
||||
serverContext context.Context
|
||||
requestHandler http.Handler
|
||||
maxHeaderBytes int
|
||||
|
||||
decoder *qpack.Decoder
|
||||
|
||||
qlogger qlogwriter.Recorder
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func newRawServerConn(
|
||||
conn *quic.Conn,
|
||||
enableDatagrams bool,
|
||||
idleTimeout time.Duration,
|
||||
qlogger qlogwriter.Recorder,
|
||||
logger *slog.Logger,
|
||||
serverContext context.Context,
|
||||
requestHandler http.Handler,
|
||||
maxHeaderBytes int,
|
||||
) *RawServerConn {
|
||||
c := &RawServerConn{
|
||||
idleTimeout: idleTimeout,
|
||||
serverContext: serverContext,
|
||||
requestHandler: requestHandler,
|
||||
maxHeaderBytes: maxHeaderBytes,
|
||||
decoder: qpack.NewDecoder(),
|
||||
qlogger: qlogger,
|
||||
logger: logger,
|
||||
}
|
||||
c.rawConn = *newRawConn(conn, enableDatagrams, c.onStreamsEmpty, nil, qlogger, logger)
|
||||
if idleTimeout > 0 {
|
||||
c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *RawServerConn) onStreamsEmpty() {
|
||||
if c.idleTimeout > 0 {
|
||||
c.idleTimer.Reset(c.idleTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RawServerConn) onIdleTimer() {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout")
|
||||
}
|
||||
|
||||
// CloseWithError closes the connection with the given error code and message.
|
||||
func (c *RawServerConn) CloseWithError(code quic.ApplicationErrorCode, msg string) error {
|
||||
if c.idleTimer != nil {
|
||||
c.idleTimer.Stop()
|
||||
}
|
||||
return c.rawConn.CloseWithError(code, msg)
|
||||
}
|
||||
|
||||
// HandleRequestStream handles an HTTP/3 request on a bidirectional request stream.
|
||||
// The stream can either be obtained by calling AcceptStream on the underlying QUIC connection,
|
||||
// or (internally) by using the server's stream accept loop.
|
||||
func (c *RawServerConn) HandleRequestStream(str *quic.Stream) {
|
||||
hstr := c.rawConn.TrackStream(str)
|
||||
c.handleRequestStream(hstr)
|
||||
}
|
||||
|
||||
func (c *RawServerConn) requestMaxHeaderBytes() int {
|
||||
if c.maxHeaderBytes <= 0 {
|
||||
return http.DefaultMaxHeaderBytes
|
||||
}
|
||||
return c.maxHeaderBytes
|
||||
}
|
||||
|
||||
func (c *RawServerConn) openControlStream(settings *settingsFrame) (*quic.SendStream, error) {
|
||||
return c.rawConn.openControlStream(settings)
|
||||
}
|
||||
|
||||
func (c *RawServerConn) handleRequestStream(str *stateTrackingStream) {
|
||||
if c.idleTimeout > 0 {
|
||||
// This only applies if the stream is the first active stream,
|
||||
// but it's ok to stop a stopped timer.
|
||||
c.idleTimer.Stop()
|
||||
}
|
||||
|
||||
conn := &c.rawConn
|
||||
qlogger := c.qlogger
|
||||
decoder := c.decoder
|
||||
connCtx := c.serverContext
|
||||
maxHeaderBytes := c.requestMaxHeaderBytes()
|
||||
|
||||
fp := &frameParser{closeConn: conn.CloseWithError, r: str, streamID: str.StreamID()}
|
||||
frame, err := fp.ParseNext(qlogger)
|
||||
if err != nil {
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
return
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
|
||||
return
|
||||
}
|
||||
if hf.Length > uint64(maxHeaderBytes) {
|
||||
maybeQlogInvalidHeadersFrame(qlogger, str.StreamID(), hf.Length)
|
||||
// stop the client from sending more data
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad))
|
||||
// send a 431 Response (Request Header Fields Too Large)
|
||||
c.rejectWithHeaderFieldsTooLarge(str)
|
||||
return
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||
maybeQlogInvalidHeadersFrame(qlogger, str.StreamID(), hf.Length)
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
return
|
||||
}
|
||||
decodeFn := decoder.Decode(headerBlock)
|
||||
var hfs []qpack.HeaderField
|
||||
if qlogger != nil {
|
||||
hfs = make([]qpack.HeaderField, 0, 16)
|
||||
}
|
||||
req, err := requestFromHeaders(decodeFn, maxHeaderBytes, &hfs)
|
||||
if qlogger != nil {
|
||||
qlogParsedHeadersFrame(qlogger, str.StreamID(), hf, hfs)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, errHeaderTooLarge) {
|
||||
// stop the client from sending more data
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad))
|
||||
// send a 431 Response (Request Header Fields Too Large)
|
||||
c.rejectWithHeaderFieldsTooLarge(str)
|
||||
return
|
||||
}
|
||||
|
||||
errCode := ErrCodeMessageError
|
||||
var qpackErr *qpackError
|
||||
if errors.As(err, &qpackErr) {
|
||||
errCode = ErrCodeQPACKDecompressionFailed
|
||||
}
|
||||
str.CancelRead(quic.StreamErrorCode(errCode))
|
||||
str.CancelWrite(quic.StreamErrorCode(errCode))
|
||||
return
|
||||
}
|
||||
|
||||
connState := conn.ConnectionState().TLS
|
||||
req.TLS = &connState
|
||||
req.RemoteAddr = conn.RemoteAddr().String()
|
||||
|
||||
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
contentLength := int64(-1)
|
||||
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
|
||||
contentLength = req.ContentLength
|
||||
}
|
||||
hstr := newStream(str, conn, nil, func(r io.Reader, hf *headersFrame) error {
|
||||
trailers, err := decodeTrailers(r, hf, maxHeaderBytes, decoder, qlogger, str.StreamID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Trailer = trailers
|
||||
return nil
|
||||
}, qlogger)
|
||||
body := newRequestBody(hstr, contentLength, connCtx, conn.ReceivedSettings(), conn.Settings)
|
||||
req.Body = body
|
||||
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(connCtx)
|
||||
req = req.WithContext(ctx)
|
||||
context.AfterFunc(str.Context(), cancel)
|
||||
|
||||
r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, c.logger)
|
||||
handler := c.requestHandler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
|
||||
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
|
||||
var panicked bool
|
||||
func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicked = true
|
||||
if p == http.ErrAbortHandler {
|
||||
return
|
||||
}
|
||||
// Copied from net/http/server.go
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
logger := c.logger
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
logger.Error("http3: panic serving", "arg", p, "trace", string(buf))
|
||||
}
|
||||
}()
|
||||
handler.ServeHTTP(r, req)
|
||||
}()
|
||||
|
||||
if r.wasStreamHijacked() {
|
||||
return
|
||||
}
|
||||
|
||||
// abort the stream when there is a panic
|
||||
if panicked {
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
|
||||
return
|
||||
}
|
||||
|
||||
// response not written to the client yet, set Content-Length
|
||||
if !r.headerWritten {
|
||||
if _, haveCL := r.header["Content-Length"]; !haveCL {
|
||||
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
|
||||
}
|
||||
}
|
||||
r.Flush()
|
||||
r.flushTrailers()
|
||||
|
||||
// If the EOF was read by the handler, CancelRead() is a no-op.
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
|
||||
str.Close()
|
||||
}
|
||||
|
||||
func (c *RawServerConn) rejectWithHeaderFieldsTooLarge(str *stateTrackingStream) {
|
||||
hstr := newStream(str, &c.rawConn, nil, nil, c.qlogger)
|
||||
defer hstr.Close()
|
||||
r := newResponseWriter(hstr, &c.rawConn, false, c.logger)
|
||||
r.WriteHeader(http.StatusRequestHeaderFieldsTooLarge)
|
||||
r.Flush()
|
||||
}
|
||||
|
||||
// HandleUnidirectionalStream handles an incoming unidirectional stream.
|
||||
func (c *RawServerConn) HandleUnidirectionalStream(str *quic.ReceiveStream) {
|
||||
c.rawConn.handleUnidirectionalStream(str, true)
|
||||
}
|
||||
173
vendor/github.com/quic-go/quic-go/http3/state_tracking_stream.go
generated
vendored
Normal file
173
vendor/github.com/quic-go/quic-go/http3/state_tracking_stream.go
generated
vendored
Normal file
@@ -0,0 +1,173 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
const streamDatagramQueueLen = 32
|
||||
|
||||
// stateTrackingStream is an implementation of quic.Stream that delegates
|
||||
// to an underlying stream
|
||||
// it takes care of proxying send and receive errors onto an implementation of
|
||||
// the errorSetter interface (intended to be occupied by a datagrammer)
|
||||
// it is also responsible for clearing the stream based on its ID from its
|
||||
// parent connection, this is done through the streamClearer interface when
|
||||
// both the send and receive sides are closed
|
||||
type stateTrackingStream struct {
|
||||
*quic.Stream
|
||||
|
||||
sendDatagram func([]byte) error
|
||||
hasData chan struct{}
|
||||
queue [][]byte // TODO: use a ring buffer
|
||||
|
||||
mx sync.Mutex
|
||||
sendErr error
|
||||
recvErr error
|
||||
|
||||
clearer streamClearer
|
||||
}
|
||||
|
||||
var _ datagramStream = &stateTrackingStream{}
|
||||
|
||||
type streamClearer interface {
|
||||
clearStream(quic.StreamID)
|
||||
}
|
||||
|
||||
func newStateTrackingStream(s *quic.Stream, clearer streamClearer, sendDatagram func([]byte) error) *stateTrackingStream {
|
||||
t := &stateTrackingStream{
|
||||
Stream: s,
|
||||
clearer: clearer,
|
||||
sendDatagram: sendDatagram,
|
||||
hasData: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
context.AfterFunc(s.Context(), func() {
|
||||
t.closeSend(context.Cause(s.Context()))
|
||||
})
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) closeSend(e error) {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
|
||||
// clear the stream the first time both the send
|
||||
// and receive are finished
|
||||
if s.sendErr == nil {
|
||||
if s.recvErr != nil {
|
||||
s.clearer.clearStream(s.StreamID())
|
||||
}
|
||||
s.sendErr = e
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) closeReceive(e error) {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
|
||||
// clear the stream the first time both the send
|
||||
// and receive are finished
|
||||
if s.recvErr == nil {
|
||||
if s.sendErr != nil {
|
||||
s.clearer.clearStream(s.StreamID())
|
||||
}
|
||||
s.recvErr = e
|
||||
s.signalHasDatagram()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) Close() error {
|
||||
s.closeSend(errors.New("write on closed stream"))
|
||||
return s.Stream.Close()
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
|
||||
s.closeSend(&quic.StreamError{StreamID: s.StreamID(), ErrorCode: e})
|
||||
s.Stream.CancelWrite(e)
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) Write(b []byte) (int, error) {
|
||||
n, err := s.Stream.Write(b)
|
||||
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.closeSend(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
|
||||
s.closeReceive(&quic.StreamError{StreamID: s.StreamID(), ErrorCode: e})
|
||||
s.Stream.CancelRead(e)
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) Read(b []byte) (int, error) {
|
||||
n, err := s.Stream.Read(b)
|
||||
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.closeReceive(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) SendDatagram(b []byte) error {
|
||||
s.mx.Lock()
|
||||
sendErr := s.sendErr
|
||||
s.mx.Unlock()
|
||||
if sendErr != nil {
|
||||
return sendErr
|
||||
}
|
||||
|
||||
return s.sendDatagram(b)
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) signalHasDatagram() {
|
||||
select {
|
||||
case s.hasData <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) enqueueDatagram(data []byte) {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
|
||||
if s.recvErr != nil {
|
||||
return
|
||||
}
|
||||
if len(s.queue) >= streamDatagramQueueLen {
|
||||
return
|
||||
}
|
||||
s.queue = append(s.queue, data)
|
||||
s.signalHasDatagram()
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||
start:
|
||||
s.mx.Lock()
|
||||
if len(s.queue) > 0 {
|
||||
data := s.queue[0]
|
||||
s.queue = s.queue[1:]
|
||||
s.mx.Unlock()
|
||||
return data, nil
|
||||
}
|
||||
if receiveErr := s.recvErr; receiveErr != nil {
|
||||
s.mx.Unlock()
|
||||
return nil, receiveErr
|
||||
}
|
||||
s.mx.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, context.Cause(ctx)
|
||||
case <-s.hasData:
|
||||
}
|
||||
goto start
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) QUICStream() *quic.Stream {
|
||||
return s.Stream
|
||||
}
|
||||
406
vendor/github.com/quic-go/quic-go/http3/stream.go
generated
vendored
Normal file
406
vendor/github.com/quic-go/quic-go/http3/stream.go
generated
vendored
Normal file
@@ -0,0 +1,406 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
type datagramStream interface {
|
||||
io.ReadWriteCloser
|
||||
CancelRead(quic.StreamErrorCode)
|
||||
CancelWrite(quic.StreamErrorCode)
|
||||
StreamID() quic.StreamID
|
||||
Context() context.Context
|
||||
SetDeadline(time.Time) error
|
||||
SetReadDeadline(time.Time) error
|
||||
SetWriteDeadline(time.Time) error
|
||||
SendDatagram(b []byte) error
|
||||
ReceiveDatagram(ctx context.Context) ([]byte, error)
|
||||
|
||||
QUICStream() *quic.Stream
|
||||
}
|
||||
|
||||
// A Stream is an HTTP/3 stream.
|
||||
//
|
||||
// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
|
||||
type Stream struct {
|
||||
datagramStream
|
||||
conn *rawConn
|
||||
frameParser *frameParser
|
||||
|
||||
buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers
|
||||
|
||||
bytesRemainingInFrame uint64
|
||||
|
||||
qlogger qlogwriter.Recorder
|
||||
|
||||
parseTrailer func(io.Reader, *headersFrame) error
|
||||
parsedTrailer bool
|
||||
}
|
||||
|
||||
func newStream(
|
||||
str datagramStream,
|
||||
conn *rawConn,
|
||||
trace *httptrace.ClientTrace,
|
||||
parseTrailer func(io.Reader, *headersFrame) error,
|
||||
qlogger qlogwriter.Recorder,
|
||||
) *Stream {
|
||||
return &Stream{
|
||||
datagramStream: str,
|
||||
conn: conn,
|
||||
buf: make([]byte, 16),
|
||||
qlogger: qlogger,
|
||||
parseTrailer: parseTrailer,
|
||||
frameParser: &frameParser{
|
||||
r: &tracingReader{Reader: str, trace: trace},
|
||||
streamID: str.StreamID(),
|
||||
closeConn: conn.CloseWithError,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stream) Read(b []byte) (int, error) {
|
||||
if s.bytesRemainingInFrame == 0 {
|
||||
parseLoop:
|
||||
for {
|
||||
frame, err := s.frameParser.ParseNext(s.qlogger)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *dataFrame:
|
||||
if s.parsedTrailer {
|
||||
return 0, errors.New("DATA frame received after trailers")
|
||||
}
|
||||
s.bytesRemainingInFrame = f.Length
|
||||
break parseLoop
|
||||
case *headersFrame:
|
||||
if s.parsedTrailer {
|
||||
maybeQlogInvalidHeadersFrame(s.qlogger, s.StreamID(), f.Length)
|
||||
return 0, errors.New("additional HEADERS frame received after trailers")
|
||||
}
|
||||
s.parsedTrailer = true
|
||||
return 0, s.parseTrailer(s.datagramStream, f)
|
||||
default:
|
||||
s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
|
||||
// parseNextFrame skips over unknown frame types
|
||||
// Therefore, this condition is only entered when we parsed another known frame type.
|
||||
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var n int
|
||||
var err error
|
||||
if s.bytesRemainingInFrame < uint64(len(b)) {
|
||||
n, err = s.datagramStream.Read(b[:s.bytesRemainingInFrame])
|
||||
} else {
|
||||
n, err = s.datagramStream.Read(b)
|
||||
}
|
||||
s.bytesRemainingInFrame -= uint64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *Stream) hasMoreData() bool {
|
||||
return s.bytesRemainingInFrame > 0
|
||||
}
|
||||
|
||||
func (s *Stream) Write(b []byte) (int, error) {
|
||||
s.buf = s.buf[:0]
|
||||
s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf)
|
||||
if s.qlogger != nil {
|
||||
s.qlogger.RecordEvent(qlog.FrameCreated{
|
||||
StreamID: s.StreamID(),
|
||||
Raw: qlog.RawInfo{
|
||||
Length: len(s.buf) + len(b),
|
||||
PayloadLength: len(b),
|
||||
},
|
||||
Frame: qlog.Frame{Frame: qlog.DataFrame{}},
|
||||
})
|
||||
}
|
||||
if _, err := s.datagramStream.Write(s.buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.datagramStream.Write(b)
|
||||
}
|
||||
|
||||
func (s *Stream) writeUnframed(b []byte) (int, error) {
|
||||
return s.datagramStream.Write(b)
|
||||
}
|
||||
|
||||
func (s *Stream) StreamID() quic.StreamID {
|
||||
return s.datagramStream.StreamID()
|
||||
}
|
||||
|
||||
func (s *Stream) SendDatagram(b []byte) error {
|
||||
// TODO: reject if datagrams are not negotiated (yet)
|
||||
return s.datagramStream.SendDatagram(b)
|
||||
}
|
||||
|
||||
func (s *Stream) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||
// TODO: reject if datagrams are not negotiated (yet)
|
||||
return s.datagramStream.ReceiveDatagram(ctx)
|
||||
}
|
||||
|
||||
// A RequestStream is a low-level abstraction representing an HTTP/3 request stream.
|
||||
// It decouples sending of the HTTP request from reading the HTTP response, allowing
|
||||
// the application to optimistically use the stream (and, for example, send datagrams)
|
||||
// before receiving the response.
|
||||
//
|
||||
// This is only needed for advanced use case, e.g. WebTransport and the various
|
||||
// MASQUE proxying protocols.
|
||||
type RequestStream struct {
|
||||
str *Stream
|
||||
|
||||
responseBody io.ReadCloser // set by ReadResponse
|
||||
|
||||
decoder *qpack.Decoder
|
||||
requestWriter *requestWriter
|
||||
maxHeaderBytes int
|
||||
reqDone chan<- struct{}
|
||||
disableCompression bool
|
||||
response *http.Response
|
||||
|
||||
sentRequest bool
|
||||
requestedGzip bool
|
||||
isConnect bool
|
||||
}
|
||||
|
||||
func newRequestStream(
|
||||
str *Stream,
|
||||
requestWriter *requestWriter,
|
||||
reqDone chan<- struct{},
|
||||
decoder *qpack.Decoder,
|
||||
disableCompression bool,
|
||||
maxHeaderBytes int,
|
||||
rsp *http.Response,
|
||||
) *RequestStream {
|
||||
return &RequestStream{
|
||||
str: str,
|
||||
requestWriter: requestWriter,
|
||||
reqDone: reqDone,
|
||||
decoder: decoder,
|
||||
disableCompression: disableCompression,
|
||||
maxHeaderBytes: maxHeaderBytes,
|
||||
response: rsp,
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads data from the underlying stream.
|
||||
//
|
||||
// It can only be used after the request has been sent (using SendRequestHeader)
|
||||
// and the response has been consumed (using ReadResponse).
|
||||
func (s *RequestStream) Read(b []byte) (int, error) {
|
||||
if s.responseBody == nil {
|
||||
return 0, errors.New("http3: invalid use of RequestStream.Read before ReadResponse")
|
||||
}
|
||||
return s.responseBody.Read(b)
|
||||
}
|
||||
|
||||
// StreamID returns the QUIC stream ID of the underlying QUIC stream.
|
||||
func (s *RequestStream) StreamID() quic.StreamID {
|
||||
return s.str.StreamID()
|
||||
}
|
||||
|
||||
// Write writes data to the stream.
|
||||
//
|
||||
// It can only be used after the request has been sent (using SendRequestHeader).
|
||||
func (s *RequestStream) Write(b []byte) (int, error) {
|
||||
if !s.sentRequest {
|
||||
return 0, errors.New("http3: invalid use of RequestStream.Write before SendRequestHeader")
|
||||
}
|
||||
return s.str.Write(b)
|
||||
}
|
||||
|
||||
// Close closes the send-direction of the stream.
|
||||
// It does not close the receive-direction of the stream.
|
||||
func (s *RequestStream) Close() error {
|
||||
return s.str.Close()
|
||||
}
|
||||
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// See [quic.Stream.CancelRead] for more details.
|
||||
func (s *RequestStream) CancelRead(errorCode quic.StreamErrorCode) {
|
||||
s.str.CancelRead(errorCode)
|
||||
}
|
||||
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// See [quic.Stream.CancelWrite] for more details.
|
||||
func (s *RequestStream) CancelWrite(errorCode quic.StreamErrorCode) {
|
||||
s.str.CancelWrite(errorCode)
|
||||
}
|
||||
|
||||
// Context returns a context derived from the underlying QUIC stream's context.
|
||||
// See [quic.Stream.Context] for more details.
|
||||
func (s *RequestStream) Context() context.Context {
|
||||
return s.str.Context()
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the deadline for Read calls.
|
||||
func (s *RequestStream) SetReadDeadline(t time.Time) error {
|
||||
return s.str.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the deadline for Write calls.
|
||||
func (s *RequestStream) SetWriteDeadline(t time.Time) error {
|
||||
return s.str.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines associated with the stream.
|
||||
// It is equivalent to calling both SetReadDeadline and SetWriteDeadline.
|
||||
func (s *RequestStream) SetDeadline(t time.Time) error {
|
||||
return s.str.SetDeadline(t)
|
||||
}
|
||||
|
||||
// SendDatagrams send a new HTTP Datagram (RFC 9297).
|
||||
//
|
||||
// It is only possible to send datagrams if the server enabled support for this extension.
|
||||
// It is recommended (though not required) to send the request before calling this method,
|
||||
// as the server might drop datagrams which it can't associate with an existing request.
|
||||
func (s *RequestStream) SendDatagram(b []byte) error {
|
||||
return s.str.SendDatagram(b)
|
||||
}
|
||||
|
||||
// ReceiveDatagram receives HTTP Datagrams (RFC 9297).
|
||||
//
|
||||
// It is only possible if support for HTTP Datagrams was enabled, using the EnableDatagram
|
||||
// option on the [Transport].
|
||||
func (s *RequestStream) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||
return s.str.ReceiveDatagram(ctx)
|
||||
}
|
||||
|
||||
// SendRequestHeader sends the HTTP request.
|
||||
//
|
||||
// It can only used for requests that don't have a request body.
|
||||
// It is invalid to call it more than once.
|
||||
// It is invalid to call it after Write has been called.
|
||||
func (s *RequestStream) SendRequestHeader(req *http.Request) error {
|
||||
if req.Body != nil && req.Body != http.NoBody {
|
||||
return errors.New("http3: invalid use of RequestStream.SendRequestHeader with a request that has a request body")
|
||||
}
|
||||
return s.sendRequestHeader(req)
|
||||
}
|
||||
|
||||
func (s *RequestStream) sendRequestHeader(req *http.Request) error {
|
||||
if s.sentRequest {
|
||||
return errors.New("http3: invalid duplicate use of RequestStream.SendRequestHeader")
|
||||
}
|
||||
if !s.disableCompression && req.Method != http.MethodHead &&
|
||||
req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||
s.requestedGzip = true
|
||||
}
|
||||
s.isConnect = req.Method == http.MethodConnect
|
||||
s.sentRequest = true
|
||||
return s.requestWriter.WriteRequestHeader(s.str.datagramStream, req, s.requestedGzip, s.str.StreamID(), s.str.qlogger)
|
||||
}
|
||||
|
||||
// sendRequestTrailer sends request trailers to the stream.
|
||||
// It should be called after the request body has been fully written.
|
||||
func (s *RequestStream) sendRequestTrailer(req *http.Request) error {
|
||||
return s.requestWriter.WriteRequestTrailer(s.str.datagramStream, req, s.str.StreamID(), s.str.qlogger)
|
||||
}
|
||||
|
||||
// ReadResponse reads the HTTP response from the stream.
|
||||
//
|
||||
// It must be called after sending the request (using SendRequestHeader).
|
||||
// It is invalid to call it more than once.
|
||||
// It doesn't set Response.Request and Response.TLS.
|
||||
// It is invalid to call it after Read has been called.
|
||||
func (s *RequestStream) ReadResponse() (*http.Response, error) {
|
||||
if !s.sentRequest {
|
||||
return nil, errors.New("http3: invalid use of RequestStream.ReadResponse before SendRequestHeader")
|
||||
}
|
||||
frame, err := s.str.frameParser.ParseNext(s.str.qlogger)
|
||||
if err != nil {
|
||||
s.str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
s.str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
return nil, fmt.Errorf("http3: parsing frame failed: %w", err)
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
s.str.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
|
||||
return nil, errors.New("http3: expected first frame to be a HEADERS frame")
|
||||
}
|
||||
if hf.Length > uint64(s.maxHeaderBytes) {
|
||||
maybeQlogInvalidHeadersFrame(s.str.qlogger, s.str.StreamID(), hf.Length)
|
||||
s.str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
s.str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes)
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(s.str.datagramStream, headerBlock); err != nil {
|
||||
maybeQlogInvalidHeadersFrame(s.str.qlogger, s.str.StreamID(), hf.Length)
|
||||
s.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
s.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
return nil, fmt.Errorf("http3: failed to read response headers: %w", err)
|
||||
}
|
||||
decodeFn := s.decoder.Decode(headerBlock)
|
||||
var hfs []qpack.HeaderField
|
||||
if s.str.qlogger != nil {
|
||||
hfs = make([]qpack.HeaderField, 0, 16)
|
||||
}
|
||||
res := s.response
|
||||
err = updateResponseFromHeaders(res, decodeFn, s.maxHeaderBytes, &hfs)
|
||||
if s.str.qlogger != nil {
|
||||
qlogParsedHeadersFrame(s.str.qlogger, s.str.StreamID(), hf, hfs)
|
||||
}
|
||||
if err != nil {
|
||||
errCode := ErrCodeMessageError
|
||||
var qpackErr *qpackError
|
||||
if errors.As(err, &qpackErr) {
|
||||
errCode = ErrCodeQPACKDecompressionFailed
|
||||
}
|
||||
s.str.CancelRead(quic.StreamErrorCode(errCode))
|
||||
s.str.CancelWrite(quic.StreamErrorCode(errCode))
|
||||
return nil, fmt.Errorf("http3: invalid response: %w", err)
|
||||
}
|
||||
|
||||
// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
respBody := newResponseBody(s.str, res.ContentLength, s.reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
isInformational := res.StatusCode >= 100 && res.StatusCode < 200
|
||||
isNoContent := res.StatusCode == http.StatusNoContent
|
||||
isSuccessfulConnect := s.isConnect && res.StatusCode >= 200 && res.StatusCode < 300
|
||||
if (isInformational || isNoContent || isSuccessfulConnect) && res.ContentLength == -1 {
|
||||
res.ContentLength = 0
|
||||
}
|
||||
if s.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||
res.Header.Del("Content-Encoding")
|
||||
res.Header.Del("Content-Length")
|
||||
res.ContentLength = -1
|
||||
s.responseBody = newGzipReader(respBody)
|
||||
res.Uncompressed = true
|
||||
} else {
|
||||
s.responseBody = respBody
|
||||
}
|
||||
res.Body = s.responseBody
|
||||
return res, nil
|
||||
}
|
||||
|
||||
type tracingReader struct {
|
||||
io.Reader
|
||||
readFirst bool
|
||||
trace *httptrace.ClientTrace
|
||||
}
|
||||
|
||||
func (r *tracingReader) Read(b []byte) (int, error) {
|
||||
n, err := r.Reader.Read(b)
|
||||
if n > 0 && !r.readFirst {
|
||||
traceGotFirstResponseByte(r.trace)
|
||||
r.readFirst = true
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
105
vendor/github.com/quic-go/quic-go/http3/trace.go
generated
vendored
Normal file
105
vendor/github.com/quic-go/quic-go/http3/trace.go
generated
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
func traceGetConn(trace *httptrace.ClientTrace, hostPort string) {
|
||||
if trace != nil && trace.GetConn != nil {
|
||||
trace.GetConn(hostPort)
|
||||
}
|
||||
}
|
||||
|
||||
// fakeConn is a wrapper for quic.EarlyConnection
|
||||
// because the quic connection does not implement net.Conn.
|
||||
type fakeConn struct {
|
||||
conn *quic.Conn
|
||||
}
|
||||
|
||||
func (c *fakeConn) Close() error { panic("connection operation prohibited") }
|
||||
func (c *fakeConn) Read(p []byte) (int, error) { panic("connection operation prohibited") }
|
||||
func (c *fakeConn) Write(p []byte) (int, error) { panic("connection operation prohibited") }
|
||||
func (c *fakeConn) SetDeadline(t time.Time) error { panic("connection operation prohibited") }
|
||||
func (c *fakeConn) SetReadDeadline(t time.Time) error { panic("connection operation prohibited") }
|
||||
func (c *fakeConn) SetWriteDeadline(t time.Time) error { panic("connection operation prohibited") }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||||
func (c *fakeConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
|
||||
func traceGotConn(trace *httptrace.ClientTrace, conn *quic.Conn, reused bool) {
|
||||
if trace != nil && trace.GotConn != nil {
|
||||
trace.GotConn(httptrace.GotConnInfo{
|
||||
Conn: &fakeConn{conn: conn},
|
||||
Reused: reused,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func traceGotFirstResponseByte(trace *httptrace.ClientTrace) {
|
||||
if trace != nil && trace.GotFirstResponseByte != nil {
|
||||
trace.GotFirstResponseByte()
|
||||
}
|
||||
}
|
||||
|
||||
func traceGot1xxResponse(trace *httptrace.ClientTrace, code int, header textproto.MIMEHeader) {
|
||||
if trace != nil && trace.Got1xxResponse != nil {
|
||||
trace.Got1xxResponse(code, header)
|
||||
}
|
||||
}
|
||||
|
||||
func traceGot100Continue(trace *httptrace.ClientTrace) {
|
||||
if trace != nil && trace.Got100Continue != nil {
|
||||
trace.Got100Continue()
|
||||
}
|
||||
}
|
||||
|
||||
func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
|
||||
return trace != nil && trace.WroteHeaderField != nil
|
||||
}
|
||||
|
||||
func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
|
||||
if trace != nil && trace.WroteHeaderField != nil {
|
||||
trace.WroteHeaderField(k, []string{v})
|
||||
}
|
||||
}
|
||||
|
||||
func traceWroteHeaders(trace *httptrace.ClientTrace) {
|
||||
if trace != nil && trace.WroteHeaders != nil {
|
||||
trace.WroteHeaders()
|
||||
}
|
||||
}
|
||||
|
||||
func traceWroteRequest(trace *httptrace.ClientTrace, err error) {
|
||||
if trace != nil && trace.WroteRequest != nil {
|
||||
trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
|
||||
}
|
||||
}
|
||||
|
||||
func traceConnectStart(trace *httptrace.ClientTrace, network, addr string) {
|
||||
if trace != nil && trace.ConnectStart != nil {
|
||||
trace.ConnectStart(network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func traceConnectDone(trace *httptrace.ClientTrace, network, addr string, err error) {
|
||||
if trace != nil && trace.ConnectDone != nil {
|
||||
trace.ConnectDone(network, addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func traceTLSHandshakeStart(trace *httptrace.ClientTrace) {
|
||||
if trace != nil && trace.TLSHandshakeStart != nil {
|
||||
trace.TLSHandshakeStart()
|
||||
}
|
||||
}
|
||||
|
||||
func traceTLSHandshakeDone(trace *httptrace.ClientTrace, state tls.ConnectionState, err error) {
|
||||
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||
trace.TLSHandshakeDone(state, err)
|
||||
}
|
||||
}
|
||||
538
vendor/github.com/quic-go/quic-go/http3/transport.go
generated
vendored
Normal file
538
vendor/github.com/quic-go/quic-go/http3/transport.go
generated
vendored
Normal file
@@ -0,0 +1,538 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// Settings are HTTP/3 settings that apply to the underlying connection.
|
||||
type Settings struct {
|
||||
// Support for HTTP/3 datagrams (RFC 9297)
|
||||
EnableDatagrams bool
|
||||
// Extended CONNECT, RFC 9220
|
||||
EnableExtendedConnect bool
|
||||
// Other settings, defined by the application
|
||||
Other map[uint64]uint64
|
||||
}
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
type RoundTripOpt struct {
|
||||
// OnlyCachedConn controls whether the Transport may create a new QUIC connection.
|
||||
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
|
||||
OnlyCachedConn bool
|
||||
}
|
||||
|
||||
type clientConn interface {
|
||||
OpenRequestStream(context.Context) (*RequestStream, error)
|
||||
RoundTrip(*http.Request) (*http.Response, error)
|
||||
handleUnidirectionalStream(*quic.ReceiveStream)
|
||||
}
|
||||
|
||||
type roundTripperWithCount struct {
|
||||
cancel context.CancelFunc
|
||||
dialing chan struct{} // closed as soon as quic.Dial(Early) returned
|
||||
dialErr error
|
||||
conn *quic.Conn
|
||||
clientConn clientConn
|
||||
|
||||
useCount atomic.Int64
|
||||
}
|
||||
|
||||
func (r *roundTripperWithCount) Close() error {
|
||||
r.cancel()
|
||||
<-r.dialing
|
||||
if r.conn != nil {
|
||||
return r.conn.CloseWithError(0, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Transport implements the http.RoundTripper interface
|
||||
type Transport struct {
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// QUICConfig is the quic.Config used for dialing new connections.
|
||||
// If nil, reasonable default values will be used.
|
||||
QUICConfig *quic.Config
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, a UDPConn will be created at the first request
|
||||
// and will be reused for subsequent connections to other servers.
|
||||
Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error)
|
||||
|
||||
// Enable support for HTTP/3 datagrams (RFC 9297).
|
||||
// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
|
||||
EnableDatagrams bool
|
||||
|
||||
// Additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
|
||||
AdditionalSettings map[uint64]uint64
|
||||
|
||||
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||
// allowed in the server's response header.
|
||||
// Zero means to use a default limit.
|
||||
MaxResponseHeaderBytes int
|
||||
|
||||
// DisableCompression, if true, prevents the Transport from requesting compression with an
|
||||
// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
|
||||
// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body.
|
||||
// However, if the user explicitly requested gzip it is not automatically uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
Logger *slog.Logger
|
||||
|
||||
mutex sync.Mutex
|
||||
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
|
||||
newClientConn func(*quic.Conn) clientConn
|
||||
|
||||
clients map[string]*roundTripperWithCount
|
||||
transport *quic.Transport
|
||||
closed bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.RoundTripper = &Transport{}
|
||||
_ io.Closer = &Transport{}
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
|
||||
ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||
// ErrTransportClosed is returned when attempting to use a closed Transport
|
||||
ErrTransportClosed = errors.New("http3: transport is closed")
|
||||
)
|
||||
|
||||
func (t *Transport) init() error {
|
||||
if t.newClientConn == nil {
|
||||
t.newClientConn = func(conn *quic.Conn) clientConn {
|
||||
return newClientConn(
|
||||
conn,
|
||||
t.EnableDatagrams,
|
||||
t.AdditionalSettings,
|
||||
t.MaxResponseHeaderBytes,
|
||||
t.DisableCompression,
|
||||
t.Logger,
|
||||
)
|
||||
}
|
||||
}
|
||||
if t.QUICConfig == nil {
|
||||
t.QUICConfig = defaultQuicConfig.Clone()
|
||||
t.QUICConfig.EnableDatagrams = t.EnableDatagrams
|
||||
}
|
||||
if t.EnableDatagrams && !t.QUICConfig.EnableDatagrams {
|
||||
return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
|
||||
}
|
||||
if len(t.QUICConfig.Versions) == 0 {
|
||||
t.QUICConfig = t.QUICConfig.Clone()
|
||||
t.QUICConfig.Versions = []quic.Version{quic.SupportedVersions()[0]}
|
||||
}
|
||||
if len(t.QUICConfig.Versions) != 1 {
|
||||
return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
||||
}
|
||||
if t.QUICConfig.MaxIncomingStreams == 0 {
|
||||
t.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||
}
|
||||
if t.Dial == nil {
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.transport = &quic.Transport{Conn: udpConn}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTripOpt is like RoundTrip, but takes options.
|
||||
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
rsp, err := t.roundTripOpt(req, opt)
|
||||
if err != nil {
|
||||
if req.Body != nil {
|
||||
req.Body.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
func (t *Transport) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
t.initOnce.Do(func() { t.initErr = t.init() })
|
||||
if t.initErr != nil {
|
||||
return nil, t.initErr
|
||||
}
|
||||
|
||||
if req.URL == nil {
|
||||
return nil, errors.New("http3: nil Request.URL")
|
||||
}
|
||||
if req.URL.Scheme != "https" {
|
||||
return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
|
||||
}
|
||||
if req.URL.Host == "" {
|
||||
return nil, errors.New("http3: no Host in request URL")
|
||||
}
|
||||
if req.Header == nil {
|
||||
return nil, errors.New("http3: nil Request.Header")
|
||||
}
|
||||
if req.Method != "" && !validMethod(req.Method) {
|
||||
return nil, fmt.Errorf("http3: invalid method %q", req.Method)
|
||||
}
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("http3: invalid http header field name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t.doRoundTripOpt(req, opt, false)
|
||||
}
|
||||
|
||||
func (t *Transport) doRoundTripOpt(req *http.Request, opt RoundTripOpt, isRetried bool) (*http.Response, error) {
|
||||
hostname := authorityAddr(hostnameFromURL(req.URL))
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
traceGetConn(trace, hostname)
|
||||
cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-cl.dialing:
|
||||
case <-req.Context().Done():
|
||||
return nil, context.Cause(req.Context())
|
||||
}
|
||||
|
||||
if cl.dialErr != nil {
|
||||
t.removeClient(hostname)
|
||||
return nil, cl.dialErr
|
||||
}
|
||||
defer cl.useCount.Add(-1)
|
||||
traceGotConn(trace, cl.conn, isReused)
|
||||
rsp, err := cl.clientConn.RoundTrip(req)
|
||||
if err != nil {
|
||||
// request aborted due to context cancellation
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
if isRetried {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.removeClient(hostname)
|
||||
req, err = canRetryRequest(err, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.doRoundTripOpt(req, opt, true)
|
||||
}
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
func canRetryRequest(err error, req *http.Request) (*http.Request, error) {
|
||||
// error occurred while opening the stream, we can be sure that the request wasn't sent out
|
||||
var connErr *errConnUnusable
|
||||
if errors.As(err, &connErr) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// If the request stream is reset, we can only be sure that the request wasn't processed
|
||||
// if the error code is H3_REQUEST_REJECTED.
|
||||
var e *Error
|
||||
if !errors.As(err, &e) || e.ErrorCode != ErrCodeRequestRejected {
|
||||
return nil, err
|
||||
}
|
||||
// if the body is nil (or http.NoBody), it's safe to reuse this request and its body
|
||||
if req.Body == nil || req.Body == http.NoBody {
|
||||
return req, nil
|
||||
}
|
||||
// if the request body can be reset back to its original state via req.GetBody, do that
|
||||
if req.GetBody != nil {
|
||||
newBody, err := req.GetBody()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqCopy := *req
|
||||
reqCopy.Body = newBody
|
||||
req = &reqCopy
|
||||
return &reqCopy, nil
|
||||
}
|
||||
return nil, fmt.Errorf("http3: Transport: cannot retry err [%w] after Request.Body was written; define Request.GetBody to avoid this error", err)
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.closed {
|
||||
return nil, false, ErrTransportClosed
|
||||
}
|
||||
|
||||
if t.clients == nil {
|
||||
t.clients = make(map[string]*roundTripperWithCount)
|
||||
}
|
||||
|
||||
cl, ok := t.clients[hostname]
|
||||
if !ok {
|
||||
if onlyCached {
|
||||
return nil, false, ErrNoCachedConn
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
cl = &roundTripperWithCount{
|
||||
dialing: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
}
|
||||
go func() {
|
||||
defer close(cl.dialing)
|
||||
defer cancel()
|
||||
conn, rt, err := t.dial(ctx, hostname)
|
||||
if err != nil {
|
||||
cl.dialErr = err
|
||||
return
|
||||
}
|
||||
cl.conn = conn
|
||||
cl.clientConn = rt
|
||||
}()
|
||||
t.clients[hostname] = cl
|
||||
}
|
||||
select {
|
||||
case <-cl.dialing:
|
||||
if cl.dialErr != nil {
|
||||
delete(t.clients, hostname)
|
||||
return nil, false, cl.dialErr
|
||||
}
|
||||
select {
|
||||
case <-cl.conn.HandshakeComplete():
|
||||
isReused = true
|
||||
default:
|
||||
}
|
||||
default:
|
||||
}
|
||||
cl.useCount.Add(1)
|
||||
return cl, isReused, nil
|
||||
}
|
||||
|
||||
func (t *Transport) dial(ctx context.Context, hostname string) (*quic.Conn, clientConn, error) {
|
||||
var tlsConf *tls.Config
|
||||
if t.TLSClientConfig == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = t.TLSClientConfig.Clone()
|
||||
}
|
||||
if tlsConf.ServerName == "" {
|
||||
sni, _, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
|
||||
sni = hostname
|
||||
}
|
||||
tlsConf.ServerName = sni
|
||||
}
|
||||
// Replace existing ALPNs by H3
|
||||
tlsConf.NextProtos = []string{NextProtoH3}
|
||||
|
||||
dial := t.Dial
|
||||
if dial == nil {
|
||||
dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
network := "udp"
|
||||
udpAddr, err := t.resolveUDPAddr(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
trace := httptrace.ContextClientTrace(ctx)
|
||||
traceConnectStart(trace, network, udpAddr.String())
|
||||
traceTLSHandshakeStart(trace)
|
||||
conn, err := t.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
|
||||
var state tls.ConnectionState
|
||||
if conn != nil {
|
||||
state = conn.ConnectionState().TLS
|
||||
}
|
||||
traceTLSHandshakeDone(trace, state, err)
|
||||
traceConnectDone(trace, network, udpAddr.String(), err)
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
conn, err := dial(ctx, hostname, tlsConf, t.QUICConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
clientConn := t.newClientConn(conn)
|
||||
go func() {
|
||||
for {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go clientConn.handleUnidirectionalStream(str)
|
||||
}
|
||||
}()
|
||||
return conn, clientConn, nil
|
||||
}
|
||||
|
||||
func (t *Transport) resolveUDPAddr(ctx context.Context, network, addr string) (*net.UDPAddr, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
port, err := net.LookupPort(network, portStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolver := net.DefaultResolver
|
||||
ipAddrs, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrs := addrList(ipAddrs)
|
||||
ip := addrs.forResolve(network, addr)
|
||||
return &net.UDPAddr{IP: ip.IP, Port: port, Zone: ip.Zone}, nil
|
||||
}
|
||||
|
||||
func (t *Transport) removeClient(hostname string) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.clients == nil {
|
||||
return
|
||||
}
|
||||
delete(t.clients, hostname)
|
||||
}
|
||||
|
||||
// NewClientConn creates a new HTTP/3 client connection on top of a QUIC connection.
|
||||
// Most users should use RoundTrip instead of creating a connection directly.
|
||||
// Specifically, it is not needed to perform GET, POST, HEAD and CONNECT requests.
|
||||
//
|
||||
// Obtaining a ClientConn is only needed for more advanced use cases, such as
|
||||
// using Extended CONNECT for WebTransport or the various MASQUE protocols.
|
||||
func (t *Transport) NewClientConn(conn *quic.Conn) *ClientConn {
|
||||
c := newClientConn(
|
||||
conn,
|
||||
t.EnableDatagrams,
|
||||
t.AdditionalSettings,
|
||||
t.MaxResponseHeaderBytes,
|
||||
t.DisableCompression,
|
||||
t.Logger,
|
||||
)
|
||||
go func() {
|
||||
for {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go c.handleUnidirectionalStream(str)
|
||||
}
|
||||
}()
|
||||
return c
|
||||
}
|
||||
|
||||
// NewRawClientConn creates a new low-level HTTP/3 client connection on top of a QUIC connection.
|
||||
// Unlike NewClientConn, the returned RawClientConn allows the application to take control
|
||||
// of the stream accept loops, by calling HandleUnidirectionalStream for incoming unidirectional
|
||||
// streams and HandleBidirectionalStream for incoming bidirectional streams.
|
||||
func (t *Transport) NewRawClientConn(conn *quic.Conn) *RawClientConn {
|
||||
return &RawClientConn{
|
||||
ClientConn: newClientConn(
|
||||
conn,
|
||||
t.EnableDatagrams,
|
||||
t.AdditionalSettings,
|
||||
t.MaxResponseHeaderBytes,
|
||||
t.DisableCompression,
|
||||
t.Logger,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this Transport has used.
|
||||
// A Transport cannot be used after it has been closed.
|
||||
func (t *Transport) Close() error {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
for _, cl := range t.clients {
|
||||
if err := cl.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
t.clients = nil
|
||||
if t.transport != nil {
|
||||
if err := t.transport.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.transport.Conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.transport = nil
|
||||
}
|
||||
t.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func hostnameFromURL(url *url.URL) string {
|
||||
if url != nil {
|
||||
return url.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
/*
|
||||
Method = "OPTIONS" ; Section 9.2
|
||||
| "GET" ; Section 9.3
|
||||
| "HEAD" ; Section 9.4
|
||||
| "POST" ; Section 9.5
|
||||
| "PUT" ; Section 9.6
|
||||
| "DELETE" ; Section 9.7
|
||||
| "TRACE" ; Section 9.8
|
||||
| "CONNECT" ; Section 9.9
|
||||
| extension-method
|
||||
extension-method = token
|
||||
token = 1*<any CHAR except CTLs or separators>
|
||||
*/
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
// copied from net/http/http.go
|
||||
func isNotToken(r rune) bool {
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes any QUIC connections in the transport's pool that are currently idle.
|
||||
// An idle connection is one that was previously used for requests but is now sitting unused.
|
||||
// This method does not interrupt any connections currently in use.
|
||||
// It also does not affect connections obtained via NewClientConn.
|
||||
func (t *Transport) CloseIdleConnections() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
for hostname, cl := range t.clients {
|
||||
if cl.useCount.Load() == 0 {
|
||||
cl.Close()
|
||||
delete(t.clients, hostname)
|
||||
}
|
||||
}
|
||||
}
|
||||
215
vendor/github.com/quic-go/quic-go/interface.go
generated
vendored
Normal file
215
vendor/github.com/quic-go/quic-go/interface.go
generated
vendored
Normal file
@@ -0,0 +1,215 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
// The StreamID is the ID of a QUIC stream.
|
||||
type StreamID = protocol.StreamID
|
||||
|
||||
// A Version is a QUIC version number.
|
||||
type Version = protocol.Version
|
||||
|
||||
const (
|
||||
// Version1 is RFC 9000
|
||||
Version1 = protocol.Version1
|
||||
// Version2 is RFC 9369
|
||||
Version2 = protocol.Version2
|
||||
)
|
||||
|
||||
// SupportedVersions returns the support versions, sorted in descending order of preference.
|
||||
func SupportedVersions() []Version {
|
||||
// clone the slice to prevent the caller from modifying the slice
|
||||
return slices.Clone(protocol.SupportedVersions)
|
||||
}
|
||||
|
||||
// A ClientToken is a token received by the client.
|
||||
// It can be used to skip address validation on future connection attempts.
|
||||
type ClientToken struct {
|
||||
data []byte
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
type TokenStore interface {
|
||||
// Pop searches for a ClientToken associated with the given key.
|
||||
// Since tokens are not supposed to be reused, it must remove the token from the cache.
|
||||
// It returns nil when no token is found.
|
||||
Pop(key string) (token *ClientToken)
|
||||
|
||||
// Put adds a token to the cache with the given key. It might get called
|
||||
// multiple times in a connection.
|
||||
Put(key string, token *ClientToken)
|
||||
}
|
||||
|
||||
// Err0RTTRejected is the returned from:
|
||||
// - Open{Uni}Stream{Sync}
|
||||
// - Accept{Uni}Stream
|
||||
// - Stream.Read and Stream.Write
|
||||
//
|
||||
// when the server rejects a 0-RTT connection attempt.
|
||||
var Err0RTTRejected = errors.New("0-RTT rejected")
|
||||
|
||||
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
|
||||
// context returned by tls.Config.ClientInfo.Context.
|
||||
var QUICVersionContextKey = handshake.QUICVersionContextKey
|
||||
|
||||
// StatelessResetKey is a key used to derive stateless reset tokens.
|
||||
type StatelessResetKey [32]byte
|
||||
|
||||
// TokenGeneratorKey is a key used to encrypt session resumption tokens.
|
||||
type TokenGeneratorKey = handshake.TokenProtectorKey
|
||||
|
||||
// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000.
|
||||
// It is not able to handle QUIC Connection IDs longer than 20 bytes,
|
||||
// as they are allowed by RFC 8999.
|
||||
type ConnectionID = protocol.ConnectionID
|
||||
|
||||
// ConnectionIDFromBytes interprets b as a [ConnectionID]. It panics if b is
|
||||
// longer than 20 bytes.
|
||||
func ConnectionIDFromBytes(b []byte) ConnectionID {
|
||||
return protocol.ParseConnectionID(b)
|
||||
}
|
||||
|
||||
// A ConnectionIDGenerator allows the application to take control over the generation of Connection IDs.
|
||||
// Connection IDs generated by an implementation must be of constant length.
|
||||
type ConnectionIDGenerator interface {
|
||||
// GenerateConnectionID generates a new Connection ID.
|
||||
// Generated Connection IDs must be unique and observers should not be able to correlate two Connection IDs.
|
||||
GenerateConnectionID() (ConnectionID, error)
|
||||
|
||||
// ConnectionIDLen returns the length of Connection IDs generated by this implementation.
|
||||
// Implementations must return constant-length Connection IDs with lengths between 0 and 20 bytes.
|
||||
// A length of 0 can only be used when an endpoint doesn't need to multiplex connections during migration.
|
||||
ConnectionIDLen() int
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// GetConfigForClient is called for incoming connections.
|
||||
// If the error is not nil, the connection attempt is refused.
|
||||
GetConfigForClient func(info *ClientInfo) (*Config, error)
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
Versions []Version
|
||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
|
||||
// If this value is zero, the timeout is set to 5 seconds.
|
||||
HandshakeIdleTimeout time.Duration
|
||||
// MaxIdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||
// The actual value for the idle timeout is the minimum of this value and the peer's.
|
||||
// This value only applies after the handshake has completed.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
MaxIdleTimeout time.Duration
|
||||
// The TokenStore stores tokens received from the server.
|
||||
// Tokens are used to skip address validation on future connection attempts.
|
||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||
// otherwise the token is associated with the server's IP address.
|
||||
TokenStore TokenStore
|
||||
// InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data.
|
||||
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
|
||||
// will increase the window up to MaxStreamReceiveWindow.
|
||||
// If this value is zero, it will default to 512 KB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
InitialStreamReceiveWindow uint64
|
||||
// MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 6 MB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
MaxStreamReceiveWindow uint64
|
||||
// InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data.
|
||||
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
|
||||
// will increase the window up to MaxConnectionReceiveWindow.
|
||||
// If this value is zero, it will default to 512 KB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
InitialConnectionReceiveWindow uint64
|
||||
// MaxConnectionReceiveWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 15 MB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
MaxConnectionReceiveWindow uint64
|
||||
// AllowConnectionWindowIncrease is called every time the connection flow controller attempts
|
||||
// to increase the connection flow control window.
|
||||
// If set, the caller can prevent an increase of the window. Typically, it would do so to
|
||||
// limit the memory usage.
|
||||
// To avoid deadlocks, it is not valid to call other functions on the connection or on streams
|
||||
// in this callback.
|
||||
AllowConnectionWindowIncrease func(conn *Conn, delta uint64) bool
|
||||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||
// Values larger than 2^60 will be clipped to that value.
|
||||
MaxIncomingStreams int64
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
// Values larger than 2^60 will be clipped to that value.
|
||||
MaxIncomingUniStreams int64
|
||||
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
|
||||
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
|
||||
// every half of MaxIdleTimeout, whichever is smaller).
|
||||
KeepAlivePeriod time.Duration
|
||||
// InitialPacketSize is the initial size (and the lower limit) for packets sent.
|
||||
// Under most circumstances, it is not necessary to manually set this value,
|
||||
// since path MTU discovery quickly finds the path's MTU.
|
||||
// If set too high, the path might not support packets of that size, leading to a timeout of the QUIC handshake.
|
||||
// Values below 1200 are invalid.
|
||||
InitialPacketSize uint16
|
||||
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
|
||||
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
|
||||
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
|
||||
DisablePathMTUDiscovery bool
|
||||
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
|
||||
// Only valid for the server.
|
||||
Allow0RTT bool
|
||||
// Enable QUIC datagram support (RFC 9221).
|
||||
EnableDatagrams bool
|
||||
// Enable QUIC Stream Resets with Partial Delivery.
|
||||
// See https://datatracker.ietf.org/doc/html/draft-ietf-quic-reliable-stream-reset-07.
|
||||
EnableStreamResetPartialDelivery bool
|
||||
|
||||
Tracer func(ctx context.Context, isClient bool, connID ConnectionID) qlogwriter.Trace
|
||||
}
|
||||
|
||||
// ClientInfo contains information about an incoming connection attempt.
|
||||
type ClientInfo struct {
|
||||
// RemoteAddr is the remote address on the Initial packet.
|
||||
// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
|
||||
RemoteAddr net.Addr
|
||||
// AddrVerified says if the remote address was verified using QUIC's Retry mechanism.
|
||||
// Note that the Retry mechanism costs one network roundtrip,
|
||||
// and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed.
|
||||
AddrVerified bool
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about a QUIC connection.
|
||||
type ConnectionState struct {
|
||||
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
|
||||
TLS tls.ConnectionState
|
||||
// SupportsDatagrams indicates support for QUIC datagrams (RFC 9221).
|
||||
SupportsDatagrams struct {
|
||||
// Remote is true if the peer advertised datagram support.
|
||||
// Local is true if datagram support was enabled via Config.EnableDatagrams.
|
||||
Remote, Local bool
|
||||
}
|
||||
// SupportsStreamResetPartialDelivery indicates support for QUIC Stream Resets with Partial Delivery.
|
||||
SupportsStreamResetPartialDelivery struct {
|
||||
// Remote is true if the peer advertised support.
|
||||
// Local is true if support was enabled via Config.EnableStreamResetPartialDelivery.
|
||||
Remote, Local bool
|
||||
}
|
||||
// Used0RTT says if 0-RTT resumption was used.
|
||||
Used0RTT bool
|
||||
// Version is the QUIC version of the QUIC connection.
|
||||
Version Version
|
||||
// GSO says if generic segmentation offload is used.
|
||||
GSO bool
|
||||
}
|
||||
33
vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go
generated
vendored
Normal file
33
vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
package ackhandler
|
||||
|
||||
import "github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
// IsFrameTypeAckEliciting returns true if the frame is ack-eliciting.
|
||||
func IsFrameTypeAckEliciting(t wire.FrameType) bool {
|
||||
//nolint:exhaustive // The default case catches the rest.
|
||||
switch t {
|
||||
case wire.FrameTypeAck, wire.FrameTypeAckECN:
|
||||
return false
|
||||
case wire.FrameTypeConnectionClose, wire.FrameTypeApplicationClose:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// IsFrameAckEliciting returns true if the frame is ack-eliciting.
|
||||
func IsFrameAckEliciting(f wire.Frame) bool {
|
||||
_, isAck := f.(*wire.AckFrame)
|
||||
_, isConnectionClose := f.(*wire.ConnectionCloseFrame)
|
||||
return !isAck && !isConnectionClose
|
||||
}
|
||||
|
||||
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
|
||||
func HasAckElicitingFrames(fs []Frame) bool {
|
||||
for _, f := range fs {
|
||||
if IsFrameAckEliciting(f.Frame) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
340
vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go
generated
vendored
Normal file
340
vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go
generated
vendored
Normal file
@@ -0,0 +1,340 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
type ecnState uint8
|
||||
|
||||
const (
|
||||
ecnStateInitial ecnState = iota
|
||||
ecnStateTesting
|
||||
ecnStateUnknown
|
||||
ecnStateCapable
|
||||
ecnStateFailed
|
||||
)
|
||||
|
||||
const (
|
||||
// ecnFailedNoECNCounts is emitted when an ACK acknowledges ECN-marked packets,
|
||||
// but doesn't contain any ECN counts
|
||||
ecnFailedNoECNCounts = "ACK doesn't contain ECN marks"
|
||||
// ecnFailedDecreasedECNCounts is emitted when an ACK frame decreases ECN counts
|
||||
ecnFailedDecreasedECNCounts = "ACK decreases ECN counts"
|
||||
// ecnFailedLostAllTestingPackets is emitted when all ECN testing packets are declared lost
|
||||
ecnFailedLostAllTestingPackets = "all ECN testing packets declared lost"
|
||||
// ecnFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent
|
||||
ecnFailedMoreECNCountsThanSent = "ACK contains more ECN counts than ECN-marked packets sent"
|
||||
// ecnFailedTooFewECNCounts is emitted when an ACK contains fewer ECN counts than it acknowledges packets
|
||||
ecnFailedTooFewECNCounts = "ACK contains fewer new ECN counts than acknowledged ECN-marked packets"
|
||||
// ecnFailedManglingDetected is emitted when the path marks all ECN-marked packets as CE
|
||||
ecnFailedManglingDetected = "ECN mangling detected"
|
||||
)
|
||||
|
||||
// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type
|
||||
const numECNTestingPackets = 10
|
||||
|
||||
type ecnHandler interface {
|
||||
SentPacket(protocol.PacketNumber, protocol.ECN)
|
||||
Mode() protocol.ECN
|
||||
HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) (congested bool)
|
||||
LostPacket(protocol.PacketNumber)
|
||||
}
|
||||
|
||||
// The ecnTracker performs ECN validation of a path.
|
||||
// Once failed, it doesn't do any re-validation of the path.
|
||||
// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces.
|
||||
// In order to avoid revealing any internal state to on-path observers,
|
||||
// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent.
|
||||
// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4.
|
||||
type ecnTracker struct {
|
||||
state ecnState
|
||||
numSentTesting, numLostTesting uint8
|
||||
|
||||
firstTestingPacket protocol.PacketNumber
|
||||
lastTestingPacket protocol.PacketNumber
|
||||
firstCapablePacket protocol.PacketNumber
|
||||
|
||||
numSentECT0, numSentECT1 int64
|
||||
numAckedECT0, numAckedECT1, numAckedECNCE int64
|
||||
|
||||
qlogger qlogwriter.Recorder
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ ecnHandler = &ecnTracker{}
|
||||
|
||||
func newECNTracker(logger utils.Logger, qlogger qlogwriter.Recorder) *ecnTracker {
|
||||
return &ecnTracker{
|
||||
firstTestingPacket: protocol.InvalidPacketNumber,
|
||||
lastTestingPacket: protocol.InvalidPacketNumber,
|
||||
firstCapablePacket: protocol.InvalidPacketNumber,
|
||||
state: ecnStateInitial,
|
||||
logger: logger,
|
||||
qlogger: qlogger,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) {
|
||||
//nolint:exhaustive // These are the only ones we need to take care of.
|
||||
switch ecn {
|
||||
case protocol.ECNNon:
|
||||
return
|
||||
case protocol.ECT0:
|
||||
e.numSentECT0++
|
||||
case protocol.ECT1:
|
||||
e.numSentECT1++
|
||||
case protocol.ECNUnsupported:
|
||||
if e.state != ecnStateFailed {
|
||||
panic("didn't expect ECN to be unsupported")
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn))
|
||||
}
|
||||
|
||||
if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber {
|
||||
e.firstCapablePacket = pn
|
||||
}
|
||||
|
||||
if e.state != ecnStateTesting {
|
||||
return
|
||||
}
|
||||
|
||||
e.numSentTesting++
|
||||
if e.firstTestingPacket == protocol.InvalidPacketNumber {
|
||||
e.firstTestingPacket = pn
|
||||
}
|
||||
if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets {
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateUnknown,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateUnknown
|
||||
e.lastTestingPacket = pn
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ecnTracker) Mode() protocol.ECN {
|
||||
switch e.state {
|
||||
case ecnStateInitial:
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateTesting,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateTesting
|
||||
return e.Mode()
|
||||
case ecnStateTesting, ecnStateCapable:
|
||||
return protocol.ECT0
|
||||
case ecnStateUnknown, ecnStateFailed:
|
||||
return protocol.ECNNon
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown ECN state: %d", e.state))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) {
|
||||
if e.state != ecnStateTesting && e.state != ecnStateUnknown {
|
||||
return
|
||||
}
|
||||
if !e.isTestingPacket(pn) {
|
||||
return
|
||||
}
|
||||
e.numLostTesting++
|
||||
// Only proceed if we have sent all 10 testing packets.
|
||||
if e.state != ecnStateUnknown {
|
||||
return
|
||||
}
|
||||
if e.numLostTesting >= e.numSentTesting {
|
||||
e.logger.Debugf("Disabling ECN. All testing packets were lost.")
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateFailed,
|
||||
Trigger: ecnFailedLostAllTestingPackets,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return
|
||||
}
|
||||
// Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked
|
||||
e.failIfMangled()
|
||||
}
|
||||
|
||||
// HandleNewlyAcked handles the ECN counts on an ACK frame.
|
||||
// It must only be called for ACK frames that increase the largest acknowledged packet number,
|
||||
// see section 13.4.2.1 of RFC 9000.
|
||||
func (e *ecnTracker) HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) (congested bool) {
|
||||
if e.state == ecnStateFailed {
|
||||
return false
|
||||
}
|
||||
|
||||
// ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds
|
||||
// the total number of packets sent with each corresponding ECT codepoint.
|
||||
if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 {
|
||||
e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.")
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateFailed,
|
||||
Trigger: ecnFailedMoreECNCountsThanSent,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged.
|
||||
var ackedECT0, ackedECT1 int64
|
||||
for _, p := range packets {
|
||||
//nolint:exhaustive // We only ever send ECT(0) and ECT(1).
|
||||
switch e.ecnMarking(p.PacketNumber) {
|
||||
case protocol.ECT0:
|
||||
ackedECT0++
|
||||
case protocol.ECT1:
|
||||
ackedECT1++
|
||||
}
|
||||
}
|
||||
|
||||
// If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1)
|
||||
// codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame.
|
||||
// This check detects:
|
||||
// * paths that bleach all ECN marks, and
|
||||
// * peers that don't report any ECN counts
|
||||
if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 {
|
||||
e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.")
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateFailed,
|
||||
Trigger: ecnFailedNoECNCounts,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// Determine the increase in ECT0, ECT1 and ECNCE marks
|
||||
newECT0 := ect0 - e.numAckedECT0
|
||||
newECT1 := ect1 - e.numAckedECT1
|
||||
newECNCE := ecnce - e.numAckedECNCE
|
||||
|
||||
// We're only processing ACKs that increase the Largest Acked.
|
||||
// Therefore, the ECN counters should only ever increase.
|
||||
// Any decrease means that the peer's counting logic is broken.
|
||||
if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 {
|
||||
e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.")
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateFailed,
|
||||
Trigger: ecnFailedDecreasedECNCounts,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number
|
||||
// of newly acknowledged packets that were originally sent with an ECT(0) marking.
|
||||
// This could be the result of (partial) bleaching.
|
||||
if newECT0+newECNCE < ackedECT0 {
|
||||
e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).")
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateFailed,
|
||||
Trigger: ecnFailedTooFewECNCounts,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
// Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than
|
||||
// the number of newly acknowledged packets sent with an ECT(1) marking.
|
||||
if newECT1+newECNCE < ackedECT1 {
|
||||
e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).")
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateFailed,
|
||||
Trigger: ecnFailedTooFewECNCounts,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// update our counters
|
||||
e.numAckedECT0 = ect0
|
||||
e.numAckedECT1 = ect1
|
||||
e.numAckedECNCE = ecnce
|
||||
|
||||
// Detect mangling (a path remarking all ECN-marked testing packets as CE),
|
||||
// once all 10 testing packets have been sent out.
|
||||
if e.state == ecnStateUnknown {
|
||||
e.failIfMangled()
|
||||
if e.state == ecnStateFailed {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if e.state == ecnStateTesting || e.state == ecnStateUnknown {
|
||||
var ackedTestingPacket bool
|
||||
for _, p := range packets {
|
||||
if e.isTestingPacket(p.PacketNumber) {
|
||||
ackedTestingPacket = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE).
|
||||
if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) {
|
||||
e.logger.Debugf("ECN capability confirmed.")
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateCapable,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateCapable
|
||||
}
|
||||
}
|
||||
|
||||
// Don't trust CE marks before having confirmed ECN capability of the path.
|
||||
// Otherwise, mangling would be misinterpreted as actual congestion.
|
||||
return e.state == ecnStateCapable && newECNCE > 0
|
||||
}
|
||||
|
||||
// failIfMangled fails ECN validation if all testing packets are lost or CE-marked.
|
||||
func (e *ecnTracker) failIfMangled() {
|
||||
numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting)
|
||||
if e.numSentECT0+e.numSentECT1 > numAckedECNCE {
|
||||
return
|
||||
}
|
||||
if e.qlogger != nil {
|
||||
e.qlogger.RecordEvent(qlog.ECNStateUpdated{
|
||||
State: qlog.ECNStateFailed,
|
||||
Trigger: ecnFailedManglingDetected,
|
||||
})
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
}
|
||||
|
||||
func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN {
|
||||
if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber {
|
||||
return protocol.ECNNon
|
||||
}
|
||||
if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber {
|
||||
return protocol.ECT0
|
||||
}
|
||||
if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber {
|
||||
return protocol.ECNNon
|
||||
}
|
||||
// We don't need to deal with the case when ECN validation fails,
|
||||
// since we're ignoring any ECN counts reported in ACK frames in that case.
|
||||
return protocol.ECT0
|
||||
}
|
||||
|
||||
func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool {
|
||||
if e.firstTestingPacket == protocol.InvalidPacketNumber {
|
||||
return false
|
||||
}
|
||||
return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber)
|
||||
}
|
||||
21
vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// FrameHandler handles the acknowledgement and the loss of a frame.
|
||||
type FrameHandler interface {
|
||||
OnAcked(wire.Frame)
|
||||
OnLost(wire.Frame)
|
||||
}
|
||||
|
||||
type Frame struct {
|
||||
Frame wire.Frame // nil if the frame has already been acknowledged in another packet
|
||||
Handler FrameHandler
|
||||
}
|
||||
|
||||
type StreamFrame struct {
|
||||
Frame *wire.StreamFrame
|
||||
Handler FrameHandler
|
||||
}
|
||||
39
vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
39
vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(t monotime.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket, isPathProbePacket bool)
|
||||
// ReceivedAck processes an ACK frame.
|
||||
// It does not store a copy of the frame.
|
||||
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime monotime.Time) (bool /* 1-RTT packet acked */, error)
|
||||
ReceivedPacket(protocol.EncryptionLevel, monotime.Time)
|
||||
ReceivedBytes(_ protocol.ByteCount, rcvTime monotime.Time)
|
||||
DropPackets(_ protocol.EncryptionLevel, rcvTime monotime.Time)
|
||||
ResetForRetry(rcvTime monotime.Time)
|
||||
|
||||
// The SendMode determines if and what kind of packets can be sent.
|
||||
SendMode(now monotime.Time) SendMode
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() monotime.Time
|
||||
SetMaxDatagramSize(count protocol.ByteCount)
|
||||
|
||||
// only to be called once the handshake is complete
|
||||
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
|
||||
|
||||
ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets
|
||||
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
|
||||
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
|
||||
|
||||
GetLossDetectionTimeout() monotime.Time
|
||||
OnLossDetectionTimeout(now monotime.Time) error
|
||||
|
||||
MigratedPath(now monotime.Time, initialMaxPacketSize protocol.ByteCount)
|
||||
}
|
||||
73
vendor/github.com/quic-go/quic-go/internal/ackhandler/lost_packet_tracker.go
generated
vendored
Normal file
73
vendor/github.com/quic-go/quic-go/internal/ackhandler/lost_packet_tracker.go
generated
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"slices"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type lostPacket struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
SendTime monotime.Time
|
||||
}
|
||||
|
||||
type lostPacketTracker struct {
|
||||
maxLength int
|
||||
lostPackets []lostPacket
|
||||
}
|
||||
|
||||
func newLostPacketTracker(maxLength int) *lostPacketTracker {
|
||||
return &lostPacketTracker{
|
||||
maxLength: maxLength,
|
||||
// Preallocate a small slice only.
|
||||
// Hopefully we won't lose many packets.
|
||||
lostPackets: make([]lostPacket, 0, 4),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *lostPacketTracker) Add(p protocol.PacketNumber, sendTime monotime.Time) {
|
||||
if len(t.lostPackets) == t.maxLength {
|
||||
t.lostPackets = t.lostPackets[1:]
|
||||
}
|
||||
t.lostPackets = append(t.lostPackets, lostPacket{
|
||||
PacketNumber: p,
|
||||
SendTime: sendTime,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete deletes a packet from the lost packet tracker.
|
||||
// This function is not optimized for performance if many packets are lost,
|
||||
// but it is only used when a spurious loss is detected, which is rare.
|
||||
func (t *lostPacketTracker) Delete(pn protocol.PacketNumber) {
|
||||
t.lostPackets = slices.DeleteFunc(t.lostPackets, func(p lostPacket) bool {
|
||||
return p.PacketNumber == pn
|
||||
})
|
||||
}
|
||||
|
||||
func (t *lostPacketTracker) All() iter.Seq2[protocol.PacketNumber, monotime.Time] {
|
||||
return func(yield func(protocol.PacketNumber, monotime.Time) bool) {
|
||||
for _, p := range t.lostPackets {
|
||||
if !yield(p.PacketNumber, p.SendTime) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *lostPacketTracker) DeleteBefore(ti monotime.Time) {
|
||||
if len(t.lostPackets) == 0 {
|
||||
return
|
||||
}
|
||||
if !t.lostPackets[0].SendTime.Before(ti) {
|
||||
return
|
||||
}
|
||||
var idx int
|
||||
for ; idx < len(t.lostPackets); idx++ {
|
||||
if !t.lostPackets[idx].SendTime.Before(ti) {
|
||||
break
|
||||
}
|
||||
}
|
||||
t.lostPackets = slices.Delete(t.lostPackets, 0, idx)
|
||||
}
|
||||
6
vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go
generated
vendored
Normal file
6
vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build gomock || generate
|
||||
|
||||
package ackhandler
|
||||
|
||||
//go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler"
|
||||
type ECNHandler = ecnHandler
|
||||
60
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
60
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type packetWithPacketNumber struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
*packet
|
||||
}
|
||||
|
||||
// A Packet is a packet
|
||||
type packet struct {
|
||||
SendTime monotime.Time
|
||||
StreamFrames []StreamFrame
|
||||
Frames []Frame
|
||||
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
|
||||
IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller.
|
||||
|
||||
includedInBytesInFlight bool
|
||||
isPathProbePacket bool
|
||||
}
|
||||
|
||||
func (p *packet) Outstanding() bool {
|
||||
return !p.IsPathMTUProbePacket && !p.isPathProbePacket && p.IsAckEliciting()
|
||||
}
|
||||
|
||||
func (p *packet) IsAckEliciting() bool {
|
||||
return len(p.StreamFrames) > 0 || len(p.Frames) > 0
|
||||
}
|
||||
|
||||
var packetPool = sync.Pool{New: func() any { return &packet{} }}
|
||||
|
||||
func getPacket() *packet {
|
||||
p := packetPool.Get().(*packet)
|
||||
p.StreamFrames = nil
|
||||
p.Frames = nil
|
||||
p.LargestAcked = 0
|
||||
p.Length = 0
|
||||
p.EncryptionLevel = protocol.EncryptionLevel(0)
|
||||
p.SendTime = 0
|
||||
p.IsPathMTUProbePacket = false
|
||||
p.includedInBytesInFlight = false
|
||||
p.isPathProbePacket = false
|
||||
return p
|
||||
}
|
||||
|
||||
// We currently only return Packets back into the pool when they're acknowledged (not when they're lost).
|
||||
// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool.
|
||||
func putPacket(p *packet) {
|
||||
p.Frames = nil
|
||||
p.StreamFrames = nil
|
||||
packetPool.Put(p)
|
||||
}
|
||||
84
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
Normal file
84
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type packetNumberGenerator interface {
|
||||
Peek() protocol.PacketNumber
|
||||
// Pop pops the packet number.
|
||||
// It reports if the packet number (before the one just popped) was skipped.
|
||||
// It never skips more than one packet number in a row.
|
||||
Pop() (skipped bool, _ protocol.PacketNumber)
|
||||
}
|
||||
|
||||
type sequentialPacketNumberGenerator struct {
|
||||
next protocol.PacketNumber
|
||||
}
|
||||
|
||||
var _ packetNumberGenerator = &sequentialPacketNumberGenerator{}
|
||||
|
||||
func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator {
|
||||
return &sequentialPacketNumberGenerator{next: initial}
|
||||
}
|
||||
|
||||
func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
|
||||
return p.next
|
||||
}
|
||||
|
||||
func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
|
||||
next := p.next
|
||||
p.next++
|
||||
return false, next
|
||||
}
|
||||
|
||||
// The skippingPacketNumberGenerator generates the packet number for the next packet
|
||||
// it randomly skips a packet number every averagePeriod packets (on average).
|
||||
// It is guaranteed to never skip two consecutive packet numbers.
|
||||
type skippingPacketNumberGenerator struct {
|
||||
period protocol.PacketNumber
|
||||
maxPeriod protocol.PacketNumber
|
||||
|
||||
next protocol.PacketNumber
|
||||
nextToSkip protocol.PacketNumber
|
||||
|
||||
rng utils.Rand
|
||||
}
|
||||
|
||||
var _ packetNumberGenerator = &skippingPacketNumberGenerator{}
|
||||
|
||||
func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator {
|
||||
g := &skippingPacketNumberGenerator{
|
||||
next: initial,
|
||||
period: initialPeriod,
|
||||
maxPeriod: maxPeriod,
|
||||
}
|
||||
g.generateNewSkip()
|
||||
return g
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
|
||||
if p.next == p.nextToSkip {
|
||||
return p.next + 1
|
||||
}
|
||||
return p.next
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
|
||||
next := p.next
|
||||
if p.next == p.nextToSkip {
|
||||
next++
|
||||
p.next += 2
|
||||
p.generateNewSkip()
|
||||
return true, next
|
||||
}
|
||||
p.next++ // generate a new packet number for the next packet
|
||||
return false, next
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) generateNewSkip() {
|
||||
// make sure that there are never two consecutive packet numbers that are skipped
|
||||
p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
|
||||
p.period = min(2*p.period, p.maxPeriod)
|
||||
}
|
||||
119
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
119
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
@@ -0,0 +1,119 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type ReceivedPacketHandler struct {
|
||||
initialPackets *receivedPacketTracker
|
||||
handshakePackets *receivedPacketTracker
|
||||
appDataPackets appDataReceivedPacketTracker
|
||||
|
||||
lowest1RTTPacket protocol.PacketNumber
|
||||
}
|
||||
|
||||
func NewReceivedPacketHandler(logger utils.Logger) *ReceivedPacketHandler {
|
||||
return &ReceivedPacketHandler{
|
||||
initialPackets: newReceivedPacketTracker(),
|
||||
handshakePackets: newReceivedPacketTracker(),
|
||||
appDataPackets: *newAppDataReceivedPacketTracker(logger),
|
||||
lowest1RTTPacket: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ReceivedPacketHandler) ReceivedPacket(
|
||||
pn protocol.PacketNumber,
|
||||
ecn protocol.ECN,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
rcvTime monotime.Time,
|
||||
ackEliciting bool,
|
||||
) error {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return h.initialPackets.ReceivedPacket(pn, ecn, ackEliciting)
|
||||
case protocol.EncryptionHandshake:
|
||||
// The Handshake packet number space might already have been dropped as a result
|
||||
// of processing the CRYPTO frame that was contained in this packet.
|
||||
if h.handshakePackets == nil {
|
||||
return nil
|
||||
}
|
||||
return h.handshakePackets.ReceivedPacket(pn, ecn, ackEliciting)
|
||||
case protocol.Encryption0RTT:
|
||||
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
|
||||
return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
|
||||
}
|
||||
return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
|
||||
case protocol.Encryption1RTT:
|
||||
if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
|
||||
h.lowest1RTTPacket = pn
|
||||
}
|
||||
return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
|
||||
default:
|
||||
panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ReceivedPacketHandler) IgnorePacketsBelow(pn protocol.PacketNumber) {
|
||||
h.appDataPackets.IgnoreBelow(pn)
|
||||
}
|
||||
|
||||
func (h *ReceivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
||||
//nolint:exhaustive // 1-RTT packet number space is never dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.initialPackets = nil
|
||||
case protocol.EncryptionHandshake:
|
||||
h.handshakePackets = nil
|
||||
case protocol.Encryption0RTT:
|
||||
// Nothing to do here.
|
||||
// If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted.
|
||||
default:
|
||||
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ReceivedPacketHandler) GetAlarmTimeout() monotime.Time {
|
||||
return h.appDataPackets.GetAlarmTimeout()
|
||||
}
|
||||
|
||||
func (h *ReceivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, now monotime.Time, onlyIfQueued bool) *wire.AckFrame {
|
||||
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialPackets != nil {
|
||||
return h.initialPackets.GetAckFrame()
|
||||
}
|
||||
return nil
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakePackets != nil {
|
||||
return h.handshakePackets.GetAckFrame()
|
||||
}
|
||||
return nil
|
||||
case protocol.Encryption1RTT:
|
||||
return h.appDataPackets.GetAckFrame(now, onlyIfQueued)
|
||||
default:
|
||||
// 0-RTT packets can't contain ACK frames
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ReceivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialPackets != nil {
|
||||
return h.initialPackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakePackets != nil {
|
||||
return h.handshakePackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
case protocol.Encryption0RTT, protocol.Encryption1RTT:
|
||||
return h.appDataPackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
panic("unexpected encryption level")
|
||||
}
|
||||
159
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go
generated
vendored
Normal file
159
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go
generated
vendored
Normal file
@@ -0,0 +1,159 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"slices"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// interval is an interval from one PacketNumber to the other
|
||||
type interval struct {
|
||||
Start protocol.PacketNumber
|
||||
End protocol.PacketNumber
|
||||
}
|
||||
|
||||
// The receivedPacketHistory stores if a packet number has already been received.
|
||||
// It generates ACK ranges which can be used to assemble an ACK frame.
|
||||
// It does not store packet contents.
|
||||
type receivedPacketHistory struct {
|
||||
ranges []interval // maximum length: protocol.MaxNumAckRanges
|
||||
|
||||
deletedBelow protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newReceivedPacketHistory() *receivedPacketHistory {
|
||||
return &receivedPacketHistory{
|
||||
deletedBelow: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
|
||||
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
|
||||
// ignore delayed packets, if we already deleted the range
|
||||
if p < h.deletedBelow {
|
||||
return false
|
||||
}
|
||||
|
||||
isNew := h.addToRanges(p)
|
||||
// Delete old ranges, if we're tracking too many of them.
|
||||
// This is a DoS defense against a peer that sends us too many gaps.
|
||||
if len(h.ranges) > protocol.MaxNumAckRanges {
|
||||
h.ranges = slices.Delete(h.ranges, 0, len(h.ranges)-protocol.MaxNumAckRanges)
|
||||
}
|
||||
return isNew
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
|
||||
if len(h.ranges) == 0 {
|
||||
h.ranges = append(h.ranges, interval{Start: p, End: p})
|
||||
return true
|
||||
}
|
||||
|
||||
for i := len(h.ranges) - 1; i >= 0; i-- {
|
||||
// p already included in an existing range. Nothing to do here
|
||||
if p >= h.ranges[i].Start && p <= h.ranges[i].End {
|
||||
return false
|
||||
}
|
||||
|
||||
if h.ranges[i].End == p-1 { // extend a range at the end
|
||||
h.ranges[i].End = p
|
||||
return true
|
||||
}
|
||||
if h.ranges[i].Start == p+1 { // extend a range at the beginning
|
||||
h.ranges[i].Start = p
|
||||
|
||||
if i > 0 && h.ranges[i-1].End+1 == h.ranges[i].Start { // merge two ranges
|
||||
h.ranges[i-1].End = h.ranges[i].End
|
||||
h.ranges = slices.Delete(h.ranges, i, i+1)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// create a new range after the current one
|
||||
if p > h.ranges[i].End {
|
||||
h.ranges = slices.Insert(h.ranges, i+1, interval{Start: p, End: p})
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// create a new range at the beginning
|
||||
h.ranges = slices.Insert(h.ranges, 0, interval{Start: p, End: p})
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteBelow deletes all entries below (but not including) p
|
||||
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||
if p < h.deletedBelow {
|
||||
return
|
||||
}
|
||||
h.deletedBelow = p
|
||||
|
||||
if len(h.ranges) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
idx := -1
|
||||
for i := 0; i < len(h.ranges); i++ {
|
||||
if h.ranges[i].End < p { // delete a whole range
|
||||
idx = i
|
||||
} else if p > h.ranges[i].Start && p <= h.ranges[i].End {
|
||||
h.ranges[i].Start = p
|
||||
break
|
||||
} else { // no ranges affected. Nothing to do
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx >= 0 {
|
||||
h.ranges = slices.Delete(h.ranges, 0, idx+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Backward returns an iterator over the ranges in reverse order
|
||||
func (h *receivedPacketHistory) Backward() iter.Seq[interval] {
|
||||
return func(yield func(interval) bool) {
|
||||
for i := len(h.ranges) - 1; i >= 0; i-- {
|
||||
if !yield(h.ranges[i]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) HighestMissingUpTo(p protocol.PacketNumber) protocol.PacketNumber {
|
||||
if len(h.ranges) == 0 || (h.deletedBelow != protocol.InvalidPacketNumber && p < h.deletedBelow) {
|
||||
return protocol.InvalidPacketNumber
|
||||
}
|
||||
p = min(h.ranges[len(h.ranges)-1].End, p)
|
||||
for i := len(h.ranges) - 1; i >= 0; i-- {
|
||||
r := h.ranges[i]
|
||||
if p >= r.Start && p <= r.End { // p is contained in this range
|
||||
highest := r.Start - 1 // highest packet in the gap before this range
|
||||
if h.deletedBelow != protocol.InvalidPacketNumber && highest < h.deletedBelow {
|
||||
return protocol.InvalidPacketNumber
|
||||
}
|
||||
return highest
|
||||
}
|
||||
if i >= 1 && p > h.ranges[i-1].End && p <= r.Start {
|
||||
// p is in the gap between the previous range and this range
|
||||
return p
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool {
|
||||
if p < h.deletedBelow {
|
||||
return true
|
||||
}
|
||||
// Iterating over the slices is faster than using a binary search (using slices.BinarySearchFunc).
|
||||
for i := len(h.ranges) - 1; i >= 0; i-- {
|
||||
if p > h.ranges[i].End {
|
||||
return false
|
||||
}
|
||||
if p <= h.ranges[i].End && p >= h.ranges[i].Start {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
228
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
Normal file
228
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
Normal file
@@ -0,0 +1,228 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
const reorderingThreshold = 1
|
||||
|
||||
// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space.
|
||||
// Every received packet is acknowledged immediately.
|
||||
type receivedPacketTracker struct {
|
||||
ect0, ect1, ecnce uint64
|
||||
|
||||
packetHistory receivedPacketHistory
|
||||
|
||||
lastAck *wire.AckFrame
|
||||
hasNewAck bool // true as soon as we received an ack-eliciting new packet
|
||||
}
|
||||
|
||||
func newReceivedPacketTracker() *receivedPacketTracker {
|
||||
return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()}
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, ackEliciting bool) error {
|
||||
if isNew := h.packetHistory.ReceivedPacket(pn); !isNew {
|
||||
return fmt.Errorf("receivedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
|
||||
}
|
||||
|
||||
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
|
||||
switch ecn {
|
||||
case protocol.ECT0:
|
||||
h.ect0++
|
||||
case protocol.ECT1:
|
||||
h.ect1++
|
||||
case protocol.ECNCE:
|
||||
h.ecnce++
|
||||
}
|
||||
if !ackEliciting {
|
||||
return nil
|
||||
}
|
||||
h.hasNewAck = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
|
||||
if !h.hasNewAck {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This function always returns the same ACK frame struct, filled with the most recent values.
|
||||
ack := h.lastAck
|
||||
if ack == nil {
|
||||
ack = &wire.AckFrame{}
|
||||
}
|
||||
ack.Reset()
|
||||
ack.ECT0 = h.ect0
|
||||
ack.ECT1 = h.ect1
|
||||
ack.ECNCE = h.ecnce
|
||||
for r := range h.packetHistory.Backward() {
|
||||
ack.AckRanges = append(ack.AckRanges, wire.AckRange{Smallest: r.Start, Largest: r.End})
|
||||
}
|
||||
|
||||
h.lastAck = ack
|
||||
h.hasNewAck = false
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
|
||||
return h.packetHistory.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
|
||||
// number of ack-eliciting packets received before sending an ACK
|
||||
const packetsBeforeAck = 2
|
||||
|
||||
// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space.
|
||||
// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached.
|
||||
type appDataReceivedPacketTracker struct {
|
||||
receivedPacketTracker
|
||||
|
||||
largestObservedRcvdTime monotime.Time
|
||||
|
||||
largestObserved protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
|
||||
maxAckDelay time.Duration
|
||||
ackQueued bool // true if we need send a new ACK
|
||||
|
||||
ackElicitingPacketsReceivedSinceLastAck int
|
||||
ackAlarm monotime.Time
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker {
|
||||
h := &appDataReceivedPacketTracker{
|
||||
receivedPacketTracker: *newReceivedPacketTracker(),
|
||||
maxAckDelay: protocol.MaxAckDelay,
|
||||
logger: logger,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime monotime.Time, ackEliciting bool) error {
|
||||
if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, ackEliciting); err != nil {
|
||||
return err
|
||||
}
|
||||
if pn >= h.largestObserved {
|
||||
h.largestObserved = pn
|
||||
h.largestObservedRcvdTime = rcvTime
|
||||
}
|
||||
if !ackEliciting {
|
||||
return nil
|
||||
}
|
||||
h.ackElicitingPacketsReceivedSinceLastAck++
|
||||
isMissing := h.isMissing(pn)
|
||||
if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) {
|
||||
h.ackQueued = true
|
||||
h.ackAlarm = 0 // cancel the ack alarm
|
||||
}
|
||||
if !h.ackQueued {
|
||||
// No ACK queued, but we'll need to acknowledge the packet after max_ack_delay.
|
||||
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IgnoreBelow sets a lower limit for acknowledging packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
||||
if pn <= h.ignoreBelow {
|
||||
return
|
||||
}
|
||||
h.ignoreBelow = pn
|
||||
h.packetHistory.DeleteBelow(pn)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tIgnoring all packets below %d.", pn)
|
||||
}
|
||||
}
|
||||
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil || p < h.ignoreBelow {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
if h.largestObserved < reorderingThreshold {
|
||||
return false
|
||||
}
|
||||
highestMissing := h.packetHistory.HighestMissingUpTo(h.largestObserved - reorderingThreshold)
|
||||
if highestMissing == protocol.InvalidPacketNumber {
|
||||
return false
|
||||
}
|
||||
if highestMissing < h.lastAck.LargestAcked() {
|
||||
// the packet was already reported missing in the last ACK
|
||||
return false
|
||||
}
|
||||
return highestMissing > h.lastAck.LargestAcked()-reorderingThreshold
|
||||
}
|
||||
|
||||
func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
|
||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ACK, send an ACK immediately.
|
||||
if wasMissing {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// send an ACK every 2 ack-eliciting packets
|
||||
if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// queue an ACK if there are new missing packets to report
|
||||
if h.hasNewMissingPackets() {
|
||||
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
|
||||
return true
|
||||
}
|
||||
|
||||
// queue an ACK if the packet was ECN-CE marked
|
||||
if ecn == protocol.ECNCE {
|
||||
h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *appDataReceivedPacketTracker) GetAckFrame(now monotime.Time, onlyIfQueued bool) *wire.AckFrame {
|
||||
if onlyIfQueued && !h.ackQueued {
|
||||
if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
|
||||
return nil
|
||||
}
|
||||
if h.logger.Debug() && !h.ackAlarm.IsZero() {
|
||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||
}
|
||||
}
|
||||
ack := h.receivedPacketTracker.GetAckFrame()
|
||||
if ack == nil {
|
||||
return nil
|
||||
}
|
||||
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
|
||||
h.ackQueued = false
|
||||
h.ackAlarm = 0
|
||||
h.ackElicitingPacketsReceivedSinceLastAck = 0
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *appDataReceivedPacketTracker) GetAlarmTimeout() monotime.Time { return h.ackAlarm }
|
||||
46
vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
46
vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
package ackhandler
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The SendMode says what kind of packets can be sent.
|
||||
type SendMode uint8
|
||||
|
||||
const (
|
||||
// SendNone means that no packets should be sent
|
||||
SendNone SendMode = iota
|
||||
// SendAck means an ACK-only packet should be sent
|
||||
SendAck
|
||||
// SendPTOInitial means that an Initial probe packet should be sent
|
||||
SendPTOInitial
|
||||
// SendPTOHandshake means that a Handshake probe packet should be sent
|
||||
SendPTOHandshake
|
||||
// SendPTOAppData means that an Application data probe packet should be sent
|
||||
SendPTOAppData
|
||||
// SendPacingLimited means that the pacer doesn't allow sending of a packet right now,
|
||||
// but will do in a little while.
|
||||
// The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend.
|
||||
SendPacingLimited
|
||||
// SendAny means that any packet should be sent
|
||||
SendAny
|
||||
)
|
||||
|
||||
func (s SendMode) String() string {
|
||||
switch s {
|
||||
case SendNone:
|
||||
return "none"
|
||||
case SendAck:
|
||||
return "ack"
|
||||
case SendPTOInitial:
|
||||
return "pto (Initial)"
|
||||
case SendPTOHandshake:
|
||||
return "pto (Handshake)"
|
||||
case SendPTOAppData:
|
||||
return "pto (Application Data)"
|
||||
case SendAny:
|
||||
return "any"
|
||||
case SendPacingLimited:
|
||||
return "pacing limited"
|
||||
default:
|
||||
return fmt.Sprintf("invalid send mode: %d", s)
|
||||
}
|
||||
}
|
||||
1143
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
1143
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
274
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
274
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
@@ -0,0 +1,274 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"slices"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const maxSkippedPackets = 4
|
||||
|
||||
type sentPacketHistory struct {
|
||||
packets []*packet
|
||||
pathProbePackets []packetWithPacketNumber
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
numOutstanding int
|
||||
|
||||
firstPacketNumber protocol.PacketNumber
|
||||
highestPacketNumber protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newSentPacketHistory(isAppData bool) *sentPacketHistory {
|
||||
h := &sentPacketHistory{
|
||||
highestPacketNumber: protocol.InvalidPacketNumber,
|
||||
firstPacketNumber: protocol.InvalidPacketNumber,
|
||||
}
|
||||
if isAppData {
|
||||
h.packets = make([]*packet, 0, 32)
|
||||
h.skippedPackets = make([]protocol.PacketNumber, 0, maxSkippedPackets)
|
||||
} else {
|
||||
h.packets = make([]*packet, 0, 6)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {
|
||||
if h.highestPacketNumber != protocol.InvalidPacketNumber {
|
||||
if pn != h.highestPacketNumber+1 {
|
||||
panic("non-sequential packet number use")
|
||||
}
|
||||
}
|
||||
h.highestPacketNumber = pn
|
||||
if len(h.packets) == 0 {
|
||||
h.firstPacketNumber = pn
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
|
||||
h.checkSequentialPacketNumberUse(pn)
|
||||
if len(h.packets) > 0 {
|
||||
h.packets = append(h.packets, nil)
|
||||
}
|
||||
if len(h.skippedPackets) == maxSkippedPackets {
|
||||
h.skippedPackets = slices.Delete(h.skippedPackets, 0, 1)
|
||||
}
|
||||
h.skippedPackets = append(h.skippedPackets, pn)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPacket(pn protocol.PacketNumber, p *packet) {
|
||||
h.checkSequentialPacketNumberUse(pn)
|
||||
h.packets = append(h.packets, p)
|
||||
if p.Outstanding() {
|
||||
h.numOutstanding++
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPathProbePacket(pn protocol.PacketNumber, p *packet) {
|
||||
h.checkSequentialPacketNumberUse(pn)
|
||||
h.packets = append(h.packets, &packet{isPathProbePacket: true})
|
||||
h.pathProbePackets = append(h.pathProbePackets, packetWithPacketNumber{PacketNumber: pn, packet: p})
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Packets() iter.Seq2[protocol.PacketNumber, *packet] {
|
||||
return func(yield func(protocol.PacketNumber, *packet) bool) {
|
||||
// h.firstPacketNumber might be updated in the yield function,
|
||||
// so we need to save it here.
|
||||
firstPacketNumber := h.firstPacketNumber
|
||||
for i, p := range h.packets {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if !yield(firstPacketNumber+protocol.PacketNumber(i), p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) PathProbes() iter.Seq2[protocol.PacketNumber, *packet] {
|
||||
return func(yield func(protocol.PacketNumber, *packet) bool) {
|
||||
for _, p := range h.pathProbePackets {
|
||||
if !yield(p.PacketNumber, p.packet) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FirstOutstanding returns the first outstanding packet.
|
||||
func (h *sentPacketHistory) FirstOutstanding() (protocol.PacketNumber, *packet) {
|
||||
if !h.HasOutstandingPackets() {
|
||||
return protocol.InvalidPacketNumber, nil
|
||||
}
|
||||
for i, p := range h.packets {
|
||||
if p != nil && p.Outstanding() {
|
||||
return h.firstPacketNumber + protocol.PacketNumber(i), p
|
||||
}
|
||||
}
|
||||
return protocol.InvalidPacketNumber, nil
|
||||
}
|
||||
|
||||
// FirstOutstandingPathProbe returns the first outstanding path probe packet
|
||||
func (h *sentPacketHistory) FirstOutstandingPathProbe() (protocol.PacketNumber, *packet) {
|
||||
if len(h.pathProbePackets) == 0 {
|
||||
return protocol.InvalidPacketNumber, nil
|
||||
}
|
||||
return h.pathProbePackets[0].PacketNumber, h.pathProbePackets[0].packet
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SkippedPackets() iter.Seq[protocol.PacketNumber] {
|
||||
return func(yield func(protocol.PacketNumber) bool) {
|
||||
for _, p := range h.skippedPackets {
|
||||
if !yield(p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Len() int {
|
||||
return len(h.packets)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) NumOutstanding() int {
|
||||
return h.numOutstanding
|
||||
}
|
||||
|
||||
// Remove removes a packet from the sent packet history.
|
||||
// It must not be used for skipped packet numbers.
|
||||
func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
|
||||
idx, ok := h.getIndex(pn)
|
||||
if !ok {
|
||||
return fmt.Errorf("packet %d not found in sent packet history", pn)
|
||||
}
|
||||
p := h.packets[idx]
|
||||
if p.Outstanding() {
|
||||
h.numOutstanding--
|
||||
if h.numOutstanding < 0 {
|
||||
panic("negative number of outstanding packets")
|
||||
}
|
||||
}
|
||||
h.packets[idx] = nil
|
||||
// clean up all skipped packets directly before this packet number
|
||||
var hasPacketBefore bool
|
||||
for idx > 0 {
|
||||
idx--
|
||||
if h.packets[idx] != nil {
|
||||
hasPacketBefore = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasPacketBefore {
|
||||
h.cleanupStart()
|
||||
}
|
||||
if len(h.packets) > 0 && h.packets[0] == nil {
|
||||
panic("cleanup failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePathProbe removes a path probe packet.
|
||||
// It scales O(N), but that's ok, since we don't expect to send many path probe packets.
|
||||
// It is not valid to call this function in IteratePathProbes.
|
||||
func (h *sentPacketHistory) RemovePathProbe(pn protocol.PacketNumber) *packet {
|
||||
var packetToDelete *packet
|
||||
idx := -1
|
||||
for i, p := range h.pathProbePackets {
|
||||
if p.PacketNumber == pn {
|
||||
packetToDelete = p.packet
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx != -1 {
|
||||
// don't use slices.Delete, because it zeros the deleted element
|
||||
copy(h.pathProbePackets[idx:], h.pathProbePackets[idx+1:])
|
||||
h.pathProbePackets = h.pathProbePackets[:len(h.pathProbePackets)-1]
|
||||
}
|
||||
return packetToDelete
|
||||
}
|
||||
|
||||
// getIndex gets the index of packet p in the packets slice.
|
||||
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
|
||||
if len(h.packets) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
if p < h.firstPacketNumber {
|
||||
return 0, false
|
||||
}
|
||||
index := int(p - h.firstPacketNumber)
|
||||
if index > len(h.packets)-1 {
|
||||
return 0, false
|
||||
}
|
||||
return index, true
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingPackets() bool {
|
||||
return h.numOutstanding > 0
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingPathProbes() bool {
|
||||
return len(h.pathProbePackets) > 0
|
||||
}
|
||||
|
||||
// delete all nil entries at the beginning of the packets slice
|
||||
func (h *sentPacketHistory) cleanupStart() {
|
||||
for i, p := range h.packets {
|
||||
if p != nil {
|
||||
h.packets = h.packets[i:]
|
||||
h.firstPacketNumber += protocol.PacketNumber(i)
|
||||
return
|
||||
}
|
||||
}
|
||||
h.packets = h.packets[:0]
|
||||
h.firstPacketNumber = protocol.InvalidPacketNumber
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber {
|
||||
if len(h.packets) == 0 {
|
||||
return protocol.InvalidPacketNumber
|
||||
}
|
||||
return h.firstPacketNumber
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) {
|
||||
idx, ok := h.getIndex(pn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p := h.packets[idx]
|
||||
if p.Outstanding() {
|
||||
h.numOutstanding--
|
||||
if h.numOutstanding < 0 {
|
||||
panic("negative number of outstanding packets")
|
||||
}
|
||||
}
|
||||
h.packets[idx] = nil
|
||||
if idx == 0 {
|
||||
h.cleanupStart()
|
||||
}
|
||||
}
|
||||
|
||||
// Difference returns the difference between two packet numbers a and b (a - b),
|
||||
// taking into account any skipped packet numbers between them.
|
||||
//
|
||||
// Note that old skipped packets are garbage collected at some point,
|
||||
// so this function is not guaranteed to return the correct result after a while.
|
||||
func (h *sentPacketHistory) Difference(a, b protocol.PacketNumber) protocol.PacketNumber {
|
||||
diff := a - b
|
||||
if len(h.skippedPackets) == 0 {
|
||||
return diff
|
||||
}
|
||||
if a < h.skippedPackets[0] || b > h.skippedPackets[len(h.skippedPackets)-1] {
|
||||
return diff
|
||||
}
|
||||
for _, p := range h.skippedPackets {
|
||||
if p > b && p < a {
|
||||
diff--
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
22
vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go
generated
vendored
Normal file
22
vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Bandwidth of a connection
|
||||
type Bandwidth uint64
|
||||
|
||||
const (
|
||||
// BitsPerSecond is 1 bit per second
|
||||
BitsPerSecond Bandwidth = 1
|
||||
// BytesPerSecond is 1 byte per second
|
||||
BytesPerSecond = 8 * BitsPerSecond
|
||||
)
|
||||
|
||||
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
||||
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
|
||||
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
||||
}
|
||||
20
vendor/github.com/quic-go/quic-go/internal/congestion/clock.go
generated
vendored
Normal file
20
vendor/github.com/quic-go/quic-go/internal/congestion/clock.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
)
|
||||
|
||||
// A Clock returns the current time
|
||||
type Clock interface {
|
||||
Now() monotime.Time
|
||||
}
|
||||
|
||||
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
||||
type DefaultClock struct{}
|
||||
|
||||
var _ Clock = DefaultClock{}
|
||||
|
||||
// Now gets the current time
|
||||
func (DefaultClock) Now() monotime.Time {
|
||||
return monotime.Now()
|
||||
}
|
||||
214
vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go
generated
vendored
Normal file
214
vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go
generated
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
||||
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
|
||||
|
||||
// Constants based on TCP defaults.
|
||||
// The following constants are in 2^10 fractions of a second instead of ms to
|
||||
// allow a 10 shift right to divide.
|
||||
|
||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const (
|
||||
cubeScale = 40
|
||||
cubeCongestionWindowScale = 410
|
||||
cubeFactor = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
||||
// TODO: when re-enabling cubic, make sure to use the actual packet size here
|
||||
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
|
||||
)
|
||||
|
||||
const defaultNumConnections = 1
|
||||
|
||||
// Default Cubic backoff factor
|
||||
const beta float32 = 0.7
|
||||
|
||||
// Additional backoff factor when loss occurs in the concave part of the Cubic
|
||||
// curve. This additional backoff factor is expected to give up bandwidth to
|
||||
// new concurrent flows and speed up convergence.
|
||||
const betaLastMax float32 = 0.85
|
||||
|
||||
// Cubic implements the cubic algorithm from TCP
|
||||
type Cubic struct {
|
||||
clock Clock
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// Time when this cycle started, after last loss event.
|
||||
epoch monotime.Time
|
||||
|
||||
// Max congestion window used just before last loss event.
|
||||
// Note: to improve fairness to other streams an additional back off is
|
||||
// applied to this value if the new value is below our latest value.
|
||||
lastMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Number of acked bytes since the cycle started (epoch).
|
||||
ackedBytesCount protocol.ByteCount
|
||||
|
||||
// TCP Reno equivalent congestion window in packets.
|
||||
estimatedTCPcongestionWindow protocol.ByteCount
|
||||
|
||||
// Origin point of cubic function.
|
||||
originPointCongestionWindow protocol.ByteCount
|
||||
|
||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||
timeToOriginPoint uint32
|
||||
|
||||
// Last congestion window in packets computed by cubic function.
|
||||
lastTargetCongestionWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
// NewCubic returns a new Cubic instance
|
||||
func NewCubic(clock Clock) *Cubic {
|
||||
c := &Cubic{
|
||||
clock: clock,
|
||||
numConnections: defaultNumConnections,
|
||||
}
|
||||
c.Reset()
|
||||
return c
|
||||
}
|
||||
|
||||
// Reset is called after a timeout to reset the cubic state
|
||||
func (c *Cubic) Reset() {
|
||||
c.epoch = 0
|
||||
c.lastMaxCongestionWindow = 0
|
||||
c.ackedBytesCount = 0
|
||||
c.estimatedTCPcongestionWindow = 0
|
||||
c.originPointCongestionWindow = 0
|
||||
c.timeToOriginPoint = 0
|
||||
c.lastTargetCongestionWindow = 0
|
||||
}
|
||||
|
||||
func (c *Cubic) alpha() float32 {
|
||||
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
|
||||
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
|
||||
// We derive the equivalent alpha for an N-connection emulation as:
|
||||
b := c.beta()
|
||||
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
|
||||
}
|
||||
|
||||
func (c *Cubic) beta() float32 {
|
||||
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
||||
// emulation, which emulates the effective backoff of an ensemble of N
|
||||
// TCP-Reno connections on a single loss event. The effective multiplier is
|
||||
// computed as:
|
||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
func (c *Cubic) betaLastMax() float32 {
|
||||
// betaLastMax is the additional backoff factor after loss for our
|
||||
// N-connection emulation, which emulates the additional backoff of
|
||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||
// effective multiplier is computed as:
|
||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||
// the available congestion window. Resets Cubic state during quiescence.
|
||||
func (c *Cubic) OnApplicationLimited() {
|
||||
// When sender is not using the available congestion window, the window does
|
||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||
// using the entire window during the time since the beginning of the current
|
||||
// "epoch" (the end of the last loss recovery period). Since
|
||||
// application-limited periods break this assumption, we reset the epoch when
|
||||
// in such a period. This reset effectively freezes congestion window growth
|
||||
// through application-limited periods and allows Cubic growth to continue
|
||||
// when the entire window is being used.
|
||||
c.epoch = 0
|
||||
}
|
||||
|
||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||
// a loss event. Returns the new congestion window in packets. The new
|
||||
// congestion window is a multiplicative decrease of our current window.
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
|
||||
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
|
||||
// We never reached the old max, so assume we are competing with another
|
||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||
} else {
|
||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||
}
|
||||
c.epoch = 0 // Reset time.
|
||||
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||
}
|
||||
|
||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||
// Returns the new congestion window in packets. The new congestion window
|
||||
// follows a cubic function that depends on the time passed since last
|
||||
// packet loss.
|
||||
func (c *Cubic) CongestionWindowAfterAck(
|
||||
ackedBytes protocol.ByteCount,
|
||||
currentCongestionWindow protocol.ByteCount,
|
||||
delayMin time.Duration,
|
||||
eventTime monotime.Time,
|
||||
) protocol.ByteCount {
|
||||
c.ackedBytesCount += ackedBytes
|
||||
|
||||
if c.epoch.IsZero() {
|
||||
// First ACK after a loss event.
|
||||
c.epoch = eventTime // Start of epoch.
|
||||
c.ackedBytesCount = ackedBytes // Reset count.
|
||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||
c.timeToOriginPoint = 0
|
||||
c.originPointCongestionWindow = currentCongestionWindow
|
||||
} else {
|
||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||
}
|
||||
}
|
||||
|
||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||
// the round trip time in account. This is done to allow us to use shift as a
|
||||
// divide operator.
|
||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||
|
||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
|
||||
var targetCongestionWindow protocol.ByteCount
|
||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||
} else {
|
||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||
}
|
||||
// Limit the CWND increase to half the acked bytes.
|
||||
targetCongestionWindow = min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||
|
||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||
// time we ack an estimated tcp window of bytes. For small
|
||||
// congestion windows (less than 25), the formula below will
|
||||
// increase slightly slower than linearly per estimated tcp window
|
||||
// of bytes.
|
||||
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
|
||||
c.ackedBytesCount = 0
|
||||
|
||||
// We have a new cubic congestion window.
|
||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||
|
||||
// Compute target congestion_window based on cubic target and estimated TCP
|
||||
// congestion_window, use highest (fastest).
|
||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||
}
|
||||
return targetCongestionWindow
|
||||
}
|
||||
|
||||
// SetNumConnections sets the number of emulated connections
|
||||
func (c *Cubic) SetNumConnections(n int) {
|
||||
c.numConnections = n
|
||||
}
|
||||
330
vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go
generated
vendored
Normal file
330
vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go
generated
vendored
Normal file
@@ -0,0 +1,330 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
|
||||
// Used in QUIC for congestion window computations in bytes.
|
||||
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
|
||||
maxBurstPackets = 3
|
||||
renoBeta = 0.7 // Reno backoff factor.
|
||||
minCongestionWindowPackets = 2
|
||||
initialCongestionWindow = 32
|
||||
)
|
||||
|
||||
type cubicSender struct {
|
||||
hybridSlowStart HybridSlowStart
|
||||
rttStats *utils.RTTStats
|
||||
connStats *utils.ConnectionStats
|
||||
cubic *Cubic
|
||||
pacer *pacer
|
||||
clock Clock
|
||||
|
||||
reno bool
|
||||
|
||||
// Track the largest packet that has been sent.
|
||||
largestSentPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet that has been acked.
|
||||
largestAckedPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||
largestSentAtLastCutback protocol.PacketNumber
|
||||
|
||||
// Whether the last loss event caused us to exit slowstart.
|
||||
// Used for stats collection of slowstartPacketsLost
|
||||
lastCutbackExitedSlowstart bool
|
||||
|
||||
// Congestion window in bytes.
|
||||
congestionWindow protocol.ByteCount
|
||||
|
||||
// Slow start congestion window in bytes, aka ssthresh.
|
||||
slowStartThreshold protocol.ByteCount
|
||||
|
||||
// ACK counter for the Reno implementation.
|
||||
numAckedPackets uint64
|
||||
|
||||
initialCongestionWindow protocol.ByteCount
|
||||
initialMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
maxDatagramSize protocol.ByteCount
|
||||
|
||||
lastState qlog.CongestionState
|
||||
qlogger qlogwriter.Recorder
|
||||
}
|
||||
|
||||
var (
|
||||
_ SendAlgorithm = &cubicSender{}
|
||||
_ SendAlgorithmWithDebugInfos = &cubicSender{}
|
||||
)
|
||||
|
||||
// NewCubicSender makes a new cubic sender
|
||||
func NewCubicSender(
|
||||
clock Clock,
|
||||
rttStats *utils.RTTStats,
|
||||
connStats *utils.ConnectionStats,
|
||||
initialMaxDatagramSize protocol.ByteCount,
|
||||
reno bool,
|
||||
qlogger qlogwriter.Recorder,
|
||||
) *cubicSender {
|
||||
return newCubicSender(
|
||||
clock,
|
||||
rttStats,
|
||||
connStats,
|
||||
reno,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow*initialMaxDatagramSize,
|
||||
protocol.MaxCongestionWindowPackets*initialMaxDatagramSize,
|
||||
qlogger,
|
||||
)
|
||||
}
|
||||
|
||||
func newCubicSender(
|
||||
clock Clock,
|
||||
rttStats *utils.RTTStats,
|
||||
connStats *utils.ConnectionStats,
|
||||
reno bool,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow,
|
||||
initialMaxCongestionWindow protocol.ByteCount,
|
||||
qlogger qlogwriter.Recorder,
|
||||
) *cubicSender {
|
||||
c := &cubicSender{
|
||||
rttStats: rttStats,
|
||||
connStats: connStats,
|
||||
largestSentPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestAckedPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestSentAtLastCutback: protocol.InvalidPacketNumber,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
||||
congestionWindow: initialCongestionWindow,
|
||||
slowStartThreshold: protocol.MaxByteCount,
|
||||
cubic: NewCubic(clock),
|
||||
clock: clock,
|
||||
reno: reno,
|
||||
qlogger: qlogger,
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
}
|
||||
c.pacer = newPacer(c.BandwidthEstimate)
|
||||
if c.qlogger != nil {
|
||||
c.lastState = qlog.CongestionStateSlowStart
|
||||
c.qlogger.RecordEvent(qlog.CongestionStateUpdated{
|
||||
State: qlog.CongestionStateSlowStart,
|
||||
})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) monotime.Time {
|
||||
return c.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (c *cubicSender) HasPacingBudget(now monotime.Time) bool {
|
||||
return c.pacer.Budget(now) >= c.maxDatagramSize
|
||||
}
|
||||
|
||||
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {
|
||||
return c.maxDatagramSize * protocol.MaxCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) minCongestionWindow() protocol.ByteCount {
|
||||
return c.maxDatagramSize * minCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(
|
||||
sentTime monotime.Time,
|
||||
_ protocol.ByteCount,
|
||||
packetNumber protocol.PacketNumber,
|
||||
bytes protocol.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
c.pacer.SentPacket(sentTime, bytes)
|
||||
if !isRetransmittable {
|
||||
return
|
||||
}
|
||||
c.largestSentPacketNumber = packetNumber
|
||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||
}
|
||||
|
||||
func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
|
||||
return bytesInFlight < c.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (c *cubicSender) InRecovery() bool {
|
||||
return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
|
||||
}
|
||||
|
||||
func (c *cubicSender) InSlowStart() bool {
|
||||
return c.GetCongestionWindow() < c.slowStartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
||||
return c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) MaybeExitSlowStart() {
|
||||
if c.InSlowStart() &&
|
||||
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
|
||||
// exit slow start
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.maybeQlogStateChange(qlog.CongestionStateCongestionAvoidance)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketAcked(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime monotime.Time,
|
||||
) {
|
||||
c.largestAckedPacketNumber = max(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
return
|
||||
}
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||
if c.InSlowStart() {
|
||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
|
||||
c.connStats.PacketsLost.Add(1)
|
||||
c.connStats.BytesLost.Add(uint64(lostBytes))
|
||||
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
return
|
||||
}
|
||||
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
||||
c.maybeQlogStateChange(qlog.CongestionStateRecovery)
|
||||
|
||||
if c.reno {
|
||||
c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta)
|
||||
} else {
|
||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||
}
|
||||
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
|
||||
c.congestionWindow = minCwnd
|
||||
}
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||
// reset packet count from congestion avoidance mode. We start
|
||||
// counting again when we're out of recovery.
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
|
||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||
// represents, but quic has a separate ack for each packet.
|
||||
func (c *cubicSender) maybeIncreaseCwnd(
|
||||
_ protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime monotime.Time,
|
||||
) {
|
||||
// Do not increase the congestion window unless the sender is close to using
|
||||
// the current window.
|
||||
if !c.isCwndLimited(priorInFlight) {
|
||||
c.cubic.OnApplicationLimited()
|
||||
c.maybeQlogStateChange(qlog.CongestionStateApplicationLimited)
|
||||
return
|
||||
}
|
||||
if c.congestionWindow >= c.maxCongestionWindow() {
|
||||
return
|
||||
}
|
||||
if c.InSlowStart() {
|
||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.maybeQlogStateChange(qlog.CongestionStateSlowStart)
|
||||
return
|
||||
}
|
||||
// Congestion avoidance
|
||||
c.maybeQlogStateChange(qlog.CongestionStateCongestionAvoidance)
|
||||
if c.reno {
|
||||
// Classic Reno congestion avoidance.
|
||||
c.numAckedPackets++
|
||||
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
} else {
|
||||
c.congestionWindow = min(
|
||||
c.maxCongestionWindow(),
|
||||
c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
|
||||
congestionWindow := c.GetCongestionWindow()
|
||||
if bytesInFlight >= congestionWindow {
|
||||
return true
|
||||
}
|
||||
availableBytes := congestionWindow - bytesInFlight
|
||||
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
||||
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
|
||||
}
|
||||
|
||||
// BandwidthEstimate returns the current bandwidth estimate
|
||||
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
||||
srtt := c.rttStats.SmoothedRTT()
|
||||
if srtt == 0 {
|
||||
// This should never happen, but if it does, avoid division by zero.
|
||||
srtt = protocol.TimerGranularity
|
||||
}
|
||||
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
||||
}
|
||||
|
||||
// OnRetransmissionTimeout is called on an retransmission timeout
|
||||
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
|
||||
if !packetsRetransmitted {
|
||||
return
|
||||
}
|
||||
c.hybridSlowStart.Restart()
|
||||
c.cubic.Reset()
|
||||
c.slowStartThreshold = c.congestionWindow / 2
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when the connection is migrated (?)
|
||||
func (c *cubicSender) OnConnectionMigration() {
|
||||
c.hybridSlowStart.Restart()
|
||||
c.largestSentPacketNumber = protocol.InvalidPacketNumber
|
||||
c.largestAckedPacketNumber = protocol.InvalidPacketNumber
|
||||
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
|
||||
c.lastCutbackExitedSlowstart = false
|
||||
c.cubic.Reset()
|
||||
c.numAckedPackets = 0
|
||||
c.congestionWindow = c.initialCongestionWindow
|
||||
c.slowStartThreshold = c.initialMaxCongestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) maybeQlogStateChange(new qlog.CongestionState) {
|
||||
if c.qlogger == nil || new == c.lastState {
|
||||
return
|
||||
}
|
||||
c.qlogger.RecordEvent(qlog.CongestionStateUpdated{State: new})
|
||||
c.lastState = new
|
||||
}
|
||||
|
||||
func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
if s < c.maxDatagramSize {
|
||||
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
|
||||
}
|
||||
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
|
||||
c.maxDatagramSize = s
|
||||
if cwndIsMinCwnd {
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
c.pacer.SetMaxDatagramSize(s)
|
||||
}
|
||||
112
vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go
generated
vendored
Normal file
112
vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go
generated
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
||||
// tcp_cubic.c.
|
||||
const hybridStartLowWindow = protocol.ByteCount(16)
|
||||
|
||||
// Number of delay samples for detecting the increase of delay.
|
||||
const hybridStartMinSamples = uint32(8)
|
||||
|
||||
// Exit slow start if the min rtt has increased by more than 1/8th.
|
||||
const hybridStartDelayFactorExp = 3 // 2^3 = 8
|
||||
// The original paper specifies 2 and 8ms, but those have changed over time.
|
||||
const (
|
||||
hybridStartDelayMinThresholdUs = int64(4000)
|
||||
hybridStartDelayMaxThresholdUs = int64(16000)
|
||||
)
|
||||
|
||||
// HybridSlowStart implements the TCP hybrid slow start algorithm
|
||||
type HybridSlowStart struct {
|
||||
endPacketNumber protocol.PacketNumber
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
started bool
|
||||
currentMinRTT time.Duration
|
||||
rttSampleCount uint32
|
||||
hystartFound bool
|
||||
}
|
||||
|
||||
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
|
||||
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
|
||||
s.endPacketNumber = lastSent
|
||||
s.currentMinRTT = 0
|
||||
s.rttSampleCount = 0
|
||||
s.started = true
|
||||
}
|
||||
|
||||
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
|
||||
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
|
||||
return s.endPacketNumber < ack
|
||||
}
|
||||
|
||||
// ShouldExitSlowStart should be called on every new ack frame, since a new
|
||||
// RTT measurement can be made then.
|
||||
// rtt: the RTT for this ack packet.
|
||||
// minRTT: is the lowest delay (RTT) we have seen during the session.
|
||||
// congestionWindow: the congestion window in packets.
|
||||
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
|
||||
if !s.started {
|
||||
// Time to start the hybrid slow start.
|
||||
s.StartReceiveRound(s.lastSentPacketNumber)
|
||||
}
|
||||
if s.hystartFound {
|
||||
return true
|
||||
}
|
||||
// Second detection parameter - delay increase detection.
|
||||
// Compare the minimum delay (s.currentMinRTT) of the current
|
||||
// burst of packets relative to the minimum delay during the session.
|
||||
// Note: we only look at the first few(8) packets in each burst, since we
|
||||
// only want to compare the lowest RTT of the burst relative to previous
|
||||
// bursts.
|
||||
s.rttSampleCount++
|
||||
if s.rttSampleCount <= hybridStartMinSamples {
|
||||
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
|
||||
s.currentMinRTT = latestRTT
|
||||
}
|
||||
}
|
||||
// We only need to check this once per round.
|
||||
if s.rttSampleCount == hybridStartMinSamples {
|
||||
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
||||
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
||||
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
||||
minRTTincreaseThresholdUs = min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
||||
minRTTincreaseThreshold := time.Duration(max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
||||
|
||||
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
||||
s.hystartFound = true
|
||||
}
|
||||
}
|
||||
// Exit from slow start if the cwnd is greater than 16 and
|
||||
// increasing delay is found.
|
||||
return congestionWindow >= hybridStartLowWindow && s.hystartFound
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet was sent
|
||||
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
|
||||
s.lastSentPacketNumber = packetNumber
|
||||
}
|
||||
|
||||
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
|
||||
// the round when the final packet of the burst is received and start it on
|
||||
// the next incoming ack.
|
||||
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
|
||||
if s.IsEndOfRound(ackedPacketNumber) {
|
||||
s.started = false
|
||||
}
|
||||
}
|
||||
|
||||
// Started returns true if started
|
||||
func (s *HybridSlowStart) Started() bool {
|
||||
return s.started
|
||||
}
|
||||
|
||||
// Restart the slow start phase
|
||||
func (s *HybridSlowStart) Restart() {
|
||||
s.started = false
|
||||
s.hystartFound = false
|
||||
}
|
||||
27
vendor/github.com/quic-go/quic-go/internal/congestion/interface.go
generated
vendored
Normal file
27
vendor/github.com/quic-go/quic-go/internal/congestion/interface.go
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// A SendAlgorithm performs congestion control
|
||||
type SendAlgorithm interface {
|
||||
TimeUntilSend(bytesInFlight protocol.ByteCount) monotime.Time
|
||||
HasPacingBudget(now monotime.Time) bool
|
||||
OnPacketSent(sentTime monotime.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
|
||||
CanSend(bytesInFlight protocol.ByteCount) bool
|
||||
MaybeExitSlowStart()
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime monotime.Time)
|
||||
OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
SetMaxDatagramSize(protocol.ByteCount)
|
||||
}
|
||||
|
||||
// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos
|
||||
type SendAlgorithmWithDebugInfos interface {
|
||||
SendAlgorithm
|
||||
InSlowStart() bool
|
||||
InRecovery() bool
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
}
|
||||
110
vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go
generated
vendored
Normal file
110
vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go
generated
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const maxBurstSizePackets = 10
|
||||
|
||||
// The pacer implements a token bucket pacing algorithm.
|
||||
type pacer struct {
|
||||
budgetAtLastSent protocol.ByteCount
|
||||
maxDatagramSize protocol.ByteCount
|
||||
lastSentTime monotime.Time
|
||||
adjustedBandwidth func() uint64 // in bytes/s
|
||||
}
|
||||
|
||||
func newPacer(getBandwidth func() Bandwidth) *pacer {
|
||||
p := &pacer{
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
adjustedBandwidth: func() uint64 {
|
||||
// Bandwidth is in bits/s. We need the value in bytes/s.
|
||||
bw := uint64(getBandwidth() / BytesPerSecond)
|
||||
// Use a slightly higher value than the actual measured bandwidth.
|
||||
// RTT variations then won't result in under-utilization of the congestion window.
|
||||
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
|
||||
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
|
||||
return bw * 5 / 4
|
||||
},
|
||||
}
|
||||
p.budgetAtLastSent = p.maxBurstSize()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *pacer) SentPacket(sendTime monotime.Time, size protocol.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size >= budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *pacer) Budget(now monotime.Time) protocol.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
delta := now.Sub(p.lastSentTime)
|
||||
var added protocol.ByteCount
|
||||
if delta > 0 {
|
||||
added = p.timeScaledBandwidth(uint64(delta.Nanoseconds()))
|
||||
}
|
||||
budget := p.budgetAtLastSent + added
|
||||
if added > 0 && budget < p.budgetAtLastSent {
|
||||
budget = protocol.MaxByteCount
|
||||
}
|
||||
return min(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *pacer) maxBurstSize() protocol.ByteCount {
|
||||
return max(
|
||||
p.timeScaledBandwidth(uint64((protocol.MinPacingDelay + protocol.TimerGranularity).Nanoseconds())),
|
||||
maxBurstSizePackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// timeScaledBandwidth calculates the number of bytes that may be sent within
|
||||
// a given time interval (ns nanoseconds), based on the current bandwidth estimate.
|
||||
// It caps the scaled value to the maximum allowed burst and handles overflows.
|
||||
func (p *pacer) timeScaledBandwidth(ns uint64) protocol.ByteCount {
|
||||
bw := p.adjustedBandwidth()
|
||||
if bw == 0 {
|
||||
return 0
|
||||
}
|
||||
const nsPerSecond = 1e9
|
||||
maxBurst := maxBurstSizePackets * p.maxDatagramSize
|
||||
var scaled protocol.ByteCount
|
||||
if ns > math.MaxUint64/bw {
|
||||
scaled = maxBurst
|
||||
} else {
|
||||
scaled = protocol.ByteCount(bw * ns / nsPerSecond)
|
||||
}
|
||||
return scaled
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns zero if a packet can be sent immediately.
|
||||
func (p *pacer) TimeUntilSend() monotime.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return 0
|
||||
}
|
||||
diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent)
|
||||
bw := p.adjustedBandwidth()
|
||||
// We might need to round up this value.
|
||||
// Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires.
|
||||
d := diff / bw
|
||||
// this is effectively a math.Ceil, but using only integer math
|
||||
if diff%bw > 0 {
|
||||
d++
|
||||
}
|
||||
return p.lastSentTime.Add(max(protocol.MinPacingDelay, time.Duration(d)*time.Nanosecond))
|
||||
}
|
||||
|
||||
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
||||
122
vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
122
vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
package flowcontrol
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type baseFlowController struct {
|
||||
// for sending data
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
lastBlockedAt protocol.ByteCount
|
||||
|
||||
// for receiving data
|
||||
//nolint:structcheck // The mutex is used both by the stream and the connection flow controller
|
||||
mutex sync.Mutex
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowSize protocol.ByteCount
|
||||
maxReceiveWindowSize protocol.ByteCount
|
||||
|
||||
allowWindowIncrease func(size protocol.ByteCount) bool
|
||||
|
||||
epochStartTime monotime.Time
|
||||
epochStartOffset protocol.ByteCount
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.SendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) {
|
||||
if offset > c.sendWindow {
|
||||
c.sendWindow = offset
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *baseFlowController) SendWindowSize() protocol.ByteCount {
|
||||
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
|
||||
if c.bytesSent > c.sendWindow {
|
||||
return 0
|
||||
}
|
||||
return c.sendWindow - c.bytesSent
|
||||
}
|
||||
|
||||
// needs to be called with locked mutex
|
||||
func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than the threshold was consumed
|
||||
return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold))
|
||||
}
|
||||
|
||||
// getWindowUpdate updates the receive window, if necessary
|
||||
// it returns the new offset
|
||||
func (c *baseFlowController) getWindowUpdate(now monotime.Time) protocol.ByteCount {
|
||||
if !c.hasWindowUpdate() {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.maybeAdjustWindowSize(now)
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||
return c.receiveWindow
|
||||
}
|
||||
|
||||
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||
func (c *baseFlowController) maybeAdjustWindowSize(now monotime.Time) {
|
||||
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||
// don't do anything if less than half the window has been consumed
|
||||
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||
return
|
||||
}
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||
// window is consumed too fast, try to increase the window size
|
||||
newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||
if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) {
|
||||
c.receiveWindowSize = newSize
|
||||
}
|
||||
}
|
||||
c.startNewAutoTuningEpoch(now)
|
||||
}
|
||||
|
||||
func (c *baseFlowController) startNewAutoTuningEpoch(now monotime.Time) {
|
||||
c.epochStartTime = now
|
||||
c.epochStartOffset = c.bytesRead
|
||||
}
|
||||
|
||||
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||
return c.highestReceived > c.receiveWindow
|
||||
}
|
||||
113
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
113
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
package flowcontrol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
baseFlowController
|
||||
}
|
||||
|
||||
var _ ConnectionFlowController = &connectionFlowController{}
|
||||
|
||||
// NewConnectionFlowController gets a new flow controller for the connection
|
||||
// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0.
|
||||
func NewConnectionFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
allowWindowIncrease func(size protocol.ByteCount) bool,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) *connectionFlowController {
|
||||
return &connectionFlowController{
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
allowWindowIncrease: allowWindowIncrease,
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount, now monotime.Time) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// If this is the first frame received on this connection, start flow-control auto-tuning.
|
||||
if c.highestReceived == 0 {
|
||||
c.startNewAutoTuningEpoch(now)
|
||||
}
|
||||
c.highestReceived += increment
|
||||
|
||||
if c.checkFlowControlViolation() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FlowControlError,
|
||||
ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) (hasWindowUpdate bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.addBytesRead(n)
|
||||
return c.hasWindowUpdate()
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) GetWindowUpdate(now monotime.Time) protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.getWindowUpdate(now)
|
||||
if c.logger.Debug() && oldWindowSize < c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
return offset
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowSize sets a minimum window size
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount, now monotime.Time) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if inc <= c.receiveWindowSize {
|
||||
return
|
||||
}
|
||||
newSize := min(inc, c.maxReceiveWindowSize)
|
||||
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
|
||||
c.receiveWindowSize = newSize
|
||||
if c.logger.Debug() {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d, in response to stream flow control window increase", newSize)
|
||||
}
|
||||
}
|
||||
c.startNewAutoTuningEpoch(now)
|
||||
}
|
||||
|
||||
// Reset rests the flow controller. This happens when 0-RTT is rejected.
|
||||
// All stream data is invalidated, it's as if we had never opened a stream and never sent any data.
|
||||
// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
|
||||
func (c *connectionFlowController) Reset() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() {
|
||||
return errors.New("flow controller reset after reading data")
|
||||
}
|
||||
c.bytesSent = 0
|
||||
c.lastBlockedAt = 0
|
||||
c.sendWindow = 0
|
||||
return nil
|
||||
}
|
||||
46
vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
46
vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
package flowcontrol
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type flowController interface {
|
||||
// for sending
|
||||
SendWindowSize() protocol.ByteCount
|
||||
UpdateSendWindow(protocol.ByteCount) (updated bool)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
GetWindowUpdate(monotime.Time) protocol.ByteCount // returns 0 if no update is necessary
|
||||
}
|
||||
|
||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
AddBytesRead(protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool)
|
||||
// UpdateHighestReceived is called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream,
|
||||
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool, now monotime.Time) error
|
||||
// Abandon is called when reading from the stream is aborted early,
|
||||
// and there won't be any further calls to AddBytesRead.
|
||||
Abandon()
|
||||
IsNewlyBlocked() bool
|
||||
}
|
||||
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
AddBytesRead(protocol.ByteCount) (hasWindowUpdate bool)
|
||||
Reset() error
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
type connectionFlowControllerI interface {
|
||||
ConnectionFlowController
|
||||
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||
// for sending
|
||||
EnsureMinimumWindowSize(protocol.ByteCount, monotime.Time)
|
||||
// for receiving
|
||||
IncrementHighestReceived(protocol.ByteCount, monotime.Time) error
|
||||
}
|
||||
154
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
154
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
package flowcontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type streamFlowController struct {
|
||||
baseFlowController
|
||||
|
||||
streamID protocol.StreamID
|
||||
|
||||
connection connectionFlowControllerI
|
||||
|
||||
receivedFinalOffset bool
|
||||
}
|
||||
|
||||
var _ StreamFlowController = &streamFlowController{}
|
||||
|
||||
// NewStreamFlowController gets a new flow controller for a stream
|
||||
func NewStreamFlowController(
|
||||
streamID protocol.StreamID,
|
||||
cfc ConnectionFlowController,
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highestReceived value, if the offset is higher.
|
||||
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool, now monotime.Time) error {
|
||||
// If the final offset for this stream is already known, check for consistency.
|
||||
if c.receivedFinalOffset {
|
||||
// If we receive another final offset, check that it's the same.
|
||||
if final && offset != c.highestReceived {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset),
|
||||
}
|
||||
}
|
||||
// Check that the offset is below the final offset.
|
||||
if offset > c.highestReceived {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if final {
|
||||
c.receivedFinalOffset = true
|
||||
}
|
||||
if offset == c.highestReceived {
|
||||
return nil
|
||||
}
|
||||
// A higher offset was received before. This can happen due to reordering.
|
||||
if offset < c.highestReceived {
|
||||
if final {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If this is the first frame received for this stream, start flow-control auto-tuning.
|
||||
if c.highestReceived == 0 {
|
||||
c.startNewAutoTuningEpoch(now)
|
||||
}
|
||||
increment := offset - c.highestReceived
|
||||
c.highestReceived = offset
|
||||
|
||||
if c.checkFlowControlViolation() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FlowControlError,
|
||||
ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
|
||||
}
|
||||
}
|
||||
return c.connection.IncrementHighestReceived(increment, now)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) {
|
||||
c.mutex.Lock()
|
||||
c.addBytesRead(n)
|
||||
hasStreamWindowUpdate = c.shouldQueueWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
hasConnWindowUpdate = c.connection.AddBytesRead(n)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *streamFlowController) Abandon() {
|
||||
c.mutex.Lock()
|
||||
unread := c.highestReceived - c.bytesRead
|
||||
c.bytesRead = c.highestReceived
|
||||
c.mutex.Unlock()
|
||||
if unread > 0 {
|
||||
c.connection.AddBytesRead(unread)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesSent(n)
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
return min(c.baseFlowController.SendWindowSize(), c.connection.SendWindowSize())
|
||||
}
|
||||
|
||||
func (c *streamFlowController) IsNewlyBlocked() bool {
|
||||
blocked, _ := c.baseFlowController.IsNewlyBlocked()
|
||||
return blocked
|
||||
}
|
||||
|
||||
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
|
||||
return !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate(now monotime.Time) protocol.ByteCount {
|
||||
// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
|
||||
if c.receivedFinalOffset {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.getWindowUpdate(now)
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d", c.streamID, c.receiveWindowSize)
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize)*protocol.ConnectionFlowControlMultiplier), now)
|
||||
}
|
||||
return offset
|
||||
}
|
||||
90
vendor/github.com/quic-go/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
90
vendor/github.com/quic-go/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
func createAEAD(suite cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
keyLabel = hkdfLabelKeyV2
|
||||
ivLabel = hkdfLabelIVV2
|
||||
}
|
||||
key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen)
|
||||
iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen())
|
||||
return suite.AEAD(key, iv)
|
||||
}
|
||||
|
||||
type longHeaderSealer struct {
|
||||
aead *xorNonceAEAD
|
||||
headerProtector headerProtector
|
||||
nonceBuf [8]byte
|
||||
}
|
||||
|
||||
var _ LongHeaderSealer = &longHeaderSealer{}
|
||||
|
||||
func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||
if aead.NonceSize() != 8 {
|
||||
panic("unexpected nonce size")
|
||||
}
|
||||
return &longHeaderSealer{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn))
|
||||
return s.aead.Seal(dst, s.nonceBuf[:], src, ad)
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) Overhead() int {
|
||||
return s.aead.Overhead()
|
||||
}
|
||||
|
||||
type longHeaderOpener struct {
|
||||
aead *xorNonceAEAD
|
||||
headerProtector headerProtector
|
||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||
|
||||
// use a single array to avoid allocations
|
||||
nonceBuf [8]byte
|
||||
}
|
||||
|
||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||
|
||||
func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||
if aead.NonceSize() != 8 {
|
||||
panic("unexpected nonce size")
|
||||
}
|
||||
return &longHeaderOpener{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
|
||||
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn))
|
||||
dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad)
|
||||
if err == nil {
|
||||
o.highestRcvdPN = max(o.highestRcvdPN, pn)
|
||||
} else {
|
||||
err = ErrDecryptionFailed
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
104
vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go
generated
vendored
Normal file
104
vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go
generated
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// These cipher suite implementations are copied from the standard library crypto/tls package.
|
||||
|
||||
const aeadNonceLength = 12
|
||||
|
||||
type cipherSuite struct {
|
||||
ID uint16
|
||||
Hash crypto.Hash
|
||||
KeyLen int
|
||||
AEAD func(key, nonceMask []byte) *xorNonceAEAD
|
||||
}
|
||||
|
||||
func (s cipherSuite) IVLen() int { return aeadNonceLength }
|
||||
|
||||
func getCipherSuite(id uint16) cipherSuite {
|
||||
switch id {
|
||||
case tls.TLS_AES_128_GCM_SHA256:
|
||||
return cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13}
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305}
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
return cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA384, KeyLen: 32, AEAD: aeadAESGCMTLS13}
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown cypher suite: %d", id))
|
||||
}
|
||||
}
|
||||
|
||||
func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
|
||||
// before each call.
|
||||
type xorNonceAEAD struct {
|
||||
nonceMask [aeadNonceLength]byte
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
|
||||
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
||||
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
|
||||
|
||||
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
720
vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
720
vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
@@ -0,0 +1,720 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
type quicVersionContextKey struct{}
|
||||
|
||||
var QUICVersionContextKey = &quicVersionContextKey{}
|
||||
|
||||
const clientSessionStateRevision = 5
|
||||
|
||||
type cryptoSetup struct {
|
||||
tlsConf *tls.Config
|
||||
conn *tls.QUICConn
|
||||
|
||||
events []Event
|
||||
|
||||
version protocol.Version
|
||||
|
||||
ourParams *wire.TransportParameters
|
||||
peerParams *wire.TransportParameters
|
||||
|
||||
zeroRTTParameters *wire.TransportParameters
|
||||
allow0RTT bool
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
qlogger qlogwriter.Recorder
|
||||
logger utils.Logger
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
handshakeCompleteTime time.Time
|
||||
|
||||
zeroRTTOpener LongHeaderOpener // only set for the server
|
||||
zeroRTTSealer LongHeaderSealer // only set for the client
|
||||
|
||||
initialOpener LongHeaderOpener
|
||||
initialSealer LongHeaderSealer
|
||||
|
||||
handshakeOpener LongHeaderOpener
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
used0RTT atomic.Bool
|
||||
|
||||
aead *updatableAEAD
|
||||
has1RTTSealer bool
|
||||
has1RTTOpener bool
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetup{}
|
||||
|
||||
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||
func NewCryptoSetupClient(
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
tlsConf *tls.Config,
|
||||
enable0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
qlogger qlogwriter.Recorder,
|
||||
logger utils.Logger,
|
||||
version protocol.Version,
|
||||
) CryptoSetup {
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
tp,
|
||||
rttStats,
|
||||
qlogger,
|
||||
logger,
|
||||
protocol.PerspectiveClient,
|
||||
version,
|
||||
)
|
||||
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
cs.tlsConf = tlsConf
|
||||
cs.allow0RTT = enable0RTT
|
||||
|
||||
cs.conn = tls.QUICClient(&tls.QUICConfig{
|
||||
TLSConfig: tlsConf,
|
||||
EnableSessionEvents: true,
|
||||
})
|
||||
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
|
||||
|
||||
return cs
|
||||
}
|
||||
|
||||
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||
func NewCryptoSetupServer(
|
||||
connID protocol.ConnectionID,
|
||||
localAddr, remoteAddr net.Addr,
|
||||
tp *wire.TransportParameters,
|
||||
tlsConf *tls.Config,
|
||||
allow0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
qlogger qlogwriter.Recorder,
|
||||
logger utils.Logger,
|
||||
version protocol.Version,
|
||||
) CryptoSetup {
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
tp,
|
||||
rttStats,
|
||||
qlogger,
|
||||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
version,
|
||||
)
|
||||
cs.allow0RTT = allow0RTT
|
||||
|
||||
tlsConf = setupConfigForServer(tlsConf, localAddr, remoteAddr)
|
||||
|
||||
cs.tlsConf = tlsConf
|
||||
cs.conn = tls.QUICServer(&tls.QUICConfig{
|
||||
TLSConfig: tlsConf,
|
||||
EnableSessionEvents: true,
|
||||
})
|
||||
return cs
|
||||
}
|
||||
|
||||
func newCryptoSetup(
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
rttStats *utils.RTTStats,
|
||||
qlogger qlogwriter.Recorder,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
version protocol.Version,
|
||||
) *cryptoSetup {
|
||||
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
|
||||
if qlogger != nil {
|
||||
qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateTLS,
|
||||
KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveClient),
|
||||
})
|
||||
qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateTLS,
|
||||
KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveServer),
|
||||
})
|
||||
}
|
||||
return &cryptoSetup{
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
aead: newUpdatableAEAD(rttStats, qlogger, logger, version),
|
||||
events: make([]Event, 0, 16),
|
||||
ourParams: tp,
|
||||
rttStats: rttStats,
|
||||
qlogger: qlogger,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
|
||||
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
|
||||
h.initialSealer = initialSealer
|
||||
h.initialOpener = initialOpener
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateTLS,
|
||||
KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveClient),
|
||||
})
|
||||
h.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateTLS,
|
||||
KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveServer),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
|
||||
return h.aead.SetLargestAcked(pn)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
|
||||
err := h.conn.Start(context.WithValue(ctx, QUICVersionContextKey, h.version))
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
if err := h.handleEvent(ev); err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
if ev.Kind == tls.QUICNoEvent {
|
||||
break
|
||||
}
|
||||
}
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
|
||||
h.logger.Debugf("Doing 0-RTT.")
|
||||
h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters})
|
||||
} else {
|
||||
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the crypto setup.
|
||||
// It aborts the handshake, if it is still running.
|
||||
func (h *cryptoSetup) Close() error {
|
||||
return h.conn.Close()
|
||||
}
|
||||
|
||||
// HandleMessage handles a TLS handshake message.
|
||||
// It is called by the crypto streams when a new message is available.
|
||||
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
if err := h.handleMessage(data, encLevel); err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
if err := h.handleEvent(ev); err != nil {
|
||||
return err
|
||||
}
|
||||
if ev.Kind == tls.QUICNoEvent {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (err error) {
|
||||
switch ev.Kind {
|
||||
case tls.QUICNoEvent:
|
||||
return nil
|
||||
case tls.QUICSetReadSecret:
|
||||
h.setReadKey(ev.Level, ev.Suite, ev.Data)
|
||||
return nil
|
||||
case tls.QUICSetWriteSecret:
|
||||
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||
return nil
|
||||
case tls.QUICTransportParameters:
|
||||
return h.handleTransportParameters(ev.Data)
|
||||
case tls.QUICTransportParametersRequired:
|
||||
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
|
||||
return nil
|
||||
case tls.QUICRejectedEarlyData:
|
||||
h.rejected0RTT()
|
||||
return nil
|
||||
case tls.QUICWriteData:
|
||||
h.writeRecord(ev.Level, ev.Data)
|
||||
return nil
|
||||
case tls.QUICHandshakeDone:
|
||||
h.handshakeComplete()
|
||||
return nil
|
||||
case tls.QUICStoreSession:
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
panic("cryptoSetup BUG: unexpected QUICStoreSession event for the server")
|
||||
}
|
||||
ev.SessionState.Extra = append(
|
||||
ev.SessionState.Extra,
|
||||
addSessionStateExtraPrefix(h.marshalDataForSessionState(ev.SessionState.EarlyData)),
|
||||
)
|
||||
return h.conn.StoreSession(ev.SessionState)
|
||||
case tls.QUICResumeSession:
|
||||
var allowEarlyData bool
|
||||
switch h.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
// for clients, this event occurs when a session ticket is selected
|
||||
allowEarlyData = h.handleDataFromSessionState(
|
||||
findSessionStateExtraData(ev.SessionState.Extra),
|
||||
ev.SessionState.EarlyData,
|
||||
)
|
||||
case protocol.PerspectiveServer:
|
||||
// for servers, this event occurs when receiving the client's session ticket
|
||||
allowEarlyData = h.handleSessionTicket(
|
||||
findSessionStateExtraData(ev.SessionState.Extra),
|
||||
ev.SessionState.EarlyData,
|
||||
)
|
||||
}
|
||||
if ev.SessionState.EarlyData {
|
||||
ev.SessionState.EarlyData = allowEarlyData
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
// Unknown events should be ignored.
|
||||
// crypto/tls will ensure that this is safe to do.
|
||||
// See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) NextEvent() Event {
|
||||
if len(h.events) == 0 {
|
||||
return Event{Kind: EventNoEvent}
|
||||
}
|
||||
ev := h.events[0]
|
||||
h.events = h.events[1:]
|
||||
return ev
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||
return err
|
||||
}
|
||||
h.peerParams = &tp
|
||||
h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams})
|
||||
return nil
|
||||
}
|
||||
|
||||
// must be called after receiving the transport parameters
|
||||
func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte {
|
||||
b := make([]byte, 0, 256)
|
||||
b = quicvarint.Append(b, clientSessionStateRevision)
|
||||
if earlyData {
|
||||
// only save the transport parameters for 0-RTT enabled session tickets
|
||||
return h.peerParams.MarshalForSessionTicket(b)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) {
|
||||
tp, err := decodeDataFromSessionState(data, earlyData)
|
||||
if err != nil {
|
||||
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
|
||||
return
|
||||
}
|
||||
// The session ticket might have been saved from a connection that allowed 0-RTT,
|
||||
// and therefore contain transport parameters.
|
||||
// Only use them if 0-RTT is actually used on the new connection.
|
||||
if tp != nil && h.allow0RTT {
|
||||
h.zeroRTTParameters = tp
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func decodeDataFromSessionState(b []byte, earlyData bool) (*wire.TransportParameters, error) {
|
||||
ver, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b = b[l:]
|
||||
if ver != clientSessionStateRevision {
|
||||
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
|
||||
}
|
||||
if !earlyData {
|
||||
return nil, nil
|
||||
}
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tp, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
||||
return (&sessionTicket{
|
||||
Parameters: h.ourParams,
|
||||
}).Marshal()
|
||||
}
|
||||
|
||||
// GetSessionTicket generates a new session ticket.
|
||||
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
|
||||
// It is only valid for the server.
|
||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
|
||||
EarlyData: h.allow0RTT,
|
||||
Extra: [][]byte{addSessionStateExtraPrefix(h.getDataForSessionTicket())},
|
||||
}); err != nil {
|
||||
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
|
||||
// We can't check h.tlsConfig here, since the actual config might have been obtained from
|
||||
// the GetConfigForClient callback.
|
||||
// See https://github.com/golang/go/issues/62032.
|
||||
// This error assertion can be removed once we drop support for Go 1.25.
|
||||
if strings.Contains(err.Error(), "session ticket keys unavailable") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// If session tickets are disabled, NextEvent will immediately return QUICNoEvent,
|
||||
// and we will return a nil ticket.
|
||||
var ticket []byte
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
if ev.Kind == tls.QUICNoEvent {
|
||||
break
|
||||
}
|
||||
if ev.Kind == tls.QUICWriteData && ev.Level == tls.QUICEncryptionLevelApplication {
|
||||
if ticket != nil {
|
||||
h.logger.Errorf("unexpected multiple session tickets")
|
||||
continue
|
||||
}
|
||||
ticket = ev.Data
|
||||
} else {
|
||||
h.logger.Errorf("unexpected event: %v", ev.Kind)
|
||||
}
|
||||
}
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
// handleSessionTicket is called for the server when receiving the client's session ticket.
|
||||
// It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT.
|
||||
// Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT:
|
||||
// A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT.
|
||||
func (h *cryptoSetup) handleSessionTicket(data []byte, using0RTT bool) (allowEarlyData bool) {
|
||||
var t sessionTicket
|
||||
if err := t.Unmarshal(data); err != nil {
|
||||
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
if !using0RTT {
|
||||
return false
|
||||
}
|
||||
valid := h.ourParams.ValidFor0RTT(t.Parameters)
|
||||
if !valid {
|
||||
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
if !h.allow0RTT {
|
||||
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// rejected0RTT is called for the client when the server rejects 0-RTT.
|
||||
func (h *cryptoSetup) rejected0RTT() {
|
||||
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
|
||||
|
||||
had0RTTKeys := h.zeroRTTSealer != nil
|
||||
h.zeroRTTSealer = nil
|
||||
|
||||
if had0RTTKeys {
|
||||
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
panic("Received 0-RTT read key for the client")
|
||||
}
|
||||
h.zeroRTTOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
h.used0RTT.Store(true)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
h.aead.SetReadKey(suite, trafficSecret)
|
||||
h.has1RTTOpener = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
default:
|
||||
panic("unexpected read encryption level")
|
||||
}
|
||||
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateTLS,
|
||||
KeyType: encLevelToKeyType(protocol.FromTLSEncryptionLevel(el), h.perspective.Opposite()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
panic("Received 0-RTT write key for the server")
|
||||
}
|
||||
h.zeroRTTSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateTLS,
|
||||
KeyType: encLevelToKeyType(protocol.Encryption0RTT, h.perspective),
|
||||
})
|
||||
}
|
||||
// don't set used0RTT here. 0-RTT might still get rejected.
|
||||
return
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
h.aead.SetWriteKey(suite, trafficSecret)
|
||||
h.has1RTTSealer = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.zeroRTTSealer != nil {
|
||||
// Once we receive handshake keys, we know that 0-RTT was not rejected.
|
||||
h.used0RTT.Store(true)
|
||||
h.zeroRTTSealer = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClient0RTT})
|
||||
}
|
||||
}
|
||||
default:
|
||||
panic("unexpected write encryption level")
|
||||
}
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateTLS,
|
||||
KeyType: encLevelToKeyType(protocol.FromTLSEncryptionLevel(el), h.perspective),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// writeRecord is called when TLS writes data
|
||||
func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
|
||||
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
|
||||
switch encLevel {
|
||||
case tls.QUICEncryptionLevelInitial:
|
||||
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
panic("unexpected write")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) DiscardInitialKeys() {
|
||||
dropped := h.initialOpener != nil
|
||||
h.initialOpener = nil
|
||||
h.initialSealer = nil
|
||||
if dropped {
|
||||
h.logger.Debugf("Dropping Initial keys.")
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClientInitial})
|
||||
h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeServerInitial})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handshakeComplete() {
|
||||
h.handshakeCompleteTime = time.Now()
|
||||
h.events = append(h.events, Event{Kind: EventHandshakeComplete})
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetHandshakeConfirmed() {
|
||||
h.aead.SetHandshakeConfirmed()
|
||||
// drop Handshake keys
|
||||
var dropped bool
|
||||
if h.handshakeOpener != nil {
|
||||
h.handshakeOpener = nil
|
||||
h.handshakeSealer = nil
|
||||
dropped = true
|
||||
}
|
||||
if dropped {
|
||||
h.logger.Debugf("Dropping Handshake keys.")
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClientHandshake})
|
||||
h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeServerHandshake})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
||||
if h.zeroRTTSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.zeroRTTSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
||||
if h.handshakeSealer == nil {
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.handshakeSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
||||
if !h.has1RTTSealer {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.aead, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
||||
if h.initialOpener == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
||||
if h.zeroRTTOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.zeroRTTOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
||||
if h.handshakeOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.handshakeOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
||||
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
|
||||
h.zeroRTTOpener = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
if h.qlogger != nil {
|
||||
h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClient0RTT})
|
||||
}
|
||||
}
|
||||
|
||||
if !h.has1RTTOpener {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.aead, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
return ConnectionState{
|
||||
ConnectionState: h.conn.ConnectionState(),
|
||||
Used0RTT: h.used0RTT.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapError(err error) error {
|
||||
if alertErr := tls.AlertError(0); errors.As(err, &alertErr) {
|
||||
return qerr.NewLocalCryptoError(uint8(alertErr), err)
|
||||
}
|
||||
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
|
||||
}
|
||||
|
||||
func encLevelToKeyType(encLevel protocol.EncryptionLevel, pers protocol.Perspective) qlog.KeyType {
|
||||
if pers == protocol.PerspectiveServer {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return qlog.KeyTypeServerInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
return qlog.KeyTypeServerHandshake
|
||||
case protocol.Encryption0RTT:
|
||||
return qlog.KeyTypeServer0RTT
|
||||
case protocol.Encryption1RTT:
|
||||
return qlog.KeyTypeServer1RTT
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return qlog.KeyTypeClientInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
return qlog.KeyTypeClientHandshake
|
||||
case protocol.Encryption0RTT:
|
||||
return qlog.KeyTypeClient0RTT
|
||||
case protocol.Encryption1RTT:
|
||||
return qlog.KeyTypeClient1RTT
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
21
vendor/github.com/quic-go/quic-go/internal/handshake/fake_conn.go
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/internal/handshake/fake_conn.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
134
vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go
generated
vendored
Normal file
134
vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go
generated
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type headerProtector interface {
|
||||
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
}
|
||||
|
||||
func hkdfHeaderProtectionLabel(v protocol.Version) string {
|
||||
if v == protocol.Version2 {
|
||||
return "quicv2 hp"
|
||||
}
|
||||
return "quic hp"
|
||||
}
|
||||
|
||||
func newHeaderProtector(suite cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector {
|
||||
hkdfLabel := hkdfHeaderProtectionLabel(v)
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
|
||||
default:
|
||||
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
|
||||
}
|
||||
}
|
||||
|
||||
type aesHeaderProtector struct {
|
||||
mask [16]byte // AES always has a 16 byte block size
|
||||
block cipher.Block
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &aesHeaderProtector{}
|
||||
|
||||
func newAESHeaderProtector(suite cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
block, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return &aesHeaderProtector{
|
||||
block: block,
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != len(p.mask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
p.block.Encrypt(p.mask[:], sample)
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
type chachaHeaderProtector struct {
|
||||
mask [5]byte
|
||||
|
||||
key [32]byte
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &chachaHeaderProtector{}
|
||||
|
||||
func newChaChaHeaderProtector(suite cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
|
||||
p := &chachaHeaderProtector{
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
copy(p.key[:], hpKey)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != 16 {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
p.mask[i] = 0
|
||||
}
|
||||
cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
|
||||
cipher.XORKeyStream(p.mask[:], p.mask[:])
|
||||
p.applyMask(firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) {
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
||||
27
vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go
generated
vendored
Normal file
27
vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
|
||||
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
|
||||
b := make([]byte, 3, 3+6+len(label)+1+len(context))
|
||||
binary.BigEndian.PutUint16(b, uint16(length))
|
||||
b[2] = uint8(6 + len(label))
|
||||
b = append(b, []byte("tls13 ")...)
|
||||
b = append(b, []byte(label)...)
|
||||
b = b[:3+6+len(label)+1]
|
||||
b[3+6+len(label)] = uint8(len(context))
|
||||
b = append(b, context...)
|
||||
|
||||
out := make([]byte, length)
|
||||
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
|
||||
if err != nil || n != length {
|
||||
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
|
||||
}
|
||||
return out
|
||||
}
|
||||
71
vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go
generated
vendored
Normal file
71
vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go
generated
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
|
||||
quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9}
|
||||
)
|
||||
|
||||
const (
|
||||
hkdfLabelKeyV1 = "quic key"
|
||||
hkdfLabelKeyV2 = "quicv2 key"
|
||||
hkdfLabelIVV1 = "quic iv"
|
||||
hkdfLabelIVV2 = "quicv2 iv"
|
||||
)
|
||||
|
||||
func getSalt(v protocol.Version) []byte {
|
||||
if v == protocol.Version2 {
|
||||
return quicSaltV2
|
||||
}
|
||||
return quicSaltV1
|
||||
}
|
||||
|
||||
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) {
|
||||
clientSecret, serverSecret := computeSecrets(connID, v)
|
||||
var mySecret, otherSecret []byte
|
||||
if pers == protocol.PerspectiveClient {
|
||||
mySecret = clientSecret
|
||||
otherSecret = serverSecret
|
||||
} else {
|
||||
mySecret = serverSecret
|
||||
otherSecret = clientSecret
|
||||
}
|
||||
myKey, myIV := computeInitialKeyAndIV(mySecret, v)
|
||||
otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v)
|
||||
|
||||
encrypter := initialSuite.AEAD(myKey, myIV)
|
||||
decrypter := initialSuite.AEAD(otherKey, otherIV)
|
||||
|
||||
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)),
|
||||
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
|
||||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) {
|
||||
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
|
||||
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
|
||||
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
keyLabel = hkdfLabelKeyV2
|
||||
ivLabel = hkdfLabelIVV2
|
||||
}
|
||||
key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
|
||||
iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
|
||||
return
|
||||
}
|
||||
140
vendor/github.com/quic-go/quic-go/internal/handshake/interface.go
generated
vendored
Normal file
140
vendor/github.com/quic-go/quic-go/internal/handshake/interface.go
generated
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level,
|
||||
// but the corresponding opener has not yet been initialized
|
||||
// This can happen when packets arrive out of order.
|
||||
ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available")
|
||||
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
|
||||
// but the corresponding keys have already been dropped.
|
||||
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
|
||||
// ErrDecryptionFailed is returned when the AEAD fails to open the packet.
|
||||
ErrDecryptionFailed = errors.New("decryption failed")
|
||||
)
|
||||
|
||||
type headerDecryptor interface {
|
||||
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
}
|
||||
|
||||
// LongHeaderOpener opens a long header packet
|
||||
type LongHeaderOpener interface {
|
||||
headerDecryptor
|
||||
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
|
||||
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ShortHeaderOpener opens a short header packet
|
||||
type ShortHeaderOpener interface {
|
||||
headerDecryptor
|
||||
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
|
||||
Open(dst, src []byte, rcvTime monotime.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// LongHeaderSealer seals a long header packet
|
||||
type LongHeaderSealer interface {
|
||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
Overhead() int
|
||||
}
|
||||
|
||||
// ShortHeaderSealer seals a short header packet
|
||||
type ShortHeaderSealer interface {
|
||||
LongHeaderSealer
|
||||
KeyPhase() protocol.KeyPhaseBit
|
||||
}
|
||||
|
||||
type ConnectionState struct {
|
||||
tls.ConnectionState
|
||||
Used0RTT bool
|
||||
}
|
||||
|
||||
// EventKind is the kind of handshake event.
|
||||
type EventKind uint8
|
||||
|
||||
const (
|
||||
// EventNoEvent signals that there are no new handshake events
|
||||
EventNoEvent EventKind = iota + 1
|
||||
// EventWriteInitialData contains new CRYPTO data to send at the Initial encryption level
|
||||
EventWriteInitialData
|
||||
// EventWriteHandshakeData contains new CRYPTO data to send at the Handshake encryption level
|
||||
EventWriteHandshakeData
|
||||
// EventReceivedReadKeys signals that new decryption keys are available.
|
||||
// It doesn't say which encryption level those keys are for.
|
||||
EventReceivedReadKeys
|
||||
// EventDiscard0RTTKeys signals that the Handshake keys were discarded.
|
||||
EventDiscard0RTTKeys
|
||||
// EventReceivedTransportParameters contains the transport parameters sent by the peer.
|
||||
EventReceivedTransportParameters
|
||||
// EventRestoredTransportParameters contains the transport parameters restored from the session ticket.
|
||||
// It is only used for the client.
|
||||
EventRestoredTransportParameters
|
||||
// EventHandshakeComplete signals that the TLS handshake was completed.
|
||||
EventHandshakeComplete
|
||||
)
|
||||
|
||||
func (k EventKind) String() string {
|
||||
switch k {
|
||||
case EventNoEvent:
|
||||
return "EventNoEvent"
|
||||
case EventWriteInitialData:
|
||||
return "EventWriteInitialData"
|
||||
case EventWriteHandshakeData:
|
||||
return "EventWriteHandshakeData"
|
||||
case EventReceivedReadKeys:
|
||||
return "EventReceivedReadKeys"
|
||||
case EventDiscard0RTTKeys:
|
||||
return "EventDiscard0RTTKeys"
|
||||
case EventReceivedTransportParameters:
|
||||
return "EventReceivedTransportParameters"
|
||||
case EventRestoredTransportParameters:
|
||||
return "EventRestoredTransportParameters"
|
||||
case EventHandshakeComplete:
|
||||
return "EventHandshakeComplete"
|
||||
default:
|
||||
return "Unknown EventKind"
|
||||
}
|
||||
}
|
||||
|
||||
// Event is a handshake event.
|
||||
type Event struct {
|
||||
Kind EventKind
|
||||
Data []byte
|
||||
TransportParameters *wire.TransportParameters
|
||||
}
|
||||
|
||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||
type CryptoSetup interface {
|
||||
StartHandshake(context.Context) error
|
||||
io.Closer
|
||||
ChangeConnectionID(protocol.ConnectionID)
|
||||
GetSessionTicket() ([]byte, error)
|
||||
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) error
|
||||
NextEvent() Event
|
||||
|
||||
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||
DiscardInitialKeys()
|
||||
SetHandshakeConfirmed()
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetInitialOpener() (LongHeaderOpener, error)
|
||||
GetHandshakeOpener() (LongHeaderOpener, error)
|
||||
Get0RTTOpener() (LongHeaderOpener, error)
|
||||
Get1RTTOpener() (ShortHeaderOpener, error)
|
||||
|
||||
GetInitialSealer() (LongHeaderSealer, error)
|
||||
GetHandshakeSealer() (LongHeaderSealer, error)
|
||||
Get0RTTSealer() (LongHeaderSealer, error)
|
||||
Get1RTTSealer() (ShortHeaderSealer, error)
|
||||
}
|
||||
66
vendor/github.com/quic-go/quic-go/internal/handshake/retry.go
generated
vendored
Normal file
66
vendor/github.com/quic-go/quic-go/internal/handshake/retry.go
generated
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Instead of using an init function, the AEADs are created lazily.
|
||||
// For more details see https://github.com/quic-go/quic-go/issues/4894.
|
||||
var (
|
||||
retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
|
||||
retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369)
|
||||
)
|
||||
|
||||
func initAEAD(key [16]byte) cipher.AEAD {
|
||||
aes, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return aead
|
||||
}
|
||||
|
||||
var (
|
||||
retryBuf bytes.Buffer
|
||||
retryMutex sync.Mutex
|
||||
retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
|
||||
retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a}
|
||||
)
|
||||
|
||||
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
|
||||
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte {
|
||||
retryMutex.Lock()
|
||||
defer retryMutex.Unlock()
|
||||
|
||||
retryBuf.WriteByte(uint8(origDestConnID.Len()))
|
||||
retryBuf.Write(origDestConnID.Bytes())
|
||||
retryBuf.Write(retry)
|
||||
defer retryBuf.Reset()
|
||||
|
||||
var tag [16]byte
|
||||
var sealed []byte
|
||||
if version == protocol.Version2 {
|
||||
if retryAEADv2 == nil {
|
||||
retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
|
||||
}
|
||||
sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes())
|
||||
} else {
|
||||
if retryAEADv1 == nil {
|
||||
retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
|
||||
}
|
||||
sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
|
||||
}
|
||||
if len(sealed) != 16 {
|
||||
panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed)))
|
||||
}
|
||||
return &tag
|
||||
}
|
||||
56
vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go
generated
vendored
Normal file
56
vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go
generated
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
const sessionTicketRevision = 5
|
||||
|
||||
type sessionTicket struct {
|
||||
Parameters *wire.TransportParameters
|
||||
}
|
||||
|
||||
func (t *sessionTicket) Marshal() []byte {
|
||||
b := make([]byte, 0, 256)
|
||||
b = quicvarint.Append(b, sessionTicketRevision)
|
||||
return t.Parameters.MarshalForSessionTicket(b)
|
||||
}
|
||||
|
||||
func (t *sessionTicket) Unmarshal(b []byte) error {
|
||||
rev, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return errors.New("failed to read session ticket revision")
|
||||
}
|
||||
b = b[l:]
|
||||
if rev != sessionTicketRevision {
|
||||
return fmt.Errorf("unknown session ticket revision: %d", rev)
|
||||
}
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
|
||||
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
|
||||
}
|
||||
t.Parameters = &tp
|
||||
return nil
|
||||
}
|
||||
|
||||
const extraPrefix = "quic-go1"
|
||||
|
||||
func addSessionStateExtraPrefix(b []byte) []byte {
|
||||
return append([]byte(extraPrefix), b...)
|
||||
}
|
||||
|
||||
func findSessionStateExtraData(extras [][]byte) []byte {
|
||||
prefix := []byte(extraPrefix)
|
||||
for _, extra := range extras {
|
||||
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
|
||||
continue
|
||||
}
|
||||
return extra[len(prefix):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
39
vendor/github.com/quic-go/quic-go/internal/handshake/tls_config.go
generated
vendored
Normal file
39
vendor/github.com/quic-go/quic-go/internal/handshake/tls_config.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
func setupConfigForServer(conf *tls.Config, localAddr, remoteAddr net.Addr) *tls.Config {
|
||||
// Workaround for https://github.com/golang/go/issues/60506.
|
||||
// This initializes the session tickets _before_ cloning the config.
|
||||
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
|
||||
|
||||
conf = conf.Clone()
|
||||
conf.MinVersion = tls.VersionTLS13
|
||||
|
||||
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
|
||||
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
|
||||
// that allows the caller to get the local and the remote address.
|
||||
if conf.GetConfigForClient != nil {
|
||||
gcfc := conf.GetConfigForClient
|
||||
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
c, err := gcfc(info)
|
||||
if c != nil {
|
||||
// we're returning a tls.Config here, so we need to apply this recursively
|
||||
c = setupConfigForServer(c, localAddr, remoteAddr)
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
}
|
||||
if conf.GetCertificate != nil {
|
||||
gc := conf.GetCertificate
|
||||
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
return gc(info)
|
||||
}
|
||||
}
|
||||
return conf
|
||||
}
|
||||
126
vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go
generated
vendored
Normal file
126
vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go
generated
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenPrefixIP byte = iota
|
||||
tokenPrefixString
|
||||
)
|
||||
|
||||
// A Token is derived from the client address and can be used to verify the ownership of this address.
|
||||
type Token struct {
|
||||
IsRetryToken bool
|
||||
SentTime time.Time
|
||||
encodedRemoteAddr []byte
|
||||
// only set for tokens sent in NEW_TOKEN frames
|
||||
RTT time.Duration
|
||||
// only set for retry tokens
|
||||
OriginalDestConnectionID protocol.ConnectionID
|
||||
RetrySrcConnectionID protocol.ConnectionID
|
||||
}
|
||||
|
||||
// ValidateRemoteAddr validates the address, but does not check expiration
|
||||
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
|
||||
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
IsRetryToken bool
|
||||
RemoteAddr []byte
|
||||
Timestamp int64
|
||||
RTT int64 // in mus
|
||||
OriginalDestConnectionID []byte
|
||||
RetrySrcConnectionID []byte
|
||||
}
|
||||
|
||||
// A TokenGenerator generates tokens
|
||||
type TokenGenerator struct {
|
||||
tokenProtector tokenProtector
|
||||
}
|
||||
|
||||
// NewTokenGenerator initializes a new TokenGenerator
|
||||
func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator {
|
||||
return &TokenGenerator{tokenProtector: *newTokenProtector(key)}
|
||||
}
|
||||
|
||||
// NewRetryToken generates a new token for a Retry for a given source address
|
||||
func (g *TokenGenerator) NewRetryToken(
|
||||
raddr net.Addr,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
retrySrcConnID protocol.ConnectionID,
|
||||
) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: encodeRemoteAddr(raddr),
|
||||
OriginalDestConnectionID: origDestConnID.Bytes(),
|
||||
RetrySrcConnectionID: retrySrcConnID.Bytes(),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.tokenProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// NewToken generates a new token to be sent in a NEW_TOKEN frame
|
||||
func (g *TokenGenerator) NewToken(raddr net.Addr, rtt time.Duration) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
RemoteAddr: encodeRemoteAddr(raddr),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
RTT: rtt.Microseconds(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.tokenProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token
|
||||
func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
|
||||
// if the client didn't send any token, DecodeToken will be called with a nil-slice
|
||||
if len(encrypted) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := g.tokenProtector.DecodeToken(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := &token{}
|
||||
rest, err := asn1.Unmarshal(data, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||
}
|
||||
token := &Token{
|
||||
IsRetryToken: t.IsRetryToken,
|
||||
SentTime: time.Unix(0, t.Timestamp),
|
||||
encodedRemoteAddr: t.RemoteAddr,
|
||||
}
|
||||
if t.IsRetryToken {
|
||||
token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID)
|
||||
token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID)
|
||||
} else {
|
||||
token.RTT = time.Duration(t.RTT) * time.Microsecond
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the token
|
||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||
return append([]byte{tokenPrefixIP}, udpAddr.IP...)
|
||||
}
|
||||
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
|
||||
}
|
||||
74
vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go
generated
vendored
Normal file
74
vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go
generated
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens.
|
||||
type TokenProtectorKey [32]byte
|
||||
|
||||
const tokenNonceSize = 32
|
||||
|
||||
// tokenProtector is used to create and verify a token
|
||||
type tokenProtector struct {
|
||||
key TokenProtectorKey
|
||||
}
|
||||
|
||||
// newTokenProtector creates a source for source address tokens
|
||||
func newTokenProtector(key TokenProtectorKey) *tokenProtector {
|
||||
return &tokenProtector{key: key}
|
||||
}
|
||||
|
||||
// NewToken encodes data into a new token.
|
||||
func (s *tokenProtector) NewToken(data []byte) ([]byte, error) {
|
||||
var nonce [tokenNonceSize]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, aeadNonce, err := s.createAEAD(nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token.
|
||||
func (s *tokenProtector) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < tokenNonceSize {
|
||||
return nil, fmt.Errorf("token too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:tokenNonceSize]
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
|
||||
}
|
||||
|
||||
func (s *tokenProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
|
||||
h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source"))
|
||||
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
|
||||
if _, err := io.ReadFull(h, key); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aeadNonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(h, aeadNonce); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return aead, aeadNonce, nil
|
||||
}
|
||||
372
vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go
generated
vendored
Normal file
372
vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go
generated
vendored
Normal file
@@ -0,0 +1,372 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/monotime"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
"github.com/quic-go/quic-go/qlogwriter"
|
||||
)
|
||||
|
||||
var keyUpdateInterval atomic.Uint64
|
||||
|
||||
func init() {
|
||||
keyUpdateInterval.Store(protocol.KeyUpdateInterval)
|
||||
}
|
||||
|
||||
func SetKeyUpdateInterval(v uint64) (reset func()) {
|
||||
old := keyUpdateInterval.Swap(v)
|
||||
return func() { keyUpdateInterval.Store(old) }
|
||||
}
|
||||
|
||||
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
|
||||
// It's a package-level variable to allow modifying it for testing purposes.
|
||||
var FirstKeyUpdateInterval uint64 = 100
|
||||
|
||||
type updatableAEAD struct {
|
||||
suite cipherSuite
|
||||
|
||||
keyPhase protocol.KeyPhase
|
||||
largestAcked protocol.PacketNumber
|
||||
firstPacketNumber protocol.PacketNumber
|
||||
handshakeConfirmed bool
|
||||
|
||||
invalidPacketLimit uint64
|
||||
invalidPacketCount uint64
|
||||
|
||||
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
|
||||
prevRcvAEADExpiry monotime.Time
|
||||
prevRcvAEAD cipher.AEAD
|
||||
|
||||
firstRcvdWithCurrentKey protocol.PacketNumber
|
||||
firstSentWithCurrentKey protocol.PacketNumber
|
||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||
numRcvdWithCurrentKey uint64
|
||||
numSentWithCurrentKey uint64
|
||||
rcvAEAD cipher.AEAD
|
||||
sendAEAD cipher.AEAD
|
||||
// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
|
||||
aeadOverhead int
|
||||
|
||||
nextRcvAEAD cipher.AEAD
|
||||
nextSendAEAD cipher.AEAD
|
||||
nextRcvTrafficSecret []byte
|
||||
nextSendTrafficSecret []byte
|
||||
|
||||
headerDecrypter headerProtector
|
||||
headerEncrypter headerProtector
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
qlogger qlogwriter.Recorder
|
||||
logger utils.Logger
|
||||
version protocol.Version
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var (
|
||||
_ ShortHeaderOpener = &updatableAEAD{}
|
||||
_ ShortHeaderSealer = &updatableAEAD{}
|
||||
)
|
||||
|
||||
func newUpdatableAEAD(rttStats *utils.RTTStats, qlogger qlogwriter.Recorder, logger utils.Logger, version protocol.Version) *updatableAEAD {
|
||||
return &updatableAEAD{
|
||||
firstPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestAcked: protocol.InvalidPacketNumber,
|
||||
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
rttStats: rttStats,
|
||||
qlogger: qlogger,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) rollKeys() {
|
||||
if a.prevRcvAEAD != nil {
|
||||
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
|
||||
if a.qlogger != nil {
|
||||
a.qlogger.RecordEvent(qlog.KeyDiscarded{
|
||||
KeyType: qlog.KeyTypeClient1RTT,
|
||||
KeyPhase: a.keyPhase - 1,
|
||||
})
|
||||
a.qlogger.RecordEvent(qlog.KeyDiscarded{
|
||||
KeyType: qlog.KeyTypeServer1RTT,
|
||||
KeyPhase: a.keyPhase - 1,
|
||||
})
|
||||
}
|
||||
a.prevRcvAEADExpiry = 0
|
||||
}
|
||||
|
||||
a.keyPhase++
|
||||
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.numRcvdWithCurrentKey = 0
|
||||
a.numSentWithCurrentKey = 0
|
||||
a.prevRcvAEAD = a.rcvAEAD
|
||||
a.rcvAEAD = a.nextRcvAEAD
|
||||
a.sendAEAD = a.nextSendAEAD
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version)
|
||||
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) startKeyDropTimer(now monotime.Time) {
|
||||
d := 3 * a.rttStats.PTO(true)
|
||||
a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
|
||||
a.prevRcvAEADExpiry = now.Add(d)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
|
||||
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
|
||||
}
|
||||
|
||||
// SetReadKey sets the read key.
|
||||
// For the client, this function is called before SetWriteKey.
|
||||
// For the server, this function is called after SetWriteKey.
|
||||
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
||||
a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite.ID == 0 { // suite is not set yet
|
||||
a.setAEADParameters(a.rcvAEAD, suite)
|
||||
}
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
// SetWriteKey sets the write key.
|
||||
// For the client, this function is called after SetReadKey.
|
||||
// For the server, this function is called before SetReadKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
||||
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite.ID == 0 { // suite is not set yet
|
||||
a.setAEADParameters(a.sendAEAD, suite)
|
||||
}
|
||||
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite cipherSuite) {
|
||||
a.nonceBuf = make([]byte, aead.NonceSize())
|
||||
a.aeadOverhead = aead.Overhead()
|
||||
a.suite = suite
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
a.invalidPacketLimit = protocol.InvalidPacketLimitAES
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
|
||||
}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
|
||||
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Open(dst, src []byte, rcvTime monotime.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
|
||||
if err == ErrDecryptionFailed {
|
||||
a.invalidPacketCount++
|
||||
if a.invalidPacketCount >= a.invalidPacketLimit {
|
||||
return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
a.highestRcvdPN = max(a.highestRcvdPN, pn)
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) open(dst, src []byte, rcvTime monotime.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||
if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
|
||||
a.prevRcvAEAD = nil
|
||||
a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
|
||||
a.prevRcvAEADExpiry = 0
|
||||
if a.qlogger != nil {
|
||||
a.qlogger.RecordEvent(qlog.KeyDiscarded{
|
||||
KeyType: qlog.KeyTypeClient1RTT,
|
||||
KeyPhase: a.keyPhase - 1,
|
||||
})
|
||||
a.qlogger.RecordEvent(qlog.KeyDiscarded{
|
||||
KeyType: qlog.KeyTypeServer1RTT,
|
||||
KeyPhase: a.keyPhase - 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
if kp != a.keyPhase.Bit() {
|
||||
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
||||
if a.prevRcvAEAD == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
// we updated the key, but the peer hasn't updated yet
|
||||
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
err = ErrDecryptionFailed
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
// try opening the packet with the next key phase
|
||||
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
return nil, ErrDecryptionFailed
|
||||
}
|
||||
// Opening succeeded. Check if the peer was allowed to update.
|
||||
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.KeyUpdateError,
|
||||
ErrorMessage: "keys updated too quickly",
|
||||
}
|
||||
}
|
||||
a.rollKeys()
|
||||
a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
|
||||
// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
|
||||
// Start a timer to drop the previous key generation.
|
||||
a.startKeyDropTimer(rcvTime)
|
||||
if a.qlogger != nil {
|
||||
a.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateRemote,
|
||||
KeyType: qlog.KeyTypeClient1RTT,
|
||||
KeyPhase: a.keyPhase,
|
||||
})
|
||||
a.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateRemote,
|
||||
KeyType: qlog.KeyTypeServer1RTT,
|
||||
KeyPhase: a.keyPhase,
|
||||
})
|
||||
}
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
return dec, err
|
||||
}
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
return dec, ErrDecryptionFailed
|
||||
}
|
||||
a.numRcvdWithCurrentKey++
|
||||
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
// We initiated the key updated, and now we received the first packet protected with the new key phase.
|
||||
// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
|
||||
if a.keyPhase > 0 {
|
||||
a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
|
||||
a.startKeyDropTimer(rcvTime)
|
||||
}
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
a.firstSentWithCurrentKey = pn
|
||||
}
|
||||
if a.firstPacketNumber == protocol.InvalidPacketNumber {
|
||||
a.firstPacketNumber = pn
|
||||
}
|
||||
a.numSentWithCurrentKey++
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
|
||||
if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||
pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.KeyUpdateError,
|
||||
ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
|
||||
}
|
||||
}
|
||||
a.largestAcked = pn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetHandshakeConfirmed() {
|
||||
a.handshakeConfirmed = true
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) updateAllowed() bool {
|
||||
if !a.handshakeConfirmed {
|
||||
return false
|
||||
}
|
||||
// the first key update is allowed as soon as the handshake is confirmed
|
||||
return a.keyPhase == 0 ||
|
||||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
|
||||
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||
a.largestAcked != protocol.InvalidPacketNumber &&
|
||||
a.largestAcked >= a.firstSentWithCurrentKey)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
||||
if !a.updateAllowed() {
|
||||
return false
|
||||
}
|
||||
// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
|
||||
if a.keyPhase == 0 {
|
||||
if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if a.numRcvdWithCurrentKey >= keyUpdateInterval.Load() {
|
||||
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
||||
return true
|
||||
}
|
||||
if a.numSentWithCurrentKey >= keyUpdateInterval.Load() {
|
||||
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
|
||||
if a.shouldInitiateKeyUpdate() {
|
||||
a.rollKeys()
|
||||
if a.qlogger != nil {
|
||||
a.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateLocal,
|
||||
KeyType: qlog.KeyTypeClient1RTT,
|
||||
KeyPhase: a.keyPhase,
|
||||
})
|
||||
a.qlogger.RecordEvent(qlog.KeyUpdated{
|
||||
Trigger: qlog.KeyUpdateLocal,
|
||||
KeyType: qlog.KeyTypeServer1RTT,
|
||||
KeyPhase: a.keyPhase,
|
||||
})
|
||||
}
|
||||
}
|
||||
return a.keyPhase.Bit()
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Overhead() int {
|
||||
return a.aeadOverhead
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
|
||||
return a.firstPacketNumber
|
||||
}
|
||||
90
vendor/github.com/quic-go/quic-go/internal/monotime/time.go
generated
vendored
Normal file
90
vendor/github.com/quic-go/quic-go/internal/monotime/time.go
generated
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
// Package monotime provides a monotonic time representation that is useful for
|
||||
// measuring elapsed time.
|
||||
// It is designed as a memory optimized drop-in replacement for time.Time, with
|
||||
// a monotime.Time consuming just 8 bytes instead of 24 bytes.
|
||||
package monotime
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// The absolute value doesn't matter, but it should be in the past,
|
||||
// so that every timestamp obtained with Now() is non-zero,
|
||||
// even on systems with low timer resolutions (e.g. Windows).
|
||||
var start = time.Now().Add(-time.Hour)
|
||||
|
||||
// A Time represents an instant in monotonic time.
|
||||
// Times can be compared using the comparison operators, but the specific
|
||||
// value is implementation-dependent and should not be relied upon.
|
||||
// The zero value of Time doesn't have any specific meaning.
|
||||
type Time int64
|
||||
|
||||
// Now returns the current monotonic time.
|
||||
func Now() Time {
|
||||
return Time(time.Since(start).Nanoseconds())
|
||||
}
|
||||
|
||||
// Sub returns the duration t-t2. If the result exceeds the maximum (or minimum)
|
||||
// value that can be stored in a Duration, the maximum (or minimum) duration
|
||||
// will be returned.
|
||||
// To compute t-d for a duration d, use t.Add(-d).
|
||||
func (t Time) Sub(t2 Time) time.Duration {
|
||||
return time.Duration(t - t2)
|
||||
}
|
||||
|
||||
// Add returns the time t+d.
|
||||
func (t Time) Add(d time.Duration) Time {
|
||||
return Time(int64(t) + d.Nanoseconds())
|
||||
}
|
||||
|
||||
// After reports whether the time instant t is after t2.
|
||||
func (t Time) After(t2 Time) bool {
|
||||
return t > t2
|
||||
}
|
||||
|
||||
// Before reports whether the time instant t is before t2.
|
||||
func (t Time) Before(t2 Time) bool {
|
||||
return t < t2
|
||||
}
|
||||
|
||||
// IsZero reports whether t represents the zero time instant.
|
||||
func (t Time) IsZero() bool {
|
||||
return t == 0
|
||||
}
|
||||
|
||||
// Equal reports whether t and t2 represent the same time instant.
|
||||
func (t Time) Equal(t2 Time) bool {
|
||||
return t == t2
|
||||
}
|
||||
|
||||
// ToTime converts the monotonic time to a time.Time value.
|
||||
// The returned time.Time will have the same instant as the monotonic time,
|
||||
// but may be subject to clock adjustments.
|
||||
func (t Time) ToTime() time.Time {
|
||||
if t.IsZero() {
|
||||
return time.Time{}
|
||||
}
|
||||
return start.Add(time.Duration(t))
|
||||
}
|
||||
|
||||
// Since returns the time elapsed since t. It is shorthand for Now().Sub(t).
|
||||
func Since(t Time) time.Duration {
|
||||
return Now().Sub(t)
|
||||
}
|
||||
|
||||
// Until returns the duration until t.
|
||||
// It is shorthand for t.Sub(Now()).
|
||||
// If t is in the past, the returned duration will be negative.
|
||||
func Until(t Time) time.Duration {
|
||||
return time.Duration(t - Now())
|
||||
}
|
||||
|
||||
// FromTime converts a time.Time to a monotonic Time.
|
||||
// The conversion is relative to the package's start time and may lose
|
||||
// precision if the time.Time is far from the start time.
|
||||
func FromTime(t time.Time) Time {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return Time(t.Sub(start).Nanoseconds())
|
||||
}
|
||||
116
vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go
generated
vendored
Normal file
116
vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go
generated
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length")
|
||||
|
||||
// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999.
|
||||
// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1
|
||||
// restricts the length to 20 bytes.
|
||||
type ArbitraryLenConnectionID []byte
|
||||
|
||||
func (c ArbitraryLenConnectionID) Len() int {
|
||||
return len(c)
|
||||
}
|
||||
|
||||
func (c ArbitraryLenConnectionID) Bytes() []byte {
|
||||
return c
|
||||
}
|
||||
|
||||
func (c ArbitraryLenConnectionID) String() string {
|
||||
if c.Len() == 0 {
|
||||
return "(empty)"
|
||||
}
|
||||
return hex.EncodeToString(c.Bytes())
|
||||
}
|
||||
|
||||
const maxConnectionIDLen = 20
|
||||
|
||||
// A ConnectionID in QUIC
|
||||
type ConnectionID struct {
|
||||
b [20]byte
|
||||
l uint8
|
||||
}
|
||||
|
||||
// GenerateConnectionID generates a connection ID using cryptographic random
|
||||
func GenerateConnectionID(l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
c.l = uint8(l)
|
||||
_, err := rand.Read(c.b[:l])
|
||||
return c, err
|
||||
}
|
||||
|
||||
// ParseConnectionID interprets b as a Connection ID.
|
||||
// It panics if b is longer than 20 bytes.
|
||||
func ParseConnectionID(b []byte) ConnectionID {
|
||||
if len(b) > maxConnectionIDLen {
|
||||
panic("invalid conn id length")
|
||||
}
|
||||
var c ConnectionID
|
||||
c.l = uint8(len(b))
|
||||
copy(c.b[:c.l], b)
|
||||
return c
|
||||
}
|
||||
|
||||
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
|
||||
// It uses a length randomly chosen between 8 and 20 bytes.
|
||||
func GenerateConnectionIDForInitial() (ConnectionID, error) {
|
||||
r := make([]byte, 1)
|
||||
if _, err := rand.Read(r); err != nil {
|
||||
return ConnectionID{}, err
|
||||
}
|
||||
l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
|
||||
return GenerateConnectionID(l)
|
||||
}
|
||||
|
||||
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
|
||||
// It returns io.EOF if there are not enough bytes to read.
|
||||
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
if l == 0 {
|
||||
return c, nil
|
||||
}
|
||||
if l > maxConnectionIDLen {
|
||||
return c, ErrInvalidConnectionIDLen
|
||||
}
|
||||
c.l = uint8(l)
|
||||
_, err := io.ReadFull(r, c.b[:l])
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return c, io.EOF
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// Len returns the length of the connection ID in bytes
|
||||
func (c ConnectionID) Len() int {
|
||||
return int(c.l)
|
||||
}
|
||||
|
||||
// Bytes returns the byte representation
|
||||
func (c ConnectionID) Bytes() []byte {
|
||||
return c.b[:c.l]
|
||||
}
|
||||
|
||||
func (c ConnectionID) String() string {
|
||||
if c.Len() == 0 {
|
||||
return "(empty)"
|
||||
}
|
||||
return hex.EncodeToString(c.Bytes())
|
||||
}
|
||||
|
||||
type DefaultConnectionIDGenerator struct {
|
||||
ConnLen int
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
|
||||
return GenerateConnectionID(d.ConnLen)
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
|
||||
return d.ConnLen
|
||||
}
|
||||
65
vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go
generated
vendored
Normal file
65
vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go
generated
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// EncryptionLevel is the encryption level
|
||||
// Default value is Unencrypted
|
||||
type EncryptionLevel uint8
|
||||
|
||||
const (
|
||||
// EncryptionInitial is the Initial encryption level
|
||||
EncryptionInitial EncryptionLevel = 1 + iota
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT
|
||||
// Encryption1RTT is the 1-RTT encryption level
|
||||
Encryption1RTT
|
||||
)
|
||||
|
||||
func (e EncryptionLevel) String() string {
|
||||
switch e {
|
||||
case EncryptionInitial:
|
||||
return "Initial"
|
||||
case EncryptionHandshake:
|
||||
return "Handshake"
|
||||
case Encryption0RTT:
|
||||
return "0-RTT"
|
||||
case Encryption1RTT:
|
||||
return "1-RTT"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel {
|
||||
switch e {
|
||||
case EncryptionInitial:
|
||||
return tls.QUICEncryptionLevelInitial
|
||||
case EncryptionHandshake:
|
||||
return tls.QUICEncryptionLevelHandshake
|
||||
case Encryption1RTT:
|
||||
return tls.QUICEncryptionLevelApplication
|
||||
case Encryption0RTT:
|
||||
return tls.QUICEncryptionLevelEarly
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel {
|
||||
switch e {
|
||||
case tls.QUICEncryptionLevelInitial:
|
||||
return EncryptionInitial
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
return EncryptionHandshake
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
return Encryption1RTT
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
return Encryption0RTT
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
36
vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go
generated
vendored
Normal file
36
vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go
generated
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
package protocol
|
||||
|
||||
// KeyPhase is the key phase
|
||||
type KeyPhase uint64
|
||||
|
||||
// Bit determines the key phase bit
|
||||
func (p KeyPhase) Bit() KeyPhaseBit {
|
||||
if p%2 == 0 {
|
||||
return KeyPhaseZero
|
||||
}
|
||||
return KeyPhaseOne
|
||||
}
|
||||
|
||||
// KeyPhaseBit is the key phase bit
|
||||
type KeyPhaseBit uint8
|
||||
|
||||
const (
|
||||
// KeyPhaseUndefined is an undefined key phase
|
||||
KeyPhaseUndefined KeyPhaseBit = iota
|
||||
// KeyPhaseZero is key phase 0
|
||||
KeyPhaseZero
|
||||
// KeyPhaseOne is key phase 1
|
||||
KeyPhaseOne
|
||||
)
|
||||
|
||||
func (p KeyPhaseBit) String() string {
|
||||
//nolint:exhaustive
|
||||
switch p {
|
||||
case KeyPhaseZero:
|
||||
return "0"
|
||||
case KeyPhaseOne:
|
||||
return "1"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
57
vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go
generated
vendored
Normal file
57
vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go
generated
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
package protocol
|
||||
|
||||
// A PacketNumber in QUIC
|
||||
type PacketNumber int64
|
||||
|
||||
// InvalidPacketNumber is a packet number that is never sent.
|
||||
// In QUIC, 0 is a valid packet number.
|
||||
const InvalidPacketNumber PacketNumber = -1
|
||||
|
||||
// PacketNumberLen is the length of the packet number in bytes
|
||||
type PacketNumberLen uint8
|
||||
|
||||
const (
|
||||
// PacketNumberLen1 is a packet number length of 1 byte
|
||||
PacketNumberLen1 PacketNumberLen = 1
|
||||
// PacketNumberLen2 is a packet number length of 2 bytes
|
||||
PacketNumberLen2 PacketNumberLen = 2
|
||||
// PacketNumberLen3 is a packet number length of 3 bytes
|
||||
PacketNumberLen3 PacketNumberLen = 3
|
||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||
PacketNumberLen4 PacketNumberLen = 4
|
||||
)
|
||||
|
||||
// DecodePacketNumber calculates the packet number based its length and the last seen packet number
|
||||
// This function is taken from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3.
|
||||
func DecodePacketNumber(length PacketNumberLen, largest PacketNumber, truncated PacketNumber) PacketNumber {
|
||||
expected := largest + 1
|
||||
win := PacketNumber(1 << (length * 8))
|
||||
hwin := win / 2
|
||||
mask := win - 1
|
||||
candidate := (expected & ^mask) | truncated
|
||||
if candidate <= expected-hwin && candidate < 1<<62-win {
|
||||
return candidate + win
|
||||
}
|
||||
if candidate > expected+hwin && candidate >= win {
|
||||
return candidate - win
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
// PacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func PacketNumberLengthForHeader(pn, largestAcked PacketNumber) PacketNumberLen {
|
||||
var numUnacked PacketNumber
|
||||
if largestAcked == InvalidPacketNumber {
|
||||
numUnacked = pn + 1
|
||||
} else {
|
||||
numUnacked = pn - largestAcked
|
||||
}
|
||||
if numUnacked < 1<<(16-1) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if numUnacked < 1<<(24-1) {
|
||||
return PacketNumberLen3
|
||||
}
|
||||
return PacketNumberLen4
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user