diff --git a/example/embed/di/injector_gen.go b/example/embed/di/injector_gen.go index c6849fc..44eacb9 100644 --- a/example/embed/di/injector_gen.go +++ b/example/embed/di/injector_gen.go @@ -9,8 +9,8 @@ import ( // NewInfra initializes dependencies and constructs Infra. func NewInfra() *Infra { - readerDatabaseConfig := config.NewReaderDatabaseConfig() - database := infra.NewDatabase(readerDatabaseConfig) + databaseConfig := config.NewReaderDatabaseConfig() + database := infra.NewDatabase(databaseConfig) return &Infra{ Database: database, diff --git a/example/embed/injector_gen.go b/example/embed/injector_gen.go index afe5671..b934ee4 100644 --- a/example/embed/injector_gen.go +++ b/example/embed/injector_gen.go @@ -10,9 +10,9 @@ import ( // NewApp initializes dependencies and constructs App. func NewApp(infra *di.Infra) *App { database := infra.Database - user := service.NewUser(database) + userService := service.NewUser(database) return &App{ - UserService: user, + UserService: userService, } } diff --git a/example/embedded-error/di/injector_gen.go b/example/embedded-error/di/injector_gen.go index 18f7d36..4d716ff 100644 --- a/example/embedded-error/di/injector_gen.go +++ b/example/embedded-error/di/injector_gen.go @@ -9,8 +9,8 @@ import ( // NewInfra initializes dependencies and constructs Infra. func NewInfra() (*Infra, error) { - readerDatabaseConfig := config.NewReaderDatabaseConfig() - database, err := infra.NewDatabase(readerDatabaseConfig) + databaseConfig := config.NewReaderDatabaseConfig() + database, err := infra.NewDatabase(databaseConfig) if err != nil { return nil, err } diff --git a/example/embedded-error/injector_gen.go b/example/embedded-error/injector_gen.go index 38bcbcf..d7b4794 100644 --- a/example/embedded-error/injector_gen.go +++ b/example/embedded-error/injector_gen.go @@ -9,10 +9,10 @@ import ( // NewApp initializes dependencies and constructs App. func NewApp(database *infra.Database) *App { - user := service.NewUser(database) + userService := service.NewUser(database) return &App{ - UserService: user, + UserService: userService, } } diff --git a/example/embedded/di/injector_gen.go b/example/embedded/di/injector_gen.go index 88a0746..1deb3bd 100644 --- a/example/embedded/di/injector_gen.go +++ b/example/embedded/di/injector_gen.go @@ -9,8 +9,8 @@ import ( // NewInfra initializes dependencies and constructs Infra. func NewInfra() *Infra { - readerDatabaseConfig := config.NewReaderDatabaseConfig() - database := infra.NewDatabase(readerDatabaseConfig) + databaseConfig := config.NewReaderDatabaseConfig() + database := infra.NewDatabase(databaseConfig) return &Infra{ Database: database, diff --git a/example/embedded/injector_gen.go b/example/embedded/injector_gen.go index 48fe35a..ea5a654 100644 --- a/example/embedded/injector_gen.go +++ b/example/embedded/injector_gen.go @@ -9,9 +9,9 @@ import ( // NewApp initializes dependencies and constructs App. func NewApp(database *infra.Database) *App { - user := service.NewUser(database) + userService := service.NewUser(database) return &App{ - UserService: user, + UserService: userService, } } diff --git a/example/name-conflict/injector_gen.go b/example/name-conflict/injector_gen.go index 08ca241..6d8a81b 100644 --- a/example/name-conflict/injector_gen.go +++ b/example/name-conflict/injector_gen.go @@ -11,8 +11,8 @@ import ( // NewTaskContainer initializes dependencies and constructs TaskContainer. func NewTaskContainer() (*TaskContainer, error) { - writerDatabaseConfig := config.NewWriterDatabaseConfig() - database, err := infra.NewDatabase(writerDatabaseConfig) + databaseConfig := config.NewWriterDatabaseConfig() + database, err := infra.NewDatabase(databaseConfig) if err != nil { return nil, err } @@ -27,8 +27,8 @@ func NewTaskContainer() (*TaskContainer, error) { // NewUserContainer initializes dependencies and constructs UserContainer. func NewUserContainer() (*UserContainer, error) { - writerDatabaseConfig := config.NewWriterDatabaseConfig() - database, err := infra.NewDatabase(writerDatabaseConfig) + databaseConfig := config.NewWriterDatabaseConfig() + database, err := infra.NewDatabase(databaseConfig) if err != nil { return nil, err } diff --git a/example/returns/injector_gen.go b/example/returns/injector_gen.go index 8b35e6f..1787ae5 100644 --- a/example/returns/injector_gen.go +++ b/example/returns/injector_gen.go @@ -8,9 +8,9 @@ import ( // NewGreeter initializes dependencies and constructs app. func NewGreeter() greeter.Greeter { - greeter2 := greeter.NewGreeter() + service := greeter.NewGreeter() return &app{ - service: greeter2, + service: service, } } diff --git a/example/simple/injector_gen.go b/example/simple/injector_gen.go index 706d9cd..c675adc 100644 --- a/example/simple/injector_gen.go +++ b/example/simple/injector_gen.go @@ -10,11 +10,11 @@ import ( // NewContainer initializes dependencies and constructs Container. func NewContainer() *Container { - readerDatabaseConfig := config.NewReaderDatabaseConfig() - database := infra.NewDatabase(readerDatabaseConfig) - user := service.NewUser(database) + databaseConfig := config.NewReaderDatabaseConfig() + database := infra.NewDatabase(databaseConfig) + userService := service.NewUser(database) return &Container{ - UserService: user, + UserService: userService, } } diff --git a/example/with-error/injector_gen.go b/example/with-error/injector_gen.go index 7986fb0..272cb79 100644 --- a/example/with-error/injector_gen.go +++ b/example/with-error/injector_gen.go @@ -10,15 +10,15 @@ import ( // NewContainer initializes dependencies and constructs Container. func NewContainer() (*Container, error) { - readerDatabaseConfig := config.NewReaderDatabaseConfig() - database, err := infra.NewDatabase(readerDatabaseConfig) + databaseConfig := config.NewReaderDatabaseConfig() + database, err := infra.NewDatabase(databaseConfig) if err != nil { return nil, err } - user := service.NewUser(database) + userService := service.NewUser(database) return &Container{ - UserService: user, + UserService: userService, }, nil } diff --git a/internal/plan/plan.go b/internal/plan/plan.go index 7ec7a98..248d47c 100644 --- a/internal/plan/plan.go +++ b/internal/plan/plan.go @@ -144,6 +144,8 @@ func Build(c ir.Container, idx *Index, opts Options) (Plan, []diag.Diag) { } } + renameOutputSteps(r.steps, outputs) + returnsErr := false for _, s := range r.steps { if s.Kind == StepKindProvider && s.Provider != nil && s.Provider.ReturnsError { @@ -612,6 +614,10 @@ func deriveInputName(t types.Type) string { if t == nil { return "arg" } + // Resolve Go 1.22+ type aliases (`type Client = valkey.Client`) so + // the alias's own name flows through instead of falling out to the + // "arg" sentinel. + t = types.Unalias(t) if ptr, ok := t.(*types.Pointer); ok { return deriveInputName(ptr.Elem()) } @@ -623,6 +629,80 @@ func deriveInputName(t types.Type) string { return "arg" } +// renameOutputSteps renames non-input steps that produce a container +// output to lowerFirst(field name). This puts the field's own identifier +// at the call site (`tx := tx.New(...)` for a `Tx` field, suffixed if it +// would shadow an existing step name). Steps that are not bound to any +// output keep the type-derived name picked at resolution time. +// +// Renaming runs in two passes so that an output whose desired base name +// is currently held by another step that is also about to be renamed can +// claim the now-vacated name without unnecessarily suffixing. Without +// this, a hypothetical swap (step A holds "foo" and wants "db", step B +// holds "db" and wants "foo") would land on `foo2`/`db` instead of the +// clean `foo`/`db`. +func renameOutputSteps(steps []Step, outputs []Output) { + used := make(map[string]bool, len(steps)) + for _, st := range steps { + used[st.VarName] = true + } + + type pendingRename struct { + stepIdx int + base string + } + var renames []pendingRename + // A shared step (bound to more than one container field, e.g. when two + // fields request the same dependency) appears multiple times in + // outputs. Decide the rename for each step at the first valid output + // and skip any later occurrences so we don't queue the same step + // twice, which would leak the dropped candidate name into `used` and + // force unrelated steps onto a suffix. + decided := make(map[int]bool, len(outputs)) + for _, o := range outputs { + if o.StepIndex < 0 || o.StepIndex >= len(steps) { + continue + } + if decided[o.StepIndex] { + continue + } + s := &steps[o.StepIndex] + if s.Kind == StepKindInput { + decided[o.StepIndex] = true + continue + } + base := lowerFirst(o.FieldName) + if base == "" { + continue + } + if base == s.VarName { + // Existing name already lines up with this field; no rename + // needed even if later outputs would have picked a different + // base. + decided[o.StepIndex] = true + continue + } + delete(used, s.VarName) + renames = append(renames, pendingRename{stepIdx: o.StepIndex, base: base}) + decided[o.StepIndex] = true + } + + for _, r := range renames { + pick := r.base + if used[pick] { + for i := 2; ; i++ { + try := fmt.Sprintf("%s%d", r.base, i) + if !used[try] { + pick = try + break + } + } + } + steps[r.stepIdx].VarName = pick + used[pick] = true + } +} + func varNameForEmbed(es embedSource, existing []Step) string { // FieldName may be a dotted selector (e.g. "CommonInfra.DB") when the // source comes from a promoted field; only the leaf segment is a valid @@ -651,11 +731,21 @@ func varNameForEmbed(es embedSource, existing []Step) string { } func varNameForProvider(p *ir.Provider, existing []Step) string { - base := p.FuncName - if strings.HasPrefix(base, "New") && len(base) > 3 { - base = base[3:] + // Name the variable after what the call produces, not after the + // constructor function. `db.Open(...) *sql.DB` reads more naturally + // as `db := db.Open(...)` than `open := db.Open(...)`, and + // container-field-bound steps later get renamed once more to the + // destination field name. + base := deriveInputName(p.Result) + if base == "arg" { + // Anonymous or unnamed result type — fall back to the function + // name for a less generic label than "arg". + base = p.FuncName + if strings.HasPrefix(base, "New") && len(base) > 3 { + base = base[3:] + } + base = lowerFirst(base) } - base = lowerFirst(base) if base == "" { base = "v" } diff --git a/internal/plan/plan_test.go b/internal/plan/plan_test.go index 056f93b..0c1f51f 100644 --- a/internal/plan/plan_test.go +++ b/internal/plan/plan_test.go @@ -144,6 +144,242 @@ type Container struct { } } +func TestBuild_FieldBoundStepUsesFieldName(t *testing.T) { + t.Parallel() + + // `Tx Transactor inject:""` should produce a variable named after the + // destination field — "tx" — rather than the function ("new") or the + // result type ("transactor"). + src := `package test +type Transactor struct{} +func New() Transactor { return Transactor{} } +type Container struct { + Tx Transactor ` + "`inject:\"\"`" + ` +} +` + p, _ := mustBuild(t, src, "Container", plan.Options{}) + + if len(p.Steps) != 1 { + t.Fatalf("steps = %d, want 1", len(p.Steps)) + } + if got, want := p.Steps[0].VarName, "tx"; got != want { + t.Errorf("var name = %q, want %q", got, want) + } +} + +func TestBuild_IntermediateStepUsesResultType(t *testing.T) { + t.Parallel() + + // `New() *Foo` consumed transitively by `Make` (whose result is the + // container's field) is an intermediate step with no field name to + // borrow from. It should fall back to the result type — "foo" — not + // the function name. + src := `package test +type Foo struct{} +type Bar struct{} +func New() *Foo { return nil } +func Make(f *Foo) *Bar { return nil } +type Container struct { + Bar *Bar ` + "`inject:\"\"`" + ` +} +` + p, _ := mustBuild(t, src, "Container", plan.Options{}) + + var fooStep, barStep plan.Step + for _, s := range p.Steps { + switch { + case s.Provider != nil && s.Provider.FuncName == "New": + fooStep = s + case s.Provider != nil && s.Provider.FuncName == "Make": + barStep = s + } + } + if got, want := fooStep.VarName, "foo"; got != want { + t.Errorf("intermediate var name = %q, want %q", got, want) + } + if got, want := barStep.VarName, "bar"; got != want { + t.Errorf("field-bound var name = %q, want %q", got, want) + } +} + +func TestBuild_FieldNameSwapNoUnnecessarySuffix(t *testing.T) { + t.Parallel() + + // Two steps want to swap names: step A holds "foo" (result type Foo) + // and is bound to a "Db" field (wants "db"), while step B holds "db" + // (result type Db) and is bound to a "Foo" field (wants "foo"). The + // rename pass must vacate both old names before picking the new ones + // so each side gets its preferred name without a `_2` suffix. + src := `package test +type Foo struct{} +type Db struct{} +func NewFoo() Foo { return Foo{} } +func NewDb() Db { return Db{} } +type Container struct { + Db Foo ` + "`inject:\"\"`" + ` + Foo Db ` + "`inject:\"\"`" + ` +} +` + p, _ := mustBuild(t, src, "Container", plan.Options{}) + + bindings := map[string]string{} + for _, o := range p.Outputs { + bindings[o.FieldName] = p.Steps[o.StepIndex].VarName + } + if got, want := bindings["Db"], "db"; got != want { + t.Errorf("field Db var = %q, want %q (swap should not force a suffix)", got, want) + } + if got, want := bindings["Foo"], "foo"; got != want { + t.Errorf("field Foo var = %q, want %q (swap should not force a suffix)", got, want) + } +} + +func TestBuild_TypeAliasResultDerivesAliasName(t *testing.T) { + t.Parallel() + + // A provider returning a type alias used to fall through deriveInputName's + // *types.Named check (alias is *types.Alias in Go 1.22+), leaving the + // variable to fall back to the function name. After types.Unalias is + // applied inside deriveInputName, the alias's target name flows through. + src := `package test +type Real struct{} +type Alias = Real +func New() Alias { return Alias{} } +type Container struct { + V Alias ` + "`inject:\"\"`" + ` +} +` + p, _ := mustBuild(t, src, "Container", plan.Options{}) + + // Steps are: one provider step (New). It's field-bound to V → renamed + // to lowerFirst("V") = "v". The interesting check is the intermediate + // case: if no field were attached, we'd want "real" not "new" / "arg". + // Reuse the existing intermediate-naming test shape: + src2 := `package test +type Real struct{} +type Alias = Real +type Wrap struct{} +func New() Alias { return Alias{} } +func Wrapper(a Alias) Wrap { return Wrap{} } +type Container struct { + W Wrap ` + "`inject:\"\"`" + ` +} +` + p2, _ := mustBuild(t, src2, "Container", plan.Options{}) + + var aliasStep plan.Step + for _, s := range p2.Steps { + if s.Provider != nil && s.Provider.FuncName == "New" { + aliasStep = s + break + } + } + if got, want := aliasStep.VarName, "real"; got != want { + t.Errorf("alias intermediate var = %q, want %q (Unalias should resolve to Real)", got, want) + } + + // Sanity: the field-bound case in p (V Alias) lands on "v" through + // the rename pass. + if p.Outputs[0].FieldName != "V" || p.Steps[p.Outputs[0].StepIndex].VarName != "v" { + t.Errorf("field-bound alias did not land on field name; outputs=%+v steps=%+v", + p.Outputs, p.Steps) + } +} + +func TestBuild_SharedStepDoesNotLeakCandidate(t *testing.T) { + t.Parallel() + + // Two container fields share the same provider step (both want *DB). + // The naive rename pass would queue the step twice — once with "db" + // and once with "backup" — and mark both names as used, forcing an + // unrelated step that wanted "backup" onto "backup2". The decided- + // once gate ensures only the first field name is queued. + src := `package test +type DB struct{} +type Backup struct{} +func NewDB() *DB { return nil } +func NewBackup() *Backup { return nil } +type Container struct { + DB *DB ` + "`inject:\"\"`" + ` + Backup *Backup ` + "`inject:\"\"`" + ` + Twin *DB ` + "`inject:\"\"`" + ` +} +` + p, _ := mustBuild(t, src, "Container", plan.Options{}) + + var backupVar string + for _, s := range p.Steps { + if s.Provider != nil && s.Provider.FuncName == "NewBackup" { + backupVar = s.VarName + break + } + } + if got, want := backupVar, "backup"; got != want { + t.Errorf("NewBackup var = %q, want %q (Twin sharing *DB must not leak the name)", got, want) + } +} + +func TestBuild_SharedStepPreservesMatchingFieldName(t *testing.T) { + t.Parallel() + + // Step's existing name "db" already matches field "DB"; another + // field "Backup" also binds to the same step. The first match + // should leave the variable alone instead of subsequently renaming + // it to "backup". + src := `package test +type DB struct{} +func NewDB() *DB { return nil } +type Container struct { + DB *DB ` + "`inject:\"\"`" + ` + Backup *DB ` + "`inject:\"\"`" + ` +} +` + p, _ := mustBuild(t, src, "Container", plan.Options{}) + + if len(p.Steps) != 1 { + t.Fatalf("steps = %d, want 1", len(p.Steps)) + } + if got, want := p.Steps[0].VarName, "db"; got != want { + t.Errorf("shared-step var = %q, want %q (matching field name should win)", got, want) + } +} + +func TestBuild_FieldNameTakenForcesSuffix(t *testing.T) { + t.Parallel() + + // An intermediate step's natural name ("tx", from result type Tx) + // already occupies that slot, then a field-bound step (field "Tx", + // holding a different type) tries to claim "tx" too. The field-bound + // step should cascade to "tx2" — the simple-cascade behavior chosen + // in option A. + src := `package test +type Tx struct{} +type Other struct{} +func NewTx() Tx { return Tx{} } +func NewOther(Tx) Other { return Other{} } +type Container struct { + Tx Other ` + "`inject:\"\"`" + ` +} +` + p, _ := mustBuild(t, src, "Container", plan.Options{}) + + var intermediate, fieldBound plan.Step + for _, s := range p.Steps { + switch { + case s.Provider != nil && s.Provider.FuncName == "NewTx": + intermediate = s + case s.Provider != nil && s.Provider.FuncName == "NewOther": + fieldBound = s + } + } + if got, want := intermediate.VarName, "tx"; got != want { + t.Errorf("intermediate var = %q, want %q", got, want) + } + if got, want := fieldBound.VarName, "tx2"; got != want { + t.Errorf("field-bound var = %q, want %q (cascade)", got, want) + } +} + func TestBuild_NonBlankWithActsAsOverride(t *testing.T) { t.Parallel()