storj/private/apigen/gogen.go

292 lines
7.7 KiB
Go
Raw Normal View History

// Copyright (C) 2022 Storj Labs, Inc.
// See LICENSE for copying information.
package apigen
import (
"fmt"
"go/format"
"net/http"
"os"
"reflect"
"strings"
"time"
"github.com/zeebo/errs"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"storj.io/common/uuid"
)
// MustWriteGo writes generated Go code into a file.
func (a *API) MustWriteGo(path string) {
generated, err := a.generateGo()
if err != nil {
panic(errs.Wrap(err))
}
err = os.WriteFile(path, generated, 0644)
if err != nil {
panic(errs.Wrap(err))
}
}
// generateGo generates api code and returns an output.
func (a *API) generateGo() ([]byte, error) {
var result string
p := func(format string, a ...interface{}) {
result += fmt.Sprintf(format+"\n", a...)
}
getPackageName := func(path string) string {
pathPackages := strings.Split(path, "/")
return pathPackages[len(pathPackages)-1]
}
p("// AUTOGENERATED BY private/apigen")
p("// DO NOT EDIT.")
p("")
p("package %s", a.PackageName)
p("")
p("import (")
p(`"context"`)
p(`"encoding/json"`)
p(`"net/http"`)
p(`"strconv"`)
p(`"time"`)
p("")
p(`"github.com/gorilla/mux"`)
p(`"github.com/zeebo/errs"`)
p(`"go.uber.org/zap"`)
p("")
p(`"storj.io/common/uuid"`)
p(`"storj.io/storj/private/api"`)
for _, group := range a.EndpointGroups {
for _, method := range group.Endpoints {
if method.Request != nil {
path := reflect.TypeOf(method.Request).Elem().PkgPath()
pn := getPackageName(path)
if pn == a.PackageName {
continue
}
p(`"%s"`, path)
}
if method.Response != nil {
path := reflect.TypeOf(method.Response).Elem().PkgPath()
pn := getPackageName(path)
if pn == a.PackageName {
continue
}
p(`"%s"`, path)
}
}
p(")")
p("")
}
p("const dateLayout = \"2006-01-02T15:04:05.000Z\"")
p("")
for _, group := range a.EndpointGroups {
p("var Err%sAPI = errs.Class(\"%s %s api\")", cases.Title(language.Und).String(group.Prefix), a.PackageName, group.Prefix)
p("")
p("type %sService interface {", group.Name)
for _, e := range group.Endpoints {
responseType := reflect.TypeOf(e.Response)
var params string
for _, param := range e.Params {
params += param.Type.String() + ", "
}
p("%s(context.Context, "+params+") (%s, api.HTTPError)", e.MethodName, a.handleTypesPackage(responseType))
}
p("}")
p("")
p("// Handler is an api handler that exposes all %s related functionality.", group.Prefix)
p("type Handler struct {")
p("log *zap.Logger")
p("service %sService", group.Name)
p("auth api.Auth")
p("}")
p("")
p(
"func New%s(log *zap.Logger, service %sService, router *mux.Router, auth api.Auth) *Handler {",
group.Name,
group.Name,
)
p("handler := &Handler{")
p("log: log,")
p("service: service,")
p("auth: auth,")
p("}")
p("")
p("%sRouter := router.PathPrefix(\"/api/v0/%s\").Subrouter()", group.Prefix, group.Prefix)
for pathMethod, endpoint := range group.Endpoints {
handlerName := "handle" + endpoint.MethodName
p("%sRouter.HandleFunc(\"%s\", handler.%s).Methods(\"%s\")", group.Prefix, pathMethod.Path, handlerName, pathMethod.Method)
}
p("")
p("return handler")
p("}")
for pathMethod, endpoint := range group.Endpoints {
p("")
handlerName := "handle" + endpoint.MethodName
p("func (h *Handler) %s(w http.ResponseWriter, r *http.Request) {", handlerName)
p("ctx := r.Context()")
p("var err error")
p("defer mon.Task()(&ctx)(&err)")
p("")
p("w.Header().Set(\"Content-Type\", \"application/json\")")
p("")
if !endpoint.NoCookieAuth || !endpoint.NoAPIAuth {
if !endpoint.NoCookieAuth && !endpoint.NoAPIAuth {
p("ctx, err = h.auth.IsAuthenticated(ctx, r, true, true)")
}
if endpoint.NoCookieAuth && !endpoint.NoAPIAuth {
p("ctx, err = h.auth.IsAuthenticated(ctx, r, false, true)")
}
if !endpoint.NoCookieAuth && endpoint.NoAPIAuth {
p("ctx, err = h.auth.IsAuthenticated(ctx, r, true, false)")
}
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusUnauthorized, err)")
p("return")
p("}")
p("")
}
switch pathMethod.Method {
case http.MethodGet:
for _, param := range endpoint.Params {
switch param.Type {
case reflect.TypeOf(uuid.UUID{}):
p("%s, err := uuid.FromString(r.URL.Query().Get(\"%s\"))", param.Name, param.Name)
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
continue
case reflect.TypeOf(time.Time{}):
p("%s, err := time.Parse(dateLayout, r.URL.Query().Get(\"%s\"))", param.Name, param.Name)
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
continue
case reflect.TypeOf(""):
p("%s := r.URL.Query().Get(\"%s\")", param.Name, param.Name)
p("if %s == \"\" {", param.Name)
p("api.ServeError(h.log, w, http.StatusBadRequest, errs.New(\"parameter '%s' can't be empty\"))", param.Name)
p("return")
p("}")
p("")
continue
}
}
case http.MethodPatch:
for _, param := range endpoint.Params {
if param.Type == reflect.TypeOf(uuid.UUID{}) {
p("%sParam, ok := mux.Vars(r)[\"%s\"]", param.Name, param.Name)
p("if !ok {")
p("api.ServeError(h.log, w, http.StatusBadRequest, errs.New(\"missing %s route param\"))", param.Name)
p("return")
p("}")
p("")
p("%s, err := uuid.FromString(%sParam)", param.Name, param.Name)
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
} else {
p("%s := &%s{}", param.Name, param.Type)
p("if err = json.NewDecoder(r.Body).Decode(&%s); err != nil {", param.Name)
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
}
}
case http.MethodPost:
for _, param := range endpoint.Params {
p("%s := &%s{}", param.Name, param.Type)
p("if err = json.NewDecoder(r.Body).Decode(&%s); err != nil {", param.Name)
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
}
}
methodFormat := "retVal, httpErr := h.service.%s(ctx, "
switch pathMethod.Method {
case http.MethodGet:
for _, methodParam := range endpoint.Params {
methodFormat += methodParam.Name + ", "
}
case http.MethodPatch:
for _, methodParam := range endpoint.Params {
if methodParam.Type == reflect.TypeOf(uuid.UUID{}) {
methodFormat += methodParam.Name + ", "
} else {
methodFormat += "*" + methodParam.Name + ", "
}
}
case http.MethodPost:
for _, methodParam := range endpoint.Params {
methodFormat += "*" + methodParam.Name + ", "
}
}
methodFormat += ")"
p(methodFormat, endpoint.MethodName)
p("if httpErr.Err != nil {")
p("api.ServeError(h.log, w, httpErr.Status, httpErr.Err)")
p("return")
p("}")
p("")
p("err = json.NewEncoder(w).Encode(retVal)")
p("if err != nil {")
p("h.log.Debug(\"failed to write json %s response\", zap.Error(Err%sAPI.Wrap(err)))", endpoint.MethodName, cases.Title(language.Und).String(group.Prefix))
p("}")
p("}")
p("")
}
}
output, err := format.Source([]byte(result))
if err != nil {
return nil, err
}
return output, nil
}
// handleTypesPackage handles the way some type is used in generated code.
// If type is from the same package then we use only type's name.
// If type is from external package then we use type along with it's appropriate package name.
func (a *API) handleTypesPackage(t reflect.Type) interface{} {
if strings.HasPrefix(t.String(), a.PackageName) {
return t.Elem().Name()
}
return t
}