diff --git a/pkg/v1/remote/options.go b/pkg/v1/remote/options.go index 99a2bb2eb..65aabfe4b 100644 --- a/pkg/v1/remote/options.go +++ b/pkg/v1/remote/options.go @@ -54,6 +54,8 @@ type options struct { // Set by Reuse, we currently store one or the other. puller *Puller pusher *Pusher + + mirrors []transport.Mirror } var defaultPlatform = v1.Platform{ @@ -100,6 +102,8 @@ var defaultRetryStatusCodes = []int{ 522, // Cloudflare-specific, connection timeout } +var defaultMirrors = []transport.Mirror{} + const ( defaultJobs = 4 @@ -135,6 +139,7 @@ func makeOptions(opts ...Option) (*options, error) { retryPredicate: defaultRetryPredicate, retryBackoff: defaultRetryBackoff, retryStatusCodes: defaultRetryStatusCodes, + mirrors: defaultMirrors, } for _, option := range opts { @@ -169,6 +174,16 @@ func makeOptions(opts ...Option) (*options, error) { if o.userAgent != "" { o.transport = transport.NewUserAgent(o.transport, o.userAgent) } + + if len(o.mirrors) > 0 { + o.transport = transport.NewWithMirrors(o.transport, o.mirrors) + } else { + testMirror := transport.Mirror{ + OriginUrl: "docker://quay.io/ubi9", + MirrorEndpoints: []transport.MirrorEndpoint{{Secure: false, Endpoint: "docker://localhost:5000/ubi9"}}, + } + o.transport = transport.NewWithMirrors(o.transport, []transport.Mirror{testMirror}) + } } return o, nil @@ -331,6 +346,16 @@ func WithFilter(key string, value string) Option { } } +func WithMirrors(m ...transport.Mirror) Option { + return func(o *options) error { + if o.mirrors == nil { + o.mirrors = make([]transport.Mirror, 0) + } + o.mirrors = append(o.mirrors, m...) + return nil + } +} + // Reuse takes a Puller or Pusher and reuses it for remote interactions // rather than starting from a clean slate. For example, it will reuse token exchanges // when possible and avoid sending redundant HEAD requests. diff --git a/pkg/v1/remote/transport/mirror.go b/pkg/v1/remote/transport/mirror.go new file mode 100644 index 000000000..845d3c0b4 --- /dev/null +++ b/pkg/v1/remote/transport/mirror.go @@ -0,0 +1,96 @@ +package transport + +import ( + "fmt" + "net/http" + "net/url" + "strings" +) + +type Mirror struct { + OriginUrl string + MirrorEndpoints []MirrorEndpoint +} + +type MirrorEndpoint struct { + Endpoint string + Secure bool +} +type mirrorTransport struct { + inner http.RoundTripper + mirrors []Mirror +} + +var _ http.RoundTripper = (*mirrorTransport)(nil) + +func NewWithMirrors(inner http.RoundTripper, mirrors []Mirror) http.RoundTripper { + return &mirrorTransport{ + inner: inner, + mirrors: mirrors, + } +} + +func (t *mirrorTransport) RoundTrip(in *http.Request) (out *http.Response, err error) { + if len(t.mirrors) > 0 { + for _, mirror := range t.mirrors { + if isApplicable, err := mirror.isApplicableTo(*in.URL); isApplicable && err == nil { + for _, endpoint := range mirror.MirrorEndpoints { + mirroredRequest, err := mirror.useMirrorEndpoint(in, endpoint) + if err != nil { + fmt.Printf("ERROR: Request %v: %v\n", mirroredRequest, err) + continue + } + out, err = t.inner.RoundTrip(mirroredRequest) + if err != nil { + fmt.Printf("ERROR: Request %v: %v\n", mirroredRequest, err) + continue + } + return out, err + } + } + } + } + return t.inner.RoundTrip(in) +} + +func (m Mirror) isApplicableTo(url url.URL) (bool, error) { + mirrorUrl, err := url.Parse(m.OriginUrl) + if err != nil { + return false, fmt.Errorf("unable to parse mirror origin url %s: %v", m.OriginUrl, err) + } + if strings.Contains(url.Host, mirrorUrl.Host) || strings.Contains(url.Path, mirrorUrl.Path) { + fmt.Printf("INFO: Request %v: mirror %v matches\n", url, m) + return true, nil + } + return false, nil +} + +func (m Mirror) useMirrorEndpoint(in *http.Request, mirrorEndpoint MirrorEndpoint) (*http.Request, error) { + mirrorUrl, err := url.Parse(m.OriginUrl) + if err != nil { + return in, fmt.Errorf("unable to parse mirror origin url %s: %v", m.OriginUrl, err) + } + mirrorEndpointUrl, err := url.Parse(mirrorEndpoint.Endpoint) + if err != nil { + return in, fmt.Errorf("unable to parse mirror endpoint %s: %v", mirrorEndpoint.Endpoint, err) + } + + mirroredIn := in.Clone(in.Context()) + inURL := in.URL.String() + inURL = strings.Replace(inURL, mirrorUrl.Host, mirrorEndpointUrl.Host, 1) + inURL = strings.Replace(inURL, mirrorUrl.Path, mirrorEndpointUrl.Path, 1) + if in.URL.Scheme == "https" && !mirrorEndpoint.Secure { + inURL = strings.Replace(inURL, "https", "http", 1) + } + if in.URL.Scheme == "http" && mirrorEndpoint.Secure { + inURL = strings.Replace(inURL, "http", "https", 1) + } + mirroredRequestURL, err := url.Parse(inURL) + if err != nil { + return in, fmt.Errorf("unable to parse mirror endpoint %s: %v", mirrorEndpoint.Endpoint, err) + + } + mirroredIn.URL = mirroredRequestURL + fmt.Printf("using %v as mirror of %v\n", mirroredIn.URL.String(), in.URL.String()) + return mirroredIn, nil +}