diff --git a/.golangci.yml b/.golangci.yml index a4f3df5a..038c18c8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,13 +2,13 @@ version: "2" linters: default: all disable: - - cyclop + #- cyclop - depguard - err113 # disabled temporarily: there are just too many issues to address - - errchkjson - - errorlint + #- errchkjson + #- errorlint - exhaustruct - - forcetypeassert + #- forcetypeassert - funlen - gochecknoglobals - gochecknoinits @@ -18,10 +18,10 @@ linters: - gomoddirectives # moved to mono-repo, multi-modules, so replace directives are needed - gosmopolitan - inamedparam - - ireturn - - lll + - ireturn # this repo adopted a pattern where there are quite many returned interfaces. To be challenged with v2 + #- lll - musttag - - nestif + #- nestif - nilerr # nilerr crashes on this repo - nlreturn - noinlineerr @@ -31,7 +31,7 @@ linters: - testpackage - thelper - tparallel - - unparam + #- unparam - varnamelen - whitespace - wrapcheck @@ -43,8 +43,17 @@ linters: goconst: min-len: 2 min-occurrences: 3 + cyclop: + max-complexity: 25 gocyclo: - min-complexity: 45 + min-complexity: 25 + gocognit: + min-complexity: 35 + exhaustive: + default-signifies-exhaustive: true + default-case-required: true + lll: + line-length: 180 exclusions: generated: lax presets: @@ -61,13 +70,17 @@ formatters: enable: - gofmt - goimports + settings: + # local prefixes regroup imports from these packages + goimports: + local-prefixes: + - github.com/go-openapi exclusions: generated: lax paths: - .worktrees - third_party$ - builtin$ - - examples$ issues: # Maximum issues count per one linter. # Set to 0 to disable. diff --git a/client-middleware/opentracing/opentracing.go b/client-middleware/opentracing/opentracing.go index df50f799..ffaf688e 100644 --- a/client-middleware/opentracing/opentracing.go +++ b/client-middleware/opentracing/opentracing.go @@ -7,11 +7,12 @@ import ( "fmt" "net/http" - "github.com/go-openapi/strfmt" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/log" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/runtime" ) diff --git a/client/content_negotiation_test.go b/client/content_negotiation_test.go index 568d7d7c..a0f74dca 100644 --- a/client/content_negotiation_test.go +++ b/client/content_negotiation_test.go @@ -238,9 +238,9 @@ func writerSetBody(payload any) runtime.ClientRequestWriter { }) } -func writerSetHeader(name, value string) runtime.ClientRequestWriter { +func writerSetContentType(value string) runtime.ClientRequestWriter { return runtime.ClientRequestWriterFunc(func(r runtime.ClientRequest, _ strfmt.Registry) error { - return r.SetHeaderParam(name, value) + return r.SetHeaderParam("Content-Type", value) }) } @@ -250,9 +250,9 @@ func writerSetForm(name string, values ...string) runtime.ClientRequestWriter { }) } -func writerSetFile(name string, files ...runtime.NamedReadCloser) runtime.ClientRequestWriter { +func writerSetFileToUpload(files ...runtime.NamedReadCloser) runtime.ClientRequestWriter { return runtime.ClientRequestWriterFunc(func(r runtime.ClientRequest, _ strfmt.Registry) error { - return r.SetFileParam(name, files...) + return r.SetFileParam("upload", files...) }) } @@ -303,8 +303,8 @@ type staticFileWithCT struct{ *staticFile } func (f *staticFileWithCT) ContentType() string { return f.ct } -func newFile(name, data string) runtime.NamedReadCloser { - return &staticFile{name: name, r: strings.NewReader(data)} +func newFile(data string) runtime.NamedReadCloser { + return &staticFile{name: "doc.txt", r: strings.NewReader(data)} } func newFileWithCT(name, data, ct string) runtime.NamedReadCloser { @@ -374,7 +374,7 @@ func payloadStructCases() iter.Seq[buildHTTPCase] { name: "struct + SetHeader Content-Type is ignored — picker wins", mediaType: runtime.JSONMime, writer: writerCompose( - writerSetHeader("Content-Type", "application/x-ignored"), + writerSetContentType("application/x-ignored"), writerSetBody(task{Content: "y"}), ), wantContentType: runtime.JSONMime, @@ -509,7 +509,7 @@ func payloadReaderCases() iter.Seq[buildHTTPCase] { name: "io.Reader + SetHeader Content-Type wins over picker", mediaType: runtime.JSONMime, writer: writerCompose( - writerSetHeader("Content-Type", vendorMime), + writerSetContentType(vendorMime), writerSetBody(bytes.NewReader([]byte("v"))), ), wantContentType: vendorMime, @@ -520,7 +520,7 @@ func payloadReaderCases() iter.Seq[buildHTTPCase] { name: "io.Reader + SetHeader wins over payload ContentType()", mediaType: runtime.JSONMime, writer: writerCompose( - writerSetHeader("Content-Type", "application/x-explicit"), + writerSetContentType("application/x-explicit"), writerSetBody(&readerWithCT{ Reader: strings.NewReader("body"), ct: ndjsonMime, @@ -535,7 +535,7 @@ func payloadReaderCases() iter.Seq[buildHTTPCase] { mediaType: runtime.JSONMime, consumes: []string{runtime.JSONMime, runtime.DefaultMime}, writer: writerCompose( - writerSetHeader("Content-Type", "application/x-explicit"), + writerSetContentType("application/x-explicit"), writerSetBody(bytes.NewReader([]byte("v"))), ), wantContentType: "application/x-explicit", @@ -546,7 +546,7 @@ func payloadReaderCases() iter.Seq[buildHTTPCase] { name: "io.ReadCloser + SetHeader Content-Type wins", mediaType: runtime.TextMime, writer: writerCompose( - writerSetHeader("Content-Type", vendorMime), + writerSetContentType(vendorMime), writerSetBody(io.NopCloser(strings.NewReader("data"))), ), wantContentType: vendorMime, @@ -615,7 +615,7 @@ func fileFieldCases() iter.Seq[buildHTTPCase] { { name: "file field + multipart mime", mediaType: runtime.MultipartFormMime, - writer: writerSetFile("upload", newFile("doc.txt", "filebody")), + writer: writerSetFileToUpload(newFile("filebody")), wantContentTypePrefix: runtime.MultipartFormMime + "; boundary=", wantBody: bodyContainsAll(`name="upload"`, `filename="doc.txt"`, "filebody"), }, @@ -624,14 +624,14 @@ func fileFieldCases() iter.Seq[buildHTTPCase] { // the file content travels as a regular form value. name: "file field + urlencoded mime — file inlined as form value", mediaType: runtime.URLencodedFormMime, - writer: writerSetFile("upload", newFile("doc.txt", "abc")), + writer: writerSetFileToUpload(newFile("abc")), wantContentType: runtime.URLencodedFormMime, wantBody: bodyContainsAll("upload=abc"), }, { name: "file with declared ContentType()", mediaType: runtime.MultipartFormMime, - writer: writerSetFile("upload", newFileWithCT("doc.txt", "x", "application/json")), + writer: writerSetFileToUpload(newFileWithCT("doc.txt", "x", "application/json")), wantContentTypePrefix: runtime.MultipartFormMime + "; boundary=", wantBody: bodyContainsAll("application/json"), }, @@ -646,7 +646,7 @@ func formAndFileFieldCases() iter.Seq[buildHTTPCase] { mediaType: runtime.MultipartFormMime, writer: writerCompose( writerSetForm("name", "fido"), - writerSetFile("upload", newFile("doc.txt", "filebody")), + writerSetFileToUpload(newFile("filebody")), ), wantContentTypePrefix: runtime.MultipartFormMime + "; boundary=", wantBody: bodyContainsAll(`name="name"`, "fido", `filename="doc.txt"`, "filebody"), @@ -656,7 +656,7 @@ func formAndFileFieldCases() iter.Seq[buildHTTPCase] { mediaType: runtime.URLencodedFormMime, writer: writerCompose( writerSetForm("name", "fido"), - writerSetFile("upload", newFile("doc.txt", "abc")), + writerSetFileToUpload(newFile("abc")), ), wantContentType: runtime.URLencodedFormMime, wantBody: bodyContainsAll("name=fido", "upload=abc"), @@ -748,7 +748,7 @@ func submitWiringCases() iter.Seq[submitCase] { name: "consumes [json] + SetHeader Content-Type — escape hatch wins", consumes: []string{runtime.JSONMime}, writer: writerCompose( - writerSetHeader("Content-Type", vendorMime), + writerSetContentType(vendorMime), writerSetBody(bytes.NewReader([]byte("data"))), ), wantContentType: vendorMime, diff --git a/client/httptrace_tls.go b/client/httptrace_tls.go index 8b0f555c..063fb259 100644 --- a/client/httptrace_tls.go +++ b/client/httptrace_tls.go @@ -71,7 +71,7 @@ func introspectTLSConfig(client *http.Client) *tls.Config { func (s *traceSession) emitTLSDiagnostic(state tls.ConnectionState, err error) { s.emitf("# TLS DIAGNOSTIC") - //nolint:exhaustive // tlsAxisGeneric is handled by the default branch. + // tlsAxisGeneric is handled by the default branch. switch axis := classifyTLSError(err); axis { case tlsAxisProtocolVersion: s.diagnoseProtocolVersion(state, err) @@ -215,7 +215,6 @@ func (s *traceSession) diagnoseCertInvalid(certInvalid x509.CertificateInvalidEr cert := certInvalid.Cert s.emitf("# reason: %s", certInvalidReasonName(certInvalid.Reason)) - //nolint:exhaustive // Less-common reasons render via the default branch (issuer + NotAfter dump). switch certInvalid.Reason { case x509.Expired: s.emitf("# leaf: subject=%s", cert.Subject) @@ -231,6 +230,7 @@ func (s *traceSession) diagnoseCertInvalid(certInvalid x509.CertificateInvalidEr s.emitf("# suggested: set TLSClientOptions.ServerName to match") s.emitf("# one of the cert SANs, or fix the cert.") default: + // Less-common reasons render via the default branch (issuer + NotAfter dump). s.emitf("# leaf: subject=%s, issuer=%s", cert.Subject, cert.Issuer) s.emitf("# NotBefore=%s", cert.NotBefore.UTC().Format(time.RFC3339)) s.emitf("# NotAfter=%s", cert.NotAfter.UTC().Format(time.RFC3339)) @@ -326,7 +326,7 @@ func cipherSuiteNames(ids []uint16) []string { // human-readable label. The stdlib does not expose a String() // method for these, so we keep a small table. // -//nolint:exhaustive // Anything outside the listed cases falls through to the numeric default. +// Anything outside the listed cases falls through to the numeric default. func certInvalidReasonName(r x509.InvalidReason) string { switch r { case x509.NotAuthorizedToSign: diff --git a/client/internal/request/request.go b/client/internal/request/request.go index cc564c30..d5e42438 100644 --- a/client/internal/request/request.go +++ b/client/internal/request/request.go @@ -6,6 +6,7 @@ package request import ( "bytes" "context" + "errors" "fmt" "io" "log" @@ -82,7 +83,7 @@ var _ runtime.ClientRequest = new(Request) // ensure compliance to the interface // [Request.SetHeaderParam] during WriteToRequest, and we treat that as an intentional escape hatch // 2. use payload's [runtime.ContentTyper] declaration (in this case, the produced payload knows its content type) // 3. use `application/octet-stream` if it is available in the registered producers -// 4. otherwise ser the picker's mediaType +// 4. otherwise set the picker's mediaType // // For multi-part requests, the content type of each part is auto-detected using the following sequence: // @@ -316,7 +317,9 @@ func (r *Request) SetConsumes(consumers []string) { // // On error the cancel is invoked internally and a no-op cancel is returned, // so callers can defer cancel unconditionally. -func (r *Request) BuildHTTPContext(parentCtx context.Context, mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (*http.Request, context.CancelFunc, error) { +func (r *Request) BuildHTTPContext(parentCtx context.Context, mediaType, basePath string, + producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter, +) (*http.Request, context.CancelFunc, error) { if err := r.writer.WriteToRequest(r, registry); err != nil { return nil, noop, err } @@ -362,11 +365,13 @@ func (r *Request) usesStreamingBody(mediaType string) bool { if (len(r.formFields) > 0 || len(r.fileFields) > 0) && r.isMultipart(mediaType) { return true } + if r.payload != nil { if _, ok := r.payload.(io.Reader); ok { return true } } + return false } @@ -409,7 +414,9 @@ func (r *Request) isMultipart(mediaType string) bool { // // Auth is trivial in this flow because the buffer is already populated when the auth helper // asks for the body via r.GetBody(). -func (r *Request) buildBufferedRequest(ctx context.Context, mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (*http.Request, error) { +func (r *Request) buildBufferedRequest(ctx context.Context, mediaType, basePath string, + producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter, +) (*http.Request, error) { var body io.Reader var err error @@ -450,7 +457,9 @@ func (r *Request) buildBufferedRequest(ctx context.Context, mediaType, basePath // (it would otherwise park forever on pw.Write with no reader). // // For stream payloads it closes the user-provided io.ReadCloser. -func (r *Request) buildStreamingRequest(ctx context.Context, mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (req *http.Request, retErr error) { +func (r *Request) buildStreamingRequest(ctx context.Context, mediaType, basePath string, + producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter, +) (req *http.Request, retErr error) { var body io.Reader if len(r.formFields) > 0 || len(r.fileFields) > 0 { body = r.writeMultipartBody(ctx, mediaType) @@ -603,7 +612,7 @@ func (r *Request) applyAuthWithBodyCopy(auth runtime.ClientAuthInfoWriter, body // underlying pipe/stream. Caller treats body as ignorable when // err != nil per Go convention; the defer reads it via closure. if copyErr != nil { - return body, fmt.Errorf("error retrieving the response body: %v", copyErr) + return body, fmt.Errorf("error copying the request body: %w", copyErr) } if authErr != nil { @@ -731,7 +740,7 @@ func (r *Request) streamMultipartParts(ctx context.Context, mp *multipart.Writer const contentTypeBufferSize = 512 buf := make([]byte, contentTypeBufferSize) size, err := fi.Read(buf) - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { logClose(err, pw) return } @@ -789,7 +798,13 @@ func (r *Request) writeStreamPayload(mediaType string, producers map[string]runt if rdr, ok := r.payload.(io.ReadCloser); ok { return rdr } - return r.payload.(io.Reader) + + rdr, ok := r.payload.(io.Reader) + if !ok { + panic("internal error: payload expected to be an io.Reader") // guaranteed by earlier checks + } + + return rdr } // writeNonStreamPayload runs the producer registered for mediaType @@ -903,7 +918,8 @@ func logClose(err error, pw *io.PipeWriter) { } } -func mangleContentType(_, boundary string) string { +func mangleContentType(mediaType, boundary string) string { + _ = mediaType // reserved for future enhancement: honor caller-provided media type // Proposal for enhancement: honor caller's boundary if specified return "multipart/form-data; boundary=" + boundary } diff --git a/client/internal/request/request_test.go b/client/internal/request/request_test.go index 68d61f38..c5d157eb 100644 --- a/client/internal/request/request_test.go +++ b/client/internal/request/request_test.go @@ -304,7 +304,7 @@ func TestBuildRequest_BuildHTTP_XMLPayload(t *testing.T) { } func TestBuildRequest_BuildHTTP_TextPayload(t *testing.T) { - const bd = "Tom: Organ trail; John: Bird watching" + const bd = "Tom: Oregon trail; John: Bird watching" reqWrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, _ strfmt.Registry) error { _ = req.SetBodyParam(bd) diff --git a/client/opentelemetry.go b/client/opentelemetry.go index 69426745..d11f7919 100644 --- a/client/opentelemetry.go +++ b/client/opentelemetry.go @@ -8,14 +8,15 @@ import ( "net/http" "strings" - "github.com/go-openapi/runtime" - "github.com/go-openapi/strfmt" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.37.0" "go.opentelemetry.io/otel/trace" + + "github.com/go-openapi/runtime" + "github.com/go-openapi/strfmt" ) const ( diff --git a/client/opentelemetry_test.go b/client/opentelemetry_test.go index 6268e216..dc1e86ab 100644 --- a/client/opentelemetry_test.go +++ b/client/opentelemetry_test.go @@ -8,9 +8,6 @@ import ( "net/http" "testing" - "github.com/go-openapi/runtime" - "github.com/go-openapi/testify/v2/assert" - "github.com/go-openapi/testify/v2/require" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -18,6 +15,10 @@ import ( tracesdk "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" "go.opentelemetry.io/otel/trace" + + "github.com/go-openapi/runtime" + "github.com/go-openapi/testify/v2/assert" + "github.com/go-openapi/testify/v2/require" ) func Test_OpenTelemetryRuntime_submit(t *testing.T) { diff --git a/client/runtime.go b/client/runtime.go index 3c828d37..62576463 100644 --- a/client/runtime.go +++ b/client/runtime.go @@ -401,7 +401,7 @@ func (r *Runtime) dumpResponse(res *http.Response, ct string) error { // Falls back to the "*/*" entry if no match found. func (r *Runtime) resolveConsumer(ct string) (runtime.Consumer, error) { if _, _, err := mime.ParseMediaType(ct); err != nil { - return nil, fmt.Errorf("parse content type: %s", err) + return nil, fmt.Errorf("parse content type: %w", err) } if cons, ok := mediatype.Lookup(r.Consumers, ct, r.matchOpts()...); ok { return cons, nil diff --git a/client/runtime_test.go b/client/runtime_test.go index 4ff45cb0..db4b49df 100644 --- a/client/runtime_test.go +++ b/client/runtime_test.go @@ -118,7 +118,8 @@ func TestRuntime_Concurrent(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -164,7 +165,8 @@ func TestRuntime_Canary(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -216,7 +218,8 @@ func TestRuntime_XMLCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, tasks{}, res) - actual := res.(tasks) + actual, ok := res.(tasks) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -258,7 +261,8 @@ func TestRuntime_TextCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, "", res) - actual := res.(string) + actual, ok := res.(string) + require.TrueT(t, ok) assert.EqualT(t, result, actual) } @@ -303,7 +307,8 @@ func TestRuntime_CSVCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, bytes.Buffer{}, res) - actual := res.(bytes.Buffer) + actual, ok := res.(bytes.Buffer) + require.TrueT(t, ok) assert.EqualT(t, result, actual.String()) } @@ -361,7 +366,8 @@ func TestRuntime_CustomTransport(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -475,7 +481,8 @@ func TestRuntime_AuthCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -575,7 +582,8 @@ func TestRuntime_PickConsumer(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -628,7 +636,8 @@ func TestRuntime_ContentTypeCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -683,7 +692,8 @@ func TestRuntime_ChunkedResponse(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -883,7 +893,8 @@ func TestRuntime_FallbackConsumer(t *testing.T) { require.NoError(t, err) assert.IsType(t, []byte{}, res) - actual := res.([]byte) + actual, ok := res.([]byte) + require.TrueT(t, ok) assert.EqualValues(t, result, actual) } @@ -933,7 +944,8 @@ func TestRuntime_AuthHeaderParamDetected(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) - actual := res.([]task) + actual, ok := res.([]task) + require.TrueT(t, ok) assert.Equal(t, result, actual) } @@ -941,7 +953,6 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t const ( // these values should be sufficient for most CI engines clientTimeout time.Duration = 25 * time.Millisecond - serverDelay time.Duration = 100 * time.Millisecond clientNoTimeout time.Duration = 250 * time.Millisecond ctxError = "context deadline exceeded" ) @@ -974,7 +985,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t }) t.Run("with timeout specified as a request parameter, no operation context", func(t *testing.T) { - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) rt := New(host, "/", []string{schemeHTTP}) @@ -1005,7 +1016,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t }) t.Run("with timeout specified as a request parameter, no context at all", func(t *testing.T) { - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) rt := New(host, "/", []string{schemeHTTP}) @@ -1036,7 +1047,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t }) t.Run("with inherited operation context, timeout specified as operation context, request timeout set to 0", func(t *testing.T) { - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) rt := New(host, "/", []string{schemeHTTP}) @@ -1073,7 +1084,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t }) t.Run("with a fresh operation context, timeout specified as operation context, request timeout set to 0", func(t *testing.T) { - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) rt := New(host, "/", []string{schemeHTTP}) rt.Context = nil @@ -1110,7 +1121,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t t.Run("with an hypothetical timeout specified as runtime context, no operation context", func(t *testing.T) { // in real life, the runtime context may be cancellable for other reasons than timeout - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) t.Run("should not time out", func(t *testing.T) { @@ -1151,7 +1162,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t }) t.Run("with multiple timeouts set, shortest wins", func(t *testing.T) { - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) rt := New(host, "/", []string{schemeHTTP}) @@ -1219,7 +1230,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t }) t.Run("with no context, explicit infinite wait", func(t *testing.T) { - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) rt := New(host, "/", []string{schemeHTTP}) @@ -1241,7 +1252,7 @@ func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates t requestEmptyWriter := runtime.ClientRequestWriterFunc(func(_ runtime.ClientRequest, _ strfmt.Registry) error { return nil }) - host, cleaner := serverBuilder(t, serverDelay, result) + host, cleaner := serverBuilder(t, result) t.Cleanup(cleaner) rt := New(host, "/", []string{schemeHTTP}) @@ -1336,7 +1347,8 @@ func TestGetBodyCallsBeforeRoundTrip(t *testing.T) { res, err := openAPIClient.Submit(operation) require.NoError(t, err) - actual := res.(string) + actual, ok := res.(string) + require.TrueT(t, ok) require.EqualT(t, "test result", actual) } @@ -1395,7 +1407,10 @@ func assertResult(result []task) func(testing.TB, any) { } } -func serverBuilder(t testing.TB, delay time.Duration, result []task) (string, func()) { +const serverDelay = 100 * time.Millisecond + +func serverBuilder(t testing.TB, result []task) (string, func()) { + delay := serverDelay server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() timer := time.NewTimer(delay) diff --git a/client/tls.go b/client/tls.go index 94a005d0..017694fa 100644 --- a/client/tls.go +++ b/client/tls.go @@ -46,6 +46,7 @@ type TLSClientOptions struct { // If set, it will be combined with the other loaded certificates (see LoadedCA and CA). // If neither LoadedCA or CA is set, the provided pool will override the system // certificate pool. + // // The caller must not use the supplied pool after calling TLSClientAuth. LoadedCAPool *x509.CertPool @@ -104,7 +105,7 @@ func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) { if opts.Certificate != "" { cert, err := tls.LoadX509KeyPair(opts.Certificate, opts.Key) if err != nil { - return nil, fmt.Errorf("tls client cert: %v", err) + return nil, fmt.Errorf("tls client cert: %w", err) } cfg.Certificates = []tls.Certificate{cert} } else if opts.LoadedCertificate != nil { @@ -115,7 +116,7 @@ func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) { // understands) and pairs with the canonical "PRIVATE KEY" PEM label. keyBytes, err := x509.MarshalPKCS8PrivateKey(opts.LoadedKey) if err != nil { - return nil, fmt.Errorf("tls client priv key: %v", err) + return nil, fmt.Errorf("tls client priv key: %w", err) } block = pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes} @@ -123,7 +124,7 @@ func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) { cert, err := tls.X509KeyPair(certPem, keyPem) if err != nil { - return nil, fmt.Errorf("tls client cert: %v", err) + return nil, fmt.Errorf("tls client cert: %w", err) } cfg.Certificates = []tls.Certificate{cert} } @@ -147,7 +148,7 @@ func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) { // load ca cert caCert, err := os.ReadFile(opts.CA) if err != nil { - return nil, fmt.Errorf("tls client ca: %v", err) + return nil, fmt.Errorf("tls client ca: %w", err) } caCertPool := basePool(opts.LoadedCAPool) caCertPool.AppendCertsFromPEM(caCert) @@ -165,7 +166,7 @@ func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) { return cfg, nil } -// TLSTransport creates a [http] client transport suitable for mutual [tls] auth. +// TLSTransport creates a [http.RoundTripper] for a client transport,suitable for mutual TLS auth. func TLSTransport(opts TLSClientOptions) (http.RoundTripper, error) { cfg, err := TLSClientAuth(opts) if err != nil { @@ -185,9 +186,12 @@ func TLSClient(opts TLSClientOptions) (*http.Client, error) { } // basePool returns pool if non-nil; otherwise it returns a new empty cert pool. +// +// Clones the pool provided up front by the caller. func basePool(pool *x509.CertPool) *x509.CertPool { if pool == nil { return x509.NewCertPool() } + return pool.Clone() } diff --git a/client_request_test.go b/client_request_test.go index 4dae9dd1..96441fb6 100644 --- a/client_request_test.go +++ b/client_request_test.go @@ -8,6 +8,7 @@ import ( "github.com/go-openapi/strfmt" "github.com/go-openapi/testify/v2/assert" + "github.com/go-openapi/testify/v2/require" ) func TestRequestWriterFunc(t *testing.T) { @@ -20,5 +21,8 @@ func TestRequestWriterFunc(t *testing.T) { tr := new(TestClientRequest) _ = hand.WriteToRequest(tr, nil) assert.EqualT(t, "blahblah", tr.Headers.Get("Blah")) - assert.EqualT(t, "Adriana", tr.Body.(struct{ Name string }).Name) + + body, ok := tr.Body.(struct{ Name string }) + require.TrueT(t, ok) + assert.EqualT(t, "Adriana", body.Name) } diff --git a/client_response.go b/client_response.go index 92668db4..7b4b7e40 100644 --- a/client_response.go +++ b/client_response.go @@ -59,7 +59,7 @@ func (o *APIError) Error() string { if err, ok := o.Response.(error); ok { resp = []byte("'" + sanitizer.Replace(err.Error()) + "'") } else { - resp, _ = json.Marshal(o.Response) + resp, _ = json.Marshal(o.Response) //nolint:errchkjson // error swallowed as this is our last best effort attempt } return fmt.Sprintf("%s (status %d): %s", o.OperationName, o.Code, resp) diff --git a/docs/examples/client/tracing/main.go b/docs/examples/client/tracing/main.go index 1c8d3698..bac3a433 100644 --- a/docs/examples/client/tracing/main.go +++ b/docs/examples/client/tracing/main.go @@ -10,11 +10,12 @@ package main import ( + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "github.com/go-openapi/runtime" "github.com/go-openapi/runtime/client" "github.com/go-openapi/strfmt" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) const samplePetID = 42 diff --git a/docs/examples/server/pipeline/main.go b/docs/examples/server/pipeline/main.go index 07ef137b..82ec9dec 100644 --- a/docs/examples/server/pipeline/main.go +++ b/docs/examples/server/pipeline/main.go @@ -15,12 +15,13 @@ import ( "os" "time" + "github.com/justinas/alice" + "github.com/go-openapi/analysis" "github.com/go-openapi/loads" "github.com/go-openapi/runtime" "github.com/go-openapi/runtime/middleware" "github.com/go-openapi/runtime/middleware/untyped" - "github.com/justinas/alice" ) // --- Stubs (excluded from rendered snippets) ------------------------ diff --git a/form.go b/form.go index fe5333b3..63f57b73 100644 --- a/form.go +++ b/form.go @@ -242,7 +242,9 @@ func parseFormBody(r *http.Request, maxMemory, maxBody int64) error { mt, _, _ := ContentType(r.Header) if mt == MultipartFormMime { - //nolint:gosec // G120: false positive (gosec doesn't track the Body). See https://github.com/securego/gosec/blob/de65614d10a6b84029e3e1215567b8ce7e490f23/testutils/g120_samples.go#L57 + //nolint:gosec // G120: false positive -- see below + // gosec doesn't track the Body. + // See https://github.com/securego/gosec/blob/de65614d10a6b84029e3e1215567b8ce7e490f23/testutils/g120_samples.go#L57 return r.ParseMultipartForm(maxMemory) } return r.ParseForm() diff --git a/form_test.go b/form_test.go index dab33943..d11d123b 100644 --- a/form_test.go +++ b/form_test.go @@ -92,6 +92,8 @@ func assertParseError(t *testing.T, err error, wantName string, reasonCheck func // assertCompositeContains extracts a *errors.CompositeError from err // and asserts that at least n inner errors satisfy match. +// +//nolint:unparam // left variable n for future assertions func assertCompositeContains(t *testing.T, err error, n int, match func(error) bool) { t.Helper() require.Error(t, err) diff --git a/internal/testing/data.go b/internal/testing/data.go index 2ba4f30a..0690217a 100644 --- a/internal/testing/data.go +++ b/internal/testing/data.go @@ -780,7 +780,7 @@ const InvalidJSON = `{ }, "info": { "contact": "apiteam@wordnik.com", - "description": "This is a sample server Petstore server. You can find out more about Swagger \n at http://swagger.wordnik.com or on irc.freenode.net, #swagger. For this sample,\n you can use the api key \"special-key\" to test the authorization filters", + "description": "This is a sample server Petstore server...", "license": "Apache 2.0", "licenseUrl": "http://www.apache.org/licenses/LICENSE-2.0.html", "termsOfServiceUrl": "http://helloreverb.com/terms/", diff --git a/internal/testing/petstore/api.go b/internal/testing/petstore/api.go index 7e0b8556..0d3fcefb 100644 --- a/internal/testing/petstore/api.go +++ b/internal/testing/petstore/api.go @@ -60,13 +60,19 @@ func NewAPI(t gotest.TB) (*loads.Document, *untyped.API) { return nil, errors.Unauthenticated("token") })) api.RegisterAuthorizer(runtime.AuthorizerFunc(func(r *http.Request, user any) error { - if r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/api/pets") && user.(string) != apiPrincipal { - if user.(string) == apiUser { + userStr, ok := user.(string) + if !ok { + return goerrors.New("unauthorized") + } + + if r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/api/pets") && userStr != apiPrincipal { + if userStr == apiUser { return errors.CompositeValidationError(errors.New(errors.InvalidTypeCode, "unauthorized")) } return goerrors.New("unauthorized") } + return nil })) api.RegisterOperation("get", "/pets", new(stubOperationHandler)) diff --git a/internal/testing/simplepetstore/api.go b/internal/testing/simplepetstore/api.go index 020809d6..a51ed1b9 100644 --- a/internal/testing/simplepetstore/api.go +++ b/internal/testing/simplepetstore/api.go @@ -5,6 +5,7 @@ package simplepetstore import ( "encoding/json" + stderrors "errors" "net/http" "sync" "sync/atomic" @@ -38,21 +39,55 @@ var getAllPets = runtime.OperationHandlerFunc(func(_ any) (any, error) { }) var createPet = runtime.OperationHandlerFunc(func(data any) (any, error) { - body := data.(map[string]any)["pet"].(map[string]any) + asMap, ok := data.(map[string]any) + if !ok { + return nil, stderrors.New("bad data: wanted map") + } + pet := asMap["pet"] + body, ok := pet.(map[string]any) + if !ok { + return nil, stderrors.New("bad pet body: wanted map") + } + + name, ok := body["name"].(string) + if !ok { + return nil, stderrors.New("bad name: wanted string") + } + status, ok := body["status"].(string) + if !ok { + return nil, stderrors.New("bad status: wanted string") + } + return addPet(Pet{ - Name: body["name"].(string), - Status: body["status"].(string), + Name: name, + Status: status, }), nil }) var deletePet = runtime.OperationHandlerFunc(func(data any) (any, error) { - id := data.(map[string]any)["id"].(int64) + asMap, ok := data.(map[string]any) + if !ok { + return nil, stderrors.New("bad data: wanted map") + } + id, ok := asMap["id"].(int64) + if !ok { + return nil, stderrors.New("bad id: wanted int64") + } + removePet(id) return map[string]any{}, nil }) var getPetByID = runtime.OperationHandlerFunc(func(data any) (any, error) { - id := data.(map[string]any)["id"].(int64) + asMap, ok := data.(map[string]any) + if !ok { + return nil, stderrors.New("bad data: wanted map") + } + id, ok := asMap["id"].(int64) + if !ok { + return nil, stderrors.New("bad id: wanted int64") + } + return petByID(id) }) diff --git a/middleware/context.go b/middleware/context.go index 2680c1cd..0942edef 100644 --- a/middleware/context.go +++ b/middleware/context.go @@ -5,6 +5,7 @@ package middleware import ( stdContext "context" + stderrors "errors" "fmt" "net/http" "strings" @@ -187,53 +188,57 @@ func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Cont if spec == nil || api == nil { return nil } + analyzer := analysis.New(spec.Spec()) for method, hls := range analyzer.Operations() { um := strings.ToUpper(method) for path, op := range hls { schemes := analyzer.SecurityRequirementsFor(op) - if oh, ok := api.OperationHandlerFor(method, path); ok { - if handlers == nil { - handlers = make(map[string]map[string]http.Handler) + oh, ok := api.OperationHandlerFor(method, path) + if !ok { + continue + } + + if handlers == nil { + handlers = make(map[string]map[string]http.Handler) + } + if b, ok := handlers[um]; !ok || b == nil { + handlers[um] = make(map[string]http.Handler) + } + + var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // lookup route info in the context + route, rCtx, _ := context.RouteInfo(r) + if rCtx != nil { + r = rCtx } - if b, ok := handlers[um]; !ok || b == nil { - handlers[um] = make(map[string]http.Handler) + + // bind and validate the request using reflection + var bound any + var validation error + bound, r, validation = context.BindAndValidate(r, route) + if validation != nil { + context.Respond(w, r, route.Produces, route, validation) + return } - var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // lookup route info in the context - route, rCtx, _ := context.RouteInfo(r) - if rCtx != nil { - r = rCtx - } - - // bind and validate the request using reflection - var bound any - var validation error - bound, r, validation = context.BindAndValidate(r, route) - if validation != nil { - context.Respond(w, r, route.Produces, route, validation) - return - } - - // actually handle the request - result, err := oh.Handle(bound) - if err != nil { - // respond with failure - context.Respond(w, r, route.Produces, route, err) - return - } - - // respond with success - context.Respond(w, r, route.Produces, route, result) - }) - - if len(schemes) > 0 { - handler = newSecureAPI(context, handler) + // actually handle the request + result, err := oh.Handle(bound) + if err != nil { + // respond with failure + context.Respond(w, r, route.Produces, route, err) + return } - handlers[um][path] = handler + + // respond with success + context.Respond(w, r, route.Produces, route, result) + }) + + if len(schemes) > 0 { + handler = newSecureAPI(context, handler) } + handlers[um][path] = handler } } @@ -357,57 +362,42 @@ func (c *Context) RequiredProduces() []string { // BindValidRequest binds a params object to a request but only when the request is valid // if the request is not valid an error will be returned. func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error { - var res []error var requestContentType string // check and validate content type, select consumer if runtime.HasBody(request) { - ct, _, err := runtime.ContentType(request.Header) + ct, cons, err := c.bindRequestBody(request, route) if err != nil { - res = append(res, err) - } else { - c.debugLogf("validating content type for %q against [%s]", ct, strings.Join(route.Consumes, ", ")) - if err := validateContentType(route.Consumes, ct); err != nil { - res = append(res, err) - } - if len(res) == 0 { - cons, ok := mediatype.Lookup(route.Consumers, ct, c.matchOpts()...) - if !ok { - res = append(res, errors.New(http.StatusInternalServerError, "no consumer registered for %s", ct)) - } else { - route.Consumer = cons - requestContentType = ct - } - } + return errors.CompositeValidationError(err) } + + // happy path + requestContentType = ct + route.Consumer = cons } // check and validate the response format - if len(res) == 0 { - // if the route does not provide Produces and a default contentType could not be identified - // based on a body, typical for GET and DELETE requests, then default contentType to. - if len(route.Produces) == 0 && requestContentType == "" { - requestContentType = "*/*" - } + // if the route does not provide Produces and a default contentType could not be identified + // based on a body, typical for GET and DELETE requests, then default contentType to. + if len(route.Produces) == 0 && requestContentType == "" { + requestContentType = "*/*" + } - if str := negotiate.ContentType(request, route.Produces, requestContentType, c.negotiateOpts()...); str == "" { - res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces)) - } + str := negotiate.ContentType(request, route.Produces, requestContentType, c.negotiateOpts()...) + if str == "" { + return errors.CompositeValidationError( + errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces), + ) + } + + if binder == nil { + return nil } // now bind the request with the provided binder // it's assumed the binder will also validate the request and return an error if the // request is invalid - if binder != nil && len(res) == 0 { - if err := binder.BindRequest(request, route); err != nil { - return err - } - } - - if len(res) > 0 { - return errors.CompositeValidationError(res...) - } - return nil + return binder.BindRequest(request, route) } // ContentType gets the parsed value of a content type @@ -515,7 +505,8 @@ func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (any, *h } if route.Authorizer != nil { if err := route.Authorizer.Authorize(request, usr); err != nil { - if _, ok := err.(errors.Error); ok { + var apiError errors.Error + if stderrors.As(err, &apiError) { return nil, nil, err } @@ -562,91 +553,29 @@ func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) { // Respond renders the response after doing some content negotiation. func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data any) { c.debugLogf("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces) - offers := []string{} - for _, mt := range produces { - if mt != c.api.DefaultProduces() { - offers = append(offers, mt) - } - } - // the default producer is last so more specific producers take precedence - offers = append(offers, c.api.DefaultProduces()) - c.debugLogf("offers: %v", offers) + offers := c.buildOffers(produces) var format string format, r = c.ResponseFormat(r, offers) rw.Header().Set(runtime.HeaderContentType, format) if resp, ok := data.(Responder); ok { - producers := route.Producers - // producers contains keys with normalized format, if a format has MIME type parameter such as `text/plain; charset=utf-8` - // then you must provide `text/plain` to get the correct producer. HOWEVER, format here is not normalized. - prod, ok := producers[normalizeOffer(format)] - if !ok { - prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) - pr, ok := prods[c.api.DefaultProduces()] - if !ok { - panic(fmt.Errorf("%d: %s", http.StatusInternalServerError, cantFindProducer(format))) - } - prod = pr - } - resp.WriteResponse(rw, prod) + c.respondWithResponder(rw, r, route, resp, format) return } if err, ok := data.(error); ok { - if format == "" { - rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime) - } - - if realm := security.FailedBasicAuth(r); realm != "" { - rw.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", realm)) - } - - if route == nil || route.Operation == nil { - c.api.ServeErrorFor("")(rw, r, err) - return - } - c.api.ServeErrorFor(route.Operation.ID)(rw, r, err) + c.respondWithError(rw, r, produces, route, err, format) return } if route == nil || route.Operation == nil { - rw.WriteHeader(http.StatusOK) - if r.Method == http.MethodHead { - return - } - producers := c.api.ProducersFor(normalizeOffers(offers)) - prod, ok := producers[format] - if !ok { - panic(fmt.Errorf("%d: %s", http.StatusInternalServerError, cantFindProducer(format))) - } - if err := prod.Produce(rw, data); err != nil { - panic(err) // let the recovery middleware deal with this - } + c.respondWithoutCode(rw, r, data, format, offers) return } if _, code, ok := route.Operation.SuccessResponse(); ok { - rw.WriteHeader(code) - if code == http.StatusNoContent || r.Method == http.MethodHead { - return - } - - producers := route.Producers - prod, ok := producers[format] - if !ok { - if !ok { - prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) - pr, ok := prods[c.api.DefaultProduces()] - if !ok { - panic(fmt.Errorf("%d: %s", http.StatusInternalServerError, cantFindProducer(format))) - } - prod = pr - } - } - if err := prod.Produce(rw, data); err != nil { - panic(err) // let the recovery middleware deal with this - } + c.respondWithCode(rw, r, route, code, data, format) return } @@ -738,6 +667,120 @@ func (c *Context) RoutesHandler(builder Builder) http.Handler { return NewRouter(c, b(NewOperationExecutor(c))) } +func (c *Context) bindRequestBody(request *http.Request, route *MatchedRoute) (string, runtime.Consumer, error) { + ct, _, err := runtime.ContentType(request.Header) + if err != nil { + return "", nil, err + } + + c.debugLogf("validating content type for %q against [%s]", ct, strings.Join(route.Consumes, ", ")) + if err := validateContentType(route.Consumes, ct); err != nil { + return "", nil, err + } + + cons, ok := mediatype.Lookup(route.Consumers, ct, c.matchOpts()...) + if !ok { + return "", nil, errors.New(http.StatusInternalServerError, "no consumer registered for %s", ct) + } + + return ct, cons, nil +} + +func (c *Context) respondWithResponder(rw http.ResponseWriter, r *http.Request, route *MatchedRoute, resp Responder, format string) { + _ = r + producers := route.Producers + + // producers contains keys with normalized format, if a format has MIME type parameter such as `text/plain; charset=utf-8` + // then you must provide `text/plain` to get the correct producer. HOWEVER, format here is not normalized. + prod, ok := producers[normalizeOffer(format)] + if !ok { + prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) + pr, ok := prods[c.api.DefaultProduces()] + if !ok { + panic(fmt.Errorf("%d: %s", http.StatusInternalServerError, cantFindProducer(format))) + } + prod = pr + } + + resp.WriteResponse(rw, prod) +} + +func (c *Context) respondWithError(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, err error, format string) { + _ = produces + + if format == "" { + rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime) + } + + if realm := security.FailedBasicAuth(r); realm != "" { + rw.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", realm)) + } + + if route == nil || route.Operation == nil { + c.api.ServeErrorFor("")(rw, r, err) + return + } + + c.api.ServeErrorFor(route.Operation.ID)(rw, r, err) +} + +func (c *Context) respondWithoutCode(rw http.ResponseWriter, r *http.Request, data any, format string, offers []string) { + rw.WriteHeader(http.StatusOK) + if r.Method == http.MethodHead { + return + } + + producers := c.api.ProducersFor(normalizeOffers(offers)) + prod, ok := producers[format] + if !ok { + panic(fmt.Errorf("%d: %s", http.StatusInternalServerError, cantFindProducer(format))) + } + + if err := prod.Produce(rw, data); err != nil { + panic(err) // let the recovery middleware deal with this + } +} + +func (c *Context) buildOffers(produces []string) []string { + offers := make([]string, 0, len(produces)+1) + + for _, mt := range produces { + if mt != c.api.DefaultProduces() { + offers = append(offers, mt) + } + } + + // the default producer is last so more specific producers take precedence + offers = append(offers, c.api.DefaultProduces()) + c.debugLogf("offers: %v", offers) + + return offers +} + +func (c *Context) respondWithCode(rw http.ResponseWriter, r *http.Request, route *MatchedRoute, code int, data any, format string) { + rw.WriteHeader(code) + if code == http.StatusNoContent || r.Method == http.MethodHead { + return + } + + producers := route.Producers + prod, ok := producers[format] + if !ok { + if !ok { + prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) + pr, ok := prods[c.api.DefaultProduces()] + if !ok { + panic(fmt.Errorf("%d: %s", http.StatusInternalServerError, cantFindProducer(format))) + } + prod = pr + } + } + + if err := prod.Produce(rw, data); err != nil { + panic(err) // let the recovery middleware deal with this + } +} + // uiOptionsForHandler bridges the deprecated [UIOption] set to the new [docui.Option] set. func (c Context) uiOptionsForHandler(opts []UIOption) []docui.Option { uiOpts := uiOptionsWithDefaults(opts) diff --git a/middleware/context_test.go b/middleware/context_test.go index 5a700f4b..203daadc 100644 --- a/middleware/context_test.go +++ b/middleware/context_test.go @@ -49,12 +49,12 @@ func assertAPIError(t *testing.T, wantCode int, err error) { require.Error(t, err) - ce, ok := err.(*apierrors.CompositeError) - assert.TrueT(t, ok) - assert.NotEmpty(t, ce.Errors) + var ce *apierrors.CompositeError + require.TrueT(t, errors.As(err, &ce)) + require.NotEmpty(t, ce.Errors) - ae, ok := ce.Errors[0].(apierrors.Error) - assert.TrueT(t, ok) + var ae apierrors.Error + require.TrueT(t, errors.As(ce.Errors[0], &ae)) assert.EqualT(t, wantCode, int(ae.Code())) } @@ -425,8 +425,8 @@ func TestContextAuthorize_WithAuthorizer(t *testing.T) { request.SetBasicAuth("anyother", "anyother") p, reqWithCtx, err = ctx.Authorize(request, ri) require.Error(t, err) - ae, ok := err.(apierrors.Error) - require.TrueTf(t, ok, "expected an apierrors.Error, but got %T", err) + var ae apierrors.Error + require.TrueTf(t, errors.As(err, &ae), "expected an apierrors.Error, but got %T", err) assert.EqualT(t, http.StatusForbidden, int(ae.Code())) assert.Nil(t, p) assert.Nil(t, reqWithCtx) diff --git a/middleware/denco/router.go b/middleware/denco/router.go index 82ee80c8..e380a138 100644 --- a/middleware/denco/router.go +++ b/middleware/denco/router.go @@ -30,8 +30,8 @@ const ( // PathParamCharacter indicates a RESTCONF path param. PathParamCharacter = '=' - // MaxSize is max size of records and internal slice. - MaxSize = (1 << 22) - 1 //nolint:mnd + // MaxSize is the maximum size of records and internal slice (encoded over 22 bits). + MaxSize = (1 << baseBits) - 1 ) // Router represents a URL router. @@ -54,9 +54,12 @@ func New() *Router { } } -// Lookup returns data and path parameters that associated with path. +// Lookup returns data and path parameters which are associated to the path. +// // params is a slice of the [Param] that arranged in the order in which parameters appeared. -// e.g. when built routing path is "/path/to/:id/:name" and given path is "/path/to/1/alice". params order is [{"id": "1"}, {"name": "alice"}], not [{"name": "alice"}, {"id": "1"}]. +// +// e.g. when built routing path is "/path/to/:id/:name" and given path is "/path/to/1/alice", +// params order is [{"id": "1"}, {"name": "alice"}], not [{"name": "alice"}, {"id": "1"}]. func (rt *Router) Lookup(path string) (data any, params Params, found bool) { if data, found = rt.static[path]; found { return data, nil, true @@ -145,6 +148,7 @@ func newDoubleArray() *doubleArray { type baseCheck uint32 const ( + baseBits = 22 flagsBits = 10 checkBits = 8 ) @@ -158,7 +162,7 @@ func (bc *baseCheck) SetBase(base int) { } func (bc baseCheck) Check() byte { - return byte(bc) //nolint:gosec // integer conversion is ok + return byte(bc) //nolint:gosec // integer conversion is ok: we pick the last 8 bits } func (bc *baseCheck) SetCheck(check byte) { diff --git a/middleware/denco/router_test.go b/middleware/denco/router_test.go index 1c17ee2b..9eeb84b5 100644 --- a/middleware/denco/router_test.go +++ b/middleware/denco/router_test.go @@ -196,7 +196,9 @@ func runLookupTest(t *testing.T, records []denco.Record, testcases []testcase) { for _, testcase := range testcases { data, params, found := r.Lookup(testcase.path) if !reflect.DeepEqual(data, testcase.value) || !reflect.DeepEqual(params, denco.Params(testcase.params)) || !reflect.DeepEqual(found, testcase.found) { - t.Errorf("Router.Lookup(%q) => (%#v, %#v, %#v), want (%#v, %#v, %#v)", testcase.path, data, params, found, testcase.value, denco.Params(testcase.params), testcase.found) + t.Errorf("Router.Lookup(%q) => (%#v, %#v, %#v), want (%#v, %#v, %#v)", + testcase.path, data, params, found, testcase.value, denco.Params(testcase.params), testcase.found, + ) } } } @@ -317,26 +319,41 @@ func TestRouter_Lookup_realURIs(t *testing.T) { {pathGists, pathGists, nil, true}, {"/gists/1", pathGistsID, []denco.Param{{paramID, "1"}}, true}, {"/gists/2/star", pathGistsIDStar, []denco.Param{{paramID, "2"}}, true}, - {"/repos/naoina/denco/git/blobs/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitBlobs, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true}, - {"/repos/naoina/denco/git/commits/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitCommits, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true}, + {"/repos/naoina/denco/git/blobs/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitBlobs, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true, + }, + {"/repos/naoina/denco/git/commits/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitCommits, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true, + }, {"/repos/naoina/denco/git/refs", pathReposGitRefs, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, - {"/repos/naoina/denco/git/tags/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitTags, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true}, - {"/repos/naoina/denco/git/trees/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitTrees, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true}, + {"/repos/naoina/denco/git/tags/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitTags, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true, + }, + {"/repos/naoina/denco/git/trees/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposGitTrees, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true, + }, {pathIssues, pathIssues, nil, true}, {pathUserIssues, pathUserIssues, nil, true}, {"/orgs/something/issues", pathOrgsIssues, []denco.Param{{paramOrg, valSomething}}, true}, {"/repos/naoina/denco/issues", pathReposIssues, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/issues/1", pathReposIssue, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true}, {"/repos/naoina/denco/assignees", pathReposAssignees, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, - {"/repos/naoina/denco/assignees/foo", pathReposAssignee, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {"assignee", valFoo}}, true}, - {"/repos/naoina/denco/issues/1/comments", pathReposIssueComments, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true}, + {"/repos/naoina/denco/assignees/foo", pathReposAssignee, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {"assignee", valFoo}}, true, + }, + {"/repos/naoina/denco/issues/1/comments", pathReposIssueComments, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true}, {"/repos/naoina/denco/issues/1/events", pathReposIssueEvents, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true}, {"/repos/naoina/denco/labels", pathReposLabels, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/labels/bug", pathReposLabel, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {"name", "bug"}}, true}, {"/repos/naoina/denco/issues/1/labels", pathReposIssueLabels, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true}, - {"/repos/naoina/denco/milestones/1/labels", pathReposMilestoneLabels, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true}, + {"/repos/naoina/denco/milestones/1/labels", pathReposMilestoneLabels, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true, + }, {"/repos/naoina/denco/milestones", pathReposMilestones, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, - {"/repos/naoina/denco/milestones/1", pathReposMilestone, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true}, + {"/repos/naoina/denco/milestones/1", pathReposMilestone, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramNumber, "1"}}, true, + }, {pathEmojis, pathEmojis, nil, true}, {pathGitignoreTemplates, pathGitignoreTemplates, nil, true}, {"/gitignore/templates/Go", pathGitignoreTemplate, []denco.Param{{"name", "Go"}}, true}, @@ -376,10 +393,14 @@ func TestRouter_Lookup_realURIs(t *testing.T) { {"/repos/naoina/denco/collaborators", pathReposCollaborators, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/collaborators/something", pathReposCollaborator, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramUser, valSomething}}, true}, {"/repos/naoina/denco/comments", pathReposComments, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, - {"/repos/naoina/denco/commits/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9/comments", pathReposCommitsSHAComments, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true}, + {"/repos/naoina/denco/commits/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9/comments", pathReposCommitsSHAComments, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true, + }, {"/repos/naoina/denco/comments/1", pathReposComment, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramID, "1"}}, true}, {"/repos/naoina/denco/commits", pathReposCommits, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, - {"/repos/naoina/denco/commits/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposCommit, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true}, + {"/repos/naoina/denco/commits/03c3bbc7f0d12268b9ca53d4fbfd8dc5ae5697b9", pathReposCommit, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramSHA, valSHA1}}, true, + }, {"/repos/naoina/denco/readme", pathReposReadme, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/keys", pathReposKeys, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/keys/1", pathReposKey, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramID, "1"}}, true}, @@ -390,7 +411,9 @@ func TestRouter_Lookup_realURIs(t *testing.T) { {"/repos/naoina/denco/hooks/2", pathReposHook, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramID, "2"}}, true}, {"/repos/naoina/denco/releases", pathReposReleases, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/releases/1", pathReposRelease, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramID, "1"}}, true}, - {"/repos/naoina/denco/releases/1/assets", pathReposReleaseAssets, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramID, "1"}}, true}, + {"/repos/naoina/denco/releases/1/assets", pathReposReleaseAssets, + []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}, {paramID, "1"}}, true, + }, {"/repos/naoina/denco/stats/contributors", pathReposStatsContributors, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/stats/commit_activity", pathReposStatsCommitActivity, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, {"/repos/naoina/denco/stats/code_frequency", pathReposStatsCodeFrequency, []denco.Param{{paramOwner, valNaoina}, {paramRepo, valDenco}}, true}, @@ -401,7 +424,9 @@ func TestRouter_Lookup_realURIs(t *testing.T) { {pathSearchCode, pathSearchCode, nil, true}, {pathSearchIssues, pathSearchIssues, nil, true}, {pathSearchUsers, pathSearchUsers, nil, true}, - {"/legacy/issues/search/naoina/denco/closed/test", pathLegacyIssuesSearch, []denco.Param{{paramOwner, valNaoina}, {"repository", valDenco}, {"state", "closed"}, {paramKeyword, valTest}}, true}, + {"/legacy/issues/search/naoina/denco/closed/test", pathLegacyIssuesSearch, + []denco.Param{{paramOwner, valNaoina}, {"repository", valDenco}, {"state", "closed"}, {paramKeyword, valTest}}, true, + }, {"/legacy/repos/search/test", pathLegacyReposSearch, []denco.Param{{paramKeyword, valTest}}, true}, {"/legacy/user/search/test", pathLegacyUserSearch, []denco.Param{{paramKeyword, valTest}}, true}, {"/legacy/user/email/naoina@kuune.org", pathLegacyUserEmail, []denco.Param{{"email", "naoina@kuune.org"}}, true}, diff --git a/middleware/denco/server.go b/middleware/denco/server.go index e6c0976d..5c36b07d 100644 --- a/middleware/denco/server.go +++ b/middleware/denco/server.go @@ -9,7 +9,7 @@ import ( "net/http" ) -// Mux represents a multiplexer for HTTP request. +// Mux represents a multiplexer for HTTP requests. type Mux struct{} // NewMux returns a new [Mux]. @@ -17,27 +17,27 @@ func NewMux() *Mux { return &Mux{} } -// GET is shorthand of [Mux].Handler("GET", path, handler). +// GET is shorthand for [Mux.Handler] ("GET", path, handler). func (m *Mux) GET(path string, handler HandlerFunc) Handler { return m.Handler("GET", path, handler) } -// POST is shorthand of [Mux].Handler("POST", path, handler). +// POST is shorthand for [Mux.Handler] ("POST", path, handler). func (m *Mux) POST(path string, handler HandlerFunc) Handler { return m.Handler("POST", path, handler) } -// PUT is shorthand of [Mux].Handler("PUT", path, handler). +// PUT is shorthand for [Mux.Handler] ("PUT", path, handler). func (m *Mux) PUT(path string, handler HandlerFunc) Handler { return m.Handler("PUT", path, handler) } -// HEAD is shorthand of [Mux].Handler("HEAD", path, handler). +// HEAD is shorthand for [Mux.Handler]("HEAD", path, handler). func (m *Mux) HEAD(path string, handler HandlerFunc) Handler { return m.Handler("HEAD", path, handler) } -// Handler returns a handler for HTTP method. +// Handler returns a [Handler] for a HTTP method. func (m *Mux) Handler(method, path string, handler HandlerFunc) Handler { return Handler{ Method: method, @@ -63,7 +63,7 @@ func (m *Mux) Build(handlers []Handler) (http.Handler, error) { return mux, nil } -// Handler represents a handler of HTTP request. +// Handler represents a handler of HTTP requests. type Handler struct { // Method is an HTTP method. Method string @@ -75,7 +75,7 @@ type Handler struct { Func HandlerFunc } -// HandlerFunc is aliased to type of handler function. +// HandlerFunc is an aliase to the handler function, similar to [http.HandlerFunc]. type HandlerFunc func(w http.ResponseWriter, r *http.Request, params Params) type serveMux struct { @@ -88,7 +88,7 @@ func newServeMux() *serveMux { } } -// ServeHTTP implements http.Handler interface. +// ServeHTTP implements the [http.Handler] interface. func (mux *serveMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { handler, params := mux.handler(r.Method, r.URL.Path) handler(w, r, params) @@ -97,15 +97,17 @@ func (mux *serveMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (mux *serveMux) handler(method, path string) (HandlerFunc, []Param) { if router, found := mux.routers[method]; found { if handler, params, found := router.Lookup(path); found { - return handler.(HandlerFunc), params + return handler.(HandlerFunc), params //nolint:forcetypeassert // type is guaranteed when the path is found } } return NotFound, nil } // NotFound replies to the request with an HTTP 404 not found error. -// NotFound is called when unknown HTTP method or a handler not found. -// If you want to use the your own NotFound handler, please overwrite this variable. +// +// NotFound is called when unknown HTTP methods are being user or a handler not found. +// +// If you want to use your own NotFound handler, please overwrite this variable. var NotFound = func(w http.ResponseWriter, r *http.Request, _ Params) { http.NotFound(w, r) } diff --git a/middleware/parameter.go b/middleware/parameter.go index da37d9ac..d8b8e3ba 100644 --- a/middleware/parameter.go +++ b/middleware/parameter.go @@ -6,7 +6,7 @@ package middleware import ( "encoding" "encoding/base64" - "fmt" + stderrors "errors" "io" "net/http" "reflect" @@ -56,117 +56,139 @@ func (p *untypedParamBinder) Type() reflect.Type { } func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, target reflect.Value) error { - // fmt.Println("binding", p.name, "as", p.Type()) switch p.parameter.In { case "query": - data, custom, hasKey, err := p.readValue(runtime.Values(request.URL.Query()), target) - if err != nil { - return err - } - if custom { - return nil - } - - return p.bindValue(data, hasKey, target) + return p.bindQuery(request, routeParams, consumer, target) case "header": - data, custom, hasKey, err := p.readValue(runtime.Values(request.Header), target) - if err != nil { - return err - } - if custom { - return nil - } - return p.bindValue(data, hasKey, target) + return p.bindHeader(request, routeParams, consumer, target) case "path": - data, custom, hasKey, err := p.readValue(routeParams, target) - if err != nil { - return err - } - if custom { - return nil - } - return p.bindValue(data, hasKey, target) + return p.bindPath(request, routeParams, consumer, target) case "formData": - mt, _, ctErr := runtime.ContentType(request.Header) - if ctErr != nil { - return errors.InvalidContentType("", []string{runtime.MultipartFormMime, runtime.URLencodedFormMime}) - } + return p.bindFormData(request, routeParams, consumer, target) - if mt != runtime.MultipartFormMime && mt != runtime.URLencodedFormMime { - return errors.InvalidContentType(mt, []string{runtime.MultipartFormMime, runtime.URLencodedFormMime}) - } + case "body": + return p.bindBody(request, routeParams, consumer, target) + default: + return errors.New(http.StatusInternalServerError, "invalid parameter location: %q", p.parameter.In) + } +} - // Parse via the shared helper. The helper routes on Content-Type - // (multipart/form-data → ParseMultipartForm; all non-multipart types, - // including application/x-www-form-urlencoded, → ParseForm) - // and applies the default 32 MiB body cap via http.MaxBytesReader. - // Idempotent across the per-parameter loop: stdlib short-circuits - // when r.MultipartForm / r.PostForm are already populated. - if _, perr := runtime.BindForm(request, runtime.BindFormMaxParseMemory(defaultMaxMemory)); perr != nil { - return perr - } +func (p *untypedParamBinder) bindQuery(request *http.Request, _ RouteParams, _ runtime.Consumer, target reflect.Value) error { + data, custom, hasKey, err := p.readValue(runtime.Values(request.URL.Query()), target) + if err != nil { + return err + } + if custom { + return nil + } + + return p.bindValue(data, hasKey, target) +} + +func (p *untypedParamBinder) bindHeader(request *http.Request, _ RouteParams, _ runtime.Consumer, target reflect.Value) error { + data, custom, hasKey, err := p.readValue(runtime.Values(request.Header), target) + if err != nil { + return err + } + if custom { + return nil + } + return p.bindValue(data, hasKey, target) +} + +func (p *untypedParamBinder) bindPath(_ *http.Request, routeParams RouteParams, _ runtime.Consumer, target reflect.Value) error { + data, custom, hasKey, err := p.readValue(routeParams, target) + if err != nil { + return err + } + if custom { + return nil + } + return p.bindValue(data, hasKey, target) +} + +func (p *untypedParamBinder) bindFormData(request *http.Request, _ RouteParams, _ runtime.Consumer, target reflect.Value) error { + mt, _, ctErr := runtime.ContentType(request.Header) + if ctErr != nil { + return errors.InvalidContentType("", []string{runtime.MultipartFormMime, runtime.URLencodedFormMime}) + } + + if mt != runtime.MultipartFormMime && mt != runtime.URLencodedFormMime { + return errors.InvalidContentType(mt, []string{runtime.MultipartFormMime, runtime.URLencodedFormMime}) + } - if p.parameter.Type == "file" { - file, header, ffErr := request.FormFile(p.parameter.Name) - if ffErr != nil { - if p.parameter.Required { - return errors.NewParseError(p.Name, p.parameter.In, "", ffErr) - } + // Parse via the shared helper. The helper routes on Content-Type + // (multipart/form-data → ParseMultipartForm; all non-multipart types, + // including application/x-www-form-urlencoded, → ParseForm) + // and applies the default 32 MiB body cap via http.MaxBytesReader. + // Idempotent across the per-parameter loop: stdlib short-circuits + // when r.MultipartForm / r.PostForm are already populated. + if _, perr := runtime.BindForm(request, runtime.BindFormMaxParseMemory(defaultMaxMemory)); perr != nil { + return perr + } - return nil + if p.parameter.Type == "file" { + file, header, ffErr := request.FormFile(p.parameter.Name) + if ffErr != nil { + if p.parameter.Required { + return errors.NewParseError(p.Name, p.parameter.In, "", ffErr) } - target.Set(reflect.ValueOf(runtime.File{Data: file, Header: header})) return nil } - if request.MultipartForm != nil { - data, custom, hasKey, rvErr := p.readValue(runtime.Values(request.MultipartForm.Value), target) - if rvErr != nil { - return rvErr - } - if custom { - return nil - } - return p.bindValue(data, hasKey, target) - } - data, custom, hasKey, err := p.readValue(runtime.Values(request.PostForm), target) - if err != nil { - return err + target.Set(reflect.ValueOf(runtime.File{Data: file, Header: header})) + return nil + } + + if request.MultipartForm != nil { + data, custom, hasKey, rvErr := p.readValue(runtime.Values(request.MultipartForm.Value), target) + if rvErr != nil { + return rvErr } if custom { return nil } return p.bindValue(data, hasKey, target) + } + data, custom, hasKey, err := p.readValue(runtime.Values(request.PostForm), target) + if err != nil { + return err + } + if custom { + return nil + } + return p.bindValue(data, hasKey, target) +} - case "body": - newValue := reflect.New(target.Type()) - if !runtime.HasBody(request) { - if p.parameter.Default != nil { - target.Set(reflect.ValueOf(p.parameter.Default)) - } +func (p *untypedParamBinder) bindBody(request *http.Request, _ RouteParams, consumer runtime.Consumer, target reflect.Value) error { + newValue := reflect.New(target.Type()) + if !runtime.HasBody(request) { + if p.parameter.Default != nil { + target.Set(reflect.ValueOf(p.parameter.Default)) + } + + return nil + } + if err := consumer.Consume(request.Body, newValue.Interface()); err != nil { + if stderrors.Is(err, io.EOF) && p.parameter.Default != nil { + target.Set(reflect.ValueOf(p.parameter.Default)) return nil } - if err := consumer.Consume(request.Body, newValue.Interface()); err != nil { - if err == io.EOF && p.parameter.Default != nil { - target.Set(reflect.ValueOf(p.parameter.Default)) - return nil - } - tpe := p.parameter.Type - if p.parameter.Format != "" { - tpe = p.parameter.Format - } - return errors.InvalidType(p.Name, p.parameter.In, tpe, nil) + tpe := p.parameter.Type + if p.parameter.Format != "" { + tpe = p.parameter.Format } - target.Set(reflect.Indirect(newValue)) - return nil - default: - return fmt.Errorf("%d: invalid parameter location %q", http.StatusInternalServerError, p.parameter.In) + return errors.InvalidType(p.Name, p.parameter.In, tpe, nil) } + + target.Set(reflect.Indirect(newValue)) + + return nil } func (p *untypedParamBinder) typeForSchema(tpe, format string, items *spec.Items) reflect.Type { @@ -252,20 +274,51 @@ func (p *untypedParamBinder) bindValue(data []string, hasKey bool, target reflec if p.parameter.Type == typeArray { return p.setSliceFieldValue(target, p.parameter.Default, data, hasKey) } + var d string if len(data) > 0 { d = data[len(data)-1] } + return p.setFieldValue(target, p.parameter.Default, d, hasKey) } -func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue any, data string, hasKey bool) error { //nolint:gocyclo +func (p *untypedParamBinder) isMissingAndRequired(hasKey bool, data string) bool { + return p.parameter.Required && + p.parameter.Default == nil && + (!hasKey || (!p.parameter.AllowEmptyValue && data == "")) +} + +func (p *untypedParamBinder) setByte(target, defVal reflect.Value, tpe, data string) error { + if data == "" { + if target.CanSet() { + target.SetBytes(defVal.Bytes()) + } + + return nil + } + + b, err := base64.StdEncoding.DecodeString(data) + if err != nil { + b, err = base64.URLEncoding.DecodeString(data) + if err != nil { + return errors.InvalidType(p.Name, p.parameter.In, tpe, data) + } + } + if target.CanSet() { + target.SetBytes(b) + } + + return nil +} + +func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue any, data string, hasKey bool) error { tpe := p.parameter.Type if p.parameter.Format != "" { tpe = p.parameter.Format } - if (!hasKey || (!p.parameter.AllowEmptyValue && data == "")) && p.parameter.Required && p.parameter.Default == nil { + if p.isMissingAndRequired(hasKey, data) { return errors.Required(p.Name, p.parameter.In, data) } @@ -283,27 +336,15 @@ func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue an } if tpe == "byte" { - if data == "" { - if target.CanSet() { - target.SetBytes(defVal.Bytes()) - } - return nil - } - - b, err := base64.StdEncoding.DecodeString(data) - if err != nil { - b, err = base64.URLEncoding.DecodeString(data) - if err != nil { - return errors.InvalidType(p.Name, p.parameter.In, tpe, data) - } - } - if target.CanSet() { - target.SetBytes(b) - } - return nil + return p.setByte(target, defVal, tpe, data) } - switch target.Kind() { //nolint:exhaustive // we want to check only types that map from a swagger parameter + return p.setReflectFieldValue(target, defVal, tpe, data, hasKey) +} + +//nolint:gocyclo,cyclop // not much we can simplify further significantly: the big case with all types is unavoidable. +func (p *untypedParamBinder) setReflectFieldValue(target, defVal reflect.Value, tpe, data string, hasKey bool) error { + switch target.Kind() { // we want to check only types that map from a swagger parameter case reflect.Bool: if data == "" { if target.CanSet() { @@ -318,6 +359,7 @@ func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue an if target.CanSet() { target.SetBool(b) } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if data == "" { if target.CanSet() { @@ -403,6 +445,7 @@ func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue an default: return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } + return nil } @@ -410,20 +453,30 @@ func (p *untypedParamBinder) tryUnmarshaler(target reflect.Value, defaultValue a if !target.CanSet() { return false, nil } + // When a type implements encoding.TextUnmarshaler we'll use that instead of reflecting some more - if reflect.PointerTo(target.Type()).Implements(textUnmarshalType) { - if defaultValue != nil && len(data) == 0 { - target.Set(reflect.ValueOf(defaultValue)) - return true, nil - } - value := reflect.New(target.Type()) - if err := value.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(data)); err != nil { - return true, err - } - target.Set(reflect.Indirect(value)) + ttyp := target.Type() + if !reflect.PointerTo(ttyp).Implements(textUnmarshalType) { + return false, nil + } + + if defaultValue != nil && len(data) == 0 { + target.Set(reflect.ValueOf(defaultValue)) return true, nil } - return false, nil + + value := reflect.New(ttyp) + if !value.CanInterface() { + return false, nil + } + + if err := value.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(data)); err != nil { //nolint:forcetypeassert // this is guaranteed by the reflect check above + return true, err + } + + target.Set(reflect.Indirect(value)) + + return true, nil } func (p *untypedParamBinder) readFormattedSliceFieldValue(data string, target reflect.Value) ([]string, bool, error) { diff --git a/middleware/parameter_test.go b/middleware/parameter_test.go index 3684193e..1f181474 100644 --- a/middleware/parameter_test.go +++ b/middleware/parameter_test.go @@ -39,6 +39,8 @@ func init() { } func testCollectionFormat(t *testing.T, param *spec.Parameter, valid bool) { + t.Helper() + binder := &untypedParamBinder{ parameter: param, } diff --git a/middleware/request.go b/middleware/request.go index ad781663..08a0362d 100644 --- a/middleware/request.go +++ b/middleware/request.go @@ -40,8 +40,25 @@ func NewUntypedRequestBinder(parameters map[string]spec.Parameter, spec *spec.Sw // Bind perform the databinding and validation. func (o *UntypedRequestBinder) Bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, data any) error { + err := o.bind(request, routeParams, consumer, data) + if err == nil { + return nil // avoids returning a nil-interface + } + + return err +} + +// SetLogger allows for injecting a logger to catch debug entries. +// +// The logger is enabled in DEBUG mode only. +func (o *UntypedRequestBinder) SetLogger(lg logger.Logger) { + o.debugLogf = debugLogfFunc(lg) +} + +func (o *UntypedRequestBinder) bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, data any) *errors.CompositeError { val := reflect.Indirect(reflect.ValueOf(data)) isMap := val.Kind() == reflect.Map + var result []error o.debugLogf("binding %d parameters for %s %s", len(o.Parameters), request.Method, request.URL.EscapedPath()) for fieldName, param := range o.Parameters { @@ -94,13 +111,6 @@ func (o *UntypedRequestBinder) Bind(request *http.Request, routeParams RoutePara return nil } -// SetLogger allows for injecting a logger to catch debug entries. -// -// The logger is enabled in DEBUG mode only. -func (o *UntypedRequestBinder) SetLogger(lg logger.Logger) { - o.debugLogf = debugLogfFunc(lg) -} - func (o *UntypedRequestBinder) setDebugLogf(fn func(string, ...any)) { o.debugLogf = fn } diff --git a/middleware/request_test.go b/middleware/request_test.go index 8819a747..a0323724 100644 --- a/middleware/request_test.go +++ b/middleware/request_test.go @@ -210,7 +210,9 @@ func TestRequestBindingDefaultValue(t *testing.T) { assert.Equal(t, age, data[paramKeyAge]) assert.InDelta(t, factor, data["factor"], 1e-6) assert.InDelta(t, score, data["score"], 1e-6) - assert.EqualT(t, "hello", string(data["picture"].(strfmt.Base64))) + formatted, ok := data["picture"].(strfmt.Base64) + require.TrueT(t, ok) + assert.EqualT(t, "hello", string(formatted)) } func TestRequestBindingForInvalid(t *testing.T) { diff --git a/middleware/router.go b/middleware/router.go index d375fd77..939cf733 100644 --- a/middleware/router.go +++ b/middleware/router.go @@ -349,49 +349,62 @@ func (m *MatchedRoute) NeedsAuth() bool { func (d *defaultRouter) Lookup(method, path string) (*MatchedRoute, bool) { mth := strings.ToUpper(method) d.debugLogf("looking up route for %s %s", method, path) - if Debug { - if len(d.routers) == 0 { + if len(d.routers) == 0 { + if Debug { d.debugLogf("there are no known routers") } + panic("internal error: no router is configured") + } + + if Debug { for meth := range d.routers { d.debugLogf("got a router for %s", meth) } } - if router, ok := d.routers[mth]; ok { - if m, rp, ok := router.Lookup(fpath.Clean(escapeLiteralColons(path))); ok && m != nil { - if entry, ok := m.(*routeEntry); ok { - d.debugLogf("found a route for %s %s with %d parameters", method, path, len(entry.Parameters)) - var params RouteParams - for _, p := range rp { - v, err := url.PathUnescape(p.Value) - if err != nil { - d.debugLogf("failed to escape %q: %v", p.Value, err) - v = p.Value - } - // a workaround to handle fragment/composing parameters until they are supported in denco router - // check if this parameter is a fragment within a path segment - const enclosureSize = 2 - if xpos := strings.Index(entry.PathPattern, fmt.Sprintf("{%s}", p.Name)) + len(p.Name) + enclosureSize; xpos < len(entry.PathPattern) && entry.PathPattern[xpos] != '/' { - // extract fragment parameters - ep := strings.Split(entry.PathPattern[xpos:], "/")[0] - pnames, pvalues := decodeCompositParams(p.Name, v, ep, nil, nil) - for i, pname := range pnames { - params = append(params, RouteParam{Name: pname, Value: pvalues[i]}) - } - } else { - // use the parameter directly - params = append(params, RouteParam{Name: p.Name, Value: v}) - } - } - return &MatchedRoute{routeEntry: *entry, Params: params}, true + + router, ok := d.routers[mth] + if !ok { + d.debugLogf("couldn't find a route by method for %s %s", method, path) + return nil, false + } + + m, rp, ok := router.Lookup(fpath.Clean(escapeLiteralColons(path))) + if !ok || m == nil { + d.debugLogf("couldn't find a route by path for %s %s", method, path) + return nil, false + } + + entry, ok := m.(*routeEntry) + if !ok { + return nil, false + } + + d.debugLogf("found a route for %s %s with %d parameters", method, path, len(entry.Parameters)) + var params RouteParams + for _, p := range rp { + v, err := url.PathUnescape(p.Value) + if err != nil { + d.debugLogf("failed to escape %q: %v", p.Value, err) + v = p.Value + } + + // a workaround to handle fragment/composing parameters until they are supported in denco router + // check if this parameter is a fragment within a path segment + const enclosureSize = 2 + if xpos := strings.Index(entry.PathPattern, fmt.Sprintf("{%s}", p.Name)) + len(p.Name) + enclosureSize; xpos < len(entry.PathPattern) && entry.PathPattern[xpos] != '/' { + // extract fragment parameters + ep := strings.Split(entry.PathPattern[xpos:], "/")[0] + pnames, pvalues := decodeCompositParams(p.Name, v, ep, nil, nil) + for i, pname := range pnames { + params = append(params, RouteParam{Name: pname, Value: pvalues[i]}) } } else { - d.debugLogf("couldn't find a route by path for %s %s", method, path) + // use the parameter directly + params = append(params, RouteParam{Name: p.Name, Value: v}) } - } else { - d.debugLogf("couldn't find a route by method for %s %s", method, path) } - return nil, false + + return &MatchedRoute{routeEntry: *entry, Params: params}, true } func (d *defaultRouter) OtherMethods(method, path string) []string { @@ -426,25 +439,28 @@ func escapeLiteralColons(path string) string { func decodeCompositParams(name string, value string, pattern string, names []string, values []string) ([]string, []string) { pleft := strings.Index(pattern, "{") names = append(names, name) + if pleft < 0 { if strings.HasSuffix(value, pattern) { values = append(values, value[:len(value)-len(pattern)]) } else { values = append(values, "") } + + return names, values + } + + toskip := pattern[:pleft] + pright := strings.Index(pattern, "}") + vright := strings.Index(value, toskip) + if vright >= 0 { + values = append(values, value[:vright]) } else { - toskip := pattern[:pleft] - pright := strings.Index(pattern, "}") - vright := strings.Index(value, toskip) - if vright >= 0 { - values = append(values, value[:vright]) - } else { - values = append(values, "") - value = "" - } - return decodeCompositParams(pattern[pleft+1:pright], value[vright+len(toskip):], pattern[pright+1:], names, values) + values = append(values, "") + value = "" } - return names, values + + return decodeCompositParams(pattern[pleft+1:pright], value[vright+len(toskip):], pattern[pright+1:], names, values) } func (d *defaultRouteBuilder) AddRoute(method, path string, operation *spec.Operation) { diff --git a/middleware/router_test.go b/middleware/router_test.go index 0beb934c..25aad2ed 100644 --- a/middleware/router_test.go +++ b/middleware/router_test.go @@ -118,7 +118,8 @@ func TestRouterBuilder(t *testing.T) { rec := postRecords[0] assert.EqualT(t, "/pets", rec.Key) - val := rec.Value.(*routeEntry) + val, ok := rec.Value.(*routeEntry) + require.TrueT(t, ok) assert.Len(t, val.Consumers, 2) assert.Len(t, val.Producers, 2) assert.Len(t, val.Consumes, 2) @@ -133,7 +134,8 @@ func TestRouterBuilder(t *testing.T) { recG := getRecords[0] assert.EqualT(t, "/pets", recG.Key) - valG := recG.Value.(*routeEntry) + valG, ok := recG.Value.(*routeEntry) + require.TrueT(t, ok) assert.Len(t, valG.Consumers, 2) assert.Len(t, valG.Producers, 4) assert.Len(t, valG.Consumes, 2) diff --git a/middleware/untyped_request_test.go b/middleware/untyped_request_test.go index 1f95bf8e..f67119bc 100644 --- a/middleware/untyped_request_test.go +++ b/middleware/untyped_request_test.go @@ -65,7 +65,8 @@ func TestUntypedFileUpload(t *testing.T) { assert.Equal(t, "the-name", data[paramKeyName]) assert.NotNil(t, data["file"]) assert.IsType(t, runtime.File{}, data["file"]) - file := data["file"].(runtime.File) + file, ok := data["file"].(runtime.File) + require.TrueT(t, ok) require.NotNil(t, file.Header) assert.EqualT(t, "plain-jane.txt", file.Header.Filename) @@ -156,7 +157,8 @@ func TestUntypedOptionalFileUpload(t *testing.T) { assert.Equal(t, "the-name", data[paramKeyName]) assert.NotNil(t, data["file"]) assert.IsType(t, runtime.File{}, data["file"]) - file := data["file"].(runtime.File) + file, ok := data["file"].(runtime.File) + require.TrueT(t, ok) assert.NotNil(t, file.Header) assert.EqualT(t, "plain-jane.txt", file.Header.Filename) @@ -211,5 +213,7 @@ func TestUntypedBindingTypesForValid(t *testing.T) { assert.InDelta(t, score, data["score"], 1e-6) pb, err := base64.URLEncoding.DecodeString(picture) require.NoError(t, err) - assert.EqualValues(t, pb, data["picture"].(strfmt.Base64)) + formatted, ok := data["picture"].(strfmt.Base64) + require.TrueT(t, ok) + assert.EqualValues(t, pb, formatted) } diff --git a/middleware/validation.go b/middleware/validation.go index c583e191..63a78d48 100644 --- a/middleware/validation.go +++ b/middleware/validation.go @@ -4,6 +4,7 @@ package middleware import ( + stderrors "errors" "net/http" "strings" @@ -73,41 +74,48 @@ func (v *validation) debugLogf(format string, args ...any) { func (v *validation) parameters() { v.debugLogf("validating request parameters for %s %s", v.request.Method, v.request.URL.EscapedPath()) - if result := v.route.Binder.Bind(v.request, v.route.Params, v.route.Consumer, v.bound); result != nil { - if result.Error() == "validation failure list" { - for _, e := range result.(*errors.Validation).Value.([]any) { - v.result = append(v.result, e.(error)) - } - return + result := v.route.Binder.bind(v.request, v.route.Params, v.route.Consumer, v.bound) + if result == nil { + return + } + + for _, e := range result.Errors { + var validationErr *errors.Validation + if stderrors.As(e, &validationErr) { + v.result = append(v.result, validationErr) } - v.result = append(v.result, result) } } func (v *validation) contentType() { - if len(v.result) == 0 && runtime.HasBody(v.request) { - v.debugLogf("validating body content type for %s %s", v.request.Method, v.request.URL.EscapedPath()) - ct, _, req, err := v.context.ContentType(v.request) - if err != nil { + if len(v.result) > 0 || !runtime.HasBody(v.request) { + return + } + + v.debugLogf("validating body content type for %s %s", v.request.Method, v.request.URL.EscapedPath()) + ct, _, req, err := v.context.ContentType(v.request) + if err != nil { + v.result = append(v.result, err) + } else { + v.request = req + } + + if len(v.result) == 0 { + v.debugLogf("validating content type for %q against [%s]", ct, strings.Join(v.route.Consumes, ", ")) + if err := validateContentType(v.route.Consumes, ct, v.context.matchOpts()...); err != nil { v.result = append(v.result, err) - } else { - v.request = req } + } - if len(v.result) == 0 { - v.debugLogf("validating content type for %q against [%s]", ct, strings.Join(v.route.Consumes, ", ")) - if err := validateContentType(v.route.Consumes, ct, v.context.matchOpts()...); err != nil { - v.result = append(v.result, err) - } - } - if ct != "" && v.route.Consumer == nil { - cons, ok := mediatype.Lookup(v.route.Consumers, ct, v.context.matchOpts()...) - if !ok { - v.result = append(v.result, errors.New(http.StatusInternalServerError, "no consumer registered for %s", ct)) - } else { - v.route.Consumer = cons - } - } + if ct == "" || v.route.Consumer != nil { + return + } + + cons, ok := mediatype.Lookup(v.route.Consumers, ct, v.context.matchOpts()...) + if !ok { + v.result = append(v.result, errors.New(http.StatusInternalServerError, "no consumer registered for %s", ct)) + } else { + v.route.Consumer = cons } } diff --git a/security/authenticator.go b/security/authenticator.go index 2430997b..d84eb370 100644 --- a/security/authenticator.go +++ b/security/authenticator.go @@ -19,8 +19,8 @@ const ( accessTokenParam = "access_token" ) -// HttpAuthenticator is a function that authenticates a HTTP request. -func HttpAuthenticator(handler func(*http.Request) (bool, any, error)) runtime.Authenticator { //nolint:revive +// HTTPAuthenticator is a function that authenticates a HTTP request. +func HTTPAuthenticator(handler func(*http.Request) (bool, any, error)) runtime.Authenticator { return runtime.AuthenticatorFunc(func(params any) (bool, any, error) { if request, ok := params.(*http.Request); ok { return handler(request) @@ -32,7 +32,14 @@ func HttpAuthenticator(handler func(*http.Request) (bool, any, error)) runtime.A }) } -// ScopedAuthenticator is a function that authenticates a HTTP request against a list of valid scopes. +// HttpAuthenticator aliases [HTTPAuthenticator] for backward-compatibility. +// +// Deprecated: use HTTPAuthenticator instead. +func HttpAuthenticator(handler func(*http.Request) (bool, any, error)) runtime.Authenticator { //nolint:revive + return HTTPAuthenticator(handler) +} + +// ScopedAuthenticator is a function that authenticates an [http.Request] against a list of valid scopes. func ScopedAuthenticator(handler func(*ScopedAuthRequest) (bool, any, error)) runtime.Authenticator { return runtime.AuthenticatorFunc(func(params any) (bool, any, error) { if request, ok := params.(*ScopedAuthRequest); ok { @@ -219,7 +226,7 @@ func APIKeyAuthCtx(name, in string, authenticate TokenAuthenticationCtx) runtime }) } -// ScopedAuthRequest contains both a [http] request and the required scopes for a particular operation. +// ScopedAuthRequest contains both the [http.Request] and the required scopes for a particular operation. type ScopedAuthRequest struct { Request *http.Request RequiredScopes []string diff --git a/security/authorizer_test.go b/security/authorizer_test.go index bef9512e..62cd82d9 100644 --- a/security/authorizer_test.go +++ b/security/authorizer_test.go @@ -23,8 +23,8 @@ func TestAuthenticator(t *testing.T) { r, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) require.NoError(t, err) - t.Run("with HttpAuthenticator", func(t *testing.T) { - auth := HttpAuthenticator(func(_ *http.Request) (bool, any, error) { return true, "test", nil }) + t.Run("with HTTPAuthenticator", func(t *testing.T) { + auth := HTTPAuthenticator(func(_ *http.Request) (bool, any, error) { return true, "test", nil }) t.Run("authenticator should work on *http.Request", func(t *testing.T) { isAuth, user, err := auth.Authenticate(r) diff --git a/security/bearer_auth_test.go b/security/bearer_auth_test.go index 498fbd83..3082d350 100644 --- a/security/bearer_auth_test.go +++ b/security/bearer_auth_test.go @@ -169,7 +169,10 @@ func TestBearerAuthCtx(t *testing.T) { }) } -func testIsAuthorized(_ context.Context, req *http.Request, authorizer runtime.Authenticator, expectAuthorized authExpectation, extraAsserters ...func(context.Context, *testing.T)) func(*testing.T) { +func testIsAuthorized(_ context.Context, + req *http.Request, authorizer runtime.Authenticator, + expectAuthorized authExpectation, extraAsserters ...func(context.Context, *testing.T), +) func(*testing.T) { return func(t *testing.T) { //nolint:contextcheck hasToken, usr, err := authorizer.Authenticate(&ScopedAuthRequest{Request: req}) switch expectAuthorized { @@ -192,6 +195,9 @@ func testIsAuthorized(_ context.Context, req *http.Request, authorizer runtime.A assert.FalseT(t, hasToken) assert.Nil(t, usr) assert.Empty(t, OAuth2SchemeName(req)) + + default: + t.FailNow() } for _, contextAsserter := range extraAsserters { diff --git a/server-middleware/mediatype/mediatype.go b/server-middleware/mediatype/mediatype.go index 2138b826..41a32a16 100644 --- a/server-middleware/mediatype/mediatype.go +++ b/server-middleware/mediatype/mediatype.go @@ -197,18 +197,14 @@ func Parse(s string) (MediaType, error) { if plus := strings.LastIndexByte(mt.Subtype, '+'); plus >= 0 && plus < len(mt.Subtype)-1 { mt.Suffix = mt.Subtype[plus+1:] } + if q, ok := params["q"]; ok { - if qf, perr := strconv.ParseFloat(q, 64); perr == nil { - if qf < 0 { - qf = 0 - } - if qf > 1 { - qf = 1 - } + if qf, isFloat := boundedQ(q); isFloat { mt.Q = qf } delete(params, "q") } + if len(params) > 0 { mt.Params = params } @@ -267,6 +263,23 @@ func (m MediaType) Specificity() int { return SpecificityExactWithParams } +func boundedQ(q string) (float64, bool) { + qf, err := strconv.ParseFloat(q, 64) + if err != nil { + return 0, false + } + + if qf < 0 { + qf = 0 + } + + if qf > 1 { + qf = 1 + } + + return qf, true +} + // typeAgrees reports whether two top-level types match, allowing "*" on // either side. A type of "*" without a "*" subtype is rejected per RFC // 7231 §5.3.2 ("*/sub" is not valid), but Parse never produces such a diff --git a/server-middleware/mediatype/mediatype_test.go b/server-middleware/mediatype/mediatype_test.go index d6f54cbc..7614a6f0 100644 --- a/server-middleware/mediatype/mediatype_test.go +++ b/server-middleware/mediatype/mediatype_test.go @@ -292,7 +292,10 @@ func TestBestMatch(t *testing.T) { {"image/* matches gif first", pngWild, []string{imageGIF, imageJPG}, imageGIF}, {"image/png beats image/* on specificity (2)", pngWild, []string{imageGIF, imagePNG}, imagePNG}, {"image/png beats image/* (offer order doesn't override)", pngWild, []string{imagePNG, imageGIF}, imagePNG}, - {"vendor params don't break match", "application/vnd.google.protobuf;proto=io.prometheus.client.MetricFamily;encoding=delimited;q=0.7,text/plain;version=0.0.4;q=0.3", []string{textPlain}, textPlain}, + {"vendor params don't break match", + "application/vnd.google.protobuf;proto=io.prometheus.client.MetricFamily;encoding=delimited;q=0.7,text/plain;version=0.0.4;q=0.3", + []string{textPlain}, textPlain, + }, // vendor MIME types are NOT structurally matched against // "+json" — text/json doesn't match application/vnd.cia.v1+json. {"vendor MIME unmatched", jsonMime, []string{"application/vnd.cia.v1+json"}, ""}, diff --git a/text.go b/text.go index 24e7eaf5..3764a87f 100644 --- a/text.go +++ b/text.go @@ -36,7 +36,7 @@ func TextConsumer() Consumer { if tu, ok := data.(encoding.TextUnmarshaler); ok { err := tu.UnmarshalText(b) if err != nil { - return fmt.Errorf("text consumer: %v", err) + return fmt.Errorf("text consumer: %w", err) } return nil @@ -70,7 +70,7 @@ func TextProducer() Producer { if tm, ok := data.(encoding.TextMarshaler); ok { txt, err := tm.MarshalText() if err != nil { - return fmt.Errorf("text producer: %v", err) + return fmt.Errorf("text producer: %w", err) } _, err = writer.Write(txt) return err diff --git a/yamlpc/yaml.go b/yamlpc/yaml.go index ca71edbb..b7fab889 100644 --- a/yamlpc/yaml.go +++ b/yamlpc/yaml.go @@ -6,8 +6,9 @@ package yamlpc import ( "io" - "github.com/go-openapi/runtime" yaml "go.yaml.in/yaml/v3" + + "github.com/go-openapi/runtime" ) // YAMLConsumer creates a consumer for [yaml] data.