Skip to content

Commit ff433ad

Browse files
committed
Add api.GenerateOptions.BaseDir for path resolution and label stripping
Restore the dir context that vanished with the io.Reader switch, but as a single optional field. BaseDir is the directory relative paths in the config (schema, queries, output) are resolved against, and the prefix stripped from file paths shown in parse errors and diff labels. When empty, BaseDir defaults to the current working directory. A small resolvePath helper sits in api/generate.go and is called from processQuerySets and ProcessResult; absolute paths pass through unchanged. Parse errors and diff labels reuse the same BaseDir for relative formatting. Side-effects: * The endtoend tests no longer pre-rewrite paths to absolute — they just pass BaseDir = test directory. The absolutizePaths helper is gone. * The CLI sets BaseDir to the config's directory (returned by loadConfig) but keeps the chdir for now since other cmd paths (vet, push) still use cwd-relative resolution. Tradeoff vs. the strictly-cosmetic option: GenerateOptions still has six fields, but BaseDir's semantics are coherent with how it's used (resolution + label) and library callers don't need to chdir or hand-rewrite paths. https://claude.ai/code/session_01RCzB2JR5Y5ScFDUmwcxGVZ
1 parent 6b2d30d commit ff433ad

6 files changed

Lines changed: 85 additions & 98 deletions

File tree

internal/api/diff.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer)
3030
return nil
3131
}
3232

33-
func diffFiles(ctx context.Context, files map[string]string, stderr io.Writer) error {
33+
func diffFiles(ctx context.Context, baseDir string, files map[string]string, stderr io.Writer) error {
3434
defer trace.StartRegion(ctx, "checkfiles").End()
3535
var errored bool
3636

37-
wd, _ := os.Getwd()
37+
if baseDir == "" {
38+
baseDir, _ = os.Getwd()
39+
}
3840

3941
keys := make([]string, 0, len(files))
4042
for k := range files {
@@ -61,8 +63,8 @@ func diffFiles(ctx context.Context, files map[string]string, stderr io.Writer) e
6163
if len(uniHunks) > 0 {
6264
errored = true
6365
label := filename
64-
if wd != "" {
65-
if rel, err := filepath.Rel(wd, filename); err == nil {
66+
if baseDir != "" {
67+
if rel, err := filepath.Rel(baseDir, filename); err == nil {
6668
label = "/" + rel
6769
}
6870
}

internal/api/generate.go

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ type GenerateOptions struct {
3434
// found, an error is appended to GenerateResult.Errors.
3535
Diff bool
3636

37+
// BaseDir is the directory relative paths in Config are resolved against,
38+
// and the prefix stripped from file paths shown in parse errors and diff
39+
// labels. When empty, BaseDir defaults to the current working directory.
40+
BaseDir string
41+
3742
// InsecureProcessPluginNames is the allowlist of process-based plugin
3843
// names that Generate is permitted to invoke. Any process plugin declared
3944
// in the configuration whose name is not in this list causes Generate to
@@ -104,9 +109,9 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
104109
}
105110
}
106111

107-
g := &generator{output: map[string]string{}}
112+
g := &generator{output: map[string]string{}, baseDir: opts.BaseDir}
108113

109-
if err := processQuerySets(ctx, g, &conf, stderr); err != nil {
114+
if err := processQuerySets(ctx, g, &conf, opts.BaseDir, stderr); err != nil {
110115
res.Errors = append(res.Errors, err)
111116
return res
112117
}
@@ -120,7 +125,7 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult {
120125
}
121126

122127
if opts.Diff {
123-
if err := diffFiles(ctx, res.Files, stderr); err != nil {
128+
if err := diffFiles(ctx, opts.BaseDir, res.Files, stderr); err != nil {
124129
res.Errors = append(res.Errors, err)
125130
}
126131
}
@@ -144,8 +149,9 @@ The supported version can only be "1" or "2".
144149
const errMessageNoPackages = `No packages are configured`
145150

146151
type generator struct {
147-
m sync.Mutex
148-
output map[string]string
152+
m sync.Mutex
153+
baseDir string
154+
output map[string]string
149155
}
150156

151157
func (g *generator) Pairs(ctx context.Context, conf *config.Config) []outputPair {
@@ -185,16 +191,10 @@ func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSett
185191
g.m.Lock()
186192
defer g.m.Unlock()
187193

188-
absout, err := filepath.Abs(out)
189-
if err != nil {
190-
return err
191-
}
194+
absout := resolvePath(g.baseDir, out)
192195

193196
for n, source := range files {
194-
filename, err := filepath.Abs(filepath.Join(out, n))
195-
if err != nil {
196-
return err
197-
}
197+
filename := resolvePath(g.baseDir, filepath.Join(out, n))
198198
if strings.Contains(filename, "..") {
199199
return fmt.Errorf("invalid file output path: %s", filename)
200200
}
@@ -205,3 +205,19 @@ func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSett
205205
}
206206
return nil
207207
}
208+
209+
// resolvePath joins p with baseDir when p is relative. baseDir is treated as
210+
// the current working directory when empty.
211+
func resolvePath(baseDir, p string) string {
212+
if filepath.IsAbs(p) {
213+
return p
214+
}
215+
if baseDir == "" {
216+
abs, err := filepath.Abs(p)
217+
if err == nil {
218+
return abs
219+
}
220+
return p
221+
}
222+
return filepath.Join(baseDir, p)
223+
}

internal/api/parse.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,22 @@ import (
1515
"github.com/sqlc-dev/sqlc/internal/opts"
1616
)
1717

18-
func printFileErr(stderr io.Writer, fileErr *multierr.FileError) {
19-
wd, err := os.Getwd()
20-
if err != nil {
21-
wd = ""
18+
func printFileErr(stderr io.Writer, baseDir string, fileErr *multierr.FileError) {
19+
if baseDir == "" {
20+
if wd, err := os.Getwd(); err == nil {
21+
baseDir = wd
22+
}
2223
}
2324
filename := fileErr.Filename
24-
if wd != "" {
25-
if rel, err := filepath.Rel(wd, fileErr.Filename); err == nil {
25+
if baseDir != "" {
26+
if rel, err := filepath.Rel(baseDir, fileErr.Filename); err == nil {
2627
filename = rel
2728
}
2829
}
2930
fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err)
3031
}
3132

32-
func parse(ctx context.Context, name string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
33+
func parse(ctx context.Context, name, baseDir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
3334
defer trace.StartRegion(ctx, "parse").End()
3435
c, err := compiler.NewCompiler(sql, combo, parserOpts)
3536
defer func() {
@@ -45,7 +46,7 @@ func parse(ctx context.Context, name string, sql config.SQL, combo config.Combin
4546
fmt.Fprintf(stderr, "# package %s\n", name)
4647
if parserErr, ok := err.(*multierr.Error); ok {
4748
for _, fileErr := range parserErr.Errs() {
48-
printFileErr(stderr, fileErr)
49+
printFileErr(stderr, baseDir, fileErr)
4950
}
5051
} else {
5152
fmt.Fprintf(stderr, "error parsing schema: %s\n", err)
@@ -59,7 +60,7 @@ func parse(ctx context.Context, name string, sql config.SQL, combo config.Combin
5960
fmt.Fprintf(stderr, "# package %s\n", name)
6061
if parserErr, ok := err.(*multierr.Error); ok {
6162
for _, fileErr := range parserErr.Errs() {
62-
printFileErr(stderr, fileErr)
63+
printFileErr(stderr, baseDir, fileErr)
6364
}
6465
} else {
6566
fmt.Fprintf(stderr, "error parsing queries: %s\n", err)

internal/api/process.go

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"context"
66
"fmt"
77
"io"
8-
"path/filepath"
98
"runtime"
109
"runtime/trace"
1110

@@ -29,7 +28,7 @@ type resultProcessor interface {
2928
ProcessResult(context.Context, config.CombinedSettings, outputPair, *compiler.Result) error
3029
}
3130

32-
func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Config, stderr io.Writer) error {
31+
func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Config, baseDir string, stderr io.Writer) error {
3332
errored := false
3433

3534
pairs := rp.Pairs(ctx, conf)
@@ -50,25 +49,13 @@ func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Conf
5049

5150
absSchema := make([]string, 0, len(sql.Schema))
5251
for _, s := range sql.Schema {
53-
abs, err := filepath.Abs(s)
54-
if err != nil {
55-
fmt.Fprintf(errout, "resolve schema path %s: %s\n", s, err)
56-
errored = true
57-
return nil
58-
}
59-
absSchema = append(absSchema, abs)
52+
absSchema = append(absSchema, resolvePath(baseDir, s))
6053
}
6154
sql.Schema = absSchema
6255

6356
absQueries := make([]string, 0, len(sql.Queries))
6457
for _, q := range sql.Queries {
65-
abs, err := filepath.Abs(q)
66-
if err != nil {
67-
fmt.Fprintf(errout, "resolve query path %s: %s\n", q, err)
68-
errored = true
69-
return nil
70-
}
71-
absQueries = append(absQueries, abs)
58+
absQueries = append(absQueries, resolvePath(baseDir, q))
7259
}
7360
sql.Queries = absQueries
7461

@@ -90,7 +77,7 @@ func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Conf
9077
packageRegion := trace.StartRegion(gctx, "package")
9178
trace.Logf(gctx, "", "name=%s plugin=%s", name, lang)
9279

93-
result, failed := parse(gctx, name, sql.SQL, combo, parseOpts, errout)
80+
result, failed := parse(gctx, name, baseDir, sql.SQL, combo, parseOpts, errout)
9481
if failed {
9582
packageRegion.End()
9683
errored = true

internal/cmd/cmd.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,10 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) {
186186
// loadConfig opens the sqlc config and reads it into memory. It also chdirs
187187
// the process to the config's directory so that relative paths declared in the
188188
// config resolve correctly when api.Generate is called. Returns the config
189-
// bytes and the list of process plugin names declared in the config (used to
190-
// populate api.GenerateOptions.InsecureProcessPluginNames).
191-
func loadConfig(stderr io.Writer, dir, name string) ([]byte, []string) {
189+
// bytes, the absolute config directory (for api.GenerateOptions.BaseDir), and
190+
// the list of process plugin names declared in the config (used to populate
191+
// api.GenerateOptions.InsecureProcessPluginNames).
192+
func loadConfig(stderr io.Writer, dir, name string) ([]byte, string, []string) {
192193
configPath, _, err := readConfig(stderr, dir, name)
193194
if err != nil {
194195
os.Exit(1)
@@ -208,7 +209,8 @@ func loadConfig(stderr io.Writer, dir, name string) ([]byte, []string) {
208209
fmt.Fprintf(stderr, "error parsing %s: %s\n", configPath, err)
209210
os.Exit(1)
210211
}
211-
if err := os.Chdir(filepath.Dir(configPath)); err != nil {
212+
configDir := filepath.Dir(configPath)
213+
if err := os.Chdir(configDir); err != nil {
212214
fmt.Fprintf(stderr, "error changing directory: %s\n", err)
213215
os.Exit(1)
214216
}
@@ -218,7 +220,7 @@ func loadConfig(stderr io.Writer, dir, name string) ([]byte, []string) {
218220
names = append(names, p.Name)
219221
}
220222
}
221-
return data, names
223+
return data, configDir, names
222224
}
223225

224226
// allowedProcessPluginNames returns the names that should populate
@@ -239,11 +241,12 @@ var genCmd = &cobra.Command{
239241
stderr := cmd.ErrOrStderr()
240242
dir, name := getConfigPath(stderr, cmd.Flag("file"))
241243
env := ParseEnv(cmd)
242-
data, declared := loadConfig(stderr, dir, name)
244+
data, baseDir, declared := loadConfig(stderr, dir, name)
243245
res := api.Generate(cmd.Context(), api.GenerateOptions{
244246
Config: bytes.NewReader(data),
245247
Stderr: stderr,
246248
Write: true,
249+
BaseDir: baseDir,
247250
InsecureProcessPluginNames: allowedProcessPluginNames(env, declared),
248251
})
249252
if len(res.Errors) > 0 {
@@ -261,10 +264,11 @@ var checkCmd = &cobra.Command{
261264
stderr := cmd.ErrOrStderr()
262265
dir, name := getConfigPath(stderr, cmd.Flag("file"))
263266
env := ParseEnv(cmd)
264-
data, declared := loadConfig(stderr, dir, name)
267+
data, baseDir, declared := loadConfig(stderr, dir, name)
265268
res := api.Generate(cmd.Context(), api.GenerateOptions{
266269
Config: bytes.NewReader(data),
267270
Stderr: stderr,
271+
BaseDir: baseDir,
268272
InsecureProcessPluginNames: allowedProcessPluginNames(env, declared),
269273
})
270274
if len(res.Errors) > 0 {
@@ -282,11 +286,12 @@ var diffCmd = &cobra.Command{
282286
stderr := cmd.ErrOrStderr()
283287
dir, name := getConfigPath(stderr, cmd.Flag("file"))
284288
env := ParseEnv(cmd)
285-
data, declared := loadConfig(stderr, dir, name)
289+
data, baseDir, declared := loadConfig(stderr, dir, name)
286290
res := api.Generate(cmd.Context(), api.GenerateOptions{
287291
Config: bytes.NewReader(data),
288292
Stderr: stderr,
289293
Diff: true,
294+
BaseDir: baseDir,
290295
InsecureProcessPluginNames: allowedProcessPluginNames(env, declared),
291296
})
292297
if len(res.Errors) > 0 {

0 commit comments

Comments
 (0)