private/apigen: rework request parameter handling

- Previously unused struct Endpoint.Request now defines the form
	of the request body.
- Path parameters (e.g. "id" in "/delete/{id}") are defined in
	the Endpoint.PathParams field.
- Endpoint.Params has been renamed to Endpoint.QueryParams to
	eliminate confusion.

Change-Id: Ifef51ca2f362c33086f0e43e936d50b0fdd18aa1
This commit is contained in:
Jeremy Wharton 2022-07-13 22:43:33 -05:00 committed by Storj Robot
parent 3afd7bcc8b
commit 731fecd96f
4 changed files with 151 additions and 143 deletions

View File

@ -17,7 +17,8 @@ type Endpoint struct {
NoAPIAuth bool
Request interface{}
Response interface{}
Params []Param
QueryParams []Param
PathParams []Param
}
// CookieAuth returns endpoint's cookie auth status.

View File

@ -6,7 +6,6 @@ package apigen
import (
"fmt"
"go/format"
"net/http"
"os"
"reflect"
"sort"
@ -58,11 +57,11 @@ func (a *API) generateGo() ([]byte, error) {
i := func(paths ...string) {
for _, path := range paths {
if getPackageName(path) == a.PackageName {
return
continue
}
if _, ok := imports.All[path]; ok {
return
continue
}
imports.All[path] = true
@ -82,10 +81,10 @@ func (a *API) generateGo() ([]byte, error) {
for _, group := range a.EndpointGroups {
for _, method := range group.endpoints {
if method.Request != nil {
i(reflect.TypeOf(method.Request).Elem().PkgPath())
i(getElementaryType(reflect.TypeOf(method.Request)).PkgPath())
}
if method.Response != nil {
i(reflect.TypeOf(method.Response).Elem().PkgPath())
i(getElementaryType(reflect.TypeOf(method.Response)).PkgPath())
}
}
}
@ -100,20 +99,31 @@ func (a *API) generateGo() ([]byte, error) {
p("")
params := make(map[*fullEndpoint][]Param)
for _, group := range a.EndpointGroups {
p("type %sService interface {", group.Name)
for _, e := range group.endpoints {
var params string
for _, param := range e.Params {
params += param.Type.String() + ", "
params[e] = append(e.QueryParams, e.PathParams...)
var paramStr string
for _, param := range params[e] {
paramStr += param.Type.String() + ", "
}
if e.Request != nil {
paramStr += reflect.TypeOf(e.Request).String() + ", "
}
i("context", "storj.io/storj/private/api")
if e.Response != nil {
responseType := reflect.TypeOf(e.Response)
p("%s(context.Context, "+params+") (%s, api.HTTPError)", e.MethodName, a.handleTypesPackage(responseType))
returnParam := a.handleTypesPackage(responseType)
if responseType == getElementaryType(responseType) {
returnParam = "*" + returnParam
}
p("%s(context.Context, "+paramStr+") (%s, api.HTTPError)", e.MethodName, returnParam)
} else {
p("%s(context.Context, "+params+") (api.HTTPError)", e.MethodName)
p("%s(context.Context, "+paramStr+") (api.HTTPError)", e.MethodName)
}
}
p("}")
@ -182,39 +192,9 @@ func (a *API) generateGo() ([]byte, error) {
p("")
}
switch endpoint.Method {
case http.MethodGet:
for _, param := range endpoint.Params {
switch param.Type {
case reflect.TypeOf(uuid.UUID{}):
i("storj.io/common/uuid")
handleUUIDQuery(p, param)
continue
case reflect.TypeOf(time.Time{}):
i("time")
handleTimeQuery(p, param)
continue
case reflect.TypeOf(""):
handleStringQuery(p, param)
continue
}
}
case http.MethodPatch:
for _, param := range endpoint.Params {
if param.Type == reflect.TypeOf(uuid.UUID{}) {
handleUUIDParam(p, param)
} else {
handleBody(p, param)
}
}
case http.MethodPost:
for _, param := range endpoint.Params {
handleBody(p, param)
}
case http.MethodDelete:
for _, param := range endpoint.Params {
handleUUIDParam(p, param)
}
handleParams(p, i, endpoint.QueryParams, endpoint.PathParams)
if endpoint.Request != nil {
handleBody(p, endpoint.Request)
}
var methodFormat string
@ -224,27 +204,11 @@ func (a *API) generateGo() ([]byte, error) {
methodFormat = "httpErr := h.service.%s(ctx, "
}
switch endpoint.Method {
case http.MethodGet:
for _, methodParam := range endpoint.Params {
methodFormat += methodParam.Name + ", "
}
case http.MethodPatch:
for _, methodParam := range endpoint.Params {
if methodParam.Type == reflect.TypeOf(uuid.UUID{}) {
methodFormat += methodParam.Name + ", "
} else {
methodFormat += "*" + methodParam.Name + ", "
}
}
case http.MethodPost:
for _, methodParam := range endpoint.Params {
methodFormat += "*" + methodParam.Name + ", "
}
case http.MethodDelete:
for _, methodParam := range endpoint.Params {
methodFormat += methodParam.Name + ", "
}
for _, param := range params[endpoint] {
methodFormat += param.Name + ", "
}
if endpoint.Request != nil {
methodFormat += "payload"
}
methodFormat += ")"
@ -306,67 +270,77 @@ func (a *API) generateGo() ([]byte, error) {
// handleTypesPackage handles the way some type is used in generated code.
// If type is from the same package then we use only type's name.
// If type is from external package then we use type along with its appropriate package name.
func (a *API) handleTypesPackage(t reflect.Type) interface{} {
func (a *API) handleTypesPackage(t reflect.Type) string {
if strings.HasPrefix(t.String(), a.PackageName) {
return t.Elem().Name()
}
return t
return t.String()
}
// handleStringQuery handles request query param of type string.
func handleStringQuery(p func(format string, a ...interface{}), param Param) {
p("%s := r.URL.Query().Get(\"%s\")", param.Name, param.Name)
p("if %s == \"\" {", param.Name)
p("api.ServeError(h.log, w, http.StatusBadRequest, errs.New(\"parameter '%s' can't be empty\"))", param.Name)
p("return")
p("}")
p("")
}
// handleParams handles parsing of URL query parameters or path parameters.
func handleParams(p func(format string, a ...interface{}), i func(paths ...string), queryParams, pathParams []Param) {
for _, params := range []*[]Param{&queryParams, &pathParams} {
for _, param := range *params {
varName := param.Name
if param.Type != reflect.TypeOf("") {
varName += "Param"
}
// handleUUIDQuery handles request query param of type uuid.UUID.
func handleUUIDQuery(p func(format string, a ...interface{}), param Param) {
p("%s, err := uuid.FromString(r.URL.Query().Get(\"%s\"))", param.Name, param.Name)
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
}
switch params {
case &queryParams:
p("%s := r.URL.Query().Get(\"%s\")", varName, param.Name)
p("if %s == \"\" {", varName)
p("api.ServeError(h.log, w, http.StatusBadRequest, errs.New(\"parameter '%s' can't be empty\"))", param.Name)
p("return")
p("}")
p("")
case &pathParams:
p("%s, ok := mux.Vars(r)[\"%s\"]", varName, param.Name)
p("if !ok {")
p("api.ServeError(h.log, w, http.StatusBadRequest, errs.New(\"missing %s route param\"))", param.Name)
p("return")
p("}")
p("")
}
// handleTimeQuery handles request query param of type time.Time.
func handleTimeQuery(p func(format string, a ...interface{}), param Param) {
p("%s, err := time.Parse(dateLayout, r.URL.Query().Get(\"%s\"))", param.Name, param.Name)
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
}
switch param.Type {
case reflect.TypeOf(uuid.UUID{}):
i("storj.io/common/uuid")
p("%s, err := uuid.FromString(%s)", param.Name, varName)
case reflect.TypeOf(time.Time{}):
i("time")
p("%s, err := time.Parse(dateLayout, %s)", param.Name, varName)
default:
p("")
continue
}
// handleUUIDParam handles request inline param of type uuid.UUID.
func handleUUIDParam(p func(format string, a ...interface{}), param Param) {
p("%sParam, ok := mux.Vars(r)[\"%s\"]", param.Name, param.Name)
p("if !ok {")
p("api.ServeError(h.log, w, http.StatusBadRequest, errs.New(\"missing %s route param\"))", param.Name)
p("return")
p("}")
p("")
p("%s, err := uuid.FromString(%sParam)", param.Name, param.Name)
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
p("if err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
}
}
}
// handleBody handles request body.
func handleBody(p func(format string, a ...interface{}), param Param) {
p("%s := &%s{}", param.Name, param.Type)
p("if err = json.NewDecoder(r.Body).Decode(&%s); err != nil {", param.Name)
func handleBody(p func(format string, a ...interface{}), body interface{}) {
p("payload := %s{}", reflect.TypeOf(body).String())
p("if err = json.NewDecoder(r.Body).Decode(&payload); err != nil {")
p("api.ServeError(h.log, w, http.StatusBadRequest, err)")
p("return")
p("}")
p("")
}
// getElementaryType simplifies a Go type.
func getElementaryType(t reflect.Type) reflect.Type {
switch t.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
return getElementaryType(t.Elem())
default:
return t
}
}

View File

@ -116,17 +116,18 @@ func (h *ProjectManagementHandler) handleGenCreateProject(w http.ResponseWriter,
ctx, err = h.auth.IsAuthenticated(ctx, r, true, true)
if err != nil {
h.auth.RemoveAuthCookie(w)
api.ServeError(h.log, w, http.StatusUnauthorized, err)
return
}
projectInfo := &console.ProjectInfo{}
if err = json.NewDecoder(r.Body).Decode(&projectInfo); err != nil {
payload := console.ProjectInfo{}
if err = json.NewDecoder(r.Body).Decode(&payload); err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
}
retVal, httpErr := h.service.GenCreateProject(ctx, *projectInfo)
retVal, httpErr := h.service.GenCreateProject(ctx, payload)
if httpErr.Err != nil {
api.ServeError(h.log, w, httpErr.Status, httpErr.Err)
return
@ -164,13 +165,13 @@ func (h *ProjectManagementHandler) handleGenUpdateProject(w http.ResponseWriter,
return
}
projectInfo := &console.ProjectInfo{}
if err = json.NewDecoder(r.Body).Decode(&projectInfo); err != nil {
payload := console.ProjectInfo{}
if err = json.NewDecoder(r.Body).Decode(&payload); err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
}
retVal, httpErr := h.service.GenUpdateProject(ctx, id, *projectInfo)
retVal, httpErr := h.service.GenUpdateProject(ctx, id, payload)
if httpErr.Err != nil {
api.ServeError(h.log, w, httpErr.Status, httpErr.Err)
return
@ -223,6 +224,7 @@ func (h *ProjectManagementHandler) handleGenGetUsersProjects(w http.ResponseWrit
ctx, err = h.auth.IsAuthenticated(ctx, r, true, true)
if err != nil {
h.auth.RemoveAuthCookie(w)
api.ServeError(h.log, w, http.StatusUnauthorized, err)
return
}
@ -253,7 +255,13 @@ func (h *ProjectManagementHandler) handleGenGetSingleBucketUsageRollup(w http.Re
return
}
projectID, err := uuid.FromString(r.URL.Query().Get("projectID"))
projectIDParam := r.URL.Query().Get("projectID")
if projectIDParam == "" {
api.ServeError(h.log, w, http.StatusBadRequest, errs.New("parameter 'projectID' can't be empty"))
return
}
projectID, err := uuid.FromString(projectIDParam)
if err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
@ -265,13 +273,25 @@ func (h *ProjectManagementHandler) handleGenGetSingleBucketUsageRollup(w http.Re
return
}
since, err := time.Parse(dateLayout, r.URL.Query().Get("since"))
sinceParam := r.URL.Query().Get("since")
if sinceParam == "" {
api.ServeError(h.log, w, http.StatusBadRequest, errs.New("parameter 'since' can't be empty"))
return
}
since, err := time.Parse(dateLayout, sinceParam)
if err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
}
before, err := time.Parse(dateLayout, r.URL.Query().Get("before"))
beforeParam := r.URL.Query().Get("before")
if beforeParam == "" {
api.ServeError(h.log, w, http.StatusBadRequest, errs.New("parameter 'before' can't be empty"))
return
}
before, err := time.Parse(dateLayout, beforeParam)
if err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
@ -303,19 +323,37 @@ func (h *ProjectManagementHandler) handleGenGetBucketUsageRollups(w http.Respons
return
}
projectID, err := uuid.FromString(r.URL.Query().Get("projectID"))
projectIDParam := r.URL.Query().Get("projectID")
if projectIDParam == "" {
api.ServeError(h.log, w, http.StatusBadRequest, errs.New("parameter 'projectID' can't be empty"))
return
}
projectID, err := uuid.FromString(projectIDParam)
if err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
}
since, err := time.Parse(dateLayout, r.URL.Query().Get("since"))
sinceParam := r.URL.Query().Get("since")
if sinceParam == "" {
api.ServeError(h.log, w, http.StatusBadRequest, errs.New("parameter 'since' can't be empty"))
return
}
since, err := time.Parse(dateLayout, sinceParam)
if err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
}
before, err := time.Parse(dateLayout, r.URL.Query().Get("before"))
beforeParam := r.URL.Query().Get("before")
if beforeParam == "" {
api.ServeError(h.log, w, http.StatusBadRequest, errs.New("parameter 'before' can't be empty"))
return
}
before, err := time.Parse(dateLayout, beforeParam)
if err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
@ -347,13 +385,13 @@ func (h *APIKeyManagementHandler) handleGenCreateAPIKey(w http.ResponseWriter, r
return
}
apikeyInfo := &console.CreateAPIKeyRequest{}
if err = json.NewDecoder(r.Body).Decode(&apikeyInfo); err != nil {
payload := console.CreateAPIKeyRequest{}
if err = json.NewDecoder(r.Body).Decode(&payload); err != nil {
api.ServeError(h.log, w, http.StatusBadRequest, err)
return
}
retVal, httpErr := h.service.GenCreateAPIKey(ctx, *apikeyInfo)
retVal, httpErr := h.service.GenCreateAPIKey(ctx, payload)
if httpErr.Err != nil {
api.ServeError(h.log, w, httpErr.Status, httpErr.Err)
return

View File

@ -30,19 +30,17 @@ func main() {
Description: "Creates new Project with given info",
MethodName: "GenCreateProject",
Response: &console.Project{},
Params: []apigen.Param{
apigen.NewParam("projectInfo", console.ProjectInfo{}),
},
Request: console.ProjectInfo{},
})
g.Patch("/update/{id}", &apigen.Endpoint{
Name: "Update Project",
Description: "Updates project with given info",
MethodName: "GenUpdateProject",
Response: &console.Project{},
Params: []apigen.Param{
Response: console.Project{},
Request: console.ProjectInfo{},
PathParams: []apigen.Param{
apigen.NewParam("id", uuid.UUID{}),
apigen.NewParam("projectInfo", console.ProjectInfo{}),
},
})
@ -50,8 +48,7 @@ func main() {
Name: "Delete Project",
Description: "Deletes project by id",
MethodName: "GenDeleteProject",
Response: nil,
Params: []apigen.Param{
PathParams: []apigen.Param{
apigen.NewParam("id", uuid.UUID{}),
},
})
@ -67,8 +64,8 @@ func main() {
Name: "Get Project's Single Bucket Usage",
Description: "Gets project's single bucket usage by bucket ID",
MethodName: "GenGetSingleBucketUsageRollup",
Response: &accounting.BucketUsageRollup{},
Params: []apigen.Param{
Response: accounting.BucketUsageRollup{},
QueryParams: []apigen.Param{
apigen.NewParam("projectID", uuid.UUID{}),
apigen.NewParam("bucket", ""),
apigen.NewParam("since", time.Time{}),
@ -81,7 +78,7 @@ func main() {
Description: "Gets project's all buckets usage",
MethodName: "GenGetBucketUsageRollups",
Response: []accounting.BucketUsageRollup{},
Params: []apigen.Param{
QueryParams: []apigen.Param{
apigen.NewParam("projectID", uuid.UUID{}),
apigen.NewParam("since", time.Time{}),
apigen.NewParam("before", time.Time{}),
@ -96,10 +93,8 @@ func main() {
Name: "Create new macaroon API key",
Description: "Creates new macaroon API key with given info",
MethodName: "GenCreateAPIKey",
Response: &console.CreateAPIKeyResponse{},
Params: []apigen.Param{
apigen.NewParam("apikeyInfo", console.CreateAPIKeyRequest{}),
},
Response: console.CreateAPIKeyResponse{},
Request: console.CreateAPIKeyRequest{},
})
}
@ -110,7 +105,7 @@ func main() {
Name: "Get User",
Description: "Gets User by request context",
MethodName: "GenGetUser",
Response: &console.ResponseUser{},
Response: console.ResponseUser{},
})
}