2019-04-22 02:59:20 +00:00

440 lines
12 KiB

// Copyright 2010 Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mockgen generates mock implementations of Go interfaces.
package mockgen
// TODO: This does not support recursive embedded interfaces.
// TODO: This does not support embedding package-local interfaces in a separate file.
import (
const (
gomockImportPath = "github.com/otokaze/mock/gomock"
var (
imports, auxFiles, buildFlags, execOnly string
progOnly bool
// Generator a generator struct.
type Generator struct {
buf bytes.Buffer
indent string
MockNames map[string]string //may be empty
Filename string // may be empty
SrcPackage, SrcInterfaces string // may be empty
packageMap map[string]string // map from import path to package name
func (g *Generator) p(format string, args ...interface{}) {
fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
func (g *Generator) in() {
g.indent += "\t"
func (g *Generator) out() {
if len(g.indent) > 0 {
g.indent = g.indent[0 : len(g.indent)-1]
func removeDot(s string) string {
if len(s) > 0 && s[len(s)-1] == '.' {
return s[0 : len(s)-1]
return s
// sanitize cleans up a string to make a suitable package name.
func sanitize(s string) string {
t := ""
for _, r := range s {
if t == "" {
if unicode.IsLetter(r) || r == '_' {
t += string(r)
} else {
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
t += string(r)
t += "_"
if t == "_" {
t = "x"
return t
// Generate gen mock by pkg.
func (g *Generator) Generate(pkg *model.Package, pkgName string, outputPackagePath string) error {
g.p("// Code generated by MockGen. DO NOT EDIT.")
if g.Filename != "" {
g.p("// Source: %v", g.Filename)
} else {
g.p("// Source: %v (interfaces: %v)", g.SrcPackage, g.SrcInterfaces)
// Get all required imports, and generate unique names for them all.
im := pkg.Imports()
im[gomockImportPath] = true
// Only import reflect if it's used. We only use reflect in mocked methods
// so only import if any of the mocked interfaces have methods.
for _, intf := range pkg.Interfaces {
if len(intf.Methods) > 0 {
im["reflect"] = true
// Sort keys to make import alias generation predictable
sortedPaths := make([]string, len(im))
x := 0
for pth := range im {
sortedPaths[x] = pth
g.packageMap = make(map[string]string, len(im))
localNames := make(map[string]bool, len(im))
for _, pth := range sortedPaths {
base := sanitize(path.Base(pth))
// Local names for an imported package can usually be the basename of the import path.
// A couple of situations don't permit that, such as duplicate local names
// (e.g. importing "html/template" and "text/template"), or where the basename is
// a keyword (e.g. "foo/case").
// try base0, base1, ...
pkgName := base
i := 0
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
pkgName = base + strconv.Itoa(i)
g.packageMap[pth] = pkgName
localNames[pkgName] = true
g.p("// Package %v is a generated GoMock package.", pkgName)
g.p("package %v", pkgName)
g.p("import (")
for path, pkg := range g.packageMap {
if path == outputPackagePath {
g.p("%v %q", pkg, path)
for _, path := range pkg.DotImports {
g.p(". %q", path)
for _, intf := range pkg.Interfaces {
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
return err
return nil
// The name of the mock type to use for the given interface identifier.
func (g *Generator) mockName(typeName string) string {
if mockName, ok := g.MockNames[typeName]; ok {
return mockName
return "Mock" + typeName
// GenerateMockInterface gen mock intf.
func (g *Generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
mockType := g.mockName(intf.Name)
g.p("// %v is a mock of %v interface", mockType, intf.Name)
g.p("type %v struct {", mockType)
g.p("ctrl *gomock.Controller")
g.p("recorder *%vMockRecorder", mockType)
g.p("// %vMockRecorder is the mock recorder for %v", mockType, mockType)
g.p("type %vMockRecorder struct {", mockType)
g.p("mock *%v", mockType)
// TODO: Re-enable this if we can import the interface reliably.
//g.p("// Verify that the mock satisfies the interface at compile time.")
//g.p("var _ %v = (*%v)(nil)", typeName, mockType)
g.p("// New%v creates a new mock instance", mockType)
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
g.p("mock := &%v{ctrl: ctrl}", mockType)
g.p("mock.recorder = &%vMockRecorder{mock}", mockType)
g.p("return mock")
// XXX: possible name collision here if someone has EXPECT in their interface.
g.p("// EXPECT returns an object that allows the caller to indicate expected use")
g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType)
g.p("return m.recorder")
g.GenerateMockMethods(mockType, intf, outputPackagePath)
return nil
// GenerateMockMethods gen mock methods.
func (g *Generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
for _, m := range intf.Methods {
g.GenerateMockMethod(mockType, m, pkgOverride)
g.GenerateMockRecorderMethod(mockType, m)
func makeArgString(argNames, argTypes []string) string {
args := make([]string, len(argNames))
for i, name := range argNames {
// specify the type only once for consecutive args of the same type
if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] {
args[i] = name
} else {
args[i] = name + " " + argTypes[i]
return strings.Join(args, ", ")
// GenerateMockMethod generates a mock method implementation.
// If non-empty, pkgOverride is the package in which unqualified types reside.
func (g *Generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
argNames := g.getArgNames(m)
argTypes := g.getArgTypes(m, pkgOverride)
argString := makeArgString(argNames, argTypes)
rets := make([]string, len(m.Out))
for i, p := range m.Out {
rets[i] = p.Type.String(g.packageMap, pkgOverride)
retString := strings.Join(rets, ", ")
if len(rets) > 1 {
retString = "(" + retString + ")"
if retString != "" {
retString = " " + retString
ia := newIdentifierAllocator(argNames)
idRecv := ia.allocateIdentifier("m")
g.p("// %v mocks base method", m.Name)
g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
var callArgs string
if m.Variadic == nil {
if len(argNames) > 0 {
callArgs = ", " + strings.Join(argNames, ", ")
} else {
// Non-trivial. The generated code must build a []interface{},
// but the variadic argument may be any type.
idVarArgs := ia.allocateIdentifier("varargs")
idVArg := ia.allocateIdentifier("a")
g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "))
g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1])
g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg)
callArgs = ", " + idVarArgs + "..."
if len(m.Out) == 0 {
g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs)
} else {
idRet := ia.allocateIdentifier("ret")
g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs)
// Go does not allow "naked" type assertions on nil values, so we use the two-value form here.
// The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T.
// Happily, this coincides with the semantics we want here.
retNames := make([]string, len(rets))
for i, t := range rets {
retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i))
g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t)
g.p("return " + strings.Join(retNames, ", "))
return nil
// GenerateMockRecorderMethod gen mock recorder method.
func (g *Generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
argNames := g.getArgNames(m)
var argString string
if m.Variadic == nil {
argString = strings.Join(argNames, ", ")
} else {
argString = strings.Join(argNames[:len(argNames)-1], ", ")
if argString != "" {
argString += " interface{}"
if m.Variadic != nil {
if argString != "" {
argString += ", "
argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1])
ia := newIdentifierAllocator(argNames)
idRecv := ia.allocateIdentifier("mr")
g.p("// %v indicates an expected call of %v", m.Name, m.Name)
g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString)
var callArgs string
if m.Variadic == nil {
if len(argNames) > 0 {
callArgs = ", " + strings.Join(argNames, ", ")
} else {
if len(argNames) == 1 {
// Easy: just use ... to push the arguments through.
callArgs = ", " + argNames[0] + "..."
} else {
// Hard: create a temporary slice.
idVarArgs := ia.allocateIdentifier("varargs")
g.p("%s := append([]interface{}{%s}, %s...)",
strings.Join(argNames[:len(argNames)-1], ", "),
callArgs = ", " + idVarArgs + "..."
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs)
return nil
func (g *Generator) getArgNames(m *model.Method) []string {
argNames := make([]string, len(m.In))
for i, p := range m.In {
name := p.Name
if name == "" {
name = fmt.Sprintf("arg%d", i)
argNames[i] = name
if m.Variadic != nil {
name := m.Variadic.Name
if name == "" {
name = fmt.Sprintf("arg%d", len(m.In))
argNames = append(argNames, name)
return argNames
func (g *Generator) getArgTypes(m *model.Method, pkgOverride string) []string {
argTypes := make([]string, len(m.In))
for i, p := range m.In {
argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
if m.Variadic != nil {
argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride))
return argTypes
type identifierAllocator map[string]struct{}
func newIdentifierAllocator(taken []string) identifierAllocator {
a := make(identifierAllocator, len(taken))
for _, s := range taken {
a[s] = struct{}{}
return a
func (o identifierAllocator) allocateIdentifier(want string) string {
id := want
for i := 2; ; i++ {
if _, ok := o[id]; !ok {
o[id] = struct{}{}
return id
id = want + "_" + strconv.Itoa(i)
// Output returns the generator's output, formatted in the standard Go style.
func (g *Generator) Output() []byte {
src, err := format.Source(g.buf.Bytes())
if err != nil {
log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String())
return src