Skip to content

Commit 454b128

Browse files
lada-gaginaintellij-monorepo-bot
authored andcommitted
PY-58857 Infer typing.LiteralString for string literals
GitOrigin-RevId: 27507deabd61faedf7937415016f0f8334e5a418
1 parent 2d036b9 commit 454b128

70 files changed

Lines changed: 622 additions & 192 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

python/python-markdown/test/com/jetbrains/python/markdown/PyCodeFenceTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class PyCodeFenceTest : PyTestCase() {
142142
expect_str("abc")
143143
144144
# Should warn
145-
expect_bytes(<warning descr="Expected type 'bytes', got 'str' instead">"abc"</warning>)
145+
expect_bytes(<warning descr="Expected type 'bytes', got 'LiteralString' instead">"abc"</warning>)
146146
```
147147
""".trimIndent())
148148
myFixture.enableInspections(PyTypeCheckerInspection::class.java)

python/python-psi-impl/resources/META-INF/PythonPsiImpl.xml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,9 @@
449449
key="enable.numpy.pyi.stubs"/>
450450
<registryKey key="python.explicit.namespace.packages" defaultValue="true" restartRequired="true"
451451
description="Require marking namespace packages explicitly, treat regular directories as implicit source roots"/>
452+
<registryKey key="python.type.hints.literal.string" defaultValue="true"
453+
description="When enabled, activates LiteralString inference for Python string literals" />
454+
452455
</extensions>
453456

454457
<extensionPoints>

python/python-psi-impl/src/com/jetbrains/python/codeInsight/intentions/PyTypeHintGenerationUtil.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ else if (type instanceof PyNamedTupleType) {
318318
symbols.add((PsiNamedElement)element);
319319
}
320320
}
321+
else if (type instanceof PyLiteralStringType) {
322+
typingTypes.add("LiteralString");
323+
}
321324
else if (type instanceof PyCollectionType) {
322325
if (type instanceof PyCollectionTypeImpl) {
323326
final PyClass pyClass = ((PyCollectionTypeImpl)type).getPyClass();

python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,9 +1044,10 @@ private static Ref<PyType> getOptionalType(@NotNull PsiElement element, @NotNull
10441044
@Nullable
10451045
private static Ref<PyType> getLiteralStringType(@NotNull PsiElement resolved, @NotNull Context context) {
10461046
if (resolved instanceof PyTargetExpression referenceExpression) {
1047-
final Collection<String> operandNames = resolveToQualifiedNames(referenceExpression, context.getTypeContext());
1047+
Collection<String> operandNames = resolveToQualifiedNames(referenceExpression, context.getTypeContext());
10481048
if (ContainerUtil.exists(operandNames, name -> name.equals(LITERALSTRING) || name.equals(LITERALSTRING_EXT))) {
1049-
return Ref.create(PyBuiltinCache.getInstance(resolved).getStringType(LanguageLevel.forElement(resolved)));
1049+
PyType strType = PyBuiltinCache.getInstance(resolved).getStringType(LanguageLevel.forElement(resolved));
1050+
return Ref.create(PyLiteralStringType.Companion.create(resolved, false));
10501051
}
10511052
}
10521053

python/python-psi-impl/src/com/jetbrains/python/inspections/PyStringFormatInspection.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ else if (myExpectedArguments > arguments) {
451451
private static class NewStyleInspection {
452452

453453
private static final List<String> CHECKED_TYPES =
454-
Arrays.asList(PyNames.TYPE_STR, PyNames.TYPE_INT, PyNames.TYPE_LONG, "float", "complex", "None");
454+
Arrays.asList(PyNames.TYPE_STR, PyNames.TYPE_INT, PyNames.TYPE_LONG, "float", "complex", "None", "LiteralString");
455455

456456
private static final List<String> NUMERIC_TYPES = Arrays.asList(PyNames.TYPE_INT, PyNames.TYPE_LONG, "float", "complex");
457457

python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
import com.intellij.openapi.util.Key;
88
import com.intellij.openapi.util.Ref;
99
import com.intellij.openapi.util.text.StringUtil;
10+
import com.intellij.psi.PsiElement;
1011
import com.intellij.psi.PsiElementVisitor;
12+
import com.intellij.psi.impl.source.tree.LeafPsiElement;
1113
import com.intellij.util.containers.ContainerUtil;
1214
import com.jetbrains.python.PyNames;
1315
import com.jetbrains.python.PyPsiBundle;
16+
import com.jetbrains.python.PyTokenTypes;
1417
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
1518
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
1619
import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;

python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyBinaryExpressionImpl.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.intellij.psi.util.PsiTreeUtil;
1010
import com.intellij.psi.util.QualifiedName;
1111
import com.intellij.util.IncorrectOperationException;
12+
import com.intellij.util.containers.ContainerUtil;
1213
import com.jetbrains.python.PyElementTypes;
1314
import com.jetbrains.python.PyNames;
1415
import com.jetbrains.python.PyTokenTypes;
@@ -164,7 +165,7 @@ public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext
164165
final boolean bothOperandsAreKnown = operandIsKnown(getLeftExpression(), context) && operandIsKnown(getRightExpression(), context);
165166
final List<PyType> resultTypes = !matchedTypes.isEmpty() ? matchedTypes : types;
166167
if (!resultTypes.isEmpty()) {
167-
final PyType result = PyUnionType.union(resultTypes);
168+
final PyType result = bothArgumentsAreLiteralStrings(matchedTypes, context) ? resultTypes.get(0) : PyUnionType.union(resultTypes);
168169
return bothOperandsAreKnown ? result : PyUnionType.createWeakType(result);
169170
}
170171
}
@@ -174,6 +175,16 @@ public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext
174175
return null;
175176
}
176177

178+
private boolean bothArgumentsAreLiteralStrings(List<PyType> matchedTypes, TypeEvalContext context) {
179+
PyExpression left = getLeftExpression();
180+
PyType leftType = left != null ? context.getType(left) : null;
181+
PyExpression right = getRightExpression();
182+
PyType rightType = right != null ? context.getType(right) : null;
183+
return leftType instanceof PyLiteralStringType &&
184+
rightType instanceof PyLiteralStringType &&
185+
ContainerUtil.exists(matchedTypes, it -> it instanceof PyLiteralStringType);
186+
}
187+
177188
@Override
178189
public PyExpression getQualifier() {
179190
return getLeftExpression();

python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ public ItemPresentation getPresentation() {
304304
PyTypeChecker.match(returnClassType.toClass(), receiverClassType.toClass(), context)) {
305305
return returnClassType.isDefinition() ? receiverClassType.toClass() : receiverClassType.toInstance();
306306
}
307+
308+
if (receiverClassType.getPyClass() == returnClassType.getPyClass() &&
309+
returnClassType instanceof PyLiteralStringType &&
310+
"str".equals(receiverClassType.getName())) {
311+
return returnClassType.isDefinition() ? receiverClassType.toClass() : receiverClassType.toInstance();
312+
}
307313
}
308314
}
309315
else if (allowCoroutineOrGenerator &&

python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyStringLiteralExpressionImpl.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import com.jetbrains.python.PyTokenTypes;
1919
import com.jetbrains.python.lexer.PythonHighlightingLexer;
2020
import com.jetbrains.python.psi.*;
21+
import com.jetbrains.python.psi.types.PyClassType;
22+
import com.jetbrains.python.psi.types.PyLiteralStringType;
2123
import com.jetbrains.python.psi.types.PyType;
2224
import com.jetbrains.python.psi.types.TypeEvalContext;
2325
import one.util.streamex.StreamEx;
@@ -164,13 +166,26 @@ public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext
164166
final LanguageLevel languageLevel = file == null ? LanguageLevel.forElement(this) : file.getLanguageLevel();
165167

166168
final ASTNode firstNode = ContainerUtil.getFirstItem(getStringNodes());
169+
final PyClassType strType = builtinCache.getStrType();
170+
final PyClassType litStr = PyLiteralStringType.Companion.create(this, true);
167171
if (firstNode != null) {
168172
if (firstNode.getElementType() == PyElementTypes.FSTRING_NODE) {
169173
// f-strings can't have "b" prefix, so they are always unicode
170-
return builtinCache.getUnicodeType(languageLevel);
174+
if (languageLevel.isPy3K()) {
175+
boolean allLiteralStringFragments = StreamEx.of(this.getStringElements())
176+
.select(PyFormattedStringElement.class)
177+
.flatMap(element -> element.getFragments().stream())
178+
.map(fragment -> fragment != null && fragment.getExpression() != null ? context.getType(fragment.getExpression()) : null)
179+
.nonNull()
180+
.allMatch(type -> type instanceof PyLiteralStringType);
181+
return allLiteralStringFragments ? litStr : strType;
182+
}
183+
else {
184+
return builtinCache.getUnicodeType(languageLevel);
185+
}
171186
}
172187
else if (firstNode.getElementType() == PyTokenTypes.DOCSTRING) {
173-
return builtinCache.getStrType();
188+
return litStr != null ? litStr : strType;
174189
}
175190
else if (((PyStringElement)firstNode).isBytes()) {
176191
return builtinCache.getBytesType(languageLevel);
@@ -182,10 +197,15 @@ else if (((PyStringElement)firstNode).isBytes()) {
182197
(file != null &&
183198
file.hasImportFromFuture(FutureFeature.UNICODE_LITERALS)));
184199
if (PyTokenTypes.UNICODE_NODES.contains(type)) {
185-
return builtinCache.getUnicodeType(languageLevel);
200+
if (languageLevel.isPy3K()) {
201+
return litStr != null ? litStr : strType;
202+
}
203+
else {
204+
return builtinCache.getUnicodeType(languageLevel);
205+
}
186206
}
187207
}
188-
return builtinCache.getStrType();
208+
return litStr != null ? litStr : strType;
189209
}
190210

191211
@Override

python/python-psi-impl/src/com/jetbrains/python/psi/types/PyCollectionTypeUtil.kt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil
1111
import com.jetbrains.python.psi.*
1212
import com.jetbrains.python.psi.impl.PyBuiltinCache
1313
import com.jetbrains.python.psi.resolve.PyResolveContext
14+
import com.jetbrains.python.pyi.PyiUtil
1415

1516
object PyCollectionTypeUtil {
1617

@@ -201,7 +202,7 @@ object PyCollectionTypeUtil {
201202
val strKeysToValueTypes = LinkedHashMap<String, Pair<PyExpression?, PyType?>>()
202203
var allStrKeys = true
203204

204-
if (keyType is PyClassType && "str" == keyType.name) {
205+
if (keyType is PyLiteralStringType || keyType is PyClassType && ("str" == keyType.name)) {
205206
when (tuple) {
206207
is PyKeyValueExpression -> {
207208
if (tuple.key is PyStringLiteralExpression) {
@@ -528,10 +529,7 @@ object PyCollectionTypeUtil {
528529
for (arg in arguments) {
529530
if (arg is PyKeywordArgument) {
530531
if (!keyStrAdded) {
531-
val strType = PyBuiltinCache.getInstance(myElement).strType
532-
if (strType != null) {
533-
keyTypes.add(strType)
534-
}
532+
keyTypes.add(PyLiteralStringType.create(arg, true))
535533
keyStrAdded = true
536534
}
537535
val value = PyUtil.peelArgument(arg)

0 commit comments

Comments
 (0)