// Copyright 2018 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package exports import ( "fmt" "go/ast" "go/parser" "go/token" "io/ioutil" ) // Package represents a Go package. type Package struct { f *token.FileSet p *ast.Package files map[string][]byte } // LoadPackage loads the package in the specified directory. // It's required there is only one package in the directory. func LoadPackage(dir string) (pkg Package, err error) { pkg.files = map[string][]byte{} pkg.f = token.NewFileSet() packages, err := parser.ParseDir(pkg.f, dir, nil, 0) if err != nil { return } if len(packages) < 1 { err = fmt.Errorf("didn't find any packages in '%s'", dir) return } if len(packages) > 1 { err = fmt.Errorf("found more than one package in '%s'", dir) return } for pn := range packages { p := packages[pn] // trim any non-exported nodes if exp := ast.PackageExports(p); !exp { err = fmt.Errorf("package '%s' doesn't contain any exports", pn) return } pkg.p = p return } // shouldn't ever get here... panic("failed to return package") } // GetExports returns the exported content of the package. func (pkg Package) GetExports() (c Content) { c = NewContent() ast.Inspect(pkg.p, func(n ast.Node) bool { switch x := n.(type) { case *ast.TypeSpec: if t, ok := x.Type.(*ast.StructType); ok { c.addStruct(pkg, x.Name.Name, t) } else if t, ok := x.Type.(*ast.InterfaceType); ok { c.addInterface(pkg, x.Name.Name, t) } case *ast.FuncDecl: c.addFunc(pkg, x) // return false as we don't care about the function body. // this is super important as it filters out the majority of // the package's AST making it WAY easier to find the bits // of interest (not doing this will break a lot of code). return false case *ast.GenDecl: if x.Tok == token.CONST { c.addConst(pkg, x) } } return true }) return } // Name returns the package name. func (pkg Package) Name() string { return pkg.p.Name } // Get loads the package in the specified directory and returns the exported // content. It's a convenience wrapper around LoadPackage() and GetExports(). func Get(pkgDir string) (Content, error) { pkg, err := LoadPackage(pkgDir) if err != nil { return Content{}, err } return pkg.GetExports(), nil } // returns the text between [start, end] func (pkg Package) getText(start token.Pos, end token.Pos) string { // convert to absolute position within the containing file p := pkg.f.Position(start) // check if the file has been loaded, if not then load it if _, ok := pkg.files[p.Filename]; !ok { b, err := ioutil.ReadFile(p.Filename) if err != nil { panic(err) } pkg.files[p.Filename] = b } return string(pkg.files[p.Filename][p.Offset : p.Offset+int(end-start)]) } // iterates over the specified field list, for each field the specified // callback is invoked with the name of the field and the type name. the field // name can be nil, e.g. anonymous fields in structs, unnamed return types etc. func (pkg Package) translateFieldList(fl []*ast.Field, cb func(*string, string)) { for _, f := range fl { var name *string if f.Names != nil { n := pkg.getText(f.Names[0].Pos(), f.Names[0].End()) name = &n } t := pkg.getText(f.Type.Pos(), f.Type.End()) cb(name, t) } } // creates a Func object from the specified ast.FuncType func (pkg Package) buildFunc(ft *ast.FuncType) (f Func) { // appends a to s, comma-delimited style, and returns s appendString := func(s, a string) string { if s != "" { s += "," } s += a return s } // build the params type list if ft.Params.List != nil { p := "" pkg.translateFieldList(ft.Params.List, func(n *string, t string) { p = appendString(p, t) }) f.Params = &p } // build the return types list if ft.Results != nil { r := "" pkg.translateFieldList(ft.Results.List, func(n *string, t string) { r = appendString(r, t) }) f.Returns = &r } return }