restic/vendor/github.com/Azure/azure-sdk-for-go/tools/apidiff/exports/package.go

166 lines
4.4 KiB
Go

// Copyright 2018 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package exports
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
)
// Package represents a Go package.
type Package struct {
f *token.FileSet
p *ast.Package
files map[string][]byte
}
// LoadPackage loads the package in the specified directory.
// It's required there is only one package in the directory.
func LoadPackage(dir string) (pkg Package, err error) {
pkg.files = map[string][]byte{}
pkg.f = token.NewFileSet()
packages, err := parser.ParseDir(pkg.f, dir, nil, 0)
if err != nil {
return
}
if len(packages) < 1 {
err = fmt.Errorf("didn't find any packages in '%s'", dir)
return
}
if len(packages) > 1 {
err = fmt.Errorf("found more than one package in '%s'", dir)
return
}
for pn := range packages {
p := packages[pn]
// trim any non-exported nodes
if exp := ast.PackageExports(p); !exp {
err = fmt.Errorf("package '%s' doesn't contain any exports", pn)
return
}
pkg.p = p
return
}
// shouldn't ever get here...
panic("failed to return package")
}
// GetExports returns the exported content of the package.
func (pkg Package) GetExports() (c Content) {
c = NewContent()
ast.Inspect(pkg.p, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.TypeSpec:
if t, ok := x.Type.(*ast.StructType); ok {
c.addStruct(pkg, x.Name.Name, t)
} else if t, ok := x.Type.(*ast.InterfaceType); ok {
c.addInterface(pkg, x.Name.Name, t)
}
case *ast.FuncDecl:
c.addFunc(pkg, x)
// return false as we don't care about the function body.
// this is super important as it filters out the majority of
// the package's AST making it WAY easier to find the bits
// of interest (not doing this will break a lot of code).
return false
case *ast.GenDecl:
if x.Tok == token.CONST {
c.addConst(pkg, x)
}
}
return true
})
return
}
// Name returns the package name.
func (pkg Package) Name() string {
return pkg.p.Name
}
// Get loads the package in the specified directory and returns the exported
// content. It's a convenience wrapper around LoadPackage() and GetExports().
func Get(pkgDir string) (Content, error) {
pkg, err := LoadPackage(pkgDir)
if err != nil {
return Content{}, err
}
return pkg.GetExports(), nil
}
// returns the text between [start, end]
func (pkg Package) getText(start token.Pos, end token.Pos) string {
// convert to absolute position within the containing file
p := pkg.f.Position(start)
// check if the file has been loaded, if not then load it
if _, ok := pkg.files[p.Filename]; !ok {
b, err := ioutil.ReadFile(p.Filename)
if err != nil {
panic(err)
}
pkg.files[p.Filename] = b
}
return string(pkg.files[p.Filename][p.Offset : p.Offset+int(end-start)])
}
// iterates over the specified field list, for each field the specified
// callback is invoked with the name of the field and the type name. the field
// name can be nil, e.g. anonymous fields in structs, unnamed return types etc.
func (pkg Package) translateFieldList(fl []*ast.Field, cb func(*string, string)) {
for _, f := range fl {
var name *string
if f.Names != nil {
n := pkg.getText(f.Names[0].Pos(), f.Names[0].End())
name = &n
}
t := pkg.getText(f.Type.Pos(), f.Type.End())
cb(name, t)
}
}
// creates a Func object from the specified ast.FuncType
func (pkg Package) buildFunc(ft *ast.FuncType) (f Func) {
// appends a to s, comma-delimited style, and returns s
appendString := func(s, a string) string {
if s != "" {
s += ","
}
s += a
return s
}
// build the params type list
if ft.Params.List != nil {
p := ""
pkg.translateFieldList(ft.Params.List, func(n *string, t string) {
p = appendString(p, t)
})
f.Params = &p
}
// build the return types list
if ft.Results != nil {
r := ""
pkg.translateFieldList(ft.Results.List, func(n *string, t string) {
r = appendString(r, t)
})
f.Returns = &r
}
return
}