440 lines
12 KiB
Go
440 lines
12 KiB
Go
// Copyright 2010 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Package mockgen generates mock implementations of Go interfaces.
|
|
package mockgen
|
|
|
|
// TODO: This does not support recursive embedded interfaces.
|
|
// TODO: This does not support embedding package-local interfaces in a separate file.
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"go/format"
|
|
"go/token"
|
|
"log"
|
|
"path"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/otokaze/mock/mockgen/model"
|
|
)
|
|
|
|
const (
|
|
gomockImportPath = "github.com/otokaze/mock/gomock"
|
|
)
|
|
|
|
var (
|
|
imports, auxFiles, buildFlags, execOnly string
|
|
progOnly bool
|
|
)
|
|
|
|
// Generator a generator struct.
|
|
type Generator struct {
|
|
buf bytes.Buffer
|
|
indent string
|
|
MockNames map[string]string //may be empty
|
|
Filename string // may be empty
|
|
SrcPackage, SrcInterfaces string // may be empty
|
|
|
|
packageMap map[string]string // map from import path to package name
|
|
}
|
|
|
|
func (g *Generator) p(format string, args ...interface{}) {
|
|
fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
|
|
}
|
|
|
|
func (g *Generator) in() {
|
|
g.indent += "\t"
|
|
}
|
|
|
|
func (g *Generator) out() {
|
|
if len(g.indent) > 0 {
|
|
g.indent = g.indent[0 : len(g.indent)-1]
|
|
}
|
|
}
|
|
|
|
func removeDot(s string) string {
|
|
if len(s) > 0 && s[len(s)-1] == '.' {
|
|
return s[0 : len(s)-1]
|
|
}
|
|
return s
|
|
}
|
|
|
|
// sanitize cleans up a string to make a suitable package name.
|
|
func sanitize(s string) string {
|
|
t := ""
|
|
for _, r := range s {
|
|
if t == "" {
|
|
if unicode.IsLetter(r) || r == '_' {
|
|
t += string(r)
|
|
continue
|
|
}
|
|
} else {
|
|
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
|
|
t += string(r)
|
|
continue
|
|
}
|
|
}
|
|
t += "_"
|
|
}
|
|
if t == "_" {
|
|
t = "x"
|
|
}
|
|
return t
|
|
}
|
|
|
|
// Generate gen mock by pkg.
|
|
func (g *Generator) Generate(pkg *model.Package, pkgName string, outputPackagePath string) error {
|
|
g.p("// Code generated by MockGen. DO NOT EDIT.")
|
|
if g.Filename != "" {
|
|
g.p("// Source: %v", g.Filename)
|
|
} else {
|
|
g.p("// Source: %v (interfaces: %v)", g.SrcPackage, g.SrcInterfaces)
|
|
}
|
|
g.p("")
|
|
|
|
// Get all required imports, and generate unique names for them all.
|
|
im := pkg.Imports()
|
|
im[gomockImportPath] = true
|
|
|
|
// Only import reflect if it's used. We only use reflect in mocked methods
|
|
// so only import if any of the mocked interfaces have methods.
|
|
for _, intf := range pkg.Interfaces {
|
|
if len(intf.Methods) > 0 {
|
|
im["reflect"] = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// Sort keys to make import alias generation predictable
|
|
sortedPaths := make([]string, len(im))
|
|
x := 0
|
|
for pth := range im {
|
|
sortedPaths[x] = pth
|
|
x++
|
|
}
|
|
sort.Strings(sortedPaths)
|
|
|
|
g.packageMap = make(map[string]string, len(im))
|
|
localNames := make(map[string]bool, len(im))
|
|
for _, pth := range sortedPaths {
|
|
base := sanitize(path.Base(pth))
|
|
|
|
// Local names for an imported package can usually be the basename of the import path.
|
|
// A couple of situations don't permit that, such as duplicate local names
|
|
// (e.g. importing "html/template" and "text/template"), or where the basename is
|
|
// a keyword (e.g. "foo/case").
|
|
// try base0, base1, ...
|
|
pkgName := base
|
|
i := 0
|
|
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
|
|
pkgName = base + strconv.Itoa(i)
|
|
i++
|
|
}
|
|
|
|
g.packageMap[pth] = pkgName
|
|
localNames[pkgName] = true
|
|
}
|
|
|
|
g.p("// Package %v is a generated GoMock package.", pkgName)
|
|
g.p("package %v", pkgName)
|
|
g.p("")
|
|
g.p("import (")
|
|
g.in()
|
|
for path, pkg := range g.packageMap {
|
|
if path == outputPackagePath {
|
|
continue
|
|
}
|
|
g.p("%v %q", pkg, path)
|
|
}
|
|
for _, path := range pkg.DotImports {
|
|
g.p(". %q", path)
|
|
}
|
|
g.out()
|
|
g.p(")")
|
|
|
|
for _, intf := range pkg.Interfaces {
|
|
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// The name of the mock type to use for the given interface identifier.
|
|
func (g *Generator) mockName(typeName string) string {
|
|
if mockName, ok := g.MockNames[typeName]; ok {
|
|
return mockName
|
|
}
|
|
|
|
return "Mock" + typeName
|
|
}
|
|
|
|
// GenerateMockInterface gen mock intf.
|
|
func (g *Generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
|
|
mockType := g.mockName(intf.Name)
|
|
|
|
g.p("")
|
|
g.p("// %v is a mock of %v interface", mockType, intf.Name)
|
|
g.p("type %v struct {", mockType)
|
|
g.in()
|
|
g.p("ctrl *gomock.Controller")
|
|
g.p("recorder *%vMockRecorder", mockType)
|
|
g.out()
|
|
g.p("}")
|
|
g.p("")
|
|
|
|
g.p("// %vMockRecorder is the mock recorder for %v", mockType, mockType)
|
|
g.p("type %vMockRecorder struct {", mockType)
|
|
g.in()
|
|
g.p("mock *%v", mockType)
|
|
g.out()
|
|
g.p("}")
|
|
g.p("")
|
|
|
|
// TODO: Re-enable this if we can import the interface reliably.
|
|
//g.p("// Verify that the mock satisfies the interface at compile time.")
|
|
//g.p("var _ %v = (*%v)(nil)", typeName, mockType)
|
|
//g.p("")
|
|
|
|
g.p("// New%v creates a new mock instance", mockType)
|
|
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
|
|
g.in()
|
|
g.p("mock := &%v{ctrl: ctrl}", mockType)
|
|
g.p("mock.recorder = &%vMockRecorder{mock}", mockType)
|
|
g.p("return mock")
|
|
g.out()
|
|
g.p("}")
|
|
g.p("")
|
|
|
|
// XXX: possible name collision here if someone has EXPECT in their interface.
|
|
g.p("// EXPECT returns an object that allows the caller to indicate expected use")
|
|
g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType)
|
|
g.in()
|
|
g.p("return m.recorder")
|
|
g.out()
|
|
g.p("}")
|
|
|
|
g.GenerateMockMethods(mockType, intf, outputPackagePath)
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenerateMockMethods gen mock methods.
|
|
func (g *Generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
|
|
for _, m := range intf.Methods {
|
|
g.p("")
|
|
g.GenerateMockMethod(mockType, m, pkgOverride)
|
|
g.p("")
|
|
g.GenerateMockRecorderMethod(mockType, m)
|
|
}
|
|
}
|
|
|
|
func makeArgString(argNames, argTypes []string) string {
|
|
args := make([]string, len(argNames))
|
|
for i, name := range argNames {
|
|
// specify the type only once for consecutive args of the same type
|
|
if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] {
|
|
args[i] = name
|
|
} else {
|
|
args[i] = name + " " + argTypes[i]
|
|
}
|
|
}
|
|
return strings.Join(args, ", ")
|
|
}
|
|
|
|
// GenerateMockMethod generates a mock method implementation.
|
|
// If non-empty, pkgOverride is the package in which unqualified types reside.
|
|
func (g *Generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
|
|
argNames := g.getArgNames(m)
|
|
argTypes := g.getArgTypes(m, pkgOverride)
|
|
argString := makeArgString(argNames, argTypes)
|
|
|
|
rets := make([]string, len(m.Out))
|
|
for i, p := range m.Out {
|
|
rets[i] = p.Type.String(g.packageMap, pkgOverride)
|
|
}
|
|
retString := strings.Join(rets, ", ")
|
|
if len(rets) > 1 {
|
|
retString = "(" + retString + ")"
|
|
}
|
|
if retString != "" {
|
|
retString = " " + retString
|
|
}
|
|
|
|
ia := newIdentifierAllocator(argNames)
|
|
idRecv := ia.allocateIdentifier("m")
|
|
|
|
g.p("// %v mocks base method", m.Name)
|
|
g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
|
|
g.in()
|
|
|
|
var callArgs string
|
|
if m.Variadic == nil {
|
|
if len(argNames) > 0 {
|
|
callArgs = ", " + strings.Join(argNames, ", ")
|
|
}
|
|
} else {
|
|
// Non-trivial. The generated code must build a []interface{},
|
|
// but the variadic argument may be any type.
|
|
idVarArgs := ia.allocateIdentifier("varargs")
|
|
idVArg := ia.allocateIdentifier("a")
|
|
g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "))
|
|
g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1])
|
|
g.in()
|
|
g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg)
|
|
g.out()
|
|
g.p("}")
|
|
callArgs = ", " + idVarArgs + "..."
|
|
}
|
|
if len(m.Out) == 0 {
|
|
g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs)
|
|
} else {
|
|
idRet := ia.allocateIdentifier("ret")
|
|
g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs)
|
|
|
|
// Go does not allow "naked" type assertions on nil values, so we use the two-value form here.
|
|
// The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T.
|
|
// Happily, this coincides with the semantics we want here.
|
|
retNames := make([]string, len(rets))
|
|
for i, t := range rets {
|
|
retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i))
|
|
g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t)
|
|
}
|
|
g.p("return " + strings.Join(retNames, ", "))
|
|
}
|
|
|
|
g.out()
|
|
g.p("}")
|
|
return nil
|
|
}
|
|
|
|
// GenerateMockRecorderMethod gen mock recorder method.
|
|
func (g *Generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
|
|
argNames := g.getArgNames(m)
|
|
|
|
var argString string
|
|
if m.Variadic == nil {
|
|
argString = strings.Join(argNames, ", ")
|
|
} else {
|
|
argString = strings.Join(argNames[:len(argNames)-1], ", ")
|
|
}
|
|
if argString != "" {
|
|
argString += " interface{}"
|
|
}
|
|
|
|
if m.Variadic != nil {
|
|
if argString != "" {
|
|
argString += ", "
|
|
}
|
|
argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1])
|
|
}
|
|
|
|
ia := newIdentifierAllocator(argNames)
|
|
idRecv := ia.allocateIdentifier("mr")
|
|
|
|
g.p("// %v indicates an expected call of %v", m.Name, m.Name)
|
|
g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString)
|
|
g.in()
|
|
|
|
var callArgs string
|
|
if m.Variadic == nil {
|
|
if len(argNames) > 0 {
|
|
callArgs = ", " + strings.Join(argNames, ", ")
|
|
}
|
|
} else {
|
|
if len(argNames) == 1 {
|
|
// Easy: just use ... to push the arguments through.
|
|
callArgs = ", " + argNames[0] + "..."
|
|
} else {
|
|
// Hard: create a temporary slice.
|
|
idVarArgs := ia.allocateIdentifier("varargs")
|
|
g.p("%s := append([]interface{}{%s}, %s...)",
|
|
idVarArgs,
|
|
strings.Join(argNames[:len(argNames)-1], ", "),
|
|
argNames[len(argNames)-1])
|
|
callArgs = ", " + idVarArgs + "..."
|
|
}
|
|
}
|
|
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs)
|
|
|
|
g.out()
|
|
g.p("}")
|
|
return nil
|
|
}
|
|
|
|
func (g *Generator) getArgNames(m *model.Method) []string {
|
|
argNames := make([]string, len(m.In))
|
|
for i, p := range m.In {
|
|
name := p.Name
|
|
if name == "" {
|
|
name = fmt.Sprintf("arg%d", i)
|
|
}
|
|
argNames[i] = name
|
|
}
|
|
if m.Variadic != nil {
|
|
name := m.Variadic.Name
|
|
if name == "" {
|
|
name = fmt.Sprintf("arg%d", len(m.In))
|
|
}
|
|
argNames = append(argNames, name)
|
|
}
|
|
return argNames
|
|
}
|
|
|
|
func (g *Generator) getArgTypes(m *model.Method, pkgOverride string) []string {
|
|
argTypes := make([]string, len(m.In))
|
|
for i, p := range m.In {
|
|
argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
|
|
}
|
|
if m.Variadic != nil {
|
|
argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride))
|
|
}
|
|
return argTypes
|
|
}
|
|
|
|
type identifierAllocator map[string]struct{}
|
|
|
|
func newIdentifierAllocator(taken []string) identifierAllocator {
|
|
a := make(identifierAllocator, len(taken))
|
|
for _, s := range taken {
|
|
a[s] = struct{}{}
|
|
}
|
|
return a
|
|
}
|
|
|
|
func (o identifierAllocator) allocateIdentifier(want string) string {
|
|
id := want
|
|
for i := 2; ; i++ {
|
|
if _, ok := o[id]; !ok {
|
|
o[id] = struct{}{}
|
|
return id
|
|
}
|
|
id = want + "_" + strconv.Itoa(i)
|
|
}
|
|
}
|
|
|
|
// Output returns the generator's output, formatted in the standard Go style.
|
|
func (g *Generator) Output() []byte {
|
|
src, err := format.Source(g.buf.Bytes())
|
|
if err != nil {
|
|
log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String())
|
|
}
|
|
return src
|
|
}
|