Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions internal/checker/symbolaccessibility.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package checker

import (
"reflect"
"slices"
"unsafe"

"github.com/microsoft/typescript-go/internal/ast"
"github.com/microsoft/typescript-go/internal/core"
Expand Down Expand Up @@ -143,7 +141,7 @@ func (ch *Checker) getWithAlternativeContainers(container *ast.Symbol, symbol *a
container.Flags&leftMeaning == 0) &&
container.Flags&ast.SymbolFlagsType != 0 &&
ch.getDeclaredTypeOfSymbol(container).flags&TypeFlagsObject != 0 {
ch.someSymbolTableInScope(enclosingDeclaration, func(t ast.SymbolTable, _ bool, _ bool, _ *ast.Node) bool {
ch.someSymbolTableInScope(enclosingDeclaration, func(t ast.SymbolTable, _ symbolTableID, _ bool, _ bool, _ *ast.Node) bool {
for _, s := range t {
if s.Flags&leftMeaning != 0 && ch.getTypeOfSymbol(s) == ch.getDeclaredTypeOfSymbol(container) {
firstVariableMatch = s
Expand Down Expand Up @@ -378,7 +376,7 @@ func (ch *Checker) getAccessibleSymbolChain(
meaning ast.SymbolFlags,
useOnlyExternalAliasing bool,
) []*ast.Symbol {
return ch.getAccessibleSymbolChainEx(accessibleSymbolChainContext{symbol, enclosingDeclaration, meaning, useOnlyExternalAliasing, make(map[ast.SymbolId]map[unsafe.Pointer]struct{})})
return ch.getAccessibleSymbolChainEx(accessibleSymbolChainContext{symbol, enclosingDeclaration, meaning, useOnlyExternalAliasing, make(map[ast.SymbolId]map[symbolTableID]struct{})})
}

func (ch *Checker) GetAccessibleSymbolChain(
Expand All @@ -395,7 +393,35 @@ type accessibleSymbolChainContext struct {
enclosingDeclaration *ast.Node
meaning ast.SymbolFlags
useOnlyExternalAliasing bool
visitedSymbolTablesMap map[ast.SymbolId]map[unsafe.Pointer]struct{}
visitedSymbolTablesMap map[ast.SymbolId]map[symbolTableID]struct{}
}

// symbolTableID uniquely identifies a symbol table by encoding its source.
// The high 2 bits encode the kind (locals, exports, members, globals),
// and the remaining bits encode the NodeId or SymbolId of the source.
type symbolTableID uint64

const (
stKindLocals symbolTableID = iota << 62
stKindExports
stKindMembers
stKindGlobals
)

func symbolTableIDFromLocals(node *ast.Node) symbolTableID {
return stKindLocals | symbolTableID(ast.GetNodeId(node))
}

func symbolTableIDFromExports(sym *ast.Symbol) symbolTableID {
return stKindExports | symbolTableID(ast.GetSymbolId(sym))
}

func symbolTableIDFromMembers(sym *ast.Symbol) symbolTableID {
return stKindMembers | symbolTableID(ast.GetSymbolId(sym))
}

func symbolTableIDFromGlobals() symbolTableID {
return stKindGlobals
}

func (ch *Checker) getAccessibleSymbolChainEx(ctx accessibleSymbolChainContext) []*ast.Symbol {
Expand All @@ -407,7 +433,7 @@ func (ch *Checker) getAccessibleSymbolChainEx(ctx accessibleSymbolChainContext)
}
// Go from enclosingDeclaration to the first scope we check, so the cache is keyed off the scope and thus shared more
var firstRelevantLocation *ast.Node
ch.someSymbolTableInScope(ctx.enclosingDeclaration, func(_ ast.SymbolTable, _ bool, _ bool, node *ast.Node) bool {
ch.someSymbolTableInScope(ctx.enclosingDeclaration, func(_ ast.SymbolTable, _ symbolTableID, _ bool, _ bool, node *ast.Node) bool {
firstRelevantLocation = node
return true
})
Expand All @@ -423,8 +449,8 @@ func (ch *Checker) getAccessibleSymbolChainEx(ctx accessibleSymbolChainContext)

var result []*ast.Symbol

ch.someSymbolTableInScope(ctx.enclosingDeclaration, func(t ast.SymbolTable, ignoreQualification bool, isLocalNameLookup bool, _ *ast.Node) bool {
res := ch.getAccessibleSymbolChainFromSymbolTable(ctx, t, ignoreQualification, isLocalNameLookup)
ch.someSymbolTableInScope(ctx.enclosingDeclaration, func(t ast.SymbolTable, tableId symbolTableID, ignoreQualification bool, isLocalNameLookup bool, _ *ast.Node) bool {
res := ch.getAccessibleSymbolChainFromSymbolTable(ctx, t, tableId, ignoreQualification, isLocalNameLookup)
if len(res) > 0 {
result = res
return true
Expand All @@ -438,30 +464,30 @@ func (ch *Checker) getAccessibleSymbolChainEx(ctx accessibleSymbolChainContext)
/**
* @param {ignoreQualification} boolean Set when a symbol is being looked for through the exports of another symbol (meaning we have a route to qualify it already)
*/
func (ch *Checker) getAccessibleSymbolChainFromSymbolTable(ctx accessibleSymbolChainContext, t ast.SymbolTable, ignoreQualification bool, isLocalNameLookup bool) []*ast.Symbol {
func (ch *Checker) getAccessibleSymbolChainFromSymbolTable(ctx accessibleSymbolChainContext, t ast.SymbolTable, tableId symbolTableID, ignoreQualification bool, isLocalNameLookup bool) []*ast.Symbol {
symId := ast.GetSymbolId(ctx.symbol)
visitedSymbolTables, ok := ctx.visitedSymbolTablesMap[symId]
if !ok {
visitedSymbolTables = make(map[unsafe.Pointer]struct{})
visitedSymbolTables = make(map[symbolTableID]struct{})
ctx.visitedSymbolTablesMap[symId] = visitedSymbolTables
}

id := reflect.ValueOf(t).UnsafePointer() // TODO: Is this seriously the only way to check reference equality of maps?
_, present := visitedSymbolTables[id]
_, present := visitedSymbolTables[tableId]
if present {
return nil
}
visitedSymbolTables[id] = struct{}{}
visitedSymbolTables[tableId] = struct{}{}

res := ch.trySymbolTable(ctx, t, ignoreQualification, isLocalNameLookup)
res := ch.trySymbolTable(ctx, t, tableId == stKindGlobals, ignoreQualification, isLocalNameLookup)

delete(visitedSymbolTables, id)
delete(visitedSymbolTables, tableId)
return res
}

func (ch *Checker) trySymbolTable(
ctx accessibleSymbolChainContext,
symbols ast.SymbolTable,
isGlobals bool,
ignoreQualification bool,
isLocalNameLookup bool,
) []*ast.Symbol {
Expand Down Expand Up @@ -506,7 +532,7 @@ func (ch *Checker) trySymbolTable(
}

// If there's no result and we're looking at the global symbol table, treat `globalThis` like an alias and try to lookup thru that
if reflect.ValueOf(ch.globals).UnsafePointer() == reflect.ValueOf(symbols).UnsafePointer() {
if isGlobals {
return ch.getCandidateListForSymbol(ctx, ch.globalThisSymbol, ch.globalThisSymbol, ignoreQualification)
}
return nil
Expand Down Expand Up @@ -553,7 +579,8 @@ func (ch *Checker) getCandidateListForSymbol(
if candidateTable == nil {
return nil
}
accessibleSymbolsFromExports := ch.getAccessibleSymbolChainFromSymbolTable(ctx, candidateTable /*ignoreQualification*/, true, false)
candidateTableId := symbolTableIDFromExports(resolvedImportedSymbol)
accessibleSymbolsFromExports := ch.getAccessibleSymbolChainFromSymbolTable(ctx, candidateTable, candidateTableId /*ignoreQualification*/, true, false)
if len(accessibleSymbolsFromExports) == 0 {
return nil
}
Expand Down Expand Up @@ -606,7 +633,7 @@ func (ch *Checker) canQualifySymbol(

func (ch *Checker) needsQualification(symbol *ast.Symbol, enclosingDeclaration *ast.Node, meaning ast.SymbolFlags) bool {
qualify := false
ch.someSymbolTableInScope(enclosingDeclaration, func(symbolTable ast.SymbolTable, _ bool, _ bool, _ *ast.Node) bool {
ch.someSymbolTableInScope(enclosingDeclaration, func(symbolTable ast.SymbolTable, _ symbolTableID, _ bool, _ bool, _ *ast.Node) bool {
// If symbol of this name is not available in the symbol table we are ok
res, ok := symbolTable[symbol.Name]
if !ok || res == nil {
Expand Down Expand Up @@ -664,12 +691,12 @@ func isPropertyOrMethodDeclarationSymbol(symbol *ast.Symbol) bool {

func (ch *Checker) someSymbolTableInScope(
enclosingDeclaration *ast.Node,
callback func(symbolTable ast.SymbolTable, ignoreQualification bool, isLocalNameLookup bool, scopeNode *ast.Node) bool,
callback func(symbolTable ast.SymbolTable, tableId symbolTableID, ignoreQualification bool, isLocalNameLookup bool, scopeNode *ast.Node) bool,
) bool {
for location := enclosingDeclaration; location != nil; location = location.Parent {
// Locals of a source file are not in scope (because they get merged into the global symbol table)
if canHaveLocals(location) && location.Locals() != nil && !ast.IsGlobalSourceFile(location) {
if callback(location.Locals(), false, true, location) {
if callback(location.Locals(), symbolTableIDFromLocals(location.AsNode()), false, true, location) {
return true
}
}
Expand All @@ -679,7 +706,7 @@ func (ch *Checker) someSymbolTableInScope(
break
}
sym := ch.getSymbolOfDeclaration(location)
if callback(sym.Exports, false, true, location) {
if callback(sym.Exports, symbolTableIDFromExports(sym), false, true, location) {
return true
}
case ast.KindClassDeclaration, ast.KindClassExpression, ast.KindInterfaceDeclaration:
Expand All @@ -701,13 +728,13 @@ func (ch *Checker) someSymbolTableInScope(
table[key] = memberSymbol
}
}
if table != nil && callback(table, false, false, location) {
if table != nil && callback(table, symbolTableIDFromMembers(sym), false, false, location) {
return true
}
}
}

return callback(ch.globals, false, true, nil)
return callback(ch.globals, symbolTableIDFromGlobals(), false, true, nil)
}

/**
Expand Down